瀏覽代碼

[cherry-pick]mv crop formula from gen_ai_client to pipeline (#4679)

* update docs

* compatible with python3.9

* support print parsing_res_list

* mv crop formula from gen_ai_client to pipeline
changdazhou 2 周之前
父節點
當前提交
0af6510a6e

+ 0 - 43
paddlex/inference/models/doc_vlm/predictor.py

@@ -370,46 +370,6 @@ class DocVLMPredictor(BasePredictor):
         }
         return rst_dict
 
-    def crop_margin(self, img):  # 输入是OpenCV图像 (numpy数组)
-        import cv2
-
-        # 如果输入是彩色图像,转换为灰度图
-        if len(img.shape) == 3:
-            gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
-        else:
-            gray = img.copy()
-
-        # 转换为0-255范围(确保是uint8类型)
-        if gray.dtype != np.uint8:
-            gray = gray.astype(np.uint8)
-
-        max_val = gray.max()
-        min_val = gray.min()
-
-        if max_val == min_val:
-            return img
-
-        # 归一化并二值化(与PIL版本逻辑一致)
-        data = (gray - min_val) / (max_val - min_val) * 255
-        data = data.astype(np.uint8)
-
-        # 创建二值图像(暗色区域为白色,亮色区域为黑色)
-        _, binary = cv2.threshold(data, 200, 255, cv2.THRESH_BINARY_INV)
-
-        # 查找非零像素坐标
-        coords = cv2.findNonZero(binary)
-
-        if coords is None:  # 如果没有找到任何内容,返回原图
-            return img
-
-        # 获取边界框
-        x, y, w, h = cv2.boundingRect(coords)
-
-        # 裁剪图像
-        cropped = img[y : y + h, x : x + w]
-
-        return cropped
-
     def _genai_client_process(
         self,
         data,
@@ -425,9 +385,6 @@ class DocVLMPredictor(BasePredictor):
 
         def _process(item):
             image = item["image"]
-            prompt = item["query"]
-            if prompt == "Formula Recognition:":
-                image = self.crop_margin(image)
             if isinstance(image, str):
                 if image.startswith("http://") or image.startswith("https://"):
                     image_url = image

+ 2 - 0
paddlex/inference/pipelines/paddleocr_vl/pipeline.py

@@ -35,6 +35,7 @@ from ..layout_parsing.utils import gather_imgs
 from .result import PaddleOCRVLBlock, PaddleOCRVLResult
 from .uilts import (
     convert_otsl_to_html,
+    crop_margin,
     filter_overlap_boxes,
     merge_blocks,
     tokenize_figure_of_table,
@@ -243,6 +244,7 @@ class _PaddleOCRVLPipeline(BasePipeline):
                         text_prompt = "Chart Recognition:"
                     elif "formula" in block_label and block_label != "formula_number":
                         text_prompt = "Formula Recognition:"
+                        block_img = crop_margin(block_img)
                     block_imgs.append(block_img)
                     text_prompts.append(text_prompt)
                     figure_token_maps.append(figure_token_map)

+ 32 - 0
paddlex/inference/pipelines/paddleocr_vl/uilts.py

@@ -923,3 +923,35 @@ def truncate_repetitive_content(
         return most_common_line
 
     return content
+
+
+def crop_margin(img):
+    import cv2
+
+    if len(img.shape) == 3:
+        gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+    else:
+        gray = img.copy()
+
+    if gray.dtype != np.uint8:
+        gray = gray.astype(np.uint8)
+
+    max_val = gray.max()
+    min_val = gray.min()
+
+    if max_val == min_val:
+        return img
+
+    data = (gray - min_val) / (max_val - min_val) * 255
+    data = data.astype(np.uint8)
+
+    _, binary = cv2.threshold(data, 200, 255, cv2.THRESH_BINARY_INV)
+    coords = cv2.findNonZero(binary)
+
+    if coords is None:
+        return img
+
+    x, y, w, h = cv2.boundingRect(coords)
+    cropped = img[y : y + h, x : x + w]
+
+    return cropped