|
|
@@ -32,7 +32,7 @@ except ImportError as e:
|
|
|
exit(1)
|
|
|
|
|
|
from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
|
|
|
-from magic_pdf.model.pek_sub_modules.post_process import get_croped_image, latex_rm_whitespace
|
|
|
+from magic_pdf.model.pek_sub_modules.post_process import latex_rm_whitespace
|
|
|
from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR
|
|
|
from magic_pdf.model.pek_sub_modules.structeqtable.StructTableModel import StructTableModel
|
|
|
from magic_pdf.model.ppTableModel import ppTableModel
|
|
|
@@ -63,7 +63,8 @@ def mfr_model_init(weight_dir, cfg_path, _device_='cpu'):
|
|
|
cfg.config.model.tokenizer_config.path = weight_dir
|
|
|
task = tasks.setup_task(cfg)
|
|
|
model = task.build_model(cfg)
|
|
|
- model = model.to(_device_)
|
|
|
+ model.to(_device_)
|
|
|
+ model.eval()
|
|
|
vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
|
|
|
mfr_transform = transforms.Compose([vis_processor, ])
|
|
|
return [model, mfr_transform]
|
|
|
@@ -154,6 +155,23 @@ def atom_model_init(model_name: str, **kwargs):
|
|
|
return atom_model
|
|
|
|
|
|
|
|
|
+# Unified crop img logic
|
|
|
+def crop_img(input_res, input_pil_img, crop_paste_x=0, crop_paste_y=0):
|
|
|
+ crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1])
|
|
|
+ crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5])
|
|
|
+ # Create a white background with an additional width and height of 50
|
|
|
+ crop_new_width = crop_xmax - crop_xmin + crop_paste_x * 2
|
|
|
+ crop_new_height = crop_ymax - crop_ymin + crop_paste_y * 2
|
|
|
+ return_image = Image.new('RGB', (crop_new_width, crop_new_height), 'white')
|
|
|
+
|
|
|
+ # Crop image
|
|
|
+ crop_box = (crop_xmin, crop_ymin, crop_xmax, crop_ymax)
|
|
|
+ cropped_img = input_pil_img.crop(crop_box)
|
|
|
+ return_image.paste(cropped_img, (crop_paste_x, crop_paste_y))
|
|
|
+ return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height]
|
|
|
+ return return_image, return_list
|
|
|
+
|
|
|
+
|
|
|
class CustomPEKModel:
|
|
|
|
|
|
def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs):
|
|
|
@@ -264,6 +282,8 @@ class CustomPEKModel:
|
|
|
layout_cost = round(time.time() - layout_start, 2)
|
|
|
logger.info(f"layout detection cost: {layout_cost}")
|
|
|
|
|
|
+ pil_img = Image.fromarray(image)
|
|
|
+
|
|
|
if self.apply_formula:
|
|
|
# 公式检测
|
|
|
mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True)[0]
|
|
|
@@ -277,7 +297,8 @@ class CustomPEKModel:
|
|
|
}
|
|
|
layout_res.append(new_item)
|
|
|
latex_filling_list.append(new_item)
|
|
|
- bbox_img = get_croped_image(Image.fromarray(image), [xmin, ymin, xmax, ymax])
|
|
|
+ # bbox_img = get_croped_image(pil_img, [xmin, ymin, xmax, ymax])
|
|
|
+ bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
|
|
|
mf_image_list.append(bbox_img)
|
|
|
|
|
|
# 公式识别
|
|
|
@@ -309,24 +330,6 @@ class CustomPEKModel:
|
|
|
elif int(res['category_id']) in [5]:
|
|
|
table_res_list.append(res)
|
|
|
|
|
|
- # Unified crop img logic
|
|
|
- def crop_img(input_res, input_pil_img, crop_paste_x=0, crop_paste_y=0):
|
|
|
- crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1])
|
|
|
- crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5])
|
|
|
- # Create a white background with an additional width and height of 50
|
|
|
- crop_new_width = crop_xmax - crop_xmin + crop_paste_x * 2
|
|
|
- crop_new_height = crop_ymax - crop_ymin + crop_paste_y * 2
|
|
|
- return_image = Image.new('RGB', (crop_new_width, crop_new_height), 'white')
|
|
|
-
|
|
|
- # Crop image
|
|
|
- crop_box = (crop_xmin, crop_ymin, crop_xmax, crop_ymax)
|
|
|
- cropped_img = input_pil_img.crop(crop_box)
|
|
|
- return_image.paste(cropped_img, (crop_paste_x, crop_paste_y))
|
|
|
- return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height]
|
|
|
- return return_image, return_list
|
|
|
-
|
|
|
- pil_img = Image.fromarray(image)
|
|
|
-
|
|
|
# ocr识别
|
|
|
if self.apply_ocr:
|
|
|
ocr_start = time.time()
|