DocLayoutYOLO.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. from doclayout_yolo import YOLOv10
  2. class DocLayoutYOLOModel(object):
  3. def __init__(self, weight, device):
  4. self.model = YOLOv10(weight)
  5. self.device = device
  6. def predict(self, image):
  7. layout_res = []
  8. doclayout_yolo_res = self.model.predict(
  9. image, imgsz=1024, conf=0.25, iou=0.45, verbose=False, device=self.device
  10. )[0]
  11. for xyxy, conf, cla in zip(
  12. doclayout_yolo_res.boxes.xyxy.cpu(),
  13. doclayout_yolo_res.boxes.conf.cpu(),
  14. doclayout_yolo_res.boxes.cls.cpu(),
  15. ):
  16. xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
  17. new_item = {
  18. "category_id": int(cla.item()),
  19. "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
  20. "score": round(float(conf.item()), 3),
  21. }
  22. layout_res.append(new_item)
  23. return layout_res
  24. def batch_predict(self, images: list, batch_size: int) -> list:
  25. images_layout_res = []
  26. for index in range(0, len(images), batch_size):
  27. doclayout_yolo_res = [
  28. image_res.cpu()
  29. for image_res in self.model.predict(
  30. images[index : index + batch_size],
  31. imgsz=1024,
  32. conf=0.25,
  33. iou=0.45,
  34. verbose=False,
  35. device=self.device,
  36. )
  37. ]
  38. for image_res in doclayout_yolo_res:
  39. layout_res = []
  40. for xyxy, conf, cla in zip(
  41. image_res.boxes.xyxy,
  42. image_res.boxes.conf,
  43. image_res.boxes.cls,
  44. ):
  45. xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
  46. new_item = {
  47. "category_id": int(cla.item()),
  48. "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
  49. "score": round(float(conf.item()), 3),
  50. }
  51. layout_res.append(new_item)
  52. images_layout_res.append(layout_res)
  53. return images_layout_res