Selaa lähdekoodia

fix ocr benchmark bug

zhouchangda 1 vuosi sitten
vanhempi
commit
4225538491
1 muutettua tiedostoa jossa 21 lisäystä ja 15 poistoa
  1. 21 15
      paddlex/repo_apis/PaddleOCR_api/text_rec/config.py

+ 21 - 15
paddlex/repo_apis/PaddleOCR_api/text_rec/config.py

@@ -79,7 +79,8 @@ class TextRecConfig(BaseConfig):
             train_list_path = f"{train_list_path}"
         else:
             train_list_path = os.path.join(dataset_path, 'train.txt')
-        if (dataset_type == 'TextRecDataset') or (dataset_type=="MSTextRecDataset"):
+        if (dataset_type == 'TextRecDataset') or (
+                dataset_type == "MSTextRecDataset"):
             _cfg = {
                 'Train.dataset.name': dataset_type,
                 'Train.dataset.data_dir': dataset_path,
@@ -94,18 +95,20 @@ class TextRecConfig(BaseConfig):
             self.update(_cfg)
         elif dataset_type == "LaTeXOCRDataSet":
             _cfg = {
-                    'Train.dataset.name': dataset_type,
-                    'Train.dataset.data_dir': dataset_path,
-                    'Train.dataset.data': os.path.join(dataset_path, "latexocr_train.pkl"),
-                    'Train.dataset.label_file_list': [train_list_path],
-                    'Eval.dataset.name': dataset_type,
-                    'Eval.dataset.data_dir': dataset_path,
-                    'Eval.dataset.data': os.path.join(dataset_path, "latexocr_val.pkl"),
-                    'Eval.dataset.label_file_list':
-                    [os.path.join(dataset_path, 'val.txt')],
-                    'Global.character_dict_path':
-                    os.path.join(dataset_path, 'dict.txt')
-                }
+                'Train.dataset.name': dataset_type,
+                'Train.dataset.data_dir': dataset_path,
+                'Train.dataset.data':
+                os.path.join(dataset_path, "latexocr_train.pkl"),
+                'Train.dataset.label_file_list': [train_list_path],
+                'Eval.dataset.name': dataset_type,
+                'Eval.dataset.data_dir': dataset_path,
+                'Eval.dataset.data':
+                os.path.join(dataset_path, "latexocr_val.pkl"),
+                'Eval.dataset.label_file_list':
+                [os.path.join(dataset_path, 'val.txt')],
+                'Global.character_dict_path':
+                os.path.join(dataset_path, 'dict.txt')
+            }
             self.update(_cfg)
         else:
             raise ValueError(f"{repr(dataset_type)} is not supported.")
@@ -129,7 +132,10 @@ class TextRecConfig(BaseConfig):
             _cfg['Train.sampler.first_bs'] = batch_size
         self.update(_cfg)
 
-    def update_batch_size_pair(self, batch_size_train: int, batch_size_val: int, mode: str='train'):
+    def update_batch_size_pair(self,
+                               batch_size_train: int,
+                               batch_size_val: int,
+                               mode: str='train'):
         """update batch size setting
         Args:
             batch_size (int): the batch size number to set.
@@ -350,7 +356,7 @@ class TextRecConfig(BaseConfig):
             cal_metrics (bool): whether or not to calculate metrics during train
         """
         assert isinstance(cal_metrics, bool), "cal_metrics should be a bool"
-        self.update({'Global.cal_metric_during_train': f'{cal_metrics}'})
+        self.update({'Global.cal_metric_during_train': cal_metrics})
 
     def update_seed(self, seed: int):
         """update seed