|
@@ -1,13 +1,5 @@
|
|
|
-import argparse
|
|
|
|
|
-import os
|
|
|
|
|
-import re
|
|
|
|
|
-
|
|
|
|
|
import torch
|
|
import torch
|
|
|
-import unimernet.tasks as tasks
|
|
|
|
|
from torch.utils.data import DataLoader, Dataset
|
|
from torch.utils.data import DataLoader, Dataset
|
|
|
-from torchvision import transforms
|
|
|
|
|
-from unimernet.common.config import Config
|
|
|
|
|
-from unimernet.processors import load_processor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MathDataset(Dataset):
|
|
class MathDataset(Dataset):
|
|
@@ -18,46 +10,26 @@ class MathDataset(Dataset):
|
|
|
def __len__(self):
|
|
def __len__(self):
|
|
|
return len(self.image_paths)
|
|
return len(self.image_paths)
|
|
|
|
|
|
|
|
-
|
|
|
|
|
-def latex_rm_whitespace(s: str):
|
|
|
|
|
- """Remove unnecessary whitespace from LaTeX code."""
|
|
|
|
|
- text_reg = r"(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})"
|
|
|
|
|
- letter = "[a-zA-Z]"
|
|
|
|
|
- noletter = "[\W_^\d]"
|
|
|
|
|
- names = [x[0].replace(" ", "") for x in re.findall(text_reg, s)]
|
|
|
|
|
- s = re.sub(text_reg, lambda match: str(names.pop(0)), s)
|
|
|
|
|
- news = s
|
|
|
|
|
- while True:
|
|
|
|
|
- s = news
|
|
|
|
|
- news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, noletter), r"\1\2", s)
|
|
|
|
|
- news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, letter), r"\1\2", news)
|
|
|
|
|
- news = re.sub(r"(%s)\s+?(%s)" % (letter, noletter), r"\1\2", news)
|
|
|
|
|
- if news == s:
|
|
|
|
|
- break
|
|
|
|
|
- return s
|
|
|
|
|
|
|
+ def __getitem__(self, idx):
|
|
|
|
|
+ raw_image = self.image_paths[idx]
|
|
|
|
|
+ if self.transform:
|
|
|
|
|
+ image = self.transform(raw_image)
|
|
|
|
|
+ return image
|
|
|
|
|
|
|
|
|
|
|
|
|
class UnimernetModel(object):
|
|
class UnimernetModel(object):
|
|
|
def __init__(self, weight_dir, cfg_path, _device_="cpu"):
|
|
def __init__(self, weight_dir, cfg_path, _device_="cpu"):
|
|
|
- args = argparse.Namespace(cfg_path=cfg_path, options=None)
|
|
|
|
|
- cfg = Config(args)
|
|
|
|
|
- cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.pth")
|
|
|
|
|
- cfg.config.model.model_config.model_name = weight_dir
|
|
|
|
|
- cfg.config.model.tokenizer_config.path = weight_dir
|
|
|
|
|
- task = tasks.setup_task(cfg)
|
|
|
|
|
- self.model = task.build_model(cfg)
|
|
|
|
|
|
|
+ from .unimernet_hf import UnimernetModel
|
|
|
|
|
+ if _device_.startswith("mps"):
|
|
|
|
|
+ self.model = UnimernetModel.from_pretrained(weight_dir, attn_implementation="eager")
|
|
|
|
|
+ else:
|
|
|
|
|
+ self.model = UnimernetModel.from_pretrained(weight_dir)
|
|
|
self.device = _device_
|
|
self.device = _device_
|
|
|
self.model.to(_device_)
|
|
self.model.to(_device_)
|
|
|
|
|
+ if not _device_.startswith("cpu"):
|
|
|
|
|
+ self.model = self.model.to(dtype=torch.float16)
|
|
|
self.model.eval()
|
|
self.model.eval()
|
|
|
- vis_processor = load_processor(
|
|
|
|
|
- "formula_image_eval",
|
|
|
|
|
- cfg.config.datasets.formula_rec_eval.vis_processor.eval,
|
|
|
|
|
- )
|
|
|
|
|
- self.mfr_transform = transforms.Compose(
|
|
|
|
|
- [
|
|
|
|
|
- vis_processor,
|
|
|
|
|
- ]
|
|
|
|
|
- )
|
|
|
|
|
|
|
+
|
|
|
|
|
|
|
|
def predict(self, mfd_res, image):
|
|
def predict(self, mfd_res, image):
|
|
|
formula_list = []
|
|
formula_list = []
|
|
@@ -76,16 +48,17 @@ class UnimernetModel(object):
|
|
|
bbox_img = image[ymin:ymax, xmin:xmax]
|
|
bbox_img = image[ymin:ymax, xmin:xmax]
|
|
|
mf_image_list.append(bbox_img)
|
|
mf_image_list.append(bbox_img)
|
|
|
|
|
|
|
|
- dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
|
|
|
|
|
|
|
+ dataset = MathDataset(mf_image_list, transform=self.model.transform)
|
|
|
dataloader = DataLoader(dataset, batch_size=32, num_workers=0)
|
|
dataloader = DataLoader(dataset, batch_size=32, num_workers=0)
|
|
|
mfr_res = []
|
|
mfr_res = []
|
|
|
for mf_img in dataloader:
|
|
for mf_img in dataloader:
|
|
|
|
|
+ mf_img = mf_img.to(dtype=self.model.dtype)
|
|
|
mf_img = mf_img.to(self.device)
|
|
mf_img = mf_img.to(self.device)
|
|
|
with torch.no_grad():
|
|
with torch.no_grad():
|
|
|
output = self.model.generate({"image": mf_img})
|
|
output = self.model.generate({"image": mf_img})
|
|
|
- mfr_res.extend(output["pred_str"])
|
|
|
|
|
|
|
+ mfr_res.extend(output["fixed_str"])
|
|
|
for res, latex in zip(formula_list, mfr_res):
|
|
for res, latex in zip(formula_list, mfr_res):
|
|
|
- res["latex"] = latex_rm_whitespace(latex)
|
|
|
|
|
|
|
+ res["latex"] = latex
|
|
|
return formula_list
|
|
return formula_list
|
|
|
|
|
|
|
|
def batch_predict(self, images_mfd_res: list, images: list, batch_size: int = 64) -> list:
|
|
def batch_predict(self, images_mfd_res: list, images: list, batch_size: int = 64) -> list:
|
|
@@ -130,22 +103,23 @@ class UnimernetModel(object):
|
|
|
index_mapping = {new_idx: old_idx for new_idx, old_idx in enumerate(sorted_indices)}
|
|
index_mapping = {new_idx: old_idx for new_idx, old_idx in enumerate(sorted_indices)}
|
|
|
|
|
|
|
|
# Create dataset with sorted images
|
|
# Create dataset with sorted images
|
|
|
- dataset = MathDataset(sorted_images, transform=self.mfr_transform)
|
|
|
|
|
|
|
+ dataset = MathDataset(sorted_images, transform=self.model.transform)
|
|
|
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
|
|
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
|
|
|
|
|
|
|
|
# Process batches and store results
|
|
# Process batches and store results
|
|
|
mfr_res = []
|
|
mfr_res = []
|
|
|
for mf_img in dataloader:
|
|
for mf_img in dataloader:
|
|
|
|
|
+ mf_img = mf_img.to(dtype=self.model.dtype)
|
|
|
mf_img = mf_img.to(self.device)
|
|
mf_img = mf_img.to(self.device)
|
|
|
with torch.no_grad():
|
|
with torch.no_grad():
|
|
|
output = self.model.generate({"image": mf_img})
|
|
output = self.model.generate({"image": mf_img})
|
|
|
- mfr_res.extend(output["pred_str"])
|
|
|
|
|
|
|
+ mfr_res.extend(output["fixed_str"])
|
|
|
|
|
|
|
|
# Restore original order
|
|
# Restore original order
|
|
|
unsorted_results = [""] * len(mfr_res)
|
|
unsorted_results = [""] * len(mfr_res)
|
|
|
for new_idx, latex in enumerate(mfr_res):
|
|
for new_idx, latex in enumerate(mfr_res):
|
|
|
original_idx = index_mapping[new_idx]
|
|
original_idx = index_mapping[new_idx]
|
|
|
- unsorted_results[original_idx] = latex_rm_whitespace(latex)
|
|
|
|
|
|
|
+ unsorted_results[original_idx] = latex
|
|
|
|
|
|
|
|
# Fill results back
|
|
# Fill results back
|
|
|
for res, latex in zip(backfill_list, unsorted_results):
|
|
for res, latex in zip(backfill_list, unsorted_results):
|