|
@@ -114,6 +114,7 @@ class PPChatOCRPipeline(TableRecPipeline):
|
|
|
self.img_reader = ReadImage()
|
|
self.img_reader = ReadImage()
|
|
|
self.visual_info = None
|
|
self.visual_info = None
|
|
|
self.vector = None
|
|
self.vector = None
|
|
|
|
|
+ self.visual_flag = False
|
|
|
|
|
|
|
|
def _build_predictor(self):
|
|
def _build_predictor(self):
|
|
|
super()._build_predictor()
|
|
super()._build_predictor()
|
|
@@ -197,6 +198,7 @@ class PPChatOCRPipeline(TableRecPipeline):
|
|
|
visual_info = VisualInfoResult(visual_info)
|
|
visual_info = VisualInfoResult(visual_info)
|
|
|
# for local user save visual info in self
|
|
# for local user save visual info in self
|
|
|
self.visual_info = visual_info
|
|
self.visual_info = visual_info
|
|
|
|
|
+ self.visual_flag = True
|
|
|
|
|
|
|
|
return visual_result, visual_info
|
|
return visual_result, visual_info
|
|
|
|
|
|
|
@@ -233,14 +235,13 @@ class PPChatOCRPipeline(TableRecPipeline):
|
|
|
):
|
|
):
|
|
|
single_img_res = {
|
|
single_img_res = {
|
|
|
"input_path": "",
|
|
"input_path": "",
|
|
|
- "layout_result": {},
|
|
|
|
|
- "ocr_result": {},
|
|
|
|
|
|
|
+ "layout_result": DetResult({}),
|
|
|
|
|
+ "ocr_result": OCRResult({}),
|
|
|
"table_ocr_result": [],
|
|
"table_ocr_result": [],
|
|
|
- "table_result": [],
|
|
|
|
|
|
|
+ "table_result": StructureTableResult([]),
|
|
|
"structure_result": [],
|
|
"structure_result": [],
|
|
|
- "structure_result": [],
|
|
|
|
|
- "oricls_result": {},
|
|
|
|
|
- "uvdoc_result": {},
|
|
|
|
|
|
|
+ "oricls_result": TopkResult({}),
|
|
|
|
|
+ "uvdoc_result": DocTrResult({}),
|
|
|
"curve_result": [],
|
|
"curve_result": [],
|
|
|
}
|
|
}
|
|
|
# update oricls and uvdoc result
|
|
# update oricls and uvdoc result
|
|
@@ -389,7 +390,7 @@ class PPChatOCRPipeline(TableRecPipeline):
|
|
|
llm_name=None,
|
|
llm_name=None,
|
|
|
llm_params={},
|
|
llm_params={},
|
|
|
visual_info=None,
|
|
visual_info=None,
|
|
|
- min_characters=0,
|
|
|
|
|
|
|
+ min_characters=3500,
|
|
|
llm_request_interval=1.0,
|
|
llm_request_interval=1.0,
|
|
|
):
|
|
):
|
|
|
"""get vector for ocr"""
|
|
"""get vector for ocr"""
|
|
@@ -429,6 +430,7 @@ class PPChatOCRPipeline(TableRecPipeline):
|
|
|
text_result = self.llm_api.get_vector(ocr_text, llm_request_interval)
|
|
text_result = self.llm_api.get_vector(ocr_text, llm_request_interval)
|
|
|
else:
|
|
else:
|
|
|
text_result = str(ocr_text)
|
|
text_result = str(ocr_text)
|
|
|
|
|
+ self.visual_flag = False
|
|
|
|
|
|
|
|
return VectorResult({"vector": text_result})
|
|
return VectorResult({"vector": text_result})
|
|
|
|
|
|
|
@@ -447,11 +449,16 @@ class PPChatOCRPipeline(TableRecPipeline):
|
|
|
|
|
|
|
|
key_list = format_key(key_list)
|
|
key_list = format_key(key_list)
|
|
|
|
|
|
|
|
|
|
+ is_seving = visual_info and llm_name
|
|
|
|
|
+
|
|
|
|
|
+ if self.visual_flag and not is_seving:
|
|
|
|
|
+ self.vector = self.get_vector_text()
|
|
|
|
|
+
|
|
|
if not any([vector, self.vector]):
|
|
if not any([vector, self.vector]):
|
|
|
logging.warning(
|
|
logging.warning(
|
|
|
"The vector library is not created, and is being created automatically"
|
|
"The vector library is not created, and is being created automatically"
|
|
|
)
|
|
)
|
|
|
- if visual_info and llm_name:
|
|
|
|
|
|
|
+ if is_seving:
|
|
|
# for serving
|
|
# for serving
|
|
|
vector = self.get_vector_text(
|
|
vector = self.get_vector_text(
|
|
|
llm_name=llm_name, llm_params=llm_params, visual_info=visual_info
|
|
llm_name=llm_name, llm_params=llm_params, visual_info=visual_info
|
|
@@ -543,7 +550,7 @@ class PPChatOCRPipeline(TableRecPipeline):
|
|
|
if len(key_list) > 0:
|
|
if len(key_list) > 0:
|
|
|
logging.info("get result from ocr")
|
|
logging.info("get result from ocr")
|
|
|
if retrieval_result:
|
|
if retrieval_result:
|
|
|
- ocr_text = retrieval_result
|
|
|
|
|
|
|
+ ocr_text = retrieval_result.get("retrieval")
|
|
|
elif use_vector and any([visual_info, vector]):
|
|
elif use_vector and any([visual_info, vector]):
|
|
|
# for serving or local
|
|
# for serving or local
|
|
|
ocr_text = self.get_retrieval_text(
|
|
ocr_text = self.get_retrieval_text(
|
|
@@ -552,10 +559,10 @@ class PPChatOCRPipeline(TableRecPipeline):
|
|
|
vector=vector,
|
|
vector=vector,
|
|
|
llm_name=llm_name,
|
|
llm_name=llm_name,
|
|
|
llm_params=llm_params,
|
|
llm_params=llm_params,
|
|
|
- )
|
|
|
|
|
|
|
+ )["retrieval"]
|
|
|
else:
|
|
else:
|
|
|
# for local
|
|
# for local
|
|
|
- ocr_text = self.get_retrieval_text(key_list=key_list)
|
|
|
|
|
|
|
+ ocr_text = self.get_retrieval_text(key_list=key_list)["retrieval"]
|
|
|
prompt = self.get_prompt_for_ocr(
|
|
prompt = self.get_prompt_for_ocr(
|
|
|
ocr_text,
|
|
ocr_text,
|
|
|
key_list,
|
|
key_list,
|