DocLayoutYOLO.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. from doclayout_yolo import YOLOv10
  2. class DocLayoutYOLOModel(object):
  3. def __init__(self, weight, device):
  4. self.model = YOLOv10(weight)
  5. self.device = device
  6. def predict(self, image):
  7. layout_res = []
  8. doclayout_yolo_res = self.model.predict(
  9. image,
  10. imgsz=1280,
  11. conf=0.10,
  12. iou=0.45,
  13. verbose=False, device=self.device
  14. )[0]
  15. for xyxy, conf, cla in zip(
  16. doclayout_yolo_res.boxes.xyxy.cpu(),
  17. doclayout_yolo_res.boxes.conf.cpu(),
  18. doclayout_yolo_res.boxes.cls.cpu(),
  19. ):
  20. xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
  21. new_item = {
  22. "category_id": int(cla.item()),
  23. "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
  24. "score": round(float(conf.item()), 3),
  25. }
  26. layout_res.append(new_item)
  27. return layout_res
  28. def batch_predict(self, images: list, batch_size: int) -> list:
  29. images_layout_res = []
  30. for index in range(0, len(images), batch_size):
  31. doclayout_yolo_res = [
  32. image_res.cpu()
  33. for image_res in self.model.predict(
  34. images[index : index + batch_size],
  35. imgsz=1280,
  36. conf=0.10,
  37. iou=0.45,
  38. verbose=False,
  39. device=self.device,
  40. )
  41. ]
  42. for image_res in doclayout_yolo_res:
  43. layout_res = []
  44. for xyxy, conf, cla in zip(
  45. image_res.boxes.xyxy,
  46. image_res.boxes.conf,
  47. image_res.boxes.cls,
  48. ):
  49. xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
  50. new_item = {
  51. "category_id": int(cla.item()),
  52. "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
  53. "score": round(float(conf.item()), 3),
  54. }
  55. layout_res.append(new_item)
  56. images_layout_res.append(layout_res)
  57. return images_layout_res