|
|
@@ -1,3 +1,5 @@
|
|
|
+import os
|
|
|
+from pathlib import Path
|
|
|
import cv2
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
@@ -17,7 +19,9 @@ class RapidTableModel(object):
|
|
|
if torch.cuda.is_available() and table_sub_model_name == "unitable":
|
|
|
input_args = RapidTableInput(model_type=table_sub_model_name, use_cuda=True, device=get_device())
|
|
|
else:
|
|
|
- input_args = RapidTableInput(model_type=table_sub_model_name)
|
|
|
+ root_dir = Path(__file__).absolute().parent.parent.parent.parent.parent
|
|
|
+ slanet_plus_model_path = os.path.join(root_dir, 'resources', 'slanet_plus', 'slanet-plus.onnx')
|
|
|
+ input_args = RapidTableInput(model_type=table_sub_model_name, model_path=slanet_plus_model_path)
|
|
|
else:
|
|
|
raise ValueError(f"Invalid table_sub_model_name: {table_sub_model_name}. It must be one of {sub_model_list}")
|
|
|
|