Ver Fonte

update version and fix bug for fixed_input_shape

jiangjiajun há 4 anos atrás
pai
commit
8ef1181bff
3 ficheiros alterados com 9 adições e 2 exclusões
  1. 1 1
      paddlex/__init__.py
  2. 7 0
      paddlex/cv/models/load_model.py
  3. 1 1
      setup.py

+ 1 - 1
paddlex/__init__.py

@@ -14,7 +14,7 @@
 
 from __future__ import absolute_import
 
-__version__ = '1.3.3'
+__version__ = '1.3.4'
 
 import os
 if 'FLAGS_eager_delete_tensor_gb' not in os.environ:

+ 7 - 0
paddlex/cv/models/load_model.py

@@ -54,6 +54,9 @@ def load_model(model_dir, fixed_input_shape=None):
                 logging.info("Model already has fixed_input_shape with {}".
                              format(fixed_input_shape))
                 model.fixed_input_shape = fixed_input_shape
+            else:
+                info['_Attributes']['fixed_input_shape'] = model.fixed_input_shape
+
 
     if info['Model'].count('RCNN') > 0:
         if info['_init_params']['with_fpn']:
@@ -67,6 +70,7 @@ def load_model(model_dir, fixed_input_shape=None):
                         "The second value in fixed_input_shape must be a multiple of 32, but recieved {}.".
                         format(model.fixed_input_shape[1]))
 
+
     with fluid.scope_guard(model_scope):
         if status == "Normal" or \
                 status == "Prune" or status == "fluid.save":
@@ -104,6 +108,7 @@ def load_model(model_dir, fixed_input_shape=None):
             for i, out in enumerate(outputs):
                 var_desc = test_outputs_info[i]
                 model.test_outputs[var_desc[0]] = out
+
     if 'Transforms' in info:
         transforms_mode = info.get('TransformsMode', 'RGB')
         # 固定模型的输入shape
@@ -127,6 +132,8 @@ def load_model(model_dir, fixed_input_shape=None):
             if k in model.__dict__:
                 model.__dict__[k] = v
 
+
+
     logging.info("Model[{}] loaded.".format(info['Model']))
     model.scope = model_scope
     model.trainable = False

+ 1 - 1
setup.py

@@ -19,7 +19,7 @@ long_description = "PaddlePaddle Entire Process Development Toolkit"
 
 setuptools.setup(
     name="paddlex",
-    version='1.3.3',
+    version='1.3.4',
     author="paddlex",
     author_email="paddlex@baidu.com",
     description=long_description,