utils.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823
  1. import psutil
  2. import shutil
  3. import os
  4. import os.path as osp
  5. from enum import Enum
  6. import multiprocessing as mp
  7. from queue import Queue
  8. import time
  9. import threading
  10. from ctypes import CDLL, c_char, c_uint, c_ulonglong
  11. from _ctypes import byref, Structure, POINTER
  12. import platform
  13. import string
  14. import logging
  15. import socket
  16. import logging.handlers
  17. import requests
  18. import json
  19. from json import JSONEncoder
  20. class CustomEncoder(JSONEncoder):
  21. def default(self, o):
  22. return o.__dict__
  23. class ShareData():
  24. workspace = None
  25. workspace_dir = ""
  26. has_gpu = True
  27. monitored_processes = mp.Queue(4096)
  28. current_port = 8000
  29. running_boards = {}
  30. machine_info = dict()
  31. load_demo_proc_dict = {}
  32. load_demo_proj_data_dict = {}
  33. DatasetStatus = Enum(
  34. 'DatasetStatus', ('XEMPTY', 'XCHECKING', 'XCHECKFAIL', 'XCOPYING',
  35. 'XCOPYDONE', 'XCOPYFAIL', 'XSPLITED'),
  36. start=0)
  37. TaskStatus = Enum(
  38. 'TaskStatus', ('XUNINIT', 'XINIT', 'XDOWNLOADING', 'XTRAINING',
  39. 'XTRAINDONE', 'XEVALUATED', 'XEXPORTING', 'XEXPORTED',
  40. 'XTRAINEXIT', 'XDOWNLOADFAIL', 'XTRAINFAIL', 'XEVALUATING',
  41. 'XEVALUATEFAIL', 'XEXPORTFAIL', 'XPRUNEING', 'XPRUNETRAIN'),
  42. start=0)
  43. ProjectType = Enum(
  44. 'ProjectType', ('classification', 'detection', 'segmentation',
  45. 'instance_segmentation', 'remote_segmentation'),
  46. start=0)
  47. DownloadStatus = Enum(
  48. 'DownloadStatus',
  49. ('XDDOWNLOADING', 'XDDOWNLOADFAIL', 'XDDOWNLOADDONE', 'XDDECOMPRESSED'),
  50. start=0)
  51. PredictStatus = Enum(
  52. 'PredictStatus', ('XPRESTART', 'XPREDONE', 'XPREFAIL'), start=0)
  53. PruneStatus = Enum(
  54. 'PruneStatus', ('XSPRUNESTART', 'XSPRUNEING', 'XSPRUNEDONE', 'XSPRUNEFAIL',
  55. 'XSPRUNEEXIT'),
  56. start=0)
  57. PretrainedModelStatus = Enum(
  58. 'PretrainedModelStatus',
  59. ('XPINIT', 'XPSAVING', 'XPSAVEFAIL', 'XPSAVEDONE'),
  60. start=0)
  61. ExportedModelType = Enum(
  62. 'ExportedModelType', ('XQUANTMOBILE', 'XPRUNEMOBILE', 'XTRAINMOBILE',
  63. 'XQUANTSERVER', 'XPRUNESERVER', 'XTRAINSERVER'),
  64. start=0)
  65. translate_chinese_table = {
  66. "Confusion_matrix": "各个类别之间的混淆矩阵",
  67. "Precision": "精准率",
  68. "Accuracy": "准确率",
  69. "Recall": "召回率",
  70. "Class": "类别",
  71. "Topk": "K取值",
  72. "Auc": "AUC",
  73. "Per_ap": "类别平均精准率",
  74. "Map": "类别平均精准率(AP)的均值(mAP)",
  75. "Mean_iou": "平均交并比",
  76. "Mean_acc": "平均准确率",
  77. "Category_iou": "各类别交并比",
  78. "Category_acc": "各类别准确率",
  79. "Ap": "平均精准率",
  80. "F1": "F1-score",
  81. "Iou": "交并比"
  82. }
  83. translate_chinese = {
  84. "Confusion_matrix": "混淆矩阵",
  85. "Mask_confusion_matrix": "Mask混淆矩阵",
  86. "Bbox_confusion_matrix": "Bbox混淆矩阵",
  87. "Precision": "精准率(Precision)",
  88. "Accuracy": "准确率(Accuracy)",
  89. "Recall": "召回率(Recall)",
  90. "Class": "类别(Class)",
  91. "PRF1": "整体分类评估结果",
  92. "PRF1_TOPk": "TopK评估结果",
  93. "Topk": "K取值",
  94. "AUC": "Area Under Curve",
  95. "Auc": "Area Under Curve",
  96. "F1": "F1-score",
  97. "Iou": "交并比(IoU)",
  98. "Per_ap": "各类别的平均精准率(AP)",
  99. "mAP": "平均精准率的均值(mAP)",
  100. "Mask_mAP": "Mask的平均精准率的均值(mAP)",
  101. "BBox_mAP": "Bbox的平均精准率的均值(mAP)",
  102. "Mean_iou": "平均交并比(mIoU)",
  103. "Mean_acc": "平均准确率(mAcc)",
  104. "Ap": "平均精准率(Average Precision)",
  105. "Category_iou": "各类别的交并比(IoU)",
  106. "Category_acc": "各类别的准确率(Accuracy)",
  107. "PRAP": "整体检测评估结果",
  108. "BBox_PRAP": "Bbox评估结果",
  109. "Mask_PRAP": "Mask评估结果",
  110. "Overall": "整体平均指标",
  111. "PRF1_average": "整体平均指标",
  112. "overall_det": "整体平均指标",
  113. "PRIoU": "整体平均指标",
  114. "Acc1": "预测Top1的准确率",
  115. "Acck": "预测Top{}的准确率"
  116. }
  117. process_pool = Queue(1000)
  118. def get_ip():
  119. try:
  120. s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
  121. s.connect(('8.8.8.8', 80))
  122. ip = s.getsockname()[0]
  123. finally:
  124. s.close()
  125. return ip
  126. def get_logger(filename):
  127. flask_logger = logging.getLogger()
  128. flask_logger.setLevel(level=logging.INFO)
  129. fmt = '%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s:%(message)s'
  130. format_str = logging.Formatter(fmt)
  131. ch = logging.StreamHandler()
  132. ch.setLevel(level=logging.INFO)
  133. ch.setFormatter(format_str)
  134. th = logging.handlers.TimedRotatingFileHandler(
  135. filename=filename, when='D', backupCount=5, encoding='utf-8')
  136. th.setFormatter(format_str)
  137. flask_logger.addHandler(th)
  138. flask_logger.addHandler(ch)
  139. return flask_logger
  140. def start_process(target, args):
  141. global process_pool
  142. p = mp.Process(target=target, args=args)
  143. p.start()
  144. process_pool.put(p)
  145. def pkill(pid):
  146. """结束进程pid,和与其相关的子进程
  147. Args:
  148. pid(int): 进程id
  149. """
  150. try:
  151. parent = psutil.Process(pid)
  152. for child in parent.children(recursive=True):
  153. child.kill()
  154. parent.kill()
  155. except:
  156. print("Try to kill process {} failed.".format(pid))
  157. def set_folder_status(dirname, status, message=""):
  158. """设置目录状态
  159. Args:
  160. dirname(str): 目录路径
  161. status(DatasetStatus): 状态
  162. message(str): 需要写到状态文件里的信息
  163. """
  164. if not osp.isdir(dirname):
  165. raise Exception("目录路径{}不存在".format(dirname))
  166. tmp_file = osp.join(dirname, status.name + '.tmp')
  167. with open(tmp_file, 'w', encoding='utf-8') as f:
  168. f.write("{}\n".format(message))
  169. shutil.move(tmp_file, osp.join(dirname, status.name))
  170. for status_type in [
  171. DatasetStatus, TaskStatus, PredictStatus, PruneStatus,
  172. DownloadStatus, PretrainedModelStatus
  173. ]:
  174. for s in status_type:
  175. if s == status:
  176. continue
  177. if osp.exists(osp.join(dirname, s.name)):
  178. os.remove(osp.join(dirname, s.name))
  179. def get_folder_status(dirname, with_message=False):
  180. """获取目录状态
  181. Args:
  182. dirname(str): 目录路径
  183. with_message(bool): 是否需要返回状态文件内的信息
  184. """
  185. status = None
  186. closest_time = 0
  187. message = ''
  188. for status_type in [
  189. DatasetStatus, TaskStatus, PredictStatus, PruneStatus,
  190. DownloadStatus, PretrainedModelStatus
  191. ]:
  192. for s in status_type:
  193. if osp.exists(osp.join(dirname, s.name)):
  194. modify_time = os.stat(osp.join(dirname, s.name)).st_mtime
  195. if modify_time > closest_time:
  196. closest_time = modify_time
  197. status = getattr(status_type, s.name)
  198. if with_message:
  199. encoding = 'utf-8'
  200. try:
  201. f = open(
  202. osp.join(dirname, s.name),
  203. 'r',
  204. encoding=encoding)
  205. message = f.read()
  206. f.close()
  207. except:
  208. try:
  209. import chardet
  210. f = open(filename, 'rb')
  211. data = f.read()
  212. f.close()
  213. encoding = chardet.detect(data).get('encoding')
  214. f = open(
  215. osp.join(dirname, s.name),
  216. 'r',
  217. encoding=encoding)
  218. message = f.read()
  219. f.close()
  220. except:
  221. pass
  222. if with_message:
  223. return status, message
  224. return status
  225. def _machine_check_proc(queue, path):
  226. info = dict()
  227. p = PyNvml()
  228. gpu_num = 0
  229. try:
  230. # import paddle.fluid.core as core
  231. # gpu_num = core.get_cuda_device_count()
  232. p.nvml_init(path)
  233. gpu_num = p.nvml_device_get_count()
  234. driver_version = bytes.decode(p.nvml_system_get_driver_version())
  235. except:
  236. driver_version = "N/A"
  237. info['gpu_num'] = gpu_num
  238. info['gpu_free_mem'] = list()
  239. try:
  240. for i in range(gpu_num):
  241. handle = p.nvml_device_get_handle_by_index(i)
  242. meminfo = p.nvml_device_get_memory_info(handle)
  243. free_mem = meminfo.free / 1024 / 1024
  244. info['gpu_free_mem'].append(free_mem)
  245. except:
  246. pass
  247. info['cpu_num'] = os.environ.get('CPU_NUM', 1)
  248. info['driver_version'] = driver_version
  249. info['path'] = p.nvml_lib_path
  250. queue.put(info, timeout=2)
  251. def get_machine_info(path=None):
  252. queue = mp.Queue(1)
  253. p = mp.Process(target=_machine_check_proc, args=(queue, path))
  254. p.start()
  255. p.join()
  256. return queue.get(timeout=2)
  257. def download(url, target_path):
  258. if not osp.exists(target_path):
  259. os.makedirs(target_path)
  260. fname = osp.split(url)[-1]
  261. fullname = osp.join(target_path, fname)
  262. retry_cnt = 0
  263. DOWNLOAD_RETRY_LIMIT = 3
  264. while not (osp.exists(fullname)):
  265. if retry_cnt < DOWNLOAD_RETRY_LIMIT:
  266. retry_cnt += 1
  267. else:
  268. # 设置下载失败
  269. msg = "Download from {} failed. Retry limit reached".format(url)
  270. set_folder_status(target_path, DownloadStatus.XDDOWNLOADFAIL, msg)
  271. raise RuntimeError(msg)
  272. req = requests.get(url, stream=True)
  273. if req.status_code != 200:
  274. msg = "Downloading from {} failed with code {}!".format(
  275. url, req.status_code)
  276. set_folder_status(target_path, DownloadStatus.XDDOWNLOADFAIL, msg)
  277. raise RuntimeError(msg)
  278. # For protecting download interupted, download to
  279. # tmp_fullname firstly, move tmp_fullname to fullname
  280. # after download finished
  281. tmp_fullname = fullname + "_tmp"
  282. total_size = req.headers.get('content-length')
  283. set_folder_status(target_path, DownloadStatus.XDDOWNLOADING,
  284. total_size)
  285. with open(tmp_fullname, 'wb') as f:
  286. if total_size:
  287. download_size = 0
  288. for chunk in req.iter_content(chunk_size=1024):
  289. f.write(chunk)
  290. download_size += 1024
  291. else:
  292. for chunk in req.iter_content(chunk_size=1024):
  293. if chunk:
  294. f.write(chunk)
  295. shutil.move(tmp_fullname, fullname)
  296. set_folder_status(target_path, DownloadStatus.XDDOWNLOADDONE)
  297. return fullname
  298. def trans_name(key, in_table=False):
  299. if in_table:
  300. if key in translate_chinese_table:
  301. key = "{}".format(translate_chinese_table[key])
  302. if key.capitalize() in translate_chinese_table:
  303. key = "{}".format(translate_chinese_table[key.capitalize()])
  304. return key
  305. else:
  306. if key in translate_chinese:
  307. key = "{}".format(translate_chinese[key])
  308. if key.capitalize() in translate_chinese:
  309. key = "{}".format(translate_chinese[key.capitalize()])
  310. return key
  311. return key
  312. def is_pic(filename):
  313. suffixes = {'JPEG', 'jpeg', 'JPG', 'jpg', 'BMP', 'bmp', 'PNG', 'png'}
  314. suffix = filename.strip().split('.')[-1]
  315. if suffix not in suffixes:
  316. return False
  317. return True
  318. def is_available(ip, port):
  319. s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  320. try:
  321. s.connect((ip, int(port)))
  322. s.shutdown(2)
  323. return False
  324. except:
  325. return True
  326. def list_files(dirname):
  327. """ 列出目录下所有文件(包括所属的一级子目录下文件)
  328. Args:
  329. dirname: 目录路径
  330. """
  331. def filter_file(f):
  332. if f.startswith('.'):
  333. return True
  334. if hasattr(PretrainedModelStatus, f):
  335. return True
  336. return False
  337. all_files = list()
  338. dirs = list()
  339. for f in os.listdir(dirname):
  340. if filter_file(f):
  341. continue
  342. if osp.isdir(osp.join(dirname, f)):
  343. dirs.append(f)
  344. else:
  345. all_files.append(f)
  346. for d in dirs:
  347. for f in os.listdir(osp.join(dirname, d)):
  348. if filter_file(f):
  349. continue
  350. if osp.isdir(osp.join(dirname, d, f)):
  351. continue
  352. all_files.append(osp.join(d, f))
  353. return all_files
  354. def copy_model_directory(src, dst, files=None, filter_files=[]):
  355. """从src目录copy文件至dst目录,
  356. 注意:拷贝前会先清空dst中的所有文件
  357. Args:
  358. src: 源目录路径
  359. dst: 目标目录路径
  360. files: 需要拷贝的文件列表(src的相对路径)
  361. """
  362. set_folder_status(dst, PretrainedModelStatus.XPSAVING, os.getpid())
  363. if files is None:
  364. files = list_files(src)
  365. try:
  366. message = '{} {}'.format(os.getpid(), len(files))
  367. set_folder_status(dst, PretrainedModelStatus.XPSAVING, message)
  368. if not osp.samefile(src, dst):
  369. for i, f in enumerate(files):
  370. items = osp.split(f)
  371. if len(items) > 2:
  372. continue
  373. if len(items) == 2:
  374. if not osp.isdir(osp.join(dst, items[0])):
  375. if osp.exists(osp.join(dst, items[0])):
  376. os.remove(osp.join(dst, items[0]))
  377. os.makedirs(osp.join(dst, items[0]))
  378. if f not in filter_files:
  379. shutil.copy(osp.join(src, f), osp.join(dst, f))
  380. set_folder_status(dst, PretrainedModelStatus.XPSAVEDONE)
  381. except Exception as e:
  382. import traceback
  383. error_info = traceback.format_exc()
  384. set_folder_status(dst, PretrainedModelStatus.XPSAVEFAIL, error_info)
  385. def copy_pretrained_model(src, dst):
  386. p = mp.Process(
  387. target=copy_model_directory, args=(src, dst, None, ['model.pdopt']))
  388. p.start()
  389. return p
  390. def _get_gpu_info(queue):
  391. gpu_info = dict()
  392. mem_free = list()
  393. mem_used = list()
  394. mem_total = list()
  395. import pycuda.driver as drv
  396. from pycuda.tools import clear_context_caches
  397. drv.init()
  398. driver_version = drv.get_driver_version()
  399. gpu_num = drv.Device.count()
  400. for gpu_id in range(gpu_num):
  401. dev = drv.Device(gpu_id)
  402. try:
  403. context = dev.make_context()
  404. free, total = drv.mem_get_info()
  405. context.pop()
  406. free = free // 1024 // 1024
  407. total = total // 1024 // 1024
  408. used = total - free
  409. except:
  410. free = 0
  411. total = 0
  412. used = 0
  413. mem_free.append(free)
  414. mem_used.append(used)
  415. mem_total.append(total)
  416. gpu_info['mem_free'] = mem_free
  417. gpu_info['mem_used'] = mem_used
  418. gpu_info['mem_total'] = mem_total
  419. gpu_info['driver_version'] = driver_version
  420. gpu_info['gpu_num'] = gpu_num
  421. queue.put(gpu_info)
  422. def get_gpu_info():
  423. try:
  424. import pycuda
  425. except:
  426. gpu_info = dict()
  427. message = "未检测到GPU \n 若存在GPU请确保安装pycuda \n 若未安装pycuda请使用'pip install pycuda'来安装"
  428. gpu_info['gpu_num'] = 0
  429. return gpu_info, message
  430. queue = mp.Queue(1)
  431. p = mp.Process(target=_get_gpu_info, args=(queue, ))
  432. p.start()
  433. p.join()
  434. gpu_info = queue.get(timeout=2)
  435. if gpu_info['gpu_num'] == 0:
  436. message = "未检测到GPU"
  437. else:
  438. message = "检测到GPU"
  439. return gpu_info, message
  440. class TrainLogReader(object):
  441. def __init__(self, log_file):
  442. self.log_file = log_file
  443. self.eta = None
  444. self.train_metrics = None
  445. self.eval_metrics = None
  446. self.download_status = None
  447. self.eval_done = False
  448. self.train_error = None
  449. self.train_stage = None
  450. self.running_duration = None
  451. def update(self):
  452. if not osp.exists(self.log_file):
  453. return
  454. if self.train_stage == "Train Error":
  455. return
  456. if self.download_status == "Failed":
  457. return
  458. if self.train_stage == "Train Complete":
  459. return
  460. logs = open(self.log_file, encoding='utf-8').read().strip().split('\n')
  461. self.eta = None
  462. self.train_metrics = None
  463. self.eval_metrics = None
  464. if self.download_status != "Done":
  465. self.download_status = None
  466. start_time_timestamp = osp.getctime(self.log_file)
  467. for line in logs[::1]:
  468. try:
  469. start_time_str = " ".join(line.split()[0:2])
  470. start_time_array = time.strptime(start_time_str,
  471. "%Y-%m-%d %H:%M:%S")
  472. start_time_timestamp = time.mktime(start_time_array)
  473. break
  474. except Exception as e:
  475. pass
  476. for line in logs[::-1]:
  477. if line.count('Train Complete!'):
  478. self.train_stage = "Train Complete"
  479. if line.count('Training stop with error!'):
  480. self.train_error = line
  481. if self.train_metrics is not None \
  482. and self.eval_metrics is not None and self.eval_done and self.eta is not None:
  483. break
  484. items = line.strip().split()
  485. if line.count('Model saved in'):
  486. self.eval_done = True
  487. if line.count('download completed'):
  488. self.download_status = 'Done'
  489. break
  490. if line.count('download failed'):
  491. self.download_status = 'Failed'
  492. break
  493. if self.download_status != 'Done':
  494. if line.count('[DEBUG]\tDownloading'
  495. ) and self.download_status is None:
  496. self.download_status = dict()
  497. if not line.endswith('KB/s'):
  498. continue
  499. speed = items[-1].strip('KB/s').split('=')[-1]
  500. download = items[-2].strip('M, ').split('=')[-1]
  501. total = items[-3].strip('M, ').split('=')[-1]
  502. self.download_status['speed'] = speed
  503. self.download_status['download'] = float(download)
  504. self.download_status['total'] = float(total)
  505. if self.eta is None:
  506. if line.count('eta') > 0 and (line[-3] == ':' or
  507. line[-4] == ':'):
  508. eta = items[-1].strip().split('=')[1]
  509. h, m, s = [int(x) for x in eta.split(':')]
  510. self.eta = h * 3600 + m * 60 + s
  511. if self.train_metrics is None:
  512. if line.count('[INFO]\t[TRAIN]') > 0 and line.count(
  513. 'Step') > 0:
  514. if not items[-1].startswith('eta'):
  515. continue
  516. self.train_metrics = dict()
  517. metrics = items[4:]
  518. for metric in metrics:
  519. try:
  520. name, value = metric.strip(', ').split('=')
  521. value = value.split('/')[0]
  522. if value.count('.') > 0:
  523. value = float(value)
  524. elif value == 'nan':
  525. value = 'nan'
  526. else:
  527. value = int(value)
  528. self.train_metrics[name] = value
  529. except:
  530. pass
  531. if self.eval_metrics is None:
  532. if line.count('[INFO]\t[EVAL]') > 0 and line.count(
  533. 'Finished') > 0:
  534. if not line.strip().endswith(' .'):
  535. continue
  536. self.eval_metrics = dict()
  537. metrics = items[5:]
  538. for metric in metrics:
  539. try:
  540. name, value = metric.strip(', ').split('=')
  541. value = value.split('/')[0]
  542. if value.count('.') > 0:
  543. value = float(value)
  544. else:
  545. value = int(value)
  546. self.eval_metrics[name] = value
  547. except:
  548. pass
  549. end_time_timestamp = osp.getmtime(self.log_file)
  550. t_diff = time.gmtime(end_time_timestamp - start_time_timestamp)
  551. self.running_duration = "{}小时{}分{}秒".format(
  552. t_diff.tm_hour, t_diff.tm_min, t_diff.tm_sec)
  553. class PruneLogReader(object):
  554. def init_attr(self):
  555. self.eta = None
  556. self.iters = None
  557. self.current = None
  558. self.progress = None
  559. def __init__(self, log_file):
  560. self.log_file = log_file
  561. self.init_attr()
  562. def update(self):
  563. if not osp.exists(self.log_file):
  564. return
  565. logs = open(self.log_file, encoding='utf-8').read().strip().split('\n')
  566. self.init_attr()
  567. for line in logs[::-1]:
  568. metric_loaded = True
  569. for k, v in self.__dict__.items():
  570. if v is None:
  571. metric_loaded = False
  572. break
  573. if metric_loaded:
  574. break
  575. if line.count("Total evaluate iters") > 0:
  576. items = line.split(',')
  577. for item in items:
  578. kv_list = item.strip().split()[-1].split('=')
  579. kv_list = [v.strip() for v in kv_list]
  580. setattr(self, kv_list[0], kv_list[1])
  581. class QuantLogReader:
  582. def __init__(self, log_file):
  583. self.log_file = log_file
  584. self.stage = None
  585. self.running_duration = None
  586. def update(self):
  587. if not osp.exists(self.log_file):
  588. return
  589. logs = open(self.log_file, encoding='utf-8').read().strip().split('\n')
  590. for line in logs[::-1]:
  591. items = line.strip().split(' ')
  592. if line.count('[Run batch data]'):
  593. info = items[-3][:-1].split('=')[1]
  594. batch_id = float(info.split('/')[0])
  595. batch_all = float(info.split('/')[1])
  596. self.running_duration = \
  597. batch_id / batch_all * (10.0 / 30.0)
  598. self.stage = 'Batch'
  599. break
  600. elif line.count('[Calculate weight]'):
  601. info = items[-3][:-1].split('=')[1]
  602. weight_id = float(info.split('/')[0])
  603. weight_all = float(info.split('/')[1])
  604. self.running_duration = \
  605. weight_id / weight_all * (3.0 / 30.0) + (10.0 / 30.0)
  606. self.stage = 'Weight'
  607. break
  608. elif line.count('[Calculate activation]'):
  609. info = items[-3][:-1].split('=')[1]
  610. activation_id = float(info.split('/')[0])
  611. activation_all = float(info.split('/')[1])
  612. self.running_duration = \
  613. activation_id / activation_all * (16.0 / 30.0) + (13.0 / 30.0)
  614. self.stage = 'Activation'
  615. break
  616. elif line.count('Finish quant!'):
  617. self.stage = 'Finish'
  618. break
  619. class PyNvml(object):
  620. """ Nvidia GPU驱动检测类,可检测当前GPU驱动版本"""
  621. class PrintableStructure(Structure):
  622. _fmt_ = {}
  623. def __str__(self):
  624. result = []
  625. for x in self._fields_:
  626. key = x[0]
  627. value = getattr(self, key)
  628. fmt = "%s"
  629. if key in self._fmt_:
  630. fmt = self._fmt_[key]
  631. elif "<default>" in self._fmt_:
  632. fmt = self._fmt_["<default>"]
  633. result.append(("%s: " + fmt) % (key, value))
  634. return self.__class__.__name__ + "(" + string.join(result,
  635. ", ") + ")"
  636. class c_nvmlMemory_t(PrintableStructure):
  637. _fields_ = [
  638. ('total', c_ulonglong),
  639. ('free', c_ulonglong),
  640. ('used', c_ulonglong),
  641. ]
  642. _fmt_ = {'<default>': "%d B"}
  643. ## Device structures
  644. class struct_c_nvmlDevice_t(Structure):
  645. pass # opaque handle
  646. c_nvmlDevice_t = POINTER(struct_c_nvmlDevice_t)
  647. def __init__(self):
  648. self.nvml_lib = None
  649. self.nvml_lib_refcount = 0
  650. self.lib_load_lock = threading.Lock()
  651. self.nvml_lib_path = None
  652. def nvml_init(self, nvml_lib_path=None):
  653. self.lib_load_lock.acquire()
  654. sysstr = platform.system()
  655. if nvml_lib_path is None or nvml_lib_path.strip() == "":
  656. if sysstr == "Windows":
  657. nvml_lib_path = osp.join(
  658. os.getenv("ProgramFiles", "C:/Program Files"),
  659. "NVIDIA Corporation/NVSMI")
  660. if not osp.exists(osp.join(nvml_lib_path, "nvml.dll")):
  661. nvml_lib_path = "C:\\Windows\\System32"
  662. elif sysstr == "Linux":
  663. p1 = "/usr/lib/x86_64-linux-gnu"
  664. p2 = "/usr/lib/i386-linux-gnu"
  665. if osp.exists(osp.join(p1, "libnvidia-ml.so.1")):
  666. nvml_lib_path = p1
  667. elif osp.exists(osp.join(p2, "libnvidia-ml.so.1")):
  668. nvml_lib_path = p2
  669. else:
  670. nvml_lib_path = ""
  671. else:
  672. nvml_lib_path = "N/A"
  673. nvml_lib_dir = nvml_lib_path
  674. if sysstr == "Windows":
  675. nvml_lib_path = osp.join(nvml_lib_dir, "nvml.dll")
  676. else:
  677. nvml_lib_path = osp.join(nvml_lib_dir, "libnvidia-ml.so.1")
  678. self.nvml_lib_path = nvml_lib_path
  679. try:
  680. self.nvml_lib = CDLL(nvml_lib_path)
  681. fn = self._get_fn_ptr("nvmlInit_v2")
  682. fn()
  683. if sysstr == "Windows":
  684. driver_version = bytes.decode(
  685. self.nvml_system_get_driver_version())
  686. if driver_version.strip() == "":
  687. nvml_lib_path = osp.join(nvml_lib_dir, "nvml9.dll")
  688. self.nvml_lib = CDLL(nvml_lib_path)
  689. fn = self._get_fn_ptr("nvmlInit_v2")
  690. fn()
  691. except Exception as e:
  692. raise e
  693. finally:
  694. self.lib_load_lock.release()
  695. self.lib_load_lock.acquire()
  696. self.nvml_lib_refcount += 1
  697. self.lib_load_lock.release()
  698. def create_string_buffer(self, init, size=None):
  699. if isinstance(init, bytes):
  700. if size is None:
  701. size = len(init) + 1
  702. buftype = c_char * size
  703. buf = buftype()
  704. buf.value = init
  705. return buf
  706. elif isinstance(init, int):
  707. buftype = c_char * init
  708. buf = buftype()
  709. return buf
  710. raise TypeError(init)
  711. def _get_fn_ptr(self, name):
  712. return getattr(self.nvml_lib, name)
  713. def nvml_system_get_driver_version(self):
  714. c_version = self.create_string_buffer(81)
  715. fn = self._get_fn_ptr("nvmlSystemGetDriverVersion")
  716. ret = fn(c_version, c_uint(81))
  717. return c_version.value
  718. def nvml_device_get_count(self):
  719. c_count = c_uint()
  720. fn = self._get_fn_ptr("nvmlDeviceGetCount_v2")
  721. ret = fn(byref(c_count))
  722. return c_count.value
  723. def nvml_device_get_handle_by_index(self, index):
  724. c_index = c_uint(index)
  725. device = PyNvml.c_nvmlDevice_t()
  726. fn = self._get_fn_ptr("nvmlDeviceGetHandleByIndex_v2")
  727. ret = fn(c_index, byref(device))
  728. return device
  729. def nvml_device_get_memory_info(self, handle):
  730. c_memory = PyNvml.c_nvmlMemory_t()
  731. fn = self._get_fn_ptr("nvmlDeviceGetMemoryInfo")
  732. ret = fn(handle, byref(c_memory))
  733. return c_memory