Unimernet.py 3.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. import os
  2. import argparse
  3. import re
  4. from PIL import Image
  5. import torch
  6. from torch.utils.data import Dataset, DataLoader
  7. from torchvision import transforms
  8. from unimernet.common.config import Config
  9. import unimernet.tasks as tasks
  10. from unimernet.processors import load_processor
  11. class MathDataset(Dataset):
  12. def __init__(self, image_paths, transform=None):
  13. self.image_paths = image_paths
  14. self.transform = transform
  15. def __len__(self):
  16. return len(self.image_paths)
  17. def __getitem__(self, idx):
  18. # if not pil image, then convert to pil image
  19. if isinstance(self.image_paths[idx], str):
  20. raw_image = Image.open(self.image_paths[idx])
  21. else:
  22. raw_image = self.image_paths[idx]
  23. if self.transform:
  24. image = self.transform(raw_image)
  25. return image
  26. def latex_rm_whitespace(s: str):
  27. """Remove unnecessary whitespace from LaTeX code.
  28. """
  29. text_reg = r'(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})'
  30. letter = '[a-zA-Z]'
  31. noletter = '[\W_^\d]'
  32. names = [x[0].replace(' ', '') for x in re.findall(text_reg, s)]
  33. s = re.sub(text_reg, lambda match: str(names.pop(0)), s)
  34. news = s
  35. while True:
  36. s = news
  37. news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, noletter), r'\1\2', s)
  38. news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, letter), r'\1\2', news)
  39. news = re.sub(r'(%s)\s+?(%s)' % (letter, noletter), r'\1\2', news)
  40. if news == s:
  41. break
  42. return s
  43. class UnimernetModel(object):
  44. def __init__(self, weight_dir, cfg_path, _device_='cpu'):
  45. args = argparse.Namespace(cfg_path=cfg_path, options=None)
  46. cfg = Config(args)
  47. cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.pth")
  48. cfg.config.model.model_config.model_name = weight_dir
  49. cfg.config.model.tokenizer_config.path = weight_dir
  50. task = tasks.setup_task(cfg)
  51. self.model = task.build_model(cfg)
  52. self.device = _device_
  53. self.model.to(_device_)
  54. self.model.eval()
  55. vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
  56. self.mfr_transform = transforms.Compose([vis_processor, ])
  57. def predict(self, mfd_res, image):
  58. formula_list = []
  59. mf_image_list = []
  60. for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
  61. xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
  62. new_item = {
  63. 'category_id': 13 + int(cla.item()),
  64. 'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
  65. 'score': round(float(conf.item()), 2),
  66. 'latex': '',
  67. }
  68. formula_list.append(new_item)
  69. pil_img = Image.fromarray(image)
  70. bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
  71. mf_image_list.append(bbox_img)
  72. dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
  73. dataloader = DataLoader(dataset, batch_size=64, num_workers=0)
  74. mfr_res = []
  75. for mf_img in dataloader:
  76. mf_img = mf_img.to(self.device)
  77. with torch.no_grad():
  78. output = self.model.generate({'image': mf_img})
  79. mfr_res.extend(output['pred_str'])
  80. for res, latex in zip(formula_list, mfr_res):
  81. res['latex'] = latex_rm_whitespace(latex)
  82. return formula_list