|
|
@@ -68,13 +68,14 @@ class FormulaRecognizer(BaseOCRV20):
|
|
|
batch_imgs = self.pre_tfs["UniMERNetImgDecode"](imgs=img_list)
|
|
|
batch_imgs = self.pre_tfs["UniMERNetTestTransform"](imgs=batch_imgs)
|
|
|
batch_imgs = self.pre_tfs["LatexImageFormat"](imgs=batch_imgs)
|
|
|
- x = self.pre_tfs["ToBatch"](imgs=batch_imgs)
|
|
|
- x = torch.from_numpy(x[0]).to(self.device)
|
|
|
+ inp = self.pre_tfs["ToBatch"](imgs=batch_imgs)
|
|
|
+ inp = torch.from_numpy(inp[0])
|
|
|
+ inp = inp.to(self.device)
|
|
|
rec_formula = []
|
|
|
with torch.no_grad():
|
|
|
- with tqdm(total=len(x), desc="Formula Predict") as pbar:
|
|
|
- for index in range(0, len(x), batch_size):
|
|
|
- batch_data = x[index: index + batch_size]
|
|
|
+ with tqdm(total=len(inp), desc="Formula Predict") as pbar:
|
|
|
+ for index in range(0, len(inp), batch_size):
|
|
|
+ batch_data = inp[index: index + batch_size]
|
|
|
batch_preds = [self.net(batch_data)]
|
|
|
batch_preds = [p.reshape([-1]) for p in batch_preds[0]]
|
|
|
rec_formula += self.post_op(batch_preds)
|