浏览代码

fix default device

gaotingquan 1 年之前
父节点
当前提交
361c48cccd
共有 2 个文件被更改,包括 4 次插入3 次删除
  1. 3 2
      paddlex/inference/utils/pp_option.py
  2. 1 1
      paddlex/model.py

+ 3 - 2
paddlex/inference/utils/pp_option.py

@@ -53,10 +53,11 @@ class PaddlePredictorOption(object):
 
     def _get_default_config(self):
         """get default config"""
+        device_type, device_id = parse_device(get_default_device())
         return {
             "run_mode": "paddle",
-            "device": get_default_device(),
-            "device_id": 0,
+            "device": device_type,
+            "device_id": device_id[0],
             "min_subgraph_size": 3,
             "shape_info_filename": None,
             "trt_calib_mode": False,

+ 1 - 1
paddlex/model.py

@@ -80,7 +80,7 @@ class _ModelBasedConfig(_BaseModel):
         kernel_option = predict_kwargs.pop("kernel_option", {})
         kernel_option.update({"device": device})
 
-        pp_option = PaddlePredictorOption(**kernel_option)
+        pp_option = PaddlePredictorOption(self._model_name, **kernel_option)
         predictor = create_predictor(model, pp_option=pp_option)
         assert "input" in predict_kwargs
         return predict_kwargs, predictor