| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404 |
- # zhch/omnidocbench_eval_fixed.py
- import json
- import time
- from pathlib import Path
- from typing import List, Dict, Any, Tuple
- import cv2
- import numpy as np
- from paddlex import create_pipeline
- class OmniDocBenchEvaluator:
- """OmniDocBench评估器(修正版),用于生成符合评测格式的结果"""
-
- def __init__(self, pipeline_config_path: str = "./PP-StructureV3-zhch.yaml"):
- """
- 初始化评估器
-
- Args:
- pipeline_config_path: PaddleX pipeline配置文件路径
- """
- self.pipeline = create_pipeline(pipeline=pipeline_config_path)
- self.category_mapping = self._get_category_mapping()
-
- def _get_category_mapping(self) -> Dict[str, str]:
- """获取PaddleX类别到OmniDocBench类别的映射"""
- return {
- # PaddleX -> OmniDocBench 类别映射
- 'title': 'title',
- 'text': 'text_block',
- 'figure': 'figure',
- 'figure_caption': 'figure_caption',
- 'table': 'table',
- 'table_caption': 'table_caption',
- 'equation': 'equation_isolated',
- 'header': 'header',
- 'footer': 'footer',
- 'reference': 'reference',
- 'seal': 'abandon', # 印章通常作为舍弃类
- 'number': 'page_number',
- # 添加更多映射关系
- }
-
- def evaluate_single_image(self, image_path: str,
- use_gpu: bool = True,
- **kwargs) -> Dict[str, Any]:
- """
- 评估单张图像
-
- Args:
- image_path: 图像路径
- use_gpu: 是否使用GPU
- **kwargs: 其他pipeline参数
-
- Returns:
- 符合OmniDocBench格式的结果字典
- """
- print(f"正在处理图像: {image_path}")
-
- # 读取图像获取尺寸信息
- image = cv2.imread(image_path)
- height, width = image.shape[:2]
-
- # 运行PaddleX pipeline
- start_time = time.time()
-
- output = list(self.pipeline.predict(
- input=image_path,
- device="gpu" if use_gpu else "cpu",
- use_doc_orientation_classify=True,
- use_doc_unwarping=False,
- use_seal_recognition=True,
- use_chart_recognition=True,
- use_table_recognition=True,
- use_formula_recognition=True,
- **kwargs
- ))
-
- process_time = time.time() - start_time
- print(f"处理耗时: {process_time:.2f}秒")
-
- # 转换为OmniDocBench格式
- result = self._convert_to_omnidocbench_format(
- output, image_path, width, height
- )
-
- return result
-
- def _convert_to_omnidocbench_format(self,
- paddlex_output: List,
- image_path: str,
- width: int,
- height: int) -> Dict[str, Any]:
- """
- 将PaddleX输出转换为OmniDocBench格式
-
- Args:
- paddlex_output: PaddleX的输出结果列表
- image_path: 图像路径
- width: 图像宽度
- height: 图像高度
-
- Returns:
- OmniDocBench格式的结果
- """
- layout_dets = []
- anno_id_counter = 0
-
- # 处理PaddleX的输出
- for res in paddlex_output:
- # 从parsing_res_list中提取布局信息
- if hasattr(res, 'parsing_res_list') and res.parsing_res_list:
- parsing_list = res.parsing_res_list
-
- for item in parsing_list:
- # 提取边界框和类别
- bbox = item.get('block_bbox', [])
- category = item.get('block_label', 'text_block')
- content = item.get('block_content', '')
-
- # 转换bbox格式 [x1, y1, x2, y2] -> [x1, y1, x2, y1, x2, y2, x1, y2]
- if len(bbox) == 4:
- x1, y1, x2, y2 = bbox
- poly = [x1, y1, x2, y1, x2, y2, x1, y2]
- else:
- poly = bbox
-
- # 映射类别
- omni_category = self.category_mapping.get(category, 'text_block')
-
- # 创建layout检测结果
- layout_det = {
- "category_type": omni_category,
- "poly": poly,
- "ignore": False,
- "order": anno_id_counter,
- "anno_id": anno_id_counter,
- }
-
- # 添加文本识别结果
- if content and content.strip():
- if omni_category == 'table':
- # 表格内容作为HTML存储
- layout_det["html"] = content
- else:
- # 其他类型作为文本存储
- layout_det["text"] = content.strip()
-
- # 添加span级别的标注(从OCR结果中提取)
- layout_det["line_with_spans"] = self._extract_spans_from_ocr(
- res, bbox, omni_category
- )
-
- # 添加属性标签
- layout_det["attribute"] = self._extract_attributes(item, omni_category)
-
- layout_dets.append(layout_det)
- anno_id_counter += 1
-
- # 构建完整结果
- result = {
- "layout_dets": layout_dets,
- "page_info": {
- "page_no": 0,
- "height": height,
- "width": width,
- "image_path": Path(image_path).name,
- "page_attribute": self._extract_page_attributes(paddlex_output)
- },
- "extra": {
- "relation": [] # 关系信息,需要根据具体情况提取
- }
- }
-
- return result
-
- def _extract_spans_from_ocr(self, res, block_bbox: List, category: str) -> List[Dict]:
- """从OCR结果中提取span级别的标注"""
- spans = []
-
- # 如果有OCR结果,提取相关的文本行
- if hasattr(res, 'overall_ocr_res') and res.overall_ocr_res:
- ocr_res = res.overall_ocr_res
-
- if hasattr(ocr_res, 'rec_texts') and hasattr(ocr_res, 'rec_boxes'):
- texts = ocr_res.rec_texts
- boxes = ocr_res.rec_boxes
- scores = getattr(ocr_res, 'rec_scores', [1.0] * len(texts))
-
- # 检查哪些OCR结果在当前block内
- if len(block_bbox) == 4:
- x1, y1, x2, y2 = block_bbox
-
- for i, (text, box, score) in enumerate(zip(texts, boxes, scores)):
- if len(box) >= 4:
- # 检查OCR框是否在block内
- ocr_x1, ocr_y1, ocr_x2, ocr_y2 = box[:4]
-
- # 简单的包含检查
- if (ocr_x1 >= x1 and ocr_y1 >= y1 and
- ocr_x2 <= x2 and ocr_y2 <= y2):
-
- span = {
- "category_type": "text_span",
- "poly": [ocr_x1, ocr_y1, ocr_x2, ocr_y1,
- ocr_x2, ocr_y2, ocr_x1, ocr_y2],
- "ignore": False,
- "text": text,
- }
-
- # 如果置信度太低,可能需要忽略
- if score < 0.5:
- span["ignore"] = True
-
- spans.append(span)
-
- return spans
-
- def _extract_attributes(self, item: Dict, category: str) -> Dict:
- """提取属性标签"""
- attributes = {}
-
- # 根据类别提取不同的属性
- if category == 'table':
- # 表格属性
- attributes.update({
- "table_layout": "vertical", # 需要根据实际情况判断
- "with_span": False, # 需要检查是否有合并单元格
- "line": "full_line", # 需要检查线框类型
- "language": "table_simplified_chinese", # 需要语言检测
- "include_equation": False,
- "include_backgroud": False,
- "table_vertical": False
- })
-
- # 检查表格内容是否有合并单元格
- content = item.get('block_content', '')
- if 'colspan' in content or 'rowspan' in content:
- attributes["with_span"] = True
-
- elif category in ['text_block', 'title']:
- # 文本属性
- attributes.update({
- "text_language": "text_simplified_chinese",
- "text_background": "white",
- "text_rotate": "normal"
- })
-
- elif 'equation' in category:
- # 公式属性
- attributes.update({
- "formula_type": "print"
- })
-
- return attributes
-
- def _extract_page_attributes(self, paddlex_output) -> Dict:
- """提取页面级别的属性"""
- return {
- "data_source": "research_report", # 需要根据实际情况判断
- "language": "simplified_chinese",
- "layout": "single_column",
- "watermark": False,
- "fuzzy_scan": False,
- "colorful_backgroud": False
- }
-
- def load_existing_result(self, result_path: str) -> Dict[str, Any]:
- """
- 从已有的PaddleX结果文件加载数据进行转换
-
- Args:
- result_path: PaddleX结果JSON文件路径
-
- Returns:
- OmniDocBench格式的结果字典
- """
- with open(result_path, 'r', encoding='utf-8') as f:
- data = json.load(f)
-
- # 从结果文件中提取图像信息
- input_path = data.get('input_path', '')
-
- # 读取图像获取尺寸
- if input_path and Path(input_path).exists():
- image = cv2.imread(input_path)
- height, width = image.shape[:2]
- image_name = Path(input_path).name
- else:
- # 如果图像路径不存在,使用默认值
- height, width = 1600, 1200
- image_name = "unknown.png"
-
- # 转换格式
- result = self._convert_paddlex_result_to_omnidocbench(
- data, image_name, width, height
- )
-
- return result
-
- def _convert_paddlex_result_to_omnidocbench(self,
- paddlex_result: Dict,
- image_name: str,
- width: int,
- height: int) -> Dict[str, Any]:
- """
- 将已有的PaddleX结果转换为OmniDocBench格式
- """
- layout_dets = []
- anno_id_counter = 0
-
- # 从parsing_res_list中提取布局信息
- parsing_list = paddlex_result.get('parsing_res_list', [])
-
- for item in parsing_list:
- # 提取边界框和类别
- bbox = item.get('block_bbox', [])
- category = item.get('block_label', 'text_block')
- content = item.get('block_content', '')
-
- # 转换bbox格式
- if len(bbox) == 4:
- x1, y1, x2, y2 = bbox
- poly = [x1, y1, x2, y1, x2, y2, x1, y2]
- else:
- poly = bbox
-
- # 映射类别
- omni_category = self.category_mapping.get(category, 'text_block')
-
- # 创建layout检测结果
- layout_det = {
- "category_type": omni_category,
- "poly": poly,
- "ignore": False,
- "order": anno_id_counter,
- "anno_id": anno_id_counter,
- }
-
- # 添加内容
- if content and content.strip():
- if omni_category == 'table':
- layout_det["html"] = content
- else:
- layout_det["text"] = content.strip()
-
- # 添加属性
- layout_det["attribute"] = self._extract_attributes(item, omni_category)
- layout_det["line_with_spans"] = [] # 简化处理
-
- layout_dets.append(layout_det)
- anno_id_counter += 1
-
- # 构建完整结果
- result = {
- "layout_dets": layout_dets,
- "page_info": {
- "page_no": 0,
- "height": height,
- "width": width,
- "image_path": image_name,
- "page_attribute": {
- "data_source": "research_report",
- "language": "simplified_chinese",
- "layout": "single_column",
- "watermark": False,
- "fuzzy_scan": False,
- "colorful_backgroud": False
- }
- },
- "extra": {
- "relation": []
- }
- }
-
- return result
- def convert_existing_results():
- """转换已有的PaddleX结果"""
- evaluator = OmniDocBenchEvaluator()
-
- # 示例:转换单个结果文件
- result_file = "./sample_data/single_pipeline_output/PP-StructureV3-zhch/300674-母公司现金流量表-扫描_res.json"
-
- if Path(result_file).exists():
- print(f"正在转换结果文件: {result_file}")
-
- omnidocbench_result = evaluator.load_existing_result(result_file)
-
- # 保存转换后的结果
- output_file = "./omnidocbench_converted_result.json"
- with open(output_file, 'w', encoding='utf-8') as f:
- json.dump([omnidocbench_result], f, ensure_ascii=False, indent=2)
-
- print(f"转换完成,结果保存至: {output_file}")
- print(f"检测到的布局元素数量: {len(omnidocbench_result['layout_dets'])}")
-
- # 显示检测到的元素
- for i, item in enumerate(omnidocbench_result['layout_dets']):
- print(f" {i+1}. {item['category_type']}: {item.get('text', item.get('html', ''))[:50]}...")
-
- else:
- print(f"结果文件不存在: {result_file}")
- if __name__ == "__main__":
- convert_existing_results()
|