瀏覽代碼

add ppocrv5 rec

zhangyubo0722 6 月之前
父節點
當前提交
648262b857

+ 39 - 0
paddlex/configs/modules/text_recognition/PP-OCRv5_server_rec.yaml

@@ -0,0 +1,39 @@
+Global:
+  model: PP-OCRv5_server_rec
+  mode: check_dataset # check_dataset/train/evaluate/predict
+  dataset_dir: "/paddle/dataset/paddlex/ocr_rec/ocr_rec_dataset_examples"
+  device: gpu:0,1,2,3
+  output: "output"
+
+CheckDataset:
+  convert: 
+    enable: False
+    src_dataset_type: null
+  split: 
+    enable: False
+    train_percent: null
+    val_percent: null
+
+Train:
+  epochs_iters: 20
+  batch_size: 8
+  learning_rate: 0.001
+  pretrain_weight_path: https://paddle-model-ecology.bj.bcebos.com/paddlex/official_pretrained_model/PP-OCRv5_server_rec_pretrained.pdparams
+  resume_path: null
+  log_interval: 20
+  eval_interval: 1
+  save_interval: 1
+
+Evaluate:
+  weight_path: "output/best_accuracy/best_accuracy.pdparams"
+  log_interval: 1
+
+Export:
+  weight_path: https://paddle-model-ecology.bj.bcebos.com/paddlex/official_pretrained_model/PP-OCRv5_server_rec_pretrained.pdparams
+
+Predict:
+  batch_size: 1
+  model_dir: "output/best_accuracy/inference"
+  input: "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/general_ocr_rec_001.png"
+  kernel_option:
+    run_mode: paddle

+ 2 - 0
paddlex/inference/utils/official_models.py

@@ -338,6 +338,8 @@ PP-LCNet_x1_0_vehicle_attribute_infer.tar",
     "PP-DocBee-2B": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0.0/PP-DocBee-2B_infer.tar",
     "PP-DocBee-7B": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0.0/PP-DocBee-7B_infer.tar",
     "PP-Chart2Table": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0.0/PP-Chart2Table_infer.tar",
+    "PP-OCRv5_server_rec": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0.0/\
+PP-OCRv5_server_rec_infer.tar",
 }
 
 

+ 1 - 0
paddlex/modules/text_recognition/model_list.py

@@ -31,4 +31,5 @@ MODELS = [
     "PP-OCRv4_server_rec_doc",
     "ch_SVTRv2_rec",
     "ch_RepSVTR_rec",
+    "PP-OCRv5_server_rec",
 ]

+ 135 - 0
paddlex/repo_apis/PaddleOCR_api/configs/PP-OCRv5_server_rec.yaml

@@ -0,0 +1,135 @@
+Global:
+  debug: false
+  use_gpu: true
+  epoch_num: 75
+  log_smooth_window: 20
+  print_batch_step: 10
+  save_model_dir: ./output/PP-OCRv5_server_rec
+  save_epoch_step: 1
+  eval_batch_step: [0, 2000]
+  cal_metric_during_train: true
+  calc_epoch_interval: 1
+  pretrained_model: 
+  checkpoints:
+  save_inference_dir:
+  use_visualdl: false
+  infer_img: doc/imgs_words/ch/word_1.jpg
+  character_dict_path: ./ppocr/utils/dict/ppocrv5_dict.txt
+  max_text_length: &max_text_length 25
+  infer_mode: false
+  use_space_char: true
+  distributed: true
+  save_res_path: ./output/rec/predicts_ppocrv5.txt
+  d2s_train_image_shape: [3, 48, 320]
+
+
+Optimizer:
+  name: Adam
+  beta1: 0.9
+  beta2: 0.999
+  lr:
+    name: Cosine
+    learning_rate: 0.0005
+    warmup_epoch: 1
+  regularizer:
+    name: L2
+    factor: 3.0e-05
+
+
+Architecture:
+  model_type: rec
+  algorithm: SVTR_HGNet
+  Transform:
+  Backbone:
+    name: PPHGNetV2_B4
+    text_rec: True
+  Head:
+    name: MultiHead
+    head_list:
+      - CTCHead:
+          Neck:
+            name: svtr
+            dims: 120
+            depth: 2
+            hidden_dims: 120
+            kernel_size: [1, 3]
+            use_guide: True
+          Head:
+            fc_decay: 0.00001
+      - NRTRHead:
+          nrtr_dim: 384
+          max_text_length: *max_text_length
+
+Loss:
+  name: MultiLoss
+  loss_config_list:
+    - CTCLoss:
+    - NRTRLoss:
+
+PostProcess:  
+  name: CTCLabelDecode
+
+Metric:
+  name: RecMetric
+  main_indicator: acc
+
+Train:
+  dataset:
+    name: MultiScaleDataSet
+    ds_width: false
+    data_dir: ./train_data/
+    ext_op_transform_idx: 1
+    label_file_list:
+    - ./train_data/train_list.txt
+    transforms:
+    - DecodeImage:
+        img_mode: BGR
+        channel_first: false
+    - RecAug:
+    - MultiLabelEncode:
+        gtc_encode: NRTRLabelEncode
+    - KeepKeys:
+        keep_keys:
+        - image
+        - label_ctc
+        - label_gtc
+        - length
+        - valid_ratio
+  sampler:
+    name: MultiScaleSampler
+    scales: [[320, 32], [320, 48], [320, 64]]
+    first_bs: &bs 64
+    fix_bs: false
+    divided_factor: [8, 16] # w, h
+    is_training: True
+  loader:
+    shuffle: true
+    batch_size_per_card: *bs
+    drop_last: true
+    num_workers: 16
+Eval:
+  dataset:
+    name: SimpleDataSet
+    data_dir: ./train_data
+    label_file_list:
+    - ./train_data/val_list.txt
+    transforms:
+    - DecodeImage:
+        img_mode: BGR
+        channel_first: false
+    - MultiLabelEncode:
+        gtc_encode: NRTRLabelEncode
+    - RecResizeImg:
+        image_shape: [3, 48, 320]
+    - KeepKeys:
+        keep_keys:
+        - image
+        - label_ctc
+        - label_gtc
+        - length
+        - valid_ratio
+  loader:
+    shuffle: false
+    drop_last: false
+    batch_size_per_card: 128
+    num_workers: 4

+ 9 - 0
paddlex/repo_apis/PaddleOCR_api/text_rec/register.py

@@ -196,3 +196,12 @@ register_model_info(
         "supported_apis": ["train", "evaluate", "predict", "export", "infer"],
     }
 )
+
+register_model_info(
+    {
+        "model_name": "PP-OCRv5_server_rec",
+        "suite": "TextRec",
+        "config_path": osp.join(PDX_CONFIG_DIR, "PP-OCRv5_server_rec.yaml"),
+        "supported_apis": ["train", "evaluate", "predict", "export"],
+    }
+)