streamlit_ocr_validator.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639
  1. #!/usr/bin/env python3
  2. """
  3. 基于Streamlit的OCR可视化校验工具(修复版)
  4. 提供丰富的交互组件和更好的用户体验
  5. """
  6. import streamlit as st
  7. import json
  8. import pandas as pd
  9. from pathlib import Path
  10. import numpy as np
  11. from PIL import Image, ImageDraw, ImageFont
  12. import cv2
  13. import base64
  14. from typing import Dict, List, Optional, Tuple
  15. import plotly.express as px
  16. import plotly.graph_objects as go
  17. from plotly.subplots import make_subplots
  18. # 设置页面配置
  19. st.set_page_config(
  20. page_title="OCR可视化校验工具",
  21. page_icon="🔍",
  22. layout="wide",
  23. initial_sidebar_state="expanded"
  24. )
  25. # 自定义CSS样式
  26. st.markdown("""
  27. <style>
  28. .main > div {
  29. padding-top: 2rem;
  30. }
  31. .stSelectbox > div > div > div {
  32. background-color: #f0f2f6;
  33. }
  34. .clickable-text {
  35. background-color: #e1f5fe;
  36. padding: 2px 6px;
  37. border-radius: 4px;
  38. border: 1px solid #0288d1;
  39. cursor: pointer;
  40. margin: 2px;
  41. display: inline-block;
  42. }
  43. .selected-text {
  44. background-color: #fff3e0;
  45. border-color: #ff9800;
  46. font-weight: bold;
  47. }
  48. .error-text {
  49. background-color: #ffebee;
  50. border-color: #f44336;
  51. color: #d32f2f;
  52. }
  53. .stats-container {
  54. background-color: #f8f9fa;
  55. padding: 1rem;
  56. border-radius: 8px;
  57. border-left: 4px solid #28a745;
  58. }
  59. </style>
  60. """, unsafe_allow_html=True)
  61. class StreamlitOCRValidator:
  62. def __init__(self):
  63. self.ocr_data = []
  64. self.md_content = ""
  65. self.image_path = ""
  66. self.text_bbox_mapping = {}
  67. self.selected_text = None
  68. self.marked_errors = set()
  69. def load_ocr_data(self, json_path: str, md_path: Optional[str] = None, image_path: Optional[str] = None):
  70. """加载OCR相关数据"""
  71. json_file = Path(json_path)
  72. # 加载JSON数据
  73. try:
  74. with open(json_file, 'r', encoding='utf-8') as f:
  75. data = json.load(f)
  76. # 确保数据是列表格式
  77. if isinstance(data, list):
  78. self.ocr_data = data
  79. elif isinstance(data, dict) and 'results' in data:
  80. self.ocr_data = data['results']
  81. else:
  82. st.error(f"❌ 不支持的JSON格式: {json_path}")
  83. return
  84. except Exception as e:
  85. st.error(f"❌ 加载JSON文件失败: {e}")
  86. return
  87. # 推断MD文件路径
  88. if md_path is None:
  89. md_file = json_file.with_suffix('.md')
  90. else:
  91. md_file = Path(md_path)
  92. if md_file.exists():
  93. with open(md_file, 'r', encoding='utf-8') as f:
  94. self.md_content = f.read()
  95. # 推断图片路径
  96. if image_path is None:
  97. image_name = json_file.stem
  98. sample_data_dir = Path("./sample_data")
  99. image_candidates = [
  100. sample_data_dir / f"{image_name}.png",
  101. sample_data_dir / f"{image_name}.jpg",
  102. json_file.parent / f"{image_name}.png",
  103. json_file.parent / f"{image_name}.jpg",
  104. ]
  105. for candidate in image_candidates:
  106. if candidate.exists():
  107. self.image_path = str(candidate)
  108. break
  109. else:
  110. self.image_path = image_path
  111. # 处理数据
  112. self.process_data()
  113. def process_data(self):
  114. """处理OCR数据,建立文本到bbox的映射"""
  115. self.text_bbox_mapping = {}
  116. # 确保 ocr_data 是列表
  117. if not isinstance(self.ocr_data, list):
  118. st.warning("⚠️ OCR数据格式不正确,期望列表格式")
  119. return
  120. for i, item in enumerate(self.ocr_data):
  121. # 确保 item 是字典类型
  122. if not isinstance(item, dict):
  123. continue
  124. if 'text' in item and 'bbox' in item:
  125. text = str(item['text']).strip()
  126. if text and text not in ['Picture', '']:
  127. bbox = item['bbox']
  128. # 确保bbox是4个数字的列表
  129. if isinstance(bbox, list) and len(bbox) == 4:
  130. if text not in self.text_bbox_mapping:
  131. self.text_bbox_mapping[text] = []
  132. self.text_bbox_mapping[text].append({
  133. 'bbox': bbox,
  134. 'category': item.get('category', 'Text'),
  135. 'index': i,
  136. 'confidence': item.get('confidence', 1.0)
  137. })
  138. def draw_bbox_on_image(self, image: Image.Image, bbox: List[int], color: str = "red", width: int = 3) -> Image.Image:
  139. """在图片上绘制bbox框"""
  140. img_copy = image.copy()
  141. draw = ImageDraw.Draw(img_copy)
  142. x1, y1, x2, y2 = bbox
  143. # 绘制矩形框
  144. draw.rectangle([x1, y1, x2, y2], outline=color, width=width)
  145. # 添加半透明填充
  146. overlay = Image.new('RGBA', img_copy.size, (0, 0, 0, 0))
  147. overlay_draw = ImageDraw.Draw(overlay)
  148. if color == "red":
  149. fill_color = (255, 0, 0, 30)
  150. elif color == "blue":
  151. fill_color = (0, 0, 255, 30)
  152. elif color == "green":
  153. fill_color = (0, 255, 0, 30)
  154. else:
  155. fill_color = (255, 255, 0, 30)
  156. overlay_draw.rectangle([x1, y1, x2, y2], fill=fill_color)
  157. img_copy = Image.alpha_composite(img_copy.convert('RGBA'), overlay).convert('RGB')
  158. return img_copy
  159. def create_interactive_plot(self, image: Image.Image, selected_bbox: Optional[List[int]] = None) -> go.Figure:
  160. """创建交互式图片显示"""
  161. fig = go.Figure()
  162. # 添加图片
  163. fig.add_layout_image(
  164. dict(
  165. source=image,
  166. xref="x",
  167. yref="y",
  168. x=0,
  169. y=image.height,
  170. sizex=image.width,
  171. sizey=image.height,
  172. sizing="stretch",
  173. opacity=1.0,
  174. layer="below"
  175. )
  176. )
  177. # 添加所有bbox(浅色显示)
  178. for text, info_list in self.text_bbox_mapping.items():
  179. for info in info_list:
  180. bbox = info['bbox']
  181. if len(bbox) >= 4: # 确保bbox有足够的坐标
  182. x1, y1, x2, y2 = bbox[:4]
  183. color = "rgba(0, 100, 200, 0.2)" # 默认浅蓝色
  184. if text in self.marked_errors:
  185. color = "rgba(255, 0, 0, 0.3)" # 错误标记为红色
  186. fig.add_shape(
  187. type="rect",
  188. x0=x1, y0=image.height-y2,
  189. x1=x2, y1=image.height-y1,
  190. line=dict(color=color.replace('0.2', '0.8').replace('0.3', '1.0'), width=1),
  191. fillcolor=color,
  192. )
  193. # 高亮显示选中的bbox
  194. if selected_bbox and len(selected_bbox) >= 4:
  195. x1, y1, x2, y2 = selected_bbox[:4]
  196. fig.add_shape(
  197. type="rect",
  198. x0=x1, y0=image.height-y2,
  199. x1=x2, y1=image.height-y1,
  200. line=dict(color="red", width=3),
  201. fillcolor="rgba(255, 0, 0, 0.2)",
  202. )
  203. # 设置布局
  204. fig.update_xaxes(
  205. visible=False,
  206. range=[0, image.width]
  207. )
  208. fig.update_yaxes(
  209. visible=False,
  210. range=[0, image.height],
  211. scaleanchor="x"
  212. )
  213. fig.update_layout(
  214. width=800,
  215. height=600,
  216. margin=dict(l=0, r=0, t=0, b=0),
  217. xaxis_showgrid=False,
  218. yaxis_showgrid=False,
  219. plot_bgcolor='white'
  220. )
  221. return fig
  222. def get_statistics(self) -> Dict:
  223. """获取统计信息"""
  224. # 先确保 ocr_data 不为空且是列表
  225. if not isinstance(self.ocr_data, list) or not self.ocr_data:
  226. return {
  227. 'total_texts': 0,
  228. 'clickable_texts': 0,
  229. 'marked_errors': 0,
  230. 'categories': {},
  231. 'accuracy_rate': 0
  232. }
  233. total_texts = len(self.ocr_data)
  234. clickable_texts = len(self.text_bbox_mapping)
  235. marked_errors = len(self.marked_errors)
  236. # 按类别统计 - 添加类型检查
  237. categories = {}
  238. for item in self.ocr_data:
  239. # 确保 item 是字典类型
  240. if isinstance(item, dict):
  241. category = item.get('category', 'Unknown')
  242. elif isinstance(item, str):
  243. category = 'Text' # 字符串类型默认为 Text 类别
  244. else:
  245. category = 'Unknown'
  246. categories[category] = categories.get(category, 0) + 1
  247. return {
  248. 'total_texts': total_texts,
  249. 'clickable_texts': clickable_texts,
  250. 'marked_errors': marked_errors,
  251. 'categories': categories,
  252. 'accuracy_rate': (clickable_texts - marked_errors) / clickable_texts * 100 if clickable_texts > 0 else 0
  253. }
  254. def convert_html_table_to_markdown(self, content: str) -> str:
  255. """将HTML表格转换为Markdown表格格式"""
  256. import re
  257. from html import unescape
  258. # 简单的HTML表格到Markdown转换
  259. def replace_table(match):
  260. table_html = match.group(0)
  261. # 提取所有行
  262. rows = re.findall(r'<tr>(.*?)</tr>', table_html, re.DOTALL | re.IGNORECASE)
  263. if not rows:
  264. return table_html # 如果没有找到行,返回原始内容
  265. markdown_rows = []
  266. for i, row in enumerate(rows):
  267. # 提取单元格
  268. cells = re.findall(r'<td[^>]*>(.*?)</td>', row, re.DOTALL | re.IGNORECASE)
  269. if cells:
  270. # 清理单元格内容
  271. clean_cells = []
  272. for cell in cells:
  273. # 移除HTML标签,保留文本
  274. cell_text = re.sub(r'<[^>]+>', '', cell).strip()
  275. cell_text = unescape(cell_text) # 解码HTML实体
  276. clean_cells.append(cell_text)
  277. # 构建Markdown行
  278. markdown_row = '| ' + ' | '.join(clean_cells) + ' |'
  279. markdown_rows.append(markdown_row)
  280. # 在第一行后添加分隔符
  281. if i == 0:
  282. separator = '| ' + ' | '.join(['---'] * len(clean_cells)) + ' |'
  283. markdown_rows.append(separator)
  284. return '\n'.join(markdown_rows) if markdown_rows else table_html
  285. # 替换所有HTML表格
  286. converted = re.sub(r'<table[^>]*>.*?</table>', replace_table, content, flags=re.DOTALL | re.IGNORECASE)
  287. return converted
  288. def render_markdown_with_options(self, markdown_content: str, table_format: str = "grid", escape_html: bool = True):
  289. """自定义Markdown渲染方法,支持多种选项"""
  290. import markdown
  291. # 处理HTML表格
  292. if escape_html:
  293. markdown_content = self.convert_html_table_to_markdown(markdown_content)
  294. # 渲染Markdown
  295. html_content = markdown.markdown(markdown_content)
  296. # 根据选项包裹在特定的HTML结构中
  297. if table_format == "grid":
  298. # 网格布局
  299. wrapped_content = f"""
  300. <div class="markdown-grid">
  301. {html_content}
  302. </div>
  303. """
  304. elif table_format == "list":
  305. # 列表布局
  306. wrapped_content = f"""
  307. <div class="markdown-list">
  308. {html_content}
  309. </div>
  310. """
  311. else:
  312. # 默认直接返回
  313. wrapped_content = html_content
  314. return wrapped_content
  315. def display_html_table_as_dataframe(self, html_content: str):
  316. """将HTML表格解析为DataFrame显示"""
  317. import pandas as pd
  318. from io import StringIO
  319. try:
  320. # 使用pandas直接读取HTML表格
  321. tables = pd.read_html(StringIO(html_content))
  322. if tables:
  323. for i, table in enumerate(tables):
  324. st.subheader(f"表格 {i+1}")
  325. st.dataframe(table, use_container_width=True)
  326. except Exception as e:
  327. st.error(f"表格解析失败: {e}")
  328. # 回退到HTML渲染
  329. st.markdown(html_content, unsafe_allow_html=True)
  330. def main():
  331. """主应用"""
  332. st.title("🔍 OCR可视化校验工具")
  333. st.markdown("---")
  334. # 初始化session state
  335. if 'validator' not in st.session_state:
  336. st.session_state.validator = StreamlitOCRValidator()
  337. if 'selected_text' not in st.session_state:
  338. st.session_state.selected_text = None
  339. if 'marked_errors' not in st.session_state:
  340. st.session_state.marked_errors = set()
  341. # 同步标记的错误到validator
  342. st.session_state.validator.marked_errors = st.session_state.marked_errors
  343. # 侧边栏 - 文件选择和控制
  344. with st.sidebar:
  345. st.header("📁 文件选择")
  346. # 查找可用的OCR文件
  347. output_dir = Path("output")
  348. available_files = []
  349. if output_dir.exists():
  350. for json_file in output_dir.rglob("*.json"):
  351. available_files.append(str(json_file))
  352. if available_files:
  353. selected_file = st.selectbox(
  354. "选择OCR结果文件",
  355. available_files,
  356. index=0
  357. )
  358. if st.button("🔄 加载文件", type="primary") and selected_file:
  359. try:
  360. st.session_state.validator.load_ocr_data(selected_file)
  361. st.success("✅ 文件加载成功!")
  362. st.rerun() # 重新运行应用以更新界面
  363. except Exception as e:
  364. st.error(f"❌ 加载失败: {e}")
  365. else:
  366. st.warning("未找到OCR结果文件")
  367. st.info("请确保output目录下有OCR结果文件")
  368. st.markdown("---")
  369. # 控制面板
  370. st.header("🎛️ 控制面板")
  371. if st.button("🧹 清除选择"):
  372. st.session_state.selected_text = None
  373. st.rerun()
  374. if st.button("❌ 清除错误标记"):
  375. st.session_state.marked_errors = set()
  376. st.rerun()
  377. # 显示调试信息
  378. if st.checkbox("🔧 调试信息"):
  379. st.write("**当前状态:**")
  380. st.write(f"- OCR数据项数: {len(st.session_state.validator.ocr_data)}")
  381. st.write(f"- 可点击文本: {len(st.session_state.validator.text_bbox_mapping)}")
  382. st.write(f"- 选中文本: {st.session_state.selected_text}")
  383. st.write(f"- 标记错误数: {len(st.session_state.marked_errors)}")
  384. if st.session_state.validator.ocr_data:
  385. st.write("**数据类型检查:**")
  386. sample_item = st.session_state.validator.ocr_data[0] if st.session_state.validator.ocr_data else None
  387. st.write(f"- 第一项类型: {type(sample_item)}")
  388. if isinstance(sample_item, dict):
  389. st.write(f"- 第一项键: {list(sample_item.keys())}")
  390. # 主内容区域
  391. if not st.session_state.validator.ocr_data:
  392. st.info("👈 请在左侧选择并加载OCR结果文件")
  393. return
  394. # 显示统计信息
  395. try:
  396. stats = st.session_state.validator.get_statistics()
  397. col1, col2, col3, col4 = st.columns(4)
  398. with col1:
  399. st.metric("📊 总文本块", stats['total_texts'])
  400. with col2:
  401. st.metric("🔗 可点击文本", stats['clickable_texts'])
  402. with col3:
  403. st.metric("❌ 标记错误", stats['marked_errors'])
  404. with col4:
  405. st.metric("✅ 准确率", f"{stats['accuracy_rate']:.1f}%")
  406. st.markdown("---")
  407. except Exception as e:
  408. st.error(f"❌ 统计信息计算失败: {e}")
  409. return
  410. # 主要布局 - 左右分栏
  411. left_col, right_col = st.columns([1, 1])
  412. # 左侧 - OCR文本内容
  413. with left_col:
  414. st.header("📄 OCR识别内容")
  415. # 文本选择器
  416. if st.session_state.validator.text_bbox_mapping:
  417. text_options = ["请选择文本..."] + list(st.session_state.validator.text_bbox_mapping.keys())
  418. selected_index = st.selectbox(
  419. "选择要校验的文本",
  420. range(len(text_options)),
  421. format_func=lambda x: text_options[x],
  422. key="text_selector"
  423. )
  424. if selected_index > 0:
  425. st.session_state.selected_text = text_options[selected_index]
  426. else:
  427. st.warning("没有找到可点击的文本")
  428. # 显示MD内容(可搜索和过滤)
  429. if st.session_state.validator.md_content:
  430. search_term = st.text_input("🔍 搜索文本内容", placeholder="输入关键词搜索...")
  431. display_content = st.session_state.validator.md_content
  432. if search_term:
  433. lines = display_content.split('\n')
  434. filtered_lines = [line for line in lines if search_term.lower() in line.lower()]
  435. display_content = '\n'.join(filtered_lines)
  436. if filtered_lines:
  437. st.success(f"找到 {len(filtered_lines)} 行包含 '{search_term}'")
  438. else:
  439. st.warning(f"未找到包含 '{search_term}' 的内容")
  440. # 渲染方式选择
  441. render_mode = st.radio(
  442. "选择渲染方式",
  443. ["HTML渲染", "Markdown渲染", "DataFrame表格", "原始文本"], # 添加DataFrame选项
  444. horizontal=True
  445. )
  446. if render_mode == "HTML渲染":
  447. # 使用unsafe_allow_html=True来渲染HTML表格
  448. st.markdown(display_content, unsafe_allow_html=True)
  449. elif render_mode == "Markdown渲染":
  450. # 转换HTML表格为Markdown格式
  451. converted_content = st.session_state.validator.convert_html_table_to_markdown(display_content)
  452. st.markdown(converted_content)
  453. elif render_mode == "DataFrame表格":
  454. # 新增:使用DataFrame显示表格
  455. if '<table>' in display_content.lower():
  456. st.session_state.validator.display_html_table_as_dataframe(display_content)
  457. else:
  458. st.info("当前内容中没有检测到HTML表格")
  459. st.markdown(display_content)
  460. else:
  461. # 原始文本显示
  462. st.text_area(
  463. "MD内容预览",
  464. display_content,
  465. height=300,
  466. help="OCR识别的文本内容"
  467. )
  468. # 可点击文本列表
  469. st.subheader("🎯 可点击文本列表")
  470. if st.session_state.validator.text_bbox_mapping:
  471. for text, info_list in st.session_state.validator.text_bbox_mapping.items():
  472. info = info_list[0] # 使用第一个bbox信息
  473. # 确定显示样式
  474. is_selected = (text == st.session_state.selected_text)
  475. is_error = (text in st.session_state.marked_errors)
  476. # 创建按钮行
  477. button_col, error_col = st.columns([4, 1])
  478. with button_col:
  479. button_type = "primary" if is_selected else "secondary"
  480. if st.button(f"📍 {text}", key=f"btn_{text}", type=button_type):
  481. st.session_state.selected_text = text
  482. st.rerun()
  483. with error_col:
  484. if is_error:
  485. if st.button("✅", key=f"fix_{text}", help="取消错误标记"):
  486. st.session_state.marked_errors.discard(text)
  487. st.rerun()
  488. else:
  489. if st.button("❌", key=f"error_{text}", help="标记为错误"):
  490. st.session_state.marked_errors.add(text)
  491. st.rerun()
  492. else:
  493. st.info("没有可点击的文本项目")
  494. # 右侧 - 图像显示
  495. with right_col:
  496. st.header("🖼️ 原图标注")
  497. if st.session_state.validator.image_path and Path(st.session_state.validator.image_path).exists():
  498. try:
  499. # 加载图片
  500. image = Image.open(st.session_state.validator.image_path)
  501. # 创建交互式图片
  502. selected_bbox = None
  503. if st.session_state.selected_text and st.session_state.selected_text in st.session_state.validator.text_bbox_mapping:
  504. info = st.session_state.validator.text_bbox_mapping[st.session_state.selected_text][0]
  505. selected_bbox = info['bbox']
  506. fig = st.session_state.validator.create_interactive_plot(image, selected_bbox)
  507. st.plotly_chart(fig, use_container_width=True)
  508. # 显示选中文本的详细信息
  509. if st.session_state.selected_text:
  510. st.subheader("📍 选中文本详情")
  511. if st.session_state.selected_text in st.session_state.validator.text_bbox_mapping:
  512. info = st.session_state.validator.text_bbox_mapping[st.session_state.selected_text][0]
  513. bbox = info['bbox']
  514. info_col1, info_col2 = st.columns(2)
  515. with info_col1:
  516. st.write(f"**文本内容:** {st.session_state.selected_text}")
  517. st.write(f"**类别:** {info['category']}")
  518. st.write(f"**置信度:** {info.get('confidence', 'N/A')}")
  519. with info_col2:
  520. st.write(f"**位置:** [{', '.join(map(str, bbox))}]")
  521. if len(bbox) >= 4:
  522. st.write(f"**宽度:** {bbox[2] - bbox[0]} px")
  523. st.write(f"**高度:** {bbox[3] - bbox[1]} px")
  524. # 标记状态
  525. is_error = st.session_state.selected_text in st.session_state.marked_errors
  526. if is_error:
  527. st.error("⚠️ 此文本已标记为错误")
  528. else:
  529. st.success("✅ 此文本未标记错误")
  530. except Exception as e:
  531. st.error(f"❌ 图片处理失败: {e}")
  532. else:
  533. st.error("未找到对应的图片文件")
  534. if st.session_state.validator.image_path:
  535. st.write(f"期望路径: {st.session_state.validator.image_path}")
  536. if __name__ == "__main__":
  537. main()