|
@@ -28,7 +28,12 @@ def load_model(model_dir):
|
|
|
raise Exception("There's not model.yml in {}".format(model_dir))
|
|
raise Exception("There's not model.yml in {}".format(model_dir))
|
|
|
with open(osp.join(model_dir, "model.yml")) as f:
|
|
with open(osp.join(model_dir, "model.yml")) as f:
|
|
|
info = yaml.load(f.read(), Loader=yaml.Loader)
|
|
info = yaml.load(f.read(), Loader=yaml.Loader)
|
|
|
- status = info['status']
|
|
|
|
|
|
|
+
|
|
|
|
|
+ if 'status' in info:
|
|
|
|
|
+ status = info['status']
|
|
|
|
|
+ elif 'save_method' in info:
|
|
|
|
|
+ # 兼容老版本PaddleX
|
|
|
|
|
+ status = info['save_method']
|
|
|
|
|
|
|
|
if not hasattr(paddlex.cv.models, info['Model']):
|
|
if not hasattr(paddlex.cv.models, info['Model']):
|
|
|
raise Exception("There's no attribute {} in paddlex.cv.models".format(
|
|
raise Exception("There's no attribute {} in paddlex.cv.models".format(
|
|
@@ -40,7 +45,7 @@ def load_model(model_dir):
|
|
|
model = getattr(paddlex.cv.models,
|
|
model = getattr(paddlex.cv.models,
|
|
|
info['Model'])(**info['_init_params'])
|
|
info['Model'])(**info['_init_params'])
|
|
|
if status == "Normal" or \
|
|
if status == "Normal" or \
|
|
|
- status == "Prune":
|
|
|
|
|
|
|
+ status == "Prune" or status == "fluid.save":
|
|
|
startup_prog = fluid.Program()
|
|
startup_prog = fluid.Program()
|
|
|
model.test_prog = fluid.Program()
|
|
model.test_prog = fluid.Program()
|
|
|
with fluid.program_guard(model.test_prog, startup_prog):
|
|
with fluid.program_guard(model.test_prog, startup_prog):
|
|
@@ -59,7 +64,7 @@ def load_model(model_dir):
|
|
|
fluid.io.set_program_state(model.test_prog, load_dict)
|
|
fluid.io.set_program_state(model.test_prog, load_dict)
|
|
|
|
|
|
|
|
elif status == "Infer" or \
|
|
elif status == "Infer" or \
|
|
|
- status == "Quant":
|
|
|
|
|
|
|
+ status == "Quant" or status == "fluid.save_inference_model":
|
|
|
[prog, input_names, outputs] = fluid.io.load_inference_model(
|
|
[prog, input_names, outputs] = fluid.io.load_inference_model(
|
|
|
model_dir, model.exe, params_filename='__params__')
|
|
model_dir, model.exe, params_filename='__params__')
|
|
|
model.test_prog = prog
|
|
model.test_prog = prog
|
|
@@ -77,9 +82,15 @@ def load_model(model_dir):
|
|
|
to_rgb = True
|
|
to_rgb = True
|
|
|
else:
|
|
else:
|
|
|
to_rgb = False
|
|
to_rgb = False
|
|
|
- model.test_transforms = build_transforms(model.model_type,
|
|
|
|
|
- info['Transforms'], to_rgb)
|
|
|
|
|
- model.eval_transforms = copy.deepcopy(model.test_transforms)
|
|
|
|
|
|
|
+ if 'BatchTransforms' in info:
|
|
|
|
|
+ # 兼容老版本PaddleX模型
|
|
|
|
|
+ model.test_transforms = build_transforms_v1(
|
|
|
|
|
+ model.model_type, info['Transforms'], info['BatchTransforms'])
|
|
|
|
|
+ model.eval_transforms = copy.deepcopy(model.test_transforms)
|
|
|
|
|
+ else:
|
|
|
|
|
+ model.test_transforms = build_transforms(
|
|
|
|
|
+ model.model_type, info['Transforms'], to_rgb)
|
|
|
|
|
+ model.eval_transforms = copy.deepcopy(model.test_transforms)
|
|
|
|
|
|
|
|
if '_Attributes' in info:
|
|
if '_Attributes' in info:
|
|
|
for k, v in info['_Attributes'].items():
|
|
for k, v in info['_Attributes'].items():
|
|
@@ -109,3 +120,46 @@ def build_transforms(model_type, transforms_info, to_rgb=True):
|
|
|
eval_transforms = T.Compose(transforms)
|
|
eval_transforms = T.Compose(transforms)
|
|
|
eval_transforms.to_rgb = to_rgb
|
|
eval_transforms.to_rgb = to_rgb
|
|
|
return eval_transforms
|
|
return eval_transforms
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def build_transforms_v1(model_type, transforms_info, batch_transforms_info):
|
|
|
|
|
+ """ 老版本模型加载,仅支持PaddleX前端导出的模型
|
|
|
|
|
+ """
|
|
|
|
|
+ logging.debug("Use build_transforms_v1 to reconstruct transforms")
|
|
|
|
|
+ if model_type == "classifier":
|
|
|
|
|
+ import paddlex.cv.transforms.cls_transforms as T
|
|
|
|
|
+ elif model_type == "detector":
|
|
|
|
|
+ import paddlex.cv.transforms.det_transforms as T
|
|
|
|
|
+ elif model_type == "segmenter":
|
|
|
|
|
+ import paddlex.cv.transforms.seg_transforms as T
|
|
|
|
|
+ transforms = list()
|
|
|
|
|
+ for op_info in transforms_info:
|
|
|
|
|
+ op_name = op_info[0]
|
|
|
|
|
+ op_attr = op_info[1]
|
|
|
|
|
+ if op_name == 'DecodeImage':
|
|
|
|
|
+ continue
|
|
|
|
|
+ if op_name == 'Permute':
|
|
|
|
|
+ continue
|
|
|
|
|
+ if op_name == 'ResizeByShort':
|
|
|
|
|
+ op_attr_new = dict()
|
|
|
|
|
+ if 'short_size' in op_attr:
|
|
|
|
|
+ op_attr_new['short_size'] = op_attr['short_size']
|
|
|
|
|
+ else:
|
|
|
|
|
+ op_attr_new['short_size'] = op_attr['target_size']
|
|
|
|
|
+ op_attr_new['max_size'] = op_attr.get('max_size', -1)
|
|
|
|
|
+ op_attr = op_attr_new
|
|
|
|
|
+ if op_name.startswith('Arrange'):
|
|
|
|
|
+ continue
|
|
|
|
|
+ if not hasattr(T, op_name):
|
|
|
|
|
+ raise Exception(
|
|
|
|
|
+ "There's no operator named '{}' in transforms of {}".format(
|
|
|
|
|
+ op_name, model_type))
|
|
|
|
|
+ transforms.append(getattr(T, op_name)(**op_attr))
|
|
|
|
|
+ if model_type == "detector" and len(batch_transforms_info) > 0:
|
|
|
|
|
+ op_name = batch_transforms_info[0][0]
|
|
|
|
|
+ op_attr = batch_transforms_info[0][1]
|
|
|
|
|
+ assert op_name == "PaddingMiniBatch", "Only PaddingMiniBatch transform is supported for batch transform"
|
|
|
|
|
+ padding = T.Padding(coarsest_stride=op_attr['coarsest_stride'])
|
|
|
|
|
+ transforms.append(padding)
|
|
|
|
|
+ eval_transforms = T.Compose(transforms)
|
|
|
|
|
+ return eval_transforms
|