|
|
@@ -107,42 +107,50 @@ class FormulaRecConfig(BaseConfig):
|
|
|
else:
|
|
|
raise ValueError(f"{repr(dataset_type)} is not supported.")
|
|
|
|
|
|
- def update_batch_size(
|
|
|
- self, batch_size_train: int, batch_size_val: int, mode: str = "train"
|
|
|
- ):
|
|
|
- """update batch size setting
|
|
|
+ def update_batch_size(self, batch_size: int, mode: str = "train"):
|
|
|
+ """update batch size setting for SimpleDataSet
|
|
|
|
|
|
Args:
|
|
|
batch_size (int): the batch size number to set.
|
|
|
- mode (str, optional): the mode that to be set batch size, must be one of 'train', 'eval', 'test'.
|
|
|
+ mode (str, optional): the mode that to be set batch size, must be one of 'train', 'eval'
|
|
|
Defaults to 'train'.
|
|
|
|
|
|
Raises:
|
|
|
- ValueError: mode error.
|
|
|
+ ValueError: `mode` error.
|
|
|
"""
|
|
|
|
|
|
- _cfg = {
|
|
|
- "Train.loader.batch_size_per_card": batch_size_train,
|
|
|
- "Eval.loader.batch_size_per_card": batch_size_val,
|
|
|
- }
|
|
|
+ if mode == "train":
|
|
|
+ _cfg = {
|
|
|
+ "Train.loader.batch_size_per_card": batch_size,
|
|
|
+ }
|
|
|
+ elif mode == "eval":
|
|
|
+ _cfg = {
|
|
|
+ "Eval.loader.batch_size_per_card": batch_size,
|
|
|
+ }
|
|
|
+ else:
|
|
|
+ raise ValueError("The input `mode` should be train or eval.")
|
|
|
self.update(_cfg)
|
|
|
|
|
|
- def update_batch_size_pair(
|
|
|
- self, batch_size_train: int, batch_size_val: int, mode: str = "train"
|
|
|
- ):
|
|
|
- """update batch size setting
|
|
|
+ def update_batch_size_pair(self, batch_size: int, mode: str = "train"):
|
|
|
+ """update batch size setting for LaTeXOCRDataSet
|
|
|
+
|
|
|
Args:
|
|
|
batch_size (int): the batch size number to set.
|
|
|
- mode (str, optional): the mode that to be set batch size, must be one of 'train', 'eval', 'test'.
|
|
|
+ mode (str, optional): the mode that to be set batch size, must be one of 'train', 'eval'
|
|
|
Defaults to 'train'.
|
|
|
+
|
|
|
Raises:
|
|
|
- ValueError: mode error.
|
|
|
+ ValueError: `mode` error.
|
|
|
"""
|
|
|
- _cfg = {
|
|
|
- "Train.dataset.batch_size_per_pair": batch_size_train,
|
|
|
- "Eval.dataset.batch_size_per_pair": batch_size_val,
|
|
|
- }
|
|
|
|
|
|
+ if mode == "train":
|
|
|
+ _cfg = {
|
|
|
+ "Train.dataset.batch_size_per_pair": batch_size,
|
|
|
+ }
|
|
|
+ elif mode == "eval":
|
|
|
+ _cfg = {"Eval.dataset.batch_size_per_pair": batch_size}
|
|
|
+ else:
|
|
|
+ raise ValueError("The input `mode` should be train or eval.")
|
|
|
self.update(_cfg)
|
|
|
|
|
|
def update_learning_rate(self, learning_rate: float):
|
|
|
@@ -376,6 +384,24 @@ class FormulaRecConfig(BaseConfig):
|
|
|
"""
|
|
|
self._update_eval_interval(eval_start_step, eval_interval)
|
|
|
|
|
|
+ def update_delimiter(self, delimiter: str, mode: str = "train"):
|
|
|
+ """update_delimiter
|
|
|
+
|
|
|
+ Args:
|
|
|
+ delimiter (str): the dataset delimiter value to set.
|
|
|
+ mode (str, optional): the mode that to be set batch size, must be one of 'train', 'eval'
|
|
|
+ Defaults to 'train'.
|
|
|
+ """
|
|
|
+ delimiter = delimiter.encode().decode("unicode_escape")
|
|
|
+
|
|
|
+ if mode == "train":
|
|
|
+ _cfg = {"Train.dataset.delimiter": delimiter}
|
|
|
+ elif mode == "eval":
|
|
|
+ _cfg = {"Eval.dataset.delimiter": delimiter}
|
|
|
+ else:
|
|
|
+ raise ValueError("The input `mode` should be train or eval.")
|
|
|
+ self.update(_cfg)
|
|
|
+
|
|
|
def _update_save_interval(self, save_interval: int):
|
|
|
"""update save interval(by steps)
|
|
|
|