浏览代码

add SVTRv2 from PaddleOCR (#1877)

WangZ 1 年之前
父节点
当前提交
f9b97d06fc

+ 37 - 0
paddlex/configs/text_detection/RepSVTR_mobile_det.yaml

@@ -0,0 +1,37 @@
+Global:
+  model: RepSVTR_mobile_det
+  mode: check_dataset # check_dataset/train/evaluate/predict
+  module: text_det
+  dataset_dir: "/paddle/dataset/paddlex/ocr_det/ocr_det_dataset_examples"
+  device: gpu:0,1,2,3
+  output: "output"
+
+CheckDataset:
+  convert:
+    enable: False
+    src_dataset_type: null
+  split:
+    enable: False
+    train_percent: null
+    val_percent: null
+
+Train:
+  epochs_iters: 100
+  batch_size: 4
+  learning_rate: 0.001
+  pretrain_weight_path: null
+  resume_path: null
+  log_interval: 10
+  eval_interval: 1
+  save_interval: 1
+
+Evaluate:
+  weight_path: "output/best_accuracy.pdparams"
+  log_interval: 1
+
+Predict:
+  model_dir: "output/best_accuracy"
+  input_path: "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/general_ocr_001.png"
+  kernel_option:
+    run_mode: paddle
+    batch_size: 1

+ 36 - 0
paddlex/configs/text_recognition/RepSVTR_mobile_rec.yaml

@@ -0,0 +1,36 @@
+Global:
+  model: RepSVTR_mobile_rec
+  mode: check_dataset # check_dataset/train/evaluate/predict
+  dataset_dir: "/paddle/dataset/paddlex/ocr_rec/ocr_rec_dataset_examples"
+  device: gpu:0,1,2,3
+  output: "output"
+
+CheckDataset:
+  convert: 
+    enable: False
+    src_dataset_type: null
+  split: 
+    enable: False
+    train_percent: null
+    val_percent: null
+
+Train:
+  epochs_iters: 20
+  batch_size: 8
+  learning_rate: 0.001
+  pretrain_weight_path: null
+  resume_path: null
+  log_interval: 20
+  eval_interval: 1
+  save_interval: 1
+
+Evaluate:
+  weight_path: "output/best_accuracy.pdparams"
+  log_interval: 1
+
+Predict:
+  model_dir: "output/best_accuracy"
+  input_path: "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/general_ocr_rec_001.png"
+  kernel_option:
+    run_mode: paddle
+    batch_size: 1

+ 36 - 0
paddlex/configs/text_recognition/SVTRv2_server_rec.yaml

@@ -0,0 +1,36 @@
+Global:
+  model: SVTRv2_server_rec
+  mode: check_dataset # check_dataset/train/evaluate/predict
+  dataset_dir: "/paddle/dataset/paddlex/ocr_rec/ocr_rec_dataset_examples"
+  device: gpu:0,1,2,3
+  output: "output"
+
+CheckDataset:
+  convert: 
+    enable: False
+    src_dataset_type: null
+  split: 
+    enable: False
+    train_percent: null
+    val_percent: null
+
+Train:
+  epochs_iters: 20
+  batch_size: 8
+  learning_rate: 0.001
+  pretrain_weight_path: null
+  resume_path: null
+  log_interval: 20
+  eval_interval: 1
+  save_interval: 1
+
+Evaluate:
+  weight_path: "output/best_accuracy.pdparams"
+  log_interval: 1
+
+Predict:
+  model_dir: "output/best_accuracy"
+  input_path: "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/general_ocr_rec_001.png"
+  kernel_option:
+    run_mode: paddle
+    batch_size: 1

+ 9 - 0
paddlex/modules/base/predictor/utils/official_models.py

@@ -251,6 +251,15 @@ PP-OCRv4_server_det_infer.tar",
     "PP-OCRv4_mobile_det":
     "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0/\
 PP-OCRv4_mobile_det_infer.tar",
+    "RepSVTR_mobile_det":
+    "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0/\
+openatom_det_repsvtr_ch_infer.tar",
+    "RepSVTR_mobile_rec":
+    "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0/\
+openatom_rec_repsvtr_ch_infer.tar",
+    "SVTRv2_server_rec":
+    "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0/\
+openatom_rec_svtrv2_ch_infer.tar",
     "PicoDet_layout_1x":
     "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0/PicoDet-L_layout_infer.tar",
     "SLANet":

+ 1 - 1
paddlex/modules/text_detection/model_list.py

@@ -12,8 +12,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
 MODELS = [
     'PP-OCRv4_mobile_det',
     'PP-OCRv4_server_det',
+    'RepSVTR_mobile_det',
 ]

+ 2 - 1
paddlex/modules/text_recognition/model_list.py

@@ -12,8 +12,9 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
 MODELS = [
     'PP-OCRv4_mobile_rec',
     'PP-OCRv4_server_rec',
+    'SVTRv2_server_rec',
+    'RepSVTR_mobile_rec',
 ]

+ 169 - 0
paddlex/repo_apis/PaddleOCR_api/configs/RepSVTR_mobile_det.yaml

@@ -0,0 +1,169 @@
+Global:
+  debug: false
+  use_gpu: true
+  epoch_num: &epoch_num 500
+  log_smooth_window: 20
+  print_batch_step: 100
+  save_model_dir: ./output/det_repsvtr_db
+  save_epoch_step: 10
+  eval_batch_step:
+  - 0
+  - 1000
+  cal_metric_during_train: false
+  checkpoints:
+  pretrained_model: https://paddleocr.bj.bcebos.com/pretrained/ch_SVTR_det_mobile_trained.pdparams
+  save_inference_dir: null
+  use_visualdl: false
+  infer_img: doc/imgs_en/img_10.jpg
+  save_res_path: ./checkpoints/det_db/predicts_db.txt
+  distributed: true
+
+Architecture:
+  model_type: det
+  algorithm: DB
+  Transform: null
+  Backbone:
+    name: RepSVTR_det
+  Neck:
+    name: RSEFPN
+    out_channels: 96
+    shortcut: True
+  Head:
+    name: DBHead
+    k: 50
+
+Loss:
+  name: DBLoss
+  balance_loss: true
+  main_loss_type: DiceLoss
+  alpha: 5
+  beta: 10
+  ohem_ratio: 3
+
+Optimizer:
+  name: Adam
+  beta1: 0.9
+  beta2: 0.999
+  lr:
+    name: Cosine
+    learning_rate: 0.001 #(8*8c)
+    warmup_epoch: 2
+  regularizer:
+    name: L2
+    factor: 5.0e-05
+
+PostProcess:
+  name: DBPostProcess
+  thresh: 0.3
+  box_thresh: 0.6
+  max_candidates: 1000
+  unclip_ratio: 1.5
+
+Metric:
+  name: DetMetric
+  main_indicator: hmean
+
+Train:
+  dataset:
+    name: SimpleDataSet
+    data_dir: ./train_data/icdar2015/text_localization/
+    label_file_list:
+      - ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
+    ratio_list: [1.0]
+    transforms:
+    - DecodeImage:
+        img_mode: BGR
+        channel_first: false
+    - DetLabelEncode: null
+    - CopyPaste: null
+    - IaaAugment:
+        augmenter_args:
+        - type: Fliplr
+          args:
+            p: 0.5
+        - type: Affine
+          args:
+            rotate:
+            - -10
+            - 10
+        - type: Resize
+          args:
+            size:
+            - 0.5
+            - 3
+    - EastRandomCropData:
+        size:
+        - 640
+        - 640
+        max_tries: 50
+        keep_ratio: true
+    - MakeBorderMap:
+        shrink_ratio: 0.4
+        thresh_min: 0.3
+        thresh_max: 0.7
+        total_epoch: *epoch_num
+    - MakeShrinkMap:
+        shrink_ratio: 0.4
+        min_text_size: 8
+        total_epoch: *epoch_num
+    - NormalizeImage:
+        scale: 1./255.
+        mean:
+        - 0.485
+        - 0.456
+        - 0.406
+        std:
+        - 0.229
+        - 0.224
+        - 0.225
+        order: hwc
+    - ToCHWImage: null
+    - KeepKeys:
+        keep_keys:
+        - image
+        - threshold_map
+        - threshold_mask
+        - shrink_map
+        - shrink_mask
+  loader:
+    shuffle: true
+    drop_last: false
+    batch_size_per_card: 8
+    num_workers: 8
+
+Eval:
+  dataset:
+    name: SimpleDataSet
+    data_dir: ./train_data/icdar2015/text_localization/
+    label_file_list:
+      - ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
+    transforms:
+    - DecodeImage:
+        img_mode: BGR
+        channel_first: false
+    - DetLabelEncode: null
+    - DetResizeForTest:
+    - NormalizeImage:
+        scale: 1./255.
+        mean:
+        - 0.485
+        - 0.456
+        - 0.406
+        std:
+        - 0.229
+        - 0.224
+        - 0.225
+        order: hwc
+    - ToCHWImage: null
+    - KeepKeys:
+        keep_keys:
+        - image
+        - shape
+        - polys
+        - ignore_tags
+  loader:
+    shuffle: false
+    drop_last: false
+    batch_size_per_card: 1
+    num_workers: 2
+profiler_options: null

+ 134 - 0
paddlex/repo_apis/PaddleOCR_api/configs/RepSVTR_mobile_rec.yaml

@@ -0,0 +1,134 @@
+Global:
+  debug: false
+  use_gpu: true
+  epoch_num: 200
+  log_smooth_window: 20
+  print_batch_step: 10
+  save_model_dir: ./output/rec_repsvtr_ch
+  save_epoch_step: 10
+  eval_batch_step: [0, 1000]
+  cal_metric_during_train: False
+  pretrained_model: https://paddleocr.bj.bcebos.com/pretrained/ch_SVTRv2_rec_mobile_trained.pdparams
+  checkpoints:
+  save_inference_dir:
+  use_visualdl: false
+  infer_img: doc/imgs_words/ch/word_1.jpg
+  character_dict_path: ppocr/utils/ppocr_keys_v1.txt
+  max_text_length: &max_text_length 25
+  infer_mode: false
+  use_space_char: true
+  distributed: true
+  save_res_path: ./output/rec/predicts_repsvtr.txt
+
+Optimizer:
+  name: AdamW
+  beta1: 0.9
+  beta2: 0.999
+  epsilon: 1.e-8
+  weight_decay: 0.025
+  no_weight_decay_name: norm
+  one_dim_param_no_weight_decay: True
+  lr:
+    name: Cosine
+    learning_rate: 0.001 # 8gpus 192bs
+    warmup_epoch: 5
+
+
+Architecture:
+  model_type: rec
+  algorithm: SVTR_HGNet
+  Transform:
+  Backbone:
+    name: RepSVTR
+  Head:
+    name: MultiHead
+    head_list:
+      - CTCHead:
+          Neck:
+            name: svtr
+            dims: 256
+            depth: 2
+            hidden_dims: 256
+            kernel_size: [1, 3]
+            use_guide: True
+          Head:
+            fc_decay: 0.00001
+      - NRTRHead:
+          nrtr_dim: 384
+          max_text_length: *max_text_length
+          num_decoder_layers: 2
+
+Loss:
+  name: MultiLoss
+  loss_config_list:
+    - CTCLoss:
+    - NRTRLoss:
+
+PostProcess:  
+  name: CTCLabelDecode
+
+Metric:
+  name: RecMetric
+  main_indicator: acc
+
+
+Train:
+  dataset:
+    name: MultiScaleDataSet
+    ds_width: false
+    data_dir: ./train_data/
+    ext_op_transform_idx: 1
+    label_file_list:
+    - ./train_data/train_list.txt
+    transforms:
+    - DecodeImage:
+        img_mode: BGR
+        channel_first: false
+    - RecAug:
+    - MultiLabelEncode:
+        gtc_encode: NRTRLabelEncode
+    - KeepKeys:
+        keep_keys:
+        - image
+        - label_ctc
+        - label_gtc
+        - length
+        - valid_ratio
+  sampler:
+    name: MultiScaleSampler
+    scales: [[320, 32], [320, 48], [320, 64]]
+    first_bs: &bs 192
+    fix_bs: false
+    divided_factor: [8, 16] # w, h
+    is_training: True
+  loader:
+    shuffle: true
+    batch_size_per_card: *bs
+    drop_last: true
+    num_workers: 8
+Eval:
+  dataset:
+    name: SimpleDataSet
+    data_dir: ./train_data
+    label_file_list:
+    - ./train_data/val_list.txt
+    transforms:
+    - DecodeImage:
+        img_mode: BGR
+        channel_first: false
+    - MultiLabelEncode:
+        gtc_encode: NRTRLabelEncode
+    - RecResizeImg:
+        image_shape: [3, 48, 320]
+    - KeepKeys:
+        keep_keys:
+        - image
+        - label_ctc
+        - label_gtc
+        - length
+        - valid_ratio
+  loader:
+    shuffle: false
+    drop_last: false
+    batch_size_per_card: 128
+    num_workers: 4

+ 143 - 0
paddlex/repo_apis/PaddleOCR_api/configs/SVTRv2_server_rec.yaml

@@ -0,0 +1,143 @@
+Global:
+  debug: false
+  use_gpu: true
+  epoch_num: 200
+  log_smooth_window: 20
+  print_batch_step: 10
+  save_model_dir: ./output/rec_svtrv2_ch
+  save_epoch_step: 10
+  eval_batch_step: [0, 1000]
+  cal_metric_during_train: False
+  pretrained_model: https://paddleocr.bj.bcebos.com/pretrained/ch_SVTRv2_rec_server_trained.pdparams
+  checkpoints:
+  save_inference_dir:
+  use_visualdl: false
+  infer_img: doc/imgs_words/ch/word_1.jpg
+  character_dict_path: ppocr/utils/ppocr_keys_v1.txt
+  max_text_length: &max_text_length 25
+  infer_mode: false
+  use_space_char: true
+  distributed: true
+  save_res_path: ./output/rec/predicts_svrtv2.txt
+
+
+Optimizer:
+  name: AdamW
+  beta1: 0.9
+  beta2: 0.999
+  epsilon: 1.e-8
+  weight_decay: 0.05
+  no_weight_decay_name: norm
+  one_dim_param_no_weight_decay: True
+  lr:
+    name: Cosine
+    learning_rate: 0.001 # 8gpus 192bs
+    warmup_epoch: 5
+
+
+Architecture:
+  model_type: rec
+  algorithm: SVTR_HGNet
+  Transform:
+  Backbone:
+    name: SVTRv2
+    use_pos_embed: False
+    dims: [128, 256, 384]
+    depths: [6, 6, 6]
+    num_heads: [4, 8, 12]
+    mixer: [['Conv','Conv','Conv','Conv','Conv','Conv'],['Conv','Conv','Global','Global','Global','Global'],['Global','Global','Global','Global','Global','Global']]
+    local_k: [[5, 5], [5, 5], [-1, -1]]
+    sub_k: [[2, 1], [2, 1], [-1, -1]]
+    last_stage: False
+    use_pool: True
+  Head:
+    name: MultiHead
+    head_list:
+      - CTCHead:
+          Neck:
+            name: svtr
+            dims: 256
+            depth: 2
+            hidden_dims: 256
+            kernel_size: [1, 3]
+            use_guide: True
+          Head:
+            fc_decay: 0.00001
+      - NRTRHead:
+          nrtr_dim: 384
+          max_text_length: *max_text_length
+          num_decoder_layers: 2
+
+Loss:
+  name: MultiLoss
+  loss_config_list:
+    - CTCLoss:
+    - NRTRLoss:
+
+PostProcess:  
+  name: CTCLabelDecode
+
+Metric:
+  name: RecMetric
+  main_indicator: acc
+
+Train:
+  dataset:
+    name: MultiScaleDataSet
+    ds_width: false
+    data_dir: ./train_data/
+    ext_op_transform_idx: 1
+    label_file_list:
+    - ./train_data/train_list.txt
+    transforms:
+    - DecodeImage:
+        img_mode: BGR
+        channel_first: false
+    - RecAug:
+    - MultiLabelEncode:
+        gtc_encode: NRTRLabelEncode
+    - KeepKeys:
+        keep_keys:
+        - image
+        - label_ctc
+        - label_gtc
+        - length
+        - valid_ratio
+  sampler:
+    name: MultiScaleSampler
+    scales: [[320, 32], [320, 48], [320, 64]]
+    first_bs: &bs 192
+    fix_bs: false
+    divided_factor: [8, 16] # w, h
+    is_training: True
+  loader:
+    shuffle: true
+    batch_size_per_card: *bs
+    drop_last: true
+    num_workers: 8
+Eval:
+  dataset:
+    name: SimpleDataSet
+    data_dir: ./train_data
+    label_file_list:
+    - ./train_data/val_list.txt
+    transforms:
+    - DecodeImage:
+        img_mode: BGR
+        channel_first: false
+    - MultiLabelEncode:
+        gtc_encode: NRTRLabelEncode
+    - RecResizeImg:
+        image_shape: [3, 48, 320]
+    - KeepKeys:
+        keep_keys:
+        - image
+        - label_ctc
+        - label_gtc
+        - length
+        - valid_ratio
+  loader:
+    shuffle: false
+    drop_last: false
+    batch_size_per_card: 128
+    num_workers: 4

+ 7 - 0
paddlex/repo_apis/PaddleOCR_api/text_det/register.py

@@ -45,3 +45,10 @@ register_model_info({
     'config_path': osp.join(PDX_CONFIG_DIR, 'PP-OCRv4_server_det.yaml'),
     'supported_apis': ['train', 'evaluate', 'predict', 'export']
 })
+
+register_model_info({
+    'model_name': 'RepSVTR_mobile_det',
+    'suite': 'TextDet',
+    'config_path': osp.join(PDX_CONFIG_DIR, 'RepSVTR_mobile_det.yaml'),
+    'supported_apis': ['train', 'evaluate', 'predict', 'export']
+})

+ 14 - 0
paddlex/repo_apis/PaddleOCR_api/text_rec/register.py

@@ -44,3 +44,17 @@ register_model_info({
     'config_path': osp.join(PDX_CONFIG_DIR, 'PP-OCRv4_server_rec.yaml'),
     'supported_apis': ['train', 'evaluate', 'predict', 'export']
 })
+
+register_model_info({
+    'model_name': 'SVTRv2_server_rec',
+    'suite': 'TextRec',
+    'config_path': osp.join(PDX_CONFIG_DIR, 'SVTRv2_server_rec.yaml'),
+    'supported_apis': ['train', 'evaluate', 'predict', 'export', 'infer']
+})
+
+register_model_info({
+    'model_name': 'RepSVTR_mobile_rec',
+    'suite': 'TextRec',
+    'config_path': osp.join(PDX_CONFIG_DIR, 'RepSVTR_mobile_rec.yaml'),
+    'supported_apis': ['train', 'evaluate', 'predict', 'export', 'infer']
+})