Selaa lähdekoodia

Merge pull request #2071 from myhloli/dev

refactor(ocr): remove redundant code and improve code quality
Xiaomeng Zhao 7 kuukautta sitten
vanhempi
commit
bb30f32ee2

+ 40 - 38
magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorch_paddle.py

@@ -1,6 +1,7 @@
 # Copyright (c) Opendatalab. All rights reserved.
 import copy
 import os.path
+import warnings
 from pathlib import Path
 
 import cv2
@@ -92,45 +93,46 @@ class PytorchPaddleOCR(TextSystem):
             exit(0)
         img = check_img(img)
         imgs = [img]
-
-        if det and rec:
-            ocr_res = []
-            for img in imgs:
-                img = preprocess_image(img)
-                dt_boxes, rec_res = self.__call__(img, mfd_res=mfd_res)
-                if not dt_boxes and not rec_res:
-                    ocr_res.append(None)
-                    continue
-                tmp_res = [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
-                ocr_res.append(tmp_res)
-            return ocr_res
-        elif det and not rec:
-            ocr_res = []
-            for img in imgs:
-                img = preprocess_image(img)
-                dt_boxes, elapse = self.text_detector(img)
-                # logger.debug("dt_boxes num : {}, elapsed : {}".format(len(dt_boxes), elapse))
-                if dt_boxes is None:
-                    ocr_res.append(None)
-                    continue
-                dt_boxes = sorted_boxes(dt_boxes)
-                # merge_det_boxes 和 update_det_boxes 都会把poly转成bbox再转回poly,因此需要过滤所有倾斜程度较大的文本框
-                dt_boxes = merge_det_boxes(dt_boxes)
-                if mfd_res:
-                    dt_boxes = update_det_boxes(dt_boxes, mfd_res)
-                tmp_res = [box.tolist() for box in dt_boxes]
-                ocr_res.append(tmp_res)
-            return ocr_res
-        elif not det and rec:
-            ocr_res = []
-            for img in imgs:
-                if not isinstance(img, list):
+        with warnings.catch_warnings():
+            warnings.simplefilter("ignore", category=RuntimeWarning)
+            if det and rec:
+                ocr_res = []
+                for img in imgs:
+                    img = preprocess_image(img)
+                    dt_boxes, rec_res = self.__call__(img, mfd_res=mfd_res)
+                    if not dt_boxes and not rec_res:
+                        ocr_res.append(None)
+                        continue
+                    tmp_res = [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
+                    ocr_res.append(tmp_res)
+                return ocr_res
+            elif det and not rec:
+                ocr_res = []
+                for img in imgs:
                     img = preprocess_image(img)
-                    img = [img]
-                rec_res, elapse = self.text_recognizer(img)
-                # logger.debug("rec_res num  : {}, elapsed : {}".format(len(rec_res), elapse))
-                ocr_res.append(rec_res)
-            return ocr_res
+                    dt_boxes, elapse = self.text_detector(img)
+                    # logger.debug("dt_boxes num : {}, elapsed : {}".format(len(dt_boxes), elapse))
+                    if dt_boxes is None:
+                        ocr_res.append(None)
+                        continue
+                    dt_boxes = sorted_boxes(dt_boxes)
+                    # merge_det_boxes 和 update_det_boxes 都会把poly转成bbox再转回poly,因此需要过滤所有倾斜程度较大的文本框
+                    dt_boxes = merge_det_boxes(dt_boxes)
+                    if mfd_res:
+                        dt_boxes = update_det_boxes(dt_boxes, mfd_res)
+                    tmp_res = [box.tolist() for box in dt_boxes]
+                    ocr_res.append(tmp_res)
+                return ocr_res
+            elif not det and rec:
+                ocr_res = []
+                for img in imgs:
+                    if not isinstance(img, list):
+                        img = preprocess_image(img)
+                        img = [img]
+                    rec_res, elapse = self.text_recognizer(img)
+                    # logger.debug("rec_res num  : {}, elapsed : {}".format(len(rec_res), elapse))
+                    ocr_res.append(rec_res)
+                return ocr_res
 
     def __call__(self, img, mfd_res=None):
 

+ 0 - 8
magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_rec.py

@@ -371,12 +371,6 @@ class TextRecognizer(BaseOCRV20):
                     gsrm_slf_attn_bias1_inp = torch.from_numpy(gsrm_slf_attn_bias1_list)
                     gsrm_slf_attn_bias2_inp = torch.from_numpy(gsrm_slf_attn_bias2_list)
 
-                    # if self.use_gpu:
-                    #     inp = inp.cuda()
-                    #     encoder_word_pos_inp = encoder_word_pos_inp.cuda()
-                    #     gsrm_word_pos_inp = gsrm_word_pos_inp.cuda()
-                    #     gsrm_slf_attn_bias1_inp = gsrm_slf_attn_bias1_inp.cuda()
-                    #     gsrm_slf_attn_bias2_inp = gsrm_slf_attn_bias2_inp.cuda()
                     inp = inp.to(self.device)
                     encoder_word_pos_inp = encoder_word_pos_inp.to(self.device)
                     gsrm_word_pos_inp = gsrm_word_pos_inp.to(self.device)
@@ -398,8 +392,6 @@ class TextRecognizer(BaseOCRV20):
 
                 with torch.no_grad():
                     inp = torch.from_numpy(norm_img_batch)
-                    # if self.use_gpu:
-                    #     inp = inp.cuda()
                     inp = inp.to(self.device)
                     preds = self.net(inp)