Kaynağa Gözat

Merge pull request #3 from PaddlePaddle/develop

20201216
LaraStuStu 4 yıl önce
ebeveyn
işleme
7c18baf44c

+ 1 - 0
paddlex/__init__.py

@@ -36,6 +36,7 @@ elif version[0] == '2':
     print(
         "[WARNING] You are using paddlepaddle(v{}) which may not compatible with paddlex(v{}), paddlepaddle==1.8.4 is strongly recommended.".
         format(paddle.__version__, __version__))
+if hasattr(paddle, 'enable_static'):
     paddle.enable_static()
 
 from .utils.utils import get_environ_info

+ 39 - 0
paddlex/cv/models/slim/post_quantization.py

@@ -24,6 +24,16 @@ import numpy as np
 import time
 
 
+def _load_variable_data(scope, var_name):
+    '''
+    Load variable value from scope
+    '''
+    var_node = scope.find_var(var_name)
+    assert var_node is not None, \
+        "Cannot find " + var_name + " in scope."
+    return np.array(var_node.get_tensor())
+
+
 class PaddleXPostTrainingQuantization(PostTrainingQuantization):
     def __init__(self,
                  executor,
@@ -284,3 +294,32 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
                         str(len(self._quantized_act_var_name)),
                         str(end - start)))
                 ct += 1
+
+    def _sample_data(self, iter):
+        '''
+        Sample the tensor data of quantized variables,
+        applied in every iteration.
+        '''
+        assert self._algo == "KL", "The algo should be KL to sample data."
+        for var_name in self._quantized_weight_var_name:
+            if var_name not in self._sampling_data:
+                var_tensor = _load_variable_data(self._scope, var_name)
+                self._sampling_data[var_name] = var_tensor
+
+        if self._is_use_cache_file:
+            for var_name in self._quantized_act_var_name:
+                var_tensor = _load_variable_data(self._scope, var_name)
+                var_tensor = var_tensor.ravel()
+                save_path = os.path.join(self._cache_dir,
+                                         var_name + "_" + str(iter) + ".npy")
+                save_dir, file_name = os.path.split(save_path)
+                if not os.path.exists(save_dir):
+                    os.mkdirs(save_dir)
+                np.save(save_path, var_tensor)
+        else:
+            for var_name in self._quantized_act_var_name:
+                if var_name not in self._sampling_data:
+                    self._sampling_data[var_name] = []
+                var_tensor = _load_variable_data(self._scope, var_name)
+                var_tensor = var_tensor.ravel()
+                self._sampling_data[var_name].append(var_tensor)

+ 8 - 5
paddlex/cv/models/utils/detection_eval.py

@@ -921,7 +921,7 @@ def coco_error_analysis(eval_details_file=None,
 
     """
 
-    from multiprocessing import Pool
+    import multiprocessing as mp
     from pycocotools.coco import COCO
     from pycocotools.cocoeval import COCOeval
 
@@ -968,10 +968,11 @@ def coco_error_analysis(eval_details_file=None,
         ps = np.vstack([ps, np.zeros((4, *ps.shape[1:]))])
         catIds = cocoGt.getCatIds()
         recThrs = cocoEval.params.recThrs
-        with Pool(processes=48) as pool:
-            args = [(k, cocoDt, cocoGt, catId, iou_type)
-                    for k, catId in enumerate(catIds)]
-            analyze_results = pool.starmap(analyze_individual_category, args)
+        thread_num = mp.cpu_count() if mp.cpu_count() < 8 else 8
+        thread_pool = mp.pool.ThreadPool(thread_num)
+        args = [(k, cocoDt, cocoGt, catId, iou_type)
+                for k, catId in enumerate(catIds)]
+        analyze_results = thread_pool.starmap(analyze_individual_category, args)
         for k, catId in enumerate(catIds):
             nm = cocoGt.loadCats(catId)[0]
             logging.info('--------------saving {}-{}---------------'.format(
@@ -996,6 +997,7 @@ def coco_error_analysis(eval_details_file=None,
             makeplot(recThrs, ps[:, :, k], res_out_dir, nm['name'], iou_type)
         makeplot(recThrs, ps, res_out_dir, 'allclass', iou_type)
 
+    np.linspace = fixed_linspace
     coco_gt = COCO()
     coco_gt.dataset = gt
     coco_gt.createIndex()
@@ -1006,4 +1008,5 @@ def coco_error_analysis(eval_details_file=None,
     if pred_mask is not None:
         coco_dt = loadRes(coco_gt, pred_mask)
         _analyze_results(coco_gt, coco_dt, res_type='segm', out_dir=save_dir)
+    np.linspace = backup_linspace
     logging.info("The analysis figures are saved in {}".format(save_dir))