utils.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import psutil
  15. import shutil
  16. import os
  17. import os.path as osp
  18. from enum import Enum
  19. import multiprocessing as mp
  20. from queue import Queue
  21. import time
  22. import threading
  23. from ctypes import CDLL, c_char, c_uint, c_ulonglong
  24. from _ctypes import byref, Structure, POINTER
  25. import platform
  26. import string
  27. import logging
  28. import socket
  29. import logging.handlers
  30. import requests
  31. from json import JSONEncoder
  32. class CustomEncoder(JSONEncoder):
  33. def default(self, o):
  34. try:
  35. return o.__dict__
  36. except AttributeError:
  37. return o.tolist()
  38. class ShareData():
  39. workspace = None
  40. workspace_dir = ""
  41. has_gpu = True
  42. monitored_processes = mp.Queue(4096)
  43. port = 5000
  44. current_port = 8000
  45. running_boards = {}
  46. machine_info = dict()
  47. load_demo_proc_dict = {}
  48. load_demo_proj_data_dict = {}
  49. DatasetStatus = Enum(
  50. 'DatasetStatus', ('XEMPTY', 'XCHECKING', 'XCHECKFAIL', 'XCOPYING',
  51. 'XCOPYDONE', 'XCOPYFAIL', 'XSPLITED'),
  52. start=0)
  53. TaskStatus = Enum(
  54. 'TaskStatus', ('XUNINIT', 'XINIT', 'XDOWNLOADING', 'XTRAINING',
  55. 'XTRAINDONE', 'XEVALUATED', 'XEXPORTING', 'XEXPORTED',
  56. 'XTRAINEXIT', 'XDOWNLOADFAIL', 'XTRAINFAIL', 'XEVALUATING',
  57. 'XEVALUATEFAIL', 'XEXPORTFAIL', 'XPRUNEING', 'XPRUNETRAIN'),
  58. start=0)
  59. ProjectType = Enum(
  60. 'ProjectType', ('classification', 'detection', 'segmentation',
  61. 'instance_segmentation', 'remote_segmentation'),
  62. start=0)
  63. DownloadStatus = Enum(
  64. 'DownloadStatus',
  65. ('XDDOWNLOADING', 'XDDOWNLOADFAIL', 'XDDOWNLOADDONE', 'XDDECOMPRESSED'),
  66. start=0)
  67. PredictStatus = Enum(
  68. 'PredictStatus', ('XPRESTART', 'XPREDONE', 'XPREFAIL'), start=0)
  69. PruneStatus = Enum(
  70. 'PruneStatus', ('XSPRUNESTART', 'XSPRUNEING', 'XSPRUNEDONE', 'XSPRUNEFAIL',
  71. 'XSPRUNEEXIT'),
  72. start=0)
  73. PretrainedModelStatus = Enum(
  74. 'PretrainedModelStatus',
  75. ('XPINIT', 'XPSAVING', 'XPSAVEFAIL', 'XPSAVEDONE'),
  76. start=0)
  77. ExportedModelType = Enum(
  78. 'ExportedModelType', ('XQUANTMOBILE', 'XPRUNEMOBILE', 'XTRAINMOBILE',
  79. 'XQUANTSERVER', 'XPRUNESERVER', 'XTRAINSERVER'),
  80. start=0)
  81. translate_chinese_table = {
  82. "Confusion_matrix": "各个类别之间的混淆矩阵",
  83. "Precision": "精准率",
  84. "Accuracy": "准确率",
  85. "Recall": "召回率",
  86. "Class": "类别",
  87. "Topk": "K取值",
  88. "Auc": "AUC",
  89. "Per_ap": "类别平均精准率",
  90. "Map": "类别平均精准率(AP)的均值(mAP)",
  91. "Mean_iou": "平均交并比",
  92. "Mean_acc": "平均准确率",
  93. "Category_iou": "各类别交并比",
  94. "Category_acc": "各类别准确率",
  95. "Ap": "平均精准率",
  96. "F1": "F1-score",
  97. "Iou": "交并比"
  98. }
  99. translate_chinese = {
  100. "Confusion_matrix": "混淆矩阵",
  101. "Mask_confusion_matrix": "Mask混淆矩阵",
  102. "Bbox_confusion_matrix": "Bbox混淆矩阵",
  103. "Precision": "精准率(Precision)",
  104. "Accuracy": "准确率(Accuracy)",
  105. "Recall": "召回率(Recall)",
  106. "Class": "类别(Class)",
  107. "PRF1": "整体分类评估结果",
  108. "PRF1_TOPk": "TopK评估结果",
  109. "Topk": "K取值",
  110. "AUC": "Area Under Curve",
  111. "Auc": "Area Under Curve",
  112. "F1": "F1-score",
  113. "Iou": "交并比(IoU)",
  114. "Per_ap": "各类别的平均精准率(AP)",
  115. "mAP": "平均精准率的均值(mAP)",
  116. "Mask_mAP": "Mask的平均精准率的均值(mAP)",
  117. "BBox_mAP": "Bbox的平均精准率的均值(mAP)",
  118. "Mean_iou": "平均交并比(mIoU)",
  119. "Mean_acc": "平均准确率(mAcc)",
  120. "Ap": "平均精准率(Average Precision)",
  121. "Category_iou": "各类别的交并比(IoU)",
  122. "Category_acc": "各类别的准确率(Accuracy)",
  123. "PRAP": "整体检测评估结果",
  124. "BBox_PRAP": "Bbox评估结果",
  125. "Mask_PRAP": "Mask评估结果",
  126. "Overall": "整体平均指标",
  127. "PRF1_average": "整体平均指标",
  128. "overall_det": "整体平均指标",
  129. "PRIoU": "整体平均指标",
  130. "Acc1": "预测Top1的准确率",
  131. "Acck": "预测Top{}的准确率"
  132. }
  133. process_pool = Queue(1000)
  134. def get_ip():
  135. try:
  136. s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
  137. s.connect(('8.8.8.8', 80))
  138. ip = s.getsockname()[0]
  139. finally:
  140. s.close()
  141. return ip
  142. def get_logger(filename):
  143. flask_logger = logging.getLogger()
  144. flask_logger.setLevel(level=logging.INFO)
  145. fmt = '%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s:%(message)s'
  146. format_str = logging.Formatter(fmt)
  147. ch = logging.StreamHandler()
  148. ch.setLevel(level=logging.INFO)
  149. ch.setFormatter(format_str)
  150. th = logging.handlers.TimedRotatingFileHandler(
  151. filename=filename, when='D', backupCount=5, encoding='utf-8')
  152. th.setFormatter(format_str)
  153. flask_logger.addHandler(th)
  154. flask_logger.addHandler(ch)
  155. return flask_logger
  156. def start_process(target, args):
  157. global process_pool
  158. p = mp.Process(target=target, args=args)
  159. p.start()
  160. process_pool.put(p)
  161. def pkill(pid):
  162. """结束进程pid,和与其相关的子进程
  163. Args:
  164. pid(int): 进程id
  165. """
  166. try:
  167. parent = psutil.Process(pid)
  168. for child in parent.children(recursive=True):
  169. child.kill()
  170. parent.kill()
  171. except:
  172. print("Try to kill process {} failed.".format(pid))
  173. def set_folder_status(dirname, status, message=""):
  174. """设置目录状态
  175. Args:
  176. dirname(str): 目录路径
  177. status(DatasetStatus): 状态
  178. message(str): 需要写到状态文件里的信息
  179. """
  180. if not osp.isdir(dirname):
  181. raise Exception("目录路径{}不存在".format(dirname))
  182. tmp_file = osp.join(dirname, status.name + '.tmp')
  183. with open(tmp_file, 'w', encoding='utf-8') as f:
  184. f.write("{}\n".format(message))
  185. shutil.move(tmp_file, osp.join(dirname, status.name))
  186. for status_type in [
  187. DatasetStatus, TaskStatus, PredictStatus, PruneStatus,
  188. DownloadStatus, PretrainedModelStatus
  189. ]:
  190. for s in status_type:
  191. if s == status:
  192. continue
  193. if osp.exists(osp.join(dirname, s.name)):
  194. os.remove(osp.join(dirname, s.name))
  195. def get_folder_status(dirname, with_message=False):
  196. """获取目录状态
  197. Args:
  198. dirname(str): 目录路径
  199. with_message(bool): 是否需要返回状态文件内的信息
  200. """
  201. status = None
  202. closest_time = 0
  203. message = ''
  204. for status_type in [
  205. DatasetStatus, TaskStatus, PredictStatus, PruneStatus,
  206. DownloadStatus, PretrainedModelStatus
  207. ]:
  208. for s in status_type:
  209. if osp.exists(osp.join(dirname, s.name)):
  210. modify_time = os.stat(osp.join(dirname, s.name)).st_mtime
  211. if modify_time > closest_time:
  212. closest_time = modify_time
  213. status = getattr(status_type, s.name)
  214. if with_message:
  215. encoding = 'utf-8'
  216. try:
  217. f = open(
  218. osp.join(dirname, s.name),
  219. 'r',
  220. encoding=encoding)
  221. message = f.read()
  222. f.close()
  223. except:
  224. try:
  225. import chardet
  226. f = open(filename, 'rb')
  227. data = f.read()
  228. f.close()
  229. encoding = chardet.detect(data).get('encoding')
  230. f = open(
  231. osp.join(dirname, s.name),
  232. 'r',
  233. encoding=encoding)
  234. message = f.read()
  235. f.close()
  236. except:
  237. pass
  238. if with_message:
  239. return status, message
  240. return status
  241. def _machine_check_proc(queue, path):
  242. info = dict()
  243. p = PyNvml()
  244. gpu_num = 0
  245. try:
  246. # import paddle.fluid.core as core
  247. # gpu_num = core.get_cuda_device_count()
  248. p.nvml_init(path)
  249. gpu_num = p.nvml_device_get_count()
  250. driver_version = bytes.decode(p.nvml_system_get_driver_version())
  251. except Exception as e:
  252. gpu_num = 0
  253. driver_version = "N/A"
  254. info['gpu_num'] = gpu_num
  255. info['gpu_free_mem'] = list()
  256. try:
  257. for i in range(gpu_num):
  258. handle = p.nvml_device_get_handle_by_index(i)
  259. meminfo = p.nvml_device_get_memory_info(handle)
  260. free_mem = meminfo.free / 1024 / 1024
  261. info['gpu_free_mem'].append(free_mem)
  262. except:
  263. pass
  264. info['cpu_num'] = int(os.environ.get('CPU_NUM', 1))
  265. info['driver_version'] = driver_version
  266. info['path'] = p.nvml_lib_path
  267. queue.put(info, timeout=3)
  268. def get_machine_info(path=None):
  269. queue = mp.Queue(1)
  270. p = mp.Process(target=_machine_check_proc, args=(queue, path))
  271. p.start()
  272. p.join()
  273. return queue.get(timeout=2)
  274. def download(url, target_path):
  275. if not osp.exists(target_path):
  276. os.makedirs(target_path)
  277. fname = osp.split(url)[-1]
  278. fullname = osp.join(target_path, fname)
  279. retry_cnt = 0
  280. DOWNLOAD_RETRY_LIMIT = 3
  281. while not (osp.exists(fullname)):
  282. if retry_cnt < DOWNLOAD_RETRY_LIMIT:
  283. retry_cnt += 1
  284. else:
  285. # 设置下载失败
  286. msg = "Download from {} failed. Retry limit reached".format(url)
  287. set_folder_status(target_path, DownloadStatus.XDDOWNLOADFAIL, msg)
  288. raise RuntimeError(msg)
  289. req = requests.get(url, stream=True)
  290. if req.status_code != 200:
  291. msg = "Downloading from {} failed with code {}!".format(
  292. url, req.status_code)
  293. set_folder_status(target_path, DownloadStatus.XDDOWNLOADFAIL, msg)
  294. raise RuntimeError(msg)
  295. # For protecting download interupted, download to
  296. # tmp_fullname firstly, move tmp_fullname to fullname
  297. # after download finished
  298. tmp_fullname = fullname + "_tmp"
  299. total_size = req.headers.get('content-length')
  300. set_folder_status(target_path, DownloadStatus.XDDOWNLOADING,
  301. total_size)
  302. with open(tmp_fullname, 'wb') as f:
  303. if total_size:
  304. download_size = 0
  305. for chunk in req.iter_content(chunk_size=1024):
  306. f.write(chunk)
  307. download_size += 1024
  308. else:
  309. for chunk in req.iter_content(chunk_size=1024):
  310. if chunk:
  311. f.write(chunk)
  312. shutil.move(tmp_fullname, fullname)
  313. set_folder_status(target_path, DownloadStatus.XDDOWNLOADDONE)
  314. return fullname
  315. def trans_name(key, in_table=False):
  316. if in_table:
  317. if key in translate_chinese_table:
  318. key = "{}".format(translate_chinese_table[key])
  319. if key.capitalize() in translate_chinese_table:
  320. key = "{}".format(translate_chinese_table[key.capitalize()])
  321. return key
  322. else:
  323. if key in translate_chinese:
  324. key = "{}".format(translate_chinese[key])
  325. if key.capitalize() in translate_chinese:
  326. key = "{}".format(translate_chinese[key.capitalize()])
  327. return key
  328. return key
  329. def is_pic(filename):
  330. suffixes = {'JPEG', 'jpeg', 'JPG', 'jpg', 'BMP', 'bmp', 'PNG', 'png'}
  331. suffix = filename.strip().split('.')[-1]
  332. if suffix not in suffixes:
  333. return False
  334. return True
  335. def is_available(ip, port):
  336. s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  337. try:
  338. s.connect((ip, int(port)))
  339. s.shutdown(2)
  340. return False
  341. except:
  342. return True
  343. def list_files(dirname):
  344. """ 列出目录下所有文件(包括所属的一级子目录下文件)
  345. Args:
  346. dirname: 目录路径
  347. """
  348. def filter_file(f):
  349. if f.startswith('.'):
  350. return True
  351. if hasattr(PretrainedModelStatus, f):
  352. return True
  353. return False
  354. all_files = list()
  355. dirs = list()
  356. for f in os.listdir(dirname):
  357. if filter_file(f):
  358. continue
  359. if osp.isdir(osp.join(dirname, f)):
  360. dirs.append(f)
  361. else:
  362. all_files.append(f)
  363. for d in dirs:
  364. for f in os.listdir(osp.join(dirname, d)):
  365. if filter_file(f):
  366. continue
  367. if osp.isdir(osp.join(dirname, d, f)):
  368. continue
  369. all_files.append(osp.join(d, f))
  370. return all_files
  371. def copy_model_directory(src, dst, files=None, filter_files=[]):
  372. """从src目录copy文件至dst目录,
  373. 注意:拷贝前会先清空dst中的所有文件
  374. Args:
  375. src: 源目录路径
  376. dst: 目标目录路径
  377. files: 需要拷贝的文件列表(src的相对路径)
  378. """
  379. set_folder_status(dst, PretrainedModelStatus.XPSAVING, os.getpid())
  380. if files is None:
  381. files = list_files(src)
  382. try:
  383. message = '{} {}'.format(os.getpid(), len(files))
  384. set_folder_status(dst, PretrainedModelStatus.XPSAVING, message)
  385. if not osp.samefile(src, dst):
  386. for i, f in enumerate(files):
  387. items = osp.split(f)
  388. if len(items) > 2:
  389. continue
  390. if len(items) == 2:
  391. if not osp.isdir(osp.join(dst, items[0])):
  392. if osp.exists(osp.join(dst, items[0])):
  393. os.remove(osp.join(dst, items[0]))
  394. os.makedirs(osp.join(dst, items[0]))
  395. if f not in filter_files:
  396. shutil.copy(osp.join(src, f), osp.join(dst, f))
  397. set_folder_status(dst, PretrainedModelStatus.XPSAVEDONE)
  398. except Exception as e:
  399. import traceback
  400. error_info = traceback.format_exc()
  401. set_folder_status(dst, PretrainedModelStatus.XPSAVEFAIL, error_info)
  402. def copy_pretrained_model(src, dst):
  403. p = mp.Process(
  404. target=copy_model_directory, args=(src, dst, None, ['model.pdopt']))
  405. p.start()
  406. return p
  407. def _get_gpu_info(queue):
  408. gpu_info = dict()
  409. mem_free = list()
  410. mem_used = list()
  411. mem_total = list()
  412. import pycuda.driver as drv
  413. from pycuda.tools import clear_context_caches
  414. drv.init()
  415. driver_version = drv.get_driver_version()
  416. gpu_num = drv.Device.count()
  417. for gpu_id in range(gpu_num):
  418. dev = drv.Device(gpu_id)
  419. try:
  420. context = dev.make_context()
  421. free, total = drv.mem_get_info()
  422. context.pop()
  423. free = free // 1024 // 1024
  424. total = total // 1024 // 1024
  425. used = total - free
  426. except:
  427. free = 0
  428. total = 0
  429. used = 0
  430. mem_free.append(free)
  431. mem_used.append(used)
  432. mem_total.append(total)
  433. gpu_info['mem_free'] = mem_free
  434. gpu_info['mem_used'] = mem_used
  435. gpu_info['mem_total'] = mem_total
  436. gpu_info['driver_version'] = driver_version
  437. gpu_info['gpu_num'] = gpu_num
  438. queue.put(gpu_info)
  439. def get_gpu_info():
  440. try:
  441. import pycuda
  442. except:
  443. gpu_info = dict()
  444. message = "未检测到GPU \n 若存在GPU请确保安装pycuda \n 若未安装pycuda请使用'pip install pycuda'来安装"
  445. gpu_info['gpu_num'] = 0
  446. return gpu_info, message
  447. queue = mp.Queue(1)
  448. p = mp.Process(target=_get_gpu_info, args=(queue, ))
  449. p.start()
  450. p.join()
  451. gpu_info = queue.get(timeout=2)
  452. if gpu_info['gpu_num'] == 0:
  453. message = "未检测到GPU"
  454. else:
  455. message = "检测到GPU"
  456. return gpu_info, message
  457. class TrainLogReader(object):
  458. def __init__(self, log_file):
  459. self.log_file = log_file
  460. self.eta = None
  461. self.train_metrics = None
  462. self.eval_metrics = None
  463. self.download_status = None
  464. self.eval_done = False
  465. self.train_error = None
  466. self.train_stage = None
  467. self.running_duration = None
  468. def update(self):
  469. if not osp.exists(self.log_file):
  470. return
  471. if self.train_stage == "Train Error":
  472. return
  473. if self.download_status == "Failed":
  474. return
  475. if self.train_stage == "Train Complete":
  476. return
  477. logs = open(self.log_file, encoding='utf-8').read().strip().split('\n')
  478. self.eta = None
  479. self.train_metrics = None
  480. self.eval_metrics = None
  481. if self.download_status != "Done":
  482. self.download_status = None
  483. start_time_timestamp = osp.getctime(self.log_file)
  484. for line in logs[::1]:
  485. try:
  486. start_time_str = " ".join(line.split()[0:2])
  487. start_time_array = time.strptime(start_time_str,
  488. "%Y-%m-%d %H:%M:%S")
  489. start_time_timestamp = time.mktime(start_time_array)
  490. break
  491. except Exception as e:
  492. pass
  493. for line in logs[::-1]:
  494. if line.count('Train Complete!'):
  495. self.train_stage = "Train Complete"
  496. if line.count('Training stop with error!'):
  497. self.train_error = line
  498. if self.train_metrics is not None \
  499. and self.eval_metrics is not None and self.eval_done and self.eta is not None:
  500. break
  501. items = line.strip().split()
  502. if line.count('Model saved in'):
  503. self.eval_done = True
  504. if line.count('download completed'):
  505. self.download_status = 'Done'
  506. break
  507. if line.count('download failed'):
  508. self.download_status = 'Failed'
  509. break
  510. if self.download_status != 'Done':
  511. if line.count('[DEBUG]\tDownloading'
  512. ) and self.download_status is None:
  513. self.download_status = dict()
  514. if not line.endswith('KB/s'):
  515. continue
  516. speed = items[-1].strip('KB/s').split('=')[-1]
  517. download = items[-2].strip('M, ').split('=')[-1]
  518. total = items[-3].strip('M, ').split('=')[-1]
  519. self.download_status['speed'] = speed
  520. self.download_status['download'] = float(download)
  521. self.download_status['total'] = float(total)
  522. if self.eta is None:
  523. if line.count('eta') > 0 and (line[-3] == ':' or
  524. line[-4] == ':'):
  525. eta = items[-1].strip().split('=')[1]
  526. h, m, s = [int(x) for x in eta.split(':')]
  527. self.eta = h * 3600 + m * 60 + s
  528. if self.train_metrics is None:
  529. if line.count('[INFO]\t[TRAIN]') > 0 and line.count(
  530. 'Step') > 0:
  531. if not items[-1].startswith('eta'):
  532. continue
  533. self.train_metrics = dict()
  534. metrics = items[4:]
  535. for metric in metrics:
  536. try:
  537. name, value = metric.strip(', ').split('=')
  538. value = value.split('/')[0]
  539. if value.count('.') > 0:
  540. value = float(value)
  541. elif value == 'nan':
  542. value = 'nan'
  543. else:
  544. value = int(value)
  545. self.train_metrics[name] = value
  546. except:
  547. pass
  548. if self.eval_metrics is None:
  549. if line.count('[INFO]\t[EVAL]') > 0 and line.count(
  550. 'Finished') > 0:
  551. if not line.strip().endswith(' .'):
  552. continue
  553. self.eval_metrics = dict()
  554. metrics = items[5:]
  555. for metric in metrics:
  556. try:
  557. name, value = metric.strip(', ').split('=')
  558. value = value.split('/')[0]
  559. if value.count('.') > 0:
  560. value = float(value)
  561. else:
  562. value = int(value)
  563. self.eval_metrics[name] = value
  564. except:
  565. pass
  566. end_time_timestamp = osp.getmtime(self.log_file)
  567. t_diff = time.gmtime(end_time_timestamp - start_time_timestamp)
  568. self.running_duration = "{}小时{}分{}秒".format(
  569. t_diff.tm_hour, t_diff.tm_min, t_diff.tm_sec)
  570. class PruneLogReader(object):
  571. def init_attr(self):
  572. self.eta = None
  573. self.iters = None
  574. self.current = None
  575. self.progress = None
  576. def __init__(self, log_file):
  577. self.log_file = log_file
  578. self.init_attr()
  579. def update(self):
  580. if not osp.exists(self.log_file):
  581. return
  582. logs = open(self.log_file, encoding='utf-8').read().strip().split('\n')
  583. self.init_attr()
  584. for line in logs[::-1]:
  585. metric_loaded = True
  586. for k, v in self.__dict__.items():
  587. if v is None:
  588. metric_loaded = False
  589. break
  590. if metric_loaded:
  591. break
  592. if line.count("Total evaluate iters") > 0:
  593. items = line.split(',')
  594. for item in items:
  595. kv_list = item.strip().split()[-1].split('=')
  596. kv_list = [v.strip() for v in kv_list]
  597. setattr(self, kv_list[0], kv_list[1])
  598. class QuantLogReader:
  599. def __init__(self, log_file):
  600. self.log_file = log_file
  601. self.stage = None
  602. self.running_duration = None
  603. def update(self):
  604. if not osp.exists(self.log_file):
  605. return
  606. logs = open(self.log_file, encoding='utf-8').read().strip().split('\n')
  607. for line in logs[::-1]:
  608. items = line.strip().split(' ')
  609. if line.count('[Run batch data]'):
  610. info = items[-3][:-1].split('=')[1]
  611. batch_id = float(info.split('/')[0])
  612. batch_all = float(info.split('/')[1])
  613. self.running_duration = \
  614. batch_id / batch_all * (10.0 / 30.0)
  615. self.stage = 'Batch'
  616. break
  617. elif line.count('[Calculate weight]'):
  618. info = items[-3][:-1].split('=')[1]
  619. weight_id = float(info.split('/')[0])
  620. weight_all = float(info.split('/')[1])
  621. self.running_duration = \
  622. weight_id / weight_all * (3.0 / 30.0) + (10.0 / 30.0)
  623. self.stage = 'Weight'
  624. break
  625. elif line.count('[Calculate activation]'):
  626. info = items[-3][:-1].split('=')[1]
  627. activation_id = float(info.split('/')[0])
  628. activation_all = float(info.split('/')[1])
  629. self.running_duration = \
  630. activation_id / activation_all * (16.0 / 30.0) + (13.0 / 30.0)
  631. self.stage = 'Activation'
  632. break
  633. elif line.count('Finish quant!'):
  634. self.stage = 'Finish'
  635. break
  636. class PyNvml(object):
  637. """ Nvidia GPU驱动检测类,可检测当前GPU驱动版本"""
  638. class PrintableStructure(Structure):
  639. _fmt_ = {}
  640. def __str__(self):
  641. result = []
  642. for x in self._fields_:
  643. key = x[0]
  644. value = getattr(self, key)
  645. fmt = "%s"
  646. if key in self._fmt_:
  647. fmt = self._fmt_[key]
  648. elif "<default>" in self._fmt_:
  649. fmt = self._fmt_["<default>"]
  650. result.append(("%s: " + fmt) % (key, value))
  651. return self.__class__.__name__ + "(" + string.join(result,
  652. ", ") + ")"
  653. class c_nvmlMemory_t(PrintableStructure):
  654. _fields_ = [
  655. ('total', c_ulonglong),
  656. ('free', c_ulonglong),
  657. ('used', c_ulonglong),
  658. ]
  659. _fmt_ = {'<default>': "%d B"}
  660. ## Device structures
  661. class struct_c_nvmlDevice_t(Structure):
  662. pass # opaque handle
  663. c_nvmlDevice_t = POINTER(struct_c_nvmlDevice_t)
  664. def __init__(self):
  665. self.nvml_lib = None
  666. self.nvml_lib_refcount = 0
  667. self.lib_load_lock = threading.Lock()
  668. self.nvml_lib_path = None
  669. def nvml_init(self, nvml_lib_path=None):
  670. self.lib_load_lock.acquire()
  671. sysstr = platform.system()
  672. if nvml_lib_path is None or nvml_lib_path.strip() == "":
  673. if sysstr == "Windows":
  674. nvml_lib_path = osp.join(
  675. os.getenv("ProgramFiles", "C:/Program Files"),
  676. "NVIDIA Corporation/NVSMI")
  677. if not osp.exists(osp.join(nvml_lib_path, "nvml.dll")):
  678. nvml_lib_path = "C:\\Windows\\System32"
  679. elif sysstr == "Linux":
  680. p1 = "/usr/lib/x86_64-linux-gnu"
  681. p2 = "/usr/lib/i386-linux-gnu"
  682. if osp.exists(osp.join(p1, "libnvidia-ml.so.1")):
  683. nvml_lib_path = p1
  684. elif osp.exists(osp.join(p2, "libnvidia-ml.so.1")):
  685. nvml_lib_path = p2
  686. else:
  687. nvml_lib_path = ""
  688. else:
  689. nvml_lib_path = "N/A"
  690. nvml_lib_dir = nvml_lib_path
  691. if sysstr == "Windows":
  692. nvml_lib_path = osp.join(nvml_lib_dir, "nvml.dll")
  693. else:
  694. nvml_lib_path = osp.join(nvml_lib_dir, "libnvidia-ml.so.1")
  695. self.nvml_lib_path = nvml_lib_path
  696. try:
  697. self.nvml_lib = CDLL(nvml_lib_path)
  698. fn = self._get_fn_ptr("nvmlInit")
  699. fn()
  700. if sysstr == "Windows":
  701. driver_version = bytes.decode(
  702. self.nvml_system_get_driver_version())
  703. if driver_version.strip() == "":
  704. nvml_lib_path = osp.join(nvml_lib_dir, "nvml9.dll")
  705. self.nvml_lib = CDLL(nvml_lib_path)
  706. fn = self._get_fn_ptr("nvmlInit")
  707. fn()
  708. except Exception as e:
  709. raise e
  710. finally:
  711. self.lib_load_lock.release()
  712. self.lib_load_lock.acquire()
  713. self.nvml_lib_refcount += 1
  714. self.lib_load_lock.release()
  715. def create_string_buffer(self, init, size=None):
  716. if isinstance(init, bytes):
  717. if size is None:
  718. size = len(init) + 1
  719. buftype = c_char * size
  720. buf = buftype()
  721. buf.value = init
  722. return buf
  723. elif isinstance(init, int):
  724. buftype = c_char * init
  725. buf = buftype()
  726. return buf
  727. raise TypeError(init)
  728. def _get_fn_ptr(self, name):
  729. return getattr(self.nvml_lib, name)
  730. def nvml_system_get_driver_version(self):
  731. c_version = self.create_string_buffer(81)
  732. fn = self._get_fn_ptr("nvmlSystemGetDriverVersion")
  733. ret = fn(c_version, c_uint(81))
  734. return c_version.value
  735. def nvml_device_get_count(self):
  736. c_count = c_uint()
  737. fn = self._get_fn_ptr("nvmlDeviceGetCount_v2")
  738. ret = fn(byref(c_count))
  739. return c_count.value
  740. def nvml_device_get_handle_by_index(self, index):
  741. c_index = c_uint(index)
  742. device = PyNvml.c_nvmlDevice_t()
  743. fn = self._get_fn_ptr("nvmlDeviceGetHandleByIndex_v2")
  744. ret = fn(c_index, byref(device))
  745. return device
  746. def nvml_device_get_memory_info(self, handle):
  747. c_memory = PyNvml.c_nvmlMemory_t()
  748. fn = self._get_fn_ptr("nvmlDeviceGetMemoryInfo")
  749. ret = fn(handle, byref(c_memory))
  750. return c_memory