Explorar o código

feat: 移除 PytorchPaddleOCR 主类的导入

zhch158_admin hai 2 semanas
pai
achega
bf859eb9b7

+ 108 - 66
zhch/unified_pytorch_models/vendor/pytorch_paddle.py → zhch/unified_pytorch_models/pytorch_paddle.py

@@ -13,12 +13,9 @@ import numpy as np
 import yaml
 from loguru import logger
 import argparse
+from typing import Optional
 
-# ✅ 修改导入
-try:
-    from .device_utils import get_device
-except ImportError:
-    from device_utils import get_device
+from vendor import get_device
 
 # 当作为脚本运行时,添加父目录到 Python 路径
 current_dir = Path(__file__).resolve().parent
@@ -26,17 +23,12 @@ parent_dir = current_dir.parent
 if str(parent_dir) not in sys.path:
     sys.path.insert(0, str(parent_dir))
 
-# ✅ 根据运行模式选择导入方式
-try:
-    # 作为模块导入时(from vendor.pytorch_paddle import ...)
-    from .ocr_utils import check_img, preprocess_image, sorted_boxes, merge_det_boxes, update_det_boxes, get_rotate_crop_image
-    from .infer.predict_system import TextSystem
-    from .infer import pytorchocr_utility as utility
-except ImportError:
-    # 作为脚本运行时(python pytorch_paddle.py)
-    from vendor.ocr_utils import check_img, preprocess_image, sorted_boxes, merge_det_boxes, update_det_boxes, get_rotate_crop_image
-    from vendor.infer.predict_system import TextSystem
-    from vendor.infer import pytorchocr_utility as utility
+# ✅ 导入方向分类器
+from orientation_classifier_v2 import OrientationClassifierV2
+
+from vendor import check_img, preprocess_image, sorted_boxes, merge_det_boxes, update_det_boxes, get_rotate_crop_image
+from vendor.infer import TextSystem
+from vendor.infer import pytorchocr_utility as utility
 
 latin_lang = [
     "af", "az", "bs", "cs", "cy", "da", "de", "es", "et", "fr", "ga", "hr", "hu", 
@@ -108,7 +100,14 @@ class PytorchPaddleOCR(TextSystem):
         # 获取语言设置
         self.lang = kwargs.get('lang', 'ch')
         self.enable_merge_det_boxes = kwargs.get("enable_merge_det_boxes", True)
-
+        
+        # ✅ 新增:方向分类器配置
+        self.use_orientation_cls = kwargs.get("use_orientation_cls", False)
+        self.orientation_model_path = kwargs.get(
+            "orientation_model_path",
+            None
+        )
+        
         # 自动检测设备
         device = kwargs.get('device', get_device())
         
@@ -130,7 +129,7 @@ class PytorchPaddleOCR(TextSystem):
             self.lang = 'devanagari'
 
         # ✅ 读取模型配置
-        models_config_path = root_dir / 'pytorchocr' / 'utils' / 'resources' / 'models_config.yml'
+        models_config_path = root_dir / 'vendor' / 'pytorchocr' / 'utils' / 'resources' / 'models_config.yml'
         
         if not models_config_path.exists():
             raise FileNotFoundError(f"❌ Config file not found: {models_config_path}")
@@ -178,7 +177,7 @@ class PytorchPaddleOCR(TextSystem):
 
         # ✅ 字典路径
         if 'rec_char_dict_path' not in kwargs:
-            dict_path = root_dir / 'pytorchocr' / 'utils' / 'resources' / 'dict' / dict_file
+            dict_path = root_dir / 'vendor' / 'pytorchocr' / 'utils' / 'resources' / 'dict' / dict_file
             
             if not dict_path.exists():
                 logger.error(f"❌ Dictionary file not found: {dict_path}")
@@ -204,6 +203,55 @@ class PytorchPaddleOCR(TextSystem):
         # ✅ 初始化 TextSystem
         super().__init__(final_args)
         
+        # ✅ 初始化方向分类器(在 TextSystem 之后,这样可以使用 self.text_detector)
+        self.orientation_classifier = None
+        if self.use_orientation_cls and self.orientation_model_path:
+            try:
+                logger.info(f"🔄 Initializing orientation classifier...")
+                logger.info(f"   Model: {self.orientation_model_path}")
+                
+                # ✅ 创建一个简单的检测器适配器
+                class TextDetectorAdapter:
+                    """适配器:将 TextDetector 包装为 OrientationClassifierV2 需要的接口"""
+                    def __init__(self, text_detector):
+                        self.text_detector = text_detector
+                    
+                    def ocr(self, img, det=True, rec=False):
+                        """执行文本检测"""
+                        if not det:
+                            return None
+                        
+                        try:
+                            # 调用检测器
+                            dt_boxes, _ = self.text_detector(img)
+                            
+                            if dt_boxes is None or len(dt_boxes) == 0:
+                                return [[]]
+                            
+                            # 返回格式: [[box1], [box2], ...]
+                            return [dt_boxes.tolist()]
+                        except Exception as e:
+                            logger.warning(f"Detection failed in adapter: {e}")
+                            return [[]]
+                
+                # ✅ 传入检测器适配器
+                detector_adapter = TextDetectorAdapter(self.text_detector)
+                
+                self.orientation_classifier = OrientationClassifierV2(
+                    model_path=self.orientation_model_path,
+                    text_detector=detector_adapter,  # ✅ 传入检测器
+                    aspect_ratio_threshold=kwargs.get("aspect_ratio_threshold", 1.2),
+                    vertical_text_ratio=kwargs.get("vertical_text_ratio", 0.28),
+                    vertical_text_min_count=kwargs.get("vertical_text_min_count", 3),
+                    use_gpu=kwargs.get("device", "cpu") != "cpu"
+                )
+                
+                logger.info(f"✅ Orientation classifier initialized")
+                
+            except Exception as e:
+                logger.warning(f"⚠️  Failed to initialize orientation classifier: {e}")
+                self.orientation_classifier = None
+        
         # ✅ 验证并修复字符集
         logger.info(f"🔍 Validating text recognizer...")
         
@@ -284,7 +332,9 @@ class PytorchPaddleOCR(TextSystem):
             logger.error('When input is a list of images, det must be False')
             exit(1)
         
+        # np.ndarray: BGR 格式图像
         img = check_img(img)
+        
         imgs = [img]
         
         with warnings.catch_warnings():
@@ -398,12 +448,11 @@ class PytorchPaddleOCR(TextSystem):
                 filter_rec_res.append(rec_result)
 
         return filter_boxes, filter_rec_res
-
     def visualize(
         self,
         img: np.ndarray,
         ocr_results: list,
-        output_path: str = None,
+        output_path: Optional[str] = None,
         show_text: bool = True,
         show_confidence: bool = True,
         font_scale: float = 0.5,
@@ -532,27 +581,28 @@ if __name__ == '__main__':
     logger.remove()
     logger.add(sys.stderr, level="INFO")
  
-    print("🚀 Testing PytorchPaddleOCR with Visualization...")
+    print("🚀 Testing PytorchPaddleOCR with Orientation Classifier...")
     
     try:
-        ocr = PytorchPaddleOCR(lang='ch', device='cpu')
+        # ✅ 启用方向分类器
+        ocr = PytorchPaddleOCR(
+            lang='ch',
+            device='cpu',
+            use_orientation_cls=True,  # ✅ 启用
+            orientation_model_path="/Users/zhch158/workspace/repository.git/PaddleX/zhch/unified_pytorch_models/Layout/PP-LCNet_x1_0_doc_ori.onnx"
+        )
         
         # ✅ 验证字符集
         if hasattr(ocr.text_recognizer, 'postprocess_op'):
             char_count = len(ocr.text_recognizer.postprocess_op.character)
             print(f"\n📊 Character set loaded: {char_count} characters")
-            if char_count > 0:
-                print(f"   Sample chars: {ocr.text_recognizer.postprocess_op.character[:20]}")
-            else:
-                print(f"   ❌ ERROR: Character set is empty!")
-                sys.exit(1)
         
         # 测试图像列表
         test_images = [
-            {
-                'path': "/Users/zhch158/workspace/data/至远彩色印刷工业有限公司/2023年度报告母公司.img/2023年度报告母公司_page_001.png",
-                'output': "/Users/zhch158/workspace/repository.git/PaddleX/zhch/sample_data/pytorch_paddle_ocr/page_001_ocr_vis.png"
-            },
+            # {
+            #     'path': "/Users/zhch158/workspace/data/至远彩色印刷工业有限公司/2023年度报告母公司.img/2023年度报告母公司_page_001.png",
+            #     'output': "/Users/zhch158/workspace/repository.git/PaddleX/zhch/sample_data/pytorch_paddle_ocr/page_001_ocr_vis.png"
+            # },
             {
                 'path': "/Users/zhch158/workspace/data/至远彩色印刷工业有限公司/2023年度报告母公司.img/2023年度报告母公司_page_003.png",
                 'output': "/Users/zhch158/workspace/repository.git/PaddleX/zhch/sample_data/pytorch_paddle_ocr/page_003_ocr_vis.png"
@@ -567,55 +617,47 @@ if __name__ == '__main__':
             print(f"📄 Processing: {Path(img_path).name}")
             print(f"{'='*60}")
             
-            if not Path(img_path).exists():
-                print(f"❌ Image not found: {img_path}")
-                continue
-            
             img = cv2.imread(img_path)
-            
-            if img is None:
-                print(f"❌ Failed to load image")
-                continue
-            
             print(f"📖 Original image: {img.shape}")
             
-            # 执行 OCR
+            # np.ndarray: BGR 格式图像
+            img = check_img(img)
+            
+            # ✅ 使用方向分类器(如果启用)
+            if ocr.orientation_classifier is not None and isinstance(img, np.ndarray):
+                logger.info(f"🔄 Checking image orientation...")
+                
+                # 预测方向
+                orientation_result = ocr.orientation_classifier.predict(img, return_debug=True)
+                
+                logger.info(f"   Orientation result: {orientation_result}")
+                
+                # 如果需要旋转
+                if orientation_result.needs_rotation:
+                    logger.info(f"🔄 Rotating image by {orientation_result.rotation_angle}°...")
+                    img = ocr.orientation_classifier.rotate_image(img, orientation_result.rotation_angle)
+                    logger.info(f"   New shape: {img.shape}")
+        
+            # 执行 OCR(会自动检测并旋转)
             results = ocr.ocr(img, det=True, rec=True)
             
             if results and results[0]:
-                print(f"\n✅ OCR completed! Found {len(results[0])} text regions")
-                
-                # 打印前 10 个结果
-                print(f"\n📝 First 10 results:")
+                print(f"\n✅ Found {len(results[0])} text regions:")
                 for i, (box, (text, conf)) in enumerate(results[0][:10], 1):
                     print(f"  [{i}] {text} (conf={conf:.3f})")
                 
-                if len(results[0]) > 10:
-                    print(f"  ... and {len(results[0]) - 10} more")
-                
                 # 可视化
-                print(f"\n🎨 Visualizing results...")
-                img_vis = ocr.visualize(
-                    img, 
-                    results, 
-                    output_path=output_path,
-                    show_text=True,
-                    show_confidence=True
-                )
+                img_vis = ocr.visualize(img, results, output_path=output_path, show_text=False, show_confidence=False)
                 
-                # 统计信息
-                total_boxes = len(results[0])
-                avg_conf = sum(conf for _, (_, conf) in results[0]) / total_boxes
-                high_conf = sum(1 for _, (_, conf) in results[0] if conf >= 0.9)
-                low_conf = sum(1 for _, (_, conf) in results[0] if conf < 0.7)
+                # 统计
+                total = len(results[0])
                 non_empty = sum(1 for _, (text, _) in results[0] if text)
+                avg_conf = sum(conf for _, (_, conf) in results[0]) / total
                 
                 print(f"\n📊 Statistics:")
-                print(f"  Total boxes: {total_boxes}")
+                print(f"  Total: {total}")
                 print(f"  Non-empty: {non_empty}")
-                print(f"  Average confidence: {avg_conf:.3f}")
-                print(f"  High confidence (≥0.9): {high_conf}")
-                print(f"  Low confidence (<0.7): {low_conf}")
+                print(f"  Avg confidence: {avg_conf:.3f}")
             else:
                 print(f"⚠️  No results found")
             

+ 0 - 3
zhch/unified_pytorch_models/vendor/__init__.py

@@ -16,9 +16,6 @@ from .ocr_utils import (
     get_rotate_crop_image
 )
 
-# PytorchPaddleOCR 主类
-from .pytorch_paddle import PytorchPaddleOCR
-
 __all__ = [
     # 设备工具
     'get_device',