|
|
@@ -8,7 +8,7 @@ import pandas as pd
|
|
|
import numpy as np
|
|
|
from pathlib import Path
|
|
|
from PIL import Image, ImageDraw
|
|
|
-from typing import Dict, List, Optional, Tuple
|
|
|
+from typing import Dict, List, Optional, Tuple, Union
|
|
|
from io import StringIO, BytesIO
|
|
|
import re
|
|
|
from html import unescape
|
|
|
@@ -41,10 +41,29 @@ def get_default_config() -> Dict:
|
|
|
'sidebar_state': 'expanded', 'default_font_size': 'medium', 'default_layout': '标准布局'
|
|
|
},
|
|
|
'paths': {
|
|
|
- 'output_dir': 'output', 'sample_data_dir': './sample_data',
|
|
|
+ 'ocr_out_dir': './sample_data', 'src_img_dir': './sample_data',
|
|
|
'supported_image_formats': ['.png', '.jpg', '.jpeg']
|
|
|
},
|
|
|
- 'ocr': {'min_text_length': 2, 'default_confidence': 1.0, 'exclude_texts': ['Picture', '']}
|
|
|
+ 'ocr': {
|
|
|
+ 'min_text_length': 2, 'default_confidence': 1.0, 'exclude_texts': ['Picture', ''],
|
|
|
+ 'tools': {
|
|
|
+ 'dots_ocr': {
|
|
|
+ 'name': 'Dots OCR', 'json_structure': 'array',
|
|
|
+ 'text_field': 'text', 'bbox_field': 'bbox', 'category_field': 'category'
|
|
|
+ },
|
|
|
+ 'ppstructv3': {
|
|
|
+ 'name': 'PPStructV3', 'json_structure': 'object', 'parsing_results_field': 'parsing_res_list',
|
|
|
+ 'text_field': 'block_content', 'bbox_field': 'block_bbox', 'category_field': 'block_label'
|
|
|
+ }
|
|
|
+ },
|
|
|
+ 'auto_detection': {
|
|
|
+ 'enabled': True,
|
|
|
+ 'rules': [
|
|
|
+ {'field_exists': 'parsing_res_list', 'tool_type': 'ppstructv3'},
|
|
|
+ {'json_is_array': True, 'tool_type': 'dots_ocr'}
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
|
|
|
@@ -62,6 +81,247 @@ def load_css_styles(css_path: str = "styles.css") -> str:
|
|
|
"""
|
|
|
|
|
|
|
|
|
+def rotate_image_and_coordinates(image: Image.Image, angle: float, coordinates_list: List[List[int]]) -> Tuple[Image.Image, List[List[int]]]:
|
|
|
+ """
|
|
|
+ 根据角度旋转图像和坐标 - 修复坐标变换和图片显示
|
|
|
+
|
|
|
+ Args:
|
|
|
+ image: 原始图像
|
|
|
+ angle: 旋转角度(度数)
|
|
|
+ coordinates_list: 坐标列表,每个坐标为[x1, y1, x2, y2]格式
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ rotated_image: 旋转后的图像
|
|
|
+ rotated_coordinates: 旋转后的坐标列表
|
|
|
+ """
|
|
|
+ if angle == 0:
|
|
|
+ return image, coordinates_list
|
|
|
+
|
|
|
+ # 标准化旋转角度
|
|
|
+ if angle == 270:
|
|
|
+ rotation_angle = -90 # 顺时针90度
|
|
|
+ elif angle == 90:
|
|
|
+ rotation_angle = 90 # 逆时针90度
|
|
|
+ elif angle == 180:
|
|
|
+ rotation_angle = 180 # 180度
|
|
|
+ else:
|
|
|
+ rotation_angle = angle
|
|
|
+
|
|
|
+ # 旋转图像
|
|
|
+ rotated_image = image.rotate(rotation_angle, expand=True)
|
|
|
+
|
|
|
+ # 获取原始和旋转后的图像尺寸
|
|
|
+ orig_width, orig_height = image.size
|
|
|
+ new_width, new_height = rotated_image.size
|
|
|
+
|
|
|
+ # 计算旋转后的坐标
|
|
|
+ rotated_coordinates = []
|
|
|
+
|
|
|
+ for coord in coordinates_list:
|
|
|
+ if len(coord) < 4:
|
|
|
+ rotated_coordinates.append(coord)
|
|
|
+ continue
|
|
|
+
|
|
|
+ x1, y1, x2, y2 = coord[:4]
|
|
|
+
|
|
|
+ # 根据旋转角度变换坐标 - 修复变换逻辑
|
|
|
+ if rotation_angle == -90: # 顺时针90度 (270度逆时针)
|
|
|
+ # 变换公式: (x, y) -> (y, orig_width - x)
|
|
|
+ new_x1 = y1
|
|
|
+ new_y1 = orig_width - x2
|
|
|
+ new_x2 = y2
|
|
|
+ new_y2 = orig_width - x1
|
|
|
+
|
|
|
+ elif rotation_angle == 90: # 逆时针90度
|
|
|
+ # 变换公式: (x, y) -> (orig_height - y, x)
|
|
|
+ new_x1 = orig_height - y2
|
|
|
+ new_y1 = x1
|
|
|
+ new_x2 = orig_height - y1
|
|
|
+ new_y2 = x2
|
|
|
+
|
|
|
+ elif rotation_angle == 180: # 180度
|
|
|
+ # 变换公式: (x, y) -> (orig_width - x, orig_height - y)
|
|
|
+ new_x1 = orig_width - x2
|
|
|
+ new_y1 = orig_height - y2
|
|
|
+ new_x2 = orig_width - x1
|
|
|
+ new_y2 = orig_height - y1
|
|
|
+
|
|
|
+ else:
|
|
|
+ # 对于其他角度,使用通用的旋转矩阵
|
|
|
+ center_x, center_y = orig_width / 2, orig_height / 2
|
|
|
+ new_center_x, new_center_y = new_width / 2, new_height / 2
|
|
|
+
|
|
|
+ angle_rad = np.radians(rotation_angle)
|
|
|
+ cos_angle = np.cos(angle_rad)
|
|
|
+ sin_angle = np.sin(angle_rad)
|
|
|
+
|
|
|
+ # 旋转四个角点
|
|
|
+ corners = [
|
|
|
+ (x1 - center_x, y1 - center_y),
|
|
|
+ (x2 - center_x, y1 - center_y),
|
|
|
+ (x2 - center_x, y2 - center_y),
|
|
|
+ (x1 - center_x, y2 - center_y)
|
|
|
+ ]
|
|
|
+
|
|
|
+ rotated_corners = []
|
|
|
+ for x, y in corners:
|
|
|
+ new_x = x * cos_angle - y * sin_angle
|
|
|
+ new_y = x * sin_angle + y * cos_angle
|
|
|
+ rotated_corners.append((new_x + new_center_x, new_y + new_center_y))
|
|
|
+
|
|
|
+ # 计算边界框
|
|
|
+ x_coords = [corner[0] for corner in rotated_corners]
|
|
|
+ y_coords = [corner[1] for corner in rotated_corners]
|
|
|
+
|
|
|
+ new_x1 = int(min(x_coords))
|
|
|
+ new_y1 = int(min(y_coords))
|
|
|
+ new_x2 = int(max(x_coords))
|
|
|
+ new_y2 = int(max(y_coords))
|
|
|
+
|
|
|
+ # 确保坐标在有效范围内
|
|
|
+ new_x1 = max(0, min(new_width, new_x1))
|
|
|
+ new_y1 = max(0, min(new_height, new_y1))
|
|
|
+ new_x2 = max(0, min(new_width, new_x2))
|
|
|
+ new_y2 = max(0, min(new_height, new_y2))
|
|
|
+
|
|
|
+ # 确保x1 < x2, y1 < y2
|
|
|
+ if new_x1 > new_x2:
|
|
|
+ new_x1, new_x2 = new_x2, new_x1
|
|
|
+ if new_y1 > new_y2:
|
|
|
+ new_y1, new_y2 = new_y2, new_y1
|
|
|
+
|
|
|
+ rotated_coordinates.append([new_x1, new_y1, new_x2, new_y2])
|
|
|
+
|
|
|
+ return rotated_image, rotated_coordinates
|
|
|
+
|
|
|
+
|
|
|
+def detect_ocr_tool_type(data: Union[List, Dict], config: Dict) -> str:
|
|
|
+ """自动检测OCR工具类型"""
|
|
|
+ if not config['ocr']['auto_detection']['enabled']:
|
|
|
+ return 'dots_ocr' # 默认类型
|
|
|
+
|
|
|
+ rules = config['ocr']['auto_detection']['rules']
|
|
|
+
|
|
|
+ for rule in rules:
|
|
|
+ if 'field_exists' in rule:
|
|
|
+ field_name = rule['field_exists']
|
|
|
+ if isinstance(data, dict) and field_name in data:
|
|
|
+ return rule['tool_type']
|
|
|
+
|
|
|
+ if 'json_is_array' in rule:
|
|
|
+ if rule['json_is_array'] and isinstance(data, list):
|
|
|
+ return rule['tool_type']
|
|
|
+
|
|
|
+ # 默认返回dots_ocr
|
|
|
+ return 'dots_ocr'
|
|
|
+
|
|
|
+
|
|
|
+def parse_dots_ocr_data(data: List, config: Dict) -> List[Dict]:
|
|
|
+ """解析Dots OCR格式的数据"""
|
|
|
+ tool_config = config['ocr']['tools']['dots_ocr']
|
|
|
+ parsed_data = []
|
|
|
+
|
|
|
+ for item in data:
|
|
|
+ if not isinstance(item, dict):
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 提取字段
|
|
|
+ text = item.get(tool_config['text_field'], '')
|
|
|
+ bbox = item.get(tool_config['bbox_field'], [])
|
|
|
+ category = item.get(tool_config['category_field'], 'Text')
|
|
|
+ confidence = item.get(tool_config.get('confidence_field', 'confidence'),
|
|
|
+ config['ocr']['default_confidence'])
|
|
|
+
|
|
|
+ if text and bbox and len(bbox) >= 4:
|
|
|
+ parsed_data.append({
|
|
|
+ 'text': str(text).strip(),
|
|
|
+ 'bbox': bbox[:4], # 确保只取前4个坐标
|
|
|
+ 'category': category,
|
|
|
+ 'confidence': confidence,
|
|
|
+ 'source_tool': 'dots_ocr'
|
|
|
+ })
|
|
|
+
|
|
|
+ return parsed_data
|
|
|
+
|
|
|
+
|
|
|
+def parse_ppstructv3_data(data: Dict, config: Dict) -> List[Dict]:
|
|
|
+ """解析PPStructV3格式的数据"""
|
|
|
+ tool_config = config['ocr']['tools']['ppstructv3']
|
|
|
+ parsed_data = []
|
|
|
+
|
|
|
+ # 获取解析结果列表
|
|
|
+ parsing_results_field = tool_config['parsing_results_field']
|
|
|
+ if parsing_results_field not in data:
|
|
|
+ return parsed_data
|
|
|
+
|
|
|
+ parsing_results = data[parsing_results_field]
|
|
|
+ if not isinstance(parsing_results, list):
|
|
|
+ return parsed_data
|
|
|
+
|
|
|
+ for item in parsing_results:
|
|
|
+ if not isinstance(item, dict):
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 提取字段
|
|
|
+ text = item.get(tool_config['text_field'], '')
|
|
|
+ bbox = item.get(tool_config['bbox_field'], [])
|
|
|
+ category = item.get(tool_config['category_field'], 'text')
|
|
|
+ confidence = item.get(tool_config.get('confidence_field', 'confidence'),
|
|
|
+ config['ocr']['default_confidence'])
|
|
|
+
|
|
|
+ if text and bbox and len(bbox) >= 4:
|
|
|
+ parsed_data.append({
|
|
|
+ 'text': str(text).strip(),
|
|
|
+ 'bbox': bbox[:4], # 确保只取前4个坐标
|
|
|
+ 'category': category,
|
|
|
+ 'confidence': confidence,
|
|
|
+ 'source_tool': 'ppstructv3'
|
|
|
+ })
|
|
|
+
|
|
|
+ # 如果有OCR文本识别结果,也添加进来
|
|
|
+ if 'overall_ocr_res' in data:
|
|
|
+ ocr_res = data['overall_ocr_res']
|
|
|
+ if isinstance(ocr_res, dict) and 'rec_texts' in ocr_res and 'rec_boxes' in ocr_res:
|
|
|
+ texts = ocr_res['rec_texts']
|
|
|
+ boxes = ocr_res['rec_boxes']
|
|
|
+ scores = ocr_res.get('rec_scores', [])
|
|
|
+
|
|
|
+ for i, (text, box) in enumerate(zip(texts, boxes)):
|
|
|
+ if text and len(box) >= 4:
|
|
|
+ confidence = scores[i] if i < len(scores) else config['ocr']['default_confidence']
|
|
|
+ parsed_data.append({
|
|
|
+ 'text': str(text).strip(),
|
|
|
+ 'bbox': box[:4],
|
|
|
+ 'category': 'OCR_Text',
|
|
|
+ 'confidence': confidence,
|
|
|
+ 'source_tool': 'ppstructv3_ocr'
|
|
|
+ })
|
|
|
+
|
|
|
+ return parsed_data
|
|
|
+
|
|
|
+
|
|
|
+def normalize_ocr_data(raw_data: Union[List, Dict], config: Dict) -> List[Dict]:
|
|
|
+ """统一不同OCR工具的数据格式"""
|
|
|
+ # 自动检测OCR工具类型
|
|
|
+ tool_type = detect_ocr_tool_type(raw_data, config)
|
|
|
+
|
|
|
+ if tool_type == 'dots_ocr':
|
|
|
+ return parse_dots_ocr_data(raw_data, config)
|
|
|
+ elif tool_type == 'ppstructv3':
|
|
|
+ return parse_ppstructv3_data(raw_data, config)
|
|
|
+ else:
|
|
|
+ raise ValueError(f"不支持的OCR工具类型: {tool_type}")
|
|
|
+
|
|
|
+
|
|
|
+def get_rotation_angle_from_ppstructv3(data: Dict) -> float:
|
|
|
+ """从PPStructV3数据中获取旋转角度"""
|
|
|
+ if 'doc_preprocessor_res' in data:
|
|
|
+ doc_res = data['doc_preprocessor_res']
|
|
|
+ if isinstance(doc_res, dict) and 'angle' in doc_res:
|
|
|
+ return float(doc_res['angle'])
|
|
|
+ return 0.0
|
|
|
+
|
|
|
+
|
|
|
def load_ocr_data_file(json_path: str, config: Dict) -> Tuple[List, str, str]:
|
|
|
"""加载OCR相关数据文件"""
|
|
|
json_file = Path(json_path)
|
|
|
@@ -72,13 +332,21 @@ def load_ocr_data_file(json_path: str, config: Dict) -> Tuple[List, str, str]:
|
|
|
# 加载JSON数据
|
|
|
try:
|
|
|
with open(json_file, 'r', encoding='utf-8') as f:
|
|
|
- data = json.load(f)
|
|
|
- if isinstance(data, list):
|
|
|
- ocr_data = data
|
|
|
- elif isinstance(data, dict) and 'results' in data:
|
|
|
- ocr_data = data['results']
|
|
|
- else:
|
|
|
- raise ValueError(f"不支持的JSON格式: {json_path}")
|
|
|
+ raw_data = json.load(f)
|
|
|
+
|
|
|
+ # 统一数据格式
|
|
|
+ ocr_data = normalize_ocr_data(raw_data, config)
|
|
|
+
|
|
|
+ # 检查是否需要处理图像旋转
|
|
|
+ rotation_angle = 0.0
|
|
|
+ if isinstance(raw_data, dict):
|
|
|
+ rotation_angle = get_rotation_angle_from_ppstructv3(raw_data)
|
|
|
+
|
|
|
+ # 如果有旋转角度,记录下来供后续图像处理使用
|
|
|
+ if rotation_angle != 0:
|
|
|
+ for item in ocr_data:
|
|
|
+ item['rotation_angle'] = rotation_angle
|
|
|
+
|
|
|
except Exception as e:
|
|
|
raise Exception(f"加载JSON文件失败: {e}")
|
|
|
|
|
|
@@ -90,15 +358,20 @@ def load_ocr_data_file(json_path: str, config: Dict) -> Tuple[List, str, str]:
|
|
|
|
|
|
# 推断图片路径
|
|
|
image_name = json_file.stem
|
|
|
- sample_data_dir = Path(config['paths']['sample_data_dir'])
|
|
|
+ src_img_dir = Path(config['paths']['src_img_dir'])
|
|
|
|
|
|
image_candidates = []
|
|
|
for ext in config['paths']['supported_image_formats']:
|
|
|
image_candidates.extend([
|
|
|
- sample_data_dir / f"{image_name}{ext}",
|
|
|
+ src_img_dir / f"{image_name}{ext}",
|
|
|
json_file.parent / f"{image_name}{ext}",
|
|
|
+ # 对于PPStructV3,可能图片名包含page信息 # 去掉page后缀的通用匹配
|
|
|
+ src_img_dir / f"{image_name.split('_page_')[0]}{ext}" if '_page_' in image_name else None,
|
|
|
])
|
|
|
|
|
|
+ # 移除None值
|
|
|
+ image_candidates = [candidate for candidate in image_candidates if candidate is not None]
|
|
|
+
|
|
|
for candidate in image_candidates:
|
|
|
if candidate.exists():
|
|
|
image_path = str(candidate)
|
|
|
@@ -120,23 +393,53 @@ def process_ocr_data(ocr_data: List, config: Dict) -> Dict[str, List]:
|
|
|
if not isinstance(item, dict):
|
|
|
continue
|
|
|
|
|
|
- if 'text' in item and 'bbox' in item:
|
|
|
- text = str(item['text']).strip()
|
|
|
- if text and text not in exclude_texts and len(text) >= min_text_length:
|
|
|
- bbox = item['bbox']
|
|
|
- if isinstance(bbox, list) and len(bbox) == 4:
|
|
|
- if text not in text_bbox_mapping:
|
|
|
- text_bbox_mapping[text] = []
|
|
|
- text_bbox_mapping[text].append({
|
|
|
- 'bbox': bbox,
|
|
|
- 'category': item.get('category', 'Text'),
|
|
|
- 'index': i,
|
|
|
- 'confidence': item.get('confidence', config['ocr']['default_confidence'])
|
|
|
- })
|
|
|
+ text = str(item['text']).strip()
|
|
|
+ if text and text not in exclude_texts and len(text) >= min_text_length:
|
|
|
+ bbox = item['bbox']
|
|
|
+ if isinstance(bbox, list) and len(bbox) == 4:
|
|
|
+ if text not in text_bbox_mapping:
|
|
|
+ text_bbox_mapping[text] = []
|
|
|
+ text_bbox_mapping[text].append({
|
|
|
+ 'bbox': bbox,
|
|
|
+ 'category': item.get('category', 'Text'),
|
|
|
+ 'index': i,
|
|
|
+ 'confidence': item.get('confidence', config['ocr']['default_confidence']),
|
|
|
+ 'source_tool': item.get('source_tool', 'unknown'),
|
|
|
+ 'rotation_angle': item.get('rotation_angle', 0.0) # 添加旋转角度信息
|
|
|
+ })
|
|
|
|
|
|
return text_bbox_mapping
|
|
|
|
|
|
|
|
|
+def find_available_ocr_files(ocr_out_dir: str) -> List[str]:
|
|
|
+ """查找可用的OCR文件"""
|
|
|
+ available_files = []
|
|
|
+
|
|
|
+ # 搜索多个可能的目录
|
|
|
+ search_dirs = [
|
|
|
+ Path(ocr_out_dir),
|
|
|
+ ]
|
|
|
+
|
|
|
+ for search_dir in search_dirs:
|
|
|
+ if search_dir.exists():
|
|
|
+ # 递归搜索JSON文件
|
|
|
+ for json_file in search_dir.rglob("*.json"):
|
|
|
+ available_files.append(str(json_file))
|
|
|
+
|
|
|
+ return available_files
|
|
|
+
|
|
|
+
|
|
|
+def get_ocr_tool_info(ocr_data: List) -> Dict:
|
|
|
+ """获取OCR工具信息统计"""
|
|
|
+ tool_counts = {}
|
|
|
+ for item in ocr_data:
|
|
|
+ if isinstance(item, dict):
|
|
|
+ source_tool = item.get('source_tool', 'unknown')
|
|
|
+ tool_counts[source_tool] = tool_counts.get(source_tool, 0) + 1
|
|
|
+
|
|
|
+ return tool_counts
|
|
|
+
|
|
|
+
|
|
|
def draw_bbox_on_image(image: Image.Image, bbox: List[int], color: str = "red", width: int = 3) -> Image.Image:
|
|
|
"""在图片上绘制bbox框"""
|
|
|
img_copy = image.copy()
|
|
|
@@ -169,7 +472,7 @@ def get_ocr_statistics(ocr_data: List, text_bbox_mapping: Dict, marked_errors: s
|
|
|
if not isinstance(ocr_data, list) or not ocr_data:
|
|
|
return {
|
|
|
'total_texts': 0, 'clickable_texts': 0, 'marked_errors': 0,
|
|
|
- 'categories': {}, 'accuracy_rate': 0
|
|
|
+ 'categories': {}, 'accuracy_rate': 0, 'tool_info': {}
|
|
|
}
|
|
|
|
|
|
total_texts = len(ocr_data)
|
|
|
@@ -181,12 +484,10 @@ def get_ocr_statistics(ocr_data: List, text_bbox_mapping: Dict, marked_errors: s
|
|
|
for item in ocr_data:
|
|
|
if isinstance(item, dict):
|
|
|
category = item.get('category', 'Unknown')
|
|
|
- elif isinstance(item, str):
|
|
|
- category = 'Text'
|
|
|
- else:
|
|
|
- category = 'Unknown'
|
|
|
-
|
|
|
- categories[category] = categories.get(category, 0) + 1
|
|
|
+ categories[category] = categories.get(category, 0) + 1
|
|
|
+
|
|
|
+ # OCR工具信息统计
|
|
|
+ tool_info = get_ocr_tool_info(ocr_data)
|
|
|
|
|
|
accuracy_rate = (clickable_texts - marked_errors_count) / clickable_texts * 100 if clickable_texts > 0 else 0
|
|
|
|
|
|
@@ -195,7 +496,8 @@ def get_ocr_statistics(ocr_data: List, text_bbox_mapping: Dict, marked_errors: s
|
|
|
'clickable_texts': clickable_texts,
|
|
|
'marked_errors': marked_errors_count,
|
|
|
'categories': categories,
|
|
|
- 'accuracy_rate': accuracy_rate
|
|
|
+ 'accuracy_rate': accuracy_rate,
|
|
|
+ 'tool_info': tool_info
|
|
|
}
|
|
|
|
|
|
|
|
|
@@ -246,18 +548,6 @@ def parse_html_tables(html_content: str) -> List[pd.DataFrame]:
|
|
|
return []
|
|
|
|
|
|
|
|
|
-def find_available_ocr_files(output_dir: str) -> List[str]:
|
|
|
- """查找可用的OCR文件"""
|
|
|
- available_files = []
|
|
|
- output_path = Path(output_dir)
|
|
|
-
|
|
|
- if output_path.exists():
|
|
|
- for json_file in output_path.rglob("*.json"):
|
|
|
- available_files.append(str(json_file))
|
|
|
-
|
|
|
- return available_files
|
|
|
-
|
|
|
-
|
|
|
def create_dynamic_css(config: Dict, font_size_key: str, height: int) -> str:
|
|
|
"""根据配置动态创建CSS样式"""
|
|
|
colors = config['styles']['colors']
|