Selaa lähdekoodia

feat: 添加文本检测条件以判断是否需要旋转图像,并更新测试代码以保存旋转后的图像

zhch158_admin 2 viikkoa sitten
vanhempi
commit
d45d09c9f7
1 muutettua tiedostoa jossa 26 lisäystä ja 9 poistoa
  1. 26 9
      zhch/unified_pytorch_models/orientation_classifier_v2.py

+ 26 - 9
zhch/unified_pytorch_models/orientation_classifier_v2.py

@@ -217,14 +217,15 @@ class OrientationClassifierV2:
             return result
         
         # 2. 使用文本检测判断是否旋转
-        is_rotated, vertical_count = self._detect_vertical_text(img)
-        result.vertical_text_count = vertical_count
-        
-        if not is_rotated:
-            if return_debug:
-                print(f"   ⏭️  No rotation needed (vertical_texts={vertical_count})")
-            return result
-        
+        if self.text_detector:
+            is_rotated, vertical_count = self._detect_vertical_text(img)
+            result.vertical_text_count = vertical_count
+            
+            if not is_rotated:
+                if return_debug:
+                    print(f"   ⏭️  No rotation needed (vertical_texts={vertical_count})")
+                return result
+            
         # 3. 使用分类模型预测旋转角度
         input_tensor = self._preprocess(img)
         
@@ -254,4 +255,20 @@ class OrientationClassifierV2:
         elif angle == "270":
             return cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
         else:
-            return img
+            return img
+
+if __name__ == "__main__":
+    # 测试代码
+    model_path = "/Users/zhch158/workspace/repository.git/PaddleX/zhch/unified_pytorch_models/Layout/PP-LCNet_x1_0_doc_ori.onnx"  # 替换为实际模型路径
+    classifier = OrientationClassifierV2(model_path=model_path, use_gpu=False)
+    
+    test_image_path = "/Users/zhch158/workspace/data/至远彩色印刷工业有限公司/2023年度报告母公司.img/2023年度报告母公司_page_003.png"  # 替换为实际测试图像路径
+    output_image_path = Path(f"/Users/zhch158/workspace/repository.git/PaddleX/zhch/sample_data/PP-LCNet_x1_0_doc_ori.onnx/{Path(test_image_path).name}.jpg")
+    img = cv2.imread(test_image_path)
+    
+    result = classifier.predict(img, return_debug=True)
+    print(result)
+    
+    if result.needs_rotation:
+        rotated_img = classifier.rotate_image(img, result.rotation_angle)
+        cv2.imwrite(output_image_path, rotated_img)