dyning 11 месяцев назад
Родитель
Сommit
dc481d60e1
35 измененных файлов с 1497 добавлено и 431 удалено
  1. 2 2
      api_examples/pipelines/test_doc_preprocessor.py
  2. 5 5
      api_examples/pipelines/test_layout_parsing.py
  3. 2 2
      api_examples/pipelines/test_ocr.py
  4. 24 19
      api_examples/pipelines/test_pp_chatocrv3.py
  5. 4 0
      paddlex/configs/pipelines/OCR.yaml
  6. 65 24
      paddlex/configs/pipelines/PP-ChatOCRv3-doc.yaml
  7. 2 0
      paddlex/configs/pipelines/doc_preprocessor.yaml
  8. 11 0
      paddlex/configs/pipelines/layout_parsing.yaml
  9. 109 13
      paddlex/inference/pipelines_new/__init__.py
  10. 64 27
      paddlex/inference/pipelines_new/base.py
  11. 6 4
      paddlex/inference/pipelines_new/components/__init__.py
  12. 9 2
      paddlex/inference/pipelines_new/components/chat_server/base.py
  13. 25 4
      paddlex/inference/pipelines_new/components/chat_server/ernie_bot_chat.py
  14. 2 1
      paddlex/inference/pipelines_new/components/common/__init__.py
  15. 36 0
      paddlex/inference/pipelines_new/components/common/base_operator.py
  16. 28 23
      paddlex/inference/pipelines_new/components/common/base_result.py
  17. 124 42
      paddlex/inference/pipelines_new/components/common/crop_image_regions.py
  18. 2 0
      paddlex/inference/pipelines_new/components/common/seal_det_warp.py
  19. 38 6
      paddlex/inference/pipelines_new/components/common/sort_boxes.py
  20. 4 1
      paddlex/inference/pipelines_new/components/prompt_engeering/base.py
  21. 53 16
      paddlex/inference/pipelines_new/components/prompt_engeering/generate_kie_prompt.py
  22. 39 3
      paddlex/inference/pipelines_new/components/retriever/base.py
  23. 80 13
      paddlex/inference/pipelines_new/components/retriever/ernie_bot_retriever.py
  24. 2 0
      paddlex/inference/pipelines_new/components/utils/mixin.py
  25. 69 16
      paddlex/inference/pipelines_new/doc_preprocessor/pipeline.py
  26. 25 6
      paddlex/inference/pipelines_new/doc_preprocessor/result.py
  27. 99 22
      paddlex/inference/pipelines_new/layout_parsing/pipeline.py
  28. 90 13
      paddlex/inference/pipelines_new/layout_parsing/result.py
  29. 47 10
      paddlex/inference/pipelines_new/layout_parsing/table_recognition_post_processing.py
  30. 38 7
      paddlex/inference/pipelines_new/layout_parsing/utils.py
  31. 36 12
      paddlex/inference/pipelines_new/ocr/pipeline.py
  32. 50 22
      paddlex/inference/pipelines_new/ocr/result.py
  33. 286 95
      paddlex/inference/pipelines_new/pp_chatocrv3_doc/pipeline.py
  34. 0 17
      paddlex/inference/pipelines_new/pp_chatocrv3_doc/result.py
  35. 21 4
      paddlex/utils/fonts/__init__.py

+ 2 - 2
api_examples/pipelines/test_doc_preprocessor.py

@@ -16,8 +16,8 @@ from paddlex import create_pipeline
 
 
 pipeline = create_pipeline(pipeline="doc_preprocessor")
 pipeline = create_pipeline(pipeline="doc_preprocessor")
 
 
-test_img_path = "./test_imgs/img_rot180_demo.jpg"
-# test_img_path = "./test_imgs/doc_distort_test.jpg"
+test_img_path = "./test_demo_imgs/img_rot180_demo.jpg"
+# test_img_path = "./test_demo_imgs/doc_distort_test.jpg"
 
 
 output = pipeline.predict(
 output = pipeline.predict(
     test_img_path, use_doc_orientation_classify=True, use_doc_unwarping=True
     test_img_path, use_doc_orientation_classify=True, use_doc_unwarping=True

+ 5 - 5
api_examples/pipelines/test_layout_parsing.py

@@ -17,7 +17,7 @@ from paddlex import create_pipeline
 pipeline = create_pipeline(pipeline="layout_parsing")
 pipeline = create_pipeline(pipeline="layout_parsing")
 
 
 output = pipeline.predict(
 output = pipeline.predict(
-    "./test_imgs/test_layout_parsing.jpg",
+    "./test_demo_imgs/test_layout_parsing.jpg",
     use_doc_orientation_classify=True,
     use_doc_orientation_classify=True,
     use_doc_unwarping=True,
     use_doc_unwarping=True,
     use_common_ocr=True,
     use_common_ocr=True,
@@ -25,10 +25,10 @@ output = pipeline.predict(
     use_table_recognition=True,
     use_table_recognition=True,
 )
 )
 
 
-# output = pipeline("./test_imgs/demo_paper.png")
-# output = pipeline("./test_imgs/table_recognition.jpg")
-# output = pipeline.predict("./test_imgs/seal_text_det.png")
-# output = pipeline.predict("./test_imgs/img_rot180_demo.jpg")
+# output = pipeline("./test_demo_imgs/demo_paper.png")
+# output = pipeline("./test_demo_imgs/table_recognition.jpg")
+# output = pipeline.predict("./test_demo_imgs/seal_text_det.png")
+# output = pipeline.predict("./test_demo_imgs/img_rot180_demo.jpg")
 for res in output:
 for res in output:
     # print(res)
     # print(res)
     res.save_results("./output")
     res.save_results("./output")

+ 2 - 2
api_examples/pipelines/test_ocr.py

@@ -16,9 +16,9 @@ from paddlex import create_pipeline
 
 
 pipeline = create_pipeline(pipeline="OCR")
 pipeline = create_pipeline(pipeline="OCR")
 
 
-# output = pipeline.predict("./test_imgs/general_ocr_002.png")
+# output = pipeline.predict("./test_demo_imgs/general_ocr_002.png")
 
 
-output = pipeline.predict("./test_imgs/seal_text_det.png")
+output = pipeline.predict("./test_demo_imgs/seal_text_det.png")
 for res in output:
 for res in output:
     print(res)
     print(res)
     res.save_to_img("./output")
     res.save_to_img("./output")

+ 24 - 19
api_examples/pipelines/test_pp_chatocrv3.py

@@ -16,35 +16,40 @@ from paddlex import create_pipeline
 
 
 pipeline = create_pipeline(pipeline="PP-ChatOCRv3-doc")
 pipeline = create_pipeline(pipeline="PP-ChatOCRv3-doc")
 
 
-# img_path = "./test_demo_imgs/vehicle_certificate-1.png"
-# key_list = ['驾驶室准乘人数']
+img_path = "./test_demo_imgs/vehicle_certificate-1.png"
+key_list = ["驾驶室准乘人数"]
 
 
 # img_path = "./test_demo_imgs/test_layout_parsing.jpg"
 # img_path = "./test_demo_imgs/test_layout_parsing.jpg"
 # key_list = ['3.2的标题']
 # key_list = ['3.2的标题']
 
 
-img_path = "./test_demo_imgs/seal_text_det.png"
-key_list = ["印章上公司"]
+# img_path = "./test_demo_imgs/seal_text_det.png"
+# key_list = ["印章上公司"]
 
 
-# visual_predict_res = pipeline.visual_predict(img_path,
-#     use_doc_orientation_classify=True,
-#     use_doc_unwarping=True,
-#     use_common_ocr=True,
-#     use_seal_recognition=True,
-#     use_table_recognition=True)
+# img_path = "./badcase_images/circle_Aug06850_1.jpg"
+# key_list = ['印章名称', '印章编号']
+
+visual_predict_res = pipeline.visual_predict(
+    img_path,
+    use_doc_orientation_classify=False,
+    use_doc_unwarping=False,
+    use_common_ocr=True,
+    use_seal_recognition=True,
+    use_table_recognition=True,
+)
 
 
 # ####[TODO] 增加类别信息
 # ####[TODO] 增加类别信息
-# visual_info_list = []
-# for res in visual_predict_res:
-#     visual_info_list.append(res["visual_info"])
+visual_info_list = []
+for res in visual_predict_res:
+    # res['layout_parsing_result'].save_results("./output/")
+    # print(res["visual_info"])
+    visual_info_list.append(res["visual_info"])
 
 
-# pipeline.save_visual_info_list(visual_info_list, "./res_visual_info/visual_info3.json")
+# pipeline.save_visual_info_list(visual_info_list, "./res_visual_info/tmp_visual_info.json")
 
 
-visual_info_list = pipeline.load_visual_info_list("./res_visual_info/visual_info3.json")
+# visual_info_list = pipeline.load_visual_info_list("./res_visual_info/tmp_visual_info.json")
 
 
 vector_info = pipeline.build_vector(visual_info_list)
 vector_info = pipeline.build_vector(visual_info_list)
 
 
-print(vector_info)
-
-final_results = pipeline.chat(visual_info_list, key_list, vector_info)
+chat_result = pipeline.chat(key_list, visual_info_list, vector_info=vector_info)
 
 
-print(final_results)
+print(chat_result)

+ 4 - 0
paddlex/configs/pipelines/OCR.yaml

@@ -10,10 +10,12 @@ text_type: common
 
 
 SubModules:
 SubModules:
   TextDetection:
   TextDetection:
+    module_name: text_detection
     model_name: PP-OCRv4_mobile_det
     model_name: PP-OCRv4_mobile_det
     model_dir: null
     model_dir: null
     batch_size: 1    
     batch_size: 1    
   TextRecognition:
   TextRecognition:
+    module_name: text_recognition
     model_name: PP-OCRv4_mobile_rec
     model_name: PP-OCRv4_mobile_rec
     model_dir: null
     model_dir: null
     batch_size: 1
     batch_size: 1
@@ -27,10 +29,12 @@ SubModules:
 
 
 # SubModules:
 # SubModules:
 #   TextDetection:
 #   TextDetection:
+#     module_name: text_detection
 #     model_name: PP-OCRv4_mobile_seal_det
 #     model_name: PP-OCRv4_mobile_seal_det
 #     model_dir: null
 #     model_dir: null
 #     batch_size: 1    
 #     batch_size: 1    
 #   TextRecognition:
 #   TextRecognition:
+#     module_name: text_recognition
 #     model_name: PP-OCRv4_mobile_rec
 #     model_name: PP-OCRv4_mobile_rec
 #     model_dir: null
 #     model_dir: null
 #     batch_size: 1
 #     batch_size: 1

+ 65 - 24
paddlex/configs/pipelines/PP-ChatOCRv3-doc.yaml

@@ -2,50 +2,81 @@
 pipeline_name: PP-ChatOCRv3-doc
 pipeline_name: PP-ChatOCRv3-doc
 input: https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/demo_paper.png
 input: https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/demo_paper.png
 
 
-use_vector_retrieval: True
-
 SubModules:
 SubModules:
   LLM_Chat:
   LLM_Chat:
+    module_name: chat_bot
     model_name: ernie-3.5
     model_name: ernie-3.5
     api_type: qianfan
     api_type: qianfan
-    # ak: "api_key" # Set this to a real API key
-    # sk: "secret_key"  # Set this to a real secret key
-    ak: 4iiqB0QfvXTAENgzUwNeDjQ7
-    sk: sHQCw4l5A6jnzbHMa0ZvDi05GT9Qz8tZ
+    ak: "api_key" # Set this to a real API key
+    sk: "secret_key"  # Set this to a real secret key
+
 
 
   LLM_Retriever:
   LLM_Retriever:
+    module_name: retriever
     model_name: ernie-3.5
     model_name: ernie-3.5
     api_type: qianfan
     api_type: qianfan
-    # ak: "api_key" # Set this to a real API key
-    # sk: "secret_key"  # Set this to a real secret key
-    ak: 4iiqB0QfvXTAENgzUwNeDjQ7
-    sk: sHQCw4l5A6jnzbHMa0ZvDi05GT9Qz8tZ
+    ak: "api_key" # Set this to a real API key
+    sk: "secret_key"  # Set this to a real secret key
 
 
   PromptEngneering:
   PromptEngneering:
     KIE_CommonText:
     KIE_CommonText:
+      module_name: prompt_engneering
       task_type: text_kie_prompt
       task_type: text_kie_prompt
-      task_description: '你现在的任务是从OCR文字识别的结果中提取关键词列表中每一项对应的关键信息。
+
+      # task_description: '你现在的任务是从OCR文字识别的结果中提取关键词列表中每一项对应的关键信息。
+      #     OCR的文字识别结果使用```符号包围,包含所识别出来的文字,顺序在原始图片中从左至右、从上至下。
+      #     我指定的关键词列表使用[]符号包围。请注意OCR的文字识别结果可能存在长句子换行被切断、不合理的分词、
+      #     文字被错误合并等问题,你需要结合上下文语义进行综合判断,以抽取准确的关键信息。'
+
+      # task_description: '你现在的任务是从OCR文字识别的结果中提取关键词列表中每一项对应的关键信息。
+      #     OCR的文字识别结果使用```符号包围,包含所识别出来的文字,顺序在原始图片中从左至右、从上至下。
+      #     我指定的关键词列表使用[]符号包围。请注意OCR的文字识别结果可能存在长句子换行被切断、不合理的分词、
+      #     文字被错误合并等问题,你需要结合上下文语义进行综合判断,以抽取准确的关键信息。
+      #     提取的关键信息尽可能详细和完整,并保持格式、单位、符号和标点都与OCR结果中的内容完全一致。'
+
+      # rules_str: 
+
+      task_description: '你现在的任务是从OCR结果中提取问题列表中每一个问题的答案。
           OCR的文字识别结果使用```符号包围,包含所识别出来的文字,顺序在原始图片中从左至右、从上至下。
           OCR的文字识别结果使用```符号包围,包含所识别出来的文字,顺序在原始图片中从左至右、从上至下。
-          我指定的关键词列表使用[]符号包围。请注意OCR的文字识别结果可能存在长句子换行被切断、不合理的分词、
-          文字被错误合并等问题,你需要结合上下文语义进行综合判断,以抽取准确的关键信息。'
-      output_format: '在返回结果时使用JSON格式,包含多个key-value对,key值为我指定的关键词,value值为所抽取的结果。
-          如果认为OCR识别结果中没有关键词key对应的value,则将value赋值为"未知"。请只输出json格式的结果,
+          我指定的问题列表使用[]符号包围。请注意OCR的文字识别结果可能存在长句子换行被切断、不合理的分词、
+          文字被错误合并等问题,你需要结合上下文语义进行综合判断,以获取准确的答案。'
+
+      output_format: '在返回结果时使用JSON格式,包含多个key-value对,key值为我指定的问题,value值为该问题对应的答案。
+          如果认为OCR识别结果中,对于问题key,没有答案,则将value赋值为"未知"。请只输出json格式的结果,
           并做json格式校验后返回,不要包含其它多余文字!'
           并做json格式校验后返回,不要包含其它多余文字!'
-      rules_str:
+
+      rules_str: '每个问题的答案用OCR结果的内容回答,可以是单词、短语或句子,针对问题回答尽可能详细和完整,
+        并保持格式、单位、符号和标点都与OCR结果中的内容完全一致。'
+
       few_shot_demo_text_content:
       few_shot_demo_text_content:
       few_shot_demo_key_value_list:
       few_shot_demo_key_value_list:
           
           
     KIE_Table:
     KIE_Table:
+      module_name: prompt_engneering
       task_type: table_kie_prompt
       task_type: table_kie_prompt
-      task_description: '你现在的任务是从输入的html格式的表格内容中提取关键词列表中每一项对应的关键信息,
-          表格内容用```符号包围,我指定的关键词列表使用[]符号包围。你需要结合上下文语义进行综合判断,以抽取准确的关键信息。
-          在返回结果时使用JSON格式,包含多个key-value对,key值为我指定的关键词,value值为所抽取的结果。
-          如果认为输入的表格内容中没有关键词key对应的value值,则将value赋值为"未知"。
-          请只输出json格式的结果,并做json格式校验后返回,不要包含其它多余文字!'
-      output_format: '在返回结果时使用JSON格式,包含多个key-value对,key值为我指定的关键词,value值为所抽取的结果。
-          如果认为表格识别结果中没有关键词key对应的value,则将value赋值为"未知"。请只输出json格式的结果,
+
+      # task_description: '你现在的任务是从输入的表格内容中提取关键词列表中每一项对应的关键信息,
+      #     表格内容用```符号包围,我指定的关键词列表使用[]符号包围。你需要结合上下文语义进行综合判断,以抽取准确的关键信息。'
+
+      # task_description: '你现在的任务是从输入的表格内容中提取关键词列表中每一项对应的关键信息,
+      #     表格内容用```符号包围,我指定的关键词列表使用[]符号包围。你需要结合上下文语义进行综合判断,以抽取准确的关键信息。
+      #     提取的关键信息尽可能详细和完整,并保持格式、单位、符号和标点都与OCR结果中的内容完全一致。'
+
+      # output_format: '在返回结果时使用JSON格式,包含多个key-value对,key值为我指定的关键词,value值为所抽取的结果。
+      #     如果认为表格识别结果中没有关键词key对应的value,则将value赋值为"未知"。请只输出json格式的结果,
+      #     并做json格式校验后返回,不要包含其它多余文字!'
+      # rules_str:
+
+      task_description: '你现在的任务是从输入的表格内容中提取问题列表中每一个问题的答案。
+          表格内容使用```符号包围,我指定的问题列表使用[]符号包围。'
+
+      output_format: '在返回结果时使用JSON格式,包含多个key-value对,key值为我指定的问题,value值为该问题对应的答案。
+          如果认为表格内容中,对于问题key,没有答案,则将value赋值为"未知"。请只输出json格式的结果,
           并做json格式校验后返回,不要包含其它多余文字!'
           并做json格式校验后返回,不要包含其它多余文字!'
-      rules_str:
+
+      rules_str: '每个问题的答案用表格中内容回答,可以是单词、短语或句子,针对问题回答尽可能详细和完整,
+        并保持格式、单位、符号和标点都与表格内容中的内容完全一致。'
+
       few_shot_demo_text_content:
       few_shot_demo_text_content:
       few_shot_demo_key_value_list:
       few_shot_demo_key_value_list:
 
 
@@ -59,10 +90,12 @@ SubPipelines:
 
 
     SubModules:
     SubModules:
       LayoutDetection:
       LayoutDetection:
+        module_name: object_detection
         model_name: RT-DETR-H_layout_3cls
         model_name: RT-DETR-H_layout_3cls
         model_dir: null
         model_dir: null
         batch_size: 1
         batch_size: 1
       TableStructurePredictor:
       TableStructurePredictor:
+        module_name: table_recognition
         model_name: SLANet_plus
         model_name: SLANet_plus
         model_dir: null
         model_dir: null
         batch_size: 1
         batch_size: 1
@@ -74,10 +107,12 @@ SubPipelines:
         use_doc_unwarping: True
         use_doc_unwarping: True
         SubModules:
         SubModules:
           DocOrientationClassify:
           DocOrientationClassify:
+            module_name: image_classification
             model_name: PP-LCNet_x1_0_doc_ori
             model_name: PP-LCNet_x1_0_doc_ori
             model_dir: null
             model_dir: null
             batch_size: 1
             batch_size: 1
           DocUnwarping:
           DocUnwarping:
+            module_name: image_unwarping
             model_name: UVDoc
             model_name: UVDoc
             model_dir: null
             model_dir: null
             batch_size: 1
             batch_size: 1
@@ -87,12 +122,15 @@ SubPipelines:
         text_type: common
         text_type: common
         SubModules:
         SubModules:
           TextDetection:
           TextDetection:
+            module_name: text_detection
             model_name: PP-OCRv4_server_det
             model_name: PP-OCRv4_server_det
             model_dir: null
             model_dir: null
             batch_size: 1    
             batch_size: 1    
           TextRecognition:
           TextRecognition:
+            module_name: text_recognition
             model_name: PP-OCRv4_server_rec
             model_name: PP-OCRv4_server_rec
             model_dir: null
             model_dir: null
+            # model_dir: /paddle/github/PaddleX/models/PP-OCRv4_server_rec_doc_infer
             batch_size: 1
             batch_size: 1
 
 
       SealOCR:
       SealOCR:
@@ -100,10 +138,13 @@ SubPipelines:
         text_type: seal
         text_type: seal
         SubModules:
         SubModules:
           TextDetection:
           TextDetection:
+            module_name: text_detection
             model_name: PP-OCRv4_server_seal_det
             model_name: PP-OCRv4_server_seal_det
             model_dir: null
             model_dir: null
             batch_size: 1    
             batch_size: 1    
           TextRecognition:
           TextRecognition:
+            module_name: text_recognition
             model_name: PP-OCRv4_server_rec
             model_name: PP-OCRv4_server_rec
             model_dir: null
             model_dir: null
+            # model_dir: /paddle/github/PaddleX/models/PP-OCRv4_server_rec_doc_infer
             batch_size: 1  
             batch_size: 1  

+ 2 - 0
paddlex/configs/pipelines/doc_preprocessor.yaml

@@ -7,10 +7,12 @@ use_doc_unwarping: True
 
 
 SubModules:
 SubModules:
   DocOrientationClassify:
   DocOrientationClassify:
+    module_name: image_classification
     model_name: PP-LCNet_x1_0_doc_ori
     model_name: PP-LCNet_x1_0_doc_ori
     model_dir: null
     model_dir: null
     batch_size: 1
     batch_size: 1
   DocUnwarping:
   DocUnwarping:
+    module_name: image_unwarping
     model_name: UVDoc
     model_name: UVDoc
     model_dir: null
     model_dir: null
     batch_size: 1
     batch_size: 1

+ 11 - 0
paddlex/configs/pipelines/layout_parsing.yaml

@@ -8,10 +8,15 @@ use_table_recognition: True
 
 
 SubModules:
 SubModules:
   LayoutDetection:
   LayoutDetection:
+    module_name: object_detection
     model_name: RT-DETR-H_layout_3cls
     model_name: RT-DETR-H_layout_3cls
     model_dir: null
     model_dir: null
     batch_size: 1
     batch_size: 1
+  ##############################################
+  ####### 【TODO】表格识别的 module_name 需要确认,是否是table_recognition
+  ##############################################
   TableStructurePredictor:
   TableStructurePredictor:
+    module_name: table_recognition
     model_name: SLANet_plus
     model_name: SLANet_plus
     model_dir: null
     model_dir: null
     batch_size: 1
     batch_size: 1
@@ -23,10 +28,12 @@ SubPipelines:
     use_doc_unwarping: True
     use_doc_unwarping: True
     SubModules:
     SubModules:
       DocOrientationClassify:
       DocOrientationClassify:
+        module_name: image_classification
         model_name: PP-LCNet_x1_0_doc_ori
         model_name: PP-LCNet_x1_0_doc_ori
         model_dir: null
         model_dir: null
         batch_size: 1
         batch_size: 1
       DocUnwarping:
       DocUnwarping:
+        module_name: image_unwarping
         model_name: UVDoc
         model_name: UVDoc
         model_dir: null
         model_dir: null
         batch_size: 1
         batch_size: 1
@@ -35,10 +42,12 @@ SubPipelines:
     text_type: common
     text_type: common
     SubModules:
     SubModules:
       TextDetection:
       TextDetection:
+        module_name: text_detection
         model_name: PP-OCRv4_server_det
         model_name: PP-OCRv4_server_det
         model_dir: null
         model_dir: null
         batch_size: 1    
         batch_size: 1    
       TextRecognition:
       TextRecognition:
+        module_name: text_recognition
         model_name: PP-OCRv4_server_rec
         model_name: PP-OCRv4_server_rec
         model_dir: null
         model_dir: null
         batch_size: 1
         batch_size: 1
@@ -47,10 +56,12 @@ SubPipelines:
     text_type: seal
     text_type: seal
     SubModules:
     SubModules:
       TextDetection:
       TextDetection:
+        module_name: text_detection
         model_name: PP-OCRv4_server_seal_det
         model_name: PP-OCRv4_server_seal_det
         model_dir: null
         model_dir: null
         batch_size: 1    
         batch_size: 1    
       TextRecognition:
       TextRecognition:
+        module_name: text_recognition
         model_name: PP-OCRv4_server_rec
         model_name: PP-OCRv4_server_rec
         model_dir: null
         model_dir: null
         batch_size: 1
         batch_size: 1

+ 109 - 13
paddlex/inference/pipelines_new/__init__.py

@@ -12,11 +12,6 @@
 # See the License for the specific language governing permissions and
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # limitations under the License.
 
 
-from pathlib import Path
-from typing import Any, Dict, Optional
-from .base import BasePipeline
-from ...utils.config import parse_config
-
 # from .single_model_pipeline import (
 # from .single_model_pipeline import (
 #     _SingleModelPipeline,
 #     _SingleModelPipeline,
 #     ImageClassification,
 #     ImageClassification,
@@ -40,13 +35,29 @@ from ...utils.config import parse_config
 # from .pp_shitu_v2 import ShiTuV2Pipeline
 # from .pp_shitu_v2 import ShiTuV2Pipeline
 # from .attribute_recognition import AttributeRecPipeline
 # from .attribute_recognition import AttributeRecPipeline
 
 
+from pathlib import Path
+from typing import Any, Dict, Optional
+from .base import BasePipeline
+from ..utils.pp_option import PaddlePredictorOption
+from .components import BaseChat, BaseRetriever, BaseGeneratePrompt
+from ...utils.config import parse_config
+
 from .ocr import OCRPipeline
 from .ocr import OCRPipeline
 from .doc_preprocessor import DocPreprocessorPipeline
 from .doc_preprocessor import DocPreprocessorPipeline
 from .layout_parsing import LayoutParsingPipeline
 from .layout_parsing import LayoutParsingPipeline
 from .pp_chatocrv3_doc import PP_ChatOCRv3_doc_Pipeline
 from .pp_chatocrv3_doc import PP_ChatOCRv3_doc_Pipeline
 
 
 
 
-def get_pipeline_path(pipeline_name):
+def get_pipeline_path(pipeline_name: str) -> str:
+    """
+    Get the full path of the pipeline configuration file based on the provided pipeline name.
+
+    Args:
+        pipeline_name (str): The name of the pipeline.
+
+    Returns:
+        str: The full path to the pipeline configuration file or None if not found.
+    """
     pipeline_path = (
     pipeline_path = (
         Path(__file__).parent.parent.parent
         Path(__file__).parent.parent.parent
         / "configs/pipelines"
         / "configs/pipelines"
@@ -58,6 +69,18 @@ def get_pipeline_path(pipeline_name):
 
 
 
 
 def load_pipeline_config(pipeline_name: str) -> Dict[str, Any]:
 def load_pipeline_config(pipeline_name: str) -> Dict[str, Any]:
+    """
+    Load the pipeline configuration.
+
+    Args:
+        pipeline_name (str): The name of the pipeline or the path to the config file.
+
+    Returns:
+        Dict[str, Any]: The parsed pipeline configuration.
+
+    Raises:
+        Exception: If the config file of pipeline does not exist.
+    """
     if not Path(pipeline_name).exists():
     if not Path(pipeline_name).exists():
         pipeline_path = get_pipeline_path(pipeline_name)
         pipeline_path = get_pipeline_path(pipeline_name)
         if pipeline_path is None:
         if pipeline_path is None:
@@ -72,24 +95,39 @@ def load_pipeline_config(pipeline_name: str) -> Dict[str, Any]:
 
 
 def create_pipeline(
 def create_pipeline(
     pipeline: str,
     pipeline: str,
-    device=None,
-    pp_option=None,
+    config: Dict = None,
+    device: str = None,
+    pp_option: PaddlePredictorOption = None,
     use_hpip: bool = False,
     use_hpip: bool = False,
     hpi_params: Optional[Dict[str, Any]] = None,
     hpi_params: Optional[Dict[str, Any]] = None,
     *args,
     *args,
     **kwargs,
     **kwargs,
 ) -> BasePipeline:
 ) -> BasePipeline:
-    """build model evaluater
+    """
+    Create a pipeline instance based on the provided parameters.
+    If the input parameter config is not provided,
+    it is obtained from the default config corresponding to the pipeline name.
 
 
     Args:
     Args:
-        pipeline (str): the pipeline name, that is name of pipeline class
+        pipeline (str): The name of the pipeline to create.
+        config (Dict, optional): The path to the pipeline configuration file. Defaults to None.
+        device (str, optional): The device to run the pipeline on. Defaults to None.
+        pp_option (PaddlePredictorOption, optional): The options for the PaddlePredictor. Defaults to None.
+        use_hpip (bool, optional): Whether to use high-performance inference (hpip) for prediction. Defaults to False.
+        hpi_params (Optional[Dict[str, Any]], optional): Additional parameters for hpip. Defaults to None.
+        *args: Additional positional arguments.
+        **kwargs: Additional keyword arguments.
 
 
     Returns:
     Returns:
-        BasePipeline: the pipeline, which is subclass of BasePipeline.
+        BasePipeline: The created pipeline instance.
     """
     """
+
     pipeline_name = pipeline
     pipeline_name = pipeline
-    config = load_pipeline_config(pipeline_name)
-    assert pipeline_name == config["pipeline_name"]
+
+    if config is None:
+        config = load_pipeline_config(pipeline_name)
+        assert pipeline_name == config["pipeline_name"]
+
     pipeline = BasePipeline.get(pipeline_name)(
     pipeline = BasePipeline.get(pipeline_name)(
         config=config,
         config=config,
         device=device,
         device=device,
@@ -98,3 +136,61 @@ def create_pipeline(
         hpi_params=hpi_params,
         hpi_params=hpi_params,
     )
     )
     return pipeline
     return pipeline
+
+
+def create_chat_bot(config: Dict, *args, **kwargs) -> BaseChat:
+    """Creates an instance of a chat bot based on the provided configuration.
+
+    Args:
+        config (Dict): Configuration settings, expected to be a dictionary with at least a 'model_name' key.
+        *args: Additional positional arguments. Not used in this function but allowed for future compatibility.
+        **kwargs: Additional keyword arguments. Not used in this function but allowed for future compatibility.
+
+    Returns:
+        BaseChat: An instance of the chat bot class corresponding to the 'model_name' in the config.
+    """
+    model_name = config["model_name"]
+    chat_bot = BaseChat.get(model_name)(config)
+    return chat_bot
+
+
+def create_retriever(
+    config: Dict,
+    *args,
+    **kwargs,
+) -> BaseRetriever:
+    """
+    Creates a retriever instance based on the provided configuration.
+
+    Args:
+        config (Dict): Configuration settings, expected to be a dictionary with at least a 'model_name' key.
+        *args: Additional positional arguments. Not used in this function but allowed for future compatibility.
+        **kwargs: Additional keyword arguments. Not used in this function but allowed for future compatibility.
+
+    Returns:
+        BaseRetriever: An instance of a retriever class corresponding to the 'model_name' in the config.
+    """
+    model_name = config["model_name"]
+    retriever = BaseRetriever.get(model_name)(config)
+    return retriever
+
+
+def create_prompt_engeering(
+    config: Dict,
+    *args,
+    **kwargs,
+) -> BaseGeneratePrompt:
+    """
+    Creates a prompt engineering instance based on the provided configuration.
+
+    Args:
+        config (Dict): Configuration settings, expected to be a dictionary with at least a 'task_type' key.
+        *args: Variable length argument list for additional positional arguments.
+        **kwargs: Arbitrary keyword arguments.
+
+    Returns:
+        BaseGeneratePrompt: An instance of a prompt engineering class corresponding to the 'task_type' in the config.
+    """
+    task_type = config["task_type"]
+    pe = BaseGeneratePrompt.get(task_type)(config)
+    return pe

+ 64 - 27
paddlex/inference/pipelines_new/base.py

@@ -18,24 +18,36 @@ import yaml
 import codecs
 import codecs
 from pathlib import Path
 from pathlib import Path
 from typing import Any, Dict, Optional
 from typing import Any, Dict, Optional
-from ..models import create_predictor
-from .components.chat_server.base import BaseChat
-from .components.retriever.base import BaseRetriever
-from .components.prompt_engeering.base import BaseGeneratePrompt
+from ..utils.pp_option import PaddlePredictorOption
+from ..models import BasePredictor
 
 
 
 
 class BasePipeline(ABC, metaclass=AutoRegisterABCMetaClass):
 class BasePipeline(ABC, metaclass=AutoRegisterABCMetaClass):
-    """Base Pipeline"""
+    """Base class for all pipelines.
+
+    This class serves as a foundation for creating various pipelines.
+    It includes common attributes and methods that are shared among all
+    pipeline implementations.
+    """
 
 
     __is_base = True
     __is_base = True
 
 
     def __init__(
     def __init__(
         self,
         self,
-        device=None,
-        pp_option=None,
+        device: str = None,
+        pp_option: PaddlePredictorOption = None,
         use_hpip: bool = False,
         use_hpip: bool = False,
         hpi_params: Optional[Dict[str, Any]] = None,
         hpi_params: Optional[Dict[str, Any]] = None,
     ) -> None:
     ) -> None:
+        """
+        Initializes the class with specified parameters.
+
+        Args:
+            device (str, optional): The device to use for prediction. Defaults to None.
+            pp_option (PaddlePredictorOption, optional): The options for PaddlePredictor. Defaults to None.
+            use_hpip (bool, optional): Whether to use high-performance inference (hpip) for prediction. Defaults to False.
+            hpi_params (Dict[str, Any], optional): Additional parameters for hpip. Defaults to None.
+        """
         super().__init__()
         super().__init__()
         self.device = device
         self.device = device
         self.pp_option = pp_option
         self.pp_option = pp_option
@@ -44,32 +56,62 @@ class BasePipeline(ABC, metaclass=AutoRegisterABCMetaClass):
 
 
     @abstractmethod
     @abstractmethod
     def predict(self, input, **kwargs):
     def predict(self, input, **kwargs):
+        """
+        Declaration of an abstract method. Subclasses are expected to
+        provide a concrete implementation of predict.
+        Args:
+            input: The input data to predict.
+            **kwargs: Additional keyword arguments.
+        """
         raise NotImplementedError("The method `predict` has not been implemented yet.")
         raise NotImplementedError("The method `predict` has not been implemented yet.")
 
 
-    def create_model(self, config):
+    def create_model(self, config: Dict) -> BasePredictor:
+        """
+        Create a model instance based on the given configuration.
+
+        Args:
+            config (Dict): A dictionary containing configuration settings.
+
+        Returns:
+            BasePredictor: An instance of the model.
+        """
 
 
         model_dir = config["model_dir"]
         model_dir = config["model_dir"]
         if model_dir == None:
         if model_dir == None:
             model_dir = config["model_name"]
             model_dir = config["model_name"]
 
 
-        model = create_predictor(
-            model_dir,
+        from ...model import create_model
+
+        model = create_model(
+            model=model_dir,
             device=self.device,
             device=self.device,
             pp_option=self.pp_option,
             pp_option=self.pp_option,
             use_hpip=self.use_hpip,
             use_hpip=self.use_hpip,
             hpi_params=self.hpi_params,
             hpi_params=self.hpi_params,
         )
         )
 
 
-        ########### [TODO]支持初始化传参能力
+        # [TODO] Support initializing with additional parameters
         if "batch_size" in config:
         if "batch_size" in config:
             batch_size = config["batch_size"]
             batch_size = config["batch_size"]
             model.set_predictor(batch_size=batch_size)
             model.set_predictor(batch_size=batch_size)
 
 
         return model
         return model
 
 
-    def create_pipeline(self, config):
+    def create_pipeline(self, config: Dict):
+        """
+        Creates a pipeline based on the provided configuration.
+
+        Args:
+            config (Dict): A dictionary containing the pipeline configuration.
+
+        Returns:
+            BasePipeline: An instance of the created pipeline.
+        """
+        from . import create_pipeline
+
         pipeline_name = config["pipeline_name"]
         pipeline_name = config["pipeline_name"]
-        pipeline = BasePipeline.get(pipeline_name)(
+        pipeline = create_pipeline(
+            pipeline_name,
             config=config,
             config=config,
             device=self.device,
             device=self.device,
             pp_option=self.pp_option,
             pp_option=self.pp_option,
@@ -78,20 +120,15 @@ class BasePipeline(ABC, metaclass=AutoRegisterABCMetaClass):
         )
         )
         return pipeline
         return pipeline
 
 
-    def create_chat_bot(self, config):
-        model_name = config["model_name"]
-        chat_bot = BaseChat.get(model_name)(config)
-        return chat_bot
-
-    def create_retriever(self, config):
-        model_name = config["model_name"]
-        retriever = BaseRetriever.get(model_name)(config)
-        return retriever
+    def __call__(self, input, **kwargs):
+        """
+        Calls the predict method with the given input and keyword arguments.
 
 
-    def create_prompt_engeering(self, config):
-        task_type = config["task_type"]
-        pe = BaseGeneratePrompt.get(task_type)(config)
-        return pe
+        Args:
+            input: The input data to be predicted.
+            **kwargs: Additional keyword arguments to be passed to the predict method.
 
 
-    def __call__(self, input, **kwargs):
+        Returns:
+            The prediction result from the predict method.
+        """
         return self.predict(input, **kwargs)
         return self.predict(input, **kwargs)

+ 6 - 4
paddlex/inference/pipelines_new/components/__init__.py

@@ -12,8 +12,10 @@
 # See the License for the specific language governing permissions and
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # limitations under the License.
 
 
-from .base import BaseComponent, CVResult, BaseResult
-from .common import SortQuadBoxes
-from .common import CropByPolys
-from .common import CropByBoxes
+from .common import CVResult, BaseResult
+from .common import SortQuadBoxes, SortPolyBoxes
+from .common import CropByPolys, CropByBoxes
 from .utils.mixin import HtmlMixin, XlsxMixin
 from .utils.mixin import HtmlMixin, XlsxMixin
+from .chat_server.base import BaseChat
+from .retriever.base import BaseRetriever
+from .prompt_engeering.base import BaseGeneratePrompt

+ 9 - 2
paddlex/inference/pipelines_new/components/chat_server/base.py

@@ -19,15 +19,22 @@ import inspect
 
 
 
 
 class BaseChat(ABC, metaclass=AutoRegisterABCMetaClass):
 class BaseChat(ABC, metaclass=AutoRegisterABCMetaClass):
-    """Base Chat"""
+    """Base class for all chat bots. This class serves as a foundation
+    for creating various chat bots.
+    """
 
 
     __is_base = True
     __is_base = True
 
 
-    def __init__(self):
+    def __init__(self) -> None:
+        """Initializes an instance of base chat."""
         super().__init__()
         super().__init__()
 
 
     @abstractmethod
     @abstractmethod
     def generate_chat_results(self):
     def generate_chat_results(self):
+        """
+        Declaration of an abstract method. Subclasses are expected to
+        provide a concrete implementation of generate_chat_results.
+        """
         raise NotImplementedError(
         raise NotImplementedError(
             "The method `generate_chat_results` has not been implemented yet."
             "The method `generate_chat_results` has not been implemented yet."
         )
         )

+ 25 - 4
paddlex/inference/pipelines_new/components/chat_server/ernie_bot_chat.py

@@ -15,6 +15,7 @@
 from .....utils import logging
 from .....utils import logging
 from .base import BaseChat
 from .base import BaseChat
 import erniebot
 import erniebot
+from typing import Dict
 
 
 
 
 class ErnieBotChat(BaseChat):
 class ErnieBotChat(BaseChat):
@@ -31,7 +32,18 @@ class ErnieBotChat(BaseChat):
         "ernie-char-8k",
         "ernie-char-8k",
     ]
     ]
 
 
-    def __init__(self, config):
+    def __init__(self, config: Dict) -> None:
+        """Initializes the ErnieBotChat with given configuration.
+
+        Args:
+            config (Dict): Configuration dictionary containing model_name, api_type, ak, sk, and access_token.
+
+        Raises:
+            ValueError: If model_name is not in the predefined entities,
+            api_type is not one of ['aistudio', 'qianfan'],
+            access_token is None for 'aistudio' api_type,
+            or ak and sk are None for 'qianfan' api_type.
+        """
         super().__init__()
         super().__init__()
         model_name = config.get("model_name", None)
         model_name = config.get("model_name", None)
         api_type = config.get("api_type", None)
         api_type = config.get("api_type", None)
@@ -54,10 +66,19 @@ class ErnieBotChat(BaseChat):
         self.model_name = model_name
         self.model_name = model_name
         self.config = config
         self.config = config
 
 
-    def generate_chat_results(self, prompt, temperature=0.001, max_retries=1):
+    def generate_chat_results(
+        self, prompt: str, temperature: float = 0.001, max_retries: int = 1
+    ) -> Dict:
         """
         """
-        args:
-        return:
+        Generate chat results using the specified model and configuration.
+
+        Args:
+            prompt (str): The user's input prompt.
+            temperature (float, optional): The temperature parameter for llms, defaults to 0.001.
+            max_retries (int, optional): The maximum number of retries for llms API calls, defaults to 1.
+
+        Returns:
+            Dict: The chat completion result from the model.
         """
         """
         try:
         try:
             cur_config = {
             cur_config = {

+ 2 - 1
paddlex/inference/pipelines_new/components/common/__init__.py

@@ -12,5 +12,6 @@
 # See the License for the specific language governing permissions and
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # limitations under the License.
 
 
-from .sort_boxes import SortQuadBoxes
+from .base_result import CVResult, BaseResult
+from .sort_boxes import SortQuadBoxes, SortPolyBoxes
 from .crop_image_regions import CropByPolys, CropByBoxes
 from .crop_image_regions import CropByPolys, CropByBoxes

+ 36 - 0
paddlex/inference/pipelines_new/components/common/base_operator.py

@@ -0,0 +1,36 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from abc import ABC, abstractmethod
+from .....utils.subclass_register import AutoRegisterABCMetaClass
+
+
+class BaseOperator(ABC, metaclass=AutoRegisterABCMetaClass):
+    """Base Operator"""
+
+    __is_base = True
+
+    def __init__(self):
+        """Initializes an instance of base operator."""
+        super().__init__()
+
+    @abstractmethod
+    def __call__(self):
+        """
+        Declaration of an abstract method. Subclasses are expected to
+        provide a concrete implementation of call method.
+        """
+        raise NotImplementedError(
+            "The component method `__call__` has not been implemented yet."
+        )

+ 28 - 23
paddlex/inference/pipelines_new/components/base.py → paddlex/inference/pipelines_new/components/common/base_result.py

@@ -12,39 +12,37 @@
 # See the License for the specific language governing permissions and
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # limitations under the License.
 
 
-from abc import ABC, abstractmethod
-from ....utils.subclass_register import AutoRegisterABCMetaClass
-
 import inspect
 import inspect
 
 
-from ....utils.func_register import FuncRegister
-from ...utils.io import ImageReader, ImageWriter
-from .utils.mixin import JsonMixin, ImgMixin, StrMixin
-
-
-class BaseComponent(ABC, metaclass=AutoRegisterABCMetaClass):
-    """Base Component"""
+from ....utils.io import ImageReader, ImageWriter
+from ..utils.mixin import JsonMixin, ImgMixin, StrMixin
+from typing import Dict
 
 
-    __is_base = True
 
 
-    def __init__(self):
-        super().__init__()
-
-    @abstractmethod
-    def __call__(self):
-        raise NotImplementedError(
-            "The component method `__call__` has not been implemented yet."
-        )
+class BaseResult(dict, StrMixin, JsonMixin):
+    """Base Result"""
 
 
+    def __init__(self, data: Dict) -> None:
+        """Initializes the instance with the provided data.
 
 
-class BaseResult(dict, StrMixin, JsonMixin):
-    def __init__(self, data):
+        Args:
+            data (Dict): The data to initialize the instance with.
+        """
         super().__init__(data)
         super().__init__(data)
         self._show_funcs = []
         self._show_funcs = []
         StrMixin.__init__(self)
         StrMixin.__init__(self)
         JsonMixin.__init__(self)
         JsonMixin.__init__(self)
 
 
-    def save_all(self, save_path):
+    def save_all(self, save_path: str) -> None:
+        """
+        Save all show functions to the specified path if they accept a save_path argument.
+
+        Args:
+            save_path (str): The path to save the functions' output.
+
+        Returns:
+            None
+        """
         for func in self._show_funcs:
         for func in self._show_funcs:
             signature = inspect.signature(func)
             signature = inspect.signature(func)
             if "save_path" in signature.parameters:
             if "save_path" in signature.parameters:
@@ -54,7 +52,14 @@ class BaseResult(dict, StrMixin, JsonMixin):
 
 
 
 
 class CVResult(BaseResult, ImgMixin):
 class CVResult(BaseResult, ImgMixin):
-    def __init__(self, data):
+    """Result For Computer Vision Tasks"""
+
+    def __init__(self, data: Dict) -> None:
+        """Initializes the instance with the given data and sets up image processing with the 'pillow' backend.
+
+        Args:
+            data (Dict): The data to initialize the instance with.
+        """
         super().__init__(data)
         super().__init__(data)
         ImgMixin.__init__(self, "pillow")
         ImgMixin.__init__(self, "pillow")
         self._img_reader = ImageReader(backend="pillow")
         self._img_reader = ImageReader(backend="pillow")

+ 124 - 42
paddlex/inference/pipelines_new/components/common/crop_image_regions.py

@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # limitations under the License.
 
 
-from ..base import BaseComponent
+from .base_operator import BaseOperator
 import numpy as np
 import numpy as np
 from ....utils.io import ImageReader
 from ....utils.io import ImageReader
 import copy
 import copy
@@ -20,19 +20,74 @@ import cv2
 from .seal_det_warp import AutoRectifier
 from .seal_det_warp import AutoRectifier
 from shapely.geometry import Polygon
 from shapely.geometry import Polygon
 from numpy.linalg import norm
 from numpy.linalg import norm
+from typing import Tuple
 
 
 
 
-class CropByPolys(BaseComponent):
+class CropByBoxes(BaseOperator):
+    """Crop Image by Boxes"""
+
+    entities = "CropByBoxes"
+
+    def __init__(self) -> None:
+        """Initializes the class."""
+        super().__init__()
+
+    def __call__(self, img: np.ndarray, boxes: list[dict]) -> list[dict]:
+        """
+        Process the input image and bounding boxes to produce a list of cropped images
+        with their corresponding bounding box coordinates and labels.
+
+        Args:
+            img (np.ndarray): The input image as a NumPy array.
+            boxes (list[dict]): A list of dictionaries, each containing bounding box
+                information including 'cls_id' (class ID), 'coordinate' (bounding box
+                coordinates as a list or tuple, left, top, right, bottom),
+                and optionally 'label' (label text).
+
+        Returns:
+            list[dict]: A list of dictionaries, each containing a cropped image ('img'),
+                the original bounding box coordinates ('box'), and the label ('label').
+        """
+        output_list = []
+        for bbox_info in boxes:
+            label_id = bbox_info["cls_id"]
+            box = bbox_info["coordinate"]
+            label = bbox_info.get("label", label_id)
+            xmin, ymin, xmax, ymax = [int(i) for i in box]
+            img_crop = img[ymin:ymax, xmin:xmax].copy()
+            output_list.append({"img": img_crop, "box": box, "label": label})
+        return output_list
+
+
+class CropByPolys(BaseOperator):
     """Crop Image by Polys"""
     """Crop Image by Polys"""
 
 
     entities = "CropByPolys"
     entities = "CropByPolys"
 
 
-    def __init__(self, det_box_type="quad"):
+    def __init__(self, det_box_type: str = "quad") -> None:
+        """
+        Initializes the operator with a default detection box type.
+
+        Args:
+            det_box_type (str, optional): The type of detection box, quad or poly. Defaults to "quad".
+        """
         super().__init__()
         super().__init__()
         self.det_box_type = det_box_type
         self.det_box_type = det_box_type
 
 
-    def __call__(self, img, dt_polys):
-        """__call__"""
+    def __call__(self, img: np.ndarray, dt_polys: list[list]) -> list[dict]:
+        """
+        Call method to crop images based on detection boxes.
+
+        Args:
+            img (nd.ndarray): The input image.
+            dt_polys (list[list]): List of detection polygons.
+
+        Returns:
+            list[dict]: A list of dictionaries containing cropped images and their sizes.
+
+        Raises:
+            NotImplementedError: If det_box_type is not 'quad' or 'poly'.
+        """
 
 
         if self.det_box_type == "quad":
         if self.det_box_type == "quad":
             dt_boxes = np.array(dt_polys)
             dt_boxes = np.array(dt_polys)
@@ -63,8 +118,17 @@ class CropByPolys(BaseComponent):
 
 
         return output_list
         return output_list
 
 
-    def get_minarea_rect_crop(self, img, points):
-        """get_minarea_rect_crop"""
+    def get_minarea_rect_crop(self, img: np.ndarray, points: np.ndarray) -> np.ndarray:
+        """
+        Get the minimum area rectangle crop from the given image and points.
+
+        Args:
+            img (np.ndarray): The input image.
+            points (np.ndarray): A list of points defining the shape to be cropped.
+
+        Returns:
+            np.ndarray: The cropped image with the minimum area rectangle.
+        """
         bounding_box = cv2.minAreaRect(np.array(points).astype(np.int32))
         bounding_box = cv2.minAreaRect(np.array(points).astype(np.int32))
         points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
         points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
 
 
@@ -86,16 +150,16 @@ class CropByPolys(BaseComponent):
         crop_img = self.get_rotate_crop_image(img, np.array(box))
         crop_img = self.get_rotate_crop_image(img, np.array(box))
         return crop_img
         return crop_img
 
 
-    def get_rotate_crop_image(self, img, points):
+    def get_rotate_crop_image(self, img: np.ndarray, points: list) -> np.ndarray:
         """
         """
-        img_height, img_width = img.shape[0:2]
-        left = int(np.min(points[:, 0]))
-        right = int(np.max(points[:, 0]))
-        top = int(np.min(points[:, 1]))
-        bottom = int(np.max(points[:, 1]))
-        img_crop = img[top:bottom, left:right, :].copy()
-        points[:, 0] = points[:, 0] - left
-        points[:, 1] = points[:, 1] - top
+        Crop and rotate the input image based on the given four points to form a perspective-transformed image.
+
+        Args:
+            img (np.ndarray): The input image array.
+            points (list): A list of four 2D points defining the crop region in the image.
+
+        Returns:
+            np.ndarray: The transformed image array.
         """
         """
         assert len(points) == 4, "shape of points must be 4*2"
         assert len(points) == 4, "shape of points must be 4*2"
         img_crop_width = int(
         img_crop_width = int(
@@ -131,7 +195,9 @@ class CropByPolys(BaseComponent):
             dst_img = np.rot90(dst_img)
             dst_img = np.rot90(dst_img)
         return dst_img
         return dst_img
 
 
-    def reorder_poly_edge(self, points):
+    def reorder_poly_edge(
+        self, points: np.ndarray
+    ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
         """Get the respective points composing head edge, tail edge, top
         """Get the respective points composing head edge, tail edge, top
         sideline and bottom sideline.
         sideline and bottom sideline.
 
 
@@ -165,11 +231,25 @@ class CropByPolys(BaseComponent):
         sideline2 = pad_points[tail_inds[1] : (head_inds[1] + len(points))]
         sideline2 = pad_points[tail_inds[1] : (head_inds[1] + len(points))]
         return head_edge, tail_edge, sideline1, sideline2
         return head_edge, tail_edge, sideline1, sideline2
 
 
-    def vector_slope(self, vec):
+    def vector_slope(self, vec: list) -> float:
+        """
+        Calculate the slope of a vector in 2D space.
+
+        Args:
+            vec (list): A list of two elements representing the coordinates of the vector.
+
+        Returns:
+            float: The slope of the vector.
+
+        Raises:
+            AssertionError: If the length of the vector is not equal to 2.
+        """
         assert len(vec) == 2
         assert len(vec) == 2
         return abs(vec[1] / (vec[0] + 1e-8))
         return abs(vec[1] / (vec[0] + 1e-8))
 
 
-    def find_head_tail(self, points, orientation_thr):
+    def find_head_tail(
+        self, points: np.ndarray, orientation_thr: float
+    ) -> tuple[list, list]:
         """Find the head edge and tail edge of a text polygon.
         """Find the head edge and tail edge of a text polygon.
 
 
         Args:
         Args:
@@ -277,7 +357,17 @@ class CropByPolys(BaseComponent):
 
 
         return head_inds, tail_inds
         return head_inds, tail_inds
 
 
-    def vector_angle(self, vec1, vec2):
+    def vector_angle(self, vec1: np.ndarray, vec2: np.ndarray) -> float:
+        """
+        Calculate the angle between two vectors.
+
+        Args:
+            vec1 (ndarray): The first vector.
+            vec2 (ndarray): The second vector.
+
+        Returns:
+            float: The angle between the two vectors in radians.
+        """
         if vec1.ndim > 1:
         if vec1.ndim > 1:
             unit_vec1 = vec1 / (norm(vec1, axis=-1) + 1e-8).reshape((-1, 1))
             unit_vec1 = vec1 / (norm(vec1, axis=-1) + 1e-8).reshape((-1, 1))
         else:
         else:
@@ -288,7 +378,20 @@ class CropByPolys(BaseComponent):
             unit_vec2 = vec2 / (norm(vec2, axis=-1) + 1e-8)
             unit_vec2 = vec2 / (norm(vec2, axis=-1) + 1e-8)
         return np.arccos(np.clip(np.sum(unit_vec1 * unit_vec2, axis=-1), -1.0, 1.0))
         return np.arccos(np.clip(np.sum(unit_vec1 * unit_vec2, axis=-1), -1.0, 1.0))
 
 
-    def get_minarea_rect(self, img, points):
+    def get_minarea_rect(
+        self, img: np.ndarray, points: np.ndarray
+    ) -> tuple[np.ndarray, list]:
+        """
+        Get the minimum area rectangle for the given points and crop the image accordingly.
+
+        Args:
+            img (np.ndarray): The input image.
+            points (np.ndarray): The points to compute the minimum area rectangle for.
+
+        Returns:
+            tuple[np.ndarray, list]: The cropped image,
+            and the list of points in the order of the bounding box.
+        """
         bounding_box = cv2.minAreaRect(points)
         bounding_box = cv2.minAreaRect(points)
         points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
         points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
 
 
@@ -453,24 +556,3 @@ class CropByPolys(BaseComponent):
             img = np.stack((img,) * 3, axis=-1)
             img = np.stack((img,) * 3, axis=-1)
         img_crop, image = rectifier.run(img, new_points_list, mode="homography")
         img_crop, image = rectifier.run(img, new_points_list, mode="homography")
         return np.array(img_crop[0], dtype=np.uint8)
         return np.array(img_crop[0], dtype=np.uint8)
-
-
-class CropByBoxes(BaseComponent):
-    """Crop Image by Box"""
-
-    entities = "CropByBoxes"
-
-    def __init__(self):
-        super().__init__()
-
-    def __call__(self, img, boxes):
-        """__call__"""
-        output_list = []
-        for bbox_info in boxes:
-            label_id = bbox_info["cls_id"]
-            box = bbox_info["coordinate"]
-            label = bbox_info.get("label", label_id)
-            xmin, ymin, xmax, ymax = [int(i) for i in box]
-            img_crop = img[ymin:ymax, xmin:xmax].copy()
-            output_list.append({"img": img_crop, "box": box, "label": label})
-        return output_list

+ 2 - 0
paddlex/inference/pipelines_new/components/common/seal_det_warp.py

@@ -21,6 +21,8 @@ import time
 
 
 from .....utils import logging
 from .....utils import logging
 
 
+#### [TODO] need sunting to add explanatory notes
+
 
 
 def Homography(
 def Homography(
     image,
     image,

+ 38 - 6
paddlex/inference/pipelines_new/components/common/sort_boxes.py

@@ -12,25 +12,26 @@
 # See the License for the specific language governing permissions and
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # limitations under the License.
 
 
-from ..base import BaseComponent
+from .base_operator import BaseOperator
 import numpy as np
 import numpy as np
 
 
 
 
-class SortQuadBoxes(BaseComponent):
-    """SortQuadBoxes Component"""
+class SortQuadBoxes(BaseOperator):
+    """SortQuadBoxes Operator."""
 
 
     entities = "SortQuadBoxes"
     entities = "SortQuadBoxes"
 
 
     def __init__(self):
     def __init__(self):
+        """Initializes the class."""
         super().__init__()
         super().__init__()
 
 
-    def __call__(self, dt_polys):
+    def __call__(self, dt_polys: list[np.ndarray]) -> np.ndarray:
         """
         """
         Sort quad boxes in order from top to bottom, left to right
         Sort quad boxes in order from top to bottom, left to right
         args:
         args:
-            dt_polys(array):detected quad boxes with shape [4, 2]
+            dt_polys(ndarray):detected quad boxes with shape [4, 2]
         return:
         return:
-            sorted boxes(array) with shape [4, 2]
+            sorted boxes(ndarray) with shape [4, 2]
         """
         """
         dt_boxes = np.array(dt_polys)
         dt_boxes = np.array(dt_polys)
         num_boxes = dt_boxes.shape[0]
         num_boxes = dt_boxes.shape[0]
@@ -48,3 +49,34 @@ class SortQuadBoxes(BaseComponent):
                 else:
                 else:
                     break
                     break
         return _boxes
         return _boxes
+
+
+class SortPolyBoxes(BaseOperator):
+    """SortPolyBoxes Operator."""
+
+    entities = "SortPolyBoxes"
+
+    def __init__(self):
+        """Initializes the class."""
+        super().__init__()
+
+    def __call__(self, dt_polys: list[np.ndarray]) -> np.ndarray:
+        """
+        Sort poly boxes in order from top to bottom, left to right
+        args:
+            dt_polys(ndarray):detected poly boxes with a [N, 2] np.ndarray list
+        return:
+            sorted boxes(ndarray) with [N, 2] np.ndarray list
+        """
+        num_boxes = len(dt_polys)
+        if num_boxes == 0:
+            return dt_polys
+        else:
+            y_min_list = []
+            for bno in range(num_boxes):
+                y_min_list.append(min(dt_polys[bno][:, 1]))
+            rank = np.argsort(np.array(y_min_list))
+            dt_polys_rank = []
+            for no in range(num_boxes):
+                dt_polys_rank.append(dt_polys[rank[no]])
+            return dt_polys_rank

+ 4 - 1
paddlex/inference/pipelines_new/components/prompt_engeering/base.py

@@ -19,15 +19,18 @@ import inspect
 
 
 
 
 class BaseGeneratePrompt(ABC, metaclass=AutoRegisterABCMetaClass):
 class BaseGeneratePrompt(ABC, metaclass=AutoRegisterABCMetaClass):
-    """Base Chat"""
+    """Base Generate Prompt class."""
 
 
     __is_base = True
     __is_base = True
 
 
     def __init__(self):
     def __init__(self):
+        """Initializes an instance of base generate prompt."""
         super().__init__()
         super().__init__()
 
 
     @abstractmethod
     @abstractmethod
     def generate_prompt(self):
     def generate_prompt(self):
+        """Declaration of an abstract method. Subclasses are expected to
+        provide a concrete implementation of generate prompt method."""
         raise NotImplementedError(
         raise NotImplementedError(
             "The method `generate_prompt` has not been implemented yet."
             "The method `generate_prompt` has not been implemented yet."
         )
         )

+ 53 - 16
paddlex/inference/pipelines_new/components/prompt_engeering/generate_kie_prompt.py

@@ -13,6 +13,7 @@
 # limitations under the License.
 # limitations under the License.
 
 
 from .base import BaseGeneratePrompt
 from .base import BaseGeneratePrompt
+from typing import Dict
 
 
 
 
 class GenerateKIEPrompt(BaseGeneratePrompt):
 class GenerateKIEPrompt(BaseGeneratePrompt):
@@ -20,7 +21,21 @@ class GenerateKIEPrompt(BaseGeneratePrompt):
 
 
     entities = ["text_kie_prompt", "table_kie_prompt"]
     entities = ["text_kie_prompt", "table_kie_prompt"]
 
 
-    def __init__(self, config):
+    def __init__(self, config: Dict) -> None:
+        """Initializes the GenerateKIEPrompt instance with the given configuration.
+
+        Args:
+            config (Dict): A dictionary containing configuration settings.
+                - task_type (str): The type of task to generate a prompt for, in the support entities list.
+                - task_description (str, optional): A description of the task. Defaults to an empty string.
+                - output_format (str, optional): The desired output format. Defaults to an empty string.
+                - rules_str (str, optional): A string representing rules for the task. Defaults to an empty string.
+                - few_shot_demo_text_content (str, optional): Text content for few-shot demos. Defaults to an empty string.
+                - few_shot_demo_key_value_list (str, optional): A key-value list for few-shot demos. Defaults to an empty string.
+
+        Raises:
+            ValueError: If the task type is not in the allowed entities for GenerateKIEPrompt.
+        """
         super().__init__()
         super().__init__()
 
 
         task_type = config.get("task_type", "")
         task_type = config.get("task_type", "")
@@ -59,19 +74,31 @@ class GenerateKIEPrompt(BaseGeneratePrompt):
 
 
     def generate_prompt(
     def generate_prompt(
         self,
         self,
-        text_content,
-        key_list,
-        task_description=None,
-        output_format=None,
-        rules_str=None,
-        few_shot_demo_text_content=None,
-        few_shot_demo_key_value_list=None,
-    ):
+        text_content: str,
+        key_list: list,
+        task_description: str = None,
+        output_format: str = None,
+        rules_str: str = None,
+        few_shot_demo_text_content: str = None,
+        few_shot_demo_key_value_list: str = None,
+    ) -> str:
+        """Generates a prompt based on the given parameters.
+
+        Args:
+            text_content (str): The main text content to be used in the prompt.
+            key_list (list): A list of keywords for information extraction.
+            task_description (str, optional): A description of the task. Defaults to None.
+            output_format (str, optional): The desired output format. Defaults to None.
+            rules_str (str, optional): A string containing rules or instructions. Defaults to None.
+            few_shot_demo_text_content (str, optional): Text content for few-shot demos. Defaults to None.
+            few_shot_demo_key_value_list (str, optional): Key-value list for few-shot demos. Defaults to None.
+
+        Returns:
+            str: The generated prompt.
+
+        Raises:
+            ValueError: If the task_type is not supported.
         """
         """
-        args:
-        return:
-        """
-
         if task_description is None:
         if task_description is None:
             task_description = self.task_description
             task_description = self.task_description
 
 
@@ -87,19 +114,29 @@ class GenerateKIEPrompt(BaseGeneratePrompt):
         if few_shot_demo_key_value_list is None:
         if few_shot_demo_key_value_list is None:
             few_shot_demo_key_value_list = self.few_shot_demo_key_value_list
             few_shot_demo_key_value_list = self.few_shot_demo_key_value_list
 
 
-        prompt = f"""{task_description}{output_format}{rules_str}{few_shot_demo_text_content}{few_shot_demo_key_value_list}"""
+        prompt = f"""{task_description}{rules_str}{output_format}{few_shot_demo_text_content}{few_shot_demo_key_value_list}"""
         if self.task_type == "table_kie_prompt":
         if self.task_type == "table_kie_prompt":
             prompt += f"""\n结合上面,下面正式开始:\
             prompt += f"""\n结合上面,下面正式开始:\
                 表格内容:```{text_content}```\
                 表格内容:```{text_content}```\
-                关键词列表:{key_list}。""".replace(
+                \n问题列表:{key_list}。""".replace(
                 "    ", ""
                 "    ", ""
             )
             )
+            # prompt += f"""\n结合上面,下面正式开始:\
+            #     表格内容:```{text_content}```\
+            #     \n关键词列表:{key_list}。""".replace(
+            #     "    ", ""
+            # )
         elif self.task_type == "text_kie_prompt":
         elif self.task_type == "text_kie_prompt":
             prompt += f"""\n结合上面的例子,下面正式开始:\
             prompt += f"""\n结合上面的例子,下面正式开始:\
                 OCR文字:```{text_content}```\
                 OCR文字:```{text_content}```\
-                关键词列表:{key_list}。""".replace(
+                \n问题列表:{key_list}。""".replace(
                 "    ", ""
                 "    ", ""
             )
             )
+            # prompt += f"""\n结合上面的例子,下面正式开始:\
+            #     OCR文字:```{text_content}```\
+            #     \n关键词列表:{key_list}。""".replace(
+            #     "    ", ""
+            # )
         else:
         else:
             raise ValueError(f"{self.task_type} is currently not supported.")
             raise ValueError(f"{self.task_type} is currently not supported.")
         return prompt
         return prompt

+ 39 - 3
paddlex/inference/pipelines_new/components/retriever/base.py

@@ -27,27 +27,63 @@ class BaseRetriever(ABC, metaclass=AutoRegisterABCMetaClass):
     VECTOR_STORE_PREFIX = "PADDLEX_VECTOR_STORE"
     VECTOR_STORE_PREFIX = "PADDLEX_VECTOR_STORE"
 
 
     def __init__(self):
     def __init__(self):
+        """Initializes an instance of base retriever."""
         super().__init__()
         super().__init__()
 
 
     @abstractmethod
     @abstractmethod
     def generate_vector_database(self):
     def generate_vector_database(self):
+        """
+        Declaration of an abstract method. Subclasses are expected to
+        provide a concrete implementation of generate_vector_database.
+        """
         raise NotImplementedError(
         raise NotImplementedError(
             "The method `generate_vector_database` has not been implemented yet."
             "The method `generate_vector_database` has not been implemented yet."
         )
         )
 
 
     @abstractmethod
     @abstractmethod
     def similarity_retrieval(self):
     def similarity_retrieval(self):
+        """
+        Declaration of an abstract method. Subclasses are expected to
+        provide a concrete implementation of similarity_retrieval.
+        """
         raise NotImplementedError(
         raise NotImplementedError(
             "The method `similarity_retrieval` has not been implemented yet."
             "The method `similarity_retrieval` has not been implemented yet."
         )
         )
 
 
-    def is_vector_store(self, s):
+    def is_vector_store(self, s: str) -> bool:
+        """
+        Check if the given string starts with the vector store prefix.
+
+        Args:
+            s (str): The input string to check.
+
+        Returns:
+            bool: True if the string starts with the vector store prefix, False otherwise.
+        """
         return s.startswith(self.VECTOR_STORE_PREFIX)
         return s.startswith(self.VECTOR_STORE_PREFIX)
 
 
-    def encode_vector_store(self, vector_store_bytes):
+    def encode_vector_store(self, vector_store_bytes: bytes) -> str:
+        """
+        Encode the vector store bytes into a base64 string prefixed with a specific prefix.
+
+        Args:
+            vector_store_bytes (bytes): The bytes to encode.
+
+        Returns:
+            str: The encoded string with the prefix.
+        """
         return self.VECTOR_STORE_PREFIX + base64.b64encode(vector_store_bytes).decode(
         return self.VECTOR_STORE_PREFIX + base64.b64encode(vector_store_bytes).decode(
             "ascii"
             "ascii"
         )
         )
 
 
-    def decode_vector_store(self, vector_store_str):
+    def decode_vector_store(self, vector_store_str: str) -> bytes:
+        """
+        Decodes the vector store string by removing the prefix and decoding the base64 encoded string.
+
+        Args:
+            vector_store_str (str): The vector store string with a prefix.
+
+        Returns:
+            bytes: The decoded vector store data.
+        """
         return base64.b64decode(vector_store_str[len(self.VECTOR_STORE_PREFIX) :])
         return base64.b64decode(vector_store_str[len(self.VECTOR_STORE_PREFIX) :])

+ 80 - 13
paddlex/inference/pipelines_new/components/retriever/ernie_bot_retriever.py

@@ -25,6 +25,8 @@ from erniebot_agent.extensions.langchain.embeddings import ErnieEmbeddings
 
 
 import time
 import time
 
 
+from typing import Dict
+
 
 
 class ErnieBotRetriever(BaseRetriever):
 class ErnieBotRetriever(BaseRetriever):
     """Ernie Bot Retriever"""
     """Ernie Bot Retriever"""
@@ -40,8 +42,24 @@ class ErnieBotRetriever(BaseRetriever):
         "ernie-char-8k",
         "ernie-char-8k",
     ]
     ]
 
 
-    def __init__(self, config):
-
+    def __init__(self, config: Dict) -> None:
+        """
+        Initializes the ErnieBotRetriever instance with the provided configuration.
+
+        Args:
+            config (Dict): A dictionary containing configuration settings.
+                - model_name (str): The name of the model to use.
+                - api_type (str): The type of API to use ('aistudio' or 'qianfan').
+                - ak (str, optional): The access key for 'qianfan' API.
+                - sk (str, optional): The secret key for 'qianfan' API.
+                - access_token (str, optional): The access token for 'aistudio' API.
+
+        Raises:
+            ValueError: If model_name is not in self.entities,
+                api_type is not 'aistudio' or 'qianfan',
+                access_token is missing for 'aistudio' API,
+                or ak and sk are missing for 'qianfan' API.
+        """
         super().__init__()
         super().__init__()
 
 
         model_name = config.get("model_name", None)
         model_name = config.get("model_name", None)
@@ -65,16 +83,29 @@ class ErnieBotRetriever(BaseRetriever):
         self.model_name = model_name
         self.model_name = model_name
         self.config = config
         self.config = config
 
 
+    # Generates a vector database from a list of texts using different embeddings based on the configured API type.
+
     def generate_vector_database(
     def generate_vector_database(
         self,
         self,
-        text_list,
-        block_size=300,
-        separators=["\t", "\n", "。", "\n\n", ""],
-        sleep_time=0.5,
-    ):
+        text_list: list[str],
+        block_size: int = 300,
+        separators: list[str] = ["\t", "\n", "。", "\n\n", ""],
+        sleep_time: float = 0.5,
+    ) -> FAISS:
         """
         """
-        args:
-        return:
+        Generates a vector database from a list of texts.
+
+        Args:
+            text_list (list[str]): A list of texts to generate the vector database from.
+            block_size (int): The size of each chunk to split the text into.
+            separators (list[str]): A list of separators to use when splitting the text.
+            sleep_time (float): The time to sleep between embedding generations to avoid rate limiting.
+
+        Returns:
+            FAISS: The generated vector database.
+
+        Raises:
+            ValueError: If an unsupported API type is configured.
         """
         """
         text_splitter = RecursiveCharacterTextSplitter(
         text_splitter = RecursiveCharacterTextSplitter(
             chunk_size=block_size, chunk_overlap=20, separators=separators
             chunk_size=block_size, chunk_overlap=20, separators=separators
@@ -113,13 +144,36 @@ class ErnieBotRetriever(BaseRetriever):
 
 
         return vectorstore
         return vectorstore
 
 
-    def encode_vector_store_to_bytes(self, vectorstore):
+    def encode_vector_store_to_bytes(self, vectorstore: FAISS) -> str:
+        """
+        Encode the vector store serialized to bytes.
+
+        Args:
+            vectorstore (FAISS): The vector store to be serialized and encoded.
+
+        Returns:
+            str: The encoded vector store.
+        """
         vectorstore = self.encode_vector_store(vectorstore.serialize_to_bytes())
         vectorstore = self.encode_vector_store(vectorstore.serialize_to_bytes())
         return vectorstore
         return vectorstore
 
 
-    def decode_vector_store_from_bytes(self, vectorstore):
+    def decode_vector_store_from_bytes(self, vectorstore: str) -> FAISS:
+        """
+        Decode a vector store from bytes according to the specified API type.
+
+        Args:
+            vectorstore (str): The serialized vector store string.
+
+        Returns:
+            FAISS: Deserialized vector store object.
+
+        Raises:
+            ValueError: If the retrieved vector store is not for PaddleX
+            or if an unsupported API type is specified.
+        """
         if not self.is_vector_store(vectorstore):
         if not self.is_vector_store(vectorstore):
             raise ValueError("The retrieved vectorstore is not for PaddleX.")
             raise ValueError("The retrieved vectorstore is not for PaddleX.")
+
         api_type = self.config["api_type"]
         api_type = self.config["api_type"]
 
 
         if api_type == "aistudio":
         if api_type == "aistudio":
@@ -131,13 +185,26 @@ class ErnieBotRetriever(BaseRetriever):
             embeddings = QianfanEmbeddingsEndpoint(qianfan_ak=ak, qianfan_sk=sk)
             embeddings = QianfanEmbeddingsEndpoint(qianfan_ak=ak, qianfan_sk=sk)
         else:
         else:
             raise ValueError(f"Unsupported api_type: {api_type}")
             raise ValueError(f"Unsupported api_type: {api_type}")
+
         vector = vectorstores.FAISS.deserialize_from_bytes(
         vector = vectorstores.FAISS.deserialize_from_bytes(
             self.decode_vector_store(vectorstore), embeddings
             self.decode_vector_store(vectorstore), embeddings
         )
         )
         return vector
         return vector
 
 
-    def similarity_retrieval(self, query_text_list, vectorstore, sleep_time=0.5):
-        # 根据提问匹配上下文
+    def similarity_retrieval(
+        self, query_text_list: list[str], vectorstore: FAISS, sleep_time: float = 0.5
+    ) -> str:
+        """
+        Retrieve similar contexts based on a list of query texts.
+
+        Args:
+            query_text_list (list[str]): A list of query texts to search for similar contexts.
+            vectorstore (FAISS): The vector store where to perform the similarity search.
+            sleep_time (float): The time to sleep between each query, in seconds. Default is 0.5.
+
+        Returns:
+            str: A concatenated string of all unique contexts found.
+        """
         C = []
         C = []
         for query_text in query_text_list:
         for query_text in query_text_list:
             QUESTION = query_text
             QUESTION = query_text

+ 2 - 0
paddlex/inference/pipelines_new/components/utils/mixin.py

@@ -30,6 +30,8 @@ from ....utils.io import (
     TextWriter,
     TextWriter,
 )
 )
 
 
+#### [TODO] need tingquan to add explanatory notes
+
 
 
 def _save_list_data(save_func, save_path, data, *args, **kwargs):
 def _save_list_data(save_func, save_path, data, *args, **kwargs):
     save_path = Path(save_path)
     save_path = Path(save_path)

+ 69 - 16
paddlex/inference/pipelines_new/doc_preprocessor/pipeline.py

@@ -16,10 +16,14 @@ from ..base import BasePipeline
 from typing import Any, Dict, Optional
 from typing import Any, Dict, Optional
 from scipy.ndimage import rotate
 from scipy.ndimage import rotate
 from .result import DocPreprocessorResult
 from .result import DocPreprocessorResult
+from ....utils import logging
+import numpy as np
 
 
 ########## [TODO]后续需要更新路径
 ########## [TODO]后续需要更新路径
 from ...components.transforms import ReadImage
 from ...components.transforms import ReadImage
 
 
+from ...utils.pp_option import PaddlePredictorOption
+
 
 
 class DocPreprocessorPipeline(BasePipeline):
 class DocPreprocessorPipeline(BasePipeline):
     """Doc Preprocessor Pipeline"""
     """Doc Preprocessor Pipeline"""
@@ -28,12 +32,22 @@ class DocPreprocessorPipeline(BasePipeline):
 
 
     def __init__(
     def __init__(
         self,
         self,
-        config,
-        device=None,
-        pp_option=None,
+        config: Dict,
+        device: str = None,
+        pp_option: PaddlePredictorOption = None,
         use_hpip: bool = False,
         use_hpip: bool = False,
         hpi_params: Optional[Dict[str, Any]] = None,
         hpi_params: Optional[Dict[str, Any]] = None,
-    ):
+    ) -> None:
+        """Initializes the doc preprocessor pipeline.
+
+        Args:
+            config (Dict): Configuration dictionary containing various settings.
+            device (str, optional): Device to run the predictions on. Defaults to None.
+            pp_option (PaddlePredictorOption, optional): PaddlePredictor options. Defaults to None.
+            use_hpip (bool, optional): Whether to use high-performance inference (hpip) for prediction. Defaults to False.
+            hpi_params (Optional[Dict[str, Any]], optional): HPIP parameters. Defaults to None.
+        """
+
         super().__init__(
         super().__init__(
             device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_params=hpi_params
             device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_params=hpi_params
         )
         )
@@ -56,35 +70,72 @@ class DocPreprocessorPipeline(BasePipeline):
 
 
         self.img_reader = ReadImage(format="BGR")
         self.img_reader = ReadImage(format="BGR")
 
 
-    def rotate_image(self, image_array, rotate_angle):
-        """rotate image"""
+    def rotate_image(self, image_array: np.ndarray, rotate_angle: float) -> np.ndarray:
+        """
+        Rotate the given image array by the specified angle.
+
+        Args:
+            image_array (np.ndarray): The input image array to be rotated.
+            rotate_angle (float): The angle in degrees by which to rotate the image.
+
+        Returns:
+            np.ndarray: The rotated image array.
+
+        Raises:
+            AssertionError: If rotate_angle is not in the range [0, 360).
+        """
         assert (
         assert (
             rotate_angle >= 0 and rotate_angle < 360
             rotate_angle >= 0 and rotate_angle < 360
         ), "rotate_angle must in [0-360), but get {rotate_angle}."
         ), "rotate_angle must in [0-360), but get {rotate_angle}."
         return rotate(image_array, rotate_angle, reshape=True)
         return rotate(image_array, rotate_angle, reshape=True)
 
 
-    def check_input_params(self, input_params):
+    def check_input_params_valid(self, input_params: Dict) -> bool:
+        """
+        Check if the input parameters are valid based on the initialized models.
+
+        Args:
+            input_params (Dict): A dictionary containing input parameters.
+
+        Returns:
+            bool: True if all required models are initialized according to input parameters, False otherwise.
+        """
 
 
         if (
         if (
             input_params["use_doc_orientation_classify"]
             input_params["use_doc_orientation_classify"]
             and not self.use_doc_orientation_classify
             and not self.use_doc_orientation_classify
         ):
         ):
-            raise ValueError(
-                "The model for doc orientation classify is not initialized."
+            logging.error(
+                "Set use_doc_orientation_classify, but the model for doc orientation classify is not initialized."
             )
             )
+            return False
 
 
         if input_params["use_doc_unwarping"] and not self.use_doc_unwarping:
         if input_params["use_doc_unwarping"] and not self.use_doc_unwarping:
-            raise ValueError("The model for doc unwarping is not initialized.")
+            logging.error(
+                "Set use_doc_unwarping, but the model for doc unwarping is not initialized."
+            )
+            return False
 
 
-        return
+        return True
 
 
     def predict(
     def predict(
         self,
         self,
-        input,
-        use_doc_orientation_classify=True,
-        use_doc_unwarping=False,
+        input: str | list[str] | np.ndarray | list[np.ndarray],
+        use_doc_orientation_classify: bool = True,
+        use_doc_unwarping: bool = False,
         **kwargs
         **kwargs
-    ):
+    ) -> DocPreprocessorResult:
+        """
+        Predict the preprocessing result for the input image or images.
+
+        Args:
+            input (str | list[str] | np.ndarray | list[np.ndarray]): The input image(s) or path(s) to the images.
+            use_doc_orientation_classify (bool): Whether to use document orientation classification.
+            use_doc_unwarping (bool): Whether to use document unwarping.
+            **kwargs: Additional keyword arguments.
+
+        Returns:
+            DocPreprocessorResult: A generator yielding preprocessing results.
+        """
 
 
         if not isinstance(input, list):
         if not isinstance(input, list):
             input_list = [input]
             input_list = [input]
@@ -95,7 +146,9 @@ class DocPreprocessorPipeline(BasePipeline):
             "use_doc_orientation_classify": use_doc_orientation_classify,
             "use_doc_orientation_classify": use_doc_orientation_classify,
             "use_doc_unwarping": use_doc_unwarping,
             "use_doc_unwarping": use_doc_unwarping,
         }
         }
-        self.check_input_params(input_params)
+
+        if not self.check_input_params_valid(input_params):
+            yield {"error": "input params invalid"}
 
 
         img_id = 1
         img_id = 1
         for input in input_list:
         for input in input_list:

+ 25 - 6
paddlex/inference/pipelines_new/doc_preprocessor/result.py

@@ -24,19 +24,38 @@ from ..components import CVResult
 
 
 
 
 class DocPreprocessorResult(CVResult):
 class DocPreprocessorResult(CVResult):
+    """doc preprocessor result"""
 
 
-    def save_to_img(self, save_path, *args, **kwargs):
+    def save_to_img(self, save_path: str, *args, **kwargs) -> None:
+        """
+        Save the image to the specified path.
+
+        Args:
+            save_path (str): The path to save the image.
+                If the path does not end with '.jpg' or '.png', it appends '_res_doc_preprocess_<img_id>.jpg'
+                to the path where <img_id> is retrieved from the object's 'img_id' attribute.
+            *args: Variable length argument list.
+            **kwargs: Arbitrary keyword arguments.
+
+        Returns:
+            None
+        """
         if not str(save_path).lower().endswith((".jpg", ".png")):
         if not str(save_path).lower().endswith((".jpg", ".png")):
             img_id = self["img_id"]
             img_id = self["img_id"]
             save_path = save_path + "/res_doc_preprocess_%d.jpg" % img_id
             save_path = save_path + "/res_doc_preprocess_%d.jpg" % img_id
         super().save_to_img(save_path, *args, **kwargs)
         super().save_to_img(save_path, *args, **kwargs)
 
 
-    def _to_img(self):
-        """draw doc preprocess result"""
-        image = self["input_image"]
+    def _to_img(self) -> PIL.Image:
+        """
+        Generate an image combining the original, rotated, and unwarping images.
+
+        Returns:
+            PIL.Image: A new image that displays the original, rotated, and unwarping images side by side.
+        """
+        image = self["input_image"][:, :, ::-1]
         angle = self["angle"]
         angle = self["angle"]
-        rot_img = self["rot_img"]
-        output_img = self["output_img"]
+        rot_img = self["rot_img"][:, :, ::-1]
+        output_img = self["output_img"][:, :, ::-1]
         h, w = image.shape[0:2]
         h, w = image.shape[0:2]
         img_show = Image.new("RGB", (w * 3, h + 25), (255, 255, 255))
         img_show = Image.new("RGB", (w * 3, h + 25), (255, 255, 255))
         img_show.paste(Image.fromarray(image), (0, 0, w, h))
         img_show.paste(Image.fromarray(image), (0, 0, w, h))

+ 99 - 22
paddlex/inference/pipelines_new/layout_parsing/pipeline.py

@@ -22,9 +22,16 @@ from .table_recognition_post_processing import get_table_recognition_res
 
 
 from .result import LayoutParsingResult
 from .result import LayoutParsingResult
 
 
+from ....utils import logging
+
+from ...utils.pp_option import PaddlePredictorOption
+
 ########## [TODO]后续需要更新路径
 ########## [TODO]后续需要更新路径
 from ...components.transforms import ReadImage
 from ...components.transforms import ReadImage
 
 
+from ..ocr.result import OCRResult
+from ...results import DetResult
+
 
 
 class LayoutParsingPipeline(BasePipeline):
 class LayoutParsingPipeline(BasePipeline):
     """Layout Parsing Pipeline"""
     """Layout Parsing Pipeline"""
@@ -33,12 +40,22 @@ class LayoutParsingPipeline(BasePipeline):
 
 
     def __init__(
     def __init__(
         self,
         self,
-        config,
-        device=None,
-        pp_option=None,
+        config: Dict,
+        device: str = None,
+        pp_option: PaddlePredictorOption = None,
         use_hpip: bool = False,
         use_hpip: bool = False,
         hpi_params: Optional[Dict[str, Any]] = None,
         hpi_params: Optional[Dict[str, Any]] = None,
-    ):
+    ) -> None:
+        """Initializes the layout parsing pipeline.
+
+        Args:
+            config (Dict): Configuration dictionary containing various settings.
+            device (str, optional): Device to run the predictions on. Defaults to None.
+            pp_option (PaddlePredictorOption, optional): PaddlePredictor options. Defaults to None.
+            use_hpip (bool, optional): Whether to use high-performance inference (hpip) for prediction. Defaults to False.
+            hpi_params (Optional[Dict[str, Any]], optional): HPIP parameters. Defaults to None.
+        """
+
         super().__init__(
         super().__init__(
             device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_params=hpi_params
             device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_params=hpi_params
         )
         )
@@ -49,13 +66,23 @@ class LayoutParsingPipeline(BasePipeline):
 
 
         self._crop_by_boxes = CropByBoxes()
         self._crop_by_boxes = CropByBoxes()
 
 
-    def inintial_predictor(self, config):
+    def inintial_predictor(self, config: Dict) -> None:
+        """Initializes the predictor based on the provided configuration.
+
+        Args:
+            config (Dict): A dictionary containing the configuration for the predictor.
+
+        Returns:
+            None
+        """
+
         layout_det_config = config["SubModules"]["LayoutDetection"]
         layout_det_config = config["SubModules"]["LayoutDetection"]
         self.layout_det_model = self.create_model(layout_det_config)
         self.layout_det_model = self.create_model(layout_det_config)
 
 
         self.use_doc_preprocessor = False
         self.use_doc_preprocessor = False
         if "use_doc_preprocessor" in config:
         if "use_doc_preprocessor" in config:
             self.use_doc_preprocessor = config["use_doc_preprocessor"]
             self.use_doc_preprocessor = config["use_doc_preprocessor"]
+
         if self.use_doc_preprocessor:
         if self.use_doc_preprocessor:
             doc_preprocessor_config = config["SubPipelines"]["DocPreprocessor"]
             doc_preprocessor_config = config["SubPipelines"]["DocPreprocessor"]
             self.doc_preprocessor_pipeline = self.create_pipeline(
             self.doc_preprocessor_pipeline = self.create_pipeline(
@@ -87,41 +114,88 @@ class LayoutParsingPipeline(BasePipeline):
                 self.common_ocr_pipeline = self.create_pipeline(common_ocr_config)
                 self.common_ocr_pipeline = self.create_pipeline(common_ocr_config)
         return
         return
 
 
-    def get_text_paragraphs_ocr_res(self, overall_ocr_res, layout_det_res):
-        """get ocr res of the text paragraphs"""
+    def get_text_paragraphs_ocr_res(
+        self, overall_ocr_res: OCRResult, layout_det_res: DetResult
+    ) -> OCRResult:
+        """
+        Retrieves the OCR results for text paragraphs, excluding those of formulas, tables, and seals.
+
+        Args:
+            overall_ocr_res (OCRResult): The overall OCR result containing text information.
+            layout_det_res (DetResult): The detection result containing the layout information of the document.
+
+        Returns:
+            OCRResult: The OCR result for text paragraphs after excluding formulas, tables, and seals.
+        """
         object_boxes = []
         object_boxes = []
         for box_info in layout_det_res["boxes"]:
         for box_info in layout_det_res["boxes"]:
-            if box_info["label"].lower() in ["image", "formula", "table", "seal"]:
+            if box_info["label"].lower() in ["formula", "table", "seal"]:
                 object_boxes.append(box_info["coordinate"])
                 object_boxes.append(box_info["coordinate"])
         object_boxes = np.array(object_boxes)
         object_boxes = np.array(object_boxes)
         return get_sub_regions_ocr_res(overall_ocr_res, object_boxes, flag_within=False)
         return get_sub_regions_ocr_res(overall_ocr_res, object_boxes, flag_within=False)
 
 
-    def check_input_params(self, input_params):
+    def check_input_params_valid(self, input_params: Dict) -> bool:
+        """
+        Check if the input parameters are valid based on the initialized models.
+
+        Args:
+            input_params (Dict): A dictionary containing input parameters.
+
+        Returns:
+            bool: True if all required models are initialized according to input parameters, False otherwise.
+        """
 
 
         if input_params["use_doc_preprocessor"] and not self.use_doc_preprocessor:
         if input_params["use_doc_preprocessor"] and not self.use_doc_preprocessor:
-            raise ValueError("The models for doc preprocessor are not initialized.")
+            logging.error(
+                "Set use_doc_preprocessor, but the models for doc preprocessor are not initialized."
+            )
+            return False
 
 
         if input_params["use_common_ocr"] and not self.use_common_ocr:
         if input_params["use_common_ocr"] and not self.use_common_ocr:
-            raise ValueError("The models for common OCR are not initialized.")
+            logging.error(
+                "Set use_common_ocr, but the models for common OCR are not initialized."
+            )
+            return False
 
 
         if input_params["use_seal_recognition"] and not self.use_seal_recognition:
         if input_params["use_seal_recognition"] and not self.use_seal_recognition:
-            raise ValueError("The models for seal recognition are not initialized.")
+            logging.error(
+                "Set use_seal_recognition, but the models for seal recognition are not initialized."
+            )
+            return False
 
 
         if input_params["use_table_recognition"] and not self.use_table_recognition:
         if input_params["use_table_recognition"] and not self.use_table_recognition:
-            raise ValueError("The models for table recognition are not initialized.")
+            logging.error(
+                "Set use_table_recognition, but the models for table recognition are not initialized."
+            )
+            return False
 
 
-        return
+        return True
 
 
     def predict(
     def predict(
         self,
         self,
-        input,
-        use_doc_orientation_classify=True,
-        use_doc_unwarping=True,
-        use_common_ocr=True,
-        use_seal_recognition=True,
-        use_table_recognition=True,
+        input: str | list[str] | np.ndarray | list[np.ndarray],
+        use_doc_orientation_classify: bool = False,
+        use_doc_unwarping: bool = False,
+        use_common_ocr: bool = True,
+        use_seal_recognition: bool = True,
+        use_table_recognition: bool = True,
         **kwargs
         **kwargs
-    ):
+    ) -> LayoutParsingResult:
+        """
+        This function predicts the layout parsing result for the given input.
+
+        Args:
+            input (str | list[str] | np.ndarray | list[np.ndarray]): The input image(s) to be processed.
+            use_doc_orientation_classify (bool): Whether to use document orientation classification.
+            use_doc_unwarping (bool): Whether to use document unwarping.
+            use_common_ocr (bool): Whether to use common OCR.
+            use_seal_recognition (bool): Whether to use seal recognition.
+            use_table_recognition (bool): Whether to use table recognition.
+            **kwargs: Additional keyword arguments.
+
+        Returns:
+            LayoutParsingResult: The predicted layout parsing result.
+        """
 
 
         if not isinstance(input, list):
         if not isinstance(input, list):
             input_list = [input]
             input_list = [input]
@@ -139,8 +213,11 @@ class LayoutParsingPipeline(BasePipeline):
 
 
         if use_doc_orientation_classify or use_doc_unwarping:
         if use_doc_orientation_classify or use_doc_unwarping:
             input_params["use_doc_preprocessor"] = True
             input_params["use_doc_preprocessor"] = True
+        else:
+            input_params["use_doc_preprocessor"] = False
 
 
-        self.check_input_params(input_params)
+        if not self.check_input_params_valid(input_params):
+            yield {"error": "input params invalid"}
 
 
         img_id = 1
         img_id = 1
         for input in input_list:
         for input in input_list:

+ 90 - 13
paddlex/inference/pipelines_new/layout_parsing/result.py

@@ -23,30 +23,92 @@ from PIL import Image, ImageDraw, ImageFont
 from ....utils.fonts import PINGFANG_FONT_FILE_PATH
 from ....utils.fonts import PINGFANG_FONT_FILE_PATH
 from ..components import CVResult, HtmlMixin, XlsxMixin
 from ..components import CVResult, HtmlMixin, XlsxMixin
 
 
+from typing import Any, Dict, Optional
+
 
 
 class TableRecognitionResult(CVResult, HtmlMixin, XlsxMixin):
 class TableRecognitionResult(CVResult, HtmlMixin, XlsxMixin):
-    def __init__(self, data):
-        super().__init__(data)
-        HtmlMixin.__init__(self)
-        XlsxMixin.__init__(self)
+    """table recognition result"""
 
 
-    def save_to_html(self, save_path, *args, **kwargs):
+    def __init__(self, data: Dict) -> None:
+        """Initializes the object with given data and sets up mixins for HTML and XLSX processing."""
+        super().__init__(data)
+        HtmlMixin.__init__(self)  # Initializes the HTML mixin functionality
+        XlsxMixin.__init__(self)  # Initializes the XLSX mixin functionality
+
+    def save_to_html(self, save_path: str, *args, **kwargs) -> None:
+        """
+        Save the content to an HTML file.
+
+        Args:
+            save_path (str): The path to save the HTML file. If the path does not end with '.html',
+                          it will append '/res_table_%d.html' % self['table_region_id'] to the path.
+            *args: Additional positional arguments to be passed to the superclass method.
+            **kwargs: Additional keyword arguments to be passed to the superclass method.
+
+        Returns:
+            None
+        """
         if not str(save_path).lower().endswith(".html"):
         if not str(save_path).lower().endswith(".html"):
             save_path = save_path + "/res_table_%d.html" % self["table_region_id"]
             save_path = save_path + "/res_table_%d.html" % self["table_region_id"]
         super().save_to_html(save_path, *args, **kwargs)
         super().save_to_html(save_path, *args, **kwargs)
 
 
-    def _to_html(self):
+    def _to_html(self) -> str:
+        """Converts the prediction to its corresponding HTML representation.
+
+        Returns:
+            str: The HTML string representation of the prediction.
+        """
         return self["pred_html"]
         return self["pred_html"]
 
 
-    def save_to_xlsx(self, save_path, *args, **kwargs):
+    def save_to_xlsx(self, save_path: str, *args, **kwargs) -> None:
+        """
+        Save the content to an Excel file (.xlsx).
+
+        If the save_path does not end with '.xlsx', it appends a default filename
+        based on the table_region_id attribute.
+
+        Args:
+            save_path (str): The path where the Excel file should be saved.
+            *args: Additional positional arguments passed to the superclass method.
+            **kwargs: Additional keyword arguments passed to the superclass method.
+
+        Returns:
+            None
+        """
         if not str(save_path).lower().endswith(".xlsx"):
         if not str(save_path).lower().endswith(".xlsx"):
             save_path = save_path + "/res_table_%d.xlsx" % self["table_region_id"]
             save_path = save_path + "/res_table_%d.xlsx" % self["table_region_id"]
         super().save_to_xlsx(save_path, *args, **kwargs)
         super().save_to_xlsx(save_path, *args, **kwargs)
 
 
-    def _to_xlsx(self):
+    def _to_xlsx(self) -> str:
+        """Converts the prediction HTML to an XLSX file path.
+
+        Returns:
+            str: The path to the XLSX file containing the prediction data.
+        """
         return self["pred_html"]
         return self["pred_html"]
 
 
-    def save_to_img(self, save_path, *args, **kwargs):
+    def save_to_img(self, save_path: str, *args, **kwargs) -> None:
+        """
+        Save the table and OCR result images to the specified path.
+
+        Args:
+            save_path (str): The directory path to save the images.
+            *args: Additional positional arguments.
+            **kwargs: Additional keyword arguments.
+
+        Returns:
+            None
+
+        Raises:
+            No specific exceptions are raised.
+
+        Notes:
+            - If save_path does not end with '.jpg' or '.png', the function appends '_res_table_cell_%d.jpg' and '_res_table_ocr_%d.jpg' to save_path
+              with table_region_id respectively for table cell and OCR images.
+            - The OCR result image is saved first with '_res_table_ocr_%d.jpg'.
+            - Then the table image is saved with '_res_table_cell_%d.jpg'.
+            - Calls the superclass's save_to_img method to save the table image.
+        """
         if not str(save_path).lower().endswith((".jpg", ".png")):
         if not str(save_path).lower().endswith((".jpg", ".png")):
             ocr_save_path = (
             ocr_save_path = (
                 save_path + "/res_table_ocr_%d.jpg" % self["table_region_id"]
                 save_path + "/res_table_ocr_%d.jpg" % self["table_region_id"]
@@ -55,7 +117,13 @@ class TableRecognitionResult(CVResult, HtmlMixin, XlsxMixin):
         self["table_ocr_pred"].save_to_img(ocr_save_path)
         self["table_ocr_pred"].save_to_img(ocr_save_path)
         super().save_to_img(save_path, *args, **kwargs)
         super().save_to_img(save_path, *args, **kwargs)
 
 
-    def _to_img(self):
+    def _to_img(self) -> np.ndarray:
+        """
+        Convert the input image with table OCR predictions to an image with cell boundaries highlighted.
+
+        Returns:
+            np.ndarray: The input image with cell boundaries highlighted in red.
+        """
         input_img = self["table_ocr_pred"]["input_img"].copy()
         input_img = self["table_ocr_pred"]["input_img"].copy()
         cell_box_list = self["cell_box_list"]
         cell_box_list = self["cell_box_list"]
         for box in cell_box_list:
         for box in cell_box_list:
@@ -65,12 +133,21 @@ class TableRecognitionResult(CVResult, HtmlMixin, XlsxMixin):
 
 
 
 
 class LayoutParsingResult(dict):
 class LayoutParsingResult(dict):
-    def __init__(self, data):
+    """Layout Parsing Result"""
+
+    def __init__(self, data) -> None:
+        """Initializes a new instance of the class with the specified data."""
         super().__init__(data)
         super().__init__(data)
 
 
-    def save_results(self, save_path):
+    def save_results(self, save_path: str) -> None:
+        """Save the layout parsing results to the specified directory.
+
+        Args:
+            save_path (str): The directory path to save the results.
+        """
+
         if not os.path.isdir(save_path):
         if not os.path.isdir(save_path):
-            raise ValueError("The save path should be a dir.")
+            return
 
 
         layout_det_res = self["layout_det_res"]
         layout_det_res = self["layout_det_res"]
         save_img_path = save_path + "/layout_det_result.jpg"
         save_img_path = save_path + "/layout_det_result.jpg"

+ 47 - 10
paddlex/inference/pipelines_new/layout_parsing/table_recognition_post_processing.py

@@ -15,9 +15,11 @@
 from .utils import convert_points_to_boxes, get_sub_regions_ocr_res
 from .utils import convert_points_to_boxes, get_sub_regions_ocr_res
 import numpy as np
 import numpy as np
 from .result import TableRecognitionResult
 from .result import TableRecognitionResult
+from typing import Any, Dict, Optional
+from ..ocr.result import OCRResult
 
 
 
 
-def get_ori_image_coordinate(x, y, box_list):
+def get_ori_image_coordinate(x: int, y: int, box_list: list) -> list:
     """
     """
     get the original coordinate from Cropped image to Original image.
     get the original coordinate from Cropped image to Original image.
     Args:
     Args:
@@ -38,8 +40,19 @@ def get_ori_image_coordinate(x, y, box_list):
 
 
 
 
 def convert_table_structure_pred_bbox(
 def convert_table_structure_pred_bbox(
-    table_structure_pred, crop_start_point, img_shape
-):
+    table_structure_pred: Dict, crop_start_point: list, img_shape: tuple
+) -> None:
+    """
+    Convert the predicted table structure bounding boxes to the original image coordinate system.
+
+    Args:
+        table_structure_pred (Dict): A dictionary containing the predicted table structure, including bounding boxes ('bbox').
+        crop_start_point (list): A list of two integers representing the starting point (x, y) of the cropped image region.
+        img_shape (tuple): A tuple of two integers representing the shape (height, width) of the original image.
+
+    Returns:
+        None: The function modifies the 'table_structure_pred' dictionary in place by adding the 'cell_box_list' key.
+    """
 
 
     cell_points_list = table_structure_pred["bbox"]
     cell_points_list = table_structure_pred["bbox"]
     ori_cell_points_list = get_ori_image_coordinate(
     ori_cell_points_list = get_ori_image_coordinate(
@@ -55,7 +68,7 @@ def convert_table_structure_pred_bbox(
     return
     return
 
 
 
 
-def distance(box_1, box_2):
+def distance(box_1: list, box_2: list) -> float:
     """
     """
     compute the distance between two boxes
     compute the distance between two boxes
 
 
@@ -64,7 +77,7 @@ def distance(box_1, box_2):
         box_2 (list): second rectangle box,eg.(x1, y1, x2, y2)
         box_2 (list): second rectangle box,eg.(x1, y1, x2, y2)
 
 
     Returns:
     Returns:
-        int: the distance between two boxes
+        float: the distance between two boxes
     """
     """
     x1, y1, x2, y2 = box_1
     x1, y1, x2, y2 = box_1
     x3, y3, x4, y4 = box_2
     x3, y3, x4, y4 = box_2
@@ -74,7 +87,7 @@ def distance(box_1, box_2):
     return dis + min(dis_2, dis_3)
     return dis + min(dis_2, dis_3)
 
 
 
 
-def compute_iou(rec1, rec2):
+def compute_iou(rec1: list, rec2: list) -> float:
     """
     """
     computing IoU
     computing IoU
     Args:
     Args:
@@ -104,7 +117,7 @@ def compute_iou(rec1, rec2):
         return (intersect / (sum_area - intersect)) * 1.0
         return (intersect / (sum_area - intersect)) * 1.0
 
 
 
 
-def match_table_and_ocr(cell_box_list, ocr_dt_boxes):
+def match_table_and_ocr(cell_box_list: list, ocr_dt_boxes: list) -> dict:
     """
     """
     match table and ocr
     match table and ocr
 
 
@@ -133,7 +146,20 @@ def match_table_and_ocr(cell_box_list, ocr_dt_boxes):
     return matched
     return matched
 
 
 
 
-def get_html_result(matched_index, ocr_contents, pred_structures):
+def get_html_result(
+    matched_index: dict, ocr_contents: dict, pred_structures: list
+) -> str:
+    """
+    Generates HTML content based on the matched index, OCR contents, and predicted structures.
+
+    Args:
+        matched_index (dict): A dictionary containing matched indices.
+        ocr_contents (dict): A dictionary of OCR contents.
+        pred_structures (list): A list of predicted HTML structures.
+
+    Returns:
+        str: Generated HTML content as a string.
+    """
     pred_html = []
     pred_html = []
     td_index = 0
     td_index = 0
     head_structure = pred_structures[0:3]
     head_structure = pred_structures[0:3]
@@ -182,9 +208,20 @@ def get_html_result(matched_index, ocr_contents, pred_structures):
     return html
     return html
 
 
 
 
-def get_table_recognition_res(crop_img_info, table_structure_pred, overall_ocr_res):
-    """get_table_recognition_res"""
+def get_table_recognition_res(
+    crop_img_info: dict, table_structure_pred: dict, overall_ocr_res: OCRResult
+) -> TableRecognitionResult:
+    """
+    Retrieve table recognition result from cropped image info, table structure prediction, and overall OCR result.
+
+    Args:
+        crop_img_info (dict): Information about the cropped image, including the bounding box.
+        table_structure_pred (dict): Predicted table structure.
+        overall_ocr_res (OCRResult): Overall OCR result from the input image.
 
 
+    Returns:
+        TableRecognitionResult: An object containing the table recognition result.
+    """
     table_box = np.array([crop_img_info["box"]])
     table_box = np.array([crop_img_info["box"]])
     table_ocr_pred = get_sub_regions_ocr_res(overall_ocr_res, table_box)
     table_ocr_pred = get_sub_regions_ocr_res(overall_ocr_res, table_box)
 
 

+ 38 - 7
paddlex/inference/pipelines_new/layout_parsing/utils.py

@@ -16,9 +16,23 @@ __all__ = ["convert_points_to_boxes", "get_sub_regions_ocr_res"]
 
 
 import numpy as np
 import numpy as np
 import copy
 import copy
+from ..ocr.result import OCRResult
 
 
 
 
-def convert_points_to_boxes(dt_polys):
+def convert_points_to_boxes(dt_polys: list) -> np.ndarray:
+    """
+    Converts a list of polygons to a numpy array of bounding boxes.
+
+    Args:
+        dt_polys (list): A list of polygons, where each polygon is represented
+                        as a list of (x, y) points.
+
+    Returns:
+        np.ndarray: A numpy array of bounding boxes, where each box is represented
+                    as [left, top, right, bottom].
+                    If the input list is empty, returns an empty numpy array.
+    """
+
     if len(dt_polys) > 0:
     if len(dt_polys) > 0:
         dt_polys_tmp = dt_polys.copy()
         dt_polys_tmp = dt_polys.copy()
         dt_polys_tmp = np.array(dt_polys_tmp)
         dt_polys_tmp = np.array(dt_polys_tmp)
@@ -33,8 +47,17 @@ def convert_points_to_boxes(dt_polys):
     return dt_boxes
     return dt_boxes
 
 
 
 
-def get_overlap_boxes_idx(src_boxes, ref_boxes):
-    """get overlap boxes idx"""
+def get_overlap_boxes_idx(src_boxes: np.ndarray, ref_boxes: np.ndarray) -> list:
+    """
+    Get the indices of source boxes that overlap with reference boxes based on a specified threshold.
+
+    Args:
+        src_boxes (np.ndarray): A 2D numpy array of source bounding boxes.
+        ref_boxes (np.ndarray): A 2D numpy array of reference bounding boxes.
+
+    Returns:
+        list: A list of indices of source boxes that overlap with any reference box.
+    """
     match_idx_list = []
     match_idx_list = []
     src_boxes_num = len(src_boxes)
     src_boxes_num = len(src_boxes)
     if src_boxes_num > 0 and len(ref_boxes) > 0:
     if src_boxes_num > 0 and len(ref_boxes) > 0:
@@ -51,12 +74,20 @@ def get_overlap_boxes_idx(src_boxes, ref_boxes):
     return match_idx_list
     return match_idx_list
 
 
 
 
-def get_sub_regions_ocr_res(overall_ocr_res, object_boxes, flag_within=True):
-    """
-    :param flag_within: True (within the object regions), False (outside the object regions)
-    :return:
+def get_sub_regions_ocr_res(
+    overall_ocr_res: OCRResult, object_boxes: list, flag_within: bool = True
+) -> OCRResult:
     """
     """
+    Filters OCR results to only include text boxes within specified object boxes based on a flag.
+
+    Args:
+        overall_ocr_res (OCRResult): The original OCR result containing all text boxes.
+        object_boxes (list): A list of bounding boxes for the objects of interest.
+        flag_within (bool): If True, only include text boxes within the object boxes. If False, exclude text boxes within the object boxes.
 
 
+    Returns:
+        OCRResult: A filtered OCR result containing only the relevant text boxes.
+    """
     sub_regions_ocr_res = copy.deepcopy(overall_ocr_res)
     sub_regions_ocr_res = copy.deepcopy(overall_ocr_res)
     sub_regions_ocr_res["input_img"] = overall_ocr_res["input_img"]
     sub_regions_ocr_res["input_img"] = overall_ocr_res["input_img"]
     sub_regions_ocr_res["img_id"] = -1
     sub_regions_ocr_res["img_id"] = -1

+ 36 - 12
paddlex/inference/pipelines_new/ocr/pipeline.py

@@ -14,12 +14,15 @@
 
 
 from ..base import BasePipeline
 from ..base import BasePipeline
 from typing import Any, Dict, Optional
 from typing import Any, Dict, Optional
-from ..components import SortQuadBoxes, CropByPolys
+from ..components import SortQuadBoxes, SortPolyBoxes, CropByPolys
 from .result import OCRResult
 from .result import OCRResult
 
 
 ########## [TODO]后续需要更新路径
 ########## [TODO]后续需要更新路径
 from ...components.transforms import ReadImage
 from ...components.transforms import ReadImage
 
 
+from ...utils.pp_option import PaddlePredictorOption
+import numpy as np
+
 
 
 class OCRPipeline(BasePipeline):
 class OCRPipeline(BasePipeline):
     """OCR Pipeline"""
     """OCR Pipeline"""
@@ -28,12 +31,22 @@ class OCRPipeline(BasePipeline):
 
 
     def __init__(
     def __init__(
         self,
         self,
-        config,
-        device=None,
-        pp_option=None,
+        config: Dict,
+        device: str = None,
+        pp_option: PaddlePredictorOption = None,
         use_hpip: bool = False,
         use_hpip: bool = False,
         hpi_params: Optional[Dict[str, Any]] = None,
         hpi_params: Optional[Dict[str, Any]] = None,
-    ):
+    ) -> None:
+        """
+        Initializes the class with given configurations and options.
+
+        Args:
+            config (Dict): Configuration dictionary containing model and other parameters.
+            device (str): The device to run the prediction on. Default is None.
+            pp_option (PaddlePredictorOption): Options for PaddlePaddle predictor. Default is None.
+            use_hpip (bool): Whether to use high-performance inference (hpip) for prediction. Defaults to False.
+            hpi_params (Optional[Dict[str, Any]]): HPIP specific parameters. Default is None.
+        """
         super().__init__(
         super().__init__(
             device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_params=hpi_params
             device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_params=hpi_params
         )
         )
@@ -46,22 +59,34 @@ class OCRPipeline(BasePipeline):
 
 
         self.text_type = config["text_type"]
         self.text_type = config["text_type"]
 
 
-        self._sort_quad_boxes = SortQuadBoxes()
-
         if self.text_type == "common":
         if self.text_type == "common":
+            self._sort_boxes = SortQuadBoxes()
             self._crop_by_polys = CropByPolys(det_box_type="quad")
             self._crop_by_polys = CropByPolys(det_box_type="quad")
         elif self.text_type == "seal":
         elif self.text_type == "seal":
+            self._sort_boxes = SortPolyBoxes()
             self._crop_by_polys = CropByPolys(det_box_type="poly")
             self._crop_by_polys = CropByPolys(det_box_type="poly")
         else:
         else:
             raise ValueError("Unsupported text type {}".format(self.text_type))
             raise ValueError("Unsupported text type {}".format(self.text_type))
 
 
         self.img_reader = ReadImage(format="BGR")
         self.img_reader = ReadImage(format="BGR")
 
 
-    def predict(self, input, **kwargs):
+    def predict(
+        self, input: str | list[str] | np.ndarray | list[np.ndarray], **kwargs
+    ) -> OCRResult:
+        """Predicts OCR results for the given input.
+
+        Args:
+            input (str | list[str] | np.ndarray | list[np.ndarray]): The input image(s) or path(s) to the images.
+            **kwargs: Additional keyword arguments that can be passed to the function.
+
+        Returns:
+            OCRResult: An iterable of OCRResult objects, each containing the predicted text and other relevant information.
+        """
         if not isinstance(input, list):
         if not isinstance(input, list):
             input_list = [input]
             input_list = [input]
         else:
         else:
             input_list = input
             input_list = input
+
         img_id = 1
         img_id = 1
         for input in input_list:
         for input in input_list:
             if isinstance(input, str):
             if isinstance(input, str):
@@ -76,10 +101,9 @@ class OCRPipeline(BasePipeline):
             dt_polys = det_res["dt_polys"]
             dt_polys = det_res["dt_polys"]
             dt_scores = det_res["dt_scores"]
             dt_scores = det_res["dt_scores"]
 
 
-            ########## [TODO]需要确认检测模块和识别模块过滤阈值等情况
+            ########## [TODO] Need to confirm filtering thresholds for detection and recognition modules
 
 
-            if self.text_type == "common":
-                dt_polys = self._sort_quad_boxes(dt_polys)
+            dt_polys = self._sort_boxes(dt_polys)
 
 
             single_img_res = {
             single_img_res = {
                 "input_img": image_array,
                 "input_img": image_array,
@@ -93,7 +117,7 @@ class OCRPipeline(BasePipeline):
             if len(dt_polys) > 0:
             if len(dt_polys) > 0:
                 all_subs_of_img = list(self._crop_by_polys(image_array, dt_polys))
                 all_subs_of_img = list(self._crop_by_polys(image_array, dt_polys))
 
 
-                ########## [TODO]updata in future
+                ########## [TODO] Update in the future
                 for sub_img in all_subs_of_img:
                 for sub_img in all_subs_of_img:
                     sub_img["input"] = sub_img["img"]
                     sub_img["input"] = sub_img["img"]
                 ##########
                 ##########

+ 50 - 22
paddlex/inference/pipelines_new/ocr/result.py

@@ -19,18 +19,41 @@ import cv2
 import PIL
 import PIL
 from PIL import Image, ImageDraw, ImageFont
 from PIL import Image, ImageDraw, ImageFont
 
 
-from ....utils.fonts import PINGFANG_FONT_FILE_PATH
+from ....utils.fonts import PINGFANG_FONT_FILE_PATH, create_font
 from ..components import CVResult
 from ..components import CVResult
 
 
 
 
 class OCRResult(CVResult):
 class OCRResult(CVResult):
-    def save_to_img(self, save_path, *args, **kwargs):
+    """OCR result"""
+
+    def save_to_img(self, save_path: str, *args, **kwargs) -> None:
+        """
+        Save the image to the specified path with the appropriate extension.
+
+        If the save_path does not end with '.jpg' or '.png', it appends '_res_ocr_<img_id>.jpg'
+        to the path where <img_id> is the id of the image.
+
+        Args:
+            save_path (str): The path to save the image.
+            *args: Additional positional arguments.
+            **kwargs: Additional keyword arguments.
+        """
         if not str(save_path).lower().endswith((".jpg", ".png")):
         if not str(save_path).lower().endswith((".jpg", ".png")):
             img_id = self["img_id"]
             img_id = self["img_id"]
             save_path = save_path + "/res_ocr_%d.jpg" % img_id
             save_path = save_path + "/res_ocr_%d.jpg" % img_id
         super().save_to_img(save_path, *args, **kwargs)
         super().save_to_img(save_path, *args, **kwargs)
 
 
-    def get_minarea_rect(self, points):
+    def get_minarea_rect(self, points: np.ndarray) -> np.ndarray:
+        """
+        Get the minimum area rectangle for the given points using OpenCV.
+
+        Args:
+            points (np.ndarray): An array of 2D points.
+
+        Returns:
+            np.ndarray: An array of 2D points representing the corners of the minimum area rectangle
+                     in a specific order (clockwise or counterclockwise starting from the top-left corner).
+        """
         bounding_box = cv2.minAreaRect(points)
         bounding_box = cv2.minAreaRect(points)
         points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
         points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
 
 
@@ -54,8 +77,14 @@ class OCRResult(CVResult):
 
 
         return box
         return box
 
 
-    def _to_img(self):
-        """draw ocr result"""
+    def _to_img(self) -> PIL.Image:
+        """
+        Converts the internal data to a PIL Image with detection and recognition results.
+
+        Returns:
+            PIL.Image: An image with detection boxes, texts, and scores blended on it.
+        """
+
         # TODO(gaotingquan): mv to postprocess
         # TODO(gaotingquan): mv to postprocess
         drop_score = 0.5
         drop_score = 0.5
 
 
@@ -105,8 +134,22 @@ class OCRResult(CVResult):
         return img_show
         return img_show
 
 
 
 
-def draw_box_txt_fine(img_size, box, txt, font_path):
-    """draw box text"""
+# Adds a function comment according to Google Style Guide
+def draw_box_txt_fine(
+    img_size: tuple, box: np.ndarray, txt: str, font_path: str
+) -> np.ndarray:
+    """
+    Draws text in a box on an image with fine control over size and orientation.
+
+    Args:
+        img_size (tuple): The size of the output image (width, height).
+        box (np.ndarray): A 4x2 numpy array defining the corners of the box in (x, y) order.
+        txt (str): The text to draw inside the box.
+        font_path (str): The path to the font file to use for drawing the text.
+
+    Returns:
+        np.ndarray: An image with the text drawn in the specified box.
+    """
     box_height = int(
     box_height = int(
         math.sqrt((box[0][0] - box[3][0]) ** 2 + (box[0][1] - box[3][1]) ** 2)
         math.sqrt((box[0][0] - box[3][0]) ** 2 + (box[0][1] - box[3][1]) ** 2)
     )
     )
@@ -144,18 +187,3 @@ def draw_box_txt_fine(img_size, box, txt, font_path):
         borderValue=(255, 255, 255),
         borderValue=(255, 255, 255),
     )
     )
     return img_right_text
     return img_right_text
-
-
-def create_font(txt, sz, font_path):
-    """create font"""
-    font_size = int(sz[1] * 0.8)
-    font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
-    if int(PIL.__version__.split(".")[0]) < 10:
-        length = font.getsize(txt)[0]
-    else:
-        length = font.getlength(txt)
-
-    if length > sz[0]:
-        font_size = int(font_size * sz[0] / length)
-        font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
-    return font

+ 286 - 95
paddlex/inference/pipelines_new/pp_chatocrv3_doc/pipeline.py

@@ -28,6 +28,12 @@ import json
 
 
 from ....utils import logging
 from ....utils import logging
 
 
+from ...utils.pp_option import PaddlePredictorOption
+
+from ..layout_parsing.result import LayoutParsingResult
+
+import numpy as np
+
 
 
 class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
 class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
     """PP-ChatOCRv3-doc Pipeline"""
     """PP-ChatOCRv3-doc Pipeline"""
@@ -36,12 +42,22 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
 
 
     def __init__(
     def __init__(
         self,
         self,
-        config,
-        device=None,
-        pp_option=None,
+        config: Dict,
+        device: str = None,
+        pp_option: PaddlePredictorOption = None,
         use_hpip: bool = False,
         use_hpip: bool = False,
         hpi_params: Optional[Dict[str, Any]] = None,
         hpi_params: Optional[Dict[str, Any]] = None,
-    ):
+    ) -> None:
+        """Initializes the pp-chatocrv3-doc pipeline.
+
+        Args:
+            config (Dict): Configuration dictionary containing various settings.
+            device (str, optional): Device to run the predictions on. Defaults to None.
+            pp_option (PaddlePredictorOption, optional): PaddlePredictor options. Defaults to None.
+            use_hpip (bool, optional): Whether to use high-performance inference (hpip) for prediction. Defaults to False.
+            hpi_params (Optional[Dict[str, Any]], optional): HPIP parameters. Defaults to None.
+        """
+
         super().__init__(
         super().__init__(
             device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_params=hpi_params
             device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_params=hpi_params
         )
         )
@@ -50,43 +66,72 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
 
 
         self.img_reader = ReadImage(format="BGR")
         self.img_reader = ReadImage(format="BGR")
 
 
-    def inintial_predictor(self, config):
-        # layout_parsing_config = config['SubPipelines']["LayoutParser"]
-        # self.layout_parsing_pipeline = self.create_pipeline(layout_parsing_config)
+        self.table_structure_len_max = 500
+
+    def inintial_predictor(self, config: dict) -> None:
+        """
+        Initializes the predictor with the given configuration.
+
+        Args:
+            config (dict): The configuration dictionary containing the necessary
+                                parameters for initializing the predictor.
+        Returns:
+            None
+        """
+        layout_parsing_config = config["SubPipelines"]["LayoutParser"]
+        self.layout_parsing_pipeline = self.create_pipeline(layout_parsing_config)
+
+        from .. import create_chat_bot
 
 
         chat_bot_config = config["SubModules"]["LLM_Chat"]
         chat_bot_config = config["SubModules"]["LLM_Chat"]
-        self.chat_bot = self.create_chat_bot(chat_bot_config)
+        self.chat_bot = create_chat_bot(chat_bot_config)
+
+        from .. import create_retriever
 
 
         retriever_config = config["SubModules"]["LLM_Retriever"]
         retriever_config = config["SubModules"]["LLM_Retriever"]
-        self.retriever = self.create_retriever(retriever_config)
+        self.retriever = create_retriever(retriever_config)
+
+        from .. import create_prompt_engeering
 
 
         text_pe_config = config["SubModules"]["PromptEngneering"]["KIE_CommonText"]
         text_pe_config = config["SubModules"]["PromptEngneering"]["KIE_CommonText"]
-        self.text_pe = self.create_prompt_engeering(text_pe_config)
+        self.text_pe = create_prompt_engeering(text_pe_config)
 
 
         table_pe_config = config["SubModules"]["PromptEngneering"]["KIE_Table"]
         table_pe_config = config["SubModules"]["PromptEngneering"]["KIE_Table"]
-        self.table_pe = self.create_prompt_engeering(table_pe_config)
+        self.table_pe = create_prompt_engeering(table_pe_config)
 
 
         return
         return
 
 
-    def decode_visual_result(self, layout_parsing_result):
+    def decode_visual_result(
+        self, layout_parsing_result: LayoutParsingResult
+    ) -> VisualInfoResult:
+        """
+        Decodes the visual result from the layout parsing result.
+
+        Args:
+            layout_parsing_result (LayoutParsingResult): The result of layout parsing.
+
+        Returns:
+            VisualInfoResult: The decoded visual information.
+        """
         text_paragraphs_ocr_res = layout_parsing_result["text_paragraphs_ocr_res"]
         text_paragraphs_ocr_res = layout_parsing_result["text_paragraphs_ocr_res"]
         seal_res_list = layout_parsing_result["seal_res_list"]
         seal_res_list = layout_parsing_result["seal_res_list"]
         normal_text_dict = {}
         normal_text_dict = {}
-        layout_type = "text"
-        for text in text_paragraphs_ocr_res["rec_text"]:
-            if layout_type not in normal_text_dict:
-                normal_text_dict[layout_type] = text
-            else:
-                normal_text_dict[layout_type] += f"\n {text}"
 
 
-        layout_type = "seal"
         for seal_res in seal_res_list:
         for seal_res in seal_res_list:
             for text in seal_res["rec_text"]:
             for text in seal_res["rec_text"]:
+                layout_type = "印章"
                 if layout_type not in normal_text_dict:
                 if layout_type not in normal_text_dict:
-                    normal_text_dict[layout_type] = text
+                    normal_text_dict[layout_type] = f"{text}"
                 else:
                 else:
                     normal_text_dict[layout_type] += f"\n {text}"
                     normal_text_dict[layout_type] += f"\n {text}"
 
 
+        for text in text_paragraphs_ocr_res["rec_text"]:
+            layout_type = "words in text block"
+            if layout_type not in normal_text_dict:
+                normal_text_dict[layout_type] = text
+            else:
+                normal_text_dict[layout_type] += f"\n {text}"
+
         table_res_list = layout_parsing_result["table_res_list"]
         table_res_list = layout_parsing_result["table_res_list"]
         table_text_list = []
         table_text_list = []
         table_html_list = []
         table_html_list = []
@@ -101,16 +146,35 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
         visual_info["table_html_list"] = table_html_list
         visual_info["table_html_list"] = table_html_list
         return VisualInfoResult(visual_info)
         return VisualInfoResult(visual_info)
 
 
+    # Function to perform visual prediction on input images
     def visual_predict(
     def visual_predict(
         self,
         self,
-        input,
-        use_doc_orientation_classify=True,
-        use_doc_unwarping=True,
-        use_common_ocr=True,
-        use_seal_recognition=True,
-        use_table_recognition=True,
+        input: str | list[str] | np.ndarray | list[np.ndarray],
+        use_doc_orientation_classify: bool = False,  # Whether to use document orientation classification
+        use_doc_unwarping: bool = False,  # Whether to use document unwarping
+        use_common_ocr: bool = True,  # Whether to use common OCR
+        use_seal_recognition: bool = True,  # Whether to use seal recognition
+        use_table_recognition: bool = True,  # Whether to use table recognition
         **kwargs,
         **kwargs,
-    ):
+    ) -> dict:
+        """
+        This function takes an input image or a list of images and performs various visual
+        prediction tasks such as document orientation classification, document unwarping,
+        common OCR, seal recognition, and table recognition based on the provided flags.
+
+        Args:
+            input (str | list[str] | np.ndarray | list[np.ndarray]): Input image path, list of image paths,
+                                                                        numpy array of an image, or list of numpy arrays.
+            use_doc_orientation_classify (bool): Flag to use document orientation classification.
+            use_doc_unwarping (bool): Flag to use document unwarping.
+            use_common_ocr (bool): Flag to use common OCR.
+            use_seal_recognition (bool): Flag to use seal recognition.
+            use_table_recognition (bool): Flag to use table recognition.
+            **kwargs: Additional keyword arguments.
+
+        Returns:
+            dict: A dictionary containing the layout parsing result and visual information.
+        """
 
 
         if not isinstance(input, list):
         if not isinstance(input, list):
             input_list = [input]
             input_list = [input]
@@ -145,7 +209,19 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
             }
             }
             yield visual_predict_res
             yield visual_predict_res
 
 
-    def save_visual_info_list(self, visual_info, save_path):
+    def save_visual_info_list(
+        self, visual_info: VisualInfoResult, save_path: str
+    ) -> None:
+        """
+        Save the visual info list to the specified file path.
+
+        Args:
+            visual_info (VisualInfoResult): The visual info result, which can be a single object or a list of objects.
+            save_path (str): The file path to save the visual info list.
+
+        Returns:
+            None
+        """
         if not isinstance(visual_info, list):
         if not isinstance(visual_info, list):
             visual_info_list = [visual_info]
             visual_info_list = [visual_info]
         else:
         else:
@@ -155,13 +231,34 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
             fout.write(json.dumps(visual_info_list, ensure_ascii=False) + "\n")
             fout.write(json.dumps(visual_info_list, ensure_ascii=False) + "\n")
         return
         return
 
 
-    def load_visual_info_list(self, data_path):
+    def load_visual_info_list(self, data_path: str) -> list[VisualInfoResult]:
+        """
+        Loads visual info list from a JSON file.
+
+        Args:
+            data_path (str): The path to the JSON file containing visual info.
+
+        Returns:
+            list[VisualInfoResult]: A list of VisualInfoResult objects parsed from the JSON file.
+        """
         with open(data_path, "r") as fin:
         with open(data_path, "r") as fin:
             data = fin.readline()
             data = fin.readline()
             visual_info_list = json.loads(data)
             visual_info_list = json.loads(data)
         return visual_info_list
         return visual_info_list
 
 
-    def merge_visual_info_list(self, visual_info_list):
+    def merge_visual_info_list(
+        self, visual_info_list: list[VisualInfoResult]
+    ) -> tuple[list, list, list]:
+        """
+        Merge visual info lists.
+
+        Args:
+            visual_info_list (list[VisualInfoResult]): A list of visual info results.
+
+        Returns:
+            tuple[list, list, list]: A tuple containing three lists, one for normal text dicts,
+                                               one for table text lists, and one for table HTML lists.
+        """
         all_normal_text_list = []
         all_normal_text_list = []
         all_table_text_list = []
         all_table_text_list = []
         all_table_html_list = []
         all_table_html_list = []
@@ -174,7 +271,23 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
             all_table_html_list.extend(table_html_list)
             all_table_html_list.extend(table_html_list)
         return all_normal_text_list, all_table_text_list, all_table_html_list
         return all_normal_text_list, all_table_text_list, all_table_html_list
 
 
-    def build_vector(self, visual_info, min_characters=3500, llm_request_interval=1.0):
+    def build_vector(
+        self,
+        visual_info: VisualInfoResult,
+        min_characters: int = 3500,
+        llm_request_interval: float = 1.0,
+    ) -> dict:
+        """
+        Build a vector representation from visual information.
+
+        Args:
+            visual_info (VisualInfoResult): The visual information input, can be a single instance or a list of instances.
+            min_characters (int): The minimum number of characters required for text processing, defaults to 3500.
+            llm_request_interval (float): The interval between LLM requests, defaults to 1.0.
+
+        Returns:
+            dict: A dictionary containing the vector info and a flag indicating if the text is too short.
+        """
 
 
         if not isinstance(visual_info, list):
         if not isinstance(visual_info, list):
             visual_info_list = [visual_info]
             visual_info_list = [visual_info]
@@ -184,17 +297,20 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
         all_visual_info = self.merge_visual_info_list(visual_info_list)
         all_visual_info = self.merge_visual_info_list(visual_info_list)
         all_normal_text_list, all_table_text_list, all_table_html_list = all_visual_info
         all_normal_text_list, all_table_text_list, all_table_html_list = all_visual_info
 
 
-        all_normal_text_str = "".join(
-            ["\n".join(e.values()) for e in all_normal_text_list]
-        )
         vector_info = {}
         vector_info = {}
 
 
         all_items = []
         all_items = []
         for i, normal_text_dict in enumerate(all_normal_text_list):
         for i, normal_text_dict in enumerate(all_normal_text_list):
             for type, text in normal_text_dict.items():
             for type, text in normal_text_dict.items():
-                all_items += [f"{type}:{text}"]
+                all_items += [f"{type}:{text}\n"]
 
 
-        if len(all_normal_text_str) > min_characters:
+        for table_html, table_text in zip(all_table_html_list, all_table_text_list):
+            if len(table_html) > min_characters - self.table_structure_len_max:
+                all_items += [f"table:{table_text}\n"]
+
+        all_text_str = "".join(all_items)
+
+        if len(all_text_str) > min_characters:
             vector_info["flag_too_short_text"] = False
             vector_info["flag_too_short_text"] = False
             vector_info["vector"] = self.retriever.generate_vector_database(all_items)
             vector_info["vector"] = self.retriever.generate_vector_database(all_items)
         else:
         else:
@@ -202,8 +318,16 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
             vector_info["vector"] = all_items
             vector_info["vector"] = all_items
         return vector_info
         return vector_info
 
 
-    def format_key(self, key_list):
-        """format key"""
+    def format_key(self, key_list: str | list[str]) -> list[str]:
+        """
+        Formats the key list.
+
+        Args:
+            key_list (str|list[str]): A string or a list of strings representing the keys.
+
+        Returns:
+            list[str]: A list of formatted keys.
+        """
         if key_list == "":
         if key_list == "":
             return []
             return []
 
 
@@ -217,7 +341,16 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
 
 
         return []
         return []
 
 
-    def fix_llm_result_format(self, llm_result):
+    def fix_llm_result_format(self, llm_result: str) -> dict:
+        """
+        Fix the format of the LLM result.
+
+        Args:
+            llm_result (str): The result from the LLM (Large Language Model).
+
+        Returns:
+            dict: A fixed format dictionary from the LLM result.
+        """
         if not llm_result:
         if not llm_result:
             return {}
             return {}
 
 
@@ -257,12 +390,30 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
                 return {}
                 return {}
 
 
     def generate_and_merge_chat_results(
     def generate_and_merge_chat_results(
-        self, prompt, key_list, final_results, failed_results
-    ):
+        self, prompt: str, key_list: list, final_results: dict, failed_results: dict
+    ) -> None:
+        """
+        Generate and merge chat results into the final results dictionary.
+
+        Args:
+            prompt (str): The input prompt for the chat bot.
+            key_list (list): A list of keys to track which results to merge.
+            final_results (dict): The dictionary to store the final merged results.
+            failed_results (dict): A dictionary of failed results to avoid merging.
+
+        Returns:
+            None
+        """
 
 
         llm_result = self.chat_bot.generate_chat_results(prompt)
         llm_result = self.chat_bot.generate_chat_results(prompt)
-        llm_result = self.fix_llm_result_format(llm_result)
+        if llm_result is None:
+            logging.warning(
+                "chat bot error: \n [prompt:]\n %s\n [result:] %s\n"
+                % (prompt, self.chat_bot.ERROR_MASSAGE)
+            )
+            return
 
 
+        llm_result = self.fix_llm_result_format(llm_result)
         for key, value in llm_result.items():
         for key, value in llm_result.items():
             if value not in failed_results and key in key_list:
             if value not in failed_results and key in key_list:
                 key_list.remove(key)
                 key_list.remove(key)
@@ -271,24 +422,49 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
 
 
     def chat(
     def chat(
         self,
         self,
-        visual_info,
-        key_list,
-        vector_info,
-        text_task_description=None,
-        text_output_format=None,
-        text_rules_str=None,
-        text_few_shot_demo_text_content=None,
-        text_few_shot_demo_key_value_list=None,
-        table_task_description=None,
-        table_output_format=None,
-        table_rules_str=None,
-        table_few_shot_demo_text_content=None,
-        table_few_shot_demo_key_value_list=None,
-    ):
+        key_list: str | list[str],
+        visual_info: VisualInfoResult,
+        use_vector_retrieval: bool = True,
+        vector_info: dict = None,
+        min_characters: int = 3500,
+        text_task_description: str = None,
+        text_output_format: str = None,
+        text_rules_str: str = None,
+        text_few_shot_demo_text_content: str = None,
+        text_few_shot_demo_key_value_list: str = None,
+        table_task_description: str = None,
+        table_output_format: str = None,
+        table_rules_str: str = None,
+        table_few_shot_demo_text_content: str = None,
+        table_few_shot_demo_key_value_list: str = None,
+    ) -> dict:
+        """
+        Generates chat results based on the provided key list and visual information.
+
+        Args:
+            key_list (str | list[str]): A single key or a list of keys to extract information.
+            visual_info (VisualInfoResult): The visual information result.
+            use_vector_retrieval (bool): Whether to use vector retrieval.
+            vector_info (dict): The vector information for retrieval.
+            min_characters (int): The minimum number of characters required.
+            text_task_description (str): The description of the text task.
+            text_output_format (str): The output format for text results.
+            text_rules_str (str): The rules for generating text results.
+            text_few_shot_demo_text_content (str): The text content for few-shot demos.
+            text_few_shot_demo_key_value_list (str): The key-value list for few-shot demos.
+            table_task_description (str): The description of the table task.
+            table_output_format (str): The output format for table results.
+            table_rules_str (str): The rules for generating table results.
+            table_few_shot_demo_text_content (str): The text content for table few-shot demos.
+            table_few_shot_demo_key_value_list (str): The key-value list for table few-shot demos.
+
+        Returns:
+            dict: A dictionary containing the chat results.
+        """
 
 
         key_list = self.format_key(key_list)
         key_list = self.format_key(key_list)
         if len(key_list) == 0:
         if len(key_list) == 0:
-            return {"chat_res": "输入的key_list无效!"}
+            return {"error": "输入的key_list无效!"}
 
 
         if not isinstance(visual_info, list):
         if not isinstance(visual_info, list):
             visual_info_list = [visual_info]
             visual_info_list = [visual_info]
@@ -301,52 +477,67 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
         final_results = {}
         final_results = {}
         failed_results = ["大模型调用失败", "未知", "未找到关键信息", "None", ""]
         failed_results = ["大模型调用失败", "未知", "未找到关键信息", "None", ""]
 
 
-        for all_table_info in [all_table_html_list, all_table_text_list]:
-            for table_info in all_table_info:
-                if len(key_list) == 0:
-                    continue
+        for table_html, table_text in zip(all_table_html_list, all_table_text_list):
+            if len(table_html) <= min_characters - self.table_structure_len_max:
+                for table_info in [table_html, table_text]:
+                    if len(key_list) > 0:
+                        prompt = self.table_pe.generate_prompt(
+                            table_info,
+                            key_list,
+                            task_description=table_task_description,
+                            output_format=table_output_format,
+                            rules_str=table_rules_str,
+                            few_shot_demo_text_content=table_few_shot_demo_text_content,
+                            few_shot_demo_key_value_list=table_few_shot_demo_key_value_list,
+                        )
+
+                        self.generate_and_merge_chat_results(
+                            prompt, key_list, final_results, failed_results
+                        )
 
 
-                prompt = self.table_pe.generate_prompt(
-                    table_info,
+        if len(key_list) > 0:
+            if use_vector_retrieval and vector_info is not None:
+                question_key_list = [f"抽取关键信息:{key}" for key in key_list]
+                vector = vector_info["vector"]
+                if not vector_info["flag_too_short_text"]:
+                    related_text = self.retriever.similarity_retrieval(
+                        question_key_list, vector
+                    )
+                    # print(question_key_list, related_text)
+                else:
+                    if len(vector) > 0:
+                        related_text = "".join(vector)
+                    else:
+                        related_text = ""
+            else:
+                all_items = []
+                for i, normal_text_dict in enumerate(all_normal_text_list):
+                    for type, text in normal_text_dict.items():
+                        all_items += [f"{type}:{text}\n"]
+                related_text = "".join(all_items)
+                if len(related_text) > min_characters:
+                    logging.warning(
+                        "The input text content is too long, the large language model may truncate it."
+                    )
+
+            if len(related_text) > 0:
+                prompt = self.text_pe.generate_prompt(
+                    related_text,
                     key_list,
                     key_list,
-                    task_description=table_task_description,
-                    output_format=table_output_format,
-                    rules_str=table_rules_str,
-                    few_shot_demo_text_content=table_few_shot_demo_text_content,
-                    few_shot_demo_key_value_list=table_few_shot_demo_key_value_list,
+                    task_description=text_task_description,
+                    output_format=text_output_format,
+                    rules_str=text_rules_str,
+                    few_shot_demo_text_content=text_few_shot_demo_text_content,
+                    few_shot_demo_key_value_list=text_few_shot_demo_key_value_list,
                 )
                 )
-
+                # print(prompt)
                 self.generate_and_merge_chat_results(
                 self.generate_and_merge_chat_results(
                     prompt, key_list, final_results, failed_results
                     prompt, key_list, final_results, failed_results
                 )
                 )
 
 
-        if len(key_list) > 0:
-            question_key_list = [f"抽取关键信息:{key}" for key in key_list]
-            vector = vector_info["vector"]
-            if not vector_info["flag_too_short_text"]:
-                related_text = self.retriever.similarity_retrieval(
-                    question_key_list, vector
-                )
-            else:
-                related_text = " ".join(vector)
-
-            prompt = self.text_pe.generate_prompt(
-                related_text,
-                key_list,
-                task_description=text_task_description,
-                output_format=text_output_format,
-                rules_str=text_rules_str,
-                few_shot_demo_text_content=text_few_shot_demo_text_content,
-                few_shot_demo_key_value_list=text_few_shot_demo_key_value_list,
-            )
-
-            self.generate_and_merge_chat_results(
-                prompt, key_list, final_results, failed_results
-            )
-
-        return final_results
+        return {"chat_res": final_results}
 
 
-    def predict(self, *args, **kwargs):
+    def predict(self, *args, **kwargs) -> None:
         logging.error(
         logging.error(
             "PP-ChatOCRv3-doc Pipeline do not support to call `predict()` directly! Please invoke `visual_predict`, `build_vector`, `chat` sequentially to obtain the result."
             "PP-ChatOCRv3-doc Pipeline do not support to call `predict()` directly! Please invoke `visual_predict`, `build_vector`, `chat` sequentially to obtain the result."
         )
         )

+ 0 - 17
paddlex/inference/pipelines_new/pp_chatocrv3_doc/result.py

@@ -27,20 +27,3 @@ class VisualInfoResult(BaseResult):
     """VisualInfoResult"""
     """VisualInfoResult"""
 
 
     pass
     pass
-
-
-# class VectorResult(BaseResult, Base64Mixin):
-#     """VisualInfoResult"""
-
-#     def _to_base64(self):
-#         return self["vector"]
-
-
-# class RetrievalResult(BaseResult):
-#     """VisualInfoResult"""
-
-#     pass
-
-
-# class ChatResult(BaseResult):
-#     """VisualInfoResult"""

+ 21 - 4
paddlex/utils/fonts/__init__.py

@@ -19,13 +19,30 @@ import PIL
 from PIL import ImageFont
 from PIL import ImageFont
 
 
 
 
-def get_pingfang_file_path():
-    """get pingfang font file path"""
+def get_pingfang_file_path() -> str:
+    """
+    Get the path of the PingFang font file.
+
+    Returns:
+    str: The path to the PingFang font file.
+    """
+
     return (Path(__file__).parent / "PingFang-SC-Regular.ttf").resolve().as_posix()
     return (Path(__file__).parent / "PingFang-SC-Regular.ttf").resolve().as_posix()
 
 
 
 
-def create_font(txt, sz, font_path):
-    """create font"""
+def create_font(txt: str, sz: tuple, font_path: str) -> ImageFont:
+    """
+    Create a font object with specified size and path, adjusted to fit within the given image region.
+
+    Parameters:
+    txt (str): The text to be rendered with the font.
+    sz (tuple): A tuple containing the height and width of an image region, used for font size.
+    font_path (str): The path to the font file.
+
+    Returns:
+    ImageFont: An ImageFont object adjusted to fit within the given image region.
+    """
+
     font_size = int(sz[1] * 0.8)
     font_size = int(sz[1] * 0.8)
     font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
     font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
     if int(PIL.__version__.split(".")[0]) < 10:
     if int(PIL.__version__.split(".")[0]) < 10: