UniMERNet.yaml 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. Global:
  2. use_gpu: True
  3. epoch_num: 40
  4. log_smooth_window: 10
  5. print_batch_step: 10
  6. save_model_dir: ./output/rec/unimernet/
  7. save_epoch_step: 5
  8. # evaluation is run every 37880 iterations after the 0th iteration
  9. eval_batch_step: [0, 10]
  10. cal_metric_during_train: True
  11. pretrained_model:
  12. checkpoints:
  13. save_inference_dir:
  14. use_visualdl: False
  15. infer_img: doc/datasets/pme_demo/0000013.png
  16. infer_mode: False
  17. use_space_char: False
  18. rec_char_dict_path: &rec_char_dict_path ppocr/utils/dict/unimernet_tokenizer
  19. input_size: &input_size [192, 672]
  20. max_seq_len: &max_seq_len 1024
  21. save_res_path: ./output/rec/predicts_unimernet.txt
  22. allow_resize_largeImg: False
  23. d2s_train_image_shape: [1,192,672]
  24. Optimizer:
  25. name: AdamW
  26. beta1: 0.9
  27. beta2: 0.999
  28. weight_decay: 0.05
  29. lr:
  30. name: LinearWarmupCosine
  31. learning_rate: 1e-4
  32. start_lr: 1e-5
  33. min_lr: 1e-8
  34. warmup_steps: 5000
  35. Architecture:
  36. model_type: rec
  37. algorithm: UniMERNet
  38. in_channels: 3
  39. Transform:
  40. Backbone:
  41. name: DonutSwinModel
  42. hidden_size : 1024
  43. num_layers: 4
  44. num_heads: [4, 8, 16, 32]
  45. add_pooling_layer: True
  46. use_mask_token: False
  47. Head:
  48. name: UniMERNetHead
  49. max_new_tokens: 1536
  50. decoder_start_token_id: 0
  51. temperature: 0.2
  52. do_sample: False
  53. top_p: 0.95
  54. encoder_hidden_size: 1024
  55. is_export: False
  56. length_aware: True
  57. Loss:
  58. name: UniMERNetLoss
  59. PostProcess:
  60. name: UniMERNetDecode
  61. rec_char_dict_path: *rec_char_dict_path
  62. Metric:
  63. name: LaTeXOCRMetric
  64. main_indicator: exp_rate
  65. cal_bleu_score: True
  66. Train:
  67. dataset:
  68. name: SimpleDataSet
  69. data_dir: ./train_data/UniMERNet/
  70. label_file_list: ["./train_data/UniMERNet/train_unimernet_1M.txt"]
  71. transforms:
  72. - UniMERNetImgDecode:
  73. input_size: *input_size
  74. - UniMERNetTrainTransform:
  75. - UniMERNetImageFormat:
  76. - UniMERNetLabelEncode:
  77. rec_char_dict_path: *rec_char_dict_path
  78. max_seq_len: *max_seq_len
  79. - KeepKeys:
  80. keep_keys: ['image', 'label', 'attention_mask']
  81. loader:
  82. shuffle: False
  83. drop_last: False
  84. batch_size_per_card: 7
  85. num_workers: 0
  86. collate_fn: UniMERNetCollator
  87. Eval:
  88. dataset:
  89. name: SimpleDataSet
  90. data_dir: ./train_data/UniMERNet/UniMER-Test/cpe
  91. label_file_list: ["./train_data/UniMERNet/test_unimernet_cpe.txt"]
  92. transforms:
  93. - UniMERNetImgDecode:
  94. input_size: *input_size
  95. - UniMERNetTestTransform:
  96. - UniMERNetImageFormat:
  97. - UniMERNetLabelEncode:
  98. max_seq_len: *max_seq_len
  99. rec_char_dict_path: *rec_char_dict_path
  100. - KeepKeys:
  101. keep_keys: ['image', 'label', 'attention_mask']
  102. loader:
  103. shuffle: False
  104. drop_last: False
  105. batch_size_per_card: 30
  106. num_workers: 0
  107. collate_fn: UniMERNetCollator