浏览代码

add reference in _summarize

FlyingQianMM 4 年之前
父节点
当前提交
321d7095ba

+ 32 - 13
paddlex/cv/models/utils/det_metrics/coco_utils.py

@@ -133,7 +133,7 @@ def cocoapi_eval(anns,
         results_flatten = list(itertools.chain(*results_per_category))
         headers = ['category', 'AP'] * (num_columns // 2)
         results_2d = itertools.zip_longest(
-            *[results_flatten[i::num_columns] for i in range(num_columns)])
+            * [results_flatten[i::num_columns] for i in range(num_columns)])
         table_data = [headers]
         table_data += [result for result in results_2d]
         table = AsciiTable(table_data)
@@ -220,6 +220,21 @@ def loadRes(coco_obj, anns):
 
 
 def makeplot(rs, ps, outDir, class_name, iou_type):
+    """针对某个特定类别,绘制不同评估要求下的准确率和召回率。
+       绘制结果说明参考COCODataset官网给出分析工具说明https://cocodataset.org/#detection-eval。
+
+       Refer to https://github.com/open-mmlab/mmdetection/blob/master/tools/analysis_tools/coco_error_analysis.py#L13
+
+       Args:
+           rs (np.array): 在不同置信度阈值下计算得到的召回率。
+           ps (np.array): 在不同置信度阈值下计算得到的准确率。ps与rs相同位置下的数值为同一个置信度阈值
+               计算得到的准确率与召回率。
+           outDir (str): 图表保存的路径。
+           class_name (str): 类别名。
+           iou_type (str): iou计算方式,若为检测框,则设置为'bbox',若为像素级分割结果,则设置为'segm'。
+
+    """
+
     import matplotlib.pyplot as plt
     cs = np.vstack([
         np.ones((2, 3)),
@@ -264,21 +279,21 @@ def analyze_individual_category(k, cocoDt, cocoGt, catId, iou_type,
                                 areas=None):
     """针对某个特定类别,分析忽略亚类混淆和类别混淆时的准确率。
 
-           Refer to https://github.com/open-mmlab/mmdetection/blob/master/tools/coco_error_analysis.py
+       Refer to https://github.com/open-mmlab/mmdetection/blob/master/tools/coco_error_analysis.py
 
-           Args:
-               k (int): 待分析类别的序号。
-               cocoDt (pycocotols.coco.COCO): 按COCO类存放的预测结果。
-               cocoGt (pycocotols.coco.COCO): 按COCO类存放的真值。
-               catId (int): 待分析类别在数据集中的类别id。
-               iou_type (str): iou计算方式,若为检测框,则设置为'bbox',若为像素级分割结果,则设置为'segm'。
+       Args:
+           k (int): 待分析类别的序号。
+           cocoDt (pycocotols.coco.COCO): 按COCO类存放的预测结果。
+           cocoGt (pycocotols.coco.COCO): 按COCO类存放的真值。
+           catId (int): 待分析类别在数据集中的类别id。
+           iou_type (str): iou计算方式,若为检测框,则设置为'bbox',若为像素级分割结果,则设置为'segm'。
 
-           Returns:
-               int:
-               dict: 有关键字'ps_supercategory'和'ps_allcategory'。关键字'ps_supercategory'的键值是忽略亚类间
-                   混淆时的准确率,关键字'ps_allcategory'的键值是忽略类别间混淆时的准确率。
+       Returns:
+           int:
+           dict: 有关键字'ps_supercategory'和'ps_allcategory'。关键字'ps_supercategory'的键值是忽略亚类间
+               混淆时的准确率,关键字'ps_allcategory'的键值是忽略类别间混淆时的准确率。
 
-        """
+    """
 
     # matplotlib.use() must be called *before* pylab, matplotlib.pyplot,
     # or matplotlib.backends is imported for the first time
@@ -394,6 +409,10 @@ def coco_error_analysis(eval_details_file=None,
         raise Exception("There is no predicted mask.")
 
     def _analyze_results(cocoGt, cocoDt, res_type, out_dir):
+        """
+        Refer to
+        https://github.com/open-mmlab/mmdetection/blob/master/tools/analysis_tools/coco_error_analysis.py#L235
+        """
         directory = osp.dirname(osp.join(out_dir, ''))
         if not osp.exists(directory):
             logging.info('-------------create {}-----------------'.format(

+ 6 - 0
paddlex/cv/models/utils/visualize.py

@@ -298,6 +298,12 @@ def draw_pr_curve(eval_details_file=None,
     coco.createIndex()
 
     def _summarize(coco_gt, ap=1, iouThr=None, areaRng='all', maxDets=100):
+        """
+        This function has the same functionality as _summarize() in pycocotools.COCOeval.summarize().
+
+        Refer to
+        https://github.com/cocodataset/cocoapi/blob/8c9bcc3cf640524c4c20a9c40e89cb6a2f2fa0e9/PythonAPI/pycocotools/cocoeval.py#L427,
+        """
         p = coco_gt.params
         aind = [i for i, aRng in enumerate(p.areaRngLbl) if aRng == areaRng]
         mind = [i for i, mDet in enumerate(p.maxDets) if mDet == maxDets]

+ 10 - 1
static/paddlex/cv/models/utils/detection_eval.py

@@ -144,6 +144,11 @@ def loadRes(coco_obj, anns):
     :param   resFile (str)     : file name of result file
     :return: res (obj)         : result api object
     """
+    # This function has the same functionality as pycocotools.COCO.loadRes,
+    # except that the input anns is list of results rather than a json file.
+    # Refer to
+    # https://github.com/cocodataset/cocoapi/blob/8c9bcc3cf640524c4c20a9c40e89cb6a2f2fa0e9/PythonAPI/pycocotools/coco.py#L305,
+
     # matplotlib.use() must be called *before* pylab, matplotlib.pyplot,
     # or matplotlib.backends is imported for the first time
     # pycocotools import matplotlib
@@ -805,7 +810,7 @@ def makeplot(rs, ps, outDir, class_name, iou_type):
     """针对某个特定类别,绘制不同评估要求下的准确率和召回率。
        绘制结果说明参考COCODataset官网给出分析工具说明https://cocodataset.org/#detection-eval。
 
-       Refer to https://github.com/open-mmlab/mmdetection/blob/master/tools/coco_error_analysis.py
+       Refer to https://github.com/open-mmlab/mmdetection/blob/master/tools/analysis_tools/coco_error_analysis.py#L13
 
        Args:
            rs (np.array): 在不同置信度阈值下计算得到的召回率。
@@ -983,6 +988,10 @@ def coco_error_analysis(eval_details_file=None,
         raise Exception("There is no predicted mask.")
 
     def _analyze_results(cocoGt, cocoDt, res_type, out_dir):
+        """
+        Refer to
+        https://github.com/open-mmlab/mmdetection/blob/master/tools/analysis_tools/coco_error_analysis.py#L235
+        """
         directory = os.path.dirname(out_dir + '/')
         if not os.path.exists(directory):
             logging.info('-------------create {}-----------------'.format(