Browse Source

disable mkl-dnn for maskrcnn

will-jl944 4 năm trước cách đây
mục cha
commit
ad5c61f28a
1 tập tin đã thay đổi với 13 bổ sung8 xóa
  1. 13 8
      paddlex/deploy.py

+ 13 - 8
paddlex/deploy.py

@@ -114,16 +114,21 @@ class Predictor(object):
             config.disable_gpu()
             config.set_cpu_math_library_num_threads(cpu_thread_num)
             if use_mkl:
-                try:
-                    # cache 10 different shapes for mkldnn to avoid memory leak
-                    config.set_mkldnn_cache_capacity(10)
-                    config.enable_mkldnn()
-                    config.set_cpu_math_library_num_threads(mkl_thread_num)
-                except Exception as e:
+                if self._model.__class__.__name__ == 'MaskRCNN':
                     logging.warning(
-                        "The current environment does not support `mkldnn`, so disable mkldnn."
+                        "MaskRCNN does not support MKL-DNN, MKL-DNN is forcibly disabled"
                     )
-                    pass
+                else:
+                    try:
+                        # cache 10 different shapes for mkldnn to avoid memory leak
+                        config.set_mkldnn_cache_capacity(10)
+                        config.enable_mkldnn()
+                        config.set_cpu_math_library_num_threads(mkl_thread_num)
+                    except Exception as e:
+                        logging.warning(
+                            "The current environment does not support MKL-DNN, MKL-DNN is disabled."
+                        )
+                        pass
 
         if not use_glog:
             config.disable_glog_info()