Unimernet.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. import argparse
  2. import os
  3. import re
  4. import torch
  5. import unimernet.tasks as tasks
  6. from torch.utils.data import DataLoader, Dataset
  7. from torchvision import transforms
  8. from unimernet.common.config import Config
  9. from unimernet.processors import load_processor
  10. class MathDataset(Dataset):
  11. def __init__(self, image_paths, transform=None):
  12. self.image_paths = image_paths
  13. self.transform = transform
  14. def __len__(self):
  15. return len(self.image_paths)
  16. def __getitem__(self, idx):
  17. # if not pil image, then convert to pil image
  18. if isinstance(self.image_paths[idx], str):
  19. raw_image = Image.open(self.image_paths[idx])
  20. else:
  21. raw_image = self.image_paths[idx]
  22. if self.transform:
  23. image = self.transform(raw_image)
  24. return image
  25. def latex_rm_whitespace(s: str):
  26. """Remove unnecessary whitespace from LaTeX code."""
  27. text_reg = r"(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})"
  28. letter = "[a-zA-Z]"
  29. noletter = "[\W_^\d]"
  30. names = [x[0].replace(" ", "") for x in re.findall(text_reg, s)]
  31. s = re.sub(text_reg, lambda match: str(names.pop(0)), s)
  32. news = s
  33. while True:
  34. s = news
  35. news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, noletter), r"\1\2", s)
  36. news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, letter), r"\1\2", news)
  37. news = re.sub(r"(%s)\s+?(%s)" % (letter, noletter), r"\1\2", news)
  38. if news == s:
  39. break
  40. return s
  41. class UnimernetModel(object):
  42. def __init__(self, weight_dir, cfg_path, _device_="cpu"):
  43. args = argparse.Namespace(cfg_path=cfg_path, options=None)
  44. cfg = Config(args)
  45. cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.pth")
  46. cfg.config.model.model_config.model_name = weight_dir
  47. cfg.config.model.tokenizer_config.path = weight_dir
  48. task = tasks.setup_task(cfg)
  49. self.model = task.build_model(cfg)
  50. self.device = _device_
  51. self.model.to(_device_)
  52. self.model.eval()
  53. vis_processor = load_processor(
  54. "formula_image_eval",
  55. cfg.config.datasets.formula_rec_eval.vis_processor.eval,
  56. )
  57. self.mfr_transform = transforms.Compose(
  58. [
  59. vis_processor,
  60. ]
  61. )
  62. def predict(self, mfd_res, image):
  63. formula_list = []
  64. mf_image_list = []
  65. for xyxy, conf, cla in zip(
  66. mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()
  67. ):
  68. xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
  69. new_item = {
  70. "category_id": 13 + int(cla.item()),
  71. "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
  72. "score": round(float(conf.item()), 2),
  73. "latex": "",
  74. }
  75. formula_list.append(new_item)
  76. pil_img = Image.fromarray(image)
  77. bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
  78. mf_image_list.append(bbox_img)
  79. dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
  80. dataloader = DataLoader(dataset, batch_size=32, num_workers=0)
  81. mfr_res = []
  82. for mf_img in dataloader:
  83. mf_img = mf_img.to(self.device)
  84. with torch.no_grad():
  85. output = self.model.generate({"image": mf_img})
  86. mfr_res.extend(output["pred_str"])
  87. for res, latex in zip(formula_list, mfr_res):
  88. res["latex"] = latex_rm_whitespace(latex)
  89. return formula_list
  90. def batch_predict(self, images_mfd_res: list, images: list, batch_size: int = 64) -> list:
  91. images_formula_list = []
  92. mf_image_list = []
  93. backfill_list = []
  94. image_info = [] # Store (area, original_index, image) tuples
  95. # Collect images with their original indices
  96. for image_index in range(len(images_mfd_res)):
  97. mfd_res = images_mfd_res[image_index]
  98. np_array_image = images[image_index]
  99. formula_list = []
  100. for idx, (xyxy, conf, cla) in enumerate(zip(
  101. mfd_res.boxes.xyxy, mfd_res.boxes.conf, mfd_res.boxes.cls
  102. )):
  103. xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
  104. new_item = {
  105. "category_id": 13 + int(cla.item()),
  106. "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
  107. "score": round(float(conf.item()), 2),
  108. "latex": "",
  109. }
  110. formula_list.append(new_item)
  111. bbox_img = np_array_image[ymin:ymax, xmin:xmax]
  112. area = (xmax - xmin) * (ymax - ymin)
  113. curr_idx = len(mf_image_list)
  114. image_info.append((area, curr_idx, bbox_img))
  115. mf_image_list.append(bbox_img)
  116. images_formula_list.append(formula_list)
  117. backfill_list += formula_list
  118. # Stable sort by area
  119. image_info.sort(key=lambda x: x[0]) # sort by area
  120. sorted_indices = [x[1] for x in image_info]
  121. sorted_images = [x[2] for x in image_info]
  122. # Create mapping for results
  123. index_mapping = {new_idx: old_idx for new_idx, old_idx in enumerate(sorted_indices)}
  124. # Create dataset with sorted images
  125. dataset = MathDataset(sorted_images, transform=self.mfr_transform)
  126. dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
  127. # Process batches and store results
  128. mfr_res = []
  129. for mf_img in dataloader:
  130. mf_img = mf_img.to(self.device)
  131. with torch.no_grad():
  132. output = self.model.generate({"image": mf_img})
  133. mfr_res.extend(output["pred_str"])
  134. # Restore original order
  135. unsorted_results = [""] * len(mfr_res)
  136. for new_idx, latex in enumerate(mfr_res):
  137. original_idx = index_mapping[new_idx]
  138. unsorted_results[original_idx] = latex_rm_whitespace(latex)
  139. # Fill results back
  140. for res, latex in zip(backfill_list, unsorted_results):
  141. res["latex"] = latex
  142. return images_formula_list