瀏覽代碼

fix train and inference issues for rocm (#1890)

cuicheng01 1 年之前
父節點
當前提交
f0425e9aca
共有 2 個文件被更改,包括 8 次插入1 次删除
  1. 4 1
      paddlex/modules/base/predictor/utils/paddle_inference_predictor.py
  2. 4 0
      paddlex/utils/device.py

+ 4 - 1
paddlex/modules/base/predictor/utils/paddle_inference_predictor.py

@@ -38,7 +38,10 @@ self._create(model_dir, model_prefix, option, delete_pass=delete_pass)
 
         if option.device == 'gpu':
             config.enable_use_gpu(200, option.device_id)
-            config.enable_new_ir(True)
+            if paddle.is_compiled_with_rocm():
+                os.environ['FLAGS_conv_workspace_size_limit'] = '2000'
+            else:
+                config.enable_new_ir(True)
         elif option.device == 'npu':
             config.enable_custom_device('npu')
             os.environ["FLAGS_npu_jit_compile"] = "0"

+ 4 - 0
paddlex/utils/device.py

@@ -15,6 +15,7 @@
 
 
 import os
+import paddle
 from .errors import raise_unsupported_device_error
 
 SUPPORTED_DEVICE_TYPE = ["cpu", "gpu", "xpu", "npu", "mlu"]
@@ -26,12 +27,15 @@ def get_device(device_cfg, using_device_number=None):
     device = device_cfg.split(":")[0]
     assert device.lower() in SUPPORTED_DEVICE_TYPE
     if device.lower() in ["gpu", "xpu", "npu", "mlu"]:
+        if device.lower() == "gpu" and paddle.is_compiled_with_rocm():
+            os.environ['FLAGS_conv_workspace_size_limit'] = '2000'
         if device.lower() == "npu":
             os.environ["FLAGS_npu_jit_compile"] = "0"
             os.environ["FLAGS_use_stride_kernel"] = "0"
             os.environ["FLAGS_allocator_strategy"] = "auto_growth"
             os.environ[
                 "CUSTOM_DEVICE_BLACK_LIST"] = "pad3d,pad3d_grad,set_value,set_value_with_tensor"
+
         if len(device_cfg.split(":")) == 2:
             device_ids = device_cfg.split(":")[1]
         else: