Przeglądaj źródła

modify dataset num_samples

jiangjiajun 5 lat temu
rodzic
commit
6baed993fb
2 zmienionych plików z 15 dodań i 6 usunięć
  1. 8 0
      paddlex/cv/datasets/dataset.py
  2. 7 6
      paddlex/cv/models/base.py

+ 8 - 0
paddlex/cv/datasets/dataset.py

@@ -254,3 +254,11 @@ class Dataset:
             buffer_size=self.buffer_size,
             batch_size=batch_size,
             drop_last=drop_last)
+
+    def set_num_samples(self, num_samples):
+        if num_samples > len(self.file_list):
+            logging.warning(
+                "You want set num_samples to {}, but your dataset only has {} samples, so we will keep your dataset num_samples as {}"
+                .format(num_samples, len(self.file_list), len(self.file_list)))
+            num_samples = len(self.file_list)
+        self.num_samples = num_samples

+ 7 - 6
paddlex/cv/models/base.py

@@ -417,7 +417,7 @@ class BaseAPI:
             earlystop = EarlyStop(early_stop_patience, thresh)
         best_accuracy_key = ""
         best_accuracy = -1.0
-        best_model_epoch = 1
+        best_model_epoch = -1
         for i in range(num_epochs):
             records = list()
             step_start_time = time.time()
@@ -490,7 +490,7 @@ class BaseAPI:
                 current_save_dir = osp.join(save_dir, "epoch_{}".format(i + 1))
                 if not osp.isdir(current_save_dir):
                     os.makedirs(current_save_dir)
-                if eval_dataset is not None:
+                if eval_dataset is not None and eval_dataset.num_samples > 0:
                     self.eval_metrics, self.eval_details = self.evaluate(
                         eval_dataset=eval_dataset,
                         batch_size=eval_batch_size,
@@ -522,10 +522,11 @@ class BaseAPI:
                 self.save_model(save_dir=current_save_dir)
                 time_eval_one_epoch = time.time() - eval_epoch_start_time
                 eval_epoch_start_time = time.time()
-                logging.info(
-                    'Current evaluated best model in eval_dataset is epoch_{}, {}={}'
-                    .format(best_model_epoch, best_accuracy_key,
-                            best_accuracy))
+                if best_model_epoch > 0:
+                    logging.info(
+                        'Current evaluated best model in eval_dataset is epoch_{}, {}={}'
+                        .format(best_model_epoch, best_accuracy_key,
+                                best_accuracy))
                 if eval_dataset is not None and early_stop:
                     if earlystop(current_accuracy):
                         break