|
|
@@ -18,16 +18,6 @@ class MathDataset(Dataset):
|
|
|
def __len__(self):
|
|
|
return len(self.image_paths)
|
|
|
|
|
|
- def __getitem__(self, idx):
|
|
|
- # if not pil image, then convert to pil image
|
|
|
- if isinstance(self.image_paths[idx], str):
|
|
|
- raw_image = Image.open(self.image_paths[idx])
|
|
|
- else:
|
|
|
- raw_image = self.image_paths[idx]
|
|
|
- if self.transform:
|
|
|
- image = self.transform(raw_image)
|
|
|
- return image
|
|
|
-
|
|
|
|
|
|
def latex_rm_whitespace(s: str):
|
|
|
"""Remove unnecessary whitespace from LaTeX code."""
|
|
|
@@ -83,8 +73,7 @@ class UnimernetModel(object):
|
|
|
"latex": "",
|
|
|
}
|
|
|
formula_list.append(new_item)
|
|
|
- pil_img = Image.fromarray(image)
|
|
|
- bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
|
|
|
+ bbox_img = image[ymin:ymax, xmin:xmax]
|
|
|
mf_image_list.append(bbox_img)
|
|
|
|
|
|
dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
|
|
|
@@ -99,7 +88,6 @@ class UnimernetModel(object):
|
|
|
res["latex"] = latex_rm_whitespace(latex)
|
|
|
return formula_list
|
|
|
|
|
|
-
|
|
|
def batch_predict(self, images_mfd_res: list, images: list, batch_size: int = 64) -> list:
|
|
|
images_formula_list = []
|
|
|
mf_image_list = []
|