Sfoglia il codice sorgente

Merge pull request #2505 from myhloli/dev

feat(ocr): add PPHGNetV2_B4 backbone and update OCR models
Xiaomeng Zhao 5 mesi fa
parent
commit
ea3003f6ef

+ 1 - 1
magic_pdf/model/doc_analyze_by_custom_model.py

@@ -189,7 +189,7 @@ def batch_doc_analyze(
     formula_enable=None,
     table_enable=None,
 ):
-    MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 200))
+    MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 100))
     batch_size = MIN_BATCH_INFERENCE_SIZE
     page_wh_list = []
 

+ 2 - 1
magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/__init__.py

@@ -35,7 +35,7 @@ def build_backbone(config, model_type):
         from .rec_mobilenet_v3 import MobileNetV3
         from .rec_svtrnet import SVTRNet
         from .rec_mv1_enhance import MobileNetV1Enhance
-
+        from .rec_pphgnetv2 import PPHGNetV2_B4
         support_dict = [
             "MobileNetV1Enhance",
             "MobileNetV3",
@@ -48,6 +48,7 @@ def build_backbone(config, model_type):
             "DenseNet",
             "PPLCNetV3",
             "PPHGNet_small",
+            "PPHGNetV2_B4",
         ]
     else:
         raise NotImplementedError

+ 18 - 5
magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/rnn.py

@@ -9,14 +9,27 @@ class Im2Seq(nn.Module):
         super().__init__()
         self.out_channels = in_channels
 
+    # def forward(self, x):
+    #     B, C, H, W = x.shape
+    #     # assert H == 1
+    #     x = x.squeeze(dim=2)
+    #     # x = x.transpose([0, 2, 1])  # paddle (NTC)(batch, width, channels)
+    #     x = x.permute(0, 2, 1)
+    #     return x
+
     def forward(self, x):
         B, C, H, W = x.shape
-        # assert H == 1
-        x = x.squeeze(dim=2)
-        # x = x.transpose([0, 2, 1])  # paddle (NTC)(batch, width, channels)
-        x = x.permute(0, 2, 1)
-        return x
+        # 处理四维张量,将空间维度展平为序列
+        if H == 1:
+            # 原来的处理逻辑,适用于H=1的情况
+            x = x.squeeze(dim=2)
+            x = x.permute(0, 2, 1)  # (B, W, C)
+        else:
+            # 处理H不为1的情况
+            x = x.permute(0, 2, 3, 1)  # (B, H, W, C)
+            x = x.reshape(B, H * W, C)  # (B, H*W, C)
 
+        return x
 
 class EncoderWithRNN_(nn.Module):
     def __init__(self, in_channels, hidden_size):

+ 26 - 0
magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/arch_config.yaml

@@ -212,6 +212,32 @@ ch_PP-OCRv4_rec_server_doc_infer:
           nrtr_dim: 384
           max_text_length: 25
 
+ch_PP-OCRv5_rec_server_infer:
+  model_type: rec
+  algorithm: SVTR_HGNet
+  Transform:
+  Backbone:
+    name: PPHGNetV2_B4
+    text_rec: True
+  Head:
+    name: MultiHead
+    out_channels_list:
+      CTCLabelDecode: 18385
+    head_list:
+      - CTCHead:
+          Neck:
+            name: svtr
+            dims: 120
+            depth: 2
+            hidden_dims: 120
+            kernel_size: [ 1, 3 ]
+            use_guide: True
+          Head:
+            fc_decay: 0.00001
+      - NRTRHead:
+          nrtr_dim: 384
+          max_text_length: 25
+
 ch_PP-OCRv5_rec_infer:
   model_type: rec
   algorithm: SVTR_HGNet

+ 8 - 4
magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/models_config.yml

@@ -1,14 +1,18 @@
 lang:
+  ch_lite:
+    det: ch_PP-OCRv3_det_infer.pth
+    rec: ch_PP-OCRv5_rec_infer.pth
+    dict: ppocrv5_dict.txt
   ch_lite_v4:
     det: ch_PP-OCRv3_det_infer.pth
     rec: ch_PP-OCRv4_rec_infer.pth
     dict: ppocr_keys_v1.txt
-  ch_lite:
-    det: ch_PP-OCRv5_det_infer.pth
-    rec: ch_PP-OCRv5_rec_infer.pth
-    dict: ppocrv5_dict.txt
   ch_server:
     det: ch_PP-OCRv3_det_infer.pth
+    rec: ch_PP-OCRv5_rec_server_infer.pth
+    dict: ppocrv5_dict.txt
+  ch_server_v4:
+    det: ch_PP-OCRv3_det_infer.pth
     rec: ch_PP-OCRv4_rec_server_infer.pth
     dict: ppocr_keys_v1.txt
   ch: