Просмотр исходного кода

feat(增强布局路由器): 在SmartLayoutRouter中优化模型配置合并逻辑,添加debug_options支持,更新调试输出目录路径,提升调试过程的灵活性和准确性。

zhch158_admin 5 дней назад
Родитель
Сommit
bcc8a748b5
1 измененных файлов с 26 добавлено и 14 удалено
  1. 26 14
      ocr_tools/universal_doc_parser/core/layout_model_router.py

+ 26 - 14
ocr_tools/universal_doc_parser/core/layout_model_router.py

@@ -48,18 +48,27 @@ class SmartLayoutRouter(BaseLayoutDetector):
         
     def initialize(self):
         """初始化所有模型"""
-        # 获取 post_process 配置(从父配置中)
+        # 获取 post_process / debug_options 配置(从父 smart_router 配置中)
         post_process_config = self.config.get('post_process', {})
-        
+        layout_debug_options = self.config.get('debug_options', {})
+        if not isinstance(layout_debug_options, dict):
+            layout_debug_options = {}
+
+        def _merge_child_model_config(child_cfg: Dict[str, Any]) -> Dict[str, Any]:
+            merged = child_cfg.copy()
+            if post_process_config:
+                merged['post_process'] = post_process_config
+            if layout_debug_options:
+                merged['debug_options'] = layout_debug_options.copy()
+            return merged
+
         # 初始化主模型
         for model_name, model_config in self.model_configs.items():
             try:
                 logger.info(f"🔧 Initializing layout model: {model_name}")
-                # 将 post_process 配置添加到子模型配置中
-                if post_process_config:
-                    model_config = model_config.copy()
-                    model_config['post_process'] = post_process_config
-                detector = ModelFactory.create_layout_detector(model_config)
+                detector = ModelFactory.create_layout_detector(
+                    _merge_child_model_config(model_config)
+                )
                 self.models[model_name] = detector
                 logger.info(f"✅ Model {model_name} initialized")
             except Exception as e:
@@ -68,11 +77,9 @@ class SmartLayoutRouter(BaseLayoutDetector):
         # 初始化回退模型(如果配置了)
         if self.fallback_config:
             try:
-                # 将 post_process 配置添加到回退模型配置中
-                fallback_config = self.fallback_config.copy()
-                if post_process_config:
-                    fallback_config['post_process'] = post_process_config
-                fallback_detector = ModelFactory.create_layout_detector(fallback_config)
+                fallback_detector = ModelFactory.create_layout_detector(
+                    _merge_child_model_config(self.fallback_config)
+                )
                 self.models['fallback'] = fallback_detector
                 logger.info("✅ Fallback model initialized")
             except Exception as e:
@@ -107,6 +114,9 @@ class SmartLayoutRouter(BaseLayoutDetector):
             model.output_dir = self.output_dir  # type: ignore[attr-defined]
         if self.page_name:
             model.page_name = self.page_name  # type: ignore[attr-defined]
+        parent_opts = self._layout_debug_options()
+        if parent_opts:
+            model.config['debug_options'] = parent_opts.copy()
     
     def _detect_raw(
         self, 
@@ -521,9 +531,11 @@ class SmartLayoutRouter(BaseLayoutDetector):
                                   font, font_scale, (255, 255, 255), text_thickness)
             
             # 保存对比图像
-            debug_dir = Path(self.output_dir) / "debug_comparison" / "layout_comparison"
+            from ocr_utils.module_debug_viz import resolve_module_debug_dir
+
+            debug_dir = resolve_module_debug_dir(self.output_dir, "layout_comparison")
             debug_dir.mkdir(parents=True, exist_ok=True)
-            output_path = debug_dir / f"{self.page_name}_layout_comparison.jpg"
+            output_path = debug_dir / f"{self.page_name}_layout_comparison.png"
             cv2.imwrite(str(output_path), vis_image)
             logger.info(f"📊 Saved layout comparison image: {output_path}")