| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139 |
- import torch
- from torch.utils.data import DataLoader, Dataset
- from tqdm import tqdm
- class MathDataset(Dataset):
- def __init__(self, image_paths, transform=None):
- self.image_paths = image_paths
- self.transform = transform
- def __len__(self):
- return len(self.image_paths)
- def __getitem__(self, idx):
- raw_image = self.image_paths[idx]
- if self.transform:
- image = self.transform(raw_image)
- return image
- class UnimernetModel(object):
- def __init__(self, weight_dir, _device_="cpu"):
- from .unimernet_hf import UnimernetModel
- if _device_.startswith("mps") or _device_.startswith("npu"):
- self.model = UnimernetModel.from_pretrained(weight_dir, attn_implementation="eager")
- else:
- self.model = UnimernetModel.from_pretrained(weight_dir)
- self.device = _device_
- self.model.to(_device_)
- if not _device_.startswith("cpu"):
- self.model = self.model.to(dtype=torch.float16)
- self.model.eval()
- def predict(self, mfd_res, image):
- formula_list = []
- mf_image_list = []
- for xyxy, conf, cla in zip(
- mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()
- ):
- xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
- new_item = {
- "category_id": 13 + int(cla.item()),
- "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
- "score": round(float(conf.item()), 2),
- "latex": "",
- }
- formula_list.append(new_item)
- bbox_img = image[ymin:ymax, xmin:xmax]
- mf_image_list.append(bbox_img)
- dataset = MathDataset(mf_image_list, transform=self.model.transform)
- dataloader = DataLoader(dataset, batch_size=32, num_workers=0)
- mfr_res = []
- for mf_img in dataloader:
- mf_img = mf_img.to(dtype=self.model.dtype)
- mf_img = mf_img.to(self.device)
- with torch.no_grad():
- output = self.model.generate({"image": mf_img})
- mfr_res.extend(output["fixed_str"])
- for res, latex in zip(formula_list, mfr_res):
- res["latex"] = latex
- return formula_list
- def batch_predict(self, images_mfd_res: list, images: list, batch_size: int = 64) -> list:
- images_formula_list = []
- mf_image_list = []
- backfill_list = []
- image_info = [] # Store (area, original_index, image) tuples
- # Collect images with their original indices
- for image_index in range(len(images_mfd_res)):
- mfd_res = images_mfd_res[image_index]
- image = images[image_index]
- formula_list = []
- for idx, (xyxy, conf, cla) in enumerate(zip(
- mfd_res.boxes.xyxy, mfd_res.boxes.conf, mfd_res.boxes.cls
- )):
- xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
- new_item = {
- "category_id": 13 + int(cla.item()),
- "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
- "score": round(float(conf.item()), 2),
- "latex": "",
- }
- formula_list.append(new_item)
- bbox_img = image[ymin:ymax, xmin:xmax]
- area = (xmax - xmin) * (ymax - ymin)
- curr_idx = len(mf_image_list)
- image_info.append((area, curr_idx, bbox_img))
- mf_image_list.append(bbox_img)
- images_formula_list.append(formula_list)
- backfill_list += formula_list
- # Stable sort by area
- image_info.sort(key=lambda x: x[0]) # sort by area
- sorted_indices = [x[1] for x in image_info]
- sorted_images = [x[2] for x in image_info]
- # Create mapping for results
- index_mapping = {new_idx: old_idx for new_idx, old_idx in enumerate(sorted_indices)}
- # Create dataset with sorted images
- dataset = MathDataset(sorted_images, transform=self.model.transform)
- # 如果batch_size > len(sorted_images),则设置为不超过len(sorted_images)的2的幂
- batch_size = min(batch_size, max(1, 2 ** (len(sorted_images).bit_length() - 1))) if sorted_images else 1
- dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
- # Process batches and store results
- mfr_res = []
- # for mf_img in dataloader:
- with tqdm(total=len(sorted_images), desc="MFR Predict") as pbar:
- for index, mf_img in enumerate(dataloader):
- mf_img = mf_img.to(dtype=self.model.dtype)
- mf_img = mf_img.to(self.device)
- with torch.no_grad():
- output = self.model.generate({"image": mf_img}, batch_size=batch_size)
- mfr_res.extend(output["fixed_str"])
- # 更新进度条,每次增加batch_size,但要注意最后一个batch可能不足batch_size
- current_batch_size = min(batch_size, len(sorted_images) - index * batch_size)
- pbar.update(current_batch_size)
- # Restore original order
- unsorted_results = [""] * len(mfr_res)
- for new_idx, latex in enumerate(mfr_res):
- original_idx = index_mapping[new_idx]
- unsorted_results[original_idx] = latex
- # Fill results back
- for res, latex in zip(backfill_list, unsorted_results):
- res["latex"] = latex
- return images_formula_list
|