Sfoglia il codice sorgente

support pp-chatocrv3 pipeline

zhouchangda 1 anno fa
parent
commit
b7b63665d3

+ 22 - 0
paddlex/inference/components/llm/__init__.py

@@ -0,0 +1,22 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from .base import BaseLLM
+from .erniebot import ErnieBot
+
+
+def create_llm_api(model_name: str, params={}) -> BaseLLM:
+    return BaseLLM.get(model_name)(
+        model_name=model_name,
+        params=params,
+    )

+ 65 - 0
paddlex/inference/components/llm/base.py

@@ -0,0 +1,65 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import base64
+from ..base import BaseComponent
+from ....utils.subclass_register import AutoRegisterABCMetaClass
+
+__all__ = ["BaseLLM"]
+
+
+class BaseLLM(BaseComponent, metaclass=AutoRegisterABCMetaClass):
+    __is_base = True
+
+    ERROR_MASSAGE = ""
+    VECTOR_STORE_PREFIX = "PADDLEX_VECTOR_STORE"
+
+    def __init__(self):
+        super().__init__()
+
+    def pre_process(self, inputs):
+        return inputs
+
+    def post_process(self, outputs):
+        return outputs
+
+    def pred(self, inputs):
+        raise NotImplementedError("The method `pred` has not been implemented yet.")
+
+    def get_vector(self):
+        raise NotImplementedError(
+            "The method `get_vector` has not been implemented yet."
+        )
+
+    def caculate_similar(self):
+        raise NotImplementedError(
+            "The method `caculate_similar` has not been implemented yet."
+        )
+
+    def apply(self, inputs):
+        pre_process_results = self.pre_process(inputs)
+        pred_results = self.pred(pre_process_results)
+        post_process_results = self.post_process(pred_results)
+        return post_process_results
+
+    def is_vector_store(self, s):
+        return s.startswith(self.VECTOR_STORE_PREFIX)
+
+    def encode_vector_store(self, vector_store_bytes):
+        return self.VECTOR_STORE_PREFIX + base64.b64encode(vector_store_bytes).decode(
+            "ascii"
+        )
+
+    def decode_vector_store(self, vector_store_str):
+        return base64.b64decode(vector_store_str[len(self.VECTOR_STORE_PREFIX) :])

+ 208 - 0
paddlex/inference/components/llm/erniebot.py

@@ -0,0 +1,208 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import time
+import json
+import erniebot
+
+from pathlib import Path
+from .base import BaseLLM
+from ....utils import logging
+from ....utils.func_register import FuncRegister
+
+from langchain.docstore.document import Document
+from langchain.text_splitter import RecursiveCharacterTextSplitter
+
+from langchain_community.embeddings import QianfanEmbeddingsEndpoint
+from langchain_community.vectorstores import FAISS
+from langchain_community import vectorstores
+from erniebot_agent.extensions.langchain.embeddings import ErnieEmbeddings
+
+__all__ = ["ErnieBot"]
+
+
+class ErnieBot(BaseLLM):
+
+    INPUT_KEYS = ["prompts"]
+    OUTPUT_KEYS = ["cls_res"]
+    DEAULT_INPUTS = {"prompts": "prompts"}
+    DEAULT_OUTPUTS = {"cls_pred": "cls_pred"}
+    API_TYPE = "aistudio"
+
+    entities = [
+        "ernie-4.0",
+        "ernie-3.5",
+        "ernie-3.5-8k",
+        "ernie-lite",
+        "ernie-tiny-8k",
+        "ernie-speed",
+        "ernie-speed-128k",
+        "ernie-char-8k",
+    ]
+
+    _FUNC_MAP = {}
+    register = FuncRegister(_FUNC_MAP)
+
+    def __init__(self, model_name="ernie-4.0", params={}):
+        super().__init__()
+        access_token = params.get("access_token")
+        ak = params.get("ak")
+        sk = params.get("sk")
+        api_type = params.get("api_type")
+        max_retries = params.get("max_retries")
+        assert model_name in self.entities, f"model_name must be in {self.entities}"
+        assert any([access_token, ak, sk]), "access_token or ak and sk must be set"
+        self.model_name = model_name
+        self.config = {
+            "api_type": api_type,
+            "max_retries": max_retries,
+        }
+        if access_token:
+            self.config["access_token"] = access_token
+        else:
+            self.config["ak"] = ak
+            self.config["sk"] = sk
+
+    def pred(self, prompt, temperature=0.001):
+        """
+        llm predict
+        """
+        try:
+            chat_completion = erniebot.ChatCompletion.create(
+                _config_=self.config,
+                model=self.model_name,
+                messages=[{"role": "user", "content": prompt}],
+                temperature=float(temperature),
+            )
+            llm_result = chat_completion.get_result()
+            return llm_result
+        except Exception as e:
+            if len(e.args) < 1:
+                self.ERROR_MASSAGE = (
+                    "当前选择后端为AI Studio,千帆调用失败,请检查token"
+                )
+            elif (
+                e.args[-1]
+                == "暂无权限使用,请在 AI Studio 正确获取访问令牌(access token)使用"
+            ):
+                self.ERROR_MASSAGE = (
+                    "当前选择后端为AI Studio,请正确获取访问令牌(access token)使用"
+                )
+            elif e.args[-1] == "the max length of current question is 4800":
+                self.ERROR_MASSAGE = "大模型调用失败"
+            else:
+                logging.error(e)
+                self.ERROR_MASSAGE = "大模型调用失败"
+        return None
+
+    def get_vector(
+        self,
+        ocr_result,
+        sleep_time=0.5,
+        block_size=300,
+        separators=["\t", "\n", "。", "\n\n", ""],
+    ):
+        """get summary prompt"""
+
+        all_items = []
+        for i, ocr_res in enumerate(ocr_result):
+            for type, text in ocr_res.items():
+                all_items += [f"第{i}页{type}:{text}"]
+
+        text_splitter = RecursiveCharacterTextSplitter(
+            chunk_size=block_size, chunk_overlap=20, separators=separators
+        )
+        texts = text_splitter.split_text("\t".join(all_items))
+
+        all_splits = [Document(page_content=text) for text in texts]
+
+        api_type = self.config["api_type"]
+        if api_type == "qianfan":
+            os.environ["QIANFAN_AK"] = os.environ.get("EB_AK", self.config["ak"])
+            os.environ["QIANFAN_SK"] = os.environ.get("EB_SK", self.config["sk"])
+            user_ak = os.environ.get("EB_AK", self.config["ak"])
+            user_id = hash(user_ak)
+            vectorstore = FAISS.from_documents(
+                documents=all_splits, embedding=QianfanEmbeddingsEndpoint()
+            )
+
+        elif api_type == "aistudio":
+            token = self.config["access_token"]
+            vectorstore = FAISS.from_documents(
+                documents=all_splits[0:1],
+                embedding=ErnieEmbeddings(aistudio_access_token=token),
+            )
+
+            #### ErnieEmbeddings.chunk_size = 16
+            step = min(16, len(all_splits) - 1)
+            for shot_splits in [
+                all_splits[i : i + step] for i in range(1, len(all_splits), step)
+            ]:
+                time.sleep(sleep_time)
+                vectorstore_slice = FAISS.from_documents(
+                    documents=shot_splits,
+                    embedding=ErnieEmbeddings(aistudio_access_token=token),
+                )
+                vectorstore.merge_from(vectorstore_slice)
+        else:
+            raise ValueError(f"Unsupported api_type: {api_type}")
+
+        vectorstore = self.encode_vector_store(vectorstore.serialize_to_bytes())
+        return vectorstore
+
+    def caculate_similar(self, vector, key_list, llm_params=None, sleep_time=0.5):
+        """caculate similar with key and doc"""
+        if self.is_vector_store(vector):
+            # XXX: The initialization parameters are hard-coded.
+            if llm_params:
+                api_type = llm_params.get("api_type")
+                access_token = llm_params.get("access_token")
+                ak = llm_params.get("ak")
+                sk = llm_params.get("sk")
+            else:
+                api_type = self.config["api_type"]
+                access_token = self.config.get("access_token")
+                ak = self.config.get("ak")
+                sk = self.config.get("sk")
+            if api_type == "aistudio":
+                embeddings = ErnieEmbeddings(aistudio_access_token=access_token)
+            elif api_type == "qianfan":
+                embeddings = QianfanEmbeddingsEndpoint(qianfan_ak=ak, qianfan_sk=sk)
+            else:
+                raise ValueError(f"Unsupported api_type: {api_type}")
+
+        vectorstore = vectorstores.FAISS.deserialize_from_bytes(
+            self.decode_vector_store(vector), embeddings
+        )
+
+        # 根据提问匹配上下文
+        Q = []
+        C = []
+        for key in key_list:
+            QUESTION = f"抽取关键信息:{key}"
+            # c_str = ""
+            Q.append(QUESTION)
+            time.sleep(sleep_time)
+            docs = vectorstore.similarity_search_with_relevance_scores(QUESTION, k=2)
+            context = [(document.page_content, score) for document, score in docs]
+            context = sorted(context, key=lambda x: x[1])
+            C.extend([x[0] for x in context[::-1]])
+
+        C = list(set(C))
+        all_C = " ".join(C)
+
+        summary_prompt = all_C
+
+        return summary_prompt

+ 5 - 2
paddlex/inference/components/task_related/text_det.py

@@ -270,8 +270,11 @@ class DBPostProcess(BaseComponent):
                 continue
             box = box.reshape(-1, 2)
 
-            _, sside = self.get_mini_boxes(box.reshape((-1, 1, 2)))
-            if sside < self.min_size + 2:
+            if len(box) > 0:
+                _, sside = self.get_mini_boxes(box.reshape((-1, 1, 2)))
+                if sside < self.min_size + 2:
+                    continue
+            else:
                 continue
 
             box = np.array(box)

+ 1 - 1
paddlex/inference/components/transforms/image/common.py

@@ -92,7 +92,7 @@ class ReadImage(_BaseRead):
     def apply(self, img):
         """apply"""
         if not isinstance(img, str):
-            with tempfile.NamedTemporaryFile(suffix=".png", delete=True) as temp_file:
+            with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
                 img_path = Path(temp_file.name)
                 self._writer.write(img_path, img)
                 yield [

+ 1 - 0
paddlex/inference/pipelines/__init__.py

@@ -32,6 +32,7 @@ from .single_model_pipeline import (
 )
 from .ocr import OCRPipeline
 from .table_recognition import TableRecPipeline
+from .ppchatocrv3 import PPChatOCRPipeline
 
 
 def create_pipeline(

+ 24 - 11
paddlex/inference/pipelines/ocr.py

@@ -15,6 +15,7 @@
 from ..components import SortBoxes, CropByPolys
 from ..results import OCRResult
 from .base import BasePipeline
+from ...utils import logging
 
 
 class OCRPipeline(BasePipeline):
@@ -22,15 +23,22 @@ class OCRPipeline(BasePipeline):
 
     entities = "OCR"
 
-    def __init__(self, det_model, rec_model, batch_size=1, predictor_kwargs=None):
+    def __init__(
+        self,
+        text_det_model,
+        text_rec_model,
+        text_det_batch_size=1,
+        text_rec_batch_size=1,
+        predictor_kwargs=None,
+    ):
         super().__init__(predictor_kwargs=predictor_kwargs)
-        self._build_predictor(det_model, rec_model)
-        self.set_predictor(batch_size)
+        self._build_predictor(text_det_model, text_rec_model)
+        self.set_predictor(text_det_batch_size, text_rec_batch_size)
 
-    def _build_predictor(self, det_model, rec_model):
-        self.det_model = self._create_model(det_model)
-        self.rec_model = self._create_model(rec_model)
-        self.is_curve = self.det_model.model_name in [
+    def _build_predictor(self, text_det_model, text_rec_model):
+        self.text_det_model = self._create_model(text_det_model)
+        self.text_rec_model = self._create_model(text_rec_model)
+        self.is_curve = self.text_det_model.model_name in [
             "PP-OCRv4_mobile_seal_det",
             "PP-OCRv4_server_seal_det",
         ]
@@ -39,12 +47,17 @@ class OCRPipeline(BasePipeline):
             det_box_type="poly" if self.is_curve else "quad"
         )
 
-    def set_predictor(self, batch_size):
-        self.rec_model.set_predictor(batch_size=batch_size)
+    def set_predictor(self, text_det_batch_size=None, text_rec_batch_size=None):
+        if text_det_batch_size and text_det_batch_size > 1:
+            logging.warning(
+                f"text det model only support batch_size=1 now,the setting of text_det_batch_size={text_det_batch_size} will not using! "
+            )
+        if text_rec_batch_size:
+            self.text_rec_model.set_predictor(batch_size=text_rec_batch_size)
 
     def predict(self, input, **kwargs):
         device = kwargs.get("device", "gpu")
-        for det_res in self.det_model(
+        for det_res in self.text_det_model(
             input, batch_size=kwargs.get("det_batch_size", 1), device=device
         ):
             single_img_res = (
@@ -54,7 +67,7 @@ class OCRPipeline(BasePipeline):
             single_img_res["rec_score"] = []
             if len(single_img_res["dt_polys"]) > 0:
                 all_subs_of_img = list(self._crop_by_polys(single_img_res))
-                for rec_res in self.rec_model(
+                for rec_res in self.text_rec_model(
                     all_subs_of_img,
                     batch_size=kwargs.get("rec_batch_size", 1),
                     device=device,

+ 15 - 0
paddlex/inference/pipelines/ppchatocrv3/__init__.py

@@ -0,0 +1,15 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .ppchatocrv3 import PPChatOCRPipeline

+ 14 - 0
paddlex/inference/pipelines/ppchatocrv3/ch_prompt.yaml

@@ -0,0 +1,14 @@
+kie_common_prompt:
+  task_description: '你现在的任务是从OCR文字识别的结果中提取关键词列表中每一项对应的关键信息。
+      OCR的文字识别结果使用```符号包围,包含所识别出来的文字,顺序在原始图片中从左至右、从上至下。
+      我指定的关键词列表使用[]符号包围。请注意OCR的文字识别结果可能存在长句子换行被切断、不合理的分词、
+      文字被错误合并等问题,你需要结合上下文语义进行综合判断,以抽取准确的关键信息。'
+  output_format: '在返回结果时使用JSON格式,包含多个key-value对,key值为我指定的关键词,value值为所抽取的结果。
+      如果认为OCR识别结果中没有关键词key对应的value,则将value赋值为"未知"。请只输出json格式的结果,
+      并做json格式校验后返回,不要包含其它多余文字!'
+kie_table_prompt:
+  task_description: '你现在的任务是从输入的html格式的表格内容中提取关键词列表中每一项对应的关键信息,
+      表格内容用```符号包围,我指定的关键词列表使用[]符号包围。你需要结合上下文语义进行综合判断,以抽取准确的关键信息。
+      在返回结果时使用JSON格式,包含多个key-value对,key值为我指定的关键词,value值为所抽取的结果。
+      如果认为输入的表格内容中没有关键词key对应的value值,则将value赋值为"未知"。
+      请只输出json格式的结果,并做json格式校验后返回,不要包含其它多余文字!'

+ 690 - 0
paddlex/inference/pipelines/ppchatocrv3/ppchatocrv3.py

@@ -0,0 +1,690 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import re
+import json
+import numpy as np
+from .utils import *
+from copy import deepcopy
+from ...components import *
+from ..ocr import OCRPipeline
+from ....utils import logging
+from ...results import (
+    TableResult,
+    LayoutStructureResult,
+    VisualInfoResult,
+    ChatOCRResult,
+)
+from ...components.llm import ErnieBot
+from ...utils.io import ImageReader, PDFReader
+from ..table_recognition import TableRecPipeline
+from ...components.llm import create_llm_api, ErnieBot
+from ....utils.file_interface import read_yaml_file
+from ..table_recognition.utils import convert_4point2rect, get_ori_coordinate_for_table
+
+PROMPT_FILE = os.path.join(os.path.dirname(__file__), "ch_prompt.yaml")
+
+
+class PPChatOCRPipeline(TableRecPipeline):
+    """PP-ChatOCRv3 Pileline"""
+
+    entities = "chatocrv3"
+
+    def __init__(
+        self,
+        layout_model,
+        text_det_model,
+        text_rec_model,
+        table_model,
+        oricls_model=None,
+        uvdoc_model=None,
+        curve_model=None,
+        llm_name="ernie-3.5",
+        llm_params={},
+        task_prompt_yaml=None,
+        user_prompt_yaml=None,
+        layout_batch_size=1,
+        text_det_batch_size=1,
+        text_rec_batch_size=1,
+        table_batch_size=1,
+        uvdoc_batch_size=1,
+        curve_batch_size=1,
+        oricls_batch_size=1,
+        recovery=True,
+        device="gpu",
+        predictor_kwargs=None,
+    ):
+        self.layout_model = layout_model
+        self.text_det_model = text_det_model
+        self.text_rec_model = text_rec_model
+        self.table_model = table_model
+        self.oricls_model = oricls_model
+        self.uvdoc_model = uvdoc_model
+        self.curve_model = curve_model
+        self.llm_name = llm_name
+        self.llm_params = llm_params
+        self.task_prompt_yaml = task_prompt_yaml
+        self.user_prompt_yaml = user_prompt_yaml
+        self.layout_batch_size = layout_batch_size
+        self.text_det_batch_size = text_det_batch_size
+        self.text_rec_batch_size = text_rec_batch_size
+        self.table_batch_size = table_batch_size
+        self.uvdoc_batch_size = uvdoc_batch_size
+        self.curve_batch_size = curve_batch_size
+        self.oricls_batch_size = oricls_batch_size
+        self.recovery = recovery
+        self.device = device
+        self.predictor_kwargs = predictor_kwargs
+        super().__init__(
+            layout_model=layout_model,
+            text_det_model=text_det_model,
+            text_rec_model=text_rec_model,
+            table_model=table_model,
+            layout_batch_size=layout_batch_size,
+            text_det_batch_size=text_det_batch_size,
+            text_rec_batch_size=text_rec_batch_size,
+            table_batch_size=table_batch_size,
+            predictor_kwargs=predictor_kwargs,
+        )
+        self._build_predictor()
+        self.llm_api = create_llm_api(
+            llm_name,
+            llm_params,
+        )
+        self.cropper = CropByBoxes()
+        # get base prompt from yaml info
+        if task_prompt_yaml:
+            self.task_prompt_dict = read_yaml_file(task_prompt_yaml)
+        else:
+            self.task_prompt_dict = read_yaml_file(
+                PROMPT_FILE
+            )  # get user prompt from yaml info
+        if user_prompt_yaml:
+            self.user_prompt_dict = read_yaml_file(user_prompt_yaml)
+        else:
+            self.user_prompt_dict = None
+        self.recovery = recovery
+        self.img_reader = ReadImage()
+        self.pdf_reader = PDFReader()
+        self.visual_info = None
+        self.vector = None
+        self._set_predictor(
+            oricls_batch_size, uvdoc_batch_size, curve_batch_size, device=device
+        )
+
+    def _build_predictor(self):
+        super()._build_predictor()
+        if self.curve_model:
+            self.curve_pipeline = OCRPipeline(
+                text_det_model=self.curve_model,
+                text_rec_model=self.text_rec_model,
+                predictor_kwargs=self.predictor_kwargs,
+            )
+        else:
+            self.curve_pipeline = None
+        if self.oricls_model:
+            self.oricls_predictor = self._create_model(self.oricls_model)
+        else:
+            self.oricls_predictor = None
+        if self.uvdoc_model:
+            self.uvdoc_predictor = self._create_model(self.uvdoc_model)
+        else:
+            self.uvdoc_predictor = None
+        if self.curve_pipeline and self.curve_batch_size:
+            self.curve_pipeline.text_det_model.set_predictor(
+                batch_size=self.curve_batch_size, device=self.device
+            )
+        if self.oricls_predictor and self.oricls_batch_size:
+            self.oricls_predictor.set_predictor(
+                batch_size=self.oricls_batch_size, device=self.device
+            )
+        if self.uvdoc_predictor and self.uvdoc_batch_size:
+            self.uvdoc_predictor.set_predictor(
+                batch_size=self.uvdoc_batch_size, device=self.device
+            )
+
+    def _set_predictor(
+        self, curve_batch_size, oricls_batch_size, uvdoc_batch_size, device
+    ):
+        if self.curve_pipeline and curve_batch_size:
+            self.curve_pipeline.text_det_model.set_predictor(
+                batch_size=curve_batch_size, device=device
+            )
+        if self.oricls_predictor and oricls_batch_size:
+            self.oricls_predictor.set_predictor(
+                batch_size=oricls_batch_size, device=device
+            )
+        if self.uvdoc_predictor and uvdoc_batch_size:
+            self.uvdoc_predictor.set_predictor(
+                batch_size=uvdoc_batch_size, device=device
+            )
+
+    def predict(self, input, **kwargs):
+        visual_info = {"ocr_text": [], "table_html": [], "table_text": []}
+        # get all visual result
+        visual_result = list(self.get_visual_result(input, **kwargs))
+        # decode visual result to get table_html, table_text, ocr_text
+        ocr_text, table_text, table_html = self.decode_visual_result(visual_result)
+
+        visual_info["ocr_text"] = ocr_text
+        visual_info["table_html"] = table_html
+        visual_info["table_text"] = table_text
+        visual_info = VisualInfoResult(visual_info)
+        # for local user save visual info in self
+        self.visual_info = visual_info
+
+        return visual_result, visual_info
+
+    def get_visual_result(self, inputs, **kwargs):
+        curve_batch_size = kwargs.get("curve_batch_size")
+        oricls_batch_size = kwargs.get("oricls_batch_size")
+        uvdoc_batch_size = kwargs.get("uvdoc_batch_size")
+        device = kwargs.get("device")
+        self._set_predictor(
+            curve_batch_size, oricls_batch_size, uvdoc_batch_size, device
+        )
+        input_imgs = []
+        img_list = []
+        for file in inputs:
+            if isinstance(file, str) and file.endswith(".pdf"):
+                img_list = self.pdf_reader.read(file)
+                for page, img in enumerate(img_list):
+                    input_imgs.append(
+                        {
+                            "input_path": f"{Path(file).parent}/{Path(file).stem}_{page}.jpg",
+                            "img": img,
+                        }
+                    )
+            else:
+                for imgs in self.img_reader(file):
+                    input_imgs.extend(imgs)
+        # get oricls and uvdoc results
+        oricls_results = []
+        if self.oricls_predictor and kwargs.get("use_oricls_model", True):
+            img_list = [img["img"] for img in input_imgs]
+            oricls_results = get_oriclas_results(
+                input_imgs, self.oricls_predictor, img_list
+            )
+        uvdoc_results = []
+        if self.uvdoc_predictor and kwargs.get("use_uvdoc_model", True):
+            img_list = [img["img"] for img in input_imgs]
+            uvdoc_results = get_uvdoc_results(
+                input_imgs, self.uvdoc_predictor, img_list
+            )
+        img_list = [img["img"] for img in input_imgs]
+        for idx, (input_img, layout_pred) in enumerate(
+            zip(input_imgs, self.layout_predictor(img_list))
+        ):
+            single_img_res = {
+                "input_path": "",
+                "layout_result": {},
+                "ocr_result": {},
+                "table_ocr_result": [],
+                "table_result": [],
+                "structure_result": [],
+                "structure_result": [],
+                "oricls_result": {},
+                "uvdoc_result": {},
+                "curve_result": [],
+            }
+            # update oricls and uvdoc result
+            if oricls_results:
+                single_img_res["oricls_result"] = oricls_results[idx]
+            if uvdoc_results:
+                single_img_res["uvdoc_result"] = uvdoc_results[idx]
+            # update layout result
+            single_img_res["input_path"] = layout_pred["input_path"]
+            single_img_res["layout_result"] = layout_pred
+            single_img = input_img["img"]
+            if len(layout_pred["boxes"]) > 0:
+                subs_of_img = list(self._crop_by_boxes(layout_pred))
+                # get cropped images with label "table"
+                table_subs = []
+                curve_subs = []
+                structure_res = []
+                ocr_res_with_layout = []
+                for sub in subs_of_img:
+                    box = sub["box"]
+                    xmin, ymin, xmax, ymax = [int(i) for i in box]
+                    mask_flag = True
+                    if sub["label"].lower() == "table":
+                        table_subs.append(sub)
+                    elif sub["label"].lower() == "seal":
+                        curve_subs.append(sub)
+                    else:
+                        if self.recovery and kwargs.get("recovery", True):
+                            # TODO: Why use the entire image?
+                            wht_im = (
+                                np.ones(single_img.shape, dtype=single_img.dtype) * 255
+                            )
+                            wht_im[ymin:ymax, xmin:xmax, :] = sub["img"]
+                            sub_ocr_res = get_ocr_res(self.ocr_pipeline, wht_im)
+                        else:
+                            sub_ocr_res = get_ocr_res(self.ocr_pipeline, sub)
+                            sub_ocr_res["dt_polys"] = get_ori_coordinate_for_table(
+                                xmin, ymin, sub_ocr_res["dt_polys"]
+                            )
+                        layout_label = sub["label"].lower()
+                        if sub_ocr_res and sub["label"].lower() in [
+                            "image",
+                            "figure",
+                            "img",
+                            "fig",
+                        ]:
+                            mask_flag = False
+                        else:
+                            ocr_res_with_layout.append(sub_ocr_res)
+                            structure_res.append(
+                                {
+                                    "layout_bbox": box,
+                                    f"{layout_label}": "\n".join(
+                                        sub_ocr_res["rec_text"]
+                                    ),
+                                }
+                            )
+                    if mask_flag:
+                        single_img[ymin:ymax, xmin:xmax, :] = 255
+
+            curve_pipeline = self.ocr_pipeline
+            if self.curve_pipeline and kwargs.get("use_curve_model", True):
+                curve_pipeline = self.curve_pipeline
+
+            all_curve_res = get_ocr_res(curve_pipeline, curve_subs)
+            single_img_res["curve_result"] = all_curve_res
+
+            for sub, curve_res in zip(curve_subs, all_curve_res):
+                structure_res.append(
+                    {
+                        "layout_bbox": sub["box"],
+                        "印章": "".join(curve_res["rec_text"]),
+                    }
+                )
+
+            ocr_res = get_ocr_res(self.ocr_pipeline, single_img)
+            ocr_res["input_path"] = layout_pred["input_path"]
+            all_table_res, _ = self.get_table_result(table_subs)
+            for idx, single_dt_poly in enumerate(ocr_res["dt_polys"]):
+                structure_res.append(
+                    {
+                        "layout_bbox": convert_4point2rect(single_dt_poly),
+                        "words in text block": ocr_res["rec_text"][idx],
+                    }
+                )
+            # update ocr result
+            for layout_ocr_res in ocr_res_with_layout:
+                ocr_res["dt_polys"].extend(layout_ocr_res["dt_polys"])
+                ocr_res["rec_text"].extend(layout_ocr_res["rec_text"])
+                ocr_res["input_path"] = single_img_res["input_path"]
+
+            all_table_ocr_res = []
+            # get table text from html
+            structure_res_table, all_table_ocr_res = get_table_text_from_html(
+                all_table_res
+            )
+            structure_res.extend(structure_res_table)
+
+            # sort the layout result by the left top point of the box
+            structure_res = sorted_layout_boxes(structure_res, w=single_img.shape[1])
+            structure_res = [LayoutStructureResult(item) for item in structure_res]
+
+            single_img_res["table_result"] = all_table_res
+            single_img_res["ocr_result"] = ocr_res
+            single_img_res["table_ocr_result"] = all_table_ocr_res
+            single_img_res["structure_result"] = structure_res
+
+            yield ChatOCRResult(single_img_res)
+
+    def decode_visual_result(self, visual_result):
+        ocr_text = []
+        table_text_list = []
+        table_html = []
+        for single_img_pred in visual_result:
+            layout_res = single_img_pred["structure_result"]
+            layout_res_copy = deepcopy(layout_res)
+            # layout_res is [{"layout_bbox": [x1, y1, x2, y2], "layout": "single","words in text block":"xxx"}, {"layout_bbox": [x1, y1, x2, y2], "layout": "double","印章":"xxx"}
+            ocr_res = {}
+            for block in layout_res_copy:
+                block.pop("layout_bbox")
+                block.pop("layout")
+                for layout_type, text in block.items():
+                    if text == "":
+                        continue
+                    # Table results are used separately
+                    if layout_type == "table":
+                        continue
+                    if layout_type not in ocr_res:
+                        ocr_res[layout_type] = text
+                    else:
+                        ocr_res[layout_type] += f"\n {text}"
+
+            single_table_text = " ".join(single_img_pred["table_ocr_result"])
+            for table_pred in single_img_pred["table_result"]:
+                html = table_pred["html"]
+                table_html.append(html)
+            if ocr_res:
+                ocr_text.append(ocr_res)
+            table_text_list.append(single_table_text)
+
+        return ocr_text, table_text_list, table_html
+
+    def get_vector_text(
+        self,
+        llm_name=None,
+        llm_params={},
+        visual_info=None,
+        min_characters=0,
+        llm_request_interval=1.0,
+    ):
+        """get vector for ocr"""
+        if isinstance(self.llm_api, ErnieBot):
+            get_vector_flag = True
+        else:
+            logging.warning("Do not use ErnieBot, will not get vector text.")
+            get_vector_flag = False
+        if not any([visual_info, self.visual_info]):
+            return {"vector": None}
+
+        if visual_info:
+            # use for serving or local
+            _visual_info = visual_info
+        else:
+            # use for local
+            _visual_info = self.visual_info
+
+        ocr_text = _visual_info["ocr_text"]
+        html_list = _visual_info["table_html"]
+        table_text_list = _visual_info["table_text"]
+
+        # add table text to ocr text
+        for html, table_text_rec in zip(html_list, table_text_list):
+            if len(html) > 3000:
+                ocr_text.append({"table": table_text_rec})
+
+        ocr_all_result = "".join(["\n".join(e.values()) for e in ocr_text])
+
+        if len(ocr_all_result) > min_characters and get_vector_flag:
+            if visual_info and llm_name:
+                # for serving or local
+                llm_api = create_llm_api(llm_name, llm_params)
+                text_result = llm_api.get_vector(ocr_text, llm_request_interval)
+            else:
+                # for local
+                text_result = self.llm_api.get_vector(ocr_text, llm_request_interval)
+        else:
+            text_result = str(ocr_text)
+
+        return {"vector": text_result}
+
+    def get_retrieval_text(
+        self,
+        key_list,
+        visual_info=None,
+        vector=None,
+        llm_name=None,
+        llm_params={},
+        llm_request_interval=0.1,
+    ):
+
+        if not any([visual_info, vector, self.visual_info, self.vector]):
+            return {"retrieval": None}
+
+        key_list = format_key(key_list)
+
+        if not any([vector, self.vector]):
+            logging.warning(
+                "The vector library is not created, and is being created automatically"
+            )
+            if visual_info and llm_name:
+                # for serving
+                vector = self.get_vector_text(
+                    llm_name=llm_name, llm_params=llm_params, visual_info=visual_info
+                )
+            else:
+                self.vector = self.get_vector_text()
+
+        if vector and llm_name:
+            _vector = vector["vector"]
+            llm_api = create_llm_api(llm_name, llm_params)
+            retrieval = llm_api.caculate_similar(
+                vector=_vector,
+                key_list=key_list,
+                llm_params=llm_params,
+                sleep_time=llm_request_interval,
+            )
+        else:
+            _vector = self.vector["vector"]
+            retrieval = self.llm_api.caculate_similar(
+                vector=_vector, key_list=key_list, sleep_time=llm_request_interval
+            )
+
+        return {"retrieval": retrieval}
+
+    def chat(
+        self,
+        key_list,
+        vector=None,
+        visual_info=None,
+        retrieval_result=None,
+        user_task_description="",
+        rules="",
+        few_shot="",
+        use_vector=True,
+        save_prompt=False,
+        llm_name="ernie-3.5",
+        llm_params={},
+    ):
+        """
+        chat with key
+
+        """
+        if not any(
+            [vector, visual_info, retrieval_result, self.visual_info, self.vector]
+        ):
+            return {"chat_res": "请先完成图像解析再开始再对话", "prompt": ""}
+        key_list = format_key(key_list)
+        # first get from table, then get from text in table, last get from all ocr
+        if visual_info:
+            # use for serving or local
+            _visual_info = visual_info
+        else:
+            # use for local
+            _visual_info = self.visual_info
+
+        ocr_text = _visual_info["ocr_text"]
+        html_list = _visual_info["table_html"]
+        table_text_list = _visual_info["table_text"]
+        if retrieval_result:
+            ocr_text = retrieval_result
+        elif use_vector and any([visual_info, vector]):
+            # for serving or local
+            ocr_text = self.get_retrieval_text(
+                key_list=key_list,
+                visual_info=visual_info,
+                vector=vector,
+                llm_name=llm_name,
+                llm_params=llm_params,
+            )
+        else:
+            # for local
+            ocr_text = self.get_retrieval_text(key_list=key_list)
+
+        prompt_res = {"ocr_prompt": "str", "table_prompt": [], "html_prompt": []}
+
+        final_results = {}
+        failed_results = ["大模型调用失败", "未知", "未找到关键信息", "None", ""]
+        if html_list:
+            prompt_list = self.get_prompt_for_table(
+                html_list, key_list, rules, few_shot
+            )
+            prompt_res["html_prompt"] = prompt_list
+            for prompt, table_text in zip(prompt_list, table_text_list):
+                logging.debug(prompt)
+                res = self.get_llm_result(prompt)
+                # TODO: why use one html but the whole table_text in next step
+                if list(res.values())[0] in failed_results:
+                    logging.info(
+                        "table html sequence is too much longer, using ocr directly"
+                    )
+                    prompt = self.get_prompt_for_ocr(
+                        table_text, key_list, rules, few_shot, user_task_description
+                    )
+                    logging.debug(prompt)
+                    prompt_res["table_prompt"].append(prompt)
+                    res = self.get_llm_result(prompt)
+                for key, value in res.items():
+                    if value not in failed_results and key in key_list:
+                        key_list.remove(key)
+                        final_results[key] = value
+        if len(key_list) > 0:
+            logging.info("get result from ocr")
+            prompt = self.get_prompt_for_ocr(
+                ocr_text,
+                key_list,
+                rules,
+                few_shot,
+                user_task_description,
+            )
+            logging.debug(prompt)
+            prompt_res["ocr_prompt"] = prompt
+            res = self.get_llm_result(prompt)
+            final_results.update(res)
+        if not res and not final_results:
+            final_results = self.llm_api.ERROR_MASSAGE
+        if save_prompt:
+            return {"chat_res": final_results, "prompt": prompt_res}
+        else:
+            return {"chat_res": final_results, "prompt": ""}
+
+    def get_llm_result(self, prompt):
+        """get llm result and decode to dict"""
+        llm_result = self.llm_api.pred(prompt)
+        # when the llm pred failed, return None
+        if not llm_result:
+            return None
+
+        if "json" in llm_result or "```" in llm_result:
+            llm_result = (
+                llm_result.replace("```", "").replace("json", "").replace("/n", "")
+            )
+
+            llm_result = llm_result.replace("[", "").replace("]", "")
+        try:
+            llm_result = json.loads(llm_result)
+            llm_result_final = {}
+            for key in llm_result:
+                value = llm_result[key]
+                if isinstance(value, list):
+                    if len(value) > 0:
+                        llm_result_final[key] = value[0]
+                else:
+                    llm_result_final[key] = value
+            return llm_result_final
+        except:
+            results = (
+                llm_result.replace("\n", "")
+                .replace("    ", "")
+                .replace("{", "")
+                .replace("}", "")
+            )
+            if not results.endswith('"'):
+                results = results + '"'
+            pattern = r'"(.*?)": "([^"]*)"'
+            matches = re.findall(pattern, str(results))
+            llm_result = {k: v for k, v in matches}
+            return llm_result
+
+    def get_prompt_for_table(self, table_result, key_list, rules="", few_shot=""):
+        """get prompt for table"""
+        prompt_key_information = []
+        merge_table = ""
+        for idx, result in enumerate(table_result):
+            if len(merge_table + result) < 2000:
+                merge_table += result
+            if len(merge_table + result) > 2000 or idx == len(table_result) - 1:
+                single_prompt = self.get_kie_prompt(
+                    merge_table,
+                    key_list,
+                    rules_str=rules,
+                    few_shot_demo_str=few_shot,
+                    prompt_type="table",
+                )
+                prompt_key_information.append(single_prompt)
+                merge_table = ""
+        return prompt_key_information
+
+    def get_prompt_for_ocr(
+        self,
+        ocr_result,
+        key_list,
+        rules="",
+        few_shot="",
+        user_task_description="",
+    ):
+        """get prompt for ocr"""
+
+        prompt_key_information = self.get_kie_prompt(
+            ocr_result, key_list, user_task_description, rules, few_shot
+        )
+        return prompt_key_information
+
+    def get_kie_prompt(
+        self,
+        text_result,
+        key_list,
+        user_task_description="",
+        rules_str="",
+        few_shot_demo_str="",
+        prompt_type="common",
+    ):
+        """get_kie_prompt"""
+
+        if prompt_type == "table":
+            task_description = self.task_prompt_dict["kie_table_prompt"][
+                "task_description"
+            ]
+        else:
+            task_description = self.task_prompt_dict["kie_common_prompt"][
+                "task_description"
+            ]
+            output_format = self.task_prompt_dict["kie_common_prompt"]["output_format"]
+            if len(user_task_description) > 0:
+                task_description = user_task_description
+            task_description = task_description + output_format
+
+        few_shot_demo_key_value = ""
+
+        if self.user_prompt_dict:
+            logging.info("======= common use custom ========")
+            task_description = self.user_prompt_dict["task_description"]
+            rules_str = self.user_prompt_dict["rules_str"]
+            few_shot_demo_str = self.user_prompt_dict["few_shot_demo_str"]
+            few_shot_demo_key_value = self.user_prompt_dict["few_shot_demo_key_value"]
+
+        prompt = f"""{task_description}{rules_str}{few_shot_demo_str}{few_shot_demo_key_value}"""
+
+        if prompt_type == "table":
+            prompt += f"""\n结合上面,下面正式开始:\
+                表格内容:```{text_result}```\
+                关键词列表:[{key_list}]。""".replace(
+                "    ", ""
+            )
+        else:
+            prompt += f"""\n结合上面的例子,下面正式开始:\
+                OCR文字:```{text_result}```\
+                关键词列表:[{key_list}]。""".replace(
+                "    ", ""
+            )
+
+        return prompt

+ 184 - 0
paddlex/inference/pipelines/ppchatocrv3/utils.py

@@ -0,0 +1,184 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import re
+import numpy as np
+from pathlib import Path
+from scipy.ndimage import rotate
+
+
+def get_ocr_res(pipeline, input):
+    """get ocr res"""
+    ocr_res_list = []
+    if isinstance(input, list):
+        img = [im["img"] for im in input]
+    elif isinstance(input, dict):
+        img = input["img"]
+    else:
+        img = input
+    for ocr_res in pipeline(img):
+        ocr_res_list.append(ocr_res)
+    if len(ocr_res_list) == 1:
+        return ocr_res_list[0]
+    else:
+        return ocr_res_list
+
+
+def get_oriclas_results(inputs, predictor, img_list):
+    results = []
+    for input, pred in zip(inputs, predictor(img_list)):
+        results.append(pred)
+        angle = int(pred["label_names"][0])
+        input["img"] = rotate_image(input["img"], angle)
+    return results
+
+
+def get_uvdoc_results(inputs, predictor, img_list):
+    results = []
+    for input, pred in zip(inputs, predictor(img_list)):
+        results.append(pred)
+        input["img"] = np.array(pred["doctr_img"], dtype=np.uint8)
+    return results
+
+
+def get_predictor_res(predictor, input):
+    """get ocr res"""
+    result_list = []
+    if isinstance(input, list):
+        img = [im["img"] for im in input]
+    elif isinstance(input, dict):
+        img = input["img"]
+    else:
+        img = input
+    for res in predictor(img):
+        result_list.append(res)
+    if len(result_list) == 1:
+        return result_list[0]
+    else:
+        return result_list
+
+
+def rotate_image(image_array, rotate_angle):
+    """rotate image"""
+    assert (
+        rotate_angle >= 0 and rotate_angle < 360
+    ), "rotate_angle must in [0-360), but get {rotate_angle}."
+    return rotate(image_array, rotate_angle, reshape=True)
+
+
+def get_table_text_from_html(all_table_res):
+    all_table_ocr_res = []
+    structure_res = []
+    for table_res in all_table_res:
+        table_list = []
+        table_lines = re.findall("<tr>(.*?)</tr>", table_res["html"])
+        single_table_ocr_res = []
+        for td_line in table_lines:
+            table_list.extend(re.findall("<td.*?>(.*?)</td>", td_line))
+        for text in table_list:
+            text = text.replace(" ", "")
+            single_table_ocr_res.append(text)
+        all_table_ocr_res.append(" ".join(single_table_ocr_res))
+        structure_res.append(
+            {
+                "layout_bbox": table_res["layout_bbox"],
+                "table": table_res["html"],
+            }
+        )
+    return structure_res, all_table_ocr_res
+
+
+def format_key(key_list):
+    """format key"""
+    if key_list == "":
+        return "未内置默认字段,请输入确定的key"
+    if isinstance(key_list, list):
+        return key_list
+    key_list = re.sub(r"[\t\n\r\f\v]", "", key_list)
+    key_list = key_list.replace(",", ",").split(",")
+    return key_list
+
+
+def sorted_layout_boxes(res, w):
+    """
+    Sort text boxes in order from top to bottom, left to right
+    args:
+        res(list):ppstructure results
+        w(int):image width
+    return:
+        sorted results(list)
+    """
+    num_boxes = len(res)
+    if num_boxes == 1:
+        res[0]["layout"] = "single"
+        return res
+
+    # Sort on the y axis first or sort it on the x axis
+    sorted_boxes = sorted(res, key=lambda x: (x["layout_bbox"][1], x["layout_bbox"][0]))
+    _boxes = list(sorted_boxes)
+
+    new_res = []
+    res_left = []
+    res_mid = []
+    res_right = []
+    i = 0
+
+    while True:
+        if i >= num_boxes:
+            break
+        # Check if there are three columns of pictures
+        if (
+            _boxes[i]["layout_bbox"][0] > w / 4
+            and _boxes[i]["layout_bbox"][0] + _boxes[i]["layout_bbox"][2] < 3 * w / 4
+        ):
+            _boxes[i]["layout"] = "double"
+            res_mid.append(_boxes[i])
+            i += 1
+        # Check that the bbox is on the left
+        elif (
+            _boxes[i]["layout_bbox"][0] < w / 4
+            and _boxes[i]["layout_bbox"][0] + _boxes[i]["layout_bbox"][2] < 3 * w / 5
+        ):
+            _boxes[i]["layout"] = "double"
+            res_left.append(_boxes[i])
+            i += 1
+        elif (
+            _boxes[i]["layout_bbox"][0] > 2 * w / 5
+            and _boxes[i]["layout_bbox"][0] + _boxes[i]["layout_bbox"][2] < w
+        ):
+            _boxes[i]["layout"] = "double"
+            res_right.append(_boxes[i])
+            i += 1
+        else:
+            new_res += res_left
+            new_res += res_right
+            _boxes[i]["layout"] = "single"
+            new_res.append(_boxes[i])
+            res_left = []
+            res_right = []
+            i += 1
+
+    res_left = sorted(res_left, key=lambda x: (x["layout_bbox"][1]))
+    res_mid = sorted(res_mid, key=lambda x: (x["layout_bbox"][1]))
+    res_right = sorted(res_right, key=lambda x: (x["layout_bbox"][1]))
+
+    if res_left:
+        new_res += res_left
+    if res_mid:
+        new_res += res_mid
+    if res_right:
+        new_res += res_right
+
+    return new_res

+ 40 - 19
paddlex/inference/pipelines/table_recognition/table_recognition.py

@@ -13,11 +13,12 @@
 # limitations under the License.
 
 import numpy as np
+from .utils import *
 from ..base import BasePipeline
 from ..ocr import OCRPipeline
+from ....utils import logging
 from ...components import CropByBoxes
 from ...results import OCRResult, TableResult, StructureTableResult
-from .utils import *
 
 
 class TableRecPipeline(BasePipeline):
@@ -32,38 +33,58 @@ class TableRecPipeline(BasePipeline):
         text_rec_model,
         table_model,
         layout_batch_size=1,
+        text_det_batch_size=1,
         text_rec_batch_size=1,
         table_batch_size=1,
         predictor_kwargs=None,
     ):
+        self.layout_model = layout_model
+        self.text_det_model = text_det_model
+        self.text_rec_model = text_rec_model
+        self.table_model = table_model
+        self.layout_batch_size = layout_batch_size
+        self.text_det_batch_size = text_det_batch_size
+        self.text_rec_batch_size = text_rec_batch_size
+        self.table_batch_size = table_batch_size
         super().__init__(predictor_kwargs=predictor_kwargs)
-        self._build_predictor(
-            layout_model, text_det_model, text_rec_model, table_model, predictor_kwargs
-        )
-        self.set_predictor(layout_batch_size, text_rec_batch_size, table_batch_size)
+        self._build_predictor()
+        # self.set_predictor(layout_batch_size, text_det_batch_size,text_rec_batch_size, table_batch_size)
 
     def _build_predictor(
         self,
-        layout_model,
-        text_det_model,
-        text_rec_model,
-        table_model,
-        predictor_kwargs,
     ):
-        self.layout_predictor = self._create_model(model=layout_model)
+        self.layout_predictor = self._create_model(model=self.layout_model)
         self.ocr_pipeline = OCRPipeline(
-            text_det_model,
-            text_rec_model,
-            predictor_kwargs=predictor_kwargs,
+            self.text_det_model,
+            self.text_rec_model,
+            self.predictor_kwargs,
         )
-        self.table_predictor = self._create_model(model=table_model)
+        self.table_predictor = self._create_model(model=self.table_model)
         self._crop_by_boxes = CropByBoxes()
         self._match = TableMatch(filter_ocr_result=False)
+        self.layout_predictor.set_predictor(batch_size=self.layout_batch_size)
+        self.ocr_pipeline.text_rec_model.set_predictor(
+            batch_size=self.text_rec_batch_size
+        )
+        self.table_predictor.set_predictor(batch_size=self.table_batch_size)
 
-    def set_predictor(self, layout_batch_size, text_rec_batch_size, table_batch_size):
-        self.layout_predictor.set_predictor(batch_size=layout_batch_size)
-        self.ocr_pipeline.rec_model.set_predictor(batch_size=text_rec_batch_size)
-        self.table_predictor.set_predictor(batch_size=table_batch_size)
+    def set_predictor(
+        self,
+        layout_batch_size=None,
+        text_det_batch_size=None,
+        text_rec_batch_size=None,
+        table_batch_size=None,
+    ):
+        if text_det_batch_size and text_det_batch_size > 1:
+            logging.warning(
+                f"text det model only support batch_size=1 now,the setting of text_det_batch_size={text_det_batch_size} will not using! "
+            )
+        if layout_batch_size:
+            self.layout_predictor.set_predictor(batch_size=layout_batch_size)
+        if text_rec_batch_size:
+            self.ocr_pipeline.rec_model.set_predictor(batch_size=text_rec_batch_size)
+        if table_batch_size:
+            self.table_predictor.set_predictor(batch_size=table_batch_size)
 
     def predict(self, x):
         for layout_pred, ocr_pred in zip(

+ 1 - 0
paddlex/inference/results/__init__.py

@@ -23,3 +23,4 @@ from .seg import SegResult
 from .instance_seg import InstanceSegResult
 from .ts import TSFcResult, TSAdResult, TSClsResult
 from .warp import DocTrResult
+from .chat_ocr import *

+ 33 - 0
paddlex/inference/results/chat_ocr.py

@@ -0,0 +1,33 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .base import BaseResult
+
+
+class LayoutStructureResult(BaseResult):
+    """LayoutStructureResult"""
+
+    pass
+
+
+class VisualInfoResult(BaseResult):
+    """VisualInfoResult"""
+
+    pass
+
+
+class ChatOCRResult(BaseResult):
+    """VisualInfoResult"""
+
+    pass

+ 24 - 21
paddlex/inference/results/ocr.py

@@ -68,28 +68,31 @@ class OCRResult(CVResult):
         if txts is None or len(txts) != len(boxes):
             txts = [None] * len(boxes)
         for idx, (box, txt) in enumerate(zip(boxes, txts)):
-            if scores is not None and scores[idx] < drop_score:
+            try:
+                if scores is not None and scores[idx] < drop_score:
+                    continue
+                color = (
+                    random.randint(0, 255),
+                    random.randint(0, 255),
+                    random.randint(0, 255),
+                )
+                box = np.array(box)
+                if len(box) > 4:
+                    pts = [(x, y) for x, y in box.tolist()]
+                    draw_left.polygon(pts, outline=color, width=8)
+                    box = self.get_minarea_rect(box)
+                    height = int(0.5 * (max(box[:, 1]) - min(box[:, 1])))
+                    box[:2, 1] = np.mean(box[:, 1])
+                    box[2:, 1] = np.mean(box[:, 1]) + min(20, height)
+                draw_left.polygon(box, fill=color)
+                img_right_text = draw_box_txt_fine(
+                    (w, h), box, txt, PINGFANG_FONT_FILE_PATH
+                )
+                pts = np.array(box, np.int32).reshape((-1, 1, 2))
+                cv2.polylines(img_right_text, [pts], True, color, 1)
+                img_right = cv2.bitwise_and(img_right, img_right_text)
+            except:
                 continue
-            color = (
-                random.randint(0, 255),
-                random.randint(0, 255),
-                random.randint(0, 255),
-            )
-            box = np.array(box)
-            if len(box) > 4:
-                pts = [(x, y) for x, y in box.tolist()]
-                draw_left.polygon(pts, outline=color, width=8)
-                box = self.get_minarea_rect(box)
-                height = int(0.5 * (max(box[:, 1]) - min(box[:, 1])))
-                box[:2, 1] = np.mean(box[:, 1])
-                box[2:, 1] = np.mean(box[:, 1]) + min(20, height)
-            draw_left.polygon(box, fill=color)
-            img_right_text = draw_box_txt_fine(
-                (w, h), box, txt, PINGFANG_FONT_FILE_PATH
-            )
-            pts = np.array(box, np.int32).reshape((-1, 1, 2))
-            cv2.polylines(img_right_text, [pts], True, color, 1)
-            img_right = cv2.bitwise_and(img_right, img_right_text)
 
         img_left = Image.blend(image, img_left, 0.5)
         img_show = Image.new("RGB", (w * 2, h), (255, 255, 255))

+ 18 - 0
paddlex/inference/results/table_rec.py

@@ -68,6 +68,24 @@ class StructureTableResult(TableRecResult, XlsxMixin):
 class TableResult(BaseResult):
     """TableResult"""
 
+    def save_to_html(self, save_path):
+        if not save_path.lower().endswith(("html")):
+            input_path = self["input_path"]
+            save_path = Path(save_path) / f"{Path(input_path).stem}"
+        else:
+            save_path = Path(save_path).stem
+        for table_result in self["table_result"]:
+            table_result.save_to_html(save_path)
+
+    def save_to_xlsx(self, save_path):
+        if not save_path.lower().endswith(("xlsx")):
+            input_path = self["input_path"]
+            save_path = Path(save_path) / f"{Path(input_path).stem}"
+        else:
+            save_path = Path(save_path).stem
+        for table_result in self["table_result"]:
+            table_result.save_to_xlsx(save_path)
+
     def save_to_img(self, save_path):
         if not save_path.lower().endswith((".jpg", ".png")):
             input_path = self["input_path"]

+ 5 - 5
paddlex/inference/results/utils/mixin.py

@@ -90,7 +90,7 @@ class JsonMixin:
         return self._to_json()
 
     def save_to_json(self, save_path, indent=4, ensure_ascii=False, *args, **kwargs):
-        if not save_path.endswith(".json"):
+        if not str(save_path).endswith(".json"):
             save_path = Path(save_path) / f"{Path(self['input_path']).stem}.json"
         _save_list_data(
             self._json_writer.write,
@@ -121,7 +121,7 @@ class ImgMixin:
         return image
 
     def save_to_img(self, save_path, *args, **kwargs):
-        if not save_path.lower().endswith((".jpg", ".png")):
+        if not str(save_path).lower().endswith((".jpg", ".png")):
             fp = Path(self["input_path"])
             save_path = Path(save_path) / f"{fp.stem}.{fp.suffix}"
         _save_list_data(self._img_writer.write, save_path, self.img, *args, **kwargs)
@@ -137,7 +137,7 @@ class CSVMixin:
         raise NotImplementedError
 
     def save_to_csv(self, save_path, *args, **kwargs):
-        if not save_path.endswith(".csv"):
+        if not str(save_path).endswith(".csv"):
             save_path = Path(save_path) / f"{Path(self['input_path']).stem}.csv"
         _save_list_data(
             self._csv_writer.write, save_path, self._to_csv(), *args, **kwargs
@@ -157,7 +157,7 @@ class HtmlMixin:
         return self["html"]
 
     def save_to_html(self, save_path, *args, **kwargs):
-        if not save_path.endswith(".html"):
+        if not str(save_path).endswith(".html"):
             save_path = Path(save_path) / f"{Path(self['input_path']).stem}.html"
         _save_list_data(self._html_writer.write, save_path, self.html, *args, **kwargs)
 
@@ -171,6 +171,6 @@ class XlsxMixin:
         return self["html"]
 
     def save_to_xlsx(self, save_path, *args, **kwargs):
-        if not save_path.endswith(".xlsx"):
+        if not str(save_path).endswith(".xlsx"):
             save_path = Path(save_path) / f"{Path(self['input_path']).stem}.xlsx"
         _save_list_data(self._xlsx_writer.write, save_path, self.html, *args, **kwargs)

+ 7 - 13
paddlex/inference/results/warp.py

@@ -17,21 +17,15 @@ import copy
 import json
 
 from ...utils import logging
-from .base import BaseResult
+from .base import CVResult
 
 
-class DocTrResult(BaseResult):
-    def __init__(self, data):
-        super().__init__(data)
-        self._img_writer.set_backend("opencv")
+class DocTrResult(CVResult):
 
-    def _get_res_img(self):
-        doctr_img = np.array(self["doctr_img"])
-        return doctr_img
+    def _to_img(self):
+        return np.array(self["doctr_img"])
 
-    def print(self, json_format=True, indent=4, ensure_ascii=False):
+    def _to_str(self, json_format=True, indent=4, ensure_ascii=False):
         str_ = copy.deepcopy(self)
-        del str_["doctr_img"]
-        if json_format:
-            str_ = json.dumps(str_, indent=indent, ensure_ascii=ensure_ascii)
-        logging.info(str_)
+        str_.pop("doctr_img")
+        return str_