|
|
@@ -60,31 +60,33 @@ class PPChatOCRPipeline(_TableRecPipeline):
|
|
|
recovery=True,
|
|
|
device=None,
|
|
|
predictor_kwargs=None,
|
|
|
+ _build_models=True,
|
|
|
):
|
|
|
super().__init__(
|
|
|
predictor_kwargs=predictor_kwargs,
|
|
|
)
|
|
|
- self._build_predictor(
|
|
|
- layout_model=layout_model,
|
|
|
- text_det_model=text_det_model,
|
|
|
- text_rec_model=text_rec_model,
|
|
|
- table_model=table_model,
|
|
|
- doc_image_ori_cls_model=doc_image_ori_cls_model,
|
|
|
- doc_image_unwarp_model=doc_image_unwarp_model,
|
|
|
- seal_text_det_model=seal_text_det_model,
|
|
|
- llm_name=llm_name,
|
|
|
- llm_params=llm_params,
|
|
|
- )
|
|
|
- self.set_predictor(
|
|
|
- layout_batch_size=layout_batch_size,
|
|
|
- text_det_batch_size=text_det_batch_size,
|
|
|
- text_rec_batch_size=text_rec_batch_size,
|
|
|
- table_batch_size=table_batch_size,
|
|
|
- doc_image_ori_cls_batch_size=doc_image_ori_cls_batch_size,
|
|
|
- doc_image_unwarp_batch_size=doc_image_unwarp_batch_size,
|
|
|
- seal_text_det_batch_size=seal_text_det_batch_size,
|
|
|
- device=device,
|
|
|
- )
|
|
|
+ if _build_models:
|
|
|
+ self._build_predictor(
|
|
|
+ layout_model=layout_model,
|
|
|
+ text_det_model=text_det_model,
|
|
|
+ text_rec_model=text_rec_model,
|
|
|
+ table_model=table_model,
|
|
|
+ doc_image_ori_cls_model=doc_image_ori_cls_model,
|
|
|
+ doc_image_unwarp_model=doc_image_unwarp_model,
|
|
|
+ seal_text_det_model=seal_text_det_model,
|
|
|
+ llm_name=llm_name,
|
|
|
+ llm_params=llm_params,
|
|
|
+ )
|
|
|
+ self.set_predictor(
|
|
|
+ layout_batch_size=layout_batch_size,
|
|
|
+ text_det_batch_size=text_det_batch_size,
|
|
|
+ text_rec_batch_size=text_rec_batch_size,
|
|
|
+ table_batch_size=table_batch_size,
|
|
|
+ doc_image_ori_cls_batch_size=doc_image_ori_cls_batch_size,
|
|
|
+ doc_image_unwarp_batch_size=doc_image_unwarp_batch_size,
|
|
|
+ seal_text_det_batch_size=seal_text_det_batch_size,
|
|
|
+ device=device,
|
|
|
+ )
|
|
|
|
|
|
# get base prompt from yaml info
|
|
|
if task_prompt_yaml:
|
|
|
@@ -127,13 +129,13 @@ class PPChatOCRPipeline(_TableRecPipeline):
|
|
|
else:
|
|
|
self.curve_pipeline = None
|
|
|
if doc_image_ori_cls_model:
|
|
|
- self.oricls_predictor = self._create(doc_image_ori_cls_model)
|
|
|
+ self.doc_image_ori_cls_predictor = self._create(doc_image_ori_cls_model)
|
|
|
else:
|
|
|
- self.oricls_predictor = None
|
|
|
+ self.doc_image_ori_cls_predictor = None
|
|
|
if doc_image_unwarp_model:
|
|
|
- self.uvdoc_predictor = self._create(doc_image_unwarp_model)
|
|
|
+ self.doc_image_unwarp_predictor = self._create(doc_image_unwarp_model)
|
|
|
else:
|
|
|
- self.uvdoc_predictor = None
|
|
|
+ self.doc_image_unwarp_predictor = None
|
|
|
|
|
|
self.img_reader = ReadImage(format="RGB")
|
|
|
self.llm_api = create_llm_api(
|
|
|
@@ -169,19 +171,23 @@ class PPChatOCRPipeline(_TableRecPipeline):
|
|
|
self.curve_pipeline.text_det_model.set_predictor(
|
|
|
batch_size=seal_text_det_batch_size
|
|
|
)
|
|
|
- if self.oricls_predictor and doc_image_ori_cls_batch_size:
|
|
|
- self.oricls_predictor.set_predictor(batch_size=doc_image_ori_cls_batch_size)
|
|
|
- if self.uvdoc_predictor and doc_image_unwarp_batch_size:
|
|
|
- self.uvdoc_predictor.set_predictor(batch_size=doc_image_unwarp_batch_size)
|
|
|
+ if self.doc_image_ori_cls_predictor and doc_image_ori_cls_batch_size:
|
|
|
+ self.doc_image_ori_cls_predictor.set_predictor(
|
|
|
+ batch_size=doc_image_ori_cls_batch_size
|
|
|
+ )
|
|
|
+ if self.doc_image_unwarp_predictor and doc_image_unwarp_batch_size:
|
|
|
+ self.doc_image_unwarp_predictor.set_predictor(
|
|
|
+ batch_size=doc_image_unwarp_batch_size
|
|
|
+ )
|
|
|
|
|
|
if device:
|
|
|
if self.curve_pipeline:
|
|
|
self.curve_pipeline.set_predictor(device=device)
|
|
|
- if self.oricls_predictor:
|
|
|
- self.oricls_predictor.set_predictor(device=device)
|
|
|
- if self.uvdoc_predictor:
|
|
|
- self.uvdoc_predictor.set_predictor(device=device)
|
|
|
- self.layout_batch_size.set_predictor(device=device)
|
|
|
+ if self.doc_image_ori_cls_predictor:
|
|
|
+ self.doc_image_ori_cls_predictor.set_predictor(device=device)
|
|
|
+ if self.doc_image_unwarp_predictor:
|
|
|
+ self.doc_image_unwarp_predictor.set_predictor(device=device)
|
|
|
+ self.layout_predictor.set_predictor(device=device)
|
|
|
self.ocr_pipeline.set_predictor(device=device)
|
|
|
|
|
|
def predict(self, *args, **kwargs):
|
|
|
@@ -200,10 +206,6 @@ class PPChatOCRPipeline(_TableRecPipeline):
|
|
|
**kwargs,
|
|
|
):
|
|
|
self.set_predictor(**kwargs)
|
|
|
- if self.uvdoc_predictor and uvdoc_batch_size:
|
|
|
- self.uvdoc_predictor.set_predictor(
|
|
|
- batch_size=uvdoc_batch_size, device=device
|
|
|
- )
|
|
|
|
|
|
visual_info = {"ocr_text": [], "table_html": [], "table_text": []}
|
|
|
# get all visual result
|
|
|
@@ -237,14 +239,18 @@ class PPChatOCRPipeline(_TableRecPipeline):
|
|
|
use_seal_text_det_model=True,
|
|
|
recovery=True,
|
|
|
):
|
|
|
- # get oricls and uvdoc results
|
|
|
+ # get oricls and unwarp results
|
|
|
img_info_list = list(self.img_reader(inputs))[0]
|
|
|
oricls_results = []
|
|
|
- if self.oricls_predictor and use_doc_image_ori_cls_model:
|
|
|
- oricls_results = get_oriclas_results(img_info_list, self.oricls_predictor)
|
|
|
- uvdoc_results = []
|
|
|
- if self.uvdoc_predictor and use_doc_image_unwarp_model:
|
|
|
- uvdoc_results = get_uvdoc_results(img_info_list, self.uvdoc_predictor)
|
|
|
+ if self.doc_image_ori_cls_predictor and use_doc_image_ori_cls_model:
|
|
|
+ oricls_results = get_oriclas_results(
|
|
|
+ img_info_list, self.doc_image_ori_cls_predictor
|
|
|
+ )
|
|
|
+ unwarp_results = []
|
|
|
+ if self.doc_image_unwarp_predictor and use_doc_image_unwarp_model:
|
|
|
+ unwarp_results = get_unwarp_results(
|
|
|
+ img_info_list, self.doc_image_unwarp_predictor
|
|
|
+ )
|
|
|
img_list = [img_info["img"] for img_info in img_info_list]
|
|
|
for idx, (img_info, layout_pred) in enumerate(
|
|
|
zip(img_info_list, self.layout_predictor(img_list))
|
|
|
@@ -257,14 +263,14 @@ class PPChatOCRPipeline(_TableRecPipeline):
|
|
|
"table_result": StructureTableResult([]),
|
|
|
"structure_result": [],
|
|
|
"oricls_result": TopkResult({}),
|
|
|
- "uvdoc_result": DocTrResult({}),
|
|
|
+ "unwarp_result": DocTrResult({}),
|
|
|
"curve_result": [],
|
|
|
}
|
|
|
- # update oricls and uvdoc result
|
|
|
+ # update oricls and unwarp results
|
|
|
if oricls_results:
|
|
|
single_img_res["oricls_result"] = oricls_results[idx]
|
|
|
- if uvdoc_results:
|
|
|
- single_img_res["uvdoc_result"] = uvdoc_results[idx]
|
|
|
+ if unwarp_results:
|
|
|
+ single_img_res["unwarp_result"] = unwarp_results[idx]
|
|
|
# update layout result
|
|
|
single_img_res["input_path"] = layout_pred["input_path"]
|
|
|
single_img_res["layout_result"] = layout_pred
|