doclayout_yolo.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. from typing import List, Dict, Union
  2. from doclayout_yolo import YOLOv10
  3. from tqdm import tqdm
  4. import numpy as np
  5. from PIL import Image
  6. class DocLayoutYOLOModel:
  7. def __init__(
  8. self,
  9. weight: str,
  10. device: str = "cuda",
  11. imgsz: int = 1280,
  12. conf: float = 0.1,
  13. iou: float = 0.45,
  14. ):
  15. self.model = YOLOv10(weight).to(device)
  16. self.device = device
  17. self.imgsz = imgsz
  18. self.conf = conf
  19. self.iou = iou
  20. def _parse_prediction(self, prediction) -> List[Dict]:
  21. layout_res = []
  22. # 容错处理
  23. if not hasattr(prediction, "boxes") or prediction.boxes is None:
  24. return layout_res
  25. for xyxy, conf, cls in zip(
  26. prediction.boxes.xyxy.cpu(),
  27. prediction.boxes.conf.cpu(),
  28. prediction.boxes.cls.cpu(),
  29. ):
  30. coords = list(map(int, xyxy.tolist()))
  31. xmin, ymin, xmax, ymax = coords
  32. layout_res.append({
  33. "category_id": int(cls.item()),
  34. "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
  35. "score": round(float(conf.item()), 3),
  36. })
  37. return layout_res
  38. def predict(self, image: Union[np.ndarray, Image.Image]) -> List[Dict]:
  39. prediction = self.model.predict(
  40. image,
  41. imgsz=self.imgsz,
  42. conf=self.conf,
  43. iou=self.iou,
  44. verbose=False
  45. )[0]
  46. return self._parse_prediction(prediction)
  47. def batch_predict(
  48. self,
  49. images: List[Union[np.ndarray, Image.Image]],
  50. batch_size: int = 4
  51. ) -> List[List[Dict]]:
  52. results = []
  53. for idx in tqdm(range(0, len(images), batch_size), desc="Layout Predict"):
  54. batch = images[idx: idx + batch_size]
  55. predictions = self.model.predict(
  56. batch,
  57. imgsz=self.imgsz,
  58. conf=self.conf,
  59. iou=self.iou,
  60. verbose=False,
  61. )
  62. for pred in predictions:
  63. results.append(self._parse_prediction(pred))
  64. return results