Kaynağa Gözat

fix the voc dataset

sunyanfang01 5 yıl önce
ebeveyn
işleme
3db718af2b
1 değiştirilmiş dosya ile 34 ekleme ve 10 silme
  1. 34 10
      paddlex/cv/datasets/voc.py

+ 34 - 10
paddlex/cv/datasets/voc.py

@@ -16,6 +16,7 @@ from __future__ import absolute_import
 import copy
 import os.path as osp
 import random
+import re
 import numpy as np
 from collections import OrderedDict
 import xml.etree.ElementTree as ET
@@ -103,23 +104,46 @@ class VOCDetection(Dataset):
                 else:
                     ct = int(tree.find('id').text)
                     im_id = np.array([int(tree.find('id').text)])
-
-                objs = tree.findall('object')
-                im_w = float(tree.find('size').find('width').text)
-                im_h = float(tree.find('size').find('height').text)
+                pattern = re.compile('<object>', re.IGNORECASE)
+                obj_tag = pattern.findall(str(ET.tostringlist(tree.getroot())))[0][1:-1]
+                objs = tree.findall(obj_tag)
+                pattern = re.compile('<size>', re.IGNORECASE)
+                size_tag = pattern.findall(str(ET.tostringlist(tree.getroot())))[0][1:-1]
+                size_element = tree.find(size_tag)
+                pattern = re.compile('<width>', re.IGNORECASE)
+                width_tag = pattern.findall(str(ET.tostringlist(size_element)))[0][1:-1]
+                im_w = float(size_element.find(width_tag).text)
+                pattern = re.compile('<height>', re.IGNORECASE)
+                height_tag = pattern.findall(str(ET.tostringlist(size_element)))[0][1:-1]
+                im_h = float(size_element.find(height_tag).text)
                 gt_bbox = np.zeros((len(objs), 4), dtype=np.float32)
                 gt_class = np.zeros((len(objs), 1), dtype=np.int32)
                 gt_score = np.ones((len(objs), 1), dtype=np.float32)
                 is_crowd = np.zeros((len(objs), 1), dtype=np.int32)
                 difficult = np.zeros((len(objs), 1), dtype=np.int32)
                 for i, obj in enumerate(objs):
-                    cname = obj.find('name').text
+                    pattern = re.compile('<name>', re.IGNORECASE)
+                    name_tag = pattern.findall(str(ET.tostringlist(obj)))[0][1:-1]
+                    cname = obj.find(name_tag).text
                     gt_class[i][0] = cname2cid[cname]
-                    _difficult = int(obj.find('difficult').text)
-                    x1 = float(obj.find('bndbox').find('xmin').text)
-                    y1 = float(obj.find('bndbox').find('ymin').text)
-                    x2 = float(obj.find('bndbox').find('xmax').text)
-                    y2 = float(obj.find('bndbox').find('ymax').text)
+                    pattern = re.compile('<difficult>', re.IGNORECASE)
+                    diff_tag = pattern.findall(str(ET.tostringlist(obj)))[0][1:-1]
+                    _difficult = int(obj.find(diff_tag).text)
+                    pattern = re.compile('<bndbox>', re.IGNORECASE)
+                    box_tag = pattern.findall(str(ET.tostringlist(obj)))[0][1:-1]
+                    box_element = obj.find(box_tag)
+                    pattern = re.compile('<xmin>', re.IGNORECASE)
+                    xmin_element = pattern.findall(str(ET.tostringlist(box_element)))[0][1:-1]
+                    x1 = float(box_element.find(xmin_element).text)
+                    pattern = re.compile('<ymin>', re.IGNORECASE)
+                    ymin_element = pattern.findall(str(ET.tostringlist(box_element)))[0][1:-1]
+                    y1 = float(box_element.find(ymin_element).text)
+                    pattern = re.compile('<xmax>', re.IGNORECASE)
+                    xmax_element = pattern.findall(str(ET.tostringlist(box_element)))[0][1:-1]
+                    x2 = float(box_element.find(xmax_element).text)
+                    pattern = re.compile('<ymax>', re.IGNORECASE)
+                    ymax_element = pattern.findall(str(ET.tostringlist(box_element)))[0][1:-1]
+                    y2 = float(box_element.find(ymax_element).text)
                     x1 = max(0, x1)
                     y1 = max(0, y1)
                     if im_w > 0.5 and im_h > 0.5: