Unimernet.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. import torch
  2. from torch.utils.data import DataLoader, Dataset
  3. from tqdm import tqdm
  4. class MathDataset(Dataset):
  5. def __init__(self, image_paths, transform=None):
  6. self.image_paths = image_paths
  7. self.transform = transform
  8. def __len__(self):
  9. return len(self.image_paths)
  10. def __getitem__(self, idx):
  11. raw_image = self.image_paths[idx]
  12. if self.transform:
  13. image = self.transform(raw_image)
  14. return image
  15. class UnimernetModel(object):
  16. def __init__(self, weight_dir, _device_="cpu"):
  17. from .unimernet_hf import UnimernetModel
  18. if _device_.startswith("mps") or _device_.startswith("npu"):
  19. self.model = UnimernetModel.from_pretrained(weight_dir, attn_implementation="eager")
  20. else:
  21. self.model = UnimernetModel.from_pretrained(weight_dir)
  22. self.device = _device_
  23. self.model.to(_device_)
  24. if not _device_.startswith("cpu"):
  25. self.model = self.model.to(dtype=torch.float16)
  26. self.model.eval()
  27. def predict(self, mfd_res, image):
  28. formula_list = []
  29. mf_image_list = []
  30. for xyxy, conf, cla in zip(
  31. mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()
  32. ):
  33. xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
  34. new_item = {
  35. "category_id": 13 + int(cla.item()),
  36. "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
  37. "score": round(float(conf.item()), 2),
  38. "latex": "",
  39. }
  40. formula_list.append(new_item)
  41. bbox_img = image[ymin:ymax, xmin:xmax]
  42. mf_image_list.append(bbox_img)
  43. dataset = MathDataset(mf_image_list, transform=self.model.transform)
  44. dataloader = DataLoader(dataset, batch_size=32, num_workers=0)
  45. mfr_res = []
  46. for mf_img in dataloader:
  47. mf_img = mf_img.to(dtype=self.model.dtype)
  48. mf_img = mf_img.to(self.device)
  49. with torch.no_grad():
  50. output = self.model.generate({"image": mf_img})
  51. mfr_res.extend(output["fixed_str"])
  52. for res, latex in zip(formula_list, mfr_res):
  53. res["latex"] = latex
  54. return formula_list
  55. def batch_predict(self, images_mfd_res: list, images: list, batch_size: int = 64) -> list:
  56. images_formula_list = []
  57. mf_image_list = []
  58. backfill_list = []
  59. image_info = [] # Store (area, original_index, image) tuples
  60. # Collect images with their original indices
  61. for image_index in range(len(images_mfd_res)):
  62. mfd_res = images_mfd_res[image_index]
  63. image = images[image_index]
  64. formula_list = []
  65. for idx, (xyxy, conf, cla) in enumerate(zip(
  66. mfd_res.boxes.xyxy, mfd_res.boxes.conf, mfd_res.boxes.cls
  67. )):
  68. xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
  69. new_item = {
  70. "category_id": 13 + int(cla.item()),
  71. "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
  72. "score": round(float(conf.item()), 2),
  73. "latex": "",
  74. }
  75. formula_list.append(new_item)
  76. bbox_img = image[ymin:ymax, xmin:xmax]
  77. area = (xmax - xmin) * (ymax - ymin)
  78. curr_idx = len(mf_image_list)
  79. image_info.append((area, curr_idx, bbox_img))
  80. mf_image_list.append(bbox_img)
  81. images_formula_list.append(formula_list)
  82. backfill_list += formula_list
  83. # Stable sort by area
  84. image_info.sort(key=lambda x: x[0]) # sort by area
  85. sorted_indices = [x[1] for x in image_info]
  86. sorted_images = [x[2] for x in image_info]
  87. # Create mapping for results
  88. index_mapping = {new_idx: old_idx for new_idx, old_idx in enumerate(sorted_indices)}
  89. # Create dataset with sorted images
  90. dataset = MathDataset(sorted_images, transform=self.model.transform)
  91. # 如果batch_size > len(sorted_images),则设置为不超过len(sorted_images)的2的幂
  92. batch_size = min(batch_size, max(1, 2 ** (len(sorted_images).bit_length() - 1))) if sorted_images else 1
  93. dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
  94. # Process batches and store results
  95. mfr_res = []
  96. # for mf_img in dataloader:
  97. with tqdm(total=len(sorted_images), desc="MFR Predict") as pbar:
  98. for index, mf_img in enumerate(dataloader):
  99. mf_img = mf_img.to(dtype=self.model.dtype)
  100. mf_img = mf_img.to(self.device)
  101. with torch.no_grad():
  102. output = self.model.generate({"image": mf_img}, batch_size=batch_size)
  103. mfr_res.extend(output["fixed_str"])
  104. # 更新进度条,每次增加batch_size,但要注意最后一个batch可能不足batch_size
  105. current_batch_size = min(batch_size, len(sorted_images) - index * batch_size)
  106. pbar.update(current_batch_size)
  107. # Restore original order
  108. unsorted_results = [""] * len(mfr_res)
  109. for new_idx, latex in enumerate(mfr_res):
  110. original_idx = index_mapping[new_idx]
  111. unsorted_results[original_idx] = latex
  112. # Fill results back
  113. for res, latex in zip(backfill_list, unsorted_results):
  114. res["latex"] = latex
  115. return images_formula_list