|
@@ -55,6 +55,13 @@ class StreamlitOCRValidator:
|
|
|
self.selected_file_index = -1
|
|
self.selected_file_index = -1
|
|
|
self.display_options = []
|
|
self.display_options = []
|
|
|
self.file_paths = []
|
|
self.file_paths = []
|
|
|
|
|
+
|
|
|
|
|
+ # ✅ 新增:交叉验证数据源
|
|
|
|
|
+ self.verify_source_key = None
|
|
|
|
|
+ self.verify_source_config = None
|
|
|
|
|
+ self.verify_file_info = []
|
|
|
|
|
+ self.verify_display_options = []
|
|
|
|
|
+ self.verify_file_paths = []
|
|
|
|
|
|
|
|
# 初始化布局管理器
|
|
# 初始化布局管理器
|
|
|
self.layout_manager = OCRLayoutManager(self)
|
|
self.layout_manager = OCRLayoutManager(self)
|
|
@@ -66,13 +73,18 @@ class StreamlitOCRValidator:
|
|
|
"""加载多数据源文件信息"""
|
|
"""加载多数据源文件信息"""
|
|
|
self.all_sources = find_available_ocr_files_multi_source(self.config)
|
|
self.all_sources = find_available_ocr_files_multi_source(self.config)
|
|
|
|
|
|
|
|
- # 如果有数据源,默认选择第一个
|
|
|
|
|
|
|
+ # 如果有数据源,默认选择第一个作为OCR源
|
|
|
if self.all_sources:
|
|
if self.all_sources:
|
|
|
- first_source_key = list(self.all_sources.keys())[0]
|
|
|
|
|
|
|
+ source_keys = list(self.all_sources.keys())
|
|
|
|
|
+ first_source_key = source_keys[0]
|
|
|
self.switch_to_source(first_source_key)
|
|
self.switch_to_source(first_source_key)
|
|
|
|
|
+
|
|
|
|
|
+ # 如果有第二个数据源,默认作为验证源
|
|
|
|
|
+ if len(source_keys) > 1:
|
|
|
|
|
+ self.switch_to_verify_source(source_keys[1])
|
|
|
|
|
|
|
|
def switch_to_source(self, source_key: str):
|
|
def switch_to_source(self, source_key: str):
|
|
|
- """切换到指定数据源"""
|
|
|
|
|
|
|
+ """切换到指定OCR数据源"""
|
|
|
if source_key in self.all_sources:
|
|
if source_key in self.all_sources:
|
|
|
self.current_source_key = source_key
|
|
self.current_source_key = source_key
|
|
|
source_data = self.all_sources[source_key]
|
|
source_data = self.all_sources[source_key]
|
|
@@ -86,11 +98,25 @@ class StreamlitOCRValidator:
|
|
|
|
|
|
|
|
# 重置文件选择
|
|
# 重置文件选择
|
|
|
self.selected_file_index = -1
|
|
self.selected_file_index = -1
|
|
|
-
|
|
|
|
|
- print(f"✅ 切换到数据源: {source_key}")
|
|
|
|
|
|
|
+ print(f"✅ 切换到OCR数据源: {source_key}")
|
|
|
else:
|
|
else:
|
|
|
print(f"⚠️ 数据源 {source_key} 没有可用文件")
|
|
print(f"⚠️ 数据源 {source_key} 没有可用文件")
|
|
|
|
|
|
|
|
|
|
+ def switch_to_verify_source(self, source_key: str):
|
|
|
|
|
+ """切换到指定验证数据源"""
|
|
|
|
|
+ if source_key in self.all_sources:
|
|
|
|
|
+ self.verify_source_key = source_key
|
|
|
|
|
+ source_data = self.all_sources[source_key]
|
|
|
|
|
+ self.verify_source_config = source_data['config']
|
|
|
|
|
+ self.verify_file_info = source_data['files']
|
|
|
|
|
+
|
|
|
|
|
+ if self.verify_file_info:
|
|
|
|
|
+ self.verify_display_options = [f"{info['display_name']}" for info in self.verify_file_info]
|
|
|
|
|
+ self.verify_file_paths = [info['path'] for info in self.verify_file_info]
|
|
|
|
|
+ print(f"✅ 切换到验证数据源: {source_key}")
|
|
|
|
|
+ else:
|
|
|
|
|
+ print(f"⚠️ 验证数据源 {source_key} 没有可用文件")
|
|
|
|
|
+
|
|
|
def setup_page_config(self):
|
|
def setup_page_config(self):
|
|
|
"""设置页面配置"""
|
|
"""设置页面配置"""
|
|
|
ui_config = self.config['ui']
|
|
ui_config = self.config['ui']
|
|
@@ -106,56 +132,91 @@ class StreamlitOCRValidator:
|
|
|
st.markdown(f"<style>{css_content}</style>", unsafe_allow_html=True)
|
|
st.markdown(f"<style>{css_content}</style>", unsafe_allow_html=True)
|
|
|
|
|
|
|
|
def create_data_source_selector(self):
|
|
def create_data_source_selector(self):
|
|
|
- """创建数据源选择器"""
|
|
|
|
|
|
|
+ """创建双数据源选择器 - 支持交叉验证"""
|
|
|
if not self.all_sources:
|
|
if not self.all_sources:
|
|
|
st.warning("❌ 未找到任何数据源,请检查配置文件")
|
|
st.warning("❌ 未找到任何数据源,请检查配置文件")
|
|
|
return
|
|
return
|
|
|
|
|
|
|
|
- # 数据源选择
|
|
|
|
|
|
|
+ # 准备数据源选项
|
|
|
source_options = {}
|
|
source_options = {}
|
|
|
for source_key, source_data in self.all_sources.items():
|
|
for source_key, source_data in self.all_sources.items():
|
|
|
display_name = get_data_source_display_name(source_data['config'])
|
|
display_name = get_data_source_display_name(source_data['config'])
|
|
|
source_options[display_name] = source_key
|
|
source_options[display_name] = source_key
|
|
|
|
|
|
|
|
- # 获取当前选择的显示名称
|
|
|
|
|
- current_display_name = None
|
|
|
|
|
- if self.current_source_key:
|
|
|
|
|
- for display_name, key in source_options.items():
|
|
|
|
|
- if key == self.current_source_key:
|
|
|
|
|
- current_display_name = display_name
|
|
|
|
|
- break
|
|
|
|
|
-
|
|
|
|
|
- selected_display_name = st.selectbox(
|
|
|
|
|
- "📁 选择数据源",
|
|
|
|
|
- options=list(source_options.keys()),
|
|
|
|
|
- index=list(source_options.keys()).index(current_display_name) if current_display_name else 0,
|
|
|
|
|
- key="data_source_selector",
|
|
|
|
|
- help="选择要分析的OCR数据源"
|
|
|
|
|
- )
|
|
|
|
|
|
|
+ # 创建两列布局
|
|
|
|
|
+ col1, col2 = st.columns(2)
|
|
|
|
|
|
|
|
- selected_source_key = source_options[selected_display_name]
|
|
|
|
|
|
|
+ with col1:
|
|
|
|
|
+ st.markdown("#### 📊 OCR数据源")
|
|
|
|
|
+ # OCR数据源选择
|
|
|
|
|
+ current_display_name = None
|
|
|
|
|
+ if self.current_source_key:
|
|
|
|
|
+ for display_name, key in source_options.items():
|
|
|
|
|
+ if key == self.current_source_key:
|
|
|
|
|
+ current_display_name = display_name
|
|
|
|
|
+ break
|
|
|
|
|
+
|
|
|
|
|
+ selected_ocr_display = st.selectbox(
|
|
|
|
|
+ "选择OCR数据源",
|
|
|
|
|
+ options=list(source_options.keys()),
|
|
|
|
|
+ index=list(source_options.keys()).index(current_display_name) if current_display_name else 0,
|
|
|
|
|
+ key="ocr_source_selector",
|
|
|
|
|
+ label_visibility="collapsed",
|
|
|
|
|
+ help="选择要分析的OCR数据源"
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ selected_ocr_key = source_options[selected_ocr_display]
|
|
|
|
|
+
|
|
|
|
|
+ # 如果OCR数据源发生变化,切换数据源
|
|
|
|
|
+ if selected_ocr_key != self.current_source_key:
|
|
|
|
|
+ self.switch_to_source(selected_ocr_key)
|
|
|
|
|
+ if 'selected_file_index' in st.session_state:
|
|
|
|
|
+ st.session_state.selected_file_index = 0
|
|
|
|
|
+ st.rerun()
|
|
|
|
|
+
|
|
|
|
|
+ # 显示OCR数据源信息
|
|
|
|
|
+ if self.current_source_config:
|
|
|
|
|
+ with st.expander("📋 OCR数据源详情", expanded=False):
|
|
|
|
|
+ st.write(f"**工具:** {self.current_source_config['ocr_tool']}")
|
|
|
|
|
+ st.write(f"**文件数:** {len(self.file_info)}")
|
|
|
|
|
|
|
|
- # 如果数据源发生变化,切换数据源
|
|
|
|
|
- if selected_source_key != self.current_source_key:
|
|
|
|
|
- self.switch_to_source(selected_source_key)
|
|
|
|
|
- # 重置session state
|
|
|
|
|
- if 'selected_file_index' in st.session_state:
|
|
|
|
|
- st.session_state.selected_file_index = 0
|
|
|
|
|
- st.rerun()
|
|
|
|
|
|
|
+ with col2:
|
|
|
|
|
+ st.markdown("#### 🔍 验证数据源")
|
|
|
|
|
+ # 验证数据源选择
|
|
|
|
|
+ verify_display_name = None
|
|
|
|
|
+ if self.verify_source_key:
|
|
|
|
|
+ for display_name, key in source_options.items():
|
|
|
|
|
+ if key == self.verify_source_key:
|
|
|
|
|
+ verify_display_name = display_name
|
|
|
|
|
+ break
|
|
|
|
|
+
|
|
|
|
|
+ selected_verify_display = st.selectbox(
|
|
|
|
|
+ "选择验证数据源",
|
|
|
|
|
+ options=list(source_options.keys()),
|
|
|
|
|
+ index=list(source_options.keys()).index(verify_display_name) if verify_display_name else (1 if len(source_options) > 1 else 0),
|
|
|
|
|
+ key="verify_source_selector",
|
|
|
|
|
+ label_visibility="collapsed",
|
|
|
|
|
+ help="选择用于交叉验证的数据源"
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ selected_verify_key = source_options[selected_verify_display]
|
|
|
|
|
+
|
|
|
|
|
+ # 如果验证数据源发生变化,切换数据源
|
|
|
|
|
+ if selected_verify_key != self.verify_source_key:
|
|
|
|
|
+ self.switch_to_verify_source(selected_verify_key)
|
|
|
|
|
+ st.rerun()
|
|
|
|
|
+
|
|
|
|
|
+ # 显示验证数据源信息
|
|
|
|
|
+ if self.verify_source_config:
|
|
|
|
|
+ with st.expander("📋 验证数据源详情", expanded=False):
|
|
|
|
|
+ st.write(f"**工具:** {self.verify_source_config['ocr_tool']}")
|
|
|
|
|
+ st.write(f"**文件数:** {len(self.verify_file_info)}")
|
|
|
|
|
|
|
|
- # 显示数据源信息
|
|
|
|
|
- if self.current_source_config:
|
|
|
|
|
- with st.expander("📋 数据源详情", expanded=False):
|
|
|
|
|
- col1, col2, col3 = st.columns(3)
|
|
|
|
|
- with col1:
|
|
|
|
|
- st.write(f"**名称:** {self.current_source_config['name']}")
|
|
|
|
|
- st.write(f"**OCR工具:** {self.current_source_config['ocr_tool']}")
|
|
|
|
|
- with col2:
|
|
|
|
|
- st.write(f"**输出目录:** {self.current_source_config['ocr_out_dir']}")
|
|
|
|
|
- st.write(f"**图片目录:** {self.current_source_config.get('src_img_dir', 'N/A')}")
|
|
|
|
|
- with col3:
|
|
|
|
|
- st.write(f"**描述:** {self.current_source_config.get('description', 'N/A')}")
|
|
|
|
|
- st.write(f"**文件数量:** {len(self.file_info)}")
|
|
|
|
|
|
|
+ # 数据源对比提示
|
|
|
|
|
+ if self.current_source_key == self.verify_source_key:
|
|
|
|
|
+ st.warning("⚠️ OCR数据源和验证数据源相同,建议选择不同的数据源进行交叉验证")
|
|
|
|
|
+ else:
|
|
|
|
|
+ st.success(f"✅ 已选择 {selected_ocr_display} 与 {selected_verify_display} 进行交叉验证")
|
|
|
|
|
|
|
|
def load_ocr_data(self, json_path: str, md_path: Optional[str] = None, image_path: Optional[str] = None):
|
|
def load_ocr_data(self, json_path: str, md_path: Optional[str] = None, image_path: Optional[str] = None):
|
|
|
"""加载OCR相关数据 - 支持多数据源配置"""
|
|
"""加载OCR相关数据 - 支持多数据源配置"""
|
|
@@ -456,107 +517,151 @@ class StreamlitOCRValidator:
|
|
|
|
|
|
|
|
else: # 完整显示
|
|
else: # 完整显示
|
|
|
return table
|
|
return table
|
|
|
-
|
|
|
|
|
- @st.dialog("VLM预校验", width="large", dismissible=True, on_dismiss="rerun")
|
|
|
|
|
- def vlm_pre_validation(self):
|
|
|
|
|
- """VLM预校验功能 - 封装OCR识别和结果对比"""
|
|
|
|
|
|
|
+
|
|
|
|
|
+ def find_verify_md_path(self, selected_file_index: int) -> Optional[Path]:
|
|
|
|
|
+ """查找当前OCR文件对应的验证文件路径"""
|
|
|
|
|
+ current_page = self.file_info[selected_file_index]['page']
|
|
|
|
|
+ verify_md_path = None
|
|
|
|
|
+
|
|
|
|
|
+ for i, info in enumerate(self.verify_file_info):
|
|
|
|
|
+ if info['page'] == current_page:
|
|
|
|
|
+ verify_md_path = Path(self.verify_file_paths[i]).with_suffix('.md')
|
|
|
|
|
+ break
|
|
|
|
|
+
|
|
|
|
|
+ return verify_md_path
|
|
|
|
|
+
|
|
|
|
|
+ @st.dialog("交叉验证", width="large", dismissible=True, on_dismiss="rerun")
|
|
|
|
|
+ def cross_validation(self):
|
|
|
|
|
+ """交叉验证功能 - 比对两个数据源的OCR结果"""
|
|
|
|
|
|
|
|
if not self.image_path or not self.md_content:
|
|
if not self.image_path or not self.md_content:
|
|
|
st.error("❌ 请先加载OCR数据文件")
|
|
st.error("❌ 请先加载OCR数据文件")
|
|
|
return
|
|
return
|
|
|
|
|
+ if self.current_source_key == self.verify_source_key:
|
|
|
|
|
+ st.error("❌ OCR数据源和验证数据源不能相同")
|
|
|
|
|
+ return
|
|
|
# 初始化对比结果存储
|
|
# 初始化对比结果存储
|
|
|
- if 'comparison_result' not in st.session_state:
|
|
|
|
|
- st.session_state.comparison_result = None
|
|
|
|
|
|
|
+ if 'cross_validation_result' not in st.session_state:
|
|
|
|
|
+ st.session_state.cross_validation_result = None
|
|
|
|
|
+
|
|
|
|
|
+ # 初始化对比结果存储
|
|
|
|
|
+ if 'cross_validation_result' not in st.session_state:
|
|
|
|
|
+ st.session_state.cross_validation_result = None
|
|
|
|
|
|
|
|
# 创建进度条和状态显示
|
|
# 创建进度条和状态显示
|
|
|
- with st.spinner("正在进行VLM预校验...", show_time=True):
|
|
|
|
|
|
|
+ with st.spinner("正在进行交叉验证...", show_time=True):
|
|
|
status_text = st.empty()
|
|
status_text = st.empty()
|
|
|
|
|
|
|
|
try:
|
|
try:
|
|
|
|
|
+ # 第一步:获取当前OCR结果文件路径
|
|
|
current_md_path = Path(self.file_paths[self.selected_file_index]).with_suffix('.md')
|
|
current_md_path = Path(self.file_paths[self.selected_file_index]).with_suffix('.md')
|
|
|
if not current_md_path.exists():
|
|
if not current_md_path.exists():
|
|
|
- st.error("❌ 当前OCR结果的Markdown文件不存在,无法进行对比")
|
|
|
|
|
|
|
+ st.error("❌ 当前OCR结果的Markdown文件不存在")
|
|
|
return
|
|
return
|
|
|
- # 第一步:准备目录
|
|
|
|
|
|
|
+
|
|
|
|
|
+ status_text.text(f"📄 OCR文件: {current_md_path.name}")
|
|
|
|
|
+
|
|
|
|
|
+ # 第二步:查找对应的验证文件
|
|
|
|
|
+ verify_md_path = self.find_verify_md_path(self.selected_file_index)
|
|
|
|
|
+
|
|
|
|
|
+ if not verify_md_path or not verify_md_path.exists():
|
|
|
|
|
+ st.error(f"❌ 未找到验证数据源中第{current_md_path}页的对应文件")
|
|
|
|
|
+ return
|
|
|
|
|
+
|
|
|
|
|
+ status_text.text(f"🔍 验证文件: {verify_md_path.name}")
|
|
|
|
|
+
|
|
|
|
|
+ # 第三步:准备输出目录
|
|
|
pre_validation_dir = Path(self.config['pre_validation'].get('out_dir', './output/pre_validation/')).resolve()
|
|
pre_validation_dir = Path(self.config['pre_validation'].get('out_dir', './output/pre_validation/')).resolve()
|
|
|
pre_validation_dir.mkdir(parents=True, exist_ok=True)
|
|
pre_validation_dir.mkdir(parents=True, exist_ok=True)
|
|
|
- status_text.write(f"工作目录: {pre_validation_dir}")
|
|
|
|
|
-
|
|
|
|
|
- # 第二步:调用VLM进行OCR识别
|
|
|
|
|
- status_text.text("🤖 正在调用VLM进行OCR识别...")
|
|
|
|
|
|
|
|
|
|
- # 在expander中显示OCR过程
|
|
|
|
|
- with st.expander("🔍 VLM OCR识别过程", expanded=True):
|
|
|
|
|
- ocr_output = st.empty()
|
|
|
|
|
|
|
+ # 第四步:调用对比功能
|
|
|
|
|
+ status_text.text("📊 正在对比OCR结果...")
|
|
|
|
|
+
|
|
|
|
|
+ comparison_result_path = pre_validation_dir / f"{current_md_path.stem}_cross_validation"
|
|
|
|
|
+
|
|
|
|
|
+ # 在expander中显示对比过程
|
|
|
|
|
+ with st.expander("🔍 交叉验证对比过程", expanded=True):
|
|
|
|
|
+ compare_output = st.empty()
|
|
|
|
|
|
|
|
- # 捕获OCR输出
|
|
|
|
|
|
|
+ # 捕获对比输出
|
|
|
import io
|
|
import io
|
|
|
import contextlib
|
|
import contextlib
|
|
|
|
|
|
|
|
- # 创建字符串缓冲区来捕获print输出
|
|
|
|
|
output_buffer = io.StringIO()
|
|
output_buffer = io.StringIO()
|
|
|
|
|
|
|
|
with contextlib.redirect_stdout(output_buffer):
|
|
with contextlib.redirect_stdout(output_buffer):
|
|
|
- ocr_result = ocr_with_vlm(
|
|
|
|
|
- image_path=str(self.image_path),
|
|
|
|
|
- output_dir=str(pre_validation_dir),
|
|
|
|
|
- normalize_numbers=True
|
|
|
|
|
|
|
+ comparison_result = compare_ocr_results(
|
|
|
|
|
+ file1_path=str(current_md_path),
|
|
|
|
|
+ file2_path=str(verify_md_path),
|
|
|
|
|
+ output_file=str(comparison_result_path),
|
|
|
|
|
+ output_format='both',
|
|
|
|
|
+ ignore_images=True,
|
|
|
|
|
+ table_mode='flow_list', # ✅ 使用流水表格模式
|
|
|
|
|
+ similarity_algorithm='ratio'
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
- # 显示OCR过程输出
|
|
|
|
|
- ocr_output.code(output_buffer.getvalue(), language='text')
|
|
|
|
|
-
|
|
|
|
|
- status_text.text("✅ VLM OCR识别完成")
|
|
|
|
|
-
|
|
|
|
|
- # 第三步:获取VLM生成的文件路径
|
|
|
|
|
- vlm_md_path = pre_validation_dir / f"{Path(self.image_path).stem}.md"
|
|
|
|
|
-
|
|
|
|
|
- if not vlm_md_path.exists():
|
|
|
|
|
- st.error("❌ VLM OCR结果文件未生成")
|
|
|
|
|
- return
|
|
|
|
|
-
|
|
|
|
|
- # 第四步:调用对比功能
|
|
|
|
|
- status_text.text("📊 正在对比OCR结果...")
|
|
|
|
|
-
|
|
|
|
|
- # 在expander中显示对比过程
|
|
|
|
|
- comparison_result_path = pre_validation_dir / f"{current_md_path.stem}_comparison_result"
|
|
|
|
|
- with st.expander("🔍 OCR结果对比过程", expanded=True):
|
|
|
|
|
- compare_output = st.empty()
|
|
|
|
|
-
|
|
|
|
|
- # 捕获对比输出
|
|
|
|
|
- output_buffer = io.StringIO()
|
|
|
|
|
-
|
|
|
|
|
- with contextlib.redirect_stdout(output_buffer):
|
|
|
|
|
- comparison_result = compare_ocr_results(
|
|
|
|
|
- file1_path=str(current_md_path),
|
|
|
|
|
- file2_path=str(vlm_md_path),
|
|
|
|
|
- output_file=str(comparison_result_path),
|
|
|
|
|
- output_format='both',
|
|
|
|
|
- ignore_images=True
|
|
|
|
|
- )
|
|
|
|
|
-
|
|
|
|
|
- # 显示对比过程输出
|
|
|
|
|
- compare_output.code(output_buffer.getvalue(), language='text')
|
|
|
|
|
-
|
|
|
|
|
- status_text.text("✅ VLM预校验完成")
|
|
|
|
|
|
|
+ # 显示对比过程输出
|
|
|
|
|
+ compare_output.code(output_buffer.getvalue(), language='text')
|
|
|
|
|
+
|
|
|
|
|
+ status_text.text("✅ 交叉验证完成")
|
|
|
|
|
|
|
|
- st.session_state.comparison_result = {
|
|
|
|
|
- "image_path": self.image_path,
|
|
|
|
|
- "comparison_result_json": f"{comparison_result_path}.json",
|
|
|
|
|
- "comparison_result_md": f"{comparison_result_path}.md",
|
|
|
|
|
- "comparison_result": comparison_result
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ st.session_state.cross_validation_result = {
|
|
|
|
|
+ "ocr_source": get_data_source_display_name(self.current_source_config),
|
|
|
|
|
+ "verify_source": get_data_source_display_name(self.verify_source_config),
|
|
|
|
|
+ "ocr_file": str(current_md_path),
|
|
|
|
|
+ "verify_file": str(verify_md_path),
|
|
|
|
|
+ "comparison_result_json": f"{comparison_result_path}.json",
|
|
|
|
|
+ "comparison_result_md": f"{comparison_result_path}.md",
|
|
|
|
|
+ "comparison_result": comparison_result
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
# 第五步:显示对比结果
|
|
# 第五步:显示对比结果
|
|
|
self.display_comparison_results(comparison_result, detailed=False)
|
|
self.display_comparison_results(comparison_result, detailed=False)
|
|
|
|
|
|
|
|
- # 第六步:提供文件下载
|
|
|
|
|
- # self.provide_download_options(pre_validation_dir, vlm_md_path, comparison_result)
|
|
|
|
|
-
|
|
|
|
|
except Exception as e:
|
|
except Exception as e:
|
|
|
- st.error(f"❌ VLM预校验失败: {e}")
|
|
|
|
|
|
|
+ st.error(f"❌ 交叉验证失败: {e}")
|
|
|
st.exception(e)
|
|
st.exception(e)
|
|
|
-
|
|
|
|
|
|
|
+
|
|
|
|
|
+ @st.dialog("查看交叉验证结果", width="large", dismissible=True, on_dismiss="rerun")
|
|
|
|
|
+ def show_cross_validation_results_dialog(self):
|
|
|
|
|
+ """显示交叉验证结果的对话框"""
|
|
|
|
|
+ current_md_path = Path(self.file_paths[self.selected_file_index]).with_suffix('.md')
|
|
|
|
|
+ pre_validation_dir = Path(self.config['pre_validation'].get('out_dir', './output/pre_validation/')).resolve()
|
|
|
|
|
+ comparison_result_path = pre_validation_dir / f"{current_md_path.stem}_cross_validation.json"
|
|
|
|
|
+
|
|
|
|
|
+ if 'cross_validation_result' in st.session_state and st.session_state.cross_validation_result:
|
|
|
|
|
+ result = st.session_state.cross_validation_result
|
|
|
|
|
+
|
|
|
|
|
+ # 显示数据源信息
|
|
|
|
|
+ col1, col2 = st.columns(2)
|
|
|
|
|
+ with col1:
|
|
|
|
|
+ st.info(f"**OCR数据源:** {result['ocr_source']}")
|
|
|
|
|
+ with col2:
|
|
|
|
|
+ st.info(f"**验证数据源:** {result['verify_source']}")
|
|
|
|
|
+
|
|
|
|
|
+ self.display_comparison_results(result['comparison_result'])
|
|
|
|
|
+
|
|
|
|
|
+ elif comparison_result_path.exists():
|
|
|
|
|
+ # 如果有历史结果文件,提示加载
|
|
|
|
|
+ if st.button("📂 加载历史验证结果"):
|
|
|
|
|
+ with open(comparison_result_path, "r", encoding="utf-8") as f:
|
|
|
|
|
+ comparison_json_result = json.load(f)
|
|
|
|
|
+
|
|
|
|
|
+ cross_validation_result = {
|
|
|
|
|
+ "ocr_source": get_data_source_display_name(self.current_source_config),
|
|
|
|
|
+ "verify_source": get_data_source_display_name(self.verify_source_config),
|
|
|
|
|
+ "ocr_file": comparison_json_result['file1_path'],
|
|
|
|
|
+ "verify_file": comparison_json_result['file2_path'],
|
|
|
|
|
+ "comparison_result_json": str(comparison_result_path),
|
|
|
|
|
+ "comparison_result_md": str(comparison_result_path.with_suffix('.md')),
|
|
|
|
|
+ "comparison_result": comparison_json_result
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ st.session_state.cross_validation_result = cross_validation_result
|
|
|
|
|
+ self.display_comparison_results(comparison_json_result)
|
|
|
|
|
+ else:
|
|
|
|
|
+ st.info("暂无交叉验证结果,请先运行交叉验证")
|
|
|
|
|
+
|
|
|
def display_comparison_results(self, comparison_result: dict, detailed: bool = True):
|
|
def display_comparison_results(self, comparison_result: dict, detailed: bool = True):
|
|
|
"""显示对比结果摘要 - 使用DataFrame展示"""
|
|
"""显示对比结果摘要 - 使用DataFrame展示"""
|
|
|
|
|
|
|
@@ -698,9 +803,9 @@ class StreamlitOCRValidator:
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
with col2:
|
|
with col2:
|
|
|
- st.write("**VLM识别结果:**")
|
|
|
|
|
|
|
+ st.write("**验证数据源识别结果:**")
|
|
|
st.text_area(
|
|
st.text_area(
|
|
|
- "VLM识别结果详情",
|
|
|
|
|
|
|
+ "验证数据源识别结果详情",
|
|
|
value=diff['file2_value'],
|
|
value=diff['file2_value'],
|
|
|
height=200,
|
|
height=200,
|
|
|
key=f"vlm_{selected_diff_index}",
|
|
key=f"vlm_{selected_diff_index}",
|
|
@@ -875,31 +980,6 @@ class StreamlitOCRValidator:
|
|
|
else:
|
|
else:
|
|
|
st.error("❌ 发现大量差异,建议重新进行OCR识别或检查原始图片质量")
|
|
st.error("❌ 发现大量差异,建议重新进行OCR识别或检查原始图片质量")
|
|
|
|
|
|
|
|
- @st.dialog("查看预校验结果", width="large", dismissible=True, on_dismiss="rerun")
|
|
|
|
|
- def show_comparison_results_dialog(self):
|
|
|
|
|
- """显示VLM预校验结果的对话框"""
|
|
|
|
|
- current_md_path = Path(self.file_paths[self.selected_file_index]).with_suffix('.md')
|
|
|
|
|
- pre_validation_dir = Path(self.config['pre_validation'].get('out_dir', './output/pre_validation/')).resolve()
|
|
|
|
|
- comparison_result_path = pre_validation_dir / f"{current_md_path.stem}_comparison_result.json"
|
|
|
|
|
- if 'comparison_result' in st.session_state and st.session_state.comparison_result:
|
|
|
|
|
- self.display_comparison_results(st.session_state.comparison_result['comparison_result'])
|
|
|
|
|
- elif comparison_result_path.exists():
|
|
|
|
|
- # 如果pre_validation_dir下有结果文件,提示用户加载
|
|
|
|
|
- if st.button("加载预校验结果"):
|
|
|
|
|
- with open(comparison_result_path, "r", encoding="utf-8") as f:
|
|
|
|
|
- comparison_json_result = json.load(f)
|
|
|
|
|
- comparison_result = {
|
|
|
|
|
- "image_path": self.image_path,
|
|
|
|
|
- "comparison_result_json": str(comparison_result_path),
|
|
|
|
|
- "comparison_result_md": str(comparison_result_path.with_suffix('.md')),
|
|
|
|
|
- "comparison_result": comparison_json_result
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- st.session_state.comparison_result = comparison_result
|
|
|
|
|
- self.display_comparison_results(comparison_json_result)
|
|
|
|
|
- else:
|
|
|
|
|
- st.info("暂无预校验结果,请先运行VLM预校验")
|
|
|
|
|
-
|
|
|
|
|
def create_compact_layout(self, config):
|
|
def create_compact_layout(self, config):
|
|
|
"""创建滚动凑布局"""
|
|
"""创建滚动凑布局"""
|
|
|
return self.layout_manager.create_compact_layout(config)
|
|
return self.layout_manager.create_compact_layout(config)
|
|
@@ -999,17 +1079,17 @@ def main():
|
|
|
st.rerun()
|
|
st.rerun()
|
|
|
else:
|
|
else:
|
|
|
st.warning("当前数据源中未找到OCR结果文件")
|
|
st.warning("当前数据源中未找到OCR结果文件")
|
|
|
-
|
|
|
|
|
- # VLM预校验按钮
|
|
|
|
|
- if st.button("VLM预校验", type="primary", icon=":material/compare_arrows:"):
|
|
|
|
|
|
|
+
|
|
|
|
|
+ # 交叉验证按钮
|
|
|
|
|
+ if st.button("交叉验证", type="primary", icon=":material/compare_arrows:"):
|
|
|
if validator.image_path and validator.md_content:
|
|
if validator.image_path and validator.md_content:
|
|
|
- validator.vlm_pre_validation()
|
|
|
|
|
|
|
+ validator.cross_validation()
|
|
|
else:
|
|
else:
|
|
|
message_box("❌ 请先选择OCR数据文件", "error")
|
|
message_box("❌ 请先选择OCR数据文件", "error")
|
|
|
|
|
|
|
|
# 查看预校验结果按钮
|
|
# 查看预校验结果按钮
|
|
|
- if st.button("查看预校验结果", type="secondary", icon=":material/quick_reference_all:"):
|
|
|
|
|
- validator.show_comparison_results_dialog()
|
|
|
|
|
|
|
+ if st.button("查看验证结果", type="secondary", icon=":material/quick_reference_all:"):
|
|
|
|
|
+ validator.show_cross_validation_results_dialog()
|
|
|
|
|
|
|
|
# 显示当前数据源统计信息
|
|
# 显示当前数据源统计信息
|
|
|
with st.expander("🔧 OCR工具统计信息", expanded=False):
|
|
with st.expander("🔧 OCR工具统计信息", expanded=False):
|
|
@@ -1035,7 +1115,7 @@ def main():
|
|
|
st.write("**详细信息:**", stats['tool_info'])
|
|
st.write("**详细信息:**", stats['tool_info'])
|
|
|
|
|
|
|
|
# 其余标签页保持不变...
|
|
# 其余标签页保持不变...
|
|
|
- tab1, tab2, tab3 = st.tabs(["📄 内容校验", "📄 VLM预校验识别结果", "📊 表格分析"])
|
|
|
|
|
|
|
+ tab1, tab2, tab3 = st.tabs(["📄 内容人工检查", "🔍 交叉验证结果", "📊 表格分析"])
|
|
|
|
|
|
|
|
with tab1:
|
|
with tab1:
|
|
|
validator.create_compact_layout(config)
|
|
validator.create_compact_layout(config)
|
|
@@ -1044,9 +1124,15 @@ def main():
|
|
|
# st.header("📄 VLM预校验识别结果")
|
|
# st.header("📄 VLM预校验识别结果")
|
|
|
current_md_path = Path(validator.file_paths[validator.selected_file_index]).with_suffix('.md')
|
|
current_md_path = Path(validator.file_paths[validator.selected_file_index]).with_suffix('.md')
|
|
|
pre_validation_dir = Path(validator.config['pre_validation'].get('out_dir', './output/pre_validation/')).resolve()
|
|
pre_validation_dir = Path(validator.config['pre_validation'].get('out_dir', './output/pre_validation/')).resolve()
|
|
|
- comparison_result_path = pre_validation_dir / f"{current_md_path.stem}_comparison_result.json"
|
|
|
|
|
- pre_validation_path = pre_validation_dir / f"{current_md_path.stem}.md"
|
|
|
|
|
|
|
+ comparison_result_path = pre_validation_dir / f"{current_md_path.stem}_cross_validation.json"
|
|
|
|
|
+ # pre_validation_path = pre_validation_dir / f"{current_md_path.stem}.md"
|
|
|
|
|
+ verify_md_path = validator.find_verify_md_path(validator.selected_file_index)
|
|
|
|
|
+
|
|
|
if comparison_result_path.exists():
|
|
if comparison_result_path.exists():
|
|
|
|
|
+ # 加载并显示验证结果
|
|
|
|
|
+ with open(comparison_result_path, "r", encoding="utf-8") as f:
|
|
|
|
|
+ comparison_result = json.load(f)
|
|
|
|
|
+
|
|
|
# 左边显示OCR结果,右边显示VLM结果
|
|
# 左边显示OCR结果,右边显示VLM结果
|
|
|
col1, col2 = st.columns([1,1])
|
|
col1, col2 = st.columns([1,1])
|
|
|
with col1:
|
|
with col1:
|
|
@@ -1059,12 +1145,16 @@ def main():
|
|
|
validator.layout_manager.render_content_by_mode(original_md_content, "HTML渲染", font_size, height, layout_type)
|
|
validator.layout_manager.render_content_by_mode(original_md_content, "HTML渲染", font_size, height, layout_type)
|
|
|
with col2:
|
|
with col2:
|
|
|
st.subheader("🤖 VLM识别结果")
|
|
st.subheader("🤖 VLM识别结果")
|
|
|
- with open(pre_validation_path, "r", encoding="utf-8") as f:
|
|
|
|
|
- pre_validation_md_content = f.read()
|
|
|
|
|
|
|
+ with open(str(verify_md_path), "r", encoding="utf-8") as f:
|
|
|
|
|
+ verify_md_content = f.read()
|
|
|
font_size = config['styles'].get('font_size', 10)
|
|
font_size = config['styles'].get('font_size', 10)
|
|
|
height = config['styles']['layout'].get('default_height', 800)
|
|
height = config['styles']['layout'].get('default_height', 800)
|
|
|
layout_type = "compact"
|
|
layout_type = "compact"
|
|
|
- validator.layout_manager.render_content_by_mode(pre_validation_md_content, "HTML渲染", font_size, height, layout_type)
|
|
|
|
|
|
|
+ validator.layout_manager.render_content_by_mode(verify_md_content, "HTML渲染", font_size, height, layout_type)
|
|
|
|
|
+
|
|
|
|
|
+ # 显示差异统计
|
|
|
|
|
+ st.markdown("---")
|
|
|
|
|
+ validator.display_comparison_results(comparison_result, detailed=True)
|
|
|
else:
|
|
else:
|
|
|
st.info("暂无预校验结果,请先运行VLM预校验")
|
|
st.info("暂无预校验结果,请先运行VLM预校验")
|
|
|
|
|
|
|
@@ -1079,35 +1169,5 @@ def main():
|
|
|
else:
|
|
else:
|
|
|
st.info("当前OCR结果中没有检测到表格数据")
|
|
st.info("当前OCR结果中没有检测到表格数据")
|
|
|
|
|
|
|
|
- # with tab4:
|
|
|
|
|
- # # 数据统计页面 - 保持原有逻辑
|
|
|
|
|
- # st.header("📈 OCR数据统计")
|
|
|
|
|
-
|
|
|
|
|
- # # 添加数据源特定的统计信息
|
|
|
|
|
- # if validator.current_source_config:
|
|
|
|
|
- # st.subheader(f"📊 {get_data_source_display_name(validator.current_source_config)} - 统计信息")
|
|
|
|
|
-
|
|
|
|
|
- # if stats['categories']:
|
|
|
|
|
- # st.subheader("📊 类别分布")
|
|
|
|
|
- # fig_pie = px.pie(
|
|
|
|
|
- # values=list(stats['categories'].values()),
|
|
|
|
|
- # names=list(stats['categories'].keys()),
|
|
|
|
|
- # title="文本类别分布"
|
|
|
|
|
- # )
|
|
|
|
|
- # st.plotly_chart(fig_pie, use_container_width=True)
|
|
|
|
|
-
|
|
|
|
|
- # # 错误率分析
|
|
|
|
|
- # st.subheader("📈 质量分析")
|
|
|
|
|
- # accuracy_data = {
|
|
|
|
|
- # '状态': ['正确', '错误'],
|
|
|
|
|
- # '数量': [stats['clickable_texts'] - stats['marked_errors'], stats['marked_errors']]
|
|
|
|
|
- # }
|
|
|
|
|
-
|
|
|
|
|
- # fig_bar = px.bar(
|
|
|
|
|
- # accuracy_data, x='状态', y='数量', title="识别质量分布",
|
|
|
|
|
- # color='状态', color_discrete_map={'正确': 'green', '错误': 'red'}
|
|
|
|
|
- # )
|
|
|
|
|
- # st.plotly_chart(fig_bar, use_container_width=True)
|
|
|
|
|
-
|
|
|
|
|
if __name__ == "__main__":
|
|
if __name__ == "__main__":
|
|
|
main()
|
|
main()
|