doclayout_yolo.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. from typing import List, Dict, Union
  2. from doclayout_yolo import YOLOv10
  3. from tqdm import tqdm
  4. import numpy as np
  5. from PIL import Image
  6. class DocLayoutYOLOModel:
  7. def __init__(
  8. self,
  9. weight: str,
  10. device: str = "cuda",
  11. imgsz: int = 1280,
  12. conf: float = 0.1,
  13. iou: float = 0.45,
  14. ):
  15. self.model = YOLOv10(weight).to(device)
  16. self.device = device
  17. self.imgsz = imgsz
  18. self.conf = conf
  19. self.iou = iou
  20. def _parse_prediction(self, prediction) -> List[Dict]:
  21. layout_res = []
  22. # 容错处理
  23. if not hasattr(prediction, "boxes") or prediction.boxes is None:
  24. return layout_res
  25. for xyxy, conf, cls in zip(
  26. prediction.boxes.xyxy.cpu(),
  27. prediction.boxes.conf.cpu(),
  28. prediction.boxes.cls.cpu(),
  29. ):
  30. coords = list(map(int, xyxy.tolist()))
  31. xmin, ymin, xmax, ymax = coords
  32. layout_res.append({
  33. "category_id": int(cls.item()),
  34. "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
  35. "score": round(float(conf.item()), 3),
  36. })
  37. return layout_res
  38. def predict(self, image: Union[np.ndarray, Image.Image]) -> List[Dict]:
  39. prediction = self.model.predict(
  40. image,
  41. imgsz=self.imgsz,
  42. conf=self.conf,
  43. iou=self.iou,
  44. verbose=False
  45. )[0]
  46. return self._parse_prediction(prediction)
  47. def batch_predict(
  48. self,
  49. images: List[Union[np.ndarray, Image.Image]],
  50. batch_size: int = 4
  51. ) -> List[List[Dict]]:
  52. results = []
  53. with tqdm(total=len(images), desc="Layout Predict") as pbar:
  54. for idx in range(0, len(images), batch_size):
  55. batch = images[idx: idx + batch_size]
  56. predictions = self.model.predict(
  57. batch,
  58. imgsz=self.imgsz,
  59. conf=self.conf,
  60. iou=self.iou,
  61. verbose=False,
  62. )
  63. for pred in predictions:
  64. results.append(self._parse_prediction(pred))
  65. pbar.update(len(batch))
  66. return results