Quellcode durchsuchen

fix: add version check for PyTorch to prevent errors in batch prediction

myhloli vor 2 Monaten
Ursprung
Commit
aa39e61fef
1 geänderte Dateien mit 6 neuen und 0 gelöschten Zeilen
  1. 6 0
      mineru/model/ori_cls/paddle_ori_cls.py

+ 6 - 0
mineru/model/ori_cls/paddle_ori_cls.py

@@ -174,6 +174,12 @@ class PaddleOrientationClsModel:
     def batch_predict(
         self, imgs: List[Dict], det_batch_size: int, batch_size: int = 16
     ) -> None:
+
+        import torch
+        from packaging import version
+        if version.parse(torch.__version__) >= version.parse("2.8.0"):
+            return None
+
         """
         批量预测传入的包含图片信息列表的旋转信息,并且将旋转过的图片正确地旋转回来
         """