zhouchangda 1 жил өмнө
parent
commit
0e8b6cec7d

+ 18 - 11
paddlex/inference/pipelines/ppchatocrv3/ppchatocrv3.py

@@ -114,6 +114,7 @@ class PPChatOCRPipeline(TableRecPipeline):
         self.img_reader = ReadImage()
         self.visual_info = None
         self.vector = None
+        self.visual_flag = False
 
     def _build_predictor(self):
         super()._build_predictor()
@@ -197,6 +198,7 @@ class PPChatOCRPipeline(TableRecPipeline):
         visual_info = VisualInfoResult(visual_info)
         # for local user save visual info in self
         self.visual_info = visual_info
+        self.visual_flag = True
 
         return visual_result, visual_info
 
@@ -233,14 +235,13 @@ class PPChatOCRPipeline(TableRecPipeline):
         ):
             single_img_res = {
                 "input_path": "",
-                "layout_result": {},
-                "ocr_result": {},
+                "layout_result": DetResult({}),
+                "ocr_result": OCRResult({}),
                 "table_ocr_result": [],
-                "table_result": [],
+                "table_result": StructureTableResult([]),
                 "structure_result": [],
-                "structure_result": [],
-                "oricls_result": {},
-                "uvdoc_result": {},
+                "oricls_result": TopkResult({}),
+                "uvdoc_result": DocTrResult({}),
                 "curve_result": [],
             }
             # update oricls and uvdoc result
@@ -389,7 +390,7 @@ class PPChatOCRPipeline(TableRecPipeline):
         llm_name=None,
         llm_params={},
         visual_info=None,
-        min_characters=0,
+        min_characters=3500,
         llm_request_interval=1.0,
     ):
         """get vector for ocr"""
@@ -429,6 +430,7 @@ class PPChatOCRPipeline(TableRecPipeline):
                 text_result = self.llm_api.get_vector(ocr_text, llm_request_interval)
         else:
             text_result = str(ocr_text)
+        self.visual_flag = False
 
         return VectorResult({"vector": text_result})
 
@@ -447,11 +449,16 @@ class PPChatOCRPipeline(TableRecPipeline):
 
         key_list = format_key(key_list)
 
+        is_seving = visual_info and llm_name
+
+        if self.visual_flag and not is_seving:
+            self.vector = self.get_vector_text()
+
         if not any([vector, self.vector]):
             logging.warning(
                 "The vector library is not created, and is being created automatically"
             )
-            if visual_info and llm_name:
+            if is_seving:
                 # for serving
                 vector = self.get_vector_text(
                     llm_name=llm_name, llm_params=llm_params, visual_info=visual_info
@@ -543,7 +550,7 @@ class PPChatOCRPipeline(TableRecPipeline):
         if len(key_list) > 0:
             logging.info("get result from ocr")
             if retrieval_result:
-                ocr_text = retrieval_result
+                ocr_text = retrieval_result.get("retrieval")
             elif use_vector and any([visual_info, vector]):
                 # for serving or local
                 ocr_text = self.get_retrieval_text(
@@ -552,10 +559,10 @@ class PPChatOCRPipeline(TableRecPipeline):
                     vector=vector,
                     llm_name=llm_name,
                     llm_params=llm_params,
-                )
+                )["retrieval"]
             else:
                 # for local
-                ocr_text = self.get_retrieval_text(key_list=key_list)
+                ocr_text = self.get_retrieval_text(key_list=key_list)["retrieval"]
             prompt = self.get_prompt_for_ocr(
                 ocr_text,
                 key_list,

+ 10 - 5
paddlex/inference/results/chat_ocr.py

@@ -59,20 +59,25 @@ class VisualResult(BaseResult):
 
         oricls_save_path = f"{save_path}_oricls.jpg"
         oricls_result = self["oricls_result"]
-        oricls_result.save_to_img(oricls_save_path)
+        if oricls_result:
+            oricls_result.save_to_img(oricls_save_path)
         uvdoc_save_path = f"{save_path}_uvdoc.jpg"
         uvdoc_result = self["uvdoc_result"]
-        uvdoc_result.save_to_img(uvdoc_save_path)
+        if uvdoc_result:
+            uvdoc_result.save_to_img(uvdoc_save_path)
         curve_save_path = f"{save_path}_curve.jpg"
-        for curve_result in self["curve_result"]:
+        curve_results = self["curve_result"]
+        for curve_result in curve_results:
             curve_result.save_to_img(curve_save_path)
         layout_save_path = f"{save_path}_layout.jpg"
         layout_result = self["layout_result"]
-        layout_result.save_to_img(layout_save_path)
+        if layout_result:
+            layout_result.save_to_img(layout_save_path)
         ocr_save_path = f"{save_path}_ocr.jpg"
         table_save_path = f"{save_path}_table.jpg"
         ocr_result = self["ocr_result"]
-        ocr_result.save_to_img(ocr_save_path)
+        if ocr_result:
+            ocr_result.save_to_img(ocr_save_path)
         for table_result in self["table_result"]:
             table_result.save_to_img(table_save_path)