|
|
@@ -109,6 +109,10 @@ class LayoutParsingPipelineV2(BasePipeline):
|
|
|
"use_formula_recognition",
|
|
|
True,
|
|
|
)
|
|
|
+ self.use_chart_recognition = config.get(
|
|
|
+ "use_chart_recognition",
|
|
|
+ True,
|
|
|
+ )
|
|
|
|
|
|
if self.use_doc_preprocessor:
|
|
|
doc_preprocessor_config = config.get("SubPipelines", {}).get(
|
|
|
@@ -194,6 +198,17 @@ class LayoutParsingPipelineV2(BasePipeline):
|
|
|
formula_recognition_config,
|
|
|
)
|
|
|
|
|
|
+ if self.use_chart_recognition:
|
|
|
+ chart_recognition_config = config.get("SubModules", {}).get(
|
|
|
+ "ChartRecognition",
|
|
|
+ {
|
|
|
+ "model_config_error": "config error for block_region_detection_model!"
|
|
|
+ },
|
|
|
+ )
|
|
|
+ self.chart_recognition_model = self.create_model(
|
|
|
+ chart_recognition_config,
|
|
|
+ )
|
|
|
+
|
|
|
return
|
|
|
|
|
|
def get_text_paragraphs_ocr_res(
|
|
|
@@ -698,6 +713,7 @@ class LayoutParsingPipelineV2(BasePipeline):
|
|
|
layout_det_res: DetResult,
|
|
|
table_res_list: list,
|
|
|
seal_res_list: list,
|
|
|
+ chart_res_list: list,
|
|
|
text_rec_model: Any,
|
|
|
text_rec_score_thresh: Union[float, None] = None,
|
|
|
) -> list:
|
|
|
@@ -731,6 +747,7 @@ class LayoutParsingPipelineV2(BasePipeline):
|
|
|
|
|
|
table_index = 0
|
|
|
seal_index = 0
|
|
|
+ chart_index = 0
|
|
|
layout_parsing_blocks: List[LayoutParsingBlock] = []
|
|
|
|
|
|
for box_idx, box_info in enumerate(layout_det_res["boxes"]):
|
|
|
@@ -747,6 +764,9 @@ class LayoutParsingPipelineV2(BasePipeline):
|
|
|
elif label == "seal" and len(seal_res_list) > 0:
|
|
|
block.content = seal_res_list[seal_index]["rec_texts"]
|
|
|
seal_index += 1
|
|
|
+ elif label == "chart" and len(chart_res_list) > 0:
|
|
|
+ block.content = chart_res_list[chart_index]
|
|
|
+ chart_index += 1
|
|
|
else:
|
|
|
if label == "formula":
|
|
|
_, ocr_idx_list = get_sub_regions_ocr_res(
|
|
|
@@ -809,6 +829,7 @@ class LayoutParsingPipelineV2(BasePipeline):
|
|
|
overall_ocr_res: OCRResult,
|
|
|
table_res_list: list,
|
|
|
seal_res_list: list,
|
|
|
+ chart_res_list: list,
|
|
|
formula_res_list: list,
|
|
|
text_rec_score_thresh: Union[float, None] = None,
|
|
|
) -> list:
|
|
|
@@ -848,6 +869,7 @@ class LayoutParsingPipelineV2(BasePipeline):
|
|
|
layout_det_res=layout_det_res,
|
|
|
table_res_list=table_res_list,
|
|
|
seal_res_list=seal_res_list,
|
|
|
+ chart_res_list=chart_res_list,
|
|
|
text_rec_model=self.general_ocr_pipeline.text_rec_model,
|
|
|
text_rec_score_thresh=self.general_ocr_pipeline.text_rec_score_thresh,
|
|
|
)
|
|
|
@@ -914,6 +936,9 @@ class LayoutParsingPipelineV2(BasePipeline):
|
|
|
if use_region_detection is None:
|
|
|
use_region_detection = self.use_region_detection
|
|
|
|
|
|
+ if use_chart_recognition is None:
|
|
|
+ use_chart_recognition = self.use_chart_recognition
|
|
|
+
|
|
|
return dict(
|
|
|
use_doc_preprocessor=use_doc_preprocessor,
|
|
|
use_general_ocr=use_general_ocr,
|
|
|
@@ -956,6 +981,8 @@ class LayoutParsingPipelineV2(BasePipeline):
|
|
|
use_table_cells_ocr_results: bool = False,
|
|
|
use_e2e_wired_table_rec_model: bool = False,
|
|
|
use_e2e_wireless_table_rec_model: bool = True,
|
|
|
+ max_new_tokens: int = 1024,
|
|
|
+ no_repeat_ngram_size: int = 20,
|
|
|
is_pretty_markdown: Union[bool, None] = None,
|
|
|
**kwargs,
|
|
|
) -> LayoutParsingResultV2:
|
|
|
@@ -994,6 +1021,9 @@ class LayoutParsingPipelineV2(BasePipeline):
|
|
|
use_table_cells_ocr_results (bool): whether to use OCR results with cells.
|
|
|
use_e2e_wired_table_rec_model (bool): Whether to use end-to-end wired table recognition model.
|
|
|
use_e2e_wireless_table_rec_model (bool): Whether to use end-to-end wireless table recognition model.
|
|
|
+ max_new_tokens: int = 1024,
|
|
|
+ no_repeat_ngram_size: int = 20,
|
|
|
+ is_pretty_markdown,
|
|
|
**kwargs (Any): Additional settings to extend functionality.
|
|
|
|
|
|
Returns:
|
|
|
@@ -1185,6 +1215,24 @@ class LayoutParsingPipelineV2(BasePipeline):
|
|
|
else:
|
|
|
seal_res_list = []
|
|
|
|
|
|
+ chart_res_list = []
|
|
|
+ if model_settings["use_chart_recognition"]:
|
|
|
+ chart_imgs_list = []
|
|
|
+ for bbox in layout_det_res["boxes"]:
|
|
|
+ if bbox["label"] == "chart":
|
|
|
+ x_min, y_min, x_max, y_max = bbox["coordinate"]
|
|
|
+ chart_img = doc_preprocessor_image[
|
|
|
+ int(y_min) : int(y_max), int(x_min) : int(x_max), :
|
|
|
+ ]
|
|
|
+ chart_imgs_list.append({"image": chart_img})
|
|
|
+
|
|
|
+ for chart_res_batch in self.chart_recognition_model(
|
|
|
+ input=chart_imgs_list,
|
|
|
+ max_new_tokens=max_new_tokens,
|
|
|
+ no_repeat_ngram_size=no_repeat_ngram_size,
|
|
|
+ ):
|
|
|
+ chart_res_list.append(chart_res_batch["result"])
|
|
|
+
|
|
|
parsing_res_list = self.get_layout_parsing_res(
|
|
|
doc_preprocessor_image,
|
|
|
region_det_res=region_det_res,
|
|
|
@@ -1192,6 +1240,7 @@ class LayoutParsingPipelineV2(BasePipeline):
|
|
|
overall_ocr_res=overall_ocr_res,
|
|
|
table_res_list=table_res_list,
|
|
|
seal_res_list=seal_res_list,
|
|
|
+ chart_res_list=chart_res_list,
|
|
|
formula_res_list=formula_res_list,
|
|
|
text_rec_score_thresh=text_rec_score_thresh,
|
|
|
)
|
|
|
@@ -1211,6 +1260,7 @@ class LayoutParsingPipelineV2(BasePipeline):
|
|
|
"overall_ocr_res": overall_ocr_res,
|
|
|
"table_res_list": table_res_list,
|
|
|
"seal_res_list": seal_res_list,
|
|
|
+ "chart_res_list": chart_res_list,
|
|
|
"formula_res_list": formula_res_list,
|
|
|
"parsing_res_list": parsing_res_list,
|
|
|
"imgs_in_doc": imgs_in_doc,
|