|
|
@@ -54,6 +54,9 @@ def load_model(model_dir, fixed_input_shape=None):
|
|
|
logging.info("Model already has fixed_input_shape with {}".
|
|
|
format(fixed_input_shape))
|
|
|
model.fixed_input_shape = fixed_input_shape
|
|
|
+ else:
|
|
|
+ info['_Attributes']['fixed_input_shape'] = model.fixed_input_shape
|
|
|
+
|
|
|
|
|
|
if info['Model'].count('RCNN') > 0:
|
|
|
if info['_init_params']['with_fpn']:
|
|
|
@@ -67,6 +70,7 @@ def load_model(model_dir, fixed_input_shape=None):
|
|
|
"The second value in fixed_input_shape must be a multiple of 32, but recieved {}.".
|
|
|
format(model.fixed_input_shape[1]))
|
|
|
|
|
|
+
|
|
|
with fluid.scope_guard(model_scope):
|
|
|
if status == "Normal" or \
|
|
|
status == "Prune" or status == "fluid.save":
|
|
|
@@ -104,6 +108,7 @@ def load_model(model_dir, fixed_input_shape=None):
|
|
|
for i, out in enumerate(outputs):
|
|
|
var_desc = test_outputs_info[i]
|
|
|
model.test_outputs[var_desc[0]] = out
|
|
|
+
|
|
|
if 'Transforms' in info:
|
|
|
transforms_mode = info.get('TransformsMode', 'RGB')
|
|
|
# 固定模型的输入shape
|
|
|
@@ -127,6 +132,8 @@ def load_model(model_dir, fixed_input_shape=None):
|
|
|
if k in model.__dict__:
|
|
|
model.__dict__[k] = v
|
|
|
|
|
|
+
|
|
|
+
|
|
|
logging.info("Model[{}] loaded.".format(info['Model']))
|
|
|
model.scope = model_scope
|
|
|
model.trainable = False
|