ocr_validator_utils.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. """
  2. OCR验证工具的工具函数模块
  3. 包含数据处理、图像处理、统计分析等功能
  4. """
  5. import json
  6. import pandas as pd
  7. import numpy as np
  8. from pathlib import Path
  9. from PIL import Image, ImageDraw
  10. from typing import Dict, List, Optional, Tuple
  11. from io import StringIO, BytesIO
  12. import re
  13. from html import unescape
  14. import yaml
  15. def load_config(config_path: str = "config.yaml") -> Dict:
  16. """加载配置文件"""
  17. try:
  18. with open(config_path, 'r', encoding='utf-8') as f:
  19. return yaml.safe_load(f)
  20. except Exception as e:
  21. # 返回默认配置
  22. return get_default_config()
  23. def get_default_config() -> Dict:
  24. """获取默认配置"""
  25. return {
  26. 'styles': {
  27. 'font_sizes': {'small': 10, 'medium': 12, 'large': 14, 'extra_large': 16},
  28. 'colors': {
  29. 'primary': '#0288d1', 'secondary': '#ff9800', 'success': '#4caf50',
  30. 'error': '#f44336', 'warning': '#ff9800', 'background': '#fafafa', 'text': '#333333'
  31. },
  32. 'layout': {'default_zoom': 1.0, 'default_height': 600, 'sidebar_width': 0.3, 'content_width': 0.7}
  33. },
  34. 'ui': {
  35. 'page_title': 'OCR可视化校验工具', 'page_icon': '🔍', 'layout': 'wide',
  36. 'sidebar_state': 'expanded', 'default_font_size': 'medium', 'default_layout': '标准布局'
  37. },
  38. 'paths': {
  39. 'output_dir': 'output', 'sample_data_dir': './sample_data',
  40. 'supported_image_formats': ['.png', '.jpg', '.jpeg']
  41. },
  42. 'ocr': {'min_text_length': 2, 'default_confidence': 1.0, 'exclude_texts': ['Picture', '']}
  43. }
  44. def load_css_styles(css_path: str = "styles.css") -> str:
  45. """加载CSS样式文件"""
  46. try:
  47. with open(css_path, 'r', encoding='utf-8') as f:
  48. return f.read()
  49. except Exception:
  50. # 返回基本样式
  51. return """
  52. .main > div { background-color: white !important; color: #333333 !important; }
  53. .stApp { background-color: white !important; }
  54. .block-container { background-color: white !important; color: #333333 !important; }
  55. """
  56. def load_ocr_data_file(json_path: str, config: Dict) -> Tuple[List, str, str]:
  57. """加载OCR相关数据文件"""
  58. json_file = Path(json_path)
  59. ocr_data = []
  60. md_content = ""
  61. image_path = ""
  62. # 加载JSON数据
  63. try:
  64. with open(json_file, 'r', encoding='utf-8') as f:
  65. data = json.load(f)
  66. if isinstance(data, list):
  67. ocr_data = data
  68. elif isinstance(data, dict) and 'results' in data:
  69. ocr_data = data['results']
  70. else:
  71. raise ValueError(f"不支持的JSON格式: {json_path}")
  72. except Exception as e:
  73. raise Exception(f"加载JSON文件失败: {e}")
  74. # 加载MD文件
  75. md_file = json_file.with_suffix('.md')
  76. if md_file.exists():
  77. with open(md_file, 'r', encoding='utf-8') as f:
  78. md_content = f.read()
  79. # 推断图片路径
  80. image_name = json_file.stem
  81. sample_data_dir = Path(config['paths']['sample_data_dir'])
  82. image_candidates = []
  83. for ext in config['paths']['supported_image_formats']:
  84. image_candidates.extend([
  85. sample_data_dir / f"{image_name}{ext}",
  86. json_file.parent / f"{image_name}{ext}",
  87. ])
  88. for candidate in image_candidates:
  89. if candidate.exists():
  90. image_path = str(candidate)
  91. break
  92. return ocr_data, md_content, image_path
  93. def process_ocr_data(ocr_data: List, config: Dict) -> Dict[str, List]:
  94. """处理OCR数据,建立文本到bbox的映射"""
  95. text_bbox_mapping = {}
  96. exclude_texts = config['ocr']['exclude_texts']
  97. min_text_length = config['ocr']['min_text_length']
  98. if not isinstance(ocr_data, list):
  99. return text_bbox_mapping
  100. for i, item in enumerate(ocr_data):
  101. if not isinstance(item, dict):
  102. continue
  103. if 'text' in item and 'bbox' in item:
  104. text = str(item['text']).strip()
  105. if text and text not in exclude_texts and len(text) >= min_text_length:
  106. bbox = item['bbox']
  107. if isinstance(bbox, list) and len(bbox) == 4:
  108. if text not in text_bbox_mapping:
  109. text_bbox_mapping[text] = []
  110. text_bbox_mapping[text].append({
  111. 'bbox': bbox,
  112. 'category': item.get('category', 'Text'),
  113. 'index': i,
  114. 'confidence': item.get('confidence', config['ocr']['default_confidence'])
  115. })
  116. return text_bbox_mapping
  117. def draw_bbox_on_image(image: Image.Image, bbox: List[int], color: str = "red", width: int = 3) -> Image.Image:
  118. """在图片上绘制bbox框"""
  119. img_copy = image.copy()
  120. draw = ImageDraw.Draw(img_copy)
  121. x1, y1, x2, y2 = bbox
  122. # 绘制矩形框
  123. draw.rectangle([x1, y1, x2, y2], outline=color, width=width)
  124. # 添加半透明填充
  125. overlay = Image.new('RGBA', img_copy.size, (0, 0, 0, 0))
  126. overlay_draw = ImageDraw.Draw(overlay)
  127. color_map = {
  128. "red": (255, 0, 0, 30),
  129. "blue": (0, 0, 255, 30),
  130. "green": (0, 255, 0, 30)
  131. }
  132. fill_color = color_map.get(color, (255, 255, 0, 30))
  133. overlay_draw.rectangle([x1, y1, x2, y2], fill=fill_color)
  134. img_copy = Image.alpha_composite(img_copy.convert('RGBA'), overlay).convert('RGB')
  135. return img_copy
  136. def get_ocr_statistics(ocr_data: List, text_bbox_mapping: Dict, marked_errors: set) -> Dict:
  137. """获取OCR数据统计信息"""
  138. if not isinstance(ocr_data, list) or not ocr_data:
  139. return {
  140. 'total_texts': 0, 'clickable_texts': 0, 'marked_errors': 0,
  141. 'categories': {}, 'accuracy_rate': 0
  142. }
  143. total_texts = len(ocr_data)
  144. clickable_texts = len(text_bbox_mapping)
  145. marked_errors_count = len(marked_errors)
  146. # 按类别统计
  147. categories = {}
  148. for item in ocr_data:
  149. if isinstance(item, dict):
  150. category = item.get('category', 'Unknown')
  151. elif isinstance(item, str):
  152. category = 'Text'
  153. else:
  154. category = 'Unknown'
  155. categories[category] = categories.get(category, 0) + 1
  156. accuracy_rate = (clickable_texts - marked_errors_count) / clickable_texts * 100 if clickable_texts > 0 else 0
  157. return {
  158. 'total_texts': total_texts,
  159. 'clickable_texts': clickable_texts,
  160. 'marked_errors': marked_errors_count,
  161. 'categories': categories,
  162. 'accuracy_rate': accuracy_rate
  163. }
  164. def convert_html_table_to_markdown(content: str) -> str:
  165. """将HTML表格转换为Markdown表格格式"""
  166. def replace_table(match):
  167. table_html = match.group(0)
  168. # 提取所有行
  169. rows = re.findall(r'<tr>(.*?)</tr>', table_html, re.DOTALL | re.IGNORECASE)
  170. if not rows:
  171. return table_html
  172. markdown_rows = []
  173. for i, row in enumerate(rows):
  174. # 提取单元格
  175. cells = re.findall(r'<td[^>]*>(.*?)</td>', row, re.DOTALL | re.IGNORECASE)
  176. if cells:
  177. # 清理单元格内容
  178. clean_cells = []
  179. for cell in cells:
  180. cell_text = re.sub(r'<[^>]+>', '', cell).strip()
  181. cell_text = unescape(cell_text)
  182. clean_cells.append(cell_text)
  183. # 构建Markdown行
  184. markdown_row = '| ' + ' | '.join(clean_cells) + ' |'
  185. markdown_rows.append(markdown_row)
  186. # 在第一行后添加分隔符
  187. if i == 0:
  188. separator = '| ' + ' | '.join(['---'] * len(clean_cells)) + ' |'
  189. markdown_rows.append(separator)
  190. return '\n'.join(markdown_rows) if markdown_rows else table_html
  191. # 替换所有HTML表格
  192. converted = re.sub(r'<table[^>]*>.*?</table>', replace_table, content, flags=re.DOTALL | re.IGNORECASE)
  193. return converted
  194. def parse_html_tables(html_content: str) -> List[pd.DataFrame]:
  195. """解析HTML内容中的表格为DataFrame列表"""
  196. try:
  197. tables = pd.read_html(StringIO(html_content))
  198. return tables if tables else []
  199. except Exception:
  200. return []
  201. def find_available_ocr_files(output_dir: str) -> List[str]:
  202. """查找可用的OCR文件"""
  203. available_files = []
  204. output_path = Path(output_dir)
  205. if output_path.exists():
  206. for json_file in output_path.rglob("*.json"):
  207. available_files.append(str(json_file))
  208. return available_files
  209. def create_dynamic_css(config: Dict, font_size_key: str, height: int) -> str:
  210. """根据配置动态创建CSS样式"""
  211. colors = config['styles']['colors']
  212. font_size = config['styles']['font_sizes'][font_size_key]
  213. return f"""
  214. <style>
  215. .dynamic-content {{
  216. height: {height}px;
  217. font-size: {font_size}px !important;
  218. line-height: 1.4;
  219. background-color: {colors['background']} !important;
  220. color: {colors['text']} !important;
  221. border: 1px solid #ddd;
  222. padding: 10px;
  223. border-radius: 5px;
  224. }}
  225. .highlight-selected {{
  226. background-color: {colors['success']} !important;
  227. color: white !important;
  228. }}
  229. .highlight-error {{
  230. background-color: {colors['error']} !important;
  231. color: white !important;
  232. }}
  233. </style>
  234. """
  235. def export_tables_to_excel(tables: List[pd.DataFrame], filename: str = "ocr_tables.xlsx") -> BytesIO:
  236. """导出表格数据到Excel"""
  237. output = BytesIO()
  238. with pd.ExcelWriter(output, engine='openpyxl') as writer:
  239. for i, table in enumerate(tables):
  240. table.to_excel(writer, sheet_name=f'Table_{i+1}', index=False)
  241. return output
  242. def get_table_statistics(tables: List[pd.DataFrame]) -> List[Dict]:
  243. """获取表格统计信息"""
  244. stats = []
  245. for i, table in enumerate(tables):
  246. numeric_cols = len(table.select_dtypes(include=[np.number]).columns)
  247. stats.append({
  248. 'table_index': i + 1,
  249. 'rows': len(table),
  250. 'columns': len(table.columns),
  251. 'numeric_columns': numeric_cols
  252. })
  253. return stats
  254. def group_texts_by_category(text_bbox_mapping: Dict[str, List]) -> Dict[str, List[str]]:
  255. """按类别对文本进行分组"""
  256. categories = {}
  257. for text, info_list in text_bbox_mapping.items():
  258. category = info_list[0]['category']
  259. if category not in categories:
  260. categories[category] = []
  261. categories[category].append(text)
  262. return categories