Explorar o código

feat: 添加无图片模式以仅分析表格结构,优化行列边界计算逻辑

zhch158_admin hai 1 día
pai
achega
74c95e92f5
Modificáronse 1 ficheiros con 83 adicións e 19 borrados
  1. 83 19
      table_line_generator/table_line_generator.py

+ 83 - 19
table_line_generator/table_line_generator.py

@@ -15,15 +15,19 @@ from bs4 import BeautifulSoup
 class TableLineGenerator:
     """表格线生成器"""
     
-    def __init__(self, image: Union[str, Image.Image], ocr_data: Dict):
+    def __init__(self, image: Union[str, Image.Image, None], ocr_data: Dict):
         """
         初始化表格线生成器
         
         Args:
-            image: 图片路径(str) 或 PIL.Image 对象
+            image: 图片路径(str) 或 PIL.Image 对象,或 None(仅分析结构时)
             ocr_data: OCR识别结果(包含bbox)
         """
-        if isinstance(image, str):
+        if image is None:
+            # 🆕 无图片模式:仅用于结构分析
+            self.image_path = None
+            self.image = None
+        elif isinstance(image, str):
             self.image_path = image
             self.image = Image.open(image)
         elif isinstance(image, Image.Image):
@@ -31,7 +35,7 @@ class TableLineGenerator:
             self.image = image
         else:
             raise TypeError(
-                f"image 参数必须是 str (路径) 或 PIL.Image.Image 对象,"
+                f"image 参数必须是 str (路径)、PIL.Image.Image 对象或 None,"
                 f"实际类型: {type(image)}"
             )
         
@@ -221,9 +225,6 @@ class TableLineGenerator:
         """
         基于单元格的 row/col 索引分析(MinerU 专用)
         
-        Args:
-            use_table_body: 是否使用 table_body 确定准确的行列数
-        
         Returns:
             表格结构信息
         """
@@ -265,20 +266,40 @@ class TableLineGenerator:
                 bboxes = cells_by_row[row_num]
                 y_min = min(bbox[1] for bbox in bboxes)
                 y_max = max(bbox[3] for bbox in bboxes)
-                row_boundaries[row_num] = (y_min, y_max)
-        
-        # 🔑 计算横线
+                row_boundaries[row_num] = (y_min, y_max)        
+
+        # 🔑 计算横线(现在使用的是过滤后的数据)
         horizontal_lines = _calculate_horizontal_lines_with_spacing(row_boundaries)
         
-        # 🔑 计算每的 x 边界
+        # 🔑 列边界计算(同样需要过滤异常值)
         col_boundaries = {}
         for col_num in range(1, actual_cols + 1):
             if col_num in cells_by_col:
                 bboxes = cells_by_col[col_num]
-                x_min = min(bbox[0] for bbox in bboxes)
-                x_max = max(bbox[2] for bbox in bboxes)
-                col_boundaries[col_num] = (x_min, x_max)
-        
+                
+                # 🎯 过滤 x 方向的异常值(使用 IQR)
+                if len(bboxes) > 1:
+                    x_centers = [(bbox[0] + bbox[2]) / 2 for bbox in bboxes]
+                    x_center_q1 = np.percentile(x_centers, 25)
+                    x_center_q3 = np.percentile(x_centers, 75)
+                    x_center_iqr = x_center_q3 - x_center_q1
+                    x_center_median = np.median(x_centers)
+                    
+                    # 允许偏移 3 倍 IQR 或至少 100px
+                    x_threshold = max(3 * x_center_iqr, 100)
+                    
+                    valid_bboxes = [
+                        bbox for bbox in bboxes
+                        if abs((bbox[0] + bbox[2]) / 2 - x_center_median) <= x_threshold
+                    ]
+                else:
+                    valid_bboxes = bboxes
+                
+                if valid_bboxes:
+                    x_min = min(bbox[0] for bbox in valid_bboxes)
+                    x_max = max(bbox[2] for bbox in valid_bboxes)
+                    col_boundaries[col_num] = (x_min, x_max)
+    
         # 🔑 计算竖线
         vertical_lines = _calculate_vertical_lines_with_spacing(col_boundaries)
         
@@ -317,7 +338,9 @@ class TableLineGenerator:
             'table_bbox': self._get_table_bbox(),
             'total_rows': actual_rows,
             'total_cols': actual_cols,
-            'method': 'mineru'
+            'mode': 'hybrid',  # ✅ 添加 mode 字段
+            'modified_h_lines': [],  # ✅ 添加修改记录字段
+            'modified_v_lines': []   # ✅ 添加修改记录字段
         }
     
     def _analyze_by_clustering(self, y_tolerance: int, x_tolerance: int, min_row_height: int) -> Dict:
@@ -356,7 +379,7 @@ class TableLineGenerator:
         
         # 4. 提取所有bbox的X坐标(用于列检测)
         x_coords = []
-        for item in self.ocr_data:
+        for item in ocr_data:
             bbox = item.get('bbox', [])
             if len(bbox) >= 4:
                 x1, x2 = bbox[0], bbox[2]
@@ -390,7 +413,9 @@ class TableLineGenerator:
             'row_height': self.row_height,
             'col_widths': self.col_widths,
             'table_bbox': self._get_table_bbox(),
-            'method': 'cluster'
+            'mode': 'fixed',  # ✅ 添加 mode 字段
+            'modified_h_lines': [],  # ✅ 添加修改记录字段
+            'modified_v_lines': []   # ✅ 添加修改记录字段
         }
 
     @staticmethod
@@ -498,6 +523,12 @@ class TableLineGenerator:
                             line_color: Tuple[int, int, int] = (0, 0, 255),
                             line_width: int = 2) -> Image.Image:
         """在原图上绘制表格线"""
+        if self.image is None:
+            raise ValueError(
+                "无图片模式下不能调用 generate_table_lines(),"
+                "请在初始化时提供图片"
+            )
+        
         img_with_lines = self.image.copy()
         draw = ImageDraw.Draw(img_with_lines)
         
@@ -526,6 +557,38 @@ class TableLineGenerator:
         
         return img_with_lines
 
+    @staticmethod
+    def analyze_structure_only(
+        ocr_data: Dict,
+        y_tolerance: int = 5,
+        x_tolerance: int = 10,
+        min_row_height: int = 20,
+        method: str = "auto"
+    ) -> Dict:
+        """
+        仅分析表格结构(无需图片)
+        
+        Args:
+            ocr_data: OCR识别结果
+            y_tolerance: Y轴聚类容差(像素)
+            x_tolerance: X轴聚类容差(像素)
+            min_row_height: 最小行高(像素)
+            method: 分析方法 ("auto" / "cluster" / "mineru")
+        
+        Returns:
+            表格结构信息
+        """
+        # 🔑 创建无图片模式的生成器
+        temp_generator = TableLineGenerator(None, ocr_data)
+        
+        # 🔑 分析结构
+        return temp_generator.analyze_table_structure(
+            y_tolerance=y_tolerance,
+            x_tolerance=x_tolerance,
+            min_row_height=min_row_height,
+            method=method
+        )
+
 
 def _calculate_horizontal_lines_with_spacing(row_boundaries: Dict[int, Tuple[int, int]]) -> List[int]:
     """
@@ -593,7 +656,8 @@ def _calculate_horizontal_lines_with_spacing(row_boundaries: Dict[int, Tuple[int
                 horizontal_lines.append(separator_y)
             else:
                 # 重叠或紧贴:在当前行的下边界画线
-                horizontal_lines.append(y_max)
+                separator_y = int(next_y_min) - max(int(gap / 4), 2)
+                horizontal_lines.append(separator_y)
         else:
             # 最后一行的下边界
             horizontal_lines.append(y_max)