|
|
@@ -2,7 +2,7 @@ from loguru import logger
|
|
|
import os
|
|
|
import time
|
|
|
|
|
|
-from magic_pdf.libs.Constants import TABLE_MAX_TIME_VALUE
|
|
|
+from magic_pdf.libs.Constants import *
|
|
|
|
|
|
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
|
|
|
try:
|
|
|
@@ -34,10 +34,18 @@ from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Pre
|
|
|
from magic_pdf.model.pek_sub_modules.post_process import get_croped_image, latex_rm_whitespace
|
|
|
from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR
|
|
|
from magic_pdf.model.pek_sub_modules.structeqtable.StructTableModel import StructTableModel
|
|
|
-
|
|
|
-
|
|
|
-def table_model_init(model_path, max_time, _device_='cpu'):
|
|
|
- table_model = StructTableModel(model_path, max_time=max_time, device=_device_)
|
|
|
+from magic_pdf.model.ppTableModel import ppTableModel
|
|
|
+
|
|
|
+
|
|
|
+def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
|
|
|
+ if table_model_type == STRUCT_EQTABLE:
|
|
|
+ table_model = StructTableModel(model_path, max_time=max_time, device=_device_)
|
|
|
+ else:
|
|
|
+ config = {
|
|
|
+ "model_dir": model_path,
|
|
|
+ "device": _device_
|
|
|
+ }
|
|
|
+ table_model = ppTableModel(config)
|
|
|
return table_model
|
|
|
|
|
|
|
|
|
@@ -104,9 +112,11 @@ class CustomPEKModel:
|
|
|
# 初始化解析配置
|
|
|
self.apply_layout = kwargs.get("apply_layout", self.configs["config"]["layout"])
|
|
|
self.apply_formula = kwargs.get("apply_formula", self.configs["config"]["formula"])
|
|
|
+ # table config
|
|
|
self.table_config = kwargs.get("table_config", self.configs["config"]["table_config"])
|
|
|
self.apply_table = self.table_config.get("is_table_recog_enable", False)
|
|
|
self.table_max_time = self.table_config.get("max_time", TABLE_MAX_TIME_VALUE)
|
|
|
+ self.table_model_type = self.table_config.get("model", TABLE_MASTER)
|
|
|
self.apply_ocr = ocr
|
|
|
logger.info(
|
|
|
"DocAnalysis init, this may take some times. apply_layout: {}, apply_formula: {}, apply_ocr: {}, apply_table: {}".format(
|
|
|
@@ -141,10 +151,11 @@ class CustomPEKModel:
|
|
|
if self.apply_ocr:
|
|
|
self.ocr_model = ModifiedPaddleOCR(show_log=show_log)
|
|
|
|
|
|
- # init structeqtable
|
|
|
+ # init table model
|
|
|
if self.apply_table:
|
|
|
- self.table_model = table_model_init(str(os.path.join(models_dir, self.configs["weights"]["table"])),
|
|
|
- max_time = self.table_max_time, _device_=self.device)
|
|
|
+ table_model_dir = self.configs["weights"][self.table_model_type]
|
|
|
+ self.table_model = table_model_init(self.table_model_type, str(os.path.join(models_dir, table_model_dir)),
|
|
|
+ max_time=self.table_max_time, _device_=self.device)
|
|
|
logger.info('DocAnalysis init done!')
|
|
|
|
|
|
def __call__(self, image):
|
|
|
@@ -278,16 +289,28 @@ class CustomPEKModel:
|
|
|
new_image, _ = crop_img(res, pil_img)
|
|
|
single_table_start_time = time.time()
|
|
|
logger.info("------------------table recognition processing begins-----------------")
|
|
|
+ latex_code = None
|
|
|
+ html_code = None
|
|
|
with torch.no_grad():
|
|
|
- latex_code = self.table_model.image2latex(new_image)[0]
|
|
|
+ if self.table_model_type == STRUCT_EQTABLE:
|
|
|
+ latex_code = self.table_model.image2latex(new_image)[0]
|
|
|
+ else:
|
|
|
+ html_code = self.table_model.img2html(new_image)
|
|
|
run_time = time.time() - single_table_start_time
|
|
|
logger.info(f"------------table recognition processing ends within {run_time}s-----")
|
|
|
if run_time > self.table_max_time:
|
|
|
logger.warning(f"------------table recognition processing exceeds max time {self.table_max_time}s----------")
|
|
|
# 判断是否返回正常
|
|
|
- expected_ending = latex_code.strip().endswith('end{tabular}') or latex_code.strip().endswith('end{table}')
|
|
|
- if latex_code and expected_ending:
|
|
|
- res["latex"] = latex_code
|
|
|
+
|
|
|
+ if latex_code:
|
|
|
+ expected_ending = latex_code.strip().endswith('end{tabular}') or latex_code.strip().endswith(
|
|
|
+ 'end{table}')
|
|
|
+ if expected_ending:
|
|
|
+ res["latex"] = latex_code
|
|
|
+ else:
|
|
|
+ logger.warning(f"------------table recognition processing fails----------")
|
|
|
+ elif html_code:
|
|
|
+ res["html"] = html_code
|
|
|
else:
|
|
|
logger.warning(f"------------table recognition processing fails----------")
|
|
|
table_cost = round(time.time() - table_start, 2)
|