Browse Source

revert base.py

jiangjiajun 5 years ago
parent
commit
b885a6d188
1 changed files with 22 additions and 28 deletions
  1. 22 28
      paddlex/cv/models/base.py

+ 22 - 28
paddlex/cv/models/base.py

@@ -79,9 +79,9 @@ class BaseAPI:
             return int(batch_size // len(self.places))
         else:
             raise Exception("Please support correct batch_size, \
-                            which can be divided by available cards({}) in {}"
-                            .format(paddlex.env_info['num'], paddlex.env_info[
-                                'place']))
+                            which can be divided by available cards({}) in {}".
+                            format(paddlex.env_info['num'],
+                                   paddlex.env_info['place']))
 
     def build_program(self):
         # 构建训练网络
@@ -141,7 +141,7 @@ class BaseAPI:
             from .slim.post_quantization import PaddleXPostTrainingQuantization
         except:
             raise Exception(
-                "Model Quantization is not available, try to upgrade your paddlepaddle>=1.8.0"
+                "Model Quantization is not available, try to upgrade your paddlepaddle>=1.7.0"
             )
         is_use_cache_file = True
         if cache_dir is None:
@@ -209,8 +209,8 @@ class BaseAPI:
             paddlex.utils.utils.load_pretrain_weights(
                 self.exe, self.train_prog, resume_checkpoint, resume=True)
             if not osp.exists(osp.join(resume_checkpoint, "model.yml")):
-                raise Exception("There's not model.yml in {}".format(
-                    resume_checkpoint))
+                raise Exception(
+                    "There's not model.yml in {}".format(resume_checkpoint))
             with open(osp.join(resume_checkpoint, "model.yml")) as f:
                 info = yaml.load(f.read(), Loader=yaml.Loader)
                 self.completed_epochs = info['completed_epochs']
@@ -361,8 +361,8 @@ class BaseAPI:
 
         # 模型保存成功的标志
         open(osp.join(save_dir, '.success'), 'w').close()
-        logging.info("Model for inference deploy saved in {}.".format(
-            save_dir))
+        logging.info(
+            "Model for inference deploy saved in {}.".format(save_dir))
 
     def train_loop(self,
                    num_epochs,
@@ -376,8 +376,7 @@ class BaseAPI:
                    early_stop=False,
                    early_stop_patience=5):
         if train_dataset.num_samples < train_batch_size:
-            raise Exception(
-                'The amount of training datset must be larger than batch size.')
+            raise Exception('The amount of training datset must be larger than batch size.')
         if not osp.isdir(save_dir):
             if osp.exists(save_dir):
                 os.remove(save_dir)
@@ -415,8 +414,8 @@ class BaseAPI:
                     build_strategy=build_strategy,
                     exec_strategy=exec_strategy)
 
-        total_num_steps = math.floor(train_dataset.num_samples /
-                                     train_batch_size)
+        total_num_steps = math.floor(
+            train_dataset.num_samples / train_batch_size)
         num_steps = 0
         time_stat = list()
         time_train_one_epoch = None
@@ -430,8 +429,8 @@ class BaseAPI:
         if self.model_type == 'detector':
             eval_batch_size = self._get_single_card_bs(train_batch_size)
         if eval_dataset is not None:
-            total_num_steps_eval = math.ceil(eval_dataset.num_samples /
-                                             eval_batch_size)
+            total_num_steps_eval = math.ceil(
+                eval_dataset.num_samples / eval_batch_size)
 
         if use_vdl:
             # VisualDL component
@@ -473,9 +472,7 @@ class BaseAPI:
 
                     if use_vdl:
                         for k, v in step_metrics.items():
-                            log_writer.add_scalar(
-                                'Metrics/Training(Step): {}'.format(k), v,
-                                num_steps)
+                            log_writer.add_scalar('Metrics/Training(Step): {}'.format(k), v, num_steps)
 
                     # 估算剩余时间
                     avg_step_time = np.mean(time_stat)
@@ -483,12 +480,11 @@ class BaseAPI:
                         eta = (num_epochs - i - 1) * time_train_one_epoch + (
                             total_num_steps - step - 1) * avg_step_time
                     else:
-                        eta = ((num_epochs - i) * total_num_steps - step - 1
-                               ) * avg_step_time
+                        eta = ((num_epochs - i) * total_num_steps - step -
+                               1) * avg_step_time
                     if time_eval_one_epoch is not None:
-                        eval_eta = (
-                            total_eval_times - i // save_interval_epochs
-                        ) * time_eval_one_epoch
+                        eval_eta = (total_eval_times - i //
+                                    save_interval_epochs) * time_eval_one_epoch
                     else:
                         eval_eta = (
                             total_eval_times - i // save_interval_epochs
@@ -498,11 +494,10 @@ class BaseAPI:
                     logging.info(
                         "[TRAIN] Epoch={}/{}, Step={}/{}, {}, time_each_step={}s, eta={}"
                         .format(i + 1, num_epochs, step + 1, total_num_steps,
-                                dict2str(step_metrics),
-                                round(avg_step_time, 2), eta_str))
+                                dict2str(step_metrics), round(
+                                    avg_step_time, 2), eta_str))
             train_metrics = OrderedDict(
-                zip(list(self.train_outputs.keys()), np.mean(
-                    records, axis=0)))
+                zip(list(self.train_outputs.keys()), np.mean(records, axis=0)))
             logging.info('[TRAIN] Epoch {} finished, {} .'.format(
                 i + 1, dict2str(train_metrics)))
             time_train_one_epoch = time.time() - epoch_start_time
@@ -538,8 +533,7 @@ class BaseAPI:
                             if isinstance(v, np.ndarray):
                                 if v.size > 1:
                                     continue
-                            log_writer.add_scalar(
-                                "Metrics/Eval(Epoch): {}".format(k), v, i + 1)
+                            log_writer.add_scalar("Metrics/Eval(Epoch): {}".format(k), v, i+1)
                 self.save_model(save_dir=current_save_dir)
                 time_eval_one_epoch = time.time() - eval_epoch_start_time
                 eval_epoch_start_time = time.time()