ocr_validator_layout.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391
  1. #!/usr/bin/env python3
  2. """
  3. OCR验证工具的布局管理模块
  4. 包含标准布局、滚动布局、紧凑布局的实现
  5. """
  6. import streamlit as st
  7. from pathlib import Path
  8. from PIL import Image
  9. from typing import Dict, List, Optional
  10. import plotly.graph_objects as go
  11. from ocr_validator_utils import (
  12. convert_html_table_to_markdown,
  13. parse_html_tables,
  14. draw_bbox_on_image
  15. )
  16. class OCRLayoutManager:
  17. """OCR布局管理器"""
  18. def __init__(self, validator):
  19. self.validator = validator
  20. self.config = validator.config
  21. def render_content_section(self, layout_type: str = "standard"):
  22. """渲染内容区域 - 统一方法"""
  23. st.header("📄 OCR识别内容")
  24. # 文本选择器
  25. if self.validator.text_bbox_mapping:
  26. text_options = ["请选择文本..."] + list(self.validator.text_bbox_mapping.keys())
  27. selected_index = st.selectbox(
  28. "选择要校验的文本",
  29. range(len(text_options)),
  30. format_func=lambda x: text_options[x][:50] + "..." if len(text_options[x]) > 50 else text_options[x],
  31. key=f"{layout_type}_text_selector"
  32. )
  33. if selected_index > 0:
  34. st.session_state.selected_text = text_options[selected_index]
  35. else:
  36. st.warning("没有找到可点击的文本")
  37. def render_md_content(self, layout_type: str):
  38. """渲染Markdown内容 - 统一方法"""
  39. if not self.validator.md_content:
  40. return None, None
  41. # 搜索功能
  42. search_term = st.text_input(
  43. "🔍 搜索文本内容",
  44. placeholder="输入关键词搜索...",
  45. key=f"{layout_type}_search"
  46. )
  47. display_content = self.validator.md_content
  48. if search_term:
  49. lines = display_content.split('\n')
  50. filtered_lines = [line for line in lines if search_term.lower() in line.lower()]
  51. display_content = '\n'.join(filtered_lines)
  52. if filtered_lines:
  53. st.success(f"找到 {len(filtered_lines)} 行包含 '{search_term}'")
  54. else:
  55. st.warning(f"未找到包含 '{search_term}' 的内容")
  56. # 渲染方式选择
  57. render_mode = st.radio(
  58. "选择渲染方式",
  59. ["HTML渲染", "Markdown渲染", "DataFrame表格", "原始文本"],
  60. horizontal=True,
  61. key=f"{layout_type}_render_mode"
  62. )
  63. return display_content, render_mode
  64. def render_content_by_mode(self, content: str, render_mode: str, font_size: int, layout_type: str):
  65. """根据渲染模式显示内容 - 统一方法"""
  66. if content is None or render_mode is None:
  67. return
  68. if render_mode == "HTML渲染":
  69. content_style = f"""
  70. <style>
  71. .{layout_type}-content-display {{
  72. font-size: {font_size}px !important;
  73. line-height: 1.4;
  74. color: #333333 !important;
  75. background-color: #fafafa !important;
  76. padding: 10px;
  77. border-radius: 5px;
  78. border: 1px solid #ddd;
  79. }}
  80. </style>
  81. """
  82. st.markdown(content_style, unsafe_allow_html=True)
  83. st.markdown(f'<div class="{layout_type}-content-display">{content}</div>', unsafe_allow_html=True)
  84. elif render_mode == "Markdown渲染":
  85. converted_content = convert_html_table_to_markdown(content)
  86. content_style = f"""
  87. <style>
  88. .{layout_type}-content-display {{
  89. font-size: {font_size}px !important;
  90. line-height: 1.4;
  91. color: #333333 !important;
  92. background-color: #fafafa !important;
  93. padding: 10px;
  94. border-radius: 5px;
  95. border: 1px solid #ddd;
  96. }}
  97. </style>
  98. """
  99. st.markdown(content_style, unsafe_allow_html=True)
  100. st.markdown(f'<div class="{layout_type}-content-display">{converted_content}</div>', unsafe_allow_html=True)
  101. elif render_mode == "DataFrame表格":
  102. if '<table' in content.lower():
  103. self.validator.display_html_table_as_dataframe(content)
  104. else:
  105. st.info("当前内容中没有检测到HTML表格")
  106. st.markdown(content, unsafe_allow_html=True)
  107. else: # 原始文本
  108. st.text_area(
  109. "MD内容预览",
  110. content,
  111. height=300,
  112. key=f"{layout_type}_text_area"
  113. )
  114. # 布局实现
  115. def create_standard_layout(self, font_size: int = 10, zoom_level: float = 1.0):
  116. """创建标准布局"""
  117. if zoom_level is None:
  118. zoom_level = self.config['styles']['layout']['default_zoom']
  119. # 主要内容区域
  120. layout = self.config['styles']['layout']
  121. left_col, right_col = st.columns([layout['content_width'], layout['sidebar_width']])
  122. with left_col:
  123. self.render_content_section("standard")
  124. # 显示内容
  125. if self.validator.md_content:
  126. display_content, render_mode = self.render_md_content("standard")
  127. self.render_content_by_mode(display_content, render_mode, font_size, "standard")
  128. with right_col:
  129. self.create_aligned_image_display(zoom_level, "compact")
  130. def create_compact_layout(self, font_size: int = 10, zoom_level: float = 1.0):
  131. """创建紧凑的对比布局"""
  132. # 主要内容区域
  133. layout = self.config['styles']['layout']
  134. left_col, right_col = st.columns([layout['content_width'], layout['sidebar_width']])
  135. with left_col:
  136. self.render_content_section("compact")
  137. # 只保留一个内容区域高度选择
  138. container_height = st.selectbox(
  139. "选择内容区域高度",
  140. [400, 600, 800, 1000, 1200],
  141. index=2,
  142. key="compact_content_height"
  143. )
  144. # 快速定位文本选择器(使用不同的key)
  145. if self.validator.text_bbox_mapping:
  146. text_options = ["请选择文本..."] + list(self.validator.text_bbox_mapping.keys())
  147. selected_index = st.selectbox(
  148. "快速定位文本",
  149. range(len(text_options)),
  150. format_func=lambda x: text_options[x][:30] + "..." if len(text_options[x]) > 30 else text_options[x],
  151. key="compact_quick_text_selector" # 使用不同的key
  152. )
  153. if selected_index > 0:
  154. st.session_state.selected_text = text_options[selected_index]
  155. # 自定义CSS样式
  156. st.markdown(f"""
  157. <style>
  158. .compact-content {{
  159. height: {container_height}px;
  160. overflow-y: auto;
  161. font-size: {font_size}px !important;
  162. line-height: 1.4;
  163. border: 1px solid #ddd;
  164. padding: 10px;
  165. background-color: #fafafa !important;
  166. font-family: 'Monaco', 'Menlo', 'Ubuntu Mono', monospace;
  167. color: #333333 !important;
  168. }}
  169. .highlight-text {{
  170. background-color: #ffeb3b !important;
  171. padding: 2px 4px;
  172. border-radius: 3px;
  173. cursor: pointer;
  174. color: #333333 !important;
  175. }}
  176. .selected-highlight {{
  177. background-color: #4caf50 !important;
  178. color: white !important;
  179. }}
  180. </style>
  181. """, unsafe_allow_html=True)
  182. # 处理并显示OCR内容
  183. if self.validator.md_content:
  184. # 高亮可点击文本
  185. highlighted_content = self.validator.md_content
  186. for text in self.validator.text_bbox_mapping.keys():
  187. if len(text) > 2: # 避免高亮过短的文本
  188. css_class = "highlight-text selected-highlight" if text == st.session_state.selected_text else "highlight-text"
  189. highlighted_content = highlighted_content.replace(
  190. text,
  191. f'<span class="{css_class}" title="{text[:50]}...">{text}</span>'
  192. )
  193. st.markdown(
  194. f'<div class="compact-content">{highlighted_content}</div>',
  195. unsafe_allow_html=True
  196. )
  197. with right_col:
  198. # 修复的对齐图片显示
  199. self.create_aligned_image_display(zoom_level, "compact")
  200. def create_aligned_image_display(self, zoom_level: float = 1.0, layout_type: str = "aligned"):
  201. """创建与左侧对齐的图片显示"""
  202. # 精确对齐CSS
  203. st.markdown(f"""
  204. <style>
  205. .aligned-image-container-{layout_type} {{
  206. margin-top: -70px;
  207. padding-top: 0px;
  208. }}
  209. .aligned-image-container-{layout_type} h1 {{
  210. margin-top: 0px !important;
  211. padding-top: 0px !important;
  212. }}
  213. </style>
  214. """, unsafe_allow_html=True)
  215. st.markdown(f'<div class="aligned-image-container-{layout_type}">', unsafe_allow_html=True)
  216. st.header("🖼️ 原图标注")
  217. # 图片缩放控制
  218. col1, col2 = st.columns(2)
  219. with col1:
  220. current_zoom = st.slider("图片缩放", 0.3, 2.0, zoom_level, 0.1, key=f"{layout_type}_zoom_level")
  221. with col2:
  222. show_all_boxes = st.checkbox("显示所有框", value=False, key=f"{layout_type}_show_all_boxes")
  223. if self.validator.image_path and Path(self.validator.image_path).exists():
  224. try:
  225. image = Image.open(self.validator.image_path)
  226. # 根据缩放级别调整图片大小
  227. new_width = int(image.width * current_zoom)
  228. new_height = int(image.height * current_zoom)
  229. resized_image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
  230. # 在固定容器中显示图片
  231. selected_bbox = None
  232. if st.session_state.selected_text and st.session_state.selected_text in self.validator.text_bbox_mapping:
  233. info = self.validator.text_bbox_mapping[st.session_state.selected_text][0]
  234. # 根据缩放级别调整bbox坐标
  235. bbox = info['bbox']
  236. selected_bbox = [int(coord * current_zoom) for coord in bbox]
  237. # 创建交互式图片 - 确保从顶部开始显示
  238. fig = self.create_resized_interactive_plot(resized_image, selected_bbox, current_zoom, show_all_boxes)
  239. st.plotly_chart(fig, use_container_width=True, key=f"{layout_type}_plot")
  240. # 显示选中文本的详细信息
  241. if st.session_state.selected_text and st.session_state.selected_text in self.validator.text_bbox_mapping:
  242. st.subheader("📍 选中文本详情")
  243. info = self.validator.text_bbox_mapping[st.session_state.selected_text][0]
  244. bbox = info['bbox']
  245. info_col1, info_col2 = st.columns(2)
  246. with info_col1:
  247. st.write(f"**文本内容:** {st.session_state.selected_text[:30]}...")
  248. st.write(f"**类别:** {info['category']}")
  249. with info_col2:
  250. st.write(f"**位置:** [{', '.join(map(str, bbox))}]")
  251. if len(bbox) >= 4:
  252. st.write(f"**大小:** {bbox[2] - bbox[0]} x {bbox[3] - bbox[1]} px")
  253. # 错误标记功能
  254. col1, col2 = st.columns(2)
  255. with col1:
  256. if st.button("❌ 标记为错误", key=f"{layout_type}_mark_error"):
  257. st.session_state.marked_errors.add(st.session_state.selected_text)
  258. st.rerun()
  259. with col2:
  260. if st.button("✅ 取消错误标记", key=f"{layout_type}_unmark_error"):
  261. st.session_state.marked_errors.discard(st.session_state.selected_text)
  262. st.rerun()
  263. except Exception as e:
  264. st.error(f"❌ 图片处理失败: {e}")
  265. else:
  266. st.error("未找到对应的图片文件")
  267. if self.validator.image_path:
  268. st.write(f"期望路径: {self.validator.image_path}")
  269. st.markdown('</div>', unsafe_allow_html=True)
  270. def create_resized_interactive_plot(self, image: Image.Image, selected_bbox: Optional[List[int]], zoom_level: float, show_all_boxes: bool) -> go.Figure:
  271. """创建可调整大小的交互式图片 - 优化显示和定位"""
  272. fig = go.Figure()
  273. fig.add_layout_image(
  274. dict(
  275. source=image,
  276. xref="x", yref="y",
  277. x=0, y=0, # 改为从底部开始,这样图片会从顶部显示
  278. sizex=image.width, sizey=image.height,
  279. sizing="stretch", opacity=1.0, layer="below"
  280. )
  281. )
  282. # 显示所有bbox(如果开启)
  283. if show_all_boxes:
  284. for text, info_list in self.validator.text_bbox_mapping.items():
  285. for info in info_list:
  286. bbox = info['bbox']
  287. if len(bbox) >= 4:
  288. x1, y1, x2, y2 = [coord * zoom_level for coord in bbox[:4]]
  289. color = "rgba(0, 100, 200, 0.2)"
  290. if text in self.validator.marked_errors:
  291. color = "rgba(255, 0, 0, 0.3)"
  292. fig.add_shape(
  293. type="rect",
  294. x0=x1, y0=y1, # 调整坐标系,不再翻转
  295. x1=x2, y1=y2,
  296. line=dict(color=color.replace('0.2', '0.8').replace('0.3', '1.0'), width=1),
  297. fillcolor=color,
  298. )
  299. # 高亮显示选中的bbox
  300. if selected_bbox and len(selected_bbox) >= 4:
  301. x1, y1, x2, y2 = selected_bbox[:4]
  302. fig.add_shape(
  303. type="rect",
  304. x0=x1, y0=y1, # 调整坐标系
  305. x1=x2, y1=y2,
  306. line=dict(color="red", width=2),
  307. fillcolor="rgba(255, 0, 0, 0.3)",
  308. )
  309. # 设置坐标轴范围 - 让图片从顶部开始显示
  310. fig.update_xaxes(visible=False, range=[0, image.width])
  311. fig.update_yaxes(visible=False, range=[image.height, 0], scaleanchor="x") # 翻转Y轴让图片从顶部开始
  312. # 计算更大的显示尺寸
  313. aspect_ratio = image.width / image.height
  314. display_height = min(1000, max(600, image.height)) # 动态调整高度
  315. display_width = int(display_height * aspect_ratio)
  316. fig.update_layout(
  317. width=display_width,
  318. height=display_height,
  319. margin=dict(l=0, r=0, t=0, b=0),
  320. showlegend=False,
  321. plot_bgcolor='white',
  322. # 设置初始视图让图片从顶部开始显示
  323. xaxis=dict(
  324. range=[0, image.width],
  325. constrain="domain"
  326. ),
  327. yaxis=dict(
  328. range=[image.height, 0], # 翻转范围
  329. constrain="domain",
  330. scaleanchor="x",
  331. scaleratio=1
  332. )
  333. )
  334. return fig