omnidocbench_eval.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671
  1. # zhch/omnidocbench_eval_fixed.py
  2. import json
  3. import time
  4. from pathlib import Path
  5. from typing import List, Dict, Any, Tuple
  6. import cv2
  7. import numpy as np
  8. from paddlex import create_pipeline
  9. import os
  10. import glob
  11. import traceback
  12. class OmniDocBenchEvaluator:
  13. """
  14. OmniDocBench评估器(修正版),用于生成符合评测格式的结果
  15. pipeline_config_path = "paddlex/configs/pipelines/PP-StructureV3.yaml"
  16. """
  17. def __init__(self, pipeline_config_path: str = "PP-StructureV3"):
  18. """
  19. 初始化评估器
  20. Args:
  21. pipeline_config_path: PaddleX pipeline配置文件路径
  22. """
  23. self.pipeline = create_pipeline(pipeline=pipeline_config_path)
  24. self.category_mapping = self._get_category_mapping()
  25. def _get_category_mapping(self) -> Dict[str, str]:
  26. """获取PaddleX类别到OmniDocBench类别的映射"""
  27. return {
  28. # PaddleX -> OmniDocBench 类别映射
  29. 'title': 'title',
  30. 'text': 'text_block',
  31. 'figure': 'figure',
  32. 'figure_caption': 'figure_caption',
  33. 'table': 'table',
  34. 'table_caption': 'table_caption',
  35. 'equation': 'equation_isolated',
  36. 'header': 'header',
  37. 'footer': 'footer',
  38. 'reference': 'reference',
  39. 'seal': 'abandon', # 印章通常作为舍弃类
  40. 'number': 'page_number',
  41. # 添加更多映射关系
  42. }
  43. def evaluate_single_image(self, image_path: str,
  44. use_gpu: bool = True,
  45. **kwargs) -> Dict[str, Any]:
  46. """
  47. 评估单张图像
  48. Args:
  49. image_path: 图像路径
  50. use_gpu: 是否使用GPU
  51. **kwargs: 其他pipeline参数
  52. Returns:
  53. 符合OmniDocBench格式的结果字典
  54. """
  55. print(f"正在处理图像: {image_path}")
  56. # 读取图像获取尺寸信息
  57. image = cv2.imread(image_path)
  58. if image is None:
  59. print(f"无法读取图像: {image_path}")
  60. return None
  61. height, width = image.shape[:2]
  62. # 运行PaddleX pipeline
  63. start_time = time.time()
  64. try:
  65. output = list(self.pipeline.predict(
  66. input=image_path,
  67. device="gpu" if use_gpu else "cpu",
  68. use_doc_orientation_classify=True,
  69. use_doc_unwarping=False,
  70. use_seal_recognition=True,
  71. use_chart_recognition=True,
  72. use_table_recognition=True,
  73. use_formula_recognition=True,
  74. **kwargs
  75. ))
  76. except Exception as e:
  77. print(f"处理图像 {image_path} 时发生错误: {str(e)}")
  78. traceback.print_exc()
  79. return None
  80. process_time = time.time() - start_time
  81. print(f"处理耗时: {process_time:.2f}秒")
  82. # 转换为OmniDocBench格式
  83. result = self._convert_to_omnidocbench_format(
  84. output, image_path, width, height
  85. )
  86. return result
  87. def _convert_to_omnidocbench_format(self,
  88. paddlex_output: List,
  89. image_path: str,
  90. width: int,
  91. height: int) -> Dict[str, Any]:
  92. """
  93. 将PaddleX输出转换为OmniDocBench格式
  94. Args:
  95. paddlex_output: PaddleX的输出结果列表
  96. image_path: 图像路径
  97. width: 图像宽度
  98. height: 图像高度
  99. Returns:
  100. OmniDocBench格式的结果
  101. """
  102. layout_dets = []
  103. anno_id_counter = 0
  104. # 处理PaddleX的输出
  105. for res in paddlex_output:
  106. res_json = res.json.get('res', {})
  107. # 从parsing_res_list中提取布局信息
  108. parsing_list = res_json.get('parsing_res_list', [])
  109. for item in parsing_list:
  110. # 提取边界框和类别
  111. bbox = item.get('block_bbox', [])
  112. category = item.get('block_label', 'text_block')
  113. content = item.get('block_content', '')
  114. # 转换bbox格式 [x1, y1, x2, y2] -> [x1, y1, x2, y1, x2, y2, x1, y2]
  115. if len(bbox) == 4:
  116. x1, y1, x2, y2 = bbox
  117. poly = [x1, y1, x2, y1, x2, y2, x1, y2]
  118. else:
  119. poly = bbox
  120. # 映射类别
  121. omni_category = self.category_mapping.get(category, 'text_block')
  122. # 创建layout检测结果
  123. layout_det = {
  124. "category_type": omni_category,
  125. "poly": poly,
  126. "ignore": False,
  127. "order": anno_id_counter,
  128. "anno_id": anno_id_counter,
  129. }
  130. # 添加文本识别结果
  131. if content and content.strip():
  132. if omni_category == 'table':
  133. # 表格内容作为HTML存储
  134. layout_det["html"] = content
  135. else:
  136. # 其他类型作为文本存储
  137. layout_det["text"] = content.strip()
  138. # 添加span级别的标注(从OCR结果中提取)
  139. layout_det["line_with_spans"] = self._extract_spans_from_ocr(
  140. res_json, bbox, omni_category
  141. )
  142. # 添加属性标签
  143. layout_det["attribute"] = self._extract_attributes(item, omni_category)
  144. layout_dets.append(layout_det)
  145. anno_id_counter += 1
  146. # 构建完整结果
  147. result = {
  148. "layout_dets": layout_dets,
  149. "page_info": {
  150. "page_no": 0,
  151. "height": height,
  152. "width": width,
  153. "image_path": Path(image_path).name,
  154. "page_attribute": self._extract_page_attributes(paddlex_output)
  155. },
  156. "extra": {
  157. "relation": [] # 关系信息,需要根据具体情况提取
  158. }
  159. }
  160. return result
  161. def _extract_spans_from_ocr(self, res, block_bbox: List, category: str) -> List[Dict]:
  162. """从OCR结果中提取span级别的标注"""
  163. spans = []
  164. # 如果有OCR结果,提取相关的文本行
  165. ocr_res = res.get('overall_ocr_res', None)
  166. if ocr_res:
  167. texts = ocr_res.get('rec_texts', [])
  168. boxes = ocr_res.get('rec_boxes', [])
  169. scores = ocr_res.get('rec_scores', [1.0] * len(texts)) if 'rec_scores' in ocr_res else [1.0] * len(texts)
  170. # 检查哪些OCR结果在当前block内
  171. if len(block_bbox) == 4:
  172. x1, y1, x2, y2 = block_bbox
  173. for i, (text, box, score) in enumerate(zip(texts, boxes, scores)):
  174. if len(box) >= 4:
  175. # 检查OCR框是否在block内
  176. ocr_x1, ocr_y1, ocr_x2, ocr_y2 = box[:4]
  177. # 简单的包含检查
  178. if (ocr_x1 >= x1 and ocr_y1 >= y1 and
  179. ocr_x2 <= x2 and ocr_y2 <= y2):
  180. span = {
  181. "category_type": "text_span",
  182. "poly": [ocr_x1, ocr_y1, ocr_x2, ocr_y1,
  183. ocr_x2, ocr_y2, ocr_x1, ocr_y2],
  184. "ignore": False,
  185. "text": text,
  186. }
  187. # 如果置信度太低,可能需要忽略
  188. if score < 0.5:
  189. span["ignore"] = True
  190. spans.append(span)
  191. return spans
  192. def _extract_attributes(self, item: Dict, category: str) -> Dict:
  193. """提取属性标签"""
  194. attributes = {}
  195. # 根据类别提取不同的属性
  196. if category == 'table':
  197. # 表格属性
  198. attributes.update({
  199. "table_layout": "vertical", # 需要根据实际情况判断
  200. "with_span": False, # 需要检查是否有合并单元格
  201. "line": "full_line", # 需要检查线框类型
  202. "language": "table_simplified_chinese", # 需要语言检测
  203. "include_equation": False,
  204. "include_backgroud": False,
  205. "table_vertical": False
  206. })
  207. # 检查表格内容是否有合并单元格
  208. content = item.get('block_content', '')
  209. if 'colspan' in content or 'rowspan' in content:
  210. attributes["with_span"] = True
  211. elif category in ['text_block', 'title']:
  212. # 文本属性
  213. attributes.update({
  214. "text_language": "text_simplified_chinese",
  215. "text_background": "white",
  216. "text_rotate": "normal"
  217. })
  218. elif 'equation' in category:
  219. # 公式属性
  220. attributes.update({
  221. "formula_type": "print"
  222. })
  223. return attributes
  224. def _extract_page_attributes(self, paddlex_output) -> Dict:
  225. """提取页面级别的属性"""
  226. return {
  227. "data_source": "research_report", # 需要根据实际情况判断
  228. "language": "simplified_chinese",
  229. "layout": "single_column",
  230. "watermark": False,
  231. "fuzzy_scan": False,
  232. "colorful_backgroud": False
  233. }
  234. def load_existing_result(self, result_path: str) -> Dict[str, Any]:
  235. """
  236. 从已有的PaddleX结果文件加载数据进行转换
  237. Args:
  238. result_path: PaddleX结果JSON文件路径
  239. Returns:
  240. OmniDocBench格式的结果字典
  241. """
  242. if not Path(result_path).exists():
  243. print(f"结果文件不存在: {result_path}")
  244. return None
  245. try:
  246. with open(result_path, 'r', encoding='utf-8') as f:
  247. data = json.load(f)
  248. # 从结果文件中提取图像信息
  249. input_path = data.get('input_path', '')
  250. # 读取图像获取尺寸
  251. if input_path and Path(input_path).exists():
  252. image = cv2.imread(input_path)
  253. if image is None:
  254. print(f"无法读取图像: {input_path}")
  255. height, width = 1600, 1200
  256. image_name = "unknown.png"
  257. else:
  258. height, width = image.shape[:2]
  259. image_name = Path(input_path).name
  260. else:
  261. # 如果图像路径不存在,使用默认值
  262. height, width = 1600, 1200
  263. image_name = "unknown.png"
  264. # 转换格式
  265. result = self._convert_paddlex_result_to_omnidocbench(
  266. data, image_name, width, height
  267. )
  268. return result
  269. except Exception as e:
  270. print(f"加载结果文件 {result_path} 时发生错误: {str(e)}")
  271. traceback.print_exc()
  272. return None
  273. def load_omnidocbench_dataset(self, dataset_dir: str) -> Dict[str, Any]:
  274. """
  275. 加载OmniDocBench数据集
  276. Args:
  277. dataset_dir: 数据集目录路径
  278. Returns:
  279. 加载的数据集字典
  280. """
  281. dataset = {}
  282. # 遍历数据集目录
  283. for file_path in Path(dataset_dir).rglob('*.json'):
  284. if "pred" in file_path.name or "result" in file_path.name:
  285. continue
  286. try:
  287. with open(file_path, 'r', encoding='utf-8') as f:
  288. data = json.load(f)
  289. # 提取文件ID
  290. file_id = file_path.stem
  291. dataset[file_id] = {
  292. "ground_truth": data,
  293. "predictions": None,
  294. "image_path": self._find_image_file(file_path.parent, file_id)
  295. }
  296. except Exception as e:
  297. print(f"加载文件 {file_path} 出错: {str(e)}")
  298. return dataset
  299. def _find_image_file(self, search_dir: Path, file_id: str) -> str:
  300. """查找对应的图像文件"""
  301. image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']
  302. for ext in image_extensions:
  303. image_path = search_dir / f"{file_id}{ext}"
  304. if image_path.exists():
  305. return str(image_path)
  306. # 如果找不到图像文件,尝试在子目录中查找
  307. for subdir in search_dir.iterdir():
  308. if subdir.is_dir():
  309. result = self._find_image_file(subdir, file_id)
  310. if result:
  311. return result
  312. return ""
  313. def generate_predictions(self, dataset: Dict[str, Any]) -> Dict[str, Any]:
  314. """
  315. 为数据集生成预测
  316. Args:
  317. dataset: 数据集字典
  318. Returns:
  319. 包含预测结果的数据集字典
  320. """
  321. for file_id, item in dataset.items():
  322. if item["image_path"] and Path(item["image_path"]).exists():
  323. try:
  324. result = self.evaluate_single_image(item["image_path"])
  325. dataset[file_id]["predictions"] = result
  326. print(f"成功处理文件: {file_id}")
  327. except Exception as e:
  328. print(f"处理文件 {file_id} 出错: {str(e)}")
  329. else:
  330. print(f"图像文件不存在: {item['image_path']}")
  331. return dataset
  332. def evaluate_dataset(self, dataset: Dict[str, Any]) -> Dict[str, float]:
  333. """
  334. 评估数据集的预测结果
  335. Args:
  336. dataset: 包含预测和真实标签的数据集字典
  337. Returns:
  338. 包含评估指标的字典
  339. """
  340. metrics = {
  341. "precision": 0.0,
  342. "recall": 0.0,
  343. "f1_score": 0.0,
  344. "iou": 0.0
  345. }
  346. # 实现具体的评估逻辑
  347. # 这里只是一个示例框架
  348. total_precision = 0.0
  349. total_recall = 0.0
  350. total_f1 = 0.0
  351. total_iou = 0.0
  352. count = 0
  353. for file_id, item in dataset.items():
  354. if item["predictions"] is None:
  355. continue
  356. # 计算单个样本的评估指标
  357. sample_metrics = self._calculate_sample_metrics(item["ground_truth"], item["predictions"])
  358. total_precision += sample_metrics["precision"]
  359. total_recall += sample_metrics["recall"]
  360. total_f1 += sample_metrics["f1_score"]
  361. total_iou += sample_metrics["iou"]
  362. count += 1
  363. if count > 0:
  364. metrics["precision"] = total_precision / count
  365. metrics["recall"] = total_recall / count
  366. metrics["f1_score"] = total_f1 / count
  367. metrics["iou"] = total_iou / count
  368. return metrics
  369. def _calculate_sample_metrics(self, ground_truth: Dict, prediction: Dict) -> Dict[str, float]:
  370. """
  371. 计算单个样本的评估指标
  372. Args:
  373. ground_truth: 真实标签数据
  374. prediction: 预测结果
  375. Returns:
  376. 包含评估指标的字典
  377. """
  378. metrics = {
  379. "precision": 0.0,
  380. "recall": 0.0,
  381. "f1_score": 0.0,
  382. "iou": 0.0
  383. }
  384. # 实现具体的评估逻辑
  385. # 这里只是一个示例框架
  386. gt_layouts = ground_truth.get("layout_dets", [])
  387. pred_layouts = prediction.get("layout_dets", [])
  388. # 简单的类别匹配计算
  389. gt_categories = set(item.get("category_type", "") for item in gt_layouts)
  390. pred_categories = set(item.get("category_type", "") for item in pred_layouts)
  391. # 计算交并集
  392. intersection = len(gt_categories.intersection(pred_categories))
  393. union = len(gt_categories.union(pred_categories))
  394. if union > 0:
  395. metrics["iou"] = intersection / union
  396. # 计算精确度、召回率、F1分数
  397. if len(pred_categories) > 0:
  398. metrics["precision"] = intersection / len(pred_categories)
  399. if len(gt_categories) > 0:
  400. metrics["recall"] = intersection / len(gt_categories)
  401. if metrics["precision"] + metrics["recall"] > 0:
  402. metrics["f1_score"] = 2 * (metrics["precision"] * metrics["recall"]) / (metrics["precision"] + metrics["recall"])
  403. return metrics
  404. def _convert_paddlex_result_to_omnidocbench(self,
  405. paddlex_result: Dict,
  406. image_name: str,
  407. width: int,
  408. height: int) -> Dict[str, Any]:
  409. """
  410. 将已有的PaddleX结果转换为OmniDocBench格式
  411. """
  412. layout_dets = []
  413. anno_id_counter = 0
  414. # 从parsing_res_list中提取布局信息
  415. parsing_list = paddlex_result.get('parsing_res_list', [])
  416. for item in parsing_list:
  417. # 提取边界框和类别
  418. bbox = item.get('block_bbox', [])
  419. category = item.get('block_label', 'text_block')
  420. content = item.get('block_content', '')
  421. # 转换bbox格式
  422. if len(bbox) == 4:
  423. x1, y1, x2, y2 = bbox
  424. poly = [x1, y1, x2, y1, x2, y2, x1, y2]
  425. else:
  426. poly = bbox
  427. # 映射类别
  428. omni_category = self.category_mapping.get(category, 'text_block')
  429. # 创建layout检测结果
  430. layout_det = {
  431. "category_type": omni_category,
  432. "poly": poly,
  433. "ignore": False,
  434. "order": anno_id_counter,
  435. "anno_id": anno_id_counter,
  436. }
  437. # 添加内容
  438. if content and content.strip():
  439. if omni_category == 'table':
  440. layout_det["html"] = content
  441. else:
  442. layout_det["text"] = content.strip()
  443. # 添加属性
  444. layout_det["attribute"] = self._extract_attributes(item, omni_category)
  445. layout_det["line_with_spans"] = [] # 简化处理
  446. layout_dets.append(layout_det)
  447. anno_id_counter += 1
  448. # 构建完整结果
  449. result = {
  450. "layout_dets": layout_dets,
  451. "page_info": {
  452. "page_no": 0,
  453. "height": height,
  454. "width": width,
  455. "image_path": image_name,
  456. "page_attribute": {
  457. "data_source": "research_report",
  458. "language": "simplified_chinese",
  459. "layout": "single_column",
  460. "watermark": False,
  461. "fuzzy_scan": False,
  462. "colorful_backgroud": False
  463. }
  464. },
  465. "extra": {
  466. "relation": []
  467. }
  468. }
  469. return result
  470. def convert_existing_results():
  471. """转换已有的PaddleX结果"""
  472. try:
  473. evaluator = OmniDocBenchEvaluator()
  474. # 示例:转换单个结果文件
  475. result_file = "./sample_data/single_pipeline_output/PP-StructureV3-zhch/300674-母公司现金流量表-扫描_res.json"
  476. if Path(result_file).exists():
  477. print(f"正在转换结果文件: {result_file}")
  478. omnidocbench_result = evaluator.load_existing_result(result_file)
  479. if omnidocbench_result is None:
  480. print(f"转换结果为空: {result_file}")
  481. return
  482. # 保存转换后的结果
  483. output_file = "./omnidocbench_converted_result.json"
  484. try:
  485. with open(output_file, 'w', encoding='utf-8') as f:
  486. json.dump([omnidocbench_result], f, ensure_ascii=False, indent=2)
  487. print(f"转换完成,结果保存至: {output_file}")
  488. print(f"检测到的布局元素数量: {len(omnidocbench_result['layout_dets'])}")
  489. # 显示检测到的元素
  490. for i, item in enumerate(omnidocbench_result['layout_dets']):
  491. print(f" {i+1}. {item['category_type']}: {item.get('text', item.get('html', ''))[:50]}...")
  492. except Exception as e:
  493. print(f"保存结果到文件 {output_file} 时发生错误: {str(e)}")
  494. traceback.print_exc()
  495. else:
  496. print(f"结果文件不存在: {result_file}")
  497. except Exception as e:
  498. print(f"转换过程中发生致命错误: {str(e)}")
  499. traceback.print_exc()
  500. def process_omnidocbench_dataset():
  501. """处理OmniDocBench数据集"""
  502. try:
  503. # 初始化评估器
  504. evaluator = OmniDocBenchEvaluator()
  505. # 数据集路径
  506. dataset_path = "/home/ubuntu/zhch/OmniDocBench/OpenDataLab___OmniDocBench"
  507. result_dir = "/home/ubuntu/zhch/PaddleX/zhch/OmniDocBench_Result"
  508. # 确保结果目录存在
  509. os.makedirs(result_dir, exist_ok=True)
  510. # 查找所有图像文件
  511. image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp', '*.tiff']
  512. image_files = []
  513. for ext in image_extensions:
  514. image_files.extend(glob.glob(os.path.join(dataset_path, '**', ext), recursive=True))
  515. print(f"找到 {len(image_files)} 个图像文件")
  516. if not image_files:
  517. print("未找到任何图像文件,程序终止")
  518. return
  519. # 存储所有结果
  520. all_results = []
  521. # 处理每个图像
  522. for i, image_path in enumerate(image_files[:10]): # 限制处理前10个文件用于测试
  523. try:
  524. print(f"处理进度: {i+1}/{len(image_files[:10])}")
  525. # 处理单个图像
  526. result = evaluator.evaluate_single_image(image_path)
  527. if result is not None:
  528. all_results.append(result)
  529. except Exception as e:
  530. print(f"处理文件 {image_path} 时出错: {str(e)}")
  531. traceback.print_exc()
  532. continue
  533. # 保存结果
  534. output_file = os.path.join(result_dir, "OmniDocBench-PPStructureV3.json")
  535. try:
  536. with open(output_file, 'w', encoding='utf-8') as f:
  537. json.dump(all_results, f, ensure_ascii=False, indent=2)
  538. print(f"处理完成,结果保存至: {output_file}")
  539. print(f"共处理 {len(all_results)} 个文件")
  540. except Exception as e:
  541. print(f"保存结果文件时发生错误: {str(e)}")
  542. traceback.print_exc()
  543. except Exception as e:
  544. print(f"处理数据集时发生致命错误: {str(e)}")
  545. traceback.print_exc()
  546. if __name__ == "__main__":
  547. # convert_existing_results()
  548. process_omnidocbench_dataset()