Просмотр исходного кода

fix: 修复文档预处理适配器中的类型注解和优化文本框检测逻辑

zhch158_admin 2 недель назад
Родитель
Сommit
df50db1eeb
1 измененных файлов с 19 добавлено и 50 удалено
  1. 19 50
      zhch/adapters/doc_preprocessor_adapter.py

+ 19 - 50
zhch/adapters/doc_preprocessor_adapter.py

@@ -5,7 +5,7 @@
 
 import sys
 from pathlib import Path
-from typing import Any, Dict, List, Optional, Union
+from typing import Any, Dict, List, Optional, Union, Tuple
 import numpy as np
 import cv2
 
@@ -96,7 +96,7 @@ class EnhancedDocPreprocessor:
         print(f"   📏 Image size: {img_width}x{img_height}, aspect_ratio: {aspect_ratio:.2f}, is_portrait: {is_portrait}")
         return is_portrait
     
-    def _detect_vertical_text_boxes(self, image: np.ndarray) -> tuple[int, int]:
+    def _detect_vertical_text_boxes(self, image: np.ndarray) -> Tuple[int, int]:
         """
         检测图片中的垂直文本框
         
@@ -133,55 +133,24 @@ class EnhancedDocPreprocessor:
             total_count = len(boxes)
             
             # 🎯 处理 numpy 数组格式: shape=(N, 4, 2)
-            if isinstance(boxes, np.ndarray):
-                if len(boxes.shape) == 3 and boxes.shape[1] == 4 and boxes.shape[2] == 2:
-                    # 格式: (N, 4, 2) - 每个框有4个点,每个点有(x,y)坐标
-                    for box in boxes:
-                        # box: shape=(4, 2) - [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
-                        p1, p2, p3, p4 = box
-                        
-                        # 计算宽高
-                        width = abs(float(p2[0] - p1[0]))  # x2 - x1
-                        height = abs(float(p3[1] - p2[1]))  # y3 - y2
-                        
-                        if height == 0:
-                            continue
-                        
-                        aspect_ratio = width / height
-                        
-                        # 🎯 MinerU 的判断标准:宽高比 < 0.8 为垂直文本
-                        if aspect_ratio < 0.8:
-                            vertical_count += 1
-                else:
-                    # 其他格式,尝试遍历处理
-                    for box in boxes:
-                        if isinstance(box, np.ndarray) and len(box) >= 4:
-                            self._process_single_box(box, vertical_count)
-            else:
-                # 处理列表格式
+            if isinstance(boxes, np.ndarray) and len(boxes.shape) == 3 and boxes.shape[1] == 4 and boxes.shape[2] == 2:
+                # 格式: (N, 4, 2) - 每个框有4个点,每个点有(x,y)坐标
                 for box in boxes:
-                    if isinstance(box, (list, tuple, np.ndarray)):
-                        if len(box) >= 4:
-                            # 格式: [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
-                            if isinstance(box[0], (list, tuple, np.ndarray)) and len(box[0]) >= 2:
-                                p1, p2, p3, p4 = box[:4]
-                                width = abs(float(p2[0]) - float(p1[0]))
-                                height = abs(float(p3[1]) - float(p2[1]))
-                            # 格式: [x1,y1,x2,y2,x3,y3,x4,y4]
-                            elif len(box) >= 8:
-                                width = abs(float(box[2]) - float(box[0]))
-                                height = abs(float(box[5]) - float(box[3]))
-                            else:
-                                continue
-                            
-                            if height == 0:
-                                continue
-                            
-                            aspect_ratio = width / height
-                            
-                            # 🎯 MinerU 的判断标准:宽高比 < 0.8 为垂直文本
-                            if aspect_ratio < 0.8:
-                                vertical_count += 1
+                    # box: shape=(4, 2) - [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
+                    p1, p2, p3, p4 = box
+                    
+                    # 计算宽高
+                    width = abs(float(p2[0] - p1[0]))  # x2 - x1
+                    height = abs(float(p3[1] - p2[1]))  # y3 - y2
+                    
+                    if height == 0:
+                        continue
+                    
+                    aspect_ratio = width / height
+                    
+                    # 🎯 MinerU 的判断标准:宽高比 < 0.8 为垂直文本
+                    if aspect_ratio < 0.8:
+                        vertical_count += 1
             
             print(f"   📊 OCR detection: {vertical_count}/{total_count} vertical boxes ({vertical_count/total_count:.1%} vertical)")
             return vertical_count, total_count