Преглед на файлове

add formula recognition to layout parsing pipeline (#2861)

changdazhou преди 10 месеца
родител
ревизия
4c8efc3ca0

+ 2 - 5
docs/pipeline_usage/tutorials/ocr_pipelines/formula_recognition.md

@@ -255,12 +255,9 @@ for res in output:
 ```
 ### 2.3 公式识别产线可视化
 如果您需要对公式识别产线进行可视化,需要运行如下命令来对LaTeX渲染环境进行安装:
-```python
-apt-get install sudo
+```shell
 sudo apt-get update
-sudo apt-get install texlive
-sudo apt-get install texlive-latex-base
-sudo apt-get install texlive-latex-extra
+sudo apt-get install texlive texlive-latex-base texlive-latex-extra -y
 ```
 之后,使用 `save_to_img` 方法对可视化图片进行保存。具体命令如下:
 ```python

+ 13 - 1
paddlex/configs/pipelines/layout_parsing.yaml

@@ -5,11 +5,12 @@ use_doc_preprocessor: True
 use_general_ocr: True
 use_seal_recognition: True
 use_table_recognition: True
+use_formula_recognition: True
 
 SubModules:
   LayoutDetection:
     module_name: layout_detection
-    model_name: RT-DETR-H_layout_3cls
+    model_name: RT-DETR-H_layout_17cls
     model_dir: null
 
 SubPipelines:
@@ -87,3 +88,14 @@ SubPipelines:
             model_dir: null
             batch_size: 1
             score_thresh: 0
+    
+  FormulaRecognition:
+    pipeline_name: formula_recognition
+    use_layout_detection: False
+    use_doc_preprocessor: False
+    SubModules:
+      FormulaRecognition:
+        module_name: formula_recognition
+        model_name: PP-FormulaNet-L
+        model_dir: null
+        batch_size: 5

+ 4 - 1
paddlex/inference/pipelines_new/formula_recognition/pipeline.py

@@ -113,7 +113,10 @@ class FormulaRecognitionPipeline(BasePipeline):
         if use_doc_orientation_classify is None and use_doc_unwarping is None:
             use_doc_preprocessor = self.use_doc_preprocessor
         else:
-            use_doc_preprocessor = True
+            if use_doc_orientation_classify is True or use_doc_unwarping is True:
+                use_doc_preprocessor = True
+            else:
+                use_doc_preprocessor = False
 
         if use_layout_detection is None:
             use_layout_detection = self.use_layout_detection

+ 35 - 1
paddlex/inference/pipelines_new/layout_parsing/pipeline.py

@@ -25,7 +25,6 @@ from ...utils.pp_option import PaddlePredictorOption
 from ...common.reader import ReadImage
 from ...common.batch_sampler import ImageBatchSampler
 from ..ocr.result import OCRResult
-from ..doc_preprocessor.result import DocPreprocessorResult
 
 # [TODO] 待更新models_new到models
 from ...models_new.object_detection.result import DetResult
@@ -78,6 +77,7 @@ class LayoutParsingPipeline(BasePipeline):
         self.use_general_ocr = config.get("use_general_ocr", True)
         self.use_table_recognition = config.get("use_table_recognition", True)
         self.use_seal_recognition = config.get("use_seal_recognition", True)
+        self.use_formula_recognition = config.get("use_formula_recognition", True)
 
         if self.use_doc_preprocessor:
             doc_preprocessor_config = config.get("SubPipelines", {}).get(
@@ -125,6 +125,17 @@ class LayoutParsingPipeline(BasePipeline):
                 table_recognition_config
             )
 
+        if self.use_formula_recognition:
+            formula_recognition_config = config.get("SubPipelines", {}).get(
+                "FormulaRecognition",
+                {
+                    "pipeline_config_error": "config error for formula_recognition_pipeline!"
+                },
+            )
+            self.formula_recognition_pipeline = self.create_pipeline(
+                formula_recognition_config
+            )
+
         return
 
     def get_text_paragraphs_ocr_res(
@@ -191,6 +202,7 @@ class LayoutParsingPipeline(BasePipeline):
         use_general_ocr: Optional[bool],
         use_seal_recognition: Optional[bool],
         use_table_recognition: Optional[bool],
+        use_formula_recognition: Optional[bool],
     ) -> dict:
         """
         Get the model settings based on the provided parameters or default values.
@@ -222,11 +234,15 @@ class LayoutParsingPipeline(BasePipeline):
         if use_table_recognition is None:
             use_table_recognition = self.use_table_recognition
 
+        if use_formula_recognition is None:
+            use_formula_recognition = self.use_formula_recognition
+
         return dict(
             use_doc_preprocessor=use_doc_preprocessor,
             use_general_ocr=use_general_ocr,
             use_seal_recognition=use_seal_recognition,
             use_table_recognition=use_table_recognition,
+            use_formula_recognition=use_formula_recognition,
         )
 
     def predict(
@@ -237,6 +253,7 @@ class LayoutParsingPipeline(BasePipeline):
         use_general_ocr: Optional[bool] = None,
         use_seal_recognition: Optional[bool] = None,
         use_table_recognition: Optional[bool] = None,
+        use_formula_recognition: Optional[bool] = None,
         text_det_limit_side_len: Optional[int] = None,
         text_det_limit_type: Optional[str] = None,
         text_det_thresh: Optional[float] = None,
@@ -273,6 +290,7 @@ class LayoutParsingPipeline(BasePipeline):
             use_general_ocr,
             use_seal_recognition,
             use_table_recognition,
+            use_formula_recognition,
         )
 
         if not self.check_model_settings_valid(model_settings):
@@ -363,6 +381,21 @@ class LayoutParsingPipeline(BasePipeline):
             else:
                 seal_res_list = []
 
+            if model_settings["use_formula_recognition"]:
+                formula_res_all = next(
+                    self.formula_recognition_pipeline(
+                        doc_preprocessor_image,
+                        use_layout_detection=False,
+                        use_doc_orientation_classify=False,
+                        use_doc_unwarping=False,
+                        layout_det_res=layout_det_res,
+                    )
+                )
+                formula_res_list = formula_res_all["formula_res_list"]
+                print(formula_res_list)
+            else:
+                formula_res_list = []
+
             single_img_res = {
                 "input_path": input_path,
                 "doc_preprocessor_res": doc_preprocessor_res,
@@ -371,6 +404,7 @@ class LayoutParsingPipeline(BasePipeline):
                 "text_paragraphs_ocr_res": text_paragraphs_ocr_res,
                 "table_res_list": table_res_list,
                 "seal_res_list": seal_res_list,
+                "formula_res_list": formula_res_list,
                 "model_settings": model_settings,
             }
             yield LayoutParsingResult(single_img_res)

+ 29 - 0
paddlex/inference/pipelines_new/layout_parsing/result.py

@@ -66,6 +66,18 @@ class LayoutParsingResult(BaseCVResult, HtmlMixin, XlsxMixin):
                 sub_seal_res_dict = seal_res.img
                 key = f"seal_res_region{seal_region_id}"
                 res_img_dict[key] = sub_seal_res_dict["ocr_res_img"]
+
+        if (
+            model_settings["use_formula_recognition"]
+            and len(self["formula_res_list"]) > 0
+        ):
+            for sno in range(len(self["formula_res_list"])):
+                formula_res = self["formula_res_list"][sno]
+                formula_region_id = formula_res["formula_region_id"]
+                sub_formula_res_dict = formula_res.img
+                key = f"formula_res_region{formula_region_id}"
+                res_img_dict[key] = sub_formula_res_dict
+
         return res_img_dict
 
     def _to_str(self, *args, **kwargs) -> Dict[str, str]:
@@ -106,6 +118,15 @@ class LayoutParsingResult(BaseCVResult, HtmlMixin, XlsxMixin):
             for sno in range(len(self["seal_res_list"])):
                 seal_res = self["seal_res_list"][sno]
                 data["seal_res_list"].append(seal_res.str["res"])
+        if (
+            model_settings["use_formula_recognition"]
+            and len(self["formula_res_list"]) > 0
+        ):
+            data["formula_res_list"] = []
+            for sno in range(len(self["formula_res_list"])):
+                formula_res = self["formula_res_list"][sno]
+                data["formula_res_list"].append(formula_res.str["res"])
+
         return StrMixin._to_str(data, *args, **kwargs)
 
     def _to_json(self, *args, **kwargs) -> Dict[str, str]:
@@ -147,6 +168,14 @@ class LayoutParsingResult(BaseCVResult, HtmlMixin, XlsxMixin):
             for sno in range(len(self["seal_res_list"])):
                 seal_res = self["seal_res_list"][sno]
                 data["seal_res_list"].append(seal_res.json["res"])
+        if (
+            model_settings["use_formula_recognition"]
+            and len(self["formula_res_list"]) > 0
+        ):
+            data["formula_res_list"] = []
+            for sno in range(len(self["formula_res_list"])):
+                formula_res = self["formula_res_list"][sno]
+                data["formula_res_list"].append(formula_res.json["res"])
         return JsonMixin._to_json(data, *args, **kwargs)
 
     def _to_html(self) -> Dict[str, str]: