Browse Source

support batch of formula and update detection param (#2925)

liuhongen1234567 10 months ago
parent
commit
b4b3847e98

+ 1 - 1
docs/module_usage/tutorials/ocr_modules/formula_recognition.md

@@ -73,7 +73,7 @@ for res in output:
 
 <img src="https://raw.githubusercontent.com/cuicheng01/PaddleX_doc_images/refs/heads/main/images/modules/formula_recog/general_formula_rec_001_res.png">
 
-<b> 注:如果您需要对公式识别产线进行可视化,需要运行如下命令来对LaTeX渲染环境进行安装。目前公式识别产线可视化只支持Ubuntu环境,其他环境暂不支持:</b>
+<b> 注:如果您需要对公式识别产线进行可视化,需要运行如下命令来对LaTeX渲染环境进行安装。目前公式识别产线可视化只支持Ubuntu环境,其他环境暂不支持。对于复杂公式,LaTeX 结果可能包含部分高级的表示,Markdown等环境中未必可以成功显示:</b>
 ```bash
 sudo apt-get update
 sudo apt-get install texlive texlive-latex-base texlive-latex-extra -y

+ 5 - 1
paddlex/configs/pipelines/formula_recognition.yaml

@@ -7,8 +7,12 @@ use_doc_preprocessor: True
 SubModules:
   LayoutDetection:
     module_name: layout_detection
-    model_name: RT-DETR-H_layout_17cls
+    model_name: PP-DocLayout-L
     model_dir: null
+    threshold: 0.5
+    layout_nms: True
+    layout_unclip_ratio: 1.0
+    layout_merge_bboxes_mode: "large"
     batch_size: 1
 
   FormulaRecognition:

+ 1 - 1
paddlex/inference/models_new/formula_recognition/result.py

@@ -36,7 +36,7 @@ class FormulaRecResult(BaseCVResult):
     def _to_str(self, *args, **kwargs):
         data = copy.deepcopy(self)
         data.pop("input_img")
-        _str = JsonMixin._to_str(data, *args, **kwargs)["res"].replace("\\\\", "\\")
+        _str = JsonMixin._to_str(data, *args, **kwargs)["res"]
         return {"res": _str}
 
     def _to_json(self, *args, **kwargs):

+ 45 - 7
paddlex/inference/pipelines_new/formula_recognition/pipeline.py

@@ -13,7 +13,7 @@
 # limitations under the License.
 
 import os, sys
-from typing import Any, Dict, Optional, Union, List
+from typing import Any, Dict, Optional, Union, List, Tuple
 import numpy as np
 import cv2
 from ..base import BasePipeline
@@ -77,6 +77,26 @@ class FormulaRecognitionPipeline(BasePipeline):
                 {"model_config_error": "config error for layout_det_model!"},
             )
             self.layout_det_model = self.create_model(layout_det_config)
+            layout_kwargs = {}
+            if (threshold := layout_det_config.get("threshold", None)) is not None:
+                layout_kwargs["threshold"] = threshold
+            if (layout_nms := layout_det_config.get("layout_nms", None)) is not None:
+                layout_kwargs["layout_nms"] = layout_nms
+            if (
+                layout_unclip_ratio := layout_det_config.get(
+                    "layout_unclip_ratio", None
+                )
+            ) is not None:
+                layout_kwargs["layout_unclip_ratio"] = layout_unclip_ratio
+            if (
+                layout_merge_bboxes_mode := layout_det_config.get(
+                    "layout_merge_bboxes_mode", None
+                )
+            ) is not None:
+                layout_kwargs["layout_merge_bboxes_mode"] = layout_merge_bboxes_mode
+            self.layout_det_model = self.create_model(
+                layout_det_config, **layout_kwargs
+            )
 
         formula_recognition_config = config.get("SubModules", {}).get(
             "FormulaRecognition",
@@ -182,6 +202,10 @@ class FormulaRecognitionPipeline(BasePipeline):
         use_doc_orientation_classify: Optional[bool] = None,
         use_doc_unwarping: Optional[bool] = None,
         layout_det_res: Optional[DetResult] = None,
+        layout_threshold: Optional[Union[float, dict]] = None,
+        layout_nms: Optional[bool] = None,
+        layout_unclip_ratio: Optional[Union[float, Tuple[float, float]]] = None,
+        layout_merge_bboxes_mode: Optional[str] = None,
         **kwargs,
     ) -> FormulaRecognitionResult:
         """
@@ -245,22 +269,36 @@ class FormulaRecognitionPipeline(BasePipeline):
                 formula_region_id += 1
             else:
                 if model_settings["use_layout_detection"]:
-                    layout_det_res = next(self.layout_det_model(doc_preprocessor_image))
+                    layout_det_res = next(
+                        self.layout_det_model(
+                            doc_preprocessor_image,
+                            threshold=layout_threshold,
+                            layout_nms=layout_nms,
+                            layout_unclip_ratio=layout_unclip_ratio,
+                            layout_merge_bboxes_mode=layout_merge_bboxes_mode,
+                        )
+                    )
+                formula_crop_img = []
                 for box_info in layout_det_res["boxes"]:
                     if box_info["label"].lower() in ["formula"]:
                         crop_img_info = self._crop_by_boxes(
                             doc_preprocessor_image, [box_info]
                         )
                         crop_img_info = crop_img_info[0]
-                        single_formula_rec_res = (
-                            self.predict_single_formula_recognition_res(
-                                crop_img_info["img"]
-                            )
-                        )
+                        formula_crop_img.append(crop_img_info["img"])
+                        single_formula_rec_res = {}
                         single_formula_rec_res["formula_region_id"] = formula_region_id
                         single_formula_rec_res["dt_polys"] = box_info["coordinate"]
                         formula_res_list.append(single_formula_rec_res)
                         formula_region_id += 1
+                for idx, formula_rec_res in enumerate(
+                    self.formula_recognition_model(formula_crop_img)
+                ):
+                    formula_region_id = formula_res_list[idx]["formula_region_id"]
+                    dt_polys = formula_res_list[idx]["dt_polys"]
+                    formula_rec_res["formula_region_id"] = formula_region_id
+                    formula_rec_res["dt_polys"] = dt_polys
+                    formula_res_list[idx] = formula_rec_res
 
             single_img_res = {
                 "input_path": input_path,