Browse Source

add the new architecture of pipelines

dyning 11 tháng trước cách đây
mục cha
commit
639ad156b0
37 tập tin đã thay đổi với 598 bổ sung418 xóa
  1. 16 3
      api_examples/pipelines/test_doc_preprocessor.py
  2. 18 3
      api_examples/pipelines/test_layout_parsing.py
  3. 13 2
      api_examples/pipelines/test_ocr.py
  4. 15 2
      api_examples/pipelines/test_pp_chatocrv3.py
  5. 16 5
      api_examples/pipelines/test_table_recognition.py
  6. 1 1
      paddlex/configs/pipelines/PP-ChatOCRv3-doc.yaml
  7. 1 1
      paddlex/configs/pipelines/layout_parsing.yaml
  8. 1 0
      paddlex/inference/__init__.py
  9. 8 3
      paddlex/inference/pipelines_new/__init__.py
  10. 24 21
      paddlex/inference/pipelines_new/base.py
  11. 4 2
      paddlex/inference/pipelines_new/components/base.py
  12. 3 1
      paddlex/inference/pipelines_new/components/chat_server/base.py
  13. 19 23
      paddlex/inference/pipelines_new/components/chat_server/ernie_bot_chat.py
  14. 0 1
      paddlex/inference/pipelines_new/components/common/__init__.py
  15. 2 1
      paddlex/inference/pipelines_new/components/common/crop_image_regions.py
  16. 1 0
      paddlex/inference/pipelines_new/components/common/seal_det_warp.py
  17. 2 0
      paddlex/inference/pipelines_new/components/common/sort_boxes.py
  18. 1 1
      paddlex/inference/pipelines_new/components/prompt_engeering/__init__.py
  19. 3 1
      paddlex/inference/pipelines_new/components/prompt_engeering/base.py
  20. 27 22
      paddlex/inference/pipelines_new/components/prompt_engeering/generate_kie_prompt.py
  21. 6 3
      paddlex/inference/pipelines_new/components/retriever/base.py
  22. 20 17
      paddlex/inference/pipelines_new/components/retriever/ernie_bot_retriever.py
  23. 1 1
      paddlex/inference/pipelines_new/doc_preprocessor/__init__.py
  24. 51 37
      paddlex/inference/pipelines_new/doc_preprocessor/pipeline.py
  25. 2 1
      paddlex/inference/pipelines_new/doc_preprocessor/result.py
  26. 1 1
      paddlex/inference/pipelines_new/layout_parsing/__init__.py
  27. 108 84
      paddlex/inference/pipelines_new/layout_parsing/pipeline.py
  28. 31 25
      paddlex/inference/pipelines_new/layout_parsing/result.py
  29. 34 28
      paddlex/inference/pipelines_new/layout_parsing/table_recognition_post_processing.py
  30. 19 17
      paddlex/inference/pipelines_new/layout_parsing/utils.py
  31. 1 1
      paddlex/inference/pipelines_new/ocr/__init__.py
  32. 28 19
      paddlex/inference/pipelines_new/ocr/pipeline.py
  33. 3 2
      paddlex/inference/pipelines_new/ocr/result.py
  34. 1 1
      paddlex/inference/pipelines_new/pp_chatocrv3_doc/__init__.py
  35. 110 86
      paddlex/inference/pipelines_new/pp_chatocrv3_doc/pipeline.py
  36. 4 2
      paddlex/inference/pipelines_new/pp_chatocrv3_doc/result.py
  37. 3 0
      paddlex/utils/fonts/__init__.py

+ 16 - 3
api_examples/pipelines/test_doc_preprocessor.py

@@ -1,3 +1,16 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
 
 from paddlex import create_pipeline
 
@@ -6,9 +19,9 @@ pipeline = create_pipeline(pipeline="doc_preprocessor")
 test_img_path = "./test_imgs/img_rot180_demo.jpg"
 # test_img_path = "./test_imgs/doc_distort_test.jpg"
 
-output = pipeline.predict(test_img_path, 
-    use_doc_orientation_classify=True,
-    use_doc_unwarping=True)
+output = pipeline.predict(
+    test_img_path, use_doc_orientation_classify=True, use_doc_unwarping=True
+)
 
 for res in output:
     print(res)

+ 18 - 3
api_examples/pipelines/test_layout_parsing.py

@@ -1,14 +1,29 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
 
 from paddlex import create_pipeline
 
 pipeline = create_pipeline(pipeline="layout_parsing")
 
-output = pipeline.predict("./test_imgs/test_layout_parsing.jpg",
+output = pipeline.predict(
+    "./test_imgs/test_layout_parsing.jpg",
     use_doc_orientation_classify=True,
     use_doc_unwarping=True,
     use_common_ocr=True,
     use_seal_recognition=True,
-    use_table_recognition=True)
+    use_table_recognition=True,
+)
 
 # output = pipeline("./test_imgs/demo_paper.png")
 # output = pipeline("./test_imgs/table_recognition.jpg")
@@ -16,4 +31,4 @@ output = pipeline.predict("./test_imgs/test_layout_parsing.jpg",
 # output = pipeline.predict("./test_imgs/img_rot180_demo.jpg")
 for res in output:
     # print(res)
-    res.save_results("./output")
+    res.save_results("./output")

+ 13 - 2
api_examples/pipelines/test_ocr.py

@@ -1,3 +1,16 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
 
 from paddlex import create_pipeline
 
@@ -9,5 +22,3 @@ output = pipeline.predict("./test_imgs/seal_text_det.png")
 for res in output:
     print(res)
     res.save_to_img("./output")
-
-

+ 15 - 2
api_examples/pipelines/test_pp_chatocrv3.py

@@ -1,3 +1,16 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
 
 from paddlex import create_pipeline
 
@@ -10,9 +23,9 @@ pipeline = create_pipeline(pipeline="PP-ChatOCRv3-doc")
 # key_list = ['3.2的标题']
 
 img_path = "./test_demo_imgs/seal_text_det.png"
-key_list = ['印章上公司']
+key_list = ["印章上公司"]
 
-# visual_predict_res = pipeline.visual_predict(img_path, 
+# visual_predict_res = pipeline.visual_predict(img_path,
 #     use_doc_orientation_classify=True,
 #     use_doc_unwarping=True,
 #     use_common_ocr=True,

+ 16 - 5
api_examples/pipelines/test_table_recognition.py

@@ -1,3 +1,16 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
 
 from paddlex import create_pipeline
 
@@ -6,8 +19,6 @@ pipeline = create_pipeline(pipeline="table_recognition")
 output = pipeline("./test_imgs/table_recognition.jpg")
 for res in output:
     print(res)
-    res.save_to_img("./output/") ## 保存img格式结果
-    res.save_to_xlsx("./output/") ## 保存表格格式结果
-    res.save_to_html("./output/") ## 保存html结果
-
-
+    res.save_to_img("./output/")  ## 保存img格式结果
+    res.save_to_xlsx("./output/")  ## 保存表格格式结果
+    res.save_to_html("./output/")  ## 保存html结果

+ 1 - 1
paddlex/configs/pipelines/PP-ChatOCRv3-doc.yaml

@@ -106,4 +106,4 @@ SubPipelines:
           TextRecognition:
             model_name: PP-OCRv4_server_rec
             model_dir: null
-            batch_size: 1  
+            batch_size: 1  

+ 1 - 1
paddlex/configs/pipelines/layout_parsing.yaml

@@ -53,4 +53,4 @@ SubPipelines:
       TextRecognition:
         model_name: PP-OCRv4_server_rec
         model_dir: null
-        batch_size: 1
+        batch_size: 1

+ 1 - 0
paddlex/inference/__init__.py

@@ -14,6 +14,7 @@
 
 from .models import create_predictor
 from ..utils import flags
+
 if flags.USE_NEW_INFERENCE:
     from .pipelines_new import create_pipeline
 else:

+ 8 - 3
paddlex/inference/pipelines_new/__init__.py

@@ -45,14 +45,18 @@ from .doc_preprocessor import DocPreprocessorPipeline
 from .layout_parsing import LayoutParsingPipeline
 from .pp_chatocrv3_doc import PP_ChatOCRv3_doc_Pipeline
 
+
 def get_pipeline_path(pipeline_name):
     pipeline_path = (
-        Path(__file__).parent.parent.parent / "configs/pipelines" / f"{pipeline_name}.yaml"
+        Path(__file__).parent.parent.parent
+        / "configs/pipelines"
+        / f"{pipeline_name}.yaml"
     ).resolve()
     if not Path(pipeline_path).exists():
         return None
     return pipeline_path
 
+
 def load_pipeline_config(pipeline_name: str) -> Dict[str, Any]:
     if not Path(pipeline_name).exists():
         pipeline_path = get_pipeline_path(pipeline_name)
@@ -65,6 +69,7 @@ def load_pipeline_config(pipeline_name: str) -> Dict[str, Any]:
     config = parse_config(pipeline_path)
     return config
 
+
 def create_pipeline(
     pipeline: str,
     device=None,
@@ -90,6 +95,6 @@ def create_pipeline(
         device=device,
         pp_option=pp_option,
         use_hpip=use_hpip,
-        hpi_params=hpi_params)
+        hpi_params=hpi_params,
+    )
     return pipeline
-    

+ 24 - 21
paddlex/inference/pipelines_new/base.py

@@ -23,16 +23,19 @@ from .components.chat_server.base import BaseChat
 from .components.retriever.base import BaseRetriever
 from .components.prompt_engeering.base import BaseGeneratePrompt
 
+
 class BasePipeline(ABC, metaclass=AutoRegisterABCMetaClass):
     """Base Pipeline"""
 
     __is_base = True
 
-    def __init__(self,
-        device=None, 
-        pp_option=None, 
-        use_hpip: bool = False, 
-        hpi_params: Optional[Dict[str, Any]] = None) -> None:
+    def __init__(
+        self,
+        device=None,
+        pp_option=None,
+        use_hpip: bool = False,
+        hpi_params: Optional[Dict[str, Any]] = None,
+    ) -> None:
         super().__init__()
         self.device = device
         self.pp_option = pp_option
@@ -41,22 +44,21 @@ class BasePipeline(ABC, metaclass=AutoRegisterABCMetaClass):
 
     @abstractmethod
     def predict(self, input, **kwargs):
-        raise NotImplementedError(
-            "The method `predict` has not been implemented yet."
-        )
-    
+        raise NotImplementedError("The method `predict` has not been implemented yet.")
+
     def create_model(self, config):
 
-        model_dir = config['model_dir']
+        model_dir = config["model_dir"]
         if model_dir == None:
-            model_dir = config['model_name']
+            model_dir = config["model_name"]
 
         model = create_predictor(
             model_dir,
             device=self.device,
             pp_option=self.pp_option,
             use_hpip=self.use_hpip,
-            hpi_params=self.hpi_params)
+            hpi_params=self.hpi_params,
+        )
 
         ########### [TODO]支持初始化传参能力
         if "batch_size" in config:
@@ -66,29 +68,30 @@ class BasePipeline(ABC, metaclass=AutoRegisterABCMetaClass):
         return model
 
     def create_pipeline(self, config):
-        pipeline_name = config['pipeline_name']
+        pipeline_name = config["pipeline_name"]
         pipeline = BasePipeline.get(pipeline_name)(
             config=config,
             device=self.device,
             pp_option=self.pp_option,
             use_hpip=self.use_hpip,
-            hpi_params=self.hpi_params)
-        return pipeline 
+            hpi_params=self.hpi_params,
+        )
+        return pipeline
 
     def create_chat_bot(self, config):
-        model_name = config['model_name']
+        model_name = config["model_name"]
         chat_bot = BaseChat.get(model_name)(config)
-        return chat_bot     
+        return chat_bot
 
     def create_retriever(self, config):
-        model_name = config['model_name']
+        model_name = config["model_name"]
         retriever = BaseRetriever.get(model_name)(config)
-        return retriever    
+        return retriever
 
     def create_prompt_engeering(self, config):
-        task_type = config['task_type']
+        task_type = config["task_type"]
         pe = BaseGeneratePrompt.get(task_type)(config)
-        return pe       
+        return pe
 
     def __call__(self, input, **kwargs):
         return self.predict(input, **kwargs)

+ 4 - 2
paddlex/inference/pipelines_new/components/base.py

@@ -21,6 +21,7 @@ from ....utils.func_register import FuncRegister
 from ...utils.io import ImageReader, ImageWriter
 from .utils.mixin import JsonMixin, ImgMixin, StrMixin
 
+
 class BaseComponent(ABC, metaclass=AutoRegisterABCMetaClass):
     """Base Component"""
 
@@ -32,7 +33,8 @@ class BaseComponent(ABC, metaclass=AutoRegisterABCMetaClass):
     @abstractmethod
     def __call__(self):
         raise NotImplementedError(
-            "The component method `__call__` has not been implemented yet.")
+            "The component method `__call__` has not been implemented yet."
+        )
 
 
 class BaseResult(dict, StrMixin, JsonMixin):
@@ -50,10 +52,10 @@ class BaseResult(dict, StrMixin, JsonMixin):
             else:
                 func()
 
+
 class CVResult(BaseResult, ImgMixin):
     def __init__(self, data):
         super().__init__(data)
         ImgMixin.__init__(self, "pillow")
         self._img_reader = ImageReader(backend="pillow")
         self._img_writer = ImageWriter(backend="pillow")
-

+ 3 - 1
paddlex/inference/pipelines_new/components/chat_server/base.py

@@ -17,6 +17,7 @@ from .....utils.subclass_register import AutoRegisterABCMetaClass
 
 import inspect
 
+
 class BaseChat(ABC, metaclass=AutoRegisterABCMetaClass):
     """Base Chat"""
 
@@ -28,4 +29,5 @@ class BaseChat(ABC, metaclass=AutoRegisterABCMetaClass):
     @abstractmethod
     def generate_chat_results(self):
         raise NotImplementedError(
-            "The method `generate_chat_results` has not been implemented yet.")
+            "The method `generate_chat_results` has not been implemented yet."
+        )

+ 19 - 23
paddlex/inference/pipelines_new/components/chat_server/ernie_bot_chat.py

@@ -16,6 +16,7 @@ from .....utils import logging
 from .base import BaseChat
 import erniebot
 
+
 class ErnieBotChat(BaseChat):
     """Ernie Bot Chat"""
 
@@ -32,11 +33,11 @@ class ErnieBotChat(BaseChat):
 
     def __init__(self, config):
         super().__init__()
-        model_name = config.get('model_name', None)
-        api_type = config.get('api_type', None)
-        ak = config.get('ak', None)
-        sk = config.get('sk', None)
-        access_token = config.get('access_token', None)
+        model_name = config.get("model_name", None)
+        api_type = config.get("api_type", None)
+        ak = config.get("ak", None)
+        sk = config.get("sk", None)
+        access_token = config.get("access_token", None)
 
         if model_name not in self.entities:
             raise ValueError(f"model_name must be in {self.entities} of ErnieBotChat.")
@@ -46,13 +47,13 @@ class ErnieBotChat(BaseChat):
 
         if api_type == "aistudio" and access_token is None:
             raise ValueError("access_token cannot be empty when api_type is aistudio.")
-            
+
         if api_type == "qianfan" and (ak is None or sk is None):
-            raise ValueError("ak and sk cannot be empty when api_type is qianfan.")            
+            raise ValueError("ak and sk cannot be empty when api_type is qianfan.")
 
         self.model_name = model_name
         self.config = config
-        
+
     def generate_chat_results(self, prompt, temperature=0.001, max_retries=1):
         """
         args:
@@ -60,14 +61,14 @@ class ErnieBotChat(BaseChat):
         """
         try:
             cur_config = {
-                "api_type": self.config['api_type'],
-                "max_retries": max_retries
+                "api_type": self.config["api_type"],
+                "max_retries": max_retries,
             }
-            if self.config['api_type'] == "aistudio":
-                cur_config['access_token'] = self.config['access_token']
-            elif self.config['api_type'] == "qianfan":
-                cur_config['ak'] = self.config['ak']
-                cur_config['sk'] = self.config['sk']
+            if self.config["api_type"] == "aistudio":
+                cur_config["access_token"] = self.config["access_token"]
+            elif self.config["api_type"] == "qianfan":
+                cur_config["ak"] = self.config["ak"]
+                cur_config["sk"] = self.config["sk"]
             chat_completion = erniebot.ChatCompletion.create(
                 _config_=cur_config,
                 model=self.model_name,
@@ -78,18 +79,13 @@ class ErnieBotChat(BaseChat):
             return llm_result
         except Exception as e:
             if len(e.args) < 1:
-                self.ERROR_MASSAGE = (
-                    "暂无权限访问ErnieBot服务,请检查访问令牌。"
-                )
+                self.ERROR_MASSAGE = "暂无权限访问ErnieBot服务,请检查访问令牌。"
             elif (
                 e.args[-1]
                 == "暂无权限使用,请在 AI Studio 正确获取访问令牌(access token)使用"
             ):
-                self.ERROR_MASSAGE = (
-                    "暂无权限访问ErnieBot服务,请检查访问令牌。"
-                )
+                self.ERROR_MASSAGE = "暂无权限访问ErnieBot服务,请检查访问令牌。"
             else:
                 logging.error(e)
                 self.ERROR_MASSAGE = "大模型调用失败"
-        return None 
-        
+        return None

+ 0 - 1
paddlex/inference/pipelines_new/components/common/__init__.py

@@ -14,4 +14,3 @@
 
 from .sort_boxes import SortQuadBoxes
 from .crop_image_regions import CropByPolys, CropByBoxes
-

+ 2 - 1
paddlex/inference/pipelines_new/components/common/crop_image_regions.py

@@ -21,6 +21,7 @@ from .seal_det_warp import AutoRectifier
 from shapely.geometry import Polygon
 from numpy.linalg import norm
 
+
 class CropByPolys(BaseComponent):
     """Crop Image by Polys"""
 
@@ -472,4 +473,4 @@ class CropByBoxes(BaseComponent):
             xmin, ymin, xmax, ymax = [int(i) for i in box]
             img_crop = img[ymin:ymax, xmin:xmax].copy()
             output_list.append({"img": img_crop, "box": box, "label": label})
-        return output_list
+        return output_list

+ 1 - 0
paddlex/inference/pipelines_new/components/common/seal_det_warp.py

@@ -21,6 +21,7 @@ import time
 
 from .....utils import logging
 
+
 def Homography(
     image,
     img_points,

+ 2 - 0
paddlex/inference/pipelines_new/components/common/sort_boxes.py

@@ -15,10 +15,12 @@
 from ..base import BaseComponent
 import numpy as np
 
+
 class SortQuadBoxes(BaseComponent):
     """SortQuadBoxes Component"""
 
     entities = "SortQuadBoxes"
+
     def __init__(self):
         super().__init__()
 

+ 1 - 1
paddlex/inference/pipelines_new/components/prompt_engeering/__init__.py

@@ -12,4 +12,4 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from .generate_kie_prompt import GenerateKIEPrompt
+from .generate_kie_prompt import GenerateKIEPrompt

+ 3 - 1
paddlex/inference/pipelines_new/components/prompt_engeering/base.py

@@ -17,6 +17,7 @@ from .....utils.subclass_register import AutoRegisterABCMetaClass
 
 import inspect
 
+
 class BaseGeneratePrompt(ABC, metaclass=AutoRegisterABCMetaClass):
     """Base Chat"""
 
@@ -28,4 +29,5 @@ class BaseGeneratePrompt(ABC, metaclass=AutoRegisterABCMetaClass):
     @abstractmethod
     def generate_prompt(self):
         raise NotImplementedError(
-            "The method `generate_prompt` has not been implemented yet.")
+            "The method `generate_prompt` has not been implemented yet."
+        )

+ 27 - 22
paddlex/inference/pipelines_new/components/prompt_engeering/generate_kie_prompt.py

@@ -14,41 +14,41 @@
 
 from .base import BaseGeneratePrompt
 
+
 class GenerateKIEPrompt(BaseGeneratePrompt):
     """Generate KIE Prompt"""
 
-    entities = [
-        "text_kie_prompt",
-        "table_kie_prompt"
-    ]
+    entities = ["text_kie_prompt", "table_kie_prompt"]
 
     def __init__(self, config):
         super().__init__()
 
-        task_type = config.get('task_type', "")
-        task_description = config.get('task_description', "")  
-        output_format = config.get('output_format', "")  
-        rules_str = config.get('rules_str', "")  
-        few_shot_demo_text_content = config.get('few_shot_demo_text_content', "")  
-        few_shot_demo_key_value_list = config.get('few_shot_demo_key_value_list', "")
+        task_type = config.get("task_type", "")
+        task_description = config.get("task_description", "")
+        output_format = config.get("output_format", "")
+        rules_str = config.get("rules_str", "")
+        few_shot_demo_text_content = config.get("few_shot_demo_text_content", "")
+        few_shot_demo_key_value_list = config.get("few_shot_demo_key_value_list", "")
 
         if task_description is None:
             task_description = ""
-        
+
         if output_format is None:
             output_format = ""
-        
+
         if rules_str is None:
             rules_str = ""
-        
+
         if few_shot_demo_text_content is None:
             few_shot_demo_text_content = ""
-        
+
         if few_shot_demo_key_value_list is None:
             few_shot_demo_key_value_list = ""
 
         if task_type not in self.entities:
-            raise ValueError(f"task type must be in {self.entities} of GenerateKIEPrompt.")
+            raise ValueError(
+                f"task type must be in {self.entities} of GenerateKIEPrompt."
+            )
 
         self.task_type = task_type
         self.task_description = task_description
@@ -56,14 +56,17 @@ class GenerateKIEPrompt(BaseGeneratePrompt):
         self.rules_str = rules_str
         self.few_shot_demo_text_content = few_shot_demo_text_content
         self.few_shot_demo_key_value_list = few_shot_demo_key_value_list
-        
-    def generate_prompt(self, text_content,
+
+    def generate_prompt(
+        self,
+        text_content,
         key_list,
         task_description=None,
         output_format=None,
         rules_str=None,
         few_shot_demo_text_content=None,
-        few_shot_demo_key_value_list=None):
+        few_shot_demo_key_value_list=None,
+    ):
         """
         args:
         return:
@@ -80,7 +83,7 @@ class GenerateKIEPrompt(BaseGeneratePrompt):
 
         if few_shot_demo_text_content is None:
             few_shot_demo_text_content = self.few_shot_demo_text_content
-            
+
         if few_shot_demo_key_value_list is None:
             few_shot_demo_key_value_list = self.few_shot_demo_key_value_list
 
@@ -89,12 +92,14 @@ class GenerateKIEPrompt(BaseGeneratePrompt):
             prompt += f"""\n结合上面,下面正式开始:\
                 表格内容:```{text_content}```\
                 关键词列表:{key_list}。""".replace(
-                "    ", "")
+                "    ", ""
+            )
         elif self.task_type == "text_kie_prompt":
             prompt += f"""\n结合上面的例子,下面正式开始:\
                 OCR文字:```{text_content}```\
                 关键词列表:{key_list}。""".replace(
-                "    ", "")
+                "    ", ""
+            )
         else:
-            raise ValueError(f"{self.task_type} is currently not supported.") 
+            raise ValueError(f"{self.task_type} is currently not supported.")
         return prompt

+ 6 - 3
paddlex/inference/pipelines_new/components/retriever/base.py

@@ -18,6 +18,7 @@ from .....utils.subclass_register import AutoRegisterABCMetaClass
 import inspect
 import base64
 
+
 class BaseRetriever(ABC, metaclass=AutoRegisterABCMetaClass):
     """Base Retriever"""
 
@@ -31,12 +32,14 @@ class BaseRetriever(ABC, metaclass=AutoRegisterABCMetaClass):
     @abstractmethod
     def generate_vector_database(self):
         raise NotImplementedError(
-            "The method `generate_vector_database` has not been implemented yet.")
+            "The method `generate_vector_database` has not been implemented yet."
+        )
 
     @abstractmethod
     def similarity_retrieval(self):
         raise NotImplementedError(
-            "The method `similarity_retrieval` has not been implemented yet.")
+            "The method `similarity_retrieval` has not been implemented yet."
+        )
 
     def is_vector_store(self, s):
         return s.startswith(self.VECTOR_STORE_PREFIX)
@@ -47,4 +50,4 @@ class BaseRetriever(ABC, metaclass=AutoRegisterABCMetaClass):
         )
 
     def decode_vector_store(self, vector_store_str):
-        return base64.b64decode(vector_store_str[len(self.VECTOR_STORE_PREFIX):])
+        return base64.b64decode(vector_store_str[len(self.VECTOR_STORE_PREFIX) :])

+ 20 - 17
paddlex/inference/pipelines_new/components/retriever/ernie_bot_retriever.py

@@ -25,6 +25,7 @@ from erniebot_agent.extensions.langchain.embeddings import ErnieEmbeddings
 
 import time
 
+
 class ErnieBotRetriever(BaseRetriever):
     """Ernie Bot Retriever"""
 
@@ -38,17 +39,17 @@ class ErnieBotRetriever(BaseRetriever):
         "ernie-speed-128k",
         "ernie-char-8k",
     ]
-    
+
     def __init__(self, config):
 
         super().__init__()
 
-        model_name = config.get('model_name', None)
-        api_type = config.get('api_type', None)
-        ak = config.get('ak', None)
-        sk = config.get('sk', None)
-        access_token = config.get('access_token', None)
-        
+        model_name = config.get("model_name", None)
+        api_type = config.get("api_type", None)
+        ak = config.get("ak", None)
+        sk = config.get("sk", None)
+        access_token = config.get("access_token", None)
+
         if model_name not in self.entities:
             raise ValueError(f"model_name must be in {self.entities} of ErnieBotChat.")
 
@@ -57,17 +58,20 @@ class ErnieBotRetriever(BaseRetriever):
 
         if api_type == "aistudio" and access_token is None:
             raise ValueError("access_token cannot be empty when api_type is aistudio.")
-            
+
         if api_type == "qianfan" and (ak is None or sk is None):
-            raise ValueError("ak and sk cannot be empty when api_type is qianfan.")            
+            raise ValueError("ak and sk cannot be empty when api_type is qianfan.")
 
         self.model_name = model_name
         self.config = config
-        
-    def generate_vector_database(self, text_list, 
+
+    def generate_vector_database(
+        self,
+        text_list,
         block_size=300,
         separators=["\t", "\n", "。", "\n\n", ""],
-        sleep_time=0.5):
+        sleep_time=0.5,
+    ):
         """
         args:
         return:
@@ -112,7 +116,7 @@ class ErnieBotRetriever(BaseRetriever):
     def encode_vector_store_to_bytes(self, vectorstore):
         vectorstore = self.encode_vector_store(vectorstore.serialize_to_bytes())
         return vectorstore
-    
+
     def decode_vector_store_from_bytes(self, vectorstore):
         if not self.is_vector_store(vectorstore):
             raise ValueError("The retrieved vectorstore is not for PaddleX.")
@@ -127,10 +131,10 @@ class ErnieBotRetriever(BaseRetriever):
             embeddings = QianfanEmbeddingsEndpoint(qianfan_ak=ak, qianfan_sk=sk)
         else:
             raise ValueError(f"Unsupported api_type: {api_type}")
-        vectorstore = vectorstores.FAISS.deserialize_from_bytes(
-            self.decode_vector_store(vector), embeddings
+        vector = vectorstores.FAISS.deserialize_from_bytes(
+            self.decode_vector_store(vectorstore), embeddings
         )
-        return vectorstore
+        return vector
 
     def similarity_retrieval(self, query_text_list, vectorstore, sleep_time=0.5):
         # 根据提问匹配上下文
@@ -145,4 +149,3 @@ class ErnieBotRetriever(BaseRetriever):
         C = list(set(C))
         all_C = " ".join(C)
         return all_C
-        

+ 1 - 1
paddlex/inference/pipelines_new/doc_preprocessor/__init__.py

@@ -12,4 +12,4 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from .pipeline import DocPreprocessorPipeline
+from .pipeline import DocPreprocessorPipeline

+ 51 - 37
paddlex/inference/pipelines_new/doc_preprocessor/pipeline.py

@@ -20,35 +20,40 @@ from .result import DocPreprocessorResult
 ########## [TODO]后续需要更新路径
 from ...components.transforms import ReadImage
 
+
 class DocPreprocessorPipeline(BasePipeline):
     """Doc Preprocessor Pipeline"""
 
     entities = "doc_preprocessor"
-    def __init__(self,
-        config,        
+
+    def __init__(
+        self,
+        config,
         device=None,
-        pp_option=None, 
+        pp_option=None,
         use_hpip: bool = False,
-        hpi_params: Optional[Dict[str, Any]] = None):
-        super().__init__(device=device, pp_option=pp_option, 
-            use_hpip=use_hpip, hpi_params=hpi_params)
-        
+        hpi_params: Optional[Dict[str, Any]] = None,
+    ):
+        super().__init__(
+            device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_params=hpi_params
+        )
+
         self.use_doc_orientation_classify = True
-        if 'use_doc_orientation_classify' in config:
-            self.use_doc_orientation_classify = config['use_doc_orientation_classify']
+        if "use_doc_orientation_classify" in config:
+            self.use_doc_orientation_classify = config["use_doc_orientation_classify"]
 
         self.use_doc_unwarping = True
-        if 'use_doc_unwarping' in config:
-            self.use_doc_unwarping = config['use_doc_unwarping']
-        
+        if "use_doc_unwarping" in config:
+            self.use_doc_unwarping = config["use_doc_unwarping"]
+
         if self.use_doc_orientation_classify:
-            doc_ori_classify_config = config['SubModules']["DocOrientationClassify"]
+            doc_ori_classify_config = config["SubModules"]["DocOrientationClassify"]
             self.doc_ori_classify_model = self.create_model(doc_ori_classify_config)
 
         if self.use_doc_unwarping:
-            doc_unwarping_config = config['SubModules']["DocUnwarping"]
+            doc_unwarping_config = config["SubModules"]["DocUnwarping"]
             self.doc_unwarping_model = self.create_model(doc_unwarping_config)
-        
+
         self.img_reader = ReadImage(format="BGR")
 
     def rotate_image(self, image_array, rotate_angle):
@@ -59,42 +64,49 @@ class DocPreprocessorPipeline(BasePipeline):
         return rotate(image_array, rotate_angle, reshape=True)
 
     def check_input_params(self, input_params):
-        
-        if input_params['use_doc_orientation_classify'] and \
-            not self.use_doc_orientation_classify:
-            raise ValueError("The model for doc orientation classify is not initialized.")
 
+        if (
+            input_params["use_doc_orientation_classify"]
+            and not self.use_doc_orientation_classify
+        ):
+            raise ValueError(
+                "The model for doc orientation classify is not initialized."
+            )
 
-        if input_params['use_doc_unwarping'] and \
-            not self.use_doc_unwarping:
+        if input_params["use_doc_unwarping"] and not self.use_doc_unwarping:
             raise ValueError("The model for doc unwarping is not initialized.")
-            
-        return 
 
-    def predict(self, input, 
+        return
+
+    def predict(
+        self,
+        input,
         use_doc_orientation_classify=True,
         use_doc_unwarping=False,
-        **kwargs):
+        **kwargs
+    ):
 
         if not isinstance(input, list):
             input_list = [input]
         else:
             input_list = input
 
-        input_params = {"use_doc_orientation_classify":use_doc_orientation_classify,
-            "use_doc_unwarping":use_doc_unwarping}
+        input_params = {
+            "use_doc_orientation_classify": use_doc_orientation_classify,
+            "use_doc_unwarping": use_doc_unwarping,
+        }
         self.check_input_params(input_params)
 
         img_id = 1
         for input in input_list:
             if isinstance(input, str):
-                image_array = next(self.img_reader(input))[0]['img']
+                image_array = next(self.img_reader(input))[0]["img"]
             else:
                 image_array = input
 
             assert len(image_array.shape) == 3
 
-            if input_params['use_doc_orientation_classify']:
+            if input_params["use_doc_orientation_classify"]:
                 pred = next(self.doc_ori_classify_model(image_array))
                 angle = int(pred["label_names"][0])
                 rot_img = self.rotate_image(image_array, angle)
@@ -102,16 +114,18 @@ class DocPreprocessorPipeline(BasePipeline):
                 angle = -1
                 rot_img = image_array
 
-            if input_params['use_doc_unwarping']:
-                output_img = next(self.doc_unwarping_model(rot_img))['doctr_img']
+            if input_params["use_doc_unwarping"]:
+                output_img = next(self.doc_unwarping_model(rot_img))["doctr_img"]
             else:
                 output_img = rot_img
 
-            single_img_res = {"input_image":image_array,
-                "input_params":input_params,
-                "angle":angle, 
-                "rot_img":rot_img, 
-                "output_img":output_img,
-                "img_id":img_id}
+            single_img_res = {
+                "input_image": image_array,
+                "input_params": input_params,
+                "angle": angle,
+                "rot_img": rot_img,
+                "output_img": output_img,
+                "img_id": img_id,
+            }
             img_id += 1
             yield DocPreprocessorResult(single_img_res)

+ 2 - 1
paddlex/inference/pipelines_new/doc_preprocessor/result.py

@@ -22,6 +22,7 @@ from PIL import Image, ImageDraw, ImageFont
 from ....utils.fonts import PINGFANG_FONT_FILE_PATH, create_font
 from ..components import CVResult
 
+
 class DocPreprocessorResult(CVResult):
 
     def save_to_img(self, save_path, *args, **kwargs):
@@ -48,4 +49,4 @@ class DocPreprocessorResult(CVResult):
             txt = txt_list[tno]
             font = create_font(txt, (w, 20), PINGFANG_FONT_FILE_PATH)
             draw_text.text([10 + w * tno, h + 2], txt, fill=(0, 0, 0), font=font)
-        return img_show
+        return img_show

+ 1 - 1
paddlex/inference/pipelines_new/layout_parsing/__init__.py

@@ -12,4 +12,4 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from .pipeline import LayoutParsingPipeline
+from .pipeline import LayoutParsingPipeline

+ 108 - 84
paddlex/inference/pipelines_new/layout_parsing/pipeline.py

@@ -25,181 +25,205 @@ from .result import LayoutParsingResult
 ########## [TODO]后续需要更新路径
 from ...components.transforms import ReadImage
 
+
 class LayoutParsingPipeline(BasePipeline):
     """Layout Parsing Pipeline"""
 
     entities = "layout_parsing"
-    def __init__(self,
-        config,        
+
+    def __init__(
+        self,
+        config,
         device=None,
-        pp_option=None, 
+        pp_option=None,
         use_hpip: bool = False,
-        hpi_params: Optional[Dict[str, Any]] = None):
-        super().__init__(device=device, pp_option=pp_option, 
-            use_hpip=use_hpip, hpi_params=hpi_params)
-        
+        hpi_params: Optional[Dict[str, Any]] = None,
+    ):
+        super().__init__(
+            device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_params=hpi_params
+        )
+
         self.inintial_predictor(config)
 
         self.img_reader = ReadImage(format="BGR")
 
         self._crop_by_boxes = CropByBoxes()
-        
 
     def inintial_predictor(self, config):
-        layout_det_config = config['SubModules']["LayoutDetection"]
+        layout_det_config = config["SubModules"]["LayoutDetection"]
         self.layout_det_model = self.create_model(layout_det_config)
 
         self.use_doc_preprocessor = False
-        if 'use_doc_preprocessor' in config:
-            self.use_doc_preprocessor = config['use_doc_preprocessor']
+        if "use_doc_preprocessor" in config:
+            self.use_doc_preprocessor = config["use_doc_preprocessor"]
         if self.use_doc_preprocessor:
-            doc_preprocessor_config = config['SubPipelines']['DocPreprocessor']
-            self.doc_preprocessor_pipeline = self.create_pipeline(doc_preprocessor_config)
-        
+            doc_preprocessor_config = config["SubPipelines"]["DocPreprocessor"]
+            self.doc_preprocessor_pipeline = self.create_pipeline(
+                doc_preprocessor_config
+            )
+
         self.use_common_ocr = False
         if "use_common_ocr" in config:
-            self.use_common_ocr = config['use_common_ocr']
+            self.use_common_ocr = config["use_common_ocr"]
         if self.use_common_ocr:
-            common_ocr_config = config['SubPipelines']['CommonOCR']
+            common_ocr_config = config["SubPipelines"]["CommonOCR"]
             self.common_ocr_pipeline = self.create_pipeline(common_ocr_config)
-        
+
         self.use_seal_recognition = False
         if "use_seal_recognition" in config:
-            self.use_seal_recognition = config['use_seal_recognition']
+            self.use_seal_recognition = config["use_seal_recognition"]
         if self.use_seal_recognition:
-            seal_ocr_config = config['SubPipelines']['SealOCR']
-            self.seal_ocr_pipeline = self.create_pipeline(seal_ocr_config)            
-        
+            seal_ocr_config = config["SubPipelines"]["SealOCR"]
+            self.seal_ocr_pipeline = self.create_pipeline(seal_ocr_config)
+
         self.use_table_recognition = False
         if "use_table_recognition" in config:
-            self.use_table_recognition = config['use_table_recognition']
+            self.use_table_recognition = config["use_table_recognition"]
         if self.use_table_recognition:
-            table_structure_config = config['SubModules']['TableStructurePredictor']
+            table_structure_config = config["SubModules"]["TableStructurePredictor"]
             self.table_structure_model = self.create_model(table_structure_config)
             if not self.use_common_ocr:
-                common_ocr_config = config['SubPipelines']['OCR']
+                common_ocr_config = config["SubPipelines"]["OCR"]
                 self.common_ocr_pipeline = self.create_pipeline(common_ocr_config)
-        return 
+        return
 
     def get_text_paragraphs_ocr_res(self, overall_ocr_res, layout_det_res):
-        '''get ocr res of the text paragraphs'''
+        """get ocr res of the text paragraphs"""
         object_boxes = []
-        for box_info in layout_det_res['boxes']:
-            if box_info['label'].lower() in ['image', 'formula', 'table', 'seal']:
-                object_boxes.append(box_info['coordinate'])
+        for box_info in layout_det_res["boxes"]:
+            if box_info["label"].lower() in ["image", "formula", "table", "seal"]:
+                object_boxes.append(box_info["coordinate"])
         object_boxes = np.array(object_boxes)
         return get_sub_regions_ocr_res(overall_ocr_res, object_boxes, flag_within=False)
 
     def check_input_params(self, input_params):
 
-        if input_params['use_doc_preprocessor'] and not self.use_doc_preprocessor:
+        if input_params["use_doc_preprocessor"] and not self.use_doc_preprocessor:
             raise ValueError("The models for doc preprocessor are not initialized.")
 
-        if input_params['use_common_ocr'] and not self.use_common_ocr:
+        if input_params["use_common_ocr"] and not self.use_common_ocr:
             raise ValueError("The models for common OCR are not initialized.")
 
-        if input_params['use_seal_recognition'] and not self.use_seal_recognition:
+        if input_params["use_seal_recognition"] and not self.use_seal_recognition:
             raise ValueError("The models for seal recognition are not initialized.")
 
-        if input_params['use_table_recognition'] and not self.use_table_recognition:
+        if input_params["use_table_recognition"] and not self.use_table_recognition:
             raise ValueError("The models for table recognition are not initialized.")
 
         return
 
-    def predict(self, input, 
+    def predict(
+        self,
+        input,
         use_doc_orientation_classify=True,
         use_doc_unwarping=True,
         use_common_ocr=True,
         use_seal_recognition=True,
         use_table_recognition=True,
-        **kwargs):
+        **kwargs
+    ):
 
         if not isinstance(input, list):
             input_list = [input]
         else:
             input_list = input
-        
-        input_params = {"use_doc_preprocessor":self.use_doc_preprocessor,
-            "use_doc_orientation_classify":use_doc_orientation_classify,
-            "use_doc_unwarping":use_doc_unwarping,
-            "use_common_ocr":use_common_ocr,
-            "use_seal_recognition":use_seal_recognition,
-            "use_table_recognition":use_table_recognition}
-            
+
+        input_params = {
+            "use_doc_preprocessor": self.use_doc_preprocessor,
+            "use_doc_orientation_classify": use_doc_orientation_classify,
+            "use_doc_unwarping": use_doc_unwarping,
+            "use_common_ocr": use_common_ocr,
+            "use_seal_recognition": use_seal_recognition,
+            "use_table_recognition": use_table_recognition,
+        }
+
         if use_doc_orientation_classify or use_doc_unwarping:
-            input_params['use_doc_preprocessor'] = True
+            input_params["use_doc_preprocessor"] = True
 
         self.check_input_params(input_params)
 
         img_id = 1
         for input in input_list:
             if isinstance(input, str):
-                image_array = next(self.img_reader(input))[0]['img']
+                image_array = next(self.img_reader(input))[0]["img"]
             else:
                 image_array = input
 
             assert len(image_array.shape) == 3
 
-            if input_params['use_doc_preprocessor']:
-                doc_preprocessor_res = next(self.doc_preprocessor_pipeline(
-                    image_array, 
-                    use_doc_orientation_classify=use_doc_orientation_classify,
-                    use_doc_unwarping=use_doc_unwarping))
-                doc_preprocessor_image = doc_preprocessor_res['output_img']
-                doc_preprocessor_res['img_id'] = img_id
+            if input_params["use_doc_preprocessor"]:
+                doc_preprocessor_res = next(
+                    self.doc_preprocessor_pipeline(
+                        image_array,
+                        use_doc_orientation_classify=use_doc_orientation_classify,
+                        use_doc_unwarping=use_doc_unwarping,
+                    )
+                )
+                doc_preprocessor_image = doc_preprocessor_res["output_img"]
+                doc_preprocessor_res["img_id"] = img_id
             else:
                 doc_preprocessor_res = {}
                 doc_preprocessor_image = image_array
-            
 
             ########## [TODO]RT-DETR 检测结果有重复
             layout_det_res = next(self.layout_det_model(doc_preprocessor_image))
 
-            if input_params['use_common_ocr'] or input_params['use_table_recognition']:
+            if input_params["use_common_ocr"] or input_params["use_table_recognition"]:
                 overall_ocr_res = next(self.common_ocr_pipeline(doc_preprocessor_image))
-                overall_ocr_res['img_id'] = img_id
-                dt_boxes = convert_points_to_boxes(overall_ocr_res['dt_polys'])
-                overall_ocr_res['dt_boxes'] = dt_boxes
+                overall_ocr_res["img_id"] = img_id
+                dt_boxes = convert_points_to_boxes(overall_ocr_res["dt_polys"])
+                overall_ocr_res["dt_boxes"] = dt_boxes
             else:
                 overall_ocr_res = {}
-            
+
             text_paragraphs_ocr_res = {}
-            if input_params['use_common_ocr']:
+            if input_params["use_common_ocr"]:
                 text_paragraphs_ocr_res = self.get_text_paragraphs_ocr_res(
-                    overall_ocr_res, layout_det_res)
-                text_paragraphs_ocr_res['img_id'] = img_id
-            
+                    overall_ocr_res, layout_det_res
+                )
+                text_paragraphs_ocr_res["img_id"] = img_id
+
             table_res_list = []
-            if input_params['use_table_recognition']:
+            if input_params["use_table_recognition"]:
                 table_region_id = 1
-                for box_info in layout_det_res['boxes']:
-                    if box_info['label'].lower() in ['table']:
-                        crop_img_info = self._crop_by_boxes(doc_preprocessor_image, [box_info])
+                for box_info in layout_det_res["boxes"]:
+                    if box_info["label"].lower() in ["table"]:
+                        crop_img_info = self._crop_by_boxes(
+                            doc_preprocessor_image, [box_info]
+                        )
                         crop_img_info = crop_img_info[0]
-                        table_structure_pred = next(self.table_structure_model(
-                            crop_img_info['img']))
+                        table_structure_pred = next(
+                            self.table_structure_model(crop_img_info["img"])
+                        )
                         table_recognition_res = get_table_recognition_res(
-                            crop_img_info, table_structure_pred, overall_ocr_res)
-                        table_recognition_res['table_region_id'] = table_region_id
+                            crop_img_info, table_structure_pred, overall_ocr_res
+                        )
+                        table_recognition_res["table_region_id"] = table_region_id
                         table_region_id += 1
                         table_res_list.append(table_recognition_res)
-            
+
             seal_res_list = []
-            if input_params['use_seal_recognition']:
+            if input_params["use_seal_recognition"]:
                 seal_region_id = 1
-                for box_info in layout_det_res['boxes']:
-                    if box_info['label'].lower() in ['seal']:
-                        crop_img_info = self._crop_by_boxes(doc_preprocessor_image, [box_info])
+                for box_info in layout_det_res["boxes"]:
+                    if box_info["label"].lower() in ["seal"]:
+                        crop_img_info = self._crop_by_boxes(
+                            doc_preprocessor_image, [box_info]
+                        )
                         crop_img_info = crop_img_info[0]
-                        seal_ocr_res = next(self.seal_ocr_pipeline(crop_img_info['img']))
-                        seal_ocr_res['seal_region_id'] = seal_region_id
+                        seal_ocr_res = next(
+                            self.seal_ocr_pipeline(crop_img_info["img"])
+                        )
+                        seal_ocr_res["seal_region_id"] = seal_region_id
                         seal_region_id += 1
                         seal_res_list.append(seal_ocr_res)
-            
-            single_img_res = {"layout_det_res":layout_det_res,
-                "doc_preprocessor_res":doc_preprocessor_res,
-                "text_paragraphs_ocr_res":text_paragraphs_ocr_res,
-                "table_res_list":table_res_list,
-                "seal_res_list":seal_res_list,
-                "input_params":input_params}
+
+            single_img_res = {
+                "layout_det_res": layout_det_res,
+                "doc_preprocessor_res": doc_preprocessor_res,
+                "text_paragraphs_ocr_res": text_paragraphs_ocr_res,
+                "table_res_list": table_res_list,
+                "seal_res_list": seal_res_list,
+                "input_params": input_params,
+            }
             yield LayoutParsingResult(single_img_res)

+ 31 - 25
paddlex/inference/pipelines_new/layout_parsing/result.py

@@ -23,6 +23,7 @@ from PIL import Image, ImageDraw, ImageFont
 from ....utils.fonts import PINGFANG_FONT_FILE_PATH
 from ..components import CVResult, HtmlMixin, XlsxMixin
 
+
 class TableRecognitionResult(CVResult, HtmlMixin, XlsxMixin):
     def __init__(self, data):
         super().__init__(data)
@@ -31,7 +32,7 @@ class TableRecognitionResult(CVResult, HtmlMixin, XlsxMixin):
 
     def save_to_html(self, save_path, *args, **kwargs):
         if not str(save_path).lower().endswith(".html"):
-            save_path = save_path + "/res_table_%d.html" % self['table_region_id']
+            save_path = save_path + "/res_table_%d.html" % self["table_region_id"]
         super().save_to_html(save_path, *args, **kwargs)
 
     def _to_html(self):
@@ -39,7 +40,7 @@ class TableRecognitionResult(CVResult, HtmlMixin, XlsxMixin):
 
     def save_to_xlsx(self, save_path, *args, **kwargs):
         if not str(save_path).lower().endswith(".xlsx"):
-            save_path = save_path + "/res_table_%d.xlsx" % self['table_region_id']
+            save_path = save_path + "/res_table_%d.xlsx" % self["table_region_id"]
         super().save_to_xlsx(save_path, *args, **kwargs)
 
     def _to_xlsx(self):
@@ -47,51 +48,56 @@ class TableRecognitionResult(CVResult, HtmlMixin, XlsxMixin):
 
     def save_to_img(self, save_path, *args, **kwargs):
         if not str(save_path).lower().endswith((".jpg", ".png")):
-            ocr_save_path = save_path + "/res_table_ocr_%d.jpg" % self['table_region_id']
-            save_path = save_path + "/res_table_cell_%d.jpg" % self['table_region_id']
-        self['table_ocr_pred'].save_to_img(ocr_save_path)
+            ocr_save_path = (
+                save_path + "/res_table_ocr_%d.jpg" % self["table_region_id"]
+            )
+            save_path = save_path + "/res_table_cell_%d.jpg" % self["table_region_id"]
+        self["table_ocr_pred"].save_to_img(ocr_save_path)
         super().save_to_img(save_path, *args, **kwargs)
 
     def _to_img(self):
-        input_img = self['table_ocr_pred']['input_img'].copy()
-        cell_box_list = self['cell_box_list']
+        input_img = self["table_ocr_pred"]["input_img"].copy()
+        cell_box_list = self["cell_box_list"]
         for box in cell_box_list:
             x1, y1, x2, y2 = [int(pos) for pos in box]
             cv2.rectangle(input_img, (x1, y1), (x2, y2), (255, 0, 0), 2)
         return input_img
 
+
 class LayoutParsingResult(dict):
     def __init__(self, data):
         super().__init__(data)
-    
+
     def save_results(self, save_path):
         if not os.path.isdir(save_path):
             raise ValueError("The save path should be a dir.")
 
-        layout_det_res = self['layout_det_res']
+        layout_det_res = self["layout_det_res"]
         save_img_path = save_path + "/layout_det_result.jpg"
         layout_det_res.save_to_img(save_img_path)
 
-        input_params = self['input_params']
-        if input_params['use_doc_preprocessor']:
+        input_params = self["input_params"]
+        if input_params["use_doc_preprocessor"]:
             save_img_path = save_path + "/doc_preprocessor_result.jpg"
-            self['doc_preprocessor_res'].save_to_img(save_img_path)
-        
-        if input_params['use_common_ocr']:
+            self["doc_preprocessor_res"].save_to_img(save_img_path)
+
+        if input_params["use_common_ocr"]:
             save_img_path = save_path + "/text_paragraphs_ocr_result.jpg"
-            self['text_paragraphs_ocr_res'].save_to_img(save_img_path)
+            self["text_paragraphs_ocr_res"].save_to_img(save_img_path)
 
-        if input_params['use_table_recognition']:
-            for tno in range(len(self['table_res_list'])):
-                table_res = self['table_res_list'][tno]
+        if input_params["use_table_recognition"]:
+            for tno in range(len(self["table_res_list"])):
+                table_res = self["table_res_list"][tno]
                 table_res.save_to_img(save_path)
                 table_res.save_to_html(save_path)
                 table_res.save_to_xlsx(save_path)
-        
-        if input_params['use_seal_recognition']:
-            for sno in range(len(self['seal_res_list'])):
-                seal_res = self['seal_res_list'][sno]
-                save_img_path = save_path + "/seal_%d_recognition_result.jpg" % seal_res['seal_region_id']
-                seal_res.save_to_img(save_img_path)          
-        return
 
+        if input_params["use_seal_recognition"]:
+            for sno in range(len(self["seal_res_list"])):
+                seal_res = self["seal_res_list"][sno]
+                save_img_path = (
+                    save_path
+                    + "/seal_%d_recognition_result.jpg" % seal_res["seal_region_id"]
+                )
+                seal_res.save_to_img(save_img_path)
+        return

+ 34 - 28
paddlex/inference/pipelines_new/layout_parsing/table_recognition_post_processing.py

@@ -16,6 +16,7 @@ from .utils import convert_points_to_boxes, get_sub_regions_ocr_res
 import numpy as np
 from .result import TableRecognitionResult
 
+
 def get_ori_image_coordinate(x, y, box_list):
     """
     get the original coordinate from Cropped image to Original image.
@@ -35,19 +36,24 @@ def get_ori_image_coordinate(x, y, box_list):
     ori_box_list = offset + box_list
     return ori_box_list
 
-def convert_table_structure_pred_bbox(table_structure_pred, 
-    crop_start_point, img_shape):
 
-    cell_points_list = table_structure_pred['bbox']
-    ori_cell_points_list = get_ori_image_coordinate(crop_start_point[0], 
-        crop_start_point[1], cell_points_list)
+def convert_table_structure_pred_bbox(
+    table_structure_pred, crop_start_point, img_shape
+):
+
+    cell_points_list = table_structure_pred["bbox"]
+    ori_cell_points_list = get_ori_image_coordinate(
+        crop_start_point[0], crop_start_point[1], cell_points_list
+    )
     ori_cell_points_list = np.reshape(ori_cell_points_list, (-1, 4, 2))
     cell_box_list = convert_points_to_boxes(ori_cell_points_list)
     img_height, img_width = img_shape
-    cell_box_list = np.clip(cell_box_list, 0, 
-        [img_width, img_height, img_width, img_height])
-    table_structure_pred['cell_box_list'] = cell_box_list
-    return 
+    cell_box_list = np.clip(
+        cell_box_list, 0, [img_width, img_height, img_width, img_height]
+    )
+    table_structure_pred["cell_box_list"] = cell_box_list
+    return
+
 
 def distance(box_1, box_2):
     """
@@ -67,6 +73,7 @@ def distance(box_1, box_2):
     dis_3 = abs(x4 - x2) + abs(y4 - y2)
     return dis + min(dis_2, dis_3)
 
+
 def compute_iou(rec1, rec2):
     """
     computing IoU
@@ -96,6 +103,7 @@ def compute_iou(rec1, rec2):
         intersect = (right_line - left_line) * (bottom_line - top_line)
         return (intersect / (sum_area - intersect)) * 1.0
 
+
 def match_table_and_ocr(cell_box_list, ocr_dt_boxes):
     """
     match table and ocr
@@ -112,18 +120,19 @@ def match_table_and_ocr(cell_box_list, ocr_dt_boxes):
         ocr_box = ocr_box.astype(np.float32)
         distances = []
         for j, table_box in enumerate(cell_box_list):
-            distances.append((distance(table_box, ocr_box), 
-                1.0 - compute_iou(table_box, ocr_box)))  # compute iou and l1 distance
+            distances.append(
+                (distance(table_box, ocr_box), 1.0 - compute_iou(table_box, ocr_box))
+            )  # compute iou and l1 distance
         sorted_distances = distances.copy()
         # select det box by iou and l1 distance
-        sorted_distances = sorted(
-            sorted_distances, key=lambda item: (item[1], item[0]))
+        sorted_distances = sorted(sorted_distances, key=lambda item: (item[1], item[0]))
         if distances.index(sorted_distances[0]) not in matched.keys():
             matched[distances.index(sorted_distances[0])] = [i]
         else:
             matched[distances.index(sorted_distances[0])].append(i)
     return matched
 
+
 def get_html_result(matched_index, ocr_contents, pred_structures):
     pred_html = []
     td_index = 0
@@ -155,10 +164,7 @@ def get_html_result(matched_index, ocr_contents, pred_structures):
                             content = content[:-4]
                         if len(content) == 0:
                             continue
-                        if (
-                            i != len(matched_index[td_index]) - 1
-                            and " " != content[-1]
-                        ):
+                        if i != len(matched_index[td_index]) - 1 and " " != content[-1]:
                             content += " "
                     pred_html.extend(content)
                 if b_with:
@@ -175,18 +181,18 @@ def get_html_result(matched_index, ocr_contents, pred_structures):
     html += "".join(end_structure)
     return html
 
+
 def get_table_recognition_res(crop_img_info, table_structure_pred, overall_ocr_res):
-    '''get_table_recognition_res'''
+    """get_table_recognition_res"""
 
-    table_box = np.array([crop_img_info['box']])
+    table_box = np.array([crop_img_info["box"]])
     table_ocr_pred = get_sub_regions_ocr_res(overall_ocr_res, table_box)
 
     crop_start_point = [table_box[0][0], table_box[0][1]]
-    img_shape = overall_ocr_res['input_img'].shape[0:2]
+    img_shape = overall_ocr_res["input_img"].shape[0:2]
+
+    convert_table_structure_pred_bbox(table_structure_pred, crop_start_point, img_shape)
 
-    convert_table_structure_pred_bbox(table_structure_pred, 
-        crop_start_point, img_shape)
-    
     structures = table_structure_pred["structure"]
     cell_box_list = table_structure_pred["cell_box_list"]
     ocr_dt_boxes = table_ocr_pred["dt_boxes"]
@@ -195,9 +201,9 @@ def get_table_recognition_res(crop_img_info, table_structure_pred, overall_ocr_r
     matched_index = match_table_and_ocr(cell_box_list, ocr_dt_boxes)
     pred_html = get_html_result(matched_index, ocr_text_res, structures)
 
-    single_img_res = {"cell_box_list":cell_box_list, 
-        "table_ocr_pred":table_ocr_pred,
-        "pred_html":pred_html}
+    single_img_res = {
+        "cell_box_list": cell_box_list,
+        "table_ocr_pred": table_ocr_pred,
+        "pred_html": pred_html,
+    }
     return TableRecognitionResult(single_img_res)
-
-

+ 19 - 17
paddlex/inference/pipelines_new/layout_parsing/utils.py

@@ -12,14 +12,12 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-__all__ = [
-    "convert_points_to_boxes",
-    "get_sub_regions_ocr_res"
-]
+__all__ = ["convert_points_to_boxes", "get_sub_regions_ocr_res"]
 
 import numpy as np
 import copy
 
+
 def convert_points_to_boxes(dt_polys):
     if len(dt_polys) > 0:
         dt_polys_tmp = dt_polys.copy()
@@ -34,8 +32,9 @@ def convert_points_to_boxes(dt_polys):
         dt_boxes = np.array([])
     return dt_boxes
 
+
 def get_overlap_boxes_idx(src_boxes, ref_boxes):
-    '''get overlap boxes idx''' 
+    """get overlap boxes idx"""
     match_idx_list = []
     src_boxes_num = len(src_boxes)
     if src_boxes_num > 0 and len(ref_boxes) > 0:
@@ -48,9 +47,10 @@ def get_overlap_boxes_idx(src_boxes, ref_boxes):
             pub_w = x2 - x1
             pub_h = y2 - y1
             match_idx = np.where((pub_w > 3) & (pub_h > 3))[0]
-            match_idx_list.extend(match_idx)                                   
+            match_idx_list.extend(match_idx)
     return match_idx_list
 
+
 def get_sub_regions_ocr_res(overall_ocr_res, object_boxes, flag_within=True):
     """
     :param flag_within: True (within the object regions), False (outside the object regions)
@@ -58,14 +58,14 @@ def get_sub_regions_ocr_res(overall_ocr_res, object_boxes, flag_within=True):
     """
 
     sub_regions_ocr_res = copy.deepcopy(overall_ocr_res)
-    sub_regions_ocr_res['input_img'] = overall_ocr_res['input_img']
-    sub_regions_ocr_res['img_id'] = -1
-    sub_regions_ocr_res['dt_polys'] = []
-    sub_regions_ocr_res['rec_text'] = []
-    sub_regions_ocr_res['rec_score'] = []
-    sub_regions_ocr_res['dt_boxes'] = []
+    sub_regions_ocr_res["input_img"] = overall_ocr_res["input_img"]
+    sub_regions_ocr_res["img_id"] = -1
+    sub_regions_ocr_res["dt_polys"] = []
+    sub_regions_ocr_res["rec_text"] = []
+    sub_regions_ocr_res["rec_score"] = []
+    sub_regions_ocr_res["dt_boxes"] = []
 
-    overall_text_boxes = overall_ocr_res['dt_boxes']
+    overall_text_boxes = overall_ocr_res["dt_boxes"]
     match_idx_list = get_overlap_boxes_idx(overall_text_boxes, object_boxes)
     match_idx_list = list(set(match_idx_list))
     for box_no in range(len(overall_text_boxes)):
@@ -80,8 +80,10 @@ def get_sub_regions_ocr_res(overall_ocr_res, object_boxes, flag_within=True):
             else:
                 flag_match = False
         if flag_match:
-            sub_regions_ocr_res['dt_polys'].append(overall_ocr_res['dt_polys'][box_no])
-            sub_regions_ocr_res['rec_text'].append(overall_ocr_res['rec_text'][box_no])
-            sub_regions_ocr_res['rec_score'].append(overall_ocr_res['rec_score'][box_no])
-            sub_regions_ocr_res['dt_boxes'].append(overall_ocr_res['dt_boxes'][box_no])
+            sub_regions_ocr_res["dt_polys"].append(overall_ocr_res["dt_polys"][box_no])
+            sub_regions_ocr_res["rec_text"].append(overall_ocr_res["rec_text"][box_no])
+            sub_regions_ocr_res["rec_score"].append(
+                overall_ocr_res["rec_score"][box_no]
+            )
+            sub_regions_ocr_res["dt_boxes"].append(overall_ocr_res["dt_boxes"][box_no])
     return sub_regions_ocr_res

+ 1 - 1
paddlex/inference/pipelines_new/ocr/__init__.py

@@ -12,4 +12,4 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from .pipeline import OCRPipeline
+from .pipeline import OCRPipeline

+ 28 - 19
paddlex/inference/pipelines_new/ocr/pipeline.py

@@ -20,33 +20,38 @@ from .result import OCRResult
 ########## [TODO]后续需要更新路径
 from ...components.transforms import ReadImage
 
+
 class OCRPipeline(BasePipeline):
     """OCR Pipeline"""
 
     entities = "OCR"
-    def __init__(self,
-        config,        
+
+    def __init__(
+        self,
+        config,
         device=None,
-        pp_option=None, 
+        pp_option=None,
         use_hpip: bool = False,
-        hpi_params: Optional[Dict[str, Any]] = None):
-        super().__init__(device=device, pp_option=pp_option, 
-            use_hpip=use_hpip, hpi_params=hpi_params)
-        
-        text_det_model_config = config['SubModules']["TextDetection"]
+        hpi_params: Optional[Dict[str, Any]] = None,
+    ):
+        super().__init__(
+            device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_params=hpi_params
+        )
+
+        text_det_model_config = config["SubModules"]["TextDetection"]
         self.text_det_model = self.create_model(text_det_model_config)
 
-        text_rec_model_config = config['SubModules']["TextRecognition"]
+        text_rec_model_config = config["SubModules"]["TextRecognition"]
         self.text_rec_model = self.create_model(text_rec_model_config)
 
-        self.text_type = config['text_type']
+        self.text_type = config["text_type"]
 
         self._sort_quad_boxes = SortQuadBoxes()
 
         if self.text_type == "common":
-            self._crop_by_polys = CropByPolys(det_box_type = "quad")
+            self._crop_by_polys = CropByPolys(det_box_type="quad")
         elif self.text_type == "seal":
-            self._crop_by_polys = CropByPolys(det_box_type = "poly")
+            self._crop_by_polys = CropByPolys(det_box_type="poly")
         else:
             raise ValueError("Unsupported text type {}".format(self.text_type))
 
@@ -60,7 +65,7 @@ class OCRPipeline(BasePipeline):
         img_id = 1
         for input in input_list:
             if isinstance(input, str):
-                image_array = next(self.img_reader(input))[0]['img']
+                image_array = next(self.img_reader(input))[0]["img"]
             else:
                 image_array = input
 
@@ -68,16 +73,20 @@ class OCRPipeline(BasePipeline):
 
             det_res = next(self.text_det_model(image_array))
 
-            dt_polys = det_res['dt_polys']
-            dt_scores = det_res['dt_scores']
+            dt_polys = det_res["dt_polys"]
+            dt_scores = det_res["dt_scores"]
 
             ########## [TODO]需要确认检测模块和识别模块过滤阈值等情况
 
             if self.text_type == "common":
                 dt_polys = self._sort_quad_boxes(dt_polys)
-                
-            single_img_res = {'input_img':image_array, 'dt_polys':dt_polys, \
-                "img_id":img_id, "text_type":self.text_type}
+
+            single_img_res = {
+                "input_img": image_array,
+                "dt_polys": dt_polys,
+                "img_id": img_id,
+                "text_type": self.text_type,
+            }
             img_id += 1
             single_img_res["rec_text"] = []
             single_img_res["rec_score"] = []
@@ -86,7 +95,7 @@ class OCRPipeline(BasePipeline):
 
                 ########## [TODO]updata in future
                 for sub_img in all_subs_of_img:
-                    sub_img['input'] = sub_img['img']
+                    sub_img["input"] = sub_img["img"]
                 ##########
 
                 for rec_res in self.text_rec_model(all_subs_of_img):

+ 3 - 2
paddlex/inference/pipelines_new/ocr/result.py

@@ -22,6 +22,7 @@ from PIL import Image, ImageDraw, ImageFont
 from ....utils.fonts import PINGFANG_FONT_FILE_PATH
 from ..components import CVResult
 
+
 class OCRResult(CVResult):
     def save_to_img(self, save_path, *args, **kwargs):
         if not str(save_path).lower().endswith((".jpg", ".png")):
@@ -61,7 +62,7 @@ class OCRResult(CVResult):
         boxes = self["dt_polys"]
         txts = self["rec_text"]
         scores = self["rec_score"]
-        image = self['input_img']
+        image = self["input_img"]
         h, w = image.shape[0:2]
         image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
         img_left = Image.fromarray(image_rgb)
@@ -157,4 +158,4 @@ def create_font(txt, sz, font_path):
     if length > sz[0]:
         font_size = int(font_size * sz[0] / length)
         font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
-    return font
+    return font

+ 1 - 1
paddlex/inference/pipelines_new/pp_chatocrv3_doc/__init__.py

@@ -12,4 +12,4 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from .pipeline import PP_ChatOCRv3_doc_Pipeline
+from .pipeline import PP_ChatOCRv3_doc_Pipeline

+ 110 - 86
paddlex/inference/pipelines_new/pp_chatocrv3_doc/pipeline.py

@@ -26,81 +26,91 @@ from ...components.transforms import ReadImage
 
 import json
 
+from ....utils import logging
+
+
 class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
     """PP-ChatOCRv3-doc Pipeline"""
 
     entities = "PP-ChatOCRv3-doc"
-    def __init__(self,
+
+    def __init__(
+        self,
         config,
         device=None,
-        pp_option=None, 
+        pp_option=None,
         use_hpip: bool = False,
-        hpi_params: Optional[Dict[str, Any]] = None):
-        super().__init__(device=device, pp_option=pp_option, 
-            use_hpip=use_hpip, hpi_params=hpi_params)
-        
+        hpi_params: Optional[Dict[str, Any]] = None,
+    ):
+        super().__init__(
+            device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_params=hpi_params
+        )
+
         self.inintial_predictor(config)
 
         self.img_reader = ReadImage(format="BGR")
-        
+
     def inintial_predictor(self, config):
         # layout_parsing_config = config['SubPipelines']["LayoutParser"]
         # self.layout_parsing_pipeline = self.create_pipeline(layout_parsing_config)
 
-        chat_bot_config = config['SubModules']['LLM_Chat']
+        chat_bot_config = config["SubModules"]["LLM_Chat"]
         self.chat_bot = self.create_chat_bot(chat_bot_config)
 
-        retriever_config = config['SubModules']['LLM_Retriever']
+        retriever_config = config["SubModules"]["LLM_Retriever"]
         self.retriever = self.create_retriever(retriever_config)
 
-        text_pe_config = config['SubModules']['PromptEngneering']['KIE_CommonText']
+        text_pe_config = config["SubModules"]["PromptEngneering"]["KIE_CommonText"]
         self.text_pe = self.create_prompt_engeering(text_pe_config)
-        
-        table_pe_config = config['SubModules']['PromptEngneering']['KIE_Table']
+
+        table_pe_config = config["SubModules"]["PromptEngneering"]["KIE_Table"]
         self.table_pe = self.create_prompt_engeering(table_pe_config)
 
-        return 
+        return
 
     def decode_visual_result(self, layout_parsing_result):
-        text_paragraphs_ocr_res = layout_parsing_result['text_paragraphs_ocr_res']
-        seal_res_list = layout_parsing_result['seal_res_list']
+        text_paragraphs_ocr_res = layout_parsing_result["text_paragraphs_ocr_res"]
+        seal_res_list = layout_parsing_result["seal_res_list"]
         normal_text_dict = {}
         layout_type = "text"
-        for text in text_paragraphs_ocr_res['rec_text']:
+        for text in text_paragraphs_ocr_res["rec_text"]:
             if layout_type not in normal_text_dict:
                 normal_text_dict[layout_type] = text
             else:
                 normal_text_dict[layout_type] += f"\n {text}"
-        
+
         layout_type = "seal"
         for seal_res in seal_res_list:
-            for text in seal_res['rec_text']:
+            for text in seal_res["rec_text"]:
                 if layout_type not in normal_text_dict:
                     normal_text_dict[layout_type] = text
                 else:
                     normal_text_dict[layout_type] += f"\n {text}"
 
-        table_res_list = layout_parsing_result['table_res_list']
+        table_res_list = layout_parsing_result["table_res_list"]
         table_text_list = []
         table_html_list = []
         for table_res in table_res_list:
-            table_html_list.append(table_res['pred_html'])
-            single_table_text = " ".join(table_res["table_ocr_pred"]['rec_text'])
+            table_html_list.append(table_res["pred_html"])
+            single_table_text = " ".join(table_res["table_ocr_pred"]["rec_text"])
             table_text_list.append(single_table_text)
 
         visual_info = {}
-        visual_info['normal_text_dict'] = normal_text_dict
-        visual_info['table_text_list'] = table_text_list
-        visual_info['table_html_list'] = table_html_list
+        visual_info["normal_text_dict"] = normal_text_dict
+        visual_info["table_text_list"] = table_text_list
+        visual_info["table_html_list"] = table_html_list
         return VisualInfoResult(visual_info)
 
-    def visual_predict(self, input,
+    def visual_predict(
+        self,
+        input,
         use_doc_orientation_classify=True,
         use_doc_unwarping=True,
         use_common_ocr=True,
         use_seal_recognition=True,
         use_table_recognition=True,
-        **kwargs):
+        **kwargs,
+    ):
 
         if not isinstance(input, list):
             input_list = [input]
@@ -110,24 +120,29 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
         img_id = 1
         for input in input_list:
             if isinstance(input, str):
-                image_array = next(self.img_reader(input))[0]['img']
+                image_array = next(self.img_reader(input))[0]["img"]
             else:
                 image_array = input
 
             assert len(image_array.shape) == 3
 
-            layout_parsing_result = next(self.layout_parsing_pipeline.predict(
-                image_array,
-                use_doc_orientation_classify=use_doc_orientation_classify,
-                use_doc_unwarping=use_doc_unwarping,
-                use_common_ocr=use_common_ocr,
-                use_seal_recognition=use_seal_recognition,
-                use_table_recognition=use_table_recognition))
-            
+            layout_parsing_result = next(
+                self.layout_parsing_pipeline.predict(
+                    image_array,
+                    use_doc_orientation_classify=use_doc_orientation_classify,
+                    use_doc_unwarping=use_doc_unwarping,
+                    use_common_ocr=use_common_ocr,
+                    use_seal_recognition=use_seal_recognition,
+                    use_table_recognition=use_table_recognition,
+                )
+            )
+
             visual_info = self.decode_visual_result(layout_parsing_result)
 
-            visual_predict_res = {"layout_parsing_result":layout_parsing_result,
-                "visual_info":visual_info}
+            visual_predict_res = {
+                "layout_parsing_result": layout_parsing_result,
+                "visual_info": visual_info,
+            }
             yield visual_predict_res
 
     def save_visual_info_list(self, visual_info, save_path):
@@ -139,7 +154,7 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
         with open(save_path, "w") as fout:
             fout.write(json.dumps(visual_info_list, ensure_ascii=False) + "\n")
         return
-    
+
     def load_visual_info_list(self, data_path):
         with open(data_path, "r") as fin:
             data = fin.readline()
@@ -151,27 +166,27 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
         all_table_text_list = []
         all_table_html_list = []
         for single_visual_info in visual_info_list:
-            normal_text_dict = single_visual_info['normal_text_dict']
-            table_text_list = single_visual_info['table_text_list']
-            table_html_list = single_visual_info['table_html_list']
+            normal_text_dict = single_visual_info["normal_text_dict"]
+            table_text_list = single_visual_info["table_text_list"]
+            table_html_list = single_visual_info["table_html_list"]
             all_normal_text_list.append(normal_text_dict)
             all_table_text_list.extend(table_text_list)
             all_table_html_list.extend(table_html_list)
         return all_normal_text_list, all_table_text_list, all_table_html_list
 
-    def build_vector(self, visual_info, 
-        min_characters=3500,
-        llm_request_interval=1.0):
+    def build_vector(self, visual_info, min_characters=3500, llm_request_interval=1.0):
 
         if not isinstance(visual_info, list):
             visual_info_list = [visual_info]
         else:
             visual_info_list = visual_info
-        
+
         all_visual_info = self.merge_visual_info_list(visual_info_list)
         all_normal_text_list, all_table_text_list, all_table_html_list = all_visual_info
 
-        all_normal_text_str = "".join(["\n".join(e.values()) for e in all_normal_text_list])
+        all_normal_text_str = "".join(
+            ["\n".join(e.values()) for e in all_normal_text_list]
+        )
         vector_info = {}
 
         all_items = []
@@ -180,12 +195,11 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
                 all_items += [f"{type}:{text}"]
 
         if len(all_normal_text_str) > min_characters:
-            vector_info['flag_too_short_text'] = False
-            vector_info['vector'] = self.retriever.generate_vector_database(
-                all_items)
+            vector_info["flag_too_short_text"] = False
+            vector_info["vector"] = self.retriever.generate_vector_database(all_items)
         else:
-            vector_info['flag_too_short_text'] = True  
-            vector_info['vector'] = all_items
+            vector_info["flag_too_short_text"] = True
+            vector_info["vector"] = all_items
         return vector_info
 
     def format_key(self, key_list):
@@ -238,12 +252,13 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
             matches = re.findall(pattern, str(results))
             if len(matches) > 0:
                 llm_result = {k: v for k, v in matches}
-                return llm_result 
+                return llm_result
             else:
-                return {}     
+                return {}
 
-    def generate_and_merge_chat_results(self, prompt, key_list,
-        final_results, failed_results):
+    def generate_and_merge_chat_results(
+        self, prompt, key_list, final_results, failed_results
+    ):
 
         llm_result = self.chat_bot.generate_chat_results(prompt)
         llm_result = self.fix_llm_result_format(llm_result)
@@ -252,22 +267,24 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
             if value not in failed_results and key in key_list:
                 key_list.remove(key)
                 final_results[key] = value
-        return 
-        
+        return
 
-    def chat(self, visual_info, 
-        key_list, 
+    def chat(
+        self,
+        visual_info,
+        key_list,
         vector_info,
         text_task_description=None,
         text_output_format=None,
         text_rules_str=None,
         text_few_shot_demo_text_content=None,
-        text_few_shot_demo_key_value_list=None,        
+        text_few_shot_demo_key_value_list=None,
         table_task_description=None,
         table_output_format=None,
         table_rules_str=None,
         table_few_shot_demo_text_content=None,
-        table_few_shot_demo_key_value_list=None):
+        table_few_shot_demo_key_value_list=None,
+    ):
 
         key_list = self.format_key(key_list)
         if len(key_list) == 0:
@@ -277,7 +294,7 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
             visual_info_list = [visual_info]
         else:
             visual_info_list = visual_info
-        
+
         all_visual_info = self.merge_visual_info_list(visual_info_list)
         all_normal_text_list, all_table_text_list, all_table_html_list = all_visual_info
 
@@ -289,36 +306,43 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
                 if len(key_list) == 0:
                     continue
 
-                prompt = self.table_pe.generate_prompt(table_info, 
-                    key_list, 
+                prompt = self.table_pe.generate_prompt(
+                    table_info,
+                    key_list,
                     task_description=table_task_description,
-                    output_format=table_output_format, 
-                    rules_str=table_rules_str, 
-                    few_shot_demo_text_content=table_few_shot_demo_text_content, 
-                    few_shot_demo_key_value_list=table_few_shot_demo_key_value_list)
-
-                self.generate_and_merge_chat_results(prompt, 
-                    key_list, final_results, failed_results)
-        
+                    output_format=table_output_format,
+                    rules_str=table_rules_str,
+                    few_shot_demo_text_content=table_few_shot_demo_text_content,
+                    few_shot_demo_key_value_list=table_few_shot_demo_key_value_list,
+                )
+
+                self.generate_and_merge_chat_results(
+                    prompt, key_list, final_results, failed_results
+                )
+
         if len(key_list) > 0:
             question_key_list = [f"抽取关键信息:{key}" for key in key_list]
-            vector = vector_info['vector']
-            if not vector_info['flag_too_short_text']:
+            vector = vector_info["vector"]
+            if not vector_info["flag_too_short_text"]:
                 related_text = self.retriever.similarity_retrieval(
-                    question_key_list, vector)
+                    question_key_list, vector
+                )
             else:
                 related_text = " ".join(vector)
-            
-            prompt = self.text_pe.generate_prompt(related_text, 
-                key_list, 
+
+            prompt = self.text_pe.generate_prompt(
+                related_text,
+                key_list,
                 task_description=text_task_description,
-                output_format=text_output_format, 
-                rules_str=text_rules_str, 
-                few_shot_demo_text_content=text_few_shot_demo_text_content, 
-                few_shot_demo_key_value_list=text_few_shot_demo_key_value_list)
-            
-            self.generate_and_merge_chat_results(prompt, 
-                key_list, final_results, failed_results)
+                output_format=text_output_format,
+                rules_str=text_rules_str,
+                few_shot_demo_text_content=text_few_shot_demo_text_content,
+                few_shot_demo_key_value_list=text_few_shot_demo_key_value_list,
+            )
+
+            self.generate_and_merge_chat_results(
+                prompt, key_list, final_results, failed_results
+            )
 
         return final_results
 
@@ -326,4 +350,4 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
         logging.error(
             "PP-ChatOCRv3-doc Pipeline do not support to call `predict()` directly! Please invoke `visual_predict`, `build_vector`, `chat` sequentially to obtain the result."
         )
-        return
+        return

+ 4 - 2
paddlex/inference/pipelines_new/pp_chatocrv3_doc/result.py

@@ -22,11 +22,13 @@ from PIL import Image, ImageDraw, ImageFont
 from ....utils.fonts import PINGFANG_FONT_FILE_PATH
 from ..components import BaseResult
 
+
 class VisualInfoResult(BaseResult):
     """VisualInfoResult"""
-    
+
     pass
 
+
 # class VectorResult(BaseResult, Base64Mixin):
 #     """VisualInfoResult"""
 
@@ -41,4 +43,4 @@ class VisualInfoResult(BaseResult):
 
 
 # class ChatResult(BaseResult):
-#     """VisualInfoResult"""
+#     """VisualInfoResult"""

+ 3 - 0
paddlex/utils/fonts/__init__.py

@@ -18,10 +18,12 @@ from pathlib import Path
 import PIL
 from PIL import ImageFont
 
+
 def get_pingfang_file_path():
     """get pingfang font file path"""
     return (Path(__file__).parent / "PingFang-SC-Regular.ttf").resolve().as_posix()
 
+
 def create_font(txt, sz, font_path):
     """create font"""
     font_size = int(sz[1] * 0.8)
@@ -36,4 +38,5 @@ def create_font(txt, sz, font_path):
         font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
     return font
 
+
 PINGFANG_FONT_FILE_PATH = get_pingfang_file_path()