|
|
@@ -1,8 +1,12 @@
|
|
|
+import os
|
|
|
from typing import List, Union
|
|
|
from tqdm import tqdm
|
|
|
from ultralytics import YOLO
|
|
|
import numpy as np
|
|
|
-from PIL import Image
|
|
|
+from PIL import Image, ImageDraw
|
|
|
+
|
|
|
+from mineru.utils.enum_class import ModelPath
|
|
|
+from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
|
|
|
|
|
|
|
|
|
class YOLOv8MFDModel:
|
|
|
@@ -50,4 +54,53 @@ class YOLOv8MFDModel:
|
|
|
batch_preds = self._run_predict(batch, is_batch=True)
|
|
|
results.extend(batch_preds)
|
|
|
pbar.update(len(batch))
|
|
|
- return results
|
|
|
+ return results
|
|
|
+
|
|
|
+ def visualize(
|
|
|
+ self,
|
|
|
+ image: Union[np.ndarray, Image.Image],
|
|
|
+ results: List
|
|
|
+ ) -> Image.Image:
|
|
|
+
|
|
|
+ if isinstance(image, np.ndarray):
|
|
|
+ image = Image.fromarray(image)
|
|
|
+
|
|
|
+ formula_list = []
|
|
|
+ for xyxy, conf, cla in zip(
|
|
|
+ results.boxes.xyxy.cpu(), results.boxes.conf.cpu(), results.boxes.cls.cpu()
|
|
|
+ ):
|
|
|
+ xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
|
|
|
+ new_item = {
|
|
|
+ "category_id": 13 + int(cla.item()),
|
|
|
+ "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
|
|
|
+ "score": round(float(conf.item()), 2),
|
|
|
+ }
|
|
|
+ formula_list.append(new_item)
|
|
|
+
|
|
|
+ draw = ImageDraw.Draw(image)
|
|
|
+ for res in formula_list:
|
|
|
+ poly = res['poly']
|
|
|
+ xmin, ymin, xmax, ymax = poly[0], poly[1], poly[4], poly[5]
|
|
|
+ print(
|
|
|
+ f"Detected box: {xmin}, {ymin}, {xmax}, {ymax}, Category ID: {res['category_id']}, Score: {res['score']}")
|
|
|
+ # 使用PIL在图像上画框
|
|
|
+ draw.rectangle([xmin, ymin, xmax, ymax], outline="red", width=2)
|
|
|
+ # 在框旁边画置信度
|
|
|
+ draw.text((xmax + 10, ymin + 10), f"{res['score']:.2f}", fill="red")
|
|
|
+ return image
|
|
|
+
|
|
|
+if __name__ == '__main__':
|
|
|
+ image_path = r"C:\Users\zhaoxiaomeng\Downloads\下载1.jpg"
|
|
|
+ yolo_v8_mfd_weights = os.path.join(auto_download_and_get_model_root_path(ModelPath.yolo_v8_mfd),
|
|
|
+ ModelPath.yolo_v8_mfd)
|
|
|
+ device = 'cuda'
|
|
|
+ model = YOLOv8MFDModel(
|
|
|
+ weight=yolo_v8_mfd_weights,
|
|
|
+ device=device,
|
|
|
+ )
|
|
|
+ image = Image.open(image_path)
|
|
|
+ results = model.predict(image)
|
|
|
+
|
|
|
+ image = model.visualize(image, results)
|
|
|
+
|
|
|
+ image.show() # 显示图像
|