yolo_v8.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. import os
  2. from typing import List, Union
  3. from tqdm import tqdm
  4. from ultralytics import YOLO
  5. import numpy as np
  6. from PIL import Image, ImageDraw
  7. from mineru.utils.enum_class import ModelPath
  8. from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
  9. class YOLOv8MFDModel:
  10. def __init__(
  11. self,
  12. weight: str,
  13. device: str = "cpu",
  14. imgsz: int = 1888,
  15. conf: float = 0.25,
  16. iou: float = 0.45,
  17. ):
  18. self.model = YOLO(weight).to(device)
  19. self.device = device
  20. self.imgsz = imgsz
  21. self.conf = conf
  22. self.iou = iou
  23. def _run_predict(
  24. self,
  25. inputs: Union[np.ndarray, Image.Image, List],
  26. is_batch: bool = False
  27. ) -> List:
  28. preds = self.model.predict(
  29. inputs,
  30. imgsz=self.imgsz,
  31. conf=self.conf,
  32. iou=self.iou,
  33. verbose=False,
  34. device=self.device
  35. )
  36. return [pred.cpu() for pred in preds] if is_batch else preds[0].cpu()
  37. def predict(self, image: Union[np.ndarray, Image.Image]):
  38. return self._run_predict(image)
  39. def batch_predict(
  40. self,
  41. images: List[Union[np.ndarray, Image.Image]],
  42. batch_size: int = 4
  43. ) -> List:
  44. results = []
  45. with tqdm(total=len(images), desc="MFD Predict") as pbar:
  46. for idx in range(0, len(images), batch_size):
  47. batch = images[idx: idx + batch_size]
  48. batch_preds = self._run_predict(batch, is_batch=True)
  49. results.extend(batch_preds)
  50. pbar.update(len(batch))
  51. return results
  52. def visualize(
  53. self,
  54. image: Union[np.ndarray, Image.Image],
  55. results: List
  56. ) -> Image.Image:
  57. if isinstance(image, np.ndarray):
  58. image = Image.fromarray(image)
  59. formula_list = []
  60. for xyxy, conf, cla in zip(
  61. results.boxes.xyxy.cpu(), results.boxes.conf.cpu(), results.boxes.cls.cpu()
  62. ):
  63. xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
  64. new_item = {
  65. "category_id": 13 + int(cla.item()),
  66. "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
  67. "score": round(float(conf.item()), 2),
  68. }
  69. formula_list.append(new_item)
  70. draw = ImageDraw.Draw(image)
  71. for res in formula_list:
  72. poly = res['poly']
  73. xmin, ymin, xmax, ymax = poly[0], poly[1], poly[4], poly[5]
  74. print(
  75. f"Detected box: {xmin}, {ymin}, {xmax}, {ymax}, Category ID: {res['category_id']}, Score: {res['score']}")
  76. # 使用PIL在图像上画框
  77. draw.rectangle([xmin, ymin, xmax, ymax], outline="red", width=2)
  78. # 在框旁边画置信度
  79. draw.text((xmax + 10, ymin + 10), f"{res['score']:.2f}", fill="red")
  80. return image
  81. if __name__ == '__main__':
  82. image_path = r"C:\Users\zhaoxiaomeng\Downloads\下载1.jpg"
  83. yolo_v8_mfd_weights = os.path.join(auto_download_and_get_model_root_path(ModelPath.yolo_v8_mfd),
  84. ModelPath.yolo_v8_mfd)
  85. device = 'cuda'
  86. model = YOLOv8MFDModel(
  87. weight=yolo_v8_mfd_weights,
  88. device=device,
  89. )
  90. image = Image.open(image_path)
  91. results = model.predict(image)
  92. image = model.visualize(image, results)
  93. image.show() # 显示图像