|
|
@@ -1,6 +1,7 @@
|
|
|
from loguru import logger
|
|
|
import os
|
|
|
import time
|
|
|
+from pypandoc import convert_text
|
|
|
|
|
|
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
|
|
|
try:
|
|
|
@@ -10,6 +11,7 @@ try:
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
import torchtext
|
|
|
+
|
|
|
if torchtext.__version__ >= "0.18.0":
|
|
|
torchtext.disable_torchtext_deprecation_warning()
|
|
|
from PIL import Image
|
|
|
@@ -30,6 +32,12 @@ except ImportError as e:
|
|
|
from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
|
|
|
from magic_pdf.model.pek_sub_modules.post_process import get_croped_image, latex_rm_whitespace
|
|
|
from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR
|
|
|
+from magic_pdf.model.pek_sub_modules.structeqtable.StructTableModel import StructTableModel
|
|
|
+
|
|
|
+
|
|
|
+def table_model_init(model_path, max_time=400, _device_='cpu'):
|
|
|
+ table_model = StructTableModel(model_path, max_time=max_time, device=_device_)
|
|
|
+ return table_model
|
|
|
|
|
|
|
|
|
def mfd_model_init(weight):
|
|
|
@@ -95,6 +103,8 @@ class CustomPEKModel:
|
|
|
# 初始化解析配置
|
|
|
self.apply_layout = kwargs.get("apply_layout", self.configs["config"]["layout"])
|
|
|
self.apply_formula = kwargs.get("apply_formula", self.configs["config"]["formula"])
|
|
|
+ self.table_config = kwargs.get("table_config", self.configs["config"]["table_config"])
|
|
|
+ self.apply_table = self.table_config.get("is_table_recog_enable", False)
|
|
|
self.apply_ocr = ocr
|
|
|
logger.info(
|
|
|
"DocAnalysis init, this may take some times. apply_layout: {}, apply_formula: {}, apply_ocr: {}".format(
|
|
|
@@ -129,6 +139,11 @@ class CustomPEKModel:
|
|
|
if self.apply_ocr:
|
|
|
self.ocr_model = ModifiedPaddleOCR(show_log=show_log)
|
|
|
|
|
|
+ # init structeqtable
|
|
|
+ if self.apply_table:
|
|
|
+ max_time = self.table_config.get("max_time", 400)
|
|
|
+ self.table_model = table_model_init(str(os.path.join(models_dir, self.configs["weights"]["table"])),
|
|
|
+ max_time=max_time, _device_=self.device)
|
|
|
logger.info('DocAnalysis init done!')
|
|
|
|
|
|
def __call__(self, image):
|
|
|
@@ -249,4 +264,32 @@ class CustomPEKModel:
|
|
|
ocr_cost = round(time.time() - ocr_start, 2)
|
|
|
logger.info(f"ocr cost: {ocr_cost}")
|
|
|
|
|
|
+ # 表格识别 table recognition
|
|
|
+ if self.apply_table:
|
|
|
+ pil_img = Image.fromarray(image)
|
|
|
+ for layout in layout_res:
|
|
|
+ if layout.get("category_id", -1) == 5:
|
|
|
+ poly = layout["poly"]
|
|
|
+ xmin, ymin = int(poly[0]), int(poly[1])
|
|
|
+ xmax, ymax = int(poly[4]), int(poly[5])
|
|
|
+
|
|
|
+ paste_x = 50
|
|
|
+ paste_y = 50
|
|
|
+ # 创建一个宽高各多50的白色背景 create a whiteboard with 50 larger width and length
|
|
|
+ new_width = xmax - xmin + paste_x * 2
|
|
|
+ new_height = ymax - ymin + paste_y * 2
|
|
|
+ new_image = Image.new('RGB', (new_width, new_height), 'white')
|
|
|
+
|
|
|
+ # 裁剪图像 crop image
|
|
|
+ crop_box = (xmin, ymin, xmax, ymax)
|
|
|
+ cropped_img = pil_img.crop(crop_box)
|
|
|
+ new_image.paste(cropped_img, (paste_x, paste_y))
|
|
|
+ start_time = time.time()
|
|
|
+ logger.info("------------------table recognition processing begins-----------------")
|
|
|
+ latex_code = self.table_model.image2latex(new_image)[0]
|
|
|
+ end_time = time.time()
|
|
|
+ run_time = end_time - start_time
|
|
|
+ logger.info(f"------------table recognition processing ends within {run_time}s-----")
|
|
|
+ layout["latex"] = latex_code
|
|
|
+
|
|
|
return layout_res
|