|
|
@@ -21,12 +21,7 @@ from copy import deepcopy
|
|
|
from ...components import *
|
|
|
from ..ocr import OCRPipeline
|
|
|
from ....utils import logging
|
|
|
-from ...results import (
|
|
|
- TableResult,
|
|
|
- LayoutStructureResult,
|
|
|
- VisualInfoResult,
|
|
|
- ChatOCRResult,
|
|
|
-)
|
|
|
+from ...results import *
|
|
|
from ...components.llm import ErnieBot
|
|
|
from ...utils.io import ImageReader, PDFReader
|
|
|
from ..table_recognition import TableRecPipeline
|
|
|
@@ -119,9 +114,6 @@ class PPChatOCRPipeline(TableRecPipeline):
|
|
|
self.img_reader = ReadImage()
|
|
|
self.visual_info = None
|
|
|
self.vector = None
|
|
|
- self._set_predictor(
|
|
|
- oricls_batch_size, uvdoc_batch_size, curve_batch_size, device=device
|
|
|
- )
|
|
|
|
|
|
def _build_predictor(self):
|
|
|
super()._build_predictor()
|
|
|
@@ -156,9 +148,29 @@ class PPChatOCRPipeline(TableRecPipeline):
|
|
|
batch_size=self.uvdoc_batch_size, device=self.device
|
|
|
)
|
|
|
|
|
|
- def _set_predictor(
|
|
|
- self, curve_batch_size, oricls_batch_size, uvdoc_batch_size, device
|
|
|
+ def set_predictor(
|
|
|
+ self,
|
|
|
+ layout_batch_size=None,
|
|
|
+ text_det_batch_size=None,
|
|
|
+ text_rec_batch_size=None,
|
|
|
+ table_batch_size=None,
|
|
|
+ curve_batch_size=None,
|
|
|
+ oricls_batch_size=None,
|
|
|
+ uvdoc_batch_size=None,
|
|
|
+ device=None,
|
|
|
):
|
|
|
+ if text_det_batch_size and text_det_batch_size > 1:
|
|
|
+ logging.warning(
|
|
|
+ f"text det model only support batch_size=1 now,the setting of text_det_batch_size={text_det_batch_size} will not using! "
|
|
|
+ )
|
|
|
+ if layout_batch_size:
|
|
|
+ self.layout_predictor.set_predictor(batch_size=layout_batch_size)
|
|
|
+ if text_rec_batch_size:
|
|
|
+ self.ocr_pipeline.text_rec_model.set_predictor(
|
|
|
+ batch_size=text_rec_batch_size
|
|
|
+ )
|
|
|
+ if table_batch_size:
|
|
|
+ self.table_predictor.set_predictor(batch_size=table_batch_size)
|
|
|
if self.curve_pipeline and curve_batch_size:
|
|
|
self.curve_pipeline.text_det_model.set_predictor(
|
|
|
batch_size=curve_batch_size, device=device
|
|
|
@@ -189,12 +201,23 @@ class PPChatOCRPipeline(TableRecPipeline):
|
|
|
return visual_result, visual_info
|
|
|
|
|
|
def get_visual_result(self, inputs, **kwargs):
|
|
|
+ layout_batch_size = kwargs.get("layout_batch_size")
|
|
|
+ text_det_batch_size = kwargs.get("text_det_batch_size")
|
|
|
+ text_rec_batch_size = kwargs.get("text_rec_batch_size")
|
|
|
+ table_batch_size = kwargs.get("table_batch_size")
|
|
|
curve_batch_size = kwargs.get("curve_batch_size")
|
|
|
oricls_batch_size = kwargs.get("oricls_batch_size")
|
|
|
uvdoc_batch_size = kwargs.get("uvdoc_batch_size")
|
|
|
device = kwargs.get("device")
|
|
|
- self._set_predictor(
|
|
|
- curve_batch_size, oricls_batch_size, uvdoc_batch_size, device
|
|
|
+ self.set_predictor(
|
|
|
+ layout_batch_size,
|
|
|
+ text_det_batch_size,
|
|
|
+ text_rec_batch_size,
|
|
|
+ table_batch_size,
|
|
|
+ curve_batch_size,
|
|
|
+ oricls_batch_size,
|
|
|
+ uvdoc_batch_size,
|
|
|
+ device,
|
|
|
)
|
|
|
# get oricls and uvdoc results
|
|
|
img_info_list = list(self.img_reader(inputs))[0]
|
|
|
@@ -229,13 +252,13 @@ class PPChatOCRPipeline(TableRecPipeline):
|
|
|
single_img_res["input_path"] = layout_pred["input_path"]
|
|
|
single_img_res["layout_result"] = layout_pred
|
|
|
single_img = img_info["img"]
|
|
|
+ table_subs = []
|
|
|
+ curve_subs = []
|
|
|
+ structure_res = []
|
|
|
+ ocr_res_with_layout = []
|
|
|
if len(layout_pred["boxes"]) > 0:
|
|
|
subs_of_img = list(self._crop_by_boxes(layout_pred))
|
|
|
- # get cropped images with label "table"
|
|
|
- table_subs = []
|
|
|
- curve_subs = []
|
|
|
- structure_res = []
|
|
|
- ocr_res_with_layout = []
|
|
|
+ # get cropped images
|
|
|
for sub in subs_of_img:
|
|
|
box = sub["box"]
|
|
|
xmin, ymin, xmax, ymax = [int(i) for i in box]
|
|
|
@@ -284,7 +307,8 @@ class PPChatOCRPipeline(TableRecPipeline):
|
|
|
|
|
|
all_curve_res = get_ocr_res(curve_pipeline, curve_subs)
|
|
|
single_img_res["curve_result"] = all_curve_res
|
|
|
-
|
|
|
+ if isinstance(all_curve_res, dict):
|
|
|
+ all_curve_res = [all_curve_res]
|
|
|
for sub, curve_res in zip(curve_subs, all_curve_res):
|
|
|
structure_res.append(
|
|
|
{
|
|
|
@@ -325,7 +349,7 @@ class PPChatOCRPipeline(TableRecPipeline):
|
|
|
single_img_res["table_ocr_result"] = all_table_ocr_res
|
|
|
single_img_res["structure_result"] = structure_res
|
|
|
|
|
|
- yield ChatOCRResult(single_img_res)
|
|
|
+ yield VisualResult(single_img_res)
|
|
|
|
|
|
def decode_visual_result(self, visual_result):
|
|
|
ocr_text = []
|
|
|
@@ -375,7 +399,7 @@ class PPChatOCRPipeline(TableRecPipeline):
|
|
|
logging.warning("Do not use ErnieBot, will not get vector text.")
|
|
|
get_vector_flag = False
|
|
|
if not any([visual_info, self.visual_info]):
|
|
|
- return {"vector": None}
|
|
|
+ return VectorResult({"vector": None})
|
|
|
|
|
|
if visual_info:
|
|
|
# use for serving or local
|
|
|
@@ -406,7 +430,7 @@ class PPChatOCRPipeline(TableRecPipeline):
|
|
|
else:
|
|
|
text_result = str(ocr_text)
|
|
|
|
|
|
- return {"vector": text_result}
|
|
|
+ return VectorResult({"vector": text_result})
|
|
|
|
|
|
def get_retrieval_text(
|
|
|
self,
|
|
|
@@ -419,7 +443,7 @@ class PPChatOCRPipeline(TableRecPipeline):
|
|
|
):
|
|
|
|
|
|
if not any([visual_info, vector, self.visual_info, self.vector]):
|
|
|
- return {"retrieval": None}
|
|
|
+ return RetrievalResult({"retrieval": None})
|
|
|
|
|
|
key_list = format_key(key_list)
|
|
|
|
|
|
@@ -450,7 +474,7 @@ class PPChatOCRPipeline(TableRecPipeline):
|
|
|
vector=_vector, key_list=key_list, sleep_time=llm_request_interval
|
|
|
)
|
|
|
|
|
|
- return {"retrieval": retrieval}
|
|
|
+ return RetrievalResult({"retrieval": retrieval})
|
|
|
|
|
|
def chat(
|
|
|
self,
|
|
|
@@ -473,7 +497,9 @@ class PPChatOCRPipeline(TableRecPipeline):
|
|
|
if not any(
|
|
|
[vector, visual_info, retrieval_result, self.visual_info, self.vector]
|
|
|
):
|
|
|
- return {"chat_res": "请先完成图像解析再开始再对话", "prompt": ""}
|
|
|
+ return ChatResult(
|
|
|
+ {"chat_res": "请先完成图像解析再开始再对话", "prompt": ""}
|
|
|
+ )
|
|
|
key_list = format_key(key_list)
|
|
|
# first get from table, then get from text in table, last get from all ocr
|
|
|
if visual_info:
|
|
|
@@ -486,20 +512,6 @@ class PPChatOCRPipeline(TableRecPipeline):
|
|
|
ocr_text = _visual_info["ocr_text"]
|
|
|
html_list = _visual_info["table_html"]
|
|
|
table_text_list = _visual_info["table_text"]
|
|
|
- if retrieval_result:
|
|
|
- ocr_text = retrieval_result
|
|
|
- elif use_vector and any([visual_info, vector]):
|
|
|
- # for serving or local
|
|
|
- ocr_text = self.get_retrieval_text(
|
|
|
- key_list=key_list,
|
|
|
- visual_info=visual_info,
|
|
|
- vector=vector,
|
|
|
- llm_name=llm_name,
|
|
|
- llm_params=llm_params,
|
|
|
- )
|
|
|
- else:
|
|
|
- # for local
|
|
|
- ocr_text = self.get_retrieval_text(key_list=key_list)
|
|
|
|
|
|
prompt_res = {"ocr_prompt": "str", "table_prompt": [], "html_prompt": []}
|
|
|
|
|
|
@@ -530,6 +542,20 @@ class PPChatOCRPipeline(TableRecPipeline):
|
|
|
final_results[key] = value
|
|
|
if len(key_list) > 0:
|
|
|
logging.info("get result from ocr")
|
|
|
+ if retrieval_result:
|
|
|
+ ocr_text = retrieval_result
|
|
|
+ elif use_vector and any([visual_info, vector]):
|
|
|
+ # for serving or local
|
|
|
+ ocr_text = self.get_retrieval_text(
|
|
|
+ key_list=key_list,
|
|
|
+ visual_info=visual_info,
|
|
|
+ vector=vector,
|
|
|
+ llm_name=llm_name,
|
|
|
+ llm_params=llm_params,
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ # for local
|
|
|
+ ocr_text = self.get_retrieval_text(key_list=key_list)
|
|
|
prompt = self.get_prompt_for_ocr(
|
|
|
ocr_text,
|
|
|
key_list,
|
|
|
@@ -545,9 +571,9 @@ class PPChatOCRPipeline(TableRecPipeline):
|
|
|
if not res and not final_results:
|
|
|
final_results = self.llm_api.ERROR_MASSAGE
|
|
|
if save_prompt:
|
|
|
- return {"chat_res": final_results, "prompt": prompt_res}
|
|
|
+ return ChatResult({"chat_res": final_results, "prompt": prompt_res})
|
|
|
else:
|
|
|
- return {"chat_res": final_results, "prompt": ""}
|
|
|
+ return ChatResult({"chat_res": final_results, "prompt": ""})
|
|
|
|
|
|
def get_llm_result(self, prompt):
|
|
|
"""get llm result and decode to dict"""
|