Unimernet.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. import argparse
  2. import os
  3. import re
  4. import torch
  5. import unimernet.tasks as tasks
  6. from PIL import Image
  7. from torch.utils.data import DataLoader, Dataset
  8. from torchvision import transforms
  9. from unimernet.common.config import Config
  10. from unimernet.processors import load_processor
  11. class MathDataset(Dataset):
  12. def __init__(self, image_paths, transform=None):
  13. self.image_paths = image_paths
  14. self.transform = transform
  15. def __len__(self):
  16. return len(self.image_paths)
  17. def __getitem__(self, idx):
  18. # if not pil image, then convert to pil image
  19. if isinstance(self.image_paths[idx], str):
  20. raw_image = Image.open(self.image_paths[idx])
  21. else:
  22. raw_image = self.image_paths[idx]
  23. if self.transform:
  24. image = self.transform(raw_image)
  25. return image
  26. def latex_rm_whitespace(s: str):
  27. """Remove unnecessary whitespace from LaTeX code."""
  28. text_reg = r"(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})"
  29. letter = "[a-zA-Z]"
  30. noletter = "[\W_^\d]"
  31. names = [x[0].replace(" ", "") for x in re.findall(text_reg, s)]
  32. s = re.sub(text_reg, lambda match: str(names.pop(0)), s)
  33. news = s
  34. while True:
  35. s = news
  36. news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, noletter), r"\1\2", s)
  37. news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, letter), r"\1\2", news)
  38. news = re.sub(r"(%s)\s+?(%s)" % (letter, noletter), r"\1\2", news)
  39. if news == s:
  40. break
  41. return s
  42. class UnimernetModel(object):
  43. def __init__(self, weight_dir, cfg_path, _device_="cpu"):
  44. args = argparse.Namespace(cfg_path=cfg_path, options=None)
  45. cfg = Config(args)
  46. cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.pth")
  47. cfg.config.model.model_config.model_name = weight_dir
  48. cfg.config.model.tokenizer_config.path = weight_dir
  49. task = tasks.setup_task(cfg)
  50. self.model = task.build_model(cfg)
  51. self.device = _device_
  52. self.model.to(_device_)
  53. self.model.eval()
  54. vis_processor = load_processor(
  55. "formula_image_eval",
  56. cfg.config.datasets.formula_rec_eval.vis_processor.eval,
  57. )
  58. self.mfr_transform = transforms.Compose(
  59. [
  60. vis_processor,
  61. ]
  62. )
  63. def predict(self, mfd_res, image):
  64. formula_list = []
  65. mf_image_list = []
  66. for xyxy, conf, cla in zip(
  67. mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()
  68. ):
  69. xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
  70. new_item = {
  71. "category_id": 13 + int(cla.item()),
  72. "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
  73. "score": round(float(conf.item()), 2),
  74. "latex": "",
  75. }
  76. formula_list.append(new_item)
  77. pil_img = Image.fromarray(image)
  78. bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
  79. mf_image_list.append(bbox_img)
  80. dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
  81. dataloader = DataLoader(dataset, batch_size=64, num_workers=0)
  82. mfr_res = []
  83. for mf_img in dataloader:
  84. mf_img = mf_img.to(self.device)
  85. with torch.no_grad():
  86. output = self.model.generate({"image": mf_img})
  87. mfr_res.extend(output["pred_str"])
  88. for res, latex in zip(formula_list, mfr_res):
  89. res["latex"] = latex_rm_whitespace(latex)
  90. return formula_list
  91. def batch_predict(
  92. self, images_mfd_res: list, images: list, batch_size: int = 64
  93. ) -> list:
  94. images_formula_list = []
  95. mf_image_list = []
  96. backfill_list = []
  97. for image_index in range(len(images_mfd_res)):
  98. mfd_res = images_mfd_res[image_index]
  99. pil_img = Image.fromarray(images[image_index])
  100. formula_list = []
  101. for xyxy, conf, cla in zip(
  102. mfd_res.boxes.xyxy, mfd_res.boxes.conf, mfd_res.boxes.cls
  103. ):
  104. xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
  105. new_item = {
  106. "category_id": 13 + int(cla.item()),
  107. "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
  108. "score": round(float(conf.item()), 2),
  109. "latex": "",
  110. }
  111. formula_list.append(new_item)
  112. bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
  113. mf_image_list.append(bbox_img)
  114. images_formula_list.append(formula_list)
  115. backfill_list += formula_list
  116. dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
  117. dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
  118. mfr_res = []
  119. for mf_img in dataloader:
  120. mf_img = mf_img.to(self.device)
  121. with torch.no_grad():
  122. output = self.model.generate({"image": mf_img})
  123. mfr_res.extend(output["pred_str"])
  124. for res, latex in zip(backfill_list, mfr_res):
  125. res["latex"] = latex_rm_whitespace(latex)
  126. return images_formula_list