Parcourir la source

add build_transforms_v1 for old version paddlex

jiangjiajun il y a 5 ans
Parent
commit
186683aaf8
1 fichiers modifiés avec 60 ajouts et 6 suppressions
  1. 60 6
      paddlex/cv/models/load_model.py

+ 60 - 6
paddlex/cv/models/load_model.py

@@ -28,7 +28,12 @@ def load_model(model_dir):
         raise Exception("There's not model.yml in {}".format(model_dir))
     with open(osp.join(model_dir, "model.yml")) as f:
         info = yaml.load(f.read(), Loader=yaml.Loader)
-    status = info['status']
+
+    if 'status' in info:
+        status = info['status']
+    elif 'save_method' in info:
+        # 兼容老版本PaddleX
+        status = info['save_method']
 
     if not hasattr(paddlex.cv.models, info['Model']):
         raise Exception("There's no attribute {} in paddlex.cv.models".format(
@@ -40,7 +45,7 @@ def load_model(model_dir):
         model = getattr(paddlex.cv.models,
                         info['Model'])(**info['_init_params'])
     if status == "Normal" or \
-            status == "Prune":
+            status == "Prune" or status == "fluid.save":
         startup_prog = fluid.Program()
         model.test_prog = fluid.Program()
         with fluid.program_guard(model.test_prog, startup_prog):
@@ -59,7 +64,7 @@ def load_model(model_dir):
         fluid.io.set_program_state(model.test_prog, load_dict)
 
     elif status == "Infer" or \
-            status == "Quant":
+            status == "Quant" or status == "fluid.save_inference_model":
         [prog, input_names, outputs] = fluid.io.load_inference_model(
             model_dir, model.exe, params_filename='__params__')
         model.test_prog = prog
@@ -77,9 +82,15 @@ def load_model(model_dir):
             to_rgb = True
         else:
             to_rgb = False
-        model.test_transforms = build_transforms(model.model_type,
-                                                 info['Transforms'], to_rgb)
-        model.eval_transforms = copy.deepcopy(model.test_transforms)
+        if 'BatchTransforms' in info:
+            # 兼容老版本PaddleX模型
+            model.test_transforms = build_transforms_v1(
+                model.model_type, info['Transforms'], info['BatchTransforms'])
+            model.eval_transforms = copy.deepcopy(model.test_transforms)
+        else:
+            model.test_transforms = build_transforms(
+                model.model_type, info['Transforms'], to_rgb)
+            model.eval_transforms = copy.deepcopy(model.test_transforms)
 
     if '_Attributes' in info:
         for k, v in info['_Attributes'].items():
@@ -109,3 +120,46 @@ def build_transforms(model_type, transforms_info, to_rgb=True):
     eval_transforms = T.Compose(transforms)
     eval_transforms.to_rgb = to_rgb
     return eval_transforms
+
+
+def build_transforms_v1(model_type, transforms_info, batch_transforms_info):
+    """ 老版本模型加载,仅支持PaddleX前端导出的模型
+    """
+    logging.debug("Use build_transforms_v1 to reconstruct transforms")
+    if model_type == "classifier":
+        import paddlex.cv.transforms.cls_transforms as T
+    elif model_type == "detector":
+        import paddlex.cv.transforms.det_transforms as T
+    elif model_type == "segmenter":
+        import paddlex.cv.transforms.seg_transforms as T
+    transforms = list()
+    for op_info in transforms_info:
+        op_name = op_info[0]
+        op_attr = op_info[1]
+        if op_name == 'DecodeImage':
+            continue
+        if op_name == 'Permute':
+            continue
+        if op_name == 'ResizeByShort':
+            op_attr_new = dict()
+            if 'short_size' in op_attr:
+                op_attr_new['short_size'] = op_attr['short_size']
+            else:
+                op_attr_new['short_size'] = op_attr['target_size']
+            op_attr_new['max_size'] = op_attr.get('max_size', -1)
+            op_attr = op_attr_new
+        if op_name.startswith('Arrange'):
+            continue
+        if not hasattr(T, op_name):
+            raise Exception(
+                "There's no operator named '{}' in transforms of {}".format(
+                    op_name, model_type))
+        transforms.append(getattr(T, op_name)(**op_attr))
+    if model_type == "detector" and len(batch_transforms_info) > 0:
+        op_name = batch_transforms_info[0][0]
+        op_attr = batch_transforms_info[0][1]
+        assert op_name == "PaddingMiniBatch", "Only PaddingMiniBatch transform is supported for batch transform"
+        padding = T.Padding(coarsest_stride=op_attr['coarsest_stride'])
+        transforms.append(padding)
+    eval_transforms = T.Compose(transforms)
+    return eval_transforms