|
|
@@ -1,64 +1,71 @@
|
|
|
+from typing import List, Dict, Union
|
|
|
from doclayout_yolo import YOLOv10
|
|
|
from tqdm import tqdm
|
|
|
+import numpy as np
|
|
|
+from PIL import Image
|
|
|
|
|
|
|
|
|
-class DocLayoutYOLOModel(object):
|
|
|
- def __init__(self, weight, device):
|
|
|
- self.model = YOLOv10(weight)
|
|
|
+class DocLayoutYOLOModel:
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ weight: str,
|
|
|
+ device: str = "cuda",
|
|
|
+ imgsz: int = 1280,
|
|
|
+ conf: float = 0.1,
|
|
|
+ iou: float = 0.45,
|
|
|
+ ):
|
|
|
+ self.model = YOLOv10(weight).to(device)
|
|
|
self.device = device
|
|
|
+ self.imgsz = imgsz
|
|
|
+ self.conf = conf
|
|
|
+ self.iou = iou
|
|
|
|
|
|
- def predict(self, image):
|
|
|
+ def _parse_prediction(self, prediction) -> List[Dict]:
|
|
|
layout_res = []
|
|
|
- doclayout_yolo_res = self.model.predict(
|
|
|
- image,
|
|
|
- imgsz=1280,
|
|
|
- conf=0.10,
|
|
|
- iou=0.45,
|
|
|
- verbose=False, device=self.device
|
|
|
- )[0]
|
|
|
- for xyxy, conf, cla in zip(
|
|
|
- doclayout_yolo_res.boxes.xyxy.cpu(),
|
|
|
- doclayout_yolo_res.boxes.conf.cpu(),
|
|
|
- doclayout_yolo_res.boxes.cls.cpu(),
|
|
|
+
|
|
|
+ # 容错处理
|
|
|
+ if not hasattr(prediction, "boxes") or prediction.boxes is None:
|
|
|
+ return layout_res
|
|
|
+
|
|
|
+ for xyxy, conf, cls in zip(
|
|
|
+ prediction.boxes.xyxy.cpu(),
|
|
|
+ prediction.boxes.conf.cpu(),
|
|
|
+ prediction.boxes.cls.cpu(),
|
|
|
):
|
|
|
- xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
|
|
|
- new_item = {
|
|
|
- "category_id": int(cla.item()),
|
|
|
+ coords = list(map(int, xyxy.tolist()))
|
|
|
+ xmin, ymin, xmax, ymax = coords
|
|
|
+ layout_res.append({
|
|
|
+ "category_id": int(cls.item()),
|
|
|
"poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
|
|
|
"score": round(float(conf.item()), 3),
|
|
|
- }
|
|
|
- layout_res.append(new_item)
|
|
|
+ })
|
|
|
return layout_res
|
|
|
|
|
|
- def batch_predict(self, images: list, batch_size: int) -> list:
|
|
|
- images_layout_res = []
|
|
|
- # for index in range(0, len(images), batch_size):
|
|
|
- for index in tqdm(range(0, len(images), batch_size), desc="Layout Predict"):
|
|
|
- doclayout_yolo_res = [
|
|
|
- image_res.cpu()
|
|
|
- for image_res in self.model.predict(
|
|
|
- images[index : index + batch_size],
|
|
|
- imgsz=1280,
|
|
|
- conf=0.10,
|
|
|
- iou=0.45,
|
|
|
- verbose=False,
|
|
|
- device=self.device,
|
|
|
- )
|
|
|
- ]
|
|
|
- for image_res in doclayout_yolo_res:
|
|
|
- layout_res = []
|
|
|
- for xyxy, conf, cla in zip(
|
|
|
- image_res.boxes.xyxy,
|
|
|
- image_res.boxes.conf,
|
|
|
- image_res.boxes.cls,
|
|
|
- ):
|
|
|
- xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
|
|
|
- new_item = {
|
|
|
- "category_id": int(cla.item()),
|
|
|
- "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
|
|
|
- "score": round(float(conf.item()), 3),
|
|
|
- }
|
|
|
- layout_res.append(new_item)
|
|
|
- images_layout_res.append(layout_res)
|
|
|
+ def predict(self, image: Union[np.ndarray, Image.Image]) -> List[Dict]:
|
|
|
+ prediction = self.model.predict(
|
|
|
+ image,
|
|
|
+ imgsz=self.imgsz,
|
|
|
+ conf=self.conf,
|
|
|
+ iou=self.iou,
|
|
|
+ verbose=False
|
|
|
+ )[0]
|
|
|
+ return self._parse_prediction(prediction)
|
|
|
|
|
|
- return images_layout_res
|
|
|
+ def batch_predict(
|
|
|
+ self,
|
|
|
+ images: List[Union[np.ndarray, Image.Image]],
|
|
|
+ batch_size: int = 4
|
|
|
+ ) -> List[List[Dict]]:
|
|
|
+ results = []
|
|
|
+ for idx in tqdm(range(0, len(images), batch_size), desc="Layout Predict"):
|
|
|
+ batch = images[idx: idx + batch_size]
|
|
|
+ predictions = self.model.predict(
|
|
|
+ batch,
|
|
|
+ imgsz=self.imgsz,
|
|
|
+ conf=self.conf,
|
|
|
+ iou=self.iou,
|
|
|
+ verbose=False,
|
|
|
+ )
|
|
|
+ for pred in predictions:
|
|
|
+ results.append(self._parse_prediction(pred))
|
|
|
+ return results
|