|
|
@@ -26,6 +26,7 @@ try:
|
|
|
from unimernet.processors import load_processor
|
|
|
from doclayout_yolo import YOLOv10
|
|
|
from rapid_table import RapidTable
|
|
|
+ from rapidocr_paddle import RapidOCR
|
|
|
|
|
|
except ImportError as e:
|
|
|
logger.exception(e)
|
|
|
@@ -42,6 +43,7 @@ from magic_pdf.model.ppTableModel import ppTableModel
|
|
|
|
|
|
|
|
|
def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
|
|
|
+ ocr_engine = None
|
|
|
if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
|
|
|
table_model = StructTableModel(model_path, max_time=max_time)
|
|
|
elif table_model_type == MODEL_NAME.TABLE_MASTER:
|
|
|
@@ -52,11 +54,15 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
|
|
|
table_model = ppTableModel(config)
|
|
|
elif table_model_type == MODEL_NAME.RAPID_TABLE:
|
|
|
table_model = RapidTable()
|
|
|
+ ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
|
|
|
else:
|
|
|
logger.error("table model type not allow")
|
|
|
exit(1)
|
|
|
|
|
|
- return table_model
|
|
|
+ if ocr_engine:
|
|
|
+ return [table_model, ocr_engine]
|
|
|
+ else:
|
|
|
+ return table_model
|
|
|
|
|
|
|
|
|
def mfd_model_init(weight):
|
|
|
@@ -283,23 +289,32 @@ class CustomPEKModel:
|
|
|
doclayout_yolo_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name]))
|
|
|
)
|
|
|
# 初始化ocr
|
|
|
- # if self.apply_ocr:
|
|
|
- self.ocr_model = atom_model_manager.get_atom_model(
|
|
|
- atom_model_name=AtomicModel.OCR,
|
|
|
- ocr_show_log=show_log,
|
|
|
- det_db_box_thresh=0.3,
|
|
|
- lang=self.lang
|
|
|
- )
|
|
|
+ if self.apply_ocr:
|
|
|
+ self.ocr_model = atom_model_manager.get_atom_model(
|
|
|
+ atom_model_name=AtomicModel.OCR,
|
|
|
+ ocr_show_log=show_log,
|
|
|
+ det_db_box_thresh=0.3,
|
|
|
+ lang=self.lang
|
|
|
+ )
|
|
|
# init table model
|
|
|
if self.apply_table:
|
|
|
table_model_dir = self.configs["weights"][self.table_model_name]
|
|
|
- self.table_model = atom_model_manager.get_atom_model(
|
|
|
- atom_model_name=AtomicModel.Table,
|
|
|
- table_model_name=self.table_model_name,
|
|
|
- table_model_path=str(os.path.join(models_dir, table_model_dir)),
|
|
|
- table_max_time=self.table_max_time,
|
|
|
- device=self.device
|
|
|
- )
|
|
|
+ if self.table_model_name in [MODEL_NAME.STRUCT_EQTABLE, MODEL_NAME.TABLE_MASTER]:
|
|
|
+ self.table_model = atom_model_manager.get_atom_model(
|
|
|
+ atom_model_name=AtomicModel.Table,
|
|
|
+ table_model_name=self.table_model_name,
|
|
|
+ table_model_path=str(os.path.join(models_dir, table_model_dir)),
|
|
|
+ table_max_time=self.table_max_time,
|
|
|
+ device=self.device
|
|
|
+ )
|
|
|
+ elif self.table_model_name in [MODEL_NAME.RAPID_TABLE]:
|
|
|
+ self.table_model, self.ocr_engine =atom_model_manager.get_atom_model(
|
|
|
+ atom_model_name=AtomicModel.Table,
|
|
|
+ table_model_name=self.table_model_name,
|
|
|
+ table_model_path=str(os.path.join(models_dir, table_model_dir)),
|
|
|
+ table_max_time=self.table_max_time,
|
|
|
+ device=self.device
|
|
|
+ )
|
|
|
|
|
|
logger.info('DocAnalysis init done!')
|
|
|
|
|
|
@@ -381,9 +396,8 @@ class CustomPEKModel:
|
|
|
table_res_list.append(res)
|
|
|
|
|
|
if torch.cuda.is_available() and self.device != 'cpu':
|
|
|
- properties = torch.cuda.get_device_properties(self.device)
|
|
|
- total_memory = properties.total_memory / (1024 ** 3) # 将字节转换为 GB
|
|
|
- if total_memory <= 10:
|
|
|
+ total_memory = torch.cuda.get_device_properties(self.device).total_memory / (1024 ** 3) # 将字节转换为 GB
|
|
|
+ if total_memory <= 8:
|
|
|
gc_start = time.time()
|
|
|
clean_memory()
|
|
|
gc_time = round(time.time() - gc_start, 2)
|
|
|
@@ -456,13 +470,8 @@ class CustomPEKModel:
|
|
|
elif self.table_model_name == MODEL_NAME.TABLE_MASTER:
|
|
|
html_code = self.table_model.img2html(new_image)
|
|
|
elif self.table_model_name == MODEL_NAME.RAPID_TABLE:
|
|
|
- new_image_bgr = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
|
|
|
- ocr_result = self.ocr_model.ocr(new_image_bgr)[0]
|
|
|
- new_ocr_result = []
|
|
|
- for box_ocr_res in ocr_result:
|
|
|
- text, score = box_ocr_res[1]
|
|
|
- new_ocr_result.append([box_ocr_res[0], text, score])
|
|
|
- html_code, table_cell_bboxes, elapse = self.table_model(np.asarray(new_image), new_ocr_result)
|
|
|
+ ocr_result, _ = self.ocr_engine(np.asarray(new_image))
|
|
|
+ html_code, table_cell_bboxes, elapse = self.table_model(np.asarray(new_image), ocr_result)
|
|
|
|
|
|
run_time = time.time() - single_table_start_time
|
|
|
# logger.info(f"------------table recognition processing ends within {run_time}s-----")
|