浏览代码

Merge pull request #855 from will-jl944/develop_jf

fix bug if offset between cls_id and cat_id is not 1
FlyingQianMM 4 年之前
父节点
当前提交
c49b06ea46
共有 2 个文件被更改,包括 7 次插入2 次删除
  1. 1 1
      dygraph/paddlex/cv/models/detector.py
  2. 6 1
      dygraph/paddlex/cv/models/utils/det_metrics/metrics.py

+ 1 - 1
dygraph/paddlex/cv/models/detector.py

@@ -512,7 +512,7 @@ class BaseDetector(BaseModel):
                     h = ymax - ymin
                     bbox = [xmin, ymin, w, h]
                     dt_res = {
-                        'category_id': int(num_id) + 1,
+                        'category_id': int(num_id),
                         'category': category,
                         'bbox': bbox,
                         'score': score

+ 6 - 1
dygraph/paddlex/cv/models/utils/det_metrics/metrics.py

@@ -61,6 +61,11 @@ class VOCMetric(Metric):
                  classwise=False):
         self.cid2cname = {i: name for i, name in enumerate(labels)}
         self.coco_gt = coco_gt
+        self.clsid2catid = {
+            i: cat['id']
+            for i, cat in enumerate(
+                self.coco_gt.loadCats(self.coco_gt.getCatIds()))
+        }
         self.overlap_thresh = overlap_thresh
         self.map_type = map_type
         self.evaluate_difficult = evaluate_difficult
@@ -121,7 +126,7 @@ class VOCMetric(Metric):
                 bbox = [xmin, ymin, w, h]
                 coco_res = {
                     'image_id': int(inputs['im_id']),
-                    'category_id': int(l + 1),
+                    'category_id': self.clsid2catid[int(l)],
                     'bbox': bbox,
                     'score': float(s)
                 }