utils.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import psutil
  15. import shutil
  16. import os
  17. import os.path as osp
  18. from enum import Enum
  19. import multiprocessing as mp
  20. from queue import Queue
  21. import time
  22. import threading
  23. from ctypes import CDLL, c_char, c_uint, c_ulonglong
  24. from _ctypes import byref, Structure, POINTER
  25. import platform
  26. import string
  27. import logging
  28. import socket
  29. import logging.handlers
  30. import requests
  31. import json
  32. from json import JSONEncoder
  33. class CustomEncoder(JSONEncoder):
  34. def default(self, o):
  35. return o.__dict__
  36. class ShareData():
  37. workspace = None
  38. workspace_dir = ""
  39. has_gpu = True
  40. monitored_processes = mp.Queue(4096)
  41. port = 5000
  42. current_port = 8000
  43. running_boards = {}
  44. machine_info = dict()
  45. load_demo_proc_dict = {}
  46. load_demo_proj_data_dict = {}
  47. DatasetStatus = Enum(
  48. 'DatasetStatus', ('XEMPTY', 'XCHECKING', 'XCHECKFAIL', 'XCOPYING',
  49. 'XCOPYDONE', 'XCOPYFAIL', 'XSPLITED'),
  50. start=0)
  51. TaskStatus = Enum(
  52. 'TaskStatus', ('XUNINIT', 'XINIT', 'XDOWNLOADING', 'XTRAINING',
  53. 'XTRAINDONE', 'XEVALUATED', 'XEXPORTING', 'XEXPORTED',
  54. 'XTRAINEXIT', 'XDOWNLOADFAIL', 'XTRAINFAIL', 'XEVALUATING',
  55. 'XEVALUATEFAIL', 'XEXPORTFAIL', 'XPRUNEING', 'XPRUNETRAIN'),
  56. start=0)
  57. ProjectType = Enum(
  58. 'ProjectType', ('classification', 'detection', 'segmentation',
  59. 'instance_segmentation', 'remote_segmentation'),
  60. start=0)
  61. DownloadStatus = Enum(
  62. 'DownloadStatus',
  63. ('XDDOWNLOADING', 'XDDOWNLOADFAIL', 'XDDOWNLOADDONE', 'XDDECOMPRESSED'),
  64. start=0)
  65. PredictStatus = Enum(
  66. 'PredictStatus', ('XPRESTART', 'XPREDONE', 'XPREFAIL'), start=0)
  67. PruneStatus = Enum(
  68. 'PruneStatus', ('XSPRUNESTART', 'XSPRUNEING', 'XSPRUNEDONE', 'XSPRUNEFAIL',
  69. 'XSPRUNEEXIT'),
  70. start=0)
  71. PretrainedModelStatus = Enum(
  72. 'PretrainedModelStatus',
  73. ('XPINIT', 'XPSAVING', 'XPSAVEFAIL', 'XPSAVEDONE'),
  74. start=0)
  75. ExportedModelType = Enum(
  76. 'ExportedModelType', ('XQUANTMOBILE', 'XPRUNEMOBILE', 'XTRAINMOBILE',
  77. 'XQUANTSERVER', 'XPRUNESERVER', 'XTRAINSERVER'),
  78. start=0)
  79. translate_chinese_table = {
  80. "Confusion_matrix": "各个类别之间的混淆矩阵",
  81. "Precision": "精准率",
  82. "Accuracy": "准确率",
  83. "Recall": "召回率",
  84. "Class": "类别",
  85. "Topk": "K取值",
  86. "Auc": "AUC",
  87. "Per_ap": "类别平均精准率",
  88. "Map": "类别平均精准率(AP)的均值(mAP)",
  89. "Mean_iou": "平均交并比",
  90. "Mean_acc": "平均准确率",
  91. "Category_iou": "各类别交并比",
  92. "Category_acc": "各类别准确率",
  93. "Ap": "平均精准率",
  94. "F1": "F1-score",
  95. "Iou": "交并比"
  96. }
  97. translate_chinese = {
  98. "Confusion_matrix": "混淆矩阵",
  99. "Mask_confusion_matrix": "Mask混淆矩阵",
  100. "Bbox_confusion_matrix": "Bbox混淆矩阵",
  101. "Precision": "精准率(Precision)",
  102. "Accuracy": "准确率(Accuracy)",
  103. "Recall": "召回率(Recall)",
  104. "Class": "类别(Class)",
  105. "PRF1": "整体分类评估结果",
  106. "PRF1_TOPk": "TopK评估结果",
  107. "Topk": "K取值",
  108. "AUC": "Area Under Curve",
  109. "Auc": "Area Under Curve",
  110. "F1": "F1-score",
  111. "Iou": "交并比(IoU)",
  112. "Per_ap": "各类别的平均精准率(AP)",
  113. "mAP": "平均精准率的均值(mAP)",
  114. "Mask_mAP": "Mask的平均精准率的均值(mAP)",
  115. "BBox_mAP": "Bbox的平均精准率的均值(mAP)",
  116. "Mean_iou": "平均交并比(mIoU)",
  117. "Mean_acc": "平均准确率(mAcc)",
  118. "Ap": "平均精准率(Average Precision)",
  119. "Category_iou": "各类别的交并比(IoU)",
  120. "Category_acc": "各类别的准确率(Accuracy)",
  121. "PRAP": "整体检测评估结果",
  122. "BBox_PRAP": "Bbox评估结果",
  123. "Mask_PRAP": "Mask评估结果",
  124. "Overall": "整体平均指标",
  125. "PRF1_average": "整体平均指标",
  126. "overall_det": "整体平均指标",
  127. "PRIoU": "整体平均指标",
  128. "Acc1": "预测Top1的准确率",
  129. "Acck": "预测Top{}的准确率"
  130. }
  131. process_pool = Queue(1000)
  132. def get_ip():
  133. try:
  134. s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
  135. s.connect(('8.8.8.8', 80))
  136. ip = s.getsockname()[0]
  137. finally:
  138. s.close()
  139. return ip
  140. def get_logger(filename):
  141. flask_logger = logging.getLogger()
  142. flask_logger.setLevel(level=logging.INFO)
  143. fmt = '%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s:%(message)s'
  144. format_str = logging.Formatter(fmt)
  145. ch = logging.StreamHandler()
  146. ch.setLevel(level=logging.INFO)
  147. ch.setFormatter(format_str)
  148. th = logging.handlers.TimedRotatingFileHandler(
  149. filename=filename, when='D', backupCount=5, encoding='utf-8')
  150. th.setFormatter(format_str)
  151. flask_logger.addHandler(th)
  152. flask_logger.addHandler(ch)
  153. return flask_logger
  154. def start_process(target, args):
  155. global process_pool
  156. p = mp.Process(target=target, args=args)
  157. p.start()
  158. process_pool.put(p)
  159. def pkill(pid):
  160. """结束进程pid,和与其相关的子进程
  161. Args:
  162. pid(int): 进程id
  163. """
  164. try:
  165. parent = psutil.Process(pid)
  166. for child in parent.children(recursive=True):
  167. child.kill()
  168. parent.kill()
  169. except:
  170. print("Try to kill process {} failed.".format(pid))
  171. def set_folder_status(dirname, status, message=""):
  172. """设置目录状态
  173. Args:
  174. dirname(str): 目录路径
  175. status(DatasetStatus): 状态
  176. message(str): 需要写到状态文件里的信息
  177. """
  178. if not osp.isdir(dirname):
  179. raise Exception("目录路径{}不存在".format(dirname))
  180. tmp_file = osp.join(dirname, status.name + '.tmp')
  181. with open(tmp_file, 'w', encoding='utf-8') as f:
  182. f.write("{}\n".format(message))
  183. shutil.move(tmp_file, osp.join(dirname, status.name))
  184. for status_type in [
  185. DatasetStatus, TaskStatus, PredictStatus, PruneStatus,
  186. DownloadStatus, PretrainedModelStatus
  187. ]:
  188. for s in status_type:
  189. if s == status:
  190. continue
  191. if osp.exists(osp.join(dirname, s.name)):
  192. os.remove(osp.join(dirname, s.name))
  193. def get_folder_status(dirname, with_message=False):
  194. """获取目录状态
  195. Args:
  196. dirname(str): 目录路径
  197. with_message(bool): 是否需要返回状态文件内的信息
  198. """
  199. status = None
  200. closest_time = 0
  201. message = ''
  202. for status_type in [
  203. DatasetStatus, TaskStatus, PredictStatus, PruneStatus,
  204. DownloadStatus, PretrainedModelStatus
  205. ]:
  206. for s in status_type:
  207. if osp.exists(osp.join(dirname, s.name)):
  208. modify_time = os.stat(osp.join(dirname, s.name)).st_mtime
  209. if modify_time > closest_time:
  210. closest_time = modify_time
  211. status = getattr(status_type, s.name)
  212. if with_message:
  213. encoding = 'utf-8'
  214. try:
  215. f = open(
  216. osp.join(dirname, s.name),
  217. 'r',
  218. encoding=encoding)
  219. message = f.read()
  220. f.close()
  221. except:
  222. try:
  223. import chardet
  224. f = open(filename, 'rb')
  225. data = f.read()
  226. f.close()
  227. encoding = chardet.detect(data).get('encoding')
  228. f = open(
  229. osp.join(dirname, s.name),
  230. 'r',
  231. encoding=encoding)
  232. message = f.read()
  233. f.close()
  234. except:
  235. pass
  236. if with_message:
  237. return status, message
  238. return status
  239. def _machine_check_proc(queue, path):
  240. info = dict()
  241. p = PyNvml()
  242. gpu_num = 0
  243. try:
  244. # import paddle.fluid.core as core
  245. # gpu_num = core.get_cuda_device_count()
  246. p.nvml_init(path)
  247. gpu_num = p.nvml_device_get_count()
  248. driver_version = bytes.decode(p.nvml_system_get_driver_version())
  249. except:
  250. driver_version = "N/A"
  251. info['gpu_num'] = gpu_num
  252. info['gpu_free_mem'] = list()
  253. try:
  254. for i in range(gpu_num):
  255. handle = p.nvml_device_get_handle_by_index(i)
  256. meminfo = p.nvml_device_get_memory_info(handle)
  257. free_mem = meminfo.free / 1024 / 1024
  258. info['gpu_free_mem'].append(free_mem)
  259. except:
  260. pass
  261. info['cpu_num'] = os.environ.get('CPU_NUM', 1)
  262. info['driver_version'] = driver_version
  263. info['path'] = p.nvml_lib_path
  264. queue.put(info, timeout=2)
  265. def get_machine_info(path=None):
  266. queue = mp.Queue(1)
  267. p = mp.Process(target=_machine_check_proc, args=(queue, path))
  268. p.start()
  269. p.join()
  270. return queue.get(timeout=2)
  271. def download(url, target_path):
  272. if not osp.exists(target_path):
  273. os.makedirs(target_path)
  274. fname = osp.split(url)[-1]
  275. fullname = osp.join(target_path, fname)
  276. retry_cnt = 0
  277. DOWNLOAD_RETRY_LIMIT = 3
  278. while not (osp.exists(fullname)):
  279. if retry_cnt < DOWNLOAD_RETRY_LIMIT:
  280. retry_cnt += 1
  281. else:
  282. # 设置下载失败
  283. msg = "Download from {} failed. Retry limit reached".format(url)
  284. set_folder_status(target_path, DownloadStatus.XDDOWNLOADFAIL, msg)
  285. raise RuntimeError(msg)
  286. req = requests.get(url, stream=True)
  287. if req.status_code != 200:
  288. msg = "Downloading from {} failed with code {}!".format(
  289. url, req.status_code)
  290. set_folder_status(target_path, DownloadStatus.XDDOWNLOADFAIL, msg)
  291. raise RuntimeError(msg)
  292. # For protecting download interupted, download to
  293. # tmp_fullname firstly, move tmp_fullname to fullname
  294. # after download finished
  295. tmp_fullname = fullname + "_tmp"
  296. total_size = req.headers.get('content-length')
  297. set_folder_status(target_path, DownloadStatus.XDDOWNLOADING,
  298. total_size)
  299. with open(tmp_fullname, 'wb') as f:
  300. if total_size:
  301. download_size = 0
  302. for chunk in req.iter_content(chunk_size=1024):
  303. f.write(chunk)
  304. download_size += 1024
  305. else:
  306. for chunk in req.iter_content(chunk_size=1024):
  307. if chunk:
  308. f.write(chunk)
  309. shutil.move(tmp_fullname, fullname)
  310. set_folder_status(target_path, DownloadStatus.XDDOWNLOADDONE)
  311. return fullname
  312. def trans_name(key, in_table=False):
  313. if in_table:
  314. if key in translate_chinese_table:
  315. key = "{}".format(translate_chinese_table[key])
  316. if key.capitalize() in translate_chinese_table:
  317. key = "{}".format(translate_chinese_table[key.capitalize()])
  318. return key
  319. else:
  320. if key in translate_chinese:
  321. key = "{}".format(translate_chinese[key])
  322. if key.capitalize() in translate_chinese:
  323. key = "{}".format(translate_chinese[key.capitalize()])
  324. return key
  325. return key
  326. def is_pic(filename):
  327. suffixes = {'JPEG', 'jpeg', 'JPG', 'jpg', 'BMP', 'bmp', 'PNG', 'png'}
  328. suffix = filename.strip().split('.')[-1]
  329. if suffix not in suffixes:
  330. return False
  331. return True
  332. def is_available(ip, port):
  333. s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  334. try:
  335. s.connect((ip, int(port)))
  336. s.shutdown(2)
  337. return False
  338. except:
  339. return True
  340. def list_files(dirname):
  341. """ 列出目录下所有文件(包括所属的一级子目录下文件)
  342. Args:
  343. dirname: 目录路径
  344. """
  345. def filter_file(f):
  346. if f.startswith('.'):
  347. return True
  348. if hasattr(PretrainedModelStatus, f):
  349. return True
  350. return False
  351. all_files = list()
  352. dirs = list()
  353. for f in os.listdir(dirname):
  354. if filter_file(f):
  355. continue
  356. if osp.isdir(osp.join(dirname, f)):
  357. dirs.append(f)
  358. else:
  359. all_files.append(f)
  360. for d in dirs:
  361. for f in os.listdir(osp.join(dirname, d)):
  362. if filter_file(f):
  363. continue
  364. if osp.isdir(osp.join(dirname, d, f)):
  365. continue
  366. all_files.append(osp.join(d, f))
  367. return all_files
  368. def copy_model_directory(src, dst, files=None, filter_files=[]):
  369. """从src目录copy文件至dst目录,
  370. 注意:拷贝前会先清空dst中的所有文件
  371. Args:
  372. src: 源目录路径
  373. dst: 目标目录路径
  374. files: 需要拷贝的文件列表(src的相对路径)
  375. """
  376. set_folder_status(dst, PretrainedModelStatus.XPSAVING, os.getpid())
  377. if files is None:
  378. files = list_files(src)
  379. try:
  380. message = '{} {}'.format(os.getpid(), len(files))
  381. set_folder_status(dst, PretrainedModelStatus.XPSAVING, message)
  382. if not osp.samefile(src, dst):
  383. for i, f in enumerate(files):
  384. items = osp.split(f)
  385. if len(items) > 2:
  386. continue
  387. if len(items) == 2:
  388. if not osp.isdir(osp.join(dst, items[0])):
  389. if osp.exists(osp.join(dst, items[0])):
  390. os.remove(osp.join(dst, items[0]))
  391. os.makedirs(osp.join(dst, items[0]))
  392. if f not in filter_files:
  393. shutil.copy(osp.join(src, f), osp.join(dst, f))
  394. set_folder_status(dst, PretrainedModelStatus.XPSAVEDONE)
  395. except Exception as e:
  396. import traceback
  397. error_info = traceback.format_exc()
  398. set_folder_status(dst, PretrainedModelStatus.XPSAVEFAIL, error_info)
  399. def copy_pretrained_model(src, dst):
  400. p = mp.Process(
  401. target=copy_model_directory, args=(src, dst, None, ['model.pdopt']))
  402. p.start()
  403. return p
  404. def _get_gpu_info(queue):
  405. gpu_info = dict()
  406. mem_free = list()
  407. mem_used = list()
  408. mem_total = list()
  409. import pycuda.driver as drv
  410. from pycuda.tools import clear_context_caches
  411. drv.init()
  412. driver_version = drv.get_driver_version()
  413. gpu_num = drv.Device.count()
  414. for gpu_id in range(gpu_num):
  415. dev = drv.Device(gpu_id)
  416. try:
  417. context = dev.make_context()
  418. free, total = drv.mem_get_info()
  419. context.pop()
  420. free = free // 1024 // 1024
  421. total = total // 1024 // 1024
  422. used = total - free
  423. except:
  424. free = 0
  425. total = 0
  426. used = 0
  427. mem_free.append(free)
  428. mem_used.append(used)
  429. mem_total.append(total)
  430. gpu_info['mem_free'] = mem_free
  431. gpu_info['mem_used'] = mem_used
  432. gpu_info['mem_total'] = mem_total
  433. gpu_info['driver_version'] = driver_version
  434. gpu_info['gpu_num'] = gpu_num
  435. queue.put(gpu_info)
  436. def get_gpu_info():
  437. try:
  438. import pycuda
  439. except:
  440. gpu_info = dict()
  441. message = "未检测到GPU \n 若存在GPU请确保安装pycuda \n 若未安装pycuda请使用'pip install pycuda'来安装"
  442. gpu_info['gpu_num'] = 0
  443. return gpu_info, message
  444. queue = mp.Queue(1)
  445. p = mp.Process(target=_get_gpu_info, args=(queue, ))
  446. p.start()
  447. p.join()
  448. gpu_info = queue.get(timeout=2)
  449. if gpu_info['gpu_num'] == 0:
  450. message = "未检测到GPU"
  451. else:
  452. message = "检测到GPU"
  453. return gpu_info, message
  454. class TrainLogReader(object):
  455. def __init__(self, log_file):
  456. self.log_file = log_file
  457. self.eta = None
  458. self.train_metrics = None
  459. self.eval_metrics = None
  460. self.download_status = None
  461. self.eval_done = False
  462. self.train_error = None
  463. self.train_stage = None
  464. self.running_duration = None
  465. def update(self):
  466. if not osp.exists(self.log_file):
  467. return
  468. if self.train_stage == "Train Error":
  469. return
  470. if self.download_status == "Failed":
  471. return
  472. if self.train_stage == "Train Complete":
  473. return
  474. logs = open(self.log_file, encoding='utf-8').read().strip().split('\n')
  475. self.eta = None
  476. self.train_metrics = None
  477. self.eval_metrics = None
  478. if self.download_status != "Done":
  479. self.download_status = None
  480. start_time_timestamp = osp.getctime(self.log_file)
  481. for line in logs[::1]:
  482. try:
  483. start_time_str = " ".join(line.split()[0:2])
  484. start_time_array = time.strptime(start_time_str,
  485. "%Y-%m-%d %H:%M:%S")
  486. start_time_timestamp = time.mktime(start_time_array)
  487. break
  488. except Exception as e:
  489. pass
  490. for line in logs[::-1]:
  491. if line.count('Train Complete!'):
  492. self.train_stage = "Train Complete"
  493. if line.count('Training stop with error!'):
  494. self.train_error = line
  495. if self.train_metrics is not None \
  496. and self.eval_metrics is not None and self.eval_done and self.eta is not None:
  497. break
  498. items = line.strip().split()
  499. if line.count('Model saved in'):
  500. self.eval_done = True
  501. if line.count('download completed'):
  502. self.download_status = 'Done'
  503. break
  504. if line.count('download failed'):
  505. self.download_status = 'Failed'
  506. break
  507. if self.download_status != 'Done':
  508. if line.count('[DEBUG]\tDownloading'
  509. ) and self.download_status is None:
  510. self.download_status = dict()
  511. if not line.endswith('KB/s'):
  512. continue
  513. speed = items[-1].strip('KB/s').split('=')[-1]
  514. download = items[-2].strip('M, ').split('=')[-1]
  515. total = items[-3].strip('M, ').split('=')[-1]
  516. self.download_status['speed'] = speed
  517. self.download_status['download'] = float(download)
  518. self.download_status['total'] = float(total)
  519. if self.eta is None:
  520. if line.count('eta') > 0 and (line[-3] == ':' or
  521. line[-4] == ':'):
  522. eta = items[-1].strip().split('=')[1]
  523. h, m, s = [int(x) for x in eta.split(':')]
  524. self.eta = h * 3600 + m * 60 + s
  525. if self.train_metrics is None:
  526. if line.count('[INFO]\t[TRAIN]') > 0 and line.count(
  527. 'Step') > 0:
  528. if not items[-1].startswith('eta'):
  529. continue
  530. self.train_metrics = dict()
  531. metrics = items[4:]
  532. for metric in metrics:
  533. try:
  534. name, value = metric.strip(', ').split('=')
  535. value = value.split('/')[0]
  536. if value.count('.') > 0:
  537. value = float(value)
  538. elif value == 'nan':
  539. value = 'nan'
  540. else:
  541. value = int(value)
  542. self.train_metrics[name] = value
  543. except:
  544. pass
  545. if self.eval_metrics is None:
  546. if line.count('[INFO]\t[EVAL]') > 0 and line.count(
  547. 'Finished') > 0:
  548. if not line.strip().endswith(' .'):
  549. continue
  550. self.eval_metrics = dict()
  551. metrics = items[5:]
  552. for metric in metrics:
  553. try:
  554. name, value = metric.strip(', ').split('=')
  555. value = value.split('/')[0]
  556. if value.count('.') > 0:
  557. value = float(value)
  558. else:
  559. value = int(value)
  560. self.eval_metrics[name] = value
  561. except:
  562. pass
  563. end_time_timestamp = osp.getmtime(self.log_file)
  564. t_diff = time.gmtime(end_time_timestamp - start_time_timestamp)
  565. self.running_duration = "{}小时{}分{}秒".format(
  566. t_diff.tm_hour, t_diff.tm_min, t_diff.tm_sec)
  567. class PruneLogReader(object):
  568. def init_attr(self):
  569. self.eta = None
  570. self.iters = None
  571. self.current = None
  572. self.progress = None
  573. def __init__(self, log_file):
  574. self.log_file = log_file
  575. self.init_attr()
  576. def update(self):
  577. if not osp.exists(self.log_file):
  578. return
  579. logs = open(self.log_file, encoding='utf-8').read().strip().split('\n')
  580. self.init_attr()
  581. for line in logs[::-1]:
  582. metric_loaded = True
  583. for k, v in self.__dict__.items():
  584. if v is None:
  585. metric_loaded = False
  586. break
  587. if metric_loaded:
  588. break
  589. if line.count("Total evaluate iters") > 0:
  590. items = line.split(',')
  591. for item in items:
  592. kv_list = item.strip().split()[-1].split('=')
  593. kv_list = [v.strip() for v in kv_list]
  594. setattr(self, kv_list[0], kv_list[1])
  595. class QuantLogReader:
  596. def __init__(self, log_file):
  597. self.log_file = log_file
  598. self.stage = None
  599. self.running_duration = None
  600. def update(self):
  601. if not osp.exists(self.log_file):
  602. return
  603. logs = open(self.log_file, encoding='utf-8').read().strip().split('\n')
  604. for line in logs[::-1]:
  605. items = line.strip().split(' ')
  606. if line.count('[Run batch data]'):
  607. info = items[-3][:-1].split('=')[1]
  608. batch_id = float(info.split('/')[0])
  609. batch_all = float(info.split('/')[1])
  610. self.running_duration = \
  611. batch_id / batch_all * (10.0 / 30.0)
  612. self.stage = 'Batch'
  613. break
  614. elif line.count('[Calculate weight]'):
  615. info = items[-3][:-1].split('=')[1]
  616. weight_id = float(info.split('/')[0])
  617. weight_all = float(info.split('/')[1])
  618. self.running_duration = \
  619. weight_id / weight_all * (3.0 / 30.0) + (10.0 / 30.0)
  620. self.stage = 'Weight'
  621. break
  622. elif line.count('[Calculate activation]'):
  623. info = items[-3][:-1].split('=')[1]
  624. activation_id = float(info.split('/')[0])
  625. activation_all = float(info.split('/')[1])
  626. self.running_duration = \
  627. activation_id / activation_all * (16.0 / 30.0) + (13.0 / 30.0)
  628. self.stage = 'Activation'
  629. break
  630. elif line.count('Finish quant!'):
  631. self.stage = 'Finish'
  632. break
  633. class PyNvml(object):
  634. """ Nvidia GPU驱动检测类,可检测当前GPU驱动版本"""
  635. class PrintableStructure(Structure):
  636. _fmt_ = {}
  637. def __str__(self):
  638. result = []
  639. for x in self._fields_:
  640. key = x[0]
  641. value = getattr(self, key)
  642. fmt = "%s"
  643. if key in self._fmt_:
  644. fmt = self._fmt_[key]
  645. elif "<default>" in self._fmt_:
  646. fmt = self._fmt_["<default>"]
  647. result.append(("%s: " + fmt) % (key, value))
  648. return self.__class__.__name__ + "(" + string.join(result,
  649. ", ") + ")"
  650. class c_nvmlMemory_t(PrintableStructure):
  651. _fields_ = [
  652. ('total', c_ulonglong),
  653. ('free', c_ulonglong),
  654. ('used', c_ulonglong),
  655. ]
  656. _fmt_ = {'<default>': "%d B"}
  657. ## Device structures
  658. class struct_c_nvmlDevice_t(Structure):
  659. pass # opaque handle
  660. c_nvmlDevice_t = POINTER(struct_c_nvmlDevice_t)
  661. def __init__(self):
  662. self.nvml_lib = None
  663. self.nvml_lib_refcount = 0
  664. self.lib_load_lock = threading.Lock()
  665. self.nvml_lib_path = None
  666. def nvml_init(self, nvml_lib_path=None):
  667. self.lib_load_lock.acquire()
  668. sysstr = platform.system()
  669. if nvml_lib_path is None or nvml_lib_path.strip() == "":
  670. if sysstr == "Windows":
  671. nvml_lib_path = osp.join(
  672. os.getenv("ProgramFiles", "C:/Program Files"),
  673. "NVIDIA Corporation/NVSMI")
  674. if not osp.exists(osp.join(nvml_lib_path, "nvml.dll")):
  675. nvml_lib_path = "C:\\Windows\\System32"
  676. elif sysstr == "Linux":
  677. p1 = "/usr/lib/x86_64-linux-gnu"
  678. p2 = "/usr/lib/i386-linux-gnu"
  679. if osp.exists(osp.join(p1, "libnvidia-ml.so.1")):
  680. nvml_lib_path = p1
  681. elif osp.exists(osp.join(p2, "libnvidia-ml.so.1")):
  682. nvml_lib_path = p2
  683. else:
  684. nvml_lib_path = ""
  685. else:
  686. nvml_lib_path = "N/A"
  687. nvml_lib_dir = nvml_lib_path
  688. if sysstr == "Windows":
  689. nvml_lib_path = osp.join(nvml_lib_dir, "nvml.dll")
  690. else:
  691. nvml_lib_path = osp.join(nvml_lib_dir, "libnvidia-ml.so.1")
  692. self.nvml_lib_path = nvml_lib_path
  693. try:
  694. self.nvml_lib = CDLL(nvml_lib_path)
  695. fn = self._get_fn_ptr("nvmlInit_v2")
  696. fn()
  697. if sysstr == "Windows":
  698. driver_version = bytes.decode(
  699. self.nvml_system_get_driver_version())
  700. if driver_version.strip() == "":
  701. nvml_lib_path = osp.join(nvml_lib_dir, "nvml9.dll")
  702. self.nvml_lib = CDLL(nvml_lib_path)
  703. fn = self._get_fn_ptr("nvmlInit_v2")
  704. fn()
  705. except Exception as e:
  706. raise e
  707. finally:
  708. self.lib_load_lock.release()
  709. self.lib_load_lock.acquire()
  710. self.nvml_lib_refcount += 1
  711. self.lib_load_lock.release()
  712. def create_string_buffer(self, init, size=None):
  713. if isinstance(init, bytes):
  714. if size is None:
  715. size = len(init) + 1
  716. buftype = c_char * size
  717. buf = buftype()
  718. buf.value = init
  719. return buf
  720. elif isinstance(init, int):
  721. buftype = c_char * init
  722. buf = buftype()
  723. return buf
  724. raise TypeError(init)
  725. def _get_fn_ptr(self, name):
  726. return getattr(self.nvml_lib, name)
  727. def nvml_system_get_driver_version(self):
  728. c_version = self.create_string_buffer(81)
  729. fn = self._get_fn_ptr("nvmlSystemGetDriverVersion")
  730. ret = fn(c_version, c_uint(81))
  731. return c_version.value
  732. def nvml_device_get_count(self):
  733. c_count = c_uint()
  734. fn = self._get_fn_ptr("nvmlDeviceGetCount_v2")
  735. ret = fn(byref(c_count))
  736. return c_count.value
  737. def nvml_device_get_handle_by_index(self, index):
  738. c_index = c_uint(index)
  739. device = PyNvml.c_nvmlDevice_t()
  740. fn = self._get_fn_ptr("nvmlDeviceGetHandleByIndex_v2")
  741. ret = fn(c_index, byref(device))
  742. return device
  743. def nvml_device_get_memory_info(self, handle):
  744. c_memory = PyNvml.c_nvmlMemory_t()
  745. fn = self._get_fn_ptr("nvmlDeviceGetMemoryInfo")
  746. ret = fn(handle, byref(c_memory))
  747. return c_memory