Ver código fonte

feat: 优化图像处理逻辑,修正方向校正返回值,更新版式检测模型名称和初始化方式

zhch158_admin 2 semanas atrás
pai
commit
a93802d89b

+ 11 - 9
zhch/universal_doc_parser/models/adapters/mineru_adapter.py

@@ -57,20 +57,19 @@ class MinerUPreprocessor(BasePreprocessor):
         if isinstance(image, Image.Image):
             image = np.array(image)
 
-        rotate_map = {0: 0, 1: 90, 2: 180, 3: 270}
-        rotate_label = 0
+        rotate_angle = 0
         processed_image = image
         
         # 方向校正
         if self.orientation_classifier is not None:
             try:
-                rotate_label = self.orientation_classifier.predict(image)
-                processed_image = self._apply_rotation(processed_image, rotate_label)
-                logger.info(f"📐 Applied rotation: {rotate_label}")
+                rotate_angle = int(self.orientation_classifier.predict(image))
+                processed_image = self._apply_rotation(processed_image, rotate_angle)
+                logger.info(f"📐 Applied rotation: {rotate_angle}")
             except Exception as e:
                 logger.error(f"⚠️ Orientation classification failed: {e}")
 
-        return processed_image, rotate_map.get(rotate_label, 0)
+        return processed_image, rotate_angle
 
 class MinerULayoutDetector(BaseLayoutDetector):
     """MinerU版式检测适配器"""
@@ -87,7 +86,7 @@ class MinerULayoutDetector(BaseLayoutDetector):
         """初始化版式检测模型"""
         try:
             # 获取模型配置
-            model_name = self.config.get('model_name', 'RT-DETR-H_layout_17cls')
+            model_name = self.config.get('model_name', AtomicModel.Layout)
             model_dir = self.config.get('model_dir')
             device = self.config.get('device', 'cpu')
             
@@ -100,9 +99,12 @@ class MinerULayoutDetector(BaseLayoutDetector):
                     device=device
                 )
             else:
-                # 使用默认模型
+                import os
+                from mineru.utils.enum_class import ModelPath
+                from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
                 self.layout_model = self.atom_model_manager.get_atom_model(
                     atom_model_name=AtomicModel.Layout,
+                    doclayout_yolo_weights=os.path.join(auto_download_and_get_model_root_path(ModelPath.doclayout_yolo), ModelPath.doclayout_yolo),
                     device=device
                 )
             print(f"✅ Layout detector initialized: {model_name}")
@@ -130,7 +132,7 @@ class MinerULayoutDetector(BaseLayoutDetector):
             
             # 转换结果格式
             formatted_results = []
-            for result in layout_results[0]:  # 第一页结果
+            for result in layout_results:  # 第一页结果
                 # 提取坐标信息
                 poly = result.get('poly', [0, 0, 0, 0, 0, 0, 0, 0])
                 if len(poly) >= 8: