|
|
@@ -1,6 +1,8 @@
|
|
|
from loguru import logger
|
|
|
import os
|
|
|
import time
|
|
|
+
|
|
|
+os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
|
|
|
try:
|
|
|
import cv2
|
|
|
import yaml
|
|
|
@@ -17,14 +19,17 @@ try:
|
|
|
import unimernet.tasks as tasks
|
|
|
from unimernet.processors import load_processor
|
|
|
|
|
|
- from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
|
|
|
- from magic_pdf.model.pek_sub_modules.post_process import get_croped_image, latex_rm_whitespace
|
|
|
- from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR
|
|
|
except ImportError as e:
|
|
|
logger.exception(e)
|
|
|
- logger.error('Required dependency not installed, please install by \n"pip install magic-pdf[full-cpu] detectron2 --extra-index-url https://myhloli.github.io/wheels/"')
|
|
|
+ logger.error(
|
|
|
+ 'Required dependency not installed, please install by \n'
|
|
|
+ '"pip install magic-pdf[full] detectron2 --extra-index-url https://myhloli.github.io/wheels/"')
|
|
|
exit(1)
|
|
|
|
|
|
+from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
|
|
|
+from magic_pdf.model.pek_sub_modules.post_process import get_croped_image, latex_rm_whitespace
|
|
|
+from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR
|
|
|
+
|
|
|
|
|
|
def mfd_model_init(weight):
|
|
|
mfd_model = YOLO(weight)
|
|
|
@@ -84,7 +89,7 @@ class CustomPEKModel:
|
|
|
model_config_dir = os.path.join(root_dir, 'resources', 'model_config')
|
|
|
# 构建 model_configs.yaml 文件的完整路径
|
|
|
config_path = os.path.join(model_config_dir, 'model_configs.yaml')
|
|
|
- with open(config_path, "r") as f:
|
|
|
+ with open(config_path, "r", encoding='utf-8') as f:
|
|
|
self.configs = yaml.load(f, Loader=yaml.FullLoader)
|
|
|
# 初始化解析配置
|
|
|
self.apply_layout = kwargs.get("apply_layout", self.configs["config"]["layout"])
|
|
|
@@ -100,6 +105,7 @@ class CustomPEKModel:
|
|
|
self.device = kwargs.get("device", self.configs["config"]["device"])
|
|
|
logger.info("using device: {}".format(self.device))
|
|
|
models_dir = kwargs.get("models_dir", os.path.join(root_dir, "resources", "models"))
|
|
|
+ logger.info("using models_dir: {}".format(models_dir))
|
|
|
|
|
|
# 初始化公式识别
|
|
|
if self.apply_formula:
|
|
|
@@ -135,66 +141,110 @@ class CustomPEKModel:
|
|
|
layout_cost = round(time.time() - layout_start, 2)
|
|
|
logger.info(f"layout detection cost: {layout_cost}")
|
|
|
|
|
|
- # 公式检测
|
|
|
- mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True)[0]
|
|
|
- for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
|
|
|
- xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
|
|
|
- new_item = {
|
|
|
- 'category_id': 13 + int(cla.item()),
|
|
|
- 'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
|
|
|
- 'score': round(float(conf.item()), 2),
|
|
|
- 'latex': '',
|
|
|
- }
|
|
|
- layout_res.append(new_item)
|
|
|
- latex_filling_list.append(new_item)
|
|
|
- bbox_img = get_croped_image(Image.fromarray(image), [xmin, ymin, xmax, ymax])
|
|
|
- mf_image_list.append(bbox_img)
|
|
|
-
|
|
|
- # 公式识别
|
|
|
- mfr_start = time.time()
|
|
|
- dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
|
|
|
- dataloader = DataLoader(dataset, batch_size=64, num_workers=0)
|
|
|
- mfr_res = []
|
|
|
- for mf_img in dataloader:
|
|
|
- mf_img = mf_img.to(self.device)
|
|
|
- output = self.mfr_model.generate({'image': mf_img})
|
|
|
- mfr_res.extend(output['pred_str'])
|
|
|
- for res, latex in zip(latex_filling_list, mfr_res):
|
|
|
- res['latex'] = latex_rm_whitespace(latex)
|
|
|
- mfr_cost = round(time.time() - mfr_start, 2)
|
|
|
- logger.info(f"formula nums: {len(mf_image_list)}, mfr time: {mfr_cost}")
|
|
|
+ if self.apply_formula:
|
|
|
+ # 公式检测
|
|
|
+ mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True)[0]
|
|
|
+ for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
|
|
|
+ xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
|
|
|
+ new_item = {
|
|
|
+ 'category_id': 13 + int(cla.item()),
|
|
|
+ 'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
|
|
|
+ 'score': round(float(conf.item()), 2),
|
|
|
+ 'latex': '',
|
|
|
+ }
|
|
|
+ layout_res.append(new_item)
|
|
|
+ latex_filling_list.append(new_item)
|
|
|
+ bbox_img = get_croped_image(Image.fromarray(image), [xmin, ymin, xmax, ymax])
|
|
|
+ mf_image_list.append(bbox_img)
|
|
|
+
|
|
|
+ # 公式识别
|
|
|
+ mfr_start = time.time()
|
|
|
+ dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
|
|
|
+ dataloader = DataLoader(dataset, batch_size=64, num_workers=0)
|
|
|
+ mfr_res = []
|
|
|
+ for mf_img in dataloader:
|
|
|
+ mf_img = mf_img.to(self.device)
|
|
|
+ output = self.mfr_model.generate({'image': mf_img})
|
|
|
+ mfr_res.extend(output['pred_str'])
|
|
|
+ for res, latex in zip(latex_filling_list, mfr_res):
|
|
|
+ res['latex'] = latex_rm_whitespace(latex)
|
|
|
+ mfr_cost = round(time.time() - mfr_start, 2)
|
|
|
+ logger.info(f"formula nums: {len(mf_image_list)}, mfr time: {mfr_cost}")
|
|
|
|
|
|
# ocr识别
|
|
|
if self.apply_ocr:
|
|
|
ocr_start = time.time()
|
|
|
pil_img = Image.fromarray(image)
|
|
|
+
|
|
|
+ # 筛选出需要OCR的区域和公式区域
|
|
|
+ ocr_res_list = []
|
|
|
single_page_mfdetrec_res = []
|
|
|
for res in layout_res:
|
|
|
if int(res['category_id']) in [13, 14]:
|
|
|
- xmin, ymin = int(res['poly'][0]), int(res['poly'][1])
|
|
|
- xmax, ymax = int(res['poly'][4]), int(res['poly'][5])
|
|
|
single_page_mfdetrec_res.append({
|
|
|
- "bbox": [xmin, ymin, xmax, ymax],
|
|
|
+ "bbox": [int(res['poly'][0]), int(res['poly'][1]),
|
|
|
+ int(res['poly'][4]), int(res['poly'][5])],
|
|
|
})
|
|
|
- for res in layout_res:
|
|
|
- if int(res['category_id']) in [0, 1, 2, 4, 6, 7]: # 需要进行ocr的类别
|
|
|
- xmin, ymin = int(res['poly'][0]), int(res['poly'][1])
|
|
|
- xmax, ymax = int(res['poly'][4]), int(res['poly'][5])
|
|
|
- crop_box = (xmin, ymin, xmax, ymax)
|
|
|
- cropped_img = Image.new('RGB', pil_img.size, 'white')
|
|
|
- cropped_img.paste(pil_img.crop(crop_box), crop_box)
|
|
|
- cropped_img = cv2.cvtColor(np.asarray(cropped_img), cv2.COLOR_RGB2BGR)
|
|
|
- ocr_res = self.ocr_model.ocr(cropped_img, mfd_res=single_page_mfdetrec_res)[0]
|
|
|
- if ocr_res:
|
|
|
- for box_ocr_res in ocr_res:
|
|
|
- p1, p2, p3, p4 = box_ocr_res[0]
|
|
|
- text, score = box_ocr_res[1]
|
|
|
- layout_res.append({
|
|
|
- 'category_id': 15,
|
|
|
- 'poly': p1 + p2 + p3 + p4,
|
|
|
- 'score': round(score, 2),
|
|
|
- 'text': text,
|
|
|
- })
|
|
|
+ elif int(res['category_id']) in [0, 1, 2, 4, 6, 7]:
|
|
|
+ ocr_res_list.append(res)
|
|
|
+
|
|
|
+ # 对每一个需OCR处理的区域进行处理
|
|
|
+ for res in ocr_res_list:
|
|
|
+ xmin, ymin = int(res['poly'][0]), int(res['poly'][1])
|
|
|
+ xmax, ymax = int(res['poly'][4]), int(res['poly'][5])
|
|
|
+
|
|
|
+ paste_x = 50
|
|
|
+ paste_y = 50
|
|
|
+ # 创建一个宽高各多50的白色背景
|
|
|
+ new_width = xmax - xmin + paste_x * 2
|
|
|
+ new_height = ymax - ymin + paste_y * 2
|
|
|
+ new_image = Image.new('RGB', (new_width, new_height), 'white')
|
|
|
+
|
|
|
+ # 裁剪图像
|
|
|
+ crop_box = (xmin, ymin, xmax, ymax)
|
|
|
+ cropped_img = pil_img.crop(crop_box)
|
|
|
+ new_image.paste(cropped_img, (paste_x, paste_y))
|
|
|
+
|
|
|
+ # 调整公式区域坐标
|
|
|
+ adjusted_mfdetrec_res = []
|
|
|
+ for mf_res in single_page_mfdetrec_res:
|
|
|
+ mf_xmin, mf_ymin, mf_xmax, mf_ymax = mf_res["bbox"]
|
|
|
+ # 将公式区域坐标调整为相对于裁剪区域的坐标
|
|
|
+ x0 = mf_xmin - xmin + paste_x
|
|
|
+ y0 = mf_ymin - ymin + paste_y
|
|
|
+ x1 = mf_xmax - xmin + paste_x
|
|
|
+ y1 = mf_ymax - ymin + paste_y
|
|
|
+ # 过滤在图外的公式块
|
|
|
+ if any([x1 < 0, y1 < 0]) or any([x0 > new_width, y0 > new_height]):
|
|
|
+ continue
|
|
|
+ else:
|
|
|
+ adjusted_mfdetrec_res.append({
|
|
|
+ "bbox": [x0, y0, x1, y1],
|
|
|
+ })
|
|
|
+
|
|
|
+ # OCR识别
|
|
|
+ new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
|
|
|
+ ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0]
|
|
|
+
|
|
|
+ # 整合结果
|
|
|
+ if ocr_res:
|
|
|
+ for box_ocr_res in ocr_res:
|
|
|
+ p1, p2, p3, p4 = box_ocr_res[0]
|
|
|
+ text, score = box_ocr_res[1]
|
|
|
+
|
|
|
+ # 将坐标转换回原图坐标系
|
|
|
+ p1 = [p1[0] - paste_x + xmin, p1[1] - paste_y + ymin]
|
|
|
+ p2 = [p2[0] - paste_x + xmin, p2[1] - paste_y + ymin]
|
|
|
+ p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin]
|
|
|
+ p4 = [p4[0] - paste_x + xmin, p4[1] - paste_y + ymin]
|
|
|
+
|
|
|
+ layout_res.append({
|
|
|
+ 'category_id': 15,
|
|
|
+ 'poly': p1 + p2 + p3 + p4,
|
|
|
+ 'score': round(score, 2),
|
|
|
+ 'text': text,
|
|
|
+ })
|
|
|
+
|
|
|
ocr_cost = round(time.time() - ocr_start, 2)
|
|
|
logger.info(f"ocr cost: {ocr_cost}")
|
|
|
|