Quellcode durchsuchen

add ppyolo in deploy

FlyingQianMM vor 5 Jahren
Ursprung
Commit
9df43751f5
3 geänderte Dateien mit 9 neuen und 6 gelöschten Zeilen
  1. 7 5
      paddlex/deploy.py
  2. 1 0
      requirements.txt
  3. 1 1
      setup.py

+ 7 - 5
paddlex/deploy.py

@@ -19,7 +19,9 @@ import yaml
 import paddlex
 import paddle.fluid as fluid
 from paddlex.cv.transforms import build_transforms
-from paddlex.cv.models import BaseClassifier, YOLOv3, FasterRCNN, MaskRCNN, DeepLabv3p
+from paddlex.cv.models import BaseClassifier
+from paddlex.cv.models import PPYOLO, FasterRCNN, MaskRCNN
+from paddlex.cv.models import DeepLabv3p
 
 
 class Predictor:
@@ -129,8 +131,8 @@ class Predictor:
                 thread_num=thread_num)
             res['image'] = im
         elif self.model_type == "detector":
-            if self.model_name == "YOLOv3":
-                im, im_size = YOLOv3._preprocess(
+            if self.model_name in ["PPYOLO", "YOLOv3"]:
+                im, im_size = PPYOLO._preprocess(
                     image,
                     self.transforms,
                     self.model_type,
@@ -190,8 +192,8 @@ class Predictor:
             res = {'bbox': (results[0][0], offset_to_lengths(results[0][1])), }
             res['im_id'] = (np.array(
                 [[i] for i in range(batch_size)]).astype('int32'), [[]])
-            if self.model_name == "YOLOv3":
-                preds = YOLOv3._postprocess(res, batch_size, self.num_classes,
+            if self.model_name in ["PPYOLO", "YOLOv3"]:
+                preds = PPYOLO._postprocess(res, batch_size, self.num_classes,
                                             self.labels)
             elif self.model_name == "FasterRCNN":
                 preds = FasterRCNN._postprocess(res, batch_size,

+ 1 - 0
requirements.txt

@@ -8,3 +8,4 @@ paddleslim == 1.0.1
 shapely
 x2paddle
 paddlepaddle-gpu
+opencv-python

+ 1 - 1
setup.py

@@ -31,7 +31,7 @@ setuptools.setup(
     install_requires=[
         "pycocotools;platform_system!='Windows'", 'pyyaml', 'colorama', 'tqdm',
         'paddleslim==1.0.1', 'visualdl>=2.0.0b', 'paddlehub>=1.6.2',
-        'shapely>=1.7.0'
+        'shapely>=1.7.0', "opencv-python"
     ],
     classifiers=[
         "Programming Language :: Python :: 3",