Browse Source

make directory of quantization cache file for inference model

FlyingQianMM 4 years ago
parent
commit
efea8117b3
1 changed files with 39 additions and 0 deletions
  1. 39 0
      paddlex/cv/models/slim/post_quantization.py

+ 39 - 0
paddlex/cv/models/slim/post_quantization.py

@@ -24,6 +24,16 @@ import numpy as np
 import time
 import time
 
 
 
 
+def _load_variable_data(scope, var_name):
+    '''
+    Load variable value from scope
+    '''
+    var_node = scope.find_var(var_name)
+    assert var_node is not None, \
+        "Cannot find " + var_name + " in scope."
+    return np.array(var_node.get_tensor())
+
+
 class PaddleXPostTrainingQuantization(PostTrainingQuantization):
 class PaddleXPostTrainingQuantization(PostTrainingQuantization):
     def __init__(self,
     def __init__(self,
                  executor,
                  executor,
@@ -284,3 +294,32 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
                         str(len(self._quantized_act_var_name)),
                         str(len(self._quantized_act_var_name)),
                         str(end - start)))
                         str(end - start)))
                 ct += 1
                 ct += 1
+
+    def _sample_data(self, iter):
+        '''
+        Sample the tensor data of quantized variables,
+        applied in every iteration.
+        '''
+        assert self._algo == "KL", "The algo should be KL to sample data."
+        for var_name in self._quantized_weight_var_name:
+            if var_name not in self._sampling_data:
+                var_tensor = _load_variable_data(self._scope, var_name)
+                self._sampling_data[var_name] = var_tensor
+
+        if self._is_use_cache_file:
+            for var_name in self._quantized_act_var_name:
+                var_tensor = _load_variable_data(self._scope, var_name)
+                var_tensor = var_tensor.ravel()
+                save_path = os.path.join(self._cache_dir,
+                                         var_name + "_" + str(iter) + ".npy")
+                save_dir, file_name = os.path.split(save_path)
+                if not os.path.exists(save_dir):
+                    os.mkdirs(save_dir)
+                np.save(save_path, var_tensor)
+        else:
+            for var_name in self._quantized_act_var_name:
+                if var_name not in self._sampling_data:
+                    self._sampling_data[var_name] = []
+                var_tensor = _load_variable_data(self._scope, var_name)
+                var_tensor = var_tensor.ravel()
+                self._sampling_data[var_name].append(var_tensor)