|
|
@@ -27,7 +27,7 @@ except ImportError as e:
|
|
|
logger.exception(e)
|
|
|
logger.error(
|
|
|
'Required dependency not installed, please install by \n'
|
|
|
- '"pip install magic-pdf[full] detectron2 --extra-index-url https://myhloli.github.io/wheels/"')
|
|
|
+ '"pip install magic-pdf[full] --extra-index-url https://myhloli.github.io/wheels/"')
|
|
|
exit(1)
|
|
|
|
|
|
from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
|
|
|
@@ -188,50 +188,56 @@ class CustomPEKModel:
|
|
|
mfr_cost = round(time.time() - mfr_start, 2)
|
|
|
logger.info(f"formula nums: {len(mf_image_list)}, mfr time: {mfr_cost}")
|
|
|
|
|
|
+ # Select regions for OCR / formula regions / table regions
|
|
|
+ ocr_res_list = []
|
|
|
+ table_res_list = []
|
|
|
+ single_page_mfdetrec_res = []
|
|
|
+ for res in layout_res:
|
|
|
+ if int(res['category_id']) in [13, 14]:
|
|
|
+ single_page_mfdetrec_res.append({
|
|
|
+ "bbox": [int(res['poly'][0]), int(res['poly'][1]),
|
|
|
+ int(res['poly'][4]), int(res['poly'][5])],
|
|
|
+ })
|
|
|
+ elif int(res['category_id']) in [0, 1, 2, 4, 6, 7]:
|
|
|
+ ocr_res_list.append(res)
|
|
|
+ elif int(res['category_id']) in [5]:
|
|
|
+ table_res_list.append(res)
|
|
|
+
|
|
|
+ # Unified crop img logic
|
|
|
+ def crop_img(input_res, input_pil_img, crop_paste_x=0, crop_paste_y=0):
|
|
|
+ crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1])
|
|
|
+ crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5])
|
|
|
+ # Create a white background with an additional width and height of 50
|
|
|
+ crop_new_width = crop_xmax - crop_xmin + crop_paste_x * 2
|
|
|
+ crop_new_height = crop_ymax - crop_ymin + crop_paste_y * 2
|
|
|
+ return_image = Image.new('RGB', (crop_new_width, crop_new_height), 'white')
|
|
|
+
|
|
|
+ # Crop image
|
|
|
+ crop_box = (crop_xmin, crop_ymin, crop_xmax, crop_ymax)
|
|
|
+ cropped_img = input_pil_img.crop(crop_box)
|
|
|
+ return_image.paste(cropped_img, (crop_paste_x, crop_paste_y))
|
|
|
+ return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height]
|
|
|
+ return return_image, return_list
|
|
|
+
|
|
|
+ pil_img = Image.fromarray(image)
|
|
|
+
|
|
|
# ocr识别
|
|
|
if self.apply_ocr:
|
|
|
ocr_start = time.time()
|
|
|
- pil_img = Image.fromarray(image)
|
|
|
-
|
|
|
- # 筛选出需要OCR的区域和公式区域
|
|
|
- ocr_res_list = []
|
|
|
- single_page_mfdetrec_res = []
|
|
|
- for res in layout_res:
|
|
|
- if int(res['category_id']) in [13, 14]:
|
|
|
- single_page_mfdetrec_res.append({
|
|
|
- "bbox": [int(res['poly'][0]), int(res['poly'][1]),
|
|
|
- int(res['poly'][4]), int(res['poly'][5])],
|
|
|
- })
|
|
|
- elif int(res['category_id']) in [0, 1, 2, 4, 6, 7]:
|
|
|
- ocr_res_list.append(res)
|
|
|
-
|
|
|
- # 对每一个需OCR处理的区域进行处理
|
|
|
+ # Process each area that requires OCR processing
|
|
|
for res in ocr_res_list:
|
|
|
- xmin, ymin = int(res['poly'][0]), int(res['poly'][1])
|
|
|
- xmax, ymax = int(res['poly'][4]), int(res['poly'][5])
|
|
|
-
|
|
|
- paste_x = 50
|
|
|
- paste_y = 50
|
|
|
- # 创建一个宽高各多50的白色背景
|
|
|
- new_width = xmax - xmin + paste_x * 2
|
|
|
- new_height = ymax - ymin + paste_y * 2
|
|
|
- new_image = Image.new('RGB', (new_width, new_height), 'white')
|
|
|
-
|
|
|
- # 裁剪图像
|
|
|
- crop_box = (xmin, ymin, xmax, ymax)
|
|
|
- cropped_img = pil_img.crop(crop_box)
|
|
|
- new_image.paste(cropped_img, (paste_x, paste_y))
|
|
|
-
|
|
|
- # 调整公式区域坐标
|
|
|
+ new_image, useful_list = crop_img(res, pil_img, crop_paste_x=50, crop_paste_y=50)
|
|
|
+ paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
|
|
|
+ # Adjust the coordinates of the formula area
|
|
|
adjusted_mfdetrec_res = []
|
|
|
for mf_res in single_page_mfdetrec_res:
|
|
|
mf_xmin, mf_ymin, mf_xmax, mf_ymax = mf_res["bbox"]
|
|
|
- # 将公式区域坐标调整为相对于裁剪区域的坐标
|
|
|
+ # Adjust the coordinates of the formula area to the coordinates relative to the cropping area
|
|
|
x0 = mf_xmin - xmin + paste_x
|
|
|
y0 = mf_ymin - ymin + paste_y
|
|
|
x1 = mf_xmax - xmin + paste_x
|
|
|
y1 = mf_ymax - ymin + paste_y
|
|
|
- # 过滤在图外的公式块
|
|
|
+ # Filter formula blocks outside the graph
|
|
|
if any([x1 < 0, y1 < 0]) or any([x0 > new_width, y0 > new_height]):
|
|
|
continue
|
|
|
else:
|
|
|
@@ -239,17 +245,17 @@ class CustomPEKModel:
|
|
|
"bbox": [x0, y0, x1, y1],
|
|
|
})
|
|
|
|
|
|
- # OCR识别
|
|
|
+ # OCR recognition
|
|
|
new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
|
|
|
ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0]
|
|
|
|
|
|
- # 整合结果
|
|
|
+ # Integration results
|
|
|
if ocr_res:
|
|
|
for box_ocr_res in ocr_res:
|
|
|
p1, p2, p3, p4 = box_ocr_res[0]
|
|
|
text, score = box_ocr_res[1]
|
|
|
|
|
|
- # 将坐标转换回原图坐标系
|
|
|
+ # Convert the coordinates back to the original coordinate system
|
|
|
p1 = [p1[0] - paste_x + xmin, p1[1] - paste_y + ymin]
|
|
|
p2 = [p2[0] - paste_x + xmin, p2[1] - paste_y + ymin]
|
|
|
p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin]
|
|
|
@@ -267,35 +273,23 @@ class CustomPEKModel:
|
|
|
|
|
|
# 表格识别 table recognition
|
|
|
if self.apply_table:
|
|
|
- pil_img = Image.fromarray(image)
|
|
|
- for layout in layout_res:
|
|
|
- if layout.get("category_id", -1) == 5:
|
|
|
- poly = layout["poly"]
|
|
|
- xmin, ymin = int(poly[0]), int(poly[1])
|
|
|
- xmax, ymax = int(poly[4]), int(poly[5])
|
|
|
-
|
|
|
- paste_x = 50
|
|
|
- paste_y = 50
|
|
|
- # 创建一个宽高各多50的白色背景 create a whiteboard with 50 larger width and length
|
|
|
- new_width = xmax - xmin + paste_x * 2
|
|
|
- new_height = ymax - ymin + paste_y * 2
|
|
|
- new_image = Image.new('RGB', (new_width, new_height), 'white')
|
|
|
-
|
|
|
- # 裁剪图像 crop image
|
|
|
- crop_box = (xmin, ymin, xmax, ymax)
|
|
|
- cropped_img = pil_img.crop(crop_box)
|
|
|
- new_image.paste(cropped_img, (paste_x, paste_y))
|
|
|
- start_time = time.time()
|
|
|
- logger.info("------------------table recognition processing begins-----------------")
|
|
|
+ table_start = time.time()
|
|
|
+ for res in table_res_list:
|
|
|
+ new_image, _ = crop_img(res, pil_img)
|
|
|
+ single_table_start_time = time.time()
|
|
|
+ logger.info("------------------table recognition processing begins-----------------")
|
|
|
+ with torch.no_grad():
|
|
|
latex_code = self.table_model.image2latex(new_image)[0]
|
|
|
- end_time = time.time()
|
|
|
- run_time = end_time - start_time
|
|
|
- logger.info(f"------------table recognition processing ends within {run_time}s-----")
|
|
|
- if run_time > self.table_max_time:
|
|
|
- logger.warning(f"------------table recognition processing exceeds max time {self.table_max_time}s----------")
|
|
|
- # 判断是否返回正常
|
|
|
- if latex_code and latex_code.strip().endswith('end{tabular}'):
|
|
|
- layout["latex"] = latex_code
|
|
|
- else:
|
|
|
- logger.warning(f"------------table recognition processing fails----------")
|
|
|
+ run_time = time.time() - single_table_start_time
|
|
|
+ logger.info(f"------------table recognition processing ends within {run_time}s-----")
|
|
|
+ if run_time > self.table_max_time:
|
|
|
+ logger.warning(f"------------table recognition processing exceeds max time {self.table_max_time}s----------")
|
|
|
+ # 判断是否返回正常
|
|
|
+ if latex_code and latex_code.strip().endswith('end{tabular}'):
|
|
|
+ res["latex"] = latex_code
|
|
|
+ else:
|
|
|
+ logger.warning(f"------------table recognition processing fails----------")
|
|
|
+ table_cost = round(time.time() - table_start, 2)
|
|
|
+ logger.info(f"table cost: {table_cost}")
|
|
|
+
|
|
|
return layout_res
|