|
|
@@ -0,0 +1,739 @@
|
|
|
+# zhch/table_mode_selector.py
|
|
|
+import cv2
|
|
|
+import numpy as np
|
|
|
+from paddlex import create_pipeline, create_model
|
|
|
+
|
|
|
+class TableModeSelector:
|
|
|
+ def __init__(self):
|
|
|
+ # 使用配置中的layout模型
|
|
|
+ self.layout_model = create_model(model_name="PP-DocLayout_plus-L")
|
|
|
+ # 使用配置中的模型进行预分析
|
|
|
+ self.table_cls_model = create_model(model_name="PP-LCNet_x1_0_table_cls")
|
|
|
+
|
|
|
+ def analyze_table_features(self, table_image):
|
|
|
+ """分析表格特征,返回特征字典"""
|
|
|
+ features = {}
|
|
|
+
|
|
|
+ # 1. 表格类型检测
|
|
|
+ table_type = self.get_table_type(table_image)
|
|
|
+ features['table_type'] = table_type
|
|
|
+
|
|
|
+ # 2. 复杂度分析
|
|
|
+ complexity = self.analyze_complexity(table_image)
|
|
|
+ features.update(complexity)
|
|
|
+
|
|
|
+ # 3. 结构规整度分析
|
|
|
+ regularity = self.analyze_regularity(table_image)
|
|
|
+ features.update(regularity)
|
|
|
+
|
|
|
+ # 4. 边框清晰度分析
|
|
|
+ border_clarity = self.analyze_border_clarity(table_image)
|
|
|
+ features['border_clarity'] = border_clarity
|
|
|
+
|
|
|
+ return features
|
|
|
+
|
|
|
+ def get_table_type(self, image):
|
|
|
+ """获取表格类型"""
|
|
|
+ try:
|
|
|
+ result = next(self.table_cls_model.predict(image))
|
|
|
+
|
|
|
+ # 调试输出,查看实际的结果格式
|
|
|
+ print(f"表格分类模型输出类型: {type(result).__name__}")
|
|
|
+
|
|
|
+ # 根据实际输出格式调整
|
|
|
+ if hasattr(result, 'keys') or isinstance(result, dict):
|
|
|
+ # 处理TopkResult对象或字典
|
|
|
+
|
|
|
+ # 标准的PaddleX输出格式
|
|
|
+ if 'class_ids' in result and 'scores' in result and 'label_names' in result:
|
|
|
+ scores = result['scores']
|
|
|
+ label_names = result['label_names']
|
|
|
+
|
|
|
+ # 找到最高分数的索引
|
|
|
+ max_score_idx = np.argmax(scores)
|
|
|
+ best_label = label_names[max_score_idx]
|
|
|
+ best_score = scores[max_score_idx]
|
|
|
+
|
|
|
+ print(f"分类结果: {best_label} (置信度: {best_score:.4f})")
|
|
|
+ return best_label
|
|
|
+
|
|
|
+ # 其他可能的格式处理...
|
|
|
+ elif 'class_ids' in result:
|
|
|
+ class_ids = result['class_ids']
|
|
|
+ if hasattr(class_ids, '__len__') and len(class_ids) > 0:
|
|
|
+ class_id = int(class_ids[0])
|
|
|
+ else:
|
|
|
+ class_id = int(class_ids)
|
|
|
+ return 'wired_table' if class_id == 0 else 'wireless_table'
|
|
|
+
|
|
|
+ elif 'label_names' in result:
|
|
|
+ label_names = result['label_names']
|
|
|
+ return label_names[0] if label_names else 'wired_table'
|
|
|
+
|
|
|
+ # 传统的字段名
|
|
|
+ elif 'label' in result:
|
|
|
+ return result['label']
|
|
|
+ elif 'class_name' in result:
|
|
|
+ return result['class_name']
|
|
|
+ elif 'prediction' in result:
|
|
|
+ return result['prediction']
|
|
|
+ else:
|
|
|
+ # 默认返回第一个可用值
|
|
|
+ first_key = list(result.keys())[0]
|
|
|
+ return str(result[first_key])
|
|
|
+
|
|
|
+ # 如果上述方法都失败,使用备用方法
|
|
|
+ print("使用备用的线条检测方法判断表格类型")
|
|
|
+ return self.detect_table_type_by_lines(image)
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"表格分类出错: {e},使用备用方法")
|
|
|
+ return self.detect_table_type_by_lines(image)
|
|
|
+
|
|
|
+ def detect_table_type_by_lines(self, image):
|
|
|
+ """通过线条检测判断表格类型(备用方法)"""
|
|
|
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
|
|
+ edges = cv2.Canny(gray, 50, 150)
|
|
|
+
|
|
|
+ # 检测直线
|
|
|
+ lines = cv2.HoughLines(edges, 1, np.pi/180, threshold=100)
|
|
|
+
|
|
|
+ if lines is not None and len(lines) > 10:
|
|
|
+ print("检测到较多直线,判断为有线表格")
|
|
|
+ return 'wired_table'
|
|
|
+ else:
|
|
|
+ print("检测到较少直线,判断为无线表格")
|
|
|
+ return 'wireless_table'
|
|
|
+
|
|
|
+ def analyze_complexity(self, image):
|
|
|
+ """分析表格复杂度"""
|
|
|
+ h, w = image.shape[:2]
|
|
|
+
|
|
|
+ # 检测线条密度
|
|
|
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
|
|
+ edges = cv2.Canny(gray, 50, 150)
|
|
|
+ line_density = np.sum(edges > 0) / (h * w)
|
|
|
+
|
|
|
+ # 检测合并单元格(简化实现)
|
|
|
+ merged_cells_ratio = self.detect_merged_cells(image)
|
|
|
+
|
|
|
+ # 文本密度分析(简化实现)
|
|
|
+ text_density = self.analyze_text_density(image)
|
|
|
+
|
|
|
+ return {
|
|
|
+ 'line_density': line_density,
|
|
|
+ 'merged_cells_ratio': merged_cells_ratio,
|
|
|
+ 'text_density': text_density,
|
|
|
+ 'size_complexity': (h * w) / (1000 * 1000) # 图像尺寸复杂度
|
|
|
+ }
|
|
|
+
|
|
|
+ def detect_merged_cells(self, image):
|
|
|
+ """检测合并单元格比例(简化实现)"""
|
|
|
+ # 这里使用简化的启发式方法
|
|
|
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
|
|
+
|
|
|
+ # 检测水平线
|
|
|
+ horizontal_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (40, 1))
|
|
|
+ horizontal_lines = cv2.morphologyEx(gray, cv2.MORPH_OPEN, horizontal_kernel)
|
|
|
+
|
|
|
+ # 检测垂直线
|
|
|
+ vertical_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 40))
|
|
|
+ vertical_lines = cv2.morphologyEx(gray, cv2.MORPH_OPEN, vertical_kernel)
|
|
|
+
|
|
|
+ # 计算线条覆盖率作为合并单元格的指标
|
|
|
+ h_coverage = np.sum(horizontal_lines > 0) / horizontal_lines.size
|
|
|
+ v_coverage = np.sum(vertical_lines > 0) / vertical_lines.size
|
|
|
+
|
|
|
+ # 简化的合并单元格比例估算
|
|
|
+ merged_ratio = 1.0 - min(h_coverage, v_coverage) * 2
|
|
|
+ return max(0.0, min(1.0, merged_ratio))
|
|
|
+
|
|
|
+ def analyze_text_density(self, image):
|
|
|
+ """分析文本密度(简化实现)"""
|
|
|
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
|
|
+
|
|
|
+ # 使用简单的阈值化来估算文本区域
|
|
|
+ _, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
|
|
+
|
|
|
+ # 计算非空白像素比例作为文本密度
|
|
|
+ text_pixels = np.sum(binary == 0) # 黑色像素(文本)
|
|
|
+ total_pixels = binary.size
|
|
|
+
|
|
|
+ return text_pixels / total_pixels
|
|
|
+
|
|
|
+ def analyze_regularity(self, image):
|
|
|
+ """分析表格结构规整度"""
|
|
|
+ # 检测水平和垂直线条的规律性
|
|
|
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
|
|
+
|
|
|
+ # 水平线检测
|
|
|
+ horizontal_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (40, 1))
|
|
|
+ horizontal_lines = cv2.morphologyEx(gray, cv2.MORPH_OPEN, horizontal_kernel)
|
|
|
+
|
|
|
+ # 垂直线检测
|
|
|
+ vertical_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 40))
|
|
|
+ vertical_lines = cv2.morphologyEx(gray, cv2.MORPH_OPEN, vertical_kernel)
|
|
|
+
|
|
|
+ # 计算规整度分数
|
|
|
+ h_regularity = self.calculate_line_regularity(horizontal_lines, axis='horizontal')
|
|
|
+ v_regularity = self.calculate_line_regularity(vertical_lines, axis='vertical')
|
|
|
+
|
|
|
+ return {
|
|
|
+ 'horizontal_regularity': h_regularity,
|
|
|
+ 'vertical_regularity': v_regularity,
|
|
|
+ 'overall_regularity': (h_regularity + v_regularity) / 2
|
|
|
+ }
|
|
|
+
|
|
|
+ def calculate_line_regularity(self, lines_image, axis='horizontal'):
|
|
|
+ """计算线条规整度"""
|
|
|
+ if axis == 'horizontal':
|
|
|
+ # 水平方向投影
|
|
|
+ projection = np.sum(lines_image, axis=1)
|
|
|
+ else:
|
|
|
+ # 垂直方向投影
|
|
|
+ projection = np.sum(lines_image, axis=0)
|
|
|
+
|
|
|
+ # 找到投影峰值
|
|
|
+ peaks = []
|
|
|
+ threshold = np.max(projection) * 0.3
|
|
|
+ for i in range(1, len(projection) - 1):
|
|
|
+ if projection[i] > threshold and projection[i] > projection[i-1] and projection[i] > projection[i+1]:
|
|
|
+ peaks.append(i)
|
|
|
+
|
|
|
+ if len(peaks) < 2:
|
|
|
+ return 0.5 # 默认中等规整度
|
|
|
+
|
|
|
+ # 计算峰值间距的标准差
|
|
|
+ intervals = [peaks[i+1] - peaks[i] for i in range(len(peaks)-1)]
|
|
|
+ if len(intervals) == 0:
|
|
|
+ return 0.5
|
|
|
+
|
|
|
+ mean_interval = np.mean(intervals)
|
|
|
+ std_interval = np.std(intervals)
|
|
|
+
|
|
|
+ # 规整度 = 1 - (标准差 / 平均值),值越大越规整
|
|
|
+ if mean_interval == 0:
|
|
|
+ return 0.5
|
|
|
+
|
|
|
+ regularity = 1.0 - min(1.0, std_interval / mean_interval)
|
|
|
+ return max(0.0, regularity)
|
|
|
+
|
|
|
+ def analyze_border_clarity(self, image):
|
|
|
+ """分析边框清晰度"""
|
|
|
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
|
|
+
|
|
|
+ # 使用Sobel算子检测边缘强度
|
|
|
+ sobelx = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)
|
|
|
+ sobely = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
|
|
|
+ edge_magnitude = np.sqrt(sobelx**2 + sobely**2)
|
|
|
+
|
|
|
+ # 计算边缘清晰度分数
|
|
|
+ clarity_score = np.mean(edge_magnitude) / 255.0
|
|
|
+
|
|
|
+ return clarity_score
|
|
|
+
|
|
|
+class TableModeDecisionEngine:
|
|
|
+ def __init__(self):
|
|
|
+ self.rules = self.load_decision_rules()
|
|
|
+
|
|
|
+ def load_decision_rules(self):
|
|
|
+ """加载决策规则"""
|
|
|
+ return {
|
|
|
+ 'wired_html_mode': {
|
|
|
+ 'conditions': [
|
|
|
+ ('table_type', 'in', ['wired_table', 'wired', '0']), # 支持多种格式
|
|
|
+ ('border_clarity', '>', 0.6),
|
|
|
+ ('merged_cells_ratio', '>', 0.3),
|
|
|
+ ('overall_regularity', '<', 0.7),
|
|
|
+ ('size_complexity', '>', 0.5)
|
|
|
+ ],
|
|
|
+ 'weight': 0.9,
|
|
|
+ 'description': '复杂有线表格,几何匹配更准确'
|
|
|
+ },
|
|
|
+ 'wired_e2e_mode': {
|
|
|
+ 'conditions': [
|
|
|
+ ('table_type', 'in', ['wired_table', 'wired', '0']),
|
|
|
+ ('overall_regularity', '>', 0.8),
|
|
|
+ ('merged_cells_ratio', '<', 0.2),
|
|
|
+ ('text_density', '>', 0.3)
|
|
|
+ ],
|
|
|
+ 'weight': 0.8,
|
|
|
+ 'description': '规整有线表格,端到端效果好'
|
|
|
+ },
|
|
|
+ 'wireless_e2e_mode': {
|
|
|
+ 'conditions': [
|
|
|
+ ('table_type', 'in', ['wireless_table', 'wireless', '1']),
|
|
|
+ ('line_density', '<', 0.1),
|
|
|
+ ('text_density', '>', 0.2)
|
|
|
+ ],
|
|
|
+ 'weight': 0.85,
|
|
|
+ 'description': '无线表格,端到端预测最适合'
|
|
|
+ },
|
|
|
+ 'regular_mode': {
|
|
|
+ 'conditions': [
|
|
|
+ ('size_complexity', '>', 1.0),
|
|
|
+ ('OR', [
|
|
|
+ ('border_clarity', '<', 0.4),
|
|
|
+ ('overall_regularity', '<', 0.5)
|
|
|
+ ])
|
|
|
+ ],
|
|
|
+ 'weight': 0.7,
|
|
|
+ 'description': '复杂场景,需要多模型协同'
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ def check_single_condition(self, features, condition):
|
|
|
+ """检查单个条件"""
|
|
|
+ feature_name, operator, threshold = condition
|
|
|
+
|
|
|
+ if feature_name not in features:
|
|
|
+ return False
|
|
|
+
|
|
|
+ value = features[feature_name]
|
|
|
+
|
|
|
+ if operator == '>':
|
|
|
+ return value > threshold
|
|
|
+ elif operator == '<':
|
|
|
+ return value < threshold
|
|
|
+ elif operator == '==':
|
|
|
+ return value == threshold
|
|
|
+ elif operator == '>=':
|
|
|
+ return value >= threshold
|
|
|
+ elif operator == '<=':
|
|
|
+ return value <= threshold
|
|
|
+ elif operator == 'in':
|
|
|
+ return value in threshold # threshold 是一个列表
|
|
|
+
|
|
|
+ return False
|
|
|
+
|
|
|
+ def evaluate_conditions(self, features, conditions):
|
|
|
+ """评估条件是否满足"""
|
|
|
+ score = 0
|
|
|
+ total_conditions = 0
|
|
|
+
|
|
|
+ for condition in conditions:
|
|
|
+ if condition[0] == 'OR':
|
|
|
+ # 处理OR条件
|
|
|
+ or_satisfied = any(
|
|
|
+ self.check_single_condition(features, sub_cond)
|
|
|
+ for sub_cond in condition[1]
|
|
|
+ )
|
|
|
+ if or_satisfied:
|
|
|
+ score += 1
|
|
|
+ total_conditions += 1
|
|
|
+ else:
|
|
|
+ # 处理单个条件
|
|
|
+ if self.check_single_condition(features, condition):
|
|
|
+ score += 1
|
|
|
+ total_conditions += 1
|
|
|
+
|
|
|
+ return score / total_conditions if total_conditions > 0 else 0
|
|
|
+
|
|
|
+ def select_best_mode(self, features):
|
|
|
+ """选择最佳模式"""
|
|
|
+ mode_scores = {}
|
|
|
+
|
|
|
+ for mode_name, rule in self.rules.items():
|
|
|
+ conditions_score = self.evaluate_conditions(features, rule['conditions'])
|
|
|
+ final_score = conditions_score * rule['weight']
|
|
|
+ mode_scores[mode_name] = {
|
|
|
+ 'score': final_score,
|
|
|
+ 'description': rule['description']
|
|
|
+ }
|
|
|
+
|
|
|
+ # 选择得分最高的模式
|
|
|
+ best_mode = max(mode_scores.items(), key=lambda x: x[1]['score'])
|
|
|
+
|
|
|
+ return best_mode[0], best_mode[1]
|
|
|
+
|
|
|
+class IntelligentTableProcessor:
|
|
|
+ def __init__(self, config_path="./PP-StructureV3-zhch.yaml"):
|
|
|
+ self.selector = TableModeSelector()
|
|
|
+ self.decision_engine = TableModeDecisionEngine()
|
|
|
+ # 暂时不初始化完整的pipeline,避免配置问题
|
|
|
+ self.config_path = config_path
|
|
|
+ self.pp_structure = None
|
|
|
+
|
|
|
+ def execute_with_mode(self, image_path, mode, optimized_config=None):
|
|
|
+ """根据选择的模式执行表格识别"""
|
|
|
+ try:
|
|
|
+ print(f"正在使用 {mode} 模式处理表格...")
|
|
|
+ print(f"优化配置: {optimized_config}")
|
|
|
+
|
|
|
+ # 创建动态配置的pipeline
|
|
|
+ result = self.create_and_run_pipeline(image_path, mode, optimized_config)
|
|
|
+
|
|
|
+ return result
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"执行 {mode} 模式时出错: {e}")
|
|
|
+ print("回退到基础处理模式")
|
|
|
+ return self.fallback_processing(image_path)
|
|
|
+
|
|
|
+ def create_and_run_pipeline(self, image_path, mode, optimized_config):
|
|
|
+ """创建并运行特定模式的pipeline"""
|
|
|
+
|
|
|
+ if mode == 'wired_html_mode':
|
|
|
+ return self.run_wired_html_mode(image_path, optimized_config)
|
|
|
+ elif mode == 'wired_e2e_mode':
|
|
|
+ return self.run_wired_e2e_mode(image_path, optimized_config)
|
|
|
+ elif mode == 'wireless_e2e_mode':
|
|
|
+ return self.run_wireless_e2e_mode(image_path, optimized_config)
|
|
|
+ elif mode == 'regular_mode':
|
|
|
+ return self.run_regular_mode(image_path, optimized_config)
|
|
|
+ else:
|
|
|
+ print(f"未知模式: {mode},使用默认处理")
|
|
|
+ return self.fallback_processing(image_path)
|
|
|
+
|
|
|
+ def run_wired_html_mode(self, image_path, config):
|
|
|
+ """运行有线表格转HTML模式"""
|
|
|
+ print("执行有线表格转HTML模式...")
|
|
|
+
|
|
|
+ try:
|
|
|
+ # 使用表格识别pipeline,启用HTML模式
|
|
|
+ from paddlex import create_pipeline
|
|
|
+
|
|
|
+ # 创建表格识别pipeline
|
|
|
+ table_pipeline = create_pipeline(
|
|
|
+ pipeline=self.config_path,
|
|
|
+ model_dir=None
|
|
|
+ )
|
|
|
+
|
|
|
+ # 模拟配置HTML模式的参数
|
|
|
+ # 注意:这里需要根据实际的PaddleX API调整
|
|
|
+ result = list(table_pipeline.predict(
|
|
|
+ image_path,
|
|
|
+ use_wired_table_html_mode=True,
|
|
|
+ use_wired_table_e2e_mode=False
|
|
|
+ ))
|
|
|
+
|
|
|
+ return self.format_result(result, mode='wired_html_mode')
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"有线表格HTML模式执行失败: {e}")
|
|
|
+ return self.create_mock_result(mode='wired_html_mode')
|
|
|
+
|
|
|
+ def run_wired_e2e_mode(self, image_path, config):
|
|
|
+ """运行有线表格端到端模式"""
|
|
|
+ print("执行有线表格端到端模式...")
|
|
|
+
|
|
|
+ try:
|
|
|
+ from paddlex import create_pipeline
|
|
|
+
|
|
|
+ table_pipeline = create_pipeline(
|
|
|
+ pipeline=self.config_path,
|
|
|
+ model_dir=None
|
|
|
+ )
|
|
|
+
|
|
|
+ result = list(table_pipeline.predict(
|
|
|
+ image_path,
|
|
|
+ use_wired_table_html_mode=False,
|
|
|
+ use_wired_table_e2e_mode=True
|
|
|
+ ))
|
|
|
+
|
|
|
+ return self.format_result(result, mode='wired_e2e_mode')
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"有线表格端到端模式执行失败: {e}")
|
|
|
+ return self.create_mock_result(mode='wired_e2e_mode')
|
|
|
+
|
|
|
+ def run_wireless_e2e_mode(self, image_path, config):
|
|
|
+ """运行无线表格端到端模式"""
|
|
|
+ print("执行无线表格端到端模式...")
|
|
|
+
|
|
|
+ try:
|
|
|
+ from paddlex import create_pipeline
|
|
|
+
|
|
|
+ table_pipeline = create_pipeline(
|
|
|
+ pipeline=self.config_path,
|
|
|
+ model_dir=None
|
|
|
+ )
|
|
|
+
|
|
|
+ result = list(table_pipeline.predict(
|
|
|
+ image_path,
|
|
|
+ use_wireless_table_e2e_mode=True
|
|
|
+ ))
|
|
|
+
|
|
|
+ return self.format_result(result, mode='wireless_e2e_mode')
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"无线表格端到端模式执行失败: {e}")
|
|
|
+ return self.create_mock_result(mode='wireless_e2e_mode')
|
|
|
+
|
|
|
+ def run_regular_mode(self, image_path, config):
|
|
|
+ """运行常规模式"""
|
|
|
+ print("执行常规模式...")
|
|
|
+
|
|
|
+ try:
|
|
|
+ # 使用完整的PP-StructureV3 pipeline
|
|
|
+ if self.pp_structure is None:
|
|
|
+ from paddlex import create_pipeline
|
|
|
+ self.pp_structure = create_pipeline(
|
|
|
+ pipeline=self.config_path
|
|
|
+ )
|
|
|
+
|
|
|
+ result = list(self.pp_structure.predict(image_path))
|
|
|
+
|
|
|
+ return self.format_result(result, mode='regular_mode')
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"常规模式执行失败: {e}")
|
|
|
+ return self.create_mock_result(mode='regular_mode')
|
|
|
+
|
|
|
+ def format_result(self, raw_result, mode):
|
|
|
+ """格式化结果"""
|
|
|
+ try:
|
|
|
+ if not raw_result:
|
|
|
+ return self.create_mock_result(mode)
|
|
|
+
|
|
|
+ formatted_result = {
|
|
|
+ 'mode': mode,
|
|
|
+ 'status': 'success',
|
|
|
+ 'raw_output': raw_result,
|
|
|
+ 'table_count': 0,
|
|
|
+ 'tables': []
|
|
|
+ }
|
|
|
+
|
|
|
+ # 提取表格结果
|
|
|
+ for item in raw_result:
|
|
|
+ if hasattr(item, 'table_recognition_res') or 'table_recognition_res' in item:
|
|
|
+ table_res = item.get('table_recognition_res', item.table_recognition_res)
|
|
|
+ if table_res and len(table_res) > 0:
|
|
|
+ formatted_result['table_count'] = len(table_res)
|
|
|
+ for i, table in enumerate(table_res):
|
|
|
+ formatted_result['tables'].append({
|
|
|
+ 'table_id': i,
|
|
|
+ 'html': getattr(table, 'html', 'HTML不可用'),
|
|
|
+ 'bbox': getattr(table, 'bbox', [0, 0, 100, 100])
|
|
|
+ })
|
|
|
+
|
|
|
+ return formatted_result
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"结果格式化失败: {e}")
|
|
|
+ return self.create_mock_result(mode)
|
|
|
+
|
|
|
+ def create_mock_result(self, mode):
|
|
|
+ """创建模拟结果(用于测试和错误回退)"""
|
|
|
+ return {
|
|
|
+ 'mode': mode,
|
|
|
+ 'status': 'mock',
|
|
|
+ 'message': f'{mode} 模式执行完成(模拟结果)',
|
|
|
+ 'table_count': 1,
|
|
|
+ 'tables': [{
|
|
|
+ 'table_id': 0,
|
|
|
+ 'html': f'<table><tr><td>模拟{mode}结果</td></tr></table>',
|
|
|
+ 'bbox': [237, 201, 1416, 2044]
|
|
|
+ }]
|
|
|
+ }
|
|
|
+
|
|
|
+ def fallback_processing(self, image_path):
|
|
|
+ """回退处理方法"""
|
|
|
+ print("使用基础OCR处理...")
|
|
|
+
|
|
|
+ try:
|
|
|
+ from paddlex import create_pipeline
|
|
|
+
|
|
|
+ # 使用基础OCR pipeline
|
|
|
+ ocr_pipeline = create_pipeline(pipeline="OCR")
|
|
|
+ result = list(ocr_pipeline.predict(image_path))
|
|
|
+
|
|
|
+ return {
|
|
|
+ 'mode': 'fallback_ocr',
|
|
|
+ 'status': 'success',
|
|
|
+ 'raw_output': result,
|
|
|
+ 'message': '使用基础OCR处理'
|
|
|
+ }
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"回退处理也失败: {e}")
|
|
|
+ return {
|
|
|
+ 'mode': 'error',
|
|
|
+ 'status': 'failed',
|
|
|
+ 'message': f'所有处理方法都失败: {e}'
|
|
|
+ }
|
|
|
+
|
|
|
+ def extract_all_table_regions(self, image_path):
|
|
|
+ """提取所有表格区域(如果有多个表格)"""
|
|
|
+ original_image = cv2.imread(image_path)
|
|
|
+ layout_results = list(self.selector.layout_model.predict(image_path))
|
|
|
+
|
|
|
+ all_tables = []
|
|
|
+ for layout_result in layout_results:
|
|
|
+ for i, box_info in enumerate(layout_result['boxes']):
|
|
|
+ if box_info['label'] == 'table':
|
|
|
+ coordinate = box_info['coordinate']
|
|
|
+ x1, y1, x2, y2 = [int(coord) for coord in coordinate]
|
|
|
+
|
|
|
+ table_image = original_image[y1:y2, x1:x2]
|
|
|
+
|
|
|
+ table_info = {
|
|
|
+ 'table_id': i,
|
|
|
+ 'image': table_image,
|
|
|
+ 'bbox': [x1, y1, x2, y2],
|
|
|
+ 'score': box_info['score']
|
|
|
+ }
|
|
|
+ all_tables.append(table_info)
|
|
|
+
|
|
|
+ # 保存每个表格区域
|
|
|
+ cv2.imwrite(f'./debug_table_{i}.jpg', table_image)
|
|
|
+ print(f"表格 {i}: bbox=[{x1}, {y1}, {x2}, {y2}], score={box_info['score']:.4f}")
|
|
|
+
|
|
|
+ return all_tables
|
|
|
+
|
|
|
+ def extract_table_region(self, image_path):
|
|
|
+ """从图像中提取表格区域"""
|
|
|
+ # 读取原图
|
|
|
+ original_image = cv2.imread(image_path)
|
|
|
+
|
|
|
+ # 使用layout模型检测版面
|
|
|
+ layout_results = list(self.selector.layout_model.predict(image_path))
|
|
|
+
|
|
|
+ table_regions = []
|
|
|
+ for layout_result in layout_results:
|
|
|
+ # 遍历检测到的所有区域
|
|
|
+ for box_info in layout_result['boxes']:
|
|
|
+ if box_info['label'] == 'table':
|
|
|
+ # 提取表格坐标
|
|
|
+ coordinate = box_info['coordinate']
|
|
|
+ x1, y1, x2, y2 = [int(coord) for coord in coordinate]
|
|
|
+
|
|
|
+ # 裁剪表格区域
|
|
|
+ table_image = original_image[y1:y2, x1:x2]
|
|
|
+
|
|
|
+ table_regions.append({
|
|
|
+ 'image': table_image,
|
|
|
+ 'bbox': [x1, y1, x2, y2],
|
|
|
+ 'score': box_info['score']
|
|
|
+ })
|
|
|
+
|
|
|
+ print(f"检测到表格区域: bbox=[{x1}, {y1}, {x2}, {y2}], score={box_info['score']:.4f}")
|
|
|
+
|
|
|
+ if len(table_regions) == 0:
|
|
|
+ print("未检测到表格区域,使用整个图像")
|
|
|
+ return original_image
|
|
|
+
|
|
|
+ # 返回得分最高的表格区域
|
|
|
+ best_table = max(table_regions, key=lambda x: x['score'])
|
|
|
+ return best_table['image']
|
|
|
+
|
|
|
+ def process_table_intelligently(self, image_path, use_layout_model=True):
|
|
|
+ """智能处理表格"""
|
|
|
+
|
|
|
+ try:
|
|
|
+ # 1. 提取表格区域
|
|
|
+ if use_layout_model:
|
|
|
+ table_image = self.extract_table_region(image_path)
|
|
|
+ else:
|
|
|
+ table_image = cv2.imread(image_path)
|
|
|
+
|
|
|
+ if table_image is None or table_image.size == 0:
|
|
|
+ print("表格区域提取失败,使用原图")
|
|
|
+ table_image = cv2.imread(image_path)
|
|
|
+
|
|
|
+ # 保存表格区域用于调试
|
|
|
+ cv2.imwrite('./debug_table_region.jpg', table_image)
|
|
|
+ print(f"表格区域已保存到: ./debug_table_region.jpg")
|
|
|
+ print(f"表格区域尺寸: {table_image.shape}")
|
|
|
+
|
|
|
+ # 2. 分析表格特征
|
|
|
+ features = self.selector.analyze_table_features(table_image)
|
|
|
+ print(f"表格特征分析: {features}")
|
|
|
+
|
|
|
+ # 3. 选择最佳模式
|
|
|
+ best_mode, mode_info = self.decision_engine.select_best_mode(features)
|
|
|
+ print(f"选择模式: {best_mode}, 分数: {mode_info['score']:.3f}")
|
|
|
+
|
|
|
+ # # 4. 动态调整配置
|
|
|
+ # optimized_config = self.optimize_config_for_mode(best_mode, features)
|
|
|
+ # print(f"优化配置: {optimized_config}")
|
|
|
+
|
|
|
+ # 5. 执行处理
|
|
|
+ result = self.execute_with_mode(image_path, best_mode, optimized_config=None)
|
|
|
+
|
|
|
+ return {
|
|
|
+ 'result': result,
|
|
|
+ 'selected_mode': best_mode,
|
|
|
+ 'mode_description': mode_info['description'],
|
|
|
+ 'confidence_score': mode_info['score'],
|
|
|
+ 'table_features': features,
|
|
|
+ 'table_region_shape': table_image.shape
|
|
|
+ }
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"智能处理过程出错: {e}")
|
|
|
+ import traceback
|
|
|
+ traceback.print_exc()
|
|
|
+
|
|
|
+ # 返回错误信息
|
|
|
+ return {
|
|
|
+ 'result': None,
|
|
|
+ 'selected_mode': 'error',
|
|
|
+ 'mode_description': f'处理失败: {e}',
|
|
|
+ 'confidence_score': 0.0,
|
|
|
+ 'table_features': {},
|
|
|
+ 'error': str(e)
|
|
|
+ }
|
|
|
+
|
|
|
+# 修改demo函数,更好地处理结果
|
|
|
+def demo_intelligent_table_processing():
|
|
|
+ """演示智能表格处理"""
|
|
|
+
|
|
|
+ try:
|
|
|
+ processor = IntelligentTableProcessor("./PP-StructureV3-zhch.yaml")
|
|
|
+
|
|
|
+ # 处理您之前的复杂财务表格
|
|
|
+ result = processor.process_table_intelligently(
|
|
|
+ "./sample_data/600916_中国黄金_2002年报_83_94_2.png",
|
|
|
+ use_layout_model=True
|
|
|
+ )
|
|
|
+
|
|
|
+ print("\n" + "="*50)
|
|
|
+ print("智能表格处理结果:")
|
|
|
+ print("="*50)
|
|
|
+ print(f"选择的模式: {result['selected_mode']}")
|
|
|
+ print(f"选择原因: {result['mode_description']}")
|
|
|
+ print(f"置信度分数: {result['confidence_score']:.3f}")
|
|
|
+
|
|
|
+ if 'table_region_shape' in result:
|
|
|
+ print(f"表格区域尺寸: {result['table_region_shape']}")
|
|
|
+
|
|
|
+ print(f"\n表格特征分析:")
|
|
|
+ for key, value in result.get('table_features', {}).items():
|
|
|
+ if isinstance(value, float):
|
|
|
+ print(f" {key}: {value:.4f}")
|
|
|
+ else:
|
|
|
+ print(f" {key}: {value}")
|
|
|
+
|
|
|
+ # 处理结果
|
|
|
+ if result['result']:
|
|
|
+ process_result = result['result']
|
|
|
+ print(f"\n处理结果:")
|
|
|
+ print(f" 模式: {process_result.get('mode', 'unknown')}")
|
|
|
+ print(f" 状态: {process_result.get('status', 'unknown')}")
|
|
|
+ print(f" 表格数量: {process_result.get('table_count', 0)}")
|
|
|
+
|
|
|
+ if process_result.get('tables'):
|
|
|
+ for i, table in enumerate(process_result['tables']):
|
|
|
+ print(f" 表格 {i}: bbox={table.get('bbox', 'N/A')}")
|
|
|
+ html_preview = table.get('html', '')[:100]
|
|
|
+ print(f" HTML预览: {html_preview}...")
|
|
|
+
|
|
|
+ # 根据置信度给出建议
|
|
|
+ if result['confidence_score'] > 0.8:
|
|
|
+ print("\n✅ 高置信度,推荐使用该模式")
|
|
|
+ elif result['confidence_score'] > 0.6:
|
|
|
+ print("\n⚠️ 中等置信度,可能需要人工验证")
|
|
|
+ else:
|
|
|
+ print("\n❌ 低置信度,建议尝试其他模式或人工处理")
|
|
|
+
|
|
|
+ return result
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"演示程序出错: {e}")
|
|
|
+ import traceback
|
|
|
+ traceback.print_exc()
|
|
|
+ return None
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ demo_intelligent_table_processing()
|