|
@@ -1,13 +1,13 @@
|
|
|
-import os
|
|
|
|
|
import argparse
|
|
import argparse
|
|
|
|
|
+import os
|
|
|
import re
|
|
import re
|
|
|
|
|
|
|
|
-from PIL import Image
|
|
|
|
|
import torch
|
|
import torch
|
|
|
-from torch.utils.data import Dataset, DataLoader
|
|
|
|
|
|
|
+import unimernet.tasks as tasks
|
|
|
|
|
+from PIL import Image
|
|
|
|
|
+from torch.utils.data import DataLoader, Dataset
|
|
|
from torchvision import transforms
|
|
from torchvision import transforms
|
|
|
from unimernet.common.config import Config
|
|
from unimernet.common.config import Config
|
|
|
-import unimernet.tasks as tasks
|
|
|
|
|
from unimernet.processors import load_processor
|
|
from unimernet.processors import load_processor
|
|
|
|
|
|
|
|
|
|
|
|
@@ -31,27 +31,25 @@ class MathDataset(Dataset):
|
|
|
|
|
|
|
|
|
|
|
|
|
def latex_rm_whitespace(s: str):
|
|
def latex_rm_whitespace(s: str):
|
|
|
- """Remove unnecessary whitespace from LaTeX code.
|
|
|
|
|
- """
|
|
|
|
|
- text_reg = r'(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})'
|
|
|
|
|
- letter = '[a-zA-Z]'
|
|
|
|
|
- noletter = '[\W_^\d]'
|
|
|
|
|
- names = [x[0].replace(' ', '') for x in re.findall(text_reg, s)]
|
|
|
|
|
|
|
+ """Remove unnecessary whitespace from LaTeX code."""
|
|
|
|
|
+ text_reg = r"(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})"
|
|
|
|
|
+ letter = "[a-zA-Z]"
|
|
|
|
|
+ noletter = "[\W_^\d]"
|
|
|
|
|
+ names = [x[0].replace(" ", "") for x in re.findall(text_reg, s)]
|
|
|
s = re.sub(text_reg, lambda match: str(names.pop(0)), s)
|
|
s = re.sub(text_reg, lambda match: str(names.pop(0)), s)
|
|
|
news = s
|
|
news = s
|
|
|
while True:
|
|
while True:
|
|
|
s = news
|
|
s = news
|
|
|
- news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, noletter), r'\1\2', s)
|
|
|
|
|
- news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, letter), r'\1\2', news)
|
|
|
|
|
- news = re.sub(r'(%s)\s+?(%s)' % (letter, noletter), r'\1\2', news)
|
|
|
|
|
|
|
+ news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, noletter), r"\1\2", s)
|
|
|
|
|
+ news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, letter), r"\1\2", news)
|
|
|
|
|
+ news = re.sub(r"(%s)\s+?(%s)" % (letter, noletter), r"\1\2", news)
|
|
|
if news == s:
|
|
if news == s:
|
|
|
break
|
|
break
|
|
|
return s
|
|
return s
|
|
|
|
|
|
|
|
|
|
|
|
|
class UnimernetModel(object):
|
|
class UnimernetModel(object):
|
|
|
- def __init__(self, weight_dir, cfg_path, _device_='cpu'):
|
|
|
|
|
-
|
|
|
|
|
|
|
+ def __init__(self, weight_dir, cfg_path, _device_="cpu"):
|
|
|
args = argparse.Namespace(cfg_path=cfg_path, options=None)
|
|
args = argparse.Namespace(cfg_path=cfg_path, options=None)
|
|
|
cfg = Config(args)
|
|
cfg = Config(args)
|
|
|
cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.pth")
|
|
cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.pth")
|
|
@@ -62,20 +60,28 @@ class UnimernetModel(object):
|
|
|
self.device = _device_
|
|
self.device = _device_
|
|
|
self.model.to(_device_)
|
|
self.model.to(_device_)
|
|
|
self.model.eval()
|
|
self.model.eval()
|
|
|
- vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
|
|
|
|
|
- self.mfr_transform = transforms.Compose([vis_processor, ])
|
|
|
|
|
|
|
+ vis_processor = load_processor(
|
|
|
|
|
+ "formula_image_eval",
|
|
|
|
|
+ cfg.config.datasets.formula_rec_eval.vis_processor.eval,
|
|
|
|
|
+ )
|
|
|
|
|
+ self.mfr_transform = transforms.Compose(
|
|
|
|
|
+ [
|
|
|
|
|
+ vis_processor,
|
|
|
|
|
+ ]
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
def predict(self, mfd_res, image):
|
|
def predict(self, mfd_res, image):
|
|
|
-
|
|
|
|
|
formula_list = []
|
|
formula_list = []
|
|
|
mf_image_list = []
|
|
mf_image_list = []
|
|
|
- for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
|
|
|
|
|
|
|
+ for xyxy, conf, cla in zip(
|
|
|
|
|
+ mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()
|
|
|
|
|
+ ):
|
|
|
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
|
|
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
|
|
|
new_item = {
|
|
new_item = {
|
|
|
- 'category_id': 13 + int(cla.item()),
|
|
|
|
|
- 'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
|
|
|
|
|
- 'score': round(float(conf.item()), 2),
|
|
|
|
|
- 'latex': '',
|
|
|
|
|
|
|
+ "category_id": 13 + int(cla.item()),
|
|
|
|
|
+ "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
|
|
|
|
|
+ "score": round(float(conf.item()), 2),
|
|
|
|
|
+ "latex": "",
|
|
|
}
|
|
}
|
|
|
formula_list.append(new_item)
|
|
formula_list.append(new_item)
|
|
|
pil_img = Image.fromarray(image)
|
|
pil_img = Image.fromarray(image)
|
|
@@ -88,11 +94,48 @@ class UnimernetModel(object):
|
|
|
for mf_img in dataloader:
|
|
for mf_img in dataloader:
|
|
|
mf_img = mf_img.to(self.device)
|
|
mf_img = mf_img.to(self.device)
|
|
|
with torch.no_grad():
|
|
with torch.no_grad():
|
|
|
- output = self.model.generate({'image': mf_img})
|
|
|
|
|
- mfr_res.extend(output['pred_str'])
|
|
|
|
|
|
|
+ output = self.model.generate({"image": mf_img})
|
|
|
|
|
+ mfr_res.extend(output["pred_str"])
|
|
|
for res, latex in zip(formula_list, mfr_res):
|
|
for res, latex in zip(formula_list, mfr_res):
|
|
|
- res['latex'] = latex_rm_whitespace(latex)
|
|
|
|
|
|
|
+ res["latex"] = latex_rm_whitespace(latex)
|
|
|
return formula_list
|
|
return formula_list
|
|
|
|
|
|
|
|
|
|
+ def batch_predict(
|
|
|
|
|
+ self, images_mfd_res: list, images: list, batch_size: int = 64
|
|
|
|
|
+ ) -> list:
|
|
|
|
|
+ images_formula_list = []
|
|
|
|
|
+ mf_image_list = []
|
|
|
|
|
+ backfill_list = []
|
|
|
|
|
+ for image_index in range(len(images_mfd_res)):
|
|
|
|
|
+ mfd_res = images_mfd_res[image_index]
|
|
|
|
|
+ pil_img = Image.fromarray(images[image_index])
|
|
|
|
|
+ formula_list = []
|
|
|
|
|
+
|
|
|
|
|
+ for xyxy, conf, cla in zip(
|
|
|
|
|
+ mfd_res.boxes.xyxy, mfd_res.boxes.conf, mfd_res.boxes.cls
|
|
|
|
|
+ ):
|
|
|
|
|
+ xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
|
|
|
|
|
+ new_item = {
|
|
|
|
|
+ "category_id": 13 + int(cla.item()),
|
|
|
|
|
+ "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
|
|
|
|
|
+ "score": round(float(conf.item()), 2),
|
|
|
|
|
+ "latex": "",
|
|
|
|
|
+ }
|
|
|
|
|
+ formula_list.append(new_item)
|
|
|
|
|
+ bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
|
|
|
|
|
+ mf_image_list.append(bbox_img)
|
|
|
|
|
+
|
|
|
|
|
+ images_formula_list.append(formula_list)
|
|
|
|
|
+ backfill_list += formula_list
|
|
|
|
|
|
|
|
-
|
|
|
|
|
|
|
+ dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
|
|
|
|
|
+ dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
|
|
|
|
|
+ mfr_res = []
|
|
|
|
|
+ for mf_img in dataloader:
|
|
|
|
|
+ mf_img = mf_img.to(self.device)
|
|
|
|
|
+ with torch.no_grad():
|
|
|
|
|
+ output = self.model.generate({"image": mf_img})
|
|
|
|
|
+ mfr_res.extend(output["pred_str"])
|
|
|
|
|
+ for res, latex in zip(backfill_list, mfr_res):
|
|
|
|
|
+ res["latex"] = latex_rm_whitespace(latex)
|
|
|
|
|
+ return images_formula_list
|