Преглед на файлове

feat(pipeline_manager): 添加场景名称设置功能,并同步到布局路由器

zhch158_admin преди 1 седмица
родител
ревизия
87c5b916fb
променени са 2 файла, в които са добавени 52 реда и са изтрити 1 реда
  1. 40 1
      ocr_tools/universal_doc_parser/core/layout_model_router.py
  2. 12 0
      ocr_tools/universal_doc_parser/core/pipeline_manager_v2.py

+ 40 - 1
ocr_tools/universal_doc_parser/core/layout_model_router.py

@@ -30,12 +30,15 @@ class SmartLayoutRouter(BaseLayoutDetector):
     
     def __init__(self, config: Dict[str, Any]):
         super().__init__(config)
-        self.strategy = config.get('strategy', 'ocr_eval')  # ocr_eval, auto
+        self.strategy = config.get('strategy', 'ocr_eval')  # ocr_eval, auto, scene
         self.models = {}
         self.model_configs = config.get('models', {})
         self.fallback_config = config.get('fallback_model', None)
         self.evaluator = OCRBasedLayoutEvaluator()
         self.ocr_recognizer = None  # 用于在ocr_eval策略中获取OCR结果
+        self.scene_name = config.get('scene_name', None)
+        self.scene_strategy = config.get('scene_strategy', {})
+        self.default_model = config.get('default_model', None)
         # 调试模式支持
         self.debug_mode = config.get('debug_mode', False)
         self.output_dir = config.get('output_dir', None)
@@ -90,6 +93,10 @@ class SmartLayoutRouter(BaseLayoutDetector):
     def set_ocr_recognizer(self, ocr_recognizer):
         """设置OCR识别器(用于ocr_eval策略)"""
         self.ocr_recognizer = ocr_recognizer
+
+    def set_scene_name(self, scene_name: Optional[str]):
+        """设置场景名称(用于scene策略)"""
+        self.scene_name = scene_name
     
     def _detect_raw(
         self, 
@@ -137,8 +144,40 @@ class SmartLayoutRouter(BaseLayoutDetector):
             return self._ocr_eval_detect(image, ocr_spans)
         elif self.strategy == 'auto':
             return self._auto_select_detect(image)
+        elif self.strategy == 'scene':
+            return self._scene_select_detect(image)
         else:
             raise ValueError(f"Unknown strategy: {self.strategy}")
+
+    def _scene_select_detect(
+        self,
+        image: Union[np.ndarray, Image.Image]
+    ) -> List[Dict[str, Any]]:
+        """
+        场景策略:根据scene_strategy直接选择模型
+
+        注意:不执行ocr_eval,直接使用选定模型
+        """
+        selected_model = None
+        if self.scene_name:
+            scene_rule = self.scene_strategy.get(self.scene_name)
+            if isinstance(scene_rule, str):
+                selected_model = scene_rule
+            elif isinstance(scene_rule, dict):
+                selected_model = scene_rule.get('model')
+
+        if not selected_model:
+            selected_model = self.default_model
+
+        if not selected_model and self.models:
+            selected_model = next(iter(self.models.keys()))
+
+        if selected_model not in self.models:
+            logger.warning(f"⚠️ Scene strategy model not available: {selected_model}, using first model")
+            selected_model = next(iter(self.models.keys()))
+
+        logger.info(f"🎯 Scene strategy selected model: {selected_model} (scene: {self.scene_name})")
+        return self.models[selected_model].detect(image)
     
     def _ocr_eval_detect(
         self, 

+ 12 - 0
ocr_tools/universal_doc_parser/core/pipeline_manager_v2.py

@@ -117,6 +117,15 @@ class EnhancedDocPipeline:
         self._init_element_processors()
         
         logger.info(f"✅ Pipeline initialized for scene: {self.scene_name}")
+
+    def set_scene_name(self, scene_name: Optional[str]):
+        """设置场景名称,并同步到布局路由器"""
+        if not scene_name:
+            return
+        self.scene_name = scene_name
+        if hasattr(self.layout_detector, 'set_scene_name'):
+            self.layout_detector.set_scene_name(scene_name)
+        logger.info(f"🔄 Scene updated in pipeline: {scene_name}")
     
     def _ensure_vl_recognizer(self):
         """懒加载 VL 识别器(仅在需要时初始化,且只初始化一次)"""
@@ -155,6 +164,9 @@ class EnhancedDocPipeline:
             self.layout_detector = ModelFactory.create_layout_detector(
                 self.config['layout_detection']
             )
+
+            if hasattr(self.layout_detector, 'set_scene_name'):
+                self.layout_detector.set_scene_name(self.scene_name)
             
             # 如果是智能路由器且使用ocr_eval策略,需要设置OCR识别器
             if hasattr(self.layout_detector, 'set_ocr_recognizer'):