|
|
@@ -107,83 +107,78 @@ class CustomPEKModel:
|
|
|
|
|
|
logger.info('DocAnalysis init done!')
|
|
|
|
|
|
- def __call__(self, images):
|
|
|
- # layout检测 + 公式检测
|
|
|
- doc_layout_result = []
|
|
|
+ def __call__(self, image):
|
|
|
+
|
|
|
latex_filling_list = []
|
|
|
mf_image_list = []
|
|
|
- for idx, img_dict in enumerate(images):
|
|
|
- image = img_dict["img"]
|
|
|
- img_height, img_width = img_dict["height"], img_dict["width"]
|
|
|
- layout_res = self.layout_model(image, ignore_catids=[])
|
|
|
- # 公式检测
|
|
|
- mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True)[0]
|
|
|
- for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
|
|
|
- xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
|
|
|
- new_item = {
|
|
|
- 'category_id': 13 + int(cla.item()),
|
|
|
- 'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
|
|
|
- 'score': round(float(conf.item()), 2),
|
|
|
- 'latex': '',
|
|
|
- }
|
|
|
- layout_res['layout_dets'].append(new_item)
|
|
|
- latex_filling_list.append(new_item)
|
|
|
- bbox_img = get_croped_image(Image.fromarray(image), [xmin, ymin, xmax, ymax])
|
|
|
- mf_image_list.append(bbox_img)
|
|
|
-
|
|
|
- layout_res['page_info'] = dict(
|
|
|
- page_no=idx,
|
|
|
- height=img_height,
|
|
|
- width=img_width
|
|
|
- )
|
|
|
- doc_layout_result.append(layout_res)
|
|
|
|
|
|
- # 公式识别,因为识别速度较慢,为了提速,把单个pdf的所有公式裁剪完,一起批量做识别。
|
|
|
- a = time.time()
|
|
|
+ # layout检测
|
|
|
+ layout_start = time.time()
|
|
|
+ layout_res = self.layout_model(image, ignore_catids=[])
|
|
|
+ layout_cost = round(time.time() - layout_start, 2)
|
|
|
+ logger.info(f"layout detection cost: {layout_cost}")
|
|
|
+
|
|
|
+ # 公式检测
|
|
|
+ mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True)[0]
|
|
|
+ for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
|
|
|
+ xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
|
|
|
+ new_item = {
|
|
|
+ 'category_id': 13 + int(cla.item()),
|
|
|
+ 'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
|
|
|
+ 'score': round(float(conf.item()), 2),
|
|
|
+ 'latex': '',
|
|
|
+ }
|
|
|
+ layout_res.append(new_item)
|
|
|
+ latex_filling_list.append(new_item)
|
|
|
+ bbox_img = get_croped_image(Image.fromarray(image), [xmin, ymin, xmax, ymax])
|
|
|
+ mf_image_list.append(bbox_img)
|
|
|
+
|
|
|
+ # 公式识别
|
|
|
+ mfr_start = time.time()
|
|
|
dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
|
|
|
- dataloader = DataLoader(dataset, batch_size=128, num_workers=0)
|
|
|
+ dataloader = DataLoader(dataset, batch_size=64, num_workers=0)
|
|
|
mfr_res = []
|
|
|
- for imgs in dataloader:
|
|
|
- imgs = imgs.to(self.device)
|
|
|
- output = self.mfr_model.generate({'image': imgs})
|
|
|
+ for mf_img in dataloader:
|
|
|
+ mf_img = mf_img.to(self.device)
|
|
|
+ output = self.mfr_model.generate({'image': mf_img})
|
|
|
mfr_res.extend(output['pred_str'])
|
|
|
for res, latex in zip(latex_filling_list, mfr_res):
|
|
|
res['latex'] = latex_rm_whitespace(latex)
|
|
|
- b = time.time()
|
|
|
- logger.info(f"formula nums: {len(mf_image_list)}, mfr time: {round(b - a, 2)}")
|
|
|
+ mfr_cost = round(time.time() - mfr_start, 2)
|
|
|
+ logger.info(f"formula nums: {len(mf_image_list)}, mfr time: {mfr_cost}")
|
|
|
|
|
|
# ocr识别
|
|
|
if self.apply_ocr:
|
|
|
- for idx, img_dict in enumerate(images):
|
|
|
- image = img_dict["img"]
|
|
|
- pil_img = Image.fromarray(image)
|
|
|
- single_page_res = doc_layout_result[idx]['layout_dets']
|
|
|
- single_page_mfdetrec_res = []
|
|
|
- for res in single_page_res:
|
|
|
- if int(res['category_id']) in [13, 14]:
|
|
|
- xmin, ymin = int(res['poly'][0]), int(res['poly'][1])
|
|
|
- xmax, ymax = int(res['poly'][4]), int(res['poly'][5])
|
|
|
- single_page_mfdetrec_res.append({
|
|
|
- "bbox": [xmin, ymin, xmax, ymax],
|
|
|
- })
|
|
|
- for res in single_page_res:
|
|
|
- if int(res['category_id']) in [0, 1, 2, 4, 6, 7]: # 需要进行ocr的类别
|
|
|
- xmin, ymin = int(res['poly'][0]), int(res['poly'][1])
|
|
|
- xmax, ymax = int(res['poly'][4]), int(res['poly'][5])
|
|
|
- crop_box = [xmin, ymin, xmax, ymax]
|
|
|
- cropped_img = Image.new('RGB', pil_img.size, 'white')
|
|
|
- cropped_img.paste(pil_img.crop(crop_box), crop_box)
|
|
|
- cropped_img = cv2.cvtColor(np.asarray(cropped_img), cv2.COLOR_RGB2BGR)
|
|
|
- ocr_res = self.ocr_model.ocr(cropped_img, mfd_res=single_page_mfdetrec_res)[0]
|
|
|
- if ocr_res:
|
|
|
- for box_ocr_res in ocr_res:
|
|
|
- p1, p2, p3, p4 = box_ocr_res[0]
|
|
|
- text, score = box_ocr_res[1]
|
|
|
- doc_layout_result[idx]['layout_dets'].append({
|
|
|
- 'category_id': 15,
|
|
|
- 'poly': p1 + p2 + p3 + p4,
|
|
|
- 'score': round(score, 2),
|
|
|
- 'text': text,
|
|
|
- })
|
|
|
-
|
|
|
- return doc_layout_result
|
|
|
+ ocr_start = time.time()
|
|
|
+ pil_img = Image.fromarray(image)
|
|
|
+ single_page_mfdetrec_res = []
|
|
|
+ for res in layout_res:
|
|
|
+ if int(res['category_id']) in [13, 14]:
|
|
|
+ xmin, ymin = int(res['poly'][0]), int(res['poly'][1])
|
|
|
+ xmax, ymax = int(res['poly'][4]), int(res['poly'][5])
|
|
|
+ single_page_mfdetrec_res.append({
|
|
|
+ "bbox": [xmin, ymin, xmax, ymax],
|
|
|
+ })
|
|
|
+ for res in layout_res:
|
|
|
+ if int(res['category_id']) in [0, 1, 2, 4, 6, 7]: # 需要进行ocr的类别
|
|
|
+ xmin, ymin = int(res['poly'][0]), int(res['poly'][1])
|
|
|
+ xmax, ymax = int(res['poly'][4]), int(res['poly'][5])
|
|
|
+ crop_box = (xmin, ymin, xmax, ymax)
|
|
|
+ cropped_img = Image.new('RGB', pil_img.size, 'white')
|
|
|
+ cropped_img.paste(pil_img.crop(crop_box), crop_box)
|
|
|
+ cropped_img = cv2.cvtColor(np.asarray(cropped_img), cv2.COLOR_RGB2BGR)
|
|
|
+ ocr_res = self.ocr_model.ocr(cropped_img, mfd_res=single_page_mfdetrec_res)[0]
|
|
|
+ if ocr_res:
|
|
|
+ for box_ocr_res in ocr_res:
|
|
|
+ p1, p2, p3, p4 = box_ocr_res[0]
|
|
|
+ text, score = box_ocr_res[1]
|
|
|
+ layout_res.append({
|
|
|
+ 'category_id': 15,
|
|
|
+ 'poly': p1 + p2 + p3 + p4,
|
|
|
+ 'score': round(score, 2),
|
|
|
+ 'text': text,
|
|
|
+ })
|
|
|
+ ocr_cost = round(time.time() - ocr_start, 2)
|
|
|
+ logger.info(f"ocr cost: {ocr_cost}")
|
|
|
+
|
|
|
+ return layout_res
|