소스 검색

move the filed to from

gaotingquan 7 달 전
부모
커밋
c4a87b48e1

+ 2 - 0
paddlex/modules/anomaly_detection/trainer.py

@@ -65,4 +65,6 @@ class UadTrainer(BaseTrainer):
             train_args["do_eval"] = True
             train_args["save_interval"] = self.train_config.eval_interval
         train_args["dy2st"] = self.train_config.get("dy2st", False)
+        # amp support 'O1', 'O2', 'OFF'
+        train_args["amp"] = self.train_config.get("amp", "OFF")
         return train_args

+ 1 - 0
paddlex/modules/formula_recognition/trainer.py

@@ -116,4 +116,5 @@ class FormulaRecTrainer(BaseTrainer):
         return {
             "device": self.get_device(),
             "dy2st": self.train_config.get("dy2st", False),
+            "amp": self.train_config.get("amp", "OFF"),  # amp support 'O1', 'O2', 'OFF'
         }

+ 2 - 0
paddlex/modules/image_classification/trainer.py

@@ -77,4 +77,6 @@ class ClsTrainer(BaseTrainer):
         ):
             train_args["resume_path"] = self.train_config.resume_path
         train_args["dy2st"] = self.train_config.get("dy2st", False)
+        # amp support 'O1', 'O2', 'OFF'
+        train_args["amp"] = self.train_config.get("amp", "OFF")
         return train_args

+ 2 - 0
paddlex/modules/m_3d_bev_detection/trainer.py

@@ -61,6 +61,8 @@ class BEVFusionTrainer(BaseTrainer):
         """
         train_args = {"device": self.get_device()}
         train_args["dy2st"] = self.train_config.get("dy2st", False)
+        # amp support 'O1', 'O2', 'OFF'
+        train_args["amp"] = self.train_config.get("amp", "OFF")
         if self.global_config.output is not None:
             train_args["save_dir"] = self.global_config.output
         return train_args

+ 2 - 0
paddlex/modules/multilabel_classification/trainer.py

@@ -80,4 +80,6 @@ class MLClsTrainer(BaseTrainer):
         ):
             train_args["resume_path"] = self.train_config.resume_path
         train_args["dy2st"] = self.train_config.get("dy2st", False)
+        # amp support 'O1', 'O2', 'OFF'
+        train_args["amp"] = self.train_config.get("amp", "OFF")
         return train_args

+ 4 - 1
paddlex/modules/multilingual_speech_recognition/trainer.py

@@ -35,5 +35,8 @@ class WhisperTrainer(BaseTrainer):
         Returns:
             dict: the arguments of training function.
         """
-        train_args = {"device": self.get_device()}
+        train_args = {
+            "device": self.get_device(),
+            "amp": self.train_config.get("amp", "OFF"),  # amp support 'O1', 'O2', 'OFF'
+        }
         return train_args

+ 2 - 0
paddlex/modules/object_detection/trainer.py

@@ -93,4 +93,6 @@ class DetTrainer(BaseTrainer):
         ):
             train_args["resume_path"] = self.train_config.resume_path
         train_args["dy2st"] = self.train_config.get("dy2st", False)
+        # amp support 'O1', 'O2', 'OFF'
+        train_args["amp"] = self.train_config.get("amp", "OFF")
         return train_args

+ 4 - 1
paddlex/modules/open_vocabulary_detection/trainer.py

@@ -37,5 +37,8 @@ class OVDetTrainer(BaseTrainer):
         Returns:
             dict: the arguments of training function.
         """
-        train_args = {"device": self.get_device()}
+        train_args = {
+            "device": self.get_device(),
+            "amp": self.train_config.get("amp", "OFF"),  # amp support 'O1', 'O2', 'OFF'
+        }
         return train_args

+ 4 - 1
paddlex/modules/open_vocabulary_segmentation/trainer.py

@@ -37,5 +37,8 @@ class OVSegTrainer(BaseTrainer):
         Returns:
             dict: the arguments of training function.
         """
-        train_args = {"device": self.get_device()}
+        train_args = {
+            "device": self.get_device(),
+            "amp": self.train_config.get("amp", "OFF"),  # amp support 'O1', 'O2', 'OFF'
+        }
         return train_args

+ 2 - 0
paddlex/modules/semantic_segmentation/trainer.py

@@ -67,4 +67,6 @@ class SegTrainer(BaseTrainer):
         train_args["dy2st"] = self.train_config.get("dy2st", False)
         if self.train_config.get("input_shape") is not None:
             train_args["input_shape"] = self.train_config.input_shape
+        # amp support 'O1', 'O2', 'OFF'
+        train_args["amp"] = self.train_config.get("amp", "OFF")
         return train_args

+ 1 - 0
paddlex/modules/table_recognition/trainer.py

@@ -63,4 +63,5 @@ class TableRecTrainer(BaseTrainer):
         return {
             "device": self.get_device(),
             "dy2st": self.train_config.get("dy2st", False),
+            "amp": self.train_config.get("amp", "OFF"),  # amp support 'O1', 'O2', 'OFF'
         }

+ 1 - 0
paddlex/modules/text_detection/trainer.py

@@ -61,4 +61,5 @@ class TextDetTrainer(BaseTrainer):
         return {
             "device": self.get_device(),
             "dy2st": self.train_config.get("dy2st", False),
+            "amp": self.train_config.get("amp", "OFF"),  # amp support 'O1', 'O2', 'OFF'
         }

+ 1 - 0
paddlex/modules/text_recognition/trainer.py

@@ -101,4 +101,5 @@ class TextRecTrainer(BaseTrainer):
         return {
             "device": self.get_device(),
             "dy2st": self.train_config.get("dy2st", False),
+            "amp": self.train_config.get("amp", "OFF"),  # amp support 'O1', 'O2', 'OFF'
         }

+ 2 - 0
paddlex/modules/ts_anomaly_detection/trainer.py

@@ -107,4 +107,6 @@ training!"
         train_args = {"device": self.get_device(using_device_number=1)}
         if self.global_config.output is not None:
             train_args["save_dir"] = self.global_config.output
+        # amp support 'O1', 'O2', 'OFF'
+        train_args["amp"] = self.train_config.get("amp", "OFF")
         return train_args

+ 2 - 0
paddlex/modules/ts_classification/trainer.py

@@ -102,4 +102,6 @@ training!"
         train_args = {"device": self.get_device(using_device_number=1)}
         if self.global_config.output is not None:
             train_args["save_dir"] = self.global_config.output
+        # amp support 'O1', 'O2', 'OFF'
+        train_args["amp"] = self.train_config.get("amp", "OFF")
         return train_args

+ 2 - 0
paddlex/modules/ts_forecast/trainer.py

@@ -102,4 +102,6 @@ training!"
         train_args = {"device": self.get_device(using_device_number=1)}
         if self.global_config.output is not None:
             train_args["save_dir"] = self.global_config.output
+        # amp support 'O1', 'O2', 'OFF'
+        train_args["amp"] = self.train_config.get("amp", "OFF")
         return train_args

+ 2 - 0
paddlex/modules/video_classification/trainer.py

@@ -83,4 +83,6 @@ class VideoClsTrainer(BaseTrainer):
         ):
             train_args["resume_path"] = self.train_config.resume_path
         train_args["dy2st"] = self.train_config.get("dy2st", False)
+        # amp support 'O1', 'O2', 'OFF'
+        train_args["amp"] = self.train_config.get("amp", "OFF")
         return train_args

+ 2 - 0
paddlex/modules/video_detection/trainer.py

@@ -77,4 +77,6 @@ class VideoDetTrainer(BaseTrainer):
         ):
             train_args["resume_path"] = self.train_config.resume_path
         train_args["dy2st"] = self.train_config.get("dy2st", False)
+        # amp support 'O1', 'O2', 'OFF'
+        train_args["amp"] = self.train_config.get("amp", "OFF")
         return train_args