doclayout_yolo.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. from doclayout_yolo import YOLOv10
  2. from tqdm import tqdm
  3. class DocLayoutYOLOModel(object):
  4. def __init__(self, weight, device):
  5. self.model = YOLOv10(weight)
  6. self.device = device
  7. def predict(self, image):
  8. layout_res = []
  9. doclayout_yolo_res = self.model.predict(
  10. image,
  11. imgsz=1280,
  12. conf=0.10,
  13. iou=0.45,
  14. verbose=False, device=self.device
  15. )[0]
  16. for xyxy, conf, cla in zip(
  17. doclayout_yolo_res.boxes.xyxy.cpu(),
  18. doclayout_yolo_res.boxes.conf.cpu(),
  19. doclayout_yolo_res.boxes.cls.cpu(),
  20. ):
  21. xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
  22. new_item = {
  23. "category_id": int(cla.item()),
  24. "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
  25. "score": round(float(conf.item()), 3),
  26. }
  27. layout_res.append(new_item)
  28. return layout_res
  29. def batch_predict(self, images: list, batch_size: int) -> list:
  30. images_layout_res = []
  31. # for index in range(0, len(images), batch_size):
  32. for index in tqdm(range(0, len(images), batch_size), desc="Layout Predict"):
  33. doclayout_yolo_res = [
  34. image_res.cpu()
  35. for image_res in self.model.predict(
  36. images[index : index + batch_size],
  37. imgsz=1280,
  38. conf=0.10,
  39. iou=0.45,
  40. verbose=False,
  41. device=self.device,
  42. )
  43. ]
  44. for image_res in doclayout_yolo_res:
  45. layout_res = []
  46. for xyxy, conf, cla in zip(
  47. image_res.boxes.xyxy,
  48. image_res.boxes.conf,
  49. image_res.boxes.cls,
  50. ):
  51. xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
  52. new_item = {
  53. "category_id": int(cla.item()),
  54. "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
  55. "score": round(float(conf.item()), 3),
  56. }
  57. layout_res.append(new_item)
  58. images_layout_res.append(layout_res)
  59. return images_layout_res