浏览代码

move logging info out of build_dataloader

will-jl944 4 年之前
父节点
当前提交
79d4164c5c

+ 2 - 13
dygraph/paddlex/cv/models/base.py

@@ -157,22 +157,11 @@ class BaseModel:
         logging.info("Model saved in {}.".format(save_dir))
 
     def build_data_loader(self, dataset, batch_size, mode='train'):
-        batch_size_each_card = get_single_card_bs(batch_size=batch_size)
-        if mode == 'eval':
-            if self.model_type == 'detector':
-                # detector only supports single card eval with batch size 1
-                total_steps = dataset.num_samples
-            else:
-                batch_size = batch_size_each_card
-                total_steps = math.ceil(dataset.num_samples * 1.0 / batch_size)
-            logging.info(
-                "Start to evaluate(total_samples={}, total_steps={})...".
-                format(dataset.num_samples, total_steps))
         if dataset.num_samples < batch_size:
             raise Exception(
-                'The volume of datset({}) must be larger than batch size({}).'
+                'The volume of dataset({}) must be larger than batch size({}).'
                 .format(dataset.num_samples, batch_size))
-
+        batch_size_each_card = get_single_card_bs(batch_size=batch_size)
         # TODO detection eval阶段需做判断
         batch_sampler = DistributedBatchSampler(
             dataset,

+ 5 - 0
dygraph/paddlex/cv/models/classifier.py

@@ -13,6 +13,7 @@
 # limitations under the License.
 
 from __future__ import absolute_import
+import math
 import os.path as osp
 from collections import OrderedDict
 import numpy as np
@@ -271,6 +272,10 @@ class BaseClassifier(BaseModel):
         if return_details:
             eval_details = list()
 
+        logging.info(
+            "Start to evaluate(total_samples={}, total_steps={})...".format(
+                eval_dataset.num_samples,
+                math.ceil(eval_dataset.num_samples * 1.0 / batch_size)))
         with paddle.no_grad():
             for step, data in enumerate(self.eval_data_loader()):
                 outputs = self.run(self.net, data, mode='eval')

+ 4 - 1
dygraph/paddlex/cv/models/detector.py

@@ -27,7 +27,7 @@ from paddlex.cv.nets.ppdet.modeling import *
 from paddlex.cv.nets.ppdet.modeling.post_process import *
 from paddlex.cv.nets.ppdet.modeling.layers import YOLOBox, MultiClassNMS, RCNNBox
 from paddlex.cv.transforms.operators import _NormalizeBox, _PadBox, _BboxXYXY2XYWH
-from paddlex.cv.transforms.batch_operators import BatchCompose, BatchRandomResize, BatchRandomResizeByShort, _BatchPadding, _Gt2YoloTarget, _Permute
+from paddlex.cv.transforms.batch_operators import BatchCompose, BatchRandomResize, BatchRandomResizeByShort, _BatchPadding, _Gt2YoloTarget
 from paddlex.cv.transforms import arrange_transforms
 from .base import BaseModel
 from .utils.det_metrics import VOCMetric, COCOMetric
@@ -337,6 +337,9 @@ class BaseDetector(BaseModel):
                         is_bbox_normalized=is_bbox_normalized,
                         classwise=False)
             scores = collections.OrderedDict()
+            logging.info(
+                "Start to evaluate(total_samples={}, total_steps={})...".
+                format(eval_dataset.num_samples, eval_dataset.num_samples))
             with paddle.no_grad():
                 for step, data in enumerate(self.eval_data_loader):
                     outputs = self.run(self.net, data, 'eval')

+ 5 - 0
dygraph/paddlex/cv/models/segmenter.py

@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import math
 import os.path as osp
 import numpy as np
 from collections import OrderedDict
@@ -271,6 +272,10 @@ class BaseSegmenter(BaseModel):
         intersect_area_all = 0
         pred_area_all = 0
         label_area_all = 0
+        logging.info(
+            "Start to evaluate(total_samples={}, total_steps={})...".format(
+                eval_dataset.num_samples,
+                math.ceil(eval_dataset.num_samples * 1.0 / batch_size)))
         with paddle.no_grad():
             for step, data in enumerate(self.eval_data_loader):
                 data.append(eval_dataset.transforms.transforms)