Переглянути джерело

modify layout parsing and pp-chatocr to support different versions (#2751)

* add the new architecture of pipelines

* add the new architecture of pipelines

* add explanatory note

* add explanatory note

* fix some modules name

* add pipelines of single modual, sseal recogniton and table recognition.

* support tacking pdf and original image

* add PP-ChatOCRv4 and support PDF

* add PP-ChatOCRv4 and support PDF

* modify layout parsing and pp-chatocr to support different versions

* modify layout parsing and pp-chatocr to support different versions

* modify layout parsing and pp-chatocr to support different versions
dyning 10 місяців тому
батько
коміт
298cef7270
50 змінених файлів з 2493 додано та 576 видалено
  1. 21 4
      api_examples/pipelines/test_doc_preprocessor.py
  2. 3 0
      api_examples/pipelines/test_image_classification.py
  3. 30 3
      api_examples/pipelines/test_layout_parsing.py
  4. 2 2
      api_examples/pipelines/test_ocr.py
  5. 14 4
      api_examples/pipelines/test_pp_chatocrv3.py
  6. 91 0
      api_examples/pipelines/test_pp_chatocrv4.py
  7. 6 0
      api_examples/pipelines/test_seal_recognition.py
  8. 5 0
      api_examples/pipelines/test_table_recognition.py
  9. 0 22
      paddlex/configs/pipelines/OCR.yaml
  10. 41 21
      paddlex/configs/pipelines/PP-ChatOCRv3-doc.yaml
  11. 163 0
      paddlex/configs/pipelines/PP-ChatOCRv4-doc.yaml
  12. 30 17
      paddlex/configs/pipelines/layout_parsing.yaml
  13. 1 0
      paddlex/configs/pipelines/seal_recognition.yaml
  14. 2 0
      paddlex/configs/pipelines/table_recognition.yaml
  15. 0 2
      paddlex/inference/common/result/base_cv_result.py
  16. 8 0
      paddlex/inference/common/result/mixin.py
  17. 3 2
      paddlex/inference/pipelines_new/__init__.py
  18. 3 3
      paddlex/inference/pipelines_new/base.py
  19. 1 2
      paddlex/inference/pipelines_new/components/chat_server/base.py
  20. 52 2
      paddlex/inference/pipelines_new/components/chat_server/ernie_bot_chat.py
  21. 1 2
      paddlex/inference/pipelines_new/components/common/base_result.py
  22. 5 5
      paddlex/inference/pipelines_new/components/common/crop_image_regions.py
  23. 2 3
      paddlex/inference/pipelines_new/components/common/seal_det_warp.py
  24. 1 1
      paddlex/inference/pipelines_new/components/common/sort_boxes.py
  25. 1 0
      paddlex/inference/pipelines_new/components/prompt_engeering/__init__.py
  26. 1 2
      paddlex/inference/pipelines_new/components/prompt_engeering/base.py
  27. 127 0
      paddlex/inference/pipelines_new/components/prompt_engeering/generate_ensemble_prompt.py
  28. 23 17
      paddlex/inference/pipelines_new/components/prompt_engeering/generate_kie_prompt.py
  29. 1 2
      paddlex/inference/pipelines_new/components/retriever/base.py
  30. 12 11
      paddlex/inference/pipelines_new/components/retriever/ernie_bot_retriever.py
  31. 9 21
      paddlex/inference/pipelines_new/doc_preprocessor/pipeline.py
  32. 18 11
      paddlex/inference/pipelines_new/doc_preprocessor/result.py
  33. 2 2
      paddlex/inference/pipelines_new/image_classification/pipeline.py
  34. 104 115
      paddlex/inference/pipelines_new/layout_parsing/pipeline.py
  35. 29 125
      paddlex/inference/pipelines_new/layout_parsing/result.py
  36. 4 3
      paddlex/inference/pipelines_new/ocr/pipeline.py
  37. 4 4
      paddlex/inference/pipelines_new/ocr/result.py
  38. 16 0
      paddlex/inference/pipelines_new/pp_chatocr/__init__.py
  39. 106 0
      paddlex/inference/pipelines_new/pp_chatocr/pipeline_base.py
  40. 142 155
      paddlex/inference/pipelines_new/pp_chatocr/pipeline_v3.py
  41. 635 0
      paddlex/inference/pipelines_new/pp_chatocr/pipeline_v4.py
  42. 1 2
      paddlex/inference/pipelines_new/pp_chatocr/result.py
  43. 1 1
      paddlex/inference/pipelines_new/seal_recognition/__init__.py
  44. 228 0
      paddlex/inference/pipelines_new/seal_recognition/pipeline.py
  45. 55 0
      paddlex/inference/pipelines_new/seal_recognition/result.py
  46. 15 0
      paddlex/inference/pipelines_new/table_recognition/__init__.py
  47. 310 0
      paddlex/inference/pipelines_new/table_recognition/pipeline.py
  48. 111 0
      paddlex/inference/pipelines_new/table_recognition/result.py
  49. 9 10
      paddlex/inference/pipelines_new/table_recognition/table_recognition_post_processing.py
  50. 44 0
      paddlex/inference/pipelines_new/table_recognition/utils.py

+ 21 - 4
api_examples/pipelines/test_doc_preprocessor.py

@@ -16,13 +16,30 @@ from paddlex import create_pipeline
 
 pipeline = create_pipeline(pipeline="doc_preprocessor")
 
-test_img_path = "./test_samples/img_rot180_demo.jpg"
-# test_img_path = "./test_samples/doc_distort_test.jpg"
-
 output = pipeline.predict(
-    test_img_path, use_doc_orientation_classify=True, use_doc_unwarping=True
+    "./test_samples/img_rot180_demo.jpg",
+    use_doc_orientation_classify=True,
+    use_doc_unwarping=False,
 )
 
+# output = pipeline.predict(
+#     "./test_samples/doc_distort_test.jpg",
+#     use_doc_orientation_classify=False,
+#     use_doc_unwarping=True
+# )
+
+# output = pipeline.predict(
+#     "./test_samples/doc_distort_test.jpg",
+#     use_doc_orientation_classify=True,
+#     use_doc_unwarping=True
+# )
+
+# output = pipeline.predict(
+#     "./test_samples/test_doc_processer.pdf",
+#     use_doc_orientation_classify=True,
+#     use_doc_unwarping=False
+# )
+
 for res in output:
     print(res)
     res.save_to_img("./output")

+ 3 - 0
api_examples/pipelines/test_image_classification.py

@@ -17,6 +17,9 @@ from paddlex import create_pipeline
 pipeline = create_pipeline(pipeline="image_classification")
 
 output = pipeline.predict("./test_samples/general_image_classification_001.jpg")
+
+# output = pipeline.predict("./test_samples/财报1.pdf")
+
 for res in output:
     print(res)
     res.print()  ## 打印预测的结构化输出

+ 30 - 3
api_examples/pipelines/test_layout_parsing.py

@@ -17,14 +17,41 @@ from paddlex import create_pipeline
 pipeline = create_pipeline(pipeline="layout_parsing")
 
 output = pipeline.predict(
-    "./test_samples/test_layout_parsing.jpg",
-    use_doc_orientation_classify=True,
-    use_doc_unwarping=True,
+    "./test_samples/demo_paper.png",
+    use_doc_orientation_classify=False,
+    use_doc_unwarping=False,
     use_common_ocr=True,
     use_seal_recognition=True,
     use_table_recognition=True,
 )
 
+# output = pipeline.predict(
+#     "./test_samples/layout.jpg",
+#     use_doc_orientation_classify=False,
+#     use_doc_unwarping=False,
+#     use_common_ocr=True,
+#     use_seal_recognition=True,
+#     use_table_recognition=True,
+# )
+
+# output = pipeline.predict(
+#     "./test_samples/test_layout_parsing.jpg",
+#     use_doc_orientation_classify=True,
+#     use_doc_unwarping=True,
+#     use_common_ocr=True,
+#     use_seal_recognition=True,
+#     use_table_recognition=True,
+# )
+
+# output = pipeline.predict(
+#     "./test_samples/财报1.pdf",
+#     use_doc_orientation_classify=False,
+#     use_doc_unwarping=False,
+#     use_common_ocr=True,
+#     use_seal_recognition=True,
+#     use_table_recognition=True,
+# )
+
 for res in output:
     print(res)
     res.save_results("./output")

+ 2 - 2
api_examples/pipelines/test_ocr.py

@@ -16,9 +16,9 @@ from paddlex import create_pipeline
 
 pipeline = create_pipeline(pipeline="OCR")
 
-# output = pipeline.predict("./test_samples/general_ocr_002.png")
+output = pipeline.predict("./test_samples/general_ocr_002.png")
 
-output = pipeline.predict("./test_samples/seal_text_det.png")
+# output = pipeline.predict("./test_samples/财报1.pdf")
 for res in output:
     print(res)
     res.save_to_img("./output")

+ 14 - 4
api_examples/pipelines/test_pp_chatocrv3.py

@@ -19,11 +19,13 @@ pipeline = create_pipeline(pipeline="PP-ChatOCRv3-doc")
 img_path = "./test_samples/vehicle_certificate-1.png"
 key_list = ["驾驶室准乘人数"]
 
+# img_path = "./test_samples/财报1.pdf"
+# key_list = ['公司全称是什么']
 
 visual_predict_res = pipeline.visual_predict(
     img_path,
-    use_doc_orientation_classify=True,
-    use_doc_unwarping=True,
+    use_doc_orientation_classify=False,
+    use_doc_unwarping=False,
     use_common_ocr=True,
     use_seal_recognition=True,
     use_table_recognition=True,
@@ -36,12 +38,20 @@ for res in visual_predict_res:
     # print(res["visual_info"])
     visual_info_list.append(res["visual_info"])
 
-# pipeline.save_visual_info_list(visual_info_list, "./res_visual_info/tmp_visual_info.json")
+pipeline.save_visual_info_list(
+    visual_info_list, "./res_visual_info/tmp_visual_info.json"
+)
 
-# visual_info_list = pipeline.load_visual_info_list("./res_visual_info/tmp_visual_info.json")
+visual_info_list = pipeline.load_visual_info_list(
+    "./res_visual_info/tmp_visual_info.json"
+)
 
 vector_info = pipeline.build_vector(visual_info_list)
 
+pipeline.save_vector(vector_info, "./res_visual_info/tmp_vector_info.json")
+
+vector_info = pipeline.load_vector("./res_visual_info/tmp_vector_info.json")
+
 chat_result = pipeline.chat(key_list, visual_info_list, vector_info=vector_info)
 
 print(chat_result)

+ 91 - 0
api_examples/pipelines/test_pp_chatocrv4.py

@@ -0,0 +1,91 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from paddlex import create_pipeline
+
+pipeline = create_pipeline(pipeline="PP-ChatOCRv4-doc")
+
+img_path = "./test_samples/研报2_11.jpg"
+key_list = ["三位一体养老生态系统包含哪些"]
+
+# img_path = "./test_samples/财报1.pdf"
+# key_list = ['公司全称是什么']
+
+
+def load_mllm_results():
+    """load mllm results"""
+    import json
+
+    predict_file_path = "/paddle/icode/baidu/paddlex_closed/evaluation/pipelines/ppchatocr/backend_predict_files/predict_mix_doc_v1_2B-1209.json"
+    mllm_predict_dict = {}
+    with open(predict_file_path, "r") as fin:
+        predict_infos_list = json.load(fin)
+        for predict_infos in predict_infos_list:
+            img_name = predict_infos["image_path"]
+            predict_info_list = predict_infos["predict_info_list"]
+            for predict_info in predict_info_list:
+                key = img_name + "_" + predict_info["question"]
+                mllm_predict_dict[key] = predict_info
+    return mllm_predict_dict
+
+
+mllm_predict_dict_all = load_mllm_results()
+
+visual_predict_res = pipeline.visual_predict(
+    img_path,
+    use_doc_orientation_classify=False,
+    use_doc_unwarping=False,
+    use_common_ocr=True,
+    use_seal_recognition=True,
+    use_table_recognition=True,
+)
+
+# ####[TODO] 增加类别信息
+visual_info_list = []
+for res in visual_predict_res:
+    # res['layout_parsing_result'].save_results("./output/")
+    # print(res["visual_info"])
+    visual_info_list.append(res["visual_info"])
+
+pipeline.save_visual_info_list(
+    visual_info_list, "./res_visual_info/tmp_visual_info.json"
+)
+
+visual_info_list = pipeline.load_visual_info_list(
+    "./res_visual_info/tmp_visual_info.json"
+)
+
+vector_info = pipeline.build_vector(visual_info_list)
+
+pipeline.save_vector(vector_info, "./res_visual_info/tmp_vector_info.json")
+
+vector_info = pipeline.load_vector("./res_visual_info/tmp_vector_info.json")
+
+mllm_predict_dict = {}
+image_name = img_path.split("/")[-1]
+for key in key_list:
+    mllm_predict_key = image_name + "_" + key
+    mllm_result = ""
+    if mllm_predict_key in mllm_predict_dict_all:
+        mllm_result = mllm_predict_dict_all[mllm_predict_key]["predicts"]
+    mllm_predict_dict[key] = mllm_result
+
+chat_result = pipeline.chat(
+    key_list,
+    visual_info_list,
+    vector_info=vector_info,
+    mllm_predict_dict=mllm_predict_dict,
+)
+
+print(chat_result)

+ 6 - 0
api_examples/pipelines/test_seal_recognition.py

@@ -16,6 +16,12 @@ from paddlex import create_pipeline
 
 pipeline = create_pipeline(pipeline="seal_recognition")
 output = pipeline.predict("./test_samples/seal_text_det.png")
+
+# output = pipeline.predict("./test_samples/seal_text_det.png",
+#     use_layout_detection=False)
+
+# output = pipeline.predict("./test_samples/财报1.pdf")
+
 for res in output:
     print(res)
     res.save_results("./output")

+ 5 - 0
api_examples/pipelines/test_table_recognition.py

@@ -17,6 +17,11 @@ from paddlex import create_pipeline
 pipeline = create_pipeline(pipeline="table_recognition")
 
 output = pipeline("./test_samples/table_recognition.jpg")
+
+# output = pipeline("./test_samples/table_recognition.jpg",
+#     use_layout_detection=False)
+
+# output = pipeline("./test_samples/财报1.pdf")
 for res in output:
     print(res)
     res.save_results("./output/")

+ 0 - 22
paddlex/configs/pipelines/OCR.yaml

@@ -1,10 +1,6 @@
 
 pipeline_name: OCR
 
-##############################################
-####### Config for Common OCR
-##############################################
-
 text_type: general
 
 SubModules:
@@ -18,21 +14,3 @@ SubModules:
     model_name: PP-OCRv4_mobile_rec
     model_dir: null
     batch_size: 1
-
-##############################################
-####### Config for Seal OCR
-##############################################
-
-# text_type: seal
-
-# SubModules:
-#   TextDetection:
-#     module_name: seal_text_detection
-#     model_name: PP-OCRv4_mobile_seal_det
-#     model_dir: null
-#     batch_size: 1    
-#   TextRecognition:
-#     module_name: text_recognition
-#     model_name: PP-OCRv4_mobile_rec
-#     model_dir: null
-#     batch_size: 1

+ 41 - 21
paddlex/configs/pipelines/PP-ChatOCRv3-doc.yaml

@@ -9,6 +9,7 @@ SubModules:
     ak: "api_key" # Set this to a real API key
     sk: "secret_key"  # Set this to a real secret key
 
+
   LLM_Retriever:
     module_name: retriever
     model_name: ernie-3.5
@@ -16,41 +17,48 @@ SubModules:
     ak: "api_key" # Set this to a real API key
     sk: "secret_key"  # Set this to a real secret key
 
+
   PromptEngneering:
     KIE_CommonText:
       module_name: prompt_engneering
-      task_type: text_kie_prompt
+      task_type: text_kie_prompt_v1
 
       task_description: '你现在的任务是从OCR文字识别的结果中提取关键词列表中每一项对应的关键信息。
           OCR的文字识别结果使用```符号包围,包含所识别出来的文字,顺序在原始图片中从左至右、从上至下。
           我指定的关键词列表使用[]符号包围。请注意OCR的文字识别结果可能存在长句子换行被切断、不合理的分词、
           文字被错误合并等问题,你需要结合上下文语义进行综合判断,以抽取准确的关键信息。'
 
+      rules_str:
+
       output_format: '在返回结果时使用JSON格式,包含多个key-value对,key值为我指定的问题,value值为该问题对应的答案。
           如果认为OCR识别结果中,对于问题key,没有答案,则将value赋值为"未知"。请只输出json格式的结果,
           并做json格式校验后返回,不要包含其它多余文字!'
 
-      rules_str:
       few_shot_demo_text_content:
+
       few_shot_demo_key_value_list:
           
     KIE_Table:
       module_name: prompt_engneering
-      task_type: table_kie_prompt
+      task_type: table_kie_prompt_v1
 
       task_description: '你现在的任务是从输入的表格内容中提取关键词列表中每一项对应的关键信息,
           表格内容用```符号包围,我指定的关键词列表使用[]符号包围。你需要结合上下文语义进行综合判断,以抽取准确的关键信息。'
+      
+      rules_str:
 
       output_format: '在返回结果时使用JSON格式,包含多个key-value对,key值为我指定的关键词,value值为所抽取的结果。
           如果认为表格识别结果中没有关键词key对应的value,则将value赋值为"未知"。请只输出json格式的结果,
           并做json格式校验后返回,不要包含其它多余文字!'
-      rules_str:
+      
       few_shot_demo_text_content:
+
       few_shot_demo_key_value_list:
 
 SubPipelines:
   LayoutParser:
     pipeline_name: layout_parsing
+    
     use_doc_preprocessor: True
     use_general_ocr: True
     use_seal_recognition: True
@@ -62,11 +70,6 @@ SubPipelines:
         model_name: RT-DETR-H_layout_3cls
         model_dir: null
         batch_size: 1
-      TableStructureRecognition:
-        module_name: table_structure_recognition
-        model_name: SLANet_plus
-        model_dir: null
-        batch_size: 1
 
     SubPipelines:
       DocPreprocessor:
@@ -100,17 +103,34 @@ SubPipelines:
             model_dir: null
             batch_size: 1
 
-      SealOCR:
-        pipeline_name: OCR
-        text_type: seal
+      TableRecognition:
+        pipeline_name: table_recognition
+        use_layout_detection: False
+        use_doc_preprocessor: False
+        use_ocr_model: False
         SubModules:
-          TextDetection:
-            module_name: seal_text_detection
-            model_name: PP-OCRv4_server_seal_det
-            model_dir: null
-            batch_size: 1    
-          TextRecognition:
-            module_name: text_recognition
-            model_name: PP-OCRv4_server_rec
+          TableStructureRecognition:
+            module_name: table_structure_recognition
+            model_name: SLANet_plus
             model_dir: null
-            batch_size: 1  
+            batch_size: 1
+
+      SealRecognition:
+        pipeline_name: seal_recognition
+        use_layout_detection: False
+        use_doc_preprocessor: False
+        SubPipelines:
+          SealOCR:
+            pipeline_name: OCR
+            text_type: seal
+            SubModules:
+              TextDetection:
+                module_name: seal_text_detection
+                model_name: PP-OCRv4_server_seal_det
+                model_dir: null
+                batch_size: 1    
+              TextRecognition:
+                module_name: text_recognition
+                model_name: PP-OCRv4_server_rec
+                model_dir: null
+                batch_size: 1

+ 163 - 0
paddlex/configs/pipelines/PP-ChatOCRv4-doc.yaml

@@ -0,0 +1,163 @@
+
+pipeline_name: PP-ChatOCRv4-doc
+
+use_mllm_predict: True
+
+SubModules:
+  LLM_Chat:
+    module_name: chat_bot
+    model_name: ernie-3.5
+    api_type: qianfan
+    ak: "api_key" # Set this to a real API key
+    sk: "secret_key"  # Set this to a real secret key     
+
+  LLM_Retriever:
+    module_name: retriever
+    model_name: ernie-3.5
+    api_type: qianfan
+    ak: "api_key" # Set this to a real API key
+    sk: "secret_key"  # Set this to a real secret key
+
+  PromptEngneering:
+    KIE_CommonText:
+      module_name: prompt_engneering
+
+      task_type: text_kie_prompt_v2
+
+      task_description: '你现在的任务是从OCR结果中提取问题列表中每一个问题的答案。
+          OCR的文字识别结果使用```符号包围,包含所识别出来的文字,顺序在原始图片中从左至右、从上至下。
+          我指定的问题列表使用[]符号包围。请注意OCR的文字识别结果可能存在长句子换行被切断、不合理的分词、
+          文字被错误合并等问题,你需要结合上下文语义进行综合判断,以获取准确的答案。'
+
+      output_format: '在返回结果时使用JSON格式,包含多个key-value对,key值为我指定的问题,value值为该问题对应的答案。
+          如果认为OCR识别结果中,对于问题key,没有答案,则将value赋值为"未知"。请只输出json格式的结果,
+          并做json格式校验后返回,不要包含其它多余文字!'
+
+      rules_str: '请依次确认满足下面要求。(1)每个问题的答案用OCR结果的内容回答,针对问题回答尽可能详细和完整,
+          并保持格式、数字、正负号、单位、符号和标点都与OCR结果中的内容完全一致。
+          (2)如果答案的句末有标点符号,请添加标点符号。
+          (3)对于答案中的数字,如果可以推断出单位,请补充相应的单位。'
+
+      few_shot_demo_text_content:
+
+      few_shot_demo_key_value_list:
+          
+    KIE_Table:
+      module_name: prompt_engneering
+
+      task_type: table_kie_prompt_v2
+
+      task_description: '你现在的任务是从输入的html格式的表格内容中提取问题列表中每一个问题的答案。
+          表格内容使用```符号包围,我指定的问题列表使用[]符号包围。'
+
+      output_format: '在返回结果时使用JSON格式,包含多个key-value对,key值为我指定的问题,value值为该问题对应的答案。
+          如果认为表格内容中,对于问题key,没有答案,则将value赋值为"未知"。请只输出json格式的结果,
+          并做json格式校验后返回,不要包含其它多余文字!'
+
+      rules_str: '请依次确认满足下面要求。(1)每个问题的答案用表格内容回答,针对问题回答尽可能详细和完整,
+          并保持格式、数字、正负号、单位、符号和标点都与表格内容完全一致。
+          (2)对于答案中的数字,如果可以推断出单位,请补充相应的单位。
+          (3)如果答案是百分比,请添加百分号。'
+
+      few_shot_demo_text_content:
+
+      few_shot_demo_key_value_list:
+    
+    Ensemble:
+      module_name: prompt_engneering
+      
+      task_type: ensemble_prompt
+
+      task_description: '你现在的任务是,对于一个问题,对比方法A和方法B的结果,选择更准确的一个回答。
+        问题用```符号包围。'
+
+      output_format: '请返回JSON格式的结果,包含多个key-value对,key值为我指定的问题,
+        value值为`方法A`或`方法B`。如果对于问题key,没有找到答案,将value赋值为"未知"。
+        请只输出json格式的结果,并做json格式校验后返回,不要包含其它多余文字!'
+
+      rules_str: '对于涉及数字的问题,请返回与问题描述最相关且数字表述正确的答案。
+        请特别注意数字中的标点使用是否合理。'
+
+      few_shot_demo_text_content:
+
+      few_shot_demo_key_value_list:
+
+SubPipelines:
+  LayoutParser:
+    pipeline_name: layout_parsing
+    
+    use_doc_preprocessor: True
+    use_general_ocr: True
+    use_seal_recognition: True
+    use_table_recognition: True
+
+    SubModules:
+      LayoutDetection:
+        module_name: layout_detection
+        model_name: RT-DETR-H_layout_3cls
+        model_dir: null
+        batch_size: 1
+
+    SubPipelines:
+      DocPreprocessor:
+        pipeline_name: doc_preprocessor
+        use_doc_orientation_classify: True
+        use_doc_unwarping: True
+        SubModules:
+          DocOrientationClassify:
+            module_name: doc_text_orientation
+            model_name: PP-LCNet_x1_0_doc_ori
+            model_dir: null
+            batch_size: 1
+          DocUnwarping:
+            module_name: image_unwarping
+            model_name: UVDoc
+            model_dir: null
+            batch_size: 1
+
+      GeneralOCR:
+        pipeline_name: OCR
+        text_type: general
+        SubModules:
+          TextDetection:
+            module_name: text_detection
+            model_name: PP-OCRv4_server_det
+            model_dir: null
+            batch_size: 1    
+          TextRecognition:
+            module_name: text_recognition
+            model_name: PP-OCRv4_server_rec
+            model_dir: null
+            batch_size: 1
+
+      TableRecognition:
+        pipeline_name: table_recognition
+        use_layout_detection: False
+        use_doc_preprocessor: False
+        use_ocr_model: False
+        SubModules:
+          TableStructureRecognition:
+            module_name: table_structure_recognition
+            model_name: SLANet_plus
+            model_dir: null
+            batch_size: 1
+
+      SealRecognition:
+        pipeline_name: seal_recognition
+        use_layout_detection: False
+        use_doc_preprocessor: False
+        SubPipelines:
+          SealOCR:
+            pipeline_name: OCR
+            text_type: seal
+            SubModules:
+              TextDetection:
+                module_name: seal_text_detection
+                model_name: PP-OCRv4_server_seal_det
+                model_dir: null
+                batch_size: 1    
+              TextRecognition:
+                module_name: text_recognition
+                model_name: PP-OCRv4_server_rec
+                model_dir: null
+                batch_size: 1

+ 30 - 17
paddlex/configs/pipelines/layout_parsing.yaml

@@ -13,12 +13,6 @@ SubModules:
     model_dir: null
     batch_size: 1
 
-  TableStructureRecognition:
-    module_name: table_structure_recognition
-    model_name: SLANet_plus
-    model_dir: null
-    batch_size: 1
-
 SubPipelines:
   DocPreprocessor:
     pipeline_name: doc_preprocessor
@@ -35,6 +29,7 @@ SubPipelines:
         model_name: UVDoc
         model_dir: null
         batch_size: 1
+
   GeneralOCR:
     pipeline_name: OCR
     text_type: general
@@ -49,17 +44,35 @@ SubPipelines:
         model_name: PP-OCRv4_server_rec
         model_dir: null
         batch_size: 1
-  SealOCR:
-    pipeline_name: OCR
-    text_type: seal
+
+  TableRecognition:
+    pipeline_name: table_recognition
+    use_layout_detection: False
+    use_doc_preprocessor: False
+    use_ocr_model: False
     SubModules:
-      TextDetection:
-        module_name: seal_text_detection
-        model_name: PP-OCRv4_server_seal_det
-        model_dir: null
-        batch_size: 1    
-      TextRecognition:
-        module_name: text_recognition
-        model_name: PP-OCRv4_server_rec
+      TableStructureRecognition:
+        module_name: table_structure_recognition
+        model_name: SLANet_plus
         model_dir: null
         batch_size: 1
+
+  SealRecognition:
+    pipeline_name: seal_recognition
+    use_layout_detection: False
+    use_doc_preprocessor: False
+    SubPipelines:
+      SealOCR:
+        pipeline_name: OCR
+        text_type: seal
+        SubModules:
+          TextDetection:
+            module_name: seal_text_detection
+            model_name: PP-OCRv4_server_seal_det
+            model_dir: null
+            batch_size: 1    
+          TextRecognition:
+            module_name: text_recognition
+            model_name: PP-OCRv4_server_rec
+            model_dir: null
+            batch_size: 1

+ 1 - 0
paddlex/configs/pipelines/seal_recognition.yaml

@@ -1,6 +1,7 @@
 
 pipeline_name: seal_recognition
 
+use_layout_detection: True
 use_doc_preprocessor: True
 
 SubModules:

+ 2 - 0
paddlex/configs/pipelines/table_recognition.yaml

@@ -1,7 +1,9 @@
 
 pipeline_name: table_recognition
 
+use_layout_detection: True
 use_doc_preprocessor: True
+use_ocr_model: True
 
 SubModules:
   LayoutDetection:

+ 0 - 2
paddlex/inference/common/result/base_cv_result.py

@@ -20,8 +20,6 @@ from ...utils.io import ImageWriter
 class BaseCVResult(BaseResult, StrMixin, JsonMixin, ImgMixin):
     """Base class for computer vision results."""
 
-    INPUT_IMG_KEY = "input_img"
-
     def __init__(self, data: dict) -> None:
         """
         Initialize the BaseCVResult.

+ 8 - 0
paddlex/inference/common/result/mixin.py

@@ -249,6 +249,8 @@ class Base64Mixin:
         if not str(save_path).lower().endswith((".b64")):
             fp = Path(self["input_path"])
             save_path = Path(save_path) / f"{fp.stem}{fp.suffix}"
+        else:
+            save_path = Path(save_path)
         self._base64_writer.write(save_path.as_posix(), self.base64, *args, **kwargs)
 
 
@@ -353,6 +355,8 @@ class CSVMixin:
         """
         if not str(save_path).endswith(".csv"):
             save_path = Path(save_path) / f"{Path(self['input_path']).stem}.csv"
+        else:
+            save_path = Path(save_path)
         self._csv_writer.write(save_path.as_posix(), self.csv, *args, **kwargs)
 
 
@@ -398,6 +402,8 @@ class HtmlMixin:
         """
         if not str(save_path).endswith(".html"):
             save_path = Path(save_path) / f"{Path(self['input_path']).stem}.html"
+        else:
+            save_path = Path(save_path)
         self._html_writer.write(save_path.as_posix(), self.html, *args, **kwargs)
 
 
@@ -443,6 +449,8 @@ class XlsxMixin:
         """
         if not str(save_path).endswith(".xlsx"):
             save_path = Path(save_path) / f"{Path(self['input_path']).stem}.xlsx"
+        else:
+            save_path = Path(save_path)
         self._xlsx_writer.write(save_path.as_posix(), self.xlsx, *args, **kwargs)
 
 

+ 3 - 2
paddlex/inference/pipelines_new/__init__.py

@@ -18,12 +18,13 @@ from .base import BasePipeline
 from ..utils.pp_option import PaddlePredictorOption
 from .components import BaseChat, BaseRetriever, BaseGeneratePrompt
 from ...utils.config import parse_config
-
 from .ocr import OCRPipeline
 from .doc_preprocessor import DocPreprocessorPipeline
 from .layout_parsing import LayoutParsingPipeline
-from .pp_chatocrv3_doc import PP_ChatOCRv3_doc_Pipeline
+from .pp_chatocr import PP_ChatOCRv3_Pipeline, PP_ChatOCRv4_Pipeline
 from .image_classification import ImageClassificationPipeline
+from .seal_recognition import SealRecognitionPipeline
+from .table_recognition import TableRecognitionPipeline
 
 
 def get_pipeline_path(pipeline_name: str) -> str:

+ 3 - 3
paddlex/inference/pipelines_new/base.py

@@ -12,12 +12,12 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from pathlib import Path
+from typing import Any, Dict, Optional
 from abc import ABC, abstractmethod
-from ...utils.subclass_register import AutoRegisterABCMetaClass
 import yaml
 import codecs
-from pathlib import Path
-from typing import Any, Dict, Optional
+from ...utils.subclass_register import AutoRegisterABCMetaClass
 from ..utils.pp_option import PaddlePredictorOption
 from ..models import BasePredictor
 

+ 1 - 2
paddlex/inference/pipelines_new/components/chat_server/base.py

@@ -13,9 +13,8 @@
 # limitations under the License.
 
 from abc import ABC, abstractmethod
-from .....utils.subclass_register import AutoRegisterABCMetaClass
-
 import inspect
+from .....utils.subclass_register import AutoRegisterABCMetaClass
 
 
 class BaseChat(ABC, metaclass=AutoRegisterABCMetaClass):

+ 52 - 2
paddlex/inference/pipelines_new/components/chat_server/ernie_bot_chat.py

@@ -12,10 +12,12 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from typing import Dict
+import re
+import json
+import erniebot
 from .....utils import logging
 from .base import BaseChat
-import erniebot
-from typing import Dict
 
 
 class ErnieBotChat(BaseChat):
@@ -110,3 +112,51 @@ class ErnieBotChat(BaseChat):
                 logging.error(e)
                 self.ERROR_MASSAGE = "大模型调用失败"
         return None
+
+    def fix_llm_result_format(self, llm_result: str) -> dict:
+        """
+        Fix the format of the LLM result.
+
+        Args:
+            llm_result (str): The result from the LLM (Large Language Model).
+
+        Returns:
+            dict: A fixed format dictionary from the LLM result.
+        """
+        if not llm_result:
+            return {}
+
+        if "json" in llm_result or "```" in llm_result:
+            llm_result = (
+                llm_result.replace("```", "").replace("json", "").replace("/n", "")
+            )
+            llm_result = llm_result.replace("[", "").replace("]", "")
+
+        try:
+            llm_result = json.loads(llm_result)
+            llm_result_final = {}
+            for key in llm_result:
+                value = llm_result[key]
+                if isinstance(value, list):
+                    if len(value) > 0:
+                        llm_result_final[key] = value[0]
+                else:
+                    llm_result_final[key] = value
+            return llm_result_final
+
+        except:
+            results = (
+                llm_result.replace("\n", "")
+                .replace("    ", "")
+                .replace("{", "")
+                .replace("}", "")
+            )
+            if not results.endswith('"'):
+                results = results + '"'
+            pattern = r'"(.*?)": "([^"]*)"'
+            matches = re.findall(pattern, str(results))
+            if len(matches) > 0:
+                llm_result = {k: v for k, v in matches}
+                return llm_result
+            else:
+                return {}

+ 1 - 2
paddlex/inference/pipelines_new/components/common/base_result.py

@@ -13,10 +13,9 @@
 # limitations under the License.
 
 import inspect
-
+from typing import Dict
 from ....utils.io import ImageReader, ImageWriter
 from ..utils.mixin import JsonMixin, ImgMixin, StrMixin
-from typing import Dict
 
 
 class BaseResult(dict, StrMixin, JsonMixin):

+ 5 - 5
paddlex/inference/pipelines_new/components/common/crop_image_regions.py

@@ -12,15 +12,15 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from .base_operator import BaseOperator
-import numpy as np
-from ....utils.io import ImageReader
+from typing import Tuple
 import copy
+import numpy as np
 import cv2
-from .seal_det_warp import AutoRectifier
 from shapely.geometry import Polygon
 from numpy.linalg import norm
-from typing import Tuple
+from .base_operator import BaseOperator
+from ....utils.io import ImageReader
+from .seal_det_warp import AutoRectifier
 
 
 class CropByBoxes(BaseOperator):

+ 2 - 3
paddlex/inference/pipelines_new/components/common/seal_det_warp.py

@@ -13,12 +13,11 @@
 # limitations under the License.
 
 import os, sys
+import copy
+import time
 import numpy as np
 from numpy import cos, sin, arctan, sqrt
 import cv2
-import copy
-import time
-
 from .....utils import logging
 
 #### [TODO] need sunting to add explanatory notes

+ 1 - 1
paddlex/inference/pipelines_new/components/common/sort_boxes.py

@@ -12,8 +12,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from .base_operator import BaseOperator
 import numpy as np
+from .base_operator import BaseOperator
 
 
 class SortQuadBoxes(BaseOperator):

+ 1 - 0
paddlex/inference/pipelines_new/components/prompt_engeering/__init__.py

@@ -13,3 +13,4 @@
 # limitations under the License.
 
 from .generate_kie_prompt import GenerateKIEPrompt
+from .generate_ensemble_prompt import GenerateEnsemblePrompt

+ 1 - 2
paddlex/inference/pipelines_new/components/prompt_engeering/base.py

@@ -12,11 +12,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import inspect
 from abc import ABC, abstractmethod
 from .....utils.subclass_register import AutoRegisterABCMetaClass
 
-import inspect
-
 
 class BaseGeneratePrompt(ABC, metaclass=AutoRegisterABCMetaClass):
     """Base Generate Prompt class."""

+ 127 - 0
paddlex/inference/pipelines_new/components/prompt_engeering/generate_ensemble_prompt.py

@@ -0,0 +1,127 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Dict
+from .base import BaseGeneratePrompt
+
+
+class GenerateEnsemblePrompt(BaseGeneratePrompt):
+    """Generate Ensemble Prompt"""
+
+    entities = ["ensemble_prompt"]
+
+    def __init__(self, config: Dict) -> None:
+        """Initializes the GenerateEnsemblePrompt instance with the given configuration.
+
+        Args:
+            config (Dict): A dictionary containing configuration settings.
+                - task_type (str): The type of task to generate a prompt for, in the support entities list.
+                - task_description (str, optional): A description of the task. Defaults to an empty string.
+                - output_format (str, optional): The desired output format. Defaults to an empty string.
+                - rules_str (str, optional): A string representing rules for the task. Defaults to an empty string.
+                - few_shot_demo_text_content (str, optional): Text content for few-shot demos. Defaults to an empty string.
+                - few_shot_demo_key_value_list (str, optional): A key-value list for few-shot demos. Defaults to an empty string.
+
+        Raises:
+            ValueError: If the task type is not in the allowed entities for GenerateKIEPrompt.
+        """
+        super().__init__()
+
+        task_type = config.get("task_type", "")
+        task_description = config.get("task_description", "")
+        output_format = config.get("output_format", "")
+        rules_str = config.get("rules_str", "")
+        few_shot_demo_text_content = config.get("few_shot_demo_text_content", "")
+        few_shot_demo_key_value_list = config.get("few_shot_demo_key_value_list", "")
+
+        if task_description is None:
+            task_description = ""
+
+        if output_format is None:
+            output_format = ""
+
+        if rules_str is None:
+            rules_str = ""
+
+        if few_shot_demo_text_content is None:
+            few_shot_demo_text_content = ""
+
+        if few_shot_demo_key_value_list is None:
+            few_shot_demo_key_value_list = ""
+
+        if task_type not in self.entities:
+            raise ValueError(
+                f"task type must be in {self.entities} of GenerateEnsemblePrompt."
+            )
+
+        self.task_type = task_type
+        self.task_description = task_description
+        self.output_format = output_format
+        self.rules_str = rules_str
+        self.few_shot_demo_text_content = few_shot_demo_text_content
+        self.few_shot_demo_key_value_list = few_shot_demo_key_value_list
+
+    def generate_prompt(
+        self,
+        key: str,
+        result_methodA: str,
+        result_methodB: str,
+        task_description: str = None,
+        output_format: str = None,
+        rules_str: str = None,
+        few_shot_demo_text_content: str = None,
+        few_shot_demo_key_value_list: str = None,
+    ) -> str:
+        """Generates a prompt based on the given parameters.
+        Args:
+            key (str): the input question.
+            result_methodA (str): the result of method A.
+            result_methodB (str): the result of method B.
+            task_description (str, optional): A description of the task. Defaults to None.
+            output_format (str, optional): The desired output format. Defaults to None.
+            rules_str (str, optional): A string containing rules or instructions. Defaults to None.
+            few_shot_demo_text_content (str, optional): Text content for few-shot demos. Defaults to None.
+            few_shot_demo_key_value_list (str, optional): Key-value list for few-shot demos. Defaults to None.
+        Returns:
+            str: The generated prompt.
+
+        Raises:
+            ValueError: If the task_type is not supported.
+        """
+        if task_description is None:
+            task_description = self.task_description
+
+        if output_format is None:
+            output_format = self.output_format
+
+        if rules_str is None:
+            rules_str = self.rules_str
+
+        if few_shot_demo_text_content is None:
+            few_shot_demo_text_content = self.few_shot_demo_text_content
+
+        if few_shot_demo_key_value_list is None:
+            few_shot_demo_key_value_list = self.few_shot_demo_key_value_list
+
+        prompt = f"""{task_description}{rules_str}{output_format}{few_shot_demo_text_content}{few_shot_demo_key_value_list}"""
+        task_type = self.task_type
+        if task_type == "ensemble_prompt":
+            prompt += f"""下面正式开始:
+                \n问题:```{key}```
+                \n方法A的结果:{result_methodA}
+                \n方法B的结果:{result_methodB}
+                """
+        else:
+            raise ValueError(f"{self.task_type} is currently not supported.")
+        return prompt

+ 23 - 17
paddlex/inference/pipelines_new/components/prompt_engeering/generate_kie_prompt.py

@@ -12,14 +12,20 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from .base import BaseGeneratePrompt
 from typing import Dict
+from .base import BaseGeneratePrompt
+from .....utils import logging
 
 
 class GenerateKIEPrompt(BaseGeneratePrompt):
     """Generate KIE Prompt"""
 
-    entities = ["text_kie_prompt", "table_kie_prompt"]
+    entities = [
+        "text_kie_prompt_v1",
+        "table_kie_prompt_v1",
+        "text_kie_prompt_v2",
+        "table_kie_prompt_v2",
+    ]
 
     def __init__(self, config: Dict) -> None:
         """Initializes the GenerateKIEPrompt instance with the given configuration.
@@ -92,7 +98,6 @@ class GenerateKIEPrompt(BaseGeneratePrompt):
             rules_str (str, optional): A string containing rules or instructions. Defaults to None.
             few_shot_demo_text_content (str, optional): Text content for few-shot demos. Defaults to None.
             few_shot_demo_key_value_list (str, optional): Key-value list for few-shot demos. Defaults to None.
-
         Returns:
             str: The generated prompt.
 
@@ -115,28 +120,29 @@ class GenerateKIEPrompt(BaseGeneratePrompt):
             few_shot_demo_key_value_list = self.few_shot_demo_key_value_list
 
         prompt = f"""{task_description}{rules_str}{output_format}{few_shot_demo_text_content}{few_shot_demo_key_value_list}"""
-        if self.task_type == "table_kie_prompt":
+        task_type = self.task_type
+        if task_type == "table_kie_prompt_v1":
+            prompt += f"""\n结合上面,下面正式开始:\
+                表格内容:```{text_content}```\
+                关键词列表:[{key_list}]。""".replace(
+                "    ", ""
+            )
+        elif task_type == "text_kie_prompt_v1":
+            prompt += f"""\n结合上面的例子,下面正式开始:\
+                OCR文字:```{text_content}```\
+                关键词列表:[{key_list}]。""".replace(
+                "    ", ""
+            )
+        elif task_type == "table_kie_prompt_v2":
             prompt += f"""\n结合上面,下面正式开始:\
                 表格内容:```{text_content}```\
                 \n问题列表:{key_list}。""".replace(
                 "    ", ""
             )
-            # prompt += f"""\n结合上面,下面正式开始:\
-            #     表格内容:```{text_content}```\
-            #     \n关键词列表:{key_list}。""".replace(
-            #     "    ", ""
-            # )
-        elif self.task_type == "text_kie_prompt":
+        elif task_type == "text_kie_prompt_v2":
             prompt += f"""\n结合上面的例子,下面正式开始:\
                 OCR文字:```{text_content}```\
                 \n问题列表:{key_list}。""".replace(
                 "    ", ""
             )
-            # prompt += f"""\n结合上面的例子,下面正式开始:\
-            #     OCR文字:```{text_content}```\
-            #     \n关键词列表:{key_list}。""".replace(
-            #     "    ", ""
-            # )
-        else:
-            raise ValueError(f"{self.task_type} is currently not supported.")
         return prompt

+ 1 - 2
paddlex/inference/pipelines_new/components/retriever/base.py

@@ -13,10 +13,9 @@
 # limitations under the License.
 
 from abc import ABC, abstractmethod
-from .....utils.subclass_register import AutoRegisterABCMetaClass
-
 import inspect
 import base64
+from .....utils.subclass_register import AutoRegisterABCMetaClass
 
 
 class BaseRetriever(ABC, metaclass=AutoRegisterABCMetaClass):

+ 12 - 11
paddlex/inference/pipelines_new/components/retriever/ernie_bot_retriever.py

@@ -12,20 +12,16 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from .base import BaseRetriever
+from typing import Dict
+import time
 import os
-
 from langchain.docstore.document import Document
 from langchain.text_splitter import RecursiveCharacterTextSplitter
-
 from langchain_community.embeddings import QianfanEmbeddingsEndpoint
 from langchain_community.vectorstores import FAISS
 from langchain_community import vectorstores
 from erniebot_agent.extensions.langchain.embeddings import ErnieEmbeddings
-
-import time
-
-from typing import Dict
+from .base import BaseRetriever
 
 
 class ErnieBotRetriever(BaseRetriever):
@@ -112,7 +108,6 @@ class ErnieBotRetriever(BaseRetriever):
         )
         texts = text_splitter.split_text("\t".join(text_list))
         all_splits = [Document(page_content=text) for text in texts]
-
         api_type = self.config["api_type"]
         if api_type == "qianfan":
             os.environ["QIANFAN_AK"] = os.environ.get("EB_AK", self.config["ak"])
@@ -192,7 +187,12 @@ class ErnieBotRetriever(BaseRetriever):
         return vector
 
     def similarity_retrieval(
-        self, query_text_list: list[str], vectorstore: FAISS, sleep_time: float = 0.5
+        self,
+        query_text_list: list[str],
+        vectorstore: FAISS,
+        sleep_time: float = 0.5,
+        topk: int = 2,
+        min_characters: int = 3500,
     ) -> str:
         """
         Retrieve similar contexts based on a list of query texts.
@@ -201,7 +201,8 @@ class ErnieBotRetriever(BaseRetriever):
             query_text_list (list[str]): A list of query texts to search for similar contexts.
             vectorstore (FAISS): The vector store where to perform the similarity search.
             sleep_time (float): The time to sleep between each query, in seconds. Default is 0.5.
-
+            topk (int): The number of results to retrieve per query. Default is 2.
+            min_characters (int): The minimum number of characters required for text processing, defaults to 3500.
         Returns:
             str: A concatenated string of all unique contexts found.
         """
@@ -209,7 +210,7 @@ class ErnieBotRetriever(BaseRetriever):
         for query_text in query_text_list:
             QUESTION = query_text
             time.sleep(sleep_time)
-            docs = vectorstore.similarity_search_with_relevance_scores(QUESTION, k=2)
+            docs = vectorstore.similarity_search_with_relevance_scores(QUESTION, k=topk)
             context = [(document.page_content, score) for document, score in docs]
             context = sorted(context, key=lambda x: x[1])
             C.extend([x[0] for x in context[::-1]])

+ 9 - 21
paddlex/inference/pipelines_new/doc_preprocessor/pipeline.py

@@ -12,16 +12,14 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from ..base import BasePipeline
 from typing import Any, Dict, Optional
 from scipy.ndimage import rotate
+import numpy as np
+from ..base import BasePipeline
 from .result import DocPreprocessorResult
 from ....utils import logging
-import numpy as np
-
-########## [TODO]后续需要更新路径
-from ...components.transforms import ReadImage
-
+from ...common.reader import ReadImage
+from ...common.batch_sampler import ImageBatchSampler
 from ...utils.pp_option import PaddlePredictorOption
 
 
@@ -68,6 +66,7 @@ class DocPreprocessorPipeline(BasePipeline):
             doc_unwarping_config = config["SubModules"]["DocUnwarping"]
             self.doc_unwarping_model = self.create_model(doc_unwarping_config)
 
+        self.batch_sampler = ImageBatchSampler(batch_size=1)
         self.img_reader = ReadImage(format="BGR")
 
     def rotate_image(self, image_array: np.ndarray, rotate_angle: float) -> np.ndarray:
@@ -128,7 +127,7 @@ class DocPreprocessorPipeline(BasePipeline):
         Predict the preprocessing result for the input image or images.
 
         Args:
-            input (str | list[str] | np.ndarray | list[np.ndarray]): The input image(s) or path(s) to the images.
+            input (str | list[str] | np.ndarray | list[np.ndarray]): The input image(s) or path(s) to the images or pdfs.
             use_doc_orientation_classify (bool): Whether to use document orientation classification.
             use_doc_unwarping (bool): Whether to use document unwarping.
             **kwargs: Additional keyword arguments.
@@ -137,11 +136,6 @@ class DocPreprocessorPipeline(BasePipeline):
             DocPreprocessorResult: A generator yielding preprocessing results.
         """
 
-        if not isinstance(input, list):
-            input_list = [input]
-        else:
-            input_list = input
-
         input_params = {
             "use_doc_orientation_classify": use_doc_orientation_classify,
             "use_doc_unwarping": use_doc_unwarping,
@@ -150,14 +144,8 @@ class DocPreprocessorPipeline(BasePipeline):
         if not self.check_input_params_valid(input_params):
             yield {"error": "input params invalid"}
 
-        img_id = 1
-        for input in input_list:
-            if isinstance(input, str):
-                image_array = next(self.img_reader(input))[0]["img"]
-            else:
-                image_array = input
-
-            assert len(image_array.shape) == 3
+        for img_id, batch_data in enumerate(self.batch_sampler(input)):
+            image_array = self.img_reader(batch_data)[0]
 
             if input_params["use_doc_orientation_classify"]:
                 pred = next(self.doc_ori_classify_model(image_array))
@@ -172,6 +160,7 @@ class DocPreprocessorPipeline(BasePipeline):
             else:
                 output_img = rot_img
 
+            img_id += 1
             single_img_res = {
                 "input_image": image_array,
                 "input_params": input_params,
@@ -180,5 +169,4 @@ class DocPreprocessorPipeline(BasePipeline):
                 "output_img": output_img,
                 "img_id": img_id,
             }
-            img_id += 1
             yield DocPreprocessorResult(single_img_res)

+ 18 - 11
paddlex/inference/pipelines_new/doc_preprocessor/result.py

@@ -14,16 +14,16 @@
 
 import math
 import random
+from pathlib import Path
 import numpy as np
 import cv2
 import PIL
 from PIL import Image, ImageDraw, ImageFont
-
 from ....utils.fonts import PINGFANG_FONT_FILE_PATH, create_font
-from ..components import CVResult
+from ...common.result import BaseCVResult
 
 
-class DocPreprocessorResult(CVResult):
+class DocPreprocessorResult(BaseCVResult):
     """doc preprocessor result"""
 
     def save_to_img(self, save_path: str, *args, **kwargs) -> None:
@@ -42,7 +42,7 @@ class DocPreprocessorResult(CVResult):
         """
         if not str(save_path).lower().endswith((".jpg", ".png")):
             img_id = self["img_id"]
-            save_path = save_path + "/res_doc_preprocess_%d.jpg" % img_id
+            save_path = Path(save_path) / f"res_doc_preprocess_{img_id}.jpg"
         super().save_to_img(save_path, *args, **kwargs)
 
     def _to_img(self) -> PIL.Image:
@@ -56,16 +56,23 @@ class DocPreprocessorResult(CVResult):
         angle = self["angle"]
         rot_img = self["rot_img"][:, :, ::-1]
         output_img = self["output_img"][:, :, ::-1]
-        h, w = image.shape[0:2]
-        img_show = Image.new("RGB", (w * 3, h + 25), (255, 255, 255))
-        img_show.paste(Image.fromarray(image), (0, 0, w, h))
-        img_show.paste(Image.fromarray(rot_img), (w, 0, w * 2, h))
-        img_show.paste(Image.fromarray(output_img), (w * 2, 0, w * 3, h))
+        h1, w1 = image.shape[0:2]
+        h2, w2 = rot_img.shape[0:2]
+        h3, w3 = output_img.shape[0:2]
+        h = max(max(h1, h2), h3)
+        img_show = Image.new("RGB", (w1 + w2 + w3, h + 25), (255, 255, 255))
+        img_show.paste(Image.fromarray(image), (0, 0, w1, h1))
+        img_show.paste(Image.fromarray(rot_img), (w1, 0, w1 + w2, h2))
+        img_show.paste(Image.fromarray(output_img), (w1 + w2, 0, w1 + w2 + w3, h3))
 
         draw_text = ImageDraw.Draw(img_show)
         txt_list = ["Original Image", "Rotated Image", "Unwarping Image"]
+        region_w_list = [w1, w2, w3]
+        beg_w_list = [0, w1, w1 + w2]
         for tno in range(len(txt_list)):
             txt = txt_list[tno]
-            font = create_font(txt, (w, 20), PINGFANG_FONT_FILE_PATH)
-            draw_text.text([10 + w * tno, h + 2], txt, fill=(0, 0, 0), font=font)
+            font = create_font(txt, (region_w_list[tno], 20), PINGFANG_FONT_FILE_PATH)
+            draw_text.text(
+                [10 + beg_w_list[tno], h + 2], txt, fill=(0, 0, 0), font=font
+            )
         return img_show

+ 2 - 2
paddlex/inference/pipelines_new/image_classification/pipeline.py

@@ -14,13 +14,13 @@
 
 from typing import Any, Dict, Optional
 import numpy as np
-
 from ...common.reader import ReadImage
 from ...common.batch_sampler import ImageBatchSampler
 from ...utils.pp_option import PaddlePredictorOption
 from ..base import BasePipeline
+
+# [TODO] 待更新models_new到models
 from ...models_new.image_classification.result import TopkResult
-from ...results import TopkResult
 
 
 class ImageClassificationPipeline(BasePipeline):

+ 104 - 115
paddlex/inference/pipelines_new/layout_parsing/pipeline.py

@@ -12,31 +12,28 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from ..base import BasePipeline
 from typing import Any, Dict, Optional
+import os, sys
 import numpy as np
 import cv2
-from ..components import CropByBoxes
+from ..base import BasePipeline
 from .utils import convert_points_to_boxes, get_sub_regions_ocr_res
-from .table_recognition_post_processing import get_table_recognition_res
-
 from .result import LayoutParsingResult
-
 from ....utils import logging
-
 from ...utils.pp_option import PaddlePredictorOption
-
-########## [TODO]后续需要更新路径
-from ...components.transforms import ReadImage
-
+from ...common.reader import ReadImage
+from ...common.batch_sampler import ImageBatchSampler
 from ..ocr.result import OCRResult
-from ...results import DetResult
+from ..doc_preprocessor.result import DocPreprocessorResult
+
+# [TODO] 待更新models_new到models
+from ...models_new.object_detection.result import DetResult
 
 
 class LayoutParsingPipeline(BasePipeline):
     """Layout Parsing Pipeline"""
 
-    entities = ["layout_parsing", "seal_recognition", "table_recognition"]
+    entities = ["layout_parsing"]
 
     def __init__(
         self,
@@ -62,9 +59,9 @@ class LayoutParsingPipeline(BasePipeline):
 
         self.inintial_predictor(config)
 
-        self.img_reader = ReadImage(format="BGR")
+        self.batch_sampler = ImageBatchSampler(batch_size=1)
 
-        self._crop_by_boxes = CropByBoxes()
+        self.img_reader = ReadImage(format="BGR")
 
     def set_used_models_flag(self, config: Dict) -> None:
         """
@@ -88,19 +85,14 @@ class LayoutParsingPipeline(BasePipeline):
         if "use_doc_preprocessor" in config:
             self.use_doc_preprocessor = config["use_doc_preprocessor"]
 
-        if pipeline_name == "layout_parsing":
-            if "use_general_ocr" in config:
-                self.use_general_ocr = config["use_general_ocr"]
-            if "use_seal_recognition" in config:
-                self.use_seal_recognition = config["use_seal_recognition"]
-            if "use_table_recognition" in config:
-                self.use_table_recognition = config["use_table_recognition"]
+        if "use_general_ocr" in config:
+            self.use_general_ocr = config["use_general_ocr"]
 
-        elif pipeline_name == "seal_recognition":
-            self.use_seal_recognition = True
+        if "use_seal_recognition" in config:
+            self.use_seal_recognition = config["use_seal_recognition"]
 
-        elif pipeline_name == "table_recognition":
-            self.use_table_recognition = True
+        if "use_table_recognition" in config:
+            self.use_table_recognition = config["use_table_recognition"]
 
     def inintial_predictor(self, config: Dict) -> None:
         """Initializes the predictor based on the provided configuration.
@@ -123,20 +115,22 @@ class LayoutParsingPipeline(BasePipeline):
                 doc_preprocessor_config
             )
 
-        if self.use_general_ocr:
+        if self.use_general_ocr or self.use_table_recognition:
             general_ocr_config = config["SubPipelines"]["GeneralOCR"]
             self.general_ocr_pipeline = self.create_pipeline(general_ocr_config)
 
         if self.use_seal_recognition:
-            seal_ocr_config = config["SubPipelines"]["SealOCR"]
-            self.seal_ocr_pipeline = self.create_pipeline(seal_ocr_config)
+            seal_recognition_config = config["SubPipelines"]["SealRecognition"]
+            self.seal_recognition_pipeline = self.create_pipeline(
+                seal_recognition_config
+            )
 
         if self.use_table_recognition:
-            table_structure_config = config["SubModules"]["TableStructureRecognition"]
-            self.table_structure_model = self.create_model(table_structure_config)
-            if not self.use_general_ocr:
-                general_ocr_config = config["SubPipelines"]["GeneralOCR"]
-                self.general_ocr_pipeline = self.create_pipeline(general_ocr_config)
+            table_recognition_config = config["SubPipelines"]["TableRecognition"]
+            self.table_recognition_pipeline = self.create_pipeline(
+                table_recognition_config
+            )
+
         return
 
     def get_text_paragraphs_ocr_res(
@@ -196,23 +190,50 @@ class LayoutParsingPipeline(BasePipeline):
 
         return True
 
-    def convert_input_params(self, input_params: Dict) -> None:
+    def predict_doc_preprocessor_res(
+        self, image_array: np.ndarray, input_params: dict
+    ) -> tuple[DocPreprocessorResult, np.ndarray]:
         """
-        Convert input parameters based on the pipeline name.
+        Preprocess the document image based on input parameters.
 
         Args:
-            input_params (Dict): The input parameters dictionary.
+            image_array (np.ndarray): The input image array.
+            input_params (dict): Dictionary containing preprocessing parameters.
 
         Returns:
-            None
+            tuple[DocPreprocessorResult, np.ndarray]: A tuple containing the preprocessing
+                                              result dictionary and the processed image array.
         """
-        if self.pipeline_name == "seal_recognition":
-            input_params["use_general_ocr"] = False
-            input_params["use_table_recognition"] = False
-        elif self.pipeline_name == "table_recognition":
-            input_params["use_general_ocr"] = False
-            input_params["use_seal_recognition"] = False
-        return
+        if input_params["use_doc_preprocessor"]:
+            use_doc_orientation_classify = input_params["use_doc_orientation_classify"]
+            use_doc_unwarping = input_params["use_doc_unwarping"]
+            doc_preprocessor_res = next(
+                self.doc_preprocessor_pipeline(
+                    image_array,
+                    use_doc_orientation_classify=use_doc_orientation_classify,
+                    use_doc_unwarping=use_doc_unwarping,
+                )
+            )
+            doc_preprocessor_image = doc_preprocessor_res["output_img"]
+        else:
+            doc_preprocessor_res = {}
+            doc_preprocessor_image = image_array
+        return doc_preprocessor_res, doc_preprocessor_image
+
+    def predict_overall_ocr_res(self, image_array: np.ndarray) -> OCRResult:
+        """
+        Predict the overall OCR result for the given image array.
+
+        Args:
+            image_array (np.ndarray): The input image array to perform OCR on.
+
+        Returns:
+            OCRResult: The predicted OCR result with updated dt_boxes.
+        """
+        overall_ocr_res = next(self.general_ocr_pipeline(image_array))
+        dt_boxes = convert_points_to_boxes(overall_ocr_res["dt_polys"])
+        overall_ocr_res["dt_boxes"] = dt_boxes
+        return overall_ocr_res
 
     def predict(
         self,
@@ -228,7 +249,7 @@ class LayoutParsingPipeline(BasePipeline):
         This function predicts the layout parsing result for the given input.
 
         Args:
-            input (str | list[str] | np.ndarray | list[np.ndarray]): The input image(s) to be processed.
+            input (str | list[str] | np.ndarray | list[np.ndarray]): The input image(s) or pdf(s) to be processed.
             use_doc_orientation_classify (bool): Whether to use document orientation classification.
             use_doc_unwarping (bool): Whether to use document unwarping.
             use_general_ocr (bool): Whether to use general OCR.
@@ -240,11 +261,6 @@ class LayoutParsingPipeline(BasePipeline):
             LayoutParsingResult: The predicted layout parsing result.
         """
 
-        if not isinstance(input, list):
-            input_list = [input]
-        else:
-            input_list = input
-
         input_params = {
             "use_doc_preprocessor": self.use_doc_preprocessor,
             "use_doc_orientation_classify": use_doc_orientation_classify,
@@ -254,100 +270,73 @@ class LayoutParsingPipeline(BasePipeline):
             "use_table_recognition": use_table_recognition,
         }
 
-        self.convert_input_params(input_params)
-
         if use_doc_orientation_classify or use_doc_unwarping:
             input_params["use_doc_preprocessor"] = True
         else:
             input_params["use_doc_preprocessor"] = False
 
         if not self.check_input_params_valid(input_params):
-            yield {"error": "input params invalid"}
+            yield None
 
-        img_id = 1
-        for input in input_list:
-            if isinstance(input, str):
-                image_array = next(self.img_reader(input))[0]["img"]
-            else:
-                image_array = input
-
-            assert len(image_array.shape) == 3
+        for img_id, batch_data in enumerate(self.batch_sampler(input)):
+            image_array = self.img_reader(batch_data)[0]
+            img_id += 1
 
-            if input_params["use_doc_preprocessor"]:
-                doc_preprocessor_res = next(
-                    self.doc_preprocessor_pipeline(
-                        image_array,
-                        use_doc_orientation_classify=use_doc_orientation_classify,
-                        use_doc_unwarping=use_doc_unwarping,
-                    )
-                )
-                doc_preprocessor_image = doc_preprocessor_res["output_img"]
-                doc_preprocessor_res["img_id"] = img_id
-            else:
-                doc_preprocessor_res = {}
-                doc_preprocessor_image = image_array
+            doc_preprocessor_res, doc_preprocessor_image = (
+                self.predict_doc_preprocessor_res(image_array, input_params)
+            )
 
-            ########## [TODO]RT-DETR 检测结果有重复
             layout_det_res = next(self.layout_det_model(doc_preprocessor_image))
 
             if input_params["use_general_ocr"] or input_params["use_table_recognition"]:
-                overall_ocr_res = next(
-                    self.general_ocr_pipeline(doc_preprocessor_image)
-                )
-                overall_ocr_res["img_id"] = img_id
-                dt_boxes = convert_points_to_boxes(overall_ocr_res["dt_polys"])
-                overall_ocr_res["dt_boxes"] = dt_boxes
+                overall_ocr_res = self.predict_overall_ocr_res(doc_preprocessor_image)
             else:
                 overall_ocr_res = {}
 
-            text_paragraphs_ocr_res = {}
             if input_params["use_general_ocr"]:
                 text_paragraphs_ocr_res = self.get_text_paragraphs_ocr_res(
                     overall_ocr_res, layout_det_res
                 )
-                text_paragraphs_ocr_res["img_id"] = img_id
+            else:
+                text_paragraphs_ocr_res = {}
 
-            table_res_list = []
             if input_params["use_table_recognition"]:
-                table_region_id = 1
-                for box_info in layout_det_res["boxes"]:
-                    if box_info["label"].lower() in ["table"]:
-                        crop_img_info = self._crop_by_boxes(
-                            doc_preprocessor_image, [box_info]
-                        )
-                        crop_img_info = crop_img_info[0]
-                        table_structure_pred = next(
-                            self.table_structure_model(crop_img_info["img"])
-                        )
-                        table_recognition_res = get_table_recognition_res(
-                            crop_img_info, table_structure_pred, overall_ocr_res
-                        )
-                        table_recognition_res["table_region_id"] = table_region_id
-                        table_region_id += 1
-                        table_res_list.append(table_recognition_res)
-
-            seal_res_list = []
+                table_res_list = next(
+                    self.table_recognition_pipeline(
+                        doc_preprocessor_image,
+                        use_layout_detection=False,
+                        use_doc_orientation_classify=False,
+                        use_doc_unwarping=False,
+                        overall_ocr_res=overall_ocr_res,
+                        layout_det_res=layout_det_res,
+                    )
+                )
+                table_res_list = table_res_list["table_res_list"]
+            else:
+                table_res_list = []
+
             if input_params["use_seal_recognition"]:
-                seal_region_id = 1
-                for box_info in layout_det_res["boxes"]:
-                    if box_info["label"].lower() in ["seal"]:
-                        crop_img_info = self._crop_by_boxes(
-                            doc_preprocessor_image, [box_info]
-                        )
-                        crop_img_info = crop_img_info[0]
-                        seal_ocr_res = next(
-                            self.seal_ocr_pipeline(crop_img_info["img"])
-                        )
-                        seal_ocr_res["seal_region_id"] = seal_region_id
-                        seal_region_id += 1
-                        seal_res_list.append(seal_ocr_res)
+                seal_res_list = next(
+                    self.seal_recognition_pipeline(
+                        doc_preprocessor_image,
+                        use_layout_detection=False,
+                        use_doc_orientation_classify=False,
+                        use_doc_unwarping=False,
+                        layout_det_res=layout_det_res,
+                    )
+                )
+                seal_res_list = seal_res_list["seal_res_list"]
+            else:
+                seal_res_list = []
 
             single_img_res = {
                 "layout_det_res": layout_det_res,
                 "doc_preprocessor_res": doc_preprocessor_res,
+                "overall_ocr_res": overall_ocr_res,
                 "text_paragraphs_ocr_res": text_paragraphs_ocr_res,
                 "table_res_list": table_res_list,
                 "seal_res_list": seal_res_list,
                 "input_params": input_params,
+                "img_id": img_id,
             }
             yield LayoutParsingResult(single_img_res)

+ 29 - 125
paddlex/inference/pipelines_new/layout_parsing/result.py

@@ -12,124 +12,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import math
-import random
-import numpy as np
-import cv2
-import PIL
 import os
-from PIL import Image, ImageDraw, ImageFont
-
-from ....utils.fonts import PINGFANG_FONT_FILE_PATH
-from ..components import CVResult, HtmlMixin, XlsxMixin
-
-from typing import Any, Dict, Optional
-
-
-class TableRecognitionResult(CVResult, HtmlMixin, XlsxMixin):
-    """table recognition result"""
-
-    def __init__(self, data: Dict) -> None:
-        """Initializes the object with given data and sets up mixins for HTML and XLSX processing."""
-        super().__init__(data)
-        HtmlMixin.__init__(self)  # Initializes the HTML mixin functionality
-        XlsxMixin.__init__(self)  # Initializes the XLSX mixin functionality
-
-    def save_to_html(self, save_path: str, *args, **kwargs) -> None:
-        """
-        Save the content to an HTML file.
-
-        Args:
-            save_path (str): The path to save the HTML file. If the path does not end with '.html',
-                          it will append '/res_table_%d.html' % self['table_region_id'] to the path.
-            *args: Additional positional arguments to be passed to the superclass method.
-            **kwargs: Additional keyword arguments to be passed to the superclass method.
-
-        Returns:
-            None
-        """
-        if not str(save_path).lower().endswith(".html"):
-            save_path = save_path + "/res_table_%d.html" % self["table_region_id"]
-        super().save_to_html(save_path, *args, **kwargs)
-
-    def _to_html(self) -> str:
-        """Converts the prediction to its corresponding HTML representation.
-
-        Returns:
-            str: The HTML string representation of the prediction.
-        """
-        return self["pred_html"]
-
-    def save_to_xlsx(self, save_path: str, *args, **kwargs) -> None:
-        """
-        Save the content to an Excel file (.xlsx).
-
-        If the save_path does not end with '.xlsx', it appends a default filename
-        based on the table_region_id attribute.
-
-        Args:
-            save_path (str): The path where the Excel file should be saved.
-            *args: Additional positional arguments passed to the superclass method.
-            **kwargs: Additional keyword arguments passed to the superclass method.
-
-        Returns:
-            None
-        """
-        if not str(save_path).lower().endswith(".xlsx"):
-            save_path = save_path + "/res_table_%d.xlsx" % self["table_region_id"]
-        super().save_to_xlsx(save_path, *args, **kwargs)
-
-    def _to_xlsx(self) -> str:
-        """Converts the prediction HTML to an XLSX file path.
-
-        Returns:
-            str: The path to the XLSX file containing the prediction data.
-        """
-        return self["pred_html"]
-
-    def save_to_img(self, save_path: str, *args, **kwargs) -> None:
-        """
-        Save the table and OCR result images to the specified path.
-
-        Args:
-            save_path (str): The directory path to save the images.
-            *args: Additional positional arguments.
-            **kwargs: Additional keyword arguments.
-
-        Returns:
-            None
-
-        Raises:
-            No specific exceptions are raised.
-
-        Notes:
-            - If save_path does not end with '.jpg' or '.png', the function appends '_res_table_cell_%d.jpg' and '_res_table_ocr_%d.jpg' to save_path
-              with table_region_id respectively for table cell and OCR images.
-            - The OCR result image is saved first with '_res_table_ocr_%d.jpg'.
-            - Then the table image is saved with '_res_table_cell_%d.jpg'.
-            - Calls the superclass's save_to_img method to save the table image.
-        """
-        if not str(save_path).lower().endswith((".jpg", ".png")):
-            ocr_save_path = (
-                save_path + "/res_table_ocr_%d.jpg" % self["table_region_id"]
-            )
-            save_path = save_path + "/res_table_cell_%d.jpg" % self["table_region_id"]
-        self["table_ocr_pred"].save_to_img(ocr_save_path)
-        super().save_to_img(save_path, *args, **kwargs)
-
-    def _to_img(self) -> np.ndarray:
-        """
-        Convert the input image with table OCR predictions to an image with cell boundaries highlighted.
-
-        Returns:
-            np.ndarray: The input image with cell boundaries highlighted in red.
-        """
-        input_img = self["table_ocr_pred"]["input_img"].copy()
-        cell_box_list = self["cell_box_list"]
-        for box in cell_box_list:
-            x1, y1, x2, y2 = [int(pos) for pos in box]
-            cv2.rectangle(input_img, (x1, y1), (x2, y2), (255, 0, 0), 2)
-        return input_img
+from pathlib import Path
 
 
 class LayoutParsingResult(dict):
@@ -149,32 +33,52 @@ class LayoutParsingResult(dict):
         if not os.path.isdir(save_path):
             return
 
+        img_id = self["img_id"]
         layout_det_res = self["layout_det_res"]
-        save_img_path = save_path + "/layout_det_result.jpg"
+        save_img_path = Path(save_path) / f"layout_det_result_img{img_id}.jpg"
         layout_det_res.save_to_img(save_img_path)
 
         input_params = self["input_params"]
         if input_params["use_doc_preprocessor"]:
-            save_img_path = save_path + "/doc_preprocessor_result.jpg"
+            save_img_path = Path(save_path) / f"doc_preprocessor_result_img{img_id}.jpg"
             self["doc_preprocessor_res"].save_to_img(save_img_path)
 
         if input_params["use_general_ocr"]:
-            save_img_path = save_path + "/text_paragraphs_ocr_result.jpg"
+            save_img_path = (
+                Path(save_path) / f"text_paragraphs_ocr_result_img{img_id}.jpg"
+            )
             self["text_paragraphs_ocr_res"].save_to_img(save_img_path)
 
+        if input_params["use_general_ocr"] or input_params["use_table_recognition"]:
+            save_img_path = Path(save_path) / f"overall_ocr_result_img{img_id}.jpg"
+            self["overall_ocr_res"].save_to_img(save_img_path)
+
         if input_params["use_table_recognition"]:
             for tno in range(len(self["table_res_list"])):
                 table_res = self["table_res_list"][tno]
-                table_res.save_to_img(save_path)
-                table_res.save_to_html(save_path)
-                table_res.save_to_xlsx(save_path)
+                table_region_id = table_res["table_region_id"]
+                save_img_path = (
+                    Path(save_path)
+                    / f"table_res_cell_img{img_id}_region{table_region_id}.jpg"
+                )
+                table_res.save_to_img(save_img_path)
+                save_html_path = (
+                    Path(save_path)
+                    / f"table_res_img{img_id}_region{table_region_id}.html"
+                )
+                table_res.save_to_html(save_html_path)
+                save_xlsx_path = (
+                    Path(save_path)
+                    / f"table_res_img{img_id}_region{table_region_id}.xlsx"
+                )
+                table_res.save_to_xlsx(save_xlsx_path)
 
         if input_params["use_seal_recognition"]:
             for sno in range(len(self["seal_res_list"])):
                 seal_res = self["seal_res_list"][sno]
+                seal_region_id = seal_res["seal_region_id"]
                 save_img_path = (
-                    save_path
-                    + "/seal_%d_recognition_result.jpg" % seal_res["seal_region_id"]
+                    Path(save_path) / f"seal_res_img{img_id}_region{seal_region_id}.jpg"
                 )
                 seal_res.save_to_img(save_img_path)
         return

+ 4 - 3
paddlex/inference/pipelines_new/ocr/pipeline.py

@@ -14,7 +14,6 @@
 
 from typing import Any, Dict, Optional
 import numpy as np
-
 from ...common.reader import ReadImage
 from ...common.batch_sampler import ImageBatchSampler
 from ...utils.pp_option import PaddlePredictorOption
@@ -76,7 +75,7 @@ class OCRPipeline(BasePipeline):
         """Predicts OCR results for the given input.
 
         Args:
-            input (str | list[str] | np.ndarray | list[np.ndarray]): The input image(s) or path(s) to the images.
+            input (str | list[str] | np.ndarray | list[np.ndarray]): The input image(s) or path(s) to the images or pdf(s).
             **kwargs: Additional keyword arguments that can be passed to the function.
 
         Returns:
@@ -94,13 +93,15 @@ class OCRPipeline(BasePipeline):
 
             dt_polys = self._sort_boxes(dt_polys)
 
+            img_id += 1
+
             single_img_res = {
                 "input_img": raw_img,
                 "dt_polys": dt_polys,
                 "img_id": img_id,
                 "text_type": self.text_type,
             }
-            img_id += 1
+
             single_img_res["rec_text"] = []
             single_img_res["rec_score"] = []
             if len(dt_polys) > 0:

+ 4 - 4
paddlex/inference/pipelines_new/ocr/result.py

@@ -12,18 +12,18 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from pathlib import Path
 import math
 import random
 import numpy as np
 import cv2
 import PIL
 from PIL import Image, ImageDraw, ImageFont
-
 from ....utils.fonts import PINGFANG_FONT_FILE_PATH, create_font
-from ..components import CVResult
+from ...common.result import BaseCVResult
 
 
-class OCRResult(CVResult):
+class OCRResult(BaseCVResult):
     """OCR result"""
 
     def save_to_img(self, save_path: str, *args, **kwargs) -> None:
@@ -40,7 +40,7 @@ class OCRResult(CVResult):
         """
         if not str(save_path).lower().endswith((".jpg", ".png")):
             img_id = self["img_id"]
-            save_path = save_path + "/res_ocr_%d.jpg" % img_id
+            save_path = Path(save_path) / f"res_ocr_{img_id}.jpg"
         super().save_to_img(save_path, *args, **kwargs)
 
     def get_minarea_rect(self, points: np.ndarray) -> np.ndarray:

+ 16 - 0
paddlex/inference/pipelines_new/pp_chatocr/__init__.py

@@ -0,0 +1,16 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .pipeline_v3 import PP_ChatOCRv3_Pipeline
+from .pipeline_v4 import PP_ChatOCRv4_Pipeline

+ 106 - 0
paddlex/inference/pipelines_new/pp_chatocr/pipeline_base.py

@@ -0,0 +1,106 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Dict, Optional
+from ..base import BasePipeline
+from ....utils import logging
+from ...utils.pp_option import PaddlePredictorOption
+
+
+class PP_ChatOCR_Pipeline(BasePipeline):
+    """PP-ChatOCR Pipeline"""
+
+    def __init__(
+        self,
+        device: str = None,
+        pp_option: PaddlePredictorOption = None,
+        use_hpip: bool = False,
+        hpi_params: Optional[Dict[str, Any]] = None,
+    ) -> None:
+        """Initializes the pp-chatocrv3-doc pipeline.
+
+        Args:
+            config (Dict): Configuration dictionary containing various settings.
+            device (str, optional): Device to run the predictions on. Defaults to None.
+            pp_option (PaddlePredictorOption, optional): PaddlePredictor options. Defaults to None.
+            use_hpip (bool, optional): Whether to use high-performance inference (hpip) for prediction. Defaults to False.
+            hpi_params (Optional[Dict[str, Any]], optional): HPIP parameters. Defaults to None.
+        """
+
+        super().__init__(
+            device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_params=hpi_params
+        )
+
+    def visual_predict(self):
+        """
+        This function takes an input image or a list of images and performs various visual
+        prediction tasks such as document orientation classification, document unwarping,
+        general OCR, seal recognition, and table recognition based on the provided flags.
+        """
+
+        raise NotImplementedError(
+            "The method `visual_predict` has not been implemented yet."
+        )
+
+    def save_visual_info_list(self):
+        """
+        Save the visual info list to the specified file path.
+        """
+        raise NotImplementedError(
+            "The method `save_visual_info_list` has not been implemented yet."
+        )
+
+    def load_visual_info_list(self):
+        """
+        Loads visual info list from a file.
+        """
+        raise NotImplementedError(
+            "The method `load_visual_info_list` has not been implemented yet."
+        )
+
+    def build_vector(self):
+        """
+        Build a vector representation from visual information.
+        """
+        raise NotImplementedError(
+            "The method `build_vector` has not been implemented yet."
+        )
+
+    def save_vector(self):
+        """
+        Save the vector information to a specified path.
+        """
+        raise NotImplementedError(
+            "The method `save_vector` has not been implemented yet."
+        )
+
+    def load_vector(self):
+        """
+        Loads vector information from a file.
+        """
+        raise NotImplementedError(
+            "The method `load_vector` has not been implemented yet."
+        )
+
+    def chat(self):
+        """
+        Generates chat results based on the provided key list and visual information.
+        """
+        raise NotImplementedError("The method `chat` has not been implemented yet.")
+
+    def predict(self, *args, **kwargs) -> None:
+        logging.error(
+            "PP-ChatOCR Pipeline do not support to call `predict()` directly! Please invoke `visual_predict`, `build_vector`, `chat` sequentially to obtain the result."
+        )
+        return

+ 142 - 155
paddlex/inference/pipelines_new/pp_chatocrv3_doc/pipeline.py → paddlex/inference/pipelines_new/pp_chatocr/pipeline_v3.py

@@ -12,33 +12,24 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from ..base import BasePipeline
-
 from typing import Any, Dict, Optional
-
-# import numpy as np
-# import cv2
-from .result import VisualInfoResult
 import re
-
-########## [TODO]后续需要更新路径
-from ...components.transforms import ReadImage
-
 import json
-
+import numpy as np
+import copy
+from .pipeline_base import PP_ChatOCR_Pipeline
+from .result import VisualInfoResult
+from ...common.reader import ReadImage
+from ...common.batch_sampler import ImageBatchSampler
 from ....utils import logging
-
 from ...utils.pp_option import PaddlePredictorOption
-
 from ..layout_parsing.result import LayoutParsingResult
 
-import numpy as np
-
 
-class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
-    """PP-ChatOCRv3-doc Pipeline"""
+class PP_ChatOCRv3_Pipeline(PP_ChatOCR_Pipeline):
+    """PP-ChatOCR Pipeline"""
 
-    entities = "PP-ChatOCRv3-doc"
+    entities = ["PP-ChatOCRv3-doc"]
 
     def __init__(
         self,
@@ -47,7 +38,6 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
         pp_option: PaddlePredictorOption = None,
         use_hpip: bool = False,
         hpi_params: Optional[Dict[str, Any]] = None,
-        use_layout_parsing: bool = True,
     ) -> None:
         """Initializes the pp-chatocrv3-doc pipeline.
 
@@ -57,16 +47,18 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
             pp_option (PaddlePredictorOption, optional): PaddlePredictor options. Defaults to None.
             use_hpip (bool, optional): Whether to use high-performance inference (hpip) for prediction. Defaults to False.
             hpi_params (Optional[Dict[str, Any]], optional): HPIP parameters. Defaults to None.
+            use_layout_parsing (bool, optional): Whether to use layout parsing. Defaults to True.
         """
 
         super().__init__(
             device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_params=hpi_params
         )
 
-        self.use_layout_parsing = use_layout_parsing
+        self.pipeline_name = config["pipeline_name"]
 
         self.inintial_predictor(config)
 
+        self.batch_sampler = ImageBatchSampler(batch_size=1)
         self.img_reader = ReadImage(format="BGR")
 
         self.table_structure_len_max = 500
@@ -82,9 +74,8 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
             None
         """
 
-        if self.use_layout_parsing:
-            layout_parsing_config = config["SubPipelines"]["LayoutParser"]
-            self.layout_parsing_pipeline = self.create_pipeline(layout_parsing_config)
+        layout_parsing_config = config["SubPipelines"]["LayoutParser"]
+        self.layout_parsing_pipeline = self.create_pipeline(layout_parsing_config)
 
         from .. import create_chat_bot
 
@@ -103,7 +94,6 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
 
         table_pe_config = config["SubModules"]["PromptEngneering"]["KIE_Table"]
         self.table_pe = create_prompt_engeering(table_pe_config)
-
         return
 
     def decode_visual_result(
@@ -181,33 +171,14 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
             dict: A dictionary containing the layout parsing result and visual information.
         """
 
-        if not self.use_layout_parsing:
-            raise ValueError("The models for layout parsing are not initialized.")
-
-        if not isinstance(input, list):
-            input_list = [input]
-        else:
-            input_list = input
-
-        img_id = 1
-        for input in input_list:
-            if isinstance(input, str):
-                image_array = next(self.img_reader(input))[0]["img"]
-            else:
-                image_array = input
-
-            assert len(image_array.shape) == 3
-
-            layout_parsing_result = next(
-                self.layout_parsing_pipeline.predict(
-                    image_array,
-                    use_doc_orientation_classify=use_doc_orientation_classify,
-                    use_doc_unwarping=use_doc_unwarping,
-                    use_general_ocr=use_general_ocr,
-                    use_seal_recognition=use_seal_recognition,
-                    use_table_recognition=use_table_recognition,
-                )
-            )
+        for layout_parsing_result in self.layout_parsing_pipeline.predict(
+            input,
+            use_doc_orientation_classify=use_doc_orientation_classify,
+            use_doc_unwarping=use_doc_unwarping,
+            use_general_ocr=use_general_ocr,
+            use_seal_recognition=use_seal_recognition,
+            use_table_recognition=use_table_recognition,
+        ):
 
             visual_info = self.decode_visual_result(layout_parsing_result)
 
@@ -264,8 +235,8 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
             visual_info_list (list[VisualInfoResult]): A list of visual info results.
 
         Returns:
-            tuple[list, list, list]: A tuple containing three lists, one for normal text dicts,
-                                               one for table text lists, and one for table HTML lists.
+            tuple[list, list, list]: A tuple containing four lists, one for normal text dicts,
+                                               one for table text lists, one for table HTML lists.
         """
         all_normal_text_list = []
         all_table_text_list = []
@@ -279,7 +250,7 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
             all_normal_text_list.append(normal_text_dict)
             all_table_text_list.extend(table_text_list)
             all_table_html_list.extend(table_html_list)
-        return all_normal_text_list, all_table_text_list, all_table_html_list
+        return (all_normal_text_list, all_table_text_list, all_table_html_list)
 
     def build_vector(
         self,
@@ -305,7 +276,12 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
             visual_info_list = visual_info
 
         all_visual_info = self.merge_visual_info_list(visual_info_list)
-        all_normal_text_list, all_table_text_list, all_table_html_list = all_visual_info
+
+        (
+            all_normal_text_list,
+            all_table_text_list,
+            all_table_html_list,
+        ) = all_visual_info
 
         vector_info = {}
 
@@ -316,10 +292,7 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
 
         for table_html, table_text in zip(all_table_html_list, all_table_text_list):
             if len(table_html) > min_characters - self.table_structure_len_max:
-                all_items += [f"table:{table_text}\n"]
-
-            # if len(table_html) > min_characters - self.table_structure_len_max:
-            #     all_items += [f"table:{table_text}\n"]
+                all_items += [f"table:{table_text}"]
 
         all_text_str = "".join(all_items)
 
@@ -331,6 +304,37 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
             vector_info["vector"] = all_items
         return vector_info
 
+    def save_vector(self, vector_info: dict, save_path: str) -> None:
+        if "flag_too_short_text" not in vector_info or "vector" not in vector_info:
+            logging.error("Invalid vector info.")
+            return
+        save_vector_info = {}
+        save_vector_info["flag_too_short_text"] = vector_info["flag_too_short_text"]
+        if not vector_info["flag_too_short_text"]:
+            save_vector_info["vector"] = self.retriever.encode_vector_store_to_bytes(
+                vector_info["vector"]
+            )
+        else:
+            save_vector_info["vector"] = vector_info["vector"]
+
+        with open(save_path, "w") as fout:
+            fout.write(json.dumps(save_vector_info, ensure_ascii=False) + "\n")
+        return
+
+    def load_vector(self, data_path: str) -> dict:
+        vector_info = None
+        with open(data_path, "r") as fin:
+            data = fin.readline()
+            vector_info = json.loads(data)
+            if "flag_too_short_text" not in vector_info or "vector" not in vector_info:
+                logging.error("Invalid vector info.")
+                return {}
+            if not vector_info["flag_too_short_text"]:
+                vector_info["vector"] = self.retriever.decode_vector_store_from_bytes(
+                    vector_info["vector"]
+                )
+        return vector_info
+
     def format_key(self, key_list: str | list[str]) -> list[str]:
         """
         Formats the key list.
@@ -354,56 +358,8 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
 
         return []
 
-    def fix_llm_result_format(self, llm_result: str) -> dict:
-        """
-        Fix the format of the LLM result.
-
-        Args:
-            llm_result (str): The result from the LLM (Large Language Model).
-
-        Returns:
-            dict: A fixed format dictionary from the LLM result.
-        """
-        if not llm_result:
-            return {}
-
-        if "json" in llm_result or "```" in llm_result:
-            llm_result = (
-                llm_result.replace("```", "").replace("json", "").replace("/n", "")
-            )
-            llm_result = llm_result.replace("[", "").replace("]", "")
-
-        try:
-            llm_result = json.loads(llm_result)
-            llm_result_final = {}
-            for key in llm_result:
-                value = llm_result[key]
-                if isinstance(value, list):
-                    if len(value) > 0:
-                        llm_result_final[key] = value[0]
-                else:
-                    llm_result_final[key] = value
-            return llm_result_final
-
-        except:
-            results = (
-                llm_result.replace("\n", "")
-                .replace("    ", "")
-                .replace("{", "")
-                .replace("}", "")
-            )
-            if not results.endswith('"'):
-                results = results + '"'
-            pattern = r'"(.*?)": "([^"]*)"'
-            matches = re.findall(pattern, str(results))
-            if len(matches) > 0:
-                llm_result = {k: v for k, v in matches}
-                return llm_result
-            else:
-                return {}
-
     def generate_and_merge_chat_results(
-        self, prompt: str, key_list: list, final_results: dict, failed_results: dict
+        self, prompt: str, key_list: list, final_results: dict, failed_results: list
     ) -> None:
         """
         Generate and merge chat results into the final results dictionary.
@@ -412,7 +368,7 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
             prompt (str): The input prompt for the chat bot.
             key_list (list): A list of keys to track which results to merge.
             final_results (dict): The dictionary to store the final merged results.
-            failed_results (dict): A dictionary of failed results to avoid merging.
+            failed_results (list): A list of failed results to avoid merging.
 
         Returns:
             None
@@ -420,15 +376,13 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
 
         llm_result = self.chat_bot.generate_chat_results(prompt)
         if llm_result is None:
-            logging.warning(
+            logging.error(
                 "chat bot error: \n [prompt:]\n %s\n [result:] %s\n"
                 % (prompt, self.chat_bot.ERROR_MASSAGE)
             )
             return
 
-        # print(prompt, llm_result)
-
-        llm_result = self.fix_llm_result_format(llm_result)
+        llm_result = self.chat_bot.fix_llm_result_format(llm_result)
 
         for key, value in llm_result.items():
             if value not in failed_results and key in key_list:
@@ -436,6 +390,52 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
                 final_results[key] = value
         return
 
+    def get_related_normal_text(
+        self,
+        use_vector_retrieval: bool,
+        vector_info: dict,
+        key_list: list[str],
+        all_normal_text_list: list,
+        min_characters: int,
+    ) -> str:
+        """
+        Retrieve related normal text based on vector retrieval or all normal text list.
+
+        Args:
+            use_vector_retrieval (bool): Whether to use vector retrieval.
+            vector_info (dict): Dictionary containing vector information.
+            key_list (list[str]): List of keys to generate question keys.
+            all_normal_text_list (list): List of normal text.
+            min_characters (int): Minimum number of characters required for the output.
+
+        Returns:
+            str: Related normal text.
+        """
+
+        if use_vector_retrieval and vector_info is not None:
+            question_key_list = [f"{key}" for key in key_list]
+            vector = vector_info["vector"]
+            if not vector_info["flag_too_short_text"]:
+                related_text = self.retriever.similarity_retrieval(
+                    question_key_list, vector
+                )
+            else:
+                if len(vector) > 0:
+                    related_text = "".join(vector)
+                else:
+                    related_text = ""
+        else:
+            all_items = []
+            for i, normal_text_dict in enumerate(all_normal_text_list):
+                for type, text in normal_text_dict.items():
+                    all_items += [f"{type}:{text}\n"]
+            related_text = "".join(all_items)
+            if len(related_text) > min_characters:
+                logging.warning(
+                    "The input text content is too long, the large language model may truncate it."
+                )
+        return related_text
+
     def chat(
         self,
         key_list: str | list[str],
@@ -473,12 +473,12 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
             table_rules_str (str): The rules for generating table results.
             table_few_shot_demo_text_content (str): The text content for table few-shot demos.
             table_few_shot_demo_key_value_list (str): The key-value list for table few-shot demos.
-
         Returns:
             dict: A dictionary containing the chat results.
         """
 
         key_list = self.format_key(key_list)
+        key_list_ori = key_list.copy()
         if len(key_list) == 0:
             return {"error": "输入的key_list无效!"}
 
@@ -488,52 +488,17 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
             visual_info_list = visual_info
 
         all_visual_info = self.merge_visual_info_list(visual_info_list)
-        all_normal_text_list, all_table_text_list, all_table_html_list = all_visual_info
+
+        (
+            all_normal_text_list,
+            all_table_text_list,
+            all_table_html_list,
+        ) = all_visual_info
 
         final_results = {}
         failed_results = ["大模型调用失败", "未知", "未找到关键信息", "None", ""]
 
         if len(key_list) > 0:
-            if use_vector_retrieval and vector_info is not None:
-                # question_key_list = [f"抽取关键信息:{key}" for key in key_list]
-                question_key_list = [f"待回答问题:{key}" for key in key_list]
-                vector = vector_info["vector"]
-                if not vector_info["flag_too_short_text"]:
-                    related_text = self.retriever.similarity_retrieval(
-                        question_key_list, vector
-                    )
-                    # print(question_key_list, related_text)
-                else:
-                    if len(vector) > 0:
-                        related_text = "".join(vector)
-                    else:
-                        related_text = ""
-            else:
-                all_items = []
-                for i, normal_text_dict in enumerate(all_normal_text_list):
-                    for type, text in normal_text_dict.items():
-                        all_items += [f"{type}:{text}\n"]
-                related_text = "".join(all_items)
-                if len(related_text) > min_characters:
-                    logging.warning(
-                        "The input text content is too long, the large language model may truncate it."
-                    )
-
-            if len(related_text) > 0:
-                prompt = self.text_pe.generate_prompt(
-                    related_text,
-                    key_list,
-                    task_description=text_task_description,
-                    output_format=text_output_format,
-                    rules_str=text_rules_str,
-                    few_shot_demo_text_content=text_few_shot_demo_text_content,
-                    few_shot_demo_key_value_list=text_few_shot_demo_key_value_list,
-                )
-                self.generate_and_merge_chat_results(
-                    prompt, key_list, final_results, failed_results
-                )
-
-        if len(key_list) > 0:
             for table_html, table_text in zip(all_table_html_list, all_table_text_list):
                 if len(table_html) <= min_characters - self.table_structure_len_max:
                     for table_info in [table_html]:
@@ -552,6 +517,28 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
                                 prompt, key_list, final_results, failed_results
                             )
 
+        if len(key_list) > 0:
+            related_text = self.get_related_normal_text(
+                use_vector_retrieval,
+                vector_info,
+                key_list,
+                all_normal_text_list,
+                min_characters,
+            )
+
+            if len(related_text) > 0:
+                prompt = self.text_pe.generate_prompt(
+                    related_text,
+                    key_list,
+                    task_description=text_task_description,
+                    output_format=text_output_format,
+                    rules_str=text_rules_str,
+                    few_shot_demo_text_content=text_few_shot_demo_text_content,
+                    few_shot_demo_key_value_list=text_few_shot_demo_key_value_list,
+                )
+                self.generate_and_merge_chat_results(
+                    prompt, key_list, final_results, failed_results
+                )
         return {"chat_res": final_results}
 
     def predict(self, *args, **kwargs) -> None:

+ 635 - 0
paddlex/inference/pipelines_new/pp_chatocr/pipeline_v4.py

@@ -0,0 +1,635 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Dict, Optional
+import re
+import json
+import numpy as np
+import copy
+from .pipeline_base import PP_ChatOCR_Pipeline
+from .result import VisualInfoResult
+from ...common.reader import ReadImage
+from ...common.batch_sampler import ImageBatchSampler
+from ....utils import logging
+from ...utils.pp_option import PaddlePredictorOption
+from ..layout_parsing.result import LayoutParsingResult
+
+
+class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
+    """PP-ChatOCRv4 Pipeline"""
+
+    entities = ["PP-ChatOCRv4-doc"]
+
+    def __init__(
+        self,
+        config: Dict,
+        device: str = None,
+        pp_option: PaddlePredictorOption = None,
+        use_hpip: bool = False,
+        hpi_params: Optional[Dict[str, Any]] = None,
+    ) -> None:
+        """Initializes the pp-chatocrv3-doc pipeline.
+
+        Args:
+            config (Dict): Configuration dictionary containing various settings.
+            device (str, optional): Device to run the predictions on. Defaults to None.
+            pp_option (PaddlePredictorOption, optional): PaddlePredictor options. Defaults to None.
+            use_hpip (bool, optional): Whether to use high-performance inference (hpip) for prediction. Defaults to False.
+            hpi_params (Optional[Dict[str, Any]], optional): HPIP parameters. Defaults to None.
+            use_layout_parsing (bool, optional): Whether to use layout parsing. Defaults to True.
+        """
+
+        super().__init__(
+            device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_params=hpi_params
+        )
+
+        self.pipeline_name = config["pipeline_name"]
+
+        self.inintial_predictor(config)
+
+        self.batch_sampler = ImageBatchSampler(batch_size=1)
+        self.img_reader = ReadImage(format="BGR")
+
+        self.table_structure_len_max = 500
+
+    def inintial_predictor(self, config: dict) -> None:
+        """
+        Initializes the predictor with the given configuration.
+
+        Args:
+            config (dict): The configuration dictionary containing the necessary
+                                parameters for initializing the predictor.
+        Returns:
+            None
+        """
+
+        layout_parsing_config = config["SubPipelines"]["LayoutParser"]
+        self.layout_parsing_pipeline = self.create_pipeline(layout_parsing_config)
+
+        from .. import create_chat_bot
+
+        chat_bot_config = config["SubModules"]["LLM_Chat"]
+        self.chat_bot = create_chat_bot(chat_bot_config)
+
+        from .. import create_retriever
+
+        retriever_config = config["SubModules"]["LLM_Retriever"]
+        self.retriever = create_retriever(retriever_config)
+
+        from .. import create_prompt_engeering
+
+        text_pe_config = config["SubModules"]["PromptEngneering"]["KIE_CommonText"]
+        self.text_pe = create_prompt_engeering(text_pe_config)
+
+        table_pe_config = config["SubModules"]["PromptEngneering"]["KIE_Table"]
+        self.table_pe = create_prompt_engeering(table_pe_config)
+
+        self.use_mllm_predict = False
+        if "use_mllm_predict" in config:
+            self.use_mllm_predict = config["use_mllm_predict"]
+        if self.use_mllm_predict:
+            ensemble_pe_config = config["SubModules"]["PromptEngneering"]["Ensemble"]
+            self.ensemble_pe = create_prompt_engeering(ensemble_pe_config)
+        return
+
+    def decode_visual_result(
+        self, layout_parsing_result: LayoutParsingResult
+    ) -> VisualInfoResult:
+        """
+        Decodes the visual result from the layout parsing result.
+
+        Args:
+            layout_parsing_result (LayoutParsingResult): The result of layout parsing.
+
+        Returns:
+            VisualInfoResult: The decoded visual information.
+        """
+        text_paragraphs_ocr_res = layout_parsing_result["text_paragraphs_ocr_res"]
+        seal_res_list = layout_parsing_result["seal_res_list"]
+        normal_text_dict = {}
+
+        for seal_res in seal_res_list:
+            for text in seal_res["rec_text"]:
+                layout_type = "印章"
+                if layout_type not in normal_text_dict:
+                    normal_text_dict[layout_type] = f"{text}"
+                else:
+                    normal_text_dict[layout_type] += f"\n {text}"
+
+        for text in text_paragraphs_ocr_res["rec_text"]:
+            layout_type = "words in text block"
+            if layout_type not in normal_text_dict:
+                normal_text_dict[layout_type] = text
+            else:
+                normal_text_dict[layout_type] += f"\n {text}"
+
+        table_res_list = layout_parsing_result["table_res_list"]
+        table_text_list = []
+        table_html_list = []
+        table_nei_text_list = []
+        for table_res in table_res_list:
+            table_html_list.append(table_res["pred_html"])
+            single_table_text = " ".join(table_res["table_ocr_pred"]["rec_text"])
+            table_text_list.append(single_table_text)
+            table_nei_text_list.append(table_res["neighbor_text"])
+
+        visual_info = {}
+        visual_info["normal_text_dict"] = normal_text_dict
+        visual_info["table_text_list"] = table_text_list
+        visual_info["table_html_list"] = table_html_list
+        visual_info["table_nei_text_list"] = table_nei_text_list
+        return VisualInfoResult(visual_info)
+
+    # Function to perform visual prediction on input images
+    def visual_predict(
+        self,
+        input: str | list[str] | np.ndarray | list[np.ndarray],
+        use_doc_orientation_classify: bool = False,  # Whether to use document orientation classification
+        use_doc_unwarping: bool = False,  # Whether to use document unwarping
+        use_general_ocr: bool = True,  # Whether to use general OCR
+        use_seal_recognition: bool = True,  # Whether to use seal recognition
+        use_table_recognition: bool = True,  # Whether to use table recognition
+        **kwargs,
+    ) -> dict:
+        """
+        This function takes an input image or a list of images and performs various visual
+        prediction tasks such as document orientation classification, document unwarping,
+        general OCR, seal recognition, and table recognition based on the provided flags.
+
+        Args:
+            input (str | list[str] | np.ndarray | list[np.ndarray]): Input image path, list of image paths,
+                                                                        numpy array of an image, or list of numpy arrays.
+            use_doc_orientation_classify (bool): Flag to use document orientation classification.
+            use_doc_unwarping (bool): Flag to use document unwarping.
+            use_general_ocr (bool): Flag to use general OCR.
+            use_seal_recognition (bool): Flag to use seal recognition.
+            use_table_recognition (bool): Flag to use table recognition.
+            **kwargs: Additional keyword arguments.
+
+        Returns:
+            dict: A dictionary containing the layout parsing result and visual information.
+        """
+
+        for layout_parsing_result in self.layout_parsing_pipeline.predict(
+            input,
+            use_doc_orientation_classify=use_doc_orientation_classify,
+            use_doc_unwarping=use_doc_unwarping,
+            use_general_ocr=use_general_ocr,
+            use_seal_recognition=use_seal_recognition,
+            use_table_recognition=use_table_recognition,
+        ):
+
+            visual_info = self.decode_visual_result(layout_parsing_result)
+
+            visual_predict_res = {
+                "layout_parsing_result": layout_parsing_result,
+                "visual_info": visual_info,
+            }
+            yield visual_predict_res
+
+    def save_visual_info_list(
+        self, visual_info: VisualInfoResult, save_path: str
+    ) -> None:
+        """
+        Save the visual info list to the specified file path.
+
+        Args:
+            visual_info (VisualInfoResult): The visual info result, which can be a single object or a list of objects.
+            save_path (str): The file path to save the visual info list.
+
+        Returns:
+            None
+        """
+        if not isinstance(visual_info, list):
+            visual_info_list = [visual_info]
+        else:
+            visual_info_list = visual_info
+
+        with open(save_path, "w") as fout:
+            fout.write(json.dumps(visual_info_list, ensure_ascii=False) + "\n")
+        return
+
+    def load_visual_info_list(self, data_path: str) -> list[VisualInfoResult]:
+        """
+        Loads visual info list from a JSON file.
+
+        Args:
+            data_path (str): The path to the JSON file containing visual info.
+
+        Returns:
+            list[VisualInfoResult]: A list of VisualInfoResult objects parsed from the JSON file.
+        """
+        with open(data_path, "r") as fin:
+            data = fin.readline()
+            visual_info_list = json.loads(data)
+        return visual_info_list
+
+    def merge_visual_info_list(
+        self, visual_info_list: list[VisualInfoResult]
+    ) -> tuple[list, list, list, list]:
+        """
+        Merge visual info lists.
+
+        Args:
+            visual_info_list (list[VisualInfoResult]): A list of visual info results.
+
+        Returns:
+            tuple[list, list, list, list]: A tuple containing four lists, one for normal text dicts,
+                                               one for table text lists, one for table HTML lists.
+                                               one for table neighbor texts.
+        """
+        all_normal_text_list = []
+        all_table_text_list = []
+        all_table_html_list = []
+        all_table_nei_text_list = []
+        for single_visual_info in visual_info_list:
+            normal_text_dict = single_visual_info["normal_text_dict"]
+            for key in normal_text_dict:
+                normal_text_dict[key] = normal_text_dict[key].replace("\n", "")
+            table_text_list = single_visual_info["table_text_list"]
+            table_html_list = single_visual_info["table_html_list"]
+            table_nei_text_list = single_visual_info["table_nei_text_list"]
+            all_normal_text_list.append(normal_text_dict)
+            all_table_text_list.extend(table_text_list)
+            all_table_html_list.extend(table_html_list)
+            all_table_nei_text_list.extend(table_nei_text_list)
+        return (
+            all_normal_text_list,
+            all_table_text_list,
+            all_table_html_list,
+            all_table_nei_text_list,
+        )
+
+    def build_vector(
+        self,
+        visual_info: VisualInfoResult,
+        min_characters: int = 3500,
+        llm_request_interval: float = 1.0,
+    ) -> dict:
+        """
+        Build a vector representation from visual information.
+
+        Args:
+            visual_info (VisualInfoResult): The visual information input, can be a single instance or a list of instances.
+            min_characters (int): The minimum number of characters required for text processing, defaults to 3500.
+            llm_request_interval (float): The interval between LLM requests, defaults to 1.0.
+
+        Returns:
+            dict: A dictionary containing the vector info and a flag indicating if the text is too short.
+        """
+
+        if not isinstance(visual_info, list):
+            visual_info_list = [visual_info]
+        else:
+            visual_info_list = visual_info
+
+        all_visual_info = self.merge_visual_info_list(visual_info_list)
+        (
+            all_normal_text_list,
+            all_table_text_list,
+            all_table_html_list,
+            all_table_nei_text_list,
+        ) = all_visual_info
+
+        vector_info = {}
+
+        all_items = []
+        for i, normal_text_dict in enumerate(all_normal_text_list):
+            for type, text in normal_text_dict.items():
+                all_items += [f"{type}:{text}\n"]
+
+        for table_html, table_text, table_nei_text in zip(
+            all_table_html_list, all_table_text_list, all_table_nei_text_list
+        ):
+            if len(table_html) > min_characters - self.table_structure_len_max:
+                all_items += [f"table:{table_text}\t{table_nei_text}"]
+
+        all_text_str = "".join(all_items)
+
+        if len(all_text_str) > min_characters:
+            vector_info["flag_too_short_text"] = False
+            vector_info["vector"] = self.retriever.generate_vector_database(all_items)
+        else:
+            vector_info["flag_too_short_text"] = True
+            vector_info["vector"] = all_items
+        return vector_info
+
+    def save_vector(self, vector_info: dict, save_path: str) -> None:
+        if "flag_too_short_text" not in vector_info or "vector" not in vector_info:
+            logging.error("Invalid vector info.")
+            return
+        save_vector_info = {}
+        save_vector_info["flag_too_short_text"] = vector_info["flag_too_short_text"]
+        if not vector_info["flag_too_short_text"]:
+            save_vector_info["vector"] = self.retriever.encode_vector_store_to_bytes(
+                vector_info["vector"]
+            )
+        else:
+            save_vector_info["vector"] = vector_info["vector"]
+
+        with open(save_path, "w") as fout:
+            fout.write(json.dumps(save_vector_info, ensure_ascii=False) + "\n")
+        return
+
+    def load_vector(self, data_path: str) -> dict:
+        vector_info = None
+        with open(data_path, "r") as fin:
+            data = fin.readline()
+            vector_info = json.loads(data)
+            if "flag_too_short_text" not in vector_info or "vector" not in vector_info:
+                logging.error("Invalid vector info.")
+                return
+            if not vector_info["flag_too_short_text"]:
+                vector_info["vector"] = self.retriever.decode_vector_store_from_bytes(
+                    vector_info["vector"]
+                )
+        return vector_info
+
+    def format_key(self, key_list: str | list[str]) -> list[str]:
+        """
+        Formats the key list.
+
+        Args:
+            key_list (str|list[str]): A string or a list of strings representing the keys.
+
+        Returns:
+            list[str]: A list of formatted keys.
+        """
+        if key_list == "":
+            return []
+
+        if isinstance(key_list, list):
+            key_list = [key.replace("\xa0", " ") for key in key_list]
+            return key_list
+
+        if isinstance(key_list, str):
+            key_list = re.sub(r"[\t\n\r\f\v]", "", key_list)
+            key_list = key_list.replace(",", ",").split(",")
+            return key_list
+
+        return []
+
+    def generate_and_merge_chat_results(
+        self, prompt: str, key_list: list, final_results: dict, failed_results: list
+    ) -> None:
+        """
+        Generate and merge chat results into the final results dictionary.
+
+        Args:
+            prompt (str): The input prompt for the chat bot.
+            key_list (list): A list of keys to track which results to merge.
+            final_results (dict): The dictionary to store the final merged results.
+            failed_results (list): A list of failed results to avoid merging.
+
+        Returns:
+            None
+        """
+
+        llm_result = self.chat_bot.generate_chat_results(prompt)
+        if llm_result is None:
+            logging.error(
+                "chat bot error: \n [prompt:]\n %s\n [result:] %s\n"
+                % (prompt, self.chat_bot.ERROR_MASSAGE)
+            )
+            return
+
+        llm_result = self.chat_bot.fix_llm_result_format(llm_result)
+
+        for key, value in llm_result.items():
+            if value not in failed_results and key in key_list:
+                key_list.remove(key)
+                final_results[key] = value
+        return
+
+    def get_related_normal_text(
+        self,
+        use_vector_retrieval: bool,
+        vector_info: dict,
+        key_list: list[str],
+        all_normal_text_list: list,
+        min_characters: int,
+    ) -> str:
+        """
+        Retrieve related normal text based on vector retrieval or all normal text list.
+
+        Args:
+            use_vector_retrieval (bool): Whether to use vector retrieval.
+            vector_info (dict): Dictionary containing vector information.
+            key_list (list[str]): List of keys to generate question keys.
+            all_normal_text_list (list): List of normal text.
+            min_characters (int): The minimum number of characters required for text processing, defaults to 3500.
+
+        Returns:
+            str: Related normal text.
+        """
+
+        if use_vector_retrieval and vector_info is not None:
+            question_key_list = [f"{key}" for key in key_list]
+            vector = vector_info["vector"]
+            if not vector_info["flag_too_short_text"]:
+                related_text = self.retriever.similarity_retrieval(
+                    question_key_list, vector, topk=5, min_characters=min_characters
+                )
+            else:
+                if len(vector) > 0:
+                    related_text = "".join(vector)
+                else:
+                    related_text = ""
+        else:
+            all_items = []
+            for i, normal_text_dict in enumerate(all_normal_text_list):
+                for type, text in normal_text_dict.items():
+                    all_items += [f"{type}:{text}\n"]
+            related_text = "".join(all_items)
+            if len(related_text) > min_characters:
+                logging.warning(
+                    "The input text content is too long, the large language model may truncate it."
+                )
+        return related_text
+
+    def ensemble_ocr_llm_mllm(
+        self, key_list: list[str], ocr_llm_predict_dict: dict, mllm_predict_dict: dict
+    ) -> dict:
+        """
+        Ensemble OCR_LLM and LMM predictions based on given key list.
+
+        Args:
+            key_list (list[str]): List of keys to retrieve predictions.
+            ocr_llm_predict_dict (dict): Dictionary containing OCR LLM predictions.
+            mllm_predict_dict (dict): Dictionary containing mLLM predictions.
+
+        Returns:
+            dict: A dictionary with final predictions.
+        """
+        final_predict_dict = {}
+
+        for key in key_list:
+            predict = ""
+            ocr_llm_predict = ""
+            mllm_predict = ""
+            if key in ocr_llm_predict_dict:
+                ocr_llm_predict = ocr_llm_predict_dict[key]
+            if key in mllm_predict_dict:
+                mllm_predict = mllm_predict_dict[key]
+            if ocr_llm_predict != "" and mllm_predict != "":
+                prompt = self.ensemble_pe.generate_prompt(
+                    key, ocr_llm_predict, mllm_predict
+                )
+                llm_result = self.chat_bot.generate_chat_results(prompt)
+                if llm_result is not None:
+                    llm_result = self.chat_bot.fix_llm_result_format(llm_result)
+                if key in llm_result:
+                    tmp = llm_result[key]
+                    if "B" in tmp:
+                        predict = mllm_predict
+                    else:
+                        predict = ocr_llm_predict
+                else:
+                    predict = ocr_llm_predict
+            elif key in ocr_llm_predict_dict:
+                predict = ocr_llm_predict_dict[key]
+            elif key in mllm_predict_dict:
+                predict = mllm_predict_dict[key]
+
+            if predict != "":
+                final_predict_dict[key] = predict
+        return final_predict_dict
+
+    def chat(
+        self,
+        key_list: str | list[str],
+        visual_info: VisualInfoResult,
+        use_vector_retrieval: bool = True,
+        vector_info: dict = None,
+        min_characters: int = 3500,
+        text_task_description: str = None,
+        text_output_format: str = None,
+        text_rules_str: str = None,
+        text_few_shot_demo_text_content: str = None,
+        text_few_shot_demo_key_value_list: str = None,
+        table_task_description: str = None,
+        table_output_format: str = None,
+        table_rules_str: str = None,
+        table_few_shot_demo_text_content: str = None,
+        table_few_shot_demo_key_value_list: str = None,
+        mllm_predict_dict: dict = None,
+    ) -> dict:
+        """
+        Generates chat results based on the provided key list and visual information.
+
+        Args:
+            key_list (str | list[str]): A single key or a list of keys to extract information.
+            visual_info (VisualInfoResult): The visual information result.
+            use_vector_retrieval (bool): Whether to use vector retrieval.
+            vector_info (dict): The vector information for retrieval.
+            min_characters (int): The minimum number of characters required for text processing, defaults to 3500.
+            text_task_description (str): The description of the text task.
+            text_output_format (str): The output format for text results.
+            text_rules_str (str): The rules for generating text results.
+            text_few_shot_demo_text_content (str): The text content for few-shot demos.
+            text_few_shot_demo_key_value_list (str): The key-value list for few-shot demos.
+            table_task_description (str): The description of the table task.
+            table_output_format (str): The output format for table results.
+            table_rules_str (str): The rules for generating table results.
+            table_few_shot_demo_text_content (str): The text content for table few-shot demos.
+            table_few_shot_demo_key_value_list (str): The key-value list for table few-shot demos.
+            mllm_predict_dict (dict): The dictionary of mLLM predicts.
+        Returns:
+            dict: A dictionary containing the chat results.
+        """
+
+        key_list = self.format_key(key_list)
+        key_list_ori = key_list.copy()
+        if len(key_list) == 0:
+            return {"error": "输入的key_list无效!"}
+
+        if not isinstance(visual_info, list):
+            visual_info_list = [visual_info]
+        else:
+            visual_info_list = visual_info
+
+        all_visual_info = self.merge_visual_info_list(visual_info_list)
+        (
+            all_normal_text_list,
+            all_table_text_list,
+            all_table_html_list,
+            all_table_nei_text_list,
+        ) = all_visual_info
+
+        final_results = {}
+        failed_results = ["大模型调用失败", "未知", "未找到关键信息", "None", ""]
+
+        if len(key_list) > 0:
+            related_text = self.get_related_normal_text(
+                use_vector_retrieval,
+                vector_info,
+                key_list,
+                all_normal_text_list,
+                min_characters,
+            )
+
+            if len(related_text) > 0:
+                prompt = self.text_pe.generate_prompt(
+                    related_text,
+                    key_list,
+                    task_description=text_task_description,
+                    output_format=text_output_format,
+                    rules_str=text_rules_str,
+                    few_shot_demo_text_content=text_few_shot_demo_text_content,
+                    few_shot_demo_key_value_list=text_few_shot_demo_key_value_list,
+                )
+                self.generate_and_merge_chat_results(
+                    prompt, key_list, final_results, failed_results
+                )
+
+        if len(key_list) > 0:
+            for table_html, table_text, table_nei_text in zip(
+                all_table_html_list, all_table_text_list, all_table_nei_text_list
+            ):
+                if len(table_html) <= min_characters - self.table_structure_len_max:
+                    for table_info in [table_html]:
+                        if len(key_list) > 0:
+
+                            if len(table_nei_text) > 0:
+                                table_info = (
+                                    table_info + "\n 表格周围文字:" + table_nei_text
+                                )
+
+                            prompt = self.table_pe.generate_prompt(
+                                table_info,
+                                key_list,
+                                task_description=table_task_description,
+                                output_format=table_output_format,
+                                rules_str=table_rules_str,
+                                few_shot_demo_text_content=table_few_shot_demo_text_content,
+                                few_shot_demo_key_value_list=table_few_shot_demo_key_value_list,
+                            )
+
+                            self.generate_and_merge_chat_results(
+                                prompt, key_list, final_results, failed_results
+                            )
+
+        if self.use_mllm_predict:
+            final_predict_dict = self.ensemble_ocr_llm_mllm(
+                key_list_ori, final_results, mllm_predict_dict
+            )
+        else:
+            final_predict_dict = final_results
+        return {"chat_res": final_predict_dict}
+
+    def predict(self, *args, **kwargs) -> None:
+        logging.error(
+            "PP-ChatOCRv4-doc Pipeline do not support to call `predict()` directly! Please invoke `visual_predict`, `build_vector`, `chat` sequentially to obtain the result."
+        )
+        return

+ 1 - 2
paddlex/inference/pipelines_new/pp_chatocrv3_doc/result.py → paddlex/inference/pipelines_new/pp_chatocr/result.py

@@ -18,9 +18,8 @@ import numpy as np
 import cv2
 import PIL
 from PIL import Image, ImageDraw, ImageFont
-
 from ....utils.fonts import PINGFANG_FONT_FILE_PATH
-from ..components import BaseResult
+from ...common.result import BaseResult
 
 
 class VisualInfoResult(BaseResult):

+ 1 - 1
paddlex/inference/pipelines_new/pp_chatocrv3_doc/__init__.py → paddlex/inference/pipelines_new/seal_recognition/__init__.py

@@ -12,4 +12,4 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from .pipeline import PP_ChatOCRv3_doc_Pipeline
+from .pipeline import SealRecognitionPipeline

+ 228 - 0
paddlex/inference/pipelines_new/seal_recognition/pipeline.py

@@ -0,0 +1,228 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os, sys
+from typing import Any, Dict, Optional
+import numpy as np
+import cv2
+from ..base import BasePipeline
+from ..components import CropByBoxes
+from .result import SealRecognitionResult
+from ....utils import logging
+from ...utils.pp_option import PaddlePredictorOption
+from ...common.reader import ReadImage
+from ...common.batch_sampler import ImageBatchSampler
+from ..doc_preprocessor.result import DocPreprocessorResult
+
+# [TODO] 待更新models_new到models
+from ...models_new.object_detection.result import DetResult
+
+
+class SealRecognitionPipeline(BasePipeline):
+    """Seal Recognition Pipeline"""
+
+    entities = ["seal_recognition"]
+
+    def __init__(
+        self,
+        config: Dict,
+        device: str = None,
+        pp_option: PaddlePredictorOption = None,
+        use_hpip: bool = False,
+        hpi_params: Optional[Dict[str, Any]] = None,
+    ) -> None:
+        """Initializes the seal recognition pipeline.
+
+        Args:
+            config (Dict): Configuration dictionary containing various settings.
+            device (str, optional): Device to run the predictions on. Defaults to None.
+            pp_option (PaddlePredictorOption, optional): PaddlePredictor options. Defaults to None.
+            use_hpip (bool, optional): Whether to use high-performance inference (hpip) for prediction. Defaults to False.
+            hpi_params (Optional[Dict[str, Any]], optional): HPIP parameters. Defaults to None.
+        """
+
+        super().__init__(
+            device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_params=hpi_params
+        )
+
+        self.use_doc_preprocessor = False
+        if "use_doc_preprocessor" in config:
+            self.use_doc_preprocessor = config["use_doc_preprocessor"]
+
+        if self.use_doc_preprocessor:
+            doc_preprocessor_config = config["SubPipelines"]["DocPreprocessor"]
+            self.doc_preprocessor_pipeline = self.create_pipeline(
+                doc_preprocessor_config
+            )
+
+        self.use_layout_detection = True
+        if "use_layout_detection" in config:
+            self.use_layout_detection = config["use_layout_detection"]
+
+        if self.use_layout_detection:
+            layout_det_config = config["SubModules"]["LayoutDetection"]
+            self.layout_det_model = self.create_model(layout_det_config)
+
+        seal_ocr_config = config["SubPipelines"]["SealOCR"]
+        self.seal_ocr_pipeline = self.create_pipeline(seal_ocr_config)
+
+        self._crop_by_boxes = CropByBoxes()
+
+        self.batch_sampler = ImageBatchSampler(batch_size=1)
+
+        self.img_reader = ReadImage(format="BGR")
+
+    def check_input_params_valid(
+        self, input_params: Dict, layout_det_res: DetResult
+    ) -> bool:
+        """
+        Check if the input parameters are valid based on the initialized models.
+
+        Args:
+            input_params (Dict): A dictionary containing input parameters.
+            layout_det_res (DetResult): Layout detection result.
+
+        Returns:
+            bool: True if all required models are initialized according to input parameters, False otherwise.
+        """
+
+        if input_params["use_doc_preprocessor"] and not self.use_doc_preprocessor:
+            logging.error(
+                "Set use_doc_preprocessor, but the models for doc preprocessor are not initialized."
+            )
+            return False
+
+        if input_params["use_layout_detection"]:
+            if layout_det_res is not None:
+                logging.error(
+                    "The layout detection model has already been initialized, please set use_layout_detection=False"
+                )
+                return False
+
+            if not self.use_layout_detection:
+                logging.error(
+                    "Set use_layout_detection, but the models for layout detection are not initialized."
+                )
+                return False
+        return True
+
+    def predict_doc_preprocessor_res(
+        self, image_array: np.ndarray, input_params: dict
+    ) -> tuple[DocPreprocessorResult, np.ndarray]:
+        """
+        Preprocess the document image based on input parameters.
+
+        Args:
+            image_array (np.ndarray): The input image array.
+            input_params (dict): Dictionary containing preprocessing parameters.
+
+        Returns:
+            tuple[DocPreprocessorResult, np.ndarray]: A tuple containing the preprocessing
+                                              result dictionary and the processed image array.
+        """
+        if input_params["use_doc_preprocessor"]:
+            use_doc_orientation_classify = input_params["use_doc_orientation_classify"]
+            use_doc_unwarping = input_params["use_doc_unwarping"]
+            doc_preprocessor_res = next(
+                self.doc_preprocessor_pipeline(
+                    image_array,
+                    use_doc_orientation_classify=use_doc_orientation_classify,
+                    use_doc_unwarping=use_doc_unwarping,
+                )
+            )
+            doc_preprocessor_image = doc_preprocessor_res["output_img"]
+        else:
+            doc_preprocessor_res = {}
+            doc_preprocessor_image = image_array
+        return doc_preprocessor_res, doc_preprocessor_image
+
+    def predict(
+        self,
+        input: str | list[str] | np.ndarray | list[np.ndarray],
+        use_layout_detection: bool = True,
+        use_doc_orientation_classify: bool = False,
+        use_doc_unwarping: bool = False,
+        layout_det_res: DetResult = None,
+        **kwargs
+    ) -> SealRecognitionResult:
+        """
+        This function predicts the seal recognition result for the given input.
+
+        Args:
+            input (str | list[str] | np.ndarray | list[np.ndarray]): The input image(s) of pdf(s) to be processed.
+            use_layout_detection (bool): Whether to use layout detection.
+            use_doc_orientation_classify (bool): Whether to use document orientation classification.
+            use_doc_unwarping (bool): Whether to use document unwarping.
+            layout_det_res (DetResult): The layout detection result.
+                It will be used if it is not None and use_layout_detection is False.
+            **kwargs: Additional keyword arguments.
+
+        Returns:
+            SealRecognitionResult: The predicted seal recognition result.
+        """
+
+        input_params = {
+            "use_layout_detection": use_layout_detection,
+            "use_doc_preprocessor": self.use_doc_preprocessor,
+            "use_doc_orientation_classify": use_doc_orientation_classify,
+            "use_doc_unwarping": use_doc_unwarping,
+        }
+
+        if use_doc_orientation_classify or use_doc_unwarping:
+            input_params["use_doc_preprocessor"] = True
+        else:
+            input_params["use_doc_preprocessor"] = False
+
+        if not self.check_input_params_valid(input_params, layout_det_res):
+            yield None
+
+        for img_id, batch_data in enumerate(self.batch_sampler(input)):
+            image_array = self.img_reader(batch_data)[0]
+            img_id += 1
+
+            doc_preprocessor_res, doc_preprocessor_image = (
+                self.predict_doc_preprocessor_res(image_array, input_params)
+            )
+
+            seal_res_list = []
+            seal_region_id = 1
+            if not input_params["use_layout_detection"] and layout_det_res is None:
+                layout_det_res = {}
+                seal_ocr_res = next(self.seal_ocr_pipeline(doc_preprocessor_image))
+                seal_ocr_res["seal_region_id"] = seal_region_id
+                seal_res_list.append(seal_ocr_res)
+                seal_region_id += 1
+            else:
+                if input_params["use_layout_detection"]:
+                    layout_det_res = next(self.layout_det_model(doc_preprocessor_image))
+
+                for box_info in layout_det_res["boxes"]:
+                    if box_info["label"].lower() in ["seal"]:
+                        crop_img_info = self._crop_by_boxes(image_array, [box_info])
+                        crop_img_info = crop_img_info[0]
+                        seal_ocr_res = next(
+                            self.seal_ocr_pipeline(crop_img_info["img"])
+                        )
+                        seal_ocr_res["seal_region_id"] = seal_region_id
+                        seal_res_list.append(seal_ocr_res)
+                        seal_region_id += 1
+
+            single_img_res = {
+                "layout_det_res": layout_det_res,
+                "doc_preprocessor_res": doc_preprocessor_res,
+                "seal_res_list": seal_res_list,
+                "input_params": input_params,
+                "img_id": img_id,
+            }
+            yield SealRecognitionResult(single_img_res)

+ 55 - 0
paddlex/inference/pipelines_new/seal_recognition/result.py

@@ -0,0 +1,55 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+from pathlib import Path
+
+
+class SealRecognitionResult(dict):
+    """Seal Recognition Result"""
+
+    def __init__(self, data) -> None:
+        """Initializes a new instance of the class with the specified data."""
+        super().__init__(data)
+
+    def save_results(self, save_path: str) -> None:
+        """Save the layout parsing results to the specified directory.
+
+        Args:
+            save_path (str): The directory path to save the results.
+        """
+
+        if not os.path.isdir(save_path):
+            return
+
+        img_id = self["img_id"]
+        layout_det_res = self["layout_det_res"]
+        if len(layout_det_res) > 0:
+            save_img_path = Path(save_path) / f"layout_det_result_img{img_id}.jpg"
+            layout_det_res.save_to_img(save_img_path)
+
+        input_params = self["input_params"]
+        if input_params["use_doc_preprocessor"]:
+            save_img_path = Path(save_path) / f"doc_preprocessor_result_img{img_id}.jpg"
+            self["doc_preprocessor_res"].save_to_img(save_img_path)
+
+        for sno in range(len(self["seal_res_list"])):
+            seal_res = self["seal_res_list"][sno]
+            seal_region_id = seal_res["seal_region_id"]
+            save_img_path = (
+                Path(save_path) / f"seal_res_img{img_id}_region{seal_region_id}.jpg"
+            )
+            seal_res.save_to_img(save_img_path)
+
+        return

+ 15 - 0
paddlex/inference/pipelines_new/table_recognition/__init__.py

@@ -0,0 +1,15 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .pipeline import TableRecognitionPipeline

+ 310 - 0
paddlex/inference/pipelines_new/table_recognition/pipeline.py

@@ -0,0 +1,310 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os, sys
+from typing import Any, Dict, Optional
+import numpy as np
+import cv2
+from ..base import BasePipeline
+from ..components import CropByBoxes
+from ..layout_parsing.utils import convert_points_to_boxes
+from .utils import get_neighbor_boxes_idx
+from .table_recognition_post_processing import get_table_recognition_res
+from .result import SingleTableRecognitionResult, TableRecognitionResult
+from ....utils import logging
+from ...utils.pp_option import PaddlePredictorOption
+from ...common.reader import ReadImage
+from ...common.batch_sampler import ImageBatchSampler
+from ..ocr.result import OCRResult
+from ..doc_preprocessor.result import DocPreprocessorResult
+
+# [TODO] 待更新models_new到models
+from ...models_new.object_detection.result import DetResult
+
+
+class TableRecognitionPipeline(BasePipeline):
+    """Table Recognition Pipeline"""
+
+    entities = ["table_recognition"]
+
+    def __init__(
+        self,
+        config: Dict,
+        device: str = None,
+        pp_option: PaddlePredictorOption = None,
+        use_hpip: bool = False,
+        hpi_params: Optional[Dict[str, Any]] = None,
+    ) -> None:
+        """Initializes the layout parsing pipeline.
+
+        Args:
+            config (Dict): Configuration dictionary containing various settings.
+            device (str, optional): Device to run the predictions on. Defaults to None.
+            pp_option (PaddlePredictorOption, optional): PaddlePredictor options. Defaults to None.
+            use_hpip (bool, optional): Whether to use high-performance inference (hpip) for prediction. Defaults to False.
+            hpi_params (Optional[Dict[str, Any]], optional): HPIP parameters. Defaults to None.
+        """
+
+        super().__init__(
+            device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_params=hpi_params
+        )
+
+        self.use_doc_preprocessor = False
+        if "use_doc_preprocessor" in config:
+            self.use_doc_preprocessor = config["use_doc_preprocessor"]
+
+        if self.use_doc_preprocessor:
+            doc_preprocessor_config = config["SubPipelines"]["DocPreprocessor"]
+            self.doc_preprocessor_pipeline = self.create_pipeline(
+                doc_preprocessor_config
+            )
+
+        self.use_layout_detection = True
+        if "use_layout_detection" in config:
+            self.use_layout_detection = config["use_layout_detection"]
+
+        if self.use_layout_detection:
+            layout_det_config = config["SubModules"]["LayoutDetection"]
+            self.layout_det_model = self.create_model(layout_det_config)
+
+        table_structure_config = config["SubModules"]["TableStructureRecognition"]
+        self.table_structure_model = self.create_model(table_structure_config)
+
+        self.use_ocr_model = True
+        if "use_ocr_model" in config:
+            self.use_ocr_model = config["use_ocr_model"]
+        if self.use_ocr_model:
+            general_ocr_config = config["SubPipelines"]["GeneralOCR"]
+            self.general_ocr_pipeline = self.create_pipeline(general_ocr_config)
+
+        self._crop_by_boxes = CropByBoxes()
+
+        self.batch_sampler = ImageBatchSampler(batch_size=1)
+        self.img_reader = ReadImage(format="BGR")
+
+    def check_input_params_valid(
+        self, input_params: Dict, overall_ocr_res: OCRResult, layout_det_res: DetResult
+    ) -> bool:
+        """
+        Check if the input parameters are valid based on the initialized models.
+
+        Args:
+            input_params (Dict): A dictionary containing input parameters.
+            overall_ocr_res (OCRResult): Overall OCR result obtained after running the OCR pipeline.
+                The overall OCR result with convert_points_to_boxes information.
+            layout_det_res (DetResult): The layout detection result.
+        Returns:
+            bool: True if all required models are initialized according to input parameters, False otherwise.
+        """
+
+        if input_params["use_doc_preprocessor"] and not self.use_doc_preprocessor:
+            logging.error(
+                "Set use_doc_preprocessor, but the models for doc preprocessor are not initialized."
+            )
+            return False
+
+        if input_params["use_layout_detection"]:
+            if layout_det_res is not None:
+                logging.error(
+                    "The layout detection model has already been initialized, please set use_layout_detection=False"
+                )
+                return False
+
+            if not self.use_layout_detection:
+                logging.error(
+                    "Set use_layout_detection, but the models for layout detection are not initialized."
+                )
+                return False
+
+        if input_params["use_ocr_model"]:
+            if overall_ocr_res is not None:
+                logging.error(
+                    "The OCR models have already been initialized, please set use_ocr_model=False"
+                )
+                return False
+
+            if not self.use_ocr_model:
+                logging.error(
+                    "Set use_ocr_model, but the models for OCR are not initialized."
+                )
+                return False
+
+        return True
+
+    def predict_doc_preprocessor_res(
+        self, image_array: np.ndarray, input_params: dict
+    ) -> tuple[DocPreprocessorResult, np.ndarray]:
+        """
+        Preprocess the document image based on input parameters.
+
+        Args:
+            image_array (np.ndarray): The input image array.
+            input_params (dict): Dictionary containing preprocessing parameters.
+
+        Returns:
+            tuple[DocPreprocessorResult, np.ndarray]: A tuple containing the preprocessing
+                                              result dictionary and the processed image array.
+        """
+        if input_params["use_doc_preprocessor"]:
+            use_doc_orientation_classify = input_params["use_doc_orientation_classify"]
+            use_doc_unwarping = input_params["use_doc_unwarping"]
+            doc_preprocessor_res = next(
+                self.doc_preprocessor_pipeline(
+                    image_array,
+                    use_doc_orientation_classify=use_doc_orientation_classify,
+                    use_doc_unwarping=use_doc_unwarping,
+                )
+            )
+            doc_preprocessor_image = doc_preprocessor_res["output_img"]
+        else:
+            doc_preprocessor_res = {}
+            doc_preprocessor_image = image_array
+        return doc_preprocessor_res, doc_preprocessor_image
+
+    def predict_single_table_recognition_res(
+        self,
+        image_array: np.ndarray,
+        overall_ocr_res: OCRResult,
+        table_box: list,
+        flag_find_nei_text: bool = True,
+    ) -> SingleTableRecognitionResult:
+        """
+        Predict table recognition results from an image array, layout detection results, and OCR results.
+
+        Args:
+            image_array (np.ndarray): The input image represented as a numpy array.
+            overall_ocr_res (OCRResult): Overall OCR result obtained after running the OCR pipeline.
+                The overall OCR results containing text recognition information.
+            table_box (list): The table box coordinates.
+            flag_find_nei_text (bool): Whether to find neighboring text.
+        Returns:
+            SingleTableRecognitionResult: single table recognition result.
+        """
+        table_structure_pred = next(self.table_structure_model(image_array))
+        single_table_recognition_res = get_table_recognition_res(
+            table_box, table_structure_pred, overall_ocr_res
+        )
+        neighbor_text = ""
+        if flag_find_nei_text:
+            match_idx_list = get_neighbor_boxes_idx(
+                overall_ocr_res["dt_boxes"], table_box
+            )
+            if len(match_idx_list) > 0:
+                for idx in match_idx_list:
+                    neighbor_text += overall_ocr_res["rec_text"][idx] + "; "
+        single_table_recognition_res["neighbor_text"] = neighbor_text
+        return single_table_recognition_res
+
+    def predict(
+        self,
+        input: str | list[str] | np.ndarray | list[np.ndarray],
+        use_layout_detection: bool = True,
+        use_doc_orientation_classify: bool = False,
+        use_doc_unwarping: bool = False,
+        overall_ocr_res: OCRResult = None,
+        layout_det_res: DetResult = None,
+        **kwargs
+    ) -> TableRecognitionResult:
+        """
+        This function predicts the layout parsing result for the given input.
+
+        Args:
+            input (str | list[str] | np.ndarray | list[np.ndarray]): The input image(s) of pdf(s) to be processed.
+            use_layout_detection (bool): Whether to use layout detection.
+            use_doc_orientation_classify (bool): Whether to use document orientation classification.
+            use_doc_unwarping (bool): Whether to use document unwarping.
+            overall_ocr_res (OCRResult): The overall OCR result with convert_points_to_boxes information.
+                It will be used if it is not None and use_ocr_model is False.
+            layout_det_res (DetResult): The layout detection result.
+                It will be used if it is not None and use_layout_detection is False.
+            **kwargs: Additional keyword arguments.
+
+        Returns:
+            TableRecognitionResult: The predicted table recognition result.
+        """
+
+        input_params = {
+            "use_layout_detection": use_layout_detection,
+            "use_doc_preprocessor": self.use_doc_preprocessor,
+            "use_doc_orientation_classify": use_doc_orientation_classify,
+            "use_doc_unwarping": use_doc_unwarping,
+            "use_ocr_model": self.use_ocr_model,
+        }
+
+        if use_doc_orientation_classify or use_doc_unwarping:
+            input_params["use_doc_preprocessor"] = True
+        else:
+            input_params["use_doc_preprocessor"] = False
+
+        if not self.check_input_params_valid(
+            input_params, overall_ocr_res, layout_det_res
+        ):
+            yield None
+
+        for img_id, batch_data in enumerate(self.batch_sampler(input)):
+            image_array = self.img_reader(batch_data)[0]
+            img_id += 1
+
+            doc_preprocessor_res, doc_preprocessor_image = (
+                self.predict_doc_preprocessor_res(image_array, input_params)
+            )
+
+            if self.use_ocr_model:
+                overall_ocr_res = next(
+                    self.general_ocr_pipeline(doc_preprocessor_image)
+                )
+                dt_boxes = convert_points_to_boxes(overall_ocr_res["dt_polys"])
+                overall_ocr_res["dt_boxes"] = dt_boxes
+
+            table_res_list = []
+            table_region_id = 1
+            if not input_params["use_layout_detection"] and layout_det_res is None:
+                layout_det_res = {}
+                img_height, img_width = doc_preprocessor_image.shape[:2]
+                table_box = [0, 0, img_width - 1, img_height - 1]
+                single_table_rec_res = self.predict_single_table_recognition_res(
+                    doc_preprocessor_image,
+                    overall_ocr_res,
+                    table_box,
+                    flag_find_nei_text=False,
+                )
+                single_table_rec_res["table_region_id"] = table_region_id
+                table_res_list.append(single_table_rec_res)
+                table_region_id += 1
+            else:
+                if input_params["use_layout_detection"]:
+                    layout_det_res = next(self.layout_det_model(doc_preprocessor_image))
+                for box_info in layout_det_res["boxes"]:
+                    if box_info["label"].lower() in ["table"]:
+                        crop_img_info = self._crop_by_boxes(image_array, [box_info])
+                        crop_img_info = crop_img_info[0]
+                        table_box = crop_img_info["box"]
+                        single_table_rec_res = (
+                            self.predict_single_table_recognition_res(
+                                crop_img_info["img"], overall_ocr_res, table_box
+                            )
+                        )
+                        single_table_rec_res["table_region_id"] = table_region_id
+                        table_res_list.append(single_table_rec_res)
+                        table_region_id += 1
+
+            single_img_res = {
+                "layout_det_res": layout_det_res,
+                "doc_preprocessor_res": doc_preprocessor_res,
+                "overall_ocr_res": overall_ocr_res,
+                "table_res_list": table_res_list,
+                "input_params": input_params,
+                "img_id": img_id,
+            }
+            yield TableRecognitionResult(single_img_res)

+ 111 - 0
paddlex/inference/pipelines_new/table_recognition/result.py

@@ -0,0 +1,111 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+from typing import Dict
+from pathlib import Path
+import numpy as np
+import cv2
+from ...common.result import BaseCVResult, HtmlMixin, XlsxMixin
+
+
+class SingleTableRecognitionResult(BaseCVResult, HtmlMixin, XlsxMixin):
+    """table recognition result"""
+
+    def __init__(self, data: Dict) -> None:
+        """Initializes the object with given data and sets up mixins for HTML and XLSX processing."""
+        super().__init__(data)
+        HtmlMixin.__init__(self)  # Initializes the HTML mixin functionality
+        XlsxMixin.__init__(self)  # Initializes the XLSX mixin functionality
+
+    def _to_html(self) -> str:
+        """Converts the prediction to its corresponding HTML representation.
+
+        Returns:
+            str: The HTML string representation of the prediction.
+        """
+        return self["pred_html"]
+
+    def _to_xlsx(self) -> str:
+        """Converts the prediction HTML to an XLSX file path.
+
+        Returns:
+            str: The path to the XLSX file containing the prediction data.
+        """
+        return self["pred_html"]
+
+    def _to_img(self) -> np.ndarray:
+        """
+        Convert the input image with table OCR predictions to an image with cell boundaries highlighted.
+
+        Returns:
+            np.ndarray: The input image with cell boundaries highlighted in red.
+        """
+        input_img = self["table_ocr_pred"]["input_img"].copy()
+        cell_box_list = self["cell_box_list"]
+        for box in cell_box_list:
+            x1, y1, x2, y2 = [int(pos) for pos in box]
+            cv2.rectangle(input_img, (x1, y1), (x2, y2), (255, 0, 0), 2)
+        return input_img
+
+
+class TableRecognitionResult(dict):
+    """Layout Parsing Result"""
+
+    def __init__(self, data) -> None:
+        """Initializes a new instance of the class with the specified data."""
+        super().__init__(data)
+
+    def save_results(self, save_path: str) -> None:
+        """Save the table recognition results to the specified directory.
+
+        Args:
+            save_path (str): The directory path to save the results.
+        """
+
+        if not os.path.isdir(save_path):
+            return
+
+        img_id = self["img_id"]
+        layout_det_res = self["layout_det_res"]
+        if len(layout_det_res) > 0:
+            save_img_path = Path(save_path) / f"layout_det_result_img{img_id}.jpg"
+            layout_det_res.save_to_img(save_img_path)
+
+        input_params = self["input_params"]
+        if input_params["use_doc_preprocessor"]:
+            save_img_path = Path(save_path) / f"doc_preprocessor_result_img{img_id}.jpg"
+            self["doc_preprocessor_res"].save_to_img(save_img_path)
+
+        save_img_path = Path(save_path) / f"overall_ocr_result_img{img_id}.jpg"
+        self["overall_ocr_res"].save_to_img(save_img_path)
+
+        for tno in range(len(self["table_res_list"])):
+            table_res = self["table_res_list"][tno]
+            table_region_id = table_res["table_region_id"]
+            save_img_path = (
+                Path(save_path)
+                / f"table_res_cell_img{img_id}_region{table_region_id}.jpg"
+            )
+            table_res.save_to_img(save_img_path)
+            save_html_path = (
+                Path(save_path) / f"table_res_img{img_id}_region{table_region_id}.html"
+            )
+            table_res.save_to_html(save_html_path)
+            save_xlsx_path = (
+                Path(save_path) / f"table_res_img{img_id}_region{table_region_id}.xlsx"
+            )
+            table_res.save_to_xlsx(save_xlsx_path)
+
+        return

+ 9 - 10
paddlex/inference/pipelines_new/layout_parsing/table_recognition_post_processing.py → paddlex/inference/pipelines_new/table_recognition/table_recognition_post_processing.py

@@ -11,11 +11,10 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
-from .utils import convert_points_to_boxes, get_sub_regions_ocr_res
-import numpy as np
-from .result import TableRecognitionResult
 from typing import Any, Dict, Optional
+import numpy as np
+from ..layout_parsing.utils import convert_points_to_boxes, get_sub_regions_ocr_res
+from .result import SingleTableRecognitionResult
 from ..ocr.result import OCRResult
 
 
@@ -209,20 +208,20 @@ def get_html_result(
 
 
 def get_table_recognition_res(
-    crop_img_info: dict, table_structure_pred: dict, overall_ocr_res: OCRResult
-) -> TableRecognitionResult:
+    table_box: list, table_structure_pred: dict, overall_ocr_res: OCRResult
+) -> SingleTableRecognitionResult:
     """
     Retrieve table recognition result from cropped image info, table structure prediction, and overall OCR result.
 
     Args:
-        crop_img_info (dict): Information about the cropped image, including the bounding box.
+        table_box (list): Information about the location of cropped image, including the bounding box.
         table_structure_pred (dict): Predicted table structure.
         overall_ocr_res (OCRResult): Overall OCR result from the input image.
 
     Returns:
-        TableRecognitionResult: An object containing the table recognition result.
+        SingleTableRecognitionResult: An object containing the single table recognition result.
     """
-    table_box = np.array([crop_img_info["box"]])
+    table_box = np.array([table_box])
     table_ocr_pred = get_sub_regions_ocr_res(overall_ocr_res, table_box)
 
     crop_start_point = [table_box[0][0], table_box[0][1]]
@@ -243,4 +242,4 @@ def get_table_recognition_res(
         "table_ocr_pred": table_ocr_pred,
         "pred_html": pred_html,
     }
-    return TableRecognitionResult(single_img_res)
+    return SingleTableRecognitionResult(single_img_res)

+ 44 - 0
paddlex/inference/pipelines_new/table_recognition/utils.py

@@ -0,0 +1,44 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+__all__ = ["get_neighbor_boxes_idx"]
+
+import numpy as np
+
+
+def get_neighbor_boxes_idx(src_boxes: np.ndarray, ref_box: np.ndarray) -> list:
+    """
+    Retrieve indices of source boxes that are neighbors to the reference box.
+
+    Parameters:
+    src_boxes (np.ndarray): An array of bounding boxes with shape (N, 4),
+                            where N is the number of boxes and each box is represented
+                            by [x1, y1, x2, y2].
+    ref_box (np.ndarray): A single bounding box represented by [x1, y1, x2, y2].
+
+    Returns:
+    list: A list of indices of the source boxes that are close to the
+          reference box based on the intersection area.
+    """
+    match_idx_list = []
+    if len(src_boxes) > 0:
+        x1 = np.maximum(ref_box[0], src_boxes[:, 0])
+        y1 = np.maximum(ref_box[1], src_boxes[:, 1])
+        x2 = np.minimum(ref_box[2], src_boxes[:, 2])
+        y2 = np.minimum(ref_box[3], src_boxes[:, 3])
+        pub_w = x2 - x1
+        pub_h = y2 - y1
+        match_idx = np.where((pub_w > 0) & (pub_h < 3) & (pub_h > -15))[0]
+        match_idx_list.extend(match_idx)
+    return match_idx_list