瀏覽代碼

add the new architecture of pipelines

dyning 11 月之前
父節點
當前提交
639ad156b0
共有 37 個文件被更改,包括 598 次插入418 次删除
  1. 16 3
      api_examples/pipelines/test_doc_preprocessor.py
  2. 18 3
      api_examples/pipelines/test_layout_parsing.py
  3. 13 2
      api_examples/pipelines/test_ocr.py
  4. 15 2
      api_examples/pipelines/test_pp_chatocrv3.py
  5. 16 5
      api_examples/pipelines/test_table_recognition.py
  6. 1 1
      paddlex/configs/pipelines/PP-ChatOCRv3-doc.yaml
  7. 1 1
      paddlex/configs/pipelines/layout_parsing.yaml
  8. 1 0
      paddlex/inference/__init__.py
  9. 8 3
      paddlex/inference/pipelines_new/__init__.py
  10. 24 21
      paddlex/inference/pipelines_new/base.py
  11. 4 2
      paddlex/inference/pipelines_new/components/base.py
  12. 3 1
      paddlex/inference/pipelines_new/components/chat_server/base.py
  13. 19 23
      paddlex/inference/pipelines_new/components/chat_server/ernie_bot_chat.py
  14. 0 1
      paddlex/inference/pipelines_new/components/common/__init__.py
  15. 2 1
      paddlex/inference/pipelines_new/components/common/crop_image_regions.py
  16. 1 0
      paddlex/inference/pipelines_new/components/common/seal_det_warp.py
  17. 2 0
      paddlex/inference/pipelines_new/components/common/sort_boxes.py
  18. 1 1
      paddlex/inference/pipelines_new/components/prompt_engeering/__init__.py
  19. 3 1
      paddlex/inference/pipelines_new/components/prompt_engeering/base.py
  20. 27 22
      paddlex/inference/pipelines_new/components/prompt_engeering/generate_kie_prompt.py
  21. 6 3
      paddlex/inference/pipelines_new/components/retriever/base.py
  22. 20 17
      paddlex/inference/pipelines_new/components/retriever/ernie_bot_retriever.py
  23. 1 1
      paddlex/inference/pipelines_new/doc_preprocessor/__init__.py
  24. 51 37
      paddlex/inference/pipelines_new/doc_preprocessor/pipeline.py
  25. 2 1
      paddlex/inference/pipelines_new/doc_preprocessor/result.py
  26. 1 1
      paddlex/inference/pipelines_new/layout_parsing/__init__.py
  27. 108 84
      paddlex/inference/pipelines_new/layout_parsing/pipeline.py
  28. 31 25
      paddlex/inference/pipelines_new/layout_parsing/result.py
  29. 34 28
      paddlex/inference/pipelines_new/layout_parsing/table_recognition_post_processing.py
  30. 19 17
      paddlex/inference/pipelines_new/layout_parsing/utils.py
  31. 1 1
      paddlex/inference/pipelines_new/ocr/__init__.py
  32. 28 19
      paddlex/inference/pipelines_new/ocr/pipeline.py
  33. 3 2
      paddlex/inference/pipelines_new/ocr/result.py
  34. 1 1
      paddlex/inference/pipelines_new/pp_chatocrv3_doc/__init__.py
  35. 110 86
      paddlex/inference/pipelines_new/pp_chatocrv3_doc/pipeline.py
  36. 4 2
      paddlex/inference/pipelines_new/pp_chatocrv3_doc/result.py
  37. 3 0
      paddlex/utils/fonts/__init__.py

+ 16 - 3
api_examples/pipelines/test_doc_preprocessor.py

@@ -1,3 +1,16 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
 
 from paddlex import create_pipeline
 
@@ -6,9 +19,9 @@ pipeline = create_pipeline(pipeline="doc_preprocessor")
 test_img_path = "./test_imgs/img_rot180_demo.jpg"
 # test_img_path = "./test_imgs/doc_distort_test.jpg"
 
-output = pipeline.predict(test_img_path, 
-    use_doc_orientation_classify=True,
-    use_doc_unwarping=True)
+output = pipeline.predict(
+    test_img_path, use_doc_orientation_classify=True, use_doc_unwarping=True
+)
 
 for res in output:
     print(res)

+ 18 - 3
api_examples/pipelines/test_layout_parsing.py

@@ -1,14 +1,29 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
 
 from paddlex import create_pipeline
 
 pipeline = create_pipeline(pipeline="layout_parsing")
 
-output = pipeline.predict("./test_imgs/test_layout_parsing.jpg",
+output = pipeline.predict(
+    "./test_imgs/test_layout_parsing.jpg",
     use_doc_orientation_classify=True,
     use_doc_unwarping=True,
     use_common_ocr=True,
     use_seal_recognition=True,
-    use_table_recognition=True)
+    use_table_recognition=True,
+)
 
 # output = pipeline("./test_imgs/demo_paper.png")
 # output = pipeline("./test_imgs/table_recognition.jpg")
@@ -16,4 +31,4 @@ output = pipeline.predict("./test_imgs/test_layout_parsing.jpg",
 # output = pipeline.predict("./test_imgs/img_rot180_demo.jpg")
 for res in output:
     # print(res)
-    res.save_results("./output")
+    res.save_results("./output")

+ 13 - 2
api_examples/pipelines/test_ocr.py

@@ -1,3 +1,16 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
 
 from paddlex import create_pipeline
 
@@ -9,5 +22,3 @@ output = pipeline.predict("./test_imgs/seal_text_det.png")
 for res in output:
     print(res)
     res.save_to_img("./output")
-
-

+ 15 - 2
api_examples/pipelines/test_pp_chatocrv3.py

@@ -1,3 +1,16 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
 
 from paddlex import create_pipeline
 
@@ -10,9 +23,9 @@ pipeline = create_pipeline(pipeline="PP-ChatOCRv3-doc")
 # key_list = ['3.2的标题']
 
 img_path = "./test_demo_imgs/seal_text_det.png"
-key_list = ['印章上公司']
+key_list = ["印章上公司"]
 
-# visual_predict_res = pipeline.visual_predict(img_path, 
+# visual_predict_res = pipeline.visual_predict(img_path,
 #     use_doc_orientation_classify=True,
 #     use_doc_unwarping=True,
 #     use_common_ocr=True,

+ 16 - 5
api_examples/pipelines/test_table_recognition.py

@@ -1,3 +1,16 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
 
 from paddlex import create_pipeline
 
@@ -6,8 +19,6 @@ pipeline = create_pipeline(pipeline="table_recognition")
 output = pipeline("./test_imgs/table_recognition.jpg")
 for res in output:
     print(res)
-    res.save_to_img("./output/") ## 保存img格式结果
-    res.save_to_xlsx("./output/") ## 保存表格格式结果
-    res.save_to_html("./output/") ## 保存html结果
-
-
+    res.save_to_img("./output/")  ## 保存img格式结果
+    res.save_to_xlsx("./output/")  ## 保存表格格式结果
+    res.save_to_html("./output/")  ## 保存html结果

+ 1 - 1
paddlex/configs/pipelines/PP-ChatOCRv3-doc.yaml

@@ -106,4 +106,4 @@ SubPipelines:
           TextRecognition:
             model_name: PP-OCRv4_server_rec
             model_dir: null
-            batch_size: 1  
+            batch_size: 1  

+ 1 - 1
paddlex/configs/pipelines/layout_parsing.yaml

@@ -53,4 +53,4 @@ SubPipelines:
       TextRecognition:
         model_name: PP-OCRv4_server_rec
         model_dir: null
-        batch_size: 1
+        batch_size: 1

+ 1 - 0
paddlex/inference/__init__.py

@@ -14,6 +14,7 @@
 
 from .models import create_predictor
 from ..utils import flags
+
 if flags.USE_NEW_INFERENCE:
     from .pipelines_new import create_pipeline
 else:

+ 8 - 3
paddlex/inference/pipelines_new/__init__.py

@@ -45,14 +45,18 @@ from .doc_preprocessor import DocPreprocessorPipeline
 from .layout_parsing import LayoutParsingPipeline
 from .pp_chatocrv3_doc import PP_ChatOCRv3_doc_Pipeline
 
+
 def get_pipeline_path(pipeline_name):
     pipeline_path = (
-        Path(__file__).parent.parent.parent / "configs/pipelines" / f"{pipeline_name}.yaml"
+        Path(__file__).parent.parent.parent
+        / "configs/pipelines"
+        / f"{pipeline_name}.yaml"
     ).resolve()
     if not Path(pipeline_path).exists():
         return None
     return pipeline_path
 
+
 def load_pipeline_config(pipeline_name: str) -> Dict[str, Any]:
     if not Path(pipeline_name).exists():
         pipeline_path = get_pipeline_path(pipeline_name)
@@ -65,6 +69,7 @@ def load_pipeline_config(pipeline_name: str) -> Dict[str, Any]:
     config = parse_config(pipeline_path)
     return config
 
+
 def create_pipeline(
     pipeline: str,
     device=None,
@@ -90,6 +95,6 @@ def create_pipeline(
         device=device,
         pp_option=pp_option,
         use_hpip=use_hpip,
-        hpi_params=hpi_params)
+        hpi_params=hpi_params,
+    )
     return pipeline
-    

+ 24 - 21
paddlex/inference/pipelines_new/base.py

@@ -23,16 +23,19 @@ from .components.chat_server.base import BaseChat
 from .components.retriever.base import BaseRetriever
 from .components.prompt_engeering.base import BaseGeneratePrompt
 
+
 class BasePipeline(ABC, metaclass=AutoRegisterABCMetaClass):
     """Base Pipeline"""
 
     __is_base = True
 
-    def __init__(self,
-        device=None, 
-        pp_option=None, 
-        use_hpip: bool = False, 
-        hpi_params: Optional[Dict[str, Any]] = None) -> None:
+    def __init__(
+        self,
+        device=None,
+        pp_option=None,
+        use_hpip: bool = False,
+        hpi_params: Optional[Dict[str, Any]] = None,
+    ) -> None:
         super().__init__()
         self.device = device
         self.pp_option = pp_option
@@ -41,22 +44,21 @@ class BasePipeline(ABC, metaclass=AutoRegisterABCMetaClass):
 
     @abstractmethod
     def predict(self, input, **kwargs):
-        raise NotImplementedError(
-            "The method `predict` has not been implemented yet."
-        )
-    
+        raise NotImplementedError("The method `predict` has not been implemented yet.")
+
     def create_model(self, config):
 
-        model_dir = config['model_dir']
+        model_dir = config["model_dir"]
         if model_dir == None:
-            model_dir = config['model_name']
+            model_dir = config["model_name"]
 
         model = create_predictor(
             model_dir,
             device=self.device,
             pp_option=self.pp_option,
             use_hpip=self.use_hpip,
-            hpi_params=self.hpi_params)
+            hpi_params=self.hpi_params,
+        )
 
         ########### [TODO]支持初始化传参能力
         if "batch_size" in config:
@@ -66,29 +68,30 @@ class BasePipeline(ABC, metaclass=AutoRegisterABCMetaClass):
         return model
 
     def create_pipeline(self, config):
-        pipeline_name = config['pipeline_name']
+        pipeline_name = config["pipeline_name"]
         pipeline = BasePipeline.get(pipeline_name)(
             config=config,
             device=self.device,
             pp_option=self.pp_option,
             use_hpip=self.use_hpip,
-            hpi_params=self.hpi_params)
-        return pipeline 
+            hpi_params=self.hpi_params,
+        )
+        return pipeline
 
     def create_chat_bot(self, config):
-        model_name = config['model_name']
+        model_name = config["model_name"]
         chat_bot = BaseChat.get(model_name)(config)
-        return chat_bot     
+        return chat_bot
 
     def create_retriever(self, config):
-        model_name = config['model_name']
+        model_name = config["model_name"]
         retriever = BaseRetriever.get(model_name)(config)
-        return retriever    
+        return retriever
 
     def create_prompt_engeering(self, config):
-        task_type = config['task_type']
+        task_type = config["task_type"]
         pe = BaseGeneratePrompt.get(task_type)(config)
-        return pe       
+        return pe
 
     def __call__(self, input, **kwargs):
         return self.predict(input, **kwargs)

+ 4 - 2
paddlex/inference/pipelines_new/components/base.py

@@ -21,6 +21,7 @@ from ....utils.func_register import FuncRegister
 from ...utils.io import ImageReader, ImageWriter
 from .utils.mixin import JsonMixin, ImgMixin, StrMixin
 
+
 class BaseComponent(ABC, metaclass=AutoRegisterABCMetaClass):
     """Base Component"""
 
@@ -32,7 +33,8 @@ class BaseComponent(ABC, metaclass=AutoRegisterABCMetaClass):
     @abstractmethod
     def __call__(self):
         raise NotImplementedError(
-            "The component method `__call__` has not been implemented yet.")
+            "The component method `__call__` has not been implemented yet."
+        )
 
 
 class BaseResult(dict, StrMixin, JsonMixin):
@@ -50,10 +52,10 @@ class BaseResult(dict, StrMixin, JsonMixin):
             else:
                 func()
 
+
 class CVResult(BaseResult, ImgMixin):
     def __init__(self, data):
         super().__init__(data)
         ImgMixin.__init__(self, "pillow")
         self._img_reader = ImageReader(backend="pillow")
         self._img_writer = ImageWriter(backend="pillow")
-

+ 3 - 1
paddlex/inference/pipelines_new/components/chat_server/base.py

@@ -17,6 +17,7 @@ from .....utils.subclass_register import AutoRegisterABCMetaClass
 
 import inspect
 
+
 class BaseChat(ABC, metaclass=AutoRegisterABCMetaClass):
     """Base Chat"""
 
@@ -28,4 +29,5 @@ class BaseChat(ABC, metaclass=AutoRegisterABCMetaClass):
     @abstractmethod
     def generate_chat_results(self):
         raise NotImplementedError(
-            "The method `generate_chat_results` has not been implemented yet.")
+            "The method `generate_chat_results` has not been implemented yet."
+        )

+ 19 - 23
paddlex/inference/pipelines_new/components/chat_server/ernie_bot_chat.py

@@ -16,6 +16,7 @@ from .....utils import logging
 from .base import BaseChat
 import erniebot
 
+
 class ErnieBotChat(BaseChat):
     """Ernie Bot Chat"""
 
@@ -32,11 +33,11 @@ class ErnieBotChat(BaseChat):
 
     def __init__(self, config):
         super().__init__()
-        model_name = config.get('model_name', None)
-        api_type = config.get('api_type', None)
-        ak = config.get('ak', None)
-        sk = config.get('sk', None)
-        access_token = config.get('access_token', None)
+        model_name = config.get("model_name", None)
+        api_type = config.get("api_type", None)
+        ak = config.get("ak", None)
+        sk = config.get("sk", None)
+        access_token = config.get("access_token", None)
 
         if model_name not in self.entities:
             raise ValueError(f"model_name must be in {self.entities} of ErnieBotChat.")
@@ -46,13 +47,13 @@ class ErnieBotChat(BaseChat):
 
         if api_type == "aistudio" and access_token is None:
             raise ValueError("access_token cannot be empty when api_type is aistudio.")
-            
+
         if api_type == "qianfan" and (ak is None or sk is None):
-            raise ValueError("ak and sk cannot be empty when api_type is qianfan.")            
+            raise ValueError("ak and sk cannot be empty when api_type is qianfan.")
 
         self.model_name = model_name
         self.config = config
-        
+
     def generate_chat_results(self, prompt, temperature=0.001, max_retries=1):
         """
         args:
@@ -60,14 +61,14 @@ class ErnieBotChat(BaseChat):
         """
         try:
             cur_config = {
-                "api_type": self.config['api_type'],
-                "max_retries": max_retries
+                "api_type": self.config["api_type"],
+                "max_retries": max_retries,
             }
-            if self.config['api_type'] == "aistudio":
-                cur_config['access_token'] = self.config['access_token']
-            elif self.config['api_type'] == "qianfan":
-                cur_config['ak'] = self.config['ak']
-                cur_config['sk'] = self.config['sk']
+            if self.config["api_type"] == "aistudio":
+                cur_config["access_token"] = self.config["access_token"]
+            elif self.config["api_type"] == "qianfan":
+                cur_config["ak"] = self.config["ak"]
+                cur_config["sk"] = self.config["sk"]
             chat_completion = erniebot.ChatCompletion.create(
                 _config_=cur_config,
                 model=self.model_name,
@@ -78,18 +79,13 @@ class ErnieBotChat(BaseChat):
             return llm_result
         except Exception as e:
             if len(e.args) < 1:
-                self.ERROR_MASSAGE = (
-                    "暂无权限访问ErnieBot服务,请检查访问令牌。"
-                )
+                self.ERROR_MASSAGE = "暂无权限访问ErnieBot服务,请检查访问令牌。"
             elif (
                 e.args[-1]
                 == "暂无权限使用,请在 AI Studio 正确获取访问令牌(access token)使用"
             ):
-                self.ERROR_MASSAGE = (
-                    "暂无权限访问ErnieBot服务,请检查访问令牌。"
-                )
+                self.ERROR_MASSAGE = "暂无权限访问ErnieBot服务,请检查访问令牌。"
             else:
                 logging.error(e)
                 self.ERROR_MASSAGE = "大模型调用失败"
-        return None 
-        
+        return None

+ 0 - 1
paddlex/inference/pipelines_new/components/common/__init__.py

@@ -14,4 +14,3 @@
 
 from .sort_boxes import SortQuadBoxes
 from .crop_image_regions import CropByPolys, CropByBoxes
-

+ 2 - 1
paddlex/inference/pipelines_new/components/common/crop_image_regions.py

@@ -21,6 +21,7 @@ from .seal_det_warp import AutoRectifier
 from shapely.geometry import Polygon
 from numpy.linalg import norm
 
+
 class CropByPolys(BaseComponent):
     """Crop Image by Polys"""
 
@@ -472,4 +473,4 @@ class CropByBoxes(BaseComponent):
             xmin, ymin, xmax, ymax = [int(i) for i in box]
             img_crop = img[ymin:ymax, xmin:xmax].copy()
             output_list.append({"img": img_crop, "box": box, "label": label})
-        return output_list
+        return output_list

+ 1 - 0
paddlex/inference/pipelines_new/components/common/seal_det_warp.py

@@ -21,6 +21,7 @@ import time
 
 from .....utils import logging
 
+
 def Homography(
     image,
     img_points,

+ 2 - 0
paddlex/inference/pipelines_new/components/common/sort_boxes.py

@@ -15,10 +15,12 @@
 from ..base import BaseComponent
 import numpy as np
 
+
 class SortQuadBoxes(BaseComponent):
     """SortQuadBoxes Component"""
 
     entities = "SortQuadBoxes"
+
     def __init__(self):
         super().__init__()
 

+ 1 - 1
paddlex/inference/pipelines_new/components/prompt_engeering/__init__.py

@@ -12,4 +12,4 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from .generate_kie_prompt import GenerateKIEPrompt
+from .generate_kie_prompt import GenerateKIEPrompt

+ 3 - 1
paddlex/inference/pipelines_new/components/prompt_engeering/base.py

@@ -17,6 +17,7 @@ from .....utils.subclass_register import AutoRegisterABCMetaClass
 
 import inspect
 
+
 class BaseGeneratePrompt(ABC, metaclass=AutoRegisterABCMetaClass):
     """Base Chat"""
 
@@ -28,4 +29,5 @@ class BaseGeneratePrompt(ABC, metaclass=AutoRegisterABCMetaClass):
     @abstractmethod
     def generate_prompt(self):
         raise NotImplementedError(
-            "The method `generate_prompt` has not been implemented yet.")
+            "The method `generate_prompt` has not been implemented yet."
+        )

+ 27 - 22
paddlex/inference/pipelines_new/components/prompt_engeering/generate_kie_prompt.py

@@ -14,41 +14,41 @@
 
 from .base import BaseGeneratePrompt
 
+
 class GenerateKIEPrompt(BaseGeneratePrompt):
     """Generate KIE Prompt"""
 
-    entities = [
-        "text_kie_prompt",
-        "table_kie_prompt"
-    ]
+    entities = ["text_kie_prompt", "table_kie_prompt"]
 
     def __init__(self, config):
         super().__init__()
 
-        task_type = config.get('task_type', "")
-        task_description = config.get('task_description', "")  
-        output_format = config.get('output_format', "")  
-        rules_str = config.get('rules_str', "")  
-        few_shot_demo_text_content = config.get('few_shot_demo_text_content', "")  
-        few_shot_demo_key_value_list = config.get('few_shot_demo_key_value_list', "")
+        task_type = config.get("task_type", "")
+        task_description = config.get("task_description", "")
+        output_format = config.get("output_format", "")
+        rules_str = config.get("rules_str", "")
+        few_shot_demo_text_content = config.get("few_shot_demo_text_content", "")
+        few_shot_demo_key_value_list = config.get("few_shot_demo_key_value_list", "")
 
         if task_description is None:
             task_description = ""
-        
+
         if output_format is None:
             output_format = ""
-        
+
         if rules_str is None:
             rules_str = ""
-        
+
         if few_shot_demo_text_content is None:
             few_shot_demo_text_content = ""
-        
+
         if few_shot_demo_key_value_list is None:
             few_shot_demo_key_value_list = ""
 
         if task_type not in self.entities:
-            raise ValueError(f"task type must be in {self.entities} of GenerateKIEPrompt.")
+            raise ValueError(
+                f"task type must be in {self.entities} of GenerateKIEPrompt."
+            )
 
         self.task_type = task_type
         self.task_description = task_description
@@ -56,14 +56,17 @@ class GenerateKIEPrompt(BaseGeneratePrompt):
         self.rules_str = rules_str
         self.few_shot_demo_text_content = few_shot_demo_text_content
         self.few_shot_demo_key_value_list = few_shot_demo_key_value_list
-        
-    def generate_prompt(self, text_content,
+
+    def generate_prompt(
+        self,
+        text_content,
         key_list,
         task_description=None,
         output_format=None,
         rules_str=None,
         few_shot_demo_text_content=None,
-        few_shot_demo_key_value_list=None):
+        few_shot_demo_key_value_list=None,
+    ):
         """
         args:
         return:
@@ -80,7 +83,7 @@ class GenerateKIEPrompt(BaseGeneratePrompt):
 
         if few_shot_demo_text_content is None:
             few_shot_demo_text_content = self.few_shot_demo_text_content
-            
+
         if few_shot_demo_key_value_list is None:
             few_shot_demo_key_value_list = self.few_shot_demo_key_value_list
 
@@ -89,12 +92,14 @@ class GenerateKIEPrompt(BaseGeneratePrompt):
             prompt += f"""\n结合上面,下面正式开始:\
                 表格内容:```{text_content}```\
                 关键词列表:{key_list}。""".replace(
-                "    ", "")
+                "    ", ""
+            )
         elif self.task_type == "text_kie_prompt":
             prompt += f"""\n结合上面的例子,下面正式开始:\
                 OCR文字:```{text_content}```\
                 关键词列表:{key_list}。""".replace(
-                "    ", "")
+                "    ", ""
+            )
         else:
-            raise ValueError(f"{self.task_type} is currently not supported.") 
+            raise ValueError(f"{self.task_type} is currently not supported.")
         return prompt

+ 6 - 3
paddlex/inference/pipelines_new/components/retriever/base.py

@@ -18,6 +18,7 @@ from .....utils.subclass_register import AutoRegisterABCMetaClass
 import inspect
 import base64
 
+
 class BaseRetriever(ABC, metaclass=AutoRegisterABCMetaClass):
     """Base Retriever"""
 
@@ -31,12 +32,14 @@ class BaseRetriever(ABC, metaclass=AutoRegisterABCMetaClass):
     @abstractmethod
     def generate_vector_database(self):
         raise NotImplementedError(
-            "The method `generate_vector_database` has not been implemented yet.")
+            "The method `generate_vector_database` has not been implemented yet."
+        )
 
     @abstractmethod
     def similarity_retrieval(self):
         raise NotImplementedError(
-            "The method `similarity_retrieval` has not been implemented yet.")
+            "The method `similarity_retrieval` has not been implemented yet."
+        )
 
     def is_vector_store(self, s):
         return s.startswith(self.VECTOR_STORE_PREFIX)
@@ -47,4 +50,4 @@ class BaseRetriever(ABC, metaclass=AutoRegisterABCMetaClass):
         )
 
     def decode_vector_store(self, vector_store_str):
-        return base64.b64decode(vector_store_str[len(self.VECTOR_STORE_PREFIX):])
+        return base64.b64decode(vector_store_str[len(self.VECTOR_STORE_PREFIX) :])

+ 20 - 17
paddlex/inference/pipelines_new/components/retriever/ernie_bot_retriever.py

@@ -25,6 +25,7 @@ from erniebot_agent.extensions.langchain.embeddings import ErnieEmbeddings
 
 import time
 
+
 class ErnieBotRetriever(BaseRetriever):
     """Ernie Bot Retriever"""
 
@@ -38,17 +39,17 @@ class ErnieBotRetriever(BaseRetriever):
         "ernie-speed-128k",
         "ernie-char-8k",
     ]
-    
+
     def __init__(self, config):
 
         super().__init__()
 
-        model_name = config.get('model_name', None)
-        api_type = config.get('api_type', None)
-        ak = config.get('ak', None)
-        sk = config.get('sk', None)
-        access_token = config.get('access_token', None)
-        
+        model_name = config.get("model_name", None)
+        api_type = config.get("api_type", None)
+        ak = config.get("ak", None)
+        sk = config.get("sk", None)
+        access_token = config.get("access_token", None)
+
         if model_name not in self.entities:
             raise ValueError(f"model_name must be in {self.entities} of ErnieBotChat.")
 
@@ -57,17 +58,20 @@ class ErnieBotRetriever(BaseRetriever):
 
         if api_type == "aistudio" and access_token is None:
             raise ValueError("access_token cannot be empty when api_type is aistudio.")
-            
+
         if api_type == "qianfan" and (ak is None or sk is None):
-            raise ValueError("ak and sk cannot be empty when api_type is qianfan.")            
+            raise ValueError("ak and sk cannot be empty when api_type is qianfan.")
 
         self.model_name = model_name
         self.config = config
-        
-    def generate_vector_database(self, text_list, 
+
+    def generate_vector_database(
+        self,
+        text_list,
         block_size=300,
         separators=["\t", "\n", "。", "\n\n", ""],
-        sleep_time=0.5):
+        sleep_time=0.5,
+    ):
         """
         args:
         return:
@@ -112,7 +116,7 @@ class ErnieBotRetriever(BaseRetriever):
     def encode_vector_store_to_bytes(self, vectorstore):
         vectorstore = self.encode_vector_store(vectorstore.serialize_to_bytes())
         return vectorstore
-    
+
     def decode_vector_store_from_bytes(self, vectorstore):
         if not self.is_vector_store(vectorstore):
             raise ValueError("The retrieved vectorstore is not for PaddleX.")
@@ -127,10 +131,10 @@ class ErnieBotRetriever(BaseRetriever):
             embeddings = QianfanEmbeddingsEndpoint(qianfan_ak=ak, qianfan_sk=sk)
         else:
             raise ValueError(f"Unsupported api_type: {api_type}")
-        vectorstore = vectorstores.FAISS.deserialize_from_bytes(
-            self.decode_vector_store(vector), embeddings
+        vector = vectorstores.FAISS.deserialize_from_bytes(
+            self.decode_vector_store(vectorstore), embeddings
         )
-        return vectorstore
+        return vector
 
     def similarity_retrieval(self, query_text_list, vectorstore, sleep_time=0.5):
         # 根据提问匹配上下文
@@ -145,4 +149,3 @@ class ErnieBotRetriever(BaseRetriever):
         C = list(set(C))
         all_C = " ".join(C)
         return all_C
-        

+ 1 - 1
paddlex/inference/pipelines_new/doc_preprocessor/__init__.py

@@ -12,4 +12,4 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from .pipeline import DocPreprocessorPipeline
+from .pipeline import DocPreprocessorPipeline

+ 51 - 37
paddlex/inference/pipelines_new/doc_preprocessor/pipeline.py

@@ -20,35 +20,40 @@ from .result import DocPreprocessorResult
 ########## [TODO]后续需要更新路径
 from ...components.transforms import ReadImage
 
+
 class DocPreprocessorPipeline(BasePipeline):
     """Doc Preprocessor Pipeline"""
 
     entities = "doc_preprocessor"
-    def __init__(self,
-        config,        
+
+    def __init__(
+        self,
+        config,
         device=None,
-        pp_option=None, 
+        pp_option=None,
         use_hpip: bool = False,
-        hpi_params: Optional[Dict[str, Any]] = None):
-        super().__init__(device=device, pp_option=pp_option, 
-            use_hpip=use_hpip, hpi_params=hpi_params)
-        
+        hpi_params: Optional[Dict[str, Any]] = None,
+    ):
+        super().__init__(
+            device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_params=hpi_params
+        )
+
         self.use_doc_orientation_classify = True
-        if 'use_doc_orientation_classify' in config:
-            self.use_doc_orientation_classify = config['use_doc_orientation_classify']
+        if "use_doc_orientation_classify" in config:
+            self.use_doc_orientation_classify = config["use_doc_orientation_classify"]
 
         self.use_doc_unwarping = True
-        if 'use_doc_unwarping' in config:
-            self.use_doc_unwarping = config['use_doc_unwarping']
-        
+        if "use_doc_unwarping" in config:
+            self.use_doc_unwarping = config["use_doc_unwarping"]
+
         if self.use_doc_orientation_classify:
-            doc_ori_classify_config = config['SubModules']["DocOrientationClassify"]
+            doc_ori_classify_config = config["SubModules"]["DocOrientationClassify"]
             self.doc_ori_classify_model = self.create_model(doc_ori_classify_config)
 
         if self.use_doc_unwarping:
-            doc_unwarping_config = config['SubModules']["DocUnwarping"]
+            doc_unwarping_config = config["SubModules"]["DocUnwarping"]
             self.doc_unwarping_model = self.create_model(doc_unwarping_config)
-        
+
         self.img_reader = ReadImage(format="BGR")
 
     def rotate_image(self, image_array, rotate_angle):
@@ -59,42 +64,49 @@ class DocPreprocessorPipeline(BasePipeline):
         return rotate(image_array, rotate_angle, reshape=True)
 
     def check_input_params(self, input_params):
-        
-        if input_params['use_doc_orientation_classify'] and \
-            not self.use_doc_orientation_classify:
-            raise ValueError("The model for doc orientation classify is not initialized.")
 
+        if (
+            input_params["use_doc_orientation_classify"]
+            and not self.use_doc_orientation_classify
+        ):
+            raise ValueError(
+                "The model for doc orientation classify is not initialized."
+            )
 
-        if input_params['use_doc_unwarping'] and \
-            not self.use_doc_unwarping:
+        if input_params["use_doc_unwarping"] and not self.use_doc_unwarping:
             raise ValueError("The model for doc unwarping is not initialized.")
-            
-        return 
 
-    def predict(self, input, 
+        return
+
+    def predict(
+        self,
+        input,
         use_doc_orientation_classify=True,
         use_doc_unwarping=False,
-        **kwargs):
+        **kwargs
+    ):
 
         if not isinstance(input, list):
             input_list = [input]
         else:
             input_list = input
 
-        input_params = {"use_doc_orientation_classify":use_doc_orientation_classify,
-            "use_doc_unwarping":use_doc_unwarping}
+        input_params = {
+            "use_doc_orientation_classify": use_doc_orientation_classify,
+            "use_doc_unwarping": use_doc_unwarping,
+        }
         self.check_input_params(input_params)
 
         img_id = 1
         for input in input_list:
             if isinstance(input, str):
-                image_array = next(self.img_reader(input))[0]['img']
+                image_array = next(self.img_reader(input))[0]["img"]
             else:
                 image_array = input
 
             assert len(image_array.shape) == 3
 
-            if input_params['use_doc_orientation_classify']:
+            if input_params["use_doc_orientation_classify"]:
                 pred = next(self.doc_ori_classify_model(image_array))
                 angle = int(pred["label_names"][0])
                 rot_img = self.rotate_image(image_array, angle)
@@ -102,16 +114,18 @@ class DocPreprocessorPipeline(BasePipeline):
                 angle = -1
                 rot_img = image_array
 
-            if input_params['use_doc_unwarping']:
-                output_img = next(self.doc_unwarping_model(rot_img))['doctr_img']
+            if input_params["use_doc_unwarping"]:
+                output_img = next(self.doc_unwarping_model(rot_img))["doctr_img"]
             else:
                 output_img = rot_img
 
-            single_img_res = {"input_image":image_array,
-                "input_params":input_params,
-                "angle":angle, 
-                "rot_img":rot_img, 
-                "output_img":output_img,
-                "img_id":img_id}
+            single_img_res = {
+                "input_image": image_array,
+                "input_params": input_params,
+                "angle": angle,
+                "rot_img": rot_img,
+                "output_img": output_img,
+                "img_id": img_id,
+            }
             img_id += 1
             yield DocPreprocessorResult(single_img_res)

+ 2 - 1
paddlex/inference/pipelines_new/doc_preprocessor/result.py

@@ -22,6 +22,7 @@ from PIL import Image, ImageDraw, ImageFont
 from ....utils.fonts import PINGFANG_FONT_FILE_PATH, create_font
 from ..components import CVResult
 
+
 class DocPreprocessorResult(CVResult):
 
     def save_to_img(self, save_path, *args, **kwargs):
@@ -48,4 +49,4 @@ class DocPreprocessorResult(CVResult):
             txt = txt_list[tno]
             font = create_font(txt, (w, 20), PINGFANG_FONT_FILE_PATH)
             draw_text.text([10 + w * tno, h + 2], txt, fill=(0, 0, 0), font=font)
-        return img_show
+        return img_show

+ 1 - 1
paddlex/inference/pipelines_new/layout_parsing/__init__.py

@@ -12,4 +12,4 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from .pipeline import LayoutParsingPipeline
+from .pipeline import LayoutParsingPipeline

+ 108 - 84
paddlex/inference/pipelines_new/layout_parsing/pipeline.py

@@ -25,181 +25,205 @@ from .result import LayoutParsingResult
 ########## [TODO]后续需要更新路径
 from ...components.transforms import ReadImage
 
+
 class LayoutParsingPipeline(BasePipeline):
     """Layout Parsing Pipeline"""
 
     entities = "layout_parsing"
-    def __init__(self,
-        config,        
+
+    def __init__(
+        self,
+        config,
         device=None,
-        pp_option=None, 
+        pp_option=None,
         use_hpip: bool = False,
-        hpi_params: Optional[Dict[str, Any]] = None):
-        super().__init__(device=device, pp_option=pp_option, 
-            use_hpip=use_hpip, hpi_params=hpi_params)
-        
+        hpi_params: Optional[Dict[str, Any]] = None,
+    ):
+        super().__init__(
+            device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_params=hpi_params
+        )
+
         self.inintial_predictor(config)
 
         self.img_reader = ReadImage(format="BGR")
 
         self._crop_by_boxes = CropByBoxes()
-        
 
     def inintial_predictor(self, config):
-        layout_det_config = config['SubModules']["LayoutDetection"]
+        layout_det_config = config["SubModules"]["LayoutDetection"]
         self.layout_det_model = self.create_model(layout_det_config)
 
         self.use_doc_preprocessor = False
-        if 'use_doc_preprocessor' in config:
-            self.use_doc_preprocessor = config['use_doc_preprocessor']
+        if "use_doc_preprocessor" in config:
+            self.use_doc_preprocessor = config["use_doc_preprocessor"]
         if self.use_doc_preprocessor:
-            doc_preprocessor_config = config['SubPipelines']['DocPreprocessor']
-            self.doc_preprocessor_pipeline = self.create_pipeline(doc_preprocessor_config)
-        
+            doc_preprocessor_config = config["SubPipelines"]["DocPreprocessor"]
+            self.doc_preprocessor_pipeline = self.create_pipeline(
+                doc_preprocessor_config
+            )
+
         self.use_common_ocr = False
         if "use_common_ocr" in config:
-            self.use_common_ocr = config['use_common_ocr']
+            self.use_common_ocr = config["use_common_ocr"]
         if self.use_common_ocr:
-            common_ocr_config = config['SubPipelines']['CommonOCR']
+            common_ocr_config = config["SubPipelines"]["CommonOCR"]
             self.common_ocr_pipeline = self.create_pipeline(common_ocr_config)
-        
+
         self.use_seal_recognition = False
         if "use_seal_recognition" in config:
-            self.use_seal_recognition = config['use_seal_recognition']
+            self.use_seal_recognition = config["use_seal_recognition"]
         if self.use_seal_recognition:
-            seal_ocr_config = config['SubPipelines']['SealOCR']
-            self.seal_ocr_pipeline = self.create_pipeline(seal_ocr_config)            
-        
+            seal_ocr_config = config["SubPipelines"]["SealOCR"]
+            self.seal_ocr_pipeline = self.create_pipeline(seal_ocr_config)
+
         self.use_table_recognition = False
         if "use_table_recognition" in config:
-            self.use_table_recognition = config['use_table_recognition']
+            self.use_table_recognition = config["use_table_recognition"]
         if self.use_table_recognition:
-            table_structure_config = config['SubModules']['TableStructurePredictor']
+            table_structure_config = config["SubModules"]["TableStructurePredictor"]
             self.table_structure_model = self.create_model(table_structure_config)
             if not self.use_common_ocr:
-                common_ocr_config = config['SubPipelines']['OCR']
+                common_ocr_config = config["SubPipelines"]["OCR"]
                 self.common_ocr_pipeline = self.create_pipeline(common_ocr_config)
-        return 
+        return
 
     def get_text_paragraphs_ocr_res(self, overall_ocr_res, layout_det_res):
-        '''get ocr res of the text paragraphs'''
+        """get ocr res of the text paragraphs"""
         object_boxes = []
-        for box_info in layout_det_res['boxes']:
-            if box_info['label'].lower() in ['image', 'formula', 'table', 'seal']:
-                object_boxes.append(box_info['coordinate'])
+        for box_info in layout_det_res["boxes"]:
+            if box_info["label"].lower() in ["image", "formula", "table", "seal"]:
+                object_boxes.append(box_info["coordinate"])
         object_boxes = np.array(object_boxes)
         return get_sub_regions_ocr_res(overall_ocr_res, object_boxes, flag_within=False)
 
     def check_input_params(self, input_params):
 
-        if input_params['use_doc_preprocessor'] and not self.use_doc_preprocessor:
+        if input_params["use_doc_preprocessor"] and not self.use_doc_preprocessor:
             raise ValueError("The models for doc preprocessor are not initialized.")
 
-        if input_params['use_common_ocr'] and not self.use_common_ocr:
+        if input_params["use_common_ocr"] and not self.use_common_ocr:
             raise ValueError("The models for common OCR are not initialized.")
 
-        if input_params['use_seal_recognition'] and not self.use_seal_recognition:
+        if input_params["use_seal_recognition"] and not self.use_seal_recognition:
             raise ValueError("The models for seal recognition are not initialized.")
 
-        if input_params['use_table_recognition'] and not self.use_table_recognition:
+        if input_params["use_table_recognition"] and not self.use_table_recognition:
             raise ValueError("The models for table recognition are not initialized.")
 
         return
 
-    def predict(self, input, 
+    def predict(
+        self,
+        input,
         use_doc_orientation_classify=True,
         use_doc_unwarping=True,
         use_common_ocr=True,
         use_seal_recognition=True,
         use_table_recognition=True,
-        **kwargs):
+        **kwargs
+    ):
 
         if not isinstance(input, list):
             input_list = [input]
         else:
             input_list = input
-        
-        input_params = {"use_doc_preprocessor":self.use_doc_preprocessor,
-            "use_doc_orientation_classify":use_doc_orientation_classify,
-            "use_doc_unwarping":use_doc_unwarping,
-            "use_common_ocr":use_common_ocr,
-            "use_seal_recognition":use_seal_recognition,
-            "use_table_recognition":use_table_recognition}
-            
+
+        input_params = {
+            "use_doc_preprocessor": self.use_doc_preprocessor,
+            "use_doc_orientation_classify": use_doc_orientation_classify,
+            "use_doc_unwarping": use_doc_unwarping,
+            "use_common_ocr": use_common_ocr,
+            "use_seal_recognition": use_seal_recognition,
+            "use_table_recognition": use_table_recognition,
+        }
+
         if use_doc_orientation_classify or use_doc_unwarping:
-            input_params['use_doc_preprocessor'] = True
+            input_params["use_doc_preprocessor"] = True
 
         self.check_input_params(input_params)
 
         img_id = 1
         for input in input_list:
             if isinstance(input, str):
-                image_array = next(self.img_reader(input))[0]['img']
+                image_array = next(self.img_reader(input))[0]["img"]
             else:
                 image_array = input
 
             assert len(image_array.shape) == 3
 
-            if input_params['use_doc_preprocessor']:
-                doc_preprocessor_res = next(self.doc_preprocessor_pipeline(
-                    image_array, 
-                    use_doc_orientation_classify=use_doc_orientation_classify,
-                    use_doc_unwarping=use_doc_unwarping))
-                doc_preprocessor_image = doc_preprocessor_res['output_img']
-                doc_preprocessor_res['img_id'] = img_id
+            if input_params["use_doc_preprocessor"]:
+                doc_preprocessor_res = next(
+                    self.doc_preprocessor_pipeline(
+                        image_array,
+                        use_doc_orientation_classify=use_doc_orientation_classify,
+                        use_doc_unwarping=use_doc_unwarping,
+                    )
+                )
+                doc_preprocessor_image = doc_preprocessor_res["output_img"]
+                doc_preprocessor_res["img_id"] = img_id
             else:
                 doc_preprocessor_res = {}
                 doc_preprocessor_image = image_array
-            
 
             ########## [TODO]RT-DETR 检测结果有重复
             layout_det_res = next(self.layout_det_model(doc_preprocessor_image))
 
-            if input_params['use_common_ocr'] or input_params['use_table_recognition']:
+            if input_params["use_common_ocr"] or input_params["use_table_recognition"]:
                 overall_ocr_res = next(self.common_ocr_pipeline(doc_preprocessor_image))
-                overall_ocr_res['img_id'] = img_id
-                dt_boxes = convert_points_to_boxes(overall_ocr_res['dt_polys'])
-                overall_ocr_res['dt_boxes'] = dt_boxes
+                overall_ocr_res["img_id"] = img_id
+                dt_boxes = convert_points_to_boxes(overall_ocr_res["dt_polys"])
+                overall_ocr_res["dt_boxes"] = dt_boxes
             else:
                 overall_ocr_res = {}
-            
+
             text_paragraphs_ocr_res = {}
-            if input_params['use_common_ocr']:
+            if input_params["use_common_ocr"]:
                 text_paragraphs_ocr_res = self.get_text_paragraphs_ocr_res(
-                    overall_ocr_res, layout_det_res)
-                text_paragraphs_ocr_res['img_id'] = img_id
-            
+                    overall_ocr_res, layout_det_res
+                )
+                text_paragraphs_ocr_res["img_id"] = img_id
+
             table_res_list = []
-            if input_params['use_table_recognition']:
+            if input_params["use_table_recognition"]:
                 table_region_id = 1
-                for box_info in layout_det_res['boxes']:
-                    if box_info['label'].lower() in ['table']:
-                        crop_img_info = self._crop_by_boxes(doc_preprocessor_image, [box_info])
+                for box_info in layout_det_res["boxes"]:
+                    if box_info["label"].lower() in ["table"]:
+                        crop_img_info = self._crop_by_boxes(
+                            doc_preprocessor_image, [box_info]
+                        )
                         crop_img_info = crop_img_info[0]
-                        table_structure_pred = next(self.table_structure_model(
-                            crop_img_info['img']))
+                        table_structure_pred = next(
+                            self.table_structure_model(crop_img_info["img"])
+                        )
                         table_recognition_res = get_table_recognition_res(
-                            crop_img_info, table_structure_pred, overall_ocr_res)
-                        table_recognition_res['table_region_id'] = table_region_id
+                            crop_img_info, table_structure_pred, overall_ocr_res
+                        )
+                        table_recognition_res["table_region_id"] = table_region_id
                         table_region_id += 1
                         table_res_list.append(table_recognition_res)
-            
+
             seal_res_list = []
-            if input_params['use_seal_recognition']:
+            if input_params["use_seal_recognition"]:
                 seal_region_id = 1
-                for box_info in layout_det_res['boxes']:
-                    if box_info['label'].lower() in ['seal']:
-                        crop_img_info = self._crop_by_boxes(doc_preprocessor_image, [box_info])
+                for box_info in layout_det_res["boxes"]:
+                    if box_info["label"].lower() in ["seal"]:
+                        crop_img_info = self._crop_by_boxes(
+                            doc_preprocessor_image, [box_info]
+                        )
                         crop_img_info = crop_img_info[0]
-                        seal_ocr_res = next(self.seal_ocr_pipeline(crop_img_info['img']))
-                        seal_ocr_res['seal_region_id'] = seal_region_id
+                        seal_ocr_res = next(
+                            self.seal_ocr_pipeline(crop_img_info["img"])
+                        )
+                        seal_ocr_res["seal_region_id"] = seal_region_id
                         seal_region_id += 1
                         seal_res_list.append(seal_ocr_res)
-            
-            single_img_res = {"layout_det_res":layout_det_res,
-                "doc_preprocessor_res":doc_preprocessor_res,
-                "text_paragraphs_ocr_res":text_paragraphs_ocr_res,
-                "table_res_list":table_res_list,
-                "seal_res_list":seal_res_list,
-                "input_params":input_params}
+
+            single_img_res = {
+                "layout_det_res": layout_det_res,
+                "doc_preprocessor_res": doc_preprocessor_res,
+                "text_paragraphs_ocr_res": text_paragraphs_ocr_res,
+                "table_res_list": table_res_list,
+                "seal_res_list": seal_res_list,
+                "input_params": input_params,
+            }
             yield LayoutParsingResult(single_img_res)

+ 31 - 25
paddlex/inference/pipelines_new/layout_parsing/result.py

@@ -23,6 +23,7 @@ from PIL import Image, ImageDraw, ImageFont
 from ....utils.fonts import PINGFANG_FONT_FILE_PATH
 from ..components import CVResult, HtmlMixin, XlsxMixin
 
+
 class TableRecognitionResult(CVResult, HtmlMixin, XlsxMixin):
     def __init__(self, data):
         super().__init__(data)
@@ -31,7 +32,7 @@ class TableRecognitionResult(CVResult, HtmlMixin, XlsxMixin):
 
     def save_to_html(self, save_path, *args, **kwargs):
         if not str(save_path).lower().endswith(".html"):
-            save_path = save_path + "/res_table_%d.html" % self['table_region_id']
+            save_path = save_path + "/res_table_%d.html" % self["table_region_id"]
         super().save_to_html(save_path, *args, **kwargs)
 
     def _to_html(self):
@@ -39,7 +40,7 @@ class TableRecognitionResult(CVResult, HtmlMixin, XlsxMixin):
 
     def save_to_xlsx(self, save_path, *args, **kwargs):
         if not str(save_path).lower().endswith(".xlsx"):
-            save_path = save_path + "/res_table_%d.xlsx" % self['table_region_id']
+            save_path = save_path + "/res_table_%d.xlsx" % self["table_region_id"]
         super().save_to_xlsx(save_path, *args, **kwargs)
 
     def _to_xlsx(self):
@@ -47,51 +48,56 @@ class TableRecognitionResult(CVResult, HtmlMixin, XlsxMixin):
 
     def save_to_img(self, save_path, *args, **kwargs):
         if not str(save_path).lower().endswith((".jpg", ".png")):
-            ocr_save_path = save_path + "/res_table_ocr_%d.jpg" % self['table_region_id']
-            save_path = save_path + "/res_table_cell_%d.jpg" % self['table_region_id']
-        self['table_ocr_pred'].save_to_img(ocr_save_path)
+            ocr_save_path = (
+                save_path + "/res_table_ocr_%d.jpg" % self["table_region_id"]
+            )
+            save_path = save_path + "/res_table_cell_%d.jpg" % self["table_region_id"]
+        self["table_ocr_pred"].save_to_img(ocr_save_path)
         super().save_to_img(save_path, *args, **kwargs)
 
     def _to_img(self):
-        input_img = self['table_ocr_pred']['input_img'].copy()
-        cell_box_list = self['cell_box_list']
+        input_img = self["table_ocr_pred"]["input_img"].copy()
+        cell_box_list = self["cell_box_list"]
         for box in cell_box_list:
             x1, y1, x2, y2 = [int(pos) for pos in box]
             cv2.rectangle(input_img, (x1, y1), (x2, y2), (255, 0, 0), 2)
         return input_img
 
+
 class LayoutParsingResult(dict):
     def __init__(self, data):
         super().__init__(data)
-    
+
     def save_results(self, save_path):
         if not os.path.isdir(save_path):
             raise ValueError("The save path should be a dir.")
 
-        layout_det_res = self['layout_det_res']
+        layout_det_res = self["layout_det_res"]
         save_img_path = save_path + "/layout_det_result.jpg"
         layout_det_res.save_to_img(save_img_path)
 
-        input_params = self['input_params']
-        if input_params['use_doc_preprocessor']:
+        input_params = self["input_params"]
+        if input_params["use_doc_preprocessor"]:
             save_img_path = save_path + "/doc_preprocessor_result.jpg"
-            self['doc_preprocessor_res'].save_to_img(save_img_path)
-        
-        if input_params['use_common_ocr']:
+            self["doc_preprocessor_res"].save_to_img(save_img_path)
+
+        if input_params["use_common_ocr"]:
             save_img_path = save_path + "/text_paragraphs_ocr_result.jpg"
-            self['text_paragraphs_ocr_res'].save_to_img(save_img_path)
+            self["text_paragraphs_ocr_res"].save_to_img(save_img_path)
 
-        if input_params['use_table_recognition']:
-            for tno in range(len(self['table_res_list'])):
-                table_res = self['table_res_list'][tno]
+        if input_params["use_table_recognition"]:
+            for tno in range(len(self["table_res_list"])):
+                table_res = self["table_res_list"][tno]
                 table_res.save_to_img(save_path)
                 table_res.save_to_html(save_path)
                 table_res.save_to_xlsx(save_path)
-        
-        if input_params['use_seal_recognition']:
-            for sno in range(len(self['seal_res_list'])):
-                seal_res = self['seal_res_list'][sno]
-                save_img_path = save_path + "/seal_%d_recognition_result.jpg" % seal_res['seal_region_id']
-                seal_res.save_to_img(save_img_path)          
-        return
 
+        if input_params["use_seal_recognition"]:
+            for sno in range(len(self["seal_res_list"])):
+                seal_res = self["seal_res_list"][sno]
+                save_img_path = (
+                    save_path
+                    + "/seal_%d_recognition_result.jpg" % seal_res["seal_region_id"]
+                )
+                seal_res.save_to_img(save_img_path)
+        return

+ 34 - 28
paddlex/inference/pipelines_new/layout_parsing/table_recognition_post_processing.py

@@ -16,6 +16,7 @@ from .utils import convert_points_to_boxes, get_sub_regions_ocr_res
 import numpy as np
 from .result import TableRecognitionResult
 
+
 def get_ori_image_coordinate(x, y, box_list):
     """
     get the original coordinate from Cropped image to Original image.
@@ -35,19 +36,24 @@ def get_ori_image_coordinate(x, y, box_list):
     ori_box_list = offset + box_list
     return ori_box_list
 
-def convert_table_structure_pred_bbox(table_structure_pred, 
-    crop_start_point, img_shape):
 
-    cell_points_list = table_structure_pred['bbox']
-    ori_cell_points_list = get_ori_image_coordinate(crop_start_point[0], 
-        crop_start_point[1], cell_points_list)
+def convert_table_structure_pred_bbox(
+    table_structure_pred, crop_start_point, img_shape
+):
+
+    cell_points_list = table_structure_pred["bbox"]
+    ori_cell_points_list = get_ori_image_coordinate(
+        crop_start_point[0], crop_start_point[1], cell_points_list
+    )
     ori_cell_points_list = np.reshape(ori_cell_points_list, (-1, 4, 2))
     cell_box_list = convert_points_to_boxes(ori_cell_points_list)
     img_height, img_width = img_shape
-    cell_box_list = np.clip(cell_box_list, 0, 
-        [img_width, img_height, img_width, img_height])
-    table_structure_pred['cell_box_list'] = cell_box_list
-    return 
+    cell_box_list = np.clip(
+        cell_box_list, 0, [img_width, img_height, img_width, img_height]
+    )
+    table_structure_pred["cell_box_list"] = cell_box_list
+    return
+
 
 def distance(box_1, box_2):
     """
@@ -67,6 +73,7 @@ def distance(box_1, box_2):
     dis_3 = abs(x4 - x2) + abs(y4 - y2)
     return dis + min(dis_2, dis_3)
 
+
 def compute_iou(rec1, rec2):
     """
     computing IoU
@@ -96,6 +103,7 @@ def compute_iou(rec1, rec2):
         intersect = (right_line - left_line) * (bottom_line - top_line)
         return (intersect / (sum_area - intersect)) * 1.0
 
+
 def match_table_and_ocr(cell_box_list, ocr_dt_boxes):
     """
     match table and ocr
@@ -112,18 +120,19 @@ def match_table_and_ocr(cell_box_list, ocr_dt_boxes):
         ocr_box = ocr_box.astype(np.float32)
         distances = []
         for j, table_box in enumerate(cell_box_list):
-            distances.append((distance(table_box, ocr_box), 
-                1.0 - compute_iou(table_box, ocr_box)))  # compute iou and l1 distance
+            distances.append(
+                (distance(table_box, ocr_box), 1.0 - compute_iou(table_box, ocr_box))
+            )  # compute iou and l1 distance
         sorted_distances = distances.copy()
         # select det box by iou and l1 distance
-        sorted_distances = sorted(
-            sorted_distances, key=lambda item: (item[1], item[0]))
+        sorted_distances = sorted(sorted_distances, key=lambda item: (item[1], item[0]))
         if distances.index(sorted_distances[0]) not in matched.keys():
             matched[distances.index(sorted_distances[0])] = [i]
         else:
             matched[distances.index(sorted_distances[0])].append(i)
     return matched
 
+
 def get_html_result(matched_index, ocr_contents, pred_structures):
     pred_html = []
     td_index = 0
@@ -155,10 +164,7 @@ def get_html_result(matched_index, ocr_contents, pred_structures):
                             content = content[:-4]
                         if len(content) == 0:
                             continue
-                        if (
-                            i != len(matched_index[td_index]) - 1
-                            and " " != content[-1]
-                        ):
+                        if i != len(matched_index[td_index]) - 1 and " " != content[-1]:
                             content += " "
                     pred_html.extend(content)
                 if b_with:
@@ -175,18 +181,18 @@ def get_html_result(matched_index, ocr_contents, pred_structures):
     html += "".join(end_structure)
     return html
 
+
 def get_table_recognition_res(crop_img_info, table_structure_pred, overall_ocr_res):
-    '''get_table_recognition_res'''
+    """get_table_recognition_res"""
 
-    table_box = np.array([crop_img_info['box']])
+    table_box = np.array([crop_img_info["box"]])
     table_ocr_pred = get_sub_regions_ocr_res(overall_ocr_res, table_box)
 
     crop_start_point = [table_box[0][0], table_box[0][1]]
-    img_shape = overall_ocr_res['input_img'].shape[0:2]
+    img_shape = overall_ocr_res["input_img"].shape[0:2]
+
+    convert_table_structure_pred_bbox(table_structure_pred, crop_start_point, img_shape)
 
-    convert_table_structure_pred_bbox(table_structure_pred, 
-        crop_start_point, img_shape)
-    
     structures = table_structure_pred["structure"]
     cell_box_list = table_structure_pred["cell_box_list"]
     ocr_dt_boxes = table_ocr_pred["dt_boxes"]
@@ -195,9 +201,9 @@ def get_table_recognition_res(crop_img_info, table_structure_pred, overall_ocr_r
     matched_index = match_table_and_ocr(cell_box_list, ocr_dt_boxes)
     pred_html = get_html_result(matched_index, ocr_text_res, structures)
 
-    single_img_res = {"cell_box_list":cell_box_list, 
-        "table_ocr_pred":table_ocr_pred,
-        "pred_html":pred_html}
+    single_img_res = {
+        "cell_box_list": cell_box_list,
+        "table_ocr_pred": table_ocr_pred,
+        "pred_html": pred_html,
+    }
     return TableRecognitionResult(single_img_res)
-
-

+ 19 - 17
paddlex/inference/pipelines_new/layout_parsing/utils.py

@@ -12,14 +12,12 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-__all__ = [
-    "convert_points_to_boxes",
-    "get_sub_regions_ocr_res"
-]
+__all__ = ["convert_points_to_boxes", "get_sub_regions_ocr_res"]
 
 import numpy as np
 import copy
 
+
 def convert_points_to_boxes(dt_polys):
     if len(dt_polys) > 0:
         dt_polys_tmp = dt_polys.copy()
@@ -34,8 +32,9 @@ def convert_points_to_boxes(dt_polys):
         dt_boxes = np.array([])
     return dt_boxes
 
+
 def get_overlap_boxes_idx(src_boxes, ref_boxes):
-    '''get overlap boxes idx''' 
+    """get overlap boxes idx"""
     match_idx_list = []
     src_boxes_num = len(src_boxes)
     if src_boxes_num > 0 and len(ref_boxes) > 0:
@@ -48,9 +47,10 @@ def get_overlap_boxes_idx(src_boxes, ref_boxes):
             pub_w = x2 - x1
             pub_h = y2 - y1
             match_idx = np.where((pub_w > 3) & (pub_h > 3))[0]
-            match_idx_list.extend(match_idx)                                   
+            match_idx_list.extend(match_idx)
     return match_idx_list
 
+
 def get_sub_regions_ocr_res(overall_ocr_res, object_boxes, flag_within=True):
     """
     :param flag_within: True (within the object regions), False (outside the object regions)
@@ -58,14 +58,14 @@ def get_sub_regions_ocr_res(overall_ocr_res, object_boxes, flag_within=True):
     """
 
     sub_regions_ocr_res = copy.deepcopy(overall_ocr_res)
-    sub_regions_ocr_res['input_img'] = overall_ocr_res['input_img']
-    sub_regions_ocr_res['img_id'] = -1
-    sub_regions_ocr_res['dt_polys'] = []
-    sub_regions_ocr_res['rec_text'] = []
-    sub_regions_ocr_res['rec_score'] = []
-    sub_regions_ocr_res['dt_boxes'] = []
+    sub_regions_ocr_res["input_img"] = overall_ocr_res["input_img"]
+    sub_regions_ocr_res["img_id"] = -1
+    sub_regions_ocr_res["dt_polys"] = []
+    sub_regions_ocr_res["rec_text"] = []
+    sub_regions_ocr_res["rec_score"] = []
+    sub_regions_ocr_res["dt_boxes"] = []
 
-    overall_text_boxes = overall_ocr_res['dt_boxes']
+    overall_text_boxes = overall_ocr_res["dt_boxes"]
     match_idx_list = get_overlap_boxes_idx(overall_text_boxes, object_boxes)
     match_idx_list = list(set(match_idx_list))
     for box_no in range(len(overall_text_boxes)):
@@ -80,8 +80,10 @@ def get_sub_regions_ocr_res(overall_ocr_res, object_boxes, flag_within=True):
             else:
                 flag_match = False
         if flag_match:
-            sub_regions_ocr_res['dt_polys'].append(overall_ocr_res['dt_polys'][box_no])
-            sub_regions_ocr_res['rec_text'].append(overall_ocr_res['rec_text'][box_no])
-            sub_regions_ocr_res['rec_score'].append(overall_ocr_res['rec_score'][box_no])
-            sub_regions_ocr_res['dt_boxes'].append(overall_ocr_res['dt_boxes'][box_no])
+            sub_regions_ocr_res["dt_polys"].append(overall_ocr_res["dt_polys"][box_no])
+            sub_regions_ocr_res["rec_text"].append(overall_ocr_res["rec_text"][box_no])
+            sub_regions_ocr_res["rec_score"].append(
+                overall_ocr_res["rec_score"][box_no]
+            )
+            sub_regions_ocr_res["dt_boxes"].append(overall_ocr_res["dt_boxes"][box_no])
     return sub_regions_ocr_res

+ 1 - 1
paddlex/inference/pipelines_new/ocr/__init__.py

@@ -12,4 +12,4 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from .pipeline import OCRPipeline
+from .pipeline import OCRPipeline

+ 28 - 19
paddlex/inference/pipelines_new/ocr/pipeline.py

@@ -20,33 +20,38 @@ from .result import OCRResult
 ########## [TODO]后续需要更新路径
 from ...components.transforms import ReadImage
 
+
 class OCRPipeline(BasePipeline):
     """OCR Pipeline"""
 
     entities = "OCR"
-    def __init__(self,
-        config,        
+
+    def __init__(
+        self,
+        config,
         device=None,
-        pp_option=None, 
+        pp_option=None,
         use_hpip: bool = False,
-        hpi_params: Optional[Dict[str, Any]] = None):
-        super().__init__(device=device, pp_option=pp_option, 
-            use_hpip=use_hpip, hpi_params=hpi_params)
-        
-        text_det_model_config = config['SubModules']["TextDetection"]
+        hpi_params: Optional[Dict[str, Any]] = None,
+    ):
+        super().__init__(
+            device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_params=hpi_params
+        )
+
+        text_det_model_config = config["SubModules"]["TextDetection"]
         self.text_det_model = self.create_model(text_det_model_config)
 
-        text_rec_model_config = config['SubModules']["TextRecognition"]
+        text_rec_model_config = config["SubModules"]["TextRecognition"]
         self.text_rec_model = self.create_model(text_rec_model_config)
 
-        self.text_type = config['text_type']
+        self.text_type = config["text_type"]
 
         self._sort_quad_boxes = SortQuadBoxes()
 
         if self.text_type == "common":
-            self._crop_by_polys = CropByPolys(det_box_type = "quad")
+            self._crop_by_polys = CropByPolys(det_box_type="quad")
         elif self.text_type == "seal":
-            self._crop_by_polys = CropByPolys(det_box_type = "poly")
+            self._crop_by_polys = CropByPolys(det_box_type="poly")
         else:
             raise ValueError("Unsupported text type {}".format(self.text_type))
 
@@ -60,7 +65,7 @@ class OCRPipeline(BasePipeline):
         img_id = 1
         for input in input_list:
             if isinstance(input, str):
-                image_array = next(self.img_reader(input))[0]['img']
+                image_array = next(self.img_reader(input))[0]["img"]
             else:
                 image_array = input
 
@@ -68,16 +73,20 @@ class OCRPipeline(BasePipeline):
 
             det_res = next(self.text_det_model(image_array))
 
-            dt_polys = det_res['dt_polys']
-            dt_scores = det_res['dt_scores']
+            dt_polys = det_res["dt_polys"]
+            dt_scores = det_res["dt_scores"]
 
             ########## [TODO]需要确认检测模块和识别模块过滤阈值等情况
 
             if self.text_type == "common":
                 dt_polys = self._sort_quad_boxes(dt_polys)
-                
-            single_img_res = {'input_img':image_array, 'dt_polys':dt_polys, \
-                "img_id":img_id, "text_type":self.text_type}
+
+            single_img_res = {
+                "input_img": image_array,
+                "dt_polys": dt_polys,
+                "img_id": img_id,
+                "text_type": self.text_type,
+            }
             img_id += 1
             single_img_res["rec_text"] = []
             single_img_res["rec_score"] = []
@@ -86,7 +95,7 @@ class OCRPipeline(BasePipeline):
 
                 ########## [TODO]updata in future
                 for sub_img in all_subs_of_img:
-                    sub_img['input'] = sub_img['img']
+                    sub_img["input"] = sub_img["img"]
                 ##########
 
                 for rec_res in self.text_rec_model(all_subs_of_img):

+ 3 - 2
paddlex/inference/pipelines_new/ocr/result.py

@@ -22,6 +22,7 @@ from PIL import Image, ImageDraw, ImageFont
 from ....utils.fonts import PINGFANG_FONT_FILE_PATH
 from ..components import CVResult
 
+
 class OCRResult(CVResult):
     def save_to_img(self, save_path, *args, **kwargs):
         if not str(save_path).lower().endswith((".jpg", ".png")):
@@ -61,7 +62,7 @@ class OCRResult(CVResult):
         boxes = self["dt_polys"]
         txts = self["rec_text"]
         scores = self["rec_score"]
-        image = self['input_img']
+        image = self["input_img"]
         h, w = image.shape[0:2]
         image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
         img_left = Image.fromarray(image_rgb)
@@ -157,4 +158,4 @@ def create_font(txt, sz, font_path):
     if length > sz[0]:
         font_size = int(font_size * sz[0] / length)
         font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
-    return font
+    return font

+ 1 - 1
paddlex/inference/pipelines_new/pp_chatocrv3_doc/__init__.py

@@ -12,4 +12,4 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from .pipeline import PP_ChatOCRv3_doc_Pipeline
+from .pipeline import PP_ChatOCRv3_doc_Pipeline

+ 110 - 86
paddlex/inference/pipelines_new/pp_chatocrv3_doc/pipeline.py

@@ -26,81 +26,91 @@ from ...components.transforms import ReadImage
 
 import json
 
+from ....utils import logging
+
+
 class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
     """PP-ChatOCRv3-doc Pipeline"""
 
     entities = "PP-ChatOCRv3-doc"
-    def __init__(self,
+
+    def __init__(
+        self,
         config,
         device=None,
-        pp_option=None, 
+        pp_option=None,
         use_hpip: bool = False,
-        hpi_params: Optional[Dict[str, Any]] = None):
-        super().__init__(device=device, pp_option=pp_option, 
-            use_hpip=use_hpip, hpi_params=hpi_params)
-        
+        hpi_params: Optional[Dict[str, Any]] = None,
+    ):
+        super().__init__(
+            device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_params=hpi_params
+        )
+
         self.inintial_predictor(config)
 
         self.img_reader = ReadImage(format="BGR")
-        
+
     def inintial_predictor(self, config):
         # layout_parsing_config = config['SubPipelines']["LayoutParser"]
         # self.layout_parsing_pipeline = self.create_pipeline(layout_parsing_config)
 
-        chat_bot_config = config['SubModules']['LLM_Chat']
+        chat_bot_config = config["SubModules"]["LLM_Chat"]
         self.chat_bot = self.create_chat_bot(chat_bot_config)
 
-        retriever_config = config['SubModules']['LLM_Retriever']
+        retriever_config = config["SubModules"]["LLM_Retriever"]
         self.retriever = self.create_retriever(retriever_config)
 
-        text_pe_config = config['SubModules']['PromptEngneering']['KIE_CommonText']
+        text_pe_config = config["SubModules"]["PromptEngneering"]["KIE_CommonText"]
         self.text_pe = self.create_prompt_engeering(text_pe_config)
-        
-        table_pe_config = config['SubModules']['PromptEngneering']['KIE_Table']
+
+        table_pe_config = config["SubModules"]["PromptEngneering"]["KIE_Table"]
         self.table_pe = self.create_prompt_engeering(table_pe_config)
 
-        return 
+        return
 
     def decode_visual_result(self, layout_parsing_result):
-        text_paragraphs_ocr_res = layout_parsing_result['text_paragraphs_ocr_res']
-        seal_res_list = layout_parsing_result['seal_res_list']
+        text_paragraphs_ocr_res = layout_parsing_result["text_paragraphs_ocr_res"]
+        seal_res_list = layout_parsing_result["seal_res_list"]
         normal_text_dict = {}
         layout_type = "text"
-        for text in text_paragraphs_ocr_res['rec_text']:
+        for text in text_paragraphs_ocr_res["rec_text"]:
             if layout_type not in normal_text_dict:
                 normal_text_dict[layout_type] = text
             else:
                 normal_text_dict[layout_type] += f"\n {text}"
-        
+
         layout_type = "seal"
         for seal_res in seal_res_list:
-            for text in seal_res['rec_text']:
+            for text in seal_res["rec_text"]:
                 if layout_type not in normal_text_dict:
                     normal_text_dict[layout_type] = text
                 else:
                     normal_text_dict[layout_type] += f"\n {text}"
 
-        table_res_list = layout_parsing_result['table_res_list']
+        table_res_list = layout_parsing_result["table_res_list"]
         table_text_list = []
         table_html_list = []
         for table_res in table_res_list:
-            table_html_list.append(table_res['pred_html'])
-            single_table_text = " ".join(table_res["table_ocr_pred"]['rec_text'])
+            table_html_list.append(table_res["pred_html"])
+            single_table_text = " ".join(table_res["table_ocr_pred"]["rec_text"])
             table_text_list.append(single_table_text)
 
         visual_info = {}
-        visual_info['normal_text_dict'] = normal_text_dict
-        visual_info['table_text_list'] = table_text_list
-        visual_info['table_html_list'] = table_html_list
+        visual_info["normal_text_dict"] = normal_text_dict
+        visual_info["table_text_list"] = table_text_list
+        visual_info["table_html_list"] = table_html_list
         return VisualInfoResult(visual_info)
 
-    def visual_predict(self, input,
+    def visual_predict(
+        self,
+        input,
         use_doc_orientation_classify=True,
         use_doc_unwarping=True,
         use_common_ocr=True,
         use_seal_recognition=True,
         use_table_recognition=True,
-        **kwargs):
+        **kwargs,
+    ):
 
         if not isinstance(input, list):
             input_list = [input]
@@ -110,24 +120,29 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
         img_id = 1
         for input in input_list:
             if isinstance(input, str):
-                image_array = next(self.img_reader(input))[0]['img']
+                image_array = next(self.img_reader(input))[0]["img"]
             else:
                 image_array = input
 
             assert len(image_array.shape) == 3
 
-            layout_parsing_result = next(self.layout_parsing_pipeline.predict(
-                image_array,
-                use_doc_orientation_classify=use_doc_orientation_classify,
-                use_doc_unwarping=use_doc_unwarping,
-                use_common_ocr=use_common_ocr,
-                use_seal_recognition=use_seal_recognition,
-                use_table_recognition=use_table_recognition))
-            
+            layout_parsing_result = next(
+                self.layout_parsing_pipeline.predict(
+                    image_array,
+                    use_doc_orientation_classify=use_doc_orientation_classify,
+                    use_doc_unwarping=use_doc_unwarping,
+                    use_common_ocr=use_common_ocr,
+                    use_seal_recognition=use_seal_recognition,
+                    use_table_recognition=use_table_recognition,
+                )
+            )
+
             visual_info = self.decode_visual_result(layout_parsing_result)
 
-            visual_predict_res = {"layout_parsing_result":layout_parsing_result,
-                "visual_info":visual_info}
+            visual_predict_res = {
+                "layout_parsing_result": layout_parsing_result,
+                "visual_info": visual_info,
+            }
             yield visual_predict_res
 
     def save_visual_info_list(self, visual_info, save_path):
@@ -139,7 +154,7 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
         with open(save_path, "w") as fout:
             fout.write(json.dumps(visual_info_list, ensure_ascii=False) + "\n")
         return
-    
+
     def load_visual_info_list(self, data_path):
         with open(data_path, "r") as fin:
             data = fin.readline()
@@ -151,27 +166,27 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
         all_table_text_list = []
         all_table_html_list = []
         for single_visual_info in visual_info_list:
-            normal_text_dict = single_visual_info['normal_text_dict']
-            table_text_list = single_visual_info['table_text_list']
-            table_html_list = single_visual_info['table_html_list']
+            normal_text_dict = single_visual_info["normal_text_dict"]
+            table_text_list = single_visual_info["table_text_list"]
+            table_html_list = single_visual_info["table_html_list"]
             all_normal_text_list.append(normal_text_dict)
             all_table_text_list.extend(table_text_list)
             all_table_html_list.extend(table_html_list)
         return all_normal_text_list, all_table_text_list, all_table_html_list
 
-    def build_vector(self, visual_info, 
-        min_characters=3500,
-        llm_request_interval=1.0):
+    def build_vector(self, visual_info, min_characters=3500, llm_request_interval=1.0):
 
         if not isinstance(visual_info, list):
             visual_info_list = [visual_info]
         else:
             visual_info_list = visual_info
-        
+
         all_visual_info = self.merge_visual_info_list(visual_info_list)
         all_normal_text_list, all_table_text_list, all_table_html_list = all_visual_info
 
-        all_normal_text_str = "".join(["\n".join(e.values()) for e in all_normal_text_list])
+        all_normal_text_str = "".join(
+            ["\n".join(e.values()) for e in all_normal_text_list]
+        )
         vector_info = {}
 
         all_items = []
@@ -180,12 +195,11 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
                 all_items += [f"{type}:{text}"]
 
         if len(all_normal_text_str) > min_characters:
-            vector_info['flag_too_short_text'] = False
-            vector_info['vector'] = self.retriever.generate_vector_database(
-                all_items)
+            vector_info["flag_too_short_text"] = False
+            vector_info["vector"] = self.retriever.generate_vector_database(all_items)
         else:
-            vector_info['flag_too_short_text'] = True  
-            vector_info['vector'] = all_items
+            vector_info["flag_too_short_text"] = True
+            vector_info["vector"] = all_items
         return vector_info
 
     def format_key(self, key_list):
@@ -238,12 +252,13 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
             matches = re.findall(pattern, str(results))
             if len(matches) > 0:
                 llm_result = {k: v for k, v in matches}
-                return llm_result 
+                return llm_result
             else:
-                return {}     
+                return {}
 
-    def generate_and_merge_chat_results(self, prompt, key_list,
-        final_results, failed_results):
+    def generate_and_merge_chat_results(
+        self, prompt, key_list, final_results, failed_results
+    ):
 
         llm_result = self.chat_bot.generate_chat_results(prompt)
         llm_result = self.fix_llm_result_format(llm_result)
@@ -252,22 +267,24 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
             if value not in failed_results and key in key_list:
                 key_list.remove(key)
                 final_results[key] = value
-        return 
-        
+        return
 
-    def chat(self, visual_info, 
-        key_list, 
+    def chat(
+        self,
+        visual_info,
+        key_list,
         vector_info,
         text_task_description=None,
         text_output_format=None,
         text_rules_str=None,
         text_few_shot_demo_text_content=None,
-        text_few_shot_demo_key_value_list=None,        
+        text_few_shot_demo_key_value_list=None,
         table_task_description=None,
         table_output_format=None,
         table_rules_str=None,
         table_few_shot_demo_text_content=None,
-        table_few_shot_demo_key_value_list=None):
+        table_few_shot_demo_key_value_list=None,
+    ):
 
         key_list = self.format_key(key_list)
         if len(key_list) == 0:
@@ -277,7 +294,7 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
             visual_info_list = [visual_info]
         else:
             visual_info_list = visual_info
-        
+
         all_visual_info = self.merge_visual_info_list(visual_info_list)
         all_normal_text_list, all_table_text_list, all_table_html_list = all_visual_info
 
@@ -289,36 +306,43 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
                 if len(key_list) == 0:
                     continue
 
-                prompt = self.table_pe.generate_prompt(table_info, 
-                    key_list, 
+                prompt = self.table_pe.generate_prompt(
+                    table_info,
+                    key_list,
                     task_description=table_task_description,
-                    output_format=table_output_format, 
-                    rules_str=table_rules_str, 
-                    few_shot_demo_text_content=table_few_shot_demo_text_content, 
-                    few_shot_demo_key_value_list=table_few_shot_demo_key_value_list)
-
-                self.generate_and_merge_chat_results(prompt, 
-                    key_list, final_results, failed_results)
-        
+                    output_format=table_output_format,
+                    rules_str=table_rules_str,
+                    few_shot_demo_text_content=table_few_shot_demo_text_content,
+                    few_shot_demo_key_value_list=table_few_shot_demo_key_value_list,
+                )
+
+                self.generate_and_merge_chat_results(
+                    prompt, key_list, final_results, failed_results
+                )
+
         if len(key_list) > 0:
             question_key_list = [f"抽取关键信息:{key}" for key in key_list]
-            vector = vector_info['vector']
-            if not vector_info['flag_too_short_text']:
+            vector = vector_info["vector"]
+            if not vector_info["flag_too_short_text"]:
                 related_text = self.retriever.similarity_retrieval(
-                    question_key_list, vector)
+                    question_key_list, vector
+                )
             else:
                 related_text = " ".join(vector)
-            
-            prompt = self.text_pe.generate_prompt(related_text, 
-                key_list, 
+
+            prompt = self.text_pe.generate_prompt(
+                related_text,
+                key_list,
                 task_description=text_task_description,
-                output_format=text_output_format, 
-                rules_str=text_rules_str, 
-                few_shot_demo_text_content=text_few_shot_demo_text_content, 
-                few_shot_demo_key_value_list=text_few_shot_demo_key_value_list)
-            
-            self.generate_and_merge_chat_results(prompt, 
-                key_list, final_results, failed_results)
+                output_format=text_output_format,
+                rules_str=text_rules_str,
+                few_shot_demo_text_content=text_few_shot_demo_text_content,
+                few_shot_demo_key_value_list=text_few_shot_demo_key_value_list,
+            )
+
+            self.generate_and_merge_chat_results(
+                prompt, key_list, final_results, failed_results
+            )
 
         return final_results
 
@@ -326,4 +350,4 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
         logging.error(
             "PP-ChatOCRv3-doc Pipeline do not support to call `predict()` directly! Please invoke `visual_predict`, `build_vector`, `chat` sequentially to obtain the result."
         )
-        return
+        return

+ 4 - 2
paddlex/inference/pipelines_new/pp_chatocrv3_doc/result.py

@@ -22,11 +22,13 @@ from PIL import Image, ImageDraw, ImageFont
 from ....utils.fonts import PINGFANG_FONT_FILE_PATH
 from ..components import BaseResult
 
+
 class VisualInfoResult(BaseResult):
     """VisualInfoResult"""
-    
+
     pass
 
+
 # class VectorResult(BaseResult, Base64Mixin):
 #     """VisualInfoResult"""
 
@@ -41,4 +43,4 @@ class VisualInfoResult(BaseResult):
 
 
 # class ChatResult(BaseResult):
-#     """VisualInfoResult"""
+#     """VisualInfoResult"""

+ 3 - 0
paddlex/utils/fonts/__init__.py

@@ -18,10 +18,12 @@ from pathlib import Path
 import PIL
 from PIL import ImageFont
 
+
 def get_pingfang_file_path():
     """get pingfang font file path"""
     return (Path(__file__).parent / "PingFang-SC-Regular.ttf").resolve().as_posix()
 
+
 def create_font(txt, sz, font_path):
     """create font"""
     font_size = int(sz[1] * 0.8)
@@ -36,4 +38,5 @@ def create_font(txt, sz, font_path):
         font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
     return font
 
+
 PINGFANG_FONT_FILE_PATH = get_pingfang_file_path()