浏览代码

support to set dy2st in Train cfg (#1955)

Tingquan Gao 1 年之前
父节点
当前提交
1e256bfd07

+ 1 - 0
paddlex/modules/image_classification/trainer.py

@@ -90,6 +90,7 @@ class ClsTrainer(BaseTrainer):
             and self.train_config.resume_path != ""
         ):
             train_args["resume_path"] = self.train_config.resume_path
+        train_args["dy2st"] = self.train_config.get("dy2st", False)
         return train_args
 
 

+ 1 - 0
paddlex/modules/object_detection/trainer.py

@@ -93,6 +93,7 @@ class DetTrainer(BaseTrainer):
             and self.train_config.resume_path != ""
         ):
             train_args["resume_path"] = self.train_config.resume_path
+        train_args["dy2st"] = self.train_config.get("dy2st", False)
         return train_args
 
 

+ 1 - 0
paddlex/modules/semantic_segmentation/trainer.py

@@ -79,6 +79,7 @@ class SegTrainer(BaseTrainer):
         if self.train_config.eval_interval:
             train_args["do_eval"] = True
             train_args["save_interval"] = self.train_config.eval_interval
+        train_args["dy2st"] = self.train_config.get("dy2st", False)
         return train_args
 
 

+ 1 - 1
paddlex/modules/table_recognition/trainer.py

@@ -76,7 +76,7 @@ class TableRecTrainer(BaseTrainer):
         Returns:
             dict: the arguments of training function.
         """
-        return {"device": self.get_device()}
+        return {"device": self.get_device(), "dy2st": self.train_config.get("dy2st", False)}
 
 
 class TableRecTrainDeamon(BaseTrainDeamon):

+ 1 - 1
paddlex/modules/text_detection/trainer.py

@@ -74,7 +74,7 @@ class TextDetTrainer(BaseTrainer):
         Returns:
             dict: the arguments of training function.
         """
-        return {"device": self.get_device()}
+        return {"device": self.get_device(), "dy2st": self.train_config.get("dy2st", False)}
 
 
 class TextDetTrainDeamon(BaseTrainDeamon):

+ 1 - 1
paddlex/modules/text_recognition/trainer.py

@@ -108,7 +108,7 @@ class TextRecTrainer(BaseTrainer):
         Returns:
             dict: the arguments of training function.
         """
-        return {"device": self.get_device()}
+        return {"device": self.get_device(), "dy2st": self.train_config.get("dy2st", False)}
 
 
 class TextRecTrainDeamon(BaseTrainDeamon):