test_doclayoutyolo.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. import os
  2. from typing import List, Dict, Union
  3. from doclayout_yolo import YOLOv10
  4. from tqdm import tqdm
  5. import numpy as np
  6. from PIL import Image, ImageDraw
  7. from mineru.utils.enum_class import ModelPath
  8. from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
  9. class DocLayoutYOLOModel:
  10. def __init__(
  11. self,
  12. weight: str,
  13. device: str = "cuda",
  14. imgsz: int = 1280,
  15. conf: float = 0.1,
  16. iou: float = 0.45,
  17. ):
  18. self.model = YOLOv10(weight).to(device)
  19. self.device = device
  20. self.imgsz = imgsz
  21. self.conf = conf
  22. self.iou = iou
  23. def _parse_prediction(self, prediction) -> List[Dict]:
  24. layout_res = []
  25. # 容错处理
  26. if not hasattr(prediction, "boxes") or prediction.boxes is None:
  27. return layout_res
  28. for xyxy, conf, cls in zip(
  29. prediction.boxes.xyxy.cpu(),
  30. prediction.boxes.conf.cpu(),
  31. prediction.boxes.cls.cpu(),
  32. ):
  33. coords = list(map(int, xyxy.tolist()))
  34. xmin, ymin, xmax, ymax = coords
  35. layout_res.append({
  36. "category_id": int(cls.item()),
  37. "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
  38. "score": round(float(conf.item()), 3),
  39. })
  40. return layout_res
  41. def predict(self, image: Union[np.ndarray, Image.Image]) -> List[Dict]:
  42. prediction = self.model.predict(
  43. image,
  44. imgsz=self.imgsz,
  45. conf=self.conf,
  46. iou=self.iou,
  47. verbose=False
  48. )[0]
  49. return self._parse_prediction(prediction)
  50. def batch_predict(
  51. self,
  52. images: List[Union[np.ndarray, Image.Image]],
  53. batch_size: int = 4
  54. ) -> List[List[Dict]]:
  55. results = []
  56. with tqdm(total=len(images), desc="Layout Predict") as pbar:
  57. for idx in range(0, len(images), batch_size):
  58. batch = images[idx: idx + batch_size]
  59. if batch_size == 1:
  60. conf = 0.9 * self.conf
  61. else:
  62. conf = self.conf
  63. predictions = self.model.predict(
  64. batch,
  65. imgsz=self.imgsz,
  66. conf=conf,
  67. iou=self.iou,
  68. verbose=False,
  69. )
  70. for pred in predictions:
  71. results.append(self._parse_prediction(pred))
  72. pbar.update(len(batch))
  73. return results
  74. # DocLayout-YOLO 类别映射
  75. CATEGORY_NAMES = {
  76. 0: "title",
  77. 1: "text",
  78. 2: "abandon",
  79. 3: "figure",
  80. 4: "figure_caption",
  81. 5: "table",
  82. 6: "table_caption",
  83. 7: "table_footnote",
  84. 8: "isolate_formula",
  85. 9: "formula_caption",
  86. }
  87. # 不同类别使用不同颜色
  88. CATEGORY_COLORS = {
  89. 0: "red", # title
  90. 1: "blue", # text
  91. 2: "gray", # abandon
  92. 3: "green", # figure
  93. 4: "lightgreen", # figure_caption
  94. 5: "orange", # table
  95. 6: "yellow", # table_caption
  96. 7: "pink", # table_footnote
  97. 8: "purple", # isolate_formula
  98. 9: "cyan", # formula_caption
  99. }
  100. def visualize(
  101. self,
  102. image: Union[np.ndarray, Image.Image],
  103. results: List
  104. ) -> Image.Image:
  105. if isinstance(image, np.ndarray):
  106. image = Image.fromarray(image)
  107. draw = ImageDraw.Draw(image)
  108. for res in results:
  109. poly = res['poly']
  110. xmin, ymin, xmax, ymax = poly[0], poly[1], poly[4], poly[5]
  111. category_id = res['category_id']
  112. category_name = self.CATEGORY_NAMES.get(category_id, f"unknown_{category_id}")
  113. color = self.CATEGORY_COLORS.get(category_id, "red")
  114. print(
  115. f"Detected box: {xmin}, {ymin}, {xmax}, {ymax}, Category: {category_name}({category_id}), Score: {res['score']}")
  116. # 使用PIL在图像上画框
  117. draw.rectangle([xmin, ymin, xmax, ymax], outline=color, width=2)
  118. # 在框旁边画类别名和置信度
  119. label = f"{category_name} {res['score']:.2f}"
  120. draw.text((xmin, ymin - 25), label, fill=color, font_size=20)
  121. return image
  122. if __name__ == '__main__':
  123. image_path = "./2023年度报告母公司_page_003_270.png"
  124. doclayout_yolo_weights = os.path.join(auto_download_and_get_model_root_path(ModelPath.doclayout_yolo), ModelPath.doclayout_yolo)
  125. device = 'cpu'
  126. model = DocLayoutYOLOModel(
  127. weight=doclayout_yolo_weights,
  128. device=device,
  129. )
  130. image = Image.open(image_path)
  131. results = model.predict(image)
  132. image = model.visualize(image, results)
  133. image.show() # 显示图像