import argparse import os import re import torch import unimernet.tasks as tasks from torch.utils.data import DataLoader, Dataset from torchvision import transforms from unimernet.common.config import Config from unimernet.processors import load_processor class MathDataset(Dataset): def __init__(self, image_paths, transform=None): self.image_paths = image_paths self.transform = transform def __len__(self): return len(self.image_paths) def __getitem__(self, idx): # if not pil image, then convert to pil image if isinstance(self.image_paths[idx], str): raw_image = Image.open(self.image_paths[idx]) else: raw_image = self.image_paths[idx] if self.transform: image = self.transform(raw_image) return image def latex_rm_whitespace(s: str): """Remove unnecessary whitespace from LaTeX code.""" text_reg = r"(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})" letter = "[a-zA-Z]" noletter = "[\W_^\d]" names = [x[0].replace(" ", "") for x in re.findall(text_reg, s)] s = re.sub(text_reg, lambda match: str(names.pop(0)), s) news = s while True: s = news news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, noletter), r"\1\2", s) news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, letter), r"\1\2", news) news = re.sub(r"(%s)\s+?(%s)" % (letter, noletter), r"\1\2", news) if news == s: break return s class UnimernetModel(object): def __init__(self, weight_dir, cfg_path, _device_="cpu"): args = argparse.Namespace(cfg_path=cfg_path, options=None) cfg = Config(args) cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.pth") cfg.config.model.model_config.model_name = weight_dir cfg.config.model.tokenizer_config.path = weight_dir task = tasks.setup_task(cfg) self.model = task.build_model(cfg) self.device = _device_ self.model.to(_device_) self.model.eval() vis_processor = load_processor( "formula_image_eval", cfg.config.datasets.formula_rec_eval.vis_processor.eval, ) self.mfr_transform = transforms.Compose( [ vis_processor, ] ) def predict(self, mfd_res, image): formula_list = [] mf_image_list = [] for xyxy, conf, cla in zip( mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu() ): xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy] new_item = { "category_id": 13 + int(cla.item()), "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax], "score": round(float(conf.item()), 2), "latex": "", } formula_list.append(new_item) pil_img = Image.fromarray(image) bbox_img = pil_img.crop((xmin, ymin, xmax, ymax)) mf_image_list.append(bbox_img) dataset = MathDataset(mf_image_list, transform=self.mfr_transform) dataloader = DataLoader(dataset, batch_size=32, num_workers=0) mfr_res = [] for mf_img in dataloader: mf_img = mf_img.to(self.device) with torch.no_grad(): output = self.model.generate({"image": mf_img}) mfr_res.extend(output["pred_str"]) for res, latex in zip(formula_list, mfr_res): res["latex"] = latex_rm_whitespace(latex) return formula_list def batch_predict(self, images_mfd_res: list, images: list, batch_size: int = 64) -> list: images_formula_list = [] mf_image_list = [] backfill_list = [] image_info = [] # Store (area, original_index, image) tuples # Collect images with their original indices for image_index in range(len(images_mfd_res)): mfd_res = images_mfd_res[image_index] np_array_image = images[image_index] formula_list = [] for idx, (xyxy, conf, cla) in enumerate(zip( mfd_res.boxes.xyxy, mfd_res.boxes.conf, mfd_res.boxes.cls )): xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy] new_item = { "category_id": 13 + int(cla.item()), "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax], "score": round(float(conf.item()), 2), "latex": "", } formula_list.append(new_item) bbox_img = np_array_image[ymin:ymax, xmin:xmax] area = (xmax - xmin) * (ymax - ymin) curr_idx = len(mf_image_list) image_info.append((area, curr_idx, bbox_img)) mf_image_list.append(bbox_img) images_formula_list.append(formula_list) backfill_list += formula_list # Stable sort by area image_info.sort(key=lambda x: x[0]) # sort by area sorted_indices = [x[1] for x in image_info] sorted_images = [x[2] for x in image_info] # Create mapping for results index_mapping = {new_idx: old_idx for new_idx, old_idx in enumerate(sorted_indices)} # Create dataset with sorted images dataset = MathDataset(sorted_images, transform=self.mfr_transform) dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0) # Process batches and store results mfr_res = [] for mf_img in dataloader: mf_img = mf_img.to(self.device) with torch.no_grad(): output = self.model.generate({"image": mf_img}) mfr_res.extend(output["pred_str"]) # Restore original order unsorted_results = [""] * len(mfr_res) for new_idx, latex in enumerate(mfr_res): original_idx = index_mapping[new_idx] unsorted_results[original_idx] = latex_rm_whitespace(latex) # Fill results back for res, latex in zip(backfill_list, unsorted_results): res["latex"] = latex return images_formula_list