Forráskód Böngészése

feat(zhch): 实现 OmniDocBench 数据集处理和评估功能

- 新增加载 OmniDocBench 数据集的方法
- 实现数据集预测和评估功能
- 优化现有结果转换流程
- 添加异常处理和日志记录
zhch158_admin 3 hónapja
szülő
commit
d1bdb0b621
2 módosított fájl, 396 hozzáadás és 535 törlés
  1. 1 407
      zhch/PPStructureV3-OmniDocBench.md
  2. 395 128
      zhch/omnidocbench_eval.py

+ 1 - 407
zhch/PPStructureV3-OmniDocBench.md

@@ -1,413 +1,7 @@
 正在收集工作区信息正在筛选到最相关的信息根据您的工作空间和OmniDocBench说明文档,我来帮您创建一个评估脚本,按照OmniDocBench要求的数据格式生成测试结果。
 
 ## 评估脚本实现
-
-```python
-# zhch/omnidocbench_eval.py
-import json
-import time
-from pathlib import Path
-from typing import List, Dict, Any, Tuple
-import cv2
-import numpy as np
-from paddlex import create_pipeline
-
-class OmniDocBenchEvaluator:
-    """OmniDocBench评估器(修正版),用于生成符合评测格式的结果"""
-    
-    def __init__(self, pipeline_config_path: str = "./PP-StructureV3-zhch.yaml"):
-        """
-        初始化评估器
-        
-        Args:
-            pipeline_config_path: PaddleX pipeline配置文件路径
-        """
-        self.pipeline = create_pipeline(pipeline=pipeline_config_path)
-        self.category_mapping = self._get_category_mapping()
-        
-    def _get_category_mapping(self) -> Dict[str, str]:
-        """获取PaddleX类别到OmniDocBench类别的映射"""
-        return {
-            # PaddleX -> OmniDocBench 类别映射
-            'title': 'title',
-            'text': 'text_block',
-            'figure': 'figure',
-            'figure_caption': 'figure_caption',
-            'table': 'table',
-            'table_caption': 'table_caption',
-            'equation': 'equation_isolated',
-            'header': 'header',
-            'footer': 'footer',
-            'reference': 'reference',
-            'seal': 'abandon',  # 印章通常作为舍弃类
-            'number': 'page_number',
-            # 添加更多映射关系
-        }
-    
-    def evaluate_single_image(self, image_path: str, 
-                            use_gpu: bool = True,
-                            **kwargs) -> Dict[str, Any]:
-        """
-        评估单张图像
-        
-        Args:
-            image_path: 图像路径
-            use_gpu: 是否使用GPU
-            **kwargs: 其他pipeline参数
-            
-        Returns:
-            符合OmniDocBench格式的结果字典
-        """
-        print(f"正在处理图像: {image_path}")
-        
-        # 读取图像获取尺寸信息
-        image = cv2.imread(image_path)
-        height, width = image.shape[:2]
-        
-        # 运行PaddleX pipeline
-        start_time = time.time()
-        
-        output = list(self.pipeline.predict(
-            input=image_path,
-            device="gpu" if use_gpu else "cpu",
-            use_doc_orientation_classify=True,
-            use_doc_unwarping=False,
-            use_seal_recognition=True,
-            use_chart_recognition=True,
-            use_table_recognition=True,
-            use_formula_recognition=True,
-            **kwargs
-        ))
-        
-        process_time = time.time() - start_time
-        print(f"处理耗时: {process_time:.2f}秒")
-        
-        # 转换为OmniDocBench格式
-        result = self._convert_to_omnidocbench_format(
-            output, image_path, width, height
-        )
-        
-        return result
-    
-    def _convert_to_omnidocbench_format(self, 
-                                      paddlex_output: List, 
-                                      image_path: str,
-                                      width: int, 
-                                      height: int) -> Dict[str, Any]:
-        """
-        将PaddleX输出转换为OmniDocBench格式
-        
-        Args:
-            paddlex_output: PaddleX的输出结果列表
-            image_path: 图像路径
-            width: 图像宽度
-            height: 图像高度
-            
-        Returns:
-            OmniDocBench格式的结果
-        """
-        layout_dets = []
-        anno_id_counter = 0
-        
-        # 处理PaddleX的输出
-        for res in paddlex_output:
-            # 从parsing_res_list中提取布局信息
-            if hasattr(res, 'parsing_res_list') and res.parsing_res_list:
-                parsing_list = res.parsing_res_list
-                
-                for item in parsing_list:
-                    # 提取边界框和类别
-                    bbox = item.get('block_bbox', [])
-                    category = item.get('block_label', 'text_block')
-                    content = item.get('block_content', '')
-                    
-                    # 转换bbox格式 [x1, y1, x2, y2] -> [x1, y1, x2, y1, x2, y2, x1, y2]
-                    if len(bbox) == 4:
-                        x1, y1, x2, y2 = bbox
-                        poly = [x1, y1, x2, y1, x2, y2, x1, y2]
-                    else:
-                        poly = bbox
-                    
-                    # 映射类别
-                    omni_category = self.category_mapping.get(category, 'text_block')
-                    
-                    # 创建layout检测结果
-                    layout_det = {
-                        "category_type": omni_category,
-                        "poly": poly,
-                        "ignore": False,
-                        "order": anno_id_counter,
-                        "anno_id": anno_id_counter,
-                    }
-                    
-                    # 添加文本识别结果
-                    if content and content.strip():
-                        if omni_category == 'table':
-                            # 表格内容作为HTML存储
-                            layout_det["html"] = content
-                        else:
-                            # 其他类型作为文本存储
-                            layout_det["text"] = content.strip()
-                    
-                    # 添加span级别的标注(从OCR结果中提取)
-                    layout_det["line_with_spans"] = self._extract_spans_from_ocr(
-                        res, bbox, omni_category
-                    )
-                    
-                    # 添加属性标签
-                    layout_det["attribute"] = self._extract_attributes(item, omni_category)
-                    
-                    layout_dets.append(layout_det)
-                    anno_id_counter += 1
-        
-        # 构建完整结果
-        result = {
-            "layout_dets": layout_dets,
-            "page_info": {
-                "page_no": 0,
-                "height": height,
-                "width": width,
-                "image_path": Path(image_path).name,
-                "page_attribute": self._extract_page_attributes(paddlex_output)
-            },
-            "extra": {
-                "relation": []  # 关系信息,需要根据具体情况提取
-            }
-        }
-        
-        return result
-    
-    def _extract_spans_from_ocr(self, res, block_bbox: List, category: str) -> List[Dict]:
-        """从OCR结果中提取span级别的标注"""
-        spans = []
-        
-        # 如果有OCR结果,提取相关的文本行
-        if hasattr(res, 'overall_ocr_res') and res.overall_ocr_res:
-            ocr_res = res.overall_ocr_res
-            
-            if hasattr(ocr_res, 'rec_texts') and hasattr(ocr_res, 'rec_boxes'):
-                texts = ocr_res.rec_texts
-                boxes = ocr_res.rec_boxes
-                scores = getattr(ocr_res, 'rec_scores', [1.0] * len(texts))
-                
-                # 检查哪些OCR结果在当前block内
-                if len(block_bbox) == 4:
-                    x1, y1, x2, y2 = block_bbox
-                    
-                    for i, (text, box, score) in enumerate(zip(texts, boxes, scores)):
-                        if len(box) >= 4:
-                            # 检查OCR框是否在block内
-                            ocr_x1, ocr_y1, ocr_x2, ocr_y2 = box[:4]
-                            
-                            # 简单的包含检查
-                            if (ocr_x1 >= x1 and ocr_y1 >= y1 and 
-                                ocr_x2 <= x2 and ocr_y2 <= y2):
-                                
-                                span = {
-                                    "category_type": "text_span",
-                                    "poly": [ocr_x1, ocr_y1, ocr_x2, ocr_y1, 
-                                            ocr_x2, ocr_y2, ocr_x1, ocr_y2],
-                                    "ignore": False,
-                                    "text": text,
-                                }
-                                
-                                # 如果置信度太低,可能需要忽略
-                                if score < 0.5:
-                                    span["ignore"] = True
-                                
-                                spans.append(span)
-        
-        return spans
-    
-    def _extract_attributes(self, item: Dict, category: str) -> Dict:
-        """提取属性标签"""
-        attributes = {}
-        
-        # 根据类别提取不同的属性
-        if category == 'table':
-            # 表格属性
-            attributes.update({
-                "table_layout": "vertical",  # 需要根据实际情况判断
-                "with_span": False,          # 需要检查是否有合并单元格
-                "line": "full_line",         # 需要检查线框类型
-                "language": "table_simplified_chinese",  # 需要语言检测
-                "include_equation": False,
-                "include_backgroud": False,
-                "table_vertical": False
-            })
-            
-            # 检查表格内容是否有合并单元格
-            content = item.get('block_content', '')
-            if 'colspan' in content or 'rowspan' in content:
-                attributes["with_span"] = True
-                
-        elif category in ['text_block', 'title']:
-            # 文本属性
-            attributes.update({
-                "text_language": "text_simplified_chinese",
-                "text_background": "white",
-                "text_rotate": "normal"
-            })
-            
-        elif 'equation' in category:
-            # 公式属性
-            attributes.update({
-                "formula_type": "print"
-            })
-        
-        return attributes
-    
-    def _extract_page_attributes(self, paddlex_output) -> Dict:
-        """提取页面级别的属性"""
-        return {
-            "data_source": "research_report",  # 需要根据实际情况判断
-            "language": "simplified_chinese",
-            "layout": "single_column",
-            "watermark": False,
-            "fuzzy_scan": False,
-            "colorful_backgroud": False
-        }
-    
-    def load_existing_result(self, result_path: str) -> Dict[str, Any]:
-        """
-        从已有的PaddleX结果文件加载数据进行转换
-        
-        Args:
-            result_path: PaddleX结果JSON文件路径
-            
-        Returns:
-            OmniDocBench格式的结果字典
-        """
-        with open(result_path, 'r', encoding='utf-8') as f:
-            data = json.load(f)
-        
-        # 从结果文件中提取图像信息
-        input_path = data.get('input_path', '')
-        
-        # 读取图像获取尺寸
-        if input_path and Path(input_path).exists():
-            image = cv2.imread(input_path)
-            height, width = image.shape[:2]
-            image_name = Path(input_path).name
-        else:
-            # 如果图像路径不存在,使用默认值
-            height, width = 1600, 1200
-            image_name = "unknown.png"
-        
-        # 转换格式
-        result = self._convert_paddlex_result_to_omnidocbench(
-            data, image_name, width, height
-        )
-        
-        return result
-    
-    def _convert_paddlex_result_to_omnidocbench(self, 
-                                              paddlex_result: Dict,
-                                              image_name: str,
-                                              width: int, 
-                                              height: int) -> Dict[str, Any]:
-        """
-        将已有的PaddleX结果转换为OmniDocBench格式
-        """
-        layout_dets = []
-        anno_id_counter = 0
-        
-        # 从parsing_res_list中提取布局信息
-        parsing_list = paddlex_result.get('parsing_res_list', [])
-        
-        for item in parsing_list:
-            # 提取边界框和类别
-            bbox = item.get('block_bbox', [])
-            category = item.get('block_label', 'text_block')
-            content = item.get('block_content', '')
-            
-            # 转换bbox格式
-            if len(bbox) == 4:
-                x1, y1, x2, y2 = bbox
-                poly = [x1, y1, x2, y1, x2, y2, x1, y2]
-            else:
-                poly = bbox
-            
-            # 映射类别
-            omni_category = self.category_mapping.get(category, 'text_block')
-            
-            # 创建layout检测结果
-            layout_det = {
-                "category_type": omni_category,
-                "poly": poly,
-                "ignore": False,
-                "order": anno_id_counter,
-                "anno_id": anno_id_counter,
-            }
-            
-            # 添加内容
-            if content and content.strip():
-                if omni_category == 'table':
-                    layout_det["html"] = content
-                else:
-                    layout_det["text"] = content.strip()
-            
-            # 添加属性
-            layout_det["attribute"] = self._extract_attributes(item, omni_category)
-            layout_det["line_with_spans"] = []  # 简化处理
-            
-            layout_dets.append(layout_det)
-            anno_id_counter += 1
-        
-        # 构建完整结果
-        result = {
-            "layout_dets": layout_dets,
-            "page_info": {
-                "page_no": 0,
-                "height": height,
-                "width": width,
-                "image_path": image_name,
-                "page_attribute": {
-                    "data_source": "research_report",
-                    "language": "simplified_chinese",
-                    "layout": "single_column",
-                    "watermark": False,
-                    "fuzzy_scan": False,
-                    "colorful_backgroud": False
-                }
-            },
-            "extra": {
-                "relation": []
-            }
-        }
-        
-        return result
-
-def convert_existing_results():
-    """转换已有的PaddleX结果"""
-    evaluator = OmniDocBenchEvaluator()
-    
-    # 示例:转换单个结果文件
-    result_file = "./sample_data/single_pipeline_output/PP-StructureV3-zhch/300674-母公司现金流量表-扫描_res.json"
-    
-    if Path(result_file).exists():
-        print(f"正在转换结果文件: {result_file}")
-        
-        omnidocbench_result = evaluator.load_existing_result(result_file)
-        
-        # 保存转换后的结果
-        output_file = "./omnidocbench_converted_result.json"
-        with open(output_file, 'w', encoding='utf-8') as f:
-            json.dump([omnidocbench_result], f, ensure_ascii=False, indent=2)
-        
-        print(f"转换完成,结果保存至: {output_file}")
-        print(f"检测到的布局元素数量: {len(omnidocbench_result['layout_dets'])}")
-        
-        # 显示检测到的元素
-        for i, item in enumerate(omnidocbench_result['layout_dets']):
-            print(f"  {i+1}. {item['category_type']}: {item.get('text', item.get('html', ''))[:50]}...")
-    
-    else:
-        print(f"结果文件不存在: {result_file}")
-
-if __name__ == "__main__":
-    convert_existing_results()
-```
+zhch/omnidocbench_eval.py
 
 ## 使用方法
 

+ 395 - 128
zhch/omnidocbench_eval.py

@@ -6,11 +6,17 @@ from typing import List, Dict, Any, Tuple
 import cv2
 import numpy as np
 from paddlex import create_pipeline
+import os
+import glob
+import traceback
 
 class OmniDocBenchEvaluator:
-    """OmniDocBench评估器(修正版),用于生成符合评测格式的结果"""
+    """
+    OmniDocBench评估器(修正版),用于生成符合评测格式的结果
+    pipeline_config_path = "paddlex/configs/pipelines/PP-StructureV3.yaml"
+    """
     
-    def __init__(self, pipeline_config_path: str = "./PP-StructureV3-zhch.yaml"):
+    def __init__(self, pipeline_config_path: str = "PP-StructureV3"):
         """
         初始化评估器
         
@@ -57,22 +63,31 @@ class OmniDocBenchEvaluator:
         
         # 读取图像获取尺寸信息
         image = cv2.imread(image_path)
+        if image is None:
+            print(f"无法读取图像: {image_path}")
+            return None
+            
         height, width = image.shape[:2]
         
         # 运行PaddleX pipeline
         start_time = time.time()
         
-        output = list(self.pipeline.predict(
-            input=image_path,
-            device="gpu" if use_gpu else "cpu",
-            use_doc_orientation_classify=True,
-            use_doc_unwarping=False,
-            use_seal_recognition=True,
-            use_chart_recognition=True,
-            use_table_recognition=True,
-            use_formula_recognition=True,
-            **kwargs
-        ))
+        try:
+            output = list(self.pipeline.predict(
+                input=image_path,
+                device="gpu" if use_gpu else "cpu",
+                use_doc_orientation_classify=True,
+                use_doc_unwarping=False,
+                use_seal_recognition=True,
+                use_chart_recognition=True,
+                use_table_recognition=True,
+                use_formula_recognition=True,
+                **kwargs
+            ))
+        except Exception as e:
+            print(f"处理图像 {image_path} 时发生错误: {str(e)}")
+            traceback.print_exc()
+            return None
         
         process_time = time.time() - start_time
         print(f"处理耗时: {process_time:.2f}秒")
@@ -106,54 +121,53 @@ class OmniDocBenchEvaluator:
         
         # 处理PaddleX的输出
         for res in paddlex_output:
+            res_json = res.json.get('res', {})
             # 从parsing_res_list中提取布局信息
-            if hasattr(res, 'parsing_res_list') and res.parsing_res_list:
-                parsing_list = res.parsing_res_list
+            parsing_list = res_json.get('parsing_res_list', [])
+            for item in parsing_list:
+                # 提取边界框和类别
+                bbox = item.get('block_bbox', [])
+                category = item.get('block_label', 'text_block')
+                content = item.get('block_content', '')
+                
+                # 转换bbox格式 [x1, y1, x2, y2] -> [x1, y1, x2, y1, x2, y2, x1, y2]
+                if len(bbox) == 4:
+                    x1, y1, x2, y2 = bbox
+                    poly = [x1, y1, x2, y1, x2, y2, x1, y2]
+                else:
+                    poly = bbox
+                
+                # 映射类别
+                omni_category = self.category_mapping.get(category, 'text_block')
+                
+                # 创建layout检测结果
+                layout_det = {
+                    "category_type": omni_category,
+                    "poly": poly,
+                    "ignore": False,
+                    "order": anno_id_counter,
+                    "anno_id": anno_id_counter,
+                }
                 
-                for item in parsing_list:
-                    # 提取边界框和类别
-                    bbox = item.get('block_bbox', [])
-                    category = item.get('block_label', 'text_block')
-                    content = item.get('block_content', '')
-                    
-                    # 转换bbox格式 [x1, y1, x2, y2] -> [x1, y1, x2, y1, x2, y2, x1, y2]
-                    if len(bbox) == 4:
-                        x1, y1, x2, y2 = bbox
-                        poly = [x1, y1, x2, y1, x2, y2, x1, y2]
+                # 添加文本识别结果
+                if content and content.strip():
+                    if omni_category == 'table':
+                        # 表格内容作为HTML存储
+                        layout_det["html"] = content
                     else:
-                        poly = bbox
-                    
-                    # 映射类别
-                    omni_category = self.category_mapping.get(category, 'text_block')
-                    
-                    # 创建layout检测结果
-                    layout_det = {
-                        "category_type": omni_category,
-                        "poly": poly,
-                        "ignore": False,
-                        "order": anno_id_counter,
-                        "anno_id": anno_id_counter,
-                    }
-                    
-                    # 添加文本识别结果
-                    if content and content.strip():
-                        if omni_category == 'table':
-                            # 表格内容作为HTML存储
-                            layout_det["html"] = content
-                        else:
-                            # 其他类型作为文本存储
-                            layout_det["text"] = content.strip()
-                    
-                    # 添加span级别的标注(从OCR结果中提取)
-                    layout_det["line_with_spans"] = self._extract_spans_from_ocr(
-                        res, bbox, omni_category
-                    )
-                    
-                    # 添加属性标签
-                    layout_det["attribute"] = self._extract_attributes(item, omni_category)
-                    
-                    layout_dets.append(layout_det)
-                    anno_id_counter += 1
+                        # 其他类型作为文本存储
+                        layout_det["text"] = content.strip()
+                
+                # 添加span级别的标注(从OCR结果中提取)
+                layout_det["line_with_spans"] = self._extract_spans_from_ocr(
+                    res_json, bbox, omni_category
+                )
+                
+                # 添加属性标签
+                layout_det["attribute"] = self._extract_attributes(item, omni_category)
+                
+                layout_dets.append(layout_det)
+                anno_id_counter += 1
         
         # 构建完整结果
         result = {
@@ -177,40 +191,38 @@ class OmniDocBenchEvaluator:
         spans = []
         
         # 如果有OCR结果,提取相关的文本行
-        if hasattr(res, 'overall_ocr_res') and res.overall_ocr_res:
-            ocr_res = res.overall_ocr_res
+        ocr_res = res.get('overall_ocr_res', None)
+        if ocr_res:
+            texts = ocr_res.get('rec_texts', [])
+            boxes = ocr_res.get('rec_boxes', [])
+            scores = ocr_res.get('rec_scores', [1.0] * len(texts)) if 'rec_scores' in ocr_res else [1.0] * len(texts)
             
-            if hasattr(ocr_res, 'rec_texts') and hasattr(ocr_res, 'rec_boxes'):
-                texts = ocr_res.rec_texts
-                boxes = ocr_res.rec_boxes
-                scores = getattr(ocr_res, 'rec_scores', [1.0] * len(texts))
+            # 检查哪些OCR结果在当前block内
+            if len(block_bbox) == 4:
+                x1, y1, x2, y2 = block_bbox
                 
-                # 检查哪些OCR结果在当前block内
-                if len(block_bbox) == 4:
-                    x1, y1, x2, y2 = block_bbox
-                    
-                    for i, (text, box, score) in enumerate(zip(texts, boxes, scores)):
-                        if len(box) >= 4:
-                            # 检查OCR框是否在block内
-                            ocr_x1, ocr_y1, ocr_x2, ocr_y2 = box[:4]
+                for i, (text, box, score) in enumerate(zip(texts, boxes, scores)):
+                    if len(box) >= 4:
+                        # 检查OCR框是否在block内
+                        ocr_x1, ocr_y1, ocr_x2, ocr_y2 = box[:4]
+                        
+                        # 简单的包含检查
+                        if (ocr_x1 >= x1 and ocr_y1 >= y1 and 
+                            ocr_x2 <= x2 and ocr_y2 <= y2):
                             
-                            # 简单的包含检查
-                            if (ocr_x1 >= x1 and ocr_y1 >= y1 and 
-                                ocr_x2 <= x2 and ocr_y2 <= y2):
-                                
-                                span = {
-                                    "category_type": "text_span",
-                                    "poly": [ocr_x1, ocr_y1, ocr_x2, ocr_y1, 
-                                            ocr_x2, ocr_y2, ocr_x1, ocr_y2],
-                                    "ignore": False,
-                                    "text": text,
-                                }
-                                
-                                # 如果置信度太低,可能需要忽略
-                                if score < 0.5:
-                                    span["ignore"] = True
-                                
-                                spans.append(span)
+                            span = {
+                                "category_type": "text_span",
+                                "poly": [ocr_x1, ocr_y1, ocr_x2, ocr_y1, 
+                                        ocr_x2, ocr_y2, ocr_x1, ocr_y2],
+                                "ignore": False,
+                                "text": text,
+                            }
+                            
+                            # 如果置信度太低,可能需要忽略
+                            if score < 0.5:
+                                span["ignore"] = True
+                            
+                            spans.append(span)
         
         return spans
     
@@ -273,28 +285,210 @@ class OmniDocBenchEvaluator:
         Returns:
             OmniDocBench格式的结果字典
         """
-        with open(result_path, 'r', encoding='utf-8') as f:
-            data = json.load(f)
+        if not Path(result_path).exists():
+            print(f"结果文件不存在: {result_path}")
+            return None
+            
+        try:
+            with open(result_path, 'r', encoding='utf-8') as f:
+                data = json.load(f)
+            
+            # 从结果文件中提取图像信息
+            input_path = data.get('input_path', '')
+            
+            # 读取图像获取尺寸
+            if input_path and Path(input_path).exists():
+                image = cv2.imread(input_path)
+                if image is None:
+                    print(f"无法读取图像: {input_path}")
+                    height, width = 1600, 1200
+                    image_name = "unknown.png"
+                else:
+                    height, width = image.shape[:2]
+                    image_name = Path(input_path).name
+            else:
+                # 如果图像路径不存在,使用默认值
+                height, width = 1600, 1200
+                image_name = "unknown.png"
+            
+            # 转换格式
+            result = self._convert_paddlex_result_to_omnidocbench(
+                data, image_name, width, height
+            )
+            
+            return result
+            
+        except Exception as e:
+            print(f"加载结果文件 {result_path} 时发生错误: {str(e)}")
+            traceback.print_exc()
+            return None
+    
+    def load_omnidocbench_dataset(self, dataset_dir: str) -> Dict[str, Any]:
+        """
+        加载OmniDocBench数据集
+        
+        Args:
+            dataset_dir: 数据集目录路径
+            
+        Returns:
+            加载的数据集字典
+        """
+        dataset = {}
+        
+        # 遍历数据集目录
+        for file_path in Path(dataset_dir).rglob('*.json'):
+            if "pred" in file_path.name or "result" in file_path.name:
+                continue
+                
+            try:
+                with open(file_path, 'r', encoding='utf-8') as f:
+                    data = json.load(f)
+                
+                # 提取文件ID
+                file_id = file_path.stem
+                dataset[file_id] = {
+                    "ground_truth": data,
+                    "predictions": None,
+                    "image_path": self._find_image_file(file_path.parent, file_id)
+                }
+                
+            except Exception as e:
+                print(f"加载文件 {file_path} 出错: {str(e)}")
         
-        # 从结果文件中提取图像信息
-        input_path = data.get('input_path', '')
+        return dataset
+    
+    def _find_image_file(self, search_dir: Path, file_id: str) -> str:
+        """查找对应的图像文件"""
+        image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']
         
-        # 读取图像获取尺寸
-        if input_path and Path(input_path).exists():
-            image = cv2.imread(input_path)
-            height, width = image.shape[:2]
-            image_name = Path(input_path).name
-        else:
-            # 如果图像路径不存在,使用默认值
-            height, width = 1600, 1200
-            image_name = "unknown.png"
+        for ext in image_extensions:
+            image_path = search_dir / f"{file_id}{ext}"
+            if image_path.exists():
+                return str(image_path)
         
-        # 转换格式
-        result = self._convert_paddlex_result_to_omnidocbench(
-            data, image_name, width, height
-        )
+        # 如果找不到图像文件,尝试在子目录中查找
+        for subdir in search_dir.iterdir():
+            if subdir.is_dir():
+                result = self._find_image_file(subdir, file_id)
+                if result:
+                    return result
         
-        return result
+        return ""
+    
+    def generate_predictions(self, dataset: Dict[str, Any]) -> Dict[str, Any]:
+        """
+        为数据集生成预测
+        
+        Args:
+            dataset: 数据集字典
+            
+        Returns:
+            包含预测结果的数据集字典
+        """
+        for file_id, item in dataset.items():
+            if item["image_path"] and Path(item["image_path"]).exists():
+                try:
+                    result = self.evaluate_single_image(item["image_path"])
+                    dataset[file_id]["predictions"] = result
+                    print(f"成功处理文件: {file_id}")
+                except Exception as e:
+                    print(f"处理文件 {file_id} 出错: {str(e)}")
+            else:
+                print(f"图像文件不存在: {item['image_path']}")
+        
+        return dataset
+    
+    def evaluate_dataset(self, dataset: Dict[str, Any]) -> Dict[str, float]:
+        """
+        评估数据集的预测结果
+        
+        Args:
+            dataset: 包含预测和真实标签的数据集字典
+            
+        Returns:
+            包含评估指标的字典
+        """
+        metrics = {
+            "precision": 0.0,
+            "recall": 0.0,
+            "f1_score": 0.0,
+            "iou": 0.0
+        }
+        
+        # 实现具体的评估逻辑
+        # 这里只是一个示例框架
+        total_precision = 0.0
+        total_recall = 0.0
+        total_f1 = 0.0
+        total_iou = 0.0
+        count = 0
+        
+        for file_id, item in dataset.items():
+            if item["predictions"] is None:
+                continue
+                
+            # 计算单个样本的评估指标
+            sample_metrics = self._calculate_sample_metrics(item["ground_truth"], item["predictions"])
+            
+            total_precision += sample_metrics["precision"]
+            total_recall += sample_metrics["recall"]
+            total_f1 += sample_metrics["f1_score"]
+            total_iou += sample_metrics["iou"]
+            count += 1
+        
+        if count > 0:
+            metrics["precision"] = total_precision / count
+            metrics["recall"] = total_recall / count
+            metrics["f1_score"] = total_f1 / count
+            metrics["iou"] = total_iou / count
+        
+        return metrics
+    
+    def _calculate_sample_metrics(self, ground_truth: Dict, prediction: Dict) -> Dict[str, float]:
+        """
+        计算单个样本的评估指标
+        
+        Args:
+            ground_truth: 真实标签数据
+            prediction: 预测结果
+            
+        Returns:
+            包含评估指标的字典
+        """
+        metrics = {
+            "precision": 0.0,
+            "recall": 0.0,
+            "f1_score": 0.0,
+            "iou": 0.0
+        }
+        
+        # 实现具体的评估逻辑
+        # 这里只是一个示例框架
+        gt_layouts = ground_truth.get("layout_dets", [])
+        pred_layouts = prediction.get("layout_dets", [])
+        
+        # 简单的类别匹配计算
+        gt_categories = set(item.get("category_type", "") for item in gt_layouts)
+        pred_categories = set(item.get("category_type", "") for item in pred_layouts)
+        
+        # 计算交并集
+        intersection = len(gt_categories.intersection(pred_categories))
+        union = len(gt_categories.union(pred_categories))
+        
+        if union > 0:
+            metrics["iou"] = intersection / union
+        
+        # 计算精确度、召回率、F1分数
+        if len(pred_categories) > 0:
+            metrics["precision"] = intersection / len(pred_categories)
+        
+        if len(gt_categories) > 0:
+            metrics["recall"] = intersection / len(gt_categories)
+        
+        if metrics["precision"] + metrics["recall"] > 0:
+            metrics["f1_score"] = 2 * (metrics["precision"] * metrics["recall"]) / (metrics["precision"] + metrics["recall"])
+        
+        return metrics
     
     def _convert_paddlex_result_to_omnidocbench(self, 
                                               paddlex_result: Dict,
@@ -375,30 +569,103 @@ class OmniDocBenchEvaluator:
 
 def convert_existing_results():
     """转换已有的PaddleX结果"""
-    evaluator = OmniDocBenchEvaluator()
-    
-    # 示例:转换单个结果文件
-    result_file = "./sample_data/single_pipeline_output/PP-StructureV3-zhch/300674-母公司现金流量表-扫描_res.json"
-    
-    if Path(result_file).exists():
-        print(f"正在转换结果文件: {result_file}")
+    try:
+        evaluator = OmniDocBenchEvaluator()
         
-        omnidocbench_result = evaluator.load_existing_result(result_file)
+        # 示例:转换单个结果文件
+        result_file = "./sample_data/single_pipeline_output/PP-StructureV3-zhch/300674-母公司现金流量表-扫描_res.json"
         
-        # 保存转换后的结果
-        output_file = "./omnidocbench_converted_result.json"
-        with open(output_file, 'w', encoding='utf-8') as f:
-            json.dump([omnidocbench_result], f, ensure_ascii=False, indent=2)
+        if Path(result_file).exists():
+            print(f"正在转换结果文件: {result_file}")
+            
+            omnidocbench_result = evaluator.load_existing_result(result_file)
+            if omnidocbench_result is None:
+                print(f"转换结果为空: {result_file}")
+                return
+                
+            # 保存转换后的结果
+            output_file = "./omnidocbench_converted_result.json"
+            try:
+                with open(output_file, 'w', encoding='utf-8') as f:
+                    json.dump([omnidocbench_result], f, ensure_ascii=False, indent=2)
+                
+                print(f"转换完成,结果保存至: {output_file}")
+                print(f"检测到的布局元素数量: {len(omnidocbench_result['layout_dets'])}")
+                
+                # 显示检测到的元素
+                for i, item in enumerate(omnidocbench_result['layout_dets']):
+                    print(f"  {i+1}. {item['category_type']}: {item.get('text', item.get('html', ''))[:50]}...")
+            except Exception as e:
+                print(f"保存结果到文件 {output_file} 时发生错误: {str(e)}")
+                traceback.print_exc()
+        else:
+            print(f"结果文件不存在: {result_file}")
+            
+    except Exception as e:
+        print(f"转换过程中发生致命错误: {str(e)}")
+        traceback.print_exc()
+
+def process_omnidocbench_dataset():
+    """处理OmniDocBench数据集"""
+    try:
+        # 初始化评估器
+        evaluator = OmniDocBenchEvaluator()
         
-        print(f"转换完成,结果保存至: {output_file}")
-        print(f"检测到的布局元素数量: {len(omnidocbench_result['layout_dets'])}")
+        # 数据集路径
+        dataset_path = "/home/ubuntu/zhch/OmniDocBench/OpenDataLab___OmniDocBench"
+        result_dir = "/home/ubuntu/zhch/PaddleX/zhch/OmniDocBench_Result"
         
-        # 显示检测到的元素
-        for i, item in enumerate(omnidocbench_result['layout_dets']):
-            print(f"  {i+1}. {item['category_type']}: {item.get('text', item.get('html', ''))[:50]}...")
-    
-    else:
-        print(f"结果文件不存在: {result_file}")
+        # 确保结果目录存在
+        os.makedirs(result_dir, exist_ok=True)
+        
+        # 查找所有图像文件
+        image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp', '*.tiff']
+        image_files = []
+        
+        for ext in image_extensions:
+            image_files.extend(glob.glob(os.path.join(dataset_path, '**', ext), recursive=True))
+        
+        print(f"找到 {len(image_files)} 个图像文件")
+        
+        if not image_files:
+            print("未找到任何图像文件,程序终止")
+            return
+        
+        # 存储所有结果
+        all_results = []
+        
+        # 处理每个图像
+        for i, image_path in enumerate(image_files[:10]):  # 限制处理前10个文件用于测试
+            try:
+                print(f"处理进度: {i+1}/{len(image_files[:10])}")
+                
+                # 处理单个图像
+                result = evaluator.evaluate_single_image(image_path)
+                if result is not None:
+                    all_results.append(result)
+                
+            except Exception as e:
+                print(f"处理文件 {image_path} 时出错: {str(e)}")
+                traceback.print_exc()
+                continue
+        
+        # 保存结果
+        output_file = os.path.join(result_dir, "OmniDocBench-PPStructureV3.json")
+        try:
+            with open(output_file, 'w', encoding='utf-8') as f:
+                json.dump(all_results, f, ensure_ascii=False, indent=2)
+            
+            print(f"处理完成,结果保存至: {output_file}")
+            print(f"共处理 {len(all_results)} 个文件")
+        except Exception as e:
+            print(f"保存结果文件时发生错误: {str(e)}")
+            traceback.print_exc()
+            
+    except Exception as e:
+        print(f"处理数据集时发生致命错误: {str(e)}")
+        traceback.print_exc()
+
 
 if __name__ == "__main__":
-    convert_existing_results()
+    # convert_existing_results()
+    process_omnidocbench_dataset()