|
|
@@ -1,6 +1,8 @@
|
|
|
+import torch
|
|
|
from loguru import logger
|
|
|
|
|
|
from magic_pdf.config.constants import MODEL_NAME
|
|
|
+from magic_pdf.libs.config_reader import get_device
|
|
|
from magic_pdf.model.model_list import AtomicModel
|
|
|
from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import \
|
|
|
DocLayoutYOLOModel
|
|
|
@@ -19,7 +21,7 @@ from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import \
|
|
|
TableMasterPaddleModel
|
|
|
|
|
|
|
|
|
-def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
|
|
|
+def table_model_init(table_model_type, model_path, max_time, _device_='cpu', lang=None):
|
|
|
if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
|
|
|
table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time)
|
|
|
elif table_model_type == MODEL_NAME.TABLE_MASTER:
|
|
|
@@ -29,7 +31,7 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
|
|
|
}
|
|
|
table_model = TableMasterPaddleModel(config)
|
|
|
elif table_model_type == MODEL_NAME.RAPID_TABLE:
|
|
|
- table_model = RapidTableModel()
|
|
|
+ table_model = RapidTableModel(lang)
|
|
|
else:
|
|
|
logger.error('table model type not allow')
|
|
|
exit(1)
|
|
|
@@ -38,6 +40,8 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
|
|
|
|
|
|
|
|
|
def mfd_model_init(weight, device='cpu'):
|
|
|
+ if str(device).startswith("npu"):
|
|
|
+ device = torch.device(device)
|
|
|
mfd_model = YOLOv8MFDModel(weight, device)
|
|
|
return mfd_model
|
|
|
|
|
|
@@ -53,6 +57,8 @@ def layout_model_init(weight, config_file, device):
|
|
|
|
|
|
|
|
|
def doclayout_yolo_model_init(weight, device='cpu'):
|
|
|
+ if str(device).startswith("npu"):
|
|
|
+ device = torch.device(device)
|
|
|
model = DocLayoutYOLOModel(weight, device)
|
|
|
return model
|
|
|
|
|
|
@@ -63,6 +69,12 @@ def ocr_model_init(show_log: bool = False,
|
|
|
use_dilation=True,
|
|
|
det_db_unclip_ratio=1.8,
|
|
|
):
|
|
|
+
|
|
|
+ use_npu = False
|
|
|
+ device = get_device()
|
|
|
+ if str(device).startswith("npu"):
|
|
|
+ use_npu = True
|
|
|
+
|
|
|
if lang is not None and lang != '':
|
|
|
model = ModifiedPaddleOCR(
|
|
|
show_log=show_log,
|
|
|
@@ -70,6 +82,7 @@ def ocr_model_init(show_log: bool = False,
|
|
|
lang=lang,
|
|
|
use_dilation=use_dilation,
|
|
|
det_db_unclip_ratio=det_db_unclip_ratio,
|
|
|
+ use_npu=use_npu,
|
|
|
)
|
|
|
else:
|
|
|
model = ModifiedPaddleOCR(
|
|
|
@@ -77,7 +90,7 @@ def ocr_model_init(show_log: bool = False,
|
|
|
det_db_box_thresh=det_db_box_thresh,
|
|
|
use_dilation=use_dilation,
|
|
|
det_db_unclip_ratio=det_db_unclip_ratio,
|
|
|
- # use_angle_cls=True,
|
|
|
+ use_npu=use_npu,
|
|
|
)
|
|
|
return model
|
|
|
|
|
|
@@ -146,7 +159,8 @@ def atom_model_init(model_name: str, **kwargs):
|
|
|
kwargs.get('table_model_name'),
|
|
|
kwargs.get('table_model_path'),
|
|
|
kwargs.get('table_max_time'),
|
|
|
- kwargs.get('device')
|
|
|
+ kwargs.get('device'),
|
|
|
+ kwargs.get('lang'),
|
|
|
)
|
|
|
else:
|
|
|
logger.error('model name not allow')
|