Sfoglia il codice sorgente

Merge pull request #96 from SunAhong1993/syf_docs

fix the voc dataset
Jason 5 anni fa
parent
commit
c2cdf8b2bc
1 ha cambiato i file con 37 aggiunte e 10 eliminazioni
  1. 37 10
      paddlex/cv/datasets/voc.py

+ 37 - 10
paddlex/cv/datasets/voc.py

@@ -17,6 +17,7 @@ import copy
 import os
 import os.path as osp
 import random
+import re
 import numpy as np
 from collections import OrderedDict
 import xml.etree.ElementTree as ET
@@ -104,23 +105,49 @@ class VOCDetection(Dataset):
                 else:
                     ct = int(tree.find('id').text)
                     im_id = np.array([int(tree.find('id').text)])
-
-                objs = tree.findall('object')
-                im_w = float(tree.find('size').find('width').text)
-                im_h = float(tree.find('size').find('height').text)
+                pattern = re.compile('<object>', re.IGNORECASE)
+                obj_tag = pattern.findall(str(ET.tostringlist(tree.getroot())))[0][1:-1]
+                objs = tree.findall(obj_tag)
+                pattern = re.compile('<size>', re.IGNORECASE)
+                size_tag = pattern.findall(str(ET.tostringlist(tree.getroot())))[0][1:-1]
+                size_element = tree.find(size_tag)
+                pattern = re.compile('<width>', re.IGNORECASE)
+                width_tag = pattern.findall(str(ET.tostringlist(size_element)))[0][1:-1]
+                im_w = float(size_element.find(width_tag).text)
+                pattern = re.compile('<height>', re.IGNORECASE)
+                height_tag = pattern.findall(str(ET.tostringlist(size_element)))[0][1:-1]
+                im_h = float(size_element.find(height_tag).text)
                 gt_bbox = np.zeros((len(objs), 4), dtype=np.float32)
                 gt_class = np.zeros((len(objs), 1), dtype=np.int32)
                 gt_score = np.ones((len(objs), 1), dtype=np.float32)
                 is_crowd = np.zeros((len(objs), 1), dtype=np.int32)
                 difficult = np.zeros((len(objs), 1), dtype=np.int32)
                 for i, obj in enumerate(objs):
-                    cname = obj.find('name').text.strip()
+                    pattern = re.compile('<name>', re.IGNORECASE)
+                    name_tag = pattern.findall(str(ET.tostringlist(obj)))[0][1:-1]
+                    cname = obj.find(name_tag).text.strip()
                     gt_class[i][0] = cname2cid[cname]
-                    _difficult = int(obj.find('difficult').text)
-                    x1 = float(obj.find('bndbox').find('xmin').text)
-                    y1 = float(obj.find('bndbox').find('ymin').text)
-                    x2 = float(obj.find('bndbox').find('xmax').text)
-                    y2 = float(obj.find('bndbox').find('ymax').text)
+                    pattern = re.compile('<difficult>', re.IGNORECASE)
+                    diff_tag = pattern.findall(str(ET.tostringlist(obj)))[0][1:-1]
+                    try:
+                        _difficult = int(obj.find(diff_tag).text)
+                    except Exception:
+                        _difficult = 0
+                    pattern = re.compile('<bndbox>', re.IGNORECASE)
+                    box_tag = pattern.findall(str(ET.tostringlist(obj)))[0][1:-1]
+                    box_element = obj.find(box_tag)
+                    pattern = re.compile('<xmin>', re.IGNORECASE)
+                    xmin_tag = pattern.findall(str(ET.tostringlist(box_element)))[0][1:-1]
+                    x1 = float(box_element.find(xmin_tag).text)
+                    pattern = re.compile('<ymin>', re.IGNORECASE)
+                    ymin_tag = pattern.findall(str(ET.tostringlist(box_element)))[0][1:-1]
+                    y1 = float(box_element.find(ymin_tag).text)
+                    pattern = re.compile('<xmax>', re.IGNORECASE)
+                    xmax_tag = pattern.findall(str(ET.tostringlist(box_element)))[0][1:-1]
+                    x2 = float(box_element.find(xmax_tag).text)
+                    pattern = re.compile('<ymax>', re.IGNORECASE)
+                    ymax_tag = pattern.findall(str(ET.tostringlist(box_element)))[0][1:-1]
+                    y2 = float(box_element.find(ymax_tag).text)
                     x1 = max(0, x1)
                     y1 = max(0, y1)
                     if im_w > 0.5 and im_h > 0.5: