|
|
@@ -1,7 +1,9 @@
|
|
|
import copy
|
|
|
+import platform
|
|
|
import time
|
|
|
import cv2
|
|
|
import numpy as np
|
|
|
+import torch
|
|
|
|
|
|
from paddleocr import PaddleOCR
|
|
|
from ppocr.utils.logging import get_logger
|
|
|
@@ -9,12 +11,23 @@ from ppocr.utils.utility import alpha_to_color, binarize_img
|
|
|
from tools.infer.predict_system import sorted_boxes
|
|
|
from tools.infer.utility import get_rotate_crop_image, get_minarea_rect_crop
|
|
|
|
|
|
-from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import update_det_boxes, merge_det_boxes, check_img
|
|
|
+from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import update_det_boxes, merge_det_boxes, check_img, \
|
|
|
+ ONNXModelSingleton
|
|
|
|
|
|
logger = get_logger()
|
|
|
|
|
|
|
|
|
class ModifiedPaddleOCR(PaddleOCR):
|
|
|
+ def __init__(self, *args, **kwargs):
|
|
|
+
|
|
|
+ super().__init__(*args, **kwargs)
|
|
|
+
|
|
|
+ # 在cpu架构为arm且不支持cuda时调用onnx、
|
|
|
+ if not torch.cuda.is_available() and platform.machine() in ['arm64', 'aarch64']:
|
|
|
+ self.use_onnx = True
|
|
|
+ onnx_model_manager = ONNXModelSingleton()
|
|
|
+ self.additional_ocr = onnx_model_manager.get_onnx_model(**kwargs)
|
|
|
+
|
|
|
def ocr(self,
|
|
|
img,
|
|
|
det=True,
|
|
|
@@ -79,7 +92,10 @@ class ModifiedPaddleOCR(PaddleOCR):
|
|
|
ocr_res = []
|
|
|
for img in imgs:
|
|
|
img = preprocess_image(img)
|
|
|
- dt_boxes, elapse = self.text_detector(img)
|
|
|
+ if self.use_onnx:
|
|
|
+ dt_boxes, elapse = self.additional_ocr.text_detector(img)
|
|
|
+ else:
|
|
|
+ dt_boxes, elapse = self.text_detector(img)
|
|
|
if dt_boxes is None:
|
|
|
ocr_res.append(None)
|
|
|
continue
|
|
|
@@ -106,7 +122,10 @@ class ModifiedPaddleOCR(PaddleOCR):
|
|
|
img, cls_res_tmp, elapse = self.text_classifier(img)
|
|
|
if not rec:
|
|
|
cls_res.append(cls_res_tmp)
|
|
|
- rec_res, elapse = self.text_recognizer(img)
|
|
|
+ if self.use_onnx:
|
|
|
+ rec_res, elapse = self.additional_ocr.text_recognizer(img)
|
|
|
+ else:
|
|
|
+ rec_res, elapse = self.text_recognizer(img)
|
|
|
ocr_res.append(rec_res)
|
|
|
if not rec:
|
|
|
return cls_res
|
|
|
@@ -121,7 +140,10 @@ class ModifiedPaddleOCR(PaddleOCR):
|
|
|
|
|
|
start = time.time()
|
|
|
ori_im = img.copy()
|
|
|
- dt_boxes, elapse = self.text_detector(img)
|
|
|
+ if self.use_onnx:
|
|
|
+ dt_boxes, elapse = self.additional_ocr.text_detector(img)
|
|
|
+ else:
|
|
|
+ dt_boxes, elapse = self.text_detector(img)
|
|
|
time_dict['det'] = elapse
|
|
|
|
|
|
if dt_boxes is None:
|
|
|
@@ -159,8 +181,10 @@ class ModifiedPaddleOCR(PaddleOCR):
|
|
|
time_dict['cls'] = elapse
|
|
|
logger.debug("cls num : {}, elapsed : {}".format(
|
|
|
len(img_crop_list), elapse))
|
|
|
-
|
|
|
- rec_res, elapse = self.text_recognizer(img_crop_list)
|
|
|
+ if self.use_onnx:
|
|
|
+ rec_res, elapse = self.additional_ocr.text_recognizer(img_crop_list)
|
|
|
+ else:
|
|
|
+ rec_res, elapse = self.text_recognizer(img_crop_list)
|
|
|
time_dict['rec'] = elapse
|
|
|
logger.debug("rec_res num : {}, elapsed : {}".format(
|
|
|
len(rec_res), elapse))
|