streamlit_ocr_validator.py 52 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275
  1. #!/usr/bin/env python3
  2. """
  3. 基于Streamlit的OCR可视化校验工具(修复版)
  4. 提供丰富的交互组件和更好的用户体验
  5. """
  6. import streamlit as st
  7. import json
  8. import pandas as pd
  9. from pathlib import Path
  10. import numpy as np
  11. from PIL import Image, ImageDraw, ImageFont
  12. import cv2
  13. import base64
  14. from typing import Dict, List, Optional, Tuple
  15. import plotly.express as px
  16. import plotly.graph_objects as go
  17. from plotly.subplots import make_subplots
  18. from io import StringIO, BytesIO
  19. # 设置页面配置
  20. st.set_page_config(
  21. page_title="OCR可视化校验工具",
  22. page_icon="🔍",
  23. layout="wide",
  24. initial_sidebar_state="expanded"
  25. )
  26. # 自定义CSS样式 - 修复背景和文字颜色
  27. st.markdown("""
  28. <style>
  29. /* 设置主体背景为白色 */
  30. .main > div {
  31. padding-top: 2rem;
  32. background-color: white !important;
  33. color: #333333 !important;
  34. }
  35. /* 设置整体页面背景 */
  36. .stApp {
  37. background-color: white !important;
  38. }
  39. /* 设置内容区域背景 */
  40. .block-container {
  41. background-color: white !important;
  42. color: #333333 !important;
  43. }
  44. /* 设置侧边栏样式 */
  45. .css-1d391kg {
  46. background-color: #f8f9fa !important;
  47. }
  48. /* 设置选择框样式 */
  49. .stSelectbox > div > div > div {
  50. background-color: #f0f2f6 !important;
  51. color: #333333 !important;
  52. }
  53. /* 设置标题样式 */
  54. h1, h2, h3, h4, h5, h6 {
  55. color: #333333 !important;
  56. }
  57. /* 设置文本样式 */
  58. p, div, span, label {
  59. color: #333333 !important;
  60. }
  61. /* 可点击文本样式 */
  62. .clickable-text {
  63. background-color: #e1f5fe;
  64. padding: 2px 6px;
  65. border-radius: 4px;
  66. border: 1px solid #0288d1;
  67. cursor: pointer;
  68. margin: 2px;
  69. display: inline-block;
  70. color: #0288d1 !important;
  71. }
  72. .selected-text {
  73. background-color: #fff3e0;
  74. border-color: #ff9800;
  75. font-weight: bold;
  76. color: #ff9800 !important;
  77. }
  78. .error-text {
  79. background-color: #ffebee;
  80. border-color: #f44336;
  81. color: #d32f2f !important;
  82. }
  83. .stats-container {
  84. background-color: #f8f9fa;
  85. padding: 1rem;
  86. border-radius: 8px;
  87. border-left: 4px solid #28a745;
  88. color: #333333 !important;
  89. }
  90. /* 修复滚动内容区域样式 */
  91. .scrollable-content {
  92. background-color: #fafafa !important;
  93. color: #333333 !important;
  94. border: 1px solid #ddd !important;
  95. }
  96. /* 修复紧凑内容样式 */
  97. .compact-content {
  98. background-color: #fafafa !important;
  99. color: #333333 !important;
  100. border: 1px solid #ddd !important;
  101. }
  102. /* 高亮文本样式 */
  103. .highlight-text {
  104. background-color: #ffeb3b !important;
  105. color: #333333 !important;
  106. padding: 2px 4px;
  107. border-radius: 3px;
  108. cursor: pointer;
  109. }
  110. .selected-highlight {
  111. background-color: #4caf50 !important;
  112. color: white !important;
  113. }
  114. /* 标准布局内容样式 */
  115. .standard-content {
  116. background-color: #fafafa !important;
  117. color: #333333 !important;
  118. border: 1px solid #ddd !important;
  119. }
  120. </style>
  121. """, unsafe_allow_html=True)
  122. class StreamlitOCRValidator:
  123. def __init__(self):
  124. self.ocr_data = []
  125. self.md_content = ""
  126. self.image_path = ""
  127. self.text_bbox_mapping = {}
  128. self.selected_text = None
  129. self.marked_errors = set()
  130. def load_ocr_data(self, json_path: str, md_path: Optional[str] = None, image_path: Optional[str] = None):
  131. """加载OCR相关数据"""
  132. json_file = Path(json_path)
  133. # 加载JSON数据
  134. try:
  135. with open(json_file, 'r', encoding='utf-8') as f:
  136. data = json.load(f)
  137. # 确保数据是列表格式
  138. if isinstance(data, list):
  139. self.ocr_data = data
  140. elif isinstance(data, dict) and 'results' in data:
  141. self.ocr_data = data['results']
  142. else:
  143. st.error(f"❌ 不支持的JSON格式: {json_path}")
  144. return
  145. except Exception as e:
  146. st.error(f"❌ 加载JSON文件失败: {e}")
  147. return
  148. # 推断MD文件路径
  149. if md_path is None:
  150. md_file = json_file.with_suffix('.md')
  151. else:
  152. md_file = Path(md_path)
  153. if md_file.exists():
  154. with open(md_file, 'r', encoding='utf-8') as f:
  155. self.md_content = f.read()
  156. # 推断图片路径
  157. if image_path is None:
  158. image_name = json_file.stem
  159. sample_data_dir = Path("./sample_data")
  160. image_candidates = [
  161. sample_data_dir / f"{image_name}.png",
  162. sample_data_dir / f"{image_name}.jpg",
  163. json_file.parent / f"{image_name}.png",
  164. json_file.parent / f"{image_name}.jpg",
  165. ]
  166. for candidate in image_candidates:
  167. if candidate.exists():
  168. self.image_path = str(candidate)
  169. break
  170. else:
  171. self.image_path = image_path
  172. # 处理数据
  173. self.process_data()
  174. def process_data(self):
  175. """处理OCR数据,建立文本到bbox的映射"""
  176. self.text_bbox_mapping = {}
  177. # 确保 ocr_data 是列表
  178. if not isinstance(self.ocr_data, list):
  179. st.warning("⚠️ OCR数据格式不正确,期望列表格式")
  180. return
  181. for i, item in enumerate(self.ocr_data):
  182. # 确保 item 是字典类型
  183. if not isinstance(item, dict):
  184. continue
  185. if 'text' in item and 'bbox' in item:
  186. text = str(item['text']).strip()
  187. if text and text not in ['Picture', '']:
  188. bbox = item['bbox']
  189. # 确保bbox是4个数字的列表
  190. if isinstance(bbox, list) and len(bbox) == 4:
  191. if text not in self.text_bbox_mapping:
  192. self.text_bbox_mapping[text] = []
  193. self.text_bbox_mapping[text].append({
  194. 'bbox': bbox,
  195. 'category': item.get('category', 'Text'),
  196. 'index': i,
  197. 'confidence': item.get('confidence', 1.0)
  198. })
  199. def draw_bbox_on_image(self, image: Image.Image, bbox: List[int], color: str = "red", width: int = 3) -> Image.Image:
  200. """在图片上绘制bbox框"""
  201. img_copy = image.copy()
  202. draw = ImageDraw.Draw(img_copy)
  203. x1, y1, x2, y2 = bbox
  204. # 绘制矩形框
  205. draw.rectangle([x1, y1, x2, y2], outline=color, width=width)
  206. # 添加半透明填充
  207. overlay = Image.new('RGBA', img_copy.size, (0, 0, 0, 0))
  208. overlay_draw = ImageDraw.Draw(overlay)
  209. if color == "red":
  210. fill_color = (255, 0, 0, 30)
  211. elif color == "blue":
  212. fill_color = (0, 0, 255, 30)
  213. elif color == "green":
  214. fill_color = (0, 255, 0, 30)
  215. else:
  216. fill_color = (255, 255, 0, 30)
  217. overlay_draw.rectangle([x1, y1, x2, y2], fill=fill_color)
  218. img_copy = Image.alpha_composite(img_copy.convert('RGBA'), overlay).convert('RGB')
  219. return img_copy
  220. def create_interactive_plot(self, image: Image.Image, selected_bbox: Optional[List[int]] = None) -> go.Figure:
  221. """创建交互式图片显示"""
  222. fig = go.Figure()
  223. # 添加图片
  224. fig.add_layout_image(
  225. dict(
  226. source=image,
  227. xref="x",
  228. yref="y",
  229. x=0,
  230. y=image.height,
  231. sizex=image.width,
  232. sizey=image.height,
  233. sizing="stretch",
  234. opacity=1.0,
  235. layer="below"
  236. )
  237. )
  238. # 添加所有bbox(浅色显示)
  239. for text, info_list in self.text_bbox_mapping.items():
  240. for info in info_list:
  241. bbox = info['bbox']
  242. if len(bbox) >= 4: # 确保bbox有足够的坐标
  243. x1, y1, x2, y2 = bbox[:4]
  244. color = "rgba(0, 100, 200, 0.2)" # 默认浅蓝色
  245. if text in self.marked_errors:
  246. color = "rgba(255, 0, 0, 0.3)" # 错误标记为红色
  247. fig.add_shape(
  248. type="rect",
  249. x0=x1, y0=image.height-y2,
  250. x1=x2, y1=image.height-y1,
  251. line=dict(color=color.replace('0.2', '0.8').replace('0.3', '1.0'), width=1),
  252. fillcolor=color,
  253. )
  254. # 高亮显示选中的bbox
  255. if selected_bbox and len(selected_bbox) >= 4:
  256. x1, y1, x2, y2 = selected_bbox[:4]
  257. fig.add_shape(
  258. type="rect",
  259. x0=x1, y0=image.height-y2,
  260. x1=x2, y1=image.height-y1,
  261. line=dict(color="red", width=3),
  262. fillcolor="rgba(255, 0, 0, 0.2)",
  263. )
  264. # 设置布局
  265. fig.update_xaxes(
  266. visible=False,
  267. range=[0, image.width]
  268. )
  269. fig.update_yaxes(
  270. visible=False,
  271. range=[0, image.height],
  272. scaleanchor="x"
  273. )
  274. fig.update_layout(
  275. width=800,
  276. height=600,
  277. margin=dict(l=0, r=0, t=0, b=0),
  278. xaxis_showgrid=False,
  279. yaxis_showgrid=False,
  280. plot_bgcolor='white'
  281. )
  282. return fig
  283. def get_statistics(self) -> Dict:
  284. """获取统计信息"""
  285. # 先确保 ocr_data 不为空且是列表
  286. if not isinstance(self.ocr_data, list) or not self.ocr_data:
  287. return {
  288. 'total_texts': 0,
  289. 'clickable_texts': 0,
  290. 'marked_errors': 0,
  291. 'categories': {},
  292. 'accuracy_rate': 0
  293. }
  294. total_texts = len(self.ocr_data)
  295. clickable_texts = len(self.text_bbox_mapping)
  296. marked_errors = len(self.marked_errors)
  297. # 按类别统计 - 添加类型检查
  298. categories = {}
  299. for item in self.ocr_data:
  300. # 确保 item 是字典类型
  301. if isinstance(item, dict):
  302. category = item.get('category', 'Unknown')
  303. elif isinstance(item, str):
  304. category = 'Text' # 字符串类型默认为 Text 类别
  305. else:
  306. category = 'Unknown'
  307. categories[category] = categories.get(category, 0) + 1
  308. return {
  309. 'total_texts': total_texts,
  310. 'clickable_texts': clickable_texts,
  311. 'marked_errors': marked_errors,
  312. 'categories': categories,
  313. 'accuracy_rate': (clickable_texts - marked_errors) / clickable_texts * 100 if clickable_texts > 0 else 0
  314. }
  315. def convert_html_table_to_markdown(self, content: str) -> str:
  316. """将HTML表格转换为Markdown表格格式"""
  317. import re
  318. from html import unescape
  319. # 简单的HTML表格到Markdown转换
  320. def replace_table(match):
  321. table_html = match.group(0)
  322. # 提取所有行
  323. rows = re.findall(r'<tr>(.*?)</tr>', table_html, re.DOTALL | re.IGNORECASE)
  324. if not rows:
  325. return table_html # 如果没有找到行,返回原始内容
  326. markdown_rows = []
  327. for i, row in enumerate(rows):
  328. # 提取单元格
  329. cells = re.findall(r'<td[^>]*>(.*?)</td>', row, re.DOTALL | re.IGNORECASE)
  330. if cells:
  331. # 清理单元格内容
  332. clean_cells = []
  333. for cell in cells:
  334. # 移除HTML标签,保留文本
  335. cell_text = re.sub(r'<[^>]+>', '', cell).strip()
  336. cell_text = unescape(cell_text) # 解码HTML实体
  337. clean_cells.append(cell_text)
  338. # 构建Markdown行
  339. markdown_row = '| ' + ' | '.join(clean_cells) + ' |'
  340. markdown_rows.append(markdown_row)
  341. # 在第一行后添加分隔符
  342. if i == 0:
  343. separator = '| ' + ' | '.join(['---'] * len(clean_cells)) + ' |'
  344. markdown_rows.append(separator)
  345. return '\n'.join(markdown_rows) if markdown_rows else table_html
  346. # 替换所有HTML表格
  347. converted = re.sub(r'<table[^>]*>.*?</table>', replace_table, content, flags=re.DOTALL | re.IGNORECASE)
  348. return converted
  349. def display_html_table_as_dataframe(self, html_content: str, enable_editing: bool = False):
  350. """将HTML表格解析为DataFrame显示"""
  351. try:
  352. # 使用pandas直接读取HTML表格
  353. tables = pd.read_html(StringIO(html_content))
  354. if tables:
  355. for i, table in enumerate(tables):
  356. st.subheader(f"📊 表格 {i+1}")
  357. # 创建表格操作按钮
  358. col1, col2, col3, col4 = st.columns(4)
  359. with col1:
  360. show_info = st.checkbox(f"显示表格信息", key=f"info_{i}")
  361. with col2:
  362. show_stats = st.checkbox(f"显示统计信息", key=f"stats_{i}")
  363. with col3:
  364. enable_filter = st.checkbox(f"启用过滤", key=f"filter_{i}")
  365. with col4:
  366. enable_sort = st.checkbox(f"启用排序", key=f"sort_{i}")
  367. # 数据过滤
  368. filtered_table = table.copy()
  369. if enable_filter and not table.empty:
  370. filter_col = st.selectbox(
  371. f"选择过滤列 (表格 {i+1})",
  372. options=['无'] + list(table.columns),
  373. key=f"filter_col_{i}"
  374. )
  375. if filter_col != '无':
  376. filter_value = st.text_input(
  377. f"过滤值 (表格 {i+1})",
  378. key=f"filter_value_{i}"
  379. )
  380. if filter_value:
  381. filtered_table = table[
  382. table[filter_col].astype(str).str.contains(filter_value, na=False)
  383. ]
  384. # 数据排序
  385. if enable_sort and not filtered_table.empty:
  386. sort_col = st.selectbox(
  387. f"选择排序列 (表格 {i+1})",
  388. options=['无'] + list(filtered_table.columns),
  389. key=f"sort_col_{i}"
  390. )
  391. if sort_col != '无':
  392. sort_order = st.radio(
  393. f"排序方式 (表格 {i+1})",
  394. options=['升序', '降序'],
  395. horizontal=True,
  396. key=f"sort_order_{i}"
  397. )
  398. ascending = (sort_order == '升序')
  399. filtered_table = filtered_table.sort_values(sort_col, ascending=ascending)
  400. # 显示表格
  401. if enable_editing:
  402. # 可编辑表格
  403. edited_table = st.data_editor(
  404. filtered_table,
  405. use_container_width=True,
  406. key=f"editor_{i}"
  407. )
  408. # 检查是否有编辑
  409. if not edited_table.equals(filtered_table):
  410. st.success("✏️ 表格已编辑,可以导出修改后的数据")
  411. else:
  412. # 只读表格
  413. st.dataframe(filtered_table, use_container_width=True)
  414. # 显示表格信息
  415. if show_info:
  416. st.write(f"**表格信息:**")
  417. st.write(f"- 原始行数: {len(table)}")
  418. st.write(f"- 过滤后行数: {len(filtered_table)}")
  419. st.write(f"- 列数: {len(table.columns)}")
  420. st.write(f"- 列名: {', '.join(table.columns)}")
  421. # 显示统计信息
  422. if show_stats:
  423. st.write(f"**统计信息:**")
  424. numeric_cols = filtered_table.select_dtypes(include=[np.number]).columns
  425. if len(numeric_cols) > 0:
  426. st.dataframe(filtered_table[numeric_cols].describe())
  427. else:
  428. st.info("表格中没有数值列")
  429. # 导出功能
  430. if st.button(f"📥 导出表格 {i+1}", key=f"export_{i}"):
  431. # 创建CSV数据
  432. csv_data = filtered_table.to_csv(index=False)
  433. st.download_button(
  434. label=f"下载CSV (表格 {i+1})",
  435. data=csv_data,
  436. file_name=f"table_{i+1}.csv",
  437. mime="text/csv",
  438. key=f"download_csv_{i}"
  439. )
  440. # 创建Excel数据
  441. excel_buffer = BytesIO()
  442. filtered_table.to_excel(excel_buffer, index=False)
  443. st.download_button(
  444. label=f"下载Excel (表格 {i+1})",
  445. data=excel_buffer.getvalue(),
  446. file_name=f"table_{i+1}.xlsx",
  447. mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
  448. key=f"download_excel_{i}"
  449. )
  450. st.markdown("---")
  451. else:
  452. st.warning("未找到可解析的表格")
  453. except Exception as e:
  454. st.error(f"表格解析失败: {e}")
  455. st.info("尝试使用HTML渲染模式查看表格")
  456. # 回退到HTML渲染
  457. st.markdown(html_content, unsafe_allow_html=True)
  458. def create_standard_layout(self, font_size: int = 12, zoom_level: float = 1.0):
  459. """创建标准布局 - 封装版本"""
  460. # 主要内容区域
  461. left_col, right_col = st.columns([0.7, 1])
  462. with left_col:
  463. st.header("📄 OCR识别内容")
  464. # 文本选择器
  465. if self.text_bbox_mapping:
  466. text_options = ["请选择文本..."] + list(self.text_bbox_mapping.keys())
  467. selected_index = st.selectbox(
  468. "选择要校验的文本",
  469. range(len(text_options)),
  470. format_func=lambda x: text_options[x][:50] + "..." if len(text_options[x]) > 50 else text_options[x],
  471. key="standard_text_selector"
  472. )
  473. if selected_index > 0:
  474. st.session_state.selected_text = text_options[selected_index]
  475. else:
  476. st.warning("没有找到可点击的文本")
  477. # 显示MD内容
  478. if self.md_content:
  479. search_term = st.text_input("🔍 搜索文本内容", placeholder="输入关键词搜索...", key="standard_search")
  480. display_content = self.md_content
  481. if search_term:
  482. lines = display_content.split('\n')
  483. filtered_lines = [line for line in lines if search_term.lower() in line.lower()]
  484. display_content = '\n'.join(filtered_lines)
  485. if filtered_lines:
  486. st.success(f"找到 {len(filtered_lines)} 行包含 '{search_term}'")
  487. else:
  488. st.warning(f"未找到包含 '{search_term}' 的内容")
  489. # 渲染方式选择
  490. render_mode = st.radio(
  491. "选择渲染方式",
  492. ["HTML渲染", "Markdown渲染", "DataFrame表格", "原始文本"],
  493. horizontal=True,
  494. key="standard_render_mode"
  495. )
  496. # 应用字体大小到内容显示
  497. content_style = f"""
  498. <style>
  499. .standard-content-display {{
  500. font-size: {font_size}px !important;
  501. line-height: 1.4;
  502. color: #333333 !important;
  503. background-color: #fafafa !important;
  504. padding: 10px;
  505. border-radius: 5px;
  506. border: 1px solid #ddd;
  507. }}
  508. </style>
  509. """
  510. st.markdown(content_style, unsafe_allow_html=True)
  511. if render_mode == "HTML渲染":
  512. st.markdown(f'<div class="standard-content-display">{display_content}</div>', unsafe_allow_html=True)
  513. elif render_mode == "Markdown渲染":
  514. converted_content = self.convert_html_table_to_markdown(display_content)
  515. st.markdown(f'<div class="standard-content-display">{converted_content}</div>', unsafe_allow_html=True)
  516. elif render_mode == "DataFrame表格":
  517. if '<table' in display_content.lower():
  518. self.display_html_table_as_dataframe(display_content)
  519. else:
  520. st.info("当前内容中没有检测到HTML表格")
  521. st.markdown(f'<div class="standard-content-display">{display_content}</div>', unsafe_allow_html=True)
  522. else:
  523. st.text_area(
  524. "MD内容预览",
  525. display_content,
  526. height=300,
  527. help="OCR识别的文本内容",
  528. key="standard_text_area"
  529. )
  530. with right_col:
  531. st.header("🖼️ 原图标注")
  532. # 图片缩放控制
  533. col1, col2 = st.columns(2)
  534. with col1:
  535. current_zoom = st.slider("图片缩放", 0.3, 2.0, zoom_level, 0.1, key="standard_zoom_level")
  536. with col2:
  537. show_all_boxes = st.checkbox("显示所有框", value=False, key="standard_show_all_boxes")
  538. if self.image_path and Path(self.image_path).exists():
  539. try:
  540. image = Image.open(self.image_path)
  541. # 应用缩放级别
  542. if current_zoom != 1.0:
  543. new_width = int(image.width * current_zoom)
  544. new_height = int(image.height * current_zoom)
  545. image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
  546. selected_bbox = None
  547. if st.session_state.selected_text and st.session_state.selected_text in self.text_bbox_mapping:
  548. info = self.text_bbox_mapping[st.session_state.selected_text][0]
  549. bbox = info['bbox']
  550. if current_zoom != 1.0:
  551. bbox = [int(coord * current_zoom) for coord in bbox]
  552. selected_bbox = bbox
  553. # 创建交互式图片
  554. if show_all_boxes:
  555. # 显示所有框的模式
  556. fig = self.create_interactive_plot(image, selected_bbox)
  557. else:
  558. # 只显示选中框的模式
  559. fig = go.Figure()
  560. # 添加图片
  561. fig.add_layout_image(
  562. dict(
  563. source=image,
  564. xref="x", yref="y",
  565. x=0, y=image.height,
  566. sizex=image.width, sizey=image.height,
  567. sizing="stretch", opacity=1.0, layer="below"
  568. )
  569. )
  570. # 只显示选中的bbox
  571. if selected_bbox and len(selected_bbox) >= 4:
  572. x1, y1, x2, y2 = selected_bbox[:4]
  573. fig.add_shape(
  574. type="rect",
  575. x0=x1, y0=image.height-y2,
  576. x1=x2, y1=image.height-y1,
  577. line=dict(color="red", width=3),
  578. fillcolor="rgba(255, 0, 0, 0.2)",
  579. )
  580. # 设置布局
  581. fig.update_xaxes(visible=False, range=[0, image.width])
  582. fig.update_yaxes(visible=False, range=[0, image.height], scaleanchor="x")
  583. fig.update_layout(
  584. width=800, height=600,
  585. margin=dict(l=0, r=0, t=0, b=0),
  586. xaxis_showgrid=False, yaxis_showgrid=False,
  587. plot_bgcolor='white'
  588. )
  589. st.plotly_chart(fig, use_container_width=True, key="standard_plot")
  590. # 显示选中文本的详细信息
  591. if st.session_state.selected_text:
  592. st.subheader("📍 选中文本详情")
  593. if st.session_state.selected_text in self.text_bbox_mapping:
  594. info = self.text_bbox_mapping[st.session_state.selected_text][0]
  595. original_bbox = info['bbox']
  596. info_col1, info_col2 = st.columns(2)
  597. with info_col1:
  598. st.write(f"**文本内容:** {st.session_state.selected_text[:30]}...")
  599. st.write(f"**类别:** {info['category']}")
  600. st.write(f"**置信度:** {info.get('confidence', 'N/A')}")
  601. with info_col2:
  602. st.write(f"**位置:** [{', '.join(map(str, original_bbox))}]")
  603. if len(original_bbox) >= 4:
  604. st.write(f"**宽度:** {original_bbox[2] - original_bbox[0]} px")
  605. st.write(f"**高度:** {original_bbox[3] - original_bbox[1]} px")
  606. # 错误标记功能
  607. col1, col2 = st.columns(2)
  608. with col1:
  609. if st.button("❌ 标记为错误", key="mark_error_standard"):
  610. st.session_state.marked_errors.add(st.session_state.selected_text)
  611. st.rerun()
  612. with col2:
  613. if st.button("✅ 取消错误标记", key="unmark_error_standard"):
  614. st.session_state.marked_errors.discard(st.session_state.selected_text)
  615. st.rerun()
  616. # 标记状态显示
  617. is_error = st.session_state.selected_text in st.session_state.marked_errors
  618. if is_error:
  619. st.error("⚠️ 此文本已标记为错误")
  620. else:
  621. st.success("✅ 此文本未标记错误")
  622. except Exception as e:
  623. st.error(f"❌ 图片处理失败: {e}")
  624. else:
  625. st.error("未找到对应的图片文件")
  626. if self.image_path:
  627. st.write(f"期望路径: {self.image_path}")
  628. def create_split_layout_with_fixed_image(self, font_size: int = 12, zoom_level: float = 1.0):
  629. """创建左侧滚动、右侧固定的布局 - 修复版本"""
  630. # 使用columns创建左右布局
  631. left_col, right_col = st.columns([0.7, 1])
  632. with left_col:
  633. st.header("📄 OCR识别内容")
  634. # 添加文本选择器
  635. if self.text_bbox_mapping:
  636. text_options = ["请选择文本..."] + list(self.text_bbox_mapping.keys())
  637. selected_index = st.selectbox(
  638. "选择要校验的文本",
  639. range(len(text_options)),
  640. format_func=lambda x: text_options[x][:50] + "..." if len(text_options[x]) > 50 else text_options[x],
  641. key="split_text_selector"
  642. )
  643. if selected_index > 0:
  644. st.session_state.selected_text = text_options[selected_index]
  645. # 创建可滚动的容器
  646. container_height = st.selectbox(
  647. "选择内容区域高度",
  648. [400, 600, 800, 1000, 1200],
  649. index=2,
  650. key="split_content_height"
  651. )
  652. # 使用自定义CSS创建滚动区域,应用字体大小参数
  653. st.markdown(f"""
  654. <style>
  655. .scrollable-content {{
  656. height: {container_height}px;
  657. overflow-y: auto;
  658. overflow-x: hidden;
  659. padding: 10px;
  660. border: 1px solid #ddd;
  661. border-radius: 5px;
  662. background-color: #fafafa !important;
  663. font-size: {font_size}px !important;
  664. line-height: 1.4;
  665. color: #333333 !important;
  666. }}
  667. .scrollable-content::-webkit-scrollbar {{
  668. width: 8px;
  669. }}
  670. .scrollable-content::-webkit-scrollbar-track {{
  671. background: #f1f1f1;
  672. border-radius: 4px;
  673. }}
  674. .scrollable-content::-webkit-scrollbar-thumb {{
  675. background: #888;
  676. border-radius: 4px;
  677. }}
  678. .scrollable-content::-webkit-scrollbar-thumb:hover {{
  679. background: #555;
  680. }}
  681. </style>
  682. """, unsafe_allow_html=True)
  683. # 显示可滚动的OCR内容
  684. if self.md_content:
  685. scrollable_content = f"""
  686. <div class="scrollable-content">
  687. {self.md_content.replace(chr(10), '<br>')}
  688. </div>
  689. """
  690. st.markdown(scrollable_content, unsafe_allow_html=True)
  691. with right_col:
  692. # 固定位置的图片显示
  693. self.create_fixed_image_display(zoom_level)
  694. def create_fixed_image_display(self, zoom_level: float = 1.0):
  695. """创建固定位置的图片显示 - 修复版本"""
  696. st.header("🖼️ 原图标注")
  697. # 图片缩放控制
  698. col1, col2 = st.columns(2)
  699. with col1:
  700. current_zoom = st.slider("图片缩放", 0.3, 2.0, zoom_level, 0.1, key="fixed_zoom_level")
  701. with col2:
  702. show_all_boxes = st.checkbox("显示所有框", value=False, key="fixed_show_all_boxes")
  703. if self.image_path and Path(self.image_path).exists():
  704. try:
  705. image = Image.open(self.image_path)
  706. # 根据缩放级别调整图片大小
  707. new_width = int(image.width * current_zoom)
  708. new_height = int(image.height * current_zoom)
  709. resized_image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
  710. # 在固定容器中显示图片
  711. selected_bbox = None
  712. if st.session_state.selected_text and st.session_state.selected_text in self.text_bbox_mapping:
  713. info = self.text_bbox_mapping[st.session_state.selected_text][0]
  714. # 根据缩放级别调整bbox坐标
  715. bbox = info['bbox']
  716. selected_bbox = [int(coord * current_zoom) for coord in bbox]
  717. # 创建交互式图片(调整大小)
  718. fig = self.create_resized_interactive_plot(resized_image, selected_bbox, current_zoom, show_all_boxes)
  719. st.plotly_chart(fig, use_container_width=True, key="fixed_plot")
  720. # 显示选中文本的详细信息
  721. if st.session_state.selected_text and st.session_state.selected_text in self.text_bbox_mapping:
  722. st.subheader("📍 选中文本详情")
  723. info = self.text_bbox_mapping[st.session_state.selected_text][0]
  724. bbox = info['bbox']
  725. info_col1, info_col2 = st.columns(2)
  726. with info_col1:
  727. st.write(f"**文本内容:** {st.session_state.selected_text[:30]}...")
  728. st.write(f"**类别:** {info['category']}")
  729. with info_col2:
  730. st.write(f"**位置:** [{', '.join(map(str, bbox))}]")
  731. if len(bbox) >= 4:
  732. st.write(f"**大小:** {bbox[2] - bbox[0]} x {bbox[3] - bbox[1]} px")
  733. except Exception as e:
  734. st.error(f"❌ 图片处理失败: {e}")
  735. else:
  736. st.error("未找到对应的图片文件")
  737. if self.image_path:
  738. st.write(f"期望路径: {self.image_path}")
  739. def create_resized_interactive_plot(self, image: Image.Image, selected_bbox: Optional[List[int]], zoom_level: float, show_all_boxes: bool) -> go.Figure:
  740. """创建可调整大小的交互式图片"""
  741. fig = go.Figure()
  742. fig.add_layout_image(
  743. dict(
  744. source=image,
  745. xref="x", yref="y",
  746. x=0, y=image.height,
  747. sizex=image.width, sizey=image.height,
  748. sizing="stretch", opacity=1.0, layer="below"
  749. )
  750. )
  751. # 显示所有bbox(如果开启)
  752. if show_all_boxes:
  753. for text, info_list in self.text_bbox_mapping.items():
  754. for info in info_list:
  755. bbox = info['bbox']
  756. if len(bbox) >= 4:
  757. x1, y1, x2, y2 = [coord * zoom_level for coord in bbox[:4]]
  758. color = "rgba(0, 100, 200, 0.2)"
  759. if text in self.marked_errors:
  760. color = "rgba(255, 0, 0, 0.3)"
  761. fig.add_shape(
  762. type="rect",
  763. x0=x1, y0=image.height-y2,
  764. x1=x2, y1=image.height-y1,
  765. line=dict(color=color.replace('0.2', '0.8').replace('0.3', '1.0'), width=1),
  766. fillcolor=color,
  767. )
  768. # 高亮显示选中的bbox
  769. if selected_bbox and len(selected_bbox) >= 4:
  770. x1, y1, x2, y2 = selected_bbox[:4]
  771. fig.add_shape(
  772. type="rect",
  773. x0=x1, y0=image.height-y2,
  774. x1=x2, y1=image.height-y1,
  775. line=dict(color="red", width=2),
  776. fillcolor="rgba(255, 0, 0, 0.3)",
  777. )
  778. fig.update_xaxes(visible=False, range=[0, image.width])
  779. fig.update_yaxes(visible=False, range=[0, image.height], scaleanchor="x")
  780. fig.update_layout(
  781. width=image.width,
  782. height=image.height,
  783. margin=dict(l=0, r=0, t=0, b=0),
  784. showlegend=False,
  785. plot_bgcolor='white'
  786. )
  787. return fig
  788. def create_compact_layout(self, font_size: int = 12, zoom_level: float = 1.0):
  789. """创建紧凑的对比布局 - 修复版本"""
  790. # 顶部控制区域
  791. control_col1, control_col2, control_col3 = st.columns([1, 1, 1])
  792. with control_col1:
  793. current_font_size = st.selectbox("字体大小", [10, 12, 14, 16, 18],
  794. index=[10, 12, 14, 16, 18].index(font_size) if font_size in [10, 12, 14, 16, 18] else 1,
  795. key="compact_font")
  796. with control_col2:
  797. content_height = st.selectbox("内容高度", [300, 400, 500, 600], index=1, key="compact_height")
  798. with control_col3:
  799. current_zoom = st.slider("图片缩放", 0.3, 1.5, zoom_level, 0.1, key="compact_zoom")
  800. # 主要内容区域
  801. left_col, right_col = st.columns([0.7, 1]) # 调整比例
  802. with left_col:
  803. st.subheader("📄 OCR内容")
  804. # 文本选择器
  805. if self.text_bbox_mapping:
  806. text_options = ["请选择文本..."] + list(self.text_bbox_mapping.keys())
  807. selected_index = st.selectbox(
  808. "快速定位文本",
  809. range(len(text_options)),
  810. format_func=lambda x: text_options[x][:30] + "..." if len(text_options[x]) > 30 else text_options[x],
  811. key="compact_text_selector"
  812. )
  813. if selected_index > 0:
  814. st.session_state.selected_text = text_options[selected_index]
  815. # 自定义CSS样式,应用字体大小参数
  816. st.markdown(f"""
  817. <style>
  818. .compact-content {{
  819. height: {content_height}px;
  820. overflow-y: auto;
  821. font-size: {current_font_size}px !important;
  822. line-height: 1.4;
  823. border: 1px solid #ddd;
  824. padding: 10px;
  825. background-color: #fafafa !important;
  826. font-family: 'Monaco', 'Menlo', 'Ubuntu Mono', monospace;
  827. color: #333333 !important;
  828. }}
  829. .highlight-text {{
  830. background-color: #ffeb3b !important;
  831. padding: 2px 4px;
  832. border-radius: 3px;
  833. cursor: pointer;
  834. color: #333333 !important;
  835. }}
  836. .selected-highlight {{
  837. background-color: #4caf50 !important;
  838. color: white !important;
  839. }}
  840. </style>
  841. """, unsafe_allow_html=True)
  842. # 处理并显示OCR内容
  843. if self.md_content:
  844. # 高亮可点击文本
  845. highlighted_content = self.md_content
  846. for text in self.text_bbox_mapping.keys():
  847. if len(text) > 2: # 避免高亮过短的文本
  848. css_class = "highlight-text selected-highlight" if text == st.session_state.selected_text else "highlight-text"
  849. # 使用更安全的替换方法
  850. highlighted_content = highlighted_content.replace(
  851. text,
  852. f'<span class="{css_class}" title="{text[:50]}...">{text}</span>'
  853. )
  854. st.markdown(
  855. f'<div class="compact-content">{highlighted_content}</div>',
  856. unsafe_allow_html=True
  857. )
  858. with right_col:
  859. st.subheader("🖼️ 图片标注")
  860. if self.image_path and Path(self.image_path).exists():
  861. try:
  862. image = Image.open(self.image_path)
  863. # 调整图片大小以适应布局
  864. display_width = int(400 * current_zoom) # 使用当前缩放值
  865. aspect_ratio = image.height / image.width
  866. display_height = int(display_width * aspect_ratio)
  867. resized_image = image.resize((display_width, display_height), Image.Resampling.LANCZOS)
  868. # 显示选中文本的bbox
  869. if st.session_state.selected_text and st.session_state.selected_text in self.text_bbox_mapping:
  870. info = self.text_bbox_mapping[st.session_state.selected_text][0]
  871. bbox = info['bbox']
  872. # 在图片上绘制bbox
  873. annotated_image = self.draw_bbox_on_image(resized_image,
  874. [int(coord * current_zoom) for coord in bbox], "red", 3)
  875. st.image(annotated_image, use_column_width=True)
  876. # 显示详细信息
  877. st.info(f"**选中:** {st.session_state.selected_text[:20]}...\n**位置:** [{', '.join(map(str, bbox))}]")
  878. else:
  879. st.image(resized_image, use_column_width=True)
  880. except Exception as e:
  881. st.error(f"❌ 图片处理失败: {e}")
  882. else:
  883. st.error("未找到对应的图片文件")
  884. def main():
  885. """主应用"""
  886. st.title("🔍 OCR可视化校验工具")
  887. st.markdown("---")
  888. # 初始化session state
  889. if 'validator' not in st.session_state:
  890. st.session_state.validator = StreamlitOCRValidator()
  891. if 'selected_text' not in st.session_state:
  892. st.session_state.selected_text = None
  893. if 'marked_errors' not in st.session_state:
  894. st.session_state.marked_errors = set()
  895. # 同步标记的错误到validator
  896. st.session_state.validator.marked_errors = st.session_state.marked_errors
  897. # 侧边栏 - 文件选择和控制
  898. with st.sidebar:
  899. st.header("📁 文件选择")
  900. # 查找可用的OCR文件
  901. output_dir = Path("output")
  902. available_files = []
  903. if output_dir.exists():
  904. for json_file in output_dir.rglob("*.json"):
  905. available_files.append(str(json_file))
  906. if available_files:
  907. selected_file = st.selectbox(
  908. "选择OCR结果文件",
  909. available_files,
  910. index=0
  911. )
  912. if st.button("🔄 加载文件", type="primary") and selected_file:
  913. try:
  914. st.session_state.validator.load_ocr_data(selected_file)
  915. st.success("✅ 文件加载成功!")
  916. st.rerun()
  917. except Exception as e:
  918. st.error(f"❌ 加载失败: {e}")
  919. else:
  920. st.warning("未找到OCR结果文件")
  921. st.info("请确保output目录下有OCR结果文件")
  922. st.markdown("---")
  923. # 控制面板
  924. st.header("🎛️ 控制面板")
  925. if st.button("🧹 清除选择"):
  926. st.session_state.selected_text = None
  927. st.rerun()
  928. if st.button("❌ 清除错误标记"):
  929. st.session_state.marked_errors = set()
  930. st.rerun()
  931. # 主内容区域
  932. if not st.session_state.validator.ocr_data:
  933. st.info("👈 请在左侧选择并加载OCR结果文件")
  934. return
  935. # 显示统计信息
  936. try:
  937. stats = st.session_state.validator.get_statistics()
  938. col1, col2, col3, col4 = st.columns(4)
  939. with col1:
  940. st.metric("📊 总文本块", stats['total_texts'])
  941. with col2:
  942. st.metric("🔗 可点击文本", stats['clickable_texts'])
  943. with col3:
  944. st.metric("❌ 标记错误", stats['marked_errors'])
  945. with col4:
  946. st.metric("✅ 准确率", f"{stats['accuracy_rate']:.1f}%")
  947. st.markdown("---")
  948. except Exception as e:
  949. st.error(f"❌ 统计信息计算失败: {e}")
  950. return
  951. # 创建标签页
  952. tab1, tab2, tab3, tab4 = st.tabs(["📄 内容校验", "📊 表格分析", "📈 数据统计", "🚀 快速导航"])
  953. with tab1:
  954. # 顶部控制区域
  955. control_col1, control_col2, control_col3, control_col4 = st.columns(4)
  956. with control_col1:
  957. layout_mode = st.selectbox(
  958. "布局模式",
  959. ["标准布局", "滚动布局", "紧凑布局"],
  960. key="layout_mode"
  961. )
  962. with control_col2:
  963. if layout_mode != "标准布局":
  964. content_height = st.selectbox("内容高度", [400, 600, 800], index=1, key="content_height_select")
  965. with control_col3:
  966. font_size = st.selectbox("字体大小", [10, 12, 14, 16], index=1, key="font_size_select")
  967. with control_col4:
  968. zoom_level = st.slider("图片缩放", 0.3, 2.0, 1.0, 0.1, key="zoom_level_select")
  969. # 根据选择的布局模式显示不同的界面,传递参数
  970. if layout_mode == "滚动布局":
  971. st.session_state.validator.create_split_layout_with_fixed_image(font_size, zoom_level)
  972. elif layout_mode == "紧凑布局":
  973. st.session_state.validator.create_compact_layout(font_size, zoom_level)
  974. else:
  975. # 调用封装的标准布局方法
  976. st.session_state.validator.create_standard_layout(font_size, zoom_level)
  977. with tab2:
  978. # 表格分析页面
  979. st.header("📊 表格数据分析")
  980. if st.session_state.validator.md_content:
  981. if '<table' in st.session_state.validator.md_content.lower():
  982. col1, col2 = st.columns([2, 1])
  983. with col1:
  984. st.subheader("🔍 表格数据预览")
  985. st.session_state.validator.display_html_table_as_dataframe(
  986. st.session_state.validator.md_content
  987. )
  988. with col2:
  989. st.subheader("⚙️ 表格操作")
  990. if st.button("📥 导出表格数据", type="primary"):
  991. try:
  992. tables = pd.read_html(StringIO(st.session_state.validator.md_content))
  993. if tables:
  994. output = BytesIO()
  995. with pd.ExcelWriter(output, engine='openpyxl') as writer:
  996. for i, table in enumerate(tables):
  997. table.to_excel(writer, sheet_name=f'Table_{i+1}', index=False)
  998. st.download_button(
  999. label="📥 下载Excel文件",
  1000. data=output.getvalue(),
  1001. file_name="ocr_tables.xlsx",
  1002. mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
  1003. )
  1004. except Exception as e:
  1005. st.error(f"导出失败: {e}")
  1006. if st.button("🔍 表格统计分析"):
  1007. try:
  1008. tables = pd.read_html(StringIO(st.session_state.validator.md_content))
  1009. if tables:
  1010. st.write("**表格统计信息:**")
  1011. for i, table in enumerate(tables):
  1012. st.write(f"表格 {i+1}:")
  1013. st.write(f"- 行数: {len(table)}")
  1014. st.write(f"- 列数: {len(table.columns)}")
  1015. st.write(f"- 数值列数: {len(table.select_dtypes(include=[np.number]).columns)}")
  1016. except Exception as e:
  1017. st.error(f"统计分析失败: {e}")
  1018. else:
  1019. st.info("当前OCR结果中没有检测到表格数据")
  1020. else:
  1021. st.warning("请先加载OCR数据")
  1022. with tab3:
  1023. # 数据统计页面
  1024. st.header("📈 OCR数据统计")
  1025. if stats:
  1026. # 类别统计图表
  1027. if stats['categories']:
  1028. st.subheader("📊 类别分布")
  1029. fig_pie = px.pie(
  1030. values=list(stats['categories'].values()),
  1031. names=list(stats['categories'].keys()),
  1032. title="文本类别分布"
  1033. )
  1034. st.plotly_chart(fig_pie, use_container_width=True)
  1035. # 错误率分析
  1036. st.subheader("📈 质量分析")
  1037. accuracy_data = {
  1038. '状态': ['正确', '错误'],
  1039. '数量': [stats['clickable_texts'] - stats['marked_errors'], stats['marked_errors']]
  1040. }
  1041. fig_bar = px.bar(
  1042. accuracy_data,
  1043. x='状态',
  1044. y='数量',
  1045. title="识别质量分布",
  1046. color='状态',
  1047. color_discrete_map={'正确': 'green', '错误': 'red'}
  1048. )
  1049. st.plotly_chart(fig_bar, use_container_width=True)
  1050. with tab4:
  1051. # 快速导航功能
  1052. st.header("🚀 快速导航")
  1053. if not st.session_state.validator.text_bbox_mapping:
  1054. st.info("没有可用的文本项进行导航")
  1055. else:
  1056. # 按类别分组
  1057. categories = {}
  1058. for text, info_list in st.session_state.validator.text_bbox_mapping.items():
  1059. category = info_list[0]['category']
  1060. if category not in categories:
  1061. categories[category] = []
  1062. categories[category].append(text)
  1063. # 创建导航按钮
  1064. for category, texts in categories.items():
  1065. with st.expander(f"{category} ({len(texts)}项)", expanded=False):
  1066. cols = st.columns(3) # 每行3个按钮
  1067. for i, text in enumerate(texts):
  1068. col_idx = i % 3
  1069. with cols[col_idx]:
  1070. display_text = text[:15] + "..." if len(text) > 15 else text
  1071. if st.button(display_text, key=f"nav_{category}_{i}"):
  1072. st.session_state.selected_text = text
  1073. st.rerun()
  1074. if __name__ == "__main__":
  1075. main()