Browse Source

add the new architecture of pipelines

dyning 11 tháng trước cách đây
mục cha
commit
831b6d21a2
45 tập tin đã thay đổi với 4248 bổ sung1 xóa
  1. 15 0
      api_examples/pipelines/test_doc_preprocessor.py
  2. 19 0
      api_examples/pipelines/test_layout_parsing.py
  3. 13 0
      api_examples/pipelines/test_ocr.py
  4. 37 0
      api_examples/pipelines/test_pp_chatocrv3.py
  5. 13 0
      api_examples/pipelines/test_table_recognition.py
  6. 36 0
      paddlex/configs/pipelines/OCR.yaml
  7. 109 0
      paddlex/configs/pipelines/PP-ChatOCRv3-doc.yaml
  8. 16 0
      paddlex/configs/pipelines/doc_preprocessor.yaml
  9. 56 0
      paddlex/configs/pipelines/layout_parsing.yaml
  10. 5 1
      paddlex/inference/__init__.py
  11. 95 0
      paddlex/inference/pipelines_new/__init__.py
  12. 94 0
      paddlex/inference/pipelines_new/base.py
  13. 19 0
      paddlex/inference/pipelines_new/components/__init__.py
  14. 59 0
      paddlex/inference/pipelines_new/components/base.py
  15. 15 0
      paddlex/inference/pipelines_new/components/chat_server/__init__.py
  16. 31 0
      paddlex/inference/pipelines_new/components/chat_server/base.py
  17. 95 0
      paddlex/inference/pipelines_new/components/chat_server/ernie_bot_chat.py
  18. 17 0
      paddlex/inference/pipelines_new/components/common/__init__.py
  19. 475 0
      paddlex/inference/pipelines_new/components/common/crop_image_regions.py
  20. 939 0
      paddlex/inference/pipelines_new/components/common/seal_det_warp.py
  21. 48 0
      paddlex/inference/pipelines_new/components/common/sort_boxes.py
  22. 15 0
      paddlex/inference/pipelines_new/components/prompt_engeering/__init__.py
  23. 31 0
      paddlex/inference/pipelines_new/components/prompt_engeering/base.py
  24. 100 0
      paddlex/inference/pipelines_new/components/prompt_engeering/generate_kie_prompt.py
  25. 15 0
      paddlex/inference/pipelines_new/components/retriever/__init__.py
  26. 50 0
      paddlex/inference/pipelines_new/components/retriever/base.py
  27. 148 0
      paddlex/inference/pipelines_new/components/retriever/ernie_bot_retriever.py
  28. 13 0
      paddlex/inference/pipelines_new/components/utils/__init__.py
  29. 204 0
      paddlex/inference/pipelines_new/components/utils/mixin.py
  30. 15 0
      paddlex/inference/pipelines_new/doc_preprocessor/__init__.py
  31. 117 0
      paddlex/inference/pipelines_new/doc_preprocessor/pipeline.py
  32. 51 0
      paddlex/inference/pipelines_new/doc_preprocessor/result.py
  33. 15 0
      paddlex/inference/pipelines_new/layout_parsing/__init__.py
  34. 205 0
      paddlex/inference/pipelines_new/layout_parsing/pipeline.py
  35. 97 0
      paddlex/inference/pipelines_new/layout_parsing/result.py
  36. 203 0
      paddlex/inference/pipelines_new/layout_parsing/table_recognition_post_processing.py
  37. 87 0
      paddlex/inference/pipelines_new/layout_parsing/utils.py
  38. 15 0
      paddlex/inference/pipelines_new/ocr/__init__.py
  39. 96 0
      paddlex/inference/pipelines_new/ocr/pipeline.py
  40. 160 0
      paddlex/inference/pipelines_new/ocr/result.py
  41. 15 0
      paddlex/inference/pipelines_new/pp_chatocrv3_doc/__init__.py
  42. 329 0
      paddlex/inference/pipelines_new/pp_chatocrv3_doc/pipeline.py
  43. 44 0
      paddlex/inference/pipelines_new/pp_chatocrv3_doc/result.py
  44. 2 0
      paddlex/utils/flags.py
  45. 15 0
      paddlex/utils/fonts/__init__.py

+ 15 - 0
api_examples/pipelines/test_doc_preprocessor.py

@@ -0,0 +1,15 @@
+
+from paddlex import create_pipeline
+
+pipeline = create_pipeline(pipeline="doc_preprocessor")
+
+test_img_path = "./test_imgs/img_rot180_demo.jpg"
+# test_img_path = "./test_imgs/doc_distort_test.jpg"
+
+output = pipeline.predict(test_img_path, 
+    use_doc_orientation_classify=True,
+    use_doc_unwarping=True)
+
+for res in output:
+    print(res)
+    res.save_to_img("./output")

+ 19 - 0
api_examples/pipelines/test_layout_parsing.py

@@ -0,0 +1,19 @@
+
+from paddlex import create_pipeline
+
+pipeline = create_pipeline(pipeline="layout_parsing")
+
+output = pipeline.predict("./test_imgs/test_layout_parsing.jpg",
+    use_doc_orientation_classify=True,
+    use_doc_unwarping=True,
+    use_common_ocr=True,
+    use_seal_recognition=True,
+    use_table_recognition=True)
+
+# output = pipeline("./test_imgs/demo_paper.png")
+# output = pipeline("./test_imgs/table_recognition.jpg")
+# output = pipeline.predict("./test_imgs/seal_text_det.png")
+# output = pipeline.predict("./test_imgs/img_rot180_demo.jpg")
+for res in output:
+    # print(res)
+    res.save_results("./output")

+ 13 - 0
api_examples/pipelines/test_ocr.py

@@ -0,0 +1,13 @@
+
+from paddlex import create_pipeline
+
+pipeline = create_pipeline(pipeline="OCR")
+
+# output = pipeline.predict("./test_imgs/general_ocr_002.png")
+
+output = pipeline.predict("./test_imgs/seal_text_det.png")
+for res in output:
+    print(res)
+    res.save_to_img("./output")
+
+

+ 37 - 0
api_examples/pipelines/test_pp_chatocrv3.py

@@ -0,0 +1,37 @@
+
+from paddlex import create_pipeline
+
+pipeline = create_pipeline(pipeline="PP-ChatOCRv3-doc")
+
+# img_path = "./test_demo_imgs/vehicle_certificate-1.png"
+# key_list = ['驾驶室准乘人数']
+
+# img_path = "./test_demo_imgs/test_layout_parsing.jpg"
+# key_list = ['3.2的标题']
+
+img_path = "./test_demo_imgs/seal_text_det.png"
+key_list = ['印章上公司']
+
+# visual_predict_res = pipeline.visual_predict(img_path, 
+#     use_doc_orientation_classify=True,
+#     use_doc_unwarping=True,
+#     use_common_ocr=True,
+#     use_seal_recognition=True,
+#     use_table_recognition=True)
+
+# ####[TODO] 增加类别信息
+# visual_info_list = []
+# for res in visual_predict_res:
+#     visual_info_list.append(res["visual_info"])
+
+# pipeline.save_visual_info_list(visual_info_list, "./res_visual_info/visual_info3.json")
+
+visual_info_list = pipeline.load_visual_info_list("./res_visual_info/visual_info3.json")
+
+vector_info = pipeline.build_vector(visual_info_list)
+
+print(vector_info)
+
+final_results = pipeline.chat(visual_info_list, key_list, vector_info)
+
+print(final_results)

+ 13 - 0
api_examples/pipelines/test_table_recognition.py

@@ -0,0 +1,13 @@
+
+from paddlex import create_pipeline
+
+pipeline = create_pipeline(pipeline="table_recognition")
+
+output = pipeline("./test_imgs/table_recognition.jpg")
+for res in output:
+    print(res)
+    res.save_to_img("./output/") ## 保存img格式结果
+    res.save_to_xlsx("./output/") ## 保存表格格式结果
+    res.save_to_html("./output/") ## 保存html结果
+
+

+ 36 - 0
paddlex/configs/pipelines/OCR.yaml

@@ -0,0 +1,36 @@
+
+pipeline_name: OCR
+
+##############################################
+####### Config for Common OCR
+##############################################
+
+input: https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/general_ocr_001.png
+text_type: common
+
+SubModules:
+  TextDetection:
+    model_name: PP-OCRv4_mobile_det
+    model_dir: null
+    batch_size: 1    
+  TextRecognition:
+    model_name: PP-OCRv4_mobile_rec
+    model_dir: null
+    batch_size: 1
+
+##############################################
+####### Config for Seal OCR
+##############################################
+
+# input: https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/seal_text_det.png
+# text_type: seal
+
+# SubModules:
+#   TextDetection:
+#     model_name: PP-OCRv4_mobile_seal_det
+#     model_dir: null
+#     batch_size: 1    
+#   TextRecognition:
+#     model_name: PP-OCRv4_mobile_rec
+#     model_dir: null
+#     batch_size: 1

+ 109 - 0
paddlex/configs/pipelines/PP-ChatOCRv3-doc.yaml

@@ -0,0 +1,109 @@
+
+pipeline_name: PP-ChatOCRv3-doc
+input: https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/demo_paper.png
+
+use_vector_retrieval: True
+
+SubModules:
+  LLM_Chat:
+    model_name: ernie-3.5
+    api_type: qianfan
+    # ak: "api_key" # Set this to a real API key
+    # sk: "secret_key"  # Set this to a real secret key
+    ak: 4iiqB0QfvXTAENgzUwNeDjQ7
+    sk: sHQCw4l5A6jnzbHMa0ZvDi05GT9Qz8tZ
+
+  LLM_Retriever:
+    model_name: ernie-3.5
+    api_type: qianfan
+    # ak: "api_key" # Set this to a real API key
+    # sk: "secret_key"  # Set this to a real secret key
+    ak: 4iiqB0QfvXTAENgzUwNeDjQ7
+    sk: sHQCw4l5A6jnzbHMa0ZvDi05GT9Qz8tZ
+
+  PromptEngneering:
+    KIE_CommonText:
+      task_type: text_kie_prompt
+      task_description: '你现在的任务是从OCR文字识别的结果中提取关键词列表中每一项对应的关键信息。
+          OCR的文字识别结果使用```符号包围,包含所识别出来的文字,顺序在原始图片中从左至右、从上至下。
+          我指定的关键词列表使用[]符号包围。请注意OCR的文字识别结果可能存在长句子换行被切断、不合理的分词、
+          文字被错误合并等问题,你需要结合上下文语义进行综合判断,以抽取准确的关键信息。'
+      output_format: '在返回结果时使用JSON格式,包含多个key-value对,key值为我指定的关键词,value值为所抽取的结果。
+          如果认为OCR识别结果中没有关键词key对应的value,则将value赋值为"未知"。请只输出json格式的结果,
+          并做json格式校验后返回,不要包含其它多余文字!'
+      rules_str:
+      few_shot_demo_text_content:
+      few_shot_demo_key_value_list:
+          
+    KIE_Table:
+      task_type: table_kie_prompt
+      task_description: '你现在的任务是从输入的html格式的表格内容中提取关键词列表中每一项对应的关键信息,
+          表格内容用```符号包围,我指定的关键词列表使用[]符号包围。你需要结合上下文语义进行综合判断,以抽取准确的关键信息。
+          在返回结果时使用JSON格式,包含多个key-value对,key值为我指定的关键词,value值为所抽取的结果。
+          如果认为输入的表格内容中没有关键词key对应的value值,则将value赋值为"未知"。
+          请只输出json格式的结果,并做json格式校验后返回,不要包含其它多余文字!'
+      output_format: '在返回结果时使用JSON格式,包含多个key-value对,key值为我指定的关键词,value值为所抽取的结果。
+          如果认为表格识别结果中没有关键词key对应的value,则将value赋值为"未知"。请只输出json格式的结果,
+          并做json格式校验后返回,不要包含其它多余文字!'
+      rules_str:
+      few_shot_demo_text_content:
+      few_shot_demo_key_value_list:
+
+SubPipelines:
+  LayoutParser:
+    pipeline_name: layout_parsing
+    use_doc_preprocessor: True
+    use_common_ocr: True
+    use_seal_recognition: True
+    use_table_recognition: True
+
+    SubModules:
+      LayoutDetection:
+        model_name: RT-DETR-H_layout_3cls
+        model_dir: null
+        batch_size: 1
+      TableStructurePredictor:
+        model_name: SLANet_plus
+        model_dir: null
+        batch_size: 1
+
+    SubPipelines:
+      DocPreprocessor:
+        pipeline_name: doc_preprocessor
+        use_doc_orientation_classify: True
+        use_doc_unwarping: True
+        SubModules:
+          DocOrientationClassify:
+            model_name: PP-LCNet_x1_0_doc_ori
+            model_dir: null
+            batch_size: 1
+          DocUnwarping:
+            model_name: UVDoc
+            model_dir: null
+            batch_size: 1
+
+      CommonOCR:
+        pipeline_name: OCR
+        text_type: common
+        SubModules:
+          TextDetection:
+            model_name: PP-OCRv4_server_det
+            model_dir: null
+            batch_size: 1    
+          TextRecognition:
+            model_name: PP-OCRv4_server_rec
+            model_dir: null
+            batch_size: 1
+
+      SealOCR:
+        pipeline_name: OCR
+        text_type: seal
+        SubModules:
+          TextDetection:
+            model_name: PP-OCRv4_server_seal_det
+            model_dir: null
+            batch_size: 1    
+          TextRecognition:
+            model_name: PP-OCRv4_server_rec
+            model_dir: null
+            batch_size: 1  

+ 16 - 0
paddlex/configs/pipelines/doc_preprocessor.yaml

@@ -0,0 +1,16 @@
+
+pipeline_name: doc_preprocessor
+input: https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/img_rot180_demo.jpg
+#input: https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/doc_test.jpg
+use_doc_orientation_classify: True
+use_doc_unwarping: True
+
+SubModules:
+  DocOrientationClassify:
+    model_name: PP-LCNet_x1_0_doc_ori
+    model_dir: null
+    batch_size: 1
+  DocUnwarping:
+    model_name: UVDoc
+    model_dir: null
+    batch_size: 1

+ 56 - 0
paddlex/configs/pipelines/layout_parsing.yaml

@@ -0,0 +1,56 @@
+
+pipeline_name: layout_parsing
+input: https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/demo_paper.png
+use_doc_preprocessor: True
+use_common_ocr: True
+use_seal_recognition: True
+use_table_recognition: True
+
+SubModules:
+  LayoutDetection:
+    model_name: RT-DETR-H_layout_3cls
+    model_dir: null
+    batch_size: 1
+  TableStructurePredictor:
+    model_name: SLANet_plus
+    model_dir: null
+    batch_size: 1
+
+SubPipelines:
+  DocPreprocessor:
+    pipeline_name: doc_preprocessor
+    use_doc_orientation_classify: True
+    use_doc_unwarping: True
+    SubModules:
+      DocOrientationClassify:
+        model_name: PP-LCNet_x1_0_doc_ori
+        model_dir: null
+        batch_size: 1
+      DocUnwarping:
+        model_name: UVDoc
+        model_dir: null
+        batch_size: 1
+  CommonOCR:
+    pipeline_name: OCR
+    text_type: common
+    SubModules:
+      TextDetection:
+        model_name: PP-OCRv4_server_det
+        model_dir: null
+        batch_size: 1    
+      TextRecognition:
+        model_name: PP-OCRv4_server_rec
+        model_dir: null
+        batch_size: 1
+  SealOCR:
+    pipeline_name: OCR
+    text_type: seal
+    SubModules:
+      TextDetection:
+        model_name: PP-OCRv4_server_seal_det
+        model_dir: null
+        batch_size: 1    
+      TextRecognition:
+        model_name: PP-OCRv4_server_rec
+        model_dir: null
+        batch_size: 1

+ 5 - 1
paddlex/inference/__init__.py

@@ -13,5 +13,9 @@
 # limitations under the License.
 
 from .models import create_predictor
-from .pipelines import create_pipeline
+from ..utils import flags
+if flags.USE_NEW_INFERENCE:
+    from .pipelines_new import create_pipeline
+else:
+    from .pipelines import create_pipeline
 from .utils.pp_option import PaddlePredictorOption

+ 95 - 0
paddlex/inference/pipelines_new/__init__.py

@@ -0,0 +1,95 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from pathlib import Path
+from typing import Any, Dict, Optional
+from .base import BasePipeline
+from ...utils.config import parse_config
+
+# from .single_model_pipeline import (
+#     _SingleModelPipeline,
+#     ImageClassification,
+#     ObjectDetection,
+#     InstanceSegmentation,
+#     SemanticSegmentation,
+#     TSFc,
+#     TSAd,
+#     TSCls,
+#     MultiLableImageClas,
+#     SmallObjDet,
+#     AnomalyDetection,
+# )
+# from .ocr import OCRPipeline
+# from .formula_recognition import FormulaRecognitionPipeline
+# from .table_recognition import TableRecPipeline
+# from .face_recognition import FaceRecPipeline
+# from .seal_recognition import SealOCRPipeline
+# from .ppchatocrv3 import PPChatOCRPipeline
+# from .layout_parsing import LayoutParsingPipeline
+# from .pp_shitu_v2 import ShiTuV2Pipeline
+# from .attribute_recognition import AttributeRecPipeline
+
+from .ocr import OCRPipeline
+from .doc_preprocessor import DocPreprocessorPipeline
+from .layout_parsing import LayoutParsingPipeline
+from .pp_chatocrv3_doc import PP_ChatOCRv3_doc_Pipeline
+
+def get_pipeline_path(pipeline_name):
+    pipeline_path = (
+        Path(__file__).parent.parent.parent / "configs/pipelines" / f"{pipeline_name}.yaml"
+    ).resolve()
+    if not Path(pipeline_path).exists():
+        return None
+    return pipeline_path
+
+def load_pipeline_config(pipeline_name: str) -> Dict[str, Any]:
+    if not Path(pipeline_name).exists():
+        pipeline_path = get_pipeline_path(pipeline_name)
+        if pipeline_path is None:
+            raise Exception(
+                f"The pipeline ({pipeline_name}) does not exist! Please use a pipeline name or a config file path!"
+            )
+    else:
+        pipeline_path = pipeline_name
+    config = parse_config(pipeline_path)
+    return config
+
+def create_pipeline(
+    pipeline: str,
+    device=None,
+    pp_option=None,
+    use_hpip: bool = False,
+    hpi_params: Optional[Dict[str, Any]] = None,
+    *args,
+    **kwargs,
+) -> BasePipeline:
+    """build model evaluater
+
+    Args:
+        pipeline (str): the pipeline name, that is name of pipeline class
+
+    Returns:
+        BasePipeline: the pipeline, which is subclass of BasePipeline.
+    """
+    pipeline_name = pipeline
+    config = load_pipeline_config(pipeline_name)
+    assert pipeline_name == config["pipeline_name"]
+    pipeline = BasePipeline.get(pipeline_name)(
+        config=config,
+        device=device,
+        pp_option=pp_option,
+        use_hpip=use_hpip,
+        hpi_params=hpi_params)
+    return pipeline
+    

+ 94 - 0
paddlex/inference/pipelines_new/base.py

@@ -0,0 +1,94 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from abc import ABC, abstractmethod
+from ...utils.subclass_register import AutoRegisterABCMetaClass
+import yaml
+import codecs
+from pathlib import Path
+from typing import Any, Dict, Optional
+from ..models import create_predictor
+from .components.chat_server.base import BaseChat
+from .components.retriever.base import BaseRetriever
+from .components.prompt_engeering.base import BaseGeneratePrompt
+
+class BasePipeline(ABC, metaclass=AutoRegisterABCMetaClass):
+    """Base Pipeline"""
+
+    __is_base = True
+
+    def __init__(self,
+        device=None, 
+        pp_option=None, 
+        use_hpip: bool = False, 
+        hpi_params: Optional[Dict[str, Any]] = None) -> None:
+        super().__init__()
+        self.device = device
+        self.pp_option = pp_option
+        self.use_hpip = use_hpip
+        self.hpi_params = hpi_params
+
+    @abstractmethod
+    def predict(self, input, **kwargs):
+        raise NotImplementedError(
+            "The method `predict` has not been implemented yet."
+        )
+    
+    def create_model(self, config):
+
+        model_dir = config['model_dir']
+        if model_dir == None:
+            model_dir = config['model_name']
+
+        model = create_predictor(
+            model_dir,
+            device=self.device,
+            pp_option=self.pp_option,
+            use_hpip=self.use_hpip,
+            hpi_params=self.hpi_params)
+
+        ########### [TODO]支持初始化传参能力
+        if "batch_size" in config:
+            batch_size = config["batch_size"]
+            model.set_predictor(batch_size=batch_size)
+
+        return model
+
+    def create_pipeline(self, config):
+        pipeline_name = config['pipeline_name']
+        pipeline = BasePipeline.get(pipeline_name)(
+            config=config,
+            device=self.device,
+            pp_option=self.pp_option,
+            use_hpip=self.use_hpip,
+            hpi_params=self.hpi_params)
+        return pipeline 
+
+    def create_chat_bot(self, config):
+        model_name = config['model_name']
+        chat_bot = BaseChat.get(model_name)(config)
+        return chat_bot     
+
+    def create_retriever(self, config):
+        model_name = config['model_name']
+        retriever = BaseRetriever.get(model_name)(config)
+        return retriever    
+
+    def create_prompt_engeering(self, config):
+        task_type = config['task_type']
+        pe = BaseGeneratePrompt.get(task_type)(config)
+        return pe       
+
+    def __call__(self, input, **kwargs):
+        return self.predict(input, **kwargs)

+ 19 - 0
paddlex/inference/pipelines_new/components/__init__.py

@@ -0,0 +1,19 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .base import BaseComponent, CVResult, BaseResult
+from .common import SortQuadBoxes
+from .common import CropByPolys
+from .common import CropByBoxes
+from .utils.mixin import HtmlMixin, XlsxMixin

+ 59 - 0
paddlex/inference/pipelines_new/components/base.py

@@ -0,0 +1,59 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from abc import ABC, abstractmethod
+from ....utils.subclass_register import AutoRegisterABCMetaClass
+
+import inspect
+
+from ....utils.func_register import FuncRegister
+from ...utils.io import ImageReader, ImageWriter
+from .utils.mixin import JsonMixin, ImgMixin, StrMixin
+
+class BaseComponent(ABC, metaclass=AutoRegisterABCMetaClass):
+    """Base Component"""
+
+    __is_base = True
+
+    def __init__(self):
+        super().__init__()
+
+    @abstractmethod
+    def __call__(self):
+        raise NotImplementedError(
+            "The component method `__call__` has not been implemented yet.")
+
+
+class BaseResult(dict, StrMixin, JsonMixin):
+    def __init__(self, data):
+        super().__init__(data)
+        self._show_funcs = []
+        StrMixin.__init__(self)
+        JsonMixin.__init__(self)
+
+    def save_all(self, save_path):
+        for func in self._show_funcs:
+            signature = inspect.signature(func)
+            if "save_path" in signature.parameters:
+                func(save_path=save_path)
+            else:
+                func()
+
+class CVResult(BaseResult, ImgMixin):
+    def __init__(self, data):
+        super().__init__(data)
+        ImgMixin.__init__(self, "pillow")
+        self._img_reader = ImageReader(backend="pillow")
+        self._img_writer = ImageWriter(backend="pillow")
+

+ 15 - 0
paddlex/inference/pipelines_new/components/chat_server/__init__.py

@@ -0,0 +1,15 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .ernie_bot_chat import ErnieBotChat

+ 31 - 0
paddlex/inference/pipelines_new/components/chat_server/base.py

@@ -0,0 +1,31 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from abc import ABC, abstractmethod
+from .....utils.subclass_register import AutoRegisterABCMetaClass
+
+import inspect
+
+class BaseChat(ABC, metaclass=AutoRegisterABCMetaClass):
+    """Base Chat"""
+
+    __is_base = True
+
+    def __init__(self):
+        super().__init__()
+
+    @abstractmethod
+    def generate_chat_results(self):
+        raise NotImplementedError(
+            "The method `generate_chat_results` has not been implemented yet.")

+ 95 - 0
paddlex/inference/pipelines_new/components/chat_server/ernie_bot_chat.py

@@ -0,0 +1,95 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .....utils import logging
+from .base import BaseChat
+import erniebot
+
+class ErnieBotChat(BaseChat):
+    """Ernie Bot Chat"""
+
+    entities = [
+        "ernie-4.0",
+        "ernie-3.5",
+        "ernie-3.5-8k",
+        "ernie-lite",
+        "ernie-tiny-8k",
+        "ernie-speed",
+        "ernie-speed-128k",
+        "ernie-char-8k",
+    ]
+
+    def __init__(self, config):
+        super().__init__()
+        model_name = config.get('model_name', None)
+        api_type = config.get('api_type', None)
+        ak = config.get('ak', None)
+        sk = config.get('sk', None)
+        access_token = config.get('access_token', None)
+
+        if model_name not in self.entities:
+            raise ValueError(f"model_name must be in {self.entities} of ErnieBotChat.")
+
+        if api_type not in ["aistudio", "qianfan"]:
+            raise ValueError("api_type must be one of ['aistudio', 'qianfan']")
+
+        if api_type == "aistudio" and access_token is None:
+            raise ValueError("access_token cannot be empty when api_type is aistudio.")
+            
+        if api_type == "qianfan" and (ak is None or sk is None):
+            raise ValueError("ak and sk cannot be empty when api_type is qianfan.")            
+
+        self.model_name = model_name
+        self.config = config
+        
+    def generate_chat_results(self, prompt, temperature=0.001, max_retries=1):
+        """
+        args:
+        return:
+        """
+        try:
+            cur_config = {
+                "api_type": self.config['api_type'],
+                "max_retries": max_retries
+            }
+            if self.config['api_type'] == "aistudio":
+                cur_config['access_token'] = self.config['access_token']
+            elif self.config['api_type'] == "qianfan":
+                cur_config['ak'] = self.config['ak']
+                cur_config['sk'] = self.config['sk']
+            chat_completion = erniebot.ChatCompletion.create(
+                _config_=cur_config,
+                model=self.model_name,
+                messages=[{"role": "user", "content": prompt}],
+                temperature=float(temperature),
+            )
+            llm_result = chat_completion.get_result()
+            return llm_result
+        except Exception as e:
+            if len(e.args) < 1:
+                self.ERROR_MASSAGE = (
+                    "暂无权限访问ErnieBot服务,请检查访问令牌。"
+                )
+            elif (
+                e.args[-1]
+                == "暂无权限使用,请在 AI Studio 正确获取访问令牌(access token)使用"
+            ):
+                self.ERROR_MASSAGE = (
+                    "暂无权限访问ErnieBot服务,请检查访问令牌。"
+                )
+            else:
+                logging.error(e)
+                self.ERROR_MASSAGE = "大模型调用失败"
+        return None 
+        

+ 17 - 0
paddlex/inference/pipelines_new/components/common/__init__.py

@@ -0,0 +1,17 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .sort_boxes import SortQuadBoxes
+from .crop_image_regions import CropByPolys, CropByBoxes
+

+ 475 - 0
paddlex/inference/pipelines_new/components/common/crop_image_regions.py

@@ -0,0 +1,475 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ..base import BaseComponent
+import numpy as np
+from ....utils.io import ImageReader
+import copy
+import cv2
+from .seal_det_warp import AutoRectifier
+from shapely.geometry import Polygon
+from numpy.linalg import norm
+
+class CropByPolys(BaseComponent):
+    """Crop Image by Polys"""
+
+    entities = "CropByPolys"
+
+    def __init__(self, det_box_type="quad"):
+        super().__init__()
+        self.det_box_type = det_box_type
+
+    def __call__(self, img, dt_polys):
+        """__call__"""
+
+        if self.det_box_type == "quad":
+            dt_boxes = np.array(dt_polys)
+            output_list = []
+            for bno in range(len(dt_boxes)):
+                tmp_box = copy.deepcopy(dt_boxes[bno])
+                img_crop = self.get_minarea_rect_crop(img, tmp_box)
+                output_list.append(
+                    {
+                        "img": img_crop,
+                        "img_size": [img_crop.shape[1], img_crop.shape[0]],
+                    }
+                )
+        elif self.det_box_type == "poly":
+            output_list = []
+            dt_boxes = dt_polys
+            for bno in range(len(dt_boxes)):
+                tmp_box = copy.deepcopy(dt_boxes[bno])
+                img_crop = self.get_poly_rect_crop(img.copy(), tmp_box)
+                output_list.append(
+                    {
+                        "img": img_crop,
+                        "img_size": [img_crop.shape[1], img_crop.shape[0]],
+                    }
+                )
+        else:
+            raise NotImplementedError
+
+        return output_list
+
+    def get_minarea_rect_crop(self, img, points):
+        """get_minarea_rect_crop"""
+        bounding_box = cv2.minAreaRect(np.array(points).astype(np.int32))
+        points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
+
+        index_a, index_b, index_c, index_d = 0, 1, 2, 3
+        if points[1][1] > points[0][1]:
+            index_a = 0
+            index_d = 1
+        else:
+            index_a = 1
+            index_d = 0
+        if points[3][1] > points[2][1]:
+            index_b = 2
+            index_c = 3
+        else:
+            index_b = 3
+            index_c = 2
+
+        box = [points[index_a], points[index_b], points[index_c], points[index_d]]
+        crop_img = self.get_rotate_crop_image(img, np.array(box))
+        return crop_img
+
+    def get_rotate_crop_image(self, img, points):
+        """
+        img_height, img_width = img.shape[0:2]
+        left = int(np.min(points[:, 0]))
+        right = int(np.max(points[:, 0]))
+        top = int(np.min(points[:, 1]))
+        bottom = int(np.max(points[:, 1]))
+        img_crop = img[top:bottom, left:right, :].copy()
+        points[:, 0] = points[:, 0] - left
+        points[:, 1] = points[:, 1] - top
+        """
+        assert len(points) == 4, "shape of points must be 4*2"
+        img_crop_width = int(
+            max(
+                np.linalg.norm(points[0] - points[1]),
+                np.linalg.norm(points[2] - points[3]),
+            )
+        )
+        img_crop_height = int(
+            max(
+                np.linalg.norm(points[0] - points[3]),
+                np.linalg.norm(points[1] - points[2]),
+            )
+        )
+        pts_std = np.float32(
+            [
+                [0, 0],
+                [img_crop_width, 0],
+                [img_crop_width, img_crop_height],
+                [0, img_crop_height],
+            ]
+        )
+        M = cv2.getPerspectiveTransform(points, pts_std)
+        dst_img = cv2.warpPerspective(
+            img,
+            M,
+            (img_crop_width, img_crop_height),
+            borderMode=cv2.BORDER_REPLICATE,
+            flags=cv2.INTER_CUBIC,
+        )
+        dst_img_height, dst_img_width = dst_img.shape[0:2]
+        if dst_img_height * 1.0 / dst_img_width >= 1.5:
+            dst_img = np.rot90(dst_img)
+        return dst_img
+
+    def reorder_poly_edge(self, points):
+        """Get the respective points composing head edge, tail edge, top
+        sideline and bottom sideline.
+
+        Args:
+            points (ndarray): The points composing a text polygon.
+
+        Returns:
+            head_edge (ndarray): The two points composing the head edge of text
+                polygon.
+            tail_edge (ndarray): The two points composing the tail edge of text
+                polygon.
+            top_sideline (ndarray): The points composing top curved sideline of
+                text polygon.
+            bot_sideline (ndarray): The points composing bottom curved sideline
+                of text polygon.
+        """
+
+        assert points.ndim == 2
+        assert points.shape[0] >= 4
+        assert points.shape[1] == 2
+
+        orientation_thr = 2.0  # 一个经验超参数
+
+        head_inds, tail_inds = self.find_head_tail(points, orientation_thr)
+        head_edge, tail_edge = points[head_inds], points[tail_inds]
+
+        pad_points = np.vstack([points, points])
+        if tail_inds[1] < 1:
+            tail_inds[1] = len(points)
+        sideline1 = pad_points[head_inds[1] : tail_inds[1]]
+        sideline2 = pad_points[tail_inds[1] : (head_inds[1] + len(points))]
+        return head_edge, tail_edge, sideline1, sideline2
+
+    def vector_slope(self, vec):
+        assert len(vec) == 2
+        return abs(vec[1] / (vec[0] + 1e-8))
+
+    def find_head_tail(self, points, orientation_thr):
+        """Find the head edge and tail edge of a text polygon.
+
+        Args:
+            points (ndarray): The points composing a text polygon.
+            orientation_thr (float): The threshold for distinguishing between
+                head edge and tail edge among the horizontal and vertical edges
+                of a quadrangle.
+
+        Returns:
+            head_inds (list): The indexes of two points composing head edge.
+            tail_inds (list): The indexes of two points composing tail edge.
+        """
+
+        assert points.ndim == 2
+        assert points.shape[0] >= 4
+        assert points.shape[1] == 2
+        assert isinstance(orientation_thr, float)
+
+        if len(points) > 4:
+            pad_points = np.vstack([points, points[0]])
+            edge_vec = pad_points[1:] - pad_points[:-1]
+
+            theta_sum = []
+            adjacent_vec_theta = []
+            for i, edge_vec1 in enumerate(edge_vec):
+                adjacent_ind = [x % len(edge_vec) for x in [i - 1, i + 1]]
+                adjacent_edge_vec = edge_vec[adjacent_ind]
+                temp_theta_sum = np.sum(self.vector_angle(edge_vec1, adjacent_edge_vec))
+                temp_adjacent_theta = self.vector_angle(
+                    adjacent_edge_vec[0], adjacent_edge_vec[1]
+                )
+                theta_sum.append(temp_theta_sum)
+                adjacent_vec_theta.append(temp_adjacent_theta)
+            theta_sum_score = np.array(theta_sum) / np.pi
+            adjacent_theta_score = np.array(adjacent_vec_theta) / np.pi
+            poly_center = np.mean(points, axis=0)
+            edge_dist = np.maximum(
+                norm(pad_points[1:] - poly_center, axis=-1),
+                norm(pad_points[:-1] - poly_center, axis=-1),
+            )
+            dist_score = edge_dist / np.max(edge_dist)
+            position_score = np.zeros(len(edge_vec))
+            score = 0.5 * theta_sum_score + 0.15 * adjacent_theta_score
+            score += 0.35 * dist_score
+            if len(points) % 2 == 0:
+                position_score[(len(score) // 2 - 1)] += 1
+                position_score[-1] += 1
+            score += 0.1 * position_score
+            pad_score = np.concatenate([score, score])
+            score_matrix = np.zeros((len(score), len(score) - 3))
+            x = np.arange(len(score) - 3) / float(len(score) - 4)
+            gaussian = (
+                1.0
+                / (np.sqrt(2.0 * np.pi) * 0.5)
+                * np.exp(-np.power((x - 0.5) / 0.5, 2.0) / 2)
+            )
+            gaussian = gaussian / np.max(gaussian)
+            for i in range(len(score)):
+                score_matrix[i, :] = (
+                    score[i]
+                    + pad_score[(i + 2) : (i + len(score) - 1)] * gaussian * 0.3
+                )
+
+            head_start, tail_increment = np.unravel_index(
+                score_matrix.argmax(), score_matrix.shape
+            )
+            tail_start = (head_start + tail_increment + 2) % len(points)
+            head_end = (head_start + 1) % len(points)
+            tail_end = (tail_start + 1) % len(points)
+
+            if head_end > tail_end:
+                head_start, tail_start = tail_start, head_start
+                head_end, tail_end = tail_end, head_end
+            head_inds = [head_start, head_end]
+            tail_inds = [tail_start, tail_end]
+        else:
+            if self.vector_slope(points[1] - points[0]) + self.vector_slope(
+                points[3] - points[2]
+            ) < self.vector_slope(points[2] - points[1]) + self.vector_slope(
+                points[0] - points[3]
+            ):
+                horizontal_edge_inds = [[0, 1], [2, 3]]
+                vertical_edge_inds = [[3, 0], [1, 2]]
+            else:
+                horizontal_edge_inds = [[3, 0], [1, 2]]
+                vertical_edge_inds = [[0, 1], [2, 3]]
+
+            vertical_len_sum = norm(
+                points[vertical_edge_inds[0][0]] - points[vertical_edge_inds[0][1]]
+            ) + norm(
+                points[vertical_edge_inds[1][0]] - points[vertical_edge_inds[1][1]]
+            )
+            horizontal_len_sum = norm(
+                points[horizontal_edge_inds[0][0]] - points[horizontal_edge_inds[0][1]]
+            ) + norm(
+                points[horizontal_edge_inds[1][0]] - points[horizontal_edge_inds[1][1]]
+            )
+
+            if vertical_len_sum > horizontal_len_sum * orientation_thr:
+                head_inds = horizontal_edge_inds[0]
+                tail_inds = horizontal_edge_inds[1]
+            else:
+                head_inds = vertical_edge_inds[0]
+                tail_inds = vertical_edge_inds[1]
+
+        return head_inds, tail_inds
+
+    def vector_angle(self, vec1, vec2):
+        if vec1.ndim > 1:
+            unit_vec1 = vec1 / (norm(vec1, axis=-1) + 1e-8).reshape((-1, 1))
+        else:
+            unit_vec1 = vec1 / (norm(vec1, axis=-1) + 1e-8)
+        if vec2.ndim > 1:
+            unit_vec2 = vec2 / (norm(vec2, axis=-1) + 1e-8).reshape((-1, 1))
+        else:
+            unit_vec2 = vec2 / (norm(vec2, axis=-1) + 1e-8)
+        return np.arccos(np.clip(np.sum(unit_vec1 * unit_vec2, axis=-1), -1.0, 1.0))
+
+    def get_minarea_rect(self, img, points):
+        bounding_box = cv2.minAreaRect(points)
+        points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
+
+        index_a, index_b, index_c, index_d = 0, 1, 2, 3
+        if points[1][1] > points[0][1]:
+            index_a = 0
+            index_d = 1
+        else:
+            index_a = 1
+            index_d = 0
+        if points[3][1] > points[2][1]:
+            index_b = 2
+            index_c = 3
+        else:
+            index_b = 3
+            index_c = 2
+
+        box = [points[index_a], points[index_b], points[index_c], points[index_d]]
+        crop_img = self.get_rotate_crop_image(img, np.array(box))
+        return crop_img, box
+
+    def sample_points_on_bbox_bp(self, line, n=50):
+        """Resample n points on a line.
+
+        Args:
+            line (ndarray): The points composing a line.
+            n (int): The resampled points number.
+
+        Returns:
+            resampled_line (ndarray): The points composing the resampled line.
+        """
+        from numpy.linalg import norm
+
+        # 断言检查输入参数的有效性
+        assert line.ndim == 2
+        assert line.shape[0] >= 2
+        assert line.shape[1] == 2
+        assert isinstance(n, int)
+        assert n > 0
+
+        length_list = [norm(line[i + 1] - line[i]) for i in range(len(line) - 1)]
+        total_length = sum(length_list)
+        length_cumsum = np.cumsum([0.0] + length_list)
+        delta_length = total_length / (float(n) + 1e-8)
+        current_edge_ind = 0
+        resampled_line = [line[0]]
+
+        for i in range(1, n):
+            current_line_len = i * delta_length
+            while (
+                current_edge_ind + 1 < len(length_cumsum)
+                and current_line_len >= length_cumsum[current_edge_ind + 1]
+            ):
+                current_edge_ind += 1
+            current_edge_end_shift = current_line_len - length_cumsum[current_edge_ind]
+            if current_edge_ind >= len(length_list):
+                break
+            end_shift_ratio = current_edge_end_shift / length_list[current_edge_ind]
+            current_point = (
+                line[current_edge_ind]
+                + (line[current_edge_ind + 1] - line[current_edge_ind])
+                * end_shift_ratio
+            )
+            resampled_line.append(current_point)
+        resampled_line.append(line[-1])
+        resampled_line = np.array(resampled_line)
+        return resampled_line
+
+    def sample_points_on_bbox(self, line, n=50):
+        """Resample n points on a line.
+
+        Args:
+            line (ndarray): The points composing a line.
+            n (int): The resampled points number.
+
+        Returns:
+            resampled_line (ndarray): The points composing the resampled line.
+        """
+        assert line.ndim == 2
+        assert line.shape[0] >= 2
+        assert line.shape[1] == 2
+        assert isinstance(n, int)
+        assert n > 0
+
+        length_list = [norm(line[i + 1] - line[i]) for i in range(len(line) - 1)]
+        total_length = sum(length_list)
+        mean_length = total_length / (len(length_list) + 1e-8)
+        group = [[0]]
+        for i in range(len(length_list)):
+            point_id = i + 1
+            if length_list[i] < 0.9 * mean_length:
+                for g in group:
+                    if i in g:
+                        g.append(point_id)
+                        break
+            else:
+                g = [point_id]
+                group.append(g)
+
+        top_tail_len = norm(line[0] - line[-1])
+        if top_tail_len < 0.9 * mean_length:
+            group[0].extend(g)
+            group.remove(g)
+        mean_positions = []
+        for indices in group:
+            x_sum = 0
+            y_sum = 0
+            for index in indices:
+                x, y = line[index]
+                x_sum += x
+                y_sum += y
+            num_points = len(indices)
+            mean_x = x_sum / num_points
+            mean_y = y_sum / num_points
+            mean_positions.append((mean_x, mean_y))
+        resampled_line = np.array(mean_positions)
+        return resampled_line
+
+    def get_poly_rect_crop(self, img, points):
+        """
+        修改该函数,实现使用polygon,对不规则、弯曲文本的矫正以及crop
+        args: img: 图片 ndarrary格式
+        points: polygon格式的多点坐标 N*2 shape, ndarray格式
+        return: 矫正后的图片 ndarray格式
+        """
+        points = np.array(points).astype(np.int32).reshape(-1, 2)
+        temp_crop_img, temp_box = self.get_minarea_rect(img, points)
+
+        # 计算最小外接矩形与polygon的IoU
+        def get_union(pD, pG):
+            return Polygon(pD).union(Polygon(pG)).area
+
+        def get_intersection_over_union(pD, pG):
+            return get_intersection(pD, pG) / (get_union(pD, pG) + 1e-10)
+
+        def get_intersection(pD, pG):
+            return Polygon(pD).intersection(Polygon(pG)).area
+
+        cal_IoU = get_intersection_over_union(points, temp_box)
+
+        if cal_IoU >= 0.7:
+            points = self.sample_points_on_bbox_bp(points, 31)
+            return temp_crop_img
+
+        points_sample = self.sample_points_on_bbox(points)
+        points_sample = points_sample.astype(np.int32)
+        head_edge, tail_edge, top_line, bot_line = self.reorder_poly_edge(points_sample)
+
+        resample_top_line = self.sample_points_on_bbox_bp(top_line, 15)
+        resample_bot_line = self.sample_points_on_bbox_bp(bot_line, 15)
+
+        sideline_mean_shift = np.mean(resample_top_line, axis=0) - np.mean(
+            resample_bot_line, axis=0
+        )
+        if sideline_mean_shift[1] > 0:
+            resample_bot_line, resample_top_line = resample_top_line, resample_bot_line
+        rectifier = AutoRectifier()
+        new_points = np.concatenate([resample_top_line, resample_bot_line])
+        new_points_list = list(new_points.astype(np.float32).reshape(1, -1).tolist())
+
+        if len(img.shape) == 2:
+            img = np.stack((img,) * 3, axis=-1)
+        img_crop, image = rectifier.run(img, new_points_list, mode="homography")
+        return np.array(img_crop[0], dtype=np.uint8)
+
+
+class CropByBoxes(BaseComponent):
+    """Crop Image by Box"""
+
+    entities = "CropByBoxes"
+
+    def __init__(self):
+        super().__init__()
+
+    def __call__(self, img, boxes):
+        """__call__"""
+        output_list = []
+        for bbox_info in boxes:
+            label_id = bbox_info["cls_id"]
+            box = bbox_info["coordinate"]
+            label = bbox_info.get("label", label_id)
+            xmin, ymin, xmax, ymax = [int(i) for i in box]
+            img_crop = img[ymin:ymax, xmin:xmax].copy()
+            output_list.append({"img": img_crop, "box": box, "label": label})
+        return output_list

+ 939 - 0
paddlex/inference/pipelines_new/components/common/seal_det_warp.py

@@ -0,0 +1,939 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os, sys
+import numpy as np
+from numpy import cos, sin, arctan, sqrt
+import cv2
+import copy
+import time
+
+from .....utils import logging
+
+def Homography(
+    image,
+    img_points,
+    world_width,
+    world_height,
+    interpolation=cv2.INTER_CUBIC,
+    ratio_width=1.0,
+    ratio_height=1.0,
+):
+    _points = np.array(img_points).reshape(-1, 2).astype(np.float32)
+
+    expand_x = int(0.5 * world_width * (ratio_width - 1))
+    expand_y = int(0.5 * world_height * (ratio_height - 1))
+
+    pt_lefttop = [expand_x, expand_y]
+    pt_righttop = [expand_x + world_width, expand_y]
+    pt_leftbottom = [expand_x + world_width, expand_y + world_height]
+    pt_rightbottom = [expand_x, expand_y + world_height]
+
+    pts_std = np.float32([pt_lefttop, pt_righttop, pt_leftbottom, pt_rightbottom])
+
+    img_crop_width = int(world_width * ratio_width)
+    img_crop_height = int(world_height * ratio_height)
+
+    M = cv2.getPerspectiveTransform(_points, pts_std)
+
+    dst_img = cv2.warpPerspective(
+        image,
+        M,
+        (img_crop_width, img_crop_height),
+        borderMode=cv2.BORDER_CONSTANT,  # BORDER_CONSTANT BORDER_REPLICATE
+        flags=interpolation,
+    )
+
+    return dst_img
+
+
+class PlanB:
+    def __call__(
+        self,
+        image,
+        points,
+        curveTextRectifier,
+        interpolation=cv2.INTER_LINEAR,
+        ratio_width=1.0,
+        ratio_height=1.0,
+        loss_thresh=5.0,
+        square=False,
+    ):
+        """
+        Plan B using sub-image when it failed in original image
+        :param image:
+        :param points:
+        :param curveTextRectifier: CurveTextRectifier
+        :param interpolation: cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_AREA, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4
+        :param ratio_width:  roi_image width expansion. It should not be smaller than 1.0
+        :param ratio_height: roi_image height expansion. It should not be smaller than 1.0
+        :param loss_thresh: if loss greater than loss_thresh --> get_rotate_crop_image
+        :param square: crop square image or not. True or False. The default is False
+        :return:
+        """
+        h, w = image.shape[:2]
+        _points = np.array(points).reshape(-1, 2).astype(np.float32)
+        x_min = int(np.min(_points[:, 0]))
+        y_min = int(np.min(_points[:, 1]))
+        x_max = int(np.max(_points[:, 0]))
+        y_max = int(np.max(_points[:, 1]))
+        dx = x_max - x_min
+        dy = y_max - y_min
+        max_d = max(dx, dy)
+        mean_pt = np.mean(_points, 0)
+
+        expand_x = (ratio_width - 1.0) * 0.5 * max_d
+        expand_y = (ratio_height - 1.0) * 0.5 * max_d
+
+        if square:
+            x_min = np.clip(int(mean_pt[0] - max_d - expand_x), 0, w - 1)
+            y_min = np.clip(int(mean_pt[1] - max_d - expand_y), 0, h - 1)
+            x_max = np.clip(int(mean_pt[0] + max_d + expand_x), 0, w - 1)
+            y_max = np.clip(int(mean_pt[1] + max_d + expand_y), 0, h - 1)
+        else:
+            x_min = np.clip(int(x_min - expand_x), 0, w - 1)
+            y_min = np.clip(int(y_min - expand_y), 0, h - 1)
+            x_max = np.clip(int(x_max + expand_x), 0, w - 1)
+            y_max = np.clip(int(y_max + expand_y), 0, h - 1)
+
+        new_image = image[y_min:y_max, x_min:x_max, :].copy()
+        new_points = _points.copy()
+        new_points[:, 0] -= x_min
+        new_points[:, 1] -= y_min
+
+        dst_img, loss = curveTextRectifier(
+            new_image,
+            new_points,
+            interpolation,
+            ratio_width,
+            ratio_height,
+            mode="calibration",
+        )
+
+        return dst_img, loss
+
+
+class CurveTextRectifier:
+    """
+    spatial transformer via monocular vision
+    """
+
+    def __init__(self):
+        self.get_virtual_camera_parameter()
+
+    def get_virtual_camera_parameter(self):
+        vcam_thz = 0
+        vcam_thx1 = 180
+        vcam_thy = 180
+        vcam_thx2 = 0
+
+        vcam_x = 0
+        vcam_y = 0
+        vcam_z = 100
+
+        radian = np.pi / 180
+
+        angle_z = radian * vcam_thz
+        angle_x1 = radian * vcam_thx1
+        angle_y = radian * vcam_thy
+        angle_x2 = radian * vcam_thx2
+
+        optic_x = vcam_x
+        optic_y = vcam_y
+        optic_z = vcam_z
+
+        fu = 100
+        fv = 100
+
+        matT = np.zeros((4, 4))
+        matT[0, 0] = cos(angle_z) * cos(angle_y) - sin(angle_z) * sin(angle_x1) * sin(
+            angle_y
+        )
+        matT[0, 1] = cos(angle_z) * sin(angle_y) * sin(angle_x2) - sin(angle_z) * (
+            cos(angle_x1) * cos(angle_x2) - sin(angle_x1) * cos(angle_y) * sin(angle_x2)
+        )
+        matT[0, 2] = cos(angle_z) * sin(angle_y) * cos(angle_x2) + sin(angle_z) * (
+            cos(angle_x1) * sin(angle_x2) + sin(angle_x1) * cos(angle_y) * cos(angle_x2)
+        )
+        matT[0, 3] = optic_x
+        matT[1, 0] = sin(angle_z) * cos(angle_y) + cos(angle_z) * sin(angle_x1) * sin(
+            angle_y
+        )
+        matT[1, 1] = sin(angle_z) * sin(angle_y) * sin(angle_x2) + cos(angle_z) * (
+            cos(angle_x1) * cos(angle_x2) - sin(angle_x1) * cos(angle_y) * sin(angle_x2)
+        )
+        matT[1, 2] = sin(angle_z) * sin(angle_y) * cos(angle_x2) - cos(angle_z) * (
+            cos(angle_x1) * sin(angle_x2) + sin(angle_x1) * cos(angle_y) * cos(angle_x2)
+        )
+        matT[1, 3] = optic_y
+        matT[2, 0] = -cos(angle_x1) * sin(angle_y)
+        matT[2, 1] = cos(angle_x1) * cos(angle_y) * sin(angle_x2) + sin(angle_x1) * cos(
+            angle_x2
+        )
+        matT[2, 2] = cos(angle_x1) * cos(angle_y) * cos(angle_x2) - sin(angle_x1) * sin(
+            angle_x2
+        )
+        matT[2, 3] = optic_z
+        matT[3, 0] = 0
+        matT[3, 1] = 0
+        matT[3, 2] = 0
+        matT[3, 3] = 1
+
+        matS = np.zeros((4, 4))
+        matS[2, 3] = 0.5
+        matS[3, 2] = 0.5
+
+        self.ifu = 1 / fu
+        self.ifv = 1 / fv
+
+        self.matT = matT
+        self.matS = matS
+        self.K = np.dot(matT.T, matS)
+        self.K = np.dot(self.K, matT)
+
+    def vertical_text_process(self, points, org_size):
+        """
+        change sequence amd process
+        :param points:
+        :param org_size:
+        :return:
+        """
+        org_w, org_h = org_size
+        _points = np.array(points).reshape(-1).tolist()
+        _points = np.array(_points[2:] + _points[:2]).reshape(-1, 2)
+
+        # convert to horizontal points
+        adjusted_points = np.zeros(_points.shape, dtype=np.float32)
+        adjusted_points[:, 0] = _points[:, 1]
+        adjusted_points[:, 1] = org_h - _points[:, 0] - 1
+
+        _image_coord, _world_coord, _new_image_size = self.horizontal_text_process(
+            adjusted_points
+        )
+
+        # # convert to vertical points back
+        image_coord = _points.reshape(1, -1, 2)
+        world_coord = np.zeros(_world_coord.shape, dtype=np.float32)
+        world_coord[:, :, 0] = 0 - _world_coord[:, :, 1]
+        world_coord[:, :, 1] = _world_coord[:, :, 0]
+        world_coord[:, :, 2] = _world_coord[:, :, 2]
+        new_image_size = (_new_image_size[1], _new_image_size[0])
+
+        return image_coord, world_coord, new_image_size
+
+    def horizontal_text_process(self, points):
+        """
+        get image coordinate and world coordinate
+        :param points:
+        :return:
+        """
+        poly = np.array(points).reshape(-1)
+
+        dx_list = []
+        dy_list = []
+        for i in range(1, len(poly) // 2):
+            xdx = poly[i * 2] - poly[(i - 1) * 2]
+            xdy = poly[i * 2 + 1] - poly[(i - 1) * 2 + 1]
+            d = sqrt(xdx**2 + xdy**2)
+            dx_list.append(d)
+
+        for i in range(0, len(poly) // 4):
+            ydx = poly[i * 2] - poly[len(poly) - 1 - (i * 2 + 1)]
+            ydy = poly[i * 2 + 1] - poly[len(poly) - 1 - (i * 2)]
+            d = sqrt(ydx**2 + ydy**2)
+            dy_list.append(d)
+
+        dx_list = [
+            (dx_list[i] + dx_list[len(dx_list) - 1 - i]) / 2
+            for i in range(len(dx_list) // 2)
+        ]
+
+        height = np.around(np.mean(dy_list))
+
+        rect_coord = [0, 0]
+        for i in range(0, len(poly) // 4 - 1):
+            x = rect_coord[-2]
+            x += dx_list[i]
+            y = 0
+            rect_coord.append(x)
+            rect_coord.append(y)
+
+        rect_coord_half = copy.deepcopy(rect_coord)
+        for i in range(0, len(poly) // 4):
+            x = rect_coord_half[len(rect_coord_half) - 2 * i - 2]
+            y = height
+            rect_coord.append(x)
+            rect_coord.append(y)
+
+        np_rect_coord = np.array(rect_coord).reshape(-1, 2)
+        x_min = np.min(np_rect_coord[:, 0])
+        y_min = np.min(np_rect_coord[:, 1])
+        x_max = np.max(np_rect_coord[:, 0])
+        y_max = np.max(np_rect_coord[:, 1])
+        new_image_size = (int(x_max - x_min + 0.5), int(y_max - y_min + 0.5))
+        x_mean = (x_max - x_min) / 2
+        y_mean = (y_max - y_min) / 2
+        np_rect_coord[:, 0] -= x_mean
+        np_rect_coord[:, 1] -= y_mean
+        rect_coord = np_rect_coord.reshape(-1).tolist()
+
+        rect_coord = np.array(rect_coord).reshape(-1, 2)
+        world_coord = np.ones((len(rect_coord), 3)) * 0
+
+        world_coord[:, :2] = rect_coord
+
+        image_coord = np.array(poly).reshape(1, -1, 2)
+        world_coord = world_coord.reshape(1, -1, 3)
+
+        return image_coord, world_coord, new_image_size
+
+    def horizontal_text_estimate(self, points):
+        """
+        horizontal or vertical text
+        :param points:
+        :return:
+        """
+        pts = np.array(points).reshape(-1, 2)
+        x_min = int(np.min(pts[:, 0]))
+        y_min = int(np.min(pts[:, 1]))
+        x_max = int(np.max(pts[:, 0]))
+        y_max = int(np.max(pts[:, 1]))
+        x = x_max - x_min
+        y = y_max - y_min
+        is_horizontal_text = True
+        if y / x > 1.5:  # vertical text condition
+            is_horizontal_text = False
+        return is_horizontal_text
+
+    def virtual_camera_to_world(self, size):
+        ifu, ifv = self.ifu, self.ifv
+        K, matT = self.K, self.matT
+
+        ppu = size[0] / 2 + 1e-6
+        ppv = size[1] / 2 + 1e-6
+
+        P = np.zeros((size[1], size[0], 3))
+
+        lu = np.array([i for i in range(size[0])])
+        lv = np.array([i for i in range(size[1])])
+        u, v = np.meshgrid(lu, lv)
+
+        yp = (v - ppv) * ifv
+        xp = (u - ppu) * ifu
+        angle_a = arctan(sqrt(xp * xp + yp * yp))
+        angle_b = arctan(yp / xp)
+
+        D0 = sin(angle_a) * cos(angle_b)
+        D1 = sin(angle_a) * sin(angle_b)
+        D2 = cos(angle_a)
+
+        D0[xp <= 0] = -D0[xp <= 0]
+        D1[xp <= 0] = -D1[xp <= 0]
+
+        ratio_a = (
+            K[0, 0] * D0 * D0
+            + K[1, 1] * D1 * D1
+            + K[2, 2] * D2 * D2
+            + (K[0, 1] + K[1, 0]) * D0 * D1
+            + (K[0, 2] + K[2, 0]) * D0 * D2
+            + (K[1, 2] + K[2, 1]) * D1 * D2
+        )
+        ratio_b = (
+            (K[0, 3] + K[3, 0]) * D0
+            + (K[1, 3] + K[3, 1]) * D1
+            + (K[2, 3] + K[3, 2]) * D2
+        )
+        ratio_c = K[3, 3] * np.ones(ratio_b.shape)
+
+        delta = ratio_b * ratio_b - 4 * ratio_a * ratio_c
+        t = np.zeros(delta.shape)
+        t[ratio_a == 0] = -ratio_c[ratio_a == 0] / ratio_b[ratio_a == 0]
+        t[ratio_a != 0] = (-ratio_b[ratio_a != 0] + sqrt(delta[ratio_a != 0])) / (
+            2 * ratio_a[ratio_a != 0]
+        )
+        t[delta < 0] = 0
+
+        P[:, :, 0] = matT[0, 3] + t * (
+            matT[0, 0] * D0 + matT[0, 1] * D1 + matT[0, 2] * D2
+        )
+        P[:, :, 1] = matT[1, 3] + t * (
+            matT[1, 0] * D0 + matT[1, 1] * D1 + matT[1, 2] * D2
+        )
+        P[:, :, 2] = matT[2, 3] + t * (
+            matT[2, 0] * D0 + matT[2, 1] * D1 + matT[2, 2] * D2
+        )
+
+        return P
+
+    def world_to_image(self, image_size, world, intrinsic, distCoeffs, rotation, tvec):
+        r11 = rotation[0, 0]
+        r12 = rotation[0, 1]
+        r13 = rotation[0, 2]
+        r21 = rotation[1, 0]
+        r22 = rotation[1, 1]
+        r23 = rotation[1, 2]
+        r31 = rotation[2, 0]
+        r32 = rotation[2, 1]
+        r33 = rotation[2, 2]
+
+        t1 = tvec[0]
+        t2 = tvec[1]
+        t3 = tvec[2]
+
+        k1 = distCoeffs[0]
+        k2 = distCoeffs[1]
+        p1 = distCoeffs[2]
+        p2 = distCoeffs[3]
+        k3 = distCoeffs[4]
+        k4 = distCoeffs[5]
+        k5 = distCoeffs[6]
+        k6 = distCoeffs[7]
+
+        if len(distCoeffs) > 8:
+            s1 = distCoeffs[8]
+            s2 = distCoeffs[9]
+            s3 = distCoeffs[10]
+            s4 = distCoeffs[11]
+        else:
+            s1 = s2 = s3 = s4 = 0
+
+        if len(distCoeffs) > 12:
+            tx = distCoeffs[12]
+            ty = distCoeffs[13]
+        else:
+            tx = ty = 0
+
+        fu = intrinsic[0, 0]
+        fv = intrinsic[1, 1]
+        ppu = intrinsic[0, 2]
+        ppv = intrinsic[1, 2]
+
+        cos_tx = cos(tx)
+        cos_ty = cos(ty)
+        sin_tx = sin(tx)
+        sin_ty = sin(ty)
+
+        tao11 = cos_ty * cos_tx * cos_ty + sin_ty * cos_tx * sin_ty
+        tao12 = cos_ty * cos_tx * sin_ty * sin_tx - sin_ty * cos_tx * cos_ty * sin_tx
+        tao13 = -cos_ty * cos_tx * sin_ty * cos_tx + sin_ty * cos_tx * cos_ty * cos_tx
+        tao21 = -sin_tx * sin_ty
+        tao22 = cos_ty * cos_tx * cos_tx + sin_tx * cos_ty * sin_tx
+        tao23 = cos_ty * cos_tx * sin_tx - sin_tx * cos_ty * cos_tx
+
+        P = np.zeros((image_size[1], image_size[0], 2))
+
+        c3 = r31 * world[:, :, 0] + r32 * world[:, :, 1] + r33 * world[:, :, 2] + t3
+        c1 = r11 * world[:, :, 0] + r12 * world[:, :, 1] + r13 * world[:, :, 2] + t1
+        c2 = r21 * world[:, :, 0] + r22 * world[:, :, 1] + r23 * world[:, :, 2] + t2
+
+        x1 = c1 / c3
+        y1 = c2 / c3
+        x12 = x1 * x1
+        y12 = y1 * y1
+        x1y1 = 2 * x1 * y1
+        r2 = x12 + y12
+        r4 = r2 * r2
+        r6 = r2 * r4
+
+        radial_distortion = (1 + k1 * r2 + k2 * r4 + k3 * r6) / (
+            1 + k4 * r2 + k5 * r4 + k6 * r6
+        )
+        x2 = (
+            x1 * radial_distortion + p1 * x1y1 + p2 * (r2 + 2 * x12) + s1 * r2 + s2 * r4
+        )
+        y2 = (
+            y1 * radial_distortion + p2 * x1y1 + p1 * (r2 + 2 * y12) + s3 * r2 + s4 * r4
+        )
+
+        x3 = tao11 * x2 + tao12 * y2 + tao13
+        y3 = tao21 * x2 + tao22 * y2 + tao23
+
+        P[:, :, 0] = fu * x3 + ppu
+        P[:, :, 1] = fv * y3 + ppv
+        P[c3 <= 0] = 0
+
+        return P
+
+    def spatial_transform(
+        self, image_data, new_image_size, mtx, dist, rvecs, tvecs, interpolation
+    ):
+        rotation, _ = cv2.Rodrigues(rvecs)
+        world_map = self.virtual_camera_to_world(new_image_size)
+        image_map = self.world_to_image(
+            new_image_size, world_map, mtx, dist, rotation, tvecs
+        )
+        image_map = image_map.astype(np.float32)
+        dst = cv2.remap(
+            image_data, image_map[:, :, 0], image_map[:, :, 1], interpolation
+        )
+        return dst
+
+    def calibrate(self, org_size, image_coord, world_coord):
+        """
+        calibration
+        :param org_size:
+        :param image_coord:
+        :param world_coord:
+        :return:
+        """
+        # flag = cv2.CALIB_RATIONAL_MODEL | cv2.CALIB_TILTED_MODEL  | cv2.CALIB_THIN_PRISM_MODEL
+        flag = cv2.CALIB_RATIONAL_MODEL
+        flag2 = cv2.CALIB_RATIONAL_MODEL | cv2.CALIB_TILTED_MODEL
+        flag3 = cv2.CALIB_RATIONAL_MODEL | cv2.CALIB_THIN_PRISM_MODEL
+        flag4 = (
+            cv2.CALIB_RATIONAL_MODEL
+            | cv2.CALIB_ZERO_TANGENT_DIST
+            | cv2.CALIB_FIX_ASPECT_RATIO
+        )
+        flag5 = (
+            cv2.CALIB_RATIONAL_MODEL
+            | cv2.CALIB_TILTED_MODEL
+            | cv2.CALIB_ZERO_TANGENT_DIST
+        )
+        flag6 = cv2.CALIB_RATIONAL_MODEL | cv2.CALIB_FIX_ASPECT_RATIO
+        flag_list = [flag2, flag3, flag4, flag5, flag6]
+
+        ret, mtx, dist, rvecs, tvecs = cv2.calibrateCamera(
+            world_coord.astype(np.float32),
+            image_coord.astype(np.float32),
+            org_size,
+            None,
+            None,
+            flags=flag,
+        )
+        if ret > 2:
+            # strategies
+            min_ret = ret
+            for i, flag in enumerate(flag_list):
+                _ret, _mtx, _dist, _rvecs, _tvecs = cv2.calibrateCamera(
+                    world_coord.astype(np.float32),
+                    image_coord.astype(np.float32),
+                    org_size,
+                    None,
+                    None,
+                    flags=flag,
+                )
+                if _ret < min_ret:
+                    min_ret = _ret
+                    ret, mtx, dist, rvecs, tvecs = _ret, _mtx, _dist, _rvecs, _tvecs
+
+        return ret, mtx, dist, rvecs, tvecs
+
+    def dc_homo(
+        self,
+        img,
+        img_points,
+        obj_points,
+        is_horizontal_text,
+        interpolation=cv2.INTER_LINEAR,
+        ratio_width=1.0,
+        ratio_height=1.0,
+    ):
+        """
+        divide and conquer: homography
+        # ratio_width and ratio_height must be 1.0 here
+        """
+        _img_points = img_points.reshape(-1, 2)
+        _obj_points = obj_points.reshape(-1, 3)
+
+        homo_img_list = []
+        width_list = []
+        height_list = []
+        # divide and conquer
+        for i in range(len(_img_points) // 2 - 1):
+            new_img_points = np.zeros((4, 2)).astype(np.float32)
+            new_obj_points = np.zeros((4, 2)).astype(np.float32)
+
+            new_img_points[0:2, :] = _img_points[i : (i + 2), :2]
+            new_img_points[2:4, :] = _img_points[::-1, :][i : (i + 2), :2][::-1, :]
+
+            new_obj_points[0:2, :] = _obj_points[i : (i + 2), :2]
+            new_obj_points[2:4, :] = _obj_points[::-1, :][i : (i + 2), :2][::-1, :]
+
+            if is_horizontal_text:
+                world_width = np.abs(new_obj_points[1, 0] - new_obj_points[0, 0])
+                world_height = np.abs(new_obj_points[3, 1] - new_obj_points[0, 1])
+            else:
+                world_width = np.abs(new_obj_points[1, 1] - new_obj_points[0, 1])
+                world_height = np.abs(new_obj_points[3, 0] - new_obj_points[0, 0])
+
+            homo_img = Homography(
+                img,
+                new_img_points,
+                world_width,
+                world_height,
+                interpolation=interpolation,
+                ratio_width=ratio_width,
+                ratio_height=ratio_height,
+            )
+
+            homo_img_list.append(homo_img)
+            _h, _w = homo_img.shape[:2]
+            width_list.append(_w)
+            height_list.append(_h)
+
+        # stitching
+        rectified_image = np.zeros((np.max(height_list), sum(width_list), 3)).astype(
+            np.uint8
+        )
+
+        st = 0
+        for homo_img, w, h in zip(homo_img_list, width_list, height_list):
+            rectified_image[:h, st : st + w, :] = homo_img
+            st += w
+
+        if not is_horizontal_text:
+            # vertical rotation
+            rectified_image = np.rot90(rectified_image, 3)
+
+        return rectified_image
+
+    def Homography(
+        self,
+        image,
+        img_points,
+        world_width,
+        world_height,
+        interpolation=cv2.INTER_CUBIC,
+        ratio_width=1.0,
+        ratio_height=1.0,
+    ):
+        _points = np.array(img_points).reshape(-1, 2).astype(np.float32)
+
+        expand_x = int(0.5 * world_width * (ratio_width - 1))
+        expand_y = int(0.5 * world_height * (ratio_height - 1))
+
+        pt_lefttop = [expand_x, expand_y]
+        pt_righttop = [expand_x + world_width, expand_y]
+        pt_leftbottom = [expand_x + world_width, expand_y + world_height]
+        pt_rightbottom = [expand_x, expand_y + world_height]
+
+        pts_std = np.float32([pt_lefttop, pt_righttop, pt_leftbottom, pt_rightbottom])
+
+        img_crop_width = int(world_width * ratio_width)
+        img_crop_height = int(world_height * ratio_height)
+
+        M = cv2.getPerspectiveTransform(_points, pts_std)
+
+        dst_img = cv2.warpPerspective(
+            image,
+            M,
+            (img_crop_width, img_crop_height),
+            borderMode=cv2.BORDER_CONSTANT,  # BORDER_CONSTANT BORDER_REPLICATE
+            flags=interpolation,
+        )
+
+        return dst_img
+
+    def __call__(
+        self,
+        image_data,
+        points,
+        interpolation=cv2.INTER_LINEAR,
+        ratio_width=1.0,
+        ratio_height=1.0,
+        mode="calibration",
+    ):
+        """
+        spatial transform for a poly text
+        :param image_data:
+        :param points: [x1,y1,x2,y2,x3,y3,...], clockwise order, (x1,y1) must be the top-left of first char.
+        :param interpolation: cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_AREA, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4
+        :param ratio_width:  roi_image width expansion. It should not be smaller than 1.0
+        :param ratio_height: roi_image height expansion. It should not be smaller than 1.0
+        :param mode: 'calibration' or 'homography'. when homography, ratio_width and ratio_height must be 1.0
+        :return:
+        """
+        org_h, org_w = image_data.shape[:2]
+        org_size = (org_w, org_h)
+        self.image = image_data
+
+        is_horizontal_text = self.horizontal_text_estimate(points)
+        if is_horizontal_text:
+            image_coord, world_coord, new_image_size = self.horizontal_text_process(
+                points
+            )
+        else:
+            image_coord, world_coord, new_image_size = self.vertical_text_process(
+                points, org_size
+            )
+
+        if mode.lower() == "calibration":
+            ret, mtx, dist, rvecs, tvecs = self.calibrate(
+                org_size, image_coord, world_coord
+            )
+
+            st_size = (
+                int(new_image_size[0] * ratio_width),
+                int(new_image_size[1] * ratio_height),
+            )
+            dst = self.spatial_transform(
+                image_data, st_size, mtx, dist[0], rvecs[0], tvecs[0], interpolation
+            )
+        elif mode.lower() == "homography":
+            # ratio_width and ratio_height must be 1.0 here and ret set to 0.01 without loss manually
+            ret = 0.01
+            dst = self.dc_homo(
+                image_data,
+                image_coord,
+                world_coord,
+                is_horizontal_text,
+                interpolation=interpolation,
+                ratio_width=1.0,
+                ratio_height=1.0,
+            )
+        else:
+            raise ValueError(
+                'mode must be ["calibration", "homography"], but got {}'.format(mode)
+            )
+
+        return dst, ret
+
+
+class AutoRectifier:
+    def __init__(self):
+        self.npoints = 10
+        self.curveTextRectifier = CurveTextRectifier()
+
+    @staticmethod
+    def get_rotate_crop_image(
+        img, points, interpolation=cv2.INTER_CUBIC, ratio_width=1.0, ratio_height=1.0
+    ):
+        """
+        crop or homography
+        :param img:
+        :param points:
+        :param interpolation:
+        :param ratio_width:
+        :param ratio_height:
+        :return:
+        """
+        h, w = img.shape[:2]
+        _points = np.array(points).reshape(-1, 2).astype(np.float32)
+
+        if len(_points) != 4:
+            x_min = int(np.min(_points[:, 0]))
+            y_min = int(np.min(_points[:, 1]))
+            x_max = int(np.max(_points[:, 0]))
+            y_max = int(np.max(_points[:, 1]))
+            dx = x_max - x_min
+            dy = y_max - y_min
+            expand_x = int(0.5 * dx * (ratio_width - 1))
+            expand_y = int(0.5 * dy * (ratio_height - 1))
+            x_min = np.clip(int(x_min - expand_x), 0, w - 1)
+            y_min = np.clip(int(y_min - expand_y), 0, h - 1)
+            x_max = np.clip(int(x_max + expand_x), 0, w - 1)
+            y_max = np.clip(int(y_max + expand_y), 0, h - 1)
+
+            dst_img = img[y_min:y_max, x_min:x_max, :].copy()
+        else:
+            img_crop_width = int(
+                max(
+                    np.linalg.norm(_points[0] - _points[1]),
+                    np.linalg.norm(_points[2] - _points[3]),
+                )
+            )
+            img_crop_height = int(
+                max(
+                    np.linalg.norm(_points[0] - _points[3]),
+                    np.linalg.norm(_points[1] - _points[2]),
+                )
+            )
+
+            dst_img = Homography(
+                img,
+                _points,
+                img_crop_width,
+                img_crop_height,
+                interpolation,
+                ratio_width,
+                ratio_height,
+            )
+
+        return dst_img
+
+    def visualize(self, image_data, points_list):
+        visualization = image_data.copy()
+
+        for box in points_list:
+            box = np.array(box).reshape(-1, 2).astype(np.int32)
+            cv2.drawContours(
+                visualization, [np.array(box).reshape((-1, 1, 2))], -1, (0, 0, 255), 2
+            )
+            for i, p in enumerate(box):
+                if i != 0:
+                    cv2.circle(
+                        visualization,
+                        tuple(p),
+                        radius=1,
+                        color=(255, 0, 0),
+                        thickness=2,
+                    )
+                else:
+                    cv2.circle(
+                        visualization,
+                        tuple(p),
+                        radius=1,
+                        color=(255, 255, 0),
+                        thickness=2,
+                    )
+        return visualization
+
+    def __call__(
+        self,
+        image_data,
+        points,
+        interpolation=cv2.INTER_LINEAR,
+        ratio_width=1.0,
+        ratio_height=1.0,
+        loss_thresh=5.0,
+        mode="calibration",
+    ):
+        """
+        rectification in strategies for a poly text
+        :param image_data:
+        :param points: [x1,y1,x2,y2,x3,y3,...], clockwise order, (x1,y1) must be the top-left of first char.
+        :param interpolation: cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_AREA, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4
+        :param ratio_width:  roi_image width expansion. It should not be smaller than 1.0
+        :param ratio_height: roi_image height expansion. It should not be smaller than 1.0
+        :param loss_thresh: if loss greater than loss_thresh --> get_rotate_crop_image
+        :param mode: 'calibration' or 'homography'. when homography, ratio_width and ratio_height must be 1.0
+        :return:
+        """
+        _points = np.array(points).reshape(-1, 2)
+        if len(_points) >= self.npoints and len(_points) % 2 == 0:
+            try:
+                curveTextRectifier = CurveTextRectifier()
+
+                dst_img, loss = curveTextRectifier(
+                    image_data, points, interpolation, ratio_width, ratio_height, mode
+                )
+                if loss >= 2:
+                    # for robust
+                    # large loss means it cannot be reconstruct correctly, we must find other way to reconstruct
+                    img_list, loss_list = [dst_img], [loss]
+                    _dst_img, _loss = PlanB()(
+                        image_data,
+                        points,
+                        curveTextRectifier,
+                        interpolation,
+                        ratio_width,
+                        ratio_height,
+                        loss_thresh=loss_thresh,
+                        square=True,
+                    )
+                    img_list += [_dst_img]
+                    loss_list += [_loss]
+
+                    _dst_img, _loss = PlanB()(
+                        image_data,
+                        points,
+                        curveTextRectifier,
+                        interpolation,
+                        ratio_width,
+                        ratio_height,
+                        loss_thresh=loss_thresh,
+                        square=False,
+                    )
+                    img_list += [_dst_img]
+                    loss_list += [_loss]
+
+                    min_loss = min(loss_list)
+                    dst_img = img_list[loss_list.index(min_loss)]
+
+                    if min_loss >= loss_thresh:
+                        logging.warning(
+                            "calibration loss: {} is too large for spatial transformer. It is failed. Using get_rotate_crop_image".format(
+                                loss
+                            )
+                        )
+                        dst_img = self.get_rotate_crop_image(
+                            image_data, points, interpolation, ratio_width, ratio_height
+                        )
+            except Exception as e:
+                logging.warning(f"Exception caught: {e}")
+                dst_img = self.get_rotate_crop_image(
+                    image_data, points, interpolation, ratio_width, ratio_height
+                )
+        else:
+            dst_img = self.get_rotate_crop_image(
+                image_data, _points, interpolation, ratio_width, ratio_height
+            )
+
+        return dst_img
+
+    def run(
+        self,
+        image_data,
+        points_list,
+        interpolation=cv2.INTER_LINEAR,
+        ratio_width=1.0,
+        ratio_height=1.0,
+        loss_thresh=5.0,
+        mode="calibration",
+    ):
+        """
+        run for texts in an image
+        :param image_data: numpy.ndarray. The shape is [h, w, 3]
+        :param points_list: [[x1,y1,x2,y2,x3,y3,...], [x1,y1,x2,y2,x3,y3,...], ...], clockwise order, (x1,y1) must be the top-left of first char.
+        :param interpolation: cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_AREA, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4
+        :param ratio_width:  roi_image width expansion. It should not be smaller than 1.0
+        :param ratio_height: roi_image height expansion. It should not be smaller than 1.0
+        :param loss_thresh: if loss greater than loss_thresh --> get_rotate_crop_image
+        :param mode: 'calibration' or 'homography'. when homography, ratio_width and ratio_height must be 1.0
+        :return: res: roi-image list, visualized_image: draw polys in original image
+        """
+        if image_data is None:
+            raise ValueError
+        if not isinstance(points_list, list):
+            raise ValueError
+        for points in points_list:
+            if not isinstance(points, list):
+                raise ValueError
+
+        if ratio_width < 1.0 or ratio_height < 1.0:
+            raise ValueError(
+                "ratio_width and ratio_height cannot be smaller than 1, but got {}",
+                (ratio_width, ratio_height),
+            )
+
+        if mode.lower() != "calibration" and mode.lower() != "homography":
+            raise ValueError(
+                'mode must be ["calibration", "homography"], but got {}'.format(mode)
+            )
+
+        if mode.lower() == "homography" and ratio_width != 1.0 and ratio_height != 1.0:
+            raise ValueError(
+                "ratio_width and ratio_height must be 1.0 when mode is homography, but got mode:{}, ratio:({},{})".format(
+                    mode, ratio_width, ratio_height
+                )
+            )
+
+        res = []
+        for points in points_list:
+            rectified_img = self(
+                image_data,
+                points,
+                interpolation,
+                ratio_width,
+                ratio_height,
+                loss_thresh=loss_thresh,
+                mode=mode,
+            )
+            res.append(rectified_img)
+
+        # visualize
+        visualized_image = self.visualize(image_data, points_list)
+
+        return res, visualized_image

+ 48 - 0
paddlex/inference/pipelines_new/components/common/sort_boxes.py

@@ -0,0 +1,48 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ..base import BaseComponent
+import numpy as np
+
+class SortQuadBoxes(BaseComponent):
+    """SortQuadBoxes Component"""
+
+    entities = "SortQuadBoxes"
+    def __init__(self):
+        super().__init__()
+
+    def __call__(self, dt_polys):
+        """
+        Sort quad boxes in order from top to bottom, left to right
+        args:
+            dt_polys(array):detected quad boxes with shape [4, 2]
+        return:
+            sorted boxes(array) with shape [4, 2]
+        """
+        dt_boxes = np.array(dt_polys)
+        num_boxes = dt_boxes.shape[0]
+        sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
+        _boxes = list(sorted_boxes)
+
+        for i in range(num_boxes - 1):
+            for j in range(i, -1, -1):
+                if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and (
+                    _boxes[j + 1][0][0] < _boxes[j][0][0]
+                ):
+                    tmp = _boxes[j]
+                    _boxes[j] = _boxes[j + 1]
+                    _boxes[j + 1] = tmp
+                else:
+                    break
+        return _boxes

+ 15 - 0
paddlex/inference/pipelines_new/components/prompt_engeering/__init__.py

@@ -0,0 +1,15 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .generate_kie_prompt import GenerateKIEPrompt

+ 31 - 0
paddlex/inference/pipelines_new/components/prompt_engeering/base.py

@@ -0,0 +1,31 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from abc import ABC, abstractmethod
+from .....utils.subclass_register import AutoRegisterABCMetaClass
+
+import inspect
+
+class BaseGeneratePrompt(ABC, metaclass=AutoRegisterABCMetaClass):
+    """Base Chat"""
+
+    __is_base = True
+
+    def __init__(self):
+        super().__init__()
+
+    @abstractmethod
+    def generate_prompt(self):
+        raise NotImplementedError(
+            "The method `generate_prompt` has not been implemented yet.")

+ 100 - 0
paddlex/inference/pipelines_new/components/prompt_engeering/generate_kie_prompt.py

@@ -0,0 +1,100 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .base import BaseGeneratePrompt
+
+class GenerateKIEPrompt(BaseGeneratePrompt):
+    """Generate KIE Prompt"""
+
+    entities = [
+        "text_kie_prompt",
+        "table_kie_prompt"
+    ]
+
+    def __init__(self, config):
+        super().__init__()
+
+        task_type = config.get('task_type', "")
+        task_description = config.get('task_description', "")  
+        output_format = config.get('output_format', "")  
+        rules_str = config.get('rules_str', "")  
+        few_shot_demo_text_content = config.get('few_shot_demo_text_content', "")  
+        few_shot_demo_key_value_list = config.get('few_shot_demo_key_value_list', "")
+
+        if task_description is None:
+            task_description = ""
+        
+        if output_format is None:
+            output_format = ""
+        
+        if rules_str is None:
+            rules_str = ""
+        
+        if few_shot_demo_text_content is None:
+            few_shot_demo_text_content = ""
+        
+        if few_shot_demo_key_value_list is None:
+            few_shot_demo_key_value_list = ""
+
+        if task_type not in self.entities:
+            raise ValueError(f"task type must be in {self.entities} of GenerateKIEPrompt.")
+
+        self.task_type = task_type
+        self.task_description = task_description
+        self.output_format = output_format
+        self.rules_str = rules_str
+        self.few_shot_demo_text_content = few_shot_demo_text_content
+        self.few_shot_demo_key_value_list = few_shot_demo_key_value_list
+        
+    def generate_prompt(self, text_content,
+        key_list,
+        task_description=None,
+        output_format=None,
+        rules_str=None,
+        few_shot_demo_text_content=None,
+        few_shot_demo_key_value_list=None):
+        """
+        args:
+        return:
+        """
+
+        if task_description is None:
+            task_description = self.task_description
+
+        if output_format is None:
+            output_format = self.output_format
+
+        if rules_str is None:
+            rules_str = self.rules_str
+
+        if few_shot_demo_text_content is None:
+            few_shot_demo_text_content = self.few_shot_demo_text_content
+            
+        if few_shot_demo_key_value_list is None:
+            few_shot_demo_key_value_list = self.few_shot_demo_key_value_list
+
+        prompt = f"""{task_description}{output_format}{rules_str}{few_shot_demo_text_content}{few_shot_demo_key_value_list}"""
+        if self.task_type == "table_kie_prompt":
+            prompt += f"""\n结合上面,下面正式开始:\
+                表格内容:```{text_content}```\
+                关键词列表:{key_list}。""".replace(
+                "    ", "")
+        elif self.task_type == "text_kie_prompt":
+            prompt += f"""\n结合上面的例子,下面正式开始:\
+                OCR文字:```{text_content}```\
+                关键词列表:{key_list}。""".replace(
+                "    ", "")
+        else:
+            raise ValueError(f"{self.task_type} is currently not supported.") 
+        return prompt

+ 15 - 0
paddlex/inference/pipelines_new/components/retriever/__init__.py

@@ -0,0 +1,15 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .ernie_bot_retriever import ErnieBotRetriever

+ 50 - 0
paddlex/inference/pipelines_new/components/retriever/base.py

@@ -0,0 +1,50 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from abc import ABC, abstractmethod
+from .....utils.subclass_register import AutoRegisterABCMetaClass
+
+import inspect
+import base64
+
+class BaseRetriever(ABC, metaclass=AutoRegisterABCMetaClass):
+    """Base Retriever"""
+
+    __is_base = True
+
+    VECTOR_STORE_PREFIX = "PADDLEX_VECTOR_STORE"
+
+    def __init__(self):
+        super().__init__()
+
+    @abstractmethod
+    def generate_vector_database(self):
+        raise NotImplementedError(
+            "The method `generate_vector_database` has not been implemented yet.")
+
+    @abstractmethod
+    def similarity_retrieval(self):
+        raise NotImplementedError(
+            "The method `similarity_retrieval` has not been implemented yet.")
+
+    def is_vector_store(self, s):
+        return s.startswith(self.VECTOR_STORE_PREFIX)
+
+    def encode_vector_store(self, vector_store_bytes):
+        return self.VECTOR_STORE_PREFIX + base64.b64encode(vector_store_bytes).decode(
+            "ascii"
+        )
+
+    def decode_vector_store(self, vector_store_str):
+        return base64.b64decode(vector_store_str[len(self.VECTOR_STORE_PREFIX):])

+ 148 - 0
paddlex/inference/pipelines_new/components/retriever/ernie_bot_retriever.py

@@ -0,0 +1,148 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .base import BaseRetriever
+import os
+
+from langchain.docstore.document import Document
+from langchain.text_splitter import RecursiveCharacterTextSplitter
+
+from langchain_community.embeddings import QianfanEmbeddingsEndpoint
+from langchain_community.vectorstores import FAISS
+from langchain_community import vectorstores
+from erniebot_agent.extensions.langchain.embeddings import ErnieEmbeddings
+
+import time
+
+class ErnieBotRetriever(BaseRetriever):
+    """Ernie Bot Retriever"""
+
+    entities = [
+        "ernie-4.0",
+        "ernie-3.5",
+        "ernie-3.5-8k",
+        "ernie-lite",
+        "ernie-tiny-8k",
+        "ernie-speed",
+        "ernie-speed-128k",
+        "ernie-char-8k",
+    ]
+    
+    def __init__(self, config):
+
+        super().__init__()
+
+        model_name = config.get('model_name', None)
+        api_type = config.get('api_type', None)
+        ak = config.get('ak', None)
+        sk = config.get('sk', None)
+        access_token = config.get('access_token', None)
+        
+        if model_name not in self.entities:
+            raise ValueError(f"model_name must be in {self.entities} of ErnieBotChat.")
+
+        if api_type not in ["aistudio", "qianfan"]:
+            raise ValueError("api_type must be one of ['aistudio', 'qianfan']")
+
+        if api_type == "aistudio" and access_token is None:
+            raise ValueError("access_token cannot be empty when api_type is aistudio.")
+            
+        if api_type == "qianfan" and (ak is None or sk is None):
+            raise ValueError("ak and sk cannot be empty when api_type is qianfan.")            
+
+        self.model_name = model_name
+        self.config = config
+        
+    def generate_vector_database(self, text_list, 
+        block_size=300,
+        separators=["\t", "\n", "。", "\n\n", ""],
+        sleep_time=0.5):
+        """
+        args:
+        return:
+        """
+        text_splitter = RecursiveCharacterTextSplitter(
+            chunk_size=block_size, chunk_overlap=20, separators=separators
+        )
+        texts = text_splitter.split_text("\t".join(text_list))
+        all_splits = [Document(page_content=text) for text in texts]
+
+        api_type = self.config["api_type"]
+        if api_type == "qianfan":
+            os.environ["QIANFAN_AK"] = os.environ.get("EB_AK", self.config["ak"])
+            os.environ["QIANFAN_SK"] = os.environ.get("EB_SK", self.config["sk"])
+            user_ak = os.environ.get("EB_AK", self.config["ak"])
+            user_id = hash(user_ak)
+            vectorstore = FAISS.from_documents(
+                documents=all_splits, embedding=QianfanEmbeddingsEndpoint()
+            )
+        elif api_type == "aistudio":
+            token = self.config["access_token"]
+            vectorstore = FAISS.from_documents(
+                documents=all_splits[0:1],
+                embedding=ErnieEmbeddings(aistudio_access_token=token),
+            )
+            #### ErnieEmbeddings.chunk_size = 16
+            step = min(16, len(all_splits) - 1)
+            for shot_splits in [
+                all_splits[i : i + step] for i in range(1, len(all_splits), step)
+            ]:
+                time.sleep(sleep_time)
+                vectorstore_slice = FAISS.from_documents(
+                    documents=shot_splits,
+                    embedding=ErnieEmbeddings(aistudio_access_token=token),
+                )
+                vectorstore.merge_from(vectorstore_slice)
+        else:
+            raise ValueError(f"Unsupported api_type: {api_type}")
+
+        return vectorstore
+
+    def encode_vector_store_to_bytes(self, vectorstore):
+        vectorstore = self.encode_vector_store(vectorstore.serialize_to_bytes())
+        return vectorstore
+    
+    def decode_vector_store_from_bytes(self, vectorstore):
+        if not self.is_vector_store(vectorstore):
+            raise ValueError("The retrieved vectorstore is not for PaddleX.")
+        api_type = self.config["api_type"]
+
+        if api_type == "aistudio":
+            access_token = self.config["access_token"]
+            embeddings = ErnieEmbeddings(aistudio_access_token=access_token)
+        elif api_type == "qianfan":
+            ak = self.config["ak"]
+            sk = self.config["sk"]
+            embeddings = QianfanEmbeddingsEndpoint(qianfan_ak=ak, qianfan_sk=sk)
+        else:
+            raise ValueError(f"Unsupported api_type: {api_type}")
+        vectorstore = vectorstores.FAISS.deserialize_from_bytes(
+            self.decode_vector_store(vector), embeddings
+        )
+        return vectorstore
+
+    def similarity_retrieval(self, query_text_list, vectorstore, sleep_time=0.5):
+        # 根据提问匹配上下文
+        C = []
+        for query_text in query_text_list:
+            QUESTION = query_text
+            time.sleep(sleep_time)
+            docs = vectorstore.similarity_search_with_relevance_scores(QUESTION, k=2)
+            context = [(document.page_content, score) for document, score in docs]
+            context = sorted(context, key=lambda x: x[1])
+            C.extend([x[0] for x in context[::-1]])
+        C = list(set(C))
+        all_C = " ".join(C)
+        return all_C
+        

+ 13 - 0
paddlex/inference/pipelines_new/components/utils/__init__.py

@@ -0,0 +1,13 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.

+ 204 - 0
paddlex/inference/pipelines_new/components/utils/mixin.py

@@ -0,0 +1,204 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from abc import abstractmethod
+import json
+from pathlib import Path
+import numpy as np
+from PIL import Image
+import pandas as pd
+
+from .....utils import logging
+from ....utils.io import (
+    JsonWriter,
+    ImageReader,
+    ImageWriter,
+    CSVWriter,
+    HtmlWriter,
+    XlsxWriter,
+    TextWriter,
+)
+
+
+def _save_list_data(save_func, save_path, data, *args, **kwargs):
+    save_path = Path(save_path)
+    if data is None:
+        return
+    if isinstance(data, list):
+        for idx, single in enumerate(data):
+            save_func(
+                (
+                    save_path.parent / f"{save_path.stem}_{idx}{save_path.suffix}"
+                ).as_posix(),
+                single,
+                *args,
+                **kwargs,
+            )
+    save_func(save_path.as_posix(), data, *args, **kwargs)
+    logging.info(f"The result has been saved in {save_path}.")
+
+
+class StrMixin:
+    @property
+    def str(self):
+        return self._to_str()
+
+    def _to_str(self, data, json_format=False, indent=4, ensure_ascii=False):
+        if json_format:
+            return json.dumps(data.json, indent=indent, ensure_ascii=ensure_ascii)
+        else:
+            return str(data)
+
+    def print(self, json_format=False, indent=4, ensure_ascii=False):
+        str_ = self._to_str(
+            self, json_format=json_format, indent=indent, ensure_ascii=ensure_ascii
+        )
+        logging.info(str_)
+
+
+class JsonMixin:
+    def __init__(self):
+        self._json_writer = JsonWriter()
+        self._show_funcs.append(self.save_to_json)
+
+    def _to_json(self):
+        def _format_data(obj):
+            if isinstance(obj, np.float32):
+                return float(obj)
+            elif isinstance(obj, np.ndarray):
+                return [_format_data(item) for item in obj.tolist()]
+            elif isinstance(obj, pd.DataFrame):
+                return obj.to_json(orient="records", force_ascii=False)
+            elif isinstance(obj, Path):
+                return obj.as_posix()
+            elif isinstance(obj, dict):
+                return type(obj)({k: _format_data(v) for k, v in obj.items()})
+            elif isinstance(obj, (list, tuple)):
+                return [_format_data(i) for i in obj]
+            else:
+                return obj
+
+        return _format_data(self)
+
+    @property
+    def json(self):
+        return self._to_json()
+
+    def save_to_json(self, save_path, indent=4, ensure_ascii=False, *args, **kwargs):
+        if not str(save_path).endswith(".json"):
+            save_path = Path(save_path) / f"{Path(self['input_path']).stem}.json"
+        _save_list_data(
+            self._json_writer.write,
+            save_path,
+            self.json,
+            indent=indent,
+            ensure_ascii=ensure_ascii,
+            *args,
+            **kwargs,
+        )
+
+
+class Base64Mixin:
+    def __init__(self, *args, **kwargs):
+        self._base64_writer = TextWriter(*args, **kwargs)
+        self._show_funcs.append(self.save_to_base64)
+
+    @abstractmethod
+    def _to_base64(self):
+        raise NotImplementedError
+
+    @property
+    def base64(self):
+        return self._to_base64()
+
+    def save_to_base64(self, save_path, *args, **kwargs):
+        if not str(save_path).lower().endswith((".b64")):
+            fp = Path(self["input_path"])
+            save_path = Path(save_path) / f"{fp.stem}{fp.suffix}"
+        _save_list_data(
+            self._base64_writer.write, save_path, self.base64, *args, **kwargs
+        )
+
+
+class ImgMixin:
+    def __init__(self, backend="pillow", *args, **kwargs):
+        self._img_writer = ImageWriter(backend=backend, *args, **kwargs)
+        self._show_funcs.append(self.save_to_img)
+
+    @abstractmethod
+    def _to_img(self):
+        raise NotImplementedError
+
+    @property
+    def img(self):
+        image = self._to_img()
+        # The img must be a PIL.Image obj
+        if isinstance(image, np.ndarray):
+            return Image.fromarray(image)
+        return image
+
+    def save_to_img(self, save_path, *args, **kwargs):
+        if not str(save_path).lower().endswith((".jpg", ".png")):
+            fp = Path(self["input_path"])
+            save_path = Path(save_path) / f"{fp.stem}{fp.suffix}"
+        _save_list_data(self._img_writer.write, save_path, self.img, *args, **kwargs)
+
+
+class CSVMixin:
+    def __init__(self, backend="pandas", *args, **kwargs):
+        self._csv_writer = CSVWriter(backend=backend, *args, **kwargs)
+        self._show_funcs.append(self.save_to_csv)
+
+    @abstractmethod
+    def _to_csv(self):
+        raise NotImplementedError
+
+    def save_to_csv(self, save_path, *args, **kwargs):
+        if not str(save_path).endswith(".csv"):
+            save_path = Path(save_path) / f"{Path(self['input_path']).stem}.csv"
+        _save_list_data(
+            self._csv_writer.write, save_path, self._to_csv(), *args, **kwargs
+        )
+
+
+class HtmlMixin:
+    def __init__(self, *args, **kwargs):
+        self._html_writer = HtmlWriter(*args, **kwargs)
+        self._show_funcs.append(self.save_to_html)
+
+    @property
+    def html(self):
+        return self._to_html()
+
+    def _to_html(self):
+        return self["html"]
+
+    def save_to_html(self, save_path, *args, **kwargs):
+        if not str(save_path).endswith(".html"):
+            save_path = Path(save_path) / f"{Path(self['input_path']).stem}.html"
+        _save_list_data(self._html_writer.write, save_path, self.html, *args, **kwargs)
+
+
+class XlsxMixin:
+    def __init__(self, *args, **kwargs):
+        self._xlsx_writer = XlsxWriter(*args, **kwargs)
+        self._show_funcs.append(self.save_to_xlsx)
+
+    def _to_xlsx(self):
+        return self["html"]
+
+    def save_to_xlsx(self, save_path, *args, **kwargs):
+        if not str(save_path).endswith(".xlsx"):
+            save_path = Path(save_path) / f"{Path(self['input_path']).stem}.xlsx"
+        _save_list_data(self._xlsx_writer.write, save_path, self.html, *args, **kwargs)

+ 15 - 0
paddlex/inference/pipelines_new/doc_preprocessor/__init__.py

@@ -0,0 +1,15 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .pipeline import DocPreprocessorPipeline

+ 117 - 0
paddlex/inference/pipelines_new/doc_preprocessor/pipeline.py

@@ -0,0 +1,117 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ..base import BasePipeline
+from typing import Any, Dict, Optional
+from scipy.ndimage import rotate
+from .result import DocPreprocessorResult
+
+########## [TODO]后续需要更新路径
+from ...components.transforms import ReadImage
+
+class DocPreprocessorPipeline(BasePipeline):
+    """Doc Preprocessor Pipeline"""
+
+    entities = "doc_preprocessor"
+    def __init__(self,
+        config,        
+        device=None,
+        pp_option=None, 
+        use_hpip: bool = False,
+        hpi_params: Optional[Dict[str, Any]] = None):
+        super().__init__(device=device, pp_option=pp_option, 
+            use_hpip=use_hpip, hpi_params=hpi_params)
+        
+        self.use_doc_orientation_classify = True
+        if 'use_doc_orientation_classify' in config:
+            self.use_doc_orientation_classify = config['use_doc_orientation_classify']
+
+        self.use_doc_unwarping = True
+        if 'use_doc_unwarping' in config:
+            self.use_doc_unwarping = config['use_doc_unwarping']
+        
+        if self.use_doc_orientation_classify:
+            doc_ori_classify_config = config['SubModules']["DocOrientationClassify"]
+            self.doc_ori_classify_model = self.create_model(doc_ori_classify_config)
+
+        if self.use_doc_unwarping:
+            doc_unwarping_config = config['SubModules']["DocUnwarping"]
+            self.doc_unwarping_model = self.create_model(doc_unwarping_config)
+        
+        self.img_reader = ReadImage(format="BGR")
+
+    def rotate_image(self, image_array, rotate_angle):
+        """rotate image"""
+        assert (
+            rotate_angle >= 0 and rotate_angle < 360
+        ), "rotate_angle must in [0-360), but get {rotate_angle}."
+        return rotate(image_array, rotate_angle, reshape=True)
+
+    def check_input_params(self, input_params):
+        
+        if input_params['use_doc_orientation_classify'] and \
+            not self.use_doc_orientation_classify:
+            raise ValueError("The model for doc orientation classify is not initialized.")
+
+
+        if input_params['use_doc_unwarping'] and \
+            not self.use_doc_unwarping:
+            raise ValueError("The model for doc unwarping is not initialized.")
+            
+        return 
+
+    def predict(self, input, 
+        use_doc_orientation_classify=True,
+        use_doc_unwarping=False,
+        **kwargs):
+
+        if not isinstance(input, list):
+            input_list = [input]
+        else:
+            input_list = input
+
+        input_params = {"use_doc_orientation_classify":use_doc_orientation_classify,
+            "use_doc_unwarping":use_doc_unwarping}
+        self.check_input_params(input_params)
+
+        img_id = 1
+        for input in input_list:
+            if isinstance(input, str):
+                image_array = next(self.img_reader(input))[0]['img']
+            else:
+                image_array = input
+
+            assert len(image_array.shape) == 3
+
+            if input_params['use_doc_orientation_classify']:
+                pred = next(self.doc_ori_classify_model(image_array))
+                angle = int(pred["label_names"][0])
+                rot_img = self.rotate_image(image_array, angle)
+            else:
+                angle = -1
+                rot_img = image_array
+
+            if input_params['use_doc_unwarping']:
+                output_img = next(self.doc_unwarping_model(rot_img))['doctr_img']
+            else:
+                output_img = rot_img
+
+            single_img_res = {"input_image":image_array,
+                "input_params":input_params,
+                "angle":angle, 
+                "rot_img":rot_img, 
+                "output_img":output_img,
+                "img_id":img_id}
+            img_id += 1
+            yield DocPreprocessorResult(single_img_res)

+ 51 - 0
paddlex/inference/pipelines_new/doc_preprocessor/result.py

@@ -0,0 +1,51 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+import random
+import numpy as np
+import cv2
+import PIL
+from PIL import Image, ImageDraw, ImageFont
+
+from ....utils.fonts import PINGFANG_FONT_FILE_PATH, create_font
+from ..components import CVResult
+
+class DocPreprocessorResult(CVResult):
+
+    def save_to_img(self, save_path, *args, **kwargs):
+        if not str(save_path).lower().endswith((".jpg", ".png")):
+            img_id = self["img_id"]
+            save_path = save_path + "/res_doc_preprocess_%d.jpg" % img_id
+        super().save_to_img(save_path, *args, **kwargs)
+
+    def _to_img(self):
+        """draw doc preprocess result"""
+        image = self["input_image"]
+        angle = self["angle"]
+        rot_img = self["rot_img"]
+        output_img = self["output_img"]
+        h, w = image.shape[0:2]
+        img_show = Image.new("RGB", (w * 3, h + 25), (255, 255, 255))
+        img_show.paste(Image.fromarray(image), (0, 0, w, h))
+        img_show.paste(Image.fromarray(rot_img), (w, 0, w * 2, h))
+        img_show.paste(Image.fromarray(output_img), (w * 2, 0, w * 3, h))
+
+        draw_text = ImageDraw.Draw(img_show)
+        txt_list = ["Original Image", "Rotated Image", "Unwarping Image"]
+        for tno in range(len(txt_list)):
+            txt = txt_list[tno]
+            font = create_font(txt, (w, 20), PINGFANG_FONT_FILE_PATH)
+            draw_text.text([10 + w * tno, h + 2], txt, fill=(0, 0, 0), font=font)
+        return img_show

+ 15 - 0
paddlex/inference/pipelines_new/layout_parsing/__init__.py

@@ -0,0 +1,15 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .pipeline import LayoutParsingPipeline

+ 205 - 0
paddlex/inference/pipelines_new/layout_parsing/pipeline.py

@@ -0,0 +1,205 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ..base import BasePipeline
+from typing import Any, Dict, Optional
+import numpy as np
+import cv2
+from ..components import CropByBoxes
+from .utils import convert_points_to_boxes, get_sub_regions_ocr_res
+from .table_recognition_post_processing import get_table_recognition_res
+
+from .result import LayoutParsingResult
+
+########## [TODO]后续需要更新路径
+from ...components.transforms import ReadImage
+
+class LayoutParsingPipeline(BasePipeline):
+    """Layout Parsing Pipeline"""
+
+    entities = "layout_parsing"
+    def __init__(self,
+        config,        
+        device=None,
+        pp_option=None, 
+        use_hpip: bool = False,
+        hpi_params: Optional[Dict[str, Any]] = None):
+        super().__init__(device=device, pp_option=pp_option, 
+            use_hpip=use_hpip, hpi_params=hpi_params)
+        
+        self.inintial_predictor(config)
+
+        self.img_reader = ReadImage(format="BGR")
+
+        self._crop_by_boxes = CropByBoxes()
+        
+
+    def inintial_predictor(self, config):
+        layout_det_config = config['SubModules']["LayoutDetection"]
+        self.layout_det_model = self.create_model(layout_det_config)
+
+        self.use_doc_preprocessor = False
+        if 'use_doc_preprocessor' in config:
+            self.use_doc_preprocessor = config['use_doc_preprocessor']
+        if self.use_doc_preprocessor:
+            doc_preprocessor_config = config['SubPipelines']['DocPreprocessor']
+            self.doc_preprocessor_pipeline = self.create_pipeline(doc_preprocessor_config)
+        
+        self.use_common_ocr = False
+        if "use_common_ocr" in config:
+            self.use_common_ocr = config['use_common_ocr']
+        if self.use_common_ocr:
+            common_ocr_config = config['SubPipelines']['CommonOCR']
+            self.common_ocr_pipeline = self.create_pipeline(common_ocr_config)
+        
+        self.use_seal_recognition = False
+        if "use_seal_recognition" in config:
+            self.use_seal_recognition = config['use_seal_recognition']
+        if self.use_seal_recognition:
+            seal_ocr_config = config['SubPipelines']['SealOCR']
+            self.seal_ocr_pipeline = self.create_pipeline(seal_ocr_config)            
+        
+        self.use_table_recognition = False
+        if "use_table_recognition" in config:
+            self.use_table_recognition = config['use_table_recognition']
+        if self.use_table_recognition:
+            table_structure_config = config['SubModules']['TableStructurePredictor']
+            self.table_structure_model = self.create_model(table_structure_config)
+            if not self.use_common_ocr:
+                common_ocr_config = config['SubPipelines']['OCR']
+                self.common_ocr_pipeline = self.create_pipeline(common_ocr_config)
+        return 
+
+    def get_text_paragraphs_ocr_res(self, overall_ocr_res, layout_det_res):
+        '''get ocr res of the text paragraphs'''
+        object_boxes = []
+        for box_info in layout_det_res['boxes']:
+            if box_info['label'].lower() in ['image', 'formula', 'table', 'seal']:
+                object_boxes.append(box_info['coordinate'])
+        object_boxes = np.array(object_boxes)
+        return get_sub_regions_ocr_res(overall_ocr_res, object_boxes, flag_within=False)
+
+    def check_input_params(self, input_params):
+
+        if input_params['use_doc_preprocessor'] and not self.use_doc_preprocessor:
+            raise ValueError("The models for doc preprocessor are not initialized.")
+
+        if input_params['use_common_ocr'] and not self.use_common_ocr:
+            raise ValueError("The models for common OCR are not initialized.")
+
+        if input_params['use_seal_recognition'] and not self.use_seal_recognition:
+            raise ValueError("The models for seal recognition are not initialized.")
+
+        if input_params['use_table_recognition'] and not self.use_table_recognition:
+            raise ValueError("The models for table recognition are not initialized.")
+
+        return
+
+    def predict(self, input, 
+        use_doc_orientation_classify=True,
+        use_doc_unwarping=True,
+        use_common_ocr=True,
+        use_seal_recognition=True,
+        use_table_recognition=True,
+        **kwargs):
+
+        if not isinstance(input, list):
+            input_list = [input]
+        else:
+            input_list = input
+        
+        input_params = {"use_doc_preprocessor":self.use_doc_preprocessor,
+            "use_doc_orientation_classify":use_doc_orientation_classify,
+            "use_doc_unwarping":use_doc_unwarping,
+            "use_common_ocr":use_common_ocr,
+            "use_seal_recognition":use_seal_recognition,
+            "use_table_recognition":use_table_recognition}
+            
+        if use_doc_orientation_classify or use_doc_unwarping:
+            input_params['use_doc_preprocessor'] = True
+
+        self.check_input_params(input_params)
+
+        img_id = 1
+        for input in input_list:
+            if isinstance(input, str):
+                image_array = next(self.img_reader(input))[0]['img']
+            else:
+                image_array = input
+
+            assert len(image_array.shape) == 3
+
+            if input_params['use_doc_preprocessor']:
+                doc_preprocessor_res = next(self.doc_preprocessor_pipeline(
+                    image_array, 
+                    use_doc_orientation_classify=use_doc_orientation_classify,
+                    use_doc_unwarping=use_doc_unwarping))
+                doc_preprocessor_image = doc_preprocessor_res['output_img']
+                doc_preprocessor_res['img_id'] = img_id
+            else:
+                doc_preprocessor_res = {}
+                doc_preprocessor_image = image_array
+            
+
+            ########## [TODO]RT-DETR 检测结果有重复
+            layout_det_res = next(self.layout_det_model(doc_preprocessor_image))
+
+            if input_params['use_common_ocr'] or input_params['use_table_recognition']:
+                overall_ocr_res = next(self.common_ocr_pipeline(doc_preprocessor_image))
+                overall_ocr_res['img_id'] = img_id
+                dt_boxes = convert_points_to_boxes(overall_ocr_res['dt_polys'])
+                overall_ocr_res['dt_boxes'] = dt_boxes
+            else:
+                overall_ocr_res = {}
+            
+            text_paragraphs_ocr_res = {}
+            if input_params['use_common_ocr']:
+                text_paragraphs_ocr_res = self.get_text_paragraphs_ocr_res(
+                    overall_ocr_res, layout_det_res)
+                text_paragraphs_ocr_res['img_id'] = img_id
+            
+            table_res_list = []
+            if input_params['use_table_recognition']:
+                table_region_id = 1
+                for box_info in layout_det_res['boxes']:
+                    if box_info['label'].lower() in ['table']:
+                        crop_img_info = self._crop_by_boxes(doc_preprocessor_image, [box_info])
+                        crop_img_info = crop_img_info[0]
+                        table_structure_pred = next(self.table_structure_model(
+                            crop_img_info['img']))
+                        table_recognition_res = get_table_recognition_res(
+                            crop_img_info, table_structure_pred, overall_ocr_res)
+                        table_recognition_res['table_region_id'] = table_region_id
+                        table_region_id += 1
+                        table_res_list.append(table_recognition_res)
+            
+            seal_res_list = []
+            if input_params['use_seal_recognition']:
+                seal_region_id = 1
+                for box_info in layout_det_res['boxes']:
+                    if box_info['label'].lower() in ['seal']:
+                        crop_img_info = self._crop_by_boxes(doc_preprocessor_image, [box_info])
+                        crop_img_info = crop_img_info[0]
+                        seal_ocr_res = next(self.seal_ocr_pipeline(crop_img_info['img']))
+                        seal_ocr_res['seal_region_id'] = seal_region_id
+                        seal_region_id += 1
+                        seal_res_list.append(seal_ocr_res)
+            
+            single_img_res = {"layout_det_res":layout_det_res,
+                "doc_preprocessor_res":doc_preprocessor_res,
+                "text_paragraphs_ocr_res":text_paragraphs_ocr_res,
+                "table_res_list":table_res_list,
+                "seal_res_list":seal_res_list,
+                "input_params":input_params}
+            yield LayoutParsingResult(single_img_res)

+ 97 - 0
paddlex/inference/pipelines_new/layout_parsing/result.py

@@ -0,0 +1,97 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+import random
+import numpy as np
+import cv2
+import PIL
+import os
+from PIL import Image, ImageDraw, ImageFont
+
+from ....utils.fonts import PINGFANG_FONT_FILE_PATH
+from ..components import CVResult, HtmlMixin, XlsxMixin
+
+class TableRecognitionResult(CVResult, HtmlMixin, XlsxMixin):
+    def __init__(self, data):
+        super().__init__(data)
+        HtmlMixin.__init__(self)
+        XlsxMixin.__init__(self)
+
+    def save_to_html(self, save_path, *args, **kwargs):
+        if not str(save_path).lower().endswith(".html"):
+            save_path = save_path + "/res_table_%d.html" % self['table_region_id']
+        super().save_to_html(save_path, *args, **kwargs)
+
+    def _to_html(self):
+        return self["pred_html"]
+
+    def save_to_xlsx(self, save_path, *args, **kwargs):
+        if not str(save_path).lower().endswith(".xlsx"):
+            save_path = save_path + "/res_table_%d.xlsx" % self['table_region_id']
+        super().save_to_xlsx(save_path, *args, **kwargs)
+
+    def _to_xlsx(self):
+        return self["pred_html"]
+
+    def save_to_img(self, save_path, *args, **kwargs):
+        if not str(save_path).lower().endswith((".jpg", ".png")):
+            ocr_save_path = save_path + "/res_table_ocr_%d.jpg" % self['table_region_id']
+            save_path = save_path + "/res_table_cell_%d.jpg" % self['table_region_id']
+        self['table_ocr_pred'].save_to_img(ocr_save_path)
+        super().save_to_img(save_path, *args, **kwargs)
+
+    def _to_img(self):
+        input_img = self['table_ocr_pred']['input_img'].copy()
+        cell_box_list = self['cell_box_list']
+        for box in cell_box_list:
+            x1, y1, x2, y2 = [int(pos) for pos in box]
+            cv2.rectangle(input_img, (x1, y1), (x2, y2), (255, 0, 0), 2)
+        return input_img
+
+class LayoutParsingResult(dict):
+    def __init__(self, data):
+        super().__init__(data)
+    
+    def save_results(self, save_path):
+        if not os.path.isdir(save_path):
+            raise ValueError("The save path should be a dir.")
+
+        layout_det_res = self['layout_det_res']
+        save_img_path = save_path + "/layout_det_result.jpg"
+        layout_det_res.save_to_img(save_img_path)
+
+        input_params = self['input_params']
+        if input_params['use_doc_preprocessor']:
+            save_img_path = save_path + "/doc_preprocessor_result.jpg"
+            self['doc_preprocessor_res'].save_to_img(save_img_path)
+        
+        if input_params['use_common_ocr']:
+            save_img_path = save_path + "/text_paragraphs_ocr_result.jpg"
+            self['text_paragraphs_ocr_res'].save_to_img(save_img_path)
+
+        if input_params['use_table_recognition']:
+            for tno in range(len(self['table_res_list'])):
+                table_res = self['table_res_list'][tno]
+                table_res.save_to_img(save_path)
+                table_res.save_to_html(save_path)
+                table_res.save_to_xlsx(save_path)
+        
+        if input_params['use_seal_recognition']:
+            for sno in range(len(self['seal_res_list'])):
+                seal_res = self['seal_res_list'][sno]
+                save_img_path = save_path + "/seal_%d_recognition_result.jpg" % seal_res['seal_region_id']
+                seal_res.save_to_img(save_img_path)          
+        return
+

+ 203 - 0
paddlex/inference/pipelines_new/layout_parsing/table_recognition_post_processing.py

@@ -0,0 +1,203 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .utils import convert_points_to_boxes, get_sub_regions_ocr_res
+import numpy as np
+from .result import TableRecognitionResult
+
+def get_ori_image_coordinate(x, y, box_list):
+    """
+    get the original coordinate from Cropped image to Original image.
+    Args:
+        x (int): x coordinate of cropped image
+        y (int): y coordinate of cropped image
+        box_list (list): list of table bounding boxes, eg. [[x1, y1, x2, y2, x3, y3, x4, y4]]
+    Returns:
+        list: list of original coordinates, eg. [[x1, y1, x2, y2, x3, y3, x4, y4]]
+    """
+    if not box_list:
+        return box_list
+    offset = np.array([x, y] * 4)
+    box_list = np.array(box_list)
+    if box_list.shape[-1] == 2:
+        offset = offset.reshape(4, 2)
+    ori_box_list = offset + box_list
+    return ori_box_list
+
+def convert_table_structure_pred_bbox(table_structure_pred, 
+    crop_start_point, img_shape):
+
+    cell_points_list = table_structure_pred['bbox']
+    ori_cell_points_list = get_ori_image_coordinate(crop_start_point[0], 
+        crop_start_point[1], cell_points_list)
+    ori_cell_points_list = np.reshape(ori_cell_points_list, (-1, 4, 2))
+    cell_box_list = convert_points_to_boxes(ori_cell_points_list)
+    img_height, img_width = img_shape
+    cell_box_list = np.clip(cell_box_list, 0, 
+        [img_width, img_height, img_width, img_height])
+    table_structure_pred['cell_box_list'] = cell_box_list
+    return 
+
+def distance(box_1, box_2):
+    """
+    compute the distance between two boxes
+
+    Args:
+        box_1 (list): first rectangle box,eg.(x1, y1, x2, y2)
+        box_2 (list): second rectangle box,eg.(x1, y1, x2, y2)
+
+    Returns:
+        int: the distance between two boxes
+    """
+    x1, y1, x2, y2 = box_1
+    x3, y3, x4, y4 = box_2
+    dis = abs(x3 - x1) + abs(y3 - y1) + abs(x4 - x2) + abs(y4 - y2)
+    dis_2 = abs(x3 - x1) + abs(y3 - y1)
+    dis_3 = abs(x4 - x2) + abs(y4 - y2)
+    return dis + min(dis_2, dis_3)
+
+def compute_iou(rec1, rec2):
+    """
+    computing IoU
+    Args:
+        rec1 (list): (x1, y1, x2, y2)
+        rec2 (list): (x1, y1, x2, y2)
+    Returns:
+        float: Intersection over Union
+    """
+    # computing area of each rectangles
+    S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1])
+    S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1])
+
+    # computing the sum_area
+    sum_area = S_rec1 + S_rec2
+
+    # find the each edge of intersect rectangle
+    left_line = max(rec1[0], rec2[0])
+    right_line = min(rec1[2], rec2[2])
+    top_line = max(rec1[1], rec2[1])
+    bottom_line = min(rec1[3], rec2[3])
+
+    # judge if there is an intersect
+    if left_line >= right_line or top_line >= bottom_line:
+        return 0.0
+    else:
+        intersect = (right_line - left_line) * (bottom_line - top_line)
+        return (intersect / (sum_area - intersect)) * 1.0
+
+def match_table_and_ocr(cell_box_list, ocr_dt_boxes):
+    """
+    match table and ocr
+
+    Args:
+        cell_box_list (list): bbox for table cell, 2 points, [left, top, right, bottom]
+        ocr_dt_boxes (list): bbox for ocr, 2 points, [left, top, right, bottom]
+
+    Returns:
+        dict: matched dict, key is table index, value is ocr index
+    """
+    matched = {}
+    for i, ocr_box in enumerate(np.array(ocr_dt_boxes)):
+        ocr_box = ocr_box.astype(np.float32)
+        distances = []
+        for j, table_box in enumerate(cell_box_list):
+            distances.append((distance(table_box, ocr_box), 
+                1.0 - compute_iou(table_box, ocr_box)))  # compute iou and l1 distance
+        sorted_distances = distances.copy()
+        # select det box by iou and l1 distance
+        sorted_distances = sorted(
+            sorted_distances, key=lambda item: (item[1], item[0]))
+        if distances.index(sorted_distances[0]) not in matched.keys():
+            matched[distances.index(sorted_distances[0])] = [i]
+        else:
+            matched[distances.index(sorted_distances[0])].append(i)
+    return matched
+
+def get_html_result(matched_index, ocr_contents, pred_structures):
+    pred_html = []
+    td_index = 0
+    head_structure = pred_structures[0:3]
+    html = "".join(head_structure)
+    table_structure = pred_structures[3:-3]
+    for tag in table_structure:
+        if "</td>" in tag:
+            if "<td></td>" == tag:
+                pred_html.extend("<td>")
+            if td_index in matched_index.keys():
+                b_with = False
+                if (
+                    "<b>" in ocr_contents[matched_index[td_index][0]]
+                    and len(matched_index[td_index]) > 1
+                ):
+                    b_with = True
+                    pred_html.extend("<b>")
+                for i, td_index_index in enumerate(matched_index[td_index]):
+                    content = ocr_contents[td_index_index]
+                    if len(matched_index[td_index]) > 1:
+                        if len(content) == 0:
+                            continue
+                        if content[0] == " ":
+                            content = content[1:]
+                        if "<b>" in content:
+                            content = content[3:]
+                        if "</b>" in content:
+                            content = content[:-4]
+                        if len(content) == 0:
+                            continue
+                        if (
+                            i != len(matched_index[td_index]) - 1
+                            and " " != content[-1]
+                        ):
+                            content += " "
+                    pred_html.extend(content)
+                if b_with:
+                    pred_html.extend("</b>")
+            if "<td></td>" == tag:
+                pred_html.append("</td>")
+            else:
+                pred_html.append(tag)
+            td_index += 1
+        else:
+            pred_html.append(tag)
+    html += "".join(pred_html)
+    end_structure = pred_structures[-3:]
+    html += "".join(end_structure)
+    return html
+
+def get_table_recognition_res(crop_img_info, table_structure_pred, overall_ocr_res):
+    '''get_table_recognition_res'''
+
+    table_box = np.array([crop_img_info['box']])
+    table_ocr_pred = get_sub_regions_ocr_res(overall_ocr_res, table_box)
+
+    crop_start_point = [table_box[0][0], table_box[0][1]]
+    img_shape = overall_ocr_res['input_img'].shape[0:2]
+
+    convert_table_structure_pred_bbox(table_structure_pred, 
+        crop_start_point, img_shape)
+    
+    structures = table_structure_pred["structure"]
+    cell_box_list = table_structure_pred["cell_box_list"]
+    ocr_dt_boxes = table_ocr_pred["dt_boxes"]
+    ocr_text_res = table_ocr_pred["rec_text"]
+
+    matched_index = match_table_and_ocr(cell_box_list, ocr_dt_boxes)
+    pred_html = get_html_result(matched_index, ocr_text_res, structures)
+
+    single_img_res = {"cell_box_list":cell_box_list, 
+        "table_ocr_pred":table_ocr_pred,
+        "pred_html":pred_html}
+    return TableRecognitionResult(single_img_res)
+
+

+ 87 - 0
paddlex/inference/pipelines_new/layout_parsing/utils.py

@@ -0,0 +1,87 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+__all__ = [
+    "convert_points_to_boxes",
+    "get_sub_regions_ocr_res"
+]
+
+import numpy as np
+import copy
+
+def convert_points_to_boxes(dt_polys):
+    if len(dt_polys) > 0:
+        dt_polys_tmp = dt_polys.copy()
+        dt_polys_tmp = np.array(dt_polys_tmp)
+        boxes_left = np.min(dt_polys_tmp[:, :, 0], axis=1)
+        boxes_right = np.max(dt_polys_tmp[:, :, 0], axis=1)
+        boxes_top = np.min(dt_polys_tmp[:, :, 1], axis=1)
+        boxes_bottom = np.max(dt_polys_tmp[:, :, 1], axis=1)
+        dt_boxes = np.array([boxes_left, boxes_top, boxes_right, boxes_bottom])
+        dt_boxes = dt_boxes.T
+    else:
+        dt_boxes = np.array([])
+    return dt_boxes
+
+def get_overlap_boxes_idx(src_boxes, ref_boxes):
+    '''get overlap boxes idx''' 
+    match_idx_list = []
+    src_boxes_num = len(src_boxes)
+    if src_boxes_num > 0 and len(ref_boxes) > 0:
+        for rno in range(len(ref_boxes)):
+            ref_box = ref_boxes[rno]
+            x1 = np.maximum(ref_box[0], src_boxes[:, 0])
+            y1 = np.maximum(ref_box[1], src_boxes[:, 1])
+            x2 = np.minimum(ref_box[2], src_boxes[:, 2])
+            y2 = np.minimum(ref_box[3], src_boxes[:, 3])
+            pub_w = x2 - x1
+            pub_h = y2 - y1
+            match_idx = np.where((pub_w > 3) & (pub_h > 3))[0]
+            match_idx_list.extend(match_idx)                                   
+    return match_idx_list
+
+def get_sub_regions_ocr_res(overall_ocr_res, object_boxes, flag_within=True):
+    """
+    :param flag_within: True (within the object regions), False (outside the object regions)
+    :return:
+    """
+
+    sub_regions_ocr_res = copy.deepcopy(overall_ocr_res)
+    sub_regions_ocr_res['input_img'] = overall_ocr_res['input_img']
+    sub_regions_ocr_res['img_id'] = -1
+    sub_regions_ocr_res['dt_polys'] = []
+    sub_regions_ocr_res['rec_text'] = []
+    sub_regions_ocr_res['rec_score'] = []
+    sub_regions_ocr_res['dt_boxes'] = []
+
+    overall_text_boxes = overall_ocr_res['dt_boxes']
+    match_idx_list = get_overlap_boxes_idx(overall_text_boxes, object_boxes)
+    match_idx_list = list(set(match_idx_list))
+    for box_no in range(len(overall_text_boxes)):
+        if flag_within:
+            if box_no in match_idx_list:
+                flag_match = True
+            else:
+                flag_match = False
+        else:
+            if box_no not in match_idx_list:
+                flag_match = True
+            else:
+                flag_match = False
+        if flag_match:
+            sub_regions_ocr_res['dt_polys'].append(overall_ocr_res['dt_polys'][box_no])
+            sub_regions_ocr_res['rec_text'].append(overall_ocr_res['rec_text'][box_no])
+            sub_regions_ocr_res['rec_score'].append(overall_ocr_res['rec_score'][box_no])
+            sub_regions_ocr_res['dt_boxes'].append(overall_ocr_res['dt_boxes'][box_no])
+    return sub_regions_ocr_res

+ 15 - 0
paddlex/inference/pipelines_new/ocr/__init__.py

@@ -0,0 +1,15 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .pipeline import OCRPipeline

+ 96 - 0
paddlex/inference/pipelines_new/ocr/pipeline.py

@@ -0,0 +1,96 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ..base import BasePipeline
+from typing import Any, Dict, Optional
+from ..components import SortQuadBoxes, CropByPolys
+from .result import OCRResult
+
+########## [TODO]后续需要更新路径
+from ...components.transforms import ReadImage
+
+class OCRPipeline(BasePipeline):
+    """OCR Pipeline"""
+
+    entities = "OCR"
+    def __init__(self,
+        config,        
+        device=None,
+        pp_option=None, 
+        use_hpip: bool = False,
+        hpi_params: Optional[Dict[str, Any]] = None):
+        super().__init__(device=device, pp_option=pp_option, 
+            use_hpip=use_hpip, hpi_params=hpi_params)
+        
+        text_det_model_config = config['SubModules']["TextDetection"]
+        self.text_det_model = self.create_model(text_det_model_config)
+
+        text_rec_model_config = config['SubModules']["TextRecognition"]
+        self.text_rec_model = self.create_model(text_rec_model_config)
+
+        self.text_type = config['text_type']
+
+        self._sort_quad_boxes = SortQuadBoxes()
+
+        if self.text_type == "common":
+            self._crop_by_polys = CropByPolys(det_box_type = "quad")
+        elif self.text_type == "seal":
+            self._crop_by_polys = CropByPolys(det_box_type = "poly")
+        else:
+            raise ValueError("Unsupported text type {}".format(self.text_type))
+
+        self.img_reader = ReadImage(format="BGR")
+
+    def predict(self, input, **kwargs):
+        if not isinstance(input, list):
+            input_list = [input]
+        else:
+            input_list = input
+        img_id = 1
+        for input in input_list:
+            if isinstance(input, str):
+                image_array = next(self.img_reader(input))[0]['img']
+            else:
+                image_array = input
+
+            assert len(image_array.shape) == 3
+
+            det_res = next(self.text_det_model(image_array))
+
+            dt_polys = det_res['dt_polys']
+            dt_scores = det_res['dt_scores']
+
+            ########## [TODO]需要确认检测模块和识别模块过滤阈值等情况
+
+            if self.text_type == "common":
+                dt_polys = self._sort_quad_boxes(dt_polys)
+                
+            single_img_res = {'input_img':image_array, 'dt_polys':dt_polys, \
+                "img_id":img_id, "text_type":self.text_type}
+            img_id += 1
+            single_img_res["rec_text"] = []
+            single_img_res["rec_score"] = []
+            if len(dt_polys) > 0:
+                all_subs_of_img = list(self._crop_by_polys(image_array, dt_polys))
+
+                ########## [TODO]updata in future
+                for sub_img in all_subs_of_img:
+                    sub_img['input'] = sub_img['img']
+                ##########
+
+                for rec_res in self.text_rec_model(all_subs_of_img):
+                    single_img_res["rec_text"].append(rec_res["rec_text"])
+                    single_img_res["rec_score"].append(rec_res["rec_score"])
+
+            yield OCRResult(single_img_res)

+ 160 - 0
paddlex/inference/pipelines_new/ocr/result.py

@@ -0,0 +1,160 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+import random
+import numpy as np
+import cv2
+import PIL
+from PIL import Image, ImageDraw, ImageFont
+
+from ....utils.fonts import PINGFANG_FONT_FILE_PATH
+from ..components import CVResult
+
+class OCRResult(CVResult):
+    def save_to_img(self, save_path, *args, **kwargs):
+        if not str(save_path).lower().endswith((".jpg", ".png")):
+            img_id = self["img_id"]
+            save_path = save_path + "/res_ocr_%d.jpg" % img_id
+        super().save_to_img(save_path, *args, **kwargs)
+
+    def get_minarea_rect(self, points):
+        bounding_box = cv2.minAreaRect(points)
+        points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
+
+        index_a, index_b, index_c, index_d = 0, 1, 2, 3
+        if points[1][1] > points[0][1]:
+            index_a = 0
+            index_d = 1
+        else:
+            index_a = 1
+            index_d = 0
+        if points[3][1] > points[2][1]:
+            index_b = 2
+            index_c = 3
+        else:
+            index_b = 3
+            index_c = 2
+
+        box = np.array(
+            [points[index_a], points[index_b], points[index_c], points[index_d]]
+        ).astype(np.int32)
+
+        return box
+
+    def _to_img(self):
+        """draw ocr result"""
+        # TODO(gaotingquan): mv to postprocess
+        drop_score = 0.5
+
+        boxes = self["dt_polys"]
+        txts = self["rec_text"]
+        scores = self["rec_score"]
+        image = self['input_img']
+        h, w = image.shape[0:2]
+        image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+        img_left = Image.fromarray(image_rgb)
+        img_right = np.ones((h, w, 3), dtype=np.uint8) * 255
+        random.seed(0)
+        draw_left = ImageDraw.Draw(img_left)
+        if txts is None or len(txts) != len(boxes):
+            txts = [None] * len(boxes)
+        for idx, (box, txt) in enumerate(zip(boxes, txts)):
+            try:
+                if scores is not None and scores[idx] < drop_score:
+                    continue
+                color = (
+                    random.randint(0, 255),
+                    random.randint(0, 255),
+                    random.randint(0, 255),
+                )
+                box = np.array(box)
+                if len(box) > 4:
+                    pts = [(x, y) for x, y in box.tolist()]
+                    draw_left.polygon(pts, outline=color, width=8)
+                    box = self.get_minarea_rect(box)
+                    height = int(0.5 * (max(box[:, 1]) - min(box[:, 1])))
+                    box[:2, 1] = np.mean(box[:, 1])
+                    box[2:, 1] = np.mean(box[:, 1]) + min(20, height)
+                draw_left.polygon(box, fill=color)
+                img_right_text = draw_box_txt_fine(
+                    (w, h), box, txt, PINGFANG_FONT_FILE_PATH
+                )
+                pts = np.array(box, np.int32).reshape((-1, 1, 2))
+                cv2.polylines(img_right_text, [pts], True, color, 1)
+                img_right = cv2.bitwise_and(img_right, img_right_text)
+            except:
+                continue
+
+        img_left = Image.blend(Image.fromarray(image_rgb), img_left, 0.5)
+        img_show = Image.new("RGB", (w * 2, h), (255, 255, 255))
+        img_show.paste(img_left, (0, 0, w, h))
+        img_show.paste(Image.fromarray(img_right), (w, 0, w * 2, h))
+        return img_show
+
+
+def draw_box_txt_fine(img_size, box, txt, font_path):
+    """draw box text"""
+    box_height = int(
+        math.sqrt((box[0][0] - box[3][0]) ** 2 + (box[0][1] - box[3][1]) ** 2)
+    )
+    box_width = int(
+        math.sqrt((box[0][0] - box[1][0]) ** 2 + (box[0][1] - box[1][1]) ** 2)
+    )
+
+    if box_height > 2 * box_width and box_height > 30:
+        img_text = Image.new("RGB", (box_height, box_width), (255, 255, 255))
+        draw_text = ImageDraw.Draw(img_text)
+        if txt:
+            font = create_font(txt, (box_height, box_width), font_path)
+            draw_text.text([0, 0], txt, fill=(0, 0, 0), font=font)
+        img_text = img_text.transpose(Image.ROTATE_270)
+    else:
+        img_text = Image.new("RGB", (box_width, box_height), (255, 255, 255))
+        draw_text = ImageDraw.Draw(img_text)
+        if txt:
+            font = create_font(txt, (box_width, box_height), font_path)
+            draw_text.text([0, 0], txt, fill=(0, 0, 0), font=font)
+
+    pts1 = np.float32(
+        [[0, 0], [box_width, 0], [box_width, box_height], [0, box_height]]
+    )
+    pts2 = np.array(box, dtype=np.float32)
+    M = cv2.getPerspectiveTransform(pts1, pts2)
+
+    img_text = np.array(img_text, dtype=np.uint8)
+    img_right_text = cv2.warpPerspective(
+        img_text,
+        M,
+        img_size,
+        flags=cv2.INTER_NEAREST,
+        borderMode=cv2.BORDER_CONSTANT,
+        borderValue=(255, 255, 255),
+    )
+    return img_right_text
+
+
+def create_font(txt, sz, font_path):
+    """create font"""
+    font_size = int(sz[1] * 0.8)
+    font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
+    if int(PIL.__version__.split(".")[0]) < 10:
+        length = font.getsize(txt)[0]
+    else:
+        length = font.getlength(txt)
+
+    if length > sz[0]:
+        font_size = int(font_size * sz[0] / length)
+        font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
+    return font

+ 15 - 0
paddlex/inference/pipelines_new/pp_chatocrv3_doc/__init__.py

@@ -0,0 +1,15 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .pipeline import PP_ChatOCRv3_doc_Pipeline

+ 329 - 0
paddlex/inference/pipelines_new/pp_chatocrv3_doc/pipeline.py

@@ -0,0 +1,329 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ..base import BasePipeline
+
+from typing import Any, Dict, Optional
+
+# import numpy as np
+# import cv2
+from .result import VisualInfoResult
+import re
+
+########## [TODO]后续需要更新路径
+from ...components.transforms import ReadImage
+
+import json
+
+class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
+    """PP-ChatOCRv3-doc Pipeline"""
+
+    entities = "PP-ChatOCRv3-doc"
+    def __init__(self,
+        config,
+        device=None,
+        pp_option=None, 
+        use_hpip: bool = False,
+        hpi_params: Optional[Dict[str, Any]] = None):
+        super().__init__(device=device, pp_option=pp_option, 
+            use_hpip=use_hpip, hpi_params=hpi_params)
+        
+        self.inintial_predictor(config)
+
+        self.img_reader = ReadImage(format="BGR")
+        
+    def inintial_predictor(self, config):
+        # layout_parsing_config = config['SubPipelines']["LayoutParser"]
+        # self.layout_parsing_pipeline = self.create_pipeline(layout_parsing_config)
+
+        chat_bot_config = config['SubModules']['LLM_Chat']
+        self.chat_bot = self.create_chat_bot(chat_bot_config)
+
+        retriever_config = config['SubModules']['LLM_Retriever']
+        self.retriever = self.create_retriever(retriever_config)
+
+        text_pe_config = config['SubModules']['PromptEngneering']['KIE_CommonText']
+        self.text_pe = self.create_prompt_engeering(text_pe_config)
+        
+        table_pe_config = config['SubModules']['PromptEngneering']['KIE_Table']
+        self.table_pe = self.create_prompt_engeering(table_pe_config)
+
+        return 
+
+    def decode_visual_result(self, layout_parsing_result):
+        text_paragraphs_ocr_res = layout_parsing_result['text_paragraphs_ocr_res']
+        seal_res_list = layout_parsing_result['seal_res_list']
+        normal_text_dict = {}
+        layout_type = "text"
+        for text in text_paragraphs_ocr_res['rec_text']:
+            if layout_type not in normal_text_dict:
+                normal_text_dict[layout_type] = text
+            else:
+                normal_text_dict[layout_type] += f"\n {text}"
+        
+        layout_type = "seal"
+        for seal_res in seal_res_list:
+            for text in seal_res['rec_text']:
+                if layout_type not in normal_text_dict:
+                    normal_text_dict[layout_type] = text
+                else:
+                    normal_text_dict[layout_type] += f"\n {text}"
+
+        table_res_list = layout_parsing_result['table_res_list']
+        table_text_list = []
+        table_html_list = []
+        for table_res in table_res_list:
+            table_html_list.append(table_res['pred_html'])
+            single_table_text = " ".join(table_res["table_ocr_pred"]['rec_text'])
+            table_text_list.append(single_table_text)
+
+        visual_info = {}
+        visual_info['normal_text_dict'] = normal_text_dict
+        visual_info['table_text_list'] = table_text_list
+        visual_info['table_html_list'] = table_html_list
+        return VisualInfoResult(visual_info)
+
+    def visual_predict(self, input,
+        use_doc_orientation_classify=True,
+        use_doc_unwarping=True,
+        use_common_ocr=True,
+        use_seal_recognition=True,
+        use_table_recognition=True,
+        **kwargs):
+
+        if not isinstance(input, list):
+            input_list = [input]
+        else:
+            input_list = input
+
+        img_id = 1
+        for input in input_list:
+            if isinstance(input, str):
+                image_array = next(self.img_reader(input))[0]['img']
+            else:
+                image_array = input
+
+            assert len(image_array.shape) == 3
+
+            layout_parsing_result = next(self.layout_parsing_pipeline.predict(
+                image_array,
+                use_doc_orientation_classify=use_doc_orientation_classify,
+                use_doc_unwarping=use_doc_unwarping,
+                use_common_ocr=use_common_ocr,
+                use_seal_recognition=use_seal_recognition,
+                use_table_recognition=use_table_recognition))
+            
+            visual_info = self.decode_visual_result(layout_parsing_result)
+
+            visual_predict_res = {"layout_parsing_result":layout_parsing_result,
+                "visual_info":visual_info}
+            yield visual_predict_res
+
+    def save_visual_info_list(self, visual_info, save_path):
+        if not isinstance(visual_info, list):
+            visual_info_list = [visual_info]
+        else:
+            visual_info_list = visual_info
+
+        with open(save_path, "w") as fout:
+            fout.write(json.dumps(visual_info_list, ensure_ascii=False) + "\n")
+        return
+    
+    def load_visual_info_list(self, data_path):
+        with open(data_path, "r") as fin:
+            data = fin.readline()
+            visual_info_list = json.loads(data)
+        return visual_info_list
+
+    def merge_visual_info_list(self, visual_info_list):
+        all_normal_text_list = []
+        all_table_text_list = []
+        all_table_html_list = []
+        for single_visual_info in visual_info_list:
+            normal_text_dict = single_visual_info['normal_text_dict']
+            table_text_list = single_visual_info['table_text_list']
+            table_html_list = single_visual_info['table_html_list']
+            all_normal_text_list.append(normal_text_dict)
+            all_table_text_list.extend(table_text_list)
+            all_table_html_list.extend(table_html_list)
+        return all_normal_text_list, all_table_text_list, all_table_html_list
+
+    def build_vector(self, visual_info, 
+        min_characters=3500,
+        llm_request_interval=1.0):
+
+        if not isinstance(visual_info, list):
+            visual_info_list = [visual_info]
+        else:
+            visual_info_list = visual_info
+        
+        all_visual_info = self.merge_visual_info_list(visual_info_list)
+        all_normal_text_list, all_table_text_list, all_table_html_list = all_visual_info
+
+        all_normal_text_str = "".join(["\n".join(e.values()) for e in all_normal_text_list])
+        vector_info = {}
+
+        all_items = []
+        for i, normal_text_dict in enumerate(all_normal_text_list):
+            for type, text in normal_text_dict.items():
+                all_items += [f"{type}:{text}"]
+
+        if len(all_normal_text_str) > min_characters:
+            vector_info['flag_too_short_text'] = False
+            vector_info['vector'] = self.retriever.generate_vector_database(
+                all_items)
+        else:
+            vector_info['flag_too_short_text'] = True  
+            vector_info['vector'] = all_items
+        return vector_info
+
+    def format_key(self, key_list):
+        """format key"""
+        if key_list == "":
+            return []
+
+        if isinstance(key_list, list):
+            return key_list
+
+        if isinstance(key_list, str):
+            key_list = re.sub(r"[\t\n\r\f\v]", "", key_list)
+            key_list = key_list.replace(",", ",").split(",")
+            return key_list
+
+        return []
+
+    def fix_llm_result_format(self, llm_result):
+        if not llm_result:
+            return {}
+
+        if "json" in llm_result or "```" in llm_result:
+            llm_result = (
+                llm_result.replace("```", "").replace("json", "").replace("/n", "")
+            )
+            llm_result = llm_result.replace("[", "").replace("]", "")
+
+        try:
+            llm_result = json.loads(llm_result)
+            llm_result_final = {}
+            for key in llm_result:
+                value = llm_result[key]
+                if isinstance(value, list):
+                    if len(value) > 0:
+                        llm_result_final[key] = value[0]
+                else:
+                    llm_result_final[key] = value
+            return llm_result_final
+
+        except:
+            results = (
+                llm_result.replace("\n", "")
+                .replace("    ", "")
+                .replace("{", "")
+                .replace("}", "")
+            )
+            if not results.endswith('"'):
+                results = results + '"'
+            pattern = r'"(.*?)": "([^"]*)"'
+            matches = re.findall(pattern, str(results))
+            if len(matches) > 0:
+                llm_result = {k: v for k, v in matches}
+                return llm_result 
+            else:
+                return {}     
+
+    def generate_and_merge_chat_results(self, prompt, key_list,
+        final_results, failed_results):
+
+        llm_result = self.chat_bot.generate_chat_results(prompt)
+        llm_result = self.fix_llm_result_format(llm_result)
+
+        for key, value in llm_result.items():
+            if value not in failed_results and key in key_list:
+                key_list.remove(key)
+                final_results[key] = value
+        return 
+        
+
+    def chat(self, visual_info, 
+        key_list, 
+        vector_info,
+        text_task_description=None,
+        text_output_format=None,
+        text_rules_str=None,
+        text_few_shot_demo_text_content=None,
+        text_few_shot_demo_key_value_list=None,        
+        table_task_description=None,
+        table_output_format=None,
+        table_rules_str=None,
+        table_few_shot_demo_text_content=None,
+        table_few_shot_demo_key_value_list=None):
+
+        key_list = self.format_key(key_list)
+        if len(key_list) == 0:
+            return {"chat_res": "输入的key_list无效!"}
+
+        if not isinstance(visual_info, list):
+            visual_info_list = [visual_info]
+        else:
+            visual_info_list = visual_info
+        
+        all_visual_info = self.merge_visual_info_list(visual_info_list)
+        all_normal_text_list, all_table_text_list, all_table_html_list = all_visual_info
+
+        final_results = {}
+        failed_results = ["大模型调用失败", "未知", "未找到关键信息", "None", ""]
+
+        for all_table_info in [all_table_html_list, all_table_text_list]:
+            for table_info in all_table_info:
+                if len(key_list) == 0:
+                    continue
+
+                prompt = self.table_pe.generate_prompt(table_info, 
+                    key_list, 
+                    task_description=table_task_description,
+                    output_format=table_output_format, 
+                    rules_str=table_rules_str, 
+                    few_shot_demo_text_content=table_few_shot_demo_text_content, 
+                    few_shot_demo_key_value_list=table_few_shot_demo_key_value_list)
+
+                self.generate_and_merge_chat_results(prompt, 
+                    key_list, final_results, failed_results)
+        
+        if len(key_list) > 0:
+            question_key_list = [f"抽取关键信息:{key}" for key in key_list]
+            vector = vector_info['vector']
+            if not vector_info['flag_too_short_text']:
+                related_text = self.retriever.similarity_retrieval(
+                    question_key_list, vector)
+            else:
+                related_text = " ".join(vector)
+            
+            prompt = self.text_pe.generate_prompt(related_text, 
+                key_list, 
+                task_description=text_task_description,
+                output_format=text_output_format, 
+                rules_str=text_rules_str, 
+                few_shot_demo_text_content=text_few_shot_demo_text_content, 
+                few_shot_demo_key_value_list=text_few_shot_demo_key_value_list)
+            
+            self.generate_and_merge_chat_results(prompt, 
+                key_list, final_results, failed_results)
+
+        return final_results
+
+    def predict(self, *args, **kwargs):
+        logging.error(
+            "PP-ChatOCRv3-doc Pipeline do not support to call `predict()` directly! Please invoke `visual_predict`, `build_vector`, `chat` sequentially to obtain the result."
+        )
+        return

+ 44 - 0
paddlex/inference/pipelines_new/pp_chatocrv3_doc/result.py

@@ -0,0 +1,44 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+import random
+import numpy as np
+import cv2
+import PIL
+from PIL import Image, ImageDraw, ImageFont
+
+from ....utils.fonts import PINGFANG_FONT_FILE_PATH
+from ..components import BaseResult
+
+class VisualInfoResult(BaseResult):
+    """VisualInfoResult"""
+    
+    pass
+
+# class VectorResult(BaseResult, Base64Mixin):
+#     """VisualInfoResult"""
+
+#     def _to_base64(self):
+#         return self["vector"]
+
+
+# class RetrievalResult(BaseResult):
+#     """VisualInfoResult"""
+
+#     pass
+
+
+# class ChatResult(BaseResult):
+#     """VisualInfoResult"""

+ 2 - 0
paddlex/utils/flags.py

@@ -26,6 +26,7 @@ __all__ = [
     "INFER_BENCHMARK_OUTPUT",
     "INFER_BENCHMARK_DATA_SIZE",
     "FLAGS_json_format_model",
+    "USE_NEW_INFERENCE",
 ]
 
 
@@ -46,6 +47,7 @@ DRY_RUN = get_flag_from_env_var("PADDLE_PDX_DRY_RUN", False)
 CHECK_OPTS = get_flag_from_env_var("PADDLE_PDX_CHECK_OPTS", False)
 EAGER_INITIALIZATION = get_flag_from_env_var("PADDLE_PDX_EAGER_INIT", True)
 FLAGS_json_format_model = get_flag_from_env_var("FLAGS_json_format_model", None)
+USE_NEW_INFERENCE = get_flag_from_env_var("USE_NEW_INFERENCE", False)
 
 # Inference Benchmark
 INFER_BENCHMARK = get_flag_from_env_var("PADDLE_PDX_INFER_BENCHMARK", None)

+ 15 - 0
paddlex/utils/fonts/__init__.py

@@ -15,10 +15,25 @@
 
 from pathlib import Path
 
+import PIL
+from PIL import ImageFont
 
 def get_pingfang_file_path():
     """get pingfang font file path"""
     return (Path(__file__).parent / "PingFang-SC-Regular.ttf").resolve().as_posix()
 
+def create_font(txt, sz, font_path):
+    """create font"""
+    font_size = int(sz[1] * 0.8)
+    font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
+    if int(PIL.__version__.split(".")[0]) < 10:
+        length = font.getsize(txt)[0]
+    else:
+        length = font.getlength(txt)
+
+    if length > sz[0]:
+        font_size = int(font_size * sz[0] / length)
+        font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
+    return font
 
 PINGFANG_FONT_FILE_PATH = get_pingfang_file_path()