浏览代码

优化load_ocr_data_file函数以支持多数据源配置,新增find_corresponding_image函数查找对应图片文件,添加find_available_ocr_files_multi_source函数以查找多个数据源的OCR文件

zhch158_admin 1 月之前
父节点
当前提交
5097154aca
共有 1 个文件被更改,包括 80 次插入30 次删除
  1. 80 30
      ocr_validator_utils.py

+ 80 - 30
ocr_validator_utils.py

@@ -604,11 +604,11 @@ def process_all_images_in_content(content: str, json_path: str) -> str:
 
 # 修改 load_ocr_data_file 函数
 def load_ocr_data_file(json_path: str, config: Dict) -> Tuple[List, str, str]:
-    """加载OCR相关数据文件"""
+    """加载OCR数据文件 - 支持多数据源配置"""
     json_file = Path(json_path)
-    ocr_data = []
-    md_content = ""
-    image_path = ""
+    
+    if not json_file.exists():
+        raise FileNotFoundError(f"找不到JSON文件: {json_path}")
     
     # 加载JSON数据
     try:
@@ -633,36 +633,37 @@ def load_ocr_data_file(json_path: str, config: Dict) -> Tuple[List, str, str]:
     
     # 加载MD文件
     md_file = json_file.with_suffix('.md')
+    md_content = ""
     if md_file.exists():
         with open(md_file, 'r', encoding='utf-8') as f:
-            raw_md_content = f.read()
-            
-        # 处理内容中的所有图片引用(HTML和Markdown)
-        md_content = process_all_images_in_content(raw_md_content, str(json_file))
-    
-    # 推断图片路径
-    image_name = json_file.stem
-    src_img_dir = Path(config['paths']['src_img_dir'])
-    
-    image_candidates = []
-    for ext in config['paths']['supported_image_formats']:
-        image_candidates.extend([
-            src_img_dir / f"{image_name}{ext}",
-            json_file.parent / f"{image_name}{ext}",
-            # 对于PPStructV3,可能图片名包含page信息 # 去掉page后缀的通用匹配
-            src_img_dir / f"{image_name.split('_page_')[0]}{ext}" if '_page_' in image_name else None,
-        ])
-    
-    # 移除None值
-    image_candidates = [candidate for candidate in image_candidates if candidate is not None]
-    
-    for candidate in image_candidates:
-        if candidate.exists():
-            image_path = str(candidate)
-            break
+            md_content = f.read()
+    
+    # 查找对应的图片文件
+    image_path = find_corresponding_image(json_file, config)
     
     return ocr_data, md_content, image_path
 
+def find_corresponding_image(json_file: Path, config: Dict) -> str:
+    """查找对应的图片文件 - 支持多数据源"""
+    # 从配置中获取图片目录
+    src_img_dir = config.get('paths', {}).get('src_img_dir', '')
+    
+    if not src_img_dir:
+        # 如果没有配置图片目录,尝试在JSON文件同级目录查找
+        src_img_dir = json_file.parent
+    
+    src_img_path = Path(src_img_dir)
+    
+    # 支持多种图片格式
+    image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif']
+    
+    for ext in image_extensions:
+        image_file = src_img_path / f"{json_file.stem}{ext}"
+        if image_file.exists():
+            return str(image_file)
+    
+    # 如果找不到,返回空字符串
+    return ""
 
 def process_ocr_data(ocr_data: List, config: Dict) -> Dict[str, List]:
     """处理OCR数据,建立文本到bbox的映射"""
@@ -1046,4 +1047,53 @@ def detect_image_orientation_by_opencv(image_path: str) -> Dict:
             'method': 'opencv_analysis',
             'error': str(e),
             'message': f'OpenCV检测过程中发生错误: {str(e)}'
-        }
+        }
+
+# ocr_validator_utils.py
+def find_available_ocr_files_multi_source(config: Dict) -> Dict[str, List[Dict]]:
+    """查找多个数据源的OCR文件"""
+    all_sources = {}
+    
+    for source in config.get('data_sources', []):
+        source_name = source['name']
+        ocr_tool = source['ocr_tool']
+        source_key = f"{source_name}_{ocr_tool}"  # 创建唯一标识
+        
+        ocr_out_dir = source['ocr_out_dir']
+        
+        if Path(ocr_out_dir).exists():
+            files = find_available_ocr_files(ocr_out_dir)
+            
+            # 为每个文件添加数据源信息
+            for file_info in files:
+                file_info.update({
+                    'source_name': source_name,
+                    'ocr_tool': ocr_tool,
+                    'description': source.get('description', ''),
+                    'src_img_dir': source.get('src_img_dir', ''),
+                    'ocr_out_dir': ocr_out_dir
+                })
+            
+            all_sources[source_key] = {
+                'files': files,
+                'config': source
+            }
+            
+            print(f"📁 找到数据源: {source_key} - {len(files)} 个文件")
+    
+    return all_sources
+
+def get_data_source_display_name(source_config: Dict) -> str:
+    """生成数据源的显示名称"""
+    name = source_config['name']
+    tool = source_config['ocr_tool']
+    description = source_config.get('description', '')
+    
+    # 获取工具的友好名称
+    tool_name_map = {
+        'dots_ocr': 'Dots OCR',
+        'ppstructv3': 'PPStructV3'
+    }
+    
+    tool_display = tool_name_map.get(tool, tool)
+    return f"{name} ({tool_display})"