|
|
@@ -57,6 +57,11 @@ def doclayout_yolo_model_init(weight, device='cpu'):
|
|
|
return model
|
|
|
|
|
|
|
|
|
+import threading
|
|
|
+current_thread = threading.current_thread()
|
|
|
+current_thread_id = current_thread.ident
|
|
|
+
|
|
|
+
|
|
|
def ocr_model_init(show_log: bool = False,
|
|
|
det_db_box_thresh=0.3,
|
|
|
lang=None,
|
|
|
@@ -92,14 +97,24 @@ class AtomModelSingleton:
|
|
|
return cls._instance
|
|
|
|
|
|
def get_atom_model(self, atom_model_name: str, **kwargs):
|
|
|
+
|
|
|
lang = kwargs.get('lang', None)
|
|
|
layout_model_name = kwargs.get('layout_model_name', None)
|
|
|
- key = (atom_model_name, layout_model_name, lang)
|
|
|
+ table_model_name = kwargs.get('table_model_name', None)
|
|
|
+
|
|
|
+ if atom_model_name in [AtomicModel.OCR]:
|
|
|
+ key = (atom_model_name, lang, current_thread_id)
|
|
|
+ elif atom_model_name in [AtomicModel.Layout]:
|
|
|
+ key = (atom_model_name, layout_model_name)
|
|
|
+ elif atom_model_name in [AtomicModel.Table]:
|
|
|
+ key = (atom_model_name, table_model_name)
|
|
|
+ else:
|
|
|
+ key = atom_model_name
|
|
|
+
|
|
|
if key not in self._models:
|
|
|
self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
|
|
|
return self._models[key]
|
|
|
|
|
|
-
|
|
|
def atom_model_init(model_name: str, **kwargs):
|
|
|
atom_model = None
|
|
|
if model_name == AtomicModel.Layout:
|
|
|
@@ -129,7 +144,7 @@ def atom_model_init(model_name: str, **kwargs):
|
|
|
atom_model = ocr_model_init(
|
|
|
kwargs.get('ocr_show_log'),
|
|
|
kwargs.get('det_db_box_thresh'),
|
|
|
- kwargs.get('lang')
|
|
|
+ kwargs.get('lang'),
|
|
|
)
|
|
|
elif model_name == AtomicModel.Table:
|
|
|
atom_model = table_model_init(
|