ocr_validator_utils.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879
  1. """
  2. OCR验证工具的工具函数模块
  3. 包含数据处理、图像处理、统计分析等功能
  4. """
  5. import json
  6. import pandas as pd
  7. import numpy as np
  8. from pathlib import Path
  9. from PIL import Image, ImageDraw
  10. from typing import Dict, List, Optional, Tuple, Union
  11. import re
  12. import yaml
  13. import sys
  14. from ocr_validator_file_utils import process_all_images_in_content
  15. def load_config(config_path: str = "config.yaml") -> Dict:
  16. """加载配置文件"""
  17. try:
  18. with open(config_path, 'r', encoding='utf-8') as f:
  19. return yaml.safe_load(f)
  20. except Exception as e:
  21. print(f"加载配置文件失败: {e}")
  22. import traceback
  23. traceback.print_exc()
  24. # 退出
  25. sys.exit(1)
  26. def rotate_image_and_coordinates(
  27. image: Image.Image,
  28. angle: float,
  29. coordinates_list: List[List[int]],
  30. rotate_coordinates: bool = True
  31. ) -> Tuple[Image.Image, List[List[int]]]:
  32. """
  33. 根据角度旋转图像和坐标 - 修正版本
  34. Args:
  35. image: 原始图像
  36. angle: 旋转角度(度数)
  37. coordinates_list: 坐标列表,每个坐标为[x1, y1, x2, y2]格式
  38. rotate_coordinates: 是否需要旋转坐标(针对不同OCR工具的处理方式)
  39. Returns:
  40. rotated_image: 旋转后的图像
  41. rotated_coordinates: 处理后的坐标列表
  42. """
  43. if angle == 0:
  44. return image, coordinates_list
  45. # 标准化旋转角度
  46. if angle == 270:
  47. rotation_angle = -90 # 顺时针90度
  48. elif angle == 90:
  49. rotation_angle = 90 # 逆时针90度
  50. elif angle == 180:
  51. rotation_angle = 180 # 180度
  52. else:
  53. rotation_angle = angle
  54. # 旋转图像
  55. rotated_image = image.rotate(rotation_angle, expand=True)
  56. # 如果不需要旋转坐标,直接返回原坐标
  57. if not rotate_coordinates:
  58. return rotated_image, coordinates_list
  59. # 获取原始和旋转后的图像尺寸
  60. orig_width, orig_height = image.size
  61. new_width, new_height = rotated_image.size
  62. # 计算旋转后的坐标
  63. rotated_coordinates = []
  64. for coord in coordinates_list:
  65. if len(coord) < 4:
  66. rotated_coordinates.append(coord)
  67. continue
  68. x1, y1, x2, y2 = coord[:4]
  69. # 验证原始坐标是否有效
  70. if x1 < 0 or y1 < 0 or x2 <= x1 or y2 <= y1:
  71. print(f"警告: 无效坐标 {coord}")
  72. rotated_coordinates.append([0, 0, 50, 50]) # 使用默认坐标
  73. continue
  74. # 根据旋转角度变换坐标
  75. if rotation_angle == -90: # 顺时针90度 (270度逆时针)
  76. # 变换公式: (x, y) -> (orig_height - y, x)
  77. new_x1 = orig_height - y2 # 这里是y2
  78. new_y1 = x1
  79. new_x2 = orig_height - y1 # 这里是y1
  80. new_y2 = x2
  81. elif rotation_angle == 90: # 逆时针90度
  82. # 变换公式: (x, y) -> (y, orig_width - x)
  83. new_x1 = y1
  84. new_y1 = orig_width - x2 # 这里是x2
  85. new_x2 = y2
  86. new_y2 = orig_width - x1 # 这里是x1
  87. elif rotation_angle == 180: # 180度
  88. # 变换公式: (x, y) -> (orig_width - x, orig_height - y)
  89. new_x1 = orig_width - x2
  90. new_y1 = orig_height - y2
  91. new_x2 = orig_width - x1
  92. new_y2 = orig_height - y1
  93. else: # 任意角度算法 - 修正版本
  94. # 将角度转换为弧度
  95. angle_rad = np.radians(rotation_angle)
  96. cos_angle = np.cos(angle_rad)
  97. sin_angle = np.sin(angle_rad)
  98. # 原图像中心点
  99. orig_center_x = orig_width / 2
  100. orig_center_y = orig_height / 2
  101. # 旋转后图像中心点
  102. new_center_x = new_width / 2
  103. new_center_y = new_height / 2
  104. # 将bbox的四个角点转换为相对于原图像中心的坐标
  105. corners = [
  106. (x1 - orig_center_x, y1 - orig_center_y), # 左上角
  107. (x2 - orig_center_x, y1 - orig_center_y), # 右上角
  108. (x2 - orig_center_x, y2 - orig_center_y), # 右下角
  109. (x1 - orig_center_x, y2 - orig_center_y) # 左下角
  110. ]
  111. # 应用修正后的旋转矩阵变换每个角点
  112. rotated_corners = []
  113. for x, y in corners:
  114. # 修正后的旋转矩阵: [cos(θ) sin(θ)] [x]
  115. # [-sin(θ) cos(θ)] [y]
  116. rotated_x = x * cos_angle + y * sin_angle
  117. rotated_y = -x * sin_angle + y * cos_angle
  118. # 转换回绝对坐标(相对于新图像)
  119. abs_x = rotated_x + new_center_x
  120. abs_y = rotated_y + new_center_y
  121. rotated_corners.append((abs_x, abs_y))
  122. # 从旋转后的四个角点计算新的边界框
  123. x_coords = [corner[0] for corner in rotated_corners]
  124. y_coords = [corner[1] for corner in rotated_corners]
  125. new_x1 = int(min(x_coords))
  126. new_y1 = int(min(y_coords))
  127. new_x2 = int(max(x_coords))
  128. new_y2 = int(max(y_coords))
  129. # 确保坐标在有效范围内
  130. new_x1 = max(0, min(new_width, new_x1))
  131. new_y1 = max(0, min(new_height, new_y1))
  132. new_x2 = max(0, min(new_width, new_x2))
  133. new_y2 = max(0, min(new_height, new_y2))
  134. # 确保x1 < x2, y1 < y2
  135. if new_x1 > new_x2:
  136. new_x1, new_x2 = new_x2, new_x1
  137. if new_y1 > new_y2:
  138. new_y1, new_y2 = new_y2, new_y1
  139. rotated_coordinates.append([new_x1, new_y1, new_x2, new_y2])
  140. return rotated_image, rotated_coordinates
  141. def parse_dots_ocr_data(data: List, config: Dict, tool_name: str) -> List[Dict]:
  142. """解析Dots OCR格式的数据"""
  143. tool_config = config['ocr']['tools'][tool_name]
  144. parsed_data = []
  145. for item in data:
  146. if not isinstance(item, dict):
  147. continue
  148. # 提取字段
  149. text = item.get(tool_config['text_field'], '')
  150. bbox = item.get(tool_config['bbox_field'], [])
  151. category = item.get(tool_config['category_field'], 'Text')
  152. confidence = item.get(tool_config.get('confidence_field', 'confidence'),
  153. config['ocr']['default_confidence'])
  154. if text and bbox and len(bbox) >= 4:
  155. parsed_data.append({
  156. 'text': str(text).strip(),
  157. 'bbox': bbox[:4], # 确保只取前4个坐标
  158. 'category': category,
  159. 'confidence': confidence,
  160. 'source_tool': tool_name
  161. })
  162. return parsed_data
  163. def parse_ppstructv3_data(data: Dict, config: Dict) -> List[Dict]:
  164. """解析PPStructV3格式的数据"""
  165. tool_config = config['ocr']['tools']['ppstructv3']
  166. parsed_data = []
  167. parsing_results = data.get(tool_config['parsing_results_field'], [])
  168. if not isinstance(parsing_results, list):
  169. return parsed_data
  170. for item in parsing_results:
  171. if not isinstance(item, dict):
  172. continue
  173. text = item.get(tool_config['text_field'], '')
  174. bbox = item.get(tool_config['bbox_field'], [])
  175. category = item.get(tool_config['category_field'], 'text')
  176. confidence = item.get(
  177. tool_config.get('confidence_field', 'confidence'),
  178. config['ocr']['default_confidence']
  179. )
  180. if text and bbox and len(bbox) >= 4:
  181. parsed_data.append({
  182. 'text': str(text).strip(),
  183. 'bbox': bbox[:4],
  184. 'category': category,
  185. 'confidence': confidence,
  186. 'source_tool': 'ppstructv3'
  187. })
  188. rec_texts = get_nested_value(data, tool_config.get('rec_texts_field', ''))
  189. rec_boxes = get_nested_value(data, tool_config.get('rec_boxes_field', ''))
  190. if isinstance(rec_texts, list) and isinstance(rec_boxes, list):
  191. for i, (text, box) in enumerate(zip(rec_texts, rec_boxes)):
  192. if text and isinstance(box, list) and len(box) >= 4:
  193. parsed_data.append({
  194. 'text': str(text).strip(),
  195. 'bbox': box[:4],
  196. 'category': 'OCR_Text',
  197. 'source_tool': 'ppstructv3_ocr'
  198. })
  199. return parsed_data
  200. def parse_table_recognition_v2_data(data: Dict, config: Dict) -> List[Dict]:
  201. tool_config = config['ocr']['tools']['table_recognition_v2']
  202. parsed_data = []
  203. tables = data.get(tool_config['parsing_results_field'], [])
  204. if not isinstance(tables, list):
  205. return parsed_data
  206. for item in tables:
  207. if not isinstance(item, dict):
  208. continue
  209. html_text = item.get(tool_config['text_field'], '')
  210. # 计算表格整体bbox
  211. cell_boxes_raw = item.get(tool_config['bbox_field'], [])
  212. if cell_boxes_raw:
  213. x1_list = [box[0] for box in cell_boxes_raw]
  214. y1_list = [box[1] for box in cell_boxes_raw]
  215. x2_list = [box[2] for box in cell_boxes_raw]
  216. y2_list = [box[3] for box in cell_boxes_raw]
  217. table_bbox = [
  218. float(min(x1_list)),
  219. float(min(y1_list)),
  220. float(max(x2_list)),
  221. float(max(y2_list))
  222. ]
  223. else:
  224. table_bbox = [0.0, 0.0, 0.0, 0.0]
  225. parsed_data.append({
  226. 'text': str(html_text).strip(),
  227. 'bbox': table_bbox,
  228. 'category': item.get(tool_config.get('category_field', ''), 'table'),
  229. 'confidence': item.get(tool_config.get('confidence_field', ''), config['ocr']['default_confidence']),
  230. 'source_tool': 'table_recognition_v2',
  231. })
  232. rec_texts = get_nested_value(item, tool_config.get('rec_texts_field', ''))
  233. rec_boxes = get_nested_value(item, tool_config.get('rec_boxes_field', ''))
  234. if isinstance(rec_texts, list) and isinstance(rec_boxes, list):
  235. for i, (text, box) in enumerate(zip(rec_texts, rec_boxes)):
  236. if text and isinstance(box, list) and len(box) >= 4:
  237. parsed_data.append({
  238. 'text': str(text).strip(),
  239. 'bbox': box[:4],
  240. 'category': 'OCR_Text',
  241. 'source_tool': 'table_recognition_v2'
  242. })
  243. return parsed_data
  244. def parse_mineru_data(data: List, config: Dict, tool_name="mineru") -> List[Dict]:
  245. """解析MinerU格式的数据"""
  246. tool_config = config['ocr']['tools'][tool_name]
  247. parsed_data = []
  248. if not isinstance(data, list):
  249. return parsed_data
  250. for item in data:
  251. if not isinstance(item, dict):
  252. continue
  253. text = item.get(tool_config['text_field'], '')
  254. bbox = item.get(tool_config['bbox_field'], [])
  255. category = item.get(tool_config['category_field'], 'Text')
  256. confidence = item.get(tool_config.get('confidence_field', 'confidence'),
  257. config['ocr']['default_confidence'])
  258. # 处理文本类型
  259. if category == 'text':
  260. if text and bbox and len(bbox) >= 4:
  261. parsed_data.append({
  262. 'text': str(text).strip(),
  263. 'bbox': bbox[:4],
  264. 'category': category,
  265. 'confidence': confidence,
  266. 'source_tool': tool_name,
  267. 'text_level': item.get('text_level', 0) # 保留文本层级信息
  268. })
  269. # 处理表格类型
  270. elif category == 'table':
  271. table_html = item.get(tool_config.get('table_body_field', 'table_body'), '')
  272. img_path = item.get(tool_config.get('img_path_field', 'img_path'), '')
  273. if bbox and len(bbox) >= 4:
  274. parsed_data.append({
  275. 'text': table_html,
  276. 'bbox': bbox[:4],
  277. 'category': 'table',
  278. 'confidence': confidence,
  279. 'source_tool': tool_name,
  280. 'img_path': img_path,
  281. 'table_body': table_html
  282. })
  283. table_cells = item.get(tool_config.get('table_cells_field', 'table_cells'), [])
  284. for cell in table_cells:
  285. cell_text = cell.get('text', '')
  286. cell_bbox = cell.get('bbox', [])
  287. if cell_text and cell_bbox and len(cell_bbox) >= 4:
  288. parsed_data.append({
  289. 'text': str(cell_text).strip(),
  290. 'bbox': cell_bbox[:4],
  291. 'row': cell.get('row', -1),
  292. 'col': cell.get('col', -1),
  293. 'category': 'table_cell',
  294. 'confidence': cell.get('score', 0.0),
  295. 'source_tool': tool_name,
  296. })
  297. # 处理图片类型
  298. elif category == 'image':
  299. img_path = item.get(tool_config.get('img_path_field', 'img_path'), '')
  300. if bbox and len(bbox) >= 4:
  301. parsed_data.append({
  302. 'text': '[Image]',
  303. 'bbox': bbox[:4],
  304. 'category': 'image',
  305. 'confidence': confidence,
  306. 'source_tool': tool_name,
  307. 'img_path': img_path
  308. })
  309. elif category in ['list']:
  310. # 处理列表和标题类型
  311. list_items = item.get('list_items', [])
  312. sub_type = item.get('sub_type', 'unordered') # 有序或无序
  313. for list_item in list_items:
  314. if list_item and bbox and len(bbox) >= 4:
  315. parsed_data.append({
  316. 'text': str(list_item).strip(),
  317. 'bbox': bbox[:4],
  318. 'category': category,
  319. 'sub_type': sub_type,
  320. 'confidence': confidence,
  321. 'source_tool': tool_name
  322. })
  323. else:
  324. # 其他类型,按文本处理, header, table_cell, ...
  325. if text and bbox and len(bbox) >= 4:
  326. parsed_data.append({
  327. 'text': str(text).strip(),
  328. 'bbox': bbox[:4],
  329. 'category': category,
  330. 'confidence': confidence,
  331. 'source_tool': tool_name
  332. })
  333. return parsed_data
  334. def detect_mineru_structure(data: Union[List, Dict]) -> bool:
  335. """检测是否为MinerU数据结构"""
  336. if not isinstance(data, list) or len(data) == 0:
  337. return False
  338. # 检查第一个元素是否包含MinerU特征字段
  339. first_item = data[0] if data else {}
  340. if not isinstance(first_item, dict):
  341. return False
  342. # MinerU特征:包含type字段,且值为text/table/image之一
  343. has_type = 'type' in first_item
  344. has_bbox = 'bbox' in first_item
  345. has_text = 'text' in first_item
  346. if has_type and has_bbox and has_text:
  347. item_type = first_item.get('type', '')
  348. return item_type in ['text', 'table', 'image']
  349. return False
  350. def detect_ocr_tool_type(data: Union[List, Dict], config: Dict) -> str:
  351. """
  352. 自动检测OCR工具类型 - 增强版
  353. Args:
  354. data: OCR数据(可能是列表或字典)
  355. config: 配置字典
  356. Returns:
  357. 工具类型字符串
  358. """
  359. if not config['ocr']['auto_detection']['enabled']:
  360. return 'mineru' # 默认类型
  361. rules = config['ocr']['auto_detection']['rules']
  362. # 按优先级排序
  363. sorted_rules = sorted(rules, key=lambda x: x.get('priority', 999))
  364. for rule in sorted_rules:
  365. tool_type = rule['tool_type']
  366. conditions = rule.get('conditions', [])
  367. # 检查所有条件是否满足
  368. if _check_all_conditions(data, conditions):
  369. return tool_type
  370. # 如果所有规则都不匹配,返回默认类型
  371. return 'dots_ocr'
  372. def _check_all_conditions(data: Union[List, Dict], conditions: List[Dict]) -> bool:
  373. """
  374. 检查所有条件是否满足
  375. Args:
  376. data: 数据
  377. conditions: 条件列表
  378. Returns:
  379. 是否所有条件都满足
  380. """
  381. for condition in conditions:
  382. condition_type = condition.get('type', '')
  383. if condition_type == 'field_exists':
  384. # 检查字段存在
  385. field = condition.get('field', '')
  386. if not _check_field_exists(data, field):
  387. return False
  388. elif condition_type == 'field_not_exists':
  389. # 检查字段不存在
  390. field = condition.get('field', '')
  391. if _check_field_exists(data, field):
  392. return False
  393. elif condition_type == 'json_structure':
  394. # 检查JSON结构类型
  395. expected_structure = condition.get('structure', '')
  396. if expected_structure == 'array' and not isinstance(data, list):
  397. return False
  398. elif expected_structure == 'object' and not isinstance(data, dict):
  399. return False
  400. elif condition_type == 'field_value':
  401. # 检查字段值
  402. field = condition.get('field', '')
  403. expected_value = condition.get('value')
  404. actual_value = _get_field_value(data, field)
  405. if actual_value != expected_value:
  406. return False
  407. elif condition_type == 'field_contains':
  408. # 检查字段包含某个值
  409. field = condition.get('field', '')
  410. expected_values = condition.get('values', [])
  411. actual_value = _get_field_value(data, field)
  412. if actual_value not in expected_values:
  413. return False
  414. return True
  415. def _check_field_exists(data: Union[List, Dict], field_path: str) -> bool:
  416. """
  417. 检查字段是否存在(支持嵌套路径)
  418. Args:
  419. data: 数据
  420. field_path: 字段路径(支持点分隔,如 "doc_preprocessor_res.angle")
  421. Returns:
  422. 字段是否存在
  423. """
  424. if not field_path:
  425. return False
  426. # 处理数组情况:检查第一个元素
  427. if isinstance(data, list):
  428. if not data:
  429. return False
  430. data = data[0]
  431. # 处理嵌套字段路径
  432. fields = field_path.split('.')
  433. current = data
  434. for field in fields:
  435. if isinstance(current, dict) and field in current:
  436. current = current[field]
  437. else:
  438. return False
  439. return True
  440. def _get_field_value(data: Union[List, Dict], field_path: str):
  441. """
  442. 获取字段值(支持嵌套路径)
  443. Args:
  444. data: 数据
  445. field_path: 字段路径
  446. Returns:
  447. 字段值,如果不存在返回 None
  448. """
  449. if not field_path:
  450. return None
  451. # 处理数组情况:检查第一个元素
  452. if isinstance(data, list):
  453. if not data:
  454. return None
  455. data = data[0]
  456. # 处理嵌套字段路径
  457. fields = field_path.split('.')
  458. current = data
  459. for field in fields:
  460. if isinstance(current, dict) and field in current:
  461. current = current[field]
  462. else:
  463. return None
  464. return current
  465. def normalize_ocr_data(raw_data: Union[List, Dict], config: Dict) -> List[Dict]:
  466. """标准化OCR数据 - 支持多种工具"""
  467. tool_type = detect_ocr_tool_type(raw_data, config)
  468. if tool_type == 'dots_ocr':
  469. return parse_dots_ocr_data(raw_data, config, tool_type)
  470. elif tool_type == 'ppstructv3':
  471. return parse_ppstructv3_data(raw_data, config)
  472. elif tool_type == 'table_recognition_v2':
  473. return parse_table_recognition_v2_data(raw_data, config)
  474. elif tool_type == 'mineru':
  475. return parse_mineru_data(raw_data, config, tool_type)
  476. else:
  477. raise ValueError(f"不支持的OCR工具类型: {tool_type}")
  478. def get_rotation_angle_from_ppstructv3(data: Dict) -> float:
  479. """从PPStructV3数据中获取旋转角度"""
  480. if 'doc_preprocessor_res' in data:
  481. doc_res = data['doc_preprocessor_res']
  482. if isinstance(doc_res, dict) and 'angle' in doc_res:
  483. return float(doc_res['angle'])
  484. return 0.0
  485. # 修改 load_ocr_data_file 函数
  486. def load_ocr_data_file(json_path: str, config: Dict) -> Tuple[List, str, str]:
  487. """加载OCR数据文件 - 支持多数据源配置"""
  488. json_file = Path(json_path)
  489. if not json_file.exists():
  490. raise FileNotFoundError(f"找不到JSON文件: {json_path}")
  491. # 加载JSON数据
  492. try:
  493. with open(json_file, 'r', encoding='utf-8') as f:
  494. raw_data = json.load(f)
  495. # 统一数据格式
  496. ocr_data = normalize_ocr_data(raw_data, config)
  497. # 检查是否需要处理图像旋转
  498. rotation_angle = 0.0
  499. if isinstance(raw_data, dict):
  500. rotation_angle = get_rotation_angle_from_ppstructv3(raw_data)
  501. # 如果有旋转角度,记录下来供后续图像处理使用
  502. if rotation_angle != 0:
  503. for item in ocr_data:
  504. item['rotation_angle'] = rotation_angle
  505. except Exception as e:
  506. raise Exception(f"加载JSON文件失败: {e}")
  507. # 加载MD文件
  508. md_file = json_file.with_suffix('.md')
  509. md_content = ""
  510. if md_file.exists():
  511. with open(md_file, 'r', encoding='utf-8') as f:
  512. md_content = f.read()
  513. # ✅ 关键修改:处理MD内容中的所有图片引用
  514. md_content = process_all_images_in_content(md_content, str(json_file))
  515. # 查找对应的图片文件
  516. image_path = find_corresponding_image(json_file, config)
  517. return ocr_data, md_content, image_path
  518. def find_corresponding_image(json_file: Path, config: Dict) -> str:
  519. """查找对应的图片文件 - 支持多数据源"""
  520. # 从配置中获取图片目录
  521. src_img_dir = config.get('paths', {}).get('src_img_dir', '')
  522. if not src_img_dir:
  523. # 如果没有配置图片目录,尝试在JSON文件同级目录查找
  524. src_img_dir = json_file.parent
  525. src_img_path = Path(src_img_dir)
  526. # 支持多种图片格式
  527. image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif']
  528. for ext in image_extensions:
  529. image_file = src_img_path / f"{json_file.stem}{ext}"
  530. if image_file.exists():
  531. return str(image_file)
  532. # 如果找不到,返回空字符串
  533. return ""
  534. def process_ocr_data(ocr_data: List, config: Dict) -> Dict[str, List]:
  535. """处理OCR数据,建立文本到bbox的映射"""
  536. text_bbox_mapping = {}
  537. exclude_texts = config['ocr']['exclude_texts']
  538. min_text_length = config['ocr']['min_text_length']
  539. if not isinstance(ocr_data, list):
  540. return text_bbox_mapping
  541. for i, item in enumerate(ocr_data):
  542. if not isinstance(item, dict):
  543. continue
  544. text = str(item['text']).strip()
  545. if text and text not in exclude_texts and len(text) >= min_text_length:
  546. bbox = item['bbox']
  547. if isinstance(bbox, list) and len(bbox) == 4:
  548. if text not in text_bbox_mapping:
  549. text_bbox_mapping[text] = []
  550. text_bbox_mapping[text].append({
  551. 'bbox': bbox,
  552. 'category': item.get('category', 'Text'),
  553. 'index': i,
  554. 'confidence': item.get('confidence', config['ocr']['default_confidence']),
  555. 'source_tool': item.get('source_tool', 'unknown'),
  556. 'rotation_angle': item.get('rotation_angle', 0.0) # 添加旋转角度信息
  557. })
  558. return text_bbox_mapping
  559. def find_available_ocr_files(ocr_out_dir: str) -> List[str]:
  560. """查找可用的OCR文件"""
  561. available_files = []
  562. # 搜索多个可能的目录
  563. search_dirs = [
  564. Path(ocr_out_dir),
  565. ]
  566. for search_dir in search_dirs:
  567. if search_dir.exists():
  568. # 递归搜索JSON文件
  569. for json_file in search_dir.rglob("*.json"):
  570. if re.match(r'.*_page_\d+\.json$', json_file.name, re.IGNORECASE):
  571. available_files.append(str(json_file))
  572. # 去重并排序
  573. # available_files = sorted(list(set(available_files)))
  574. # 解析文件名并提取页码信息
  575. file_info = []
  576. for file_path in available_files:
  577. file_name = Path(file_path).stem
  578. # 提取页码 (例如从 "2023年度报告母公司_page_001" 中提取 "001")
  579. if 'page_' in file_name:
  580. try:
  581. page_part = file_name.split('page_')[-1]
  582. page_num = int(page_part)
  583. file_info.append({
  584. 'path': file_path,
  585. 'page': page_num,
  586. 'display_name': f"第{page_num}页"
  587. })
  588. except ValueError:
  589. # 如果无法解析页码,使用文件名
  590. file_info.append({
  591. 'path': file_path,
  592. 'page': len(file_info) + 1,
  593. 'display_name': Path(file_path).stem
  594. })
  595. else:
  596. # 对于没有page_的文件,按顺序编号
  597. file_info.append({
  598. 'path': file_path,
  599. 'page': len(file_info) + 1,
  600. 'display_name': Path(file_path).stem
  601. })
  602. # 按页码排序
  603. file_info.sort(key=lambda x: x['page'])
  604. return file_info
  605. def get_ocr_tool_info(ocr_data: List) -> Dict:
  606. """获取OCR工具信息统计"""
  607. tool_counts = {}
  608. for item in ocr_data:
  609. if isinstance(item, dict):
  610. source_tool = item.get('source_tool', 'unknown')
  611. tool_counts[source_tool] = tool_counts.get(source_tool, 0) + 1
  612. return tool_counts
  613. def get_ocr_statistics(ocr_data: List, text_bbox_mapping: Dict, marked_errors: set) -> Dict:
  614. """获取OCR数据统计信息"""
  615. if not isinstance(ocr_data, list) or not ocr_data:
  616. return {
  617. 'total_texts': 0, 'clickable_texts': 0, 'marked_errors': 0,
  618. 'categories': {}, 'accuracy_rate': 0, 'tool_info': {}
  619. }
  620. total_texts = len(ocr_data)
  621. clickable_texts = len(text_bbox_mapping)
  622. marked_errors_count = len(marked_errors)
  623. # 按类别统计
  624. categories = {}
  625. for item in ocr_data:
  626. if isinstance(item, dict):
  627. category = item.get('category', 'Unknown')
  628. categories[category] = categories.get(category, 0) + 1
  629. # OCR工具信息统计
  630. tool_info = get_ocr_tool_info(ocr_data)
  631. accuracy_rate = (clickable_texts - marked_errors_count) / clickable_texts * 100 if clickable_texts > 0 else 0
  632. return {
  633. 'total_texts': total_texts,
  634. 'clickable_texts': clickable_texts,
  635. 'marked_errors': marked_errors_count,
  636. 'categories': categories,
  637. 'accuracy_rate': accuracy_rate,
  638. 'tool_info': tool_info
  639. }
  640. def group_texts_by_category(text_bbox_mapping: Dict[str, List]) -> Dict[str, List[str]]:
  641. """按类别对文本进行分组"""
  642. categories = {}
  643. for text, info_list in text_bbox_mapping.items():
  644. category = info_list[0]['category']
  645. if category not in categories:
  646. categories[category] = []
  647. categories[category].append(text)
  648. return categories
  649. def get_ocr_tool_rotation_config(ocr_data: List, config: Dict) -> Dict:
  650. """获取OCR工具的旋转配置"""
  651. if not ocr_data or not isinstance(ocr_data, list):
  652. # 默认配置
  653. return {
  654. 'coordinates_are_pre_rotated': False
  655. }
  656. # 从第一个OCR数据项获取工具类型
  657. first_item = ocr_data[0] if ocr_data else {}
  658. source_tool = first_item.get('source_tool', 'dots_ocr')
  659. # 获取工具配置
  660. tools_config = config.get('ocr', {}).get('tools', {})
  661. if source_tool in tools_config:
  662. tool_config = tools_config[source_tool]
  663. return tool_config.get('rotation', {
  664. 'coordinates_are_pre_rotated': False
  665. })
  666. else:
  667. # 默认配置
  668. return {
  669. 'coordinates_are_pre_rotated': False
  670. }
  671. # ocr_validator_utils.py
  672. def find_available_ocr_files_multi_source(config: Dict) -> Dict[str, List[Dict]]:
  673. """查找多个数据源的OCR文件"""
  674. all_sources = {}
  675. for source in config.get('data_sources', []):
  676. source_name = source['name']
  677. ocr_tool = source['ocr_tool']
  678. source_key = f"{source_name}_{ocr_tool}" # 创建唯一标识
  679. ocr_out_dir = source['ocr_out_dir']
  680. if Path(ocr_out_dir).exists():
  681. files = find_available_ocr_files(ocr_out_dir)
  682. # 为每个文件添加数据源信息
  683. for file_info in files:
  684. file_info.update({
  685. 'source_name': source_name,
  686. 'ocr_tool': ocr_tool,
  687. 'description': source.get('description', ''),
  688. 'src_img_dir': source.get('src_img_dir', ''),
  689. 'ocr_out_dir': ocr_out_dir
  690. })
  691. all_sources[source_key] = {
  692. 'files': files,
  693. 'config': source
  694. }
  695. print(f"📁 找到数据源: {source_key} - {len(files)} 个文件")
  696. return all_sources
  697. def get_data_source_display_name(source_config: Dict) -> str:
  698. """生成数据源的显示名称"""
  699. name = source_config['name']
  700. tool = source_config['ocr_tool']
  701. description = source_config.get('description', '')
  702. # 获取工具的友好名称
  703. tool_name_map = {
  704. 'dots_ocr': 'Dots OCR',
  705. 'ppstructv3': 'PPStructV3',
  706. 'table_recognition_v2': 'Table Recognition V2',
  707. 'mineru': 'MinerU VLM-2.5.3'
  708. }
  709. tool_display = tool_name_map.get(tool, tool)
  710. return f"{name} ({tool_display})"
  711. def get_nested_value(data: Dict, path: str, default=None):
  712. if not path:
  713. return default
  714. keys = path.split('.')
  715. value = data
  716. for key in keys:
  717. if isinstance(value, dict) and key in value:
  718. value = value[key]
  719. else:
  720. return default
  721. return value