|
|
@@ -604,11 +604,11 @@ def process_all_images_in_content(content: str, json_path: str) -> str:
|
|
|
|
|
|
# 修改 load_ocr_data_file 函数
|
|
|
def load_ocr_data_file(json_path: str, config: Dict) -> Tuple[List, str, str]:
|
|
|
- """加载OCR相关数据文件"""
|
|
|
+ """加载OCR数据文件 - 支持多数据源配置"""
|
|
|
json_file = Path(json_path)
|
|
|
- ocr_data = []
|
|
|
- md_content = ""
|
|
|
- image_path = ""
|
|
|
+
|
|
|
+ if not json_file.exists():
|
|
|
+ raise FileNotFoundError(f"找不到JSON文件: {json_path}")
|
|
|
|
|
|
# 加载JSON数据
|
|
|
try:
|
|
|
@@ -633,36 +633,37 @@ def load_ocr_data_file(json_path: str, config: Dict) -> Tuple[List, str, str]:
|
|
|
|
|
|
# 加载MD文件
|
|
|
md_file = json_file.with_suffix('.md')
|
|
|
+ md_content = ""
|
|
|
if md_file.exists():
|
|
|
with open(md_file, 'r', encoding='utf-8') as f:
|
|
|
- raw_md_content = f.read()
|
|
|
-
|
|
|
- # 处理内容中的所有图片引用(HTML和Markdown)
|
|
|
- md_content = process_all_images_in_content(raw_md_content, str(json_file))
|
|
|
-
|
|
|
- # 推断图片路径
|
|
|
- image_name = json_file.stem
|
|
|
- src_img_dir = Path(config['paths']['src_img_dir'])
|
|
|
-
|
|
|
- image_candidates = []
|
|
|
- for ext in config['paths']['supported_image_formats']:
|
|
|
- image_candidates.extend([
|
|
|
- src_img_dir / f"{image_name}{ext}",
|
|
|
- json_file.parent / f"{image_name}{ext}",
|
|
|
- # 对于PPStructV3,可能图片名包含page信息 # 去掉page后缀的通用匹配
|
|
|
- src_img_dir / f"{image_name.split('_page_')[0]}{ext}" if '_page_' in image_name else None,
|
|
|
- ])
|
|
|
-
|
|
|
- # 移除None值
|
|
|
- image_candidates = [candidate for candidate in image_candidates if candidate is not None]
|
|
|
-
|
|
|
- for candidate in image_candidates:
|
|
|
- if candidate.exists():
|
|
|
- image_path = str(candidate)
|
|
|
- break
|
|
|
+ md_content = f.read()
|
|
|
+
|
|
|
+ # 查找对应的图片文件
|
|
|
+ image_path = find_corresponding_image(json_file, config)
|
|
|
|
|
|
return ocr_data, md_content, image_path
|
|
|
|
|
|
+def find_corresponding_image(json_file: Path, config: Dict) -> str:
|
|
|
+ """查找对应的图片文件 - 支持多数据源"""
|
|
|
+ # 从配置中获取图片目录
|
|
|
+ src_img_dir = config.get('paths', {}).get('src_img_dir', '')
|
|
|
+
|
|
|
+ if not src_img_dir:
|
|
|
+ # 如果没有配置图片目录,尝试在JSON文件同级目录查找
|
|
|
+ src_img_dir = json_file.parent
|
|
|
+
|
|
|
+ src_img_path = Path(src_img_dir)
|
|
|
+
|
|
|
+ # 支持多种图片格式
|
|
|
+ image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif']
|
|
|
+
|
|
|
+ for ext in image_extensions:
|
|
|
+ image_file = src_img_path / f"{json_file.stem}{ext}"
|
|
|
+ if image_file.exists():
|
|
|
+ return str(image_file)
|
|
|
+
|
|
|
+ # 如果找不到,返回空字符串
|
|
|
+ return ""
|
|
|
|
|
|
def process_ocr_data(ocr_data: List, config: Dict) -> Dict[str, List]:
|
|
|
"""处理OCR数据,建立文本到bbox的映射"""
|
|
|
@@ -1046,4 +1047,53 @@ def detect_image_orientation_by_opencv(image_path: str) -> Dict:
|
|
|
'method': 'opencv_analysis',
|
|
|
'error': str(e),
|
|
|
'message': f'OpenCV检测过程中发生错误: {str(e)}'
|
|
|
- }
|
|
|
+ }
|
|
|
+
|
|
|
+# ocr_validator_utils.py
|
|
|
+def find_available_ocr_files_multi_source(config: Dict) -> Dict[str, List[Dict]]:
|
|
|
+ """查找多个数据源的OCR文件"""
|
|
|
+ all_sources = {}
|
|
|
+
|
|
|
+ for source in config.get('data_sources', []):
|
|
|
+ source_name = source['name']
|
|
|
+ ocr_tool = source['ocr_tool']
|
|
|
+ source_key = f"{source_name}_{ocr_tool}" # 创建唯一标识
|
|
|
+
|
|
|
+ ocr_out_dir = source['ocr_out_dir']
|
|
|
+
|
|
|
+ if Path(ocr_out_dir).exists():
|
|
|
+ files = find_available_ocr_files(ocr_out_dir)
|
|
|
+
|
|
|
+ # 为每个文件添加数据源信息
|
|
|
+ for file_info in files:
|
|
|
+ file_info.update({
|
|
|
+ 'source_name': source_name,
|
|
|
+ 'ocr_tool': ocr_tool,
|
|
|
+ 'description': source.get('description', ''),
|
|
|
+ 'src_img_dir': source.get('src_img_dir', ''),
|
|
|
+ 'ocr_out_dir': ocr_out_dir
|
|
|
+ })
|
|
|
+
|
|
|
+ all_sources[source_key] = {
|
|
|
+ 'files': files,
|
|
|
+ 'config': source
|
|
|
+ }
|
|
|
+
|
|
|
+ print(f"📁 找到数据源: {source_key} - {len(files)} 个文件")
|
|
|
+
|
|
|
+ return all_sources
|
|
|
+
|
|
|
+def get_data_source_display_name(source_config: Dict) -> str:
|
|
|
+ """生成数据源的显示名称"""
|
|
|
+ name = source_config['name']
|
|
|
+ tool = source_config['ocr_tool']
|
|
|
+ description = source_config.get('description', '')
|
|
|
+
|
|
|
+ # 获取工具的友好名称
|
|
|
+ tool_name_map = {
|
|
|
+ 'dots_ocr': 'Dots OCR',
|
|
|
+ 'ppstructv3': 'PPStructV3'
|
|
|
+ }
|
|
|
+
|
|
|
+ tool_display = tool_name_map.get(tool, tool)
|
|
|
+ return f"{name} ({tool_display})"
|