| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798 |
- import os
- import argparse
- import re
- from PIL import Image
- import torch
- from torch.utils.data import Dataset, DataLoader
- from torchvision import transforms
- from unimernet.common.config import Config
- import unimernet.tasks as tasks
- from unimernet.processors import load_processor
- class MathDataset(Dataset):
- def __init__(self, image_paths, transform=None):
- self.image_paths = image_paths
- self.transform = transform
- def __len__(self):
- return len(self.image_paths)
- def __getitem__(self, idx):
- # if not pil image, then convert to pil image
- if isinstance(self.image_paths[idx], str):
- raw_image = Image.open(self.image_paths[idx])
- else:
- raw_image = self.image_paths[idx]
- if self.transform:
- image = self.transform(raw_image)
- return image
- def latex_rm_whitespace(s: str):
- """Remove unnecessary whitespace from LaTeX code.
- """
- text_reg = r'(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})'
- letter = '[a-zA-Z]'
- noletter = '[\W_^\d]'
- names = [x[0].replace(' ', '') for x in re.findall(text_reg, s)]
- s = re.sub(text_reg, lambda match: str(names.pop(0)), s)
- news = s
- while True:
- s = news
- news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, noletter), r'\1\2', s)
- news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, letter), r'\1\2', news)
- news = re.sub(r'(%s)\s+?(%s)' % (letter, noletter), r'\1\2', news)
- if news == s:
- break
- return s
- class UnimernetModel(object):
- def __init__(self, weight_dir, cfg_path, _device_='cpu'):
- args = argparse.Namespace(cfg_path=cfg_path, options=None)
- cfg = Config(args)
- cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.pth")
- cfg.config.model.model_config.model_name = weight_dir
- cfg.config.model.tokenizer_config.path = weight_dir
- task = tasks.setup_task(cfg)
- self.model = task.build_model(cfg)
- self.device = _device_
- self.model.to(_device_)
- self.model.eval()
- vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
- self.mfr_transform = transforms.Compose([vis_processor, ])
- def predict(self, mfd_res, image):
- formula_list = []
- mf_image_list = []
- for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
- xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
- new_item = {
- 'category_id': 13 + int(cla.item()),
- 'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
- 'score': round(float(conf.item()), 2),
- 'latex': '',
- }
- formula_list.append(new_item)
- pil_img = Image.fromarray(image)
- bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
- mf_image_list.append(bbox_img)
- dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
- dataloader = DataLoader(dataset, batch_size=64, num_workers=0)
- mfr_res = []
- for mf_img in dataloader:
- mf_img = mf_img.to(self.device)
- with torch.no_grad():
- output = self.model.generate({'image': mf_img})
- mfr_res.extend(output['pred_str'])
- for res, latex in zip(formula_list, mfr_res):
- res['latex'] = latex_rm_whitespace(latex)
- return formula_list
|