Răsfoiți Sursa

feat: 统一OCR数据解析接口,支持多种工具类型并优化数据提取流程

zhch158_admin 2 zile în urmă
părinte
comite
7e26b885b4
1 a modificat fișierele cu 229 adăugiri și 108 ștergeri
  1. 229 108
      table_line_generator/table_line_generator.py

+ 229 - 108
table_line_generator/table_line_generator.py

@@ -15,7 +15,7 @@ from bs4 import BeautifulSoup
 class TableLineGenerator:
     """表格线生成器"""
     
-    def __init__(self, image: Union[str, Image.Image], ocr_data: List[Dict]):
+    def __init__(self, image: Union[str, Image.Image], ocr_data: Dict):
         """
         初始化表格线生成器
         
@@ -45,16 +45,34 @@ class TableLineGenerator:
 
 
     @staticmethod
-    def parse_mineru_table_result(mineru_result: Union[Dict, List], use_table_body: bool = True) -> Tuple[List[int], Dict]:
+    def parse_ocr_data(ocr_result: Dict, tool: str = "ppstructv3") -> Tuple[List[int], Dict]:
         """
-        解析 MinerU 格式的结果,自动提取 table 并计算行列分割线
+        统一的 OCR 数据解析接口(第一步:仅读取数据)
         
         Args:
-            mineru_result: MinerU 的完整 JSON 结果(可以是 dict 或 list)
-            use_table_body: 是否使用 table_body 来确定准确的行列数
+            ocr_result: OCR 识别结果(完整 JSON)
+            tool: 工具类型 ("ppstructv3" / "mineru")
+        
+        Returns:
+            (table_bbox, ocr_data): 表格边界框和文本框列表
+        """
+        if tool.lower() == "mineru":
+            return TableLineGenerator._parse_mineru_data(ocr_result)
+        elif tool.lower() in ["ppstructv3", "ppstructure"]:
+            return TableLineGenerator._parse_ppstructure_data(ocr_result)
+        else:
+            raise ValueError(f"不支持的工具类型: {tool}")
+    
+    @staticmethod
+    def _parse_mineru_data(mineru_result: Union[Dict, List]) -> Tuple[List[int], Dict]:
+        """
+        解析 MinerU 格式数据(仅提取数据,不分析结构)
         
+        Args:
+            mineru_result: MinerU 的完整 JSON 结果
+            
         Returns:
-            (table_bbox, structure): 表格边界框和结构信息
+            (table_bbox, ocr_data): 表格边界框和文本框列表
         """
         # 🔑 提取 table 数据
         table_data = _extract_table_data(mineru_result)
@@ -71,86 +89,21 @@ class TableLineGenerator:
             raise ValueError("table_cells 为空")
         
         # 🔑 优先使用 table_body 确定准确的行列数
-        if use_table_body and 'table_body' in table_data:
+        if 'table_body' in table_data:
             actual_rows, actual_cols = _parse_table_body_structure(table_data['table_body'])
             print(f"📋 从 table_body 解析: {actual_rows} 行 × {actual_cols} 列")
         else:
             # 回退:从 table_cells 推断
             actual_rows = max(cell.get('row', 0) for cell in table_cells if 'row' in cell)
             actual_cols = max(cell.get('col', 0) for cell in table_cells if 'col' in cell)
-            print(f"📋 从 table_cells 推断: {actual_rows} 行 × {actual_cols} 列")
-        
-        # 🔑 按行列索引分组单元格
-        cells_by_row = {}
-        cells_by_col = {}
-        
-        for cell in table_cells:
-            if 'row' not in cell or 'col' not in cell or 'bbox' not in cell:
-                continue
-            
-            row = cell['row']
-            col = cell['col']
-            bbox = cell['bbox']  # [x1, y1, x2, y2]
-            
-            # 仅保留在有效范围内的单元格
-            if row <= actual_rows and col <= actual_cols:
-                if row not in cells_by_row:
-                    cells_by_row[row] = []
-                cells_by_row[row].append(bbox)
-                
-                if col not in cells_by_col:
-                    cells_by_col[col] = []
-                cells_by_col[col].append(bbox)
-        
-        # 🔑 计算每行的 y 边界(考虑折行)
-        row_boundaries = {}
-        for row_num in range(1, actual_rows + 1):
-            if row_num in cells_by_row:
-                bboxes = cells_by_row[row_num]
-                y_min = min(bbox[1] for bbox in bboxes)
-                y_max = max(bbox[3] for bbox in bboxes)
-                row_boundaries[row_num] = (y_min, y_max)
-        
-        # 🔑 分析行间距,识别记录边界
-        horizontal_lines = _calculate_horizontal_lines_with_spacing(row_boundaries)
-        
-        # 🔑 计算竖线(考虑列间距)
-        col_boundaries = {}
-        for col_num in range(1, actual_cols + 1):
-            if col_num in cells_by_col:
-                bboxes = cells_by_col[col_num]
-                x_min = min(bbox[0] for bbox in bboxes)
-                x_max = max(bbox[2] for bbox in bboxes)
-                col_boundaries[col_num] = (x_min, x_max)
-        
-        vertical_lines = _calculate_vertical_lines_with_spacing(col_boundaries)
+            print(f"📋 从 table_cells 推断: {actual_rows} 行 × {actual_cols} 列")        
+        if not table_data or 'table_cells' not in table_data:
+            raise ValueError("未找到有效的 MinerU 表格数据")
         
-        # 🔑 生成行区间
-        rows = []
-        for row_num in sorted(row_boundaries.keys()):
-            y_min, y_max = row_boundaries[row_num]
-            rows.append({
-                'y_start': y_min,
-                'y_end': y_max,
-                'bboxes': cells_by_row.get(row_num, []),
-                'row_index': row_num
-            })
-        
-        # 🔑 生成列区间
-        columns = []
-        for col_num in sorted(col_boundaries.keys()):
-            x_min, x_max = col_boundaries[col_num]
-            columns.append({
-                'x_start': x_min,
-                'x_end': x_max,
-                'col_index': col_num
-            })
+        table_cells = table_data['table_cells']
         
         # 🔑 计算表格边界框
-        all_bboxes = [
-            cell['bbox'] for cell in table_cells 
-            if 'bbox' in cell and cell.get('row', 0) <= actual_rows and cell.get('col', 0) <= actual_cols
-        ]
+        all_bboxes = [cell['bbox'] for cell in table_cells if 'bbox' in cell]
         
         if all_bboxes:
             x_min = min(bbox[0] for bbox in all_bboxes)
@@ -161,31 +114,30 @@ class TableLineGenerator:
         else:
             table_bbox = table_data.get('bbox', [0, 0, 2000, 2000])
         
-        # 🔑 返回结构信息
-        structure = {
-            'rows': rows,
-            'columns': columns,
-            'horizontal_lines': horizontal_lines,
-            'vertical_lines': vertical_lines,
-            'row_height': int(np.median([r['y_end'] - r['y_start'] for r in rows])) if rows else 0,
-            'col_widths': [c['x_end'] - c['x_start'] for c in columns],
+        # 按位置排序(从上到下,从左到右)
+        table_cells.sort(key=lambda x: (x['bbox'][1], x['bbox'][0]))
+        # 🔑 转换为统一的 ocr_data 格式
+        ocr_data = {
             'table_bbox': table_bbox,
-            'total_rows': actual_rows,
-            'total_cols': actual_cols
+            'actual_rows': actual_rows,
+            'actual_cols': actual_cols,
+            'text_boxes': table_cells
         }
         
-        return table_bbox, structure
-    
+        print(f"📊 MinerU 数据解析完成: {len(table_cells)} 个文本框")
+        
+        return table_bbox, ocr_data
+
     @staticmethod
-    def parse_ppstructure_result(ocr_result: Dict) -> Tuple[List[int], List[Dict]]:
+    def _parse_ppstructure_data(ocr_result: Dict) -> Tuple[List[int], Dict]:
         """
-        解析 PPStructure V3 的 OCR 结果
+        解析 PPStructure V3 格式数据
         
         Args:
             ocr_result: PPStructure V3 的完整 JSON 结果
         
         Returns:
-            (table_bbox, text_boxes): 表格边界框和文本框列表
+            (table_bbox, ocr_data): 表格边界框和文本框列表
         """
         # 1. 从 parsing_res_list 中找到 table 区域
         table_bbox = None
@@ -198,7 +150,7 @@ class TableLineGenerator:
         if not table_bbox:
             raise ValueError("未找到表格区域 (block_label='table')")
         
-        # 2. 从 overall_ocr_res 中提取文本框(使用 rec_boxes)
+        # 2. 从 overall_ocr_res 中提取文本框
         text_boxes = []
         if 'overall_ocr_res' in ocr_result:
             rec_boxes = ocr_result['overall_ocr_res'].get('rec_boxes', [])
@@ -207,7 +159,6 @@ class TableLineGenerator:
             # 过滤出表格区域内的文本框
             for i, bbox in enumerate(rec_boxes):
                 if len(bbox) >= 4:
-                    # bbox 格式: [x1, y1, x2, y2]
                     x1, y1, x2, y2 = bbox[:4]
                     
                     # 判断文本框是否在表格区域内
@@ -217,32 +168,177 @@ class TableLineGenerator:
                             'bbox': [int(x1), int(y1), int(x2), int(y2)],
                             'text': rec_texts[i] if i < len(rec_texts) else ''
                         })
-            # 对text_boxes从上到下,从左到右排序
-            text_boxes.sort(key=lambda x: (x['bbox'][1], x['bbox'][0]))
         
-        return table_bbox, text_boxes
+        # 按位置排序
+        text_boxes.sort(key=lambda x: (x['bbox'][1], x['bbox'][0]))
+        
+        print(f"📊 PPStructure 数据解析完成: {len(text_boxes)} 个文本框")
+        ocr_data = {
+            'table_bbox': table_bbox,
+            'text_boxes': text_boxes
+        }
         
+        return table_bbox, ocr_data
+    
+    # ==================== 统一接口:第二步 - 分析结构 ====================
+    
     def analyze_table_structure(self, 
                                y_tolerance: int = 5,
                                x_tolerance: int = 10,
-                               min_row_height: int = 20) -> Dict:
+                               min_row_height: int = 20,
+                               method: str = "auto",
+                               ) -> Dict:
         """
-        分析表格结构(行列分布)
+        分析表格结构(支持多种算法
         
         Args:
             y_tolerance: Y轴聚类容差(像素)
             x_tolerance: X轴聚类容差(像素)
             min_row_height: 最小行高(像素)
+            method: 分析方法 ("auto" / "cluster" / "mineru")
+            use_table_body: 是否使用 table_body(仅 mineru 方法有效)
+        
+        Returns:
+            表格结构信息
+        """
+        if not self.ocr_data:
+            return {}
+        
+        # 🔑 自动选择方法
+        if method == "auto":
+            # 根据数据特征自动选择
+            has_cell_index = any('row' in item and 'col' in item for item in self.ocr_data.get('text_boxes', []))
+            method = "mineru" if has_cell_index else "cluster"
+            print(f"🤖 自动选择分析方法: {method}")
+        
+        # 🔑 根据方法选择算法
+        if method == "mineru":
+            return self._analyze_by_cell_index()
+        else:
+            return self._analyze_by_clustering(y_tolerance, x_tolerance, min_row_height)
+
+    def _analyze_by_cell_index(self) -> Dict:
+        """
+        基于单元格的 row/col 索引分析(MinerU 专用)
+        
+        Args:
+            use_table_body: 是否使用 table_body 确定准确的行列数
         
         Returns:
             表格结构信息
         """
         if not self.ocr_data:
             return {}
+
+        # 🔑 确定实际行列数
+        actual_rows = self.ocr_data.get('actual_rows', 0)
+        actual_cols = self.ocr_data.get('actual_cols', 0)
+        print(f"📋 检测到: {actual_rows} 行 × {actual_cols} 列")
+
+        ocr_data = self.ocr_data.get('text_boxes', [])
         
+        # 🔑 按行列索引分组单元格
+        cells_by_row = {}
+        cells_by_col = {}
+        
+        for item in ocr_data:
+            if 'row' not in item or 'col' not in item:
+                continue
+            
+            row = item['row']
+            col = item['col']
+            bbox = item['bbox']
+            
+            if row <= actual_rows and col <= actual_cols:
+                if row not in cells_by_row:
+                    cells_by_row[row] = []
+                cells_by_row[row].append(bbox)
+                
+                if col not in cells_by_col:
+                    cells_by_col[col] = []
+                cells_by_col[col].append(bbox)
+        
+        # 🔑 计算每行的 y 边界
+        row_boundaries = {}
+        for row_num in range(1, actual_rows + 1):
+            if row_num in cells_by_row:
+                bboxes = cells_by_row[row_num]
+                y_min = min(bbox[1] for bbox in bboxes)
+                y_max = max(bbox[3] for bbox in bboxes)
+                row_boundaries[row_num] = (y_min, y_max)
+        
+        # 🔑 计算横线
+        horizontal_lines = _calculate_horizontal_lines_with_spacing(row_boundaries)
+        
+        # 🔑 计算每列的 x 边界
+        col_boundaries = {}
+        for col_num in range(1, actual_cols + 1):
+            if col_num in cells_by_col:
+                bboxes = cells_by_col[col_num]
+                x_min = min(bbox[0] for bbox in bboxes)
+                x_max = max(bbox[2] for bbox in bboxes)
+                col_boundaries[col_num] = (x_min, x_max)
+        
+        # 🔑 计算竖线
+        vertical_lines = _calculate_vertical_lines_with_spacing(col_boundaries)
+        
+        # 🔑 生成行区间
+        self.rows = []
+        for row_num in sorted(row_boundaries.keys()):
+            y_min, y_max = row_boundaries[row_num]
+            self.rows.append({
+                'y_start': y_min,
+                'y_end': y_max,
+                'bboxes': cells_by_row.get(row_num, []),
+                'row_index': row_num
+            })
+        
+        # 🔑 生成列区间
+        self.columns = []
+        for col_num in sorted(col_boundaries.keys()):
+            x_min, x_max = col_boundaries[col_num]
+            self.columns.append({
+                'x_start': x_min,
+                'x_end': x_max,
+                'col_index': col_num
+            })
+        
+        # 计算行高和列宽
+        self.row_height = int(np.median([r['y_end'] - r['y_start'] for r in self.rows])) if self.rows else 0
+        self.col_widths = [c['x_end'] - c['x_start'] for c in self.columns]
+        
+        return {
+            'rows': self.rows,
+            'columns': self.columns,
+            'horizontal_lines': horizontal_lines,
+            'vertical_lines': vertical_lines,
+            'row_height': self.row_height,
+            'col_widths': self.col_widths,
+            'table_bbox': self._get_table_bbox(),
+            'total_rows': actual_rows,
+            'total_cols': actual_cols,
+            'method': 'mineru'
+        }
+    
+    def _analyze_by_clustering(self, y_tolerance: int, x_tolerance: int, min_row_height: int) -> Dict:
+        """
+        基于坐标聚类分析(通用方法)
+        
+        Args:
+            y_tolerance: Y轴聚类容差
+            x_tolerance: X轴聚类容差
+            min_row_height: 最小行高
+        
+        Returns:
+            表格结构信息
+        """
+        if not self.ocr_data:
+            return {}
+
+        ocr_data = self.ocr_data.get('text_boxes', [])
         # 1. 提取所有bbox的Y坐标(用于行检测)
         y_coords = []
-        for item in self.ocr_data:
+        for item in ocr_data:
             bbox = item.get('bbox', [])
             if len(bbox) >= 4:
                 y1, y2 = bbox[1], bbox[3]
@@ -251,10 +347,10 @@ class TableLineGenerator:
         # 按Y坐标排序
         y_coords.sort(key=lambda x: x[0])
         
-        # 2. 聚类检测行(基于Y坐标相近的bbox)
+        # 2. 聚类检测行
         self.rows = self._cluster_rows(y_coords, y_tolerance, min_row_height)
         
-        # 3. 计算标准行高(中位数)
+        # 3. 计算标准行高
         row_heights = [row['y_end'] - row['y_start'] for row in self.rows]
         self.row_height = int(np.median(row_heights)) if row_heights else 30
         
@@ -266,20 +362,20 @@ class TableLineGenerator:
                 x1, x2 = bbox[0], bbox[2]
                 x_coords.append((x1, x2))
         
-        # 5. 聚类检测列(基于X坐标相近的bbox)
+        # 5. 聚类检测列
         self.columns = self._cluster_columns(x_coords, x_tolerance)
         
-        # 6. 计算列宽
+        # 6. 计算列宽
         self.col_widths = [col['x_end'] - col['x_start'] for col in self.columns]
         
-        # 7. 生成横线坐标列表
+        # 7. 生成横线坐标
         horizontal_lines = []
         for row in self.rows:
             horizontal_lines.append(row['y_start'])
         if self.rows:
             horizontal_lines.append(self.rows[-1]['y_end'])
         
-        # 8. 生成竖线坐标列表
+        # 8. 生成竖线坐标
         vertical_lines = []
         for col in self.columns:
             vertical_lines.append(col['x_start'])
@@ -293,9 +389,34 @@ class TableLineGenerator:
             'vertical_lines': vertical_lines,
             'row_height': self.row_height,
             'col_widths': self.col_widths,
-            'table_bbox': self._get_table_bbox()
+            'table_bbox': self._get_table_bbox(),
+            'method': 'cluster'
         }
-    
+
+    @staticmethod
+    def parse_mineru_table_result(mineru_result: Union[Dict, List], use_table_body: bool = True) -> Tuple[List[int], Dict]:
+        """
+        [已弃用] 建议使用 parse_ocr_data() + analyze_table_structure()
+        
+        保留此方法是为了向后兼容
+        """
+        import warnings
+        warnings.warn(
+            "parse_mineru_table_result() 已弃用,请使用 "
+            "parse_ocr_data() + analyze_table_structure()",
+            DeprecationWarning
+        )
+        raise NotImplementedError( "parse_mineru_table_result() 已弃用,请使用 " "parse_ocr_data() + analyze_table_structure()")
+
+    @staticmethod
+    def parse_ppstructure_result(ocr_result: Dict) -> Tuple[List[int], Dict]:
+        """
+        [推荐] 解析 PPStructure V3 的 OCR 结果
+        
+        这是第一步操作,建议继续使用
+        """
+        return TableLineGenerator._parse_ppstructure_data(ocr_result)
+        
     def _cluster_rows(self, y_coords: List[Tuple], tolerance: int, min_height: int) -> List[Dict]:
         """聚类检测行"""
         if not y_coords: