ocr_validator_utils.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738
  1. """
  2. OCR验证工具的工具函数模块
  3. 包含数据处理、图像处理、统计分析等功能
  4. """
  5. import json
  6. import pandas as pd
  7. import numpy as np
  8. from pathlib import Path
  9. from PIL import Image, ImageDraw
  10. from typing import Dict, List, Optional, Tuple, Union
  11. import re
  12. import yaml
  13. import sys
  14. from ocr_validator_file_utils import process_all_images_in_content
  15. def load_config(config_path: str = "config.yaml") -> Dict:
  16. """加载配置文件"""
  17. try:
  18. with open(config_path, 'r', encoding='utf-8') as f:
  19. return yaml.safe_load(f)
  20. except Exception as e:
  21. print(f"加载配置文件失败: {e}")
  22. import traceback
  23. traceback.print_exc()
  24. # 退出
  25. sys.exit(1)
  26. def rotate_image_and_coordinates(
  27. image: Image.Image,
  28. angle: float,
  29. coordinates_list: List[List[int]],
  30. rotate_coordinates: bool = True
  31. ) -> Tuple[Image.Image, List[List[int]]]:
  32. """
  33. 根据角度旋转图像和坐标 - 修正版本
  34. Args:
  35. image: 原始图像
  36. angle: 旋转角度(度数)
  37. coordinates_list: 坐标列表,每个坐标为[x1, y1, x2, y2]格式
  38. rotate_coordinates: 是否需要旋转坐标(针对不同OCR工具的处理方式)
  39. Returns:
  40. rotated_image: 旋转后的图像
  41. rotated_coordinates: 处理后的坐标列表
  42. """
  43. if angle == 0:
  44. return image, coordinates_list
  45. # 标准化旋转角度
  46. if angle == 270:
  47. rotation_angle = -90 # 顺时针90度
  48. elif angle == 90:
  49. rotation_angle = 90 # 逆时针90度
  50. elif angle == 180:
  51. rotation_angle = 180 # 180度
  52. else:
  53. rotation_angle = angle
  54. # 旋转图像
  55. rotated_image = image.rotate(rotation_angle, expand=True)
  56. # 如果不需要旋转坐标,直接返回原坐标
  57. if not rotate_coordinates:
  58. return rotated_image, coordinates_list
  59. # 获取原始和旋转后的图像尺寸
  60. orig_width, orig_height = image.size
  61. new_width, new_height = rotated_image.size
  62. # 计算旋转后的坐标
  63. rotated_coordinates = []
  64. for coord in coordinates_list:
  65. if len(coord) < 4:
  66. rotated_coordinates.append(coord)
  67. continue
  68. x1, y1, x2, y2 = coord[:4]
  69. # 验证原始坐标是否有效
  70. if x1 < 0 or y1 < 0 or x2 <= x1 or y2 <= y1:
  71. print(f"警告: 无效坐标 {coord}")
  72. rotated_coordinates.append([0, 0, 50, 50]) # 使用默认坐标
  73. continue
  74. # 根据旋转角度变换坐标
  75. if rotation_angle == -90: # 顺时针90度 (270度逆时针)
  76. # 变换公式: (x, y) -> (orig_height - y, x)
  77. new_x1 = orig_height - y2 # 这里是y2
  78. new_y1 = x1
  79. new_x2 = orig_height - y1 # 这里是y1
  80. new_y2 = x2
  81. elif rotation_angle == 90: # 逆时针90度
  82. # 变换公式: (x, y) -> (y, orig_width - x)
  83. new_x1 = y1
  84. new_y1 = orig_width - x2 # 这里是x2
  85. new_x2 = y2
  86. new_y2 = orig_width - x1 # 这里是x1
  87. elif rotation_angle == 180: # 180度
  88. # 变换公式: (x, y) -> (orig_width - x, orig_height - y)
  89. new_x1 = orig_width - x2
  90. new_y1 = orig_height - y2
  91. new_x2 = orig_width - x1
  92. new_y2 = orig_height - y1
  93. else: # 任意角度算法 - 修正版本
  94. # 将角度转换为弧度
  95. angle_rad = np.radians(rotation_angle)
  96. cos_angle = np.cos(angle_rad)
  97. sin_angle = np.sin(angle_rad)
  98. # 原图像中心点
  99. orig_center_x = orig_width / 2
  100. orig_center_y = orig_height / 2
  101. # 旋转后图像中心点
  102. new_center_x = new_width / 2
  103. new_center_y = new_height / 2
  104. # 将bbox的四个角点转换为相对于原图像中心的坐标
  105. corners = [
  106. (x1 - orig_center_x, y1 - orig_center_y), # 左上角
  107. (x2 - orig_center_x, y1 - orig_center_y), # 右上角
  108. (x2 - orig_center_x, y2 - orig_center_y), # 右下角
  109. (x1 - orig_center_x, y2 - orig_center_y) # 左下角
  110. ]
  111. # 应用修正后的旋转矩阵变换每个角点
  112. rotated_corners = []
  113. for x, y in corners:
  114. # 修正后的旋转矩阵: [cos(θ) sin(θ)] [x]
  115. # [-sin(θ) cos(θ)] [y]
  116. rotated_x = x * cos_angle + y * sin_angle
  117. rotated_y = -x * sin_angle + y * cos_angle
  118. # 转换回绝对坐标(相对于新图像)
  119. abs_x = rotated_x + new_center_x
  120. abs_y = rotated_y + new_center_y
  121. rotated_corners.append((abs_x, abs_y))
  122. # 从旋转后的四个角点计算新的边界框
  123. x_coords = [corner[0] for corner in rotated_corners]
  124. y_coords = [corner[1] for corner in rotated_corners]
  125. new_x1 = int(min(x_coords))
  126. new_y1 = int(min(y_coords))
  127. new_x2 = int(max(x_coords))
  128. new_y2 = int(max(y_coords))
  129. # 确保坐标在有效范围内
  130. new_x1 = max(0, min(new_width, new_x1))
  131. new_y1 = max(0, min(new_height, new_y1))
  132. new_x2 = max(0, min(new_width, new_x2))
  133. new_y2 = max(0, min(new_height, new_y2))
  134. # 确保x1 < x2, y1 < y2
  135. if new_x1 > new_x2:
  136. new_x1, new_x2 = new_x2, new_x1
  137. if new_y1 > new_y2:
  138. new_y1, new_y2 = new_y2, new_y1
  139. rotated_coordinates.append([new_x1, new_y1, new_x2, new_y2])
  140. return rotated_image, rotated_coordinates
  141. def parse_dots_ocr_data(data: List, config: Dict, tool_name: str) -> List[Dict]:
  142. """解析Dots OCR格式的数据"""
  143. tool_config = config['ocr']['tools'][tool_name]
  144. parsed_data = []
  145. for item in data:
  146. if not isinstance(item, dict):
  147. continue
  148. # 提取字段
  149. text = item.get(tool_config['text_field'], '')
  150. bbox = item.get(tool_config['bbox_field'], [])
  151. category = item.get(tool_config['category_field'], 'Text')
  152. confidence = item.get(tool_config.get('confidence_field', 'confidence'),
  153. config['ocr']['default_confidence'])
  154. if text and bbox and len(bbox) >= 4:
  155. parsed_data.append({
  156. 'text': str(text).strip(),
  157. 'bbox': bbox[:4], # 确保只取前4个坐标
  158. 'category': category,
  159. 'confidence': confidence,
  160. 'source_tool': tool_name
  161. })
  162. return parsed_data
  163. def parse_ppstructv3_data(data: Dict, config: Dict) -> List[Dict]:
  164. """解析PPStructV3格式的数据"""
  165. tool_config = config['ocr']['tools']['ppstructv3']
  166. parsed_data = []
  167. parsing_results = data.get(tool_config['parsing_results_field'], [])
  168. if not isinstance(parsing_results, list):
  169. return parsed_data
  170. for item in parsing_results:
  171. if not isinstance(item, dict):
  172. continue
  173. text = item.get(tool_config['text_field'], '')
  174. bbox = item.get(tool_config['bbox_field'], [])
  175. category = item.get(tool_config['category_field'], 'text')
  176. confidence = item.get(
  177. tool_config.get('confidence_field', 'confidence'),
  178. config['ocr']['default_confidence']
  179. )
  180. if text and bbox and len(bbox) >= 4:
  181. parsed_data.append({
  182. 'text': str(text).strip(),
  183. 'bbox': bbox[:4],
  184. 'category': category,
  185. 'confidence': confidence,
  186. 'source_tool': 'ppstructv3'
  187. })
  188. rec_texts = get_nested_value(data, tool_config.get('rec_texts_field', ''))
  189. rec_boxes = get_nested_value(data, tool_config.get('rec_boxes_field', ''))
  190. if isinstance(rec_texts, list) and isinstance(rec_boxes, list):
  191. for i, (text, box) in enumerate(zip(rec_texts, rec_boxes)):
  192. if text and isinstance(box, list) and len(box) >= 4:
  193. parsed_data.append({
  194. 'text': str(text).strip(),
  195. 'bbox': box[:4],
  196. 'category': 'OCR_Text',
  197. 'source_tool': 'ppstructv3_ocr'
  198. })
  199. return parsed_data
  200. def parse_table_recognition_v2_data(data: Dict, config: Dict) -> List[Dict]:
  201. tool_config = config['ocr']['tools']['table_recognition_v2']
  202. parsed_data = []
  203. tables = data.get(tool_config['parsing_results_field'], [])
  204. if not isinstance(tables, list):
  205. return parsed_data
  206. for item in tables:
  207. if not isinstance(item, dict):
  208. continue
  209. html_text = item.get(tool_config['text_field'], '')
  210. # 计算表格整体bbox
  211. cell_boxes_raw = item.get(tool_config['bbox_field'], [])
  212. if cell_boxes_raw:
  213. x1_list = [box[0] for box in cell_boxes_raw]
  214. y1_list = [box[1] for box in cell_boxes_raw]
  215. x2_list = [box[2] for box in cell_boxes_raw]
  216. y2_list = [box[3] for box in cell_boxes_raw]
  217. table_bbox = [
  218. float(min(x1_list)),
  219. float(min(y1_list)),
  220. float(max(x2_list)),
  221. float(max(y2_list))
  222. ]
  223. else:
  224. table_bbox = [0.0, 0.0, 0.0, 0.0]
  225. parsed_data.append({
  226. 'text': str(html_text).strip(),
  227. 'bbox': table_bbox,
  228. 'category': item.get(tool_config.get('category_field', ''), 'table'),
  229. 'confidence': item.get(tool_config.get('confidence_field', ''), config['ocr']['default_confidence']),
  230. 'source_tool': 'table_recognition_v2',
  231. })
  232. rec_texts = get_nested_value(item, tool_config.get('rec_texts_field', ''))
  233. rec_boxes = get_nested_value(item, tool_config.get('rec_boxes_field', ''))
  234. if isinstance(rec_texts, list) and isinstance(rec_boxes, list):
  235. for i, (text, box) in enumerate(zip(rec_texts, rec_boxes)):
  236. if text and isinstance(box, list) and len(box) >= 4:
  237. parsed_data.append({
  238. 'text': str(text).strip(),
  239. 'bbox': box[:4],
  240. 'category': 'OCR_Text',
  241. 'source_tool': 'table_recognition_v2'
  242. })
  243. return parsed_data
  244. def parse_mineru_data(data: List, config: Dict, tool_name="mineru") -> List[Dict]:
  245. """解析MinerU格式的数据"""
  246. tool_config = config['ocr']['tools'][tool_name]
  247. parsed_data = []
  248. if not isinstance(data, list):
  249. return parsed_data
  250. for item in data:
  251. if not isinstance(item, dict):
  252. continue
  253. text = item.get(tool_config['text_field'], '')
  254. bbox = item.get(tool_config['bbox_field'], [])
  255. category = item.get(tool_config['category_field'], 'Text')
  256. confidence = item.get(tool_config.get('confidence_field', 'confidence'),
  257. config['ocr']['default_confidence'])
  258. # 处理文本类型
  259. if category == 'text':
  260. if text and bbox and len(bbox) >= 4:
  261. parsed_data.append({
  262. 'text': str(text).strip(),
  263. 'bbox': bbox[:4],
  264. 'category': category,
  265. 'confidence': confidence,
  266. 'source_tool': tool_name,
  267. 'text_level': item.get('text_level', 0) # 保留文本层级信息
  268. })
  269. # 处理表格类型
  270. elif category == 'table':
  271. table_html = item.get(tool_config.get('table_body_field', 'table_body'), '')
  272. img_path = item.get(tool_config.get('img_path_field', 'img_path'), '')
  273. if bbox and len(bbox) >= 4:
  274. parsed_data.append({
  275. 'text': table_html,
  276. 'bbox': bbox[:4],
  277. 'category': 'table',
  278. 'confidence': confidence,
  279. 'source_tool': tool_name,
  280. 'img_path': img_path,
  281. 'table_body': table_html
  282. })
  283. elif category == 'image':
  284. img_path = item.get(tool_config.get('img_path_field', 'img_path'), '')
  285. if bbox and len(bbox) >= 4:
  286. parsed_data.append({
  287. 'text': '[Image]',
  288. 'bbox': bbox[:4],
  289. 'category': 'image',
  290. 'confidence': confidence,
  291. 'source_tool': tool_name,
  292. 'img_path': img_path
  293. })
  294. elif category == 'table_cell':
  295. if bbox and len(bbox) >= 4:
  296. parsed_data.append({
  297. 'text': str(text).strip(),
  298. 'bbox': bbox[:4],
  299. 'row': item.get('row', -1),
  300. 'col': item.get('col', -1),
  301. 'category': 'table_cell',
  302. 'confidence': confidence,
  303. 'source_tool': tool_name,
  304. })
  305. else:
  306. # 其他类型,按文本处理, header, table_cell, ...
  307. if text and bbox and len(bbox) >= 4:
  308. parsed_data.append({
  309. 'text': str(text).strip(),
  310. 'bbox': bbox[:4],
  311. 'category': category,
  312. 'confidence': confidence,
  313. 'source_tool': tool_name
  314. })
  315. return parsed_data
  316. def detect_mineru_structure(data: Union[List, Dict]) -> bool:
  317. """检测是否为MinerU数据结构"""
  318. if not isinstance(data, list) or len(data) == 0:
  319. return False
  320. # 检查第一个元素是否包含MinerU特征字段
  321. first_item = data[0] if data else {}
  322. if not isinstance(first_item, dict):
  323. return False
  324. # MinerU特征:包含type字段,且值为text/table/image之一
  325. has_type = 'type' in first_item
  326. has_bbox = 'bbox' in first_item
  327. has_text = 'text' in first_item
  328. if has_type and has_bbox and has_text:
  329. item_type = first_item.get('type', '')
  330. return item_type in ['text', 'table', 'image']
  331. return False
  332. def detect_ocr_tool_type(data: Union[List, Dict], config: Dict) -> str:
  333. """自动检测OCR工具类型 - 增强版"""
  334. if not config['ocr']['auto_detection']['enabled']:
  335. return 'dots_ocr' # 默认类型
  336. rules = config['ocr']['auto_detection']['rules']
  337. for rule in rules:
  338. # 检查字段存在性
  339. if 'field_exists' in rule:
  340. field_name = rule['field_exists']
  341. if isinstance(data, dict) and field_name in data:
  342. return rule['tool_type']
  343. elif isinstance(data, list) and data and isinstance(data[0], dict) and field_name in data[0]:
  344. # 如果是list,检查第一个元素
  345. return rule['tool_type']
  346. # 检查是否为数组
  347. if 'json_is_array' in rule:
  348. if rule['json_is_array'] and isinstance(data, list):
  349. # 进一步区分是dots_ocr还是mineru
  350. if not detect_mineru_structure(data):
  351. return rule['tool_type']
  352. # 默认返回dots_ocr
  353. return 'dots_ocr'
  354. def normalize_ocr_data(raw_data: Union[List, Dict], config: Dict) -> List[Dict]:
  355. """标准化OCR数据 - 支持多种工具"""
  356. tool_type = detect_ocr_tool_type(raw_data, config)
  357. if tool_type == 'dots_ocr':
  358. return parse_dots_ocr_data(raw_data, config, tool_type)
  359. elif tool_type == 'ppstructv3':
  360. return parse_ppstructv3_data(raw_data, config)
  361. elif tool_type == 'table_recognition_v2':
  362. return parse_table_recognition_v2_data(raw_data, config)
  363. elif tool_type == 'mineru':
  364. return parse_mineru_data(raw_data, config, tool_type)
  365. else:
  366. raise ValueError(f"不支持的OCR工具类型: {tool_type}")
  367. def get_rotation_angle_from_ppstructv3(data: Dict) -> float:
  368. """从PPStructV3数据中获取旋转角度"""
  369. if 'doc_preprocessor_res' in data:
  370. doc_res = data['doc_preprocessor_res']
  371. if isinstance(doc_res, dict) and 'angle' in doc_res:
  372. return float(doc_res['angle'])
  373. return 0.0
  374. # 修改 load_ocr_data_file 函数
  375. def load_ocr_data_file(json_path: str, config: Dict) -> Tuple[List, str, str]:
  376. """加载OCR数据文件 - 支持多数据源配置"""
  377. json_file = Path(json_path)
  378. if not json_file.exists():
  379. raise FileNotFoundError(f"找不到JSON文件: {json_path}")
  380. # 加载JSON数据
  381. try:
  382. with open(json_file, 'r', encoding='utf-8') as f:
  383. raw_data = json.load(f)
  384. # 统一数据格式
  385. ocr_data = normalize_ocr_data(raw_data, config)
  386. # 检查是否需要处理图像旋转
  387. rotation_angle = 0.0
  388. if isinstance(raw_data, dict):
  389. rotation_angle = get_rotation_angle_from_ppstructv3(raw_data)
  390. # 如果有旋转角度,记录下来供后续图像处理使用
  391. if rotation_angle != 0:
  392. for item in ocr_data:
  393. item['rotation_angle'] = rotation_angle
  394. except Exception as e:
  395. raise Exception(f"加载JSON文件失败: {e}")
  396. # 加载MD文件
  397. md_file = json_file.with_suffix('.md')
  398. md_content = ""
  399. if md_file.exists():
  400. with open(md_file, 'r', encoding='utf-8') as f:
  401. md_content = f.read()
  402. # ✅ 关键修改:处理MD内容中的所有图片引用
  403. md_content = process_all_images_in_content(md_content, str(json_file))
  404. # 查找对应的图片文件
  405. image_path = find_corresponding_image(json_file, config)
  406. return ocr_data, md_content, image_path
  407. def find_corresponding_image(json_file: Path, config: Dict) -> str:
  408. """查找对应的图片文件 - 支持多数据源"""
  409. # 从配置中获取图片目录
  410. src_img_dir = config.get('paths', {}).get('src_img_dir', '')
  411. if not src_img_dir:
  412. # 如果没有配置图片目录,尝试在JSON文件同级目录查找
  413. src_img_dir = json_file.parent
  414. src_img_path = Path(src_img_dir)
  415. # 支持多种图片格式
  416. image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif']
  417. for ext in image_extensions:
  418. image_file = src_img_path / f"{json_file.stem}{ext}"
  419. if image_file.exists():
  420. return str(image_file)
  421. # 如果找不到,返回空字符串
  422. return ""
  423. def process_ocr_data(ocr_data: List, config: Dict) -> Dict[str, List]:
  424. """处理OCR数据,建立文本到bbox的映射"""
  425. text_bbox_mapping = {}
  426. exclude_texts = config['ocr']['exclude_texts']
  427. min_text_length = config['ocr']['min_text_length']
  428. if not isinstance(ocr_data, list):
  429. return text_bbox_mapping
  430. for i, item in enumerate(ocr_data):
  431. if not isinstance(item, dict):
  432. continue
  433. text = str(item['text']).strip()
  434. if text and text not in exclude_texts and len(text) >= min_text_length:
  435. bbox = item['bbox']
  436. if isinstance(bbox, list) and len(bbox) == 4:
  437. if text not in text_bbox_mapping:
  438. text_bbox_mapping[text] = []
  439. text_bbox_mapping[text].append({
  440. 'bbox': bbox,
  441. 'category': item.get('category', 'Text'),
  442. 'index': i,
  443. 'confidence': item.get('confidence', config['ocr']['default_confidence']),
  444. 'source_tool': item.get('source_tool', 'unknown'),
  445. 'rotation_angle': item.get('rotation_angle', 0.0) # 添加旋转角度信息
  446. })
  447. return text_bbox_mapping
  448. def find_available_ocr_files(ocr_out_dir: str) -> List[str]:
  449. """查找可用的OCR文件"""
  450. available_files = []
  451. # 搜索多个可能的目录
  452. search_dirs = [
  453. Path(ocr_out_dir),
  454. ]
  455. for search_dir in search_dirs:
  456. if search_dir.exists():
  457. # 递归搜索JSON文件
  458. for json_file in search_dir.rglob("*.json"):
  459. if re.match(r'.*_page_\d+\.json$', json_file.name, re.IGNORECASE):
  460. available_files.append(str(json_file))
  461. # 去重并排序
  462. # available_files = sorted(list(set(available_files)))
  463. # 解析文件名并提取页码信息
  464. file_info = []
  465. for file_path in available_files:
  466. file_name = Path(file_path).stem
  467. # 提取页码 (例如从 "2023年度报告母公司_page_001" 中提取 "001")
  468. if 'page_' in file_name:
  469. try:
  470. page_part = file_name.split('page_')[-1]
  471. page_num = int(page_part)
  472. file_info.append({
  473. 'path': file_path,
  474. 'page': page_num,
  475. 'display_name': f"第{page_num}页"
  476. })
  477. except ValueError:
  478. # 如果无法解析页码,使用文件名
  479. file_info.append({
  480. 'path': file_path,
  481. 'page': len(file_info) + 1,
  482. 'display_name': Path(file_path).stem
  483. })
  484. else:
  485. # 对于没有page_的文件,按顺序编号
  486. file_info.append({
  487. 'path': file_path,
  488. 'page': len(file_info) + 1,
  489. 'display_name': Path(file_path).stem
  490. })
  491. # 按页码排序
  492. file_info.sort(key=lambda x: x['page'])
  493. return file_info
  494. def get_ocr_tool_info(ocr_data: List) -> Dict:
  495. """获取OCR工具信息统计"""
  496. tool_counts = {}
  497. for item in ocr_data:
  498. if isinstance(item, dict):
  499. source_tool = item.get('source_tool', 'unknown')
  500. tool_counts[source_tool] = tool_counts.get(source_tool, 0) + 1
  501. return tool_counts
  502. def get_ocr_statistics(ocr_data: List, text_bbox_mapping: Dict, marked_errors: set) -> Dict:
  503. """获取OCR数据统计信息"""
  504. if not isinstance(ocr_data, list) or not ocr_data:
  505. return {
  506. 'total_texts': 0, 'clickable_texts': 0, 'marked_errors': 0,
  507. 'categories': {}, 'accuracy_rate': 0, 'tool_info': {}
  508. }
  509. total_texts = len(ocr_data)
  510. clickable_texts = len(text_bbox_mapping)
  511. marked_errors_count = len(marked_errors)
  512. # 按类别统计
  513. categories = {}
  514. for item in ocr_data:
  515. if isinstance(item, dict):
  516. category = item.get('category', 'Unknown')
  517. categories[category] = categories.get(category, 0) + 1
  518. # OCR工具信息统计
  519. tool_info = get_ocr_tool_info(ocr_data)
  520. accuracy_rate = (clickable_texts - marked_errors_count) / clickable_texts * 100 if clickable_texts > 0 else 0
  521. return {
  522. 'total_texts': total_texts,
  523. 'clickable_texts': clickable_texts,
  524. 'marked_errors': marked_errors_count,
  525. 'categories': categories,
  526. 'accuracy_rate': accuracy_rate,
  527. 'tool_info': tool_info
  528. }
  529. def group_texts_by_category(text_bbox_mapping: Dict[str, List]) -> Dict[str, List[str]]:
  530. """按类别对文本进行分组"""
  531. categories = {}
  532. for text, info_list in text_bbox_mapping.items():
  533. category = info_list[0]['category']
  534. if category not in categories:
  535. categories[category] = []
  536. categories[category].append(text)
  537. return categories
  538. def get_ocr_tool_rotation_config(ocr_data: List, config: Dict) -> Dict:
  539. """获取OCR工具的旋转配置"""
  540. if not ocr_data or not isinstance(ocr_data, list):
  541. # 默认配置
  542. return {
  543. 'coordinates_are_pre_rotated': False
  544. }
  545. # 从第一个OCR数据项获取工具类型
  546. first_item = ocr_data[0] if ocr_data else {}
  547. source_tool = first_item.get('source_tool', 'dots_ocr')
  548. # 获取工具配置
  549. tools_config = config.get('ocr', {}).get('tools', {})
  550. if source_tool in tools_config:
  551. tool_config = tools_config[source_tool]
  552. return tool_config.get('rotation', {
  553. 'coordinates_are_pre_rotated': False
  554. })
  555. else:
  556. # 默认配置
  557. return {
  558. 'coordinates_are_pre_rotated': False
  559. }
  560. # ocr_validator_utils.py
  561. def find_available_ocr_files_multi_source(config: Dict) -> Dict[str, List[Dict]]:
  562. """查找多个数据源的OCR文件"""
  563. all_sources = {}
  564. for source in config.get('data_sources', []):
  565. source_name = source['name']
  566. ocr_tool = source['ocr_tool']
  567. source_key = f"{source_name}_{ocr_tool}" # 创建唯一标识
  568. ocr_out_dir = source['ocr_out_dir']
  569. if Path(ocr_out_dir).exists():
  570. files = find_available_ocr_files(ocr_out_dir)
  571. # 为每个文件添加数据源信息
  572. for file_info in files:
  573. file_info.update({
  574. 'source_name': source_name,
  575. 'ocr_tool': ocr_tool,
  576. 'description': source.get('description', ''),
  577. 'src_img_dir': source.get('src_img_dir', ''),
  578. 'ocr_out_dir': ocr_out_dir
  579. })
  580. all_sources[source_key] = {
  581. 'files': files,
  582. 'config': source
  583. }
  584. print(f"📁 找到数据源: {source_key} - {len(files)} 个文件")
  585. return all_sources
  586. def get_data_source_display_name(source_config: Dict) -> str:
  587. """生成数据源的显示名称"""
  588. name = source_config['name']
  589. tool = source_config['ocr_tool']
  590. description = source_config.get('description', '')
  591. # 获取工具的友好名称
  592. tool_name_map = {
  593. 'dots_ocr': 'Dots OCR',
  594. 'ppstructv3': 'PPStructV3',
  595. 'table_recognition_v2': 'Table Recognition V2',
  596. 'mineru': 'MinerU VLM-2.5.3'
  597. }
  598. tool_display = tool_name_map.get(tool, tool)
  599. return f"{name} ({tool_display})"
  600. def get_nested_value(data: Dict, path: str, default=None):
  601. if not path:
  602. return default
  603. keys = path.split('.')
  604. value = data
  605. for key in keys:
  606. if isinstance(value, dict) and key in value:
  607. value = value[key]
  608. else:
  609. return default
  610. return value