Unimernet.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. import torch
  2. from torch.utils.data import DataLoader, Dataset
  3. from tqdm import tqdm
  4. class MathDataset(Dataset):
  5. def __init__(self, image_paths, transform=None):
  6. self.image_paths = image_paths
  7. self.transform = transform
  8. def __len__(self):
  9. return len(self.image_paths)
  10. def __getitem__(self, idx):
  11. raw_image = self.image_paths[idx]
  12. if self.transform:
  13. image = self.transform(raw_image)
  14. return image
  15. class UnimernetModel(object):
  16. def __init__(self, weight_dir, _device_="cpu"):
  17. from .unimernet_hf import UnimernetModel
  18. if _device_.startswith("mps") or _device_.startswith("npu"):
  19. self.model = UnimernetModel.from_pretrained(weight_dir, attn_implementation="eager")
  20. else:
  21. self.model = UnimernetModel.from_pretrained(weight_dir)
  22. self.device = _device_
  23. self.model.to(_device_)
  24. if not _device_.startswith("cpu"):
  25. self.model = self.model.to(dtype=torch.float16)
  26. self.model.eval()
  27. def predict(self, mfd_res, image):
  28. formula_list = []
  29. mf_image_list = []
  30. for xyxy, conf, cla in zip(
  31. mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()
  32. ):
  33. xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
  34. new_item = {
  35. "category_id": 13 + int(cla.item()),
  36. "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
  37. "score": round(float(conf.item()), 2),
  38. "latex": "",
  39. }
  40. formula_list.append(new_item)
  41. bbox_img = image[ymin:ymax, xmin:xmax]
  42. mf_image_list.append(bbox_img)
  43. dataset = MathDataset(mf_image_list, transform=self.model.transform)
  44. dataloader = DataLoader(dataset, batch_size=32, num_workers=0)
  45. mfr_res = []
  46. for mf_img in dataloader:
  47. mf_img = mf_img.to(dtype=self.model.dtype)
  48. mf_img = mf_img.to(self.device)
  49. with torch.no_grad():
  50. output = self.model.generate({"image": mf_img})
  51. mfr_res.extend(output["fixed_str"])
  52. for res, latex in zip(formula_list, mfr_res):
  53. res["latex"] = latex
  54. return formula_list
  55. def batch_predict(self, images_mfd_res: list, images: list, batch_size: int = 64) -> list:
  56. images_formula_list = []
  57. mf_image_list = []
  58. backfill_list = []
  59. image_info = [] # Store (area, original_index, image) tuples
  60. # Collect images with their original indices
  61. for image_index in range(len(images_mfd_res)):
  62. mfd_res = images_mfd_res[image_index]
  63. pil_img = images[image_index]
  64. formula_list = []
  65. for idx, (xyxy, conf, cla) in enumerate(zip(
  66. mfd_res.boxes.xyxy, mfd_res.boxes.conf, mfd_res.boxes.cls
  67. )):
  68. xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
  69. new_item = {
  70. "category_id": 13 + int(cla.item()),
  71. "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
  72. "score": round(float(conf.item()), 2),
  73. "latex": "",
  74. }
  75. formula_list.append(new_item)
  76. bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
  77. area = (xmax - xmin) * (ymax - ymin)
  78. curr_idx = len(mf_image_list)
  79. image_info.append((area, curr_idx, bbox_img))
  80. mf_image_list.append(bbox_img)
  81. images_formula_list.append(formula_list)
  82. backfill_list += formula_list
  83. # Stable sort by area
  84. image_info.sort(key=lambda x: x[0]) # sort by area
  85. sorted_indices = [x[1] for x in image_info]
  86. sorted_images = [x[2] for x in image_info]
  87. # Create mapping for results
  88. index_mapping = {new_idx: old_idx for new_idx, old_idx in enumerate(sorted_indices)}
  89. # Create dataset with sorted images
  90. dataset = MathDataset(sorted_images, transform=self.model.transform)
  91. dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
  92. # Process batches and store results
  93. mfr_res = []
  94. # for mf_img in dataloader:
  95. with tqdm(total=len(sorted_images), desc="MFR Predict") as pbar:
  96. for index, mf_img in enumerate(dataloader):
  97. mf_img = mf_img.to(dtype=self.model.dtype)
  98. mf_img = mf_img.to(self.device)
  99. with torch.no_grad():
  100. output = self.model.generate({"image": mf_img})
  101. mfr_res.extend(output["fixed_str"])
  102. # 更新进度条,每次增加batch_size,但要注意最后一个batch可能不足batch_size
  103. current_batch_size = min(batch_size, len(sorted_images) - index * batch_size)
  104. pbar.update(current_batch_size)
  105. # Restore original order
  106. unsorted_results = [""] * len(mfr_res)
  107. for new_idx, latex in enumerate(mfr_res):
  108. original_idx = index_mapping[new_idx]
  109. unsorted_results[original_idx] = latex
  110. # Fill results back
  111. for res, latex in zip(backfill_list, unsorted_results):
  112. res["latex"] = latex
  113. return images_formula_list