Jelajahi Sumber

add the new architecture of pipelines

dyning 11 bulan lalu
induk
melakukan
831b6d21a2
45 mengubah file dengan 4248 tambahan dan 1 penghapusan
  1. 15 0
      api_examples/pipelines/test_doc_preprocessor.py
  2. 19 0
      api_examples/pipelines/test_layout_parsing.py
  3. 13 0
      api_examples/pipelines/test_ocr.py
  4. 37 0
      api_examples/pipelines/test_pp_chatocrv3.py
  5. 13 0
      api_examples/pipelines/test_table_recognition.py
  6. 36 0
      paddlex/configs/pipelines/OCR.yaml
  7. 109 0
      paddlex/configs/pipelines/PP-ChatOCRv3-doc.yaml
  8. 16 0
      paddlex/configs/pipelines/doc_preprocessor.yaml
  9. 56 0
      paddlex/configs/pipelines/layout_parsing.yaml
  10. 5 1
      paddlex/inference/__init__.py
  11. 95 0
      paddlex/inference/pipelines_new/__init__.py
  12. 94 0
      paddlex/inference/pipelines_new/base.py
  13. 19 0
      paddlex/inference/pipelines_new/components/__init__.py
  14. 59 0
      paddlex/inference/pipelines_new/components/base.py
  15. 15 0
      paddlex/inference/pipelines_new/components/chat_server/__init__.py
  16. 31 0
      paddlex/inference/pipelines_new/components/chat_server/base.py
  17. 95 0
      paddlex/inference/pipelines_new/components/chat_server/ernie_bot_chat.py
  18. 17 0
      paddlex/inference/pipelines_new/components/common/__init__.py
  19. 475 0
      paddlex/inference/pipelines_new/components/common/crop_image_regions.py
  20. 939 0
      paddlex/inference/pipelines_new/components/common/seal_det_warp.py
  21. 48 0
      paddlex/inference/pipelines_new/components/common/sort_boxes.py
  22. 15 0
      paddlex/inference/pipelines_new/components/prompt_engeering/__init__.py
  23. 31 0
      paddlex/inference/pipelines_new/components/prompt_engeering/base.py
  24. 100 0
      paddlex/inference/pipelines_new/components/prompt_engeering/generate_kie_prompt.py
  25. 15 0
      paddlex/inference/pipelines_new/components/retriever/__init__.py
  26. 50 0
      paddlex/inference/pipelines_new/components/retriever/base.py
  27. 148 0
      paddlex/inference/pipelines_new/components/retriever/ernie_bot_retriever.py
  28. 13 0
      paddlex/inference/pipelines_new/components/utils/__init__.py
  29. 204 0
      paddlex/inference/pipelines_new/components/utils/mixin.py
  30. 15 0
      paddlex/inference/pipelines_new/doc_preprocessor/__init__.py
  31. 117 0
      paddlex/inference/pipelines_new/doc_preprocessor/pipeline.py
  32. 51 0
      paddlex/inference/pipelines_new/doc_preprocessor/result.py
  33. 15 0
      paddlex/inference/pipelines_new/layout_parsing/__init__.py
  34. 205 0
      paddlex/inference/pipelines_new/layout_parsing/pipeline.py
  35. 97 0
      paddlex/inference/pipelines_new/layout_parsing/result.py
  36. 203 0
      paddlex/inference/pipelines_new/layout_parsing/table_recognition_post_processing.py
  37. 87 0
      paddlex/inference/pipelines_new/layout_parsing/utils.py
  38. 15 0
      paddlex/inference/pipelines_new/ocr/__init__.py
  39. 96 0
      paddlex/inference/pipelines_new/ocr/pipeline.py
  40. 160 0
      paddlex/inference/pipelines_new/ocr/result.py
  41. 15 0
      paddlex/inference/pipelines_new/pp_chatocrv3_doc/__init__.py
  42. 329 0
      paddlex/inference/pipelines_new/pp_chatocrv3_doc/pipeline.py
  43. 44 0
      paddlex/inference/pipelines_new/pp_chatocrv3_doc/result.py
  44. 2 0
      paddlex/utils/flags.py
  45. 15 0
      paddlex/utils/fonts/__init__.py

+ 15 - 0
api_examples/pipelines/test_doc_preprocessor.py

@@ -0,0 +1,15 @@
+
+from paddlex import create_pipeline
+
+pipeline = create_pipeline(pipeline="doc_preprocessor")
+
+test_img_path = "./test_imgs/img_rot180_demo.jpg"
+# test_img_path = "./test_imgs/doc_distort_test.jpg"
+
+output = pipeline.predict(test_img_path, 
+    use_doc_orientation_classify=True,
+    use_doc_unwarping=True)
+
+for res in output:
+    print(res)
+    res.save_to_img("./output")

+ 19 - 0
api_examples/pipelines/test_layout_parsing.py

@@ -0,0 +1,19 @@
+
+from paddlex import create_pipeline
+
+pipeline = create_pipeline(pipeline="layout_parsing")
+
+output = pipeline.predict("./test_imgs/test_layout_parsing.jpg",
+    use_doc_orientation_classify=True,
+    use_doc_unwarping=True,
+    use_common_ocr=True,
+    use_seal_recognition=True,
+    use_table_recognition=True)
+
+# output = pipeline("./test_imgs/demo_paper.png")
+# output = pipeline("./test_imgs/table_recognition.jpg")
+# output = pipeline.predict("./test_imgs/seal_text_det.png")
+# output = pipeline.predict("./test_imgs/img_rot180_demo.jpg")
+for res in output:
+    # print(res)
+    res.save_results("./output")

+ 13 - 0
api_examples/pipelines/test_ocr.py

@@ -0,0 +1,13 @@
+
+from paddlex import create_pipeline
+
+pipeline = create_pipeline(pipeline="OCR")
+
+# output = pipeline.predict("./test_imgs/general_ocr_002.png")
+
+output = pipeline.predict("./test_imgs/seal_text_det.png")
+for res in output:
+    print(res)
+    res.save_to_img("./output")
+
+

+ 37 - 0
api_examples/pipelines/test_pp_chatocrv3.py

@@ -0,0 +1,37 @@
+
+from paddlex import create_pipeline
+
+pipeline = create_pipeline(pipeline="PP-ChatOCRv3-doc")
+
+# img_path = "./test_demo_imgs/vehicle_certificate-1.png"
+# key_list = ['驾驶室准乘人数']
+
+# img_path = "./test_demo_imgs/test_layout_parsing.jpg"
+# key_list = ['3.2的标题']
+
+img_path = "./test_demo_imgs/seal_text_det.png"
+key_list = ['印章上公司']
+
+# visual_predict_res = pipeline.visual_predict(img_path, 
+#     use_doc_orientation_classify=True,
+#     use_doc_unwarping=True,
+#     use_common_ocr=True,
+#     use_seal_recognition=True,
+#     use_table_recognition=True)
+
+# ####[TODO] 增加类别信息
+# visual_info_list = []
+# for res in visual_predict_res:
+#     visual_info_list.append(res["visual_info"])
+
+# pipeline.save_visual_info_list(visual_info_list, "./res_visual_info/visual_info3.json")
+
+visual_info_list = pipeline.load_visual_info_list("./res_visual_info/visual_info3.json")
+
+vector_info = pipeline.build_vector(visual_info_list)
+
+print(vector_info)
+
+final_results = pipeline.chat(visual_info_list, key_list, vector_info)
+
+print(final_results)

+ 13 - 0
api_examples/pipelines/test_table_recognition.py

@@ -0,0 +1,13 @@
+
+from paddlex import create_pipeline
+
+pipeline = create_pipeline(pipeline="table_recognition")
+
+output = pipeline("./test_imgs/table_recognition.jpg")
+for res in output:
+    print(res)
+    res.save_to_img("./output/") ## 保存img格式结果
+    res.save_to_xlsx("./output/") ## 保存表格格式结果
+    res.save_to_html("./output/") ## 保存html结果
+
+

+ 36 - 0
paddlex/configs/pipelines/OCR.yaml

@@ -0,0 +1,36 @@
+
+pipeline_name: OCR
+
+##############################################
+####### Config for Common OCR
+##############################################
+
+input: https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/general_ocr_001.png
+text_type: common
+
+SubModules:
+  TextDetection:
+    model_name: PP-OCRv4_mobile_det
+    model_dir: null
+    batch_size: 1    
+  TextRecognition:
+    model_name: PP-OCRv4_mobile_rec
+    model_dir: null
+    batch_size: 1
+
+##############################################
+####### Config for Seal OCR
+##############################################
+
+# input: https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/seal_text_det.png
+# text_type: seal
+
+# SubModules:
+#   TextDetection:
+#     model_name: PP-OCRv4_mobile_seal_det
+#     model_dir: null
+#     batch_size: 1    
+#   TextRecognition:
+#     model_name: PP-OCRv4_mobile_rec
+#     model_dir: null
+#     batch_size: 1

+ 109 - 0
paddlex/configs/pipelines/PP-ChatOCRv3-doc.yaml

@@ -0,0 +1,109 @@
+
+pipeline_name: PP-ChatOCRv3-doc
+input: https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/demo_paper.png
+
+use_vector_retrieval: True
+
+SubModules:
+  LLM_Chat:
+    model_name: ernie-3.5
+    api_type: qianfan
+    # ak: "api_key" # Set this to a real API key
+    # sk: "secret_key"  # Set this to a real secret key
+    ak: 4iiqB0QfvXTAENgzUwNeDjQ7
+    sk: sHQCw4l5A6jnzbHMa0ZvDi05GT9Qz8tZ
+
+  LLM_Retriever:
+    model_name: ernie-3.5
+    api_type: qianfan
+    # ak: "api_key" # Set this to a real API key
+    # sk: "secret_key"  # Set this to a real secret key
+    ak: 4iiqB0QfvXTAENgzUwNeDjQ7
+    sk: sHQCw4l5A6jnzbHMa0ZvDi05GT9Qz8tZ
+
+  PromptEngneering:
+    KIE_CommonText:
+      task_type: text_kie_prompt
+      task_description: '你现在的任务是从OCR文字识别的结果中提取关键词列表中每一项对应的关键信息。
+          OCR的文字识别结果使用```符号包围,包含所识别出来的文字,顺序在原始图片中从左至右、从上至下。
+          我指定的关键词列表使用[]符号包围。请注意OCR的文字识别结果可能存在长句子换行被切断、不合理的分词、
+          文字被错误合并等问题,你需要结合上下文语义进行综合判断,以抽取准确的关键信息。'
+      output_format: '在返回结果时使用JSON格式,包含多个key-value对,key值为我指定的关键词,value值为所抽取的结果。
+          如果认为OCR识别结果中没有关键词key对应的value,则将value赋值为"未知"。请只输出json格式的结果,
+          并做json格式校验后返回,不要包含其它多余文字!'
+      rules_str:
+      few_shot_demo_text_content:
+      few_shot_demo_key_value_list:
+          
+    KIE_Table:
+      task_type: table_kie_prompt
+      task_description: '你现在的任务是从输入的html格式的表格内容中提取关键词列表中每一项对应的关键信息,
+          表格内容用```符号包围,我指定的关键词列表使用[]符号包围。你需要结合上下文语义进行综合判断,以抽取准确的关键信息。
+          在返回结果时使用JSON格式,包含多个key-value对,key值为我指定的关键词,value值为所抽取的结果。
+          如果认为输入的表格内容中没有关键词key对应的value值,则将value赋值为"未知"。
+          请只输出json格式的结果,并做json格式校验后返回,不要包含其它多余文字!'
+      output_format: '在返回结果时使用JSON格式,包含多个key-value对,key值为我指定的关键词,value值为所抽取的结果。
+          如果认为表格识别结果中没有关键词key对应的value,则将value赋值为"未知"。请只输出json格式的结果,
+          并做json格式校验后返回,不要包含其它多余文字!'
+      rules_str:
+      few_shot_demo_text_content:
+      few_shot_demo_key_value_list:
+
+SubPipelines:
+  LayoutParser:
+    pipeline_name: layout_parsing
+    use_doc_preprocessor: True
+    use_common_ocr: True
+    use_seal_recognition: True
+    use_table_recognition: True
+
+    SubModules:
+      LayoutDetection:
+        model_name: RT-DETR-H_layout_3cls
+        model_dir: null
+        batch_size: 1
+      TableStructurePredictor:
+        model_name: SLANet_plus
+        model_dir: null
+        batch_size: 1
+
+    SubPipelines:
+      DocPreprocessor:
+        pipeline_name: doc_preprocessor
+        use_doc_orientation_classify: True
+        use_doc_unwarping: True
+        SubModules:
+          DocOrientationClassify:
+            model_name: PP-LCNet_x1_0_doc_ori
+            model_dir: null
+            batch_size: 1
+          DocUnwarping:
+            model_name: UVDoc
+            model_dir: null
+            batch_size: 1
+
+      CommonOCR:
+        pipeline_name: OCR
+        text_type: common
+        SubModules:
+          TextDetection:
+            model_name: PP-OCRv4_server_det
+            model_dir: null
+            batch_size: 1    
+          TextRecognition:
+            model_name: PP-OCRv4_server_rec
+            model_dir: null
+            batch_size: 1
+
+      SealOCR:
+        pipeline_name: OCR
+        text_type: seal
+        SubModules:
+          TextDetection:
+            model_name: PP-OCRv4_server_seal_det
+            model_dir: null
+            batch_size: 1    
+          TextRecognition:
+            model_name: PP-OCRv4_server_rec
+            model_dir: null
+            batch_size: 1  

+ 16 - 0
paddlex/configs/pipelines/doc_preprocessor.yaml

@@ -0,0 +1,16 @@
+
+pipeline_name: doc_preprocessor
+input: https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/img_rot180_demo.jpg
+#input: https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/doc_test.jpg
+use_doc_orientation_classify: True
+use_doc_unwarping: True
+
+SubModules:
+  DocOrientationClassify:
+    model_name: PP-LCNet_x1_0_doc_ori
+    model_dir: null
+    batch_size: 1
+  DocUnwarping:
+    model_name: UVDoc
+    model_dir: null
+    batch_size: 1

+ 56 - 0
paddlex/configs/pipelines/layout_parsing.yaml

@@ -0,0 +1,56 @@
+
+pipeline_name: layout_parsing
+input: https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/demo_paper.png
+use_doc_preprocessor: True
+use_common_ocr: True
+use_seal_recognition: True
+use_table_recognition: True
+
+SubModules:
+  LayoutDetection:
+    model_name: RT-DETR-H_layout_3cls
+    model_dir: null
+    batch_size: 1
+  TableStructurePredictor:
+    model_name: SLANet_plus
+    model_dir: null
+    batch_size: 1
+
+SubPipelines:
+  DocPreprocessor:
+    pipeline_name: doc_preprocessor
+    use_doc_orientation_classify: True
+    use_doc_unwarping: True
+    SubModules:
+      DocOrientationClassify:
+        model_name: PP-LCNet_x1_0_doc_ori
+        model_dir: null
+        batch_size: 1
+      DocUnwarping:
+        model_name: UVDoc
+        model_dir: null
+        batch_size: 1
+  CommonOCR:
+    pipeline_name: OCR
+    text_type: common
+    SubModules:
+      TextDetection:
+        model_name: PP-OCRv4_server_det
+        model_dir: null
+        batch_size: 1    
+      TextRecognition:
+        model_name: PP-OCRv4_server_rec
+        model_dir: null
+        batch_size: 1
+  SealOCR:
+    pipeline_name: OCR
+    text_type: seal
+    SubModules:
+      TextDetection:
+        model_name: PP-OCRv4_server_seal_det
+        model_dir: null
+        batch_size: 1    
+      TextRecognition:
+        model_name: PP-OCRv4_server_rec
+        model_dir: null
+        batch_size: 1

+ 5 - 1
paddlex/inference/__init__.py

@@ -13,5 +13,9 @@
 # limitations under the License.
 
 from .models import create_predictor
-from .pipelines import create_pipeline
+from ..utils import flags
+if flags.USE_NEW_INFERENCE:
+    from .pipelines_new import create_pipeline
+else:
+    from .pipelines import create_pipeline
 from .utils.pp_option import PaddlePredictorOption

+ 95 - 0
paddlex/inference/pipelines_new/__init__.py

@@ -0,0 +1,95 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from pathlib import Path
+from typing import Any, Dict, Optional
+from .base import BasePipeline
+from ...utils.config import parse_config
+
+# from .single_model_pipeline import (
+#     _SingleModelPipeline,
+#     ImageClassification,
+#     ObjectDetection,
+#     InstanceSegmentation,
+#     SemanticSegmentation,
+#     TSFc,
+#     TSAd,
+#     TSCls,
+#     MultiLableImageClas,
+#     SmallObjDet,
+#     AnomalyDetection,
+# )
+# from .ocr import OCRPipeline
+# from .formula_recognition import FormulaRecognitionPipeline
+# from .table_recognition import TableRecPipeline
+# from .face_recognition import FaceRecPipeline
+# from .seal_recognition import SealOCRPipeline
+# from .ppchatocrv3 import PPChatOCRPipeline
+# from .layout_parsing import LayoutParsingPipeline
+# from .pp_shitu_v2 import ShiTuV2Pipeline
+# from .attribute_recognition import AttributeRecPipeline
+
+from .ocr import OCRPipeline
+from .doc_preprocessor import DocPreprocessorPipeline
+from .layout_parsing import LayoutParsingPipeline
+from .pp_chatocrv3_doc import PP_ChatOCRv3_doc_Pipeline
+
+def get_pipeline_path(pipeline_name):
+    pipeline_path = (
+        Path(__file__).parent.parent.parent / "configs/pipelines" / f"{pipeline_name}.yaml"
+    ).resolve()
+    if not Path(pipeline_path).exists():
+        return None
+    return pipeline_path
+
+def load_pipeline_config(pipeline_name: str) -> Dict[str, Any]:
+    if not Path(pipeline_name).exists():
+        pipeline_path = get_pipeline_path(pipeline_name)
+        if pipeline_path is None:
+            raise Exception(
+                f"The pipeline ({pipeline_name}) does not exist! Please use a pipeline name or a config file path!"
+            )
+    else:
+        pipeline_path = pipeline_name
+    config = parse_config(pipeline_path)
+    return config
+
+def create_pipeline(
+    pipeline: str,
+    device=None,
+    pp_option=None,
+    use_hpip: bool = False,
+    hpi_params: Optional[Dict[str, Any]] = None,
+    *args,
+    **kwargs,
+) -> BasePipeline:
+    """build model evaluater
+
+    Args:
+        pipeline (str): the pipeline name, that is name of pipeline class
+
+    Returns:
+        BasePipeline: the pipeline, which is subclass of BasePipeline.
+    """
+    pipeline_name = pipeline
+    config = load_pipeline_config(pipeline_name)
+    assert pipeline_name == config["pipeline_name"]
+    pipeline = BasePipeline.get(pipeline_name)(
+        config=config,
+        device=device,
+        pp_option=pp_option,
+        use_hpip=use_hpip,
+        hpi_params=hpi_params)
+    return pipeline
+    

+ 94 - 0
paddlex/inference/pipelines_new/base.py

@@ -0,0 +1,94 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from abc import ABC, abstractmethod
+from ...utils.subclass_register import AutoRegisterABCMetaClass
+import yaml
+import codecs
+from pathlib import Path
+from typing import Any, Dict, Optional
+from ..models import create_predictor
+from .components.chat_server.base import BaseChat
+from .components.retriever.base import BaseRetriever
+from .components.prompt_engeering.base import BaseGeneratePrompt
+
+class BasePipeline(ABC, metaclass=AutoRegisterABCMetaClass):
+    """Base Pipeline"""
+
+    __is_base = True
+
+    def __init__(self,
+        device=None, 
+        pp_option=None, 
+        use_hpip: bool = False, 
+        hpi_params: Optional[Dict[str, Any]] = None) -> None:
+        super().__init__()
+        self.device = device
+        self.pp_option = pp_option
+        self.use_hpip = use_hpip
+        self.hpi_params = hpi_params
+
+    @abstractmethod
+    def predict(self, input, **kwargs):
+        raise NotImplementedError(
+            "The method `predict` has not been implemented yet."
+        )
+    
+    def create_model(self, config):
+
+        model_dir = config['model_dir']
+        if model_dir == None:
+            model_dir = config['model_name']
+
+        model = create_predictor(
+            model_dir,
+            device=self.device,
+            pp_option=self.pp_option,
+            use_hpip=self.use_hpip,
+            hpi_params=self.hpi_params)
+
+        ########### [TODO]支持初始化传参能力
+        if "batch_size" in config:
+            batch_size = config["batch_size"]
+            model.set_predictor(batch_size=batch_size)
+
+        return model
+
+    def create_pipeline(self, config):
+        pipeline_name = config['pipeline_name']
+        pipeline = BasePipeline.get(pipeline_name)(
+            config=config,
+            device=self.device,
+            pp_option=self.pp_option,
+            use_hpip=self.use_hpip,
+            hpi_params=self.hpi_params)
+        return pipeline 
+
+    def create_chat_bot(self, config):
+        model_name = config['model_name']
+        chat_bot = BaseChat.get(model_name)(config)
+        return chat_bot     
+
+    def create_retriever(self, config):
+        model_name = config['model_name']
+        retriever = BaseRetriever.get(model_name)(config)
+        return retriever    
+
+    def create_prompt_engeering(self, config):
+        task_type = config['task_type']
+        pe = BaseGeneratePrompt.get(task_type)(config)
+        return pe       
+
+    def __call__(self, input, **kwargs):
+        return self.predict(input, **kwargs)

+ 19 - 0
paddlex/inference/pipelines_new/components/__init__.py

@@ -0,0 +1,19 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .base import BaseComponent, CVResult, BaseResult
+from .common import SortQuadBoxes
+from .common import CropByPolys
+from .common import CropByBoxes
+from .utils.mixin import HtmlMixin, XlsxMixin

+ 59 - 0
paddlex/inference/pipelines_new/components/base.py

@@ -0,0 +1,59 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from abc import ABC, abstractmethod
+from ....utils.subclass_register import AutoRegisterABCMetaClass
+
+import inspect
+
+from ....utils.func_register import FuncRegister
+from ...utils.io import ImageReader, ImageWriter
+from .utils.mixin import JsonMixin, ImgMixin, StrMixin
+
+class BaseComponent(ABC, metaclass=AutoRegisterABCMetaClass):
+    """Base Component"""
+
+    __is_base = True
+
+    def __init__(self):
+        super().__init__()
+
+    @abstractmethod
+    def __call__(self):
+        raise NotImplementedError(
+            "The component method `__call__` has not been implemented yet.")
+
+
+class BaseResult(dict, StrMixin, JsonMixin):
+    def __init__(self, data):
+        super().__init__(data)
+        self._show_funcs = []
+        StrMixin.__init__(self)
+        JsonMixin.__init__(self)
+
+    def save_all(self, save_path):
+        for func in self._show_funcs:
+            signature = inspect.signature(func)
+            if "save_path" in signature.parameters:
+                func(save_path=save_path)
+            else:
+                func()
+
+class CVResult(BaseResult, ImgMixin):
+    def __init__(self, data):
+        super().__init__(data)
+        ImgMixin.__init__(self, "pillow")
+        self._img_reader = ImageReader(backend="pillow")
+        self._img_writer = ImageWriter(backend="pillow")
+

+ 15 - 0
paddlex/inference/pipelines_new/components/chat_server/__init__.py

@@ -0,0 +1,15 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .ernie_bot_chat import ErnieBotChat

+ 31 - 0
paddlex/inference/pipelines_new/components/chat_server/base.py

@@ -0,0 +1,31 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from abc import ABC, abstractmethod
+from .....utils.subclass_register import AutoRegisterABCMetaClass
+
+import inspect
+
+class BaseChat(ABC, metaclass=AutoRegisterABCMetaClass):
+    """Base Chat"""
+
+    __is_base = True
+
+    def __init__(self):
+        super().__init__()
+
+    @abstractmethod
+    def generate_chat_results(self):
+        raise NotImplementedError(
+            "The method `generate_chat_results` has not been implemented yet.")

+ 95 - 0
paddlex/inference/pipelines_new/components/chat_server/ernie_bot_chat.py

@@ -0,0 +1,95 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .....utils import logging
+from .base import BaseChat
+import erniebot
+
+class ErnieBotChat(BaseChat):
+    """Ernie Bot Chat"""
+
+    entities = [
+        "ernie-4.0",
+        "ernie-3.5",
+        "ernie-3.5-8k",
+        "ernie-lite",
+        "ernie-tiny-8k",
+        "ernie-speed",
+        "ernie-speed-128k",
+        "ernie-char-8k",
+    ]
+
+    def __init__(self, config):
+        super().__init__()
+        model_name = config.get('model_name', None)
+        api_type = config.get('api_type', None)
+        ak = config.get('ak', None)
+        sk = config.get('sk', None)
+        access_token = config.get('access_token', None)
+
+        if model_name not in self.entities:
+            raise ValueError(f"model_name must be in {self.entities} of ErnieBotChat.")
+
+        if api_type not in ["aistudio", "qianfan"]:
+            raise ValueError("api_type must be one of ['aistudio', 'qianfan']")
+
+        if api_type == "aistudio" and access_token is None:
+            raise ValueError("access_token cannot be empty when api_type is aistudio.")
+            
+        if api_type == "qianfan" and (ak is None or sk is None):
+            raise ValueError("ak and sk cannot be empty when api_type is qianfan.")            
+
+        self.model_name = model_name
+        self.config = config
+        
+    def generate_chat_results(self, prompt, temperature=0.001, max_retries=1):
+        """
+        args:
+        return:
+        """
+        try:
+            cur_config = {
+                "api_type": self.config['api_type'],
+                "max_retries": max_retries
+            }
+            if self.config['api_type'] == "aistudio":
+                cur_config['access_token'] = self.config['access_token']
+            elif self.config['api_type'] == "qianfan":
+                cur_config['ak'] = self.config['ak']
+                cur_config['sk'] = self.config['sk']
+            chat_completion = erniebot.ChatCompletion.create(
+                _config_=cur_config,
+                model=self.model_name,
+                messages=[{"role": "user", "content": prompt}],
+                temperature=float(temperature),
+            )
+            llm_result = chat_completion.get_result()
+            return llm_result
+        except Exception as e:
+            if len(e.args) < 1:
+                self.ERROR_MASSAGE = (
+                    "暂无权限访问ErnieBot服务,请检查访问令牌。"
+                )
+            elif (
+                e.args[-1]
+                == "暂无权限使用,请在 AI Studio 正确获取访问令牌(access token)使用"
+            ):
+                self.ERROR_MASSAGE = (
+                    "暂无权限访问ErnieBot服务,请检查访问令牌。"
+                )
+            else:
+                logging.error(e)
+                self.ERROR_MASSAGE = "大模型调用失败"
+        return None 
+        

+ 17 - 0
paddlex/inference/pipelines_new/components/common/__init__.py

@@ -0,0 +1,17 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .sort_boxes import SortQuadBoxes
+from .crop_image_regions import CropByPolys, CropByBoxes
+

+ 475 - 0
paddlex/inference/pipelines_new/components/common/crop_image_regions.py

@@ -0,0 +1,475 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ..base import BaseComponent
+import numpy as np
+from ....utils.io import ImageReader
+import copy
+import cv2
+from .seal_det_warp import AutoRectifier
+from shapely.geometry import Polygon
+from numpy.linalg import norm
+
+class CropByPolys(BaseComponent):
+    """Crop Image by Polys"""
+
+    entities = "CropByPolys"
+
+    def __init__(self, det_box_type="quad"):
+        super().__init__()
+        self.det_box_type = det_box_type
+
+    def __call__(self, img, dt_polys):
+        """__call__"""
+
+        if self.det_box_type == "quad":
+            dt_boxes = np.array(dt_polys)
+            output_list = []
+            for bno in range(len(dt_boxes)):
+                tmp_box = copy.deepcopy(dt_boxes[bno])
+                img_crop = self.get_minarea_rect_crop(img, tmp_box)
+                output_list.append(
+                    {
+                        "img": img_crop,
+                        "img_size": [img_crop.shape[1], img_crop.shape[0]],
+                    }
+                )
+        elif self.det_box_type == "poly":
+            output_list = []
+            dt_boxes = dt_polys
+            for bno in range(len(dt_boxes)):
+                tmp_box = copy.deepcopy(dt_boxes[bno])
+                img_crop = self.get_poly_rect_crop(img.copy(), tmp_box)
+                output_list.append(
+                    {
+                        "img": img_crop,
+                        "img_size": [img_crop.shape[1], img_crop.shape[0]],
+                    }
+                )
+        else:
+            raise NotImplementedError
+
+        return output_list
+
+    def get_minarea_rect_crop(self, img, points):
+        """get_minarea_rect_crop"""
+        bounding_box = cv2.minAreaRect(np.array(points).astype(np.int32))
+        points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
+
+        index_a, index_b, index_c, index_d = 0, 1, 2, 3
+        if points[1][1] > points[0][1]:
+            index_a = 0
+            index_d = 1
+        else:
+            index_a = 1
+            index_d = 0
+        if points[3][1] > points[2][1]:
+            index_b = 2
+            index_c = 3
+        else:
+            index_b = 3
+            index_c = 2
+
+        box = [points[index_a], points[index_b], points[index_c], points[index_d]]
+        crop_img = self.get_rotate_crop_image(img, np.array(box))
+        return crop_img
+
+    def get_rotate_crop_image(self, img, points):
+        """
+        img_height, img_width = img.shape[0:2]
+        left = int(np.min(points[:, 0]))
+        right = int(np.max(points[:, 0]))
+        top = int(np.min(points[:, 1]))
+        bottom = int(np.max(points[:, 1]))
+        img_crop = img[top:bottom, left:right, :].copy()
+        points[:, 0] = points[:, 0] - left
+        points[:, 1] = points[:, 1] - top
+        """
+        assert len(points) == 4, "shape of points must be 4*2"
+        img_crop_width = int(
+            max(
+                np.linalg.norm(points[0] - points[1]),
+                np.linalg.norm(points[2] - points[3]),
+            )
+        )
+        img_crop_height = int(
+            max(
+                np.linalg.norm(points[0] - points[3]),
+                np.linalg.norm(points[1] - points[2]),
+            )
+        )
+        pts_std = np.float32(
+            [
+                [0, 0],
+                [img_crop_width, 0],
+                [img_crop_width, img_crop_height],
+                [0, img_crop_height],
+            ]
+        )
+        M = cv2.getPerspectiveTransform(points, pts_std)
+        dst_img = cv2.warpPerspective(
+            img,
+            M,
+            (img_crop_width, img_crop_height),
+            borderMode=cv2.BORDER_REPLICATE,
+            flags=cv2.INTER_CUBIC,
+        )
+        dst_img_height, dst_img_width = dst_img.shape[0:2]
+        if dst_img_height * 1.0 / dst_img_width >= 1.5:
+            dst_img = np.rot90(dst_img)
+        return dst_img
+
+    def reorder_poly_edge(self, points):
+        """Get the respective points composing head edge, tail edge, top
+        sideline and bottom sideline.
+
+        Args:
+            points (ndarray): The points composing a text polygon.
+
+        Returns:
+            head_edge (ndarray): The two points composing the head edge of text
+                polygon.
+            tail_edge (ndarray): The two points composing the tail edge of text
+                polygon.
+            top_sideline (ndarray): The points composing top curved sideline of
+                text polygon.
+            bot_sideline (ndarray): The points composing bottom curved sideline
+                of text polygon.
+        """
+
+        assert points.ndim == 2
+        assert points.shape[0] >= 4
+        assert points.shape[1] == 2
+
+        orientation_thr = 2.0  # 一个经验超参数
+
+        head_inds, tail_inds = self.find_head_tail(points, orientation_thr)
+        head_edge, tail_edge = points[head_inds], points[tail_inds]
+
+        pad_points = np.vstack([points, points])
+        if tail_inds[1] < 1:
+            tail_inds[1] = len(points)
+        sideline1 = pad_points[head_inds[1] : tail_inds[1]]
+        sideline2 = pad_points[tail_inds[1] : (head_inds[1] + len(points))]
+        return head_edge, tail_edge, sideline1, sideline2
+
+    def vector_slope(self, vec):
+        assert len(vec) == 2
+        return abs(vec[1] / (vec[0] + 1e-8))
+
+    def find_head_tail(self, points, orientation_thr):
+        """Find the head edge and tail edge of a text polygon.
+
+        Args:
+            points (ndarray): The points composing a text polygon.
+            orientation_thr (float): The threshold for distinguishing between
+                head edge and tail edge among the horizontal and vertical edges
+                of a quadrangle.
+
+        Returns:
+            head_inds (list): The indexes of two points composing head edge.
+            tail_inds (list): The indexes of two points composing tail edge.
+        """
+
+        assert points.ndim == 2
+        assert points.shape[0] >= 4
+        assert points.shape[1] == 2
+        assert isinstance(orientation_thr, float)
+
+        if len(points) > 4:
+            pad_points = np.vstack([points, points[0]])
+            edge_vec = pad_points[1:] - pad_points[:-1]
+
+            theta_sum = []
+            adjacent_vec_theta = []
+            for i, edge_vec1 in enumerate(edge_vec):
+                adjacent_ind = [x % len(edge_vec) for x in [i - 1, i + 1]]
+                adjacent_edge_vec = edge_vec[adjacent_ind]
+                temp_theta_sum = np.sum(self.vector_angle(edge_vec1, adjacent_edge_vec))
+                temp_adjacent_theta = self.vector_angle(
+                    adjacent_edge_vec[0], adjacent_edge_vec[1]
+                )
+                theta_sum.append(temp_theta_sum)
+                adjacent_vec_theta.append(temp_adjacent_theta)
+            theta_sum_score = np.array(theta_sum) / np.pi
+            adjacent_theta_score = np.array(adjacent_vec_theta) / np.pi
+            poly_center = np.mean(points, axis=0)
+            edge_dist = np.maximum(
+                norm(pad_points[1:] - poly_center, axis=-1),
+                norm(pad_points[:-1] - poly_center, axis=-1),
+            )
+            dist_score = edge_dist / np.max(edge_dist)
+            position_score = np.zeros(len(edge_vec))
+            score = 0.5 * theta_sum_score + 0.15 * adjacent_theta_score
+            score += 0.35 * dist_score
+            if len(points) % 2 == 0:
+                position_score[(len(score) // 2 - 1)] += 1
+                position_score[-1] += 1
+            score += 0.1 * position_score
+            pad_score = np.concatenate([score, score])
+            score_matrix = np.zeros((len(score), len(score) - 3))
+            x = np.arange(len(score) - 3) / float(len(score) - 4)
+            gaussian = (
+                1.0
+                / (np.sqrt(2.0 * np.pi) * 0.5)
+                * np.exp(-np.power((x - 0.5) / 0.5, 2.0) / 2)
+            )
+            gaussian = gaussian / np.max(gaussian)
+            for i in range(len(score)):
+                score_matrix[i, :] = (
+                    score[i]
+                    + pad_score[(i + 2) : (i + len(score) - 1)] * gaussian * 0.3
+                )
+
+            head_start, tail_increment = np.unravel_index(
+                score_matrix.argmax(), score_matrix.shape
+            )
+            tail_start = (head_start + tail_increment + 2) % len(points)
+            head_end = (head_start + 1) % len(points)
+            tail_end = (tail_start + 1) % len(points)
+
+            if head_end > tail_end:
+                head_start, tail_start = tail_start, head_start
+                head_end, tail_end = tail_end, head_end
+            head_inds = [head_start, head_end]
+            tail_inds = [tail_start, tail_end]
+        else:
+            if self.vector_slope(points[1] - points[0]) + self.vector_slope(
+                points[3] - points[2]
+            ) < self.vector_slope(points[2] - points[1]) + self.vector_slope(
+                points[0] - points[3]
+            ):
+                horizontal_edge_inds = [[0, 1], [2, 3]]
+                vertical_edge_inds = [[3, 0], [1, 2]]
+            else:
+                horizontal_edge_inds = [[3, 0], [1, 2]]
+                vertical_edge_inds = [[0, 1], [2, 3]]
+
+            vertical_len_sum = norm(
+                points[vertical_edge_inds[0][0]] - points[vertical_edge_inds[0][1]]
+            ) + norm(
+                points[vertical_edge_inds[1][0]] - points[vertical_edge_inds[1][1]]
+            )
+            horizontal_len_sum = norm(
+                points[horizontal_edge_inds[0][0]] - points[horizontal_edge_inds[0][1]]
+            ) + norm(
+                points[horizontal_edge_inds[1][0]] - points[horizontal_edge_inds[1][1]]
+            )
+
+            if vertical_len_sum > horizontal_len_sum * orientation_thr:
+                head_inds = horizontal_edge_inds[0]
+                tail_inds = horizontal_edge_inds[1]
+            else:
+                head_inds = vertical_edge_inds[0]
+                tail_inds = vertical_edge_inds[1]
+
+        return head_inds, tail_inds
+
+    def vector_angle(self, vec1, vec2):
+        if vec1.ndim > 1:
+            unit_vec1 = vec1 / (norm(vec1, axis=-1) + 1e-8).reshape((-1, 1))
+        else:
+            unit_vec1 = vec1 / (norm(vec1, axis=-1) + 1e-8)
+        if vec2.ndim > 1:
+            unit_vec2 = vec2 / (norm(vec2, axis=-1) + 1e-8).reshape((-1, 1))
+        else:
+            unit_vec2 = vec2 / (norm(vec2, axis=-1) + 1e-8)
+        return np.arccos(np.clip(np.sum(unit_vec1 * unit_vec2, axis=-1), -1.0, 1.0))
+
+    def get_minarea_rect(self, img, points):
+        bounding_box = cv2.minAreaRect(points)
+        points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
+
+        index_a, index_b, index_c, index_d = 0, 1, 2, 3
+        if points[1][1] > points[0][1]:
+            index_a = 0
+            index_d = 1
+        else:
+            index_a = 1
+            index_d = 0
+        if points[3][1] > points[2][1]:
+            index_b = 2
+            index_c = 3
+        else:
+            index_b = 3
+            index_c = 2
+
+        box = [points[index_a], points[index_b], points[index_c], points[index_d]]
+        crop_img = self.get_rotate_crop_image(img, np.array(box))
+        return crop_img, box
+
+    def sample_points_on_bbox_bp(self, line, n=50):
+        """Resample n points on a line.
+
+        Args:
+            line (ndarray): The points composing a line.
+            n (int): The resampled points number.
+
+        Returns:
+            resampled_line (ndarray): The points composing the resampled line.
+        """
+        from numpy.linalg import norm
+
+        # 断言检查输入参数的有效性
+        assert line.ndim == 2
+        assert line.shape[0] >= 2
+        assert line.shape[1] == 2
+        assert isinstance(n, int)
+        assert n > 0
+
+        length_list = [norm(line[i + 1] - line[i]) for i in range(len(line) - 1)]
+        total_length = sum(length_list)
+        length_cumsum = np.cumsum([0.0] + length_list)
+        delta_length = total_length / (float(n) + 1e-8)
+        current_edge_ind = 0
+        resampled_line = [line[0]]
+
+        for i in range(1, n):
+            current_line_len = i * delta_length
+            while (
+                current_edge_ind + 1 < len(length_cumsum)
+                and current_line_len >= length_cumsum[current_edge_ind + 1]
+            ):
+                current_edge_ind += 1
+            current_edge_end_shift = current_line_len - length_cumsum[current_edge_ind]
+            if current_edge_ind >= len(length_list):
+                break
+            end_shift_ratio = current_edge_end_shift / length_list[current_edge_ind]
+            current_point = (
+                line[current_edge_ind]
+                + (line[current_edge_ind + 1] - line[current_edge_ind])
+                * end_shift_ratio
+            )
+            resampled_line.append(current_point)
+        resampled_line.append(line[-1])
+        resampled_line = np.array(resampled_line)
+        return resampled_line
+
+    def sample_points_on_bbox(self, line, n=50):
+        """Resample n points on a line.
+
+        Args:
+            line (ndarray): The points composing a line.
+            n (int): The resampled points number.
+
+        Returns:
+            resampled_line (ndarray): The points composing the resampled line.
+        """
+        assert line.ndim == 2
+        assert line.shape[0] >= 2
+        assert line.shape[1] == 2
+        assert isinstance(n, int)
+        assert n > 0
+
+        length_list = [norm(line[i + 1] - line[i]) for i in range(len(line) - 1)]
+        total_length = sum(length_list)
+        mean_length = total_length / (len(length_list) + 1e-8)
+        group = [[0]]
+        for i in range(len(length_list)):
+            point_id = i + 1
+            if length_list[i] < 0.9 * mean_length:
+                for g in group:
+                    if i in g:
+                        g.append(point_id)
+                        break
+            else:
+                g = [point_id]
+                group.append(g)
+
+        top_tail_len = norm(line[0] - line[-1])
+        if top_tail_len < 0.9 * mean_length:
+            group[0].extend(g)
+            group.remove(g)
+        mean_positions = []
+        for indices in group:
+            x_sum = 0
+            y_sum = 0
+            for index in indices:
+                x, y = line[index]
+                x_sum += x
+                y_sum += y
+            num_points = len(indices)
+            mean_x = x_sum / num_points
+            mean_y = y_sum / num_points
+            mean_positions.append((mean_x, mean_y))
+        resampled_line = np.array(mean_positions)
+        return resampled_line
+
+    def get_poly_rect_crop(self, img, points):
+        """
+        修改该函数,实现使用polygon,对不规则、弯曲文本的矫正以及crop
+        args: img: 图片 ndarrary格式
+        points: polygon格式的多点坐标 N*2 shape, ndarray格式
+        return: 矫正后的图片 ndarray格式
+        """
+        points = np.array(points).astype(np.int32).reshape(-1, 2)
+        temp_crop_img, temp_box = self.get_minarea_rect(img, points)
+
+        # 计算最小外接矩形与polygon的IoU
+        def get_union(pD, pG):
+            return Polygon(pD).union(Polygon(pG)).area
+
+        def get_intersection_over_union(pD, pG):
+            return get_intersection(pD, pG) / (get_union(pD, pG) + 1e-10)
+
+        def get_intersection(pD, pG):
+            return Polygon(pD).intersection(Polygon(pG)).area
+
+        cal_IoU = get_intersection_over_union(points, temp_box)
+
+        if cal_IoU >= 0.7:
+            points = self.sample_points_on_bbox_bp(points, 31)
+            return temp_crop_img
+
+        points_sample = self.sample_points_on_bbox(points)
+        points_sample = points_sample.astype(np.int32)
+        head_edge, tail_edge, top_line, bot_line = self.reorder_poly_edge(points_sample)
+
+        resample_top_line = self.sample_points_on_bbox_bp(top_line, 15)
+        resample_bot_line = self.sample_points_on_bbox_bp(bot_line, 15)
+
+        sideline_mean_shift = np.mean(resample_top_line, axis=0) - np.mean(
+            resample_bot_line, axis=0
+        )
+        if sideline_mean_shift[1] > 0:
+            resample_bot_line, resample_top_line = resample_top_line, resample_bot_line
+        rectifier = AutoRectifier()
+        new_points = np.concatenate([resample_top_line, resample_bot_line])
+        new_points_list = list(new_points.astype(np.float32).reshape(1, -1).tolist())
+
+        if len(img.shape) == 2:
+            img = np.stack((img,) * 3, axis=-1)
+        img_crop, image = rectifier.run(img, new_points_list, mode="homography")
+        return np.array(img_crop[0], dtype=np.uint8)
+
+
+class CropByBoxes(BaseComponent):
+    """Crop Image by Box"""
+
+    entities = "CropByBoxes"
+
+    def __init__(self):
+        super().__init__()
+
+    def __call__(self, img, boxes):
+        """__call__"""
+        output_list = []
+        for bbox_info in boxes:
+            label_id = bbox_info["cls_id"]
+            box = bbox_info["coordinate"]
+            label = bbox_info.get("label", label_id)
+            xmin, ymin, xmax, ymax = [int(i) for i in box]
+            img_crop = img[ymin:ymax, xmin:xmax].copy()
+            output_list.append({"img": img_crop, "box": box, "label": label})
+        return output_list

+ 939 - 0
paddlex/inference/pipelines_new/components/common/seal_det_warp.py

@@ -0,0 +1,939 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os, sys
+import numpy as np
+from numpy import cos, sin, arctan, sqrt
+import cv2
+import copy
+import time
+
+from .....utils import logging
+
+def Homography(
+    image,
+    img_points,
+    world_width,
+    world_height,
+    interpolation=cv2.INTER_CUBIC,
+    ratio_width=1.0,
+    ratio_height=1.0,
+):
+    _points = np.array(img_points).reshape(-1, 2).astype(np.float32)
+
+    expand_x = int(0.5 * world_width * (ratio_width - 1))
+    expand_y = int(0.5 * world_height * (ratio_height - 1))
+
+    pt_lefttop = [expand_x, expand_y]
+    pt_righttop = [expand_x + world_width, expand_y]
+    pt_leftbottom = [expand_x + world_width, expand_y + world_height]
+    pt_rightbottom = [expand_x, expand_y + world_height]
+
+    pts_std = np.float32([pt_lefttop, pt_righttop, pt_leftbottom, pt_rightbottom])
+
+    img_crop_width = int(world_width * ratio_width)
+    img_crop_height = int(world_height * ratio_height)
+
+    M = cv2.getPerspectiveTransform(_points, pts_std)
+
+    dst_img = cv2.warpPerspective(
+        image,
+        M,
+        (img_crop_width, img_crop_height),
+        borderMode=cv2.BORDER_CONSTANT,  # BORDER_CONSTANT BORDER_REPLICATE
+        flags=interpolation,
+    )
+
+    return dst_img
+
+
+class PlanB:
+    def __call__(
+        self,
+        image,
+        points,
+        curveTextRectifier,
+        interpolation=cv2.INTER_LINEAR,
+        ratio_width=1.0,
+        ratio_height=1.0,
+        loss_thresh=5.0,
+        square=False,
+    ):
+        """
+        Plan B using sub-image when it failed in original image
+        :param image:
+        :param points:
+        :param curveTextRectifier: CurveTextRectifier
+        :param interpolation: cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_AREA, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4
+        :param ratio_width:  roi_image width expansion. It should not be smaller than 1.0
+        :param ratio_height: roi_image height expansion. It should not be smaller than 1.0
+        :param loss_thresh: if loss greater than loss_thresh --> get_rotate_crop_image
+        :param square: crop square image or not. True or False. The default is False
+        :return:
+        """
+        h, w = image.shape[:2]
+        _points = np.array(points).reshape(-1, 2).astype(np.float32)
+        x_min = int(np.min(_points[:, 0]))
+        y_min = int(np.min(_points[:, 1]))
+        x_max = int(np.max(_points[:, 0]))
+        y_max = int(np.max(_points[:, 1]))
+        dx = x_max - x_min
+        dy = y_max - y_min
+        max_d = max(dx, dy)
+        mean_pt = np.mean(_points, 0)
+
+        expand_x = (ratio_width - 1.0) * 0.5 * max_d
+        expand_y = (ratio_height - 1.0) * 0.5 * max_d
+
+        if square:
+            x_min = np.clip(int(mean_pt[0] - max_d - expand_x), 0, w - 1)
+            y_min = np.clip(int(mean_pt[1] - max_d - expand_y), 0, h - 1)
+            x_max = np.clip(int(mean_pt[0] + max_d + expand_x), 0, w - 1)
+            y_max = np.clip(int(mean_pt[1] + max_d + expand_y), 0, h - 1)
+        else:
+            x_min = np.clip(int(x_min - expand_x), 0, w - 1)
+            y_min = np.clip(int(y_min - expand_y), 0, h - 1)
+            x_max = np.clip(int(x_max + expand_x), 0, w - 1)
+            y_max = np.clip(int(y_max + expand_y), 0, h - 1)
+
+        new_image = image[y_min:y_max, x_min:x_max, :].copy()
+        new_points = _points.copy()
+        new_points[:, 0] -= x_min
+        new_points[:, 1] -= y_min
+
+        dst_img, loss = curveTextRectifier(
+            new_image,
+            new_points,
+            interpolation,
+            ratio_width,
+            ratio_height,
+            mode="calibration",
+        )
+
+        return dst_img, loss
+
+
+class CurveTextRectifier:
+    """
+    spatial transformer via monocular vision
+    """
+
+    def __init__(self):
+        self.get_virtual_camera_parameter()
+
+    def get_virtual_camera_parameter(self):
+        vcam_thz = 0
+        vcam_thx1 = 180
+        vcam_thy = 180
+        vcam_thx2 = 0
+
+        vcam_x = 0
+        vcam_y = 0
+        vcam_z = 100
+
+        radian = np.pi / 180
+
+        angle_z = radian * vcam_thz
+        angle_x1 = radian * vcam_thx1
+        angle_y = radian * vcam_thy
+        angle_x2 = radian * vcam_thx2
+
+        optic_x = vcam_x
+        optic_y = vcam_y
+        optic_z = vcam_z
+
+        fu = 100
+        fv = 100
+
+        matT = np.zeros((4, 4))
+        matT[0, 0] = cos(angle_z) * cos(angle_y) - sin(angle_z) * sin(angle_x1) * sin(
+            angle_y
+        )
+        matT[0, 1] = cos(angle_z) * sin(angle_y) * sin(angle_x2) - sin(angle_z) * (
+            cos(angle_x1) * cos(angle_x2) - sin(angle_x1) * cos(angle_y) * sin(angle_x2)
+        )
+        matT[0, 2] = cos(angle_z) * sin(angle_y) * cos(angle_x2) + sin(angle_z) * (
+            cos(angle_x1) * sin(angle_x2) + sin(angle_x1) * cos(angle_y) * cos(angle_x2)
+        )
+        matT[0, 3] = optic_x
+        matT[1, 0] = sin(angle_z) * cos(angle_y) + cos(angle_z) * sin(angle_x1) * sin(
+            angle_y
+        )
+        matT[1, 1] = sin(angle_z) * sin(angle_y) * sin(angle_x2) + cos(angle_z) * (
+            cos(angle_x1) * cos(angle_x2) - sin(angle_x1) * cos(angle_y) * sin(angle_x2)
+        )
+        matT[1, 2] = sin(angle_z) * sin(angle_y) * cos(angle_x2) - cos(angle_z) * (
+            cos(angle_x1) * sin(angle_x2) + sin(angle_x1) * cos(angle_y) * cos(angle_x2)
+        )
+        matT[1, 3] = optic_y
+        matT[2, 0] = -cos(angle_x1) * sin(angle_y)
+        matT[2, 1] = cos(angle_x1) * cos(angle_y) * sin(angle_x2) + sin(angle_x1) * cos(
+            angle_x2
+        )
+        matT[2, 2] = cos(angle_x1) * cos(angle_y) * cos(angle_x2) - sin(angle_x1) * sin(
+            angle_x2
+        )
+        matT[2, 3] = optic_z
+        matT[3, 0] = 0
+        matT[3, 1] = 0
+        matT[3, 2] = 0
+        matT[3, 3] = 1
+
+        matS = np.zeros((4, 4))
+        matS[2, 3] = 0.5
+        matS[3, 2] = 0.5
+
+        self.ifu = 1 / fu
+        self.ifv = 1 / fv
+
+        self.matT = matT
+        self.matS = matS
+        self.K = np.dot(matT.T, matS)
+        self.K = np.dot(self.K, matT)
+
+    def vertical_text_process(self, points, org_size):
+        """
+        change sequence amd process
+        :param points:
+        :param org_size:
+        :return:
+        """
+        org_w, org_h = org_size
+        _points = np.array(points).reshape(-1).tolist()
+        _points = np.array(_points[2:] + _points[:2]).reshape(-1, 2)
+
+        # convert to horizontal points
+        adjusted_points = np.zeros(_points.shape, dtype=np.float32)
+        adjusted_points[:, 0] = _points[:, 1]
+        adjusted_points[:, 1] = org_h - _points[:, 0] - 1
+
+        _image_coord, _world_coord, _new_image_size = self.horizontal_text_process(
+            adjusted_points
+        )
+
+        # # convert to vertical points back
+        image_coord = _points.reshape(1, -1, 2)
+        world_coord = np.zeros(_world_coord.shape, dtype=np.float32)
+        world_coord[:, :, 0] = 0 - _world_coord[:, :, 1]
+        world_coord[:, :, 1] = _world_coord[:, :, 0]
+        world_coord[:, :, 2] = _world_coord[:, :, 2]
+        new_image_size = (_new_image_size[1], _new_image_size[0])
+
+        return image_coord, world_coord, new_image_size
+
+    def horizontal_text_process(self, points):
+        """
+        get image coordinate and world coordinate
+        :param points:
+        :return:
+        """
+        poly = np.array(points).reshape(-1)
+
+        dx_list = []
+        dy_list = []
+        for i in range(1, len(poly) // 2):
+            xdx = poly[i * 2] - poly[(i - 1) * 2]
+            xdy = poly[i * 2 + 1] - poly[(i - 1) * 2 + 1]
+            d = sqrt(xdx**2 + xdy**2)
+            dx_list.append(d)
+
+        for i in range(0, len(poly) // 4):
+            ydx = poly[i * 2] - poly[len(poly) - 1 - (i * 2 + 1)]
+            ydy = poly[i * 2 + 1] - poly[len(poly) - 1 - (i * 2)]
+            d = sqrt(ydx**2 + ydy**2)
+            dy_list.append(d)
+
+        dx_list = [
+            (dx_list[i] + dx_list[len(dx_list) - 1 - i]) / 2
+            for i in range(len(dx_list) // 2)
+        ]
+
+        height = np.around(np.mean(dy_list))
+
+        rect_coord = [0, 0]
+        for i in range(0, len(poly) // 4 - 1):
+            x = rect_coord[-2]
+            x += dx_list[i]
+            y = 0
+            rect_coord.append(x)
+            rect_coord.append(y)
+
+        rect_coord_half = copy.deepcopy(rect_coord)
+        for i in range(0, len(poly) // 4):
+            x = rect_coord_half[len(rect_coord_half) - 2 * i - 2]
+            y = height
+            rect_coord.append(x)
+            rect_coord.append(y)
+
+        np_rect_coord = np.array(rect_coord).reshape(-1, 2)
+        x_min = np.min(np_rect_coord[:, 0])
+        y_min = np.min(np_rect_coord[:, 1])
+        x_max = np.max(np_rect_coord[:, 0])
+        y_max = np.max(np_rect_coord[:, 1])
+        new_image_size = (int(x_max - x_min + 0.5), int(y_max - y_min + 0.5))
+        x_mean = (x_max - x_min) / 2
+        y_mean = (y_max - y_min) / 2
+        np_rect_coord[:, 0] -= x_mean
+        np_rect_coord[:, 1] -= y_mean
+        rect_coord = np_rect_coord.reshape(-1).tolist()
+
+        rect_coord = np.array(rect_coord).reshape(-1, 2)
+        world_coord = np.ones((len(rect_coord), 3)) * 0
+
+        world_coord[:, :2] = rect_coord
+
+        image_coord = np.array(poly).reshape(1, -1, 2)
+        world_coord = world_coord.reshape(1, -1, 3)
+
+        return image_coord, world_coord, new_image_size
+
+    def horizontal_text_estimate(self, points):
+        """
+        horizontal or vertical text
+        :param points:
+        :return:
+        """
+        pts = np.array(points).reshape(-1, 2)
+        x_min = int(np.min(pts[:, 0]))
+        y_min = int(np.min(pts[:, 1]))
+        x_max = int(np.max(pts[:, 0]))
+        y_max = int(np.max(pts[:, 1]))
+        x = x_max - x_min
+        y = y_max - y_min
+        is_horizontal_text = True
+        if y / x > 1.5:  # vertical text condition
+            is_horizontal_text = False
+        return is_horizontal_text
+
+    def virtual_camera_to_world(self, size):
+        ifu, ifv = self.ifu, self.ifv
+        K, matT = self.K, self.matT
+
+        ppu = size[0] / 2 + 1e-6
+        ppv = size[1] / 2 + 1e-6
+
+        P = np.zeros((size[1], size[0], 3))
+
+        lu = np.array([i for i in range(size[0])])
+        lv = np.array([i for i in range(size[1])])
+        u, v = np.meshgrid(lu, lv)
+
+        yp = (v - ppv) * ifv
+        xp = (u - ppu) * ifu
+        angle_a = arctan(sqrt(xp * xp + yp * yp))
+        angle_b = arctan(yp / xp)
+
+        D0 = sin(angle_a) * cos(angle_b)
+        D1 = sin(angle_a) * sin(angle_b)
+        D2 = cos(angle_a)
+
+        D0[xp <= 0] = -D0[xp <= 0]
+        D1[xp <= 0] = -D1[xp <= 0]
+
+        ratio_a = (
+            K[0, 0] * D0 * D0
+            + K[1, 1] * D1 * D1
+            + K[2, 2] * D2 * D2
+            + (K[0, 1] + K[1, 0]) * D0 * D1
+            + (K[0, 2] + K[2, 0]) * D0 * D2
+            + (K[1, 2] + K[2, 1]) * D1 * D2
+        )
+        ratio_b = (
+            (K[0, 3] + K[3, 0]) * D0
+            + (K[1, 3] + K[3, 1]) * D1
+            + (K[2, 3] + K[3, 2]) * D2
+        )
+        ratio_c = K[3, 3] * np.ones(ratio_b.shape)
+
+        delta = ratio_b * ratio_b - 4 * ratio_a * ratio_c
+        t = np.zeros(delta.shape)
+        t[ratio_a == 0] = -ratio_c[ratio_a == 0] / ratio_b[ratio_a == 0]
+        t[ratio_a != 0] = (-ratio_b[ratio_a != 0] + sqrt(delta[ratio_a != 0])) / (
+            2 * ratio_a[ratio_a != 0]
+        )
+        t[delta < 0] = 0
+
+        P[:, :, 0] = matT[0, 3] + t * (
+            matT[0, 0] * D0 + matT[0, 1] * D1 + matT[0, 2] * D2
+        )
+        P[:, :, 1] = matT[1, 3] + t * (
+            matT[1, 0] * D0 + matT[1, 1] * D1 + matT[1, 2] * D2
+        )
+        P[:, :, 2] = matT[2, 3] + t * (
+            matT[2, 0] * D0 + matT[2, 1] * D1 + matT[2, 2] * D2
+        )
+
+        return P
+
+    def world_to_image(self, image_size, world, intrinsic, distCoeffs, rotation, tvec):
+        r11 = rotation[0, 0]
+        r12 = rotation[0, 1]
+        r13 = rotation[0, 2]
+        r21 = rotation[1, 0]
+        r22 = rotation[1, 1]
+        r23 = rotation[1, 2]
+        r31 = rotation[2, 0]
+        r32 = rotation[2, 1]
+        r33 = rotation[2, 2]
+
+        t1 = tvec[0]
+        t2 = tvec[1]
+        t3 = tvec[2]
+
+        k1 = distCoeffs[0]
+        k2 = distCoeffs[1]
+        p1 = distCoeffs[2]
+        p2 = distCoeffs[3]
+        k3 = distCoeffs[4]
+        k4 = distCoeffs[5]
+        k5 = distCoeffs[6]
+        k6 = distCoeffs[7]
+
+        if len(distCoeffs) > 8:
+            s1 = distCoeffs[8]
+            s2 = distCoeffs[9]
+            s3 = distCoeffs[10]
+            s4 = distCoeffs[11]
+        else:
+            s1 = s2 = s3 = s4 = 0
+
+        if len(distCoeffs) > 12:
+            tx = distCoeffs[12]
+            ty = distCoeffs[13]
+        else:
+            tx = ty = 0
+
+        fu = intrinsic[0, 0]
+        fv = intrinsic[1, 1]
+        ppu = intrinsic[0, 2]
+        ppv = intrinsic[1, 2]
+
+        cos_tx = cos(tx)
+        cos_ty = cos(ty)
+        sin_tx = sin(tx)
+        sin_ty = sin(ty)
+
+        tao11 = cos_ty * cos_tx * cos_ty + sin_ty * cos_tx * sin_ty
+        tao12 = cos_ty * cos_tx * sin_ty * sin_tx - sin_ty * cos_tx * cos_ty * sin_tx
+        tao13 = -cos_ty * cos_tx * sin_ty * cos_tx + sin_ty * cos_tx * cos_ty * cos_tx
+        tao21 = -sin_tx * sin_ty
+        tao22 = cos_ty * cos_tx * cos_tx + sin_tx * cos_ty * sin_tx
+        tao23 = cos_ty * cos_tx * sin_tx - sin_tx * cos_ty * cos_tx
+
+        P = np.zeros((image_size[1], image_size[0], 2))
+
+        c3 = r31 * world[:, :, 0] + r32 * world[:, :, 1] + r33 * world[:, :, 2] + t3
+        c1 = r11 * world[:, :, 0] + r12 * world[:, :, 1] + r13 * world[:, :, 2] + t1
+        c2 = r21 * world[:, :, 0] + r22 * world[:, :, 1] + r23 * world[:, :, 2] + t2
+
+        x1 = c1 / c3
+        y1 = c2 / c3
+        x12 = x1 * x1
+        y12 = y1 * y1
+        x1y1 = 2 * x1 * y1
+        r2 = x12 + y12
+        r4 = r2 * r2
+        r6 = r2 * r4
+
+        radial_distortion = (1 + k1 * r2 + k2 * r4 + k3 * r6) / (
+            1 + k4 * r2 + k5 * r4 + k6 * r6
+        )
+        x2 = (
+            x1 * radial_distortion + p1 * x1y1 + p2 * (r2 + 2 * x12) + s1 * r2 + s2 * r4
+        )
+        y2 = (
+            y1 * radial_distortion + p2 * x1y1 + p1 * (r2 + 2 * y12) + s3 * r2 + s4 * r4
+        )
+
+        x3 = tao11 * x2 + tao12 * y2 + tao13
+        y3 = tao21 * x2 + tao22 * y2 + tao23
+
+        P[:, :, 0] = fu * x3 + ppu
+        P[:, :, 1] = fv * y3 + ppv
+        P[c3 <= 0] = 0
+
+        return P
+
+    def spatial_transform(
+        self, image_data, new_image_size, mtx, dist, rvecs, tvecs, interpolation
+    ):
+        rotation, _ = cv2.Rodrigues(rvecs)
+        world_map = self.virtual_camera_to_world(new_image_size)
+        image_map = self.world_to_image(
+            new_image_size, world_map, mtx, dist, rotation, tvecs
+        )
+        image_map = image_map.astype(np.float32)
+        dst = cv2.remap(
+            image_data, image_map[:, :, 0], image_map[:, :, 1], interpolation
+        )
+        return dst
+
+    def calibrate(self, org_size, image_coord, world_coord):
+        """
+        calibration
+        :param org_size:
+        :param image_coord:
+        :param world_coord:
+        :return:
+        """
+        # flag = cv2.CALIB_RATIONAL_MODEL | cv2.CALIB_TILTED_MODEL  | cv2.CALIB_THIN_PRISM_MODEL
+        flag = cv2.CALIB_RATIONAL_MODEL
+        flag2 = cv2.CALIB_RATIONAL_MODEL | cv2.CALIB_TILTED_MODEL
+        flag3 = cv2.CALIB_RATIONAL_MODEL | cv2.CALIB_THIN_PRISM_MODEL
+        flag4 = (
+            cv2.CALIB_RATIONAL_MODEL
+            | cv2.CALIB_ZERO_TANGENT_DIST
+            | cv2.CALIB_FIX_ASPECT_RATIO
+        )
+        flag5 = (
+            cv2.CALIB_RATIONAL_MODEL
+            | cv2.CALIB_TILTED_MODEL
+            | cv2.CALIB_ZERO_TANGENT_DIST
+        )
+        flag6 = cv2.CALIB_RATIONAL_MODEL | cv2.CALIB_FIX_ASPECT_RATIO
+        flag_list = [flag2, flag3, flag4, flag5, flag6]
+
+        ret, mtx, dist, rvecs, tvecs = cv2.calibrateCamera(
+            world_coord.astype(np.float32),
+            image_coord.astype(np.float32),
+            org_size,
+            None,
+            None,
+            flags=flag,
+        )
+        if ret > 2:
+            # strategies
+            min_ret = ret
+            for i, flag in enumerate(flag_list):
+                _ret, _mtx, _dist, _rvecs, _tvecs = cv2.calibrateCamera(
+                    world_coord.astype(np.float32),
+                    image_coord.astype(np.float32),
+                    org_size,
+                    None,
+                    None,
+                    flags=flag,
+                )
+                if _ret < min_ret:
+                    min_ret = _ret
+                    ret, mtx, dist, rvecs, tvecs = _ret, _mtx, _dist, _rvecs, _tvecs
+
+        return ret, mtx, dist, rvecs, tvecs
+
+    def dc_homo(
+        self,
+        img,
+        img_points,
+        obj_points,
+        is_horizontal_text,
+        interpolation=cv2.INTER_LINEAR,
+        ratio_width=1.0,
+        ratio_height=1.0,
+    ):
+        """
+        divide and conquer: homography
+        # ratio_width and ratio_height must be 1.0 here
+        """
+        _img_points = img_points.reshape(-1, 2)
+        _obj_points = obj_points.reshape(-1, 3)
+
+        homo_img_list = []
+        width_list = []
+        height_list = []
+        # divide and conquer
+        for i in range(len(_img_points) // 2 - 1):
+            new_img_points = np.zeros((4, 2)).astype(np.float32)
+            new_obj_points = np.zeros((4, 2)).astype(np.float32)
+
+            new_img_points[0:2, :] = _img_points[i : (i + 2), :2]
+            new_img_points[2:4, :] = _img_points[::-1, :][i : (i + 2), :2][::-1, :]
+
+            new_obj_points[0:2, :] = _obj_points[i : (i + 2), :2]
+            new_obj_points[2:4, :] = _obj_points[::-1, :][i : (i + 2), :2][::-1, :]
+
+            if is_horizontal_text:
+                world_width = np.abs(new_obj_points[1, 0] - new_obj_points[0, 0])
+                world_height = np.abs(new_obj_points[3, 1] - new_obj_points[0, 1])
+            else:
+                world_width = np.abs(new_obj_points[1, 1] - new_obj_points[0, 1])
+                world_height = np.abs(new_obj_points[3, 0] - new_obj_points[0, 0])
+
+            homo_img = Homography(
+                img,
+                new_img_points,
+                world_width,
+                world_height,
+                interpolation=interpolation,
+                ratio_width=ratio_width,
+                ratio_height=ratio_height,
+            )
+
+            homo_img_list.append(homo_img)
+            _h, _w = homo_img.shape[:2]
+            width_list.append(_w)
+            height_list.append(_h)
+
+        # stitching
+        rectified_image = np.zeros((np.max(height_list), sum(width_list), 3)).astype(
+            np.uint8
+        )
+
+        st = 0
+        for homo_img, w, h in zip(homo_img_list, width_list, height_list):
+            rectified_image[:h, st : st + w, :] = homo_img
+            st += w
+
+        if not is_horizontal_text:
+            # vertical rotation
+            rectified_image = np.rot90(rectified_image, 3)
+
+        return rectified_image
+
+    def Homography(
+        self,
+        image,
+        img_points,
+        world_width,
+        world_height,
+        interpolation=cv2.INTER_CUBIC,
+        ratio_width=1.0,
+        ratio_height=1.0,
+    ):
+        _points = np.array(img_points).reshape(-1, 2).astype(np.float32)
+
+        expand_x = int(0.5 * world_width * (ratio_width - 1))
+        expand_y = int(0.5 * world_height * (ratio_height - 1))
+
+        pt_lefttop = [expand_x, expand_y]
+        pt_righttop = [expand_x + world_width, expand_y]
+        pt_leftbottom = [expand_x + world_width, expand_y + world_height]
+        pt_rightbottom = [expand_x, expand_y + world_height]
+
+        pts_std = np.float32([pt_lefttop, pt_righttop, pt_leftbottom, pt_rightbottom])
+
+        img_crop_width = int(world_width * ratio_width)
+        img_crop_height = int(world_height * ratio_height)
+
+        M = cv2.getPerspectiveTransform(_points, pts_std)
+
+        dst_img = cv2.warpPerspective(
+            image,
+            M,
+            (img_crop_width, img_crop_height),
+            borderMode=cv2.BORDER_CONSTANT,  # BORDER_CONSTANT BORDER_REPLICATE
+            flags=interpolation,
+        )
+
+        return dst_img
+
+    def __call__(
+        self,
+        image_data,
+        points,
+        interpolation=cv2.INTER_LINEAR,
+        ratio_width=1.0,
+        ratio_height=1.0,
+        mode="calibration",
+    ):
+        """
+        spatial transform for a poly text
+        :param image_data:
+        :param points: [x1,y1,x2,y2,x3,y3,...], clockwise order, (x1,y1) must be the top-left of first char.
+        :param interpolation: cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_AREA, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4
+        :param ratio_width:  roi_image width expansion. It should not be smaller than 1.0
+        :param ratio_height: roi_image height expansion. It should not be smaller than 1.0
+        :param mode: 'calibration' or 'homography'. when homography, ratio_width and ratio_height must be 1.0
+        :return:
+        """
+        org_h, org_w = image_data.shape[:2]
+        org_size = (org_w, org_h)
+        self.image = image_data
+
+        is_horizontal_text = self.horizontal_text_estimate(points)
+        if is_horizontal_text:
+            image_coord, world_coord, new_image_size = self.horizontal_text_process(
+                points
+            )
+        else:
+            image_coord, world_coord, new_image_size = self.vertical_text_process(
+                points, org_size
+            )
+
+        if mode.lower() == "calibration":
+            ret, mtx, dist, rvecs, tvecs = self.calibrate(
+                org_size, image_coord, world_coord
+            )
+
+            st_size = (
+                int(new_image_size[0] * ratio_width),
+                int(new_image_size[1] * ratio_height),
+            )
+            dst = self.spatial_transform(
+                image_data, st_size, mtx, dist[0], rvecs[0], tvecs[0], interpolation
+            )
+        elif mode.lower() == "homography":
+            # ratio_width and ratio_height must be 1.0 here and ret set to 0.01 without loss manually
+            ret = 0.01
+            dst = self.dc_homo(
+                image_data,
+                image_coord,
+                world_coord,
+                is_horizontal_text,
+                interpolation=interpolation,
+                ratio_width=1.0,
+                ratio_height=1.0,
+            )
+        else:
+            raise ValueError(
+                'mode must be ["calibration", "homography"], but got {}'.format(mode)
+            )
+
+        return dst, ret
+
+
+class AutoRectifier:
+    def __init__(self):
+        self.npoints = 10
+        self.curveTextRectifier = CurveTextRectifier()
+
+    @staticmethod
+    def get_rotate_crop_image(
+        img, points, interpolation=cv2.INTER_CUBIC, ratio_width=1.0, ratio_height=1.0
+    ):
+        """
+        crop or homography
+        :param img:
+        :param points:
+        :param interpolation:
+        :param ratio_width:
+        :param ratio_height:
+        :return:
+        """
+        h, w = img.shape[:2]
+        _points = np.array(points).reshape(-1, 2).astype(np.float32)
+
+        if len(_points) != 4:
+            x_min = int(np.min(_points[:, 0]))
+            y_min = int(np.min(_points[:, 1]))
+            x_max = int(np.max(_points[:, 0]))
+            y_max = int(np.max(_points[:, 1]))
+            dx = x_max - x_min
+            dy = y_max - y_min
+            expand_x = int(0.5 * dx * (ratio_width - 1))
+            expand_y = int(0.5 * dy * (ratio_height - 1))
+            x_min = np.clip(int(x_min - expand_x), 0, w - 1)
+            y_min = np.clip(int(y_min - expand_y), 0, h - 1)
+            x_max = np.clip(int(x_max + expand_x), 0, w - 1)
+            y_max = np.clip(int(y_max + expand_y), 0, h - 1)
+
+            dst_img = img[y_min:y_max, x_min:x_max, :].copy()
+        else:
+            img_crop_width = int(
+                max(
+                    np.linalg.norm(_points[0] - _points[1]),
+                    np.linalg.norm(_points[2] - _points[3]),
+                )
+            )
+            img_crop_height = int(
+                max(
+                    np.linalg.norm(_points[0] - _points[3]),
+                    np.linalg.norm(_points[1] - _points[2]),
+                )
+            )
+
+            dst_img = Homography(
+                img,
+                _points,
+                img_crop_width,
+                img_crop_height,
+                interpolation,
+                ratio_width,
+                ratio_height,
+            )
+
+        return dst_img
+
+    def visualize(self, image_data, points_list):
+        visualization = image_data.copy()
+
+        for box in points_list:
+            box = np.array(box).reshape(-1, 2).astype(np.int32)
+            cv2.drawContours(
+                visualization, [np.array(box).reshape((-1, 1, 2))], -1, (0, 0, 255), 2
+            )
+            for i, p in enumerate(box):
+                if i != 0:
+                    cv2.circle(
+                        visualization,
+                        tuple(p),
+                        radius=1,
+                        color=(255, 0, 0),
+                        thickness=2,
+                    )
+                else:
+                    cv2.circle(
+                        visualization,
+                        tuple(p),
+                        radius=1,
+                        color=(255, 255, 0),
+                        thickness=2,
+                    )
+        return visualization
+
+    def __call__(
+        self,
+        image_data,
+        points,
+        interpolation=cv2.INTER_LINEAR,
+        ratio_width=1.0,
+        ratio_height=1.0,
+        loss_thresh=5.0,
+        mode="calibration",
+    ):
+        """
+        rectification in strategies for a poly text
+        :param image_data:
+        :param points: [x1,y1,x2,y2,x3,y3,...], clockwise order, (x1,y1) must be the top-left of first char.
+        :param interpolation: cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_AREA, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4
+        :param ratio_width:  roi_image width expansion. It should not be smaller than 1.0
+        :param ratio_height: roi_image height expansion. It should not be smaller than 1.0
+        :param loss_thresh: if loss greater than loss_thresh --> get_rotate_crop_image
+        :param mode: 'calibration' or 'homography'. when homography, ratio_width and ratio_height must be 1.0
+        :return:
+        """
+        _points = np.array(points).reshape(-1, 2)
+        if len(_points) >= self.npoints and len(_points) % 2 == 0:
+            try:
+                curveTextRectifier = CurveTextRectifier()
+
+                dst_img, loss = curveTextRectifier(
+                    image_data, points, interpolation, ratio_width, ratio_height, mode
+                )
+                if loss >= 2:
+                    # for robust
+                    # large loss means it cannot be reconstruct correctly, we must find other way to reconstruct
+                    img_list, loss_list = [dst_img], [loss]
+                    _dst_img, _loss = PlanB()(
+                        image_data,
+                        points,
+                        curveTextRectifier,
+                        interpolation,
+                        ratio_width,
+                        ratio_height,
+                        loss_thresh=loss_thresh,
+                        square=True,
+                    )
+                    img_list += [_dst_img]
+                    loss_list += [_loss]
+
+                    _dst_img, _loss = PlanB()(
+                        image_data,
+                        points,
+                        curveTextRectifier,
+                        interpolation,
+                        ratio_width,
+                        ratio_height,
+                        loss_thresh=loss_thresh,
+                        square=False,
+                    )
+                    img_list += [_dst_img]
+                    loss_list += [_loss]
+
+                    min_loss = min(loss_list)
+                    dst_img = img_list[loss_list.index(min_loss)]
+
+                    if min_loss >= loss_thresh:
+                        logging.warning(
+                            "calibration loss: {} is too large for spatial transformer. It is failed. Using get_rotate_crop_image".format(
+                                loss
+                            )
+                        )
+                        dst_img = self.get_rotate_crop_image(
+                            image_data, points, interpolation, ratio_width, ratio_height
+                        )
+            except Exception as e:
+                logging.warning(f"Exception caught: {e}")
+                dst_img = self.get_rotate_crop_image(
+                    image_data, points, interpolation, ratio_width, ratio_height
+                )
+        else:
+            dst_img = self.get_rotate_crop_image(
+                image_data, _points, interpolation, ratio_width, ratio_height
+            )
+
+        return dst_img
+
+    def run(
+        self,
+        image_data,
+        points_list,
+        interpolation=cv2.INTER_LINEAR,
+        ratio_width=1.0,
+        ratio_height=1.0,
+        loss_thresh=5.0,
+        mode="calibration",
+    ):
+        """
+        run for texts in an image
+        :param image_data: numpy.ndarray. The shape is [h, w, 3]
+        :param points_list: [[x1,y1,x2,y2,x3,y3,...], [x1,y1,x2,y2,x3,y3,...], ...], clockwise order, (x1,y1) must be the top-left of first char.
+        :param interpolation: cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_AREA, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4
+        :param ratio_width:  roi_image width expansion. It should not be smaller than 1.0
+        :param ratio_height: roi_image height expansion. It should not be smaller than 1.0
+        :param loss_thresh: if loss greater than loss_thresh --> get_rotate_crop_image
+        :param mode: 'calibration' or 'homography'. when homography, ratio_width and ratio_height must be 1.0
+        :return: res: roi-image list, visualized_image: draw polys in original image
+        """
+        if image_data is None:
+            raise ValueError
+        if not isinstance(points_list, list):
+            raise ValueError
+        for points in points_list:
+            if not isinstance(points, list):
+                raise ValueError
+
+        if ratio_width < 1.0 or ratio_height < 1.0:
+            raise ValueError(
+                "ratio_width and ratio_height cannot be smaller than 1, but got {}",
+                (ratio_width, ratio_height),
+            )
+
+        if mode.lower() != "calibration" and mode.lower() != "homography":
+            raise ValueError(
+                'mode must be ["calibration", "homography"], but got {}'.format(mode)
+            )
+
+        if mode.lower() == "homography" and ratio_width != 1.0 and ratio_height != 1.0:
+            raise ValueError(
+                "ratio_width and ratio_height must be 1.0 when mode is homography, but got mode:{}, ratio:({},{})".format(
+                    mode, ratio_width, ratio_height
+                )
+            )
+
+        res = []
+        for points in points_list:
+            rectified_img = self(
+                image_data,
+                points,
+                interpolation,
+                ratio_width,
+                ratio_height,
+                loss_thresh=loss_thresh,
+                mode=mode,
+            )
+            res.append(rectified_img)
+
+        # visualize
+        visualized_image = self.visualize(image_data, points_list)
+
+        return res, visualized_image

+ 48 - 0
paddlex/inference/pipelines_new/components/common/sort_boxes.py

@@ -0,0 +1,48 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ..base import BaseComponent
+import numpy as np
+
+class SortQuadBoxes(BaseComponent):
+    """SortQuadBoxes Component"""
+
+    entities = "SortQuadBoxes"
+    def __init__(self):
+        super().__init__()
+
+    def __call__(self, dt_polys):
+        """
+        Sort quad boxes in order from top to bottom, left to right
+        args:
+            dt_polys(array):detected quad boxes with shape [4, 2]
+        return:
+            sorted boxes(array) with shape [4, 2]
+        """
+        dt_boxes = np.array(dt_polys)
+        num_boxes = dt_boxes.shape[0]
+        sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
+        _boxes = list(sorted_boxes)
+
+        for i in range(num_boxes - 1):
+            for j in range(i, -1, -1):
+                if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and (
+                    _boxes[j + 1][0][0] < _boxes[j][0][0]
+                ):
+                    tmp = _boxes[j]
+                    _boxes[j] = _boxes[j + 1]
+                    _boxes[j + 1] = tmp
+                else:
+                    break
+        return _boxes

+ 15 - 0
paddlex/inference/pipelines_new/components/prompt_engeering/__init__.py

@@ -0,0 +1,15 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .generate_kie_prompt import GenerateKIEPrompt

+ 31 - 0
paddlex/inference/pipelines_new/components/prompt_engeering/base.py

@@ -0,0 +1,31 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from abc import ABC, abstractmethod
+from .....utils.subclass_register import AutoRegisterABCMetaClass
+
+import inspect
+
+class BaseGeneratePrompt(ABC, metaclass=AutoRegisterABCMetaClass):
+    """Base Chat"""
+
+    __is_base = True
+
+    def __init__(self):
+        super().__init__()
+
+    @abstractmethod
+    def generate_prompt(self):
+        raise NotImplementedError(
+            "The method `generate_prompt` has not been implemented yet.")

+ 100 - 0
paddlex/inference/pipelines_new/components/prompt_engeering/generate_kie_prompt.py

@@ -0,0 +1,100 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .base import BaseGeneratePrompt
+
+class GenerateKIEPrompt(BaseGeneratePrompt):
+    """Generate KIE Prompt"""
+
+    entities = [
+        "text_kie_prompt",
+        "table_kie_prompt"
+    ]
+
+    def __init__(self, config):
+        super().__init__()
+
+        task_type = config.get('task_type', "")
+        task_description = config.get('task_description', "")  
+        output_format = config.get('output_format', "")  
+        rules_str = config.get('rules_str', "")  
+        few_shot_demo_text_content = config.get('few_shot_demo_text_content', "")  
+        few_shot_demo_key_value_list = config.get('few_shot_demo_key_value_list', "")
+
+        if task_description is None:
+            task_description = ""
+        
+        if output_format is None:
+            output_format = ""
+        
+        if rules_str is None:
+            rules_str = ""
+        
+        if few_shot_demo_text_content is None:
+            few_shot_demo_text_content = ""
+        
+        if few_shot_demo_key_value_list is None:
+            few_shot_demo_key_value_list = ""
+
+        if task_type not in self.entities:
+            raise ValueError(f"task type must be in {self.entities} of GenerateKIEPrompt.")
+
+        self.task_type = task_type
+        self.task_description = task_description
+        self.output_format = output_format
+        self.rules_str = rules_str
+        self.few_shot_demo_text_content = few_shot_demo_text_content
+        self.few_shot_demo_key_value_list = few_shot_demo_key_value_list
+        
+    def generate_prompt(self, text_content,
+        key_list,
+        task_description=None,
+        output_format=None,
+        rules_str=None,
+        few_shot_demo_text_content=None,
+        few_shot_demo_key_value_list=None):
+        """
+        args:
+        return:
+        """
+
+        if task_description is None:
+            task_description = self.task_description
+
+        if output_format is None:
+            output_format = self.output_format
+
+        if rules_str is None:
+            rules_str = self.rules_str
+
+        if few_shot_demo_text_content is None:
+            few_shot_demo_text_content = self.few_shot_demo_text_content
+            
+        if few_shot_demo_key_value_list is None:
+            few_shot_demo_key_value_list = self.few_shot_demo_key_value_list
+
+        prompt = f"""{task_description}{output_format}{rules_str}{few_shot_demo_text_content}{few_shot_demo_key_value_list}"""
+        if self.task_type == "table_kie_prompt":
+            prompt += f"""\n结合上面,下面正式开始:\
+                表格内容:```{text_content}```\
+                关键词列表:{key_list}。""".replace(
+                "    ", "")
+        elif self.task_type == "text_kie_prompt":
+            prompt += f"""\n结合上面的例子,下面正式开始:\
+                OCR文字:```{text_content}```\
+                关键词列表:{key_list}。""".replace(
+                "    ", "")
+        else:
+            raise ValueError(f"{self.task_type} is currently not supported.") 
+        return prompt

+ 15 - 0
paddlex/inference/pipelines_new/components/retriever/__init__.py

@@ -0,0 +1,15 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .ernie_bot_retriever import ErnieBotRetriever

+ 50 - 0
paddlex/inference/pipelines_new/components/retriever/base.py

@@ -0,0 +1,50 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from abc import ABC, abstractmethod
+from .....utils.subclass_register import AutoRegisterABCMetaClass
+
+import inspect
+import base64
+
+class BaseRetriever(ABC, metaclass=AutoRegisterABCMetaClass):
+    """Base Retriever"""
+
+    __is_base = True
+
+    VECTOR_STORE_PREFIX = "PADDLEX_VECTOR_STORE"
+
+    def __init__(self):
+        super().__init__()
+
+    @abstractmethod
+    def generate_vector_database(self):
+        raise NotImplementedError(
+            "The method `generate_vector_database` has not been implemented yet.")
+
+    @abstractmethod
+    def similarity_retrieval(self):
+        raise NotImplementedError(
+            "The method `similarity_retrieval` has not been implemented yet.")
+
+    def is_vector_store(self, s):
+        return s.startswith(self.VECTOR_STORE_PREFIX)
+
+    def encode_vector_store(self, vector_store_bytes):
+        return self.VECTOR_STORE_PREFIX + base64.b64encode(vector_store_bytes).decode(
+            "ascii"
+        )
+
+    def decode_vector_store(self, vector_store_str):
+        return base64.b64decode(vector_store_str[len(self.VECTOR_STORE_PREFIX):])

+ 148 - 0
paddlex/inference/pipelines_new/components/retriever/ernie_bot_retriever.py

@@ -0,0 +1,148 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .base import BaseRetriever
+import os
+
+from langchain.docstore.document import Document
+from langchain.text_splitter import RecursiveCharacterTextSplitter
+
+from langchain_community.embeddings import QianfanEmbeddingsEndpoint
+from langchain_community.vectorstores import FAISS
+from langchain_community import vectorstores
+from erniebot_agent.extensions.langchain.embeddings import ErnieEmbeddings
+
+import time
+
+class ErnieBotRetriever(BaseRetriever):
+    """Ernie Bot Retriever"""
+
+    entities = [
+        "ernie-4.0",
+        "ernie-3.5",
+        "ernie-3.5-8k",
+        "ernie-lite",
+        "ernie-tiny-8k",
+        "ernie-speed",
+        "ernie-speed-128k",
+        "ernie-char-8k",
+    ]
+    
+    def __init__(self, config):
+
+        super().__init__()
+
+        model_name = config.get('model_name', None)
+        api_type = config.get('api_type', None)
+        ak = config.get('ak', None)
+        sk = config.get('sk', None)
+        access_token = config.get('access_token', None)
+        
+        if model_name not in self.entities:
+            raise ValueError(f"model_name must be in {self.entities} of ErnieBotChat.")
+
+        if api_type not in ["aistudio", "qianfan"]:
+            raise ValueError("api_type must be one of ['aistudio', 'qianfan']")
+
+        if api_type == "aistudio" and access_token is None:
+            raise ValueError("access_token cannot be empty when api_type is aistudio.")
+            
+        if api_type == "qianfan" and (ak is None or sk is None):
+            raise ValueError("ak and sk cannot be empty when api_type is qianfan.")            
+
+        self.model_name = model_name
+        self.config = config
+        
+    def generate_vector_database(self, text_list, 
+        block_size=300,
+        separators=["\t", "\n", "。", "\n\n", ""],
+        sleep_time=0.5):
+        """
+        args:
+        return:
+        """
+        text_splitter = RecursiveCharacterTextSplitter(
+            chunk_size=block_size, chunk_overlap=20, separators=separators
+        )
+        texts = text_splitter.split_text("\t".join(text_list))
+        all_splits = [Document(page_content=text) for text in texts]
+
+        api_type = self.config["api_type"]
+        if api_type == "qianfan":
+            os.environ["QIANFAN_AK"] = os.environ.get("EB_AK", self.config["ak"])
+            os.environ["QIANFAN_SK"] = os.environ.get("EB_SK", self.config["sk"])
+            user_ak = os.environ.get("EB_AK", self.config["ak"])
+            user_id = hash(user_ak)
+            vectorstore = FAISS.from_documents(
+                documents=all_splits, embedding=QianfanEmbeddingsEndpoint()
+            )
+        elif api_type == "aistudio":
+            token = self.config["access_token"]
+            vectorstore = FAISS.from_documents(
+                documents=all_splits[0:1],
+                embedding=ErnieEmbeddings(aistudio_access_token=token),
+            )
+            #### ErnieEmbeddings.chunk_size = 16
+            step = min(16, len(all_splits) - 1)
+            for shot_splits in [
+                all_splits[i : i + step] for i in range(1, len(all_splits), step)
+            ]:
+                time.sleep(sleep_time)
+                vectorstore_slice = FAISS.from_documents(
+                    documents=shot_splits,
+                    embedding=ErnieEmbeddings(aistudio_access_token=token),
+                )
+                vectorstore.merge_from(vectorstore_slice)
+        else:
+            raise ValueError(f"Unsupported api_type: {api_type}")
+
+        return vectorstore
+
+    def encode_vector_store_to_bytes(self, vectorstore):
+        vectorstore = self.encode_vector_store(vectorstore.serialize_to_bytes())
+        return vectorstore
+    
+    def decode_vector_store_from_bytes(self, vectorstore):
+        if not self.is_vector_store(vectorstore):
+            raise ValueError("The retrieved vectorstore is not for PaddleX.")
+        api_type = self.config["api_type"]
+
+        if api_type == "aistudio":
+            access_token = self.config["access_token"]
+            embeddings = ErnieEmbeddings(aistudio_access_token=access_token)
+        elif api_type == "qianfan":
+            ak = self.config["ak"]
+            sk = self.config["sk"]
+            embeddings = QianfanEmbeddingsEndpoint(qianfan_ak=ak, qianfan_sk=sk)
+        else:
+            raise ValueError(f"Unsupported api_type: {api_type}")
+        vectorstore = vectorstores.FAISS.deserialize_from_bytes(
+            self.decode_vector_store(vector), embeddings
+        )
+        return vectorstore
+
+    def similarity_retrieval(self, query_text_list, vectorstore, sleep_time=0.5):
+        # 根据提问匹配上下文
+        C = []
+        for query_text in query_text_list:
+            QUESTION = query_text
+            time.sleep(sleep_time)
+            docs = vectorstore.similarity_search_with_relevance_scores(QUESTION, k=2)
+            context = [(document.page_content, score) for document, score in docs]
+            context = sorted(context, key=lambda x: x[1])
+            C.extend([x[0] for x in context[::-1]])
+        C = list(set(C))
+        all_C = " ".join(C)
+        return all_C
+        

+ 13 - 0
paddlex/inference/pipelines_new/components/utils/__init__.py

@@ -0,0 +1,13 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.

+ 204 - 0
paddlex/inference/pipelines_new/components/utils/mixin.py

@@ -0,0 +1,204 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from abc import abstractmethod
+import json
+from pathlib import Path
+import numpy as np
+from PIL import Image
+import pandas as pd
+
+from .....utils import logging
+from ....utils.io import (
+    JsonWriter,
+    ImageReader,
+    ImageWriter,
+    CSVWriter,
+    HtmlWriter,
+    XlsxWriter,
+    TextWriter,
+)
+
+
+def _save_list_data(save_func, save_path, data, *args, **kwargs):
+    save_path = Path(save_path)
+    if data is None:
+        return
+    if isinstance(data, list):
+        for idx, single in enumerate(data):
+            save_func(
+                (
+                    save_path.parent / f"{save_path.stem}_{idx}{save_path.suffix}"
+                ).as_posix(),
+                single,
+                *args,
+                **kwargs,
+            )
+    save_func(save_path.as_posix(), data, *args, **kwargs)
+    logging.info(f"The result has been saved in {save_path}.")
+
+
+class StrMixin:
+    @property
+    def str(self):
+        return self._to_str()
+
+    def _to_str(self, data, json_format=False, indent=4, ensure_ascii=False):
+        if json_format:
+            return json.dumps(data.json, indent=indent, ensure_ascii=ensure_ascii)
+        else:
+            return str(data)
+
+    def print(self, json_format=False, indent=4, ensure_ascii=False):
+        str_ = self._to_str(
+            self, json_format=json_format, indent=indent, ensure_ascii=ensure_ascii
+        )
+        logging.info(str_)
+
+
+class JsonMixin:
+    def __init__(self):
+        self._json_writer = JsonWriter()
+        self._show_funcs.append(self.save_to_json)
+
+    def _to_json(self):
+        def _format_data(obj):
+            if isinstance(obj, np.float32):
+                return float(obj)
+            elif isinstance(obj, np.ndarray):
+                return [_format_data(item) for item in obj.tolist()]
+            elif isinstance(obj, pd.DataFrame):
+                return obj.to_json(orient="records", force_ascii=False)
+            elif isinstance(obj, Path):
+                return obj.as_posix()
+            elif isinstance(obj, dict):
+                return type(obj)({k: _format_data(v) for k, v in obj.items()})
+            elif isinstance(obj, (list, tuple)):
+                return [_format_data(i) for i in obj]
+            else:
+                return obj
+
+        return _format_data(self)
+
+    @property
+    def json(self):
+        return self._to_json()
+
+    def save_to_json(self, save_path, indent=4, ensure_ascii=False, *args, **kwargs):
+        if not str(save_path).endswith(".json"):
+            save_path = Path(save_path) / f"{Path(self['input_path']).stem}.json"
+        _save_list_data(
+            self._json_writer.write,
+            save_path,
+            self.json,
+            indent=indent,
+            ensure_ascii=ensure_ascii,
+            *args,
+            **kwargs,
+        )
+
+
+class Base64Mixin:
+    def __init__(self, *args, **kwargs):
+        self._base64_writer = TextWriter(*args, **kwargs)
+        self._show_funcs.append(self.save_to_base64)
+
+    @abstractmethod
+    def _to_base64(self):
+        raise NotImplementedError
+
+    @property
+    def base64(self):
+        return self._to_base64()
+
+    def save_to_base64(self, save_path, *args, **kwargs):
+        if not str(save_path).lower().endswith((".b64")):
+            fp = Path(self["input_path"])
+            save_path = Path(save_path) / f"{fp.stem}{fp.suffix}"
+        _save_list_data(
+            self._base64_writer.write, save_path, self.base64, *args, **kwargs
+        )
+
+
+class ImgMixin:
+    def __init__(self, backend="pillow", *args, **kwargs):
+        self._img_writer = ImageWriter(backend=backend, *args, **kwargs)
+        self._show_funcs.append(self.save_to_img)
+
+    @abstractmethod
+    def _to_img(self):
+        raise NotImplementedError
+
+    @property
+    def img(self):
+        image = self._to_img()
+        # The img must be a PIL.Image obj
+        if isinstance(image, np.ndarray):
+            return Image.fromarray(image)
+        return image
+
+    def save_to_img(self, save_path, *args, **kwargs):
+        if not str(save_path).lower().endswith((".jpg", ".png")):
+            fp = Path(self["input_path"])
+            save_path = Path(save_path) / f"{fp.stem}{fp.suffix}"
+        _save_list_data(self._img_writer.write, save_path, self.img, *args, **kwargs)
+
+
+class CSVMixin:
+    def __init__(self, backend="pandas", *args, **kwargs):
+        self._csv_writer = CSVWriter(backend=backend, *args, **kwargs)
+        self._show_funcs.append(self.save_to_csv)
+
+    @abstractmethod
+    def _to_csv(self):
+        raise NotImplementedError
+
+    def save_to_csv(self, save_path, *args, **kwargs):
+        if not str(save_path).endswith(".csv"):
+            save_path = Path(save_path) / f"{Path(self['input_path']).stem}.csv"
+        _save_list_data(
+            self._csv_writer.write, save_path, self._to_csv(), *args, **kwargs
+        )
+
+
+class HtmlMixin:
+    def __init__(self, *args, **kwargs):
+        self._html_writer = HtmlWriter(*args, **kwargs)
+        self._show_funcs.append(self.save_to_html)
+
+    @property
+    def html(self):
+        return self._to_html()
+
+    def _to_html(self):
+        return self["html"]
+
+    def save_to_html(self, save_path, *args, **kwargs):
+        if not str(save_path).endswith(".html"):
+            save_path = Path(save_path) / f"{Path(self['input_path']).stem}.html"
+        _save_list_data(self._html_writer.write, save_path, self.html, *args, **kwargs)
+
+
+class XlsxMixin:
+    def __init__(self, *args, **kwargs):
+        self._xlsx_writer = XlsxWriter(*args, **kwargs)
+        self._show_funcs.append(self.save_to_xlsx)
+
+    def _to_xlsx(self):
+        return self["html"]
+
+    def save_to_xlsx(self, save_path, *args, **kwargs):
+        if not str(save_path).endswith(".xlsx"):
+            save_path = Path(save_path) / f"{Path(self['input_path']).stem}.xlsx"
+        _save_list_data(self._xlsx_writer.write, save_path, self.html, *args, **kwargs)

+ 15 - 0
paddlex/inference/pipelines_new/doc_preprocessor/__init__.py

@@ -0,0 +1,15 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .pipeline import DocPreprocessorPipeline

+ 117 - 0
paddlex/inference/pipelines_new/doc_preprocessor/pipeline.py

@@ -0,0 +1,117 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ..base import BasePipeline
+from typing import Any, Dict, Optional
+from scipy.ndimage import rotate
+from .result import DocPreprocessorResult
+
+########## [TODO]后续需要更新路径
+from ...components.transforms import ReadImage
+
+class DocPreprocessorPipeline(BasePipeline):
+    """Doc Preprocessor Pipeline"""
+
+    entities = "doc_preprocessor"
+    def __init__(self,
+        config,        
+        device=None,
+        pp_option=None, 
+        use_hpip: bool = False,
+        hpi_params: Optional[Dict[str, Any]] = None):
+        super().__init__(device=device, pp_option=pp_option, 
+            use_hpip=use_hpip, hpi_params=hpi_params)
+        
+        self.use_doc_orientation_classify = True
+        if 'use_doc_orientation_classify' in config:
+            self.use_doc_orientation_classify = config['use_doc_orientation_classify']
+
+        self.use_doc_unwarping = True
+        if 'use_doc_unwarping' in config:
+            self.use_doc_unwarping = config['use_doc_unwarping']
+        
+        if self.use_doc_orientation_classify:
+            doc_ori_classify_config = config['SubModules']["DocOrientationClassify"]
+            self.doc_ori_classify_model = self.create_model(doc_ori_classify_config)
+
+        if self.use_doc_unwarping:
+            doc_unwarping_config = config['SubModules']["DocUnwarping"]
+            self.doc_unwarping_model = self.create_model(doc_unwarping_config)
+        
+        self.img_reader = ReadImage(format="BGR")
+
+    def rotate_image(self, image_array, rotate_angle):
+        """rotate image"""
+        assert (
+            rotate_angle >= 0 and rotate_angle < 360
+        ), "rotate_angle must in [0-360), but get {rotate_angle}."
+        return rotate(image_array, rotate_angle, reshape=True)
+
+    def check_input_params(self, input_params):
+        
+        if input_params['use_doc_orientation_classify'] and \
+            not self.use_doc_orientation_classify:
+            raise ValueError("The model for doc orientation classify is not initialized.")
+
+
+        if input_params['use_doc_unwarping'] and \
+            not self.use_doc_unwarping:
+            raise ValueError("The model for doc unwarping is not initialized.")
+            
+        return 
+
+    def predict(self, input, 
+        use_doc_orientation_classify=True,
+        use_doc_unwarping=False,
+        **kwargs):
+
+        if not isinstance(input, list):
+            input_list = [input]
+        else:
+            input_list = input
+
+        input_params = {"use_doc_orientation_classify":use_doc_orientation_classify,
+            "use_doc_unwarping":use_doc_unwarping}
+        self.check_input_params(input_params)
+
+        img_id = 1
+        for input in input_list:
+            if isinstance(input, str):
+                image_array = next(self.img_reader(input))[0]['img']
+            else:
+                image_array = input
+
+            assert len(image_array.shape) == 3
+
+            if input_params['use_doc_orientation_classify']:
+                pred = next(self.doc_ori_classify_model(image_array))
+                angle = int(pred["label_names"][0])
+                rot_img = self.rotate_image(image_array, angle)
+            else:
+                angle = -1
+                rot_img = image_array
+
+            if input_params['use_doc_unwarping']:
+                output_img = next(self.doc_unwarping_model(rot_img))['doctr_img']
+            else:
+                output_img = rot_img
+
+            single_img_res = {"input_image":image_array,
+                "input_params":input_params,
+                "angle":angle, 
+                "rot_img":rot_img, 
+                "output_img":output_img,
+                "img_id":img_id}
+            img_id += 1
+            yield DocPreprocessorResult(single_img_res)

+ 51 - 0
paddlex/inference/pipelines_new/doc_preprocessor/result.py

@@ -0,0 +1,51 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+import random
+import numpy as np
+import cv2
+import PIL
+from PIL import Image, ImageDraw, ImageFont
+
+from ....utils.fonts import PINGFANG_FONT_FILE_PATH, create_font
+from ..components import CVResult
+
+class DocPreprocessorResult(CVResult):
+
+    def save_to_img(self, save_path, *args, **kwargs):
+        if not str(save_path).lower().endswith((".jpg", ".png")):
+            img_id = self["img_id"]
+            save_path = save_path + "/res_doc_preprocess_%d.jpg" % img_id
+        super().save_to_img(save_path, *args, **kwargs)
+
+    def _to_img(self):
+        """draw doc preprocess result"""
+        image = self["input_image"]
+        angle = self["angle"]
+        rot_img = self["rot_img"]
+        output_img = self["output_img"]
+        h, w = image.shape[0:2]
+        img_show = Image.new("RGB", (w * 3, h + 25), (255, 255, 255))
+        img_show.paste(Image.fromarray(image), (0, 0, w, h))
+        img_show.paste(Image.fromarray(rot_img), (w, 0, w * 2, h))
+        img_show.paste(Image.fromarray(output_img), (w * 2, 0, w * 3, h))
+
+        draw_text = ImageDraw.Draw(img_show)
+        txt_list = ["Original Image", "Rotated Image", "Unwarping Image"]
+        for tno in range(len(txt_list)):
+            txt = txt_list[tno]
+            font = create_font(txt, (w, 20), PINGFANG_FONT_FILE_PATH)
+            draw_text.text([10 + w * tno, h + 2], txt, fill=(0, 0, 0), font=font)
+        return img_show

+ 15 - 0
paddlex/inference/pipelines_new/layout_parsing/__init__.py

@@ -0,0 +1,15 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .pipeline import LayoutParsingPipeline

+ 205 - 0
paddlex/inference/pipelines_new/layout_parsing/pipeline.py

@@ -0,0 +1,205 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ..base import BasePipeline
+from typing import Any, Dict, Optional
+import numpy as np
+import cv2
+from ..components import CropByBoxes
+from .utils import convert_points_to_boxes, get_sub_regions_ocr_res
+from .table_recognition_post_processing import get_table_recognition_res
+
+from .result import LayoutParsingResult
+
+########## [TODO]后续需要更新路径
+from ...components.transforms import ReadImage
+
+class LayoutParsingPipeline(BasePipeline):
+    """Layout Parsing Pipeline"""
+
+    entities = "layout_parsing"
+    def __init__(self,
+        config,        
+        device=None,
+        pp_option=None, 
+        use_hpip: bool = False,
+        hpi_params: Optional[Dict[str, Any]] = None):
+        super().__init__(device=device, pp_option=pp_option, 
+            use_hpip=use_hpip, hpi_params=hpi_params)
+        
+        self.inintial_predictor(config)
+
+        self.img_reader = ReadImage(format="BGR")
+
+        self._crop_by_boxes = CropByBoxes()
+        
+
+    def inintial_predictor(self, config):
+        layout_det_config = config['SubModules']["LayoutDetection"]
+        self.layout_det_model = self.create_model(layout_det_config)
+
+        self.use_doc_preprocessor = False
+        if 'use_doc_preprocessor' in config:
+            self.use_doc_preprocessor = config['use_doc_preprocessor']
+        if self.use_doc_preprocessor:
+            doc_preprocessor_config = config['SubPipelines']['DocPreprocessor']
+            self.doc_preprocessor_pipeline = self.create_pipeline(doc_preprocessor_config)
+        
+        self.use_common_ocr = False
+        if "use_common_ocr" in config:
+            self.use_common_ocr = config['use_common_ocr']
+        if self.use_common_ocr:
+            common_ocr_config = config['SubPipelines']['CommonOCR']
+            self.common_ocr_pipeline = self.create_pipeline(common_ocr_config)
+        
+        self.use_seal_recognition = False
+        if "use_seal_recognition" in config:
+            self.use_seal_recognition = config['use_seal_recognition']
+        if self.use_seal_recognition:
+            seal_ocr_config = config['SubPipelines']['SealOCR']
+            self.seal_ocr_pipeline = self.create_pipeline(seal_ocr_config)            
+        
+        self.use_table_recognition = False
+        if "use_table_recognition" in config:
+            self.use_table_recognition = config['use_table_recognition']
+        if self.use_table_recognition:
+            table_structure_config = config['SubModules']['TableStructurePredictor']
+            self.table_structure_model = self.create_model(table_structure_config)
+            if not self.use_common_ocr:
+                common_ocr_config = config['SubPipelines']['OCR']
+                self.common_ocr_pipeline = self.create_pipeline(common_ocr_config)
+        return 
+
+    def get_text_paragraphs_ocr_res(self, overall_ocr_res, layout_det_res):
+        '''get ocr res of the text paragraphs'''
+        object_boxes = []
+        for box_info in layout_det_res['boxes']:
+            if box_info['label'].lower() in ['image', 'formula', 'table', 'seal']:
+                object_boxes.append(box_info['coordinate'])
+        object_boxes = np.array(object_boxes)
+        return get_sub_regions_ocr_res(overall_ocr_res, object_boxes, flag_within=False)
+
+    def check_input_params(self, input_params):
+
+        if input_params['use_doc_preprocessor'] and not self.use_doc_preprocessor:
+            raise ValueError("The models for doc preprocessor are not initialized.")
+
+        if input_params['use_common_ocr'] and not self.use_common_ocr:
+            raise ValueError("The models for common OCR are not initialized.")
+
+        if input_params['use_seal_recognition'] and not self.use_seal_recognition:
+            raise ValueError("The models for seal recognition are not initialized.")
+
+        if input_params['use_table_recognition'] and not self.use_table_recognition:
+            raise ValueError("The models for table recognition are not initialized.")
+
+        return
+
+    def predict(self, input, 
+        use_doc_orientation_classify=True,
+        use_doc_unwarping=True,
+        use_common_ocr=True,
+        use_seal_recognition=True,
+        use_table_recognition=True,
+        **kwargs):
+
+        if not isinstance(input, list):
+            input_list = [input]
+        else:
+            input_list = input
+        
+        input_params = {"use_doc_preprocessor":self.use_doc_preprocessor,
+            "use_doc_orientation_classify":use_doc_orientation_classify,
+            "use_doc_unwarping":use_doc_unwarping,
+            "use_common_ocr":use_common_ocr,
+            "use_seal_recognition":use_seal_recognition,
+            "use_table_recognition":use_table_recognition}
+            
+        if use_doc_orientation_classify or use_doc_unwarping:
+            input_params['use_doc_preprocessor'] = True
+
+        self.check_input_params(input_params)
+
+        img_id = 1
+        for input in input_list:
+            if isinstance(input, str):
+                image_array = next(self.img_reader(input))[0]['img']
+            else:
+                image_array = input
+
+            assert len(image_array.shape) == 3
+
+            if input_params['use_doc_preprocessor']:
+                doc_preprocessor_res = next(self.doc_preprocessor_pipeline(
+                    image_array, 
+                    use_doc_orientation_classify=use_doc_orientation_classify,
+                    use_doc_unwarping=use_doc_unwarping))
+                doc_preprocessor_image = doc_preprocessor_res['output_img']
+                doc_preprocessor_res['img_id'] = img_id
+            else:
+                doc_preprocessor_res = {}
+                doc_preprocessor_image = image_array
+            
+
+            ########## [TODO]RT-DETR 检测结果有重复
+            layout_det_res = next(self.layout_det_model(doc_preprocessor_image))
+
+            if input_params['use_common_ocr'] or input_params['use_table_recognition']:
+                overall_ocr_res = next(self.common_ocr_pipeline(doc_preprocessor_image))
+                overall_ocr_res['img_id'] = img_id
+                dt_boxes = convert_points_to_boxes(overall_ocr_res['dt_polys'])
+                overall_ocr_res['dt_boxes'] = dt_boxes
+            else:
+                overall_ocr_res = {}
+            
+            text_paragraphs_ocr_res = {}
+            if input_params['use_common_ocr']:
+                text_paragraphs_ocr_res = self.get_text_paragraphs_ocr_res(
+                    overall_ocr_res, layout_det_res)
+                text_paragraphs_ocr_res['img_id'] = img_id
+            
+            table_res_list = []
+            if input_params['use_table_recognition']:
+                table_region_id = 1
+                for box_info in layout_det_res['boxes']:
+                    if box_info['label'].lower() in ['table']:
+                        crop_img_info = self._crop_by_boxes(doc_preprocessor_image, [box_info])
+                        crop_img_info = crop_img_info[0]
+                        table_structure_pred = next(self.table_structure_model(
+                            crop_img_info['img']))
+                        table_recognition_res = get_table_recognition_res(
+                            crop_img_info, table_structure_pred, overall_ocr_res)
+                        table_recognition_res['table_region_id'] = table_region_id
+                        table_region_id += 1
+                        table_res_list.append(table_recognition_res)
+            
+            seal_res_list = []
+            if input_params['use_seal_recognition']:
+                seal_region_id = 1
+                for box_info in layout_det_res['boxes']:
+                    if box_info['label'].lower() in ['seal']:
+                        crop_img_info = self._crop_by_boxes(doc_preprocessor_image, [box_info])
+                        crop_img_info = crop_img_info[0]
+                        seal_ocr_res = next(self.seal_ocr_pipeline(crop_img_info['img']))
+                        seal_ocr_res['seal_region_id'] = seal_region_id
+                        seal_region_id += 1
+                        seal_res_list.append(seal_ocr_res)
+            
+            single_img_res = {"layout_det_res":layout_det_res,
+                "doc_preprocessor_res":doc_preprocessor_res,
+                "text_paragraphs_ocr_res":text_paragraphs_ocr_res,
+                "table_res_list":table_res_list,
+                "seal_res_list":seal_res_list,
+                "input_params":input_params}
+            yield LayoutParsingResult(single_img_res)

+ 97 - 0
paddlex/inference/pipelines_new/layout_parsing/result.py

@@ -0,0 +1,97 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+import random
+import numpy as np
+import cv2
+import PIL
+import os
+from PIL import Image, ImageDraw, ImageFont
+
+from ....utils.fonts import PINGFANG_FONT_FILE_PATH
+from ..components import CVResult, HtmlMixin, XlsxMixin
+
+class TableRecognitionResult(CVResult, HtmlMixin, XlsxMixin):
+    def __init__(self, data):
+        super().__init__(data)
+        HtmlMixin.__init__(self)
+        XlsxMixin.__init__(self)
+
+    def save_to_html(self, save_path, *args, **kwargs):
+        if not str(save_path).lower().endswith(".html"):
+            save_path = save_path + "/res_table_%d.html" % self['table_region_id']
+        super().save_to_html(save_path, *args, **kwargs)
+
+    def _to_html(self):
+        return self["pred_html"]
+
+    def save_to_xlsx(self, save_path, *args, **kwargs):
+        if not str(save_path).lower().endswith(".xlsx"):
+            save_path = save_path + "/res_table_%d.xlsx" % self['table_region_id']
+        super().save_to_xlsx(save_path, *args, **kwargs)
+
+    def _to_xlsx(self):
+        return self["pred_html"]
+
+    def save_to_img(self, save_path, *args, **kwargs):
+        if not str(save_path).lower().endswith((".jpg", ".png")):
+            ocr_save_path = save_path + "/res_table_ocr_%d.jpg" % self['table_region_id']
+            save_path = save_path + "/res_table_cell_%d.jpg" % self['table_region_id']
+        self['table_ocr_pred'].save_to_img(ocr_save_path)
+        super().save_to_img(save_path, *args, **kwargs)
+
+    def _to_img(self):
+        input_img = self['table_ocr_pred']['input_img'].copy()
+        cell_box_list = self['cell_box_list']
+        for box in cell_box_list:
+            x1, y1, x2, y2 = [int(pos) for pos in box]
+            cv2.rectangle(input_img, (x1, y1), (x2, y2), (255, 0, 0), 2)
+        return input_img
+
+class LayoutParsingResult(dict):
+    def __init__(self, data):
+        super().__init__(data)
+    
+    def save_results(self, save_path):
+        if not os.path.isdir(save_path):
+            raise ValueError("The save path should be a dir.")
+
+        layout_det_res = self['layout_det_res']
+        save_img_path = save_path + "/layout_det_result.jpg"
+        layout_det_res.save_to_img(save_img_path)
+
+        input_params = self['input_params']
+        if input_params['use_doc_preprocessor']:
+            save_img_path = save_path + "/doc_preprocessor_result.jpg"
+            self['doc_preprocessor_res'].save_to_img(save_img_path)
+        
+        if input_params['use_common_ocr']:
+            save_img_path = save_path + "/text_paragraphs_ocr_result.jpg"
+            self['text_paragraphs_ocr_res'].save_to_img(save_img_path)
+
+        if input_params['use_table_recognition']:
+            for tno in range(len(self['table_res_list'])):
+                table_res = self['table_res_list'][tno]
+                table_res.save_to_img(save_path)
+                table_res.save_to_html(save_path)
+                table_res.save_to_xlsx(save_path)
+        
+        if input_params['use_seal_recognition']:
+            for sno in range(len(self['seal_res_list'])):
+                seal_res = self['seal_res_list'][sno]
+                save_img_path = save_path + "/seal_%d_recognition_result.jpg" % seal_res['seal_region_id']
+                seal_res.save_to_img(save_img_path)          
+        return
+

+ 203 - 0
paddlex/inference/pipelines_new/layout_parsing/table_recognition_post_processing.py

@@ -0,0 +1,203 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .utils import convert_points_to_boxes, get_sub_regions_ocr_res
+import numpy as np
+from .result import TableRecognitionResult
+
+def get_ori_image_coordinate(x, y, box_list):
+    """
+    get the original coordinate from Cropped image to Original image.
+    Args:
+        x (int): x coordinate of cropped image
+        y (int): y coordinate of cropped image
+        box_list (list): list of table bounding boxes, eg. [[x1, y1, x2, y2, x3, y3, x4, y4]]
+    Returns:
+        list: list of original coordinates, eg. [[x1, y1, x2, y2, x3, y3, x4, y4]]
+    """
+    if not box_list:
+        return box_list
+    offset = np.array([x, y] * 4)
+    box_list = np.array(box_list)
+    if box_list.shape[-1] == 2:
+        offset = offset.reshape(4, 2)
+    ori_box_list = offset + box_list
+    return ori_box_list
+
+def convert_table_structure_pred_bbox(table_structure_pred, 
+    crop_start_point, img_shape):
+
+    cell_points_list = table_structure_pred['bbox']
+    ori_cell_points_list = get_ori_image_coordinate(crop_start_point[0], 
+        crop_start_point[1], cell_points_list)
+    ori_cell_points_list = np.reshape(ori_cell_points_list, (-1, 4, 2))
+    cell_box_list = convert_points_to_boxes(ori_cell_points_list)
+    img_height, img_width = img_shape
+    cell_box_list = np.clip(cell_box_list, 0, 
+        [img_width, img_height, img_width, img_height])
+    table_structure_pred['cell_box_list'] = cell_box_list
+    return 
+
+def distance(box_1, box_2):
+    """
+    compute the distance between two boxes
+
+    Args:
+        box_1 (list): first rectangle box,eg.(x1, y1, x2, y2)
+        box_2 (list): second rectangle box,eg.(x1, y1, x2, y2)
+
+    Returns:
+        int: the distance between two boxes
+    """
+    x1, y1, x2, y2 = box_1
+    x3, y3, x4, y4 = box_2
+    dis = abs(x3 - x1) + abs(y3 - y1) + abs(x4 - x2) + abs(y4 - y2)
+    dis_2 = abs(x3 - x1) + abs(y3 - y1)
+    dis_3 = abs(x4 - x2) + abs(y4 - y2)
+    return dis + min(dis_2, dis_3)
+
+def compute_iou(rec1, rec2):
+    """
+    computing IoU
+    Args:
+        rec1 (list): (x1, y1, x2, y2)
+        rec2 (list): (x1, y1, x2, y2)
+    Returns:
+        float: Intersection over Union
+    """
+    # computing area of each rectangles
+    S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1])
+    S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1])
+
+    # computing the sum_area
+    sum_area = S_rec1 + S_rec2
+
+    # find the each edge of intersect rectangle
+    left_line = max(rec1[0], rec2[0])
+    right_line = min(rec1[2], rec2[2])
+    top_line = max(rec1[1], rec2[1])
+    bottom_line = min(rec1[3], rec2[3])
+
+    # judge if there is an intersect
+    if left_line >= right_line or top_line >= bottom_line:
+        return 0.0
+    else:
+        intersect = (right_line - left_line) * (bottom_line - top_line)
+        return (intersect / (sum_area - intersect)) * 1.0
+
+def match_table_and_ocr(cell_box_list, ocr_dt_boxes):
+    """
+    match table and ocr
+
+    Args:
+        cell_box_list (list): bbox for table cell, 2 points, [left, top, right, bottom]
+        ocr_dt_boxes (list): bbox for ocr, 2 points, [left, top, right, bottom]
+
+    Returns:
+        dict: matched dict, key is table index, value is ocr index
+    """
+    matched = {}
+    for i, ocr_box in enumerate(np.array(ocr_dt_boxes)):
+        ocr_box = ocr_box.astype(np.float32)
+        distances = []
+        for j, table_box in enumerate(cell_box_list):
+            distances.append((distance(table_box, ocr_box), 
+                1.0 - compute_iou(table_box, ocr_box)))  # compute iou and l1 distance
+        sorted_distances = distances.copy()
+        # select det box by iou and l1 distance
+        sorted_distances = sorted(
+            sorted_distances, key=lambda item: (item[1], item[0]))
+        if distances.index(sorted_distances[0]) not in matched.keys():
+            matched[distances.index(sorted_distances[0])] = [i]
+        else:
+            matched[distances.index(sorted_distances[0])].append(i)
+    return matched
+
+def get_html_result(matched_index, ocr_contents, pred_structures):
+    pred_html = []
+    td_index = 0
+    head_structure = pred_structures[0:3]
+    html = "".join(head_structure)
+    table_structure = pred_structures[3:-3]
+    for tag in table_structure:
+        if "</td>" in tag:
+            if "<td></td>" == tag:
+                pred_html.extend("<td>")
+            if td_index in matched_index.keys():
+                b_with = False
+                if (
+                    "<b>" in ocr_contents[matched_index[td_index][0]]
+                    and len(matched_index[td_index]) > 1
+                ):
+                    b_with = True
+                    pred_html.extend("<b>")
+                for i, td_index_index in enumerate(matched_index[td_index]):
+                    content = ocr_contents[td_index_index]
+                    if len(matched_index[td_index]) > 1:
+                        if len(content) == 0:
+                            continue
+                        if content[0] == " ":
+                            content = content[1:]
+                        if "<b>" in content:
+                            content = content[3:]
+                        if "</b>" in content:
+                            content = content[:-4]
+                        if len(content) == 0:
+                            continue
+                        if (
+                            i != len(matched_index[td_index]) - 1
+                            and " " != content[-1]
+                        ):
+                            content += " "
+                    pred_html.extend(content)
+                if b_with:
+                    pred_html.extend("</b>")
+            if "<td></td>" == tag:
+                pred_html.append("</td>")
+            else:
+                pred_html.append(tag)
+            td_index += 1
+        else:
+            pred_html.append(tag)
+    html += "".join(pred_html)
+    end_structure = pred_structures[-3:]
+    html += "".join(end_structure)
+    return html
+
+def get_table_recognition_res(crop_img_info, table_structure_pred, overall_ocr_res):
+    '''get_table_recognition_res'''
+
+    table_box = np.array([crop_img_info['box']])
+    table_ocr_pred = get_sub_regions_ocr_res(overall_ocr_res, table_box)
+
+    crop_start_point = [table_box[0][0], table_box[0][1]]
+    img_shape = overall_ocr_res['input_img'].shape[0:2]
+
+    convert_table_structure_pred_bbox(table_structure_pred, 
+        crop_start_point, img_shape)
+    
+    structures = table_structure_pred["structure"]
+    cell_box_list = table_structure_pred["cell_box_list"]
+    ocr_dt_boxes = table_ocr_pred["dt_boxes"]
+    ocr_text_res = table_ocr_pred["rec_text"]
+
+    matched_index = match_table_and_ocr(cell_box_list, ocr_dt_boxes)
+    pred_html = get_html_result(matched_index, ocr_text_res, structures)
+
+    single_img_res = {"cell_box_list":cell_box_list, 
+        "table_ocr_pred":table_ocr_pred,
+        "pred_html":pred_html}
+    return TableRecognitionResult(single_img_res)
+
+

+ 87 - 0
paddlex/inference/pipelines_new/layout_parsing/utils.py

@@ -0,0 +1,87 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+__all__ = [
+    "convert_points_to_boxes",
+    "get_sub_regions_ocr_res"
+]
+
+import numpy as np
+import copy
+
+def convert_points_to_boxes(dt_polys):
+    if len(dt_polys) > 0:
+        dt_polys_tmp = dt_polys.copy()
+        dt_polys_tmp = np.array(dt_polys_tmp)
+        boxes_left = np.min(dt_polys_tmp[:, :, 0], axis=1)
+        boxes_right = np.max(dt_polys_tmp[:, :, 0], axis=1)
+        boxes_top = np.min(dt_polys_tmp[:, :, 1], axis=1)
+        boxes_bottom = np.max(dt_polys_tmp[:, :, 1], axis=1)
+        dt_boxes = np.array([boxes_left, boxes_top, boxes_right, boxes_bottom])
+        dt_boxes = dt_boxes.T
+    else:
+        dt_boxes = np.array([])
+    return dt_boxes
+
+def get_overlap_boxes_idx(src_boxes, ref_boxes):
+    '''get overlap boxes idx''' 
+    match_idx_list = []
+    src_boxes_num = len(src_boxes)
+    if src_boxes_num > 0 and len(ref_boxes) > 0:
+        for rno in range(len(ref_boxes)):
+            ref_box = ref_boxes[rno]
+            x1 = np.maximum(ref_box[0], src_boxes[:, 0])
+            y1 = np.maximum(ref_box[1], src_boxes[:, 1])
+            x2 = np.minimum(ref_box[2], src_boxes[:, 2])
+            y2 = np.minimum(ref_box[3], src_boxes[:, 3])
+            pub_w = x2 - x1
+            pub_h = y2 - y1
+            match_idx = np.where((pub_w > 3) & (pub_h > 3))[0]
+            match_idx_list.extend(match_idx)                                   
+    return match_idx_list
+
+def get_sub_regions_ocr_res(overall_ocr_res, object_boxes, flag_within=True):
+    """
+    :param flag_within: True (within the object regions), False (outside the object regions)
+    :return:
+    """
+
+    sub_regions_ocr_res = copy.deepcopy(overall_ocr_res)
+    sub_regions_ocr_res['input_img'] = overall_ocr_res['input_img']
+    sub_regions_ocr_res['img_id'] = -1
+    sub_regions_ocr_res['dt_polys'] = []
+    sub_regions_ocr_res['rec_text'] = []
+    sub_regions_ocr_res['rec_score'] = []
+    sub_regions_ocr_res['dt_boxes'] = []
+
+    overall_text_boxes = overall_ocr_res['dt_boxes']
+    match_idx_list = get_overlap_boxes_idx(overall_text_boxes, object_boxes)
+    match_idx_list = list(set(match_idx_list))
+    for box_no in range(len(overall_text_boxes)):
+        if flag_within:
+            if box_no in match_idx_list:
+                flag_match = True
+            else:
+                flag_match = False
+        else:
+            if box_no not in match_idx_list:
+                flag_match = True
+            else:
+                flag_match = False
+        if flag_match:
+            sub_regions_ocr_res['dt_polys'].append(overall_ocr_res['dt_polys'][box_no])
+            sub_regions_ocr_res['rec_text'].append(overall_ocr_res['rec_text'][box_no])
+            sub_regions_ocr_res['rec_score'].append(overall_ocr_res['rec_score'][box_no])
+            sub_regions_ocr_res['dt_boxes'].append(overall_ocr_res['dt_boxes'][box_no])
+    return sub_regions_ocr_res

+ 15 - 0
paddlex/inference/pipelines_new/ocr/__init__.py

@@ -0,0 +1,15 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .pipeline import OCRPipeline

+ 96 - 0
paddlex/inference/pipelines_new/ocr/pipeline.py

@@ -0,0 +1,96 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ..base import BasePipeline
+from typing import Any, Dict, Optional
+from ..components import SortQuadBoxes, CropByPolys
+from .result import OCRResult
+
+########## [TODO]后续需要更新路径
+from ...components.transforms import ReadImage
+
+class OCRPipeline(BasePipeline):
+    """OCR Pipeline"""
+
+    entities = "OCR"
+    def __init__(self,
+        config,        
+        device=None,
+        pp_option=None, 
+        use_hpip: bool = False,
+        hpi_params: Optional[Dict[str, Any]] = None):
+        super().__init__(device=device, pp_option=pp_option, 
+            use_hpip=use_hpip, hpi_params=hpi_params)
+        
+        text_det_model_config = config['SubModules']["TextDetection"]
+        self.text_det_model = self.create_model(text_det_model_config)
+
+        text_rec_model_config = config['SubModules']["TextRecognition"]
+        self.text_rec_model = self.create_model(text_rec_model_config)
+
+        self.text_type = config['text_type']
+
+        self._sort_quad_boxes = SortQuadBoxes()
+
+        if self.text_type == "common":
+            self._crop_by_polys = CropByPolys(det_box_type = "quad")
+        elif self.text_type == "seal":
+            self._crop_by_polys = CropByPolys(det_box_type = "poly")
+        else:
+            raise ValueError("Unsupported text type {}".format(self.text_type))
+
+        self.img_reader = ReadImage(format="BGR")
+
+    def predict(self, input, **kwargs):
+        if not isinstance(input, list):
+            input_list = [input]
+        else:
+            input_list = input
+        img_id = 1
+        for input in input_list:
+            if isinstance(input, str):
+                image_array = next(self.img_reader(input))[0]['img']
+            else:
+                image_array = input
+
+            assert len(image_array.shape) == 3
+
+            det_res = next(self.text_det_model(image_array))
+
+            dt_polys = det_res['dt_polys']
+            dt_scores = det_res['dt_scores']
+
+            ########## [TODO]需要确认检测模块和识别模块过滤阈值等情况
+
+            if self.text_type == "common":
+                dt_polys = self._sort_quad_boxes(dt_polys)
+                
+            single_img_res = {'input_img':image_array, 'dt_polys':dt_polys, \
+                "img_id":img_id, "text_type":self.text_type}
+            img_id += 1
+            single_img_res["rec_text"] = []
+            single_img_res["rec_score"] = []
+            if len(dt_polys) > 0:
+                all_subs_of_img = list(self._crop_by_polys(image_array, dt_polys))
+
+                ########## [TODO]updata in future
+                for sub_img in all_subs_of_img:
+                    sub_img['input'] = sub_img['img']
+                ##########
+
+                for rec_res in self.text_rec_model(all_subs_of_img):
+                    single_img_res["rec_text"].append(rec_res["rec_text"])
+                    single_img_res["rec_score"].append(rec_res["rec_score"])
+
+            yield OCRResult(single_img_res)

+ 160 - 0
paddlex/inference/pipelines_new/ocr/result.py

@@ -0,0 +1,160 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+import random
+import numpy as np
+import cv2
+import PIL
+from PIL import Image, ImageDraw, ImageFont
+
+from ....utils.fonts import PINGFANG_FONT_FILE_PATH
+from ..components import CVResult
+
+class OCRResult(CVResult):
+    def save_to_img(self, save_path, *args, **kwargs):
+        if not str(save_path).lower().endswith((".jpg", ".png")):
+            img_id = self["img_id"]
+            save_path = save_path + "/res_ocr_%d.jpg" % img_id
+        super().save_to_img(save_path, *args, **kwargs)
+
+    def get_minarea_rect(self, points):
+        bounding_box = cv2.minAreaRect(points)
+        points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
+
+        index_a, index_b, index_c, index_d = 0, 1, 2, 3
+        if points[1][1] > points[0][1]:
+            index_a = 0
+            index_d = 1
+        else:
+            index_a = 1
+            index_d = 0
+        if points[3][1] > points[2][1]:
+            index_b = 2
+            index_c = 3
+        else:
+            index_b = 3
+            index_c = 2
+
+        box = np.array(
+            [points[index_a], points[index_b], points[index_c], points[index_d]]
+        ).astype(np.int32)
+
+        return box
+
+    def _to_img(self):
+        """draw ocr result"""
+        # TODO(gaotingquan): mv to postprocess
+        drop_score = 0.5
+
+        boxes = self["dt_polys"]
+        txts = self["rec_text"]
+        scores = self["rec_score"]
+        image = self['input_img']
+        h, w = image.shape[0:2]
+        image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+        img_left = Image.fromarray(image_rgb)
+        img_right = np.ones((h, w, 3), dtype=np.uint8) * 255
+        random.seed(0)
+        draw_left = ImageDraw.Draw(img_left)
+        if txts is None or len(txts) != len(boxes):
+            txts = [None] * len(boxes)
+        for idx, (box, txt) in enumerate(zip(boxes, txts)):
+            try:
+                if scores is not None and scores[idx] < drop_score:
+                    continue
+                color = (
+                    random.randint(0, 255),
+                    random.randint(0, 255),
+                    random.randint(0, 255),
+                )
+                box = np.array(box)
+                if len(box) > 4:
+                    pts = [(x, y) for x, y in box.tolist()]
+                    draw_left.polygon(pts, outline=color, width=8)
+                    box = self.get_minarea_rect(box)
+                    height = int(0.5 * (max(box[:, 1]) - min(box[:, 1])))
+                    box[:2, 1] = np.mean(box[:, 1])
+                    box[2:, 1] = np.mean(box[:, 1]) + min(20, height)
+                draw_left.polygon(box, fill=color)
+                img_right_text = draw_box_txt_fine(
+                    (w, h), box, txt, PINGFANG_FONT_FILE_PATH
+                )
+                pts = np.array(box, np.int32).reshape((-1, 1, 2))
+                cv2.polylines(img_right_text, [pts], True, color, 1)
+                img_right = cv2.bitwise_and(img_right, img_right_text)
+            except:
+                continue
+
+        img_left = Image.blend(Image.fromarray(image_rgb), img_left, 0.5)
+        img_show = Image.new("RGB", (w * 2, h), (255, 255, 255))
+        img_show.paste(img_left, (0, 0, w, h))
+        img_show.paste(Image.fromarray(img_right), (w, 0, w * 2, h))
+        return img_show
+
+
+def draw_box_txt_fine(img_size, box, txt, font_path):
+    """draw box text"""
+    box_height = int(
+        math.sqrt((box[0][0] - box[3][0]) ** 2 + (box[0][1] - box[3][1]) ** 2)
+    )
+    box_width = int(
+        math.sqrt((box[0][0] - box[1][0]) ** 2 + (box[0][1] - box[1][1]) ** 2)
+    )
+
+    if box_height > 2 * box_width and box_height > 30:
+        img_text = Image.new("RGB", (box_height, box_width), (255, 255, 255))
+        draw_text = ImageDraw.Draw(img_text)
+        if txt:
+            font = create_font(txt, (box_height, box_width), font_path)
+            draw_text.text([0, 0], txt, fill=(0, 0, 0), font=font)
+        img_text = img_text.transpose(Image.ROTATE_270)
+    else:
+        img_text = Image.new("RGB", (box_width, box_height), (255, 255, 255))
+        draw_text = ImageDraw.Draw(img_text)
+        if txt:
+            font = create_font(txt, (box_width, box_height), font_path)
+            draw_text.text([0, 0], txt, fill=(0, 0, 0), font=font)
+
+    pts1 = np.float32(
+        [[0, 0], [box_width, 0], [box_width, box_height], [0, box_height]]
+    )
+    pts2 = np.array(box, dtype=np.float32)
+    M = cv2.getPerspectiveTransform(pts1, pts2)
+
+    img_text = np.array(img_text, dtype=np.uint8)
+    img_right_text = cv2.warpPerspective(
+        img_text,
+        M,
+        img_size,
+        flags=cv2.INTER_NEAREST,
+        borderMode=cv2.BORDER_CONSTANT,
+        borderValue=(255, 255, 255),
+    )
+    return img_right_text
+
+
+def create_font(txt, sz, font_path):
+    """create font"""
+    font_size = int(sz[1] * 0.8)
+    font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
+    if int(PIL.__version__.split(".")[0]) < 10:
+        length = font.getsize(txt)[0]
+    else:
+        length = font.getlength(txt)
+
+    if length > sz[0]:
+        font_size = int(font_size * sz[0] / length)
+        font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
+    return font

+ 15 - 0
paddlex/inference/pipelines_new/pp_chatocrv3_doc/__init__.py

@@ -0,0 +1,15 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .pipeline import PP_ChatOCRv3_doc_Pipeline

+ 329 - 0
paddlex/inference/pipelines_new/pp_chatocrv3_doc/pipeline.py

@@ -0,0 +1,329 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ..base import BasePipeline
+
+from typing import Any, Dict, Optional
+
+# import numpy as np
+# import cv2
+from .result import VisualInfoResult
+import re
+
+########## [TODO]后续需要更新路径
+from ...components.transforms import ReadImage
+
+import json
+
+class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
+    """PP-ChatOCRv3-doc Pipeline"""
+
+    entities = "PP-ChatOCRv3-doc"
+    def __init__(self,
+        config,
+        device=None,
+        pp_option=None, 
+        use_hpip: bool = False,
+        hpi_params: Optional[Dict[str, Any]] = None):
+        super().__init__(device=device, pp_option=pp_option, 
+            use_hpip=use_hpip, hpi_params=hpi_params)
+        
+        self.inintial_predictor(config)
+
+        self.img_reader = ReadImage(format="BGR")
+        
+    def inintial_predictor(self, config):
+        # layout_parsing_config = config['SubPipelines']["LayoutParser"]
+        # self.layout_parsing_pipeline = self.create_pipeline(layout_parsing_config)
+
+        chat_bot_config = config['SubModules']['LLM_Chat']
+        self.chat_bot = self.create_chat_bot(chat_bot_config)
+
+        retriever_config = config['SubModules']['LLM_Retriever']
+        self.retriever = self.create_retriever(retriever_config)
+
+        text_pe_config = config['SubModules']['PromptEngneering']['KIE_CommonText']
+        self.text_pe = self.create_prompt_engeering(text_pe_config)
+        
+        table_pe_config = config['SubModules']['PromptEngneering']['KIE_Table']
+        self.table_pe = self.create_prompt_engeering(table_pe_config)
+
+        return 
+
+    def decode_visual_result(self, layout_parsing_result):
+        text_paragraphs_ocr_res = layout_parsing_result['text_paragraphs_ocr_res']
+        seal_res_list = layout_parsing_result['seal_res_list']
+        normal_text_dict = {}
+        layout_type = "text"
+        for text in text_paragraphs_ocr_res['rec_text']:
+            if layout_type not in normal_text_dict:
+                normal_text_dict[layout_type] = text
+            else:
+                normal_text_dict[layout_type] += f"\n {text}"
+        
+        layout_type = "seal"
+        for seal_res in seal_res_list:
+            for text in seal_res['rec_text']:
+                if layout_type not in normal_text_dict:
+                    normal_text_dict[layout_type] = text
+                else:
+                    normal_text_dict[layout_type] += f"\n {text}"
+
+        table_res_list = layout_parsing_result['table_res_list']
+        table_text_list = []
+        table_html_list = []
+        for table_res in table_res_list:
+            table_html_list.append(table_res['pred_html'])
+            single_table_text = " ".join(table_res["table_ocr_pred"]['rec_text'])
+            table_text_list.append(single_table_text)
+
+        visual_info = {}
+        visual_info['normal_text_dict'] = normal_text_dict
+        visual_info['table_text_list'] = table_text_list
+        visual_info['table_html_list'] = table_html_list
+        return VisualInfoResult(visual_info)
+
+    def visual_predict(self, input,
+        use_doc_orientation_classify=True,
+        use_doc_unwarping=True,
+        use_common_ocr=True,
+        use_seal_recognition=True,
+        use_table_recognition=True,
+        **kwargs):
+
+        if not isinstance(input, list):
+            input_list = [input]
+        else:
+            input_list = input
+
+        img_id = 1
+        for input in input_list:
+            if isinstance(input, str):
+                image_array = next(self.img_reader(input))[0]['img']
+            else:
+                image_array = input
+
+            assert len(image_array.shape) == 3
+
+            layout_parsing_result = next(self.layout_parsing_pipeline.predict(
+                image_array,
+                use_doc_orientation_classify=use_doc_orientation_classify,
+                use_doc_unwarping=use_doc_unwarping,
+                use_common_ocr=use_common_ocr,
+                use_seal_recognition=use_seal_recognition,
+                use_table_recognition=use_table_recognition))
+            
+            visual_info = self.decode_visual_result(layout_parsing_result)
+
+            visual_predict_res = {"layout_parsing_result":layout_parsing_result,
+                "visual_info":visual_info}
+            yield visual_predict_res
+
+    def save_visual_info_list(self, visual_info, save_path):
+        if not isinstance(visual_info, list):
+            visual_info_list = [visual_info]
+        else:
+            visual_info_list = visual_info
+
+        with open(save_path, "w") as fout:
+            fout.write(json.dumps(visual_info_list, ensure_ascii=False) + "\n")
+        return
+    
+    def load_visual_info_list(self, data_path):
+        with open(data_path, "r") as fin:
+            data = fin.readline()
+            visual_info_list = json.loads(data)
+        return visual_info_list
+
+    def merge_visual_info_list(self, visual_info_list):
+        all_normal_text_list = []
+        all_table_text_list = []
+        all_table_html_list = []
+        for single_visual_info in visual_info_list:
+            normal_text_dict = single_visual_info['normal_text_dict']
+            table_text_list = single_visual_info['table_text_list']
+            table_html_list = single_visual_info['table_html_list']
+            all_normal_text_list.append(normal_text_dict)
+            all_table_text_list.extend(table_text_list)
+            all_table_html_list.extend(table_html_list)
+        return all_normal_text_list, all_table_text_list, all_table_html_list
+
+    def build_vector(self, visual_info, 
+        min_characters=3500,
+        llm_request_interval=1.0):
+
+        if not isinstance(visual_info, list):
+            visual_info_list = [visual_info]
+        else:
+            visual_info_list = visual_info
+        
+        all_visual_info = self.merge_visual_info_list(visual_info_list)
+        all_normal_text_list, all_table_text_list, all_table_html_list = all_visual_info
+
+        all_normal_text_str = "".join(["\n".join(e.values()) for e in all_normal_text_list])
+        vector_info = {}
+
+        all_items = []
+        for i, normal_text_dict in enumerate(all_normal_text_list):
+            for type, text in normal_text_dict.items():
+                all_items += [f"{type}:{text}"]
+
+        if len(all_normal_text_str) > min_characters:
+            vector_info['flag_too_short_text'] = False
+            vector_info['vector'] = self.retriever.generate_vector_database(
+                all_items)
+        else:
+            vector_info['flag_too_short_text'] = True  
+            vector_info['vector'] = all_items
+        return vector_info
+
+    def format_key(self, key_list):
+        """format key"""
+        if key_list == "":
+            return []
+
+        if isinstance(key_list, list):
+            return key_list
+
+        if isinstance(key_list, str):
+            key_list = re.sub(r"[\t\n\r\f\v]", "", key_list)
+            key_list = key_list.replace(",", ",").split(",")
+            return key_list
+
+        return []
+
+    def fix_llm_result_format(self, llm_result):
+        if not llm_result:
+            return {}
+
+        if "json" in llm_result or "```" in llm_result:
+            llm_result = (
+                llm_result.replace("```", "").replace("json", "").replace("/n", "")
+            )
+            llm_result = llm_result.replace("[", "").replace("]", "")
+
+        try:
+            llm_result = json.loads(llm_result)
+            llm_result_final = {}
+            for key in llm_result:
+                value = llm_result[key]
+                if isinstance(value, list):
+                    if len(value) > 0:
+                        llm_result_final[key] = value[0]
+                else:
+                    llm_result_final[key] = value
+            return llm_result_final
+
+        except:
+            results = (
+                llm_result.replace("\n", "")
+                .replace("    ", "")
+                .replace("{", "")
+                .replace("}", "")
+            )
+            if not results.endswith('"'):
+                results = results + '"'
+            pattern = r'"(.*?)": "([^"]*)"'
+            matches = re.findall(pattern, str(results))
+            if len(matches) > 0:
+                llm_result = {k: v for k, v in matches}
+                return llm_result 
+            else:
+                return {}     
+
+    def generate_and_merge_chat_results(self, prompt, key_list,
+        final_results, failed_results):
+
+        llm_result = self.chat_bot.generate_chat_results(prompt)
+        llm_result = self.fix_llm_result_format(llm_result)
+
+        for key, value in llm_result.items():
+            if value not in failed_results and key in key_list:
+                key_list.remove(key)
+                final_results[key] = value
+        return 
+        
+
+    def chat(self, visual_info, 
+        key_list, 
+        vector_info,
+        text_task_description=None,
+        text_output_format=None,
+        text_rules_str=None,
+        text_few_shot_demo_text_content=None,
+        text_few_shot_demo_key_value_list=None,        
+        table_task_description=None,
+        table_output_format=None,
+        table_rules_str=None,
+        table_few_shot_demo_text_content=None,
+        table_few_shot_demo_key_value_list=None):
+
+        key_list = self.format_key(key_list)
+        if len(key_list) == 0:
+            return {"chat_res": "输入的key_list无效!"}
+
+        if not isinstance(visual_info, list):
+            visual_info_list = [visual_info]
+        else:
+            visual_info_list = visual_info
+        
+        all_visual_info = self.merge_visual_info_list(visual_info_list)
+        all_normal_text_list, all_table_text_list, all_table_html_list = all_visual_info
+
+        final_results = {}
+        failed_results = ["大模型调用失败", "未知", "未找到关键信息", "None", ""]
+
+        for all_table_info in [all_table_html_list, all_table_text_list]:
+            for table_info in all_table_info:
+                if len(key_list) == 0:
+                    continue
+
+                prompt = self.table_pe.generate_prompt(table_info, 
+                    key_list, 
+                    task_description=table_task_description,
+                    output_format=table_output_format, 
+                    rules_str=table_rules_str, 
+                    few_shot_demo_text_content=table_few_shot_demo_text_content, 
+                    few_shot_demo_key_value_list=table_few_shot_demo_key_value_list)
+
+                self.generate_and_merge_chat_results(prompt, 
+                    key_list, final_results, failed_results)
+        
+        if len(key_list) > 0:
+            question_key_list = [f"抽取关键信息:{key}" for key in key_list]
+            vector = vector_info['vector']
+            if not vector_info['flag_too_short_text']:
+                related_text = self.retriever.similarity_retrieval(
+                    question_key_list, vector)
+            else:
+                related_text = " ".join(vector)
+            
+            prompt = self.text_pe.generate_prompt(related_text, 
+                key_list, 
+                task_description=text_task_description,
+                output_format=text_output_format, 
+                rules_str=text_rules_str, 
+                few_shot_demo_text_content=text_few_shot_demo_text_content, 
+                few_shot_demo_key_value_list=text_few_shot_demo_key_value_list)
+            
+            self.generate_and_merge_chat_results(prompt, 
+                key_list, final_results, failed_results)
+
+        return final_results
+
+    def predict(self, *args, **kwargs):
+        logging.error(
+            "PP-ChatOCRv3-doc Pipeline do not support to call `predict()` directly! Please invoke `visual_predict`, `build_vector`, `chat` sequentially to obtain the result."
+        )
+        return

+ 44 - 0
paddlex/inference/pipelines_new/pp_chatocrv3_doc/result.py

@@ -0,0 +1,44 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+import random
+import numpy as np
+import cv2
+import PIL
+from PIL import Image, ImageDraw, ImageFont
+
+from ....utils.fonts import PINGFANG_FONT_FILE_PATH
+from ..components import BaseResult
+
+class VisualInfoResult(BaseResult):
+    """VisualInfoResult"""
+    
+    pass
+
+# class VectorResult(BaseResult, Base64Mixin):
+#     """VisualInfoResult"""
+
+#     def _to_base64(self):
+#         return self["vector"]
+
+
+# class RetrievalResult(BaseResult):
+#     """VisualInfoResult"""
+
+#     pass
+
+
+# class ChatResult(BaseResult):
+#     """VisualInfoResult"""

+ 2 - 0
paddlex/utils/flags.py

@@ -26,6 +26,7 @@ __all__ = [
     "INFER_BENCHMARK_OUTPUT",
     "INFER_BENCHMARK_DATA_SIZE",
     "FLAGS_json_format_model",
+    "USE_NEW_INFERENCE",
 ]
 
 
@@ -46,6 +47,7 @@ DRY_RUN = get_flag_from_env_var("PADDLE_PDX_DRY_RUN", False)
 CHECK_OPTS = get_flag_from_env_var("PADDLE_PDX_CHECK_OPTS", False)
 EAGER_INITIALIZATION = get_flag_from_env_var("PADDLE_PDX_EAGER_INIT", True)
 FLAGS_json_format_model = get_flag_from_env_var("FLAGS_json_format_model", None)
+USE_NEW_INFERENCE = get_flag_from_env_var("USE_NEW_INFERENCE", False)
 
 # Inference Benchmark
 INFER_BENCHMARK = get_flag_from_env_var("PADDLE_PDX_INFER_BENCHMARK", None)

+ 15 - 0
paddlex/utils/fonts/__init__.py

@@ -15,10 +15,25 @@
 
 from pathlib import Path
 
+import PIL
+from PIL import ImageFont
 
 def get_pingfang_file_path():
     """get pingfang font file path"""
     return (Path(__file__).parent / "PingFang-SC-Regular.ttf").resolve().as_posix()
 
+def create_font(txt, sz, font_path):
+    """create font"""
+    font_size = int(sz[1] * 0.8)
+    font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
+    if int(PIL.__version__.split(".")[0]) < 10:
+        length = font.getsize(txt)[0]
+    else:
+        length = font.getlength(txt)
+
+    if length > sz[0]:
+        font_size = int(font_size * sz[0] / length)
+        font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
+    return font
 
 PINGFANG_FONT_FILE_PATH = get_pingfang_file_path()