فهرست منبع

优化ocr推理性能400%

cjsdurj 1 ماه پیش
والد
کامیت
af66bc02c2

+ 3 - 3
mineru/model/ocr/paddleocr2pytorch/pytorch_paddle.py

@@ -12,9 +12,9 @@ from loguru import logger
 from mineru.utils.config_reader import get_device
 from mineru.utils.enum_class import ModelPath
 from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
-from ....utils.ocr_utils import check_img, preprocess_image, sorted_boxes, merge_det_boxes, update_det_boxes, get_rotate_crop_image
-from .tools.infer.predict_system import TextSystem
-from .tools.infer import pytorchocr_utility as utility
+from mineru.utils.ocr_utils import check_img, preprocess_image, sorted_boxes, merge_det_boxes, update_det_boxes, get_rotate_crop_image
+from mineru.model.ocr.paddleocr2pytorch.tools.infer.predict_system import TextSystem
+from mineru.model.ocr.paddleocr2pytorch.tools.infer import pytorchocr_utility as utility
 import argparse
 
 

+ 5 - 5
mineru/model/ocr/paddleocr2pytorch/pytorchocr/data/imaug/operators.py

@@ -23,6 +23,7 @@ import sys
 import six
 import cv2
 import numpy as np
+from PIL import Image
 
 
 class DecodeImage(object):
@@ -104,16 +105,15 @@ class NormalizeImage(object):
         shape = (3, 1, 1) if order == 'chw' else (1, 1, 3)
         self.mean = np.array(mean).reshape(shape).astype('float32')
         self.std = np.array(std).reshape(shape).astype('float32')
+        self.scale = self.scale / self.std
+        self.mean = self.mean / self.std
+
 
     def __call__(self, data):
         img = data['image']
-        from PIL import Image
         if isinstance(img, Image.Image):
             img = np.array(img)
-        assert isinstance(img,
-                          np.ndarray), "invalid input 'img' in NormalizeImage"
-        data['image'] = (
-            img.astype('float32') * self.scale - self.mean) / self.std
+        data['image'] = img.astype('float32') * self.scale - self.mean
         return data
 
 

+ 7 - 7
mineru/model/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_lcnetv3.py

@@ -245,18 +245,18 @@ class LearnableRepLayer(nn.Module):
             return 0, 0
         elif isinstance(branch, ConvBNLayer):
             kernel = branch.conv.weight
-            running_mean = branch.bn._mean
-            running_var = branch.bn._variance
+            running_mean = branch.bn.running_mean
+            running_var = branch.bn.running_var
             gamma = branch.bn.weight
             beta = branch.bn.bias
-            eps = branch.bn._epsilon
+            eps = branch.bn.eps
         else:
             assert isinstance(branch, nn.BatchNorm2d)
             if not hasattr(self, "id_tensor"):
                 input_dim = self.in_channels // self.groups
                 kernel_value = torch.zeros(
                     (self.in_channels, input_dim, self.kernel_size, self.kernel_size),
-                    dtype=branch.weight.dtype,
+                    dtype=branch.weight.dtype,  device= branch.weight.device,
                 )
                 for i in range(self.in_channels):
                     kernel_value[
@@ -264,11 +264,11 @@ class LearnableRepLayer(nn.Module):
                     ] = 1
                 self.id_tensor = kernel_value
             kernel = self.id_tensor
-            running_mean = branch._mean
-            running_var = branch._variance
+            running_mean = branch.running_mean
+            running_var = branch.running_var
             gamma = branch.weight
             beta = branch.bias
-            eps = branch._epsilon
+            eps = branch.eps
         std = (running_var + eps).sqrt()
         t = (gamma / std).reshape((-1, 1, 1, 1))
         return kernel * t, beta - running_mean * gamma / std

+ 25 - 28
mineru/model/ocr/paddleocr2pytorch/pytorchocr/postprocess/rec_postprocess.py

@@ -47,7 +47,7 @@ class BaseRecLabelDecode(object):
         self.dict = {}
         for i, char in enumerate(dict_character):
             self.dict[char] = i
-        self.character = dict_character
+        self.character = np.array(dict_character)
 
     def pred_reverse(self, pred):
         pred_re = []
@@ -143,27 +143,27 @@ class BaseRecLabelDecode(object):
     ):
         """ convert text-index into text-label. """
         result_list = []
-        ignored_tokens = self.get_ignored_tokens()
-        batch_size = len(text_index)
+        batch_size = text_index.shape[0]
+        blank_word = self.get_ignored_tokens()[0]
         for batch_idx in range(batch_size):
-            char_list = []
-            conf_list = []
-            for idx in range(len(text_index[batch_idx])):
-                if text_index[batch_idx][idx] in ignored_tokens:
-                    continue
-                if is_remove_duplicate:
-                    # only for predict
-                    if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
-                            batch_idx][idx]:
-                        continue
-                char_list.append(self.character[int(text_index[batch_idx][
-                    idx])])
-                if text_prob is not None:
-                    conf_list.append(text_prob[batch_idx][idx])
-                else:
-                    conf_list.append(1)
-            text = ''.join(char_list)
-            result_list.append((text, np.mean(conf_list)))
+            probs = None if text_prob is None else np.array(text_prob[batch_idx])
+            sequence = text_index[batch_idx]
+
+            final_mask = sequence != blank_word
+            if is_remove_duplicate:
+                duplicate_mask = np.insert(sequence[1:] != sequence[:-1], 0, True)
+                final_mask &= duplicate_mask
+
+            sequence = sequence[final_mask]
+            probs = None if probs is None else probs[final_mask]
+            text = "".join(self.character[sequence])
+
+            if text_prob is not None and probs is not None and len(probs) > 0:
+                mean_conf = np.mean(probs)
+            else:
+                # 如果没有提供概率或最终结果为空,则默认置信度为1.0
+                mean_conf = 1.0
+            result_list.append((text, mean_conf))
         return result_list
 
     def get_ignored_tokens(self):
@@ -181,13 +181,10 @@ class CTCLabelDecode(BaseRecLabelDecode):
                                              use_space_char)
 
     def __call__(self, preds, label=None, return_word_box=False, *args, **kwargs):
-        if isinstance(preds, torch.Tensor):
-            preds = preds.numpy()
-        preds_idx = preds.argmax(axis=2)
-        preds_prob = preds.max(axis=2)
+        preds_prob, preds_idx = preds.max(axis=2)
         text = self.decode(
-            preds_idx,
-            preds_prob,
+            preds_idx.cpu().numpy(),
+            preds_prob.float().cpu().numpy(),
             is_remove_duplicate=True,
             return_word_box=return_word_box,
         )
@@ -199,7 +196,7 @@ class CTCLabelDecode(BaseRecLabelDecode):
 
         if label is None:
             return text
-        label = self.decode(label)
+        label = self.decode(label.cpu().numpy())
         return text, label
 
     def add_special_char(self, dict_character):

+ 6 - 3
mineru/model/ocr/paddleocr2pytorch/tools/infer/predict_det.py

@@ -116,6 +116,9 @@ class TextDetector(BaseOCRV20):
         self.load_pytorch_weights(self.weights_path)
         self.net.eval()
         self.net.to(self.device)
+        for module in self.net.modules():
+            if hasattr(module, 'rep'):
+                module.rep()
 
     def _batch_process_same_size(self, img_list):
         """
@@ -293,7 +296,7 @@ class TextDetector(BaseOCRV20):
         return dt_boxes
 
     def __call__(self, img):
-        ori_im = img.copy()
+        ori_shape = img.shape
         data = {'image': img}
         data = transform(data, self.preprocess_op)
         img, shape_list = data
@@ -331,9 +334,9 @@ class TextDetector(BaseOCRV20):
         if (self.det_algorithm == "SAST" and
             self.det_sast_polygon) or (self.det_algorithm in ["PSE", "FCE"] and
                                        self.postprocess_op.box_type == 'poly'):
-            dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape)
+            dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_shape)
         else:
-            dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
+            dt_boxes = self.filter_tag_det_res(dt_boxes, ori_shape)
 
         elapse = time.time() - starttime
         return dt_boxes, elapse

+ 16 - 25
mineru/model/ocr/paddleocr2pytorch/tools/infer/predict_rec.py

@@ -9,6 +9,7 @@ from tqdm import tqdm
 from ...pytorchocr.base_ocr_v20 import BaseOCRV20
 from . import pytorchocr_utility as utility
 from ...pytorchocr.postprocess import build_post_process
+from ...pytorchocr.modeling.backbones.rec_hgnet import ConvBNAct
 
 
 class TextRecognizer(BaseOCRV20):
@@ -93,6 +94,12 @@ class TextRecognizer(BaseOCRV20):
         self.load_state_dict(weights)
         self.net.eval()
         self.net.to(self.device)
+        for module in self.net.modules():
+            if isinstance(module, ConvBNAct):
+                if module.use_act:
+                    torch.quantization.fuse_modules(module, ['conv', 'bn', 'act'], inplace=True)
+                else:
+                    torch.quantization.fuse_modules(module, ['conv', 'bn'], inplace=True)
 
     def resize_norm_img(self, img, max_wh_ratio):
         imgC, imgH, imgW = self.rec_image_shape
@@ -125,23 +132,15 @@ class TextRecognizer(BaseOCRV20):
 
         assert imgC == img.shape[2]
         max_wh_ratio = max(max_wh_ratio, imgW / imgH)
-        imgW = int((imgH * max_wh_ratio))
+        imgW = int(imgH * max_wh_ratio)
         imgW = max(min(imgW, self.limited_max_width), self.limited_min_width)
         h, w = img.shape[:2]
         ratio = w / float(h)
-        ratio_imgH = math.ceil(imgH * ratio)
-        ratio_imgH = max(ratio_imgH, self.limited_min_width)
-        if ratio_imgH > imgW:
-            resized_w = imgW
-        else:
-            resized_w = int(ratio_imgH)
-        resized_image = cv2.resize(img, (resized_w, imgH))
-        resized_image = resized_image.astype('float32')
-        resized_image = resized_image.transpose((2, 0, 1)) / 255
-        resized_image -= 0.5
-        resized_image /= 0.5
+        ratio_imgH = max(math.ceil(imgH * ratio), self.limited_min_width)
+        resized_w = min(imgW,int(ratio_imgH))
+        resized_image = cv2.resize(img, (resized_w, imgH)) /127.5 - 1
         padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
-        padding_im[:, :, 0:resized_w] = resized_image
+        padding_im[:, :, 0:resized_w] = resized_image.transpose((2, 0, 1))
         return padding_im
 
     def resize_norm_img_svtr(self, img, image_shape):
@@ -307,12 +306,7 @@ class TextRecognizer(BaseOCRV20):
             for beg_img_no in range(0, img_num, batch_num):
                 end_img_no = min(img_num, beg_img_no + batch_num)
                 norm_img_batch = []
-                max_wh_ratio = 0
-                for ino in range(beg_img_no, end_img_no):
-                    # h, w = img_list[ino].shape[0:2]
-                    h, w = img_list[indices[ino]].shape[0:2]
-                    wh_ratio = w * 1.0 / h
-                    max_wh_ratio = max(max_wh_ratio, wh_ratio)
+                max_wh_ratio = width_list[indices[end_img_no - 1]]
                 for ino in range(beg_img_no, end_img_no):
                     if self.rec_algorithm == "SAR":
                         norm_img, _, _, valid_ratio = self.resize_norm_img_sar(
@@ -420,14 +414,11 @@ class TextRecognizer(BaseOCRV20):
                     with torch.no_grad():
                         inp = torch.from_numpy(norm_img_batch)
                         inp = inp.to(self.device)
-                        prob_out = self.net(inp)
+                        preds = self.net(inp)
 
-                    if isinstance(prob_out, list):
-                        preds = [v.cpu().numpy() for v in prob_out]
-                    else:
-                        preds = prob_out.cpu().numpy()
+                with torch.no_grad():
+                    rec_result = self.postprocess_op(preds)
 
-                rec_result = self.postprocess_op(preds)
                 for rno in range(len(rec_result)):
                     rec_res[indices[beg_img_no + rno]] = rec_result[rno]
                 elapse += time.time() - starttime

+ 16 - 0
mineru/utils/ocr_utils.py

@@ -406,6 +406,12 @@ def calculate_is_angle(poly):
         # logger.info((p3[1] - p1[1])/height)
         return True
 
+def is_bbox_aligned_rect(points):
+    x_coords = points[:, 0]
+    y_coords = points[:, 1]
+    unique_x = np.unique(x_coords)
+    unique_y = np.unique(y_coords)
+    return len(unique_x) == 2 and len(unique_y) == 2
 
 def get_rotate_crop_image(img, points):
     '''
@@ -419,6 +425,16 @@ def get_rotate_crop_image(img, points):
     points[:, 1] = points[:, 1] - top
     '''
     assert len(points) == 4, "shape of points must be 4*2"
+
+    if is_bbox_aligned_rect(points):
+        xmin = int(np.min(points[:, 0]))
+        xmax = int(np.max(points[:, 0]))
+        ymin = int(np.min(points[:, 1]))
+        ymax = int(np.max(points[:, 1]))
+        new_img = img[ymin:ymax, xmin:xmax].copy()
+        if new_img.shape[0] > 0 and new_img.shape[1] > 0:
+            return new_img
+
     img_crop_width = int(
         max(
             np.linalg.norm(points[0] - points[1]),