瀏覽代碼

support print mem info and save log by ranks

zhouchangda 1 年之前
父節點
當前提交
754acae192
共有 2 個文件被更改,包括 19 次插入1 次删除
  1. 17 0
      paddlex/repo_apis/PaddleClas_api/cls/config.py
  2. 2 1
      paddlex/repo_apis/PaddleClas_api/cls/model.py

+ 17 - 0
paddlex/repo_apis/PaddleClas_api/cls/config.py

@@ -315,6 +315,23 @@ indicating that no pretrained model to be used."
         """
         self.update([f'Global.save_interval={save_interval}'])
 
+    def update_log_ranks(self, device):
+        """update log ranks
+
+        Args:
+            device (str): the running device to set
+        """
+        log_ranks = device.split(':')[1]
+        self.update([f'Global.log_ranks="{log_ranks}"'])
+
+    def enable_print_mem_info(self):
+        """print memory info"""
+        self.update([f'Global.print_mem_info=True'])
+
+    def disable_print_mem_info(self):
+        """do not print memory info"""
+        self.update([f'Global.print_mem_info=False'])
+
     def _update_predict_img(self, infer_img: str, infer_list: str=None):
         """update image to be predicted
 

+ 2 - 1
paddlex/repo_apis/PaddleClas_api/cls/model.py

@@ -12,7 +12,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
 import os
 
 from ...base import BaseModel
@@ -67,6 +66,8 @@ class ClsModel(BaseModel):
             config.update_device(device)
             config._update_to_static(dy2st)
             config._update_use_vdl(use_vdl)
+            config.update_log_ranks(device)
+            config.enable_print_mem_info()
 
             if batch_size is not None:
                 config.update_batch_size(batch_size)