# zhch/table_mode_selector.py import cv2 import numpy as np from paddlex import create_pipeline, create_model class TableModeSelector: def __init__(self): # 使用配置中的layout模型 self.layout_model = create_model(model_name="PP-DocLayout_plus-L") # 使用配置中的模型进行预分析 self.table_cls_model = create_model(model_name="PP-LCNet_x1_0_table_cls") def analyze_table_features(self, table_image): """分析表格特征,返回特征字典""" features = {} # 1. 表格类型检测 table_type = self.get_table_type(table_image) features['table_type'] = table_type # 2. 复杂度分析 complexity = self.analyze_complexity(table_image) features.update(complexity) # 3. 结构规整度分析 regularity = self.analyze_regularity(table_image) features.update(regularity) # 4. 边框清晰度分析 border_clarity = self.analyze_border_clarity(table_image) features['border_clarity'] = border_clarity return features def get_table_type(self, image): """获取表格类型""" try: result = next(self.table_cls_model.predict(image)) # 调试输出,查看实际的结果格式 print(f"表格分类模型输出类型: {type(result).__name__}") # 根据实际输出格式调整 if hasattr(result, 'keys') or isinstance(result, dict): # 处理TopkResult对象或字典 # 标准的PaddleX输出格式 if 'class_ids' in result and 'scores' in result and 'label_names' in result: scores = result['scores'] label_names = result['label_names'] # 找到最高分数的索引 max_score_idx = np.argmax(scores) best_label = label_names[max_score_idx] best_score = scores[max_score_idx] print(f"分类结果: {best_label} (置信度: {best_score:.4f})") return best_label # 其他可能的格式处理... elif 'class_ids' in result: class_ids = result['class_ids'] if hasattr(class_ids, '__len__') and len(class_ids) > 0: class_id = int(class_ids[0]) else: class_id = int(class_ids) return 'wired_table' if class_id == 0 else 'wireless_table' elif 'label_names' in result: label_names = result['label_names'] return label_names[0] if label_names else 'wired_table' # 传统的字段名 elif 'label' in result: return result['label'] elif 'class_name' in result: return result['class_name'] elif 'prediction' in result: return result['prediction'] else: # 默认返回第一个可用值 first_key = list(result.keys())[0] return str(result[first_key]) # 如果上述方法都失败,使用备用方法 print("使用备用的线条检测方法判断表格类型") return self.detect_table_type_by_lines(image) except Exception as e: print(f"表格分类出错: {e},使用备用方法") return self.detect_table_type_by_lines(image) def detect_table_type_by_lines(self, image): """通过线条检测判断表格类型(备用方法)""" gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) edges = cv2.Canny(gray, 50, 150) # 检测直线 lines = cv2.HoughLines(edges, 1, np.pi/180, threshold=100) if lines is not None and len(lines) > 10: print("检测到较多直线,判断为有线表格") return 'wired_table' else: print("检测到较少直线,判断为无线表格") return 'wireless_table' def analyze_complexity(self, image): """分析表格复杂度""" h, w = image.shape[:2] # 检测线条密度 gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) edges = cv2.Canny(gray, 50, 150) line_density = np.sum(edges > 0) / (h * w) # 检测合并单元格(简化实现) merged_cells_ratio = self.detect_merged_cells(image) # 文本密度分析(简化实现) text_density = self.analyze_text_density(image) return { 'line_density': line_density, 'merged_cells_ratio': merged_cells_ratio, 'text_density': text_density, 'size_complexity': (h * w) / (1000 * 1000) # 图像尺寸复杂度 } def detect_merged_cells(self, image): """检测合并单元格比例(简化实现)""" # 这里使用简化的启发式方法 gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) # 检测水平线 horizontal_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (40, 1)) horizontal_lines = cv2.morphologyEx(gray, cv2.MORPH_OPEN, horizontal_kernel) # 检测垂直线 vertical_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 40)) vertical_lines = cv2.morphologyEx(gray, cv2.MORPH_OPEN, vertical_kernel) # 计算线条覆盖率作为合并单元格的指标 h_coverage = np.sum(horizontal_lines > 0) / horizontal_lines.size v_coverage = np.sum(vertical_lines > 0) / vertical_lines.size # 简化的合并单元格比例估算 merged_ratio = 1.0 - min(h_coverage, v_coverage) * 2 return max(0.0, min(1.0, merged_ratio)) def analyze_text_density(self, image): """分析文本密度(简化实现)""" gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) # 使用简单的阈值化来估算文本区域 _, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) # 计算非空白像素比例作为文本密度 text_pixels = np.sum(binary == 0) # 黑色像素(文本) total_pixels = binary.size return text_pixels / total_pixels def analyze_regularity(self, image): """分析表格结构规整度""" # 检测水平和垂直线条的规律性 gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) # 水平线检测 horizontal_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (40, 1)) horizontal_lines = cv2.morphologyEx(gray, cv2.MORPH_OPEN, horizontal_kernel) # 垂直线检测 vertical_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 40)) vertical_lines = cv2.morphologyEx(gray, cv2.MORPH_OPEN, vertical_kernel) # 计算规整度分数 h_regularity = self.calculate_line_regularity(horizontal_lines, axis='horizontal') v_regularity = self.calculate_line_regularity(vertical_lines, axis='vertical') return { 'horizontal_regularity': h_regularity, 'vertical_regularity': v_regularity, 'overall_regularity': (h_regularity + v_regularity) / 2 } def calculate_line_regularity(self, lines_image, axis='horizontal'): """计算线条规整度""" if axis == 'horizontal': # 水平方向投影 projection = np.sum(lines_image, axis=1) else: # 垂直方向投影 projection = np.sum(lines_image, axis=0) # 找到投影峰值 peaks = [] threshold = np.max(projection) * 0.3 for i in range(1, len(projection) - 1): if projection[i] > threshold and projection[i] > projection[i-1] and projection[i] > projection[i+1]: peaks.append(i) if len(peaks) < 2: return 0.5 # 默认中等规整度 # 计算峰值间距的标准差 intervals = [peaks[i+1] - peaks[i] for i in range(len(peaks)-1)] if len(intervals) == 0: return 0.5 mean_interval = np.mean(intervals) std_interval = np.std(intervals) # 规整度 = 1 - (标准差 / 平均值),值越大越规整 if mean_interval == 0: return 0.5 regularity = 1.0 - min(1.0, std_interval / mean_interval) return max(0.0, regularity) def analyze_border_clarity(self, image): """分析边框清晰度""" gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) # 使用Sobel算子检测边缘强度 sobelx = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3) sobely = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3) edge_magnitude = np.sqrt(sobelx**2 + sobely**2) # 计算边缘清晰度分数 clarity_score = np.mean(edge_magnitude) / 255.0 return clarity_score class TableModeDecisionEngine: def __init__(self): self.rules = self.load_decision_rules() def load_decision_rules(self): """加载决策规则""" return { 'wired_html_mode': { 'conditions': [ ('table_type', 'in', ['wired_table', 'wired', '0']), # 支持多种格式 ('border_clarity', '>', 0.6), ('merged_cells_ratio', '>', 0.3), ('overall_regularity', '<', 0.7), ('size_complexity', '>', 0.5) ], 'weight': 0.9, 'description': '复杂有线表格,几何匹配更准确' }, 'wired_e2e_mode': { 'conditions': [ ('table_type', 'in', ['wired_table', 'wired', '0']), ('overall_regularity', '>', 0.8), ('merged_cells_ratio', '<', 0.2), ('text_density', '>', 0.3) ], 'weight': 0.8, 'description': '规整有线表格,端到端效果好' }, 'wireless_e2e_mode': { 'conditions': [ ('table_type', 'in', ['wireless_table', 'wireless', '1']), ('line_density', '<', 0.1), ('text_density', '>', 0.2) ], 'weight': 0.85, 'description': '无线表格,端到端预测最适合' }, 'regular_mode': { 'conditions': [ ('size_complexity', '>', 1.0), ('OR', [ ('border_clarity', '<', 0.4), ('overall_regularity', '<', 0.5) ]) ], 'weight': 0.7, 'description': '复杂场景,需要多模型协同' } } def check_single_condition(self, features, condition): """检查单个条件""" feature_name, operator, threshold = condition if feature_name not in features: return False value = features[feature_name] if operator == '>': return value > threshold elif operator == '<': return value < threshold elif operator == '==': return value == threshold elif operator == '>=': return value >= threshold elif operator == '<=': return value <= threshold elif operator == 'in': return value in threshold # threshold 是一个列表 return False def evaluate_conditions(self, features, conditions): """评估条件是否满足""" score = 0 total_conditions = 0 for condition in conditions: if condition[0] == 'OR': # 处理OR条件 or_satisfied = any( self.check_single_condition(features, sub_cond) for sub_cond in condition[1] ) if or_satisfied: score += 1 total_conditions += 1 else: # 处理单个条件 if self.check_single_condition(features, condition): score += 1 total_conditions += 1 return score / total_conditions if total_conditions > 0 else 0 def select_best_mode(self, features): """选择最佳模式""" mode_scores = {} for mode_name, rule in self.rules.items(): conditions_score = self.evaluate_conditions(features, rule['conditions']) final_score = conditions_score * rule['weight'] mode_scores[mode_name] = { 'score': final_score, 'description': rule['description'] } # 选择得分最高的模式 best_mode = max(mode_scores.items(), key=lambda x: x[1]['score']) return best_mode[0], best_mode[1] class IntelligentTableProcessor: def __init__(self, config_path="./PP-StructureV3-zhch.yaml"): self.selector = TableModeSelector() self.decision_engine = TableModeDecisionEngine() # 暂时不初始化完整的pipeline,避免配置问题 self.config_path = config_path self.pp_structure = None def execute_with_mode(self, image_path, mode, optimized_config=None): """根据选择的模式执行表格识别""" try: print(f"正在使用 {mode} 模式处理表格...") print(f"优化配置: {optimized_config}") # 创建动态配置的pipeline result = self.create_and_run_pipeline(image_path, mode, optimized_config) return result except Exception as e: print(f"执行 {mode} 模式时出错: {e}") print("回退到基础处理模式") return self.fallback_processing(image_path) def create_and_run_pipeline(self, image_path, mode, optimized_config): """创建并运行特定模式的pipeline""" if mode == 'wired_html_mode': return self.run_wired_html_mode(image_path, optimized_config) elif mode == 'wired_e2e_mode': return self.run_wired_e2e_mode(image_path, optimized_config) elif mode == 'wireless_e2e_mode': return self.run_wireless_e2e_mode(image_path, optimized_config) elif mode == 'regular_mode': return self.run_regular_mode(image_path, optimized_config) else: print(f"未知模式: {mode},使用默认处理") return self.fallback_processing(image_path) def run_wired_html_mode(self, image_path, config): """运行有线表格转HTML模式""" print("执行有线表格转HTML模式...") try: # 使用表格识别pipeline,启用HTML模式 from paddlex import create_pipeline # 创建表格识别pipeline table_pipeline = create_pipeline( pipeline=self.config_path, model_dir=None ) # 模拟配置HTML模式的参数 # 注意:这里需要根据实际的PaddleX API调整 result = list(table_pipeline.predict( image_path, use_wired_table_html_mode=True, use_wired_table_e2e_mode=False )) return self.format_result(result, mode='wired_html_mode') except Exception as e: print(f"有线表格HTML模式执行失败: {e}") return self.create_mock_result(mode='wired_html_mode') def run_wired_e2e_mode(self, image_path, config): """运行有线表格端到端模式""" print("执行有线表格端到端模式...") try: from paddlex import create_pipeline table_pipeline = create_pipeline( pipeline=self.config_path, model_dir=None ) result = list(table_pipeline.predict( image_path, use_wired_table_html_mode=False, use_wired_table_e2e_mode=True )) return self.format_result(result, mode='wired_e2e_mode') except Exception as e: print(f"有线表格端到端模式执行失败: {e}") return self.create_mock_result(mode='wired_e2e_mode') def run_wireless_e2e_mode(self, image_path, config): """运行无线表格端到端模式""" print("执行无线表格端到端模式...") try: from paddlex import create_pipeline table_pipeline = create_pipeline( pipeline=self.config_path, model_dir=None ) result = list(table_pipeline.predict( image_path, use_wireless_table_e2e_mode=True )) return self.format_result(result, mode='wireless_e2e_mode') except Exception as e: print(f"无线表格端到端模式执行失败: {e}") return self.create_mock_result(mode='wireless_e2e_mode') def run_regular_mode(self, image_path, config): """运行常规模式""" print("执行常规模式...") try: # 使用完整的PP-StructureV3 pipeline if self.pp_structure is None: from paddlex import create_pipeline self.pp_structure = create_pipeline( pipeline=self.config_path ) result = list(self.pp_structure.predict(image_path)) return self.format_result(result, mode='regular_mode') except Exception as e: print(f"常规模式执行失败: {e}") return self.create_mock_result(mode='regular_mode') def format_result(self, raw_result, mode): """格式化结果""" try: if not raw_result: return self.create_mock_result(mode) formatted_result = { 'mode': mode, 'status': 'success', 'raw_output': raw_result, 'table_count': 0, 'tables': [] } # 提取表格结果 - 根据实际的PP-StructureV3输出结构 for item in raw_result: print(f"处理结果项: {type(item)}") # 检查是否有table_res_list字段(PP-StructureV3的实际结构) if hasattr(item, 'table_res_list') or 'table_res_list' in item: table_res_list = item.get('table_res_list', getattr(item, 'table_res_list', [])) if table_res_list and len(table_res_list) > 0: formatted_result['table_count'] = len(table_res_list) for i, table_item in enumerate(table_res_list): # 提取HTML内容 html_content = table_item.get('pred_html', 'HTML不可用') # 提取表格区域信息 table_region_id = table_item.get('table_region_id', i) # 尝试从cell_box_list获取bbox信息 bbox = [0, 0, 100, 100] # 默认值 if 'cell_box_list' in table_item and table_item['cell_box_list']: # 从单元格列表计算整体边界框 bbox = self.calculate_table_bbox_from_cells(table_item['cell_box_list']) formatted_result['tables'].append({ 'table_id': i, 'table_region_id': table_region_id, 'html': html_content, 'bbox': bbox, 'cell_count': len(table_item.get('cell_box_list', [])), 'neighbor_texts': table_item.get('neighbor_texts', '') }) print(f"提取表格 {i}: region_id={table_region_id}, cells={len(table_item.get('cell_box_list', []))}") # 检查parsing_res_list(可能包含额外的表格信息) elif hasattr(item, 'parsing_res_list') or 'parsing_res_list' in item: parsing_res_list = item.get('parsing_res_list', getattr(item, 'parsing_res_list', [])) for parsing_item in parsing_res_list: if hasattr(parsing_item, 'label') and parsing_item.label == 'table': # 这是一个表格解析结果 formatted_result['table_count'] += 1 html_content = getattr(parsing_item, 'html', 'HTML不可用') bbox = getattr(parsing_item, 'bbox', [0, 0, 100, 100]) formatted_result['tables'].append({ 'table_id': len(formatted_result['tables']), 'html': html_content, 'bbox': bbox, 'source': 'parsing_res' }) # 兼容旧版本的table_recognition_res结构 elif hasattr(item, 'table_recognition_res') or 'table_recognition_res' in item: table_res = item.get('table_recognition_res', getattr(item, 'table_recognition_res', None)) if table_res and len(table_res) > 0: formatted_result['table_count'] = len(table_res) for i, table in enumerate(table_res): formatted_result['tables'].append({ 'table_id': i, 'html': getattr(table, 'html', 'HTML不可用'), 'bbox': getattr(table, 'bbox', [0, 0, 100, 100]) }) return formatted_result except Exception as e: print(f"结果格式化失败: {e}") import traceback traceback.print_exc() return self.create_mock_result(mode) def calculate_table_bbox_from_cells(self, cell_box_list): """从单元格列表计算表格的整体边界框""" try: if not cell_box_list: return [0, 0, 100, 100] min_x = float('inf') min_y = float('inf') max_x = float('-inf') max_y = float('-inf') for cell in cell_box_list: # cell格式可能是 [x1, y1, x2, y2] 或其他格式 if isinstance(cell, (list, tuple)) and len(cell) >= 4: x1, y1, x2, y2 = cell[:4] min_x = min(min_x, x1, x2) min_y = min(min_y, y1, y2) max_x = max(max_x, x1, x2) max_y = max(max_y, y1, y2) elif hasattr(cell, 'bbox'): bbox = cell.bbox if len(bbox) >= 4: x1, y1, x2, y2 = bbox[:4] min_x = min(min_x, x1, x2) min_y = min(min_y, y1, y2) max_x = max(max_x, x1, x2) max_y = max(max_y, y1, y2) if min_x == float('inf'): return [0, 0, 100, 100] return [int(min_x), int(min_y), int(max_x), int(max_y)] except Exception as e: print(f"计算表格边界框失败: {e}") return [0, 0, 100, 100] def create_mock_result(self, mode): """创建模拟结果(用于测试和错误回退)""" return { 'mode': mode, 'status': 'mock', 'message': f'{mode} 模式执行完成(模拟结果)', 'table_count': 1, 'tables': [{ 'table_id': 0, 'html': f'
模拟{mode}结果
', 'bbox': [237, 201, 1416, 2044] }] } def fallback_processing(self, image_path): """回退处理方法""" print("使用基础OCR处理...") try: from paddlex import create_pipeline # 使用基础OCR pipeline ocr_pipeline = create_pipeline(pipeline="OCR") result = list(ocr_pipeline.predict(image_path)) return { 'mode': 'fallback_ocr', 'status': 'success', 'raw_output': result, 'message': '使用基础OCR处理' } except Exception as e: print(f"回退处理也失败: {e}") return { 'mode': 'error', 'status': 'failed', 'message': f'所有处理方法都失败: {e}' } def extract_all_table_regions(self, image_path): """提取所有表格区域(如果有多个表格)""" original_image = cv2.imread(image_path) layout_results = list(self.selector.layout_model.predict(image_path)) all_tables = [] for layout_result in layout_results: for i, box_info in enumerate(layout_result['boxes']): if box_info['label'] == 'table': coordinate = box_info['coordinate'] x1, y1, x2, y2 = [int(coord) for coord in coordinate] table_image = original_image[y1:y2, x1:x2] table_info = { 'table_id': i, 'image': table_image, 'bbox': [x1, y1, x2, y2], 'score': box_info['score'] } all_tables.append(table_info) # 保存每个表格区域 cv2.imwrite(f'./debug_table_{i}.jpg', table_image) print(f"表格 {i}: bbox=[{x1}, {y1}, {x2}, {y2}], score={box_info['score']:.4f}") return all_tables def extract_table_region(self, image_path): """从图像中提取表格区域""" # 读取原图 original_image = cv2.imread(image_path) # 使用layout模型检测版面 layout_results = list(self.selector.layout_model.predict(image_path)) table_regions = [] for layout_result in layout_results: # 遍历检测到的所有区域 for box_info in layout_result['boxes']: if box_info['label'] == 'table': # 提取表格坐标 coordinate = box_info['coordinate'] x1, y1, x2, y2 = [int(coord) for coord in coordinate] # 裁剪表格区域 table_image = original_image[y1:y2, x1:x2] table_regions.append({ 'image': table_image, 'bbox': [x1, y1, x2, y2], 'score': box_info['score'] }) print(f"检测到表格区域: bbox=[{x1}, {y1}, {x2}, {y2}], score={box_info['score']:.4f}") if len(table_regions) == 0: print("未检测到表格区域,使用整个图像") return original_image # 返回得分最高的表格区域 best_table = max(table_regions, key=lambda x: x['score']) return best_table['image'] def process_table_intelligently(self, image_path, use_layout_model=True): """智能处理表格""" try: # 1. 提取表格区域 if use_layout_model: table_image = self.extract_table_region(image_path) else: table_image = cv2.imread(image_path) if table_image is None or table_image.size == 0: print("表格区域提取失败,使用原图") table_image = cv2.imread(image_path) # 保存表格区域用于调试 cv2.imwrite('./debug_table_region.jpg', table_image) print(f"表格区域已保存到: ./debug_table_region.jpg") print(f"表格区域尺寸: {table_image.shape}") # 2. 分析表格特征 features = self.selector.analyze_table_features(table_image) print(f"表格特征分析: {features}") # 3. 选择最佳模式 best_mode, mode_info = self.decision_engine.select_best_mode(features) print(f"选择模式: {best_mode}, 分数: {mode_info['score']:.3f}") # # 4. 动态调整配置 # optimized_config = self.optimize_config_for_mode(best_mode, features) # print(f"优化配置: {optimized_config}") # 5. 执行处理 # result = self.execute_with_mode(image_path, best_mode, optimized_config=None) result = self.execute_with_mode(table_image, best_mode, optimized_config=None) return { 'result': result, 'selected_mode': best_mode, 'mode_description': mode_info['description'], 'confidence_score': mode_info['score'], 'table_features': features, 'table_region_shape': table_image.shape } except Exception as e: print(f"智能处理过程出错: {e}") import traceback traceback.print_exc() # 返回错误信息 return { 'result': None, 'selected_mode': 'error', 'mode_description': f'处理失败: {e}', 'confidence_score': 0.0, 'table_features': {}, 'error': str(e) } # 修改demo函数,更好地处理结果 def demo_intelligent_table_processing(): """演示智能表格处理""" try: processor = IntelligentTableProcessor("./PP-StructureV3-zhch.yaml") # 处理您之前的复杂财务表格 result = processor.process_table_intelligently( "./sample_data/600916_中国黄金_2002年报_83_94_2.png", use_layout_model=True ) print("\n" + "="*50) print("智能表格处理结果:") print("="*50) print(f"选择的模式: {result['selected_mode']}") print(f"选择原因: {result['mode_description']}") print(f"置信度分数: {result['confidence_score']:.3f}") if 'table_region_shape' in result: print(f"表格区域尺寸: {result['table_region_shape']}") print(f"\n表格特征分析:") for key, value in result.get('table_features', {}).items(): if isinstance(value, float): print(f" {key}: {value:.4f}") else: print(f" {key}: {value}") # 处理结果 if result['result']: process_result = result['result'] print(f"\n处理结果:") print(f" 模式: {process_result.get('mode', 'unknown')}") print(f" 状态: {process_result.get('status', 'unknown')}") print(f" 表格数量: {process_result.get('table_count', 0)}") if process_result.get('tables'): for i, table in enumerate(process_result['tables']): print(f"\n 表格 {i}:") print(f" bbox: {table.get('bbox', 'N/A')}") print(f" 单元格数量: {table.get('cell_count', 'N/A')}") print(f" 区域ID: {table.get('table_region_id', 'N/A')}") html_content = table.get('html', '') if len(html_content) > 200: html_preview = html_content[:200] + "..." else: html_preview = html_content print(f" HTML预览: {html_preview}") # 保存完整HTML到文件 html_filename = f"./table_{i}_result.html" try: with open(html_filename, 'w', encoding='utf-8') as f: f.write(html_content) print(f" 完整HTML已保存到: {html_filename}") except Exception as e: print(f" 保存HTML失败: {e}") # 根据置信度给出建议 if result['confidence_score'] > 0.8: print("\n✅ 高置信度,推荐使用该模式") elif result['confidence_score'] > 0.6: print("\n⚠️ 中等置信度,可能需要人工验证") else: print("\n❌ 低置信度,建议尝试其他模式或人工处理") return result except Exception as e: print(f"演示程序出错: {e}") import traceback traceback.print_exc() return None if __name__ == "__main__": demo_intelligent_table_processing()