|
|
@@ -3,8 +3,8 @@ import torch
|
|
|
import yaml
|
|
|
from pathlib import Path
|
|
|
from tqdm import tqdm
|
|
|
-from mineru.model.ocr.paddleocr2pytorch.tools.infer import pytorchocr_utility
|
|
|
-from mineru.model.ocr.paddleocr2pytorch.pytorchocr.base_ocr_v20 import BaseOCRV20
|
|
|
+from mineru.model.utils.tools.infer import pytorchocr_utility
|
|
|
+from mineru.model.utils.pytorchocr.base_ocr_v20 import BaseOCRV20
|
|
|
from .processors import (
|
|
|
UniMERNetImgDecode,
|
|
|
UniMERNetTestTransform,
|
|
|
@@ -24,7 +24,14 @@ class FormulaRecognizer(BaseOCRV20):
|
|
|
weight_dir,
|
|
|
"PP-FormulaNet_plus-M.pth",
|
|
|
)
|
|
|
- self.yaml_path = Path(__file__).parent / "config" / "arch_config.yaml"
|
|
|
+ self.yaml_path = os.path.join(
|
|
|
+ Path(__file__).parent.parent.parent,
|
|
|
+ "utils",
|
|
|
+ "pytorchocr",
|
|
|
+ "utils",
|
|
|
+ "resources",
|
|
|
+ "pp_formulanet_arch_config.yaml"
|
|
|
+ )
|
|
|
self.infer_yaml_path = os.path.join(
|
|
|
weight_dir,
|
|
|
"PP-FormulaNet_plus-M_inference.yml",
|
|
|
@@ -38,6 +45,7 @@ class FormulaRecognizer(BaseOCRV20):
|
|
|
super(FormulaRecognizer, self).__init__(network_config)
|
|
|
|
|
|
self.load_state_dict(weights)
|
|
|
+ # device = "cpu"
|
|
|
self.device = torch.device(device) if isinstance(device, str) else device
|
|
|
self.net.to(self.device)
|
|
|
self.net.eval()
|