Просмотр исходного кода

优化图像旋转和坐标处理,新增Markdown图片引用的base64转换功能

zhch158_admin 2 месяцев назад
Родитель
Сommit
2fd2181eb9
1 измененных файлов с 118 добавлено и 7 удалено
  1. 118 7
      ocr_validator_utils.py

+ 118 - 7
ocr_validator_utils.py

@@ -13,6 +13,9 @@ from io import StringIO, BytesIO
 import re
 from html import unescape
 import yaml
+import base64
+from urllib.parse import urlparse
+import os
 
 
 def load_config(config_path: str = "config.yaml") -> Dict:
@@ -81,18 +84,24 @@ def load_css_styles(css_path: str = "styles.css") -> str:
         """
 
 
-def rotate_image_and_coordinates(image: Image.Image, angle: float, coordinates_list: List[List[int]]) -> Tuple[Image.Image, List[List[int]]]:
+def rotate_image_and_coordinates(
+    image: Image.Image, 
+    angle: float, 
+    coordinates_list: List[List[int]], 
+    rotate_coordinates: bool = True
+) -> Tuple[Image.Image, List[List[int]]]:
     """
-    根据角度旋转图像和坐标 - 修复坐标变换和图片显示
+    根据角度旋转图像和坐标
     
     Args:
         image: 原始图像
         angle: 旋转角度(度数)
         coordinates_list: 坐标列表,每个坐标为[x1, y1, x2, y2]格式
+        rotate_coordinates: 是否需要旋转坐标(针对不同OCR工具的处理方式)
     
     Returns:
         rotated_image: 旋转后的图像
-        rotated_coordinates: 旋转后的坐标列表
+        rotated_coordinates: 处理后的坐标列表
     """
     if angle == 0:
         return image, coordinates_list
@@ -110,6 +119,10 @@ def rotate_image_and_coordinates(image: Image.Image, angle: float, coordinates_l
     # 旋转图像
     rotated_image = image.rotate(rotation_angle, expand=True)
     
+    # 如果不需要旋转坐标,直接返回原坐标
+    if not rotate_coordinates:
+        return rotated_image, coordinates_list
+    
     # 获取原始和旋转后的图像尺寸
     orig_width, orig_height = image.size
     new_width, new_height = rotated_image.size
@@ -124,7 +137,13 @@ def rotate_image_and_coordinates(image: Image.Image, angle: float, coordinates_l
             
         x1, y1, x2, y2 = coord[:4]
         
-        # 根据旋转角度变换坐标 - 修复变换逻辑
+        # 验证原始坐标是否有效
+        if x1 < 0 or y1 < 0 or x2 <= x1 or y2 <= y1:
+            print(f"警告: 无效坐标 {coord}")
+            rotated_coordinates.append([0, 0, 50, 50])  # 使用默认坐标
+            continue
+        
+        # 根据旋转角度变换坐标
         if rotation_angle == -90:  # 顺时针90度 (270度逆时针)
             # 变换公式: (x, y) -> (y, orig_width - x)
             new_x1 = y1
@@ -322,6 +341,65 @@ def get_rotation_angle_from_ppstructv3(data: Dict) -> float:
     return 0.0
 
 
+def process_markdown_images(md_content: str, json_path: str) -> str:
+    """
+    处理Markdown中的图片引用,将本地图片转换为base64
+    """
+    import re
+    
+    # 匹配Markdown图片语法: ![alt](path)
+    img_pattern = r'!\[([^\]]*)\]\(([^)]+)\)'
+    
+    def replace_image(match):
+        alt_text = match.group(1)
+        img_path = match.group(2)
+        
+        # 如果已经是base64或者网络链接,直接返回
+        if img_path.startswith('data:image') or img_path.startswith('http'):
+            return match.group(0)
+        
+        # 处理相对路径
+        if not os.path.isabs(img_path):
+            # 相对于JSON文件的路径
+            json_dir = os.path.dirname(json_path)
+            full_img_path = os.path.join(json_dir, img_path)
+        else:
+            full_img_path = img_path
+        
+        # 尝试转换为base64
+        try:
+            if os.path.exists(full_img_path):
+                with open(full_img_path, 'rb') as img_file:
+                    img_data = img_file.read()
+                    
+                # 获取文件扩展名确定MIME类型
+                ext = os.path.splitext(full_img_path)[1].lower()
+                mime_type = {
+                    '.png': 'image/png',
+                    '.jpg': 'image/jpeg',
+                    '.jpeg': 'image/jpeg',
+                    '.gif': 'image/gif',
+                    '.bmp': 'image/bmp',
+                    '.webp': 'image/webp'
+                }.get(ext, 'image/jpeg')
+                
+                # 转换为base64
+                img_base64 = base64.b64encode(img_data).decode('utf-8')
+                data_url = f"data:{mime_type};base64,{img_base64}"
+                
+                return f'![{alt_text}]({data_url})'
+            else:
+                # 文件不存在,返回原始链接但添加警告
+                return f'![{alt_text} (文件不存在)]({img_path})'
+        except Exception as e:
+            # 转换失败,返回原始链接
+            return f'![{alt_text} (加载失败)]({img_path})'
+    
+    # 替换所有图片引用
+    processed_content = re.sub(img_pattern, replace_image, md_content)
+    return processed_content
+
+
 def load_ocr_data_file(json_path: str, config: Dict) -> Tuple[List, str, str]:
     """加载OCR相关数据文件"""
     json_file = Path(json_path)
@@ -354,7 +432,10 @@ def load_ocr_data_file(json_path: str, config: Dict) -> Tuple[List, str, str]:
     md_file = json_file.with_suffix('.md')
     if md_file.exists():
         with open(md_file, 'r', encoding='utf-8') as f:
-            md_content = f.read()
+            raw_md_content = f.read()
+            
+        # 处理Markdown中的图片引用
+        md_content = process_markdown_images(raw_md_content, str(json_file))
     
     # 推断图片路径
     image_name = json_file.stem
@@ -425,7 +506,9 @@ def find_available_ocr_files(ocr_out_dir: str) -> List[str]:
             # 递归搜索JSON文件
             for json_file in search_dir.rglob("*.json"):
                 available_files.append(str(json_file))
-    
+    # 去重并排序
+    available_files = sorted(list(set(available_files)))
+
     return available_files
 
 
@@ -610,4 +693,32 @@ def group_texts_by_category(text_bbox_mapping: Dict[str, List]) -> Dict[str, Lis
         if category not in categories:
             categories[category] = []
         categories[category].append(text)
-    return categories
+    return categories
+
+
+def get_ocr_tool_rotation_config(ocr_data: List, config: Dict) -> Dict:
+    """获取OCR工具的旋转配置"""
+    if not ocr_data or not isinstance(ocr_data, list):
+        # 默认配置
+        return {
+            'coordinates_need_rotation': True,
+            'coordinates_are_pre_rotated': False
+        }
+    
+    # 从第一个OCR数据项获取工具类型
+    first_item = ocr_data[0] if ocr_data else {}
+    source_tool = first_item.get('source_tool', 'dots_ocr')
+    
+    # 获取工具配置
+    tools_config = config.get('ocr', {}).get('tools', {})
+    
+    if source_tool in tools_config:
+        tool_config = tools_config[source_tool]
+        return tool_config.get('rotation', {
+            'coordinates_are_pre_rotated': False
+        })
+    else:
+        # 默认配置
+        return {
+            'coordinates_are_pre_rotated': False
+        }