utils.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import psutil
  15. import shutil
  16. import os
  17. import os.path as osp
  18. from enum import Enum
  19. import multiprocessing as mp
  20. from queue import Queue
  21. import time
  22. import threading
  23. from ctypes import CDLL, c_char, c_uint, c_ulonglong
  24. from _ctypes import byref, Structure, POINTER
  25. import platform
  26. import string
  27. import logging
  28. import socket
  29. import logging.handlers
  30. import requests
  31. import json
  32. from json import JSONEncoder
  33. class CustomEncoder(JSONEncoder):
  34. def default(self, o):
  35. return o.__dict__
  36. class ShareData():
  37. workspace = None
  38. workspace_dir = ""
  39. has_gpu = True
  40. monitored_processes = mp.Queue(4096)
  41. port = 5000
  42. current_port = 8000
  43. running_boards = {}
  44. machine_info = dict()
  45. load_demo_proc_dict = {}
  46. load_demo_proj_data_dict = {}
  47. DatasetStatus = Enum(
  48. 'DatasetStatus', ('XEMPTY', 'XCHECKING', 'XCHECKFAIL', 'XCOPYING',
  49. 'XCOPYDONE', 'XCOPYFAIL', 'XSPLITED'),
  50. start=0)
  51. TaskStatus = Enum(
  52. 'TaskStatus', ('XUNINIT', 'XINIT', 'XDOWNLOADING', 'XTRAINING',
  53. 'XTRAINDONE', 'XEVALUATED', 'XEXPORTING', 'XEXPORTED',
  54. 'XTRAINEXIT', 'XDOWNLOADFAIL', 'XTRAINFAIL', 'XEVALUATING',
  55. 'XEVALUATEFAIL', 'XEXPORTFAIL', 'XPRUNEING', 'XPRUNETRAIN'),
  56. start=0)
  57. ProjectType = Enum(
  58. 'ProjectType', ('classification', 'detection', 'segmentation',
  59. 'instance_segmentation', 'remote_segmentation'),
  60. start=0)
  61. DownloadStatus = Enum(
  62. 'DownloadStatus',
  63. ('XDDOWNLOADING', 'XDDOWNLOADFAIL', 'XDDOWNLOADDONE', 'XDDECOMPRESSED'),
  64. start=0)
  65. PredictStatus = Enum(
  66. 'PredictStatus', ('XPRESTART', 'XPREDONE', 'XPREFAIL'), start=0)
  67. PruneStatus = Enum(
  68. 'PruneStatus', ('XSPRUNESTART', 'XSPRUNEING', 'XSPRUNEDONE', 'XSPRUNEFAIL',
  69. 'XSPRUNEEXIT'),
  70. start=0)
  71. PretrainedModelStatus = Enum(
  72. 'PretrainedModelStatus',
  73. ('XPINIT', 'XPSAVING', 'XPSAVEFAIL', 'XPSAVEDONE'),
  74. start=0)
  75. ExportedModelType = Enum(
  76. 'ExportedModelType', ('XQUANTMOBILE', 'XPRUNEMOBILE', 'XTRAINMOBILE',
  77. 'XQUANTSERVER', 'XPRUNESERVER', 'XTRAINSERVER'),
  78. start=0)
  79. translate_chinese_table = {
  80. "Confusion_matrix": "各个类别之间的混淆矩阵",
  81. "Precision": "精准率",
  82. "Accuracy": "准确率",
  83. "Recall": "召回率",
  84. "Class": "类别",
  85. "Topk": "K取值",
  86. "Auc": "AUC",
  87. "Per_ap": "类别平均精准率",
  88. "Map": "类别平均精准率(AP)的均值(mAP)",
  89. "Mean_iou": "平均交并比",
  90. "Mean_acc": "平均准确率",
  91. "Category_iou": "各类别交并比",
  92. "Category_acc": "各类别准确率",
  93. "Ap": "平均精准率",
  94. "F1": "F1-score",
  95. "Iou": "交并比"
  96. }
  97. translate_chinese = {
  98. "Confusion_matrix": "混淆矩阵",
  99. "Mask_confusion_matrix": "Mask混淆矩阵",
  100. "Bbox_confusion_matrix": "Bbox混淆矩阵",
  101. "Precision": "精准率(Precision)",
  102. "Accuracy": "准确率(Accuracy)",
  103. "Recall": "召回率(Recall)",
  104. "Class": "类别(Class)",
  105. "PRF1": "整体分类评估结果",
  106. "PRF1_TOPk": "TopK评估结果",
  107. "Topk": "K取值",
  108. "AUC": "Area Under Curve",
  109. "Auc": "Area Under Curve",
  110. "F1": "F1-score",
  111. "Iou": "交并比(IoU)",
  112. "Per_ap": "各类别的平均精准率(AP)",
  113. "mAP": "平均精准率的均值(mAP)",
  114. "Mask_mAP": "Mask的平均精准率的均值(mAP)",
  115. "BBox_mAP": "Bbox的平均精准率的均值(mAP)",
  116. "Mean_iou": "平均交并比(mIoU)",
  117. "Mean_acc": "平均准确率(mAcc)",
  118. "Ap": "平均精准率(Average Precision)",
  119. "Category_iou": "各类别的交并比(IoU)",
  120. "Category_acc": "各类别的准确率(Accuracy)",
  121. "PRAP": "整体检测评估结果",
  122. "BBox_PRAP": "Bbox评估结果",
  123. "Mask_PRAP": "Mask评估结果",
  124. "Overall": "整体平均指标",
  125. "PRF1_average": "整体平均指标",
  126. "overall_det": "整体平均指标",
  127. "PRIoU": "整体平均指标",
  128. "Acc1": "预测Top1的准确率",
  129. "Acck": "预测Top{}的准确率"
  130. }
  131. process_pool = Queue(1000)
  132. def get_ip():
  133. try:
  134. s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
  135. s.connect(('8.8.8.8', 80))
  136. ip = s.getsockname()[0]
  137. finally:
  138. s.close()
  139. return ip
  140. def get_logger(filename):
  141. flask_logger = logging.getLogger()
  142. flask_logger.setLevel(level=logging.INFO)
  143. fmt = '%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s:%(message)s'
  144. format_str = logging.Formatter(fmt)
  145. ch = logging.StreamHandler()
  146. ch.setLevel(level=logging.INFO)
  147. ch.setFormatter(format_str)
  148. th = logging.handlers.TimedRotatingFileHandler(
  149. filename=filename, when='D', backupCount=5, encoding='utf-8')
  150. th.setFormatter(format_str)
  151. flask_logger.addHandler(th)
  152. flask_logger.addHandler(ch)
  153. return flask_logger
  154. def start_process(target, args):
  155. global process_pool
  156. p = mp.Process(target=target, args=args)
  157. p.start()
  158. process_pool.put(p)
  159. def pkill(pid):
  160. """结束进程pid,和与其相关的子进程
  161. Args:
  162. pid(int): 进程id
  163. """
  164. try:
  165. parent = psutil.Process(pid)
  166. for child in parent.children(recursive=True):
  167. child.kill()
  168. parent.kill()
  169. except:
  170. print("Try to kill process {} failed.".format(pid))
  171. def set_folder_status(dirname, status, message=""):
  172. """设置目录状态
  173. Args:
  174. dirname(str): 目录路径
  175. status(DatasetStatus): 状态
  176. message(str): 需要写到状态文件里的信息
  177. """
  178. if not osp.isdir(dirname):
  179. raise Exception("目录路径{}不存在".format(dirname))
  180. tmp_file = osp.join(dirname, status.name + '.tmp')
  181. with open(tmp_file, 'w', encoding='utf-8') as f:
  182. f.write("{}\n".format(message))
  183. shutil.move(tmp_file, osp.join(dirname, status.name))
  184. for status_type in [
  185. DatasetStatus, TaskStatus, PredictStatus, PruneStatus,
  186. DownloadStatus, PretrainedModelStatus
  187. ]:
  188. for s in status_type:
  189. if s == status:
  190. continue
  191. if osp.exists(osp.join(dirname, s.name)):
  192. os.remove(osp.join(dirname, s.name))
  193. def get_folder_status(dirname, with_message=False):
  194. """获取目录状态
  195. Args:
  196. dirname(str): 目录路径
  197. with_message(bool): 是否需要返回状态文件内的信息
  198. """
  199. status = None
  200. closest_time = 0
  201. message = ''
  202. for status_type in [
  203. DatasetStatus, TaskStatus, PredictStatus, PruneStatus,
  204. DownloadStatus, PretrainedModelStatus
  205. ]:
  206. for s in status_type:
  207. if osp.exists(osp.join(dirname, s.name)):
  208. modify_time = os.stat(osp.join(dirname, s.name)).st_mtime
  209. if modify_time > closest_time:
  210. closest_time = modify_time
  211. status = getattr(status_type, s.name)
  212. if with_message:
  213. encoding = 'utf-8'
  214. try:
  215. f = open(
  216. osp.join(dirname, s.name),
  217. 'r',
  218. encoding=encoding)
  219. message = f.read()
  220. f.close()
  221. except:
  222. try:
  223. import chardet
  224. f = open(filename, 'rb')
  225. data = f.read()
  226. f.close()
  227. encoding = chardet.detect(data).get('encoding')
  228. f = open(
  229. osp.join(dirname, s.name),
  230. 'r',
  231. encoding=encoding)
  232. message = f.read()
  233. f.close()
  234. except:
  235. pass
  236. if with_message:
  237. return status, message
  238. return status
  239. def _machine_check_proc(queue, path):
  240. info = dict()
  241. p = PyNvml()
  242. gpu_num = 0
  243. try:
  244. # import paddle.fluid.core as core
  245. # gpu_num = core.get_cuda_device_count()
  246. p.nvml_init(path)
  247. gpu_num = p.nvml_device_get_count()
  248. driver_version = bytes.decode(p.nvml_system_get_driver_version())
  249. except Exception as e:
  250. gpu_num = 0
  251. driver_version = "N/A"
  252. info['gpu_num'] = gpu_num
  253. info['gpu_free_mem'] = list()
  254. try:
  255. for i in range(gpu_num):
  256. handle = p.nvml_device_get_handle_by_index(i)
  257. meminfo = p.nvml_device_get_memory_info(handle)
  258. free_mem = meminfo.free / 1024 / 1024
  259. info['gpu_free_mem'].append(free_mem)
  260. except:
  261. pass
  262. info['cpu_num'] = int(os.environ.get('CPU_NUM', 1))
  263. info['driver_version'] = driver_version
  264. info['path'] = p.nvml_lib_path
  265. queue.put(info, timeout=3)
  266. def get_machine_info(path=None):
  267. queue = mp.Queue(1)
  268. p = mp.Process(target=_machine_check_proc, args=(queue, path))
  269. p.start()
  270. p.join()
  271. return queue.get(timeout=2)
  272. def download(url, target_path):
  273. if not osp.exists(target_path):
  274. os.makedirs(target_path)
  275. fname = osp.split(url)[-1]
  276. fullname = osp.join(target_path, fname)
  277. retry_cnt = 0
  278. DOWNLOAD_RETRY_LIMIT = 3
  279. while not (osp.exists(fullname)):
  280. if retry_cnt < DOWNLOAD_RETRY_LIMIT:
  281. retry_cnt += 1
  282. else:
  283. # 设置下载失败
  284. msg = "Download from {} failed. Retry limit reached".format(url)
  285. set_folder_status(target_path, DownloadStatus.XDDOWNLOADFAIL, msg)
  286. raise RuntimeError(msg)
  287. req = requests.get(url, stream=True)
  288. if req.status_code != 200:
  289. msg = "Downloading from {} failed with code {}!".format(
  290. url, req.status_code)
  291. set_folder_status(target_path, DownloadStatus.XDDOWNLOADFAIL, msg)
  292. raise RuntimeError(msg)
  293. # For protecting download interupted, download to
  294. # tmp_fullname firstly, move tmp_fullname to fullname
  295. # after download finished
  296. tmp_fullname = fullname + "_tmp"
  297. total_size = req.headers.get('content-length')
  298. set_folder_status(target_path, DownloadStatus.XDDOWNLOADING,
  299. total_size)
  300. with open(tmp_fullname, 'wb') as f:
  301. if total_size:
  302. download_size = 0
  303. for chunk in req.iter_content(chunk_size=1024):
  304. f.write(chunk)
  305. download_size += 1024
  306. else:
  307. for chunk in req.iter_content(chunk_size=1024):
  308. if chunk:
  309. f.write(chunk)
  310. shutil.move(tmp_fullname, fullname)
  311. set_folder_status(target_path, DownloadStatus.XDDOWNLOADDONE)
  312. return fullname
  313. def trans_name(key, in_table=False):
  314. if in_table:
  315. if key in translate_chinese_table:
  316. key = "{}".format(translate_chinese_table[key])
  317. if key.capitalize() in translate_chinese_table:
  318. key = "{}".format(translate_chinese_table[key.capitalize()])
  319. return key
  320. else:
  321. if key in translate_chinese:
  322. key = "{}".format(translate_chinese[key])
  323. if key.capitalize() in translate_chinese:
  324. key = "{}".format(translate_chinese[key.capitalize()])
  325. return key
  326. return key
  327. def is_pic(filename):
  328. suffixes = {'JPEG', 'jpeg', 'JPG', 'jpg', 'BMP', 'bmp', 'PNG', 'png'}
  329. suffix = filename.strip().split('.')[-1]
  330. if suffix not in suffixes:
  331. return False
  332. return True
  333. def is_available(ip, port):
  334. s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  335. try:
  336. s.connect((ip, int(port)))
  337. s.shutdown(2)
  338. return False
  339. except:
  340. return True
  341. def list_files(dirname):
  342. """ 列出目录下所有文件(包括所属的一级子目录下文件)
  343. Args:
  344. dirname: 目录路径
  345. """
  346. def filter_file(f):
  347. if f.startswith('.'):
  348. return True
  349. if hasattr(PretrainedModelStatus, f):
  350. return True
  351. return False
  352. all_files = list()
  353. dirs = list()
  354. for f in os.listdir(dirname):
  355. if filter_file(f):
  356. continue
  357. if osp.isdir(osp.join(dirname, f)):
  358. dirs.append(f)
  359. else:
  360. all_files.append(f)
  361. for d in dirs:
  362. for f in os.listdir(osp.join(dirname, d)):
  363. if filter_file(f):
  364. continue
  365. if osp.isdir(osp.join(dirname, d, f)):
  366. continue
  367. all_files.append(osp.join(d, f))
  368. return all_files
  369. def copy_model_directory(src, dst, files=None, filter_files=[]):
  370. """从src目录copy文件至dst目录,
  371. 注意:拷贝前会先清空dst中的所有文件
  372. Args:
  373. src: 源目录路径
  374. dst: 目标目录路径
  375. files: 需要拷贝的文件列表(src的相对路径)
  376. """
  377. set_folder_status(dst, PretrainedModelStatus.XPSAVING, os.getpid())
  378. if files is None:
  379. files = list_files(src)
  380. try:
  381. message = '{} {}'.format(os.getpid(), len(files))
  382. set_folder_status(dst, PretrainedModelStatus.XPSAVING, message)
  383. if not osp.samefile(src, dst):
  384. for i, f in enumerate(files):
  385. items = osp.split(f)
  386. if len(items) > 2:
  387. continue
  388. if len(items) == 2:
  389. if not osp.isdir(osp.join(dst, items[0])):
  390. if osp.exists(osp.join(dst, items[0])):
  391. os.remove(osp.join(dst, items[0]))
  392. os.makedirs(osp.join(dst, items[0]))
  393. if f not in filter_files:
  394. shutil.copy(osp.join(src, f), osp.join(dst, f))
  395. set_folder_status(dst, PretrainedModelStatus.XPSAVEDONE)
  396. except Exception as e:
  397. import traceback
  398. error_info = traceback.format_exc()
  399. set_folder_status(dst, PretrainedModelStatus.XPSAVEFAIL, error_info)
  400. def copy_pretrained_model(src, dst):
  401. p = mp.Process(
  402. target=copy_model_directory, args=(src, dst, None, ['model.pdopt']))
  403. p.start()
  404. return p
  405. def _get_gpu_info(queue):
  406. gpu_info = dict()
  407. mem_free = list()
  408. mem_used = list()
  409. mem_total = list()
  410. import pycuda.driver as drv
  411. from pycuda.tools import clear_context_caches
  412. drv.init()
  413. driver_version = drv.get_driver_version()
  414. gpu_num = drv.Device.count()
  415. for gpu_id in range(gpu_num):
  416. dev = drv.Device(gpu_id)
  417. try:
  418. context = dev.make_context()
  419. free, total = drv.mem_get_info()
  420. context.pop()
  421. free = free // 1024 // 1024
  422. total = total // 1024 // 1024
  423. used = total - free
  424. except:
  425. free = 0
  426. total = 0
  427. used = 0
  428. mem_free.append(free)
  429. mem_used.append(used)
  430. mem_total.append(total)
  431. gpu_info['mem_free'] = mem_free
  432. gpu_info['mem_used'] = mem_used
  433. gpu_info['mem_total'] = mem_total
  434. gpu_info['driver_version'] = driver_version
  435. gpu_info['gpu_num'] = gpu_num
  436. queue.put(gpu_info)
  437. def get_gpu_info():
  438. try:
  439. import pycuda
  440. except:
  441. gpu_info = dict()
  442. message = "未检测到GPU \n 若存在GPU请确保安装pycuda \n 若未安装pycuda请使用'pip install pycuda'来安装"
  443. gpu_info['gpu_num'] = 0
  444. return gpu_info, message
  445. queue = mp.Queue(1)
  446. p = mp.Process(target=_get_gpu_info, args=(queue, ))
  447. p.start()
  448. p.join()
  449. gpu_info = queue.get(timeout=2)
  450. if gpu_info['gpu_num'] == 0:
  451. message = "未检测到GPU"
  452. else:
  453. message = "检测到GPU"
  454. return gpu_info, message
  455. class TrainLogReader(object):
  456. def __init__(self, log_file):
  457. self.log_file = log_file
  458. self.eta = None
  459. self.train_metrics = None
  460. self.eval_metrics = None
  461. self.download_status = None
  462. self.eval_done = False
  463. self.train_error = None
  464. self.train_stage = None
  465. self.running_duration = None
  466. def update(self):
  467. if not osp.exists(self.log_file):
  468. return
  469. if self.train_stage == "Train Error":
  470. return
  471. if self.download_status == "Failed":
  472. return
  473. if self.train_stage == "Train Complete":
  474. return
  475. logs = open(self.log_file, encoding='utf-8').read().strip().split('\n')
  476. self.eta = None
  477. self.train_metrics = None
  478. self.eval_metrics = None
  479. if self.download_status != "Done":
  480. self.download_status = None
  481. start_time_timestamp = osp.getctime(self.log_file)
  482. for line in logs[::1]:
  483. try:
  484. start_time_str = " ".join(line.split()[0:2])
  485. start_time_array = time.strptime(start_time_str,
  486. "%Y-%m-%d %H:%M:%S")
  487. start_time_timestamp = time.mktime(start_time_array)
  488. break
  489. except Exception as e:
  490. pass
  491. for line in logs[::-1]:
  492. if line.count('Train Complete!'):
  493. self.train_stage = "Train Complete"
  494. if line.count('Training stop with error!'):
  495. self.train_error = line
  496. if self.train_metrics is not None \
  497. and self.eval_metrics is not None and self.eval_done and self.eta is not None:
  498. break
  499. items = line.strip().split()
  500. if line.count('Model saved in'):
  501. self.eval_done = True
  502. if line.count('download completed'):
  503. self.download_status = 'Done'
  504. break
  505. if line.count('download failed'):
  506. self.download_status = 'Failed'
  507. break
  508. if self.download_status != 'Done':
  509. if line.count('[DEBUG]\tDownloading'
  510. ) and self.download_status is None:
  511. self.download_status = dict()
  512. if not line.endswith('KB/s'):
  513. continue
  514. speed = items[-1].strip('KB/s').split('=')[-1]
  515. download = items[-2].strip('M, ').split('=')[-1]
  516. total = items[-3].strip('M, ').split('=')[-1]
  517. self.download_status['speed'] = speed
  518. self.download_status['download'] = float(download)
  519. self.download_status['total'] = float(total)
  520. if self.eta is None:
  521. if line.count('eta') > 0 and (line[-3] == ':' or
  522. line[-4] == ':'):
  523. eta = items[-1].strip().split('=')[1]
  524. h, m, s = [int(x) for x in eta.split(':')]
  525. self.eta = h * 3600 + m * 60 + s
  526. if self.train_metrics is None:
  527. if line.count('[INFO]\t[TRAIN]') > 0 and line.count(
  528. 'Step') > 0:
  529. if not items[-1].startswith('eta'):
  530. continue
  531. self.train_metrics = dict()
  532. metrics = items[4:]
  533. for metric in metrics:
  534. try:
  535. name, value = metric.strip(', ').split('=')
  536. value = value.split('/')[0]
  537. if value.count('.') > 0:
  538. value = float(value)
  539. elif value == 'nan':
  540. value = 'nan'
  541. else:
  542. value = int(value)
  543. self.train_metrics[name] = value
  544. except:
  545. pass
  546. if self.eval_metrics is None:
  547. if line.count('[INFO]\t[EVAL]') > 0 and line.count(
  548. 'Finished') > 0:
  549. if not line.strip().endswith(' .'):
  550. continue
  551. self.eval_metrics = dict()
  552. metrics = items[5:]
  553. for metric in metrics:
  554. try:
  555. name, value = metric.strip(', ').split('=')
  556. value = value.split('/')[0]
  557. if value.count('.') > 0:
  558. value = float(value)
  559. else:
  560. value = int(value)
  561. self.eval_metrics[name] = value
  562. except:
  563. pass
  564. end_time_timestamp = osp.getmtime(self.log_file)
  565. t_diff = time.gmtime(end_time_timestamp - start_time_timestamp)
  566. self.running_duration = "{}小时{}分{}秒".format(
  567. t_diff.tm_hour, t_diff.tm_min, t_diff.tm_sec)
  568. class PruneLogReader(object):
  569. def init_attr(self):
  570. self.eta = None
  571. self.iters = None
  572. self.current = None
  573. self.progress = None
  574. def __init__(self, log_file):
  575. self.log_file = log_file
  576. self.init_attr()
  577. def update(self):
  578. if not osp.exists(self.log_file):
  579. return
  580. logs = open(self.log_file, encoding='utf-8').read().strip().split('\n')
  581. self.init_attr()
  582. for line in logs[::-1]:
  583. metric_loaded = True
  584. for k, v in self.__dict__.items():
  585. if v is None:
  586. metric_loaded = False
  587. break
  588. if metric_loaded:
  589. break
  590. if line.count("Total evaluate iters") > 0:
  591. items = line.split(',')
  592. for item in items:
  593. kv_list = item.strip().split()[-1].split('=')
  594. kv_list = [v.strip() for v in kv_list]
  595. setattr(self, kv_list[0], kv_list[1])
  596. class QuantLogReader:
  597. def __init__(self, log_file):
  598. self.log_file = log_file
  599. self.stage = None
  600. self.running_duration = None
  601. def update(self):
  602. if not osp.exists(self.log_file):
  603. return
  604. logs = open(self.log_file, encoding='utf-8').read().strip().split('\n')
  605. for line in logs[::-1]:
  606. items = line.strip().split(' ')
  607. if line.count('[Run batch data]'):
  608. info = items[-3][:-1].split('=')[1]
  609. batch_id = float(info.split('/')[0])
  610. batch_all = float(info.split('/')[1])
  611. self.running_duration = \
  612. batch_id / batch_all * (10.0 / 30.0)
  613. self.stage = 'Batch'
  614. break
  615. elif line.count('[Calculate weight]'):
  616. info = items[-3][:-1].split('=')[1]
  617. weight_id = float(info.split('/')[0])
  618. weight_all = float(info.split('/')[1])
  619. self.running_duration = \
  620. weight_id / weight_all * (3.0 / 30.0) + (10.0 / 30.0)
  621. self.stage = 'Weight'
  622. break
  623. elif line.count('[Calculate activation]'):
  624. info = items[-3][:-1].split('=')[1]
  625. activation_id = float(info.split('/')[0])
  626. activation_all = float(info.split('/')[1])
  627. self.running_duration = \
  628. activation_id / activation_all * (16.0 / 30.0) + (13.0 / 30.0)
  629. self.stage = 'Activation'
  630. break
  631. elif line.count('Finish quant!'):
  632. self.stage = 'Finish'
  633. break
  634. class PyNvml(object):
  635. """ Nvidia GPU驱动检测类,可检测当前GPU驱动版本"""
  636. class PrintableStructure(Structure):
  637. _fmt_ = {}
  638. def __str__(self):
  639. result = []
  640. for x in self._fields_:
  641. key = x[0]
  642. value = getattr(self, key)
  643. fmt = "%s"
  644. if key in self._fmt_:
  645. fmt = self._fmt_[key]
  646. elif "<default>" in self._fmt_:
  647. fmt = self._fmt_["<default>"]
  648. result.append(("%s: " + fmt) % (key, value))
  649. return self.__class__.__name__ + "(" + string.join(result,
  650. ", ") + ")"
  651. class c_nvmlMemory_t(PrintableStructure):
  652. _fields_ = [
  653. ('total', c_ulonglong),
  654. ('free', c_ulonglong),
  655. ('used', c_ulonglong),
  656. ]
  657. _fmt_ = {'<default>': "%d B"}
  658. ## Device structures
  659. class struct_c_nvmlDevice_t(Structure):
  660. pass # opaque handle
  661. c_nvmlDevice_t = POINTER(struct_c_nvmlDevice_t)
  662. def __init__(self):
  663. self.nvml_lib = None
  664. self.nvml_lib_refcount = 0
  665. self.lib_load_lock = threading.Lock()
  666. self.nvml_lib_path = None
  667. def nvml_init(self, nvml_lib_path=None):
  668. self.lib_load_lock.acquire()
  669. sysstr = platform.system()
  670. if nvml_lib_path is None or nvml_lib_path.strip() == "":
  671. if sysstr == "Windows":
  672. nvml_lib_path = osp.join(
  673. os.getenv("ProgramFiles", "C:/Program Files"),
  674. "NVIDIA Corporation/NVSMI")
  675. if not osp.exists(osp.join(nvml_lib_path, "nvml.dll")):
  676. nvml_lib_path = "C:\\Windows\\System32"
  677. elif sysstr == "Linux":
  678. p1 = "/usr/lib/x86_64-linux-gnu"
  679. p2 = "/usr/lib/i386-linux-gnu"
  680. if osp.exists(osp.join(p1, "libnvidia-ml.so.1")):
  681. nvml_lib_path = p1
  682. elif osp.exists(osp.join(p2, "libnvidia-ml.so.1")):
  683. nvml_lib_path = p2
  684. else:
  685. nvml_lib_path = ""
  686. else:
  687. nvml_lib_path = "N/A"
  688. nvml_lib_dir = nvml_lib_path
  689. if sysstr == "Windows":
  690. nvml_lib_path = osp.join(nvml_lib_dir, "nvml.dll")
  691. else:
  692. nvml_lib_path = osp.join(nvml_lib_dir, "libnvidia-ml.so.1")
  693. self.nvml_lib_path = nvml_lib_path
  694. try:
  695. self.nvml_lib = CDLL(nvml_lib_path)
  696. fn = self._get_fn_ptr("nvmlInit")
  697. fn()
  698. if sysstr == "Windows":
  699. driver_version = bytes.decode(
  700. self.nvml_system_get_driver_version())
  701. if driver_version.strip() == "":
  702. nvml_lib_path = osp.join(nvml_lib_dir, "nvml9.dll")
  703. self.nvml_lib = CDLL(nvml_lib_path)
  704. fn = self._get_fn_ptr("nvmlInit")
  705. fn()
  706. except Exception as e:
  707. raise e
  708. finally:
  709. self.lib_load_lock.release()
  710. self.lib_load_lock.acquire()
  711. self.nvml_lib_refcount += 1
  712. self.lib_load_lock.release()
  713. def create_string_buffer(self, init, size=None):
  714. if isinstance(init, bytes):
  715. if size is None:
  716. size = len(init) + 1
  717. buftype = c_char * size
  718. buf = buftype()
  719. buf.value = init
  720. return buf
  721. elif isinstance(init, int):
  722. buftype = c_char * init
  723. buf = buftype()
  724. return buf
  725. raise TypeError(init)
  726. def _get_fn_ptr(self, name):
  727. return getattr(self.nvml_lib, name)
  728. def nvml_system_get_driver_version(self):
  729. c_version = self.create_string_buffer(81)
  730. fn = self._get_fn_ptr("nvmlSystemGetDriverVersion")
  731. ret = fn(c_version, c_uint(81))
  732. return c_version.value
  733. def nvml_device_get_count(self):
  734. c_count = c_uint()
  735. fn = self._get_fn_ptr("nvmlDeviceGetCount_v2")
  736. ret = fn(byref(c_count))
  737. return c_count.value
  738. def nvml_device_get_handle_by_index(self, index):
  739. c_index = c_uint(index)
  740. device = PyNvml.c_nvmlDevice_t()
  741. fn = self._get_fn_ptr("nvmlDeviceGetHandleByIndex_v2")
  742. ret = fn(c_index, byref(device))
  743. return device
  744. def nvml_device_get_memory_info(self, handle):
  745. c_memory = PyNvml.c_nvmlMemory_t()
  746. fn = self._get_fn_ptr("nvmlDeviceGetMemoryInfo")
  747. ret = fn(handle, byref(c_memory))
  748. return c_memory