Quellcode durchsuchen

remove requirement of field for voc dataset

jiangjiajun vor 4 Jahren
Ursprung
Commit
c4d02a9610
3 geänderte Dateien mit 17 neuen und 12 gelöschten Zeilen
  1. 1 1
      paddlex/__init__.py
  2. 15 10
      paddlex/cv/datasets/voc.py
  3. 1 1
      setup.py

+ 1 - 1
paddlex/__init__.py

@@ -14,7 +14,7 @@
 
 from __future__ import absolute_import
 
-__version__ = '1.3.7'
+__version__ = '1.3.8'
 
 import os
 if 'FLAGS_eager_delete_tensor_gb' not in os.environ:

+ 15 - 10
paddlex/cv/datasets/voc.py

@@ -131,16 +131,21 @@ class VOCDetection(Dataset):
                 objs = tree.findall(obj_tag)
                 pattern = re.compile('<size>', re.IGNORECASE)
                 size_tag = pattern.findall(
-                    str(ET.tostringlist(tree.getroot())))[0][1:-1]
-                size_element = tree.find(size_tag)
-                pattern = re.compile('<width>', re.IGNORECASE)
-                width_tag = pattern.findall(
-                    str(ET.tostringlist(size_element)))[0][1:-1]
-                im_w = float(size_element.find(width_tag).text)
-                pattern = re.compile('<height>', re.IGNORECASE)
-                height_tag = pattern.findall(
-                    str(ET.tostringlist(size_element)))[0][1:-1]
-                im_h = float(size_element.find(height_tag).text)
+                    str(ET.tostringlist(tree.getroot())))
+                if len(size_tag) > 0:
+                    size_tag = size_tag[0][1:-1]
+                    size_element = tree.find(size_tag)
+                    pattern = re.compile('<width>', re.IGNORECASE)
+                    width_tag = pattern.findall(
+                        str(ET.tostringlist(size_element)))[0][1:-1]
+                    im_w = float(size_element.find(width_tag).text)
+                    pattern = re.compile('<height>', re.IGNORECASE)
+                    height_tag = pattern.findall(
+                        str(ET.tostringlist(size_element)))[0][1:-1]
+                    im_h = float(size_element.find(height_tag).text)
+                else:
+                    im_w = 0
+                    im_h = 0
                 gt_bbox = np.zeros((len(objs), 4), dtype=np.float32)
                 gt_class = np.zeros((len(objs), 1), dtype=np.int32)
                 gt_score = np.ones((len(objs), 1), dtype=np.float32)

+ 1 - 1
setup.py

@@ -19,7 +19,7 @@ long_description = "PaddlePaddle Entire Process Development Toolkit"
 
 setuptools.setup(
     name="paddlex",
-    version='1.3.7',
+    version='1.3.8',
     author="paddlex",
     author_email="paddlex@baidu.com",
     description=long_description,