|
|
@@ -100,20 +100,61 @@ class UnimernetModel(object):
|
|
|
res["latex"] = latex_rm_whitespace(latex)
|
|
|
return formula_list
|
|
|
|
|
|
- def batch_predict(
|
|
|
- self, images_mfd_res: list, images: list, batch_size: int = 64
|
|
|
- ) -> list:
|
|
|
+ # def batch_predict(
|
|
|
+ # self, images_mfd_res: list, images: list, batch_size: int = 64
|
|
|
+ # ) -> list:
|
|
|
+ # images_formula_list = []
|
|
|
+ # mf_image_list = []
|
|
|
+ # backfill_list = []
|
|
|
+ # for image_index in range(len(images_mfd_res)):
|
|
|
+ # mfd_res = images_mfd_res[image_index]
|
|
|
+ # pil_img = Image.fromarray(images[image_index])
|
|
|
+ # formula_list = []
|
|
|
+ #
|
|
|
+ # for xyxy, conf, cla in zip(
|
|
|
+ # mfd_res.boxes.xyxy, mfd_res.boxes.conf, mfd_res.boxes.cls
|
|
|
+ # ):
|
|
|
+ # xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
|
|
|
+ # new_item = {
|
|
|
+ # "category_id": 13 + int(cla.item()),
|
|
|
+ # "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
|
|
|
+ # "score": round(float(conf.item()), 2),
|
|
|
+ # "latex": "",
|
|
|
+ # }
|
|
|
+ # formula_list.append(new_item)
|
|
|
+ # bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
|
|
|
+ # mf_image_list.append(bbox_img)
|
|
|
+ #
|
|
|
+ # images_formula_list.append(formula_list)
|
|
|
+ # backfill_list += formula_list
|
|
|
+ #
|
|
|
+ # dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
|
|
|
+ # dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
|
|
|
+ # mfr_res = []
|
|
|
+ # for mf_img in dataloader:
|
|
|
+ # mf_img = mf_img.to(self.device)
|
|
|
+ # with torch.no_grad():
|
|
|
+ # output = self.model.generate({"image": mf_img})
|
|
|
+ # mfr_res.extend(output["pred_str"])
|
|
|
+ # for res, latex in zip(backfill_list, mfr_res):
|
|
|
+ # res["latex"] = latex_rm_whitespace(latex)
|
|
|
+ # return images_formula_list
|
|
|
+
|
|
|
+ def batch_predict(self, images_mfd_res: list, images: list, batch_size: int = 64) -> list:
|
|
|
images_formula_list = []
|
|
|
mf_image_list = []
|
|
|
backfill_list = []
|
|
|
+ image_info = [] # Store (area, original_index, image) tuples
|
|
|
+
|
|
|
+ # Collect images with their original indices
|
|
|
for image_index in range(len(images_mfd_res)):
|
|
|
mfd_res = images_mfd_res[image_index]
|
|
|
pil_img = Image.fromarray(images[image_index])
|
|
|
formula_list = []
|
|
|
|
|
|
- for xyxy, conf, cla in zip(
|
|
|
- mfd_res.boxes.xyxy, mfd_res.boxes.conf, mfd_res.boxes.cls
|
|
|
- ):
|
|
|
+ for idx, (xyxy, conf, cla) in enumerate(zip(
|
|
|
+ mfd_res.boxes.xyxy, mfd_res.boxes.conf, mfd_res.boxes.cls
|
|
|
+ )):
|
|
|
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
|
|
|
new_item = {
|
|
|
"category_id": 13 + int(cla.item()),
|
|
|
@@ -123,19 +164,43 @@ class UnimernetModel(object):
|
|
|
}
|
|
|
formula_list.append(new_item)
|
|
|
bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
|
|
|
+ area = (xmax - xmin) * (ymax - ymin)
|
|
|
+
|
|
|
+ curr_idx = len(mf_image_list)
|
|
|
+ image_info.append((area, curr_idx, bbox_img))
|
|
|
mf_image_list.append(bbox_img)
|
|
|
|
|
|
images_formula_list.append(formula_list)
|
|
|
backfill_list += formula_list
|
|
|
|
|
|
- dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
|
|
|
+ # Stable sort by area
|
|
|
+ image_info.sort(key=lambda x: x[0]) # sort by area
|
|
|
+ sorted_indices = [x[1] for x in image_info]
|
|
|
+ sorted_images = [x[2] for x in image_info]
|
|
|
+
|
|
|
+ # Create mapping for results
|
|
|
+ index_mapping = {new_idx: old_idx for new_idx, old_idx in enumerate(sorted_indices)}
|
|
|
+
|
|
|
+ # Create dataset with sorted images
|
|
|
+ dataset = MathDataset(sorted_images, transform=self.mfr_transform)
|
|
|
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
|
|
|
+
|
|
|
+ # Process batches and store results
|
|
|
mfr_res = []
|
|
|
for mf_img in dataloader:
|
|
|
mf_img = mf_img.to(self.device)
|
|
|
with torch.no_grad():
|
|
|
output = self.model.generate({"image": mf_img})
|
|
|
mfr_res.extend(output["pred_str"])
|
|
|
- for res, latex in zip(backfill_list, mfr_res):
|
|
|
- res["latex"] = latex_rm_whitespace(latex)
|
|
|
+
|
|
|
+ # Restore original order
|
|
|
+ unsorted_results = [""] * len(mfr_res)
|
|
|
+ for new_idx, latex in enumerate(mfr_res):
|
|
|
+ original_idx = index_mapping[new_idx]
|
|
|
+ unsorted_results[original_idx] = latex_rm_whitespace(latex)
|
|
|
+
|
|
|
+ # Fill results back
|
|
|
+ for res, latex in zip(backfill_list, unsorted_results):
|
|
|
+ res["latex"] = latex
|
|
|
+
|
|
|
return images_formula_list
|