omnidocbench_eval.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404
  1. # zhch/omnidocbench_eval_fixed.py
  2. import json
  3. import time
  4. from pathlib import Path
  5. from typing import List, Dict, Any, Tuple
  6. import cv2
  7. import numpy as np
  8. from paddlex import create_pipeline
  9. class OmniDocBenchEvaluator:
  10. """OmniDocBench评估器(修正版),用于生成符合评测格式的结果"""
  11. def __init__(self, pipeline_config_path: str = "./PP-StructureV3-zhch.yaml"):
  12. """
  13. 初始化评估器
  14. Args:
  15. pipeline_config_path: PaddleX pipeline配置文件路径
  16. """
  17. self.pipeline = create_pipeline(pipeline=pipeline_config_path)
  18. self.category_mapping = self._get_category_mapping()
  19. def _get_category_mapping(self) -> Dict[str, str]:
  20. """获取PaddleX类别到OmniDocBench类别的映射"""
  21. return {
  22. # PaddleX -> OmniDocBench 类别映射
  23. 'title': 'title',
  24. 'text': 'text_block',
  25. 'figure': 'figure',
  26. 'figure_caption': 'figure_caption',
  27. 'table': 'table',
  28. 'table_caption': 'table_caption',
  29. 'equation': 'equation_isolated',
  30. 'header': 'header',
  31. 'footer': 'footer',
  32. 'reference': 'reference',
  33. 'seal': 'abandon', # 印章通常作为舍弃类
  34. 'number': 'page_number',
  35. # 添加更多映射关系
  36. }
  37. def evaluate_single_image(self, image_path: str,
  38. use_gpu: bool = True,
  39. **kwargs) -> Dict[str, Any]:
  40. """
  41. 评估单张图像
  42. Args:
  43. image_path: 图像路径
  44. use_gpu: 是否使用GPU
  45. **kwargs: 其他pipeline参数
  46. Returns:
  47. 符合OmniDocBench格式的结果字典
  48. """
  49. print(f"正在处理图像: {image_path}")
  50. # 读取图像获取尺寸信息
  51. image = cv2.imread(image_path)
  52. height, width = image.shape[:2]
  53. # 运行PaddleX pipeline
  54. start_time = time.time()
  55. output = list(self.pipeline.predict(
  56. input=image_path,
  57. device="gpu" if use_gpu else "cpu",
  58. use_doc_orientation_classify=True,
  59. use_doc_unwarping=False,
  60. use_seal_recognition=True,
  61. use_chart_recognition=True,
  62. use_table_recognition=True,
  63. use_formula_recognition=True,
  64. **kwargs
  65. ))
  66. process_time = time.time() - start_time
  67. print(f"处理耗时: {process_time:.2f}秒")
  68. # 转换为OmniDocBench格式
  69. result = self._convert_to_omnidocbench_format(
  70. output, image_path, width, height
  71. )
  72. return result
  73. def _convert_to_omnidocbench_format(self,
  74. paddlex_output: List,
  75. image_path: str,
  76. width: int,
  77. height: int) -> Dict[str, Any]:
  78. """
  79. 将PaddleX输出转换为OmniDocBench格式
  80. Args:
  81. paddlex_output: PaddleX的输出结果列表
  82. image_path: 图像路径
  83. width: 图像宽度
  84. height: 图像高度
  85. Returns:
  86. OmniDocBench格式的结果
  87. """
  88. layout_dets = []
  89. anno_id_counter = 0
  90. # 处理PaddleX的输出
  91. for res in paddlex_output:
  92. # 从parsing_res_list中提取布局信息
  93. if hasattr(res, 'parsing_res_list') and res.parsing_res_list:
  94. parsing_list = res.parsing_res_list
  95. for item in parsing_list:
  96. # 提取边界框和类别
  97. bbox = item.get('block_bbox', [])
  98. category = item.get('block_label', 'text_block')
  99. content = item.get('block_content', '')
  100. # 转换bbox格式 [x1, y1, x2, y2] -> [x1, y1, x2, y1, x2, y2, x1, y2]
  101. if len(bbox) == 4:
  102. x1, y1, x2, y2 = bbox
  103. poly = [x1, y1, x2, y1, x2, y2, x1, y2]
  104. else:
  105. poly = bbox
  106. # 映射类别
  107. omni_category = self.category_mapping.get(category, 'text_block')
  108. # 创建layout检测结果
  109. layout_det = {
  110. "category_type": omni_category,
  111. "poly": poly,
  112. "ignore": False,
  113. "order": anno_id_counter,
  114. "anno_id": anno_id_counter,
  115. }
  116. # 添加文本识别结果
  117. if content and content.strip():
  118. if omni_category == 'table':
  119. # 表格内容作为HTML存储
  120. layout_det["html"] = content
  121. else:
  122. # 其他类型作为文本存储
  123. layout_det["text"] = content.strip()
  124. # 添加span级别的标注(从OCR结果中提取)
  125. layout_det["line_with_spans"] = self._extract_spans_from_ocr(
  126. res, bbox, omni_category
  127. )
  128. # 添加属性标签
  129. layout_det["attribute"] = self._extract_attributes(item, omni_category)
  130. layout_dets.append(layout_det)
  131. anno_id_counter += 1
  132. # 构建完整结果
  133. result = {
  134. "layout_dets": layout_dets,
  135. "page_info": {
  136. "page_no": 0,
  137. "height": height,
  138. "width": width,
  139. "image_path": Path(image_path).name,
  140. "page_attribute": self._extract_page_attributes(paddlex_output)
  141. },
  142. "extra": {
  143. "relation": [] # 关系信息,需要根据具体情况提取
  144. }
  145. }
  146. return result
  147. def _extract_spans_from_ocr(self, res, block_bbox: List, category: str) -> List[Dict]:
  148. """从OCR结果中提取span级别的标注"""
  149. spans = []
  150. # 如果有OCR结果,提取相关的文本行
  151. if hasattr(res, 'overall_ocr_res') and res.overall_ocr_res:
  152. ocr_res = res.overall_ocr_res
  153. if hasattr(ocr_res, 'rec_texts') and hasattr(ocr_res, 'rec_boxes'):
  154. texts = ocr_res.rec_texts
  155. boxes = ocr_res.rec_boxes
  156. scores = getattr(ocr_res, 'rec_scores', [1.0] * len(texts))
  157. # 检查哪些OCR结果在当前block内
  158. if len(block_bbox) == 4:
  159. x1, y1, x2, y2 = block_bbox
  160. for i, (text, box, score) in enumerate(zip(texts, boxes, scores)):
  161. if len(box) >= 4:
  162. # 检查OCR框是否在block内
  163. ocr_x1, ocr_y1, ocr_x2, ocr_y2 = box[:4]
  164. # 简单的包含检查
  165. if (ocr_x1 >= x1 and ocr_y1 >= y1 and
  166. ocr_x2 <= x2 and ocr_y2 <= y2):
  167. span = {
  168. "category_type": "text_span",
  169. "poly": [ocr_x1, ocr_y1, ocr_x2, ocr_y1,
  170. ocr_x2, ocr_y2, ocr_x1, ocr_y2],
  171. "ignore": False,
  172. "text": text,
  173. }
  174. # 如果置信度太低,可能需要忽略
  175. if score < 0.5:
  176. span["ignore"] = True
  177. spans.append(span)
  178. return spans
  179. def _extract_attributes(self, item: Dict, category: str) -> Dict:
  180. """提取属性标签"""
  181. attributes = {}
  182. # 根据类别提取不同的属性
  183. if category == 'table':
  184. # 表格属性
  185. attributes.update({
  186. "table_layout": "vertical", # 需要根据实际情况判断
  187. "with_span": False, # 需要检查是否有合并单元格
  188. "line": "full_line", # 需要检查线框类型
  189. "language": "table_simplified_chinese", # 需要语言检测
  190. "include_equation": False,
  191. "include_backgroud": False,
  192. "table_vertical": False
  193. })
  194. # 检查表格内容是否有合并单元格
  195. content = item.get('block_content', '')
  196. if 'colspan' in content or 'rowspan' in content:
  197. attributes["with_span"] = True
  198. elif category in ['text_block', 'title']:
  199. # 文本属性
  200. attributes.update({
  201. "text_language": "text_simplified_chinese",
  202. "text_background": "white",
  203. "text_rotate": "normal"
  204. })
  205. elif 'equation' in category:
  206. # 公式属性
  207. attributes.update({
  208. "formula_type": "print"
  209. })
  210. return attributes
  211. def _extract_page_attributes(self, paddlex_output) -> Dict:
  212. """提取页面级别的属性"""
  213. return {
  214. "data_source": "research_report", # 需要根据实际情况判断
  215. "language": "simplified_chinese",
  216. "layout": "single_column",
  217. "watermark": False,
  218. "fuzzy_scan": False,
  219. "colorful_backgroud": False
  220. }
  221. def load_existing_result(self, result_path: str) -> Dict[str, Any]:
  222. """
  223. 从已有的PaddleX结果文件加载数据进行转换
  224. Args:
  225. result_path: PaddleX结果JSON文件路径
  226. Returns:
  227. OmniDocBench格式的结果字典
  228. """
  229. with open(result_path, 'r', encoding='utf-8') as f:
  230. data = json.load(f)
  231. # 从结果文件中提取图像信息
  232. input_path = data.get('input_path', '')
  233. # 读取图像获取尺寸
  234. if input_path and Path(input_path).exists():
  235. image = cv2.imread(input_path)
  236. height, width = image.shape[:2]
  237. image_name = Path(input_path).name
  238. else:
  239. # 如果图像路径不存在,使用默认值
  240. height, width = 1600, 1200
  241. image_name = "unknown.png"
  242. # 转换格式
  243. result = self._convert_paddlex_result_to_omnidocbench(
  244. data, image_name, width, height
  245. )
  246. return result
  247. def _convert_paddlex_result_to_omnidocbench(self,
  248. paddlex_result: Dict,
  249. image_name: str,
  250. width: int,
  251. height: int) -> Dict[str, Any]:
  252. """
  253. 将已有的PaddleX结果转换为OmniDocBench格式
  254. """
  255. layout_dets = []
  256. anno_id_counter = 0
  257. # 从parsing_res_list中提取布局信息
  258. parsing_list = paddlex_result.get('parsing_res_list', [])
  259. for item in parsing_list:
  260. # 提取边界框和类别
  261. bbox = item.get('block_bbox', [])
  262. category = item.get('block_label', 'text_block')
  263. content = item.get('block_content', '')
  264. # 转换bbox格式
  265. if len(bbox) == 4:
  266. x1, y1, x2, y2 = bbox
  267. poly = [x1, y1, x2, y1, x2, y2, x1, y2]
  268. else:
  269. poly = bbox
  270. # 映射类别
  271. omni_category = self.category_mapping.get(category, 'text_block')
  272. # 创建layout检测结果
  273. layout_det = {
  274. "category_type": omni_category,
  275. "poly": poly,
  276. "ignore": False,
  277. "order": anno_id_counter,
  278. "anno_id": anno_id_counter,
  279. }
  280. # 添加内容
  281. if content and content.strip():
  282. if omni_category == 'table':
  283. layout_det["html"] = content
  284. else:
  285. layout_det["text"] = content.strip()
  286. # 添加属性
  287. layout_det["attribute"] = self._extract_attributes(item, omni_category)
  288. layout_det["line_with_spans"] = [] # 简化处理
  289. layout_dets.append(layout_det)
  290. anno_id_counter += 1
  291. # 构建完整结果
  292. result = {
  293. "layout_dets": layout_dets,
  294. "page_info": {
  295. "page_no": 0,
  296. "height": height,
  297. "width": width,
  298. "image_path": image_name,
  299. "page_attribute": {
  300. "data_source": "research_report",
  301. "language": "simplified_chinese",
  302. "layout": "single_column",
  303. "watermark": False,
  304. "fuzzy_scan": False,
  305. "colorful_backgroud": False
  306. }
  307. },
  308. "extra": {
  309. "relation": []
  310. }
  311. }
  312. return result
  313. def convert_existing_results():
  314. """转换已有的PaddleX结果"""
  315. evaluator = OmniDocBenchEvaluator()
  316. # 示例:转换单个结果文件
  317. result_file = "./sample_data/single_pipeline_output/PP-StructureV3-zhch/300674-母公司现金流量表-扫描_res.json"
  318. if Path(result_file).exists():
  319. print(f"正在转换结果文件: {result_file}")
  320. omnidocbench_result = evaluator.load_existing_result(result_file)
  321. # 保存转换后的结果
  322. output_file = "./omnidocbench_converted_result.json"
  323. with open(output_file, 'w', encoding='utf-8') as f:
  324. json.dump([omnidocbench_result], f, ensure_ascii=False, indent=2)
  325. print(f"转换完成,结果保存至: {output_file}")
  326. print(f"检测到的布局元素数量: {len(omnidocbench_result['layout_dets'])}")
  327. # 显示检测到的元素
  328. for i, item in enumerate(omnidocbench_result['layout_dets']):
  329. print(f" {i+1}. {item['category_type']}: {item.get('text', item.get('html', ''))[:50]}...")
  330. else:
  331. print(f"结果文件不存在: {result_file}")
  332. if __name__ == "__main__":
  333. convert_existing_results()