소스 검색

fix formula to ocr_rec bug (#3999)

changdazhou 6 달 전
부모
커밋
fe92f3f5ea

+ 6 - 3
paddlex/inference/pipelines/layout_parsing/pipeline_v2.py

@@ -1149,9 +1149,12 @@ class _LayoutParsingPipelineV2(BasePipeline):
                             (x_min, y_max),
                         ]
                         table_contents_for_img["dt_polys"].append(poly_points)
-                        table_contents_for_img["rec_texts"].append(
-                            f"${formula_res['rec_formula']}$"
-                        )
+                        rec_formula = formula_res["rec_formula"]
+                        if not rec_formula.startswith("$") or not rec_formula.endswith(
+                            "$"
+                        ):
+                            rec_formula = f"${rec_formula}$"
+                        table_contents_for_img["rec_texts"].append(f"{rec_formula}")
                         if table_contents_for_img["rec_boxes"].size == 0:
                             table_contents_for_img["rec_boxes"] = np.array(
                                 [formula_res["dt_polys"]]

+ 2 - 2
paddlex/inference/pipelines/layout_parsing/result_v2.py

@@ -709,14 +709,14 @@ class LayoutParsingRegion:
 
     def calculate_bbox_metrics(self, image_shape):
         x1, y1, x2, y2 = self.bbox
-        width = x2 - x1
         image_height, image_width = image_shape
+        width = x2 - x1
         x_center, y_center = (x1 + x2) / 2, (y1 + y2) / 2
         self.euclidean_distance = math.sqrt(((x1) ** 2 + (y1) ** 2))
         self.center_euclidean_distance = math.sqrt(((x_center) ** 2 + (y_center) ** 2))
         self.angle_rad = math.atan2(y_center, x_center)
         self.weighted_distance = (
-            y1 + width + (x1 // (image_width // 10)) * (image_width // 10) * 1.5
+            y2 + width + (x1 // (image_width // 10)) * (image_width // 10) * 1.5
         )
 
     def sort_normal_blocks(self, blocks):

+ 6 - 8
paddlex/inference/pipelines/layout_parsing/utils.py

@@ -442,10 +442,12 @@ def format_line(
 
     for span in line:
         if span[2] == "formula" and block_label != "formula":
-            if len(line) > 1:
-                span[1] = f"${span[1]}$"
-            else:
-                span[1] = f"\n${span[1]}$"
+            formula_rec = span[1]
+            if not formula_rec.startswith("$") and not formula_rec.endswith("$"):
+                if len(line) > 1:
+                    span[1] = f"${span[1]}$"
+                else:
+                    span[1] = f"\n${span[1]}$"
 
     line_text = ""
     for span in line:
@@ -881,10 +883,6 @@ def convert_formula_res_to_ocr_format(formula_res_list: List, ocr_res: dict):
         ]
         ocr_res["dt_polys"].append(poly_points)
         formula_res_text: str = formula_res["rec_formula"]
-        if formula_res_text.startswith("$$") and formula_res_text.endswith("$$"):
-            formula_res_text = formula_res_text[2:-2]
-        elif formula_res_text.startswith("$") and formula_res_text.endswith("$"):
-            formula_res_text = formula_res_text[1:-1]
         ocr_res["rec_texts"].append(formula_res_text)
         if ocr_res["rec_boxes"].size == 0:
             ocr_res["rec_boxes"] = np.array(formula_res["dt_polys"])