Parcourir la source

cat id starts from 1 in voc dataset

will-jl944 il y a 4 ans
Parent
commit
5587a9a41a
2 fichiers modifiés avec 4 ajouts et 3 suppressions
  1. 2 2
      dygraph/paddlex/cv/datasets/voc.py
  2. 2 1
      dygraph/paddlex/cv/models/detector.py

+ 2 - 2
dygraph/paddlex/cv/datasets/voc.py

@@ -89,7 +89,7 @@ class VOCDetection(Dataset):
         for k, v in cname2cid.items():
             annotations['categories'].append({
                 'supercategory': 'component',
-                'id': v,
+                'id': v + 1,
                 'name': k
             })
         ct = 0
@@ -219,7 +219,7 @@ class VOCDetection(Dataset):
                         'image_id': int(im_id[0]),
                         'bbox': [x1, y1, x2 - x1 + 1, y2 - y1 + 1],
                         'area': float((x2 - x1 + 1) * (y2 - y1 + 1)),
-                        'category_id': cname2cid[cname],
+                        'category_id': cname2cid[cname] + 1,
                         'id': ann_ct,
                         'difficult': _difficult
                     })

+ 2 - 1
dygraph/paddlex/cv/models/detector.py

@@ -512,7 +512,7 @@ class BaseDetector(BaseModel):
                     h = ymax - ymin
                     bbox = [xmin, ymin, w, h]
                     dt_res = {
-                        'category_id': int(num_id),
+                        'category_id': int(num_id) + 1,
                         'category': category,
                         'bbox': bbox,
                         'score': score
@@ -544,6 +544,7 @@ class BaseDetector(BaseModel):
                         if 'counts' in rle:
                             rle['counts'] = rle['counts'].decode("utf8")
                     sg_res = {
+                        'category_id': int(label) + 1,
                         'category': category,
                         'segmentation': rle,
                         'score': score