ocr_validator_utils.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751
  1. """
  2. OCR验证工具的工具函数模块
  3. 包含数据处理、图像处理、统计分析等功能
  4. """
  5. import json
  6. import pandas as pd
  7. import numpy as np
  8. from pathlib import Path
  9. from PIL import Image, ImageDraw
  10. from typing import Dict, List, Optional, Tuple, Union
  11. import re
  12. import yaml
  13. import sys
  14. # 添加 ocr_platform 根目录到 Python 路径(用于导入 ocr_utils)
  15. # 使用 resolve() 确保路径是绝对路径,避免相对路径导致的 IndexError
  16. _file_path = Path(__file__).resolve()
  17. ocr_platform_root = _file_path.parents[1] # ocr_validator -> ocr_platform
  18. if str(ocr_platform_root) not in sys.path:
  19. sys.path.insert(0, str(ocr_platform_root))
  20. # 从 ocr_utils 导入通用工具
  21. from ocr_utils.html_utils import process_all_images_in_content
  22. from ocr_utils.image_utils import rotate_image_and_coordinates
  23. # rotate_image_and_coordinates 已从 ocr_utils.image_utils 导入,无需重新定义
  24. def load_config(config_path: str = "config.yaml") -> Dict:
  25. """加载配置文件"""
  26. try:
  27. with open(config_path, 'r', encoding='utf-8') as f:
  28. return yaml.safe_load(f)
  29. except Exception as e:
  30. print(f"加载配置文件失败: {e}")
  31. import traceback
  32. traceback.print_exc()
  33. # 退出
  34. sys.exit(1)
  35. # rotate_image_and_coordinates 已从 ocr_utils.image_utils 导入,无需重新定义
  36. def parse_dots_ocr_data(data: List, config: Dict, tool_name: str) -> List[Dict]:
  37. """解析Dots OCR格式的数据"""
  38. tool_config = config['ocr']['tools'][tool_name]
  39. parsed_data = []
  40. for item in data:
  41. if not isinstance(item, dict):
  42. continue
  43. # 提取字段
  44. text = item.get(tool_config['text_field'], '')
  45. bbox = item.get(tool_config['bbox_field'], [])
  46. category = item.get(tool_config['category_field'], 'Text')
  47. confidence = item.get(tool_config.get('confidence_field', 'confidence'),
  48. config['ocr']['default_confidence'])
  49. if text and bbox and len(bbox) >= 4:
  50. parsed_data.append({
  51. 'text': str(text).strip(),
  52. 'bbox': bbox[:4], # 确保只取前4个坐标
  53. 'category': category,
  54. 'confidence': confidence,
  55. 'source_tool': tool_name
  56. })
  57. return parsed_data
  58. def parse_ppstructv3_data(data: Dict, config: Dict) -> List[Dict]:
  59. """解析PPStructV3格式的数据"""
  60. tool_config = config['ocr']['tools']['ppstructv3']
  61. parsed_data = []
  62. parsing_results = data.get(tool_config['parsing_results_field'], [])
  63. if not isinstance(parsing_results, list):
  64. return parsed_data
  65. for item in parsing_results:
  66. if not isinstance(item, dict):
  67. continue
  68. text = item.get(tool_config['text_field'], '')
  69. bbox = item.get(tool_config['bbox_field'], [])
  70. category = item.get(tool_config['category_field'], 'text')
  71. confidence = item.get(
  72. tool_config.get('confidence_field', 'confidence'),
  73. config['ocr']['default_confidence']
  74. )
  75. if text and bbox and len(bbox) >= 4:
  76. parsed_data.append({
  77. 'text': str(text).strip(),
  78. 'bbox': bbox[:4],
  79. 'category': category,
  80. 'confidence': confidence,
  81. 'source_tool': 'ppstructv3'
  82. })
  83. rec_texts = get_nested_value(data, tool_config.get('rec_texts_field', ''))
  84. rec_boxes = get_nested_value(data, tool_config.get('rec_boxes_field', ''))
  85. if isinstance(rec_texts, list) and isinstance(rec_boxes, list):
  86. for i, (text, box) in enumerate(zip(rec_texts, rec_boxes)):
  87. if text and isinstance(box, list) and len(box) >= 4:
  88. parsed_data.append({
  89. 'text': str(text).strip(),
  90. 'bbox': box[:4],
  91. 'category': 'OCR_Text',
  92. 'source_tool': 'ppstructv3_ocr'
  93. })
  94. return parsed_data
  95. def parse_table_recognition_v2_data(data: Dict, config: Dict) -> List[Dict]:
  96. tool_config = config['ocr']['tools']['table_recognition_v2']
  97. parsed_data = []
  98. tables = data.get(tool_config['parsing_results_field'], [])
  99. if not isinstance(tables, list):
  100. return parsed_data
  101. for item in tables:
  102. if not isinstance(item, dict):
  103. continue
  104. html_text = item.get(tool_config['text_field'], '')
  105. # 计算表格整体bbox
  106. cell_boxes_raw = item.get(tool_config['bbox_field'], [])
  107. if cell_boxes_raw:
  108. x1_list = [box[0] for box in cell_boxes_raw]
  109. y1_list = [box[1] for box in cell_boxes_raw]
  110. x2_list = [box[2] for box in cell_boxes_raw]
  111. y2_list = [box[3] for box in cell_boxes_raw]
  112. table_bbox = [
  113. float(min(x1_list)),
  114. float(min(y1_list)),
  115. float(max(x2_list)),
  116. float(max(y2_list))
  117. ]
  118. else:
  119. table_bbox = [0.0, 0.0, 0.0, 0.0]
  120. parsed_data.append({
  121. 'text': str(html_text).strip(),
  122. 'bbox': table_bbox,
  123. 'category': item.get(tool_config.get('category_field', ''), 'table'),
  124. 'confidence': item.get(tool_config.get('confidence_field', ''), config['ocr']['default_confidence']),
  125. 'source_tool': 'table_recognition_v2',
  126. })
  127. rec_texts = get_nested_value(item, tool_config.get('rec_texts_field', ''))
  128. rec_boxes = get_nested_value(item, tool_config.get('rec_boxes_field', ''))
  129. if isinstance(rec_texts, list) and isinstance(rec_boxes, list):
  130. for i, (text, box) in enumerate(zip(rec_texts, rec_boxes)):
  131. if text and isinstance(box, list) and len(box) >= 4:
  132. parsed_data.append({
  133. 'text': str(text).strip(),
  134. 'bbox': box[:4],
  135. 'category': 'OCR_Text',
  136. 'source_tool': 'table_recognition_v2'
  137. })
  138. return parsed_data
  139. def parse_mineru_data(data: List, config: Dict, tool_name="mineru") -> List[Dict]:
  140. """解析MinerU格式的数据"""
  141. tool_config = config['ocr']['tools'][tool_name]
  142. parsed_data = []
  143. if not isinstance(data, list):
  144. return parsed_data
  145. for item in data:
  146. if not isinstance(item, dict):
  147. continue
  148. text = item.get(tool_config['text_field'], '')
  149. bbox = item.get(tool_config['bbox_field'], [])
  150. category = item.get(tool_config['category_field'], 'Text')
  151. confidence = item.get(tool_config.get('confidence_field', 'confidence'),
  152. config['ocr']['default_confidence'])
  153. # 处理文本类型
  154. if category == 'text':
  155. if text and bbox and len(bbox) >= 4:
  156. parsed_data.append({
  157. 'text': str(text).strip(),
  158. 'bbox': bbox[:4],
  159. 'category': category,
  160. 'confidence': confidence,
  161. 'source_tool': tool_name,
  162. 'text_level': item.get('text_level', 0) # 保留文本层级信息
  163. })
  164. # 处理表格类型
  165. elif category == 'table':
  166. table_html = item.get(tool_config.get('table_body_field', 'table_body'), '')
  167. img_path = item.get(tool_config.get('img_path_field', 'img_path'), '')
  168. if bbox and len(bbox) >= 4:
  169. parsed_data.append({
  170. 'text': table_html,
  171. 'bbox': bbox[:4],
  172. 'category': 'table',
  173. 'confidence': confidence,
  174. 'source_tool': tool_name,
  175. 'img_path': img_path,
  176. 'table_body': table_html
  177. })
  178. table_cells = item.get(tool_config.get('table_cells_field', 'table_cells'), [])
  179. for cell in table_cells:
  180. cell_text = cell.get('text', '')
  181. cell_bbox = cell.get('bbox', [])
  182. if cell_text and cell_bbox and len(cell_bbox) >= 4:
  183. parsed_data.append({
  184. 'text': str(cell_text).strip(),
  185. 'matched_text': cell.get('matched_text', ''),
  186. 'bbox': cell_bbox[:4],
  187. 'row': cell.get('row', -1),
  188. 'col': cell.get('col', -1),
  189. 'category': 'table_cell',
  190. 'confidence': cell.get('score', 0.0),
  191. 'source_tool': tool_name,
  192. })
  193. # 处理图片类型
  194. elif category == 'image':
  195. img_path = item.get(tool_config.get('img_path_field', 'img_path'), '')
  196. if bbox and len(bbox) >= 4:
  197. parsed_data.append({
  198. 'text': '[Image]',
  199. 'bbox': bbox[:4],
  200. 'category': 'image',
  201. 'confidence': confidence,
  202. 'source_tool': tool_name,
  203. 'img_path': img_path
  204. })
  205. elif category in ['list']:
  206. # 处理列表和标题类型
  207. list_items = item.get('list_items', [])
  208. sub_type = item.get('sub_type', 'unordered') # 有序或无序
  209. for list_item in list_items:
  210. if list_item and bbox and len(bbox) >= 4:
  211. parsed_data.append({
  212. 'text': str(list_item).strip(),
  213. 'bbox': bbox[:4],
  214. 'category': category,
  215. 'sub_type': sub_type,
  216. 'confidence': confidence,
  217. 'source_tool': tool_name
  218. })
  219. else:
  220. # 其他类型,按文本处理, header, table_cell, ...
  221. if text and bbox and len(bbox) >= 4:
  222. parsed_data.append({
  223. 'text': str(text).strip(),
  224. 'bbox': bbox[:4],
  225. 'category': category,
  226. 'confidence': confidence,
  227. 'source_tool': tool_name
  228. })
  229. return parsed_data
  230. def detect_mineru_structure(data: Union[List, Dict]) -> bool:
  231. """检测是否为MinerU数据结构"""
  232. if not isinstance(data, list) or len(data) == 0:
  233. return False
  234. # 检查第一个元素是否包含MinerU特征字段
  235. first_item = data[0] if data else {}
  236. if not isinstance(first_item, dict):
  237. return False
  238. # MinerU特征:包含type字段,且值为text/table/image之一
  239. has_type = 'type' in first_item
  240. has_bbox = 'bbox' in first_item
  241. has_text = 'text' in first_item
  242. if has_type and has_bbox and has_text:
  243. item_type = first_item.get('type', '')
  244. return item_type in ['text', 'table', 'image']
  245. return False
  246. def detect_ocr_tool_type(data: Union[List, Dict], config: Dict) -> str:
  247. """
  248. 自动检测OCR工具类型 - 增强版
  249. Args:
  250. data: OCR数据(可能是列表或字典)
  251. config: 配置字典
  252. Returns:
  253. 工具类型字符串
  254. """
  255. if not config['ocr']['auto_detection']['enabled']:
  256. return 'mineru' # 默认类型
  257. rules = config['ocr']['auto_detection']['rules']
  258. # 按优先级排序
  259. sorted_rules = sorted(rules, key=lambda x: x.get('priority', 999))
  260. for rule in sorted_rules:
  261. tool_type = rule['tool_type']
  262. conditions = rule.get('conditions', [])
  263. # 检查所有条件是否满足
  264. if _check_all_conditions(data, conditions):
  265. return tool_type
  266. # 如果所有规则都不匹配,返回默认类型
  267. return 'dots_ocr'
  268. def _check_all_conditions(data: Union[List, Dict], conditions: List[Dict]) -> bool:
  269. """
  270. 检查所有条件是否满足
  271. Args:
  272. data: 数据
  273. conditions: 条件列表
  274. Returns:
  275. 是否所有条件都满足
  276. """
  277. for condition in conditions:
  278. condition_type = condition.get('type', '')
  279. if condition_type == 'field_exists':
  280. # 检查字段存在
  281. field = condition.get('field', '')
  282. if not _check_field_exists(data, field):
  283. return False
  284. elif condition_type == 'field_not_exists':
  285. # 检查字段不存在
  286. field = condition.get('field', '')
  287. if _check_field_exists(data, field):
  288. return False
  289. elif condition_type == 'json_structure':
  290. # 检查JSON结构类型
  291. expected_structure = condition.get('structure', '')
  292. if expected_structure == 'array' and not isinstance(data, list):
  293. return False
  294. elif expected_structure == 'object' and not isinstance(data, dict):
  295. return False
  296. elif condition_type == 'field_value':
  297. # 检查字段值
  298. field = condition.get('field', '')
  299. expected_value = condition.get('value')
  300. actual_value = _get_field_value(data, field)
  301. if actual_value != expected_value:
  302. return False
  303. elif condition_type == 'field_contains':
  304. # 检查字段包含某个值
  305. field = condition.get('field', '')
  306. expected_values = condition.get('values', [])
  307. actual_value = _get_field_value(data, field)
  308. if actual_value not in expected_values:
  309. return False
  310. return True
  311. def _check_field_exists(data: Union[List, Dict], field_path: str) -> bool:
  312. """
  313. 检查字段是否存在(支持嵌套路径)
  314. Args:
  315. data: 数据
  316. field_path: 字段路径(支持点分隔,如 "doc_preprocessor_res.angle")
  317. Returns:
  318. 字段是否存在
  319. """
  320. if not field_path:
  321. return False
  322. # 处理数组情况:检查第一个元素
  323. if isinstance(data, list):
  324. if not data:
  325. return False
  326. data = data[0]
  327. # 处理嵌套字段路径
  328. fields = field_path.split('.')
  329. current = data
  330. for field in fields:
  331. if isinstance(current, dict) and field in current:
  332. current = current[field]
  333. else:
  334. return False
  335. return True
  336. def _get_field_value(data: Union[List, Dict], field_path: str):
  337. """
  338. 获取字段值(支持嵌套路径)
  339. Args:
  340. data: 数据
  341. field_path: 字段路径
  342. Returns:
  343. 字段值,如果不存在返回 None
  344. """
  345. if not field_path:
  346. return None
  347. # 处理数组情况:检查第一个元素
  348. if isinstance(data, list):
  349. if not data:
  350. return None
  351. data = data[0]
  352. # 处理嵌套字段路径
  353. fields = field_path.split('.')
  354. current = data
  355. for field in fields:
  356. if isinstance(current, dict) and field in current:
  357. current = current[field]
  358. else:
  359. return None
  360. return current
  361. def normalize_ocr_data(raw_data: Union[List, Dict], config: Dict) -> List[Dict]:
  362. """标准化OCR数据 - 支持多种工具"""
  363. tool_type = detect_ocr_tool_type(raw_data, config)
  364. if tool_type == 'dots_ocr':
  365. return parse_dots_ocr_data(raw_data, config, tool_type)
  366. elif tool_type == 'ppstructv3':
  367. return parse_ppstructv3_data(raw_data, config)
  368. elif tool_type == 'table_recognition_v2':
  369. return parse_table_recognition_v2_data(raw_data, config)
  370. elif tool_type == 'mineru':
  371. return parse_mineru_data(raw_data, config, tool_type)
  372. else:
  373. raise ValueError(f"不支持的OCR工具类型: {tool_type}")
  374. def get_rotation_angle_from_ppstructv3(data: Dict) -> float:
  375. """从PPStructV3数据中获取旋转角度"""
  376. if 'doc_preprocessor_res' in data:
  377. doc_res = data['doc_preprocessor_res']
  378. if isinstance(doc_res, dict) and 'angle' in doc_res:
  379. return float(doc_res['angle'])
  380. return 0.0
  381. # 修改 load_ocr_data_file 函数
  382. def load_ocr_data_file(json_path: str, config: Dict) -> Tuple[List, str, str]:
  383. """加载OCR数据文件 - 支持多数据源配置"""
  384. json_file = Path(json_path)
  385. if not json_file.exists():
  386. raise FileNotFoundError(f"找不到JSON文件: {json_path}")
  387. # 加载JSON数据
  388. try:
  389. with open(json_file, 'r', encoding='utf-8') as f:
  390. raw_data = json.load(f)
  391. # 统一数据格式
  392. ocr_data = normalize_ocr_data(raw_data, config)
  393. # 检查是否需要处理图像旋转
  394. rotation_angle = 0.0
  395. if isinstance(raw_data, dict):
  396. rotation_angle = get_rotation_angle_from_ppstructv3(raw_data)
  397. # 如果有旋转角度,记录下来供后续图像处理使用
  398. if rotation_angle != 0:
  399. for item in ocr_data:
  400. item['rotation_angle'] = rotation_angle
  401. except Exception as e:
  402. raise Exception(f"加载JSON文件失败: {e}")
  403. # 加载MD文件
  404. md_file = json_file.with_suffix('.md')
  405. md_content = ""
  406. if md_file.exists():
  407. with open(md_file, 'r', encoding='utf-8') as f:
  408. md_content = f.read()
  409. # ✅ 关键修改:处理MD内容中的所有图片引用
  410. md_content = process_all_images_in_content(md_content, str(json_file))
  411. # 查找对应的图片文件
  412. image_path = find_corresponding_image(json_file, config)
  413. return ocr_data, md_content, image_path
  414. def find_corresponding_image(json_file: Path, config: Dict) -> str:
  415. """查找对应的图片文件 - 支持多数据源"""
  416. # 从配置中获取图片目录
  417. src_img_dir = config.get('paths', {}).get('src_img_dir', '')
  418. if not src_img_dir:
  419. # 如果没有配置图片目录,尝试在JSON文件同级目录查找
  420. src_img_dir = json_file.parent
  421. src_img_path = Path(src_img_dir)
  422. # 支持多种图片格式
  423. image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif']
  424. for ext in image_extensions:
  425. image_file = src_img_path / f"{json_file.stem}{ext}"
  426. if image_file.exists():
  427. return str(image_file)
  428. # 如果找不到,返回空字符串
  429. return ""
  430. def process_ocr_data(ocr_data: List, config: Dict) -> Dict[str, List]:
  431. """处理OCR数据,建立文本到bbox的映射"""
  432. text_bbox_mapping = {}
  433. exclude_texts = config['ocr']['exclude_texts']
  434. min_text_length = config['ocr']['min_text_length']
  435. if not isinstance(ocr_data, list):
  436. return text_bbox_mapping
  437. for i, item in enumerate(ocr_data):
  438. if not isinstance(item, dict):
  439. continue
  440. text = str(item['text']).strip()
  441. if text and text not in exclude_texts and len(text) >= min_text_length:
  442. bbox = item['bbox']
  443. if isinstance(bbox, list) and len(bbox) == 4:
  444. if text not in text_bbox_mapping:
  445. text_bbox_mapping[text] = []
  446. text_bbox_mapping[text].append({
  447. 'matched_text': item.get('matched_text', ''),
  448. 'bbox': bbox,
  449. 'category': item.get('category', 'Text'),
  450. 'index': i,
  451. 'confidence': item.get('confidence', config['ocr']['default_confidence']),
  452. 'source_tool': item.get('source_tool', 'unknown'),
  453. 'rotation_angle': item.get('rotation_angle', 0.0) # 添加旋转角度信息
  454. })
  455. return text_bbox_mapping
  456. def find_available_ocr_files(ocr_out_dir: str) -> List[str]:
  457. """查找可用的OCR文件"""
  458. available_files = []
  459. # 搜索多个可能的目录
  460. search_dirs = [
  461. Path(ocr_out_dir),
  462. ]
  463. for search_dir in search_dirs:
  464. if search_dir.exists():
  465. # 递归搜索JSON文件
  466. for json_file in search_dir.rglob("*.json"):
  467. if re.match(r'.*_page_\d+\.json$', json_file.name, re.IGNORECASE):
  468. available_files.append(str(json_file))
  469. # 去重并排序
  470. # available_files = sorted(list(set(available_files)))
  471. # 解析文件名并提取页码信息
  472. file_info = []
  473. for file_path in available_files:
  474. file_name = Path(file_path).stem
  475. # 提取页码 (例如从 "2023年度报告母公司_page_001" 中提取 "001")
  476. if 'page_' in file_name:
  477. try:
  478. page_part = file_name.split('page_')[-1]
  479. page_num = int(page_part)
  480. file_info.append({
  481. 'path': file_path,
  482. 'page': page_num,
  483. 'display_name': f"第{page_num}页"
  484. })
  485. except ValueError:
  486. # 如果无法解析页码,使用文件名
  487. file_info.append({
  488. 'path': file_path,
  489. 'page': len(file_info) + 1,
  490. 'display_name': Path(file_path).stem
  491. })
  492. else:
  493. # 对于没有page_的文件,按顺序编号
  494. file_info.append({
  495. 'path': file_path,
  496. 'page': len(file_info) + 1,
  497. 'display_name': Path(file_path).stem
  498. })
  499. # 按页码排序
  500. file_info.sort(key=lambda x: x['page'])
  501. return file_info
  502. def get_ocr_tool_info(ocr_data: List) -> Dict:
  503. """获取OCR工具信息统计"""
  504. tool_counts = {}
  505. for item in ocr_data:
  506. if isinstance(item, dict):
  507. source_tool = item.get('source_tool', 'unknown')
  508. tool_counts[source_tool] = tool_counts.get(source_tool, 0) + 1
  509. return tool_counts
  510. def get_ocr_statistics(ocr_data: List, text_bbox_mapping: Dict, marked_errors: set) -> Dict:
  511. """获取OCR数据统计信息"""
  512. if not isinstance(ocr_data, list) or not ocr_data:
  513. return {
  514. 'total_texts': 0, 'clickable_texts': 0, 'marked_errors': 0,
  515. 'categories': {}, 'accuracy_rate': 0, 'tool_info': {}
  516. }
  517. total_texts = len(ocr_data)
  518. clickable_texts = len(text_bbox_mapping)
  519. marked_errors_count = len(marked_errors)
  520. # 按类别统计
  521. categories = {}
  522. for item in ocr_data:
  523. if isinstance(item, dict):
  524. category = item.get('category', 'Unknown')
  525. categories[category] = categories.get(category, 0) + 1
  526. # OCR工具信息统计
  527. tool_info = get_ocr_tool_info(ocr_data)
  528. accuracy_rate = (clickable_texts - marked_errors_count) / clickable_texts * 100 if clickable_texts > 0 else 0
  529. return {
  530. 'total_texts': total_texts,
  531. 'clickable_texts': clickable_texts,
  532. 'marked_errors': marked_errors_count,
  533. 'categories': categories,
  534. 'accuracy_rate': accuracy_rate,
  535. 'tool_info': tool_info
  536. }
  537. def group_texts_by_category(text_bbox_mapping: Dict[str, List]) -> Dict[str, List[str]]:
  538. """按类别对文本进行分组"""
  539. categories = {}
  540. for text, info_list in text_bbox_mapping.items():
  541. category = info_list[0]['category']
  542. if category not in categories:
  543. categories[category] = []
  544. categories[category].append(text)
  545. return categories
  546. def get_ocr_tool_rotation_config(ocr_data: List, config: Dict) -> Dict:
  547. """获取OCR工具的旋转配置"""
  548. if not ocr_data or not isinstance(ocr_data, list):
  549. # 默认配置
  550. return {
  551. 'coordinates_are_pre_rotated': False
  552. }
  553. # 从第一个OCR数据项获取工具类型
  554. first_item = ocr_data[0] if ocr_data else {}
  555. source_tool = first_item.get('source_tool', 'dots_ocr')
  556. # 获取工具配置
  557. tools_config = config.get('ocr', {}).get('tools', {})
  558. if source_tool in tools_config:
  559. tool_config = tools_config[source_tool]
  560. return tool_config.get('rotation', {
  561. 'coordinates_are_pre_rotated': False
  562. })
  563. else:
  564. # 默认配置
  565. return {
  566. 'coordinates_are_pre_rotated': False
  567. }
  568. # ocr_validator_utils.py
  569. def find_available_ocr_files_multi_source(config: Dict) -> Dict[str, List[Dict]]:
  570. """查找多个数据源的OCR文件"""
  571. all_sources = {}
  572. for source in config.get('data_sources', []):
  573. source_name = source['name']
  574. ocr_tool = source['ocr_tool']
  575. source_key = f"{source_name}"
  576. ocr_out_dir = source['ocr_out_dir']
  577. if Path(ocr_out_dir).exists():
  578. files = find_available_ocr_files(ocr_out_dir)
  579. # 为每个文件添加数据源信息
  580. for file_info in files:
  581. file_info.update({
  582. 'source_name': source_name,
  583. 'ocr_tool': ocr_tool,
  584. 'description': source.get('description', ''),
  585. 'src_img_dir': source.get('src_img_dir', ''),
  586. 'ocr_out_dir': ocr_out_dir
  587. })
  588. all_sources[source_key] = {
  589. 'files': files,
  590. 'config': source
  591. }
  592. print(f"📁 找到数据源: {source_key} - {len(files)} 个文件")
  593. return all_sources
  594. def get_data_source_display_name(source_config: Dict) -> str:
  595. """生成数据源的显示名称"""
  596. name = source_config['name']
  597. tool = source_config['ocr_tool']
  598. description = source_config.get('description', '')
  599. # 获取工具的友好名称
  600. tool_name_map = {
  601. 'dots_ocr': 'Dots OCR',
  602. 'ppstructv3': 'PPStructV3',
  603. 'table_recognition_v2': 'Table Recognition V2',
  604. 'mineru': 'MinerU VLM-2.5.3'
  605. }
  606. tool_display = tool_name_map.get(tool, tool)
  607. return f"{name} ({tool_display})"
  608. def get_nested_value(data: Dict, path: str, default=None):
  609. if not path:
  610. return default
  611. keys = path.split('.')
  612. value = data
  613. for key in keys:
  614. if isinstance(value, dict) and key in value:
  615. value = value[key]
  616. else:
  617. return default
  618. return value