yolo_v8.py 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. from typing import List, Union
  2. from tqdm import tqdm
  3. from ultralytics import YOLO
  4. import numpy as np
  5. from PIL import Image
  6. class YOLOv8MFDModel:
  7. def __init__(
  8. self,
  9. weight: str,
  10. device: str = "cpu",
  11. imgsz: int = 1888,
  12. conf: float = 0.25,
  13. iou: float = 0.45,
  14. ):
  15. self.model = YOLO(weight).to(device)
  16. self.device = device
  17. self.imgsz = imgsz
  18. self.conf = conf
  19. self.iou = iou
  20. def _run_predict(
  21. self,
  22. inputs: Union[np.ndarray, Image.Image, List],
  23. is_batch: bool = False
  24. ) -> List:
  25. preds = self.model.predict(
  26. inputs,
  27. imgsz=self.imgsz,
  28. conf=self.conf,
  29. iou=self.iou,
  30. verbose=False,
  31. device=self.device
  32. )
  33. return [pred.cpu() for pred in preds] if is_batch else preds[0].cpu()
  34. def predict(self, image: Union[np.ndarray, Image.Image]):
  35. return self._run_predict(image)
  36. def batch_predict(
  37. self,
  38. images: List[Union[np.ndarray, Image.Image]],
  39. batch_size: int = 4
  40. ) -> List:
  41. results = []
  42. with tqdm(total=len(images), desc="MFD Predict") as pbar:
  43. for idx in range(0, len(images), batch_size):
  44. batch = images[idx: idx + batch_size]
  45. batch_preds = self._run_predict(batch, is_batch=True)
  46. results.extend(batch_preds)
  47. pbar.update(len(batch))
  48. return results