| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671 |
- # zhch/omnidocbench_eval_fixed.py
- import json
- import time
- from pathlib import Path
- from typing import List, Dict, Any, Tuple
- import cv2
- import numpy as np
- from paddlex import create_pipeline
- import os
- import glob
- import traceback
- class OmniDocBenchEvaluator:
- """
- OmniDocBench评估器(修正版),用于生成符合评测格式的结果
- pipeline_config_path = "paddlex/configs/pipelines/PP-StructureV3.yaml"
- """
-
- def __init__(self, pipeline_config_path: str = "PP-StructureV3"):
- """
- 初始化评估器
-
- Args:
- pipeline_config_path: PaddleX pipeline配置文件路径
- """
- self.pipeline = create_pipeline(pipeline=pipeline_config_path)
- self.category_mapping = self._get_category_mapping()
-
- def _get_category_mapping(self) -> Dict[str, str]:
- """获取PaddleX类别到OmniDocBench类别的映射"""
- return {
- # PaddleX -> OmniDocBench 类别映射
- 'title': 'title',
- 'text': 'text_block',
- 'figure': 'figure',
- 'figure_caption': 'figure_caption',
- 'table': 'table',
- 'table_caption': 'table_caption',
- 'equation': 'equation_isolated',
- 'header': 'header',
- 'footer': 'footer',
- 'reference': 'reference',
- 'seal': 'abandon', # 印章通常作为舍弃类
- 'number': 'page_number',
- # 添加更多映射关系
- }
-
- def evaluate_single_image(self, image_path: str,
- use_gpu: bool = True,
- **kwargs) -> Dict[str, Any]:
- """
- 评估单张图像
-
- Args:
- image_path: 图像路径
- use_gpu: 是否使用GPU
- **kwargs: 其他pipeline参数
-
- Returns:
- 符合OmniDocBench格式的结果字典
- """
- print(f"正在处理图像: {image_path}")
-
- # 读取图像获取尺寸信息
- image = cv2.imread(image_path)
- if image is None:
- print(f"无法读取图像: {image_path}")
- return None
-
- height, width = image.shape[:2]
-
- # 运行PaddleX pipeline
- start_time = time.time()
-
- try:
- output = list(self.pipeline.predict(
- input=image_path,
- device="gpu" if use_gpu else "cpu",
- use_doc_orientation_classify=True,
- use_doc_unwarping=False,
- use_seal_recognition=True,
- use_chart_recognition=True,
- use_table_recognition=True,
- use_formula_recognition=True,
- **kwargs
- ))
- except Exception as e:
- print(f"处理图像 {image_path} 时发生错误: {str(e)}")
- traceback.print_exc()
- return None
-
- process_time = time.time() - start_time
- print(f"处理耗时: {process_time:.2f}秒")
-
- # 转换为OmniDocBench格式
- result = self._convert_to_omnidocbench_format(
- output, image_path, width, height
- )
-
- return result
-
- def _convert_to_omnidocbench_format(self,
- paddlex_output: List,
- image_path: str,
- width: int,
- height: int) -> Dict[str, Any]:
- """
- 将PaddleX输出转换为OmniDocBench格式
-
- Args:
- paddlex_output: PaddleX的输出结果列表
- image_path: 图像路径
- width: 图像宽度
- height: 图像高度
-
- Returns:
- OmniDocBench格式的结果
- """
- layout_dets = []
- anno_id_counter = 0
-
- # 处理PaddleX的输出
- for res in paddlex_output:
- res_json = res.json.get('res', {})
- # 从parsing_res_list中提取布局信息
- parsing_list = res_json.get('parsing_res_list', [])
- for item in parsing_list:
- # 提取边界框和类别
- bbox = item.get('block_bbox', [])
- category = item.get('block_label', 'text_block')
- content = item.get('block_content', '')
-
- # 转换bbox格式 [x1, y1, x2, y2] -> [x1, y1, x2, y1, x2, y2, x1, y2]
- if len(bbox) == 4:
- x1, y1, x2, y2 = bbox
- poly = [x1, y1, x2, y1, x2, y2, x1, y2]
- else:
- poly = bbox
-
- # 映射类别
- omni_category = self.category_mapping.get(category, 'text_block')
-
- # 创建layout检测结果
- layout_det = {
- "category_type": omni_category,
- "poly": poly,
- "ignore": False,
- "order": anno_id_counter,
- "anno_id": anno_id_counter,
- }
-
- # 添加文本识别结果
- if content and content.strip():
- if omni_category == 'table':
- # 表格内容作为HTML存储
- layout_det["html"] = content
- else:
- # 其他类型作为文本存储
- layout_det["text"] = content.strip()
-
- # 添加span级别的标注(从OCR结果中提取)
- layout_det["line_with_spans"] = self._extract_spans_from_ocr(
- res_json, bbox, omni_category
- )
-
- # 添加属性标签
- layout_det["attribute"] = self._extract_attributes(item, omni_category)
-
- layout_dets.append(layout_det)
- anno_id_counter += 1
-
- # 构建完整结果
- result = {
- "layout_dets": layout_dets,
- "page_info": {
- "page_no": 0,
- "height": height,
- "width": width,
- "image_path": Path(image_path).name,
- "page_attribute": self._extract_page_attributes(paddlex_output)
- },
- "extra": {
- "relation": [] # 关系信息,需要根据具体情况提取
- }
- }
-
- return result
-
- def _extract_spans_from_ocr(self, res, block_bbox: List, category: str) -> List[Dict]:
- """从OCR结果中提取span级别的标注"""
- spans = []
-
- # 如果有OCR结果,提取相关的文本行
- ocr_res = res.get('overall_ocr_res', None)
- if ocr_res:
- texts = ocr_res.get('rec_texts', [])
- boxes = ocr_res.get('rec_boxes', [])
- scores = ocr_res.get('rec_scores', [1.0] * len(texts)) if 'rec_scores' in ocr_res else [1.0] * len(texts)
-
- # 检查哪些OCR结果在当前block内
- if len(block_bbox) == 4:
- x1, y1, x2, y2 = block_bbox
-
- for i, (text, box, score) in enumerate(zip(texts, boxes, scores)):
- if len(box) >= 4:
- # 检查OCR框是否在block内
- ocr_x1, ocr_y1, ocr_x2, ocr_y2 = box[:4]
-
- # 简单的包含检查
- if (ocr_x1 >= x1 and ocr_y1 >= y1 and
- ocr_x2 <= x2 and ocr_y2 <= y2):
-
- span = {
- "category_type": "text_span",
- "poly": [ocr_x1, ocr_y1, ocr_x2, ocr_y1,
- ocr_x2, ocr_y2, ocr_x1, ocr_y2],
- "ignore": False,
- "text": text,
- }
-
- # 如果置信度太低,可能需要忽略
- if score < 0.5:
- span["ignore"] = True
-
- spans.append(span)
-
- return spans
-
- def _extract_attributes(self, item: Dict, category: str) -> Dict:
- """提取属性标签"""
- attributes = {}
-
- # 根据类别提取不同的属性
- if category == 'table':
- # 表格属性
- attributes.update({
- "table_layout": "vertical", # 需要根据实际情况判断
- "with_span": False, # 需要检查是否有合并单元格
- "line": "full_line", # 需要检查线框类型
- "language": "table_simplified_chinese", # 需要语言检测
- "include_equation": False,
- "include_backgroud": False,
- "table_vertical": False
- })
-
- # 检查表格内容是否有合并单元格
- content = item.get('block_content', '')
- if 'colspan' in content or 'rowspan' in content:
- attributes["with_span"] = True
-
- elif category in ['text_block', 'title']:
- # 文本属性
- attributes.update({
- "text_language": "text_simplified_chinese",
- "text_background": "white",
- "text_rotate": "normal"
- })
-
- elif 'equation' in category:
- # 公式属性
- attributes.update({
- "formula_type": "print"
- })
-
- return attributes
-
- def _extract_page_attributes(self, paddlex_output) -> Dict:
- """提取页面级别的属性"""
- return {
- "data_source": "research_report", # 需要根据实际情况判断
- "language": "simplified_chinese",
- "layout": "single_column",
- "watermark": False,
- "fuzzy_scan": False,
- "colorful_backgroud": False
- }
-
- def load_existing_result(self, result_path: str) -> Dict[str, Any]:
- """
- 从已有的PaddleX结果文件加载数据进行转换
-
- Args:
- result_path: PaddleX结果JSON文件路径
-
- Returns:
- OmniDocBench格式的结果字典
- """
- if not Path(result_path).exists():
- print(f"结果文件不存在: {result_path}")
- return None
-
- try:
- with open(result_path, 'r', encoding='utf-8') as f:
- data = json.load(f)
-
- # 从结果文件中提取图像信息
- input_path = data.get('input_path', '')
-
- # 读取图像获取尺寸
- if input_path and Path(input_path).exists():
- image = cv2.imread(input_path)
- if image is None:
- print(f"无法读取图像: {input_path}")
- height, width = 1600, 1200
- image_name = "unknown.png"
- else:
- height, width = image.shape[:2]
- image_name = Path(input_path).name
- else:
- # 如果图像路径不存在,使用默认值
- height, width = 1600, 1200
- image_name = "unknown.png"
-
- # 转换格式
- result = self._convert_paddlex_result_to_omnidocbench(
- data, image_name, width, height
- )
-
- return result
-
- except Exception as e:
- print(f"加载结果文件 {result_path} 时发生错误: {str(e)}")
- traceback.print_exc()
- return None
-
- def load_omnidocbench_dataset(self, dataset_dir: str) -> Dict[str, Any]:
- """
- 加载OmniDocBench数据集
-
- Args:
- dataset_dir: 数据集目录路径
-
- Returns:
- 加载的数据集字典
- """
- dataset = {}
-
- # 遍历数据集目录
- for file_path in Path(dataset_dir).rglob('*.json'):
- if "pred" in file_path.name or "result" in file_path.name:
- continue
-
- try:
- with open(file_path, 'r', encoding='utf-8') as f:
- data = json.load(f)
-
- # 提取文件ID
- file_id = file_path.stem
- dataset[file_id] = {
- "ground_truth": data,
- "predictions": None,
- "image_path": self._find_image_file(file_path.parent, file_id)
- }
-
- except Exception as e:
- print(f"加载文件 {file_path} 出错: {str(e)}")
-
- return dataset
-
- def _find_image_file(self, search_dir: Path, file_id: str) -> str:
- """查找对应的图像文件"""
- image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']
-
- for ext in image_extensions:
- image_path = search_dir / f"{file_id}{ext}"
- if image_path.exists():
- return str(image_path)
-
- # 如果找不到图像文件,尝试在子目录中查找
- for subdir in search_dir.iterdir():
- if subdir.is_dir():
- result = self._find_image_file(subdir, file_id)
- if result:
- return result
-
- return ""
-
- def generate_predictions(self, dataset: Dict[str, Any]) -> Dict[str, Any]:
- """
- 为数据集生成预测
-
- Args:
- dataset: 数据集字典
-
- Returns:
- 包含预测结果的数据集字典
- """
- for file_id, item in dataset.items():
- if item["image_path"] and Path(item["image_path"]).exists():
- try:
- result = self.evaluate_single_image(item["image_path"])
- dataset[file_id]["predictions"] = result
- print(f"成功处理文件: {file_id}")
- except Exception as e:
- print(f"处理文件 {file_id} 出错: {str(e)}")
- else:
- print(f"图像文件不存在: {item['image_path']}")
-
- return dataset
-
- def evaluate_dataset(self, dataset: Dict[str, Any]) -> Dict[str, float]:
- """
- 评估数据集的预测结果
-
- Args:
- dataset: 包含预测和真实标签的数据集字典
-
- Returns:
- 包含评估指标的字典
- """
- metrics = {
- "precision": 0.0,
- "recall": 0.0,
- "f1_score": 0.0,
- "iou": 0.0
- }
-
- # 实现具体的评估逻辑
- # 这里只是一个示例框架
- total_precision = 0.0
- total_recall = 0.0
- total_f1 = 0.0
- total_iou = 0.0
- count = 0
-
- for file_id, item in dataset.items():
- if item["predictions"] is None:
- continue
-
- # 计算单个样本的评估指标
- sample_metrics = self._calculate_sample_metrics(item["ground_truth"], item["predictions"])
-
- total_precision += sample_metrics["precision"]
- total_recall += sample_metrics["recall"]
- total_f1 += sample_metrics["f1_score"]
- total_iou += sample_metrics["iou"]
- count += 1
-
- if count > 0:
- metrics["precision"] = total_precision / count
- metrics["recall"] = total_recall / count
- metrics["f1_score"] = total_f1 / count
- metrics["iou"] = total_iou / count
-
- return metrics
-
- def _calculate_sample_metrics(self, ground_truth: Dict, prediction: Dict) -> Dict[str, float]:
- """
- 计算单个样本的评估指标
-
- Args:
- ground_truth: 真实标签数据
- prediction: 预测结果
-
- Returns:
- 包含评估指标的字典
- """
- metrics = {
- "precision": 0.0,
- "recall": 0.0,
- "f1_score": 0.0,
- "iou": 0.0
- }
-
- # 实现具体的评估逻辑
- # 这里只是一个示例框架
- gt_layouts = ground_truth.get("layout_dets", [])
- pred_layouts = prediction.get("layout_dets", [])
-
- # 简单的类别匹配计算
- gt_categories = set(item.get("category_type", "") for item in gt_layouts)
- pred_categories = set(item.get("category_type", "") for item in pred_layouts)
-
- # 计算交并集
- intersection = len(gt_categories.intersection(pred_categories))
- union = len(gt_categories.union(pred_categories))
-
- if union > 0:
- metrics["iou"] = intersection / union
-
- # 计算精确度、召回率、F1分数
- if len(pred_categories) > 0:
- metrics["precision"] = intersection / len(pred_categories)
-
- if len(gt_categories) > 0:
- metrics["recall"] = intersection / len(gt_categories)
-
- if metrics["precision"] + metrics["recall"] > 0:
- metrics["f1_score"] = 2 * (metrics["precision"] * metrics["recall"]) / (metrics["precision"] + metrics["recall"])
-
- return metrics
-
- def _convert_paddlex_result_to_omnidocbench(self,
- paddlex_result: Dict,
- image_name: str,
- width: int,
- height: int) -> Dict[str, Any]:
- """
- 将已有的PaddleX结果转换为OmniDocBench格式
- """
- layout_dets = []
- anno_id_counter = 0
-
- # 从parsing_res_list中提取布局信息
- parsing_list = paddlex_result.get('parsing_res_list', [])
-
- for item in parsing_list:
- # 提取边界框和类别
- bbox = item.get('block_bbox', [])
- category = item.get('block_label', 'text_block')
- content = item.get('block_content', '')
-
- # 转换bbox格式
- if len(bbox) == 4:
- x1, y1, x2, y2 = bbox
- poly = [x1, y1, x2, y1, x2, y2, x1, y2]
- else:
- poly = bbox
-
- # 映射类别
- omni_category = self.category_mapping.get(category, 'text_block')
-
- # 创建layout检测结果
- layout_det = {
- "category_type": omni_category,
- "poly": poly,
- "ignore": False,
- "order": anno_id_counter,
- "anno_id": anno_id_counter,
- }
-
- # 添加内容
- if content and content.strip():
- if omni_category == 'table':
- layout_det["html"] = content
- else:
- layout_det["text"] = content.strip()
-
- # 添加属性
- layout_det["attribute"] = self._extract_attributes(item, omni_category)
- layout_det["line_with_spans"] = [] # 简化处理
-
- layout_dets.append(layout_det)
- anno_id_counter += 1
-
- # 构建完整结果
- result = {
- "layout_dets": layout_dets,
- "page_info": {
- "page_no": 0,
- "height": height,
- "width": width,
- "image_path": image_name,
- "page_attribute": {
- "data_source": "research_report",
- "language": "simplified_chinese",
- "layout": "single_column",
- "watermark": False,
- "fuzzy_scan": False,
- "colorful_backgroud": False
- }
- },
- "extra": {
- "relation": []
- }
- }
-
- return result
- def convert_existing_results():
- """转换已有的PaddleX结果"""
- try:
- evaluator = OmniDocBenchEvaluator()
-
- # 示例:转换单个结果文件
- result_file = "./sample_data/single_pipeline_output/PP-StructureV3-zhch/300674-母公司现金流量表-扫描_res.json"
-
- if Path(result_file).exists():
- print(f"正在转换结果文件: {result_file}")
-
- omnidocbench_result = evaluator.load_existing_result(result_file)
- if omnidocbench_result is None:
- print(f"转换结果为空: {result_file}")
- return
-
- # 保存转换后的结果
- output_file = "./omnidocbench_converted_result.json"
- try:
- with open(output_file, 'w', encoding='utf-8') as f:
- json.dump([omnidocbench_result], f, ensure_ascii=False, indent=2)
-
- print(f"转换完成,结果保存至: {output_file}")
- print(f"检测到的布局元素数量: {len(omnidocbench_result['layout_dets'])}")
-
- # 显示检测到的元素
- for i, item in enumerate(omnidocbench_result['layout_dets']):
- print(f" {i+1}. {item['category_type']}: {item.get('text', item.get('html', ''))[:50]}...")
- except Exception as e:
- print(f"保存结果到文件 {output_file} 时发生错误: {str(e)}")
- traceback.print_exc()
- else:
- print(f"结果文件不存在: {result_file}")
-
- except Exception as e:
- print(f"转换过程中发生致命错误: {str(e)}")
- traceback.print_exc()
- def process_omnidocbench_dataset():
- """处理OmniDocBench数据集"""
- try:
- # 初始化评估器
- evaluator = OmniDocBenchEvaluator()
-
- # 数据集路径
- dataset_path = "/home/ubuntu/zhch/OmniDocBench/OpenDataLab___OmniDocBench"
- result_dir = "/home/ubuntu/zhch/PaddleX/zhch/OmniDocBench_Result"
-
- # 确保结果目录存在
- os.makedirs(result_dir, exist_ok=True)
-
- # 查找所有图像文件
- image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp', '*.tiff']
- image_files = []
-
- for ext in image_extensions:
- image_files.extend(glob.glob(os.path.join(dataset_path, '**', ext), recursive=True))
-
- print(f"找到 {len(image_files)} 个图像文件")
-
- if not image_files:
- print("未找到任何图像文件,程序终止")
- return
-
- # 存储所有结果
- all_results = []
-
- # 处理每个图像
- for i, image_path in enumerate(image_files[:10]): # 限制处理前10个文件用于测试
- try:
- print(f"处理进度: {i+1}/{len(image_files[:10])}")
-
- # 处理单个图像
- result = evaluator.evaluate_single_image(image_path)
- if result is not None:
- all_results.append(result)
-
- except Exception as e:
- print(f"处理文件 {image_path} 时出错: {str(e)}")
- traceback.print_exc()
- continue
-
- # 保存结果
- output_file = os.path.join(result_dir, "OmniDocBench-PPStructureV3.json")
- try:
- with open(output_file, 'w', encoding='utf-8') as f:
- json.dump(all_results, f, ensure_ascii=False, indent=2)
-
- print(f"处理完成,结果保存至: {output_file}")
- print(f"共处理 {len(all_results)} 个文件")
- except Exception as e:
- print(f"保存结果文件时发生错误: {str(e)}")
- traceback.print_exc()
-
- except Exception as e:
- print(f"处理数据集时发生致命错误: {str(e)}")
- traceback.print_exc()
- if __name__ == "__main__":
- # convert_existing_results()
- process_omnidocbench_dataset()
|