Эх сурвалжийг харах

support inline formula embedding & update reference format (#2907)

shuai.liu 10 сар өмнө
parent
commit
d52120387e

+ 41 - 16
paddlex/inference/pipelines_new/layout_parsing/pipeline_v2.py

@@ -336,6 +336,24 @@ class LayoutParsingPipelineV2(BasePipeline):
                 self.layout_det_model(doc_preprocessor_image),
             )
 
+            if model_settings["use_formula_recognition"]:
+                formula_res_all = next(
+                    self.formula_recognition_pipeline(
+                        doc_preprocessor_image,
+                        use_layout_detection=False,
+                        use_doc_orientation_classify=False,
+                        use_doc_unwarping=False,
+                        layout_det_res=layout_det_res,
+                    ),
+                )
+                formula_res_list = formula_res_all["formula_res_list"]
+            else:
+                formula_res_list = []
+
+            for formula_res in formula_res_list:
+                x_min, y_min, x_max, y_max = list(map(int, formula_res["dt_polys"]))
+                doc_preprocessor_image[y_min:y_max, x_min:x_max, :] = 255.0
+
             if (
                 model_settings["use_general_ocr"]
                 or model_settings["use_table_recognition"]
@@ -351,6 +369,24 @@ class LayoutParsingPipelineV2(BasePipeline):
                         text_rec_score_thresh=text_rec_score_thresh,
                     ),
                 )
+
+                for formula_res in formula_res_list:
+                    x_min, y_min, x_max, y_max = list(map(int, formula_res["dt_polys"]))
+                    poly_points = [
+                        (x_min, y_min),
+                        (x_max, y_min),
+                        (x_max, y_max),
+                        (x_min, y_max),
+                    ]
+                    overall_ocr_res["dt_polys"].append(poly_points)
+                    overall_ocr_res["rec_texts"].append(
+                        f"${formula_res['rec_formula']}$"
+                    )
+                    overall_ocr_res["rec_boxes"] = np.vstack(
+                        (overall_ocr_res["rec_boxes"], [formula_res["dt_polys"]])
+                    )
+                    overall_ocr_res["rec_polys"].append(poly_points)
+                    overall_ocr_res["rec_scores"].append(1)
             else:
                 overall_ocr_res = {}
 
@@ -398,22 +434,11 @@ class LayoutParsingPipelineV2(BasePipeline):
             else:
                 seal_res_list = []
 
-            if model_settings["use_formula_recognition"]:
-                formula_res_all = next(
-                    self.formula_recognition_pipeline(
-                        doc_preprocessor_image,
-                        use_layout_detection=False,
-                        use_doc_orientation_classify=False,
-                        use_doc_unwarping=False,
-                        layout_det_res=layout_det_res,
-                    ),
-                )
-                formula_res_list = formula_res_all["formula_res_list"]
-            else:
-                formula_res_list = []
-
-            for table_res in table_res_list:
-                table_res["layout_bbox"] = table_res["cell_box_list"][0]
+            for formula_res in formula_res_list:
+                x_min, y_min, x_max, y_max = list(map(int, formula_res["dt_polys"]))
+                doc_preprocessor_image[y_min:y_max, x_min:x_max, :] = formula_res[
+                    "input_img"
+                ]
 
             structure_res = get_structure_res(
                 overall_ocr_res,

+ 2 - 2
paddlex/inference/pipelines_new/layout_parsing/result_v2.py

@@ -395,11 +395,11 @@ class LayoutParsingResultV2(BaseCVResult, HtmlMixin, XlsxMixin, MarkdownMixin):
                 return "\n".join(img_tags)
 
             def format_reference():
-                pattern = r"\[\d+\]"
+                pattern = r"\s*\[\s*\d+\s*\]\s*"
                 res = re.sub(
                     pattern,
                     lambda match: "\n" + match.group(),
-                    sub_block["reference"],
+                    sub_block["reference"].replace("\n", ""),
                 )
                 return "\n" + res
 

+ 7 - 4
paddlex/inference/pipelines_new/layout_parsing/utils.py

@@ -191,7 +191,7 @@ def _sort_box_by_y_projection(layout_bbox, ocr_res, line_height_threshold=0.7):
         first_span = line[0]
         end_span = line[-1]
         if first_span[0][0] - x_min > 20:
-            first_span[1] = "\n " + first_span[1]
+            first_span[1] = "\n" + first_span[1]
         if x_max - end_span[0][2] > 20:
             end_span[1] = end_span[1] + "\n"
 
@@ -235,13 +235,12 @@ def get_structure_res(
         layout_bbox = box_info["coordinate"]
         label = box_info["label"]
         rec_res = {"boxes": [], "rec_texts": [], "flag": False}
-        drop_index = []
         seg_start_flag = True
         seg_end_flag = True
 
         if label == "table":
             for i, table_res in enumerate(table_res_list):
-                if calculate_iou(layout_bbox, table_res["layout_bbox"]) > 0.5:
+                if calculate_iou(layout_bbox, table_res["cell_box_list"][0]) > 0.5:
                     structure_boxes.append(
                         {
                             "label": label,
@@ -262,7 +261,6 @@ def get_structure_res(
                         overall_ocr_res["rec_texts"][box_no],
                     )
                     rec_res["flag"] = True
-                    drop_index.append(box_no)
 
             if rec_res["flag"]:
                 rec_res = _sort_box_by_y_projection(layout_bbox, rec_res, 0.7)
@@ -272,6 +270,11 @@ def get_structure_res(
                     seg_start_flag = False
                 if layout_bbox[2] - rec_res_end_bbox[2] < 20:
                     seg_end_flag = False
+                if label == "formula":
+                    rec_res["rec_texts"] = [
+                        rec_res_text.replace("$", "")
+                        for rec_res_text in rec_res["rec_texts"]
+                    ]
 
             if label in ["chart", "image"]:
                 structure_boxes.append(