瀏覽代碼

remove pycocotools dependency

jiangjiajun 5 年之前
父節點
當前提交
c22fb0531c

+ 6 - 0
paddlex/__init__.py

@@ -20,6 +20,12 @@ from . import seg
 from . import cls
 from . import slim
 
+try:
+    import pycocotools
+except:
+    print("[WARNING] pycocotools is not installed, detection model is not available now.")
+    print("[WARNING] pycocotools install: https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/install.md")
+
 env_info = get_environ_info()
 load_model = cv.models.load_model
 datasets = cv.datasets

+ 2 - 1
paddlex/cv/datasets/coco.py

@@ -19,7 +19,6 @@ import random
 import numpy as np
 import paddlex.utils.logging as logging
 import paddlex as pst
-from pycocotools.coco import COCO
 from .voc import VOCDetection
 from .dataset import is_pic
 
@@ -47,6 +46,8 @@ class CocoDetection(VOCDetection):
                  buffer_size=100,
                  parallel_method='process',
                  shuffle=False):
+        from pycocotools.coco import COCO
+
         super(VOCDetection, self).__init__(
             transforms=transforms,
             num_workers=num_workers,

+ 1 - 1
paddlex/cv/datasets/voc.py

@@ -18,7 +18,6 @@ import os.path as osp
 import random
 import numpy as np
 import xml.etree.ElementTree as ET
-from pycocotools.coco import COCO
 import paddlex.utils.logging as logging
 from .dataset import Dataset
 from .dataset import is_pic
@@ -51,6 +50,7 @@ class VOCDetection(Dataset):
                  buffer_size=100,
                  parallel_method='process',
                  shuffle=False):
+        from pycocotools.coco import COCO
         super(VOCDetection, self).__init__(
             transforms=transforms,
             num_workers=num_workers,

+ 4 - 3
paddlex/cv/models/slim/visualize.py

@@ -15,9 +15,6 @@
 import os.path as osp
 import tqdm
 import numpy as np
-import matplotlib
-matplotlib.use('Agg')
-import matplotlib.pyplot as plt
 from .prune import cal_model_size
 from paddleslim.prune import load_sensitivities
 
@@ -30,6 +27,10 @@ def visualize(model, sensitivities_file, save_dir='./'):
         model (paddlex.cv.models): paddlex中的模型。
         sensitivities_file (str): 敏感度文件存储路径。
     """
+    import matplotlib
+    matplotlib.use('Agg')
+    import matplotlib.pyplot as plt
+
     program = model.test_prog
     place = model.places[0]
     fig = plt.figure()

+ 1 - 1
paddlex/cv/models/utils/visualize.py

@@ -15,7 +15,6 @@
 import os
 import cv2
 import numpy as np
-import matplotlib.pyplot as plt
 from PIL import Image, ImageDraw
 
 import paddlex.utils.logging as logging
@@ -222,6 +221,7 @@ def draw_pr_curve(eval_details_file=None,
         return mean_s
 
     def cal_pr(coco_gt, coco_dt, iou_thresh, save_dir, style='bbox'):
+        import matplotlib.pyplot as plt
         from pycocotools.cocoeval import COCOeval
         coco_dt = loadRes(coco_gt, coco_dt)
         np.linspace = fixed_linspace

+ 2 - 1
setup.py

@@ -29,7 +29,8 @@ setuptools.setup(
     packages=setuptools.find_packages(),
     setup_requires=['cython', 'numpy', 'sklearn'],
     install_requires=[
-        'pycocotools', 'pyyaml', 'colorama', 'tqdm', 'visualdl==1.3.0',
+        "pycocotools;platform_system!='Windows'", 
+        'pyyaml', 'colorama', 'tqdm', 'visualdl==1.3.0',
         'paddleslim==1.0.1', 'paddlehub>=1.6.2'
     ],
     classifiers=[