Browse Source

fix: add version check for PyTorch to prevent errors in batch prediction

myhloli 2 tháng trước cách đây
mục cha
commit
aa39e61fef
1 tập tin đã thay đổi với 6 bổ sung0 xóa
  1. 6 0
      mineru/model/ori_cls/paddle_ori_cls.py

+ 6 - 0
mineru/model/ori_cls/paddle_ori_cls.py

@@ -174,6 +174,12 @@ class PaddleOrientationClsModel:
     def batch_predict(
         self, imgs: List[Dict], det_batch_size: int, batch_size: int = 16
     ) -> None:
+
+        import torch
+        from packaging import version
+        if version.parse(torch.__version__) >= version.parse("2.8.0"):
+            return None
+
         """
         批量预测传入的包含图片信息列表的旋转信息,并且将旋转过的图片正确地旋转回来
         """