task.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884
  1. # copytrue (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from .. import workspace_pb2 as w
  15. import os
  16. import os.path as osp
  17. import shutil
  18. import time
  19. import pickle
  20. import json
  21. import multiprocessing as mp
  22. import xlwt
  23. import numpy as np
  24. from ..utils import set_folder_status, TaskStatus, get_folder_status, is_available, get_ip, trans_name
  25. from .train.params import ClsParams, DetParams, SegParams
  26. def create_task(data, workspace):
  27. """根据request创建task。
  28. Args:
  29. data为dict,key包括
  30. 'pid'所属项目id, 'train'训练参数。训练参数和数据增强参数以pickle的形式保存
  31. 在任务目录下的params.pkl文件中。 'parent_id'(可选)该裁剪训练任务的父任务,
  32. 'desc'(可选)任务描述。
  33. """
  34. create_time = time.time()
  35. time_array = time.localtime(create_time)
  36. create_time = time.strftime("%Y-%m-%d %H:%M:%S", time_array)
  37. id = workspace.max_task_id + 1
  38. workspace.max_task_id = id
  39. if id < 10000:
  40. id = 'T%04d' % id
  41. else:
  42. id = 'T{}'.format(id)
  43. pid = data['pid']
  44. assert pid in workspace.projects, "【任务创建】项目ID'{}'不存在.".format(pid)
  45. assert not id in workspace.tasks, "【任务创建】任务ID'{}'已经被占用.".format(id)
  46. did = workspace.projects[pid].did
  47. assert did in workspace.datasets, "【任务创建】数据集ID'{}'不存在".format(did)
  48. path = osp.join(workspace.projects[pid].path, id)
  49. if not osp.exists(path):
  50. os.makedirs(path)
  51. set_folder_status(path, TaskStatus.XINIT)
  52. data['task_type'] = workspace.projects[pid].type
  53. data['dataset_path'] = workspace.datasets[did].path
  54. data['pretrain_weights_download_save_dir'] = osp.join(workspace.path,
  55. 'pretrain')
  56. #获取参数
  57. if 'train' in data:
  58. params_json = json.loads(data['train'])
  59. if (data['task_type'] == 'classification'):
  60. params_init = ClsParams()
  61. if (data['task_type'] == 'detection' or
  62. data['task_type'] == 'instance_segmentation'):
  63. params_init = DetParams()
  64. if (data['task_type'] == 'segmentation' or
  65. data['task_type'] == 'remote_segmentation'):
  66. params_init = SegParams()
  67. params_init.load_from_dict(params_json)
  68. data['train'] = params_init
  69. parent_id = ''
  70. if 'parent_id' in data:
  71. data['tid'] = data['parent_id']
  72. parent_id = data['parent_id']
  73. assert data['parent_id'] in workspace.tasks, "【任务创建】裁剪任务创建失败".format(
  74. data['parent_id'])
  75. r = get_task_params(data, workspace)
  76. train_params = r['train']
  77. data['train'] = train_params
  78. desc = ""
  79. if 'desc' in data:
  80. desc = data['desc']
  81. with open(osp.join(path, 'params.pkl'), 'wb') as f:
  82. pickle.dump(data, f)
  83. task = w.Task(
  84. id=id,
  85. pid=pid,
  86. path=path,
  87. create_time=create_time,
  88. parent_id=parent_id,
  89. desc=desc)
  90. workspace.tasks[id].CopyFrom(task)
  91. with open(os.path.join(path, 'info.pb'), 'wb') as f:
  92. f.write(task.SerializeToString())
  93. return {'status': 1, 'tid': id}
  94. def delete_task(data, workspace):
  95. """删除task。
  96. Args:
  97. data为dict,key包括
  98. 'tid'任务id
  99. """
  100. task_id = data['tid']
  101. assert task_id in workspace.tasks, "任务ID'{}'不存在.".format(task_id)
  102. if osp.exists(workspace.tasks[task_id].path):
  103. shutil.rmtree(workspace.tasks[task_id].path)
  104. del workspace.tasks[task_id]
  105. return {'status': 1}
  106. def get_task_params(data, workspace):
  107. """根据request获取task的参数。
  108. Args:
  109. data为dict,key包括
  110. 'tid'任务id
  111. """
  112. tid = data['tid']
  113. assert tid in workspace.tasks, "【任务创建】任务ID'{}'不存在.".format(tid)
  114. path = workspace.tasks[tid].path
  115. with open(osp.join(path, 'params.pkl'), 'rb') as f:
  116. task_params = pickle.load(f)
  117. return {'status': 1, 'train': task_params['train']}
  118. def list_tasks(data, workspace):
  119. '''列出任务列表,可request的参数进行筛选
  120. Args:
  121. data为dict, 包括
  122. 'pid'(可选)所属项目id
  123. '''
  124. task_list = list()
  125. for key in workspace.tasks:
  126. task_id = workspace.tasks[key].id
  127. task_name = workspace.tasks[key].name
  128. task_desc = workspace.tasks[key].desc
  129. task_pid = workspace.tasks[key].pid
  130. task_path = workspace.tasks[key].path
  131. task_create_time = workspace.tasks[key].create_time
  132. task_type = workspace.projects[task_pid].type
  133. from .operate import get_task_status
  134. path = workspace.tasks[task_id].path
  135. status, message = get_task_status(path)
  136. if data is not None:
  137. if "pid" in data:
  138. if data["pid"] != task_pid:
  139. continue
  140. attr = {
  141. "id": task_id,
  142. "name": task_name,
  143. "desc": task_desc,
  144. "pid": task_pid,
  145. "path": task_path,
  146. "create_time": task_create_time,
  147. "status": status.value,
  148. 'type': task_type
  149. }
  150. task_list.append(attr)
  151. return {'status': 1, 'tasks': task_list}
  152. def set_task_params(data, workspace):
  153. """根据request设置task的参数。只有在task是TaskStatus.XINIT状态时才有效
  154. Args:
  155. data为dict,key包括
  156. 'tid'任务id, 'train'训练参数. 训练
  157. 参数和数据增强参数以pickle的形式保存在任务目录下的params.pkl文件
  158. 中。
  159. """
  160. tid = data['tid']
  161. train = data['train']
  162. assert tid in workspace.tasks, "【任务创建】任务ID'{}'不存在.".format(tid)
  163. path = workspace.tasks[tid].path
  164. status = get_folder_status(path)
  165. assert status == TaskStatus.XINIT, "该任务不在初始化阶段,设置参数失败"
  166. with open(osp.join(path, 'params.pkl'), 'rb') as f:
  167. task_params = pickle.load(f)
  168. train_json = json.loads(train)
  169. task_params['train'].load_from_dict(train_json)
  170. with open(osp.join(path, 'params.pkl'), 'wb') as f:
  171. pickle.dump(task_params, f)
  172. return {'status': 1}
  173. def get_default_params(data, workspace, machine_info):
  174. from .train.params_v2 import get_params
  175. from ..dataset.dataset import get_dataset_details
  176. pid = data['pid']
  177. assert pid in workspace.projects, "项目ID{}不存在.".format(pid)
  178. project_type = workspace.projects[pid].type
  179. did = workspace.projects[pid].did
  180. result = get_dataset_details({'did': did}, workspace)
  181. if result['status'] == 1:
  182. details = result['details']
  183. else:
  184. raise Exception("Fail to get dataset details!")
  185. train_num = len(details['train_files'])
  186. class_num = len(details['labels'])
  187. if machine_info['gpu_num'] == 0:
  188. gpu_num = 0
  189. per_gpu_memory = 0
  190. gpu_list = None
  191. else:
  192. if 'gpu_list' in data:
  193. gpu_list = data['gpu_list']
  194. gpu_num = len(gpu_list)
  195. per_gpu_memory = None
  196. for gpu_id in gpu_list:
  197. if per_gpu_memory is None:
  198. per_gpu_memory = machine_info['gpu_free_mem'][gpu_id]
  199. elif machine_info['gpu_free_mem'][gpu_id] < per_gpu_memory:
  200. per_gpu_memory = machine_info['gpu_free_mem'][gpu_id]
  201. else:
  202. gpu_num = 1
  203. per_gpu_memory = machine_info['gpu_free_mem'][0]
  204. gpu_list = [0]
  205. params = get_params(data, project_type, train_num, class_num, gpu_num,
  206. per_gpu_memory, gpu_list)
  207. return {"status": 1, "train": params}
  208. def get_task_params(data, workspace):
  209. """根据request获取task的参数。
  210. Args:
  211. data为dict,key包括
  212. 'tid'任务id
  213. """
  214. tid = data['tid']
  215. assert tid in workspace.tasks, "【任务创建】任务ID'{}'不存在.".format(tid)
  216. path = workspace.tasks[tid].path
  217. with open(osp.join(path, 'params.pkl'), 'rb') as f:
  218. task_params = pickle.load(f)
  219. return {'status': 1, 'train': task_params['train']}
  220. def get_task_status(data, workspace):
  221. """ 获取任务状态
  222. Args:
  223. data为dict, key包括
  224. 'tid'任务id, 'resume'(可选):获取是否可以恢复训练的状态
  225. """
  226. from .operate import get_task_status, get_task_max_saved_epochs
  227. tid = data['tid']
  228. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  229. path = workspace.tasks[tid].path
  230. status, message = get_task_status(path)
  231. task_pid = workspace.tasks[tid].pid
  232. task_type = workspace.projects[task_pid].type
  233. if 'resume' in data:
  234. max_saved_epochs = get_task_max_saved_epochs(path)
  235. params = {'tid': tid}
  236. results = get_task_params(params, workspace)
  237. total_epochs = results['train'].num_epochs
  238. resumable = max_saved_epochs > 0 and max_saved_epochs < total_epochs
  239. return {
  240. 'status': 1,
  241. 'task_status': status.value,
  242. 'message': message,
  243. 'resumable': resumable,
  244. 'max_saved_epochs': max_saved_epochs,
  245. 'type': task_type
  246. }
  247. return {
  248. 'status': 1,
  249. 'task_status': status.value,
  250. 'message': message,
  251. 'type': task_type
  252. }
  253. def get_train_metrics(data, workspace):
  254. """ 获取任务日志
  255. Args:
  256. data为dict, key包括
  257. 'tid'任务id
  258. Return:
  259. train_log(dict): 'eta':剩余时间,'train_metrics': 训练指标,'eval_metircs': 评估指标,
  260. 'download_status': 下载模型状态,'eval_done': 是否已保存模型,'train_error': 训练错误原因
  261. """
  262. tid = data['tid']
  263. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  264. from ..utils import TrainLogReader
  265. task_path = workspace.tasks[tid].path
  266. log_file = osp.join(task_path, 'out.log')
  267. train_log = TrainLogReader(log_file)
  268. train_log.update()
  269. train_log = train_log.__dict__
  270. return {'status': 1, 'train_log': train_log}
  271. def get_eval_metrics(data, workspace):
  272. """ 获取任务日志
  273. Args:
  274. data为dict, key包括
  275. 'tid'父任务id
  276. """
  277. tid = data['tid']
  278. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  279. best_model_path = osp.join(workspace.tasks[tid].path, "output",
  280. "best_model", "model.yml")
  281. import yaml
  282. f = open(best_model_path, "r", encoding="utf-8")
  283. eval_metrics = yaml.load(f)['_Attributes']['eval_metrics']
  284. f.close()
  285. return {'status': 1, 'eval_metric': eval_metrics}
  286. def get_eval_all_metrics(data, workspace):
  287. tid = data['tid']
  288. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  289. output_dir = osp.join(workspace.tasks[tid].path, "output")
  290. epoch_result_dict = dict()
  291. best_epoch = -1
  292. best_result = -1
  293. import yaml
  294. for file_dir in os.listdir(output_dir):
  295. if file_dir.startswith("epoch"):
  296. epoch_dir = osp.join(output_dir, file_dir)
  297. if osp.exists(osp.join(epoch_dir, ".success")):
  298. epoch_index = int(file_dir.split('_')[-1])
  299. yml_file_path = osp.join(epoch_dir, "model.yml")
  300. f = open(yml_file_path, 'r', encoding='utf-8')
  301. yml_file = yaml.load(f.read())
  302. result = yml_file["_Attributes"]["eval_metrics"]
  303. key = list(result.keys())[0]
  304. value = result[key]
  305. if value > best_result:
  306. best_result = value
  307. best_epoch = epoch_index
  308. elif value == best_result:
  309. if epoch_index < best_epoch:
  310. best_epoch = epoch_index
  311. epoch_result_dict[epoch_index] = value
  312. return {
  313. 'status': 1,
  314. 'key': key,
  315. 'epoch_result_dict': epoch_result_dict,
  316. 'best_epoch': best_epoch,
  317. 'best_result': best_result
  318. }
  319. def get_sensitivities_loss_img(data, workspace):
  320. """ 获取敏感度与模型裁剪率关系图
  321. Args:
  322. data为dict, key包括
  323. 'tid'任务id
  324. """
  325. tid = data['tid']
  326. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  327. task_path = workspace.tasks[tid].path
  328. pkl_path = osp.join(task_path, 'prune', 'sensitivities_xy.pkl')
  329. import pickle
  330. f = open(pkl_path, 'rb')
  331. sensitivities_xy = pickle.load(f)
  332. return {'status': 1, 'sensitivities_loss_img': sensitivities_xy}
  333. def start_train_task(data, workspace, monitored_processes):
  334. """启动训练任务。
  335. Args:
  336. data为dict,key包括
  337. 'tid'任务id, 'eval_metric_loss'(可选)裁剪任务所需的评估loss
  338. """
  339. from .operate import train_model
  340. tid = data['tid']
  341. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  342. path = workspace.tasks[tid].path
  343. if 'eval_metric_loss' in data and \
  344. data['eval_metric_loss'] is not None:
  345. # 裁剪任务
  346. parent_id = workspace.tasks[tid].parent_id
  347. assert parent_id != "", "任务{}不是裁剪训练任务".format(tid)
  348. parent_path = workspace.tasks[parent_id].path
  349. sensitivities_path = osp.join(parent_path, 'prune',
  350. 'sensitivities.data')
  351. eval_metric_loss = data['eval_metric_loss']
  352. parent_best_model_path = osp.join(parent_path, 'output', 'best_model')
  353. params_conf_file = osp.join(path, 'params.pkl')
  354. with open(params_conf_file, 'rb') as f:
  355. params = pickle.load(f)
  356. params['train'].sensitivities_path = sensitivities_path
  357. params['train'].eval_metric_loss = eval_metric_loss
  358. params['train'].pretrain_weights = parent_best_model_path
  359. with open(params_conf_file, 'wb') as f:
  360. pickle.dump(params, f)
  361. p = train_model(path)
  362. monitored_processes.put(p.pid)
  363. return {'status': 1}
  364. def resume_train_task(data, workspace, monitored_processes):
  365. """恢复训练任务
  366. Args:
  367. data为dict, key包括
  368. 'tid'任务id,'epoch'恢复训练的起始轮数
  369. """
  370. from .operate import train_model
  371. tid = data['tid']
  372. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  373. path = workspace.tasks[tid].path
  374. epoch_path = "epoch_" + str(data['epoch'])
  375. resume_checkpoint_path = osp.join(path, "output", epoch_path)
  376. params_conf_file = osp.join(path, 'params.pkl')
  377. with open(params_conf_file, 'rb') as f:
  378. params = pickle.load(f)
  379. params['train'].resume_checkpoint = resume_checkpoint_path
  380. with open(params_conf_file, 'wb') as f:
  381. pickle.dump(params, f)
  382. p = train_model(path)
  383. monitored_processes.put(p.pid)
  384. return {'status': 1}
  385. def stop_train_task(data, workspace):
  386. """停止训练任务
  387. Args:
  388. data为dict, key包括
  389. 'tid'任务id
  390. """
  391. from .operate import stop_train_model
  392. tid = data['tid']
  393. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  394. path = workspace.tasks[tid].path
  395. stop_train_model(path)
  396. return {'status': 1}
  397. def start_prune_analysis(data, workspace, monitored_processes):
  398. """开始模型裁剪分析
  399. Args:
  400. data为dict, key包括
  401. 'tid'任务id
  402. """
  403. tid = data['tid']
  404. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  405. task_path = workspace.tasks[tid].path
  406. from .operate import prune_analysis_model
  407. p = prune_analysis_model(task_path)
  408. monitored_processes.put(p.pid)
  409. return {'status': 1}
  410. def get_prune_metrics(data, workspace):
  411. """ 获取模型裁剪分析日志
  412. Args:
  413. data为dict, key包括
  414. 'tid'任务id
  415. Return:
  416. prune_log(dict): 'eta':剩余时间,'iters': 模型裁剪总轮数,'current': 当前轮数,
  417. 'progress': 模型裁剪进度
  418. """
  419. tid = data['tid']
  420. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  421. from ..utils import PruneLogReader
  422. task_path = workspace.tasks[tid].path
  423. log_file = osp.join(task_path, 'prune', 'out.log')
  424. # assert osp.exists(log_file), "模型裁剪分析任务还未开始,请稍等"
  425. if not osp.exists(log_file):
  426. return {'status': 1, 'prune_log': None}
  427. prune_log = PruneLogReader(log_file)
  428. prune_log.update()
  429. prune_log = prune_log.__dict__
  430. return {'status': 1, 'prune_log': prune_log}
  431. def get_prune_status(data, workspace):
  432. """ 获取模型裁剪状态
  433. Args:
  434. data为dict, key包括
  435. 'tid'任务id
  436. """
  437. from .operate import get_prune_status
  438. tid = data['tid']
  439. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  440. path = workspace.tasks[tid].path
  441. prune_path = osp.join(path, "prune")
  442. status, message = get_prune_status(prune_path)
  443. if status is not None:
  444. status = status.value
  445. return {'status': 1, 'prune_status': status, 'message': message}
  446. def stop_prune_analysis(data, workspace):
  447. """停止模型裁剪分析
  448. Args:
  449. data为dict, key包括
  450. 'tid'任务id
  451. """
  452. tid = data['tid']
  453. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  454. from .operate import stop_prune_analysis
  455. prune_path = osp.join(workspace.tasks[tid].path, 'prune')
  456. stop_prune_analysis(prune_path)
  457. return {'status': 1}
  458. def evaluate_model(data, workspace, monitored_processes):
  459. """ 模型评估
  460. Args:
  461. data为dict, key包括
  462. 'tid'任务id, topk, score_thresh, overlap_thresh这些评估所需参数
  463. Return:
  464. None
  465. """
  466. from .operate import evaluate_model
  467. tid = data['tid']
  468. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  469. pid = workspace.tasks[tid].pid
  470. assert pid in workspace.projects, "项目ID'{}'不存在".format(pid)
  471. path = workspace.tasks[tid].path
  472. type = workspace.projects[pid].type
  473. p = evaluate_model(path, type, data['epoch'], data['topk'],
  474. data['score_thresh'], data['overlap_thresh'])
  475. monitored_processes.put(p.pid)
  476. return {'status': 1}
  477. def get_evaluate_result(data, workspace):
  478. """ 获评估结果
  479. Args:
  480. data为dict, key包括
  481. 'tid'任务id
  482. Return:
  483. 包含评估指标的dict
  484. """
  485. from .operate import get_evaluate_status
  486. tid = data['tid']
  487. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  488. task_path = workspace.tasks[tid].path
  489. status, message = get_evaluate_status(task_path)
  490. if status == TaskStatus.XEVALUATED:
  491. result_file = osp.join(task_path, 'eval_res.pkl')
  492. if os.path.exists(result_file):
  493. result = pickle.load(open(result_file, "rb"))
  494. return {
  495. 'status': 1,
  496. 'evaluate_status': status,
  497. 'message': "{}评估完成".format(tid),
  498. 'path': result_file,
  499. 'result': result
  500. }
  501. else:
  502. return {
  503. 'status': -1,
  504. 'evaluate_status': status,
  505. 'message': "评估结果丢失,建议重新评估!",
  506. 'result': None
  507. }
  508. if status == TaskStatus.XEVALUATEFAIL:
  509. return {
  510. 'status': -1,
  511. 'evaluate_status': status,
  512. 'message': "评估失败,请重新评估!",
  513. 'result': None
  514. }
  515. return {
  516. 'status': 1,
  517. 'evaluate_status': status,
  518. 'message': "{}正在评估中,请稍后!".format(tid),
  519. 'result': None
  520. }
  521. def import_evaluate_excel(data, result, workspace):
  522. excel_ret = dict()
  523. workbook = xlwt.Workbook()
  524. labels = None
  525. START_ROW = 0
  526. sheet = workbook.add_sheet("评估报告")
  527. if 'label_list' not in result:
  528. pass
  529. else:
  530. labels = result['label_list']
  531. for k, v in result.items():
  532. if k == 'label_list':
  533. continue
  534. if type(v) == np.ndarray:
  535. sheet.write(START_ROW + 0, 0, trans_name(k))
  536. sheet.write(START_ROW + 1, 1, trans_name("Class"))
  537. if labels is None:
  538. labels = ["{}".format(x) for x in range(len(v))]
  539. for i in range(len(labels)):
  540. sheet.write(START_ROW + 1, 2 + i, labels[i])
  541. sheet.write(START_ROW + 2 + i, 1, labels[i])
  542. for i in range(len(labels)):
  543. for j in range(len(labels)):
  544. sheet.write(START_ROW + 2 + i, 2 + j, str(v[i, j]))
  545. START_ROW = (START_ROW + 4 + len(labels))
  546. if type(v) == dict:
  547. sheet.write(START_ROW + 0, 0, trans_name(k))
  548. multi_row = False
  549. Cols = ["Class"]
  550. for k1, v1 in v.items():
  551. if type(v1) == dict:
  552. multi_row = True
  553. for sub_k, sub_v in v1.items():
  554. Cols.append(sub_k)
  555. else:
  556. Cols.append(k)
  557. break
  558. for i in range(len(Cols)):
  559. sheet.write(START_ROW + 1, 1 + i, trans_name(Cols[i]))
  560. index = 2
  561. for k1, v1 in v.items():
  562. sheet.write(START_ROW + index, 1, k1)
  563. if multi_row:
  564. for sub_k, sub_v in v1.items():
  565. sheet.write(START_ROW + index,
  566. Cols.index(sub_k) + 1, "nan"
  567. if (sub_v is None) or sub_v == -1 else
  568. "{:.4f}".format(sub_v))
  569. else:
  570. sheet.write(START_ROW + index, 2, "{}".format(v1))
  571. index += 1
  572. START_ROW = (START_ROW + index + 2)
  573. if type(v) in [float, np.float, np.float32, np.float64, type(None)]:
  574. front_str = "{}".format(trans_name(k))
  575. if k == "Acck":
  576. if "topk" in data:
  577. front_str = front_str.format(data["topk"])
  578. else:
  579. front_str = front_str.format(5)
  580. sheet.write(START_ROW + 0, 0, front_str)
  581. sheet.write(START_ROW + 1, 1, "{:.4f}".format(v)
  582. if v is not None else "nan")
  583. START_ROW = (START_ROW + 2 + 2)
  584. tid = data['tid']
  585. path = workspace.tasks[tid].path
  586. final_save = os.path.join(path, 'report-task{}.xls'.format(tid))
  587. workbook.save(final_save)
  588. excel_ret['status'] = 1
  589. excel_ret['path'] = final_save
  590. excel_ret['message'] = "成功导出结果到excel"
  591. return excel_ret
  592. def get_predict_status(data, workspace):
  593. from .operate import get_predict_status
  594. tid = data['tid']
  595. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  596. path = workspace.tasks[tid].path
  597. status, message, predict_num, total_num = get_predict_status(path)
  598. return {
  599. 'status': 1,
  600. 'predict_status': status.value,
  601. 'message': message,
  602. 'predict_num': predict_num,
  603. 'total_num': total_num
  604. }
  605. def predict_test_pics(data, workspace, monitored_processes):
  606. from .operate import predict_test_pics
  607. tid = data['tid']
  608. if 'img_list' in data:
  609. img_list = data['img_list']
  610. else:
  611. img_list = list()
  612. if 'image_data' in data:
  613. image_data = data['image_data']
  614. else:
  615. image_data = None
  616. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  617. path = workspace.tasks[tid].path
  618. save_dir = data['save_dir'] if 'save_dir' in data else None
  619. epoch = data['epoch'] if 'epoch' in data else None
  620. score_thresh = data['score_thresh'] if 'score_thresh' in data else 0.5
  621. p, save_dir = predict_test_pics(
  622. path,
  623. save_dir=save_dir,
  624. img_list=img_list,
  625. img_data=image_data,
  626. score_thresh=score_thresh,
  627. epoch=epoch)
  628. monitored_processes.put(p.pid)
  629. if 'image_data' in data:
  630. path = osp.join(save_dir, 'predict_result.png')
  631. else:
  632. path = None
  633. return {'status': 1, 'path': path}
  634. def stop_predict_task(data, workspace):
  635. from .operate import stop_predict_task
  636. tid = data['tid']
  637. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  638. path = workspace.tasks[tid].path
  639. status, message, predict_num, total_num = stop_predict_task(path)
  640. return {
  641. 'status': 1,
  642. 'predict_status': status.value,
  643. 'message': message,
  644. 'predict_num': predict_num,
  645. 'total_num': total_num
  646. }
  647. def get_quant_progress(data, workspace):
  648. tid = data['tid']
  649. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  650. from ..utils import QuantLogReader
  651. export_path = osp.join(workspace.tasks[tid].path, "./logs/export")
  652. log_file = osp.join(export_path, 'out.log')
  653. quant_log = QuantLogReader(log_file)
  654. quant_log.update()
  655. quant_log = quant_log.__dict__
  656. return {'status': 1, 'quant_log': quant_log}
  657. def get_quant_result(data, workspace):
  658. tid = data['tid']
  659. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  660. export_path = osp.join(workspace.tasks[tid].path, "./logs/export")
  661. result_json = osp.join(export_path, 'quant_result.json')
  662. result = {}
  663. import json
  664. if osp.exists(result_json):
  665. with open(result_json, 'r') as f:
  666. result = json.load(f)
  667. return {'status': 1, 'quant_result': result}
  668. def get_export_status(data, workspace):
  669. """ 获取导出状态
  670. Args:
  671. data为dict, key包括
  672. 'tid'任务id
  673. Return:
  674. 目前导出状态.
  675. """
  676. from .operate import get_export_status
  677. tid = data['tid']
  678. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  679. task_path = workspace.tasks[tid].path
  680. status, message = get_export_status(task_path)
  681. if status == TaskStatus.XEXPORTED:
  682. return {
  683. 'status': 1,
  684. 'export_status': status,
  685. 'message': "恭喜您,{}任务模型导出成功!".format(tid)
  686. }
  687. if status == TaskStatus.XEXPORTFAIL:
  688. return {
  689. 'status': -1,
  690. 'export_status': status,
  691. 'message': "{}任务模型导出失败,请重试!".format(tid)
  692. }
  693. return {
  694. 'status': 1,
  695. 'export_status': status,
  696. 'message': "{}任务模型导出中,请稍等!".format(tid)
  697. }
  698. def export_infer_model(data, workspace, monitored_processes):
  699. """导出部署模型
  700. Args:
  701. data为dict,key包括
  702. 'tid'任务id, 'save_dir'导出模型保存路径
  703. """
  704. from .operate import export_noquant_model, export_quant_model
  705. tid = data['tid']
  706. save_dir = data['save_dir']
  707. epoch = data['epoch'] if 'epoch' in data else None
  708. quant = data['quant'] if 'quant' in data else False
  709. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  710. path = workspace.tasks[tid].path
  711. if quant:
  712. p = export_quant_model(path, save_dir, epoch)
  713. else:
  714. p = export_noquant_model(path, save_dir, epoch)
  715. monitored_processes.put(p.pid)
  716. return {'status': 1, 'save_dir': save_dir}
  717. def export_lite_model(data, workspace):
  718. """ 导出lite模型
  719. Args:
  720. data为dict, key包括
  721. 'tid'任务id, 'save_dir'导出模型保存路径
  722. """
  723. from .operate import opt_lite_model
  724. model_path = data['model_path']
  725. save_dir = data['save_dir']
  726. tid = data['tid']
  727. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  728. opt_lite_model(model_path, save_dir)
  729. if not osp.exists(osp.join(save_dir, "model.nb")):
  730. if osp.exists(save_dir):
  731. shutil.rmtree(save_dir)
  732. return {'status': -1, 'message': "导出为lite模型失败"}
  733. return {'status': 1, 'message': "完成"}
  734. def stop_export_task(data, workspace):
  735. """ 停止导出任务
  736. Args:
  737. data为dict, key包括
  738. 'tid'任务id
  739. Return:
  740. 目前导出的状态.
  741. """
  742. from .operate import stop_export_task
  743. tid = data['tid']
  744. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  745. task_path = workspace.tasks[tid].path
  746. status, message = stop_export_task(task_path)
  747. return {'status': 1, 'export_status': status.value, 'message': message}
  748. def _open_vdl(logdir, current_port):
  749. from visualdl.server import app
  750. app.run(logdir=logdir, host='0.0.0.0', port=current_port)
  751. def open_vdl(data, workspace, current_port, monitored_processes,
  752. running_boards):
  753. """打开vdl页面
  754. Args:
  755. data为dict,
  756. 'tid' 任务id
  757. """
  758. tid = data['tid']
  759. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  760. ip = get_ip()
  761. if tid in running_boards:
  762. url = ip + ":{}".format(running_boards[tid][0])
  763. return {'status': 1, 'url': url}
  764. task_path = workspace.tasks[tid].path
  765. logdir = osp.join(task_path, 'output', 'vdl_log')
  766. assert osp.exists(logdir), "该任务还未正常产生日志文件"
  767. port_available = is_available(ip, current_port)
  768. while not port_available:
  769. current_port += 1
  770. port_available = is_available(ip, current_port)
  771. assert current_port <= 8500, "找不到可用的端口"
  772. p = mp.Process(target=_open_vdl, args=(logdir, current_port))
  773. p.start()
  774. monitored_processes.put(p.pid)
  775. url = ip + ":{}".format(current_port)
  776. running_boards[tid] = [current_port, p.pid]
  777. current_port += 1
  778. total_time = 0
  779. while True:
  780. if not is_available(ip, current_port - 1):
  781. break
  782. print(current_port)
  783. time.sleep(0.5)
  784. total_time += 0.5
  785. assert total_time <= 8, "VisualDL服务启动超时,请重新尝试打开"
  786. return {'status': 1, 'url': url}