|
@@ -48,18 +48,27 @@ class SmartLayoutRouter(BaseLayoutDetector):
|
|
|
|
|
|
|
|
def initialize(self):
|
|
def initialize(self):
|
|
|
"""初始化所有模型"""
|
|
"""初始化所有模型"""
|
|
|
- # 获取 post_process 配置(从父配置中)
|
|
|
|
|
|
|
+ # 获取 post_process / debug_options 配置(从父 smart_router 配置中)
|
|
|
post_process_config = self.config.get('post_process', {})
|
|
post_process_config = self.config.get('post_process', {})
|
|
|
-
|
|
|
|
|
|
|
+ layout_debug_options = self.config.get('debug_options', {})
|
|
|
|
|
+ if not isinstance(layout_debug_options, dict):
|
|
|
|
|
+ layout_debug_options = {}
|
|
|
|
|
+
|
|
|
|
|
+ def _merge_child_model_config(child_cfg: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
|
|
+ merged = child_cfg.copy()
|
|
|
|
|
+ if post_process_config:
|
|
|
|
|
+ merged['post_process'] = post_process_config
|
|
|
|
|
+ if layout_debug_options:
|
|
|
|
|
+ merged['debug_options'] = layout_debug_options.copy()
|
|
|
|
|
+ return merged
|
|
|
|
|
+
|
|
|
# 初始化主模型
|
|
# 初始化主模型
|
|
|
for model_name, model_config in self.model_configs.items():
|
|
for model_name, model_config in self.model_configs.items():
|
|
|
try:
|
|
try:
|
|
|
logger.info(f"🔧 Initializing layout model: {model_name}")
|
|
logger.info(f"🔧 Initializing layout model: {model_name}")
|
|
|
- # 将 post_process 配置添加到子模型配置中
|
|
|
|
|
- if post_process_config:
|
|
|
|
|
- model_config = model_config.copy()
|
|
|
|
|
- model_config['post_process'] = post_process_config
|
|
|
|
|
- detector = ModelFactory.create_layout_detector(model_config)
|
|
|
|
|
|
|
+ detector = ModelFactory.create_layout_detector(
|
|
|
|
|
+ _merge_child_model_config(model_config)
|
|
|
|
|
+ )
|
|
|
self.models[model_name] = detector
|
|
self.models[model_name] = detector
|
|
|
logger.info(f"✅ Model {model_name} initialized")
|
|
logger.info(f"✅ Model {model_name} initialized")
|
|
|
except Exception as e:
|
|
except Exception as e:
|
|
@@ -68,11 +77,9 @@ class SmartLayoutRouter(BaseLayoutDetector):
|
|
|
# 初始化回退模型(如果配置了)
|
|
# 初始化回退模型(如果配置了)
|
|
|
if self.fallback_config:
|
|
if self.fallback_config:
|
|
|
try:
|
|
try:
|
|
|
- # 将 post_process 配置添加到回退模型配置中
|
|
|
|
|
- fallback_config = self.fallback_config.copy()
|
|
|
|
|
- if post_process_config:
|
|
|
|
|
- fallback_config['post_process'] = post_process_config
|
|
|
|
|
- fallback_detector = ModelFactory.create_layout_detector(fallback_config)
|
|
|
|
|
|
|
+ fallback_detector = ModelFactory.create_layout_detector(
|
|
|
|
|
+ _merge_child_model_config(self.fallback_config)
|
|
|
|
|
+ )
|
|
|
self.models['fallback'] = fallback_detector
|
|
self.models['fallback'] = fallback_detector
|
|
|
logger.info("✅ Fallback model initialized")
|
|
logger.info("✅ Fallback model initialized")
|
|
|
except Exception as e:
|
|
except Exception as e:
|
|
@@ -107,6 +114,9 @@ class SmartLayoutRouter(BaseLayoutDetector):
|
|
|
model.output_dir = self.output_dir # type: ignore[attr-defined]
|
|
model.output_dir = self.output_dir # type: ignore[attr-defined]
|
|
|
if self.page_name:
|
|
if self.page_name:
|
|
|
model.page_name = self.page_name # type: ignore[attr-defined]
|
|
model.page_name = self.page_name # type: ignore[attr-defined]
|
|
|
|
|
+ parent_opts = self._layout_debug_options()
|
|
|
|
|
+ if parent_opts:
|
|
|
|
|
+ model.config['debug_options'] = parent_opts.copy()
|
|
|
|
|
|
|
|
def _detect_raw(
|
|
def _detect_raw(
|
|
|
self,
|
|
self,
|
|
@@ -521,9 +531,11 @@ class SmartLayoutRouter(BaseLayoutDetector):
|
|
|
font, font_scale, (255, 255, 255), text_thickness)
|
|
font, font_scale, (255, 255, 255), text_thickness)
|
|
|
|
|
|
|
|
# 保存对比图像
|
|
# 保存对比图像
|
|
|
- debug_dir = Path(self.output_dir) / "debug_comparison" / "layout_comparison"
|
|
|
|
|
|
|
+ from ocr_utils.module_debug_viz import resolve_module_debug_dir
|
|
|
|
|
+
|
|
|
|
|
+ debug_dir = resolve_module_debug_dir(self.output_dir, "layout_comparison")
|
|
|
debug_dir.mkdir(parents=True, exist_ok=True)
|
|
debug_dir.mkdir(parents=True, exist_ok=True)
|
|
|
- output_path = debug_dir / f"{self.page_name}_layout_comparison.jpg"
|
|
|
|
|
|
|
+ output_path = debug_dir / f"{self.page_name}_layout_comparison.png"
|
|
|
cv2.imwrite(str(output_path), vis_image)
|
|
cv2.imwrite(str(output_path), vis_image)
|
|
|
logger.info(f"📊 Saved layout comparison image: {output_path}")
|
|
logger.info(f"📊 Saved layout comparison image: {output_path}")
|
|
|
|
|
|