|
@@ -1,8 +1,6 @@
|
|
|
from loguru import logger
|
|
from loguru import logger
|
|
|
import os
|
|
import os
|
|
|
import time
|
|
import time
|
|
|
-from pathlib import Path
|
|
|
|
|
-import shutil
|
|
|
|
|
from magic_pdf.libs.Constants import *
|
|
from magic_pdf.libs.Constants import *
|
|
|
from magic_pdf.libs.clean_memory import clean_memory
|
|
from magic_pdf.libs.clean_memory import clean_memory
|
|
|
from magic_pdf.model.model_list import AtomicModel
|
|
from magic_pdf.model.model_list import AtomicModel
|
|
@@ -27,6 +25,7 @@ try:
|
|
|
import unimernet.tasks as tasks
|
|
import unimernet.tasks as tasks
|
|
|
from unimernet.processors import load_processor
|
|
from unimernet.processors import load_processor
|
|
|
from doclayout_yolo import YOLOv10
|
|
from doclayout_yolo import YOLOv10
|
|
|
|
|
+ from rapid_table import RapidTable
|
|
|
|
|
|
|
|
except ImportError as e:
|
|
except ImportError as e:
|
|
|
logger.exception(e)
|
|
logger.exception(e)
|
|
@@ -51,9 +50,12 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
|
|
|
"device": _device_
|
|
"device": _device_
|
|
|
}
|
|
}
|
|
|
table_model = ppTableModel(config)
|
|
table_model = ppTableModel(config)
|
|
|
|
|
+ elif table_model_type == MODEL_NAME.RAPID_TABLE:
|
|
|
|
|
+ table_model = RapidTable()
|
|
|
else:
|
|
else:
|
|
|
logger.error("table model type not allow")
|
|
logger.error("table model type not allow")
|
|
|
exit(1)
|
|
exit(1)
|
|
|
|
|
+
|
|
|
return table_model
|
|
return table_model
|
|
|
|
|
|
|
|
|
|
|
|
@@ -226,7 +228,7 @@ class CustomPEKModel:
|
|
|
self.table_config = kwargs.get("table_config")
|
|
self.table_config = kwargs.get("table_config")
|
|
|
self.apply_table = self.table_config.get("enable", False)
|
|
self.apply_table = self.table_config.get("enable", False)
|
|
|
self.table_max_time = self.table_config.get("max_time", TABLE_MAX_TIME_VALUE)
|
|
self.table_max_time = self.table_config.get("max_time", TABLE_MAX_TIME_VALUE)
|
|
|
- self.table_model_name = self.table_config.get("model", MODEL_NAME.TABLE_MASTER)
|
|
|
|
|
|
|
+ self.table_model_name = self.table_config.get("model", MODEL_NAME.RAPID_TABLE)
|
|
|
|
|
|
|
|
# ocr config
|
|
# ocr config
|
|
|
self.apply_ocr = ocr
|
|
self.apply_ocr = ocr
|
|
@@ -281,13 +283,13 @@ class CustomPEKModel:
|
|
|
doclayout_yolo_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name]))
|
|
doclayout_yolo_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name]))
|
|
|
)
|
|
)
|
|
|
# 初始化ocr
|
|
# 初始化ocr
|
|
|
- if self.apply_ocr:
|
|
|
|
|
- self.ocr_model = atom_model_manager.get_atom_model(
|
|
|
|
|
- atom_model_name=AtomicModel.OCR,
|
|
|
|
|
- ocr_show_log=show_log,
|
|
|
|
|
- det_db_box_thresh=0.3,
|
|
|
|
|
- lang=self.lang
|
|
|
|
|
- )
|
|
|
|
|
|
|
+ # if self.apply_ocr:
|
|
|
|
|
+ self.ocr_model = atom_model_manager.get_atom_model(
|
|
|
|
|
+ atom_model_name=AtomicModel.OCR,
|
|
|
|
|
+ ocr_show_log=show_log,
|
|
|
|
|
+ det_db_box_thresh=0.3,
|
|
|
|
|
+ lang=self.lang
|
|
|
|
|
+ )
|
|
|
# init table model
|
|
# init table model
|
|
|
if self.apply_table:
|
|
if self.apply_table:
|
|
|
table_model_dir = self.configs["weights"][self.table_model_name]
|
|
table_model_dir = self.configs["weights"][self.table_model_name]
|
|
@@ -451,8 +453,16 @@ class CustomPEKModel:
|
|
|
table_result = self.table_model.predict(new_image, "html")
|
|
table_result = self.table_model.predict(new_image, "html")
|
|
|
if len(table_result) > 0:
|
|
if len(table_result) > 0:
|
|
|
html_code = table_result[0]
|
|
html_code = table_result[0]
|
|
|
- else:
|
|
|
|
|
|
|
+ elif self.table_model_name == MODEL_NAME.TABLE_MASTER:
|
|
|
html_code = self.table_model.img2html(new_image)
|
|
html_code = self.table_model.img2html(new_image)
|
|
|
|
|
+ elif self.table_model_name == MODEL_NAME.RAPID_TABLE:
|
|
|
|
|
+ new_image_bgr = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
|
|
|
|
|
+ ocr_result = self.ocr_model.ocr(new_image_bgr)[0]
|
|
|
|
|
+ new_ocr_result = []
|
|
|
|
|
+ for box_ocr_res in ocr_result:
|
|
|
|
|
+ text, score = box_ocr_res[1]
|
|
|
|
|
+ new_ocr_result.append([box_ocr_res[0], text, score])
|
|
|
|
|
+ html_code, table_cell_bboxes, elapse = self.table_model(np.asarray(new_image), new_ocr_result)
|
|
|
|
|
|
|
|
run_time = time.time() - single_table_start_time
|
|
run_time = time.time() - single_table_start_time
|
|
|
# logger.info(f"------------table recognition processing ends within {run_time}s-----")
|
|
# logger.info(f"------------table recognition processing ends within {run_time}s-----")
|