|
|
@@ -79,7 +79,8 @@ class TextRecConfig(BaseConfig):
|
|
|
train_list_path = f"{train_list_path}"
|
|
|
else:
|
|
|
train_list_path = os.path.join(dataset_path, 'train.txt')
|
|
|
- if (dataset_type == 'TextRecDataset') or (dataset_type=="MSTextRecDataset"):
|
|
|
+ if (dataset_type == 'TextRecDataset') or (
|
|
|
+ dataset_type == "MSTextRecDataset"):
|
|
|
_cfg = {
|
|
|
'Train.dataset.name': dataset_type,
|
|
|
'Train.dataset.data_dir': dataset_path,
|
|
|
@@ -94,18 +95,20 @@ class TextRecConfig(BaseConfig):
|
|
|
self.update(_cfg)
|
|
|
elif dataset_type == "LaTeXOCRDataSet":
|
|
|
_cfg = {
|
|
|
- 'Train.dataset.name': dataset_type,
|
|
|
- 'Train.dataset.data_dir': dataset_path,
|
|
|
- 'Train.dataset.data': os.path.join(dataset_path, "latexocr_train.pkl"),
|
|
|
- 'Train.dataset.label_file_list': [train_list_path],
|
|
|
- 'Eval.dataset.name': dataset_type,
|
|
|
- 'Eval.dataset.data_dir': dataset_path,
|
|
|
- 'Eval.dataset.data': os.path.join(dataset_path, "latexocr_val.pkl"),
|
|
|
- 'Eval.dataset.label_file_list':
|
|
|
- [os.path.join(dataset_path, 'val.txt')],
|
|
|
- 'Global.character_dict_path':
|
|
|
- os.path.join(dataset_path, 'dict.txt')
|
|
|
- }
|
|
|
+ 'Train.dataset.name': dataset_type,
|
|
|
+ 'Train.dataset.data_dir': dataset_path,
|
|
|
+ 'Train.dataset.data':
|
|
|
+ os.path.join(dataset_path, "latexocr_train.pkl"),
|
|
|
+ 'Train.dataset.label_file_list': [train_list_path],
|
|
|
+ 'Eval.dataset.name': dataset_type,
|
|
|
+ 'Eval.dataset.data_dir': dataset_path,
|
|
|
+ 'Eval.dataset.data':
|
|
|
+ os.path.join(dataset_path, "latexocr_val.pkl"),
|
|
|
+ 'Eval.dataset.label_file_list':
|
|
|
+ [os.path.join(dataset_path, 'val.txt')],
|
|
|
+ 'Global.character_dict_path':
|
|
|
+ os.path.join(dataset_path, 'dict.txt')
|
|
|
+ }
|
|
|
self.update(_cfg)
|
|
|
else:
|
|
|
raise ValueError(f"{repr(dataset_type)} is not supported.")
|
|
|
@@ -129,7 +132,10 @@ class TextRecConfig(BaseConfig):
|
|
|
_cfg['Train.sampler.first_bs'] = batch_size
|
|
|
self.update(_cfg)
|
|
|
|
|
|
- def update_batch_size_pair(self, batch_size_train: int, batch_size_val: int, mode: str='train'):
|
|
|
+ def update_batch_size_pair(self,
|
|
|
+ batch_size_train: int,
|
|
|
+ batch_size_val: int,
|
|
|
+ mode: str='train'):
|
|
|
"""update batch size setting
|
|
|
Args:
|
|
|
batch_size (int): the batch size number to set.
|
|
|
@@ -350,7 +356,7 @@ class TextRecConfig(BaseConfig):
|
|
|
cal_metrics (bool): whether or not to calculate metrics during train
|
|
|
"""
|
|
|
assert isinstance(cal_metrics, bool), "cal_metrics should be a bool"
|
|
|
- self.update({'Global.cal_metric_during_train': f'{cal_metrics}'})
|
|
|
+ self.update({'Global.cal_metric_during_train': cal_metrics})
|
|
|
|
|
|
def update_seed(self, seed: int):
|
|
|
"""update seed
|