소스 검색

feat: 支持混合模式,优化模板应用逻辑并增强OCR数据处理

zhch158_admin 2 일 전
부모
커밋
adb6af311f
1개의 변경된 파일327개의 추가작업 그리고 102개의 파일을 삭제
  1. 327 102
      table_line_generator/table_template_applier.py

+ 327 - 102
table_line_generator/table_template_applier.py

@@ -9,63 +9,60 @@ from PIL import Image, ImageDraw
 from typing import Dict, List, Tuple
 import numpy as np
 import argparse
+import sys
+
+# 添加父目录到路径
+sys.path.insert(0, str(Path(__file__).parent))
 
 try:
+    from editor.data_processor import get_structure_from_ocr
     from table_line_generator import TableLineGenerator
 except ImportError:
+    from .editor.data_processor import get_structure_from_ocr
     from .table_line_generator import TableLineGenerator
 
 
 class TableTemplateApplier:
-    """表格模板应用器"""
+    """表格模板应用器(混合模式)"""
     
     def __init__(self, template_config_path: str):
-        """
-        初始化模板应用器
-        
-        Args:
-            template_config_path: 模板配置文件路径(人工标注的结果)
-        """
+        """初始化时只提取列信息和表头信息"""
         with open(template_config_path, 'r', encoding='utf-8') as f:
             self.template = json.load(f)
         
-        # 🎯 从标注结果提取固定参数
+        # ✅ 只提取列宽(固定)
         self.col_widths = self.template['col_widths']
         
-        # 🔧 计算数据行的标准行高(排除表头)
-        rows = self.template['rows']
-        if len(rows) > 1:
-            # 计算每行的实际高度
-            row_heights = [row['y_end'] - row['y_start'] for row in rows]
-            
-            # 🎯 假设第一行是表头,从第二行开始计算
-            data_row_heights = row_heights[1:] if len(row_heights) > 1 else row_heights
-            
-            # 使用中位数作为标准行高(更稳健)
-            self.row_height = int(np.median(data_row_heights))
-            self.header_height = row_heights[0] if row_heights else self.row_height
-            
-            print(f"📏 表头高度: {self.header_height}px")
-            print(f"📏 数据行高度: {self.row_height}px")
-            print(f"   (从 {len(data_row_heights)} 行数据中计算,中位数)")
-        else:
-            # 兜底方案
-            self.row_height = self.template.get('row_height', 60)
-            self.header_height = self.row_height
-        
-        # 🎯 计算列的相对位置(从第一列开始的偏移量)
+        # ✅ 计算列的相对位置
         self.col_offsets = [0]
         for width in self.col_widths:
             self.col_offsets.append(self.col_offsets[-1] + width)
         
-        # 🎯 提取表头的Y坐标(作为参考)
-        self.template_header_y = rows[0]['y_start'] if rows else 0
+        # ✅ 提取表头高度(通常固定)
+        rows = self.template['rows']
+        if rows:
+            self.header_height = rows[0]['y_end'] - rows[0]['y_start']
+        else:
+            self.header_height = 40
+        
+        # ✅ 计算数据行高度(用于固定行高模式)
+        if len(rows) > 1:
+            data_row_heights = [row['y_end'] - row['y_start'] for row in rows[1:]]
+            # 使用中位数作为典型行高
+            self.row_height = int(np.median(data_row_heights)) if data_row_heights else 40
+            # 兜底行高(同样使用中位数)
+            self.fallback_row_height = self.row_height
+        else:
+            # 如果只有表头,使用默认值
+            self.row_height = 40
+            self.fallback_row_height = 40
         
         print(f"\n✅ 加载模板配置:")
-        print(f"   表头高度: {self.header_height}px")
-        print(f"   数据行高度: {self.row_height}px")
         print(f"   列数: {len(self.col_widths)}")
         print(f"   列宽: {self.col_widths}")
+        print(f"   表头高度: {self.header_height}px")
+        print(f"   数据行高: {self.row_height}px (用于固定行高模式)")
+        print(f"   兜底行高: {self.fallback_row_height}px (OCR失败时使用)")
     
     def detect_table_anchor(self, ocr_data: List[Dict]) -> Tuple[int, int]:
         """
@@ -128,14 +125,14 @@ class TableTemplateApplier:
         
         return total_rows
     
-    def apply_to_image(self, 
+    def apply_template_fixed(self, 
                        image: Image.Image,
                        ocr_data: List[Dict],
                        anchor_x: int = None,
                        anchor_y: int = None,
                        num_rows: int = None,
                        line_width: int = 2,
-                       line_color: Tuple[int, int, int] = (0, 0, 0)) -> Image.Image:
+                       line_color: Tuple[int, int, int] = (0, 0, 0)) -> Tuple[Image.Image, Dict]:
         """
         将模板应用到图片
         
@@ -202,62 +199,208 @@ class TableTemplateApplier:
         y_end = horizontal_lines[-1]
         for x in vertical_lines:
             draw.line([(x, y_start), (x, y_end)], fill=line_color, width=line_width)
+
+        print(f"✅ 表格绘制完成: {len(horizontal_lines)}行 × {len(vertical_lines)-1}列")
+
+                # 🔑 生成结构信息
+        structure = self._build_structure(
+            horizontal_lines, 
+            vertical_lines, 
+            anchor_x, 
+            anchor_y,
+            mode='fixed'
+        )
         
-        return img_with_lines
-    
-    def generate_structure_for_image(self,
-                                    ocr_data: List[Dict],
-                                    anchor_x: int = None,
-                                    anchor_y: int = None,
-                                    num_rows: int = None) -> Dict:
+        return img_with_lines, structure
+    
+    def apply_template_hybrid(self,
+                             image: Image.Image,
+                             ocr_data_dict: Dict,
+                             use_ocr_rows: bool = True,
+                             anchor_x: int = None,
+                             anchor_y: int = None,
+                             y_tolerance: int = 5,
+                             line_width: int = 2,
+                             line_color: Tuple[int, int, int] = (0, 0, 0)) -> Tuple[Image.Image, Dict]:
         """
-        为新图片生成表格结构配置
+        混合模式:使用模板的列 + OCR的行
         
         Args:
-            ocr_data: OCR识别结果
+            image: 目标图片
+            ocr_data: OCR识别结果(用于检测行)
+            use_ocr_rows: 是否使用OCR检测的行(True=自适应行高)
             anchor_x: 表格起始X坐标(None=自动检测)
             anchor_y: 表头起始Y坐标(None=自动检测)
-            num_rows: 总行数(None=自动检测)
+            y_tolerance: Y轴聚类容差(像素)
+            line_width: 线条宽度
+            line_color: 线条颜色
         
         Returns:
-            表格结构配置
+            绘制了表格线的图片, 结构信息
         """
+        img_with_lines = image.copy()
+        draw = ImageDraw.Draw(img_with_lines)
+        
+        ocr_data = ocr_data_dict.get('text_boxes', [])
+        
         # 🔍 自动检测锚点
         if anchor_x is None or anchor_y is None:
             detected_x, detected_y = self.detect_table_anchor(ocr_data)
             anchor_x = anchor_x or detected_x
             anchor_y = anchor_y or detected_y
         
-        # 🔍 自动检测行数
-        if num_rows is None:
-            num_rows = self.detect_table_rows(ocr_data, anchor_y)
+        print(f"\n📍 表格锚点: ({anchor_x}, {anchor_y})")
         
-        # 🎨 生成横线坐标
-        horizontal_lines = []
-        horizontal_lines.append(anchor_y)
+        # ✅ 竖线:使用模板的列宽(固定)
+        vertical_lines = [anchor_x + offset for offset in self.col_offsets]
+        print(f"📏 竖线坐标: {vertical_lines} (使用模板,共{len(vertical_lines)}条)")
+        
+        # ✅ 横线:根据模式选择
+        if use_ocr_rows and ocr_data:
+            horizontal_lines = self._detect_rows_from_ocr(
+                ocr_data, anchor_y, y_tolerance
+            )
+            print(f"📏 横线坐标: 使用OCR检测 (共{len(horizontal_lines)}条,自适应行高)")
+        else:
+            num_rows = self.detect_table_rows(ocr_data, anchor_y) if ocr_data else 10
+            horizontal_lines = self._generate_fixed_rows(anchor_y, num_rows)
+            print(f"📏 横线坐标: 使用固定行高 (共{len(horizontal_lines)}条)")
+        
+        # 🖊️ 绘制横线
+        x_start = vertical_lines[0]
+        x_end = vertical_lines[-1]
+        for y in horizontal_lines:
+            draw.line([(x_start, y), (x_end, y)], fill=line_color, width=line_width)
+        
+        # 🖊️ 绘制竖线
+        y_start = horizontal_lines[0]
+        y_end = horizontal_lines[-1]
+        for x in vertical_lines:
+            draw.line([(x, y_start), (x, y_end)], fill=line_color, width=line_width)
+        
+        print(f"✅ 表格绘制完成: {len(horizontal_lines)}行 × {len(vertical_lines)-1}列")
+        
+        # 🔑 生成结构信息
+        structure = self._build_structure(
+            horizontal_lines, 
+            vertical_lines, 
+            anchor_x, 
+            anchor_y,
+            mode='hybrid'
+        )
+        
+        return img_with_lines, structure
+
+    def _detect_rows_from_ocr(self, 
+                              ocr_data: List[Dict], 
+                              anchor_y: int,
+                              y_tolerance: int = 5) -> List[int]:
+        """
+        从OCR结果中检测行(自适应行高)
+        复用 get_structure_from_ocr 统一接口
+        
+        Args:
+            ocr_data: OCR识别结果(MinerU 格式的 text_boxes)
+            anchor_y: 表格起始Y坐标
+            y_tolerance: Y轴聚类容差(未使用,保留参数兼容性)
+        
+        Returns:
+            横线 y 坐标列表
+        """
+        if not ocr_data:
+            return [anchor_y, anchor_y + self.header_height]
+        
+        print(f"\n🔍 OCR行检测 (使用 MinerU 算法):")
+        print(f"   有效文本框数: {len(ocr_data)}")
+        
+        # 🔑 验证是否为 MinerU 格式
+        has_cell_index = any('row' in item and 'col' in item for item in ocr_data)
+        
+        if not has_cell_index:
+            print("   ⚠️ 警告: OCR数据不包含 row/col 索引,可能不是 MinerU 格式")
+            print("   ⚠️ 混合模式需要 MinerU 格式的 JSON 文件")
+            return [anchor_y, anchor_y + self.header_height]
+        
+        # 🔑 重构原始数据格式(MinerU 需要完整的 table 结构)
+        raw_data = {
+            'type': 'table',
+            'table_cells': ocr_data
+        }
+        
+        try:
+            # ✅ 使用统一接口解析和分析(无需 dummy_image)
+            table_bbox, structure = get_structure_from_ocr(
+                raw_data, 
+                tool="mineru"
+            )
+            
+            if not structure or 'horizontal_lines' not in structure:
+                print("   ⚠️ MinerU 分析失败,使用兜底方案")
+                return [anchor_y, anchor_y + self.header_height]
+            
+            # 🔑 获取横线坐标
+            horizontal_lines = structure['horizontal_lines']
+            
+            # 🔑 调整第一条线到 anchor_y(表头顶部)
+            if horizontal_lines:
+                offset = anchor_y - horizontal_lines[0]
+                horizontal_lines = [y + offset for y in horizontal_lines]
+            
+            print(f"   检测到行数: {len(horizontal_lines) - 1}")
+            
+            # 🔑 分析行高分布
+            if len(horizontal_lines) > 1:
+                row_heights = []
+                for i in range(len(horizontal_lines) - 1):
+                    h = horizontal_lines[i+1] - horizontal_lines[i]
+                    row_heights.append(h)
+                
+                if len(row_heights) > 1:
+                    import numpy as np
+                    print(f"   行高分布: min={min(row_heights)}, "
+                          f"median={int(np.median(row_heights))}, "
+                          f"max={max(row_heights)}")
+            
+            return horizontal_lines
+            
+        except Exception as e:
+            print(f"   ⚠️ 解析失败: {e}")
+            import traceback
+            traceback.print_exc()
+            return [anchor_y, anchor_y + self.header_height]
+    
+    def _generate_fixed_rows(self, anchor_y: int, num_rows: int) -> List[int]:
+        """生成固定行高的横线(兜底方案)"""
+        horizontal_lines = [anchor_y]
+        
+        # 表头
         horizontal_lines.append(anchor_y + self.header_height)
         
+        # 数据行
         current_y = anchor_y + self.header_height
         for i in range(num_rows - 1):
-            current_y += self.row_height
+            current_y += self.fallback_row_height
             horizontal_lines.append(current_y)
         
-        # 🎨 生成竖线坐标
-        vertical_lines = []
-        for offset in self.col_offsets:
-            x = anchor_x + offset
-            vertical_lines.append(x)
-        
-        # 🎨 生成行区间
+        return horizontal_lines
+    
+    def _build_structure(self,
+                        horizontal_lines: List[int],
+                        vertical_lines: List[int],
+                        anchor_x: int,
+                        anchor_y: int,
+                        mode: str = 'fixed') -> Dict:
+        """构建表格结构信息(统一)"""
+        # 生成行区间
         rows = []
-        for i in range(num_rows):
+        for i in range(len(horizontal_lines) - 1):
             rows.append({
                 'y_start': horizontal_lines[i],
                 'y_end': horizontal_lines[i + 1],
                 'bboxes': []
             })
         
-        # 🎨 生成列区间
+        # 生成列区间
         columns = []
         for i in range(len(vertical_lines) - 1):
             columns.append({
@@ -265,30 +408,40 @@ class TableTemplateApplier:
                 'x_end': vertical_lines[i + 1]
             })
         
+        # ✅ 根据模式设置正确的 mode 值
+        if mode == 'hybrid':
+            mode_value = 'hybrid'
+        elif mode == 'fixed':
+            mode_value = 'fixed'
+        else:
+            mode_value = mode  # 保留原始值
+        
         return {
             'rows': rows,
             'columns': columns,
             'horizontal_lines': horizontal_lines,
             'vertical_lines': vertical_lines,
-            'header_height': self.header_height,
-            'row_height': self.row_height,
             'col_widths': self.col_widths,
+            'row_height': self.row_height if mode == 'fixed' else None,
             'table_bbox': [
                 vertical_lines[0],
                 horizontal_lines[0],
                 vertical_lines[-1],
                 horizontal_lines[-1]
             ],
+            'mode': mode_value,  # ✅ 确保有 mode 字段
             'anchor': {'x': anchor_x, 'y': anchor_y},
-            'num_rows': num_rows
+            'modified_h_lines': [],  # ✅ 添加修改记录字段
+            'modified_v_lines': []   # ✅ 添加修改记录字段
         }
 
-
 def apply_template_to_single_file(
     applier: TableTemplateApplier,
     image_file: Path,
     json_file: Path,
     output_dir: Path,
+    structure_suffix: str = "_structure.json",
+    use_hybrid_mode: bool = True,
     line_width: int = 2,
     line_color: Tuple[int, int, int] = (0, 0, 0)
 ) -> bool:
@@ -300,6 +453,7 @@ def apply_template_to_single_file(
         image_file: 图片文件路径
         json_file: OCR JSON文件路径
         output_dir: 输出目录
+        use_hybrid_mode: 是否使用混合模式(需要 MinerU 格式)
         line_width: 线条宽度
         line_color: 线条颜色
     
@@ -313,39 +467,79 @@ def apply_template_to_single_file(
         with open(json_file, 'r', encoding='utf-8') as f:
             raw_data = json.load(f)
         
-        # 🔧 解析OCR数据(支持PPStructure格式)
+        # 🔑 自动检测 OCR 格式
+        ocr_format = None
+        
         if 'parsing_res_list' in raw_data and 'overall_ocr_res' in raw_data:
-            table_bbox, ocr_data = TableLineGenerator.parse_ppstructure_result(raw_data)
+            # PPStructure 格式
+            ocr_format = 'ppstructure'
+        elif isinstance(raw_data, (list, dict)):
+            # 尝试提取 MinerU 格式
+            table_data = None
+            if isinstance(raw_data, list):
+                for item in raw_data:
+                    if isinstance(item, dict) and item.get('type') == 'table':
+                        table_data = item
+                        break
+            elif isinstance(raw_data, dict) and raw_data.get('type') == 'table':
+                table_data = raw_data
+            if table_data and 'table_cells' in table_data:
+                ocr_format = 'mineru'
+            else:
+                raise ValueError("未识别的 OCR 格式")
         else:
-            raise ValueError("不是PPStructure格式的OCR结果")
+            raise ValueError("未识别的 OCR 格式(仅支持 PPStructure 或 MinerU)")
+
+        table_bbox, ocr_data = TableLineGenerator.parse_ocr_data(
+            raw_data, 
+            tool=ocr_format
+        )
         
-        print(f"  ✅ 加载OCR数据: {len(ocr_data)} 个文本框")
+        text_boxes = ocr_data.get('text_boxes', [])
+        print(f"  ✅ 加载OCR数据: {len(text_boxes)} 个文本框")
+        print(f"  📋 OCR格式: {ocr_format}")
         
         # 加载图片
         image = Image.open(image_file)
         print(f"  ✅ 加载图片: {image.size}")
         
-        # 🎯 应用模板
-        img_with_lines = applier.apply_to_image(
-            image,
-            ocr_data,
-            line_width=line_width,
-            line_color=line_color
-        )
+        # 🔑 验证混合模式的格式要求
+        if use_hybrid_mode and ocr_format != 'mineru':
+            print(f"  ⚠️ 警告: 混合模式需要 MinerU 格式,当前格式为 {ocr_format}")
+            print(f"  ℹ️  自动切换到完全模板模式")
+            use_hybrid_mode = False
+        
+        # 🆕 根据模式选择处理方式
+        if use_hybrid_mode:
+            print(f"  🔧 使用混合模式 (模板列 + MinerU 行)")
+            img_with_lines, structure  = applier.apply_template_hybrid(
+                image,
+                ocr_data,
+                use_ocr_rows=True,
+                line_width=line_width,
+                line_color=line_color
+            )
+        else:
+            print(f"  🔧 使用完全模板模式 (固定行高)")
+            img_with_lines, structure = applier.apply_template_fixed(
+                image,
+                text_boxes,
+                line_width=line_width,
+                line_color=line_color
+            )
         
         # 保存图片
-        output_file = output_dir / f"{image_file.stem}_with_lines.png"
+        output_file = output_dir / f"{image_file.stem}.png"
         img_with_lines.save(output_file)
         
-        # 🆕 生成并保存结构配置
-        structure = applier.generate_structure_for_image(ocr_data)
-        structure_file = output_dir / f"{image_file.stem}_structure.json"
+        # 保存结构配置
+        structure_file = output_dir / f"{image_file.stem}{structure_suffix}"
         with open(structure_file, 'w', encoding='utf-8') as f:
             json.dump(structure, f, indent=2, ensure_ascii=False)
         
         print(f"  ✅ 保存图片: {output_file.name}")
         print(f"  ✅ 保存配置: {structure_file.name}")
-        print(f"  📊 表格: {structure['num_rows']}行 x {len(structure['columns'])}列")
+        print(f"  📊 表格: {len(structure['rows'])}行 x {len(structure['columns'])}列")
         
         return True
         
@@ -361,6 +555,8 @@ def apply_template_batch(
     image_dir: str,
     json_dir: str,
     output_dir: str,
+    structure_suffix: str = "_structure.json",
+    use_hybrid_mode: bool = False,
     line_width: int = 2,
     line_color: Tuple[int, int, int] = (0, 0, 0)
 ):
@@ -414,7 +610,7 @@ def apply_template_batch(
             continue
         
         if apply_template_to_single_file(
-            applier, image_file, json_file, output_path, 
+            applier, image_file, json_file, output_path, structure_suffix, use_hybrid_mode,
             line_width, line_color
         ):
             results.append({
@@ -454,29 +650,31 @@ def apply_template_batch(
 def main():
     """主函数"""
     parser = argparse.ArgumentParser(
-        description='应用表格模板到其他页面',
+        description='应用表格模板到其他页面(支持混合模式)',
         formatter_class=argparse.RawDescriptionHelpFormatter,
         epilog="""
 示例用法:
 
-  1. 批量处理整个目录:
+  1. 混合模式(推荐,自适应行高):
      python table_template_applier.py \\
-         --template output/康强_北京农村商业银行_page_001_structure.json \\
+         --template template.json \\
          --image-dir /path/to/images \\
          --json-dir /path/to/jsons \\
-         --output-dir /path/to/output
+         --output-dir /path/to/output \\
+         --structure-suffix _structure.json \\
+         --hybrid
 
-  2. 处理单个文件:
+  2. 完全模板模式(固定行高):
      python table_template_applier.py \\
-         --template output/康强_北京农村商业银行_page_001_structure.json \\
-         --image-file /path/to/page_002.png \\
-         --json-file /path/to/page_002.json \\
-         --output-dir /path/to/output
+         --template template.json \\
+         --image-file page.png \\
+         --json-file page.json \\
+         --output-dir /path/to/output \\
+         --structure-suffix _structure.json \\
 
-输出内容:
-  - {name}_with_lines.png: 带表格线的图片
-  - {name}_structure.json: 表格结构配置
-  - batch_results.json: 批处理统计结果
+模式说明:
+  - 混合模式(--hybrid): 列宽使用模板,行高根据OCR自适应
+  - 完全模板模式: 列宽和行高都使用模板(适合固定格式表格)
         """
     )
     
@@ -522,6 +720,12 @@ def main():
         required=True,
         help='输出目录(必需)'
     )
+    output_group.add_argument(
+        '--structure-suffix',
+        type=str,
+        default='_structure.json',
+        help='输出结构配置文件后缀(默认: _structure.json)'
+    )
     
     # 绘图参数组
     draw_group = parser.add_argument_group('绘图参数')
@@ -538,6 +742,14 @@ def main():
         help='线条颜色(默认: black)'
     )
     
+    # 🆕 新增模式参数
+    mode_group = parser.add_argument_group('模式参数')
+    mode_group.add_argument(
+        '--hybrid',
+        action='store_true',
+        help='使用混合模式(模板列 + OCR行,自适应行高,推荐)'
+    )
+    
     args = parser.parse_args()
     
     # 颜色映射
@@ -581,7 +793,9 @@ def main():
         
         success = apply_template_to_single_file(
             applier, image_file, json_file, output_path,
-            args.width, line_color
+            use_hybrid_mode=args.hybrid,  # 🆕 传递混合模式参数
+            line_width=args.width, 
+            line_color=line_color
         )
         
         if success:
@@ -610,8 +824,10 @@ def main():
             str(image_dir),
             str(json_dir),
             str(output_path),
-            args.width,
-            line_color
+            structure_suffix=args.structure_suffix,
+            use_hybrid_mode=args.hybrid,  # 🆕 传递混合模式参数
+            line_width=args.width,
+            line_color=line_color,
         )
     
     else:
@@ -633,14 +849,21 @@ if __name__ == "__main__":
         
         # 默认配置
         default_config = {
-            "template": "output/table_structures/康强_北京农村商业银行_page_001_structure.json",
+            "template": "/Users/zhch158/workspace/data/流水分析/康强_北京农村商业银行.wiredtable/康强_北京农村商业银行_page_001_structure.json",
             "image-file": "/Users/zhch158/workspace/data/流水分析/康强_北京农村商业银行/ppstructurev3_client_results/康强_北京农村商业银行/康强_北京农村商业银行_page_002.png",
             "json-file": "/Users/zhch158/workspace/data/流水分析/康强_北京农村商业银行/ppstructurev3_client_results/康强_北京农村商业银行_page_002.json",
             "output-dir": "output/batch_results",
             "width": "2",
             "color": "black"
         }
-        
+        # default_config = {
+        #     "template": "/Users/zhch158/workspace/data/流水分析/B用户_扫描流水.wiredtable/B用户_扫描流水_page_001_structure.json",
+        #     "image-file": "/Users/zhch158/workspace/data/流水分析/B用户_扫描流水/mineru_vllm_results/B用户_扫描流水/B用户_扫描流水_page_002.png",
+        #     "json-file": "/Users/zhch158/workspace/data/流水分析/B用户_扫描流水/mineru_vllm_results_cell_bbox/B用户_扫描流水_page_002.json",
+        #     "output-dir": "output/batch_results",
+        #     "width": "2",
+        #     "color": "black"
+        # }        
         print("⚙️  默认参数:")
         for key, value in default_config.items():
             print(f"  --{key}: {value}")
@@ -649,5 +872,7 @@ if __name__ == "__main__":
         sys.argv = [sys.argv[0]]
         for key, value in default_config.items():
             sys.argv.extend([f"--{key}", str(value)])
+        
+        sys.argv.append("--hybrid")  # 使用混合模式
     
     sys.exit(main())