Jelajahi Sumber

fis the post quant

sunyanfang01 5 tahun lalu
induk
melakukan
1ee6d162d1
1 mengubah file dengan 10 tambahan dan 17 penghapusan
  1. 10 17
      paddlex/cv/models/base.py

+ 10 - 17
paddlex/cv/models/base.py

@@ -15,6 +15,7 @@
 from __future__ import absolute_import
 import paddle.fluid as fluid
 import os
+import sys
 import numpy as np
 import time
 import math
@@ -252,6 +253,9 @@ class BaseAPI:
             del self.init_params['self']
         if '__class__' in self.init_params:
             del self.init_params['__class__']
+        if 'model_name' in self.init_params:
+            del self.init_params['model_name']
+
         info['_init_params'] = self.init_params
 
         info['_Attributes']['num_classes'] = self.num_classes
@@ -372,6 +376,8 @@ class BaseAPI:
                    use_vdl=False,
                    early_stop=False,
                    early_stop_patience=5):
+        if train_dataset.num_samples < train_batch_size:
+            raise Exception('The amount of training datset must be larger than batch size.')
         if not osp.isdir(save_dir):
             if osp.exists(save_dir):
                 os.remove(save_dir)
@@ -429,9 +435,7 @@ class BaseAPI:
 
         if use_vdl:
             # VisualDL component
-            log_writer = LogWriter(vdl_logdir, sync_cycle=20)
-            train_step_component = OrderedDict()
-            eval_component = OrderedDict()
+            log_writer = LogWriter(vdl_logdir)
 
         thresh = 0.0001
         if early_stop:
@@ -469,13 +473,7 @@ class BaseAPI:
 
                     if use_vdl:
                         for k, v in step_metrics.items():
-                            if k not in train_step_component.keys():
-                                with log_writer.mode('Each_Step_while_Training'
-                                                     ) as step_logger:
-                                    train_step_component[
-                                        k] = step_logger.scalar(
-                                            'Training: {}'.format(k))
-                            train_step_component[k].add_record(num_steps, v)
+                            log_writer.add_scalar('Metrics/Training(Step): {}'.format(k), v, num_steps)
 
                     # 估算剩余时间
                     avg_step_time = np.mean(time_stat)
@@ -536,12 +534,7 @@ class BaseAPI:
                             if isinstance(v, np.ndarray):
                                 if v.size > 1:
                                     continue
-                            if k not in eval_component:
-                                with log_writer.mode('Each_Epoch_on_Eval_Data'
-                                                     ) as eval_logger:
-                                    eval_component[k] = eval_logger.scalar(
-                                        'Evaluation: {}'.format(k))
-                            eval_component[k].add_record(i + 1, v)
+                            log_writer.add_scalar("Metrics/Eval(Epoch): {}".format(k), v, i+1)
                 self.save_model(save_dir=current_save_dir)
                 time_eval_one_epoch = time.time() - eval_epoch_start_time
                 eval_epoch_start_time = time.time()
@@ -552,4 +545,4 @@ class BaseAPI:
                                 best_accuracy))
                 if eval_dataset is not None and early_stop:
                     if earlystop(current_accuracy):
-                        break
+                        break