Explorar o código

Merge pull request #2763 from herryqg/master

encapsulate prediction parsing logic in DocLayoutYOLOModel
Xiaomeng Zhao hai 4 meses
pai
achega
037a3ae6c8
Modificáronse 3 ficheiros con 110 adicións e 77 borrados
  1. 58 51
      mineru/model/layout/doclayout_yolo.py
  2. 44 26
      mineru/model/mfd/yolo_v8.py
  3. 8 0
      signatures/version1/cla.json

+ 58 - 51
mineru/model/layout/doclayout_yolo.py

@@ -1,64 +1,71 @@
+from typing import List, Dict, Union
 from doclayout_yolo import YOLOv10
 from tqdm import tqdm
+import numpy as np
+from PIL import Image
 
 
-class DocLayoutYOLOModel(object):
-    def __init__(self, weight, device):
-        self.model = YOLOv10(weight)
+class DocLayoutYOLOModel:
+    def __init__(
+        self,
+        weight: str,
+        device: str = "cuda",
+        imgsz: int = 1280,
+        conf: float = 0.1,
+        iou: float = 0.45,
+    ):
+        self.model = YOLOv10(weight).to(device)
         self.device = device
+        self.imgsz = imgsz
+        self.conf = conf
+        self.iou = iou
 
-    def predict(self, image):
+    def _parse_prediction(self, prediction) -> List[Dict]:
         layout_res = []
-        doclayout_yolo_res = self.model.predict(
-            image,
-            imgsz=1280,
-            conf=0.10,
-            iou=0.45,
-            verbose=False, device=self.device
-        )[0]
-        for xyxy, conf, cla in zip(
-            doclayout_yolo_res.boxes.xyxy.cpu(),
-            doclayout_yolo_res.boxes.conf.cpu(),
-            doclayout_yolo_res.boxes.cls.cpu(),
+
+        # 容错处理
+        if not hasattr(prediction, "boxes") or prediction.boxes is None:
+            return layout_res
+
+        for xyxy, conf, cls in zip(
+            prediction.boxes.xyxy.cpu(),
+            prediction.boxes.conf.cpu(),
+            prediction.boxes.cls.cpu(),
         ):
-            xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
-            new_item = {
-                "category_id": int(cla.item()),
+            coords = list(map(int, xyxy.tolist()))
+            xmin, ymin, xmax, ymax = coords
+            layout_res.append({
+                "category_id": int(cls.item()),
                 "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
                 "score": round(float(conf.item()), 3),
-            }
-            layout_res.append(new_item)
+            })
         return layout_res
 
-    def batch_predict(self, images: list, batch_size: int) -> list:
-        images_layout_res = []
-        # for index in range(0, len(images), batch_size):
-        for index in tqdm(range(0, len(images), batch_size), desc="Layout Predict"):
-            doclayout_yolo_res = [
-                image_res.cpu()
-                for image_res in self.model.predict(
-                    images[index : index + batch_size],
-                    imgsz=1280,
-                    conf=0.10,
-                    iou=0.45,
-                    verbose=False,
-                    device=self.device,
-                )
-            ]
-            for image_res in doclayout_yolo_res:
-                layout_res = []
-                for xyxy, conf, cla in zip(
-                    image_res.boxes.xyxy,
-                    image_res.boxes.conf,
-                    image_res.boxes.cls,
-                ):
-                    xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
-                    new_item = {
-                        "category_id": int(cla.item()),
-                        "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
-                        "score": round(float(conf.item()), 3),
-                    }
-                    layout_res.append(new_item)
-                images_layout_res.append(layout_res)
+    def predict(self, image: Union[np.ndarray, Image.Image]) -> List[Dict]:
+        prediction = self.model.predict(
+            image,
+            imgsz=self.imgsz,
+            conf=self.conf,
+            iou=self.iou,
+            verbose=False
+        )[0]
+        return self._parse_prediction(prediction)
 
-        return images_layout_res
+    def batch_predict(
+        self,
+        images: List[Union[np.ndarray, Image.Image]],
+        batch_size: int = 4
+    ) -> List[List[Dict]]:
+        results = []
+        for idx in tqdm(range(0, len(images), batch_size), desc="Layout Predict"):
+            batch = images[idx: idx + batch_size]
+            predictions = self.model.predict(
+                batch,
+                imgsz=self.imgsz,
+                conf=self.conf,
+                iou=self.iou,
+                verbose=False,
+            )
+            for pred in predictions:
+                results.append(self._parse_prediction(pred))
+        return results

+ 44 - 26
mineru/model/mfd/yolo_v8.py

@@ -1,33 +1,51 @@
+from typing import List, Union
 from tqdm import tqdm
 from ultralytics import YOLO
+import numpy as np
+from PIL import Image
 
 
-class YOLOv8MFDModel(object):
-    def __init__(self, weight, device="cpu"):
-        self.mfd_model = YOLO(weight)
+class YOLOv8MFDModel:
+    def __init__(
+        self,
+        weight: str,
+        device: str = "cpu",
+        imgsz: int = 1888,
+        conf: float = 0.25,
+        iou: float = 0.45,
+    ):
+        self.model = YOLO(weight).to(device)
         self.device = device
+        self.imgsz = imgsz
+        self.conf = conf
+        self.iou = iou
 
-    def predict(self, image):
-        mfd_res = self.mfd_model.predict(
-            image, imgsz=1888, conf=0.25, iou=0.45, verbose=False, device=self.device
-        )[0]
-        return mfd_res
+    def _run_predict(
+        self,
+        inputs: Union[np.ndarray, Image.Image, List],
+        is_batch: bool = False
+    ) -> List:
+        preds = self.model.predict(
+            inputs,
+            imgsz=self.imgsz,
+            conf=self.conf,
+            iou=self.iou,
+            verbose=False,
+            device=self.device
+        )
+        return [pred.cpu() for pred in preds] if is_batch else preds[0].cpu()
 
-    def batch_predict(self, images: list, batch_size: int) -> list:
-        images_mfd_res = []
-        # for index in range(0, len(images), batch_size):
-        for index in tqdm(range(0, len(images), batch_size), desc="MFD Predict"):
-            mfd_res = [
-                image_res.cpu()
-                for image_res in self.mfd_model.predict(
-                    images[index : index + batch_size],
-                    imgsz=1888,
-                    conf=0.25,
-                    iou=0.45,
-                    verbose=False,
-                    device=self.device,
-                )
-            ]
-            for image_res in mfd_res:
-                images_mfd_res.append(image_res)
-        return images_mfd_res
+    def predict(self, image: Union[np.ndarray, Image.Image]):
+        return self._run_predict(image)
+
+    def batch_predict(
+        self,
+        images: List[Union[np.ndarray, Image.Image]],
+        batch_size: int = 4
+    ) -> List:
+        results = []
+        for idx in tqdm(range(0, len(images), batch_size), desc="MFD Predict"):
+            batch = images[idx: idx + batch_size]
+            batch_preds = self._run_predict(batch, is_batch=True)
+            results.extend(batch_preds)
+        return results

+ 8 - 0
signatures/version1/cla.json

@@ -343,6 +343,14 @@
       "created_at": "2025-06-18T11:27:23Z",
       "repoId": 765083837,
       "pullRequestNo": 2727
+    },
+    {
+      "name": "QIN2DIM",
+      "id": 62018067,
+      "comment_id": 2992279796,
+      "created_at": "2025-06-20T17:04:59Z",
+      "repoId": 765083837,
+      "pullRequestNo": 2758
     }
   ]
 }