| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849 |
- # zhch/table_mode_selector.py
- import cv2
- import numpy as np
- from paddlex import create_pipeline, create_model
- class TableModeSelector:
- def __init__(self):
- # 使用配置中的layout模型
- self.layout_model = create_model(model_name="PP-DocLayout_plus-L")
- # 使用配置中的模型进行预分析
- self.table_cls_model = create_model(model_name="PP-LCNet_x1_0_table_cls")
-
- def analyze_table_features(self, table_image):
- """分析表格特征,返回特征字典"""
- features = {}
-
- # 1. 表格类型检测
- table_type = self.get_table_type(table_image)
- features['table_type'] = table_type
-
- # 2. 复杂度分析
- complexity = self.analyze_complexity(table_image)
- features.update(complexity)
-
- # 3. 结构规整度分析
- regularity = self.analyze_regularity(table_image)
- features.update(regularity)
-
- # 4. 边框清晰度分析
- border_clarity = self.analyze_border_clarity(table_image)
- features['border_clarity'] = border_clarity
-
- return features
-
- def get_table_type(self, image):
- """获取表格类型"""
- try:
- result = next(self.table_cls_model.predict(image))
-
- # 调试输出,查看实际的结果格式
- print(f"表格分类模型输出类型: {type(result).__name__}")
-
- # 根据实际输出格式调整
- if hasattr(result, 'keys') or isinstance(result, dict):
- # 处理TopkResult对象或字典
-
- # 标准的PaddleX输出格式
- if 'class_ids' in result and 'scores' in result and 'label_names' in result:
- scores = result['scores']
- label_names = result['label_names']
-
- # 找到最高分数的索引
- max_score_idx = np.argmax(scores)
- best_label = label_names[max_score_idx]
- best_score = scores[max_score_idx]
-
- print(f"分类结果: {best_label} (置信度: {best_score:.4f})")
- return best_label
-
- # 其他可能的格式处理...
- elif 'class_ids' in result:
- class_ids = result['class_ids']
- if hasattr(class_ids, '__len__') and len(class_ids) > 0:
- class_id = int(class_ids[0])
- else:
- class_id = int(class_ids)
- return 'wired_table' if class_id == 0 else 'wireless_table'
-
- elif 'label_names' in result:
- label_names = result['label_names']
- return label_names[0] if label_names else 'wired_table'
-
- # 传统的字段名
- elif 'label' in result:
- return result['label']
- elif 'class_name' in result:
- return result['class_name']
- elif 'prediction' in result:
- return result['prediction']
- else:
- # 默认返回第一个可用值
- first_key = list(result.keys())[0]
- return str(result[first_key])
-
- # 如果上述方法都失败,使用备用方法
- print("使用备用的线条检测方法判断表格类型")
- return self.detect_table_type_by_lines(image)
-
- except Exception as e:
- print(f"表格分类出错: {e},使用备用方法")
- return self.detect_table_type_by_lines(image)
- def detect_table_type_by_lines(self, image):
- """通过线条检测判断表格类型(备用方法)"""
- gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
- edges = cv2.Canny(gray, 50, 150)
-
- # 检测直线
- lines = cv2.HoughLines(edges, 1, np.pi/180, threshold=100)
-
- if lines is not None and len(lines) > 10:
- print("检测到较多直线,判断为有线表格")
- return 'wired_table'
- else:
- print("检测到较少直线,判断为无线表格")
- return 'wireless_table'
-
- def analyze_complexity(self, image):
- """分析表格复杂度"""
- h, w = image.shape[:2]
-
- # 检测线条密度
- gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
- edges = cv2.Canny(gray, 50, 150)
- line_density = np.sum(edges > 0) / (h * w)
-
- # 检测合并单元格(简化实现)
- merged_cells_ratio = self.detect_merged_cells(image)
-
- # 文本密度分析(简化实现)
- text_density = self.analyze_text_density(image)
-
- return {
- 'line_density': line_density,
- 'merged_cells_ratio': merged_cells_ratio,
- 'text_density': text_density,
- 'size_complexity': (h * w) / (1000 * 1000) # 图像尺寸复杂度
- }
-
- def detect_merged_cells(self, image):
- """检测合并单元格比例(简化实现)"""
- # 这里使用简化的启发式方法
- gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
-
- # 检测水平线
- horizontal_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (40, 1))
- horizontal_lines = cv2.morphologyEx(gray, cv2.MORPH_OPEN, horizontal_kernel)
-
- # 检测垂直线
- vertical_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 40))
- vertical_lines = cv2.morphologyEx(gray, cv2.MORPH_OPEN, vertical_kernel)
-
- # 计算线条覆盖率作为合并单元格的指标
- h_coverage = np.sum(horizontal_lines > 0) / horizontal_lines.size
- v_coverage = np.sum(vertical_lines > 0) / vertical_lines.size
-
- # 简化的合并单元格比例估算
- merged_ratio = 1.0 - min(h_coverage, v_coverage) * 2
- return max(0.0, min(1.0, merged_ratio))
-
- def analyze_text_density(self, image):
- """分析文本密度(简化实现)"""
- gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
-
- # 使用简单的阈值化来估算文本区域
- _, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
-
- # 计算非空白像素比例作为文本密度
- text_pixels = np.sum(binary == 0) # 黑色像素(文本)
- total_pixels = binary.size
-
- return text_pixels / total_pixels
-
- def analyze_regularity(self, image):
- """分析表格结构规整度"""
- # 检测水平和垂直线条的规律性
- gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
-
- # 水平线检测
- horizontal_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (40, 1))
- horizontal_lines = cv2.morphologyEx(gray, cv2.MORPH_OPEN, horizontal_kernel)
-
- # 垂直线检测
- vertical_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 40))
- vertical_lines = cv2.morphologyEx(gray, cv2.MORPH_OPEN, vertical_kernel)
-
- # 计算规整度分数
- h_regularity = self.calculate_line_regularity(horizontal_lines, axis='horizontal')
- v_regularity = self.calculate_line_regularity(vertical_lines, axis='vertical')
-
- return {
- 'horizontal_regularity': h_regularity,
- 'vertical_regularity': v_regularity,
- 'overall_regularity': (h_regularity + v_regularity) / 2
- }
-
- def calculate_line_regularity(self, lines_image, axis='horizontal'):
- """计算线条规整度"""
- if axis == 'horizontal':
- # 水平方向投影
- projection = np.sum(lines_image, axis=1)
- else:
- # 垂直方向投影
- projection = np.sum(lines_image, axis=0)
-
- # 找到投影峰值
- peaks = []
- threshold = np.max(projection) * 0.3
- for i in range(1, len(projection) - 1):
- if projection[i] > threshold and projection[i] > projection[i-1] and projection[i] > projection[i+1]:
- peaks.append(i)
-
- if len(peaks) < 2:
- return 0.5 # 默认中等规整度
-
- # 计算峰值间距的标准差
- intervals = [peaks[i+1] - peaks[i] for i in range(len(peaks)-1)]
- if len(intervals) == 0:
- return 0.5
-
- mean_interval = np.mean(intervals)
- std_interval = np.std(intervals)
-
- # 规整度 = 1 - (标准差 / 平均值),值越大越规整
- if mean_interval == 0:
- return 0.5
-
- regularity = 1.0 - min(1.0, std_interval / mean_interval)
- return max(0.0, regularity)
-
- def analyze_border_clarity(self, image):
- """分析边框清晰度"""
- gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
-
- # 使用Sobel算子检测边缘强度
- sobelx = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)
- sobely = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
- edge_magnitude = np.sqrt(sobelx**2 + sobely**2)
-
- # 计算边缘清晰度分数
- clarity_score = np.mean(edge_magnitude) / 255.0
-
- return clarity_score
- class TableModeDecisionEngine:
- def __init__(self):
- self.rules = self.load_decision_rules()
-
- def load_decision_rules(self):
- """加载决策规则"""
- return {
- 'wired_html_mode': {
- 'conditions': [
- ('table_type', 'in', ['wired_table', 'wired', '0']), # 支持多种格式
- ('border_clarity', '>', 0.6),
- ('merged_cells_ratio', '>', 0.3),
- ('overall_regularity', '<', 0.7),
- ('size_complexity', '>', 0.5)
- ],
- 'weight': 0.9,
- 'description': '复杂有线表格,几何匹配更准确'
- },
- 'wired_e2e_mode': {
- 'conditions': [
- ('table_type', 'in', ['wired_table', 'wired', '0']),
- ('overall_regularity', '>', 0.8),
- ('merged_cells_ratio', '<', 0.2),
- ('text_density', '>', 0.3)
- ],
- 'weight': 0.8,
- 'description': '规整有线表格,端到端效果好'
- },
- 'wireless_e2e_mode': {
- 'conditions': [
- ('table_type', 'in', ['wireless_table', 'wireless', '1']),
- ('line_density', '<', 0.1),
- ('text_density', '>', 0.2)
- ],
- 'weight': 0.85,
- 'description': '无线表格,端到端预测最适合'
- },
- 'regular_mode': {
- 'conditions': [
- ('size_complexity', '>', 1.0),
- ('OR', [
- ('border_clarity', '<', 0.4),
- ('overall_regularity', '<', 0.5)
- ])
- ],
- 'weight': 0.7,
- 'description': '复杂场景,需要多模型协同'
- }
- }
-
- def check_single_condition(self, features, condition):
- """检查单个条件"""
- feature_name, operator, threshold = condition
-
- if feature_name not in features:
- return False
-
- value = features[feature_name]
-
- if operator == '>':
- return value > threshold
- elif operator == '<':
- return value < threshold
- elif operator == '==':
- return value == threshold
- elif operator == '>=':
- return value >= threshold
- elif operator == '<=':
- return value <= threshold
- elif operator == 'in':
- return value in threshold # threshold 是一个列表
-
- return False
-
- def evaluate_conditions(self, features, conditions):
- """评估条件是否满足"""
- score = 0
- total_conditions = 0
-
- for condition in conditions:
- if condition[0] == 'OR':
- # 处理OR条件
- or_satisfied = any(
- self.check_single_condition(features, sub_cond)
- for sub_cond in condition[1]
- )
- if or_satisfied:
- score += 1
- total_conditions += 1
- else:
- # 处理单个条件
- if self.check_single_condition(features, condition):
- score += 1
- total_conditions += 1
-
- return score / total_conditions if total_conditions > 0 else 0
-
- def select_best_mode(self, features):
- """选择最佳模式"""
- mode_scores = {}
-
- for mode_name, rule in self.rules.items():
- conditions_score = self.evaluate_conditions(features, rule['conditions'])
- final_score = conditions_score * rule['weight']
- mode_scores[mode_name] = {
- 'score': final_score,
- 'description': rule['description']
- }
-
- # 选择得分最高的模式
- best_mode = max(mode_scores.items(), key=lambda x: x[1]['score'])
-
- return best_mode[0], best_mode[1]
- class IntelligentTableProcessor:
- def __init__(self, config_path="./PP-StructureV3-zhch.yaml"):
- self.selector = TableModeSelector()
- self.decision_engine = TableModeDecisionEngine()
- # 暂时不初始化完整的pipeline,避免配置问题
- self.config_path = config_path
- self.pp_structure = None
-
- def execute_with_mode(self, image_path, mode, optimized_config=None):
- """根据选择的模式执行表格识别"""
- try:
- print(f"正在使用 {mode} 模式处理表格...")
- print(f"优化配置: {optimized_config}")
-
- # 创建动态配置的pipeline
- result = self.create_and_run_pipeline(image_path, mode, optimized_config)
-
- return result
-
- except Exception as e:
- print(f"执行 {mode} 模式时出错: {e}")
- print("回退到基础处理模式")
- return self.fallback_processing(image_path)
-
- def create_and_run_pipeline(self, image_path, mode, optimized_config):
- """创建并运行特定模式的pipeline"""
-
- if mode == 'wired_html_mode':
- return self.run_wired_html_mode(image_path, optimized_config)
- elif mode == 'wired_e2e_mode':
- return self.run_wired_e2e_mode(image_path, optimized_config)
- elif mode == 'wireless_e2e_mode':
- return self.run_wireless_e2e_mode(image_path, optimized_config)
- elif mode == 'regular_mode':
- return self.run_regular_mode(image_path, optimized_config)
- else:
- print(f"未知模式: {mode},使用默认处理")
- return self.fallback_processing(image_path)
-
- def run_wired_html_mode(self, image_path, config):
- """运行有线表格转HTML模式"""
- print("执行有线表格转HTML模式...")
-
- try:
- # 使用表格识别pipeline,启用HTML模式
- from paddlex import create_pipeline
-
- # 创建表格识别pipeline
- table_pipeline = create_pipeline(
- pipeline=self.config_path,
- model_dir=None
- )
-
- # 模拟配置HTML模式的参数
- # 注意:这里需要根据实际的PaddleX API调整
- result = list(table_pipeline.predict(
- image_path,
- use_wired_table_html_mode=True,
- use_wired_table_e2e_mode=False
- ))
-
- return self.format_result(result, mode='wired_html_mode')
-
- except Exception as e:
- print(f"有线表格HTML模式执行失败: {e}")
- return self.create_mock_result(mode='wired_html_mode')
-
- def run_wired_e2e_mode(self, image_path, config):
- """运行有线表格端到端模式"""
- print("执行有线表格端到端模式...")
-
- try:
- from paddlex import create_pipeline
-
- table_pipeline = create_pipeline(
- pipeline=self.config_path,
- model_dir=None
- )
-
- result = list(table_pipeline.predict(
- image_path,
- use_wired_table_html_mode=False,
- use_wired_table_e2e_mode=True
- ))
-
- return self.format_result(result, mode='wired_e2e_mode')
-
- except Exception as e:
- print(f"有线表格端到端模式执行失败: {e}")
- return self.create_mock_result(mode='wired_e2e_mode')
-
- def run_wireless_e2e_mode(self, image_path, config):
- """运行无线表格端到端模式"""
- print("执行无线表格端到端模式...")
-
- try:
- from paddlex import create_pipeline
-
- table_pipeline = create_pipeline(
- pipeline=self.config_path,
- model_dir=None
- )
-
- result = list(table_pipeline.predict(
- image_path,
- use_wireless_table_e2e_mode=True
- ))
-
- return self.format_result(result, mode='wireless_e2e_mode')
-
- except Exception as e:
- print(f"无线表格端到端模式执行失败: {e}")
- return self.create_mock_result(mode='wireless_e2e_mode')
-
- def run_regular_mode(self, image_path, config):
- """运行常规模式"""
- print("执行常规模式...")
-
- try:
- # 使用完整的PP-StructureV3 pipeline
- if self.pp_structure is None:
- from paddlex import create_pipeline
- self.pp_structure = create_pipeline(
- pipeline=self.config_path
- )
-
- result = list(self.pp_structure.predict(image_path))
-
- return self.format_result(result, mode='regular_mode')
-
- except Exception as e:
- print(f"常规模式执行失败: {e}")
- return self.create_mock_result(mode='regular_mode')
-
- def format_result(self, raw_result, mode):
- """格式化结果"""
- try:
- if not raw_result:
- return self.create_mock_result(mode)
-
- formatted_result = {
- 'mode': mode,
- 'status': 'success',
- 'raw_output': raw_result,
- 'table_count': 0,
- 'tables': []
- }
-
- # 提取表格结果 - 根据实际的PP-StructureV3输出结构
- for item in raw_result:
- print(f"处理结果项: {type(item)}")
-
- # 检查是否有table_res_list字段(PP-StructureV3的实际结构)
- if hasattr(item, 'table_res_list') or 'table_res_list' in item:
- table_res_list = item.get('table_res_list', getattr(item, 'table_res_list', []))
-
- if table_res_list and len(table_res_list) > 0:
- formatted_result['table_count'] = len(table_res_list)
-
- for i, table_item in enumerate(table_res_list):
- # 提取HTML内容
- html_content = table_item.get('pred_html', 'HTML不可用')
-
- # 提取表格区域信息
- table_region_id = table_item.get('table_region_id', i)
-
- # 尝试从cell_box_list获取bbox信息
- bbox = [0, 0, 100, 100] # 默认值
- if 'cell_box_list' in table_item and table_item['cell_box_list']:
- # 从单元格列表计算整体边界框
- bbox = self.calculate_table_bbox_from_cells(table_item['cell_box_list'])
-
- formatted_result['tables'].append({
- 'table_id': i,
- 'table_region_id': table_region_id,
- 'html': html_content,
- 'bbox': bbox,
- 'cell_count': len(table_item.get('cell_box_list', [])),
- 'neighbor_texts': table_item.get('neighbor_texts', '')
- })
-
- print(f"提取表格 {i}: region_id={table_region_id}, cells={len(table_item.get('cell_box_list', []))}")
-
- # 检查parsing_res_list(可能包含额外的表格信息)
- elif hasattr(item, 'parsing_res_list') or 'parsing_res_list' in item:
- parsing_res_list = item.get('parsing_res_list', getattr(item, 'parsing_res_list', []))
-
- for parsing_item in parsing_res_list:
- if hasattr(parsing_item, 'label') and parsing_item.label == 'table':
- # 这是一个表格解析结果
- formatted_result['table_count'] += 1
-
- html_content = getattr(parsing_item, 'html', 'HTML不可用')
- bbox = getattr(parsing_item, 'bbox', [0, 0, 100, 100])
-
- formatted_result['tables'].append({
- 'table_id': len(formatted_result['tables']),
- 'html': html_content,
- 'bbox': bbox,
- 'source': 'parsing_res'
- })
-
- # 兼容旧版本的table_recognition_res结构
- elif hasattr(item, 'table_recognition_res') or 'table_recognition_res' in item:
- table_res = item.get('table_recognition_res', getattr(item, 'table_recognition_res', None))
- if table_res and len(table_res) > 0:
- formatted_result['table_count'] = len(table_res)
- for i, table in enumerate(table_res):
- formatted_result['tables'].append({
- 'table_id': i,
- 'html': getattr(table, 'html', 'HTML不可用'),
- 'bbox': getattr(table, 'bbox', [0, 0, 100, 100])
- })
-
- return formatted_result
-
- except Exception as e:
- print(f"结果格式化失败: {e}")
- import traceback
- traceback.print_exc()
- return self.create_mock_result(mode)
- def calculate_table_bbox_from_cells(self, cell_box_list):
- """从单元格列表计算表格的整体边界框"""
- try:
- if not cell_box_list:
- return [0, 0, 100, 100]
-
- min_x = float('inf')
- min_y = float('inf')
- max_x = float('-inf')
- max_y = float('-inf')
-
- for cell in cell_box_list:
- # cell格式可能是 [x1, y1, x2, y2] 或其他格式
- if isinstance(cell, (list, tuple)) and len(cell) >= 4:
- x1, y1, x2, y2 = cell[:4]
- min_x = min(min_x, x1, x2)
- min_y = min(min_y, y1, y2)
- max_x = max(max_x, x1, x2)
- max_y = max(max_y, y1, y2)
- elif hasattr(cell, 'bbox'):
- bbox = cell.bbox
- if len(bbox) >= 4:
- x1, y1, x2, y2 = bbox[:4]
- min_x = min(min_x, x1, x2)
- min_y = min(min_y, y1, y2)
- max_x = max(max_x, x1, x2)
- max_y = max(max_y, y1, y2)
-
- if min_x == float('inf'):
- return [0, 0, 100, 100]
-
- return [int(min_x), int(min_y), int(max_x), int(max_y)]
-
- except Exception as e:
- print(f"计算表格边界框失败: {e}")
- return [0, 0, 100, 100]
-
- def create_mock_result(self, mode):
- """创建模拟结果(用于测试和错误回退)"""
- return {
- 'mode': mode,
- 'status': 'mock',
- 'message': f'{mode} 模式执行完成(模拟结果)',
- 'table_count': 1,
- 'tables': [{
- 'table_id': 0,
- 'html': f'<table><tr><td>模拟{mode}结果</td></tr></table>',
- 'bbox': [237, 201, 1416, 2044]
- }]
- }
-
- def fallback_processing(self, image_path):
- """回退处理方法"""
- print("使用基础OCR处理...")
-
- try:
- from paddlex import create_pipeline
-
- # 使用基础OCR pipeline
- ocr_pipeline = create_pipeline(pipeline="OCR")
- result = list(ocr_pipeline.predict(image_path))
-
- return {
- 'mode': 'fallback_ocr',
- 'status': 'success',
- 'raw_output': result,
- 'message': '使用基础OCR处理'
- }
-
- except Exception as e:
- print(f"回退处理也失败: {e}")
- return {
- 'mode': 'error',
- 'status': 'failed',
- 'message': f'所有处理方法都失败: {e}'
- }
- def extract_all_table_regions(self, image_path):
- """提取所有表格区域(如果有多个表格)"""
- original_image = cv2.imread(image_path)
- layout_results = list(self.selector.layout_model.predict(image_path))
-
- all_tables = []
- for layout_result in layout_results:
- for i, box_info in enumerate(layout_result['boxes']):
- if box_info['label'] == 'table':
- coordinate = box_info['coordinate']
- x1, y1, x2, y2 = [int(coord) for coord in coordinate]
-
- table_image = original_image[y1:y2, x1:x2]
-
- table_info = {
- 'table_id': i,
- 'image': table_image,
- 'bbox': [x1, y1, x2, y2],
- 'score': box_info['score']
- }
- all_tables.append(table_info)
-
- # 保存每个表格区域
- cv2.imwrite(f'./debug_table_{i}.jpg', table_image)
- print(f"表格 {i}: bbox=[{x1}, {y1}, {x2}, {y2}], score={box_info['score']:.4f}")
-
- return all_tables
-
- def extract_table_region(self, image_path):
- """从图像中提取表格区域"""
- # 读取原图
- original_image = cv2.imread(image_path)
-
- # 使用layout模型检测版面
- layout_results = list(self.selector.layout_model.predict(image_path))
-
- table_regions = []
- for layout_result in layout_results:
- # 遍历检测到的所有区域
- for box_info in layout_result['boxes']:
- if box_info['label'] == 'table':
- # 提取表格坐标
- coordinate = box_info['coordinate']
- x1, y1, x2, y2 = [int(coord) for coord in coordinate]
-
- # 裁剪表格区域
- table_image = original_image[y1:y2, x1:x2]
-
- table_regions.append({
- 'image': table_image,
- 'bbox': [x1, y1, x2, y2],
- 'score': box_info['score']
- })
-
- print(f"检测到表格区域: bbox=[{x1}, {y1}, {x2}, {y2}], score={box_info['score']:.4f}")
-
- if len(table_regions) == 0:
- print("未检测到表格区域,使用整个图像")
- return original_image
-
- # 返回得分最高的表格区域
- best_table = max(table_regions, key=lambda x: x['score'])
- return best_table['image']
- def process_table_intelligently(self, image_path, use_layout_model=True):
- """智能处理表格"""
-
- try:
- # 1. 提取表格区域
- if use_layout_model:
- table_image = self.extract_table_region(image_path)
- else:
- table_image = cv2.imread(image_path)
-
- if table_image is None or table_image.size == 0:
- print("表格区域提取失败,使用原图")
- table_image = cv2.imread(image_path)
-
- # 保存表格区域用于调试
- cv2.imwrite('./debug_table_region.jpg', table_image)
- print(f"表格区域已保存到: ./debug_table_region.jpg")
- print(f"表格区域尺寸: {table_image.shape}")
-
- # 2. 分析表格特征
- features = self.selector.analyze_table_features(table_image)
- print(f"表格特征分析: {features}")
-
- # 3. 选择最佳模式
- best_mode, mode_info = self.decision_engine.select_best_mode(features)
- print(f"选择模式: {best_mode}, 分数: {mode_info['score']:.3f}")
-
- # # 4. 动态调整配置
- # optimized_config = self.optimize_config_for_mode(best_mode, features)
- # print(f"优化配置: {optimized_config}")
-
- # 5. 执行处理
- # result = self.execute_with_mode(image_path, best_mode, optimized_config=None)
- result = self.execute_with_mode(table_image, best_mode, optimized_config=None)
-
- return {
- 'result': result,
- 'selected_mode': best_mode,
- 'mode_description': mode_info['description'],
- 'confidence_score': mode_info['score'],
- 'table_features': features,
- 'table_region_shape': table_image.shape
- }
-
- except Exception as e:
- print(f"智能处理过程出错: {e}")
- import traceback
- traceback.print_exc()
-
- # 返回错误信息
- return {
- 'result': None,
- 'selected_mode': 'error',
- 'mode_description': f'处理失败: {e}',
- 'confidence_score': 0.0,
- 'table_features': {},
- 'error': str(e)
- }
- # 修改demo函数,更好地处理结果
- def demo_intelligent_table_processing():
- """演示智能表格处理"""
-
- try:
- processor = IntelligentTableProcessor("./PP-StructureV3-zhch.yaml")
-
- # 处理您之前的复杂财务表格
- result = processor.process_table_intelligently(
- "./sample_data/600916_中国黄金_2002年报_83_94_2.png",
- use_layout_model=True
- )
-
- print("\n" + "="*50)
- print("智能表格处理结果:")
- print("="*50)
- print(f"选择的模式: {result['selected_mode']}")
- print(f"选择原因: {result['mode_description']}")
- print(f"置信度分数: {result['confidence_score']:.3f}")
-
- if 'table_region_shape' in result:
- print(f"表格区域尺寸: {result['table_region_shape']}")
-
- print(f"\n表格特征分析:")
- for key, value in result.get('table_features', {}).items():
- if isinstance(value, float):
- print(f" {key}: {value:.4f}")
- else:
- print(f" {key}: {value}")
-
- # 处理结果
- if result['result']:
- process_result = result['result']
- print(f"\n处理结果:")
- print(f" 模式: {process_result.get('mode', 'unknown')}")
- print(f" 状态: {process_result.get('status', 'unknown')}")
- print(f" 表格数量: {process_result.get('table_count', 0)}")
-
- if process_result.get('tables'):
- for i, table in enumerate(process_result['tables']):
- print(f"\n 表格 {i}:")
- print(f" bbox: {table.get('bbox', 'N/A')}")
- print(f" 单元格数量: {table.get('cell_count', 'N/A')}")
- print(f" 区域ID: {table.get('table_region_id', 'N/A')}")
-
- html_content = table.get('html', '')
- if len(html_content) > 200:
- html_preview = html_content[:200] + "..."
- else:
- html_preview = html_content
- print(f" HTML预览: {html_preview}")
-
- # 保存完整HTML到文件
- html_filename = f"./table_{i}_result.html"
- try:
- with open(html_filename, 'w', encoding='utf-8') as f:
- f.write(html_content)
- print(f" 完整HTML已保存到: {html_filename}")
- except Exception as e:
- print(f" 保存HTML失败: {e}")
-
- # 根据置信度给出建议
- if result['confidence_score'] > 0.8:
- print("\n✅ 高置信度,推荐使用该模式")
- elif result['confidence_score'] > 0.6:
- print("\n⚠️ 中等置信度,可能需要人工验证")
- else:
- print("\n❌ 低置信度,建议尝试其他模式或人工处理")
-
- return result
-
- except Exception as e:
- print(f"演示程序出错: {e}")
- import traceback
- traceback.print_exc()
- return None
- if __name__ == "__main__":
- demo_intelligent_table_processing()
|