utils.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import psutil
  15. import shutil
  16. import os
  17. import os.path as osp
  18. from enum import Enum
  19. import multiprocessing as mp
  20. from queue import Queue
  21. import time
  22. import threading
  23. from ctypes import CDLL, c_char, c_uint, c_ulonglong
  24. from _ctypes import byref, Structure, POINTER
  25. import platform
  26. import string
  27. import logging
  28. import socket
  29. import logging.handlers
  30. import requests
  31. import json
  32. from json import JSONEncoder
  33. class CustomEncoder(JSONEncoder):
  34. def default(self, o):
  35. return o.__dict__
  36. class ShareData():
  37. workspace = None
  38. workspace_dir = ""
  39. has_gpu = True
  40. monitored_processes = mp.Queue(4096)
  41. current_port = 8000
  42. running_boards = {}
  43. machine_info = dict()
  44. load_demo_proc_dict = {}
  45. load_demo_proj_data_dict = {}
  46. DatasetStatus = Enum(
  47. 'DatasetStatus', ('XEMPTY', 'XCHECKING', 'XCHECKFAIL', 'XCOPYING',
  48. 'XCOPYDONE', 'XCOPYFAIL', 'XSPLITED'),
  49. start=0)
  50. TaskStatus = Enum(
  51. 'TaskStatus', ('XUNINIT', 'XINIT', 'XDOWNLOADING', 'XTRAINING',
  52. 'XTRAINDONE', 'XEVALUATED', 'XEXPORTING', 'XEXPORTED',
  53. 'XTRAINEXIT', 'XDOWNLOADFAIL', 'XTRAINFAIL', 'XEVALUATING',
  54. 'XEVALUATEFAIL', 'XEXPORTFAIL', 'XPRUNEING', 'XPRUNETRAIN'),
  55. start=0)
  56. ProjectType = Enum(
  57. 'ProjectType', ('classification', 'detection', 'segmentation',
  58. 'instance_segmentation', 'remote_segmentation'),
  59. start=0)
  60. DownloadStatus = Enum(
  61. 'DownloadStatus',
  62. ('XDDOWNLOADING', 'XDDOWNLOADFAIL', 'XDDOWNLOADDONE', 'XDDECOMPRESSED'),
  63. start=0)
  64. PredictStatus = Enum(
  65. 'PredictStatus', ('XPRESTART', 'XPREDONE', 'XPREFAIL'), start=0)
  66. PruneStatus = Enum(
  67. 'PruneStatus', ('XSPRUNESTART', 'XSPRUNEING', 'XSPRUNEDONE', 'XSPRUNEFAIL',
  68. 'XSPRUNEEXIT'),
  69. start=0)
  70. PretrainedModelStatus = Enum(
  71. 'PretrainedModelStatus',
  72. ('XPINIT', 'XPSAVING', 'XPSAVEFAIL', 'XPSAVEDONE'),
  73. start=0)
  74. ExportedModelType = Enum(
  75. 'ExportedModelType', ('XQUANTMOBILE', 'XPRUNEMOBILE', 'XTRAINMOBILE',
  76. 'XQUANTSERVER', 'XPRUNESERVER', 'XTRAINSERVER'),
  77. start=0)
  78. translate_chinese_table = {
  79. "Confusion_matrix": "各个类别之间的混淆矩阵",
  80. "Precision": "精准率",
  81. "Accuracy": "准确率",
  82. "Recall": "召回率",
  83. "Class": "类别",
  84. "Topk": "K取值",
  85. "Auc": "AUC",
  86. "Per_ap": "类别平均精准率",
  87. "Map": "类别平均精准率(AP)的均值(mAP)",
  88. "Mean_iou": "平均交并比",
  89. "Mean_acc": "平均准确率",
  90. "Category_iou": "各类别交并比",
  91. "Category_acc": "各类别准确率",
  92. "Ap": "平均精准率",
  93. "F1": "F1-score",
  94. "Iou": "交并比"
  95. }
  96. translate_chinese = {
  97. "Confusion_matrix": "混淆矩阵",
  98. "Mask_confusion_matrix": "Mask混淆矩阵",
  99. "Bbox_confusion_matrix": "Bbox混淆矩阵",
  100. "Precision": "精准率(Precision)",
  101. "Accuracy": "准确率(Accuracy)",
  102. "Recall": "召回率(Recall)",
  103. "Class": "类别(Class)",
  104. "PRF1": "整体分类评估结果",
  105. "PRF1_TOPk": "TopK评估结果",
  106. "Topk": "K取值",
  107. "AUC": "Area Under Curve",
  108. "Auc": "Area Under Curve",
  109. "F1": "F1-score",
  110. "Iou": "交并比(IoU)",
  111. "Per_ap": "各类别的平均精准率(AP)",
  112. "mAP": "平均精准率的均值(mAP)",
  113. "Mask_mAP": "Mask的平均精准率的均值(mAP)",
  114. "BBox_mAP": "Bbox的平均精准率的均值(mAP)",
  115. "Mean_iou": "平均交并比(mIoU)",
  116. "Mean_acc": "平均准确率(mAcc)",
  117. "Ap": "平均精准率(Average Precision)",
  118. "Category_iou": "各类别的交并比(IoU)",
  119. "Category_acc": "各类别的准确率(Accuracy)",
  120. "PRAP": "整体检测评估结果",
  121. "BBox_PRAP": "Bbox评估结果",
  122. "Mask_PRAP": "Mask评估结果",
  123. "Overall": "整体平均指标",
  124. "PRF1_average": "整体平均指标",
  125. "overall_det": "整体平均指标",
  126. "PRIoU": "整体平均指标",
  127. "Acc1": "预测Top1的准确率",
  128. "Acck": "预测Top{}的准确率"
  129. }
  130. process_pool = Queue(1000)
  131. def get_ip():
  132. try:
  133. s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
  134. s.connect(('8.8.8.8', 80))
  135. ip = s.getsockname()[0]
  136. finally:
  137. s.close()
  138. return ip
  139. def get_logger(filename):
  140. flask_logger = logging.getLogger()
  141. flask_logger.setLevel(level=logging.INFO)
  142. fmt = '%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s:%(message)s'
  143. format_str = logging.Formatter(fmt)
  144. ch = logging.StreamHandler()
  145. ch.setLevel(level=logging.INFO)
  146. ch.setFormatter(format_str)
  147. th = logging.handlers.TimedRotatingFileHandler(
  148. filename=filename, when='D', backupCount=5, encoding='utf-8')
  149. th.setFormatter(format_str)
  150. flask_logger.addHandler(th)
  151. flask_logger.addHandler(ch)
  152. return flask_logger
  153. def start_process(target, args):
  154. global process_pool
  155. p = mp.Process(target=target, args=args)
  156. p.start()
  157. process_pool.put(p)
  158. def pkill(pid):
  159. """结束进程pid,和与其相关的子进程
  160. Args:
  161. pid(int): 进程id
  162. """
  163. try:
  164. parent = psutil.Process(pid)
  165. for child in parent.children(recursive=True):
  166. child.kill()
  167. parent.kill()
  168. except:
  169. print("Try to kill process {} failed.".format(pid))
  170. def set_folder_status(dirname, status, message=""):
  171. """设置目录状态
  172. Args:
  173. dirname(str): 目录路径
  174. status(DatasetStatus): 状态
  175. message(str): 需要写到状态文件里的信息
  176. """
  177. if not osp.isdir(dirname):
  178. raise Exception("目录路径{}不存在".format(dirname))
  179. tmp_file = osp.join(dirname, status.name + '.tmp')
  180. with open(tmp_file, 'w', encoding='utf-8') as f:
  181. f.write("{}\n".format(message))
  182. shutil.move(tmp_file, osp.join(dirname, status.name))
  183. for status_type in [
  184. DatasetStatus, TaskStatus, PredictStatus, PruneStatus,
  185. DownloadStatus, PretrainedModelStatus
  186. ]:
  187. for s in status_type:
  188. if s == status:
  189. continue
  190. if osp.exists(osp.join(dirname, s.name)):
  191. os.remove(osp.join(dirname, s.name))
  192. def get_folder_status(dirname, with_message=False):
  193. """获取目录状态
  194. Args:
  195. dirname(str): 目录路径
  196. with_message(bool): 是否需要返回状态文件内的信息
  197. """
  198. status = None
  199. closest_time = 0
  200. message = ''
  201. for status_type in [
  202. DatasetStatus, TaskStatus, PredictStatus, PruneStatus,
  203. DownloadStatus, PretrainedModelStatus
  204. ]:
  205. for s in status_type:
  206. if osp.exists(osp.join(dirname, s.name)):
  207. modify_time = os.stat(osp.join(dirname, s.name)).st_mtime
  208. if modify_time > closest_time:
  209. closest_time = modify_time
  210. status = getattr(status_type, s.name)
  211. if with_message:
  212. encoding = 'utf-8'
  213. try:
  214. f = open(
  215. osp.join(dirname, s.name),
  216. 'r',
  217. encoding=encoding)
  218. message = f.read()
  219. f.close()
  220. except:
  221. try:
  222. import chardet
  223. f = open(filename, 'rb')
  224. data = f.read()
  225. f.close()
  226. encoding = chardet.detect(data).get('encoding')
  227. f = open(
  228. osp.join(dirname, s.name),
  229. 'r',
  230. encoding=encoding)
  231. message = f.read()
  232. f.close()
  233. except:
  234. pass
  235. if with_message:
  236. return status, message
  237. return status
  238. def _machine_check_proc(queue, path):
  239. info = dict()
  240. p = PyNvml()
  241. gpu_num = 0
  242. try:
  243. # import paddle.fluid.core as core
  244. # gpu_num = core.get_cuda_device_count()
  245. p.nvml_init(path)
  246. gpu_num = p.nvml_device_get_count()
  247. driver_version = bytes.decode(p.nvml_system_get_driver_version())
  248. except:
  249. driver_version = "N/A"
  250. info['gpu_num'] = gpu_num
  251. info['gpu_free_mem'] = list()
  252. try:
  253. for i in range(gpu_num):
  254. handle = p.nvml_device_get_handle_by_index(i)
  255. meminfo = p.nvml_device_get_memory_info(handle)
  256. free_mem = meminfo.free / 1024 / 1024
  257. info['gpu_free_mem'].append(free_mem)
  258. except:
  259. pass
  260. info['cpu_num'] = os.environ.get('CPU_NUM', 1)
  261. info['driver_version'] = driver_version
  262. info['path'] = p.nvml_lib_path
  263. queue.put(info, timeout=2)
  264. def get_machine_info(path=None):
  265. queue = mp.Queue(1)
  266. p = mp.Process(target=_machine_check_proc, args=(queue, path))
  267. p.start()
  268. p.join()
  269. return queue.get(timeout=2)
  270. def download(url, target_path):
  271. if not osp.exists(target_path):
  272. os.makedirs(target_path)
  273. fname = osp.split(url)[-1]
  274. fullname = osp.join(target_path, fname)
  275. retry_cnt = 0
  276. DOWNLOAD_RETRY_LIMIT = 3
  277. while not (osp.exists(fullname)):
  278. if retry_cnt < DOWNLOAD_RETRY_LIMIT:
  279. retry_cnt += 1
  280. else:
  281. # 设置下载失败
  282. msg = "Download from {} failed. Retry limit reached".format(url)
  283. set_folder_status(target_path, DownloadStatus.XDDOWNLOADFAIL, msg)
  284. raise RuntimeError(msg)
  285. req = requests.get(url, stream=True)
  286. if req.status_code != 200:
  287. msg = "Downloading from {} failed with code {}!".format(
  288. url, req.status_code)
  289. set_folder_status(target_path, DownloadStatus.XDDOWNLOADFAIL, msg)
  290. raise RuntimeError(msg)
  291. # For protecting download interupted, download to
  292. # tmp_fullname firstly, move tmp_fullname to fullname
  293. # after download finished
  294. tmp_fullname = fullname + "_tmp"
  295. total_size = req.headers.get('content-length')
  296. set_folder_status(target_path, DownloadStatus.XDDOWNLOADING,
  297. total_size)
  298. with open(tmp_fullname, 'wb') as f:
  299. if total_size:
  300. download_size = 0
  301. for chunk in req.iter_content(chunk_size=1024):
  302. f.write(chunk)
  303. download_size += 1024
  304. else:
  305. for chunk in req.iter_content(chunk_size=1024):
  306. if chunk:
  307. f.write(chunk)
  308. shutil.move(tmp_fullname, fullname)
  309. set_folder_status(target_path, DownloadStatus.XDDOWNLOADDONE)
  310. return fullname
  311. def trans_name(key, in_table=False):
  312. if in_table:
  313. if key in translate_chinese_table:
  314. key = "{}".format(translate_chinese_table[key])
  315. if key.capitalize() in translate_chinese_table:
  316. key = "{}".format(translate_chinese_table[key.capitalize()])
  317. return key
  318. else:
  319. if key in translate_chinese:
  320. key = "{}".format(translate_chinese[key])
  321. if key.capitalize() in translate_chinese:
  322. key = "{}".format(translate_chinese[key.capitalize()])
  323. return key
  324. return key
  325. def is_pic(filename):
  326. suffixes = {'JPEG', 'jpeg', 'JPG', 'jpg', 'BMP', 'bmp', 'PNG', 'png'}
  327. suffix = filename.strip().split('.')[-1]
  328. if suffix not in suffixes:
  329. return False
  330. return True
  331. def is_available(ip, port):
  332. s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  333. try:
  334. s.connect((ip, int(port)))
  335. s.shutdown(2)
  336. return False
  337. except:
  338. return True
  339. def list_files(dirname):
  340. """ 列出目录下所有文件(包括所属的一级子目录下文件)
  341. Args:
  342. dirname: 目录路径
  343. """
  344. def filter_file(f):
  345. if f.startswith('.'):
  346. return True
  347. if hasattr(PretrainedModelStatus, f):
  348. return True
  349. return False
  350. all_files = list()
  351. dirs = list()
  352. for f in os.listdir(dirname):
  353. if filter_file(f):
  354. continue
  355. if osp.isdir(osp.join(dirname, f)):
  356. dirs.append(f)
  357. else:
  358. all_files.append(f)
  359. for d in dirs:
  360. for f in os.listdir(osp.join(dirname, d)):
  361. if filter_file(f):
  362. continue
  363. if osp.isdir(osp.join(dirname, d, f)):
  364. continue
  365. all_files.append(osp.join(d, f))
  366. return all_files
  367. def copy_model_directory(src, dst, files=None, filter_files=[]):
  368. """从src目录copy文件至dst目录,
  369. 注意:拷贝前会先清空dst中的所有文件
  370. Args:
  371. src: 源目录路径
  372. dst: 目标目录路径
  373. files: 需要拷贝的文件列表(src的相对路径)
  374. """
  375. set_folder_status(dst, PretrainedModelStatus.XPSAVING, os.getpid())
  376. if files is None:
  377. files = list_files(src)
  378. try:
  379. message = '{} {}'.format(os.getpid(), len(files))
  380. set_folder_status(dst, PretrainedModelStatus.XPSAVING, message)
  381. if not osp.samefile(src, dst):
  382. for i, f in enumerate(files):
  383. items = osp.split(f)
  384. if len(items) > 2:
  385. continue
  386. if len(items) == 2:
  387. if not osp.isdir(osp.join(dst, items[0])):
  388. if osp.exists(osp.join(dst, items[0])):
  389. os.remove(osp.join(dst, items[0]))
  390. os.makedirs(osp.join(dst, items[0]))
  391. if f not in filter_files:
  392. shutil.copy(osp.join(src, f), osp.join(dst, f))
  393. set_folder_status(dst, PretrainedModelStatus.XPSAVEDONE)
  394. except Exception as e:
  395. import traceback
  396. error_info = traceback.format_exc()
  397. set_folder_status(dst, PretrainedModelStatus.XPSAVEFAIL, error_info)
  398. def copy_pretrained_model(src, dst):
  399. p = mp.Process(
  400. target=copy_model_directory, args=(src, dst, None, ['model.pdopt']))
  401. p.start()
  402. return p
  403. def _get_gpu_info(queue):
  404. gpu_info = dict()
  405. mem_free = list()
  406. mem_used = list()
  407. mem_total = list()
  408. import pycuda.driver as drv
  409. from pycuda.tools import clear_context_caches
  410. drv.init()
  411. driver_version = drv.get_driver_version()
  412. gpu_num = drv.Device.count()
  413. for gpu_id in range(gpu_num):
  414. dev = drv.Device(gpu_id)
  415. try:
  416. context = dev.make_context()
  417. free, total = drv.mem_get_info()
  418. context.pop()
  419. free = free // 1024 // 1024
  420. total = total // 1024 // 1024
  421. used = total - free
  422. except:
  423. free = 0
  424. total = 0
  425. used = 0
  426. mem_free.append(free)
  427. mem_used.append(used)
  428. mem_total.append(total)
  429. gpu_info['mem_free'] = mem_free
  430. gpu_info['mem_used'] = mem_used
  431. gpu_info['mem_total'] = mem_total
  432. gpu_info['driver_version'] = driver_version
  433. gpu_info['gpu_num'] = gpu_num
  434. queue.put(gpu_info)
  435. def get_gpu_info():
  436. try:
  437. import pycuda
  438. except:
  439. gpu_info = dict()
  440. message = "未检测到GPU \n 若存在GPU请确保安装pycuda \n 若未安装pycuda请使用'pip install pycuda'来安装"
  441. gpu_info['gpu_num'] = 0
  442. return gpu_info, message
  443. queue = mp.Queue(1)
  444. p = mp.Process(target=_get_gpu_info, args=(queue, ))
  445. p.start()
  446. p.join()
  447. gpu_info = queue.get(timeout=2)
  448. if gpu_info['gpu_num'] == 0:
  449. message = "未检测到GPU"
  450. else:
  451. message = "检测到GPU"
  452. return gpu_info, message
  453. class TrainLogReader(object):
  454. def __init__(self, log_file):
  455. self.log_file = log_file
  456. self.eta = None
  457. self.train_metrics = None
  458. self.eval_metrics = None
  459. self.download_status = None
  460. self.eval_done = False
  461. self.train_error = None
  462. self.train_stage = None
  463. self.running_duration = None
  464. def update(self):
  465. if not osp.exists(self.log_file):
  466. return
  467. if self.train_stage == "Train Error":
  468. return
  469. if self.download_status == "Failed":
  470. return
  471. if self.train_stage == "Train Complete":
  472. return
  473. logs = open(self.log_file, encoding='utf-8').read().strip().split('\n')
  474. self.eta = None
  475. self.train_metrics = None
  476. self.eval_metrics = None
  477. if self.download_status != "Done":
  478. self.download_status = None
  479. start_time_timestamp = osp.getctime(self.log_file)
  480. for line in logs[::1]:
  481. try:
  482. start_time_str = " ".join(line.split()[0:2])
  483. start_time_array = time.strptime(start_time_str,
  484. "%Y-%m-%d %H:%M:%S")
  485. start_time_timestamp = time.mktime(start_time_array)
  486. break
  487. except Exception as e:
  488. pass
  489. for line in logs[::-1]:
  490. if line.count('Train Complete!'):
  491. self.train_stage = "Train Complete"
  492. if line.count('Training stop with error!'):
  493. self.train_error = line
  494. if self.train_metrics is not None \
  495. and self.eval_metrics is not None and self.eval_done and self.eta is not None:
  496. break
  497. items = line.strip().split()
  498. if line.count('Model saved in'):
  499. self.eval_done = True
  500. if line.count('download completed'):
  501. self.download_status = 'Done'
  502. break
  503. if line.count('download failed'):
  504. self.download_status = 'Failed'
  505. break
  506. if self.download_status != 'Done':
  507. if line.count('[DEBUG]\tDownloading'
  508. ) and self.download_status is None:
  509. self.download_status = dict()
  510. if not line.endswith('KB/s'):
  511. continue
  512. speed = items[-1].strip('KB/s').split('=')[-1]
  513. download = items[-2].strip('M, ').split('=')[-1]
  514. total = items[-3].strip('M, ').split('=')[-1]
  515. self.download_status['speed'] = speed
  516. self.download_status['download'] = float(download)
  517. self.download_status['total'] = float(total)
  518. if self.eta is None:
  519. if line.count('eta') > 0 and (line[-3] == ':' or
  520. line[-4] == ':'):
  521. eta = items[-1].strip().split('=')[1]
  522. h, m, s = [int(x) for x in eta.split(':')]
  523. self.eta = h * 3600 + m * 60 + s
  524. if self.train_metrics is None:
  525. if line.count('[INFO]\t[TRAIN]') > 0 and line.count(
  526. 'Step') > 0:
  527. if not items[-1].startswith('eta'):
  528. continue
  529. self.train_metrics = dict()
  530. metrics = items[4:]
  531. for metric in metrics:
  532. try:
  533. name, value = metric.strip(', ').split('=')
  534. value = value.split('/')[0]
  535. if value.count('.') > 0:
  536. value = float(value)
  537. elif value == 'nan':
  538. value = 'nan'
  539. else:
  540. value = int(value)
  541. self.train_metrics[name] = value
  542. except:
  543. pass
  544. if self.eval_metrics is None:
  545. if line.count('[INFO]\t[EVAL]') > 0 and line.count(
  546. 'Finished') > 0:
  547. if not line.strip().endswith(' .'):
  548. continue
  549. self.eval_metrics = dict()
  550. metrics = items[5:]
  551. for metric in metrics:
  552. try:
  553. name, value = metric.strip(', ').split('=')
  554. value = value.split('/')[0]
  555. if value.count('.') > 0:
  556. value = float(value)
  557. else:
  558. value = int(value)
  559. self.eval_metrics[name] = value
  560. except:
  561. pass
  562. end_time_timestamp = osp.getmtime(self.log_file)
  563. t_diff = time.gmtime(end_time_timestamp - start_time_timestamp)
  564. self.running_duration = "{}小时{}分{}秒".format(
  565. t_diff.tm_hour, t_diff.tm_min, t_diff.tm_sec)
  566. class PruneLogReader(object):
  567. def init_attr(self):
  568. self.eta = None
  569. self.iters = None
  570. self.current = None
  571. self.progress = None
  572. def __init__(self, log_file):
  573. self.log_file = log_file
  574. self.init_attr()
  575. def update(self):
  576. if not osp.exists(self.log_file):
  577. return
  578. logs = open(self.log_file, encoding='utf-8').read().strip().split('\n')
  579. self.init_attr()
  580. for line in logs[::-1]:
  581. metric_loaded = True
  582. for k, v in self.__dict__.items():
  583. if v is None:
  584. metric_loaded = False
  585. break
  586. if metric_loaded:
  587. break
  588. if line.count("Total evaluate iters") > 0:
  589. items = line.split(',')
  590. for item in items:
  591. kv_list = item.strip().split()[-1].split('=')
  592. kv_list = [v.strip() for v in kv_list]
  593. setattr(self, kv_list[0], kv_list[1])
  594. class QuantLogReader:
  595. def __init__(self, log_file):
  596. self.log_file = log_file
  597. self.stage = None
  598. self.running_duration = None
  599. def update(self):
  600. if not osp.exists(self.log_file):
  601. return
  602. logs = open(self.log_file, encoding='utf-8').read().strip().split('\n')
  603. for line in logs[::-1]:
  604. items = line.strip().split(' ')
  605. if line.count('[Run batch data]'):
  606. info = items[-3][:-1].split('=')[1]
  607. batch_id = float(info.split('/')[0])
  608. batch_all = float(info.split('/')[1])
  609. self.running_duration = \
  610. batch_id / batch_all * (10.0 / 30.0)
  611. self.stage = 'Batch'
  612. break
  613. elif line.count('[Calculate weight]'):
  614. info = items[-3][:-1].split('=')[1]
  615. weight_id = float(info.split('/')[0])
  616. weight_all = float(info.split('/')[1])
  617. self.running_duration = \
  618. weight_id / weight_all * (3.0 / 30.0) + (10.0 / 30.0)
  619. self.stage = 'Weight'
  620. break
  621. elif line.count('[Calculate activation]'):
  622. info = items[-3][:-1].split('=')[1]
  623. activation_id = float(info.split('/')[0])
  624. activation_all = float(info.split('/')[1])
  625. self.running_duration = \
  626. activation_id / activation_all * (16.0 / 30.0) + (13.0 / 30.0)
  627. self.stage = 'Activation'
  628. break
  629. elif line.count('Finish quant!'):
  630. self.stage = 'Finish'
  631. break
  632. class PyNvml(object):
  633. """ Nvidia GPU驱动检测类,可检测当前GPU驱动版本"""
  634. class PrintableStructure(Structure):
  635. _fmt_ = {}
  636. def __str__(self):
  637. result = []
  638. for x in self._fields_:
  639. key = x[0]
  640. value = getattr(self, key)
  641. fmt = "%s"
  642. if key in self._fmt_:
  643. fmt = self._fmt_[key]
  644. elif "<default>" in self._fmt_:
  645. fmt = self._fmt_["<default>"]
  646. result.append(("%s: " + fmt) % (key, value))
  647. return self.__class__.__name__ + "(" + string.join(result,
  648. ", ") + ")"
  649. class c_nvmlMemory_t(PrintableStructure):
  650. _fields_ = [
  651. ('total', c_ulonglong),
  652. ('free', c_ulonglong),
  653. ('used', c_ulonglong),
  654. ]
  655. _fmt_ = {'<default>': "%d B"}
  656. ## Device structures
  657. class struct_c_nvmlDevice_t(Structure):
  658. pass # opaque handle
  659. c_nvmlDevice_t = POINTER(struct_c_nvmlDevice_t)
  660. def __init__(self):
  661. self.nvml_lib = None
  662. self.nvml_lib_refcount = 0
  663. self.lib_load_lock = threading.Lock()
  664. self.nvml_lib_path = None
  665. def nvml_init(self, nvml_lib_path=None):
  666. self.lib_load_lock.acquire()
  667. sysstr = platform.system()
  668. if nvml_lib_path is None or nvml_lib_path.strip() == "":
  669. if sysstr == "Windows":
  670. nvml_lib_path = osp.join(
  671. os.getenv("ProgramFiles", "C:/Program Files"),
  672. "NVIDIA Corporation/NVSMI")
  673. if not osp.exists(osp.join(nvml_lib_path, "nvml.dll")):
  674. nvml_lib_path = "C:\\Windows\\System32"
  675. elif sysstr == "Linux":
  676. p1 = "/usr/lib/x86_64-linux-gnu"
  677. p2 = "/usr/lib/i386-linux-gnu"
  678. if osp.exists(osp.join(p1, "libnvidia-ml.so.1")):
  679. nvml_lib_path = p1
  680. elif osp.exists(osp.join(p2, "libnvidia-ml.so.1")):
  681. nvml_lib_path = p2
  682. else:
  683. nvml_lib_path = ""
  684. else:
  685. nvml_lib_path = "N/A"
  686. nvml_lib_dir = nvml_lib_path
  687. if sysstr == "Windows":
  688. nvml_lib_path = osp.join(nvml_lib_dir, "nvml.dll")
  689. else:
  690. nvml_lib_path = osp.join(nvml_lib_dir, "libnvidia-ml.so.1")
  691. self.nvml_lib_path = nvml_lib_path
  692. try:
  693. self.nvml_lib = CDLL(nvml_lib_path)
  694. fn = self._get_fn_ptr("nvmlInit_v2")
  695. fn()
  696. if sysstr == "Windows":
  697. driver_version = bytes.decode(
  698. self.nvml_system_get_driver_version())
  699. if driver_version.strip() == "":
  700. nvml_lib_path = osp.join(nvml_lib_dir, "nvml9.dll")
  701. self.nvml_lib = CDLL(nvml_lib_path)
  702. fn = self._get_fn_ptr("nvmlInit_v2")
  703. fn()
  704. except Exception as e:
  705. raise e
  706. finally:
  707. self.lib_load_lock.release()
  708. self.lib_load_lock.acquire()
  709. self.nvml_lib_refcount += 1
  710. self.lib_load_lock.release()
  711. def create_string_buffer(self, init, size=None):
  712. if isinstance(init, bytes):
  713. if size is None:
  714. size = len(init) + 1
  715. buftype = c_char * size
  716. buf = buftype()
  717. buf.value = init
  718. return buf
  719. elif isinstance(init, int):
  720. buftype = c_char * init
  721. buf = buftype()
  722. return buf
  723. raise TypeError(init)
  724. def _get_fn_ptr(self, name):
  725. return getattr(self.nvml_lib, name)
  726. def nvml_system_get_driver_version(self):
  727. c_version = self.create_string_buffer(81)
  728. fn = self._get_fn_ptr("nvmlSystemGetDriverVersion")
  729. ret = fn(c_version, c_uint(81))
  730. return c_version.value
  731. def nvml_device_get_count(self):
  732. c_count = c_uint()
  733. fn = self._get_fn_ptr("nvmlDeviceGetCount_v2")
  734. ret = fn(byref(c_count))
  735. return c_count.value
  736. def nvml_device_get_handle_by_index(self, index):
  737. c_index = c_uint(index)
  738. device = PyNvml.c_nvmlDevice_t()
  739. fn = self._get_fn_ptr("nvmlDeviceGetHandleByIndex_v2")
  740. ret = fn(c_index, byref(device))
  741. return device
  742. def nvml_device_get_memory_info(self, handle):
  743. c_memory = PyNvml.c_nvmlMemory_t()
  744. fn = self._get_fn_ptr("nvmlDeviceGetMemoryInfo")
  745. ret = fn(handle, byref(c_memory))
  746. return c_memory