|
|
@@ -77,6 +77,8 @@ class FormulaRecognizer(BaseOCRV20):
|
|
|
with tqdm(total=len(inp), desc="MFR Predict") as pbar:
|
|
|
for index in range(0, len(inp), batch_size):
|
|
|
batch_data = inp[index: index + batch_size]
|
|
|
+ # with torch.amp.autocast(device_type=self.device.type):
|
|
|
+ # batch_preds = [self.net(batch_data)]
|
|
|
batch_preds = [self.net(batch_data)]
|
|
|
batch_preds = [p.reshape([-1]) for p in batch_preds[0]]
|
|
|
batch_preds = [bp.cpu().numpy() for bp in batch_preds]
|