ソースを参照

support to recognize chart to table

gaotingquan 6 ヶ月 前
コミット
42304b3b4b

+ 5 - 0
paddlex/configs/pipelines/PP-StructureV3.yaml

@@ -85,6 +85,11 @@ SubModules:
       20: "union" # header_image
       21: "union" # footer_image
       22: "union" # aside_text
+  ChartRecognition:
+    module_name: chart_recognition
+    model_name: PP-Chart2Table
+    model_dir: null
+    batch_size: 1 
 
 SubPipelines:
   DocPreprocessor:

+ 50 - 0
paddlex/inference/pipelines/layout_parsing/pipeline_v2.py

@@ -109,6 +109,10 @@ class LayoutParsingPipelineV2(BasePipeline):
             "use_formula_recognition",
             True,
         )
+        self.use_chart_recognition = config.get(
+            "use_chart_recognition",
+            True,
+        )
 
         if self.use_doc_preprocessor:
             doc_preprocessor_config = config.get("SubPipelines", {}).get(
@@ -194,6 +198,17 @@ class LayoutParsingPipelineV2(BasePipeline):
                 formula_recognition_config,
             )
 
+        if self.use_chart_recognition:
+            chart_recognition_config = config.get("SubModules", {}).get(
+                "ChartRecognition",
+                {
+                    "model_config_error": "config error for block_region_detection_model!"
+                },
+            )
+            self.chart_recognition_model = self.create_model(
+                chart_recognition_config,
+            )
+
         return
 
     def get_text_paragraphs_ocr_res(
@@ -698,6 +713,7 @@ class LayoutParsingPipelineV2(BasePipeline):
         layout_det_res: DetResult,
         table_res_list: list,
         seal_res_list: list,
+        chart_res_list: list,
         text_rec_model: Any,
         text_rec_score_thresh: Union[float, None] = None,
     ) -> list:
@@ -731,6 +747,7 @@ class LayoutParsingPipelineV2(BasePipeline):
 
         table_index = 0
         seal_index = 0
+        chart_index = 0
         layout_parsing_blocks: List[LayoutParsingBlock] = []
 
         for box_idx, box_info in enumerate(layout_det_res["boxes"]):
@@ -747,6 +764,9 @@ class LayoutParsingPipelineV2(BasePipeline):
             elif label == "seal" and len(seal_res_list) > 0:
                 block.content = seal_res_list[seal_index]["rec_texts"]
                 seal_index += 1
+            elif label == "chart" and len(chart_res_list) > 0:
+                block.content = chart_res_list[chart_index]
+                chart_index += 1
             else:
                 if label == "formula":
                     _, ocr_idx_list = get_sub_regions_ocr_res(
@@ -809,6 +829,7 @@ class LayoutParsingPipelineV2(BasePipeline):
         overall_ocr_res: OCRResult,
         table_res_list: list,
         seal_res_list: list,
+        chart_res_list: list,
         formula_res_list: list,
         text_rec_score_thresh: Union[float, None] = None,
     ) -> list:
@@ -848,6 +869,7 @@ class LayoutParsingPipelineV2(BasePipeline):
             layout_det_res=layout_det_res,
             table_res_list=table_res_list,
             seal_res_list=seal_res_list,
+            chart_res_list=chart_res_list,
             text_rec_model=self.general_ocr_pipeline.text_rec_model,
             text_rec_score_thresh=self.general_ocr_pipeline.text_rec_score_thresh,
         )
@@ -914,6 +936,9 @@ class LayoutParsingPipelineV2(BasePipeline):
         if use_region_detection is None:
             use_region_detection = self.use_region_detection
 
+        if use_chart_recognition is None:
+            use_chart_recognition = self.use_chart_recognition
+
         return dict(
             use_doc_preprocessor=use_doc_preprocessor,
             use_general_ocr=use_general_ocr,
@@ -956,6 +981,8 @@ class LayoutParsingPipelineV2(BasePipeline):
         use_table_cells_ocr_results: bool = False,
         use_e2e_wired_table_rec_model: bool = False,
         use_e2e_wireless_table_rec_model: bool = True,
+        max_new_tokens: int = 1024,
+        no_repeat_ngram_size: int = 20,
         is_pretty_markdown: Union[bool, None] = None,
         **kwargs,
     ) -> LayoutParsingResultV2:
@@ -994,6 +1021,9 @@ class LayoutParsingPipelineV2(BasePipeline):
             use_table_cells_ocr_results (bool): whether to use OCR results with cells.
             use_e2e_wired_table_rec_model (bool): Whether to use end-to-end wired table recognition model.
             use_e2e_wireless_table_rec_model (bool): Whether to use end-to-end wireless table recognition model.
+            max_new_tokens: int = 1024,
+            no_repeat_ngram_size: int = 20,
+            is_pretty_markdown,
             **kwargs (Any): Additional settings to extend functionality.
 
         Returns:
@@ -1185,6 +1215,24 @@ class LayoutParsingPipelineV2(BasePipeline):
             else:
                 seal_res_list = []
 
+            chart_res_list = []
+            if model_settings["use_chart_recognition"]:
+                chart_imgs_list = []
+                for bbox in layout_det_res["boxes"]:
+                    if bbox["label"] == "chart":
+                        x_min, y_min, x_max, y_max = bbox["coordinate"]
+                        chart_img = doc_preprocessor_image[
+                            int(y_min) : int(y_max), int(x_min) : int(x_max), :
+                        ]
+                        chart_imgs_list.append({"image": chart_img})
+
+                for chart_res_batch in self.chart_recognition_model(
+                    input=chart_imgs_list,
+                    max_new_tokens=max_new_tokens,
+                    no_repeat_ngram_size=no_repeat_ngram_size,
+                ):
+                    chart_res_list.append(chart_res_batch["result"])
+
             parsing_res_list = self.get_layout_parsing_res(
                 doc_preprocessor_image,
                 region_det_res=region_det_res,
@@ -1192,6 +1240,7 @@ class LayoutParsingPipelineV2(BasePipeline):
                 overall_ocr_res=overall_ocr_res,
                 table_res_list=table_res_list,
                 seal_res_list=seal_res_list,
+                chart_res_list=chart_res_list,
                 formula_res_list=formula_res_list,
                 text_rec_score_thresh=text_rec_score_thresh,
             )
@@ -1211,6 +1260,7 @@ class LayoutParsingPipelineV2(BasePipeline):
                 "overall_ocr_res": overall_ocr_res,
                 "table_res_list": table_res_list,
                 "seal_res_list": seal_res_list,
+                "chart_res_list": chart_res_list,
                 "formula_res_list": formula_res_list,
                 "parsing_res_list": parsing_res_list,
                 "imgs_in_doc": imgs_in_doc,

+ 10 - 1
paddlex/inference/pipelines/layout_parsing/result_v2.py

@@ -326,6 +326,15 @@ class LayoutParsingResultV2(BaseCVResult, HtmlMixin, XlsxMixin, MarkdownMixin):
                 )
                 return "\n".join(img_tags)
 
+            def format_chart():
+                if not self["model_settings"].get("use_chart_recognition", False):
+                    return format_image()
+                lines_list = block.content.split("\n")
+                column_num = len(lines_list[0].split("|"))
+                lines_list.insert(1, "|".join(["---"] * column_num))
+                lines_list = [f"|{line}|" for line in lines_list]
+                return "\n".join(lines_list)
+
             def format_first_line(templates, format_func, spliter):
                 lines = block.content.split(spliter)
                 for idx in range(len(lines)):
@@ -420,7 +429,7 @@ class LayoutParsingResultV2(BaseCVResult, HtmlMixin, XlsxMixin, MarkdownMixin):
                     "\n", "  \n"
                 ),
                 "image": lambda: format_image(),
-                "chart": lambda: format_image(),
+                "chart": lambda: format_chart(),
                 "formula": lambda: f"$${block.content}$$",
                 "table": format_table,
                 "reference": lambda: format_first_line(