FlyingQianMM 5 жил өмнө
parent
commit
6fda0a8c25

+ 2 - 5
paddlex/cv/datasets/coco.py

@@ -100,7 +100,7 @@ class CocoDetection(VOCDetection):
             gt_score = np.ones((num_bbox, 1), dtype=np.float32)
             is_crowd = np.zeros((num_bbox, 1), dtype=np.int32)
             difficult = np.zeros((num_bbox, 1), dtype=np.int32)
-            gt_poly = None
+            gt_poly = [None] * num_bbox
 
             for i, box in enumerate(bboxes):
                 catid = box['category_id']
@@ -108,8 +108,6 @@ class CocoDetection(VOCDetection):
                 gt_bbox[i, :] = box['clean_bbox']
                 is_crowd[i][0] = box['iscrowd']
                 if 'segmentation' in box:
-                    if gt_poly is None:
-                        gt_poly = [None] * num_bbox
                     gt_poly[i] = box['segmentation']
 
             im_info = {
@@ -121,10 +119,9 @@ class CocoDetection(VOCDetection):
                 'gt_class': gt_class,
                 'gt_bbox': gt_bbox,
                 'gt_score': gt_score,
+                'gt_poly': gt_poly,
                 'difficult': difficult
             }
-            if gt_poly is not None:
-                label_info['gt_poly'] = gt_poly
 
             coco_rec = (im_info, label_info)
             self.file_list.append([im_fname, coco_rec])

+ 23 - 11
paddlex/cv/datasets/voc.py

@@ -106,16 +106,20 @@ class VOCDetection(Dataset):
                     ct = int(tree.find('id').text)
                     im_id = np.array([int(tree.find('id').text)])
                 pattern = re.compile('<object>', re.IGNORECASE)
-                obj_tag = pattern.findall(str(ET.tostringlist(tree.getroot())))[0][1:-1]
+                obj_tag = pattern.findall(
+                    str(ET.tostringlist(tree.getroot())))[0][1:-1]
                 objs = tree.findall(obj_tag)
                 pattern = re.compile('<size>', re.IGNORECASE)
-                size_tag = pattern.findall(str(ET.tostringlist(tree.getroot())))[0][1:-1]
+                size_tag = pattern.findall(
+                    str(ET.tostringlist(tree.getroot())))[0][1:-1]
                 size_element = tree.find(size_tag)
                 pattern = re.compile('<width>', re.IGNORECASE)
-                width_tag = pattern.findall(str(ET.tostringlist(size_element)))[0][1:-1]
+                width_tag = pattern.findall(
+                    str(ET.tostringlist(size_element)))[0][1:-1]
                 im_w = float(size_element.find(width_tag).text)
                 pattern = re.compile('<height>', re.IGNORECASE)
-                height_tag = pattern.findall(str(ET.tostringlist(size_element)))[0][1:-1]
+                height_tag = pattern.findall(
+                    str(ET.tostringlist(size_element)))[0][1:-1]
                 im_h = float(size_element.find(height_tag).text)
                 gt_bbox = np.zeros((len(objs), 4), dtype=np.float32)
                 gt_class = np.zeros((len(objs), 1), dtype=np.int32)
@@ -124,29 +128,36 @@ class VOCDetection(Dataset):
                 difficult = np.zeros((len(objs), 1), dtype=np.int32)
                 for i, obj in enumerate(objs):
                     pattern = re.compile('<name>', re.IGNORECASE)
-                    name_tag = pattern.findall(str(ET.tostringlist(obj)))[0][1:-1]
+                    name_tag = pattern.findall(str(ET.tostringlist(obj)))[0][
+                        1:-1]
                     cname = obj.find(name_tag).text.strip()
                     gt_class[i][0] = cname2cid[cname]
                     pattern = re.compile('<difficult>', re.IGNORECASE)
-                    diff_tag = pattern.findall(str(ET.tostringlist(obj)))[0][1:-1]
+                    diff_tag = pattern.findall(str(ET.tostringlist(obj)))[0][
+                        1:-1]
                     try:
                         _difficult = int(obj.find(diff_tag).text)
                     except Exception:
                         _difficult = 0
                     pattern = re.compile('<bndbox>', re.IGNORECASE)
-                    box_tag = pattern.findall(str(ET.tostringlist(obj)))[0][1:-1]
+                    box_tag = pattern.findall(str(ET.tostringlist(obj)))[0][1:
+                                                                            -1]
                     box_element = obj.find(box_tag)
                     pattern = re.compile('<xmin>', re.IGNORECASE)
-                    xmin_tag = pattern.findall(str(ET.tostringlist(box_element)))[0][1:-1]
+                    xmin_tag = pattern.findall(
+                        str(ET.tostringlist(box_element)))[0][1:-1]
                     x1 = float(box_element.find(xmin_tag).text)
                     pattern = re.compile('<ymin>', re.IGNORECASE)
-                    ymin_tag = pattern.findall(str(ET.tostringlist(box_element)))[0][1:-1]
+                    ymin_tag = pattern.findall(
+                        str(ET.tostringlist(box_element)))[0][1:-1]
                     y1 = float(box_element.find(ymin_tag).text)
                     pattern = re.compile('<xmax>', re.IGNORECASE)
-                    xmax_tag = pattern.findall(str(ET.tostringlist(box_element)))[0][1:-1]
+                    xmax_tag = pattern.findall(
+                        str(ET.tostringlist(box_element)))[0][1:-1]
                     x2 = float(box_element.find(xmax_tag).text)
                     pattern = re.compile('<ymax>', re.IGNORECASE)
-                    ymax_tag = pattern.findall(str(ET.tostringlist(box_element)))[0][1:-1]
+                    ymax_tag = pattern.findall(
+                        str(ET.tostringlist(box_element)))[0][1:-1]
                     y2 = float(box_element.find(ymax_tag).text)
                     x1 = max(0, x1)
                     y1 = max(0, y1)
@@ -176,6 +187,7 @@ class VOCDetection(Dataset):
                     'gt_class': gt_class,
                     'gt_bbox': gt_bbox,
                     'gt_score': gt_score,
+                    'gt_poly': [],
                     'difficult': difficult
                 }
                 voc_rec = (im_info, label_info)