|
@@ -30,12 +30,15 @@ class SmartLayoutRouter(BaseLayoutDetector):
|
|
|
|
|
|
|
|
def __init__(self, config: Dict[str, Any]):
|
|
def __init__(self, config: Dict[str, Any]):
|
|
|
super().__init__(config)
|
|
super().__init__(config)
|
|
|
- self.strategy = config.get('strategy', 'ocr_eval') # ocr_eval, auto
|
|
|
|
|
|
|
+ self.strategy = config.get('strategy', 'ocr_eval') # ocr_eval, auto, scene
|
|
|
self.models = {}
|
|
self.models = {}
|
|
|
self.model_configs = config.get('models', {})
|
|
self.model_configs = config.get('models', {})
|
|
|
self.fallback_config = config.get('fallback_model', None)
|
|
self.fallback_config = config.get('fallback_model', None)
|
|
|
self.evaluator = OCRBasedLayoutEvaluator()
|
|
self.evaluator = OCRBasedLayoutEvaluator()
|
|
|
self.ocr_recognizer = None # 用于在ocr_eval策略中获取OCR结果
|
|
self.ocr_recognizer = None # 用于在ocr_eval策略中获取OCR结果
|
|
|
|
|
+ self.scene_name = config.get('scene_name', None)
|
|
|
|
|
+ self.scene_strategy = config.get('scene_strategy', {})
|
|
|
|
|
+ self.default_model = config.get('default_model', None)
|
|
|
# 调试模式支持
|
|
# 调试模式支持
|
|
|
self.debug_mode = config.get('debug_mode', False)
|
|
self.debug_mode = config.get('debug_mode', False)
|
|
|
self.output_dir = config.get('output_dir', None)
|
|
self.output_dir = config.get('output_dir', None)
|
|
@@ -90,6 +93,10 @@ class SmartLayoutRouter(BaseLayoutDetector):
|
|
|
def set_ocr_recognizer(self, ocr_recognizer):
|
|
def set_ocr_recognizer(self, ocr_recognizer):
|
|
|
"""设置OCR识别器(用于ocr_eval策略)"""
|
|
"""设置OCR识别器(用于ocr_eval策略)"""
|
|
|
self.ocr_recognizer = ocr_recognizer
|
|
self.ocr_recognizer = ocr_recognizer
|
|
|
|
|
+
|
|
|
|
|
+ def set_scene_name(self, scene_name: Optional[str]):
|
|
|
|
|
+ """设置场景名称(用于scene策略)"""
|
|
|
|
|
+ self.scene_name = scene_name
|
|
|
|
|
|
|
|
def _detect_raw(
|
|
def _detect_raw(
|
|
|
self,
|
|
self,
|
|
@@ -137,8 +144,40 @@ class SmartLayoutRouter(BaseLayoutDetector):
|
|
|
return self._ocr_eval_detect(image, ocr_spans)
|
|
return self._ocr_eval_detect(image, ocr_spans)
|
|
|
elif self.strategy == 'auto':
|
|
elif self.strategy == 'auto':
|
|
|
return self._auto_select_detect(image)
|
|
return self._auto_select_detect(image)
|
|
|
|
|
+ elif self.strategy == 'scene':
|
|
|
|
|
+ return self._scene_select_detect(image)
|
|
|
else:
|
|
else:
|
|
|
raise ValueError(f"Unknown strategy: {self.strategy}")
|
|
raise ValueError(f"Unknown strategy: {self.strategy}")
|
|
|
|
|
+
|
|
|
|
|
+ def _scene_select_detect(
|
|
|
|
|
+ self,
|
|
|
|
|
+ image: Union[np.ndarray, Image.Image]
|
|
|
|
|
+ ) -> List[Dict[str, Any]]:
|
|
|
|
|
+ """
|
|
|
|
|
+ 场景策略:根据scene_strategy直接选择模型
|
|
|
|
|
+
|
|
|
|
|
+ 注意:不执行ocr_eval,直接使用选定模型
|
|
|
|
|
+ """
|
|
|
|
|
+ selected_model = None
|
|
|
|
|
+ if self.scene_name:
|
|
|
|
|
+ scene_rule = self.scene_strategy.get(self.scene_name)
|
|
|
|
|
+ if isinstance(scene_rule, str):
|
|
|
|
|
+ selected_model = scene_rule
|
|
|
|
|
+ elif isinstance(scene_rule, dict):
|
|
|
|
|
+ selected_model = scene_rule.get('model')
|
|
|
|
|
+
|
|
|
|
|
+ if not selected_model:
|
|
|
|
|
+ selected_model = self.default_model
|
|
|
|
|
+
|
|
|
|
|
+ if not selected_model and self.models:
|
|
|
|
|
+ selected_model = next(iter(self.models.keys()))
|
|
|
|
|
+
|
|
|
|
|
+ if selected_model not in self.models:
|
|
|
|
|
+ logger.warning(f"⚠️ Scene strategy model not available: {selected_model}, using first model")
|
|
|
|
|
+ selected_model = next(iter(self.models.keys()))
|
|
|
|
|
+
|
|
|
|
|
+ logger.info(f"🎯 Scene strategy selected model: {selected_model} (scene: {self.scene_name})")
|
|
|
|
|
+ return self.models[selected_model].detect(image)
|
|
|
|
|
|
|
|
def _ocr_eval_detect(
|
|
def _ocr_eval_detect(
|
|
|
self,
|
|
self,
|