|
|
@@ -19,7 +19,11 @@ from ...utils.models_download_utils import auto_download_and_get_model_root_path
|
|
|
def img_orientation_cls_model_init():
|
|
|
atom_model_manager = AtomModelSingleton()
|
|
|
ocr_engine = atom_model_manager.get_atom_model(
|
|
|
- atom_model_name="ocr", det_db_box_thresh=0.5, det_db_unclip_ratio=1.6, lang="ch_lite"
|
|
|
+ atom_model_name=AtomicModel.OCR,
|
|
|
+ det_db_box_thresh=0.5,
|
|
|
+ det_db_unclip_ratio=1.6,
|
|
|
+ lang="ch_lite",
|
|
|
+ enable_merge_det_boxes=False
|
|
|
)
|
|
|
cls_model = PaddleOrientationClsModel(ocr_engine)
|
|
|
return cls_model
|
|
|
@@ -32,7 +36,11 @@ def table_cls_model_init():
|
|
|
def wired_table_model_init(lang=None):
|
|
|
atom_model_manager = AtomModelSingleton()
|
|
|
ocr_engine = atom_model_manager.get_atom_model(
|
|
|
- atom_model_name="ocr", det_db_box_thresh=0.5, det_db_unclip_ratio=1.6, lang=lang
|
|
|
+ atom_model_name=AtomicModel.OCR,
|
|
|
+ det_db_box_thresh=0.5,
|
|
|
+ det_db_unclip_ratio=1.6,
|
|
|
+ lang=lang,
|
|
|
+ enable_merge_det_boxes=False
|
|
|
)
|
|
|
table_model = UnetTableModel(ocr_engine)
|
|
|
return table_model
|
|
|
@@ -41,7 +49,11 @@ def wired_table_model_init(lang=None):
|
|
|
def wireless_table_model_init(lang=None):
|
|
|
atom_model_manager = AtomModelSingleton()
|
|
|
ocr_engine = atom_model_manager.get_atom_model(
|
|
|
- atom_model_name="ocr", det_db_box_thresh=0.5, det_db_unclip_ratio=1.6, lang=lang
|
|
|
+ atom_model_name=AtomicModel.OCR,
|
|
|
+ det_db_box_thresh=0.5,
|
|
|
+ det_db_unclip_ratio=1.6,
|
|
|
+ lang=lang,
|
|
|
+ enable_merge_det_boxes=False
|
|
|
)
|
|
|
table_model = RapidTableModel(ocr_engine)
|
|
|
return table_model
|
|
|
@@ -67,21 +79,23 @@ def doclayout_yolo_model_init(weight, device='cpu'):
|
|
|
|
|
|
def ocr_model_init(det_db_box_thresh=0.3,
|
|
|
lang=None,
|
|
|
- use_dilation=True,
|
|
|
det_db_unclip_ratio=1.8,
|
|
|
+ enable_merge_det_boxes=True
|
|
|
):
|
|
|
if lang is not None and lang != '':
|
|
|
model = PytorchPaddleOCR(
|
|
|
det_db_box_thresh=det_db_box_thresh,
|
|
|
lang=lang,
|
|
|
- use_dilation=use_dilation,
|
|
|
+ use_dilation=True,
|
|
|
det_db_unclip_ratio=det_db_unclip_ratio,
|
|
|
+ enable_merge_det_boxes=enable_merge_det_boxes,
|
|
|
)
|
|
|
else:
|
|
|
model = PytorchPaddleOCR(
|
|
|
det_db_box_thresh=det_db_box_thresh,
|
|
|
- use_dilation=use_dilation,
|
|
|
+ use_dilation=True,
|
|
|
det_db_unclip_ratio=det_db_unclip_ratio,
|
|
|
+ enable_merge_det_boxes=enable_merge_det_boxes,
|
|
|
)
|
|
|
return model
|
|
|
|
|
|
@@ -99,8 +113,14 @@ class AtomModelSingleton:
|
|
|
|
|
|
lang = kwargs.get('lang', None)
|
|
|
|
|
|
- if atom_model_name in [AtomicModel.OCR, AtomicModel.WiredTable, AtomicModel.WirelessTable]:
|
|
|
+ if atom_model_name in [AtomicModel.WiredTable, AtomicModel.WirelessTable]:
|
|
|
key = (atom_model_name, lang)
|
|
|
+ elif atom_model_name in [AtomicModel.OCR]:
|
|
|
+ key = (atom_model_name,
|
|
|
+ kwargs.get('det_db_box_thresh', 0.3),
|
|
|
+ lang, kwargs.get('det_db_unclip_ratio', 1.8),
|
|
|
+ kwargs.get('enable_merge_det_boxes', True)
|
|
|
+ )
|
|
|
else:
|
|
|
key = atom_model_name
|
|
|
|
|
|
@@ -127,8 +147,10 @@ def atom_model_init(model_name: str, **kwargs):
|
|
|
)
|
|
|
elif model_name == AtomicModel.OCR:
|
|
|
atom_model = ocr_model_init(
|
|
|
- kwargs.get('det_db_box_thresh'),
|
|
|
+ kwargs.get('det_db_box_thresh', 0.3),
|
|
|
kwargs.get('lang'),
|
|
|
+ kwargs.get('det_db_unclip_ratio', 1.8),
|
|
|
+ kwargs.get('enable_merge_det_boxes', True)
|
|
|
)
|
|
|
elif model_name == AtomicModel.WirelessTable:
|
|
|
atom_model = wireless_table_model_init(
|