Jelajahi Sumber

fix input params and save_to_img for pp-chatocr (#2874)

* fix input params and save_to_img for layout parsing

* fix input params and save_to_img for pp-chatocr

* fix input params and save_to_img for pp-chatocr

* speed batch predict for ocr pipeline

---------

Co-authored-by: cuicheng01 <45199522+cuicheng01@users.noreply.github.com>
dyning 10 bulan lalu
induk
melakukan
5acb306b1f

+ 9 - 4
api_examples/pipelines/test_pp_chatocrv3.py

@@ -31,12 +31,17 @@ visual_predict_res = pipeline.visual_predict(
     use_table_recognition=True,
 )
 
-# ####[TODO] 增加类别信息
 visual_info_list = []
 for res in visual_predict_res:
-    # res['layout_parsing_result'].save_results("./output/")
-    # print(res["visual_info"])
     visual_info_list.append(res["visual_info"])
+    layout_parsing_result = res["layout_parsing_result"]
+    print(layout_parsing_result)
+    layout_parsing_result.print()
+    layout_parsing_result.save_to_img("./output")
+    layout_parsing_result.save_to_json("./output")
+    layout_parsing_result.save_to_xlsx("./output")
+    layout_parsing_result.save_to_html("./output")
+
 
 pipeline.save_visual_info_list(
     visual_info_list, "./res_visual_info/tmp_visual_info.json"
@@ -46,7 +51,7 @@ visual_info_list = pipeline.load_visual_info_list(
     "./res_visual_info/tmp_visual_info.json"
 )
 
-vector_info = pipeline.build_vector(visual_info_list)
+vector_info = pipeline.build_vector(visual_info_list, flag_save_bytes_vector=True)
 
 pipeline.save_vector(vector_info, "./res_visual_info/tmp_vector_info.json")
 

+ 8 - 4
api_examples/pipelines/test_pp_chatocrv4.py

@@ -51,12 +51,16 @@ visual_predict_res = pipeline.visual_predict(
     use_table_recognition=True,
 )
 
-# ####[TODO] 增加类别信息
 visual_info_list = []
 for res in visual_predict_res:
-    # res['layout_parsing_result'].save_results("./output/")
-    # print(res["visual_info"])
     visual_info_list.append(res["visual_info"])
+    layout_parsing_result = res["layout_parsing_result"]
+    print(layout_parsing_result)
+    layout_parsing_result.print()
+    layout_parsing_result.save_to_img("./output")
+    layout_parsing_result.save_to_json("./output")
+    layout_parsing_result.save_to_xlsx("./output")
+    layout_parsing_result.save_to_html("./output")
 
 pipeline.save_visual_info_list(
     visual_info_list, "./res_visual_info/tmp_visual_info.json"
@@ -66,7 +70,7 @@ visual_info_list = pipeline.load_visual_info_list(
     "./res_visual_info/tmp_visual_info.json"
 )
 
-vector_info = pipeline.build_vector(visual_info_list)
+vector_info = pipeline.build_vector(visual_info_list, flag_save_bytes_vector=True)
 
 pipeline.save_vector(vector_info, "./res_visual_info/tmp_vector_info.json")
 

+ 21 - 8
paddlex/configs/pipelines/PP-ChatOCRv3-doc.yaml

@@ -1,7 +1,7 @@
 
 pipeline_name: PP-ChatOCRv3-doc
 
-use_layout_parser: False
+use_layout_parser: True
 
 SubModules:
   LLM_Chat:
@@ -18,6 +18,7 @@ SubModules:
     ak: "api_key" # Set this to a real API key
     sk: "secret_key"  # Set this to a real secret key
 
+
   PromptEngneering:
     KIE_CommonText:
       module_name: prompt_engneering
@@ -58,18 +59,18 @@ SubModules:
 SubPipelines:
   LayoutParser:
     pipeline_name: layout_parsing
-    
+
     use_doc_preprocessor: True
     use_general_ocr: True
     use_seal_recognition: True
     use_table_recognition: True
+    use_formula_recognition: False
 
     SubModules:
       LayoutDetection:
         module_name: layout_detection
         model_name: RT-DETR-H_layout_3cls
         model_dir: null
-        batch_size: 1
 
     SubPipelines:
       DocPreprocessor:
@@ -81,27 +82,33 @@ SubPipelines:
             module_name: doc_text_orientation
             model_name: PP-LCNet_x1_0_doc_ori
             model_dir: null
-            batch_size: 1
           DocUnwarping:
             module_name: image_unwarping
             model_name: UVDoc
             model_dir: null
-            batch_size: 1
 
       GeneralOCR:
         pipeline_name: OCR
         text_type: general
+        use_doc_preprocessor: False
+        use_textline_orientation: False
         SubModules:
           TextDetection:
             module_name: text_detection
             model_name: PP-OCRv4_server_det
             model_dir: null
-            batch_size: 1    
+            limit_side_len: 960
+            limit_type: max
+            thresh: 0.3
+            box_thresh: 0.6
+            unclip_ratio: 2.0
+            
           TextRecognition:
             module_name: text_recognition
             model_name: PP-OCRv4_server_rec
             model_dir: null
             batch_size: 1
+            score_thresh: 0
 
       TableRecognition:
         pipeline_name: table_recognition
@@ -113,7 +120,6 @@ SubPipelines:
             module_name: table_structure_recognition
             model_name: SLANet_plus
             model_dir: null
-            batch_size: 1
 
       SealRecognition:
         pipeline_name: seal_recognition
@@ -123,14 +129,21 @@ SubPipelines:
           SealOCR:
             pipeline_name: OCR
             text_type: seal
+            use_doc_preprocessor: False
+            use_textline_orientation: False
             SubModules:
               TextDetection:
                 module_name: seal_text_detection
                 model_name: PP-OCRv4_server_seal_det
                 model_dir: null
-                batch_size: 1    
+                limit_side_len: 736
+                limit_type: min
+                thresh: 0.2
+                box_thresh: 0.6
+                unclip_ratio: 0.5
               TextRecognition:
                 module_name: text_recognition
                 model_name: PP-OCRv4_server_rec
                 model_dir: null
                 batch_size: 1
+                score_thresh: 0

+ 21 - 9
paddlex/configs/pipelines/PP-ChatOCRv4-doc.yaml

@@ -1,7 +1,7 @@
 
 pipeline_name: PP-ChatOCRv4-doc
 
-use_layout_parser: False
+use_layout_parser: True
 
 use_mllm_predict: True
 
@@ -11,7 +11,7 @@ SubModules:
     model_name: ernie-3.5
     api_type: qianfan
     ak: "api_key" # Set this to a real API key
-    sk: "secret_key"  # Set this to a real secret key     
+    sk: "secret_key"  # Set this to a real secret key
 
   LLM_Retriever:
     module_name: retriever
@@ -94,18 +94,18 @@ SubModules:
 SubPipelines:
   LayoutParser:
     pipeline_name: layout_parsing
-    
+
     use_doc_preprocessor: True
     use_general_ocr: True
     use_seal_recognition: True
     use_table_recognition: True
+    use_formula_recognition: False
 
     SubModules:
       LayoutDetection:
         module_name: layout_detection
         model_name: RT-DETR-H_layout_3cls
         model_dir: null
-        batch_size: 1
 
     SubPipelines:
       DocPreprocessor:
@@ -117,27 +117,33 @@ SubPipelines:
             module_name: doc_text_orientation
             model_name: PP-LCNet_x1_0_doc_ori
             model_dir: null
-            batch_size: 1
           DocUnwarping:
             module_name: image_unwarping
             model_name: UVDoc
             model_dir: null
-            batch_size: 1
 
       GeneralOCR:
         pipeline_name: OCR
         text_type: general
+        use_doc_preprocessor: False
+        use_textline_orientation: False
         SubModules:
           TextDetection:
             module_name: text_detection
             model_name: PP-OCRv4_server_det
             model_dir: null
-            batch_size: 1    
+            limit_side_len: 960
+            limit_type: max
+            thresh: 0.3
+            box_thresh: 0.6
+            unclip_ratio: 2.0
+            
           TextRecognition:
             module_name: text_recognition
             model_name: PP-OCRv4_server_rec
             model_dir: null
             batch_size: 1
+            score_thresh: 0
 
       TableRecognition:
         pipeline_name: table_recognition
@@ -149,7 +155,6 @@ SubPipelines:
             module_name: table_structure_recognition
             model_name: SLANet_plus
             model_dir: null
-            batch_size: 1
 
       SealRecognition:
         pipeline_name: seal_recognition
@@ -159,14 +164,21 @@ SubPipelines:
           SealOCR:
             pipeline_name: OCR
             text_type: seal
+            use_doc_preprocessor: False
+            use_textline_orientation: False
             SubModules:
               TextDetection:
                 module_name: seal_text_detection
                 model_name: PP-OCRv4_server_seal_det
                 model_dir: null
-                batch_size: 1    
+                limit_side_len: 736
+                limit_type: min
+                thresh: 0.2
+                box_thresh: 0.6
+                unclip_ratio: 0.5
               TextRecognition:
                 module_name: text_recognition
                 model_name: PP-OCRv4_server_rec
                 model_dir: null
                 batch_size: 1
+                score_thresh: 0

+ 8 - 1
paddlex/inference/pipelines_new/__init__.py

@@ -83,7 +83,7 @@ def load_pipeline_config(pipeline_name: str) -> Dict[str, Any]:
     Raises:
         Exception: If the config file of pipeline does not exist.
     """
-    if not Path(pipeline_name).exists():
+    if not (pipeline_name.endswith(".yml") or pipeline_name.endswith(".yaml")):
         pipeline_path = get_pipeline_path(pipeline_name)
         if pipeline_path is None:
             raise Exception(
@@ -150,6 +150,9 @@ def create_chat_bot(config: Dict, *args, **kwargs) -> BaseChat:
     Returns:
         BaseChat: An instance of the chat bot class corresponding to the 'model_name' in the config.
     """
+    if "chat_bot_config_error" in config:
+        raise ValueError(config["chat_bot_config_error"])
+
     api_type = config["api_type"]
     chat_bot = BaseChat.get(api_type)(config)
     return chat_bot
@@ -171,6 +174,8 @@ def create_retriever(
     Returns:
         BaseRetriever: An instance of a retriever class corresponding to the 'model_name' in the config.
     """
+    if "retriever_config_error" in config:
+        raise ValueError(config["retriever_config_error"])
     api_type = config["api_type"]
     retriever = BaseRetriever.get(api_type)(config)
     return retriever
@@ -192,6 +197,8 @@ def create_prompt_engeering(
     Returns:
         BaseGeneratePrompt: An instance of a prompt engineering class corresponding to the 'task_type' in the config.
     """
+    if "pe_config_error" in config:
+        raise ValueError(config["pe_config_error"])
     task_type = config["task_type"]
     pe = BaseGeneratePrompt.get(task_type)(config)
     return pe

+ 0 - 1
paddlex/inference/pipelines_new/base.py

@@ -89,7 +89,6 @@ class BasePipeline(ABC, metaclass=AutoRegisterABCMetaClass):
             device=self.device,
             pp_option=self.pp_option,
             use_hpip=self.use_hpip,
-            hpi_params=hpi_params,
             **kwargs,
         )
         return model

+ 1 - 1
paddlex/inference/pipelines_new/doc_preprocessor/pipeline.py

@@ -165,7 +165,7 @@ class DocPreprocessorPipeline(BasePipeline):
         for img_id, batch_data in enumerate(self.batch_sampler(input)):
             if not isinstance(batch_data[0], str):
                 # TODO: add support input_pth for ndarray and pdf
-                input_path = f"{img_id}"
+                input_path = f"{img_id}.jpg"
             else:
                 input_path = batch_data[0]
 

+ 1 - 1
paddlex/inference/pipelines_new/layout_parsing/pipeline.py

@@ -295,7 +295,7 @@ class LayoutParsingPipeline(BasePipeline):
         for img_id, batch_data in enumerate(self.batch_sampler(input)):
             if not isinstance(batch_data[0], str):
                 # TODO: add support input_pth for ndarray and pdf
-                input_path = f"{img_id}"
+                input_path = f"{img_id}.jpg"
             else:
                 input_path = batch_data[0]
 

+ 19 - 5
paddlex/inference/pipelines_new/ocr/pipeline.py

@@ -315,7 +315,7 @@ class OCRPipeline(BasePipeline):
         for img_id, batch_data in enumerate(self.batch_sampler(input)):
             if not isinstance(batch_data[0], str):
                 # TODO: add support input_pth for ndarray and pdf
-                input_path = f"{img_id}"
+                input_path = f"{img_id}.jpg"
             else:
                 input_path = batch_data[0]
 
@@ -373,9 +373,24 @@ class OCRPipeline(BasePipeline):
                     angles = [-1] * len(all_subs_of_img)
                 single_img_res["textline_orientation_angles"] = angles
 
-                rno = -1
-                for rec_res in self.text_rec_model(all_subs_of_img):
-                    rno += 1
+                sub_img_info_list = [
+                    {
+                        "sub_img_id": img_id,
+                        "sub_img_ratio": sub_img.shape[1] / float(sub_img.shape[0]),
+                    }
+                    for img_id, sub_img in enumerate(all_subs_of_img)
+                ]
+                sorted_subs_info = sorted(
+                    sub_img_info_list, key=lambda x: x["sub_img_ratio"]
+                )
+                sorted_subs_of_img = [
+                    all_subs_of_img[x["sub_img_id"]] for x in sorted_subs_info
+                ]
+                for idx, rec_res in enumerate(self.text_rec_model(sorted_subs_of_img)):
+                    sub_img_id = sorted_subs_info[idx]["sub_img_id"]
+                    sub_img_info_list[sub_img_id]["rec_res"] = rec_res
+                for sno in range(len(sub_img_info_list)):
+                    rec_res = sub_img_info_list[sno]["rec_res"]
                     if rec_res["rec_score"] >= text_rec_score_thresh:
                         single_img_res["rec_texts"].append(rec_res["rec_text"])
                         single_img_res["rec_scores"].append(rec_res["rec_score"])
@@ -385,5 +400,4 @@ class OCRPipeline(BasePipeline):
                 single_img_res["rec_boxes"] = rec_boxes
             else:
                 single_img_res["rec_boxes"] = np.array([])
-                
             yield OCRResult(single_img_res)

+ 92 - 53
paddlex/inference/pipelines_new/pp_chatocr/pipeline_v3.py

@@ -18,7 +18,6 @@ import json
 import numpy as np
 import copy
 from .pipeline_base import PP_ChatOCR_Pipeline
-from .result import VisualInfoResult
 from ...common.reader import ReadImage
 from ...common.batch_sampler import ImageBatchSampler
 from ....utils import logging
@@ -70,36 +69,54 @@ class PP_ChatOCRv3_Pipeline(PP_ChatOCR_Pipeline):
             None
         """
 
-        self.use_layout_parser = True
-        if "use_layout_parser" in config:
-            self.use_layout_parser = config["use_layout_parser"]
-
+        self.use_layout_parser = config.get("use_layout_parser", True)
         if self.use_layout_parser:
-            layout_parsing_config = config["SubPipelines"]["LayoutParser"]
+            layout_parsing_config = config.get("SubPipelines", {}).get(
+                "LayoutParser",
+                {"pipeline_config_error": "config error for layout_parsing_pipeline!"},
+            )
             self.layout_parsing_pipeline = self.create_pipeline(layout_parsing_config)
 
         from .. import create_chat_bot
 
-        chat_bot_config = config["SubModules"]["LLM_Chat"]
+        chat_bot_config = config.get("SubModules", {}).get(
+            "LLM_Chat",
+            {"chat_bot_config_error": "config error for llm chat bot!"},
+        )
         self.chat_bot = create_chat_bot(chat_bot_config)
 
         from .. import create_retriever
 
-        retriever_config = config["SubModules"]["LLM_Retriever"]
+        retriever_config = config.get("SubModules", {}).get(
+            "LLM_Retriever",
+            {"retriever_config_error": "config error for llm retriever!"},
+        )
         self.retriever = create_retriever(retriever_config)
 
         from .. import create_prompt_engeering
 
-        text_pe_config = config["SubModules"]["PromptEngneering"]["KIE_CommonText"]
+        text_pe_config = (
+            config.get("SubModules", {})
+            .get("PromptEngneering", {})
+            .get(
+                "KIE_CommonText",
+                {"pe_config_error": "config error for text_pe!"},
+            )
+        )
         self.text_pe = create_prompt_engeering(text_pe_config)
 
-        table_pe_config = config["SubModules"]["PromptEngneering"]["KIE_Table"]
+        table_pe_config = (
+            config.get("SubModules", {})
+            .get("PromptEngneering", {})
+            .get(
+                "KIE_Table",
+                {"pe_config_error": "config error for table_pe!"},
+            )
+        )
         self.table_pe = create_prompt_engeering(table_pe_config)
         return
 
-    def decode_visual_result(
-        self, layout_parsing_result: LayoutParsingResult
-    ) -> VisualInfoResult:
+    def decode_visual_result(self, layout_parsing_result: LayoutParsingResult) -> dict:
         """
         Decodes the visual result from the layout parsing result.
 
@@ -107,21 +124,21 @@ class PP_ChatOCRv3_Pipeline(PP_ChatOCR_Pipeline):
             layout_parsing_result (LayoutParsingResult): The result of layout parsing.
 
         Returns:
-            VisualInfoResult: The decoded visual information.
+            dict: The decoded visual information.
         """
         text_paragraphs_ocr_res = layout_parsing_result["text_paragraphs_ocr_res"]
         seal_res_list = layout_parsing_result["seal_res_list"]
         normal_text_dict = {}
 
         for seal_res in seal_res_list:
-            for text in seal_res["rec_text"]:
+            for text in seal_res["rec_texts"]:
                 layout_type = "印章"
                 if layout_type not in normal_text_dict:
                     normal_text_dict[layout_type] = f"{text}"
                 else:
                     normal_text_dict[layout_type] += f"\n {text}"
 
-        for text in text_paragraphs_ocr_res["rec_text"]:
+        for text in text_paragraphs_ocr_res["rec_texts"]:
             layout_type = "words in text block"
             if layout_type not in normal_text_dict:
                 normal_text_dict[layout_type] = text
@@ -133,24 +150,36 @@ class PP_ChatOCRv3_Pipeline(PP_ChatOCR_Pipeline):
         table_html_list = []
         for table_res in table_res_list:
             table_html_list.append(table_res["pred_html"])
-            single_table_text = " ".join(table_res["table_ocr_pred"]["rec_text"])
+            single_table_text = " ".join(table_res["table_ocr_pred"]["rec_texts"])
             table_text_list.append(single_table_text)
 
         visual_info = {}
         visual_info["normal_text_dict"] = normal_text_dict
         visual_info["table_text_list"] = table_text_list
         visual_info["table_html_list"] = table_html_list
-        return VisualInfoResult(visual_info)
+        return visual_info
 
     # Function to perform visual prediction on input images
     def visual_predict(
         self,
         input: str | list[str] | np.ndarray | list[np.ndarray],
-        use_doc_orientation_classify: bool = False,  # Whether to use document orientation classification
-        use_doc_unwarping: bool = False,  # Whether to use document unwarping
-        use_general_ocr: bool = True,  # Whether to use general OCR
-        use_seal_recognition: bool = True,  # Whether to use seal recognition
-        use_table_recognition: bool = True,  # Whether to use table recognition
+        use_doc_orientation_classify: Optional[bool] = None,
+        use_doc_unwarping: Optional[bool] = None,
+        use_general_ocr: Optional[bool] = None,
+        use_seal_recognition: Optional[bool] = None,
+        use_table_recognition: Optional[bool] = None,
+        text_det_limit_side_len: Optional[int] = None,
+        text_det_limit_type: Optional[str] = None,
+        text_det_thresh: Optional[float] = None,
+        text_det_box_thresh: Optional[float] = None,
+        text_det_unclip_ratio: Optional[float] = None,
+        text_rec_score_thresh: Optional[float] = None,
+        seal_det_limit_side_len: Optional[int] = None,
+        seal_det_limit_type: Optional[str] = None,
+        seal_det_thresh: Optional[float] = None,
+        seal_det_box_thresh: Optional[float] = None,
+        seal_det_unclip_ratio: Optional[float] = None,
+        seal_rec_score_thresh: Optional[float] = None,
         **kwargs,
     ) -> dict:
         """
@@ -174,7 +203,7 @@ class PP_ChatOCRv3_Pipeline(PP_ChatOCR_Pipeline):
 
         if self.use_layout_parser == False:
             logging.error("The models for layout parser are not initialized.")
-            yield None
+            yield {"error": "The models for layout parser are not initialized."}
 
         for layout_parsing_result in self.layout_parsing_pipeline.predict(
             input,
@@ -183,6 +212,18 @@ class PP_ChatOCRv3_Pipeline(PP_ChatOCR_Pipeline):
             use_general_ocr=use_general_ocr,
             use_seal_recognition=use_seal_recognition,
             use_table_recognition=use_table_recognition,
+            text_det_limit_side_len=text_det_limit_side_len,
+            text_det_limit_type=text_det_limit_type,
+            text_det_thresh=text_det_thresh,
+            text_det_box_thresh=text_det_box_thresh,
+            text_det_unclip_ratio=text_det_unclip_ratio,
+            text_rec_score_thresh=text_rec_score_thresh,
+            seal_det_box_thresh=seal_det_box_thresh,
+            seal_det_limit_side_len=seal_det_limit_side_len,
+            seal_det_limit_type=seal_det_limit_type,
+            seal_det_thresh=seal_det_thresh,
+            seal_det_unclip_ratio=seal_det_unclip_ratio,
+            seal_rec_score_thresh=seal_rec_score_thresh,
         ):
 
             visual_info = self.decode_visual_result(layout_parsing_result)
@@ -193,14 +234,12 @@ class PP_ChatOCRv3_Pipeline(PP_ChatOCR_Pipeline):
             }
             yield visual_predict_res
 
-    def save_visual_info_list(
-        self, visual_info: VisualInfoResult, save_path: str
-    ) -> None:
+    def save_visual_info_list(self, visual_info: dict, save_path: str) -> None:
         """
         Save the visual info list to the specified file path.
 
         Args:
-            visual_info (VisualInfoResult): The visual info result, which can be a single object or a list of objects.
+            visual_info (dict): The visual info result, which can be a single object or a list of objects.
             save_path (str): The file path to save the visual info list.
 
         Returns:
@@ -215,7 +254,7 @@ class PP_ChatOCRv3_Pipeline(PP_ChatOCR_Pipeline):
             fout.write(json.dumps(visual_info_list, ensure_ascii=False) + "\n")
         return
 
-    def load_visual_info_list(self, data_path: str) -> list[VisualInfoResult]:
+    def load_visual_info_list(self, data_path: str) -> list[dict]:
         """
         Loads visual info list from a JSON file.
 
@@ -223,7 +262,7 @@ class PP_ChatOCRv3_Pipeline(PP_ChatOCR_Pipeline):
             data_path (str): The path to the JSON file containing visual info.
 
         Returns:
-            list[VisualInfoResult]: A list of VisualInfoResult objects parsed from the JSON file.
+            list[dict]: A list of dict objects parsed from the JSON file.
         """
         with open(data_path, "r") as fin:
             data = fin.readline()
@@ -231,13 +270,13 @@ class PP_ChatOCRv3_Pipeline(PP_ChatOCR_Pipeline):
         return visual_info_list
 
     def merge_visual_info_list(
-        self, visual_info_list: list[VisualInfoResult]
+        self, visual_info_list: list[dict]
     ) -> tuple[list, list, list]:
         """
         Merge visual info lists.
 
         Args:
-            visual_info_list (list[VisualInfoResult]): A list of visual info results.
+            visual_info_list (list[dict]): A list of visual info results.
 
         Returns:
             tuple[list, list, list]: A tuple containing four lists, one for normal text dicts,
@@ -259,17 +298,19 @@ class PP_ChatOCRv3_Pipeline(PP_ChatOCR_Pipeline):
 
     def build_vector(
         self,
-        visual_info: VisualInfoResult,
+        visual_info: dict,
         min_characters: int = 3500,
         llm_request_interval: float = 1.0,
+        flag_save_bytes_vector: bool = False,
     ) -> dict:
         """
         Build a vector representation from visual information.
 
         Args:
-            visual_info (VisualInfoResult): The visual information input, can be a single instance or a list of instances.
+            visual_info (dict): The visual information input, can be a single instance or a list of instances.
             min_characters (int): The minimum number of characters required for text processing, defaults to 3500.
             llm_request_interval (float): The interval between LLM requests, defaults to 1.0.
+            flag_save_bytes_vector (bool): Whether to save the vector as bytes, defaults to False.
 
         Returns:
             dict: A dictionary containing the vector info and a flag indicating if the text is too short.
@@ -300,30 +341,23 @@ class PP_ChatOCRv3_Pipeline(PP_ChatOCR_Pipeline):
                 all_items += [f"table:{table_text}"]
 
         all_text_str = "".join(all_items)
-
+        vector_info["flag_save_bytes_vector"] = False
         if len(all_text_str) > min_characters:
             vector_info["flag_too_short_text"] = False
             vector_info["vector"] = self.retriever.generate_vector_database(all_items)
+            if flag_save_bytes_vector:
+                vector_info["vector"] = self.retriever.encode_vector_store_to_bytes(
+                    vector_info["vector"]
+                )
+                vector_info["flag_save_bytes_vector"] = True
         else:
             vector_info["flag_too_short_text"] = True
             vector_info["vector"] = all_items
         return vector_info
 
     def save_vector(self, vector_info: dict, save_path: str) -> None:
-        if "flag_too_short_text" not in vector_info or "vector" not in vector_info:
-            logging.error("Invalid vector info.")
-            return
-        save_vector_info = {}
-        save_vector_info["flag_too_short_text"] = vector_info["flag_too_short_text"]
-        if not vector_info["flag_too_short_text"]:
-            save_vector_info["vector"] = self.retriever.encode_vector_store_to_bytes(
-                vector_info["vector"]
-            )
-        else:
-            save_vector_info["vector"] = vector_info["vector"]
-
         with open(save_path, "w") as fout:
-            fout.write(json.dumps(save_vector_info, ensure_ascii=False) + "\n")
+            fout.write(json.dumps(vector_info, ensure_ascii=False) + "\n")
         return
 
     def load_vector(self, data_path: str) -> dict:
@@ -331,10 +365,15 @@ class PP_ChatOCRv3_Pipeline(PP_ChatOCR_Pipeline):
         with open(data_path, "r") as fin:
             data = fin.readline()
             vector_info = json.loads(data)
-            if "flag_too_short_text" not in vector_info or "vector" not in vector_info:
+            if (
+                "flag_too_short_text" not in vector_info
+                or "flag_save_bytes_vector" not in vector_info
+                or "vector" not in vector_info
+            ):
                 logging.error("Invalid vector info.")
-                return {}
-            if not vector_info["flag_too_short_text"]:
+                return {"error": "Invalid vector info when load vector!"}
+
+            if vector_info["flag_save_bytes_vector"]:
                 vector_info["vector"] = self.retriever.decode_vector_store_from_bytes(
                     vector_info["vector"]
                 )
@@ -444,7 +483,7 @@ class PP_ChatOCRv3_Pipeline(PP_ChatOCR_Pipeline):
     def chat(
         self,
         key_list: str | list[str],
-        visual_info: VisualInfoResult,
+        visual_info: list[dict],
         use_vector_retrieval: bool = True,
         vector_info: dict = None,
         min_characters: int = 3500,
@@ -464,7 +503,7 @@ class PP_ChatOCRv3_Pipeline(PP_ChatOCR_Pipeline):
 
         Args:
             key_list (str | list[str]): A single key or a list of keys to extract information.
-            visual_info (VisualInfoResult): The visual information result.
+            visual_info (dict): The visual information result.
             use_vector_retrieval (bool): Whether to use vector retrieval.
             vector_info (dict): The vector information for retrieval.
             min_characters (int): The minimum number of characters required.

+ 106 - 59
paddlex/inference/pipelines_new/pp_chatocr/pipeline_v4.py

@@ -20,7 +20,6 @@ import base64
 import numpy as np
 import copy
 from .pipeline_base import PP_ChatOCR_Pipeline
-from .result import VisualInfoResult
 from ...common.reader import ReadImage
 from ...common.batch_sampler import ImageBatchSampler
 from ....utils import logging
@@ -72,45 +71,71 @@ class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
             None
         """
 
-        self.use_layout_parser = True
-        if "use_layout_parser" in config:
-            self.use_layout_parser = config["use_layout_parser"]
-
+        self.use_layout_parser = config.get("use_layout_parser", True)
         if self.use_layout_parser:
-            layout_parsing_config = config["SubPipelines"]["LayoutParser"]
+            layout_parsing_config = config.get("SubPipelines", {}).get(
+                "LayoutParser",
+                {"pipeline_config_error": "config error for layout_parsing_pipeline!"},
+            )
             self.layout_parsing_pipeline = self.create_pipeline(layout_parsing_config)
 
         from .. import create_chat_bot
 
-        chat_bot_config = config["SubModules"]["LLM_Chat"]
+        chat_bot_config = config.get("SubModules", {}).get(
+            "LLM_Chat",
+            {"chat_bot_config_error": "config error for llm chat bot!"},
+        )
         self.chat_bot = create_chat_bot(chat_bot_config)
 
         from .. import create_retriever
 
-        retriever_config = config["SubModules"]["LLM_Retriever"]
+        retriever_config = config.get("SubModules", {}).get(
+            "LLM_Retriever",
+            {"retriever_config_error": "config error for llm retriever!"},
+        )
         self.retriever = create_retriever(retriever_config)
 
         from .. import create_prompt_engeering
 
-        text_pe_config = config["SubModules"]["PromptEngneering"]["KIE_CommonText"]
+        text_pe_config = (
+            config.get("SubModules", {})
+            .get("PromptEngneering", {})
+            .get(
+                "KIE_CommonText",
+                {"pe_config_error": "config error for text_pe!"},
+            )
+        )
         self.text_pe = create_prompt_engeering(text_pe_config)
 
-        table_pe_config = config["SubModules"]["PromptEngneering"]["KIE_Table"]
+        table_pe_config = (
+            config.get("SubModules", {})
+            .get("PromptEngneering", {})
+            .get(
+                "KIE_Table",
+                {"pe_config_error": "config error for table_pe!"},
+            )
+        )
         self.table_pe = create_prompt_engeering(table_pe_config)
 
-        self.use_mllm_predict = False
-        if "use_mllm_predict" in config:
-            self.use_mllm_predict = config["use_mllm_predict"]
+        self.use_mllm_predict = config.get("use_mllm_predict", True)
         if self.use_mllm_predict:
-            mllm_chat_bot_config = config["SubModules"]["MLLM_Chat"]
+            mllm_chat_bot_config = config.get("SubModules", {}).get(
+                "MLLM_Chat",
+                {"mllm_chat_bot_config": "config error for mllm chat bot!"},
+            )
             self.mllm_chat_bot = create_chat_bot(mllm_chat_bot_config)
-            ensemble_pe_config = config["SubModules"]["PromptEngneering"]["Ensemble"]
+            ensemble_pe_config = (
+                config.get("SubModules", {})
+                .get("PromptEngneering", {})
+                .get(
+                    "Ensemble",
+                    {"pe_config_error": "config error for ensemble_pe!"},
+                )
+            )
             self.ensemble_pe = create_prompt_engeering(ensemble_pe_config)
         return
 
-    def decode_visual_result(
-        self, layout_parsing_result: LayoutParsingResult
-    ) -> VisualInfoResult:
+    def decode_visual_result(self, layout_parsing_result: LayoutParsingResult) -> dict:
         """
         Decodes the visual result from the layout parsing result.
 
@@ -118,21 +143,21 @@ class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
             layout_parsing_result (LayoutParsingResult): The result of layout parsing.
 
         Returns:
-            VisualInfoResult: The decoded visual information.
+            dict: The decoded visual information.
         """
         text_paragraphs_ocr_res = layout_parsing_result["text_paragraphs_ocr_res"]
         seal_res_list = layout_parsing_result["seal_res_list"]
         normal_text_dict = {}
 
         for seal_res in seal_res_list:
-            for text in seal_res["rec_text"]:
+            for text in seal_res["rec_texts"]:
                 layout_type = "印章"
                 if layout_type not in normal_text_dict:
                     normal_text_dict[layout_type] = f"{text}"
                 else:
                     normal_text_dict[layout_type] += f"\n {text}"
 
-        for text in text_paragraphs_ocr_res["rec_text"]:
+        for text in text_paragraphs_ocr_res["rec_texts"]:
             layout_type = "words in text block"
             if layout_type not in normal_text_dict:
                 normal_text_dict[layout_type] = text
@@ -145,26 +170,38 @@ class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
         table_nei_text_list = []
         for table_res in table_res_list:
             table_html_list.append(table_res["pred_html"])
-            single_table_text = " ".join(table_res["table_ocr_pred"]["rec_text"])
+            single_table_text = " ".join(table_res["table_ocr_pred"]["rec_texts"])
             table_text_list.append(single_table_text)
-            table_nei_text_list.append(table_res["neighbor_text"])
+            table_nei_text_list.append(table_res["neighbor_texts"])
 
         visual_info = {}
         visual_info["normal_text_dict"] = normal_text_dict
         visual_info["table_text_list"] = table_text_list
         visual_info["table_html_list"] = table_html_list
         visual_info["table_nei_text_list"] = table_nei_text_list
-        return VisualInfoResult(visual_info)
+        return visual_info
 
     # Function to perform visual prediction on input images
     def visual_predict(
         self,
         input: str | list[str] | np.ndarray | list[np.ndarray],
-        use_doc_orientation_classify: bool = False,  # Whether to use document orientation classification
-        use_doc_unwarping: bool = False,  # Whether to use document unwarping
-        use_general_ocr: bool = True,  # Whether to use general OCR
-        use_seal_recognition: bool = True,  # Whether to use seal recognition
-        use_table_recognition: bool = True,  # Whether to use table recognition
+        use_doc_orientation_classify: Optional[bool] = None,
+        use_doc_unwarping: Optional[bool] = None,
+        use_general_ocr: Optional[bool] = None,
+        use_seal_recognition: Optional[bool] = None,
+        use_table_recognition: Optional[bool] = None,
+        text_det_limit_side_len: Optional[int] = None,
+        text_det_limit_type: Optional[str] = None,
+        text_det_thresh: Optional[float] = None,
+        text_det_box_thresh: Optional[float] = None,
+        text_det_unclip_ratio: Optional[float] = None,
+        text_rec_score_thresh: Optional[float] = None,
+        seal_det_limit_side_len: Optional[int] = None,
+        seal_det_limit_type: Optional[str] = None,
+        seal_det_thresh: Optional[float] = None,
+        seal_det_box_thresh: Optional[float] = None,
+        seal_det_unclip_ratio: Optional[float] = None,
+        seal_rec_score_thresh: Optional[float] = None,
         **kwargs,
     ) -> dict:
         """
@@ -187,7 +224,7 @@ class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
         """
         if self.use_layout_parser == False:
             logging.error("The models for layout parser are not initialized.")
-            yield None
+            yield {"error": "The models for layout parser are not initialized."}
 
         for layout_parsing_result in self.layout_parsing_pipeline.predict(
             input,
@@ -196,6 +233,18 @@ class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
             use_general_ocr=use_general_ocr,
             use_seal_recognition=use_seal_recognition,
             use_table_recognition=use_table_recognition,
+            text_det_limit_side_len=text_det_limit_side_len,
+            text_det_limit_type=text_det_limit_type,
+            text_det_thresh=text_det_thresh,
+            text_det_box_thresh=text_det_box_thresh,
+            text_det_unclip_ratio=text_det_unclip_ratio,
+            text_rec_score_thresh=text_rec_score_thresh,
+            seal_det_box_thresh=seal_det_box_thresh,
+            seal_det_limit_side_len=seal_det_limit_side_len,
+            seal_det_limit_type=seal_det_limit_type,
+            seal_det_thresh=seal_det_thresh,
+            seal_det_unclip_ratio=seal_det_unclip_ratio,
+            seal_rec_score_thresh=seal_rec_score_thresh,
         ):
 
             visual_info = self.decode_visual_result(layout_parsing_result)
@@ -206,14 +255,12 @@ class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
             }
             yield visual_predict_res
 
-    def save_visual_info_list(
-        self, visual_info: VisualInfoResult, save_path: str
-    ) -> None:
+    def save_visual_info_list(self, visual_info: dict, save_path: str) -> None:
         """
         Save the visual info list to the specified file path.
 
         Args:
-            visual_info (VisualInfoResult): The visual info result, which can be a single object or a list of objects.
+            visual_info (dict): The visual info result, which can be a single object or a list of objects.
             save_path (str): The file path to save the visual info list.
 
         Returns:
@@ -228,7 +275,7 @@ class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
             fout.write(json.dumps(visual_info_list, ensure_ascii=False) + "\n")
         return
 
-    def load_visual_info_list(self, data_path: str) -> list[VisualInfoResult]:
+    def load_visual_info_list(self, data_path: str) -> list[dict]:
         """
         Loads visual info list from a JSON file.
 
@@ -236,7 +283,7 @@ class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
             data_path (str): The path to the JSON file containing visual info.
 
         Returns:
-            list[VisualInfoResult]: A list of VisualInfoResult objects parsed from the JSON file.
+            list[dict]: A list of dict objects parsed from the JSON file.
         """
         with open(data_path, "r") as fin:
             data = fin.readline()
@@ -244,13 +291,13 @@ class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
         return visual_info_list
 
     def merge_visual_info_list(
-        self, visual_info_list: list[VisualInfoResult]
+        self, visual_info_list: list[dict]
     ) -> tuple[list, list, list, list]:
         """
         Merge visual info lists.
 
         Args:
-            visual_info_list (list[VisualInfoResult]): A list of visual info results.
+            visual_info_list (list[dict]): A list of visual info results.
 
         Returns:
             tuple[list, list, list, list]: A tuple containing four lists, one for normal text dicts,
@@ -281,17 +328,19 @@ class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
 
     def build_vector(
         self,
-        visual_info: VisualInfoResult,
+        visual_info: dict,
         min_characters: int = 3500,
         llm_request_interval: float = 1.0,
+        flag_save_bytes_vector: bool = False,
     ) -> dict:
         """
         Build a vector representation from visual information.
 
         Args:
-            visual_info (VisualInfoResult): The visual information input, can be a single instance or a list of instances.
+            visual_info (dict): The visual information input, can be a single instance or a list of instances.
             min_characters (int): The minimum number of characters required for text processing, defaults to 3500.
             llm_request_interval (float): The interval between LLM requests, defaults to 1.0.
+            flag_save_bytes_vector (bool): Whether to save the vector as bytes, defaults to False.
 
         Returns:
             dict: A dictionary containing the vector info and a flag indicating if the text is too short.
@@ -324,30 +373,23 @@ class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
                 all_items += [f"table:{table_text}\t{table_nei_text}"]
 
         all_text_str = "".join(all_items)
-
+        vector_info["flag_save_bytes_vector"] = False
         if len(all_text_str) > min_characters:
             vector_info["flag_too_short_text"] = False
             vector_info["vector"] = self.retriever.generate_vector_database(all_items)
+            if flag_save_bytes_vector:
+                vector_info["vector"] = self.retriever.encode_vector_store_to_bytes(
+                    vector_info["vector"]
+                )
+                vector_info["flag_save_bytes_vector"] = True
         else:
             vector_info["flag_too_short_text"] = True
             vector_info["vector"] = all_items
         return vector_info
 
     def save_vector(self, vector_info: dict, save_path: str) -> None:
-        if "flag_too_short_text" not in vector_info or "vector" not in vector_info:
-            logging.error("Invalid vector info.")
-            return
-        save_vector_info = {}
-        save_vector_info["flag_too_short_text"] = vector_info["flag_too_short_text"]
-        if not vector_info["flag_too_short_text"]:
-            save_vector_info["vector"] = self.retriever.encode_vector_store_to_bytes(
-                vector_info["vector"]
-            )
-        else:
-            save_vector_info["vector"] = vector_info["vector"]
-
         with open(save_path, "w") as fout:
-            fout.write(json.dumps(save_vector_info, ensure_ascii=False) + "\n")
+            fout.write(json.dumps(vector_info, ensure_ascii=False) + "\n")
         return
 
     def load_vector(self, data_path: str) -> dict:
@@ -355,10 +397,15 @@ class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
         with open(data_path, "r") as fin:
             data = fin.readline()
             vector_info = json.loads(data)
-            if "flag_too_short_text" not in vector_info or "vector" not in vector_info:
+            if (
+                "flag_too_short_text" not in vector_info
+                or "flag_save_bytes_vector" not in vector_info
+                or "vector" not in vector_info
+            ):
                 logging.error("Invalid vector info.")
-                return
-            if not vector_info["flag_too_short_text"]:
+                return {"error": "Invalid vector info when load vector!"}
+
+            if vector_info["flag_save_bytes_vector"]:
                 vector_info["vector"] = self.retriever.decode_vector_store_from_bytes(
                     vector_info["vector"]
                 )
@@ -558,7 +605,7 @@ class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
     def chat(
         self,
         key_list: str | list[str],
-        visual_info: VisualInfoResult,
+        visual_info: dict,
         use_vector_retrieval: bool = True,
         vector_info: dict = None,
         min_characters: int = 3500,
@@ -580,7 +627,7 @@ class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
 
         Args:
             key_list (str | list[str]): A single key or a list of keys to extract information.
-            visual_info (VisualInfoResult): The visual information result.
+            visual_info (dict): The visual information result.
             use_vector_retrieval (bool): Whether to use vector retrieval.
             vector_info (dict): The vector information for retrieval.
             min_characters (int): The minimum number of characters required for text processing, defaults to 3500.

+ 0 - 28
paddlex/inference/pipelines_new/pp_chatocr/result.py

@@ -1,28 +0,0 @@
-# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import math
-import random
-import numpy as np
-import cv2
-import PIL
-from PIL import Image, ImageDraw, ImageFont
-from ....utils.fonts import PINGFANG_FONT_FILE_PATH
-from ...common.result import BaseResult
-
-
-class VisualInfoResult(BaseResult):
-    """VisualInfoResult"""
-
-    pass

+ 1 - 1
paddlex/inference/pipelines_new/seal_recognition/pipeline.py

@@ -175,7 +175,7 @@ class SealRecognitionPipeline(BasePipeline):
         for img_id, batch_data in enumerate(self.batch_sampler(input)):
             if not isinstance(batch_data[0], str):
                 # TODO: add support input_pth for ndarray and pdf
-                input_path = f"{img_id}"
+                input_path = f"{img_id}.jpg"
             else:
                 input_path = batch_data[0]
 

+ 1 - 1
paddlex/inference/pipelines_new/table_recognition/pipeline.py

@@ -249,7 +249,7 @@ class TableRecognitionPipeline(BasePipeline):
             if len(match_idx_list) > 0:
                 for idx in match_idx_list:
                     neighbor_text += overall_ocr_res["rec_texts"][idx] + "; "
-        single_table_recognition_res["neighbor_text"] = neighbor_text
+        single_table_recognition_res["neighbor_texts"] = neighbor_text
         return single_table_recognition_res
 
     def predict(