浏览代码

models can be initilized without building net

will-jl944 4 年之前
父节点
当前提交
dd7d8de2ed
共有 5 个文件被更改,包括 844 次插入660 次删除
  1. 85 77
      paddlex/cv/models/classifier.py
  2. 549 521
      paddlex/cv/models/detector.py
  3. 55 44
      paddlex/cv/models/load_model.py
  4. 36 18
      paddlex/cv/models/segmenter.py
  5. 119 0
      paddlex/deploy.py

+ 85 - 77
paddlex/cv/models/classifier.py

@@ -55,6 +55,8 @@ class BaseClassifier(BaseModel):
         self.init_params.update(params)
         if 'lr_mult_list' in self.init_params:
             del self.init_params['lr_mult_list']
+        if 'with_net' in self.init_params:
+            del self.init_params['with_net']
         super(BaseClassifier, self).__init__('classifier')
         if not hasattr(architectures, model_name):
             raise Exception("ERROR: There's no model named {}.".format(
@@ -65,7 +67,9 @@ class BaseClassifier(BaseModel):
         self.num_classes = num_classes
         for k, v in params.items():
             setattr(self, k, v)
-        self.net = self.build_net(**params)
+        if params.get('with_net', True):
+            params.pop('with_net', None)
+            self.net = self.build_net(**params)
 
     def build_net(self, **params):
         with paddle.utils.unique_name.guard():
@@ -469,93 +473,95 @@ class BaseClassifier(BaseModel):
 
 
 class ResNet18(BaseClassifier):
-    def __init__(self, num_classes=1000):
+    def __init__(self, num_classes=1000, **params):
         super(ResNet18, self).__init__(
-            model_name='ResNet18', num_classes=num_classes)
+            model_name='ResNet18', num_classes=num_classes, **params)
 
 
 class ResNet34(BaseClassifier):
-    def __init__(self, num_classes=1000):
+    def __init__(self, num_classes=1000, **params):
         super(ResNet34, self).__init__(
-            model_name='ResNet34', num_classes=num_classes)
+            model_name='ResNet34', num_classes=num_classes, **params)
 
 
 class ResNet50(BaseClassifier):
-    def __init__(self, num_classes=1000):
+    def __init__(self, num_classes=1000, **params):
         super(ResNet50, self).__init__(
-            model_name='ResNet50', num_classes=num_classes)
+            model_name='ResNet50', num_classes=num_classes, **params)
 
 
 class ResNet101(BaseClassifier):
-    def __init__(self, num_classes=1000):
+    def __init__(self, num_classes=1000, **params):
         super(ResNet101, self).__init__(
-            model_name='ResNet101', num_classes=num_classes)
+            model_name='ResNet101', num_classes=num_classes, **params)
 
 
 class ResNet152(BaseClassifier):
-    def __init__(self, num_classes=1000):
+    def __init__(self, num_classes=1000, **params):
         super(ResNet152, self).__init__(
-            model_name='ResNet152', num_classes=num_classes)
+            model_name='ResNet152', num_classes=num_classes, **params)
 
 
 class ResNet18_vd(BaseClassifier):
-    def __init__(self, num_classes=1000):
+    def __init__(self, num_classes=1000, **params):
         super(ResNet18_vd, self).__init__(
-            model_name='ResNet18_vd', num_classes=num_classes)
+            model_name='ResNet18_vd', num_classes=num_classes, **params)
 
 
 class ResNet34_vd(BaseClassifier):
-    def __init__(self, num_classes=1000):
+    def __init__(self, num_classes=1000, **params):
         super(ResNet34_vd, self).__init__(
-            model_name='ResNet34_vd', num_classes=num_classes)
+            model_name='ResNet34_vd', num_classes=num_classes, **params)
 
 
 class ResNet50_vd(BaseClassifier):
-    def __init__(self, num_classes=1000):
+    def __init__(self, num_classes=1000, **params):
         super(ResNet50_vd, self).__init__(
-            model_name='ResNet50_vd', num_classes=num_classes)
+            model_name='ResNet50_vd', num_classes=num_classes, **params)
 
 
 class ResNet50_vd_ssld(BaseClassifier):
-    def __init__(self, num_classes=1000):
+    def __init__(self, num_classes=1000, **params):
         super(ResNet50_vd_ssld, self).__init__(
             model_name='ResNet50_vd',
             num_classes=num_classes,
-            lr_mult_list=[.1, .1, .2, .2, .3])
+            lr_mult_list=[.1, .1, .2, .2, .3],
+            **params)
         self.model_name = 'ResNet50_vd_ssld'
 
 
 class ResNet101_vd(BaseClassifier):
-    def __init__(self, num_classes=1000):
+    def __init__(self, num_classes=1000, **params):
         super(ResNet101_vd, self).__init__(
-            model_name='ResNet101_vd', num_classes=num_classes)
+            model_name='ResNet101_vd', num_classes=num_classes, **params)
 
 
 class ResNet101_vd_ssld(BaseClassifier):
-    def __init__(self, num_classes=1000):
+    def __init__(self, num_classes=1000, **params):
         super(ResNet101_vd_ssld, self).__init__(
             model_name='ResNet101_vd',
             num_classes=num_classes,
-            lr_mult_list=[.1, .1, .2, .2, .3])
+            lr_mult_list=[.1, .1, .2, .2, .3],
+            **params)
         self.model_name = 'ResNet101_vd_ssld'
 
 
 class ResNet152_vd(BaseClassifier):
-    def __init__(self, num_classes=1000):
+    def __init__(self, num_classes=1000, **params):
         super(ResNet152_vd, self).__init__(
-            model_name='ResNet152_vd', num_classes=num_classes)
+            model_name='ResNet152_vd', num_classes=num_classes, **params)
 
 
 class ResNet200_vd(BaseClassifier):
-    def __init__(self, num_classes=1000):
+    def __init__(self, num_classes=1000, **params):
         super(ResNet200_vd, self).__init__(
-            model_name='ResNet200_vd', num_classes=num_classes)
+            model_name='ResNet200_vd', num_classes=num_classes, **params)
 
 
 class AlexNet(BaseClassifier):
-    def __init__(self, num_classes=1000):
+    def __init__(self, num_classes=1000, **params):
         super(AlexNet, self).__init__(
-            model_name='AlexNet', num_classes=num_classes)
+            model_name='AlexNet', num_classes=num_classes, **params)
 
     def _get_test_inputs(self, image_shape):
         if image_shape is not None:
@@ -580,13 +586,13 @@ class AlexNet(BaseClassifier):
 
 
 class DarkNet53(BaseClassifier):
-    def __init__(self, num_classes=1000):
+    def __init__(self, num_classes=1000, **params):
         super(DarkNet53, self).__init__(
-            model_name='DarkNet53', num_classes=num_classes)
+            model_name='DarkNet53', num_classes=num_classes, **params)
 
 
 class MobileNetV1(BaseClassifier):
-    def __init__(self, num_classes=1000, scale=1.0):
+    def __init__(self, num_classes=1000, scale=1.0, **params):
         supported_scale = [.25, .5, .75, 1.0]
         if scale not in supported_scale:
             logging.warning("scale={} is not supported by MobileNetV1, "
@@ -598,11 +604,11 @@ class MobileNetV1(BaseClassifier):
             model_name = 'MobileNetV1_x' + str(scale).replace('.', '_')
         self.scale = scale
         super(MobileNetV1, self).__init__(
-            model_name=model_name, num_classes=num_classes)
+            model_name=model_name, num_classes=num_classes, **params)
 
 
 class MobileNetV2(BaseClassifier):
-    def __init__(self, num_classes=1000, scale=1.0):
+    def __init__(self, num_classes=1000, scale=1.0, **params):
         supported_scale = [.25, .5, .75, 1.0, 1.5, 2.0]
         if scale not in supported_scale:
             logging.warning("scale={} is not supported by MobileNetV2, "
@@ -613,11 +619,11 @@ class MobileNetV2(BaseClassifier):
         else:
             model_name = 'MobileNetV2_x' + str(scale).replace('.', '_')
         super(MobileNetV2, self).__init__(
-            model_name=model_name, num_classes=num_classes)
+            model_name=model_name, num_classes=num_classes, **params)
 
 
 class MobileNetV3_small(BaseClassifier):
-    def __init__(self, num_classes=1000, scale=1.0):
+    def __init__(self, num_classes=1000, scale=1.0, **params):
         supported_scale = [.35, .5, .75, 1.0, 1.25]
         if scale not in supported_scale:
             logging.warning("scale={} is not supported by MobileNetV3_small, "
@@ -626,11 +632,11 @@ class MobileNetV3_small(BaseClassifier):
         model_name = 'MobileNetV3_small_x' + str(float(scale)).replace('.',
                                                                        '_')
         super(MobileNetV3_small, self).__init__(
-            model_name=model_name, num_classes=num_classes)
+            model_name=model_name, num_classes=num_classes, **params)
 
 
 class MobileNetV3_small_ssld(BaseClassifier):
-    def __init__(self, num_classes=1000, scale=1.0):
+    def __init__(self, num_classes=1000, scale=1.0, **params):
         supported_scale = [.35, 1.0]
         if scale not in supported_scale:
             logging.warning(
@@ -640,12 +646,12 @@ class MobileNetV3_small_ssld(BaseClassifier):
         model_name = 'MobileNetV3_small_x' + str(float(scale)).replace('.',
                                                                        '_')
         super(MobileNetV3_small_ssld, self).__init__(
-            model_name=model_name, num_classes=num_classes)
+            model_name=model_name, num_classes=num_classes, **params)
         self.model_name = model_name + '_ssld'
 
 
 class MobileNetV3_large(BaseClassifier):
-    def __init__(self, num_classes=1000, scale=1.0):
+    def __init__(self, num_classes=1000, scale=1.0, **params):
         supported_scale = [.35, .5, .75, 1.0, 1.25]
         if scale not in supported_scale:
             logging.warning("scale={} is not supported by MobileNetV3_large, "
@@ -654,108 +660,110 @@ class MobileNetV3_large(BaseClassifier):
         model_name = 'MobileNetV3_large_x' + str(float(scale)).replace('.',
                                                                        '_')
         super(MobileNetV3_large, self).__init__(
-            model_name=model_name, num_classes=num_classes)
+            model_name=model_name, num_classes=num_classes, **params)
 
 
 class MobileNetV3_large_ssld(BaseClassifier):
-    def __init__(self, num_classes=1000):
+    def __init__(self, num_classes=1000, **params):
         super(MobileNetV3_large_ssld, self).__init__(
-            model_name='MobileNetV3_large_x1_0', num_classes=num_classes)
+            model_name='MobileNetV3_large_x1_0',
+            num_classes=num_classes,
+            **params)
         self.model_name = 'MobileNetV3_large_x1_0_ssld'
 
 
 class DenseNet121(BaseClassifier):
-    def __init__(self, num_classes=1000):
+    def __init__(self, num_classes=1000, **params):
         super(DenseNet121, self).__init__(
-            model_name='DenseNet121', num_classes=num_classes)
+            model_name='DenseNet121', num_classes=num_classes, **params)
 
 
 class DenseNet161(BaseClassifier):
-    def __init__(self, num_classes=1000):
+    def __init__(self, num_classes=1000, **params):
         super(DenseNet161, self).__init__(
-            model_name='DenseNet161', num_classes=num_classes)
+            model_name='DenseNet161', num_classes=num_classes, **params)
 
 
 class DenseNet169(BaseClassifier):
-    def __init__(self, num_classes=1000):
+    def __init__(self, num_classes=1000, **params):
         super(DenseNet169, self).__init__(
-            model_name='DenseNet169', num_classes=num_classes)
+            model_name='DenseNet169', num_classes=num_classes, **params)
 
 
 class DenseNet201(BaseClassifier):
-    def __init__(self, num_classes=1000):
+    def __init__(self, num_classes=1000, **params):
         super(DenseNet201, self).__init__(
-            model_name='DenseNet201', num_classes=num_classes)
+            model_name='DenseNet201', num_classes=num_classes, **params)
 
 
 class DenseNet264(BaseClassifier):
-    def __init__(self, num_classes=1000):
+    def __init__(self, num_classes=1000, **params):
         super(DenseNet264, self).__init__(
-            model_name='DenseNet264', num_classes=num_classes)
+            model_name='DenseNet264', num_classes=num_classes, **params)
 
 
 class HRNet_W18_C(BaseClassifier):
-    def __init__(self, num_classes=1000):
+    def __init__(self, num_classes=1000, **params):
         super(HRNet_W18_C, self).__init__(
-            model_name='HRNet_W18_C', num_classes=num_classes)
+            model_name='HRNet_W18_C', num_classes=num_classes, **params)
 
 
 class HRNet_W30_C(BaseClassifier):
-    def __init__(self, num_classes=1000):
+    def __init__(self, num_classes=1000, **params):
         super(HRNet_W30_C, self).__init__(
-            model_name='HRNet_W30_C', num_classes=num_classes)
+            model_name='HRNet_W30_C', num_classes=num_classes, **params)
 
 
 class HRNet_W32_C(BaseClassifier):
-    def __init__(self, num_classes=1000):
+    def __init__(self, num_classes=1000, **params):
         super(HRNet_W32_C, self).__init__(
-            model_name='HRNet_W32_C', num_classes=num_classes)
+            model_name='HRNet_W32_C', num_classes=num_classes, **params)
 
 
 class HRNet_W40_C(BaseClassifier):
-    def __init__(self, num_classes=1000):
+    def __init__(self, num_classes=1000, **params):
         super(HRNet_W40_C, self).__init__(
-            model_name='HRNet_W40_C', num_classes=num_classes)
+            model_name='HRNet_W40_C', num_classes=num_classes, **params)
 
 
 class HRNet_W44_C(BaseClassifier):
-    def __init__(self, num_classes=1000):
+    def __init__(self, num_classes=1000, **params):
         super(HRNet_W44_C, self).__init__(
-            model_name='HRNet_W44_C', num_classes=num_classes)
+            model_name='HRNet_W44_C', num_classes=num_classes, **params)
 
 
 class HRNet_W48_C(BaseClassifier):
-    def __init__(self, num_classes=1000):
+    def __init__(self, num_classes=1000, **params):
         super(HRNet_W48_C, self).__init__(
-            model_name='HRNet_W48_C', num_classes=num_classes)
+            model_name='HRNet_W48_C', num_classes=num_classes, **params)
 
 
 class HRNet_W64_C(BaseClassifier):
-    def __init__(self, num_classes=1000):
+    def __init__(self, num_classes=1000, **params):
         super(HRNet_W64_C, self).__init__(
-            model_name='HRNet_W64_C', num_classes=num_classes)
+            model_name='HRNet_W64_C', num_classes=num_classes, **params)
 
 
 class Xception41(BaseClassifier):
-    def __init__(self, num_classes=1000):
+    def __init__(self, num_classes=1000, **params):
         super(Xception41, self).__init__(
-            model_name='Xception41', num_classes=num_classes)
+            model_name='Xception41', num_classes=num_classes, **params)
 
 
 class Xception65(BaseClassifier):
-    def __init__(self, num_classes=1000):
+    def __init__(self, num_classes=1000, **params):
         super(Xception65, self).__init__(
-            model_name='Xception65', num_classes=num_classes)
+            model_name='Xception65', num_classes=num_classes, **params)
 
 
 class Xception71(BaseClassifier):
-    def __init__(self, num_classes=1000):
+    def __init__(self, num_classes=1000, **params):
         super(Xception71, self).__init__(
-            model_name='Xception71', num_classes=num_classes)
+            model_name='Xception71', num_classes=num_classes, **params)
 
 
 class ShuffleNetV2(BaseClassifier):
-    def __init__(self, num_classes=1000, scale=1.0):
+    def __init__(self, num_classes=1000, scale=1.0, **params):
         supported_scale = [.25, .33, .5, 1.0, 1.5, 2.0]
         if scale not in supported_scale:
             logging.warning("scale={} is not supported by ShuffleNetV2, "
@@ -763,7 +771,7 @@ class ShuffleNetV2(BaseClassifier):
             scale = 1.0
         model_name = 'ShuffleNetV2_x' + str(float(scale)).replace('.', '_')
         super(ShuffleNetV2, self).__init__(
-            model_name=model_name, num_classes=num_classes)
+            model_name=model_name, num_classes=num_classes, **params)
 
     def _get_test_inputs(self, image_shape):
         if image_shape is not None:
@@ -788,9 +796,9 @@ class ShuffleNetV2(BaseClassifier):
 
 
 class ShuffleNetV2_swish(BaseClassifier):
-    def __init__(self, num_classes=1000):
+    def __init__(self, num_classes=1000, **params):
         super(ShuffleNetV2_swish, self).__init__(
-            model_name='ShuffleNetV2_x1_5', num_classes=num_classes)
+            model_name='ShuffleNetV2_x1_5', num_classes=num_classes, **params)
 
     def _get_test_inputs(self, image_shape):
         if image_shape is not None:

文件差异内容过多而无法显示
+ 549 - 521
paddlex/cv/models/detector.py


+ 55 - 44
paddlex/cv/models/load_model.py

@@ -45,7 +45,7 @@ def load_rcnn_inference_model(model_dir):
     return net_state_dict
 
 
-def load_model(model_dir):
+def load_model(model_dir, **params):
     """
     Load saved model from a given directory.
     Args:
@@ -69,6 +69,10 @@ def load_model(model_dir):
             format(paddlex.__version__, version))
 
     status = model_info['status']
+    with_net = params.get('with_net', True)
+    if not with_net:
+        assert status == 'Infer', \
+            "Only exported inference models can be deployed, current model status is {}".format(status)
 
     if not hasattr(paddlex.cv.models, model_info['Model']):
         raise Exception("There's no attribute {} in paddlex.cv.models".format(
@@ -76,51 +80,58 @@ def load_model(model_dir):
     if 'model_name' in model_info['_init_params']:
         del model_info['_init_params']['model_name']
 
-    with paddle.utils.unique_name.guard():
+    model_info['_init_params'].update({'with_net': with_net})
+
+    if with_net:
+        with paddle.utils.unique_name.guard():
+            model = getattr(paddlex.cv.models, model_info['Model'])(
+                **model_info['_init_params'])
+
+            if status == 'Pruned' or osp.exists(
+                    osp.join(model_dir, "prune.yml")):
+                with open(osp.join(model_dir, "prune.yml")) as f:
+                    pruning_info = yaml.load(f.read(), Loader=yaml.Loader)
+                    inputs = pruning_info['pruner_inputs']
+                    if model.model_type == 'detector':
+                        inputs = [{
+                            k: paddle.to_tensor(v)
+                            for k, v in inputs.items()
+                        }]
+                        model.net.eval()
+                    model.pruner = getattr(paddleslim, pruning_info['pruner'])(
+                        model.net, inputs=inputs)
+                    model.pruning_ratios = pruning_info['pruning_ratios']
+                    model.pruner.prune_vars(
+                        ratios=model.pruning_ratios,
+                        axis=paddleslim.dygraph.prune.filter_pruner.FILTER_DIM)
+
+            if status == 'Quantized':
+                with open(osp.join(model_dir, "quant.yml")) as f:
+                    quant_info = yaml.load(f.read(), Loader=yaml.Loader)
+                    model.quant_config = quant_info['quant_config']
+                    model.quantizer = paddleslim.QAT(model.quant_config)
+                    model.quantizer.quantize(model.net)
+
+            if status == 'Infer':
+                if model_info['Model'] in ['FasterRCNN', 'MaskRCNN']:
+                    net_state_dict = load_rcnn_inference_model(model_dir)
+                else:
+                    net_state_dict = paddle.load(osp.join(model_dir, 'model'))
+            else:
+                net_state_dict = paddle.load(
+                    osp.join(model_dir, 'model.pdparams'))
+            model.net.set_state_dict(net_state_dict)
+    else:
         model = getattr(paddlex.cv.models, model_info['Model'])(
             **model_info['_init_params'])
 
-        if 'Transforms' in model_info:
-            model.test_transforms = build_transforms(model_info['Transforms'])
-
-        if '_Attributes' in model_info:
-            for k, v in model_info['_Attributes'].items():
-                if k in model.__dict__:
-                    model.__dict__[k] = v
-
-        if status == 'Pruned' or osp.exists(osp.join(model_dir, "prune.yml")):
-            with open(osp.join(model_dir, "prune.yml")) as f:
-                pruning_info = yaml.load(f.read(), Loader=yaml.Loader)
-                inputs = pruning_info['pruner_inputs']
-                if model.model_type == 'detector':
-                    inputs = [{
-                        k: paddle.to_tensor(v)
-                        for k, v in inputs.items()
-                    }]
-                    model.net.eval()
-                model.pruner = getattr(paddleslim, pruning_info['pruner'])(
-                    model.net, inputs=inputs)
-                model.pruning_ratios = pruning_info['pruning_ratios']
-                model.pruner.prune_vars(
-                    ratios=model.pruning_ratios,
-                    axis=paddleslim.dygraph.prune.filter_pruner.FILTER_DIM)
-
-        if status == 'Quantized':
-            with open(osp.join(model_dir, "quant.yml")) as f:
-                quant_info = yaml.load(f.read(), Loader=yaml.Loader)
-                model.quant_config = quant_info['quant_config']
-                model.quantizer = paddleslim.QAT(model.quant_config)
-                model.quantizer.quantize(model.net)
-
-        if status == 'Infer':
-            if model_info['Model'] in ['FasterRCNN', 'MaskRCNN']:
-                net_state_dict = load_rcnn_inference_model(model_dir)
-            else:
-                net_state_dict = paddle.load(osp.join(model_dir, 'model'))
-        else:
-            net_state_dict = paddle.load(osp.join(model_dir, 'model.pdparams'))
-        model.net.set_state_dict(net_state_dict)
+    if 'Transforms' in model_info:
+        model.test_transforms = build_transforms(model_info['Transforms'])
 
-        logging.info("Model[{}] loaded.".format(model_info['Model']))
-        model.status = status
+    if '_Attributes' in model_info:
+        for k, v in model_info['_Attributes'].items():
+            if k in model.__dict__:
+                model.__dict__[k] = v
+    logging.info("Model[{}] loaded.".format(model_info['Model']))
+    model.status = status
     return model

+ 36 - 18
paddlex/cv/models/segmenter.py

@@ -39,6 +39,8 @@ class BaseSegmenter(BaseModel):
                  use_mixed_loss=False,
                  **params):
         self.init_params = locals()
+        if 'with_net' in self.init_params:
+            del self.init_params['with_net']
         super(BaseSegmenter, self).__init__('segmenter')
         if not hasattr(paddleseg.models, model_name):
             raise Exception("ERROR: There's no model named {}.".format(
@@ -48,7 +50,9 @@ class BaseSegmenter(BaseModel):
         self.use_mixed_loss = use_mixed_loss
         self.losses = None
         self.labels = None
-        self.net = self.build_net(**params)
+        if params.get('with_net', True):
+            params.pop('with_net', None)
+            self.net = self.build_net(**params)
         self.find_unused_parameters = True
 
     def build_net(self, **params):
@@ -582,8 +586,12 @@ class UNet(BaseSegmenter):
                  num_classes=2,
                  use_mixed_loss=False,
                  use_deconv=False,
-                 align_corners=False):
-        params = {'use_deconv': use_deconv, 'align_corners': align_corners}
+                 align_corners=False,
+                 **params):
+        params.update({
+            'use_deconv': use_deconv,
+            'align_corners': align_corners
+        })
         super(UNet, self).__init__(
             model_name='UNet',
             num_classes=num_classes,
@@ -600,22 +608,26 @@ class DeepLabV3P(BaseSegmenter):
                  backbone_indices=(0, 3),
                  aspp_ratios=(1, 12, 24, 36),
                  aspp_out_channels=256,
-                 align_corners=False):
+                 align_corners=False,
+                 **params):
         self.backbone_name = backbone
         if backbone not in ['ResNet50_vd', 'ResNet101_vd']:
             raise ValueError(
                 "backbone: {} is not supported. Please choose one of "
                 "('ResNet50_vd', 'ResNet101_vd')".format(backbone))
-        with DisablePrint():
-            backbone = getattr(paddleseg.models, backbone)(
-                output_stride=output_stride)
-        params = {
+        if params.get('with_net', True):
+            with DisablePrint():
+                backbone = getattr(paddleseg.models, backbone)(
+                    output_stride=output_stride)
+        else:
+            backbone = None
+        params.update({
             'backbone': backbone,
             'backbone_indices': backbone_indices,
             'aspp_ratios': aspp_ratios,
             'aspp_out_channels': aspp_out_channels,
             'align_corners': align_corners
-        }
+        })
         super(DeepLabV3P, self).__init__(
             model_name='DeepLabV3P',
             num_classes=num_classes,
@@ -627,8 +639,9 @@ class FastSCNN(BaseSegmenter):
     def __init__(self,
                  num_classes=2,
                  use_mixed_loss=False,
-                 align_corners=False):
-        params = {'align_corners': align_corners}
+                 align_corners=False,
+                 **params):
+        params.update({'align_corners': align_corners})
         super(FastSCNN, self).__init__(
             model_name='FastSCNN',
             num_classes=num_classes,
@@ -641,17 +654,21 @@ class HRNet(BaseSegmenter):
                  num_classes=2,
                  width=48,
                  use_mixed_loss=False,
-                 align_corners=False):
+                 align_corners=False,
+                 **params):
         if width not in (18, 48):
             raise ValueError(
                 "width={} is not supported, please choose from [18, 48]".
                 format(width))
         self.backbone_name = 'HRNet_W{}'.format(width)
-        with DisablePrint():
-            backbone = getattr(paddleseg.models, self.backbone_name)(
-                align_corners=align_corners)
+        if params.get('with_net', True):
+            with DisablePrint():
+                backbone = getattr(paddleseg.models, self.backbone_name)(
+                    align_corners=align_corners)
+        else:
+            backbone = None
 
-        params = {'backbone': backbone, 'align_corners': align_corners}
+        params.update({'backbone': backbone, 'align_corners': align_corners})
         super(HRNet, self).__init__(
             model_name='FCN',
             num_classes=num_classes,
@@ -664,8 +681,9 @@ class BiSeNetV2(BaseSegmenter):
     def __init__(self,
                  num_classes=2,
                  use_mixed_loss=False,
-                 align_corners=False):
-        params = {'align_corners': align_corners}
+                 align_corners=False,
+                 **params):
+        params.update({'align_corners': align_corners})
         super(BiSeNetV2, self).__init__(
             model_name='BiSeNetV2',
             num_classes=num_classes,

+ 119 - 0
paddlex/deploy.py

@@ -0,0 +1,119 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import os.path as osp
+import numpy as np
+import yaml
+from paddle.inference import Config
+from paddle.inference import create_predictor
+from paddle.inference import PrecisionType
+from paddlex.cv.transforms import build_transforms
+from paddlex.utils import logging
+
+
+class Predictor(object):
+    def __init__(self,
+                 model_dir,
+                 use_gpu=True,
+                 gpu_id=0,
+                 cpu_thread_num=1,
+                 use_mkl=True,
+                 mkl_thread_num=4,
+                 use_trt=False,
+                 use_glog=False,
+                 memory_optimize=True,
+                 max_trt_batch_size=1,
+                 trt_precision_mode='float32'):
+        """ 创建Paddle Predictor
+
+            Args:
+                model_dir: 模型路径(必须是导出的部署或量化模型)
+                use_gpu: 是否使用gpu,默认True
+                gpu_id: 使用gpu的id,默认0
+                cpu_thread_num=1:使用cpu进行预测时的线程数,默认为1
+                use_mkl: 是否使用mkldnn计算库,CPU情况下使用,默认False
+                mkl_thread_num: mkldnn计算线程数,默认为4
+                use_trt: 是否使用TensorRT,默认False
+                use_glog: 是否启用glog日志, 默认False
+                memory_optimize: 是否启动内存优化,默认True
+                max_trt_batch_size: 在使用TensorRT时配置的最大batch size,默认1
+                trt_precision_mode:在使用TensorRT时采用的精度,默认float32
+        """
+        if not osp.isdir(model_dir):
+            logging.error(
+                "{} is not a valid model directory.".format(model_dir),
+                exit=True)
+
+        if trt_precision_mode == 'float32':
+            trt_precision_mode = PrecisionType.Float32
+        elif trt_precision_mode == 'float16':
+            trt_precision_mode = PrecisionType.Float16
+        else:
+            logging.error(
+                "TensorRT precision mode {} is invalid. Supported modes are float32 and float16."
+                .format(trt_precision_mode),
+                exit=True)
+
+    def create_predictor(self,
+                         use_gpu=True,
+                         gpu_id=0,
+                         cpu_thread_num=1,
+                         use_mkl=True,
+                         mkl_thread_num=4,
+                         use_trt=False,
+                         use_glog=False,
+                         memory_optimize=True,
+                         max_trt_batch_size=1,
+                         trt_precision_mode=PrecisionType.Float32):
+        config = Config(
+            prog_file=osp.join(self.model_dir, 'model.pdmodel'),
+            params_file=osp.join(self.model_dir, 'model.pdiparams'))
+
+        if use_gpu:
+            # 设置GPU初始显存(单位M)和Device ID
+            config.enable_use_gpu(100, gpu_id)
+            config.switch_ir_optim(True)
+            if use_trt:
+                config.enable_tensorrt_engine(
+                    workspace_size=1 << 10,
+                    max_batch_size=max_trt_batch_size,
+                    min_subgraph_size=3,
+                    precision_mode=trt_precision_mode,
+                    use_static=False,
+                    use_calib_mode=False)
+        else:
+            config.disable_gpu()
+            config.set_cpu_math_library_num_threads(cpu_thread_num)
+            if use_mkl:
+                try:
+                    # cache 10 different shapes for mkldnn to avoid memory leak
+                    config.set_mkldnn_cache_capacity(10)
+                    config.enable_mkldnn()
+                    config.set_cpu_math_library_num_threads(mkl_thread_num)
+                except Exception as e:
+                    logging.warning(
+                        "The current environment does not support `mkldnn`, so disable mkldnn."
+                    )
+                    pass
+
+        if use_glog:
+            config.enable_glog_info()
+        else:
+            config.disable_glog_info()
+        if memory_optimize:
+            config.enable_memory_optim()
+        config.switch_use_feed_fetch_ops(False)
+        predictor = create_predictor(config)
+        return predictor

部分文件因为文件数量过多而无法显示