""" OCR验证工具的工具函数模块 包含数据处理、图像处理、统计分析等功能 """ import json import pandas as pd import numpy as np from pathlib import Path from PIL import Image, ImageDraw from typing import Dict, List, Optional, Tuple, Union import re import yaml import sys from ocr_validator_file_utils import process_all_images_in_content def load_config(config_path: str = "config.yaml") -> Dict: """加载配置文件""" try: with open(config_path, 'r', encoding='utf-8') as f: return yaml.safe_load(f) except Exception as e: print(f"加载配置文件失败: {e}") import traceback traceback.print_exc() # 退出 sys.exit(1) def rotate_image_and_coordinates( image: Image.Image, angle: float, coordinates_list: List[List[int]], rotate_coordinates: bool = True ) -> Tuple[Image.Image, List[List[int]]]: """ 根据角度旋转图像和坐标 - 修正版本 Args: image: 原始图像 angle: 旋转角度(度数) coordinates_list: 坐标列表,每个坐标为[x1, y1, x2, y2]格式 rotate_coordinates: 是否需要旋转坐标(针对不同OCR工具的处理方式) Returns: rotated_image: 旋转后的图像 rotated_coordinates: 处理后的坐标列表 """ if angle == 0: return image, coordinates_list # 标准化旋转角度 if angle == 270: rotation_angle = -90 # 顺时针90度 elif angle == 90: rotation_angle = 90 # 逆时针90度 elif angle == 180: rotation_angle = 180 # 180度 else: rotation_angle = angle # 旋转图像 rotated_image = image.rotate(rotation_angle, expand=True) # 如果不需要旋转坐标,直接返回原坐标 if not rotate_coordinates: return rotated_image, coordinates_list # 获取原始和旋转后的图像尺寸 orig_width, orig_height = image.size new_width, new_height = rotated_image.size # 计算旋转后的坐标 rotated_coordinates = [] for coord in coordinates_list: if len(coord) < 4: rotated_coordinates.append(coord) continue x1, y1, x2, y2 = coord[:4] # 验证原始坐标是否有效 if x1 < 0 or y1 < 0 or x2 <= x1 or y2 <= y1: print(f"警告: 无效坐标 {coord}") rotated_coordinates.append([0, 0, 50, 50]) # 使用默认坐标 continue # 根据旋转角度变换坐标 if rotation_angle == -90: # 顺时针90度 (270度逆时针) # 变换公式: (x, y) -> (orig_height - y, x) new_x1 = orig_height - y2 # 这里是y2 new_y1 = x1 new_x2 = orig_height - y1 # 这里是y1 new_y2 = x2 elif rotation_angle == 90: # 逆时针90度 # 变换公式: (x, y) -> (y, orig_width - x) new_x1 = y1 new_y1 = orig_width - x2 # 这里是x2 new_x2 = y2 new_y2 = orig_width - x1 # 这里是x1 elif rotation_angle == 180: # 180度 # 变换公式: (x, y) -> (orig_width - x, orig_height - y) new_x1 = orig_width - x2 new_y1 = orig_height - y2 new_x2 = orig_width - x1 new_y2 = orig_height - y1 else: # 任意角度算法 - 修正版本 # 将角度转换为弧度 angle_rad = np.radians(rotation_angle) cos_angle = np.cos(angle_rad) sin_angle = np.sin(angle_rad) # 原图像中心点 orig_center_x = orig_width / 2 orig_center_y = orig_height / 2 # 旋转后图像中心点 new_center_x = new_width / 2 new_center_y = new_height / 2 # 将bbox的四个角点转换为相对于原图像中心的坐标 corners = [ (x1 - orig_center_x, y1 - orig_center_y), # 左上角 (x2 - orig_center_x, y1 - orig_center_y), # 右上角 (x2 - orig_center_x, y2 - orig_center_y), # 右下角 (x1 - orig_center_x, y2 - orig_center_y) # 左下角 ] # 应用修正后的旋转矩阵变换每个角点 rotated_corners = [] for x, y in corners: # 修正后的旋转矩阵: [cos(θ) sin(θ)] [x] # [-sin(θ) cos(θ)] [y] rotated_x = x * cos_angle + y * sin_angle rotated_y = -x * sin_angle + y * cos_angle # 转换回绝对坐标(相对于新图像) abs_x = rotated_x + new_center_x abs_y = rotated_y + new_center_y rotated_corners.append((abs_x, abs_y)) # 从旋转后的四个角点计算新的边界框 x_coords = [corner[0] for corner in rotated_corners] y_coords = [corner[1] for corner in rotated_corners] new_x1 = int(min(x_coords)) new_y1 = int(min(y_coords)) new_x2 = int(max(x_coords)) new_y2 = int(max(y_coords)) # 确保坐标在有效范围内 new_x1 = max(0, min(new_width, new_x1)) new_y1 = max(0, min(new_height, new_y1)) new_x2 = max(0, min(new_width, new_x2)) new_y2 = max(0, min(new_height, new_y2)) # 确保x1 < x2, y1 < y2 if new_x1 > new_x2: new_x1, new_x2 = new_x2, new_x1 if new_y1 > new_y2: new_y1, new_y2 = new_y2, new_y1 rotated_coordinates.append([new_x1, new_y1, new_x2, new_y2]) return rotated_image, rotated_coordinates def parse_dots_ocr_data(data: List, config: Dict, tool_name: str) -> List[Dict]: """解析Dots OCR格式的数据""" tool_config = config['ocr']['tools'][tool_name] parsed_data = [] for item in data: if not isinstance(item, dict): continue # 提取字段 text = item.get(tool_config['text_field'], '') bbox = item.get(tool_config['bbox_field'], []) category = item.get(tool_config['category_field'], 'Text') confidence = item.get(tool_config.get('confidence_field', 'confidence'), config['ocr']['default_confidence']) if text and bbox and len(bbox) >= 4: parsed_data.append({ 'text': str(text).strip(), 'bbox': bbox[:4], # 确保只取前4个坐标 'category': category, 'confidence': confidence, 'source_tool': tool_name }) return parsed_data def parse_ppstructv3_data(data: Dict, config: Dict) -> List[Dict]: """解析PPStructV3格式的数据""" tool_config = config['ocr']['tools']['ppstructv3'] parsed_data = [] parsing_results = data.get(tool_config['parsing_results_field'], []) if not isinstance(parsing_results, list): return parsed_data for item in parsing_results: if not isinstance(item, dict): continue text = item.get(tool_config['text_field'], '') bbox = item.get(tool_config['bbox_field'], []) category = item.get(tool_config['category_field'], 'text') confidence = item.get( tool_config.get('confidence_field', 'confidence'), config['ocr']['default_confidence'] ) if text and bbox and len(bbox) >= 4: parsed_data.append({ 'text': str(text).strip(), 'bbox': bbox[:4], 'category': category, 'confidence': confidence, 'source_tool': 'ppstructv3' }) rec_texts = get_nested_value(data, tool_config.get('rec_texts_field', '')) rec_boxes = get_nested_value(data, tool_config.get('rec_boxes_field', '')) if isinstance(rec_texts, list) and isinstance(rec_boxes, list): for i, (text, box) in enumerate(zip(rec_texts, rec_boxes)): if text and isinstance(box, list) and len(box) >= 4: parsed_data.append({ 'text': str(text).strip(), 'bbox': box[:4], 'category': 'OCR_Text', 'source_tool': 'ppstructv3_ocr' }) return parsed_data def parse_table_recognition_v2_data(data: Dict, config: Dict) -> List[Dict]: tool_config = config['ocr']['tools']['table_recognition_v2'] parsed_data = [] tables = data.get(tool_config['parsing_results_field'], []) if not isinstance(tables, list): return parsed_data for item in tables: if not isinstance(item, dict): continue html_text = item.get(tool_config['text_field'], '') # 计算表格整体bbox cell_boxes_raw = item.get(tool_config['bbox_field'], []) if cell_boxes_raw: x1_list = [box[0] for box in cell_boxes_raw] y1_list = [box[1] for box in cell_boxes_raw] x2_list = [box[2] for box in cell_boxes_raw] y2_list = [box[3] for box in cell_boxes_raw] table_bbox = [ float(min(x1_list)), float(min(y1_list)), float(max(x2_list)), float(max(y2_list)) ] else: table_bbox = [0.0, 0.0, 0.0, 0.0] parsed_data.append({ 'text': str(html_text).strip(), 'bbox': table_bbox, 'category': item.get(tool_config.get('category_field', ''), 'table'), 'confidence': item.get(tool_config.get('confidence_field', ''), config['ocr']['default_confidence']), 'source_tool': 'table_recognition_v2', }) rec_texts = get_nested_value(item, tool_config.get('rec_texts_field', '')) rec_boxes = get_nested_value(item, tool_config.get('rec_boxes_field', '')) if isinstance(rec_texts, list) and isinstance(rec_boxes, list): for i, (text, box) in enumerate(zip(rec_texts, rec_boxes)): if text and isinstance(box, list) and len(box) >= 4: parsed_data.append({ 'text': str(text).strip(), 'bbox': box[:4], 'category': 'OCR_Text', 'source_tool': 'table_recognition_v2' }) return parsed_data def parse_mineru_data(data: List, config: Dict, tool_name="mineru") -> List[Dict]: """解析MinerU格式的数据""" tool_config = config['ocr']['tools'][tool_name] parsed_data = [] if not isinstance(data, list): return parsed_data for item in data: if not isinstance(item, dict): continue text = item.get(tool_config['text_field'], '') bbox = item.get(tool_config['bbox_field'], []) category = item.get(tool_config['category_field'], 'Text') confidence = item.get(tool_config.get('confidence_field', 'confidence'), config['ocr']['default_confidence']) # 处理文本类型 if category == 'text': if text and bbox and len(bbox) >= 4: parsed_data.append({ 'text': str(text).strip(), 'bbox': bbox[:4], 'category': category, 'confidence': confidence, 'source_tool': tool_name, 'text_level': item.get('text_level', 0) # 保留文本层级信息 }) # 处理表格类型 elif category == 'table': table_html = item.get(tool_config.get('table_body_field', 'table_body'), '') img_path = item.get(tool_config.get('img_path_field', 'img_path'), '') if bbox and len(bbox) >= 4: parsed_data.append({ 'text': table_html, 'bbox': bbox[:4], 'category': 'table', 'confidence': confidence, 'source_tool': tool_name, 'img_path': img_path, 'table_body': table_html }) table_cells = item.get(tool_config.get('table_cells_field', 'table_cells'), []) for cell in table_cells: cell_text = cell.get('text', '') cell_bbox = cell.get('bbox', []) if cell_text and cell_bbox and len(cell_bbox) >= 4: parsed_data.append({ 'text': str(cell_text).strip(), 'bbox': cell_bbox[:4], 'row': cell.get('row', -1), 'col': cell.get('col', -1), 'category': 'table_cell', 'confidence': cell.get('score', 0.0), 'source_tool': tool_name, }) # 处理图片类型 elif category == 'image': img_path = item.get(tool_config.get('img_path_field', 'img_path'), '') if bbox and len(bbox) >= 4: parsed_data.append({ 'text': '[Image]', 'bbox': bbox[:4], 'category': 'image', 'confidence': confidence, 'source_tool': tool_name, 'img_path': img_path }) elif category in ['list']: # 处理列表和标题类型 list_items = item.get('list_items', []) sub_type = item.get('sub_type', 'unordered') # 有序或无序 for list_item in list_items: if list_item and bbox and len(bbox) >= 4: parsed_data.append({ 'text': str(list_item).strip(), 'bbox': bbox[:4], 'category': category, 'sub_type': sub_type, 'confidence': confidence, 'source_tool': tool_name }) else: # 其他类型,按文本处理, header, table_cell, ... if text and bbox and len(bbox) >= 4: parsed_data.append({ 'text': str(text).strip(), 'bbox': bbox[:4], 'category': category, 'confidence': confidence, 'source_tool': tool_name }) return parsed_data def detect_mineru_structure(data: Union[List, Dict]) -> bool: """检测是否为MinerU数据结构""" if not isinstance(data, list) or len(data) == 0: return False # 检查第一个元素是否包含MinerU特征字段 first_item = data[0] if data else {} if not isinstance(first_item, dict): return False # MinerU特征:包含type字段,且值为text/table/image之一 has_type = 'type' in first_item has_bbox = 'bbox' in first_item has_text = 'text' in first_item if has_type and has_bbox and has_text: item_type = first_item.get('type', '') return item_type in ['text', 'table', 'image'] return False def detect_ocr_tool_type(data: Union[List, Dict], config: Dict) -> str: """ 自动检测OCR工具类型 - 增强版 Args: data: OCR数据(可能是列表或字典) config: 配置字典 Returns: 工具类型字符串 """ if not config['ocr']['auto_detection']['enabled']: return 'mineru' # 默认类型 rules = config['ocr']['auto_detection']['rules'] # 按优先级排序 sorted_rules = sorted(rules, key=lambda x: x.get('priority', 999)) for rule in sorted_rules: tool_type = rule['tool_type'] conditions = rule.get('conditions', []) # 检查所有条件是否满足 if _check_all_conditions(data, conditions): return tool_type # 如果所有规则都不匹配,返回默认类型 return 'dots_ocr' def _check_all_conditions(data: Union[List, Dict], conditions: List[Dict]) -> bool: """ 检查所有条件是否满足 Args: data: 数据 conditions: 条件列表 Returns: 是否所有条件都满足 """ for condition in conditions: condition_type = condition.get('type', '') if condition_type == 'field_exists': # 检查字段存在 field = condition.get('field', '') if not _check_field_exists(data, field): return False elif condition_type == 'field_not_exists': # 检查字段不存在 field = condition.get('field', '') if _check_field_exists(data, field): return False elif condition_type == 'json_structure': # 检查JSON结构类型 expected_structure = condition.get('structure', '') if expected_structure == 'array' and not isinstance(data, list): return False elif expected_structure == 'object' and not isinstance(data, dict): return False elif condition_type == 'field_value': # 检查字段值 field = condition.get('field', '') expected_value = condition.get('value') actual_value = _get_field_value(data, field) if actual_value != expected_value: return False elif condition_type == 'field_contains': # 检查字段包含某个值 field = condition.get('field', '') expected_values = condition.get('values', []) actual_value = _get_field_value(data, field) if actual_value not in expected_values: return False return True def _check_field_exists(data: Union[List, Dict], field_path: str) -> bool: """ 检查字段是否存在(支持嵌套路径) Args: data: 数据 field_path: 字段路径(支持点分隔,如 "doc_preprocessor_res.angle") Returns: 字段是否存在 """ if not field_path: return False # 处理数组情况:检查第一个元素 if isinstance(data, list): if not data: return False data = data[0] # 处理嵌套字段路径 fields = field_path.split('.') current = data for field in fields: if isinstance(current, dict) and field in current: current = current[field] else: return False return True def _get_field_value(data: Union[List, Dict], field_path: str): """ 获取字段值(支持嵌套路径) Args: data: 数据 field_path: 字段路径 Returns: 字段值,如果不存在返回 None """ if not field_path: return None # 处理数组情况:检查第一个元素 if isinstance(data, list): if not data: return None data = data[0] # 处理嵌套字段路径 fields = field_path.split('.') current = data for field in fields: if isinstance(current, dict) and field in current: current = current[field] else: return None return current def normalize_ocr_data(raw_data: Union[List, Dict], config: Dict) -> List[Dict]: """标准化OCR数据 - 支持多种工具""" tool_type = detect_ocr_tool_type(raw_data, config) if tool_type == 'dots_ocr': return parse_dots_ocr_data(raw_data, config, tool_type) elif tool_type == 'ppstructv3': return parse_ppstructv3_data(raw_data, config) elif tool_type == 'table_recognition_v2': return parse_table_recognition_v2_data(raw_data, config) elif tool_type == 'mineru': return parse_mineru_data(raw_data, config, tool_type) else: raise ValueError(f"不支持的OCR工具类型: {tool_type}") def get_rotation_angle_from_ppstructv3(data: Dict) -> float: """从PPStructV3数据中获取旋转角度""" if 'doc_preprocessor_res' in data: doc_res = data['doc_preprocessor_res'] if isinstance(doc_res, dict) and 'angle' in doc_res: return float(doc_res['angle']) return 0.0 # 修改 load_ocr_data_file 函数 def load_ocr_data_file(json_path: str, config: Dict) -> Tuple[List, str, str]: """加载OCR数据文件 - 支持多数据源配置""" json_file = Path(json_path) if not json_file.exists(): raise FileNotFoundError(f"找不到JSON文件: {json_path}") # 加载JSON数据 try: with open(json_file, 'r', encoding='utf-8') as f: raw_data = json.load(f) # 统一数据格式 ocr_data = normalize_ocr_data(raw_data, config) # 检查是否需要处理图像旋转 rotation_angle = 0.0 if isinstance(raw_data, dict): rotation_angle = get_rotation_angle_from_ppstructv3(raw_data) # 如果有旋转角度,记录下来供后续图像处理使用 if rotation_angle != 0: for item in ocr_data: item['rotation_angle'] = rotation_angle except Exception as e: raise Exception(f"加载JSON文件失败: {e}") # 加载MD文件 md_file = json_file.with_suffix('.md') md_content = "" if md_file.exists(): with open(md_file, 'r', encoding='utf-8') as f: md_content = f.read() # ✅ 关键修改:处理MD内容中的所有图片引用 md_content = process_all_images_in_content(md_content, str(json_file)) # 查找对应的图片文件 image_path = find_corresponding_image(json_file, config) return ocr_data, md_content, image_path def find_corresponding_image(json_file: Path, config: Dict) -> str: """查找对应的图片文件 - 支持多数据源""" # 从配置中获取图片目录 src_img_dir = config.get('paths', {}).get('src_img_dir', '') if not src_img_dir: # 如果没有配置图片目录,尝试在JSON文件同级目录查找 src_img_dir = json_file.parent src_img_path = Path(src_img_dir) # 支持多种图片格式 image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif'] for ext in image_extensions: image_file = src_img_path / f"{json_file.stem}{ext}" if image_file.exists(): return str(image_file) # 如果找不到,返回空字符串 return "" def process_ocr_data(ocr_data: List, config: Dict) -> Dict[str, List]: """处理OCR数据,建立文本到bbox的映射""" text_bbox_mapping = {} exclude_texts = config['ocr']['exclude_texts'] min_text_length = config['ocr']['min_text_length'] if not isinstance(ocr_data, list): return text_bbox_mapping for i, item in enumerate(ocr_data): if not isinstance(item, dict): continue text = str(item['text']).strip() if text and text not in exclude_texts and len(text) >= min_text_length: bbox = item['bbox'] if isinstance(bbox, list) and len(bbox) == 4: if text not in text_bbox_mapping: text_bbox_mapping[text] = [] text_bbox_mapping[text].append({ 'bbox': bbox, 'category': item.get('category', 'Text'), 'index': i, 'confidence': item.get('confidence', config['ocr']['default_confidence']), 'source_tool': item.get('source_tool', 'unknown'), 'rotation_angle': item.get('rotation_angle', 0.0) # 添加旋转角度信息 }) return text_bbox_mapping def find_available_ocr_files(ocr_out_dir: str) -> List[str]: """查找可用的OCR文件""" available_files = [] # 搜索多个可能的目录 search_dirs = [ Path(ocr_out_dir), ] for search_dir in search_dirs: if search_dir.exists(): # 递归搜索JSON文件 for json_file in search_dir.rglob("*.json"): if re.match(r'.*_page_\d+\.json$', json_file.name, re.IGNORECASE): available_files.append(str(json_file)) # 去重并排序 # available_files = sorted(list(set(available_files))) # 解析文件名并提取页码信息 file_info = [] for file_path in available_files: file_name = Path(file_path).stem # 提取页码 (例如从 "2023年度报告母公司_page_001" 中提取 "001") if 'page_' in file_name: try: page_part = file_name.split('page_')[-1] page_num = int(page_part) file_info.append({ 'path': file_path, 'page': page_num, 'display_name': f"第{page_num}页" }) except ValueError: # 如果无法解析页码,使用文件名 file_info.append({ 'path': file_path, 'page': len(file_info) + 1, 'display_name': Path(file_path).stem }) else: # 对于没有page_的文件,按顺序编号 file_info.append({ 'path': file_path, 'page': len(file_info) + 1, 'display_name': Path(file_path).stem }) # 按页码排序 file_info.sort(key=lambda x: x['page']) return file_info def get_ocr_tool_info(ocr_data: List) -> Dict: """获取OCR工具信息统计""" tool_counts = {} for item in ocr_data: if isinstance(item, dict): source_tool = item.get('source_tool', 'unknown') tool_counts[source_tool] = tool_counts.get(source_tool, 0) + 1 return tool_counts def get_ocr_statistics(ocr_data: List, text_bbox_mapping: Dict, marked_errors: set) -> Dict: """获取OCR数据统计信息""" if not isinstance(ocr_data, list) or not ocr_data: return { 'total_texts': 0, 'clickable_texts': 0, 'marked_errors': 0, 'categories': {}, 'accuracy_rate': 0, 'tool_info': {} } total_texts = len(ocr_data) clickable_texts = len(text_bbox_mapping) marked_errors_count = len(marked_errors) # 按类别统计 categories = {} for item in ocr_data: if isinstance(item, dict): category = item.get('category', 'Unknown') categories[category] = categories.get(category, 0) + 1 # OCR工具信息统计 tool_info = get_ocr_tool_info(ocr_data) accuracy_rate = (clickable_texts - marked_errors_count) / clickable_texts * 100 if clickable_texts > 0 else 0 return { 'total_texts': total_texts, 'clickable_texts': clickable_texts, 'marked_errors': marked_errors_count, 'categories': categories, 'accuracy_rate': accuracy_rate, 'tool_info': tool_info } def group_texts_by_category(text_bbox_mapping: Dict[str, List]) -> Dict[str, List[str]]: """按类别对文本进行分组""" categories = {} for text, info_list in text_bbox_mapping.items(): category = info_list[0]['category'] if category not in categories: categories[category] = [] categories[category].append(text) return categories def get_ocr_tool_rotation_config(ocr_data: List, config: Dict) -> Dict: """获取OCR工具的旋转配置""" if not ocr_data or not isinstance(ocr_data, list): # 默认配置 return { 'coordinates_are_pre_rotated': False } # 从第一个OCR数据项获取工具类型 first_item = ocr_data[0] if ocr_data else {} source_tool = first_item.get('source_tool', 'dots_ocr') # 获取工具配置 tools_config = config.get('ocr', {}).get('tools', {}) if source_tool in tools_config: tool_config = tools_config[source_tool] return tool_config.get('rotation', { 'coordinates_are_pre_rotated': False }) else: # 默认配置 return { 'coordinates_are_pre_rotated': False } # ocr_validator_utils.py def find_available_ocr_files_multi_source(config: Dict) -> Dict[str, List[Dict]]: """查找多个数据源的OCR文件""" all_sources = {} for source in config.get('data_sources', []): source_name = source['name'] ocr_tool = source['ocr_tool'] source_key = f"{source_name}" ocr_out_dir = source['ocr_out_dir'] if Path(ocr_out_dir).exists(): files = find_available_ocr_files(ocr_out_dir) # 为每个文件添加数据源信息 for file_info in files: file_info.update({ 'source_name': source_name, 'ocr_tool': ocr_tool, 'description': source.get('description', ''), 'src_img_dir': source.get('src_img_dir', ''), 'ocr_out_dir': ocr_out_dir }) all_sources[source_key] = { 'files': files, 'config': source } print(f"📁 找到数据源: {source_key} - {len(files)} 个文件") return all_sources def get_data_source_display_name(source_config: Dict) -> str: """生成数据源的显示名称""" name = source_config['name'] tool = source_config['ocr_tool'] description = source_config.get('description', '') # 获取工具的友好名称 tool_name_map = { 'dots_ocr': 'Dots OCR', 'ppstructv3': 'PPStructV3', 'table_recognition_v2': 'Table Recognition V2', 'mineru': 'MinerU VLM-2.5.3' } tool_display = tool_name_map.get(tool, tool) return f"{name} ({tool_display})" def get_nested_value(data: Dict, path: str, default=None): if not path: return default keys = path.split('.') value = data for key in keys: if isinstance(value, dict) and key in value: value = value[key] else: return default return value