SLANet_plus_paddleocr.yml 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. Global:
  2. model_name: SLANet_plus # To use static model for inference.
  3. use_gpu: False # 使用CPU
  4. epoch_num: 8 # 可适当增加训练轮数
  5. log_smooth_window: 10 # 日志平滑窗口
  6. print_batch_step: 10 # 每10个batch打印一次日志
  7. save_model_dir: ./output/SLANet_plus
  8. save_epoch_step: 2 # 每2个epoch保存一次模型
  9. # evaluation is run every 331 iterations after the 0th iteration
  10. eval_batch_step: [0, 50] # 更频繁评估
  11. cal_metric_during_train: True
  12. pretrained_model: "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_pretrained_model/SLANet_plus_pretrained.pdparams"
  13. checkpoints:
  14. save_inference_dir: ./output/SLANet_plus/infer
  15. use_visualdl: False
  16. infer_img: paddlex/repo_manager/repos/PaddleOCR/ppstructure/docs/table/table.jpg
  17. # for data or label process
  18. character_dict_path: paddlex/repo_manager/repos/PaddleOCR/ppocr/utils/dict/table_structure_dict_ch.txt
  19. character_type: en
  20. max_text_length: &max_text_length 500
  21. box_format: &box_format xyxyxyxy # 'xywh', 'xyxy', 'xyxyxyxy'
  22. infer_mode: False
  23. use_sync_bn: False # 是否使用同步BN, Mac必须关闭
  24. save_res_path: output/infer
  25. d2s_train_image_shape: [3, 488, 488]
  26. Optimizer:
  27. name: Adam
  28. beta1: 0.9
  29. beta2: 0.999
  30. clip_norm: 5.0
  31. lr:
  32. learning_rate: 0.0008 # 微调学习率
  33. regularizer:
  34. name: 'L2'
  35. factor: 0.00000
  36. Architecture:
  37. model_type: table
  38. algorithm: SLANet
  39. Backbone:
  40. name: PPLCNet
  41. scale: 1.0
  42. pretrained: True
  43. use_ssld: True
  44. Neck:
  45. name: CSPPAN
  46. out_channels: 96
  47. Head:
  48. name: SLAHead
  49. hidden_size: 256
  50. max_text_length: *max_text_length
  51. loc_reg_num: &loc_reg_num 8
  52. Loss:
  53. name: SLALoss
  54. structure_weight: 1.0
  55. loc_weight: 2.0
  56. loc_loss: smooth_l1
  57. PostProcess:
  58. name: TableLabelDecode
  59. merge_no_span_structure: &merge_no_span_structure True
  60. Metric:
  61. name: TableMetric
  62. main_indicator: acc
  63. compute_bbox_metric: False
  64. loc_reg_num: *loc_reg_num
  65. box_format: *box_format
  66. del_thead_tbody: True
  67. Train:
  68. dataset:
  69. name: PubTabDataSet
  70. data_dir: dataset/table_rec_dataset_examples/ # ✅ 正确路径
  71. label_file_list: [dataset/table_rec_dataset_examples/train.txt] # ✅ 正确文件
  72. transforms:
  73. - DecodeImage:
  74. img_mode: BGR
  75. channel_first: False
  76. - TableLabelEncode:
  77. learn_empty_box: False
  78. merge_no_span_structure: *merge_no_span_structure
  79. replace_empty_cell_token: False
  80. loc_reg_num: *loc_reg_num
  81. max_text_length: *max_text_length
  82. - TableBoxEncode:
  83. in_box_format: *box_format
  84. out_box_format: *box_format
  85. - ResizeTableImage:
  86. max_len: 488
  87. - NormalizeImage:
  88. scale: 1./255.
  89. mean: [0.485, 0.456, 0.406]
  90. std: [0.229, 0.224, 0.225]
  91. order: 'hwc'
  92. - PaddingTableImage:
  93. size: [488, 488]
  94. - ToCHWImage:
  95. - KeepKeys:
  96. keep_keys: ['image', 'structure', 'bboxes', 'bbox_masks', 'length', 'shape']
  97. loader:
  98. shuffle: True
  99. batch_size_per_card: 32 # 减小batch_size提高稳定性
  100. drop_last: True # 丢弃不完整的最后一个batch
  101. num_workers: 2 # 减少worker数量,避免CPU负载过高
  102. Eval:
  103. dataset:
  104. name: PubTabDataSet
  105. data_dir: dataset/table_rec_dataset_examples/ # ✅ 正确路径
  106. label_file_list: [dataset/table_rec_dataset_examples/val.txt] # ✅ 正确文件
  107. ratio_list: [1.0] # 添加这一行,长度与 label_file_list 相同
  108. transforms:
  109. - DecodeImage:
  110. img_mode: BGR
  111. channel_first: False
  112. - TableLabelEncode:
  113. learn_empty_box: False
  114. merge_no_span_structure: *merge_no_span_structure
  115. replace_empty_cell_token: False
  116. loc_reg_num: *loc_reg_num
  117. max_text_length: *max_text_length
  118. - TableBoxEncode:
  119. in_box_format: *box_format
  120. out_box_format: *box_format
  121. - ResizeTableImage:
  122. max_len: 488
  123. - NormalizeImage:
  124. scale: 1./255.
  125. mean: [0.485, 0.456, 0.406]
  126. std: [0.229, 0.224, 0.225]
  127. order: 'hwc'
  128. - PaddingTableImage:
  129. size: [488, 488]
  130. - ToCHWImage:
  131. - KeepKeys:
  132. keep_keys: ['image', 'structure', 'bboxes', 'bbox_masks', 'length', 'shape']
  133. loader:
  134. shuffle: False
  135. drop_last: False
  136. batch_size_per_card: 48
  137. num_workers: 1