Unimernet.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. import torch
  2. from torch.utils.data import DataLoader, Dataset
  3. class MathDataset(Dataset):
  4. def __init__(self, image_paths, transform=None):
  5. self.image_paths = image_paths
  6. self.transform = transform
  7. def __len__(self):
  8. return len(self.image_paths)
  9. def __getitem__(self, idx):
  10. raw_image = self.image_paths[idx]
  11. if self.transform:
  12. image = self.transform(raw_image)
  13. return image
  14. class UnimernetModel(object):
  15. def __init__(self, weight_dir, cfg_path, _device_="cpu"):
  16. from .unimernet_hf import UnimernetModel
  17. if _device_.startswith("mps"):
  18. self.model = UnimernetModel.from_pretrained(weight_dir, attn_implementation="eager")
  19. else:
  20. self.model = UnimernetModel.from_pretrained(weight_dir)
  21. self.device = _device_
  22. self.model.to(_device_)
  23. if not _device_.startswith("cpu"):
  24. self.model = self.model.to(dtype=torch.float16)
  25. self.model.eval()
  26. def predict(self, mfd_res, image):
  27. formula_list = []
  28. mf_image_list = []
  29. for xyxy, conf, cla in zip(
  30. mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()
  31. ):
  32. xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
  33. new_item = {
  34. "category_id": 13 + int(cla.item()),
  35. "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
  36. "score": round(float(conf.item()), 2),
  37. "latex": "",
  38. }
  39. formula_list.append(new_item)
  40. bbox_img = image[ymin:ymax, xmin:xmax]
  41. mf_image_list.append(bbox_img)
  42. dataset = MathDataset(mf_image_list, transform=self.model.transform)
  43. dataloader = DataLoader(dataset, batch_size=32, num_workers=0)
  44. mfr_res = []
  45. for mf_img in dataloader:
  46. mf_img = mf_img.to(dtype=self.model.dtype)
  47. mf_img = mf_img.to(self.device)
  48. with torch.no_grad():
  49. output = self.model.generate({"image": mf_img})
  50. mfr_res.extend(output["fixed_str"])
  51. for res, latex in zip(formula_list, mfr_res):
  52. res["latex"] = latex
  53. return formula_list
  54. def batch_predict(self, images_mfd_res: list, images: list, batch_size: int = 64) -> list:
  55. images_formula_list = []
  56. mf_image_list = []
  57. backfill_list = []
  58. image_info = [] # Store (area, original_index, image) tuples
  59. # Collect images with their original indices
  60. for image_index in range(len(images_mfd_res)):
  61. mfd_res = images_mfd_res[image_index]
  62. np_array_image = images[image_index]
  63. formula_list = []
  64. for idx, (xyxy, conf, cla) in enumerate(zip(
  65. mfd_res.boxes.xyxy, mfd_res.boxes.conf, mfd_res.boxes.cls
  66. )):
  67. xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
  68. new_item = {
  69. "category_id": 13 + int(cla.item()),
  70. "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
  71. "score": round(float(conf.item()), 2),
  72. "latex": "",
  73. }
  74. formula_list.append(new_item)
  75. bbox_img = np_array_image[ymin:ymax, xmin:xmax]
  76. area = (xmax - xmin) * (ymax - ymin)
  77. curr_idx = len(mf_image_list)
  78. image_info.append((area, curr_idx, bbox_img))
  79. mf_image_list.append(bbox_img)
  80. images_formula_list.append(formula_list)
  81. backfill_list += formula_list
  82. # Stable sort by area
  83. image_info.sort(key=lambda x: x[0]) # sort by area
  84. sorted_indices = [x[1] for x in image_info]
  85. sorted_images = [x[2] for x in image_info]
  86. # Create mapping for results
  87. index_mapping = {new_idx: old_idx for new_idx, old_idx in enumerate(sorted_indices)}
  88. # Create dataset with sorted images
  89. dataset = MathDataset(sorted_images, transform=self.model.transform)
  90. dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
  91. # Process batches and store results
  92. mfr_res = []
  93. for mf_img in dataloader:
  94. mf_img = mf_img.to(dtype=self.model.dtype)
  95. mf_img = mf_img.to(self.device)
  96. with torch.no_grad():
  97. output = self.model.generate({"image": mf_img})
  98. mfr_res.extend(output["fixed_str"])
  99. # Restore original order
  100. unsorted_results = [""] * len(mfr_res)
  101. for new_idx, latex in enumerate(mfr_res):
  102. original_idx = index_mapping[new_idx]
  103. unsorted_results[original_idx] = latex
  104. # Fill results back
  105. for res, latex in zip(backfill_list, unsorted_results):
  106. res["latex"] = latex
  107. return images_formula_list