瀏覽代碼

[Fix] Fix bugs and polish docs (#3859)

* Polish docs

* Set PIR-TRT to default

* Fix bugs

* Fix doc

* Fix trt eager load bug

* Set alias for doc understanding endpoint
Lin Manhui 7 月之前
父節點
當前提交
69f4ebf469

+ 1 - 1
docs/installation/installation.en.md

@@ -255,7 +255,7 @@ PaddleX currently provides the following dependency groups:
 | Dependency Group | Corresponding Features |
 | - | - |
 | `base` | All basic features of PaddleX. |
-| `cv` | Basic features of computer vision pipelines (excluding multimodal pipelines). |
+| `cv` | Basic features of computer vision pipelines. |
 | `multimodal` | Basic features of multimodal pipelines. |
 | `ie` | Basic features of information extraction pipelines. |
 | `ocr` | Basic features of OCR-related pipelines. |

+ 1 - 1
docs/installation/installation.md

@@ -255,7 +255,7 @@ PaddleX 目前提供如下依赖组:
 | 依赖组名称 | 对应的功能 |
 | - | - |
 | `base` | PaddleX 的所有基础功能。 |
-| `cv` | 除多模态产线外的 CV 产线的基础功能。 |
+| `cv` | CV 产线的基础功能。 |
 | `multimodal` | 多模态产线的基础功能。 |
 | `ie` | 信息抽取产线的基础功能。 |
 | `ocr` | OCR 类产线的基础功能。 |

+ 2 - 2
docs/module_usage/instructions/model_python_API.en.md

@@ -108,9 +108,9 @@ PaddleX supports modifying the inference configuration through `PaddlePredictorO
 * `cpu_threads`: Number of CPU threads for the acceleration library, only valid when the inference device is 'cpu'.
   * Supports setting an `int` type for the number of CPU threads for the acceleration library during CPU inference.
   * Return value: `int` type, the currently set number of threads for the acceleration library.
-* `trt_dynamic_shapes`: TensorRT dynamic shapes, only effective when `run_mode` is set to 'trt_fp32' or 'trt_fp16'.
+* `trt_dynamic_shapes`: TensorRT dynamic shape configuration, only effective when `run_mode` is set to 'trt_fp32' or 'trt_fp16'.
   * Supports setting a value of type `dict` or `None`. If it is a `dict`, the keys are the input tensor names and the values are two-level nested lists formatted as `[{minimum shape}, {optimal shape}, {maximum shape}]`, for example `[[1, 2], [1, 2], [2, 2]]`.
-  * Return value: `dict` type or `None`, the current TensorRT dynamic shape settings.
+  * Return value: `dict` type or `None`, the current TensorRT dynamic shape configuration.
 * `trt_dynamic_shape_input_data`: For TensorRT usage, this parameter provides the fill data for the input tensors used to build the engine, and it is only valid when `run_mode` is set to 'trt_fp32' or 'trt_fp16'.
   * Supports setting a value of type `dict` or `None`. If it is a `dict`, the keys are the input tensor names and the values are two-level nested lists formatted as `[{fill data corresponding to the minimum shape}, {fill data corresponding to the optimal shape}, {fill data corresponding to the maximum shape}]`, for example `[[1.0, 1.0], [1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]`. The data are floating point numbers filled in row-major order.
   * Return value: `dict` type or `None`, the currently set input tensor fill data.

+ 2 - 2
docs/module_usage/instructions/model_python_API.md

@@ -110,9 +110,9 @@ PaddleX 支持通过`PaddlePredictorOption`修改推理配置,相关API如下
 * `cpu_threads`:cpu 加速库计算线程数,仅当推理设备使用 cpu 时有效;
   * 支持设置 `int` 类型,cpu 推理时加速库计算线程数;
   * 返回值:`int` 类型,当前设置的加速库计算线程数。
-* `trt_dynamic_shapes`:TensorRT 动态形状,仅当 `run_mode` 为 'trt_fp32' 或 'trt_fp16' 时有效;
+* `trt_dynamic_shapes`:TensorRT 动态形状配置,仅当 `run_mode` 为 'trt_fp32' 或 'trt_fp16' 时有效;
   * 支持设置:`dict` 类型或 `None`,如果为 `dict`,键为输入张量名称,值为一个两级嵌套列表:`[{最小形状}, {优化形状}, {最大形状}]`,例如 `[[1, 2], [1, 2], [2, 2]]`;
-  * 返回值:`dict` 类型或 `None`,当前设置的 TensorRT 动态形状。
+  * 返回值:`dict` 类型或 `None`,当前设置的 TensorRT 动态形状配置
 * `trt_dynamic_shape_input_data`:使用 TensorRT 时,为用于构建引擎的输入张量填充的数据,仅当 `run_mode` 为 'trt_fp32' 或 'trt_fp16' 时有效;
   * 支持设置:`dict` 类型或 `None`,如果为 `dict`,键为输入张量名称,值为一个两级嵌套列表:`[{最小形状对应的填充数据}, {优化形状对应的填充数据}, {最大形状对应的填充数据}]`,例如 `[[1.0, 1.0], [1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]`,数据为浮点数,按照行优先顺序填充;
   * 返回值:`dict` 类型或 `None`,当前设置的输入张量填充数据。

+ 5 - 5
docs/pipeline_deploy/high_performance_inference.en.md

@@ -24,7 +24,7 @@ In real production environments, many applications impose strict performance met
 
 Before using the high-performance inference plugin, please ensure that you have completed the PaddleX installation according to the [PaddleX Local Installation Tutorial](../installation/installation.en.md) and have run the quick inference using the PaddleX pipeline command line or the PaddleX pipeline Python script as described in the usage instructions.
 
-High-performance inference supports handling **PaddlePaddle static graph models (`.pdmodel`, `.json`)** and **ONNX format models (`.onnx`)**. For ONNX format models, it is recommended to convert them using the [Paddle2ONNX Plugin](./paddle2onnx.en.md). If multiple model formats are present in the model directory, PaddleX will automatically choose the appropriate one as needed.
+The high-performance inference plugin supports handling multiple model formats, including **PaddlePaddle static graph (`.pdmodel`, `.json`)**, **ONNX (`.onnx`)** and **Huawei OM (`.om`)**, among others. For ONNX models, it is recommended to convert them using the [Paddle2ONNX Plugin](./paddle2onnx.en.md). If multiple model formats are present in the model directory, PaddleX will automatically choose the appropriate one as needed, and aotimatic model conversion may be performed.
 
 ### 1.1 Installing the High-Performance Inference Plugin
 
@@ -60,7 +60,7 @@ Currently, the supported processor architectures, operating systems, device type
   </tr>
 </table>
 
-#### 1.1.1 Installing the High-Performance Inference Plugin in a Docker Container (Highly Recommended):
+#### 1.1.1 Installing the High-Performance Inference Plugin in a Docker Container (Highly Recommended)
 
 Refer to [Get PaddleX based on Docker](../installation/installation.en.md#21-obtaining-paddlex-based-on-docker) to start a PaddleX container using Docker. After starting the container, execute the following commands according to your device type to install the high-performance inference plugin:
 
@@ -90,7 +90,7 @@ In the official PaddleX Docker image, TensorRT is installed by default. The high
 
 **Please note that the aforementioned Docker image refers to the official PaddleX image described in [Get PaddleX via Docker](../installation/installation.en.md#21-get-paddlex-based-on-docker), rather than the PaddlePaddle official image described in [PaddlePaddle Local Installation Tutorial](../installation/paddlepaddle_install.en.md#installing-paddlepaddle-via-docker). For the latter, please refer to the local installation instructions for the high-performance inference plugin.**
 
-#### 1.1.2 Installing the High-Performance Inference Plugin Locally:
+#### 1.1.2 Installing the High-Performance Inference Plugin Locally
 
 **To install the CPU version of the high-performance inference plugin:**
 
@@ -134,7 +134,7 @@ Please refer to the [Ascend NPU High-Performance Inference Tutorial](../practica
 
 1. **Currently, the official PaddleX only provides precompiled packages for CUDA 11.8 + cuDNN 8.9**; support for CUDA 12 is in progress.
 2. Only one version of the high-performance inference plugin should exist in the same environment.
-3. For Windows systems, it is currently recommended to install and use the high-performance inference plugin within a Docker container.
+3. For Windows systems, it is currently recommended to install and use the high-performance inference plugin within a Docker container or in [WSL](https://learn.microsoft.com/en-us/windows/wsl/install) environments.
 
 ### 1.2 Enabling the High-Performance Inference Plugin
 
@@ -314,7 +314,7 @@ The available configuration items for `backend_config` vary for different backen
     <td>
       <code>precision</code> (<code>str</code>): The precision used, either <code>"fp16"</code> or <code>"fp32"</code>. The default is <code>"fp32"</code>.
       <br />
-      <code>dynamic_shapes</code> (<code>dict</code>): Dynamic shapes. Dynamic shapes include the minimum shape, optimal shape, and maximum shape, which represent TensorRT’s ability to delay specifying some or all tensor dimensions until runtime. The format is: <code>{input tensor name}: [{minimum shape}, {optimization shape}, {maximum shape}]</code>. For more details, please refer to the <a href="https://docs.nvidia.com/deeplearning/tensorrt/latest/inference-library/work-dynamic-shapes.html">TensorRT official documentation</a>.
+      <code>dynamic_shapes</code> (<code>dict</code>): Dynamic shape configuration that specifies, for each input, its minimum shape, optimization shape, and maximum shape. The format is: <code>{input tensor name}: [{minimum shape}, {optimization shape}, {maximum shape}]</code>. Dynamic shapes is TensorRT’s ability to defer specifying some or all tensor dimensions until runtime. For more information, see the <a href="https://docs.nvidia.com/deeplearning/tensorrt/latest/inference-library/work-dynamic-shapes.html">TensorRT official documentation</a>.
     </td>
   </tr>
   <tr>

+ 6 - 6
docs/pipeline_deploy/high_performance_inference.md

@@ -24,7 +24,7 @@ comments: true
 
 使用高性能推理插件前,请确保您已经按照 [PaddleX本地安装教程](../installation/installation.md) 完成了PaddleX的安装,且按照PaddleX产线命令行使用说明或PaddleX产线Python脚本使用说明跑通了产线的快速推理。
 
-高性能推理支持处理 **PaddlePaddle 静态图模型( `.pdmodel`、 `.json` )** 和 **ONNX 格式模型( `.onnx` )**。对于 ONNX 格式模型,建议使用 [Paddle2ONNX 插件](./paddle2onnx.md) 转换得到。如果模型目录中存在多种格式的模型,PaddleX 会根据需要自动选择。
+高性能推理插件支持处理 **PaddlePaddle 静态图(`.pdmodel`、 `.json`)**、**ONNX(`.onnx`)**、**华为 OM(`.om`)** 等多种模型格式。对于 ONNX 模型,建议使用 [Paddle2ONNX 插件](./paddle2onnx.md) 转换得到。如果模型目录中存在多种格式的模型,PaddleX 会根据需要自动选择,并可能进行自动模型转换
 
 ### 1.1 安装高性能推理插件
 
@@ -60,7 +60,7 @@ comments: true
   </tr>
 </table>
 
-#### 1.1.1 在 Docker 容器中安装高性能推理插件(强烈推荐)
+#### 1.1.1 在 Docker 容器中安装高性能推理插件(强烈推荐)
 
 参考 [基于Docker获取PaddleX](../installation/installation.md#21-基于docker获取paddlex) 使用 Docker 启动 PaddleX 容器。启动容器后,根据设备类型,执行如下指令,安装高性能推理插件:
 
@@ -90,7 +90,7 @@ PaddleX 官方 Docker 镜像中默认安装了 TensorRT,高性能推理插件
 
 **请注意,以上提到的镜像指的是 [基于Docker获取PaddleX](../installation/installation.md#21-基于docker获取paddlex) 中描述的 PaddleX 官方镜像,而非 [飞桨PaddlePaddle本地安装教程](../installation/paddlepaddle_install.md#基于-docker-安装飞桨) 中描述的飞桨框架官方镜像。对于后者,请参考高性能推理插件本地安装说明。**
 
-#### 1.1.2 本地安装高性能推理插件
+#### 1.1.2 本地安装高性能推理插件
 
 **安装 CPU 版本的高性能推理插件:**
 
@@ -136,7 +136,7 @@ paddlex --install hpi-gpu
 
 2. 同一环境中只应该存在一个版本的高性能推理插件。
 
-3. 对于 Windows 系统,目前建议在 Docker 容器中安装和使用高性能推理插件。
+3. 对于 Windows 系统,目前建议在 Docker 容器或者 [WSL](https://learn.microsoft.com/zh-cn/windows/wsl/install) 环境中安装和使用高性能推理插件。
 
 ### 1.2 启用高性能推理插件
 
@@ -287,7 +287,7 @@ output = model.predict("https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/
   </tr>
   <tr>
     <td><code>om</code></td>
-    <td>华为昇腾NPU定制的离线模型格式对应的推理引擎,针对硬件进行了深度优化,减少算子计算时间和调度时间,能够有效提升推理性能。</td>
+    <td>华为昇腾 NPU 定制的离线模型格式对应的推理引擎,针对硬件进行了深度优化,减少算子计算时间和调度时间,能够有效提升推理性能。</td>
     <td>NPU</td>
   </tr>
 </table>
@@ -316,7 +316,7 @@ output = model.predict("https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/
     <td>
       <code>precision</code>(<code>str</code>):使用的精度,<code>"fp16"</code>或<code>"fp32"</code>。默认为<code>"fp32"</code>。
       <br />
-      <code>dynamic_shapes</code>(<code>dict</code>):动态形状。动态形状包含最小形状、最优形状以及最大形状,是 TensorRT 延迟指定部分或全部张量维度直到运行时的能力。格式为:<code>{输入张量名称}: [{最小形状}, {优化形状}, {最大形状}]</code>。更多介绍请参考 <a href="https://docs.nvidia.com/deeplearning/tensorrt/latest/inference-library/work-dynamic-shapes.html">TensorRT 官方文档</a>。
+      <code>dynamic_shapes</code>(<code>dict</code>):动态形状配置,指定每个输入对应的最小形状、优化形状以及最大形状。格式为:<code>{输入张量名称}: [{最小形状}, {优化形状}, {最大形状}]</code>。动态形状是 TensorRT 延迟指定部分或全部张量维度直到运行时的能力,更多介绍请参考 <a href="https://docs.nvidia.com/deeplearning/tensorrt/latest/inference-library/work-dynamic-shapes.html">TensorRT 官方文档</a>。
     </td>
   </tr>
   <tr>

+ 2 - 2
docs/pipeline_usage/instructions/pipeline_python_API.en.md

@@ -106,9 +106,9 @@ PaddleX supports modifying the inference configuration through `PaddlePredictorO
 * `cpu_threads`: Number of CPU threads for the acceleration library, only valid when the inference device is 'cpu'.
   * Supports setting an `int` type for the number of CPU threads for the acceleration library during CPU inference.
   * Return value: `int` type, the currently set number of threads for the acceleration library.
-* `trt_dynamic_shapes`: TensorRT dynamic shapes, only effective when `run_mode` is set to 'trt_fp32' or 'trt_fp16'.
+* `trt_dynamic_shapes`: TensorRT dynamic shape configuration, only effective when `run_mode` is set to 'trt_fp32' or 'trt_fp16'.
   * Supports setting a value of type `dict` or `None`. If it is a `dict`, the keys are the input tensor names and the values are two-level nested lists formatted as `[{minimum shape}, {optimal shape}, {maximum shape}]`, for example `[[1, 2], [1, 2], [2, 2]]`.
-  * Return value: `dict` type or `None`, the current TensorRT dynamic shape settings.
+  * Return value: `dict` type or `None`, the current TensorRT dynamic shape configuration.
 * `trt_dynamic_shape_input_data`: For TensorRT usage, this parameter provides the fill data for the input tensors used to build the engine, and it is only valid when `run_mode` is set to 'trt_fp32' or 'trt_fp16'.
   * Supports setting a value of type `dict` or `None`. If it is a `dict`, the keys are the input tensor names and the values are two-level nested lists formatted as `[{fill data corresponding to the minimum shape}, {fill data corresponding to the optimal shape}, {fill data corresponding to the maximum shape}]`, for example `[[1.0, 1.0], [1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]`. The data are floating point numbers filled in row-major order.
   * Return value: `dict` type or `None`, the currently set input tensor fill data.

+ 2 - 2
docs/pipeline_usage/instructions/pipeline_python_API.md

@@ -107,9 +107,9 @@ PaddleX 支持通过`PaddlePredictorOption`修改推理配置,相关API如下
 * `cpu_threads`:cpu 加速库计算线程数,仅当推理设备使用 cpu 时有效;
   * 支持设置 `int` 类型,cpu 推理时加速库计算线程数;
   * 返回值:`int` 类型,当前设置的加速库计算线程数。
-* `trt_dynamic_shapes`:TensorRT 动态形状,仅当 `run_mode` 为 'trt_fp32' 或 'trt_fp16' 时有效;
+* `trt_dynamic_shapes`:TensorRT 动态形状配置,仅当 `run_mode` 为 'trt_fp32' 或 'trt_fp16' 时有效;
   * 支持设置:`dict` 类型或 `None`,如果为 `dict`,键为输入张量名称,值为一个两级嵌套列表:`[{最小形状}, {优化形状}, {最大形状}]`,例如 `[[1, 2], [1, 2], [2, 2]]`;
-  * 返回值:`dict` 类型或 `None`,当前设置的 TensorRT 动态形状。
+  * 返回值:`dict` 类型或 `None`,当前的 TensorRT 动态形状配置
 * `trt_dynamic_shape_input_data`:使用 TensorRT 时,为用于构建引擎的输入张量填充的数据,仅当 `run_mode` 为 'trt_fp32' 或 'trt_fp16' 时有效;
   * 支持设置:`dict` 类型或 `None`,如果为 `dict`,键为输入张量名称,值为一个两级嵌套列表:`[{最小形状对应的填充数据}, {优化形状对应的填充数据}, {最大形状对应的填充数据}]`,例如 `[[1.0, 1.0], [1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]`,数据为浮点数,按照行优先顺序填充;
   * 返回值:`dict` 类型或 `None`,当前设置的输入张量填充数据。

+ 37 - 27
paddlex/inference/models/common/static_infer.py

@@ -14,8 +14,9 @@
 
 import abc
 import subprocess
+from os import PathLike
 from pathlib import Path
-from typing import List, Sequence
+from typing import List, Sequence, Union
 
 import numpy as np
 
@@ -307,12 +308,12 @@ class StaticInfer(metaclass=abc.ABCMeta):
 class PaddleInfer(StaticInfer):
     def __init__(
         self,
-        model_dir: str,
+        model_dir: Union[str, PathLike],
         model_file_prefix: str,
         option: PaddlePredictorOption,
     ) -> None:
         super().__init__()
-        self.model_dir = model_dir
+        self.model_dir = Path(model_dir)
         self.model_file_prefix = model_file_prefix
         self._option = option
         self.predictor = self._create()
@@ -491,6 +492,8 @@ class PaddleInfer(StaticInfer):
         import paddle.inference
 
         if USE_PIR_TRT:
+            if self._option.trt_dynamic_shapes is None:
+                raise RuntimeError("No dynamic shape information provided")
             trt_save_path = cache_dir / "trt" / self.model_file_prefix
             _convert_trt(
                 self._option.trt_cfg_setting,
@@ -520,6 +523,8 @@ class PaddleInfer(StaticInfer):
                     getattr(config, func_name)(**args)
 
             if self._option.trt_use_dynamic_shapes:
+                if self._option.trt_dynamic_shapes is None:
+                    raise RuntimeError("No dynamic shape information provided")
                 if self._option.trt_collect_shape_range_info:
                     # NOTE: We always use a shape range info file.
                     if self._option.trt_shape_range_info_path is not None:
@@ -567,20 +572,17 @@ class PaddleInfer(StaticInfer):
                         self._option.trt_allow_rebuild_at_runtime,
                     )
                 else:
-                    if self._option.trt_dynamic_shapes is not None:
-                        min_shapes, opt_shapes, max_shapes = {}, {}, {}
-                        for (
-                            key,
-                            shapes,
-                        ) in self._option.trt_dynamic_shapes.items():
-                            min_shapes[key] = shapes[0]
-                            opt_shapes[key] = shapes[1]
-                            max_shapes[key] = shapes[2]
-                            config.set_trt_dynamic_shape_info(
-                                min_shapes, max_shapes, opt_shapes
-                            )
-                    else:
-                        raise RuntimeError("No dynamic shape information provided")
+                    min_shapes, opt_shapes, max_shapes = {}, {}, {}
+                    for (
+                        key,
+                        shapes,
+                    ) in self._option.trt_dynamic_shapes.items():
+                        min_shapes[key] = shapes[0]
+                        opt_shapes[key] = shapes[1]
+                        max_shapes[key] = shapes[2]
+                        config.set_trt_dynamic_shape_info(
+                            min_shapes, max_shapes, opt_shapes
+                        )
 
         return config
 
@@ -605,12 +607,12 @@ class MultiBackendInfer(object):
 class HPInfer(StaticInfer):
     def __init__(
         self,
-        model_dir: str,
+        model_dir: Union[str, PathLike],
         model_file_prefix: str,
         config: HPIConfig,
     ) -> None:
         super().__init__()
-        self._model_dir = model_dir
+        self._model_dir = Path(model_dir)
         self._model_file_prefix = model_file_prefix
         self._config = config
         backend, backend_config = self._determine_backend_and_config()
@@ -627,7 +629,7 @@ class HPInfer(StaticInfer):
             ]
 
     @property
-    def model_dir(self) -> str:
+    def model_dir(self) -> Path:
         return self._model_dir
 
     @property
@@ -695,7 +697,11 @@ class HPInfer(StaticInfer):
         }
         # TODO: This is probably redundant. Can we reuse the code in the
         # predictor class?
-        paddle_info = self._config.hpi_info.backend_configs.paddle_infer
+        paddle_info = None
+        if self._config.hpi_info:
+            hpi_info = self._config.hpi_info
+            if hpi_info.backend_configs:
+                paddle_info = hpi_info.backend_configs.paddle_infer
         if paddle_info is not None:
             if (
                 kwargs.get("trt_dynamic_shapes") is None
@@ -736,7 +742,7 @@ class HPInfer(StaticInfer):
                 f"Unsupported device type {repr(self._config.device_type)}"
             )
 
-        model_paths = get_model_paths(self.model_dir, self.model_file_prefix)
+        model_paths = get_model_paths(self._model_dir, self.model_file_prefix)
         if backend in ("openvino", "onnxruntime", "tensorrt"):
             # XXX: This introduces side effects.
             if "onnx" not in model_paths:
@@ -753,9 +759,9 @@ class HPInfer(StaticInfer):
                                 "paddlex",
                                 "--paddle2onnx",
                                 "--paddle_model_dir",
-                                self._model_dir,
+                                str(self._model_dir),
                                 "--onnx_model_dir",
-                                self._model_dir,
+                                str(self._model_dir),
                             ],
                             capture_output=True,
                             check=True,
@@ -766,7 +772,7 @@ class HPInfer(StaticInfer):
                             f"PaddlePaddle-to-ONNX conversion failed:\n{e.stderr}"
                         ) from e
                     model_paths = get_model_paths(
-                        self.model_dir, self.model_file_prefix
+                        self._model_dir, self.model_file_prefix
                     )
                     assert "onnx" in model_paths
                 else:
@@ -792,7 +798,11 @@ class HPInfer(StaticInfer):
                 backend_config.get("use_dynamic_shapes", True)
                 and backend_config.get("dynamic_shapes") is None
             ):
-                trt_info = self._config.hpi_info.backend_configs.tensorrt
+                trt_info = None
+                if self._config.hpi_info:
+                    hpi_info = self._config.hpi_info
+                    if hpi_info.backend_configs:
+                        trt_info = hpi_info.backend_configs.tensorrt
                 if trt_info is not None and trt_info.dynamic_shapes is not None:
                     trt_dynamic_shapes = trt_info.dynamic_shapes
                     logging.debug(
@@ -804,7 +814,7 @@ class HPInfer(StaticInfer):
                     }
             backend_config = TensorRTConfig.model_validate(backend_config)
             ui_option.use_trt_backend()
-            cache_dir = self.model_dir / CACHE_DIR / "tensorrt"
+            cache_dir = self._model_dir / CACHE_DIR / "tensorrt"
             cache_dir.mkdir(parents=True, exist_ok=True)
             ui_option.trt_option.serialize_file = str(cache_dir / "trt_serialized.trt")
             if backend_config.precision == "fp16":

+ 5 - 0
paddlex/inference/serving/basic_serving/_pipeline_apps/doc_understanding.py

@@ -45,6 +45,11 @@ def create_pipeline_app(pipeline: Any, app_config: AppConfig) -> "FastAPI":
 
     @primary_operation(
         app,
+        "/chat/completions",
+        "inferA",
+    )
+    @primary_operation(
+        app,
         INFER_ENDPOINT,
         "infer",
     )

+ 1 - 1
paddlex/inference/serving/schemas/doc_understanding.py

@@ -29,7 +29,7 @@ __all__ = [
     "PRIMARY_OPERATIONS",
 ]
 
-INFER_ENDPOINT: Final[str] = "/chat/completions"
+INFER_ENDPOINT: Final[str] = "/document-understanding"
 
 
 class ContentType(str, Enum):

+ 4 - 2
paddlex/inference/utils/hpi.py

@@ -177,8 +177,10 @@ def suggest_inference_backend_and_config(
     ]
 
     # XXX
-    if not ctypes.util.find_library("nvinfer") or (
-        USE_PIR_TRT and importlib.util.find_spec("tensorrt") is None
+    if (
+        USE_PIR_TRT
+        and importlib.util.find_spec("tensorrt")
+        and ctypes.util.find_library("nvinfer")
     ):
         if "paddle_tensorrt" in supported_pseudo_backends:
             supported_pseudo_backends.remove("paddle_tensorrt")

+ 7 - 12
paddlex/inference/utils/model_paths.py

@@ -16,7 +16,7 @@ from os import PathLike
 from pathlib import Path
 from typing import Tuple, TypedDict, Union
 
-from ...utils.flags import FLAGS_json_format_model
+from ...constants import MODEL_FILE_PREFIX
 
 
 class ModelPaths(TypedDict, total=False):
@@ -26,21 +26,16 @@ class ModelPaths(TypedDict, total=False):
 
 
 def get_model_paths(
-    model_dir: Union[str, PathLike], model_file_prefix: str
+    model_dir: Union[str, PathLike],
+    model_file_prefix: str = MODEL_FILE_PREFIX,
 ) -> ModelPaths:
     model_dir = Path(model_dir)
     model_paths: ModelPaths = {}
     pd_model_path = None
-    if FLAGS_json_format_model:
-        if (model_dir / f"{model_file_prefix}.json").exists():
-            pd_model_path = model_dir / f"{model_file_prefix}.json"
-        elif (model_dir / f"{model_file_prefix}.pdmodel").exists():
-            pd_model_path = model_dir / f"{model_file_prefix}.pdmodel"
-    else:
-        if (model_dir / f"{model_file_prefix}.json").exists():
-            pd_model_path = model_dir / f"{model_file_prefix}.json"
-        elif (model_dir / f"{model_file_prefix}.pdmodel").exists():
-            pd_model_path = model_dir / f"{model_file_prefix}.pdmodel"
+    if (model_dir / f"{model_file_prefix}.json").exists():
+        pd_model_path = model_dir / f"{model_file_prefix}.json"
+    elif (model_dir / f"{model_file_prefix}.pdmodel").exists():
+        pd_model_path = model_dir / f"{model_file_prefix}.pdmodel"
     if pd_model_path and (model_dir / f"{model_file_prefix}.pdiparams").exists():
         model_paths["paddle"] = (
             pd_model_path,

+ 1 - 1
paddlex/inference/utils/pp_option.py

@@ -76,7 +76,7 @@ class PaddlePredictorOption(object):
             self._cfg.setdefault(k, v)
 
         # for trt
-        if self.run_mode in TRT_PRECISION_MAP:
+        if self.run_mode in ("trt_int8", "trt_fp32", "trt_fp16"):
             trt_cfg_setting = TRT_CFG_SETTING[self.model_name]
             if USE_PIR_TRT:
                 trt_cfg_setting["precision_mode"] = TRT_PRECISION_MAP[self.run_mode]

+ 12 - 20
paddlex/paddlex_cli.py

@@ -24,6 +24,7 @@ from pathlib import Path
 from . import create_pipeline
 from .constants import MODEL_FILE_PREFIX
 from .inference.pipelines import load_pipeline_config
+from .inference.utils.model_paths import get_model_paths
 from .repo_manager import get_all_supported_repo_names, setup
 from .utils import logging
 from .utils.deps import (
@@ -32,7 +33,6 @@ from .utils.deps import (
     require_paddle2onnx_plugin,
 )
 from .utils.env import get_cuda_version
-from .utils.flags import FLAGS_json_format_model
 from .utils.install import install_packages
 from .utils.interactive_get_pipeline import interactive_get_pipeline
 from .utils.pipeline_arguments import PIPELINE_ARGUMENTS
@@ -348,25 +348,20 @@ def serve(pipeline, *, device, use_hpip, hpi_config, host, port):
 def paddle_to_onnx(paddle_model_dir, onnx_model_dir, *, opset_version):
     require_paddle2onnx_plugin()
 
-    PD_MODEL_FILE_PREFIX = MODEL_FILE_PREFIX
-    PD_PARAMS_FILENAME = f"{MODEL_FILE_PREFIX}.pdiparams"
     ONNX_MODEL_FILENAME = f"{MODEL_FILE_PREFIX}.onnx"
     CONFIG_FILENAME = f"{MODEL_FILE_PREFIX}.yml"
     ADDITIONAL_FILENAMES = ["scaler.pkl"]
 
-    def _check_input_dir(input_dir, pd_model_file_ext):
+    def _check_input_dir(input_dir):
         if input_dir is None:
             sys.exit("Input directory must be specified")
         if not input_dir.exists():
             sys.exit(f"{input_dir} does not exist")
         if not input_dir.is_dir():
             sys.exit(f"{input_dir} is not a directory")
-        model_path = (input_dir / PD_MODEL_FILE_PREFIX).with_suffix(pd_model_file_ext)
-        if not model_path.exists():
-            sys.exit(f"{model_path} does not exist")
-        params_path = input_dir / PD_PARAMS_FILENAME
-        if not params_path.exists():
-            sys.exit(f"{params_path} does not exist")
+        model_paths = get_model_paths(input_dir)
+        if "paddle" not in model_paths:
+            sys.exit("PaddlePaddle model does not exist")
         config_path = input_dir / CONFIG_FILENAME
         if not config_path.exists():
             sys.exit(f"{config_path} does not exist")
@@ -375,17 +370,18 @@ def paddle_to_onnx(paddle_model_dir, onnx_model_dir, *, opset_version):
         if shutil.which("paddle2onnx") is None:
             sys.exit("Paddle2ONNX is not available. Please install the plugin first.")
 
-    def _run_paddle2onnx(input_dir, pd_model_file_ext, output_dir, opset_version):
+    def _run_paddle2onnx(input_dir, output_dir, opset_version):
+        model_paths = get_model_paths(input_dir)
         logging.info("Paddle2ONNX conversion starting...")
         # XXX: To circumvent Paddle2ONNX's bug
         cmd = [
             "paddle2onnx",
             "--model_dir",
-            str(input_dir),
+            str(model_paths["paddle"][0].parent),
             "--model_filename",
-            str(Path(PD_MODEL_FILE_PREFIX).with_suffix(pd_model_file_ext)),
+            str(model_paths["paddle"][0].name),
             "--params_filename",
-            PD_PARAMS_FILENAME,
+            str(model_paths["paddle"][1].name),
             "--save_file",
             str(output_dir / ONNX_MODEL_FILENAME),
             "--opset_version",
@@ -418,13 +414,9 @@ def paddle_to_onnx(paddle_model_dir, onnx_model_dir, *, opset_version):
     onnx_model_dir = Path(onnx_model_dir)
     logging.info(f"Input dir: {paddle_model_dir}")
     logging.info(f"Output dir: {onnx_model_dir}")
-    pd_model_file_ext = ".json"
-    if not FLAGS_json_format_model:
-        if not (paddle_model_dir / f"{PD_MODEL_FILE_PREFIX}.json").exists():
-            pd_model_file_ext = ".pdmodel"
-    _check_input_dir(paddle_model_dir, pd_model_file_ext)
+    _check_input_dir(paddle_model_dir)
     _check_paddle2onnx()
-    _run_paddle2onnx(paddle_model_dir, pd_model_file_ext, onnx_model_dir, opset_version)
+    _run_paddle2onnx(paddle_model_dir, onnx_model_dir, opset_version)
     if not (onnx_model_dir.exists() and onnx_model_dir.samefile(paddle_model_dir)):
         _copy_config_file(paddle_model_dir, onnx_model_dir)
         _copy_additional_files(paddle_model_dir, onnx_model_dir)

+ 1 - 1
paddlex/utils/flags.py

@@ -49,7 +49,7 @@ DRY_RUN = get_flag_from_env_var("PADDLE_PDX_DRY_RUN", False)
 CHECK_OPTS = get_flag_from_env_var("PADDLE_PDX_CHECK_OPTS", False)
 EAGER_INITIALIZATION = get_flag_from_env_var("PADDLE_PDX_EAGER_INIT", True)
 FLAGS_json_format_model = get_flag_from_env_var("FLAGS_json_format_model", True)
-USE_PIR_TRT = get_flag_from_env_var("PADDLE_PDX_USE_PIR_TRT", False)
+USE_PIR_TRT = get_flag_from_env_var("PADDLE_PDX_USE_PIR_TRT", True)
 DISABLE_DEV_MODEL_WL = get_flag_from_env_var("PADDLE_PDX_DISABLE_DEV_MODEL_WL", False)
 DISABLE_CINN_MODEL_WL = get_flag_from_env_var("PADDLE_PDX_DISABLE_CINN_MODEL_WL", False)