DocLayoutYOLO.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. from doclayout_yolo import YOLOv10
  2. class DocLayoutYOLOModel(object):
  3. def __init__(self, weight, device):
  4. self.model = YOLOv10(weight)
  5. self.device = device
  6. def predict(self, image):
  7. layout_res = []
  8. doclayout_yolo_res = self.model.predict(
  9. image, imgsz=1024, conf=0.25, iou=0.45, verbose=True, device=self.device
  10. )[0]
  11. for xyxy, conf, cla in zip(
  12. doclayout_yolo_res.boxes.xyxy.cpu(),
  13. doclayout_yolo_res.boxes.conf.cpu(),
  14. doclayout_yolo_res.boxes.cls.cpu(),
  15. ):
  16. xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
  17. new_item = {
  18. "category_id": int(cla.item()),
  19. "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
  20. "score": round(float(conf.item()), 3),
  21. }
  22. layout_res.append(new_item)
  23. return layout_res
  24. def batch_predict(self, images: list, batch_size: int) -> list:
  25. images_layout_res = []
  26. for index in range(0, len(images), batch_size):
  27. doclayout_yolo_res = self.model.predict(
  28. images[index : index + batch_size],
  29. imgsz=1024,
  30. conf=0.25,
  31. iou=0.45,
  32. verbose=True,
  33. device=self.device,
  34. ).cpu()
  35. for image_res in doclayout_yolo_res:
  36. layout_res = []
  37. for xyxy, conf, cla in zip(
  38. image_res.boxes.xyxy,
  39. image_res.boxes.conf,
  40. image_res.boxes.cls,
  41. ):
  42. xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
  43. new_item = {
  44. "category_id": int(cla.item()),
  45. "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
  46. "score": round(float(conf.item()), 3),
  47. }
  48. layout_res.append(new_item)
  49. images_layout_res.append(layout_res)
  50. return images_layout_res