Explorar o código

repair formual module to support cl train (#3240)

liuhongen1234567 hai 9 meses
pai
achega
1c215dd494

+ 2 - 2
paddlex/configs/modules/formula_recognition/LaTeX_OCR_rec.yaml

@@ -16,8 +16,7 @@ CheckDataset:
 
 Train:
   epochs_iters: 20
-  batch_size_train: 30
-  batch_size_val: 10
+  batch_size: 30
   learning_rate: 0.0001
   pretrain_weight_path: https://paddle-model-ecology.bj.bcebos.com/paddlex/official_pretrained_model/LaTeX_OCR_rec_pretrained.pdparams
   resume_path: null
@@ -26,6 +25,7 @@ Train:
   save_interval: 1
 
 Evaluate:
+  batch_size: 10
   weight_path: output/best_accuracy/best_accuracy.pdparams
   log_interval: 1
 

+ 2 - 2
paddlex/configs/modules/formula_recognition/PP-FormulaNet-L.yaml

@@ -16,8 +16,7 @@ CheckDataset:
 
 Train:
   epochs_iters: 20
-  batch_size_train: 5
-  batch_size_val: 5
+  batch_size: 5
   learning_rate: 0.0001
   pretrain_weight_path: https://paddle-model-ecology.bj.bcebos.com/paddlex/official_pretrained_model/PP-FormulaNet-L_pretrained.pdparams
   resume_path: null
@@ -26,6 +25,7 @@ Train:
   save_interval: 1
 
 Evaluate:
+  batch_size: 5
   weight_path: output/best_accuracy/best_accuracy.pdparams
   log_interval: 1
 

+ 2 - 2
paddlex/configs/modules/formula_recognition/PP-FormulaNet-S.yaml

@@ -16,8 +16,7 @@ CheckDataset:
 
 Train:
   epochs_iters: 20
-  batch_size_train: 30
-  batch_size_val: 10
+  batch_size: 14
   learning_rate: 0.0001
   pretrain_weight_path: https://paddle-model-ecology.bj.bcebos.com/paddlex/official_pretrained_model/PP-FormulaNet-S_pretrained.pdparams
   resume_path: null
@@ -26,6 +25,7 @@ Train:
   save_interval: 1
 
 Evaluate:
+  batch_size: 20
   weight_path: output/best_accuracy/best_accuracy.pdparams
   log_interval: 1
 

+ 2 - 2
paddlex/configs/modules/formula_recognition/UniMERNet.yaml

@@ -16,8 +16,7 @@ CheckDataset:
 
 Train:
   epochs_iters: 20
-  batch_size_train: 7
-  batch_size_val: 20
+  batch_size: 7
   learning_rate: 0.0001
   pretrain_weight_path: https://paddle-model-ecology.bj.bcebos.com/paddlex/official_pretrained_model/UniMERNet_pretrained.pdparams
   resume_path: null
@@ -26,6 +25,7 @@ Train:
   save_interval: 1
 
 Evaluate:
+  batch_size: 20
   weight_path: output/best_accuracy/best_accuracy.pdparams
   log_interval: 1
 

+ 13 - 0
paddlex/modules/formula_recognition/evaluator.py

@@ -52,6 +52,19 @@ class FormulaRecEvaluator(BaseEvaluator):
         if label_dict_path is not None:
             self.pdx_config.update_label_dict_path(label_dict_path)
 
+        if self.eval_config.batch_size is not None:
+            if self.global_config["model"] == "LaTeX_OCR_rec":
+                self.pdx_config.update_batch_size_pair(
+                    self.eval_config.batch_size, mode="eval"
+                )
+            else:
+                self.pdx_config.update_batch_size(
+                    self.eval_config.batch_size, mode="eval"
+                )
+
+        if self.eval_config.get("delimiter", None) is not None:
+            self.pdx_config.update_delimiter(self.eval_config.delimiter, mode="eval")
+
     def get_eval_kwargs(self) -> dict:
         """get key-value arguments of model evalution function
 

+ 22 - 12
paddlex/modules/formula_recognition/trainer.py

@@ -70,25 +70,35 @@ class FormulaRecTrainer(BaseTrainer):
                 self.train_config.pretrain_weight_path
             )
 
-        if self.global_config["model"] == "LaTeX_OCR_rec":
-            if (
-                self.train_config.batch_size_train is not None
-                and self.train_config.batch_size_val is not None
-            ):
+        if self.train_config.batch_size is not None:
+            if self.global_config["model"] == "LaTeX_OCR_rec":
                 self.pdx_config.update_batch_size_pair(
-                    self.train_config.batch_size_train, self.train_config.batch_size_val
+                    self.train_config.batch_size, mode="train"
                 )
-        else:
-            if (
-                self.train_config.batch_size_train is not None
-                and self.train_config.batch_size_val is not None
-            ):
+            else:
                 self.pdx_config.update_batch_size(
-                    self.train_config.batch_size_train, self.train_config.batch_size_val
+                    self.train_config.batch_size, mode="train"
+                )
+
+        if self.eval_config.batch_size is not None:
+            if self.global_config["model"] == "LaTeX_OCR_rec":
+                self.pdx_config.update_batch_size_pair(
+                    self.eval_config.batch_size, mode="eval"
+                )
+            else:
+                self.pdx_config.update_batch_size(
+                    self.eval_config.batch_size, mode="eval"
                 )
 
         if self.train_config.learning_rate is not None:
             self.pdx_config.update_learning_rate(self.train_config.learning_rate)
+
+        if self.train_config.get("delimiter", None) is not None:
+            self.pdx_config.update_delimiter(self.train_config.delimiter, mode="train")
+
+        if self.eval_config.get("delimiter", None) is not None:
+            self.pdx_config.update_delimiter(self.eval_config.delimiter, mode="eval")
+
         if self.train_config.epochs_iters is not None:
             self.pdx_config._update_epochs(self.train_config.epochs_iters)
         if (

+ 2 - 2
paddlex/repo_apis/PaddleOCR_api/configs/LaTeX_OCR_rec.yaml

@@ -9,7 +9,7 @@ Global:
   # evaluation is run every 60000 iterations (22 epoch)(batch_size = 56)
   eval_batch_step: [0, 60000]
   cal_metric_during_train: True
-  pretrained_model: https://paddle-model-ecology.bj.bcebos.com/pretrained/rec_latex_ocr_trained.pdparams
+  pretrained_model: 
   checkpoints:
   save_inference_dir:
   use_visualdl: False
@@ -66,7 +66,7 @@ PostProcess:
 Metric:
   name: LaTeXOCRMetric
   main_indicator:  exp_rate
-  cal_blue_score: False
+  cal_bleu_score: True
 
 Train:
   dataset:

+ 1 - 1
paddlex/repo_apis/PaddleOCR_api/configs/PP-FormulaNet-L.yaml

@@ -69,7 +69,7 @@ PostProcess:
 Metric:
   name: LaTeXOCRMetric
   main_indicator:  exp_rate
-  cal_blue_score: False
+  cal_bleu_score: True
 
 Train:
   dataset:

+ 1 - 1
paddlex/repo_apis/PaddleOCR_api/configs/PP-FormulaNet-S.yaml

@@ -67,7 +67,7 @@ PostProcess:
 Metric:
   name: LaTeXOCRMetric
   main_indicator:  exp_rate
-  cal_blue_score: False
+  cal_bleu_score: True
 
 Train:
   dataset:

+ 1 - 1
paddlex/repo_apis/PaddleOCR_api/configs/UniMERNet.yaml

@@ -66,7 +66,7 @@ PostProcess:
 Metric:
   name: LaTeXOCRMetric
   main_indicator:  exp_rate
-  cal_blue_score: False
+  cal_bleu_score: True
 
 Train:
   dataset:

+ 46 - 20
paddlex/repo_apis/PaddleOCR_api/formula_rec/config.py

@@ -107,42 +107,50 @@ class FormulaRecConfig(BaseConfig):
         else:
             raise ValueError(f"{repr(dataset_type)} is not supported.")
 
-    def update_batch_size(
-        self, batch_size_train: int, batch_size_val: int, mode: str = "train"
-    ):
-        """update batch size setting
+    def update_batch_size(self, batch_size: int, mode: str = "train"):
+        """update batch size setting for SimpleDataSet
 
         Args:
             batch_size (int): the batch size number to set.
-            mode (str, optional): the mode that to be set batch size, must be one of 'train', 'eval', 'test'.
+            mode (str, optional): the mode that to be set batch size, must be one of 'train', 'eval'
                 Defaults to 'train'.
 
         Raises:
-            ValueError: mode error.
+            ValueError: `mode` error.
         """
 
-        _cfg = {
-            "Train.loader.batch_size_per_card": batch_size_train,
-            "Eval.loader.batch_size_per_card": batch_size_val,
-        }
+        if mode == "train":
+            _cfg = {
+                "Train.loader.batch_size_per_card": batch_size,
+            }
+        elif mode == "eval":
+            _cfg = {
+                "Eval.loader.batch_size_per_card": batch_size,
+            }
+        else:
+            raise ValueError("The input `mode` should be train or eval.")
         self.update(_cfg)
 
-    def update_batch_size_pair(
-        self, batch_size_train: int, batch_size_val: int, mode: str = "train"
-    ):
-        """update batch size setting
+    def update_batch_size_pair(self, batch_size: int, mode: str = "train"):
+        """update batch size setting for LaTeXOCRDataSet
+
         Args:
             batch_size (int): the batch size number to set.
-            mode (str, optional): the mode that to be set batch size, must be one of 'train', 'eval', 'test'.
+            mode (str, optional): the mode that to be set batch size, must be one of 'train', 'eval'
                 Defaults to 'train'.
+
         Raises:
-            ValueError: mode error.
+            ValueError: `mode` error.
         """
-        _cfg = {
-            "Train.dataset.batch_size_per_pair": batch_size_train,
-            "Eval.dataset.batch_size_per_pair": batch_size_val,
-        }
 
+        if mode == "train":
+            _cfg = {
+                "Train.dataset.batch_size_per_pair": batch_size,
+            }
+        elif mode == "eval":
+            _cfg = {"Eval.dataset.batch_size_per_pair": batch_size}
+        else:
+            raise ValueError("The input `mode` should be train or eval.")
         self.update(_cfg)
 
     def update_learning_rate(self, learning_rate: float):
@@ -376,6 +384,24 @@ class FormulaRecConfig(BaseConfig):
         """
         self._update_eval_interval(eval_start_step, eval_interval)
 
+    def update_delimiter(self, delimiter: str, mode: str = "train"):
+        """update_delimiter
+
+        Args:
+            delimiter (str): the dataset delimiter value to set.
+            mode (str, optional): the mode that to be set batch size, must be one of 'train', 'eval'
+                Defaults to 'train'.
+        """
+        delimiter = delimiter.encode().decode("unicode_escape")
+
+        if mode == "train":
+            _cfg = {"Train.dataset.delimiter": delimiter}
+        elif mode == "eval":
+            _cfg = {"Eval.dataset.delimiter": delimiter}
+        else:
+            raise ValueError("The input `mode` should be train or eval.")
+        self.update(_cfg)
+
     def _update_save_interval(self, save_interval: int):
         """update save interval(by steps)