doclayoutyolo.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. import os
  2. from typing import List, Dict, Union
  3. from doclayout_yolo import YOLOv10
  4. from tqdm import tqdm
  5. import numpy as np
  6. from PIL import Image, ImageDraw
  7. from mineru.utils.enum_class import ModelPath
  8. from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
  9. class DocLayoutYOLOModel:
  10. def __init__(
  11. self,
  12. weight: str,
  13. device: str = "cuda",
  14. imgsz: int = 1280,
  15. conf: float = 0.1,
  16. iou: float = 0.45,
  17. ):
  18. self.model = YOLOv10(weight).to(device)
  19. self.device = device
  20. self.imgsz = imgsz
  21. self.conf = conf
  22. self.iou = iou
  23. def _parse_prediction(self, prediction) -> List[Dict]:
  24. layout_res = []
  25. # 容错处理
  26. if not hasattr(prediction, "boxes") or prediction.boxes is None:
  27. return layout_res
  28. for xyxy, conf, cls in zip(
  29. prediction.boxes.xyxy.cpu(),
  30. prediction.boxes.conf.cpu(),
  31. prediction.boxes.cls.cpu(),
  32. ):
  33. coords = list(map(int, xyxy.tolist()))
  34. xmin, ymin, xmax, ymax = coords
  35. layout_res.append({
  36. "category_id": int(cls.item()),
  37. "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
  38. "score": round(float(conf.item()), 3),
  39. })
  40. return layout_res
  41. def predict(self, image: Union[np.ndarray, Image.Image]) -> List[Dict]:
  42. prediction = self.model.predict(
  43. image,
  44. imgsz=self.imgsz,
  45. conf=self.conf,
  46. iou=self.iou,
  47. verbose=False
  48. )[0]
  49. return self._parse_prediction(prediction)
  50. def batch_predict(
  51. self,
  52. images: List[Union[np.ndarray, Image.Image]],
  53. batch_size: int = 4
  54. ) -> List[List[Dict]]:
  55. results = []
  56. with tqdm(total=len(images), desc="Layout Predict") as pbar:
  57. for idx in range(0, len(images), batch_size):
  58. batch = images[idx: idx + batch_size]
  59. if batch_size == 1:
  60. conf = 0.9 * self.conf
  61. else:
  62. conf = self.conf
  63. predictions = self.model.predict(
  64. batch,
  65. imgsz=self.imgsz,
  66. conf=conf,
  67. iou=self.iou,
  68. verbose=False,
  69. )
  70. for pred in predictions:
  71. results.append(self._parse_prediction(pred))
  72. pbar.update(len(batch))
  73. return results
  74. def visualize(
  75. self,
  76. image: Union[np.ndarray, Image.Image],
  77. results: List
  78. ) -> Image.Image:
  79. if isinstance(image, np.ndarray):
  80. image = Image.fromarray(image)
  81. draw = ImageDraw.Draw(image)
  82. for res in results:
  83. poly = res['poly']
  84. xmin, ymin, xmax, ymax = poly[0], poly[1], poly[4], poly[5]
  85. print(
  86. f"Detected box: {xmin}, {ymin}, {xmax}, {ymax}, Category ID: {res['category_id']}, Score: {res['score']}")
  87. # 使用PIL在图像上画框
  88. draw.rectangle([xmin, ymin, xmax, ymax], outline="red", width=2)
  89. # 在框旁边画置信度
  90. draw.text((xmax + 10, ymin + 10), f"{res['score']:.2f}", fill="red", font_size=22)
  91. return image
  92. if __name__ == '__main__':
  93. image_path = r"C:\Users\zhaoxiaomeng\Downloads\下载1.jpg"
  94. doclayout_yolo_weights = os.path.join(auto_download_and_get_model_root_path(ModelPath.doclayout_yolo), ModelPath.doclayout_yolo)
  95. device = 'cuda'
  96. model = DocLayoutYOLOModel(
  97. weight=doclayout_yolo_weights,
  98. device=device,
  99. )
  100. image = Image.open(image_path)
  101. results = model.predict(image)
  102. image = model.visualize(image, results)
  103. image.show() # 显示图像