Forráskód Böngészése

fix some modules name for new pipelines (#2627)

* add the new architecture of pipelines

* add the new architecture of pipelines

* add explanatory note

* add explanatory note

* fix some modules name
dyning 11 hónapja
szülő
commit
694e30d48a

+ 2 - 2
api_examples/pipelines/test_pp_chatocrv3.py

@@ -22,8 +22,8 @@ key_list = ["驾驶室准乘人数"]
 
 visual_predict_res = pipeline.visual_predict(
     img_path,
-    use_doc_orientation_classify=False,
-    use_doc_unwarping=False,
+    use_doc_orientation_classify=True,
+    use_doc_unwarping=True,
     use_common_ocr=True,
     use_seal_recognition=True,
     use_table_recognition=True,

+ 2 - 2
paddlex/configs/pipelines/OCR.yaml

@@ -6,7 +6,7 @@ pipeline_name: OCR
 ##############################################
 
 input: https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/general_ocr_001.png
-text_type: common
+text_type: general
 
 SubModules:
   TextDetection:
@@ -29,7 +29,7 @@ SubModules:
 
 # SubModules:
 #   TextDetection:
-#     module_name: text_detection
+#     module_name: seal_text_detection
 #     model_name: PP-OCRv4_mobile_seal_det
 #     model_dir: null
 #     batch_size: 1    

+ 8 - 9
paddlex/configs/pipelines/PP-ChatOCRv3-doc.yaml

@@ -10,7 +10,6 @@ SubModules:
     ak: "api_key" # Set this to a real API key
     sk: "secret_key"  # Set this to a real secret key
 
-
   LLM_Retriever:
     module_name: retriever
     model_name: ernie-3.5
@@ -54,18 +53,18 @@ SubPipelines:
   LayoutParser:
     pipeline_name: layout_parsing
     use_doc_preprocessor: True
-    use_common_ocr: True
+    use_general_ocr: True
     use_seal_recognition: True
     use_table_recognition: True
 
     SubModules:
       LayoutDetection:
-        module_name: object_detection
+        module_name: layout_detection
         model_name: RT-DETR-H_layout_3cls
         model_dir: null
         batch_size: 1
-      TableStructurePredictor:
-        module_name: table_recognition
+      TableStructureRecognition:
+        module_name: table_structure_recognition
         model_name: SLANet_plus
         model_dir: null
         batch_size: 1
@@ -77,7 +76,7 @@ SubPipelines:
         use_doc_unwarping: True
         SubModules:
           DocOrientationClassify:
-            module_name: image_classification
+            module_name: doc_text_orientation
             model_name: PP-LCNet_x1_0_doc_ori
             model_dir: null
             batch_size: 1
@@ -87,9 +86,9 @@ SubPipelines:
             model_dir: null
             batch_size: 1
 
-      CommonOCR:
+      GeneralOCR:
         pipeline_name: OCR
-        text_type: common
+        text_type: general
         SubModules:
           TextDetection:
             module_name: text_detection
@@ -107,7 +106,7 @@ SubPipelines:
         text_type: seal
         SubModules:
           TextDetection:
-            module_name: text_detection
+            module_name: seal_text_detection
             model_name: PP-OCRv4_server_seal_det
             model_dir: null
             batch_size: 1    

+ 1 - 1
paddlex/configs/pipelines/doc_preprocessor.yaml

@@ -7,7 +7,7 @@ use_doc_unwarping: True
 
 SubModules:
   DocOrientationClassify:
-    module_name: image_classification
+    module_name: doc_text_orientation
     model_name: PP-LCNet_x1_0_doc_ori
     model_dir: null
     batch_size: 1

+ 9 - 11
paddlex/configs/pipelines/layout_parsing.yaml

@@ -2,21 +2,19 @@
 pipeline_name: layout_parsing
 input: https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/demo_paper.png
 use_doc_preprocessor: True
-use_common_ocr: True
+use_general_ocr: True
 use_seal_recognition: True
 use_table_recognition: True
 
 SubModules:
   LayoutDetection:
-    module_name: object_detection
+    module_name: layout_detection
     model_name: RT-DETR-H_layout_3cls
     model_dir: null
     batch_size: 1
-  ##############################################
-  ####### 【TODO】表格识别的 module_name 需要确认,是否是table_recognition
-  ##############################################
-  TableStructurePredictor:
-    module_name: table_recognition
+
+  TableStructureRecognition:
+    module_name: table_structure_recognition
     model_name: SLANet_plus
     model_dir: null
     batch_size: 1
@@ -28,7 +26,7 @@ SubPipelines:
     use_doc_unwarping: True
     SubModules:
       DocOrientationClassify:
-        module_name: image_classification
+        module_name: doc_text_orientation
         model_name: PP-LCNet_x1_0_doc_ori
         model_dir: null
         batch_size: 1
@@ -37,9 +35,9 @@ SubPipelines:
         model_name: UVDoc
         model_dir: null
         batch_size: 1
-  CommonOCR:
+  GeneralOCR:
     pipeline_name: OCR
-    text_type: common
+    text_type: general
     SubModules:
       TextDetection:
         module_name: text_detection
@@ -56,7 +54,7 @@ SubPipelines:
     text_type: seal
     SubModules:
       TextDetection:
-        module_name: text_detection
+        module_name: seal_text_detection
         model_name: PP-OCRv4_server_seal_det
         model_dir: null
         batch_size: 1    

+ 2 - 0
paddlex/inference/pipelines_new/__init__.py

@@ -134,6 +134,8 @@ def create_pipeline(
         pp_option=pp_option,
         use_hpip=use_hpip,
         hpi_params=hpi_params,
+        *args,
+        **kwargs,
     )
     return pipeline
 

+ 2 - 0
paddlex/inference/pipelines_new/base.py

@@ -38,6 +38,8 @@ class BasePipeline(ABC, metaclass=AutoRegisterABCMetaClass):
         pp_option: PaddlePredictorOption = None,
         use_hpip: bool = False,
         hpi_params: Optional[Dict[str, Any]] = None,
+        *args,
+        **kwargs,
     ) -> None:
         """
         Initializes the class with specified parameters.

+ 11 - 11
paddlex/inference/pipelines_new/components/prompt_engeering/generate_kie_prompt.py

@@ -116,27 +116,27 @@ class GenerateKIEPrompt(BaseGeneratePrompt):
 
         prompt = f"""{task_description}{rules_str}{output_format}{few_shot_demo_text_content}{few_shot_demo_key_value_list}"""
         if self.task_type == "table_kie_prompt":
-            # prompt += f"""\n结合上面,下面正式开始:\
-            #     表格内容:```{text_content}```\
-            #     \n问题列表:{key_list}。""".replace(
-            #     "    ", ""
-            # )
             prompt += f"""\n结合上面,下面正式开始:\
                 表格内容:```{text_content}```\
-                \n关键词列表:{key_list}。""".replace(
+                \n问题列表:{key_list}。""".replace(
                 "    ", ""
             )
-        elif self.task_type == "text_kie_prompt":
-            # prompt += f"""\n结合上面的例子,下面正式开始:\
-            #     OCR文字:```{text_content}```\
-            #     \n问题列表:{key_list}。""".replace(
+            # prompt += f"""\n结合上面,下面正式开始:\
+            #     表格内容:```{text_content}```\
+            #     \n关键词列表:{key_list}。""".replace(
             #     "    ", ""
             # )
+        elif self.task_type == "text_kie_prompt":
             prompt += f"""\n结合上面的例子,下面正式开始:\
                 OCR文字:```{text_content}```\
-                \n关键词列表:{key_list}。""".replace(
+                \n问题列表:{key_list}。""".replace(
                 "    ", ""
             )
+            # prompt += f"""\n结合上面的例子,下面正式开始:\
+            #     OCR文字:```{text_content}```\
+            #     \n关键词列表:{key_list}。""".replace(
+            #     "    ", ""
+            # )
         else:
             raise ValueError(f"{self.task_type} is currently not supported.")
         return prompt

+ 20 - 18
paddlex/inference/pipelines_new/layout_parsing/pipeline.py

@@ -89,12 +89,12 @@ class LayoutParsingPipeline(BasePipeline):
                 doc_preprocessor_config
             )
 
-        self.use_common_ocr = False
-        if "use_common_ocr" in config:
-            self.use_common_ocr = config["use_common_ocr"]
-        if self.use_common_ocr:
-            common_ocr_config = config["SubPipelines"]["CommonOCR"]
-            self.common_ocr_pipeline = self.create_pipeline(common_ocr_config)
+        self.use_general_ocr = False
+        if "use_general_ocr" in config:
+            self.use_general_ocr = config["use_general_ocr"]
+        if self.use_general_ocr:
+            general_ocr_config = config["SubPipelines"]["GeneralOCR"]
+            self.general_ocr_pipeline = self.create_pipeline(general_ocr_config)
 
         self.use_seal_recognition = False
         if "use_seal_recognition" in config:
@@ -107,11 +107,11 @@ class LayoutParsingPipeline(BasePipeline):
         if "use_table_recognition" in config:
             self.use_table_recognition = config["use_table_recognition"]
         if self.use_table_recognition:
-            table_structure_config = config["SubModules"]["TableStructurePredictor"]
+            table_structure_config = config["SubModules"]["TableStructureRecognition"]
             self.table_structure_model = self.create_model(table_structure_config)
-            if not self.use_common_ocr:
-                common_ocr_config = config["SubPipelines"]["OCR"]
-                self.common_ocr_pipeline = self.create_pipeline(common_ocr_config)
+            if not self.use_general_ocr:
+                general_ocr_config = config["SubPipelines"]["GeneralOCR"]
+                self.general_ocr_pipeline = self.create_pipeline(general_ocr_config)
         return
 
     def get_text_paragraphs_ocr_res(
@@ -151,9 +151,9 @@ class LayoutParsingPipeline(BasePipeline):
             )
             return False
 
-        if input_params["use_common_ocr"] and not self.use_common_ocr:
+        if input_params["use_general_ocr"] and not self.use_general_ocr:
             logging.error(
-                "Set use_common_ocr, but the models for common OCR are not initialized."
+                "Set use_general_ocr, but the models for general OCR are not initialized."
             )
             return False
 
@@ -176,7 +176,7 @@ class LayoutParsingPipeline(BasePipeline):
         input: str | list[str] | np.ndarray | list[np.ndarray],
         use_doc_orientation_classify: bool = False,
         use_doc_unwarping: bool = False,
-        use_common_ocr: bool = True,
+        use_general_ocr: bool = True,
         use_seal_recognition: bool = True,
         use_table_recognition: bool = True,
         **kwargs
@@ -188,7 +188,7 @@ class LayoutParsingPipeline(BasePipeline):
             input (str | list[str] | np.ndarray | list[np.ndarray]): The input image(s) to be processed.
             use_doc_orientation_classify (bool): Whether to use document orientation classification.
             use_doc_unwarping (bool): Whether to use document unwarping.
-            use_common_ocr (bool): Whether to use common OCR.
+            use_general_ocr (bool): Whether to use general OCR.
             use_seal_recognition (bool): Whether to use seal recognition.
             use_table_recognition (bool): Whether to use table recognition.
             **kwargs: Additional keyword arguments.
@@ -206,7 +206,7 @@ class LayoutParsingPipeline(BasePipeline):
             "use_doc_preprocessor": self.use_doc_preprocessor,
             "use_doc_orientation_classify": use_doc_orientation_classify,
             "use_doc_unwarping": use_doc_unwarping,
-            "use_common_ocr": use_common_ocr,
+            "use_general_ocr": use_general_ocr,
             "use_seal_recognition": use_seal_recognition,
             "use_table_recognition": use_table_recognition,
         }
@@ -245,8 +245,10 @@ class LayoutParsingPipeline(BasePipeline):
             ########## [TODO]RT-DETR 检测结果有重复
             layout_det_res = next(self.layout_det_model(doc_preprocessor_image))
 
-            if input_params["use_common_ocr"] or input_params["use_table_recognition"]:
-                overall_ocr_res = next(self.common_ocr_pipeline(doc_preprocessor_image))
+            if input_params["use_general_ocr"] or input_params["use_table_recognition"]:
+                overall_ocr_res = next(
+                    self.general_ocr_pipeline(doc_preprocessor_image)
+                )
                 overall_ocr_res["img_id"] = img_id
                 dt_boxes = convert_points_to_boxes(overall_ocr_res["dt_polys"])
                 overall_ocr_res["dt_boxes"] = dt_boxes
@@ -254,7 +256,7 @@ class LayoutParsingPipeline(BasePipeline):
                 overall_ocr_res = {}
 
             text_paragraphs_ocr_res = {}
-            if input_params["use_common_ocr"]:
+            if input_params["use_general_ocr"]:
                 text_paragraphs_ocr_res = self.get_text_paragraphs_ocr_res(
                     overall_ocr_res, layout_det_res
                 )

+ 1 - 1
paddlex/inference/pipelines_new/layout_parsing/result.py

@@ -158,7 +158,7 @@ class LayoutParsingResult(dict):
             save_img_path = save_path + "/doc_preprocessor_result.jpg"
             self["doc_preprocessor_res"].save_to_img(save_img_path)
 
-        if input_params["use_common_ocr"]:
+        if input_params["use_general_ocr"]:
             save_img_path = save_path + "/text_paragraphs_ocr_result.jpg"
             self["text_paragraphs_ocr_res"].save_to_img(save_img_path)
 

+ 1 - 1
paddlex/inference/pipelines_new/ocr/pipeline.py

@@ -59,7 +59,7 @@ class OCRPipeline(BasePipeline):
 
         self.text_type = config["text_type"]
 
-        if self.text_type == "common":
+        if self.text_type == "general":
             self._sort_boxes = SortQuadBoxes()
             self._crop_by_polys = CropByPolys(det_box_type="quad")
         elif self.text_type == "seal":

+ 43 - 26
paddlex/inference/pipelines_new/pp_chatocrv3_doc/pipeline.py

@@ -47,6 +47,7 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
         pp_option: PaddlePredictorOption = None,
         use_hpip: bool = False,
         hpi_params: Optional[Dict[str, Any]] = None,
+        use_layout_parsing: bool = True,
     ) -> None:
         """Initializes the pp-chatocrv3-doc pipeline.
 
@@ -62,6 +63,8 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
             device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_params=hpi_params
         )
 
+        self.use_layout_parsing = use_layout_parsing
+
         self.inintial_predictor(config)
 
         self.img_reader = ReadImage(format="BGR")
@@ -78,8 +81,10 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
         Returns:
             None
         """
-        layout_parsing_config = config["SubPipelines"]["LayoutParser"]
-        self.layout_parsing_pipeline = self.create_pipeline(layout_parsing_config)
+
+        if self.use_layout_parsing:
+            layout_parsing_config = config["SubPipelines"]["LayoutParser"]
+            self.layout_parsing_pipeline = self.create_pipeline(layout_parsing_config)
 
         from .. import create_chat_bot
 
@@ -152,7 +157,7 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
         input: str | list[str] | np.ndarray | list[np.ndarray],
         use_doc_orientation_classify: bool = False,  # Whether to use document orientation classification
         use_doc_unwarping: bool = False,  # Whether to use document unwarping
-        use_common_ocr: bool = True,  # Whether to use common OCR
+        use_general_ocr: bool = True,  # Whether to use general OCR
         use_seal_recognition: bool = True,  # Whether to use seal recognition
         use_table_recognition: bool = True,  # Whether to use table recognition
         **kwargs,
@@ -160,14 +165,14 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
         """
         This function takes an input image or a list of images and performs various visual
         prediction tasks such as document orientation classification, document unwarping,
-        common OCR, seal recognition, and table recognition based on the provided flags.
+        general OCR, seal recognition, and table recognition based on the provided flags.
 
         Args:
             input (str | list[str] | np.ndarray | list[np.ndarray]): Input image path, list of image paths,
                                                                         numpy array of an image, or list of numpy arrays.
             use_doc_orientation_classify (bool): Flag to use document orientation classification.
             use_doc_unwarping (bool): Flag to use document unwarping.
-            use_common_ocr (bool): Flag to use common OCR.
+            use_general_ocr (bool): Flag to use general OCR.
             use_seal_recognition (bool): Flag to use seal recognition.
             use_table_recognition (bool): Flag to use table recognition.
             **kwargs: Additional keyword arguments.
@@ -176,6 +181,9 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
             dict: A dictionary containing the layout parsing result and visual information.
         """
 
+        if not self.use_layout_parsing:
+            raise ValueError("The models for layout parsing are not initialized.")
+
         if not isinstance(input, list):
             input_list = [input]
         else:
@@ -195,7 +203,7 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
                     image_array,
                     use_doc_orientation_classify=use_doc_orientation_classify,
                     use_doc_unwarping=use_doc_unwarping,
-                    use_common_ocr=use_common_ocr,
+                    use_general_ocr=use_general_ocr,
                     use_seal_recognition=use_seal_recognition,
                     use_table_recognition=use_table_recognition,
                 )
@@ -264,6 +272,8 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
         all_table_html_list = []
         for single_visual_info in visual_info_list:
             normal_text_dict = single_visual_info["normal_text_dict"]
+            for key in normal_text_dict:
+                normal_text_dict[key] = normal_text_dict[key].replace("\n", "")
             table_text_list = single_visual_info["table_text_list"]
             table_html_list = single_visual_info["table_html_list"]
             all_normal_text_list.append(normal_text_dict)
@@ -308,6 +318,9 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
             if len(table_html) > min_characters - self.table_structure_len_max:
                 all_items += [f"table:{table_text}\n"]
 
+            # if len(table_html) > min_characters - self.table_structure_len_max:
+            #     all_items += [f"table:{table_text}\n"]
+
         all_text_str = "".join(all_items)
 
         if len(all_text_str) > min_characters:
@@ -413,7 +426,10 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
             )
             return
 
+        # print(prompt, llm_result)
+
         llm_result = self.fix_llm_result_format(llm_result)
+
         for key, value in llm_result.items():
             if value not in failed_results and key in key_list:
                 key_list.remove(key)
@@ -477,27 +493,10 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
         final_results = {}
         failed_results = ["大模型调用失败", "未知", "未找到关键信息", "None", ""]
 
-        for table_html, table_text in zip(all_table_html_list, all_table_text_list):
-            if len(table_html) <= min_characters - self.table_structure_len_max:
-                for table_info in [table_html, table_text]:
-                    if len(key_list) > 0:
-                        prompt = self.table_pe.generate_prompt(
-                            table_info,
-                            key_list,
-                            task_description=table_task_description,
-                            output_format=table_output_format,
-                            rules_str=table_rules_str,
-                            few_shot_demo_text_content=table_few_shot_demo_text_content,
-                            few_shot_demo_key_value_list=table_few_shot_demo_key_value_list,
-                        )
-
-                        self.generate_and_merge_chat_results(
-                            prompt, key_list, final_results, failed_results
-                        )
-
         if len(key_list) > 0:
             if use_vector_retrieval and vector_info is not None:
-                question_key_list = [f"抽取关键信息:{key}" for key in key_list]
+                # question_key_list = [f"抽取关键信息:{key}" for key in key_list]
+                question_key_list = [f"待回答问题:{key}" for key in key_list]
                 vector = vector_info["vector"]
                 if not vector_info["flag_too_short_text"]:
                     related_text = self.retriever.similarity_retrieval(
@@ -530,11 +529,29 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
                     few_shot_demo_text_content=text_few_shot_demo_text_content,
                     few_shot_demo_key_value_list=text_few_shot_demo_key_value_list,
                 )
-                # print(prompt)
                 self.generate_and_merge_chat_results(
                     prompt, key_list, final_results, failed_results
                 )
 
+        if len(key_list) > 0:
+            for table_html, table_text in zip(all_table_html_list, all_table_text_list):
+                if len(table_html) <= min_characters - self.table_structure_len_max:
+                    for table_info in [table_html]:
+                        if len(key_list) > 0:
+                            prompt = self.table_pe.generate_prompt(
+                                table_info,
+                                key_list,
+                                task_description=table_task_description,
+                                output_format=table_output_format,
+                                rules_str=table_rules_str,
+                                few_shot_demo_text_content=table_few_shot_demo_text_content,
+                                few_shot_demo_key_value_list=table_few_shot_demo_key_value_list,
+                            )
+
+                            self.generate_and_merge_chat_results(
+                                prompt, key_list, final_results, failed_results
+                            )
+
         return {"chat_res": final_results}
 
     def predict(self, *args, **kwargs) -> None: