Bläddra i källkod

Refactor image rotation handling in batch_analyze.py and paddle_ori_cls.py for improved compatibility with torch versions

myhloli 1 månad sedan
förälder
incheckning
6d5d1cf26b
3 ändrade filer med 37 tillägg och 27 borttagningar
  1. 10 3
      mineru/backend/pipeline/batch_analyze.py
  2. 24 21
      mineru/model/ori_cls/paddle_ori_cls.py
  3. 3 3
      pyproject.toml

+ 10 - 3
mineru/backend/pipeline/batch_analyze.py

@@ -116,9 +116,16 @@ class BatchAnalyze:
                 atom_model_name=AtomicModel.ImgOrientationCls,
             )
             try:
-                img_orientation_cls_model.batch_predict(table_res_list_all_page,
-                                                        det_batch_size=self.batch_ratio * OCR_DET_BASE_BATCH_SIZE,
-                                                        batch_size=TABLE_ORI_CLS_BATCH_SIZE)
+                import torch
+                from packaging import version
+                if version.parse(torch.__version__) >= version.parse("2.8.0"):
+                    for table_res in table_res_list_all_page:
+                        rotate_label = img_orientation_cls_model.predict(table_res['table_img'])
+                        img_orientation_cls_model.img_rotate(table_res, rotate_label)
+                else:
+                    img_orientation_cls_model.batch_predict(table_res_list_all_page,
+                                                            det_batch_size=self.batch_ratio * OCR_DET_BASE_BATCH_SIZE,
+                                                            batch_size=TABLE_ORI_CLS_BATCH_SIZE)
             except Exception as e:
                 logger.warning(
                     f"Image orientation classification failed: {e}, using original image"

+ 24 - 21
mineru/model/ori_cls/paddle_ori_cls.py

@@ -255,25 +255,28 @@ class PaddleOrientationClsModel:
                     results = self.sess.run(None, {"x": x})
                     for img_info, res in zip(rotated_imgs, results[0]):
                         label = self.labels[np.argmax(res)]
-                        if label == "270":
-                            img_info["table_img"] = cv2.rotate(
-                                np.asarray(img_info["table_img"]),
-                                cv2.ROTATE_90_CLOCKWISE,
-                            )
-                            img_info["wired_table_img"] = cv2.rotate(
-                                np.asarray(img_info["wired_table_img"]),
-                                cv2.ROTATE_90_CLOCKWISE,
-                            )
-                        elif label == "90":
-                            img_info["table_img"] = cv2.rotate(
-                                np.asarray(img_info["table_img"]),
-                                cv2.ROTATE_90_COUNTERCLOCKWISE,
-                            )
-                            img_info["wired_table_img"] = cv2.rotate(
-                                np.asarray(img_info["wired_table_img"]),
-                                cv2.ROTATE_90_COUNTERCLOCKWISE,
-                            )
-                        else:
-                            # 180度和0度不做处理
-                            pass
+                        self.img_rotate(img_info, label)
                         pbar.update(1)
+
+    def img_rotate(self, img_info, label):
+        if label == "270":
+            img_info["table_img"] = cv2.rotate(
+                np.asarray(img_info["table_img"]),
+                cv2.ROTATE_90_CLOCKWISE,
+            )
+            img_info["wired_table_img"] = cv2.rotate(
+                np.asarray(img_info["wired_table_img"]),
+                cv2.ROTATE_90_CLOCKWISE,
+            )
+        elif label == "90":
+            img_info["table_img"] = cv2.rotate(
+                np.asarray(img_info["table_img"]),
+                cv2.ROTATE_90_COUNTERCLOCKWISE,
+            )
+            img_info["wired_table_img"] = cv2.rotate(
+                np.asarray(img_info["wired_table_img"]),
+                cv2.ROTATE_90_COUNTERCLOCKWISE,
+            )
+        else:
+            # 180度和0度不做处理
+            pass

+ 3 - 3
pyproject.toml

@@ -51,12 +51,12 @@ test = [
     "fuzzywuzzy"
 ]
 vlm = [
-    "torch>=2.6.0,<2.8.0",
+    "torch>=2.6.0,<3",
     "transformers>=4.51.1,<5.0.0",
     "accelerate>=1.5.1",
 ]
 vllm = [
-    "vllm==0.10.1.1",
+    "vllm>=0.10.1.1",
 ]
 pipeline = [
     "matplotlib>=3.10,<4",
@@ -68,7 +68,7 @@ pipeline = [
     "shapely>=2.0.7,<3",
     "pyclipper>=1.3.0,<2",
     "omegaconf>=2.3.0,<3",
-    "torch>=2.6.0,<2.8.0",
+    "torch>=2.6.0,<3",
     "torchvision",
     "transformers>=4.49.0,!=4.51.0,<5.0.0",
     "onnxruntime>1.17.0",