| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879 |
- """
- OCR验证工具的工具函数模块
- 包含数据处理、图像处理、统计分析等功能
- """
- import json
- import pandas as pd
- import numpy as np
- from pathlib import Path
- from PIL import Image, ImageDraw
- from typing import Dict, List, Optional, Tuple, Union
- import re
- import yaml
- import sys
- from ocr_validator_file_utils import process_all_images_in_content
- def load_config(config_path: str = "config.yaml") -> Dict:
- """加载配置文件"""
- try:
- with open(config_path, 'r', encoding='utf-8') as f:
- return yaml.safe_load(f)
- except Exception as e:
- print(f"加载配置文件失败: {e}")
- import traceback
- traceback.print_exc()
- # 退出
- sys.exit(1)
- def rotate_image_and_coordinates(
- image: Image.Image,
- angle: float,
- coordinates_list: List[List[int]],
- rotate_coordinates: bool = True
- ) -> Tuple[Image.Image, List[List[int]]]:
- """
- 根据角度旋转图像和坐标 - 修正版本
-
- Args:
- image: 原始图像
- angle: 旋转角度(度数)
- coordinates_list: 坐标列表,每个坐标为[x1, y1, x2, y2]格式
- rotate_coordinates: 是否需要旋转坐标(针对不同OCR工具的处理方式)
-
- Returns:
- rotated_image: 旋转后的图像
- rotated_coordinates: 处理后的坐标列表
- """
- if angle == 0:
- return image, coordinates_list
-
- # 标准化旋转角度
- if angle == 270:
- rotation_angle = -90 # 顺时针90度
- elif angle == 90:
- rotation_angle = 90 # 逆时针90度
- elif angle == 180:
- rotation_angle = 180 # 180度
- else:
- rotation_angle = angle
-
- # 旋转图像
- rotated_image = image.rotate(rotation_angle, expand=True)
-
- # 如果不需要旋转坐标,直接返回原坐标
- if not rotate_coordinates:
- return rotated_image, coordinates_list
-
- # 获取原始和旋转后的图像尺寸
- orig_width, orig_height = image.size
- new_width, new_height = rotated_image.size
-
- # 计算旋转后的坐标
- rotated_coordinates = []
-
- for coord in coordinates_list:
- if len(coord) < 4:
- rotated_coordinates.append(coord)
- continue
-
- x1, y1, x2, y2 = coord[:4]
-
- # 验证原始坐标是否有效
- if x1 < 0 or y1 < 0 or x2 <= x1 or y2 <= y1:
- print(f"警告: 无效坐标 {coord}")
- rotated_coordinates.append([0, 0, 50, 50]) # 使用默认坐标
- continue
-
- # 根据旋转角度变换坐标
- if rotation_angle == -90: # 顺时针90度 (270度逆时针)
- # 变换公式: (x, y) -> (orig_height - y, x)
- new_x1 = orig_height - y2 # 这里是y2
- new_y1 = x1
- new_x2 = orig_height - y1 # 这里是y1
- new_y2 = x2
-
- elif rotation_angle == 90: # 逆时针90度
- # 变换公式: (x, y) -> (y, orig_width - x)
- new_x1 = y1
- new_y1 = orig_width - x2 # 这里是x2
- new_x2 = y2
- new_y2 = orig_width - x1 # 这里是x1
- elif rotation_angle == 180: # 180度
- # 变换公式: (x, y) -> (orig_width - x, orig_height - y)
- new_x1 = orig_width - x2
- new_y1 = orig_height - y2
- new_x2 = orig_width - x1
- new_y2 = orig_height - y1
-
- else: # 任意角度算法 - 修正版本
- # 将角度转换为弧度
- angle_rad = np.radians(rotation_angle)
- cos_angle = np.cos(angle_rad)
- sin_angle = np.sin(angle_rad)
-
- # 原图像中心点
- orig_center_x = orig_width / 2
- orig_center_y = orig_height / 2
-
- # 旋转后图像中心点
- new_center_x = new_width / 2
- new_center_y = new_height / 2
-
- # 将bbox的四个角点转换为相对于原图像中心的坐标
- corners = [
- (x1 - orig_center_x, y1 - orig_center_y), # 左上角
- (x2 - orig_center_x, y1 - orig_center_y), # 右上角
- (x2 - orig_center_x, y2 - orig_center_y), # 右下角
- (x1 - orig_center_x, y2 - orig_center_y) # 左下角
- ]
-
- # 应用修正后的旋转矩阵变换每个角点
- rotated_corners = []
- for x, y in corners:
- # 修正后的旋转矩阵: [cos(θ) sin(θ)] [x]
- # [-sin(θ) cos(θ)] [y]
- rotated_x = x * cos_angle + y * sin_angle
- rotated_y = -x * sin_angle + y * cos_angle
-
- # 转换回绝对坐标(相对于新图像)
- abs_x = rotated_x + new_center_x
- abs_y = rotated_y + new_center_y
-
- rotated_corners.append((abs_x, abs_y))
-
- # 从旋转后的四个角点计算新的边界框
- x_coords = [corner[0] for corner in rotated_corners]
- y_coords = [corner[1] for corner in rotated_corners]
-
- new_x1 = int(min(x_coords))
- new_y1 = int(min(y_coords))
- new_x2 = int(max(x_coords))
- new_y2 = int(max(y_coords))
-
- # 确保坐标在有效范围内
- new_x1 = max(0, min(new_width, new_x1))
- new_y1 = max(0, min(new_height, new_y1))
- new_x2 = max(0, min(new_width, new_x2))
- new_y2 = max(0, min(new_height, new_y2))
-
- # 确保x1 < x2, y1 < y2
- if new_x1 > new_x2:
- new_x1, new_x2 = new_x2, new_x1
- if new_y1 > new_y2:
- new_y1, new_y2 = new_y2, new_y1
-
- rotated_coordinates.append([new_x1, new_y1, new_x2, new_y2])
-
- return rotated_image, rotated_coordinates
- def parse_dots_ocr_data(data: List, config: Dict, tool_name: str) -> List[Dict]:
- """解析Dots OCR格式的数据"""
- tool_config = config['ocr']['tools'][tool_name]
- parsed_data = []
-
- for item in data:
- if not isinstance(item, dict):
- continue
-
- # 提取字段
- text = item.get(tool_config['text_field'], '')
- bbox = item.get(tool_config['bbox_field'], [])
- category = item.get(tool_config['category_field'], 'Text')
- confidence = item.get(tool_config.get('confidence_field', 'confidence'),
- config['ocr']['default_confidence'])
-
- if text and bbox and len(bbox) >= 4:
- parsed_data.append({
- 'text': str(text).strip(),
- 'bbox': bbox[:4], # 确保只取前4个坐标
- 'category': category,
- 'confidence': confidence,
- 'source_tool': tool_name
- })
-
- return parsed_data
- def parse_ppstructv3_data(data: Dict, config: Dict) -> List[Dict]:
- """解析PPStructV3格式的数据"""
- tool_config = config['ocr']['tools']['ppstructv3']
- parsed_data = []
-
- parsing_results = data.get(tool_config['parsing_results_field'], [])
- if not isinstance(parsing_results, list):
- return parsed_data
-
- for item in parsing_results:
- if not isinstance(item, dict):
- continue
-
- text = item.get(tool_config['text_field'], '')
- bbox = item.get(tool_config['bbox_field'], [])
- category = item.get(tool_config['category_field'], 'text')
- confidence = item.get(
- tool_config.get('confidence_field', 'confidence'),
- config['ocr']['default_confidence']
- )
-
- if text and bbox and len(bbox) >= 4:
- parsed_data.append({
- 'text': str(text).strip(),
- 'bbox': bbox[:4],
- 'category': category,
- 'confidence': confidence,
- 'source_tool': 'ppstructv3'
- })
-
- rec_texts = get_nested_value(data, tool_config.get('rec_texts_field', ''))
- rec_boxes = get_nested_value(data, tool_config.get('rec_boxes_field', ''))
- if isinstance(rec_texts, list) and isinstance(rec_boxes, list):
- for i, (text, box) in enumerate(zip(rec_texts, rec_boxes)):
- if text and isinstance(box, list) and len(box) >= 4:
- parsed_data.append({
- 'text': str(text).strip(),
- 'bbox': box[:4],
- 'category': 'OCR_Text',
- 'source_tool': 'ppstructv3_ocr'
- })
-
- return parsed_data
- def parse_table_recognition_v2_data(data: Dict, config: Dict) -> List[Dict]:
- tool_config = config['ocr']['tools']['table_recognition_v2']
- parsed_data = []
- tables = data.get(tool_config['parsing_results_field'], [])
- if not isinstance(tables, list):
- return parsed_data
- for item in tables:
- if not isinstance(item, dict):
- continue
- html_text = item.get(tool_config['text_field'], '')
- # 计算表格整体bbox
- cell_boxes_raw = item.get(tool_config['bbox_field'], [])
- if cell_boxes_raw:
- x1_list = [box[0] for box in cell_boxes_raw]
- y1_list = [box[1] for box in cell_boxes_raw]
- x2_list = [box[2] for box in cell_boxes_raw]
- y2_list = [box[3] for box in cell_boxes_raw]
- table_bbox = [
- float(min(x1_list)),
- float(min(y1_list)),
- float(max(x2_list)),
- float(max(y2_list))
- ]
- else:
- table_bbox = [0.0, 0.0, 0.0, 0.0]
- parsed_data.append({
- 'text': str(html_text).strip(),
- 'bbox': table_bbox,
- 'category': item.get(tool_config.get('category_field', ''), 'table'),
- 'confidence': item.get(tool_config.get('confidence_field', ''), config['ocr']['default_confidence']),
- 'source_tool': 'table_recognition_v2',
- })
- rec_texts = get_nested_value(item, tool_config.get('rec_texts_field', ''))
- rec_boxes = get_nested_value(item, tool_config.get('rec_boxes_field', ''))
- if isinstance(rec_texts, list) and isinstance(rec_boxes, list):
- for i, (text, box) in enumerate(zip(rec_texts, rec_boxes)):
- if text and isinstance(box, list) and len(box) >= 4:
- parsed_data.append({
- 'text': str(text).strip(),
- 'bbox': box[:4],
- 'category': 'OCR_Text',
- 'source_tool': 'table_recognition_v2'
- })
-
- return parsed_data
- def parse_mineru_data(data: List, config: Dict, tool_name="mineru") -> List[Dict]:
- """解析MinerU格式的数据"""
- tool_config = config['ocr']['tools'][tool_name]
- parsed_data = []
-
- if not isinstance(data, list):
- return parsed_data
-
- for item in data:
- if not isinstance(item, dict):
- continue
-
- text = item.get(tool_config['text_field'], '')
- bbox = item.get(tool_config['bbox_field'], [])
- category = item.get(tool_config['category_field'], 'Text')
- confidence = item.get(tool_config.get('confidence_field', 'confidence'),
- config['ocr']['default_confidence'])
- # 处理文本类型
- if category == 'text':
- if text and bbox and len(bbox) >= 4:
- parsed_data.append({
- 'text': str(text).strip(),
- 'bbox': bbox[:4],
- 'category': category,
- 'confidence': confidence,
- 'source_tool': tool_name,
- 'text_level': item.get('text_level', 0) # 保留文本层级信息
- })
-
- # 处理表格类型
- elif category == 'table':
- table_html = item.get(tool_config.get('table_body_field', 'table_body'), '')
- img_path = item.get(tool_config.get('img_path_field', 'img_path'), '')
-
- if bbox and len(bbox) >= 4:
- parsed_data.append({
- 'text': table_html,
- 'bbox': bbox[:4],
- 'category': 'table',
- 'confidence': confidence,
- 'source_tool': tool_name,
- 'img_path': img_path,
- 'table_body': table_html
- })
- table_cells = item.get(tool_config.get('table_cells_field', 'table_cells'), [])
- for cell in table_cells:
- cell_text = cell.get('text', '')
- cell_bbox = cell.get('bbox', [])
- if cell_text and cell_bbox and len(cell_bbox) >= 4:
- parsed_data.append({
- 'text': str(cell_text).strip(),
- 'bbox': cell_bbox[:4],
- 'row': cell.get('row', -1),
- 'col': cell.get('col', -1),
- 'category': 'table_cell',
- 'confidence': cell.get('score', 0.0),
- 'source_tool': tool_name,
- })
- # 处理图片类型
- elif category == 'image':
- img_path = item.get(tool_config.get('img_path_field', 'img_path'), '')
- if bbox and len(bbox) >= 4:
- parsed_data.append({
- 'text': '[Image]',
- 'bbox': bbox[:4],
- 'category': 'image',
- 'confidence': confidence,
- 'source_tool': tool_name,
- 'img_path': img_path
- })
- elif category in ['list']:
- # 处理列表和标题类型
- list_items = item.get('list_items', [])
- sub_type = item.get('sub_type', 'unordered') # 有序或无序
-
- for list_item in list_items:
- if list_item and bbox and len(bbox) >= 4:
- parsed_data.append({
- 'text': str(list_item).strip(),
- 'bbox': bbox[:4],
- 'category': category,
- 'sub_type': sub_type,
- 'confidence': confidence,
- 'source_tool': tool_name
- })
- else:
- # 其他类型,按文本处理, header, table_cell, ...
- if text and bbox and len(bbox) >= 4:
- parsed_data.append({
- 'text': str(text).strip(),
- 'bbox': bbox[:4],
- 'category': category,
- 'confidence': confidence,
- 'source_tool': tool_name
- })
-
- return parsed_data
- def detect_mineru_structure(data: Union[List, Dict]) -> bool:
- """检测是否为MinerU数据结构"""
- if not isinstance(data, list) or len(data) == 0:
- return False
-
- # 检查第一个元素是否包含MinerU特征字段
- first_item = data[0] if data else {}
- if not isinstance(first_item, dict):
- return False
-
- # MinerU特征:包含type字段,且值为text/table/image之一
- has_type = 'type' in first_item
- has_bbox = 'bbox' in first_item
- has_text = 'text' in first_item
-
- if has_type and has_bbox and has_text:
- item_type = first_item.get('type', '')
- return item_type in ['text', 'table', 'image']
-
- return False
- def detect_ocr_tool_type(data: Union[List, Dict], config: Dict) -> str:
- """
- 自动检测OCR工具类型 - 增强版
-
- Args:
- data: OCR数据(可能是列表或字典)
- config: 配置字典
-
- Returns:
- 工具类型字符串
- """
- if not config['ocr']['auto_detection']['enabled']:
- return 'mineru' # 默认类型
-
- rules = config['ocr']['auto_detection']['rules']
-
- # 按优先级排序
- sorted_rules = sorted(rules, key=lambda x: x.get('priority', 999))
-
- for rule in sorted_rules:
- tool_type = rule['tool_type']
- conditions = rule.get('conditions', [])
-
- # 检查所有条件是否满足
- if _check_all_conditions(data, conditions):
- return tool_type
-
- # 如果所有规则都不匹配,返回默认类型
- return 'dots_ocr'
- def _check_all_conditions(data: Union[List, Dict], conditions: List[Dict]) -> bool:
- """
- 检查所有条件是否满足
-
- Args:
- data: 数据
- conditions: 条件列表
-
- Returns:
- 是否所有条件都满足
- """
- for condition in conditions:
- condition_type = condition.get('type', '')
-
- if condition_type == 'field_exists':
- # 检查字段存在
- field = condition.get('field', '')
- if not _check_field_exists(data, field):
- return False
-
- elif condition_type == 'field_not_exists':
- # 检查字段不存在
- field = condition.get('field', '')
- if _check_field_exists(data, field):
- return False
-
- elif condition_type == 'json_structure':
- # 检查JSON结构类型
- expected_structure = condition.get('structure', '')
- if expected_structure == 'array' and not isinstance(data, list):
- return False
- elif expected_structure == 'object' and not isinstance(data, dict):
- return False
-
- elif condition_type == 'field_value':
- # 检查字段值
- field = condition.get('field', '')
- expected_value = condition.get('value')
- actual_value = _get_field_value(data, field)
- if actual_value != expected_value:
- return False
-
- elif condition_type == 'field_contains':
- # 检查字段包含某个值
- field = condition.get('field', '')
- expected_values = condition.get('values', [])
- actual_value = _get_field_value(data, field)
- if actual_value not in expected_values:
- return False
-
- return True
- def _check_field_exists(data: Union[List, Dict], field_path: str) -> bool:
- """
- 检查字段是否存在(支持嵌套路径)
-
- Args:
- data: 数据
- field_path: 字段路径(支持点分隔,如 "doc_preprocessor_res.angle")
-
- Returns:
- 字段是否存在
- """
- if not field_path:
- return False
-
- # 处理数组情况:检查第一个元素
- if isinstance(data, list):
- if not data:
- return False
- data = data[0]
-
- # 处理嵌套字段路径
- fields = field_path.split('.')
- current = data
-
- for field in fields:
- if isinstance(current, dict) and field in current:
- current = current[field]
- else:
- return False
-
- return True
- def _get_field_value(data: Union[List, Dict], field_path: str):
- """
- 获取字段值(支持嵌套路径)
-
- Args:
- data: 数据
- field_path: 字段路径
-
- Returns:
- 字段值,如果不存在返回 None
- """
- if not field_path:
- return None
-
- # 处理数组情况:检查第一个元素
- if isinstance(data, list):
- if not data:
- return None
- data = data[0]
-
- # 处理嵌套字段路径
- fields = field_path.split('.')
- current = data
-
- for field in fields:
- if isinstance(current, dict) and field in current:
- current = current[field]
- else:
- return None
-
- return current
- def normalize_ocr_data(raw_data: Union[List, Dict], config: Dict) -> List[Dict]:
- """标准化OCR数据 - 支持多种工具"""
- tool_type = detect_ocr_tool_type(raw_data, config)
-
- if tool_type == 'dots_ocr':
- return parse_dots_ocr_data(raw_data, config, tool_type)
- elif tool_type == 'ppstructv3':
- return parse_ppstructv3_data(raw_data, config)
- elif tool_type == 'table_recognition_v2':
- return parse_table_recognition_v2_data(raw_data, config)
- elif tool_type == 'mineru':
- return parse_mineru_data(raw_data, config, tool_type)
- else:
- raise ValueError(f"不支持的OCR工具类型: {tool_type}")
- def get_rotation_angle_from_ppstructv3(data: Dict) -> float:
- """从PPStructV3数据中获取旋转角度"""
- if 'doc_preprocessor_res' in data:
- doc_res = data['doc_preprocessor_res']
- if isinstance(doc_res, dict) and 'angle' in doc_res:
- return float(doc_res['angle'])
- return 0.0
- # 修改 load_ocr_data_file 函数
- def load_ocr_data_file(json_path: str, config: Dict) -> Tuple[List, str, str]:
- """加载OCR数据文件 - 支持多数据源配置"""
- json_file = Path(json_path)
-
- if not json_file.exists():
- raise FileNotFoundError(f"找不到JSON文件: {json_path}")
-
- # 加载JSON数据
- try:
- with open(json_file, 'r', encoding='utf-8') as f:
- raw_data = json.load(f)
-
- # 统一数据格式
- ocr_data = normalize_ocr_data(raw_data, config)
-
- # 检查是否需要处理图像旋转
- rotation_angle = 0.0
- if isinstance(raw_data, dict):
- rotation_angle = get_rotation_angle_from_ppstructv3(raw_data)
-
- # 如果有旋转角度,记录下来供后续图像处理使用
- if rotation_angle != 0:
- for item in ocr_data:
- item['rotation_angle'] = rotation_angle
-
- except Exception as e:
- raise Exception(f"加载JSON文件失败: {e}")
-
- # 加载MD文件
- md_file = json_file.with_suffix('.md')
- md_content = ""
- if md_file.exists():
- with open(md_file, 'r', encoding='utf-8') as f:
- md_content = f.read()
-
- # ✅ 关键修改:处理MD内容中的所有图片引用
- md_content = process_all_images_in_content(md_content, str(json_file))
-
- # 查找对应的图片文件
- image_path = find_corresponding_image(json_file, config)
-
- return ocr_data, md_content, image_path
- def find_corresponding_image(json_file: Path, config: Dict) -> str:
- """查找对应的图片文件 - 支持多数据源"""
- # 从配置中获取图片目录
- src_img_dir = config.get('paths', {}).get('src_img_dir', '')
-
- if not src_img_dir:
- # 如果没有配置图片目录,尝试在JSON文件同级目录查找
- src_img_dir = json_file.parent
-
- src_img_path = Path(src_img_dir)
-
- # 支持多种图片格式
- image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif']
-
- for ext in image_extensions:
- image_file = src_img_path / f"{json_file.stem}{ext}"
- if image_file.exists():
- return str(image_file)
-
- # 如果找不到,返回空字符串
- return ""
- def process_ocr_data(ocr_data: List, config: Dict) -> Dict[str, List]:
- """处理OCR数据,建立文本到bbox的映射"""
- text_bbox_mapping = {}
- exclude_texts = config['ocr']['exclude_texts']
- min_text_length = config['ocr']['min_text_length']
-
- if not isinstance(ocr_data, list):
- return text_bbox_mapping
-
- for i, item in enumerate(ocr_data):
- if not isinstance(item, dict):
- continue
-
- text = str(item['text']).strip()
- if text and text not in exclude_texts and len(text) >= min_text_length:
- bbox = item['bbox']
- if isinstance(bbox, list) and len(bbox) == 4:
- if text not in text_bbox_mapping:
- text_bbox_mapping[text] = []
- text_bbox_mapping[text].append({
- 'bbox': bbox,
- 'category': item.get('category', 'Text'),
- 'index': i,
- 'confidence': item.get('confidence', config['ocr']['default_confidence']),
- 'source_tool': item.get('source_tool', 'unknown'),
- 'rotation_angle': item.get('rotation_angle', 0.0) # 添加旋转角度信息
- })
-
- return text_bbox_mapping
- def find_available_ocr_files(ocr_out_dir: str) -> List[str]:
- """查找可用的OCR文件"""
- available_files = []
-
- # 搜索多个可能的目录
- search_dirs = [
- Path(ocr_out_dir),
- ]
-
- for search_dir in search_dirs:
- if search_dir.exists():
- # 递归搜索JSON文件
- for json_file in search_dir.rglob("*.json"):
- if re.match(r'.*_page_\d+\.json$', json_file.name, re.IGNORECASE):
- available_files.append(str(json_file))
- # 去重并排序
- # available_files = sorted(list(set(available_files)))
- # 解析文件名并提取页码信息
- file_info = []
- for file_path in available_files:
- file_name = Path(file_path).stem
- # 提取页码 (例如从 "2023年度报告母公司_page_001" 中提取 "001")
- if 'page_' in file_name:
- try:
- page_part = file_name.split('page_')[-1]
- page_num = int(page_part)
- file_info.append({
- 'path': file_path,
- 'page': page_num,
- 'display_name': f"第{page_num}页"
- })
- except ValueError:
- # 如果无法解析页码,使用文件名
- file_info.append({
- 'path': file_path,
- 'page': len(file_info) + 1,
- 'display_name': Path(file_path).stem
- })
- else:
- # 对于没有page_的文件,按顺序编号
- file_info.append({
- 'path': file_path,
- 'page': len(file_info) + 1,
- 'display_name': Path(file_path).stem
- })
-
- # 按页码排序
- file_info.sort(key=lambda x: x['page'])
- return file_info
- def get_ocr_tool_info(ocr_data: List) -> Dict:
- """获取OCR工具信息统计"""
- tool_counts = {}
- for item in ocr_data:
- if isinstance(item, dict):
- source_tool = item.get('source_tool', 'unknown')
- tool_counts[source_tool] = tool_counts.get(source_tool, 0) + 1
-
- return tool_counts
- def get_ocr_statistics(ocr_data: List, text_bbox_mapping: Dict, marked_errors: set) -> Dict:
- """获取OCR数据统计信息"""
- if not isinstance(ocr_data, list) or not ocr_data:
- return {
- 'total_texts': 0, 'clickable_texts': 0, 'marked_errors': 0,
- 'categories': {}, 'accuracy_rate': 0, 'tool_info': {}
- }
-
- total_texts = len(ocr_data)
- clickable_texts = len(text_bbox_mapping)
- marked_errors_count = len(marked_errors)
-
- # 按类别统计
- categories = {}
- for item in ocr_data:
- if isinstance(item, dict):
- category = item.get('category', 'Unknown')
- categories[category] = categories.get(category, 0) + 1
-
- # OCR工具信息统计
- tool_info = get_ocr_tool_info(ocr_data)
-
- accuracy_rate = (clickable_texts - marked_errors_count) / clickable_texts * 100 if clickable_texts > 0 else 0
-
- return {
- 'total_texts': total_texts,
- 'clickable_texts': clickable_texts,
- 'marked_errors': marked_errors_count,
- 'categories': categories,
- 'accuracy_rate': accuracy_rate,
- 'tool_info': tool_info
- }
- def group_texts_by_category(text_bbox_mapping: Dict[str, List]) -> Dict[str, List[str]]:
- """按类别对文本进行分组"""
- categories = {}
- for text, info_list in text_bbox_mapping.items():
- category = info_list[0]['category']
- if category not in categories:
- categories[category] = []
- categories[category].append(text)
- return categories
- def get_ocr_tool_rotation_config(ocr_data: List, config: Dict) -> Dict:
- """获取OCR工具的旋转配置"""
- if not ocr_data or not isinstance(ocr_data, list):
- # 默认配置
- return {
- 'coordinates_are_pre_rotated': False
- }
-
- # 从第一个OCR数据项获取工具类型
- first_item = ocr_data[0] if ocr_data else {}
- source_tool = first_item.get('source_tool', 'dots_ocr')
-
- # 获取工具配置
- tools_config = config.get('ocr', {}).get('tools', {})
-
- if source_tool in tools_config:
- tool_config = tools_config[source_tool]
- return tool_config.get('rotation', {
- 'coordinates_are_pre_rotated': False
- })
- else:
- # 默认配置
- return {
- 'coordinates_are_pre_rotated': False
- }
- # ocr_validator_utils.py
- def find_available_ocr_files_multi_source(config: Dict) -> Dict[str, List[Dict]]:
- """查找多个数据源的OCR文件"""
- all_sources = {}
-
- for source in config.get('data_sources', []):
- source_name = source['name']
- ocr_tool = source['ocr_tool']
- source_key = f"{source_name}_{ocr_tool}" # 创建唯一标识
-
- ocr_out_dir = source['ocr_out_dir']
-
- if Path(ocr_out_dir).exists():
- files = find_available_ocr_files(ocr_out_dir)
-
- # 为每个文件添加数据源信息
- for file_info in files:
- file_info.update({
- 'source_name': source_name,
- 'ocr_tool': ocr_tool,
- 'description': source.get('description', ''),
- 'src_img_dir': source.get('src_img_dir', ''),
- 'ocr_out_dir': ocr_out_dir
- })
-
- all_sources[source_key] = {
- 'files': files,
- 'config': source
- }
-
- print(f"📁 找到数据源: {source_key} - {len(files)} 个文件")
-
- return all_sources
- def get_data_source_display_name(source_config: Dict) -> str:
- """生成数据源的显示名称"""
- name = source_config['name']
- tool = source_config['ocr_tool']
- description = source_config.get('description', '')
-
- # 获取工具的友好名称
- tool_name_map = {
- 'dots_ocr': 'Dots OCR',
- 'ppstructv3': 'PPStructV3',
- 'table_recognition_v2': 'Table Recognition V2',
- 'mineru': 'MinerU VLM-2.5.3'
- }
-
- tool_display = tool_name_map.get(tool, tool)
- return f"{name} ({tool_display})"
- def get_nested_value(data: Dict, path: str, default=None):
- if not path:
- return default
- keys = path.split('.')
- value = data
- for key in keys:
- if isinstance(value, dict) and key in value:
- value = value[key]
- else:
- return default
- return value
|