task.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885
  1. # copytrue (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from .. import workspace_pb2 as w
  15. import os
  16. import os.path as osp
  17. import shutil
  18. import time
  19. import pickle
  20. import json
  21. import multiprocessing as mp
  22. import xlwt
  23. import numpy as np
  24. from ..utils import set_folder_status, TaskStatus, get_folder_status, is_available, get_ip, trans_name
  25. from .train.params import ClsParams, DetParams, SegParams
  26. from paddlex_restful.restful.dataset.utils import get_encoding
  27. def create_task(data, workspace):
  28. """根据request创建task。
  29. Args:
  30. data为dict,key包括
  31. 'pid'所属项目id, 'train'训练参数。训练参数和数据增强参数以pickle的形式保存
  32. 在任务目录下的params.pkl文件中。 'parent_id'(可选)该裁剪训练任务的父任务,
  33. 'desc'(可选)任务描述。
  34. """
  35. create_time = time.time()
  36. time_array = time.localtime(create_time)
  37. create_time = time.strftime("%Y-%m-%d %H:%M:%S", time_array)
  38. id = workspace.max_task_id + 1
  39. workspace.max_task_id = id
  40. if id < 10000:
  41. id = 'T%04d' % id
  42. else:
  43. id = 'T{}'.format(id)
  44. pid = data['pid']
  45. assert pid in workspace.projects, "【任务创建】项目ID'{}'不存在.".format(pid)
  46. assert not id in workspace.tasks, "【任务创建】任务ID'{}'已经被占用.".format(id)
  47. did = workspace.projects[pid].did
  48. assert did in workspace.datasets, "【任务创建】数据集ID'{}'不存在".format(did)
  49. path = osp.join(workspace.projects[pid].path, id)
  50. if not osp.exists(path):
  51. os.makedirs(path)
  52. set_folder_status(path, TaskStatus.XINIT)
  53. data['task_type'] = workspace.projects[pid].type
  54. data['dataset_path'] = workspace.datasets[did].path
  55. data['pretrain_weights_download_save_dir'] = osp.join(workspace.path,
  56. 'pretrain')
  57. #获取参数
  58. if 'train' in data:
  59. params_json = json.loads(data['train'])
  60. if (data['task_type'] == 'classification'):
  61. params_init = ClsParams()
  62. if (data['task_type'] == 'detection' or
  63. data['task_type'] == 'instance_segmentation'):
  64. params_init = DetParams()
  65. if (data['task_type'] == 'segmentation' or
  66. data['task_type'] == 'remote_segmentation'):
  67. params_init = SegParams()
  68. params_init.load_from_dict(params_json)
  69. data['train'] = params_init
  70. parent_id = ''
  71. if 'parent_id' in data:
  72. data['tid'] = data['parent_id']
  73. parent_id = data['parent_id']
  74. assert data['parent_id'] in workspace.tasks, "【任务创建】裁剪任务创建失败".format(
  75. data['parent_id'])
  76. r = get_task_params(data, workspace)
  77. train_params = r['train']
  78. data['train'] = train_params
  79. desc = ""
  80. if 'desc' in data:
  81. desc = data['desc']
  82. with open(osp.join(path, 'params.pkl'), 'wb') as f:
  83. pickle.dump(data, f)
  84. task = w.Task(
  85. id=id,
  86. pid=pid,
  87. path=path,
  88. create_time=create_time,
  89. parent_id=parent_id,
  90. desc=desc)
  91. workspace.tasks[id].CopyFrom(task)
  92. with open(os.path.join(path, 'info.pb'), 'wb') as f:
  93. f.write(task.SerializeToString())
  94. return {'status': 1, 'tid': id}
  95. def delete_task(data, workspace):
  96. """删除task。
  97. Args:
  98. data为dict,key包括
  99. 'tid'任务id
  100. """
  101. task_id = data['tid']
  102. assert task_id in workspace.tasks, "任务ID'{}'不存在.".format(task_id)
  103. if osp.exists(workspace.tasks[task_id].path):
  104. shutil.rmtree(workspace.tasks[task_id].path)
  105. del workspace.tasks[task_id]
  106. return {'status': 1}
  107. def get_task_params(data, workspace):
  108. """根据request获取task的参数。
  109. Args:
  110. data为dict,key包括
  111. 'tid'任务id
  112. """
  113. tid = data['tid']
  114. assert tid in workspace.tasks, "【任务创建】任务ID'{}'不存在.".format(tid)
  115. path = workspace.tasks[tid].path
  116. with open(osp.join(path, 'params.pkl'), 'rb') as f:
  117. task_params = pickle.load(f)
  118. return {'status': 1, 'train': task_params['train']}
  119. def list_tasks(data, workspace):
  120. '''列出任务列表,可request的参数进行筛选
  121. Args:
  122. data为dict, 包括
  123. 'pid'(可选)所属项目id
  124. '''
  125. task_list = list()
  126. for key in workspace.tasks:
  127. task_id = workspace.tasks[key].id
  128. task_name = workspace.tasks[key].name
  129. task_desc = workspace.tasks[key].desc
  130. task_pid = workspace.tasks[key].pid
  131. task_path = workspace.tasks[key].path
  132. task_create_time = workspace.tasks[key].create_time
  133. task_type = workspace.projects[task_pid].type
  134. from .operate import get_task_status
  135. path = workspace.tasks[task_id].path
  136. status, message = get_task_status(path)
  137. if data is not None:
  138. if "pid" in data:
  139. if data["pid"] != task_pid:
  140. continue
  141. attr = {
  142. "id": task_id,
  143. "name": task_name,
  144. "desc": task_desc,
  145. "pid": task_pid,
  146. "path": task_path,
  147. "create_time": task_create_time,
  148. "status": status.value,
  149. 'type': task_type
  150. }
  151. task_list.append(attr)
  152. return {'status': 1, 'tasks': task_list}
  153. def set_task_params(data, workspace):
  154. """根据request设置task的参数。只有在task是TaskStatus.XINIT状态时才有效
  155. Args:
  156. data为dict,key包括
  157. 'tid'任务id, 'train'训练参数. 训练
  158. 参数和数据增强参数以pickle的形式保存在任务目录下的params.pkl文件
  159. 中。
  160. """
  161. tid = data['tid']
  162. train = data['train']
  163. assert tid in workspace.tasks, "【任务创建】任务ID'{}'不存在.".format(tid)
  164. path = workspace.tasks[tid].path
  165. status = get_folder_status(path)
  166. assert status == TaskStatus.XINIT, "该任务不在初始化阶段,设置参数失败"
  167. with open(osp.join(path, 'params.pkl'), 'rb') as f:
  168. task_params = pickle.load(f)
  169. train_json = json.loads(train)
  170. task_params['train'].load_from_dict(train_json)
  171. with open(osp.join(path, 'params.pkl'), 'wb') as f:
  172. pickle.dump(task_params, f)
  173. return {'status': 1}
  174. def get_default_params(data, workspace, machine_info):
  175. from .train.params_v2 import get_params
  176. from ..dataset.dataset import get_dataset_details
  177. pid = data['pid']
  178. assert pid in workspace.projects, "项目ID{}不存在.".format(pid)
  179. project_type = workspace.projects[pid].type
  180. did = workspace.projects[pid].did
  181. result = get_dataset_details({'did': did}, workspace)
  182. if result['status'] == 1:
  183. details = result['details']
  184. else:
  185. raise Exception("Fail to get dataset details!")
  186. train_num = len(details['train_files'])
  187. class_num = len(details['labels'])
  188. if machine_info['gpu_num'] == 0:
  189. gpu_num = 0
  190. per_gpu_memory = 0
  191. gpu_list = None
  192. else:
  193. if 'gpu_list' in data:
  194. gpu_list = data['gpu_list']
  195. gpu_num = len(gpu_list)
  196. per_gpu_memory = None
  197. for gpu_id in gpu_list:
  198. if per_gpu_memory is None:
  199. per_gpu_memory = machine_info['gpu_free_mem'][gpu_id]
  200. elif machine_info['gpu_free_mem'][gpu_id] < per_gpu_memory:
  201. per_gpu_memory = machine_info['gpu_free_mem'][gpu_id]
  202. else:
  203. gpu_num = 1
  204. per_gpu_memory = machine_info['gpu_free_mem'][0]
  205. gpu_list = [0]
  206. params = get_params(data, project_type, train_num, class_num, gpu_num,
  207. per_gpu_memory, gpu_list)
  208. return {"status": 1, "train": params}
  209. def get_task_params(data, workspace):
  210. """根据request获取task的参数。
  211. Args:
  212. data为dict,key包括
  213. 'tid'任务id
  214. """
  215. tid = data['tid']
  216. assert tid in workspace.tasks, "【任务创建】任务ID'{}'不存在.".format(tid)
  217. path = workspace.tasks[tid].path
  218. with open(osp.join(path, 'params.pkl'), 'rb') as f:
  219. task_params = pickle.load(f)
  220. return {'status': 1, 'train': task_params['train']}
  221. def get_task_status(data, workspace):
  222. """ 获取任务状态
  223. Args:
  224. data为dict, key包括
  225. 'tid'任务id, 'resume'(可选):获取是否可以恢复训练的状态
  226. """
  227. from .operate import get_task_status, get_task_max_saved_epochs
  228. tid = data['tid']
  229. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  230. path = workspace.tasks[tid].path
  231. status, message = get_task_status(path)
  232. task_pid = workspace.tasks[tid].pid
  233. task_type = workspace.projects[task_pid].type
  234. if 'resume' in data:
  235. max_saved_epochs = get_task_max_saved_epochs(path)
  236. params = {'tid': tid}
  237. results = get_task_params(params, workspace)
  238. total_epochs = results['train'].num_epochs
  239. resumable = max_saved_epochs > 0 and max_saved_epochs < total_epochs
  240. return {
  241. 'status': 1,
  242. 'task_status': status.value,
  243. 'message': message,
  244. 'resumable': resumable,
  245. 'max_saved_epochs': max_saved_epochs,
  246. 'type': task_type
  247. }
  248. return {
  249. 'status': 1,
  250. 'task_status': status.value,
  251. 'message': message,
  252. 'type': task_type
  253. }
  254. def get_train_metrics(data, workspace):
  255. """ 获取任务日志
  256. Args:
  257. data为dict, key包括
  258. 'tid'任务id
  259. Return:
  260. train_log(dict): 'eta':剩余时间,'train_metrics': 训练指标,'eval_metircs': 评估指标,
  261. 'download_status': 下载模型状态,'eval_done': 是否已保存模型,'train_error': 训练错误原因
  262. """
  263. tid = data['tid']
  264. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  265. from ..utils import TrainLogReader
  266. task_path = workspace.tasks[tid].path
  267. log_file = osp.join(task_path, 'out.log')
  268. train_log = TrainLogReader(log_file)
  269. train_log.update()
  270. train_log = train_log.__dict__
  271. return {'status': 1, 'train_log': train_log}
  272. def get_eval_metrics(data, workspace):
  273. """ 获取任务日志
  274. Args:
  275. data为dict, key包括
  276. 'tid'父任务id
  277. """
  278. tid = data['tid']
  279. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  280. best_model_path = osp.join(workspace.tasks[tid].path, "output",
  281. "best_model", "model.yml")
  282. import yaml
  283. f = open(best_model_path, "r", encoding="utf-8")
  284. eval_metrics = yaml.load(f)['_Attributes']['eval_metrics']
  285. f.close()
  286. return {'status': 1, 'eval_metric': eval_metrics}
  287. def get_eval_all_metrics(data, workspace):
  288. tid = data['tid']
  289. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  290. output_dir = osp.join(workspace.tasks[tid].path, "output")
  291. epoch_result_dict = dict()
  292. best_epoch = -1
  293. best_result = -1
  294. import yaml
  295. for file_dir in os.listdir(output_dir):
  296. if file_dir.startswith("epoch"):
  297. epoch_dir = osp.join(output_dir, file_dir)
  298. if osp.exists(osp.join(epoch_dir, ".success")):
  299. epoch_index = int(file_dir.split('_')[-1])
  300. yml_file_path = osp.join(epoch_dir, "model.yml")
  301. f = open(yml_file_path, 'r', encoding='utf-8')
  302. yml_file = yaml.load(f.read())
  303. result = yml_file["_Attributes"]["eval_metrics"]
  304. key = list(result.keys())[0]
  305. value = result[key]
  306. if value > best_result:
  307. best_result = value
  308. best_epoch = epoch_index
  309. elif value == best_result:
  310. if epoch_index < best_epoch:
  311. best_epoch = epoch_index
  312. epoch_result_dict[epoch_index] = value
  313. return {
  314. 'status': 1,
  315. 'key': key,
  316. 'epoch_result_dict': epoch_result_dict,
  317. 'best_epoch': best_epoch,
  318. 'best_result': best_result
  319. }
  320. def get_sensitivities_loss_img(data, workspace):
  321. """ 获取敏感度与模型裁剪率关系图
  322. Args:
  323. data为dict, key包括
  324. 'tid'任务id
  325. """
  326. tid = data['tid']
  327. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  328. task_path = workspace.tasks[tid].path
  329. pkl_path = osp.join(task_path, 'prune', 'sensitivities_xy.pkl')
  330. import pickle
  331. f = open(pkl_path, 'rb')
  332. sensitivities_xy = pickle.load(f)
  333. return {'status': 1, 'sensitivities_loss_img': sensitivities_xy}
  334. def start_train_task(data, workspace, monitored_processes):
  335. """启动训练任务。
  336. Args:
  337. data为dict,key包括
  338. 'tid'任务id, 'eval_metric_loss'(可选)裁剪任务所需的评估loss
  339. """
  340. from .operate import train_model
  341. tid = data['tid']
  342. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  343. path = workspace.tasks[tid].path
  344. if 'eval_metric_loss' in data and \
  345. data['eval_metric_loss'] is not None:
  346. # 裁剪任务
  347. parent_id = workspace.tasks[tid].parent_id
  348. assert parent_id != "", "任务{}不是裁剪训练任务".format(tid)
  349. parent_path = workspace.tasks[parent_id].path
  350. sensitivities_path = osp.join(parent_path, 'prune',
  351. 'sensitivities.data')
  352. eval_metric_loss = data['eval_metric_loss']
  353. parent_best_model_path = osp.join(parent_path, 'output', 'best_model')
  354. params_conf_file = osp.join(path, 'params.pkl')
  355. with open(params_conf_file, 'rb') as f:
  356. params = pickle.load(f)
  357. params['train'].sensitivities_path = sensitivities_path
  358. params['train'].eval_metric_loss = eval_metric_loss
  359. params['train'].pretrain_weights = parent_best_model_path
  360. with open(params_conf_file, 'wb') as f:
  361. pickle.dump(params, f)
  362. p = train_model(path)
  363. monitored_processes.put(p.pid)
  364. return {'status': 1}
  365. def resume_train_task(data, workspace, monitored_processes):
  366. """恢复训练任务
  367. Args:
  368. data为dict, key包括
  369. 'tid'任务id,'epoch'恢复训练的起始轮数
  370. """
  371. from .operate import train_model
  372. tid = data['tid']
  373. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  374. path = workspace.tasks[tid].path
  375. epoch_path = "epoch_" + str(data['epoch'])
  376. resume_checkpoint_path = osp.join(path, "output", epoch_path)
  377. params_conf_file = osp.join(path, 'params.pkl')
  378. with open(params_conf_file, 'rb') as f:
  379. params = pickle.load(f)
  380. params['train'].resume_checkpoint = resume_checkpoint_path
  381. with open(params_conf_file, 'wb') as f:
  382. pickle.dump(params, f)
  383. p = train_model(path)
  384. monitored_processes.put(p.pid)
  385. return {'status': 1}
  386. def stop_train_task(data, workspace):
  387. """停止训练任务
  388. Args:
  389. data为dict, key包括
  390. 'tid'任务id
  391. """
  392. from .operate import stop_train_model
  393. tid = data['tid']
  394. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  395. path = workspace.tasks[tid].path
  396. stop_train_model(path)
  397. return {'status': 1}
  398. def start_prune_analysis(data, workspace, monitored_processes):
  399. """开始模型裁剪分析
  400. Args:
  401. data为dict, key包括
  402. 'tid'任务id
  403. """
  404. tid = data['tid']
  405. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  406. task_path = workspace.tasks[tid].path
  407. from .operate import prune_analysis_model
  408. p = prune_analysis_model(task_path)
  409. monitored_processes.put(p.pid)
  410. return {'status': 1}
  411. def get_prune_metrics(data, workspace):
  412. """ 获取模型裁剪分析日志
  413. Args:
  414. data为dict, key包括
  415. 'tid'任务id
  416. Return:
  417. prune_log(dict): 'eta':剩余时间,'iters': 模型裁剪总轮数,'current': 当前轮数,
  418. 'progress': 模型裁剪进度
  419. """
  420. tid = data['tid']
  421. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  422. from ..utils import PruneLogReader
  423. task_path = workspace.tasks[tid].path
  424. log_file = osp.join(task_path, 'prune', 'out.log')
  425. # assert osp.exists(log_file), "模型裁剪分析任务还未开始,请稍等"
  426. if not osp.exists(log_file):
  427. return {'status': 1, 'prune_log': None}
  428. prune_log = PruneLogReader(log_file)
  429. prune_log.update()
  430. prune_log = prune_log.__dict__
  431. return {'status': 1, 'prune_log': prune_log}
  432. def get_prune_status(data, workspace):
  433. """ 获取模型裁剪状态
  434. Args:
  435. data为dict, key包括
  436. 'tid'任务id
  437. """
  438. from .operate import get_prune_status
  439. tid = data['tid']
  440. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  441. path = workspace.tasks[tid].path
  442. prune_path = osp.join(path, "prune")
  443. status, message = get_prune_status(prune_path)
  444. if status is not None:
  445. status = status.value
  446. return {'status': 1, 'prune_status': status, 'message': message}
  447. def stop_prune_analysis(data, workspace):
  448. """停止模型裁剪分析
  449. Args:
  450. data为dict, key包括
  451. 'tid'任务id
  452. """
  453. tid = data['tid']
  454. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  455. from .operate import stop_prune_analysis
  456. prune_path = osp.join(workspace.tasks[tid].path, 'prune')
  457. stop_prune_analysis(prune_path)
  458. return {'status': 1}
  459. def evaluate_model(data, workspace, monitored_processes):
  460. """ 模型评估
  461. Args:
  462. data为dict, key包括
  463. 'tid'任务id, topk, score_thresh, overlap_thresh这些评估所需参数
  464. Return:
  465. None
  466. """
  467. from .operate import evaluate_model
  468. tid = data['tid']
  469. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  470. pid = workspace.tasks[tid].pid
  471. assert pid in workspace.projects, "项目ID'{}'不存在".format(pid)
  472. path = workspace.tasks[tid].path
  473. type = workspace.projects[pid].type
  474. p = evaluate_model(path, type, data['epoch'], data['topk'],
  475. data['score_thresh'], data['overlap_thresh'])
  476. monitored_processes.put(p.pid)
  477. return {'status': 1}
  478. def get_evaluate_result(data, workspace):
  479. """ 获评估结果
  480. Args:
  481. data为dict, key包括
  482. 'tid'任务id
  483. Return:
  484. 包含评估指标的dict
  485. """
  486. from .operate import get_evaluate_status
  487. tid = data['tid']
  488. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  489. task_path = workspace.tasks[tid].path
  490. status, message = get_evaluate_status(task_path)
  491. if status == TaskStatus.XEVALUATED:
  492. result_file = osp.join(task_path, 'eval_res.pkl')
  493. if os.path.exists(result_file):
  494. result = pickle.load(open(result_file, "rb"))
  495. return {
  496. 'status': 1,
  497. 'evaluate_status': status,
  498. 'message': "{}评估完成".format(tid),
  499. 'path': result_file,
  500. 'result': result
  501. }
  502. else:
  503. return {
  504. 'status': -1,
  505. 'evaluate_status': status,
  506. 'message': "评估结果丢失,建议重新评估!",
  507. 'result': None
  508. }
  509. if status == TaskStatus.XEVALUATEFAIL:
  510. return {
  511. 'status': -1,
  512. 'evaluate_status': status,
  513. 'message': "评估失败,请重新评估!",
  514. 'result': None
  515. }
  516. return {
  517. 'status': 1,
  518. 'evaluate_status': status,
  519. 'message': "{}正在评估中,请稍后!".format(tid),
  520. 'result': None
  521. }
  522. def import_evaluate_excel(data, result, workspace):
  523. excel_ret = dict()
  524. workbook = xlwt.Workbook()
  525. labels = None
  526. START_ROW = 0
  527. sheet = workbook.add_sheet("评估报告")
  528. if 'label_list' not in result:
  529. pass
  530. else:
  531. labels = result['label_list']
  532. for k, v in result.items():
  533. if k == 'label_list':
  534. continue
  535. if type(v) == np.ndarray:
  536. sheet.write(START_ROW + 0, 0, trans_name(k))
  537. sheet.write(START_ROW + 1, 1, trans_name("Class"))
  538. if labels is None:
  539. labels = ["{}".format(x) for x in range(len(v))]
  540. for i in range(len(labels)):
  541. sheet.write(START_ROW + 1, 2 + i, labels[i])
  542. sheet.write(START_ROW + 2 + i, 1, labels[i])
  543. for i in range(len(labels)):
  544. for j in range(len(labels)):
  545. sheet.write(START_ROW + 2 + i, 2 + j, str(v[i, j]))
  546. START_ROW = (START_ROW + 4 + len(labels))
  547. if type(v) == dict:
  548. sheet.write(START_ROW + 0, 0, trans_name(k))
  549. multi_row = False
  550. Cols = ["Class"]
  551. for k1, v1 in v.items():
  552. if type(v1) == dict:
  553. multi_row = True
  554. for sub_k, sub_v in v1.items():
  555. Cols.append(sub_k)
  556. else:
  557. Cols.append(k)
  558. break
  559. for i in range(len(Cols)):
  560. sheet.write(START_ROW + 1, 1 + i, trans_name(Cols[i]))
  561. index = 2
  562. for k1, v1 in v.items():
  563. sheet.write(START_ROW + index, 1, k1)
  564. if multi_row:
  565. for sub_k, sub_v in v1.items():
  566. sheet.write(START_ROW + index,
  567. Cols.index(sub_k) + 1, "nan"
  568. if (sub_v is None) or sub_v == -1 else
  569. "{:.4f}".format(sub_v))
  570. else:
  571. sheet.write(START_ROW + index, 2, "{}".format(v1))
  572. index += 1
  573. START_ROW = (START_ROW + index + 2)
  574. if type(v) in [float, np.float, np.float32, np.float64, type(None)]:
  575. front_str = "{}".format(trans_name(k))
  576. if k == "Acck":
  577. if "topk" in data:
  578. front_str = front_str.format(data["topk"])
  579. else:
  580. front_str = front_str.format(5)
  581. sheet.write(START_ROW + 0, 0, front_str)
  582. sheet.write(START_ROW + 1, 1, "{:.4f}".format(v)
  583. if v is not None else "nan")
  584. START_ROW = (START_ROW + 2 + 2)
  585. tid = data['tid']
  586. path = workspace.tasks[tid].path
  587. final_save = os.path.join(path, 'report-task{}.xls'.format(tid))
  588. workbook.save(final_save)
  589. excel_ret['status'] = 1
  590. excel_ret['path'] = final_save
  591. excel_ret['message'] = "成功导出结果到excel"
  592. return excel_ret
  593. def get_predict_status(data, workspace):
  594. from .operate import get_predict_status
  595. tid = data['tid']
  596. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  597. path = workspace.tasks[tid].path
  598. status, message, predict_num, total_num = get_predict_status(path)
  599. return {
  600. 'status': 1,
  601. 'predict_status': status.value,
  602. 'message': message,
  603. 'predict_num': predict_num,
  604. 'total_num': total_num
  605. }
  606. def predict_test_pics(data, workspace, monitored_processes):
  607. from .operate import predict_test_pics
  608. tid = data['tid']
  609. if 'img_list' in data:
  610. img_list = data['img_list']
  611. else:
  612. img_list = list()
  613. if 'image_data' in data:
  614. image_data = data['image_data']
  615. else:
  616. image_data = None
  617. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  618. path = workspace.tasks[tid].path
  619. save_dir = data['save_dir'] if 'save_dir' in data else None
  620. epoch = data['epoch'] if 'epoch' in data else None
  621. score_thresh = data['score_thresh'] if 'score_thresh' in data else 0.5
  622. p, save_dir = predict_test_pics(
  623. path,
  624. save_dir=save_dir,
  625. img_list=img_list,
  626. img_data=image_data,
  627. score_thresh=score_thresh,
  628. epoch=epoch)
  629. monitored_processes.put(p.pid)
  630. if 'image_data' in data:
  631. path = osp.join(save_dir, 'predict_result.png')
  632. else:
  633. path = None
  634. return {'status': 1, 'path': path}
  635. def stop_predict_task(data, workspace):
  636. from .operate import stop_predict_task
  637. tid = data['tid']
  638. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  639. path = workspace.tasks[tid].path
  640. status, message, predict_num, total_num = stop_predict_task(path)
  641. return {
  642. 'status': 1,
  643. 'predict_status': status.value,
  644. 'message': message,
  645. 'predict_num': predict_num,
  646. 'total_num': total_num
  647. }
  648. def get_quant_progress(data, workspace):
  649. tid = data['tid']
  650. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  651. from ..utils import QuantLogReader
  652. export_path = osp.join(workspace.tasks[tid].path, "./logs/export")
  653. log_file = osp.join(export_path, 'out.log')
  654. quant_log = QuantLogReader(log_file)
  655. quant_log.update()
  656. quant_log = quant_log.__dict__
  657. return {'status': 1, 'quant_log': quant_log}
  658. def get_quant_result(data, workspace):
  659. tid = data['tid']
  660. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  661. export_path = osp.join(workspace.tasks[tid].path, "./logs/export")
  662. result_json = osp.join(export_path, 'quant_result.json')
  663. result = {}
  664. import json
  665. if osp.exists(result_json):
  666. with open(result_json, 'r', encoding=get_encoding(result_json)) as f:
  667. result = json.load(f)
  668. return {'status': 1, 'quant_result': result}
  669. def get_export_status(data, workspace):
  670. """ 获取导出状态
  671. Args:
  672. data为dict, key包括
  673. 'tid'任务id
  674. Return:
  675. 目前导出状态.
  676. """
  677. from .operate import get_export_status
  678. tid = data['tid']
  679. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  680. task_path = workspace.tasks[tid].path
  681. status, message = get_export_status(task_path)
  682. if status == TaskStatus.XEXPORTED:
  683. return {
  684. 'status': 1,
  685. 'export_status': status,
  686. 'message': "恭喜您,{}任务模型导出成功!".format(tid)
  687. }
  688. if status == TaskStatus.XEXPORTFAIL:
  689. return {
  690. 'status': -1,
  691. 'export_status': status,
  692. 'message': "{}任务模型导出失败,请重试!".format(tid)
  693. }
  694. return {
  695. 'status': 1,
  696. 'export_status': status,
  697. 'message': "{}任务模型导出中,请稍等!".format(tid)
  698. }
  699. def export_infer_model(data, workspace, monitored_processes):
  700. """导出部署模型
  701. Args:
  702. data为dict,key包括
  703. 'tid'任务id, 'save_dir'导出模型保存路径
  704. """
  705. from .operate import export_noquant_model, export_quant_model
  706. tid = data['tid']
  707. save_dir = data['save_dir']
  708. epoch = data['epoch'] if 'epoch' in data else None
  709. quant = data['quant'] if 'quant' in data else False
  710. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  711. path = workspace.tasks[tid].path
  712. if quant:
  713. p = export_quant_model(path, save_dir, epoch)
  714. else:
  715. p = export_noquant_model(path, save_dir, epoch)
  716. monitored_processes.put(p.pid)
  717. return {'status': 1, 'save_dir': save_dir}
  718. def export_lite_model(data, workspace):
  719. """ 导出lite模型
  720. Args:
  721. data为dict, key包括
  722. 'tid'任务id, 'save_dir'导出模型保存路径
  723. """
  724. from .operate import opt_lite_model
  725. model_path = data['model_path']
  726. save_dir = data['save_dir']
  727. tid = data['tid']
  728. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  729. opt_lite_model(model_path, save_dir)
  730. if not osp.exists(osp.join(save_dir, "model.nb")):
  731. if osp.exists(save_dir):
  732. shutil.rmtree(save_dir)
  733. return {'status': -1, 'message': "导出为lite模型失败"}
  734. return {'status': 1, 'message': "完成"}
  735. def stop_export_task(data, workspace):
  736. """ 停止导出任务
  737. Args:
  738. data为dict, key包括
  739. 'tid'任务id
  740. Return:
  741. 目前导出的状态.
  742. """
  743. from .operate import stop_export_task
  744. tid = data['tid']
  745. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  746. task_path = workspace.tasks[tid].path
  747. status, message = stop_export_task(task_path)
  748. return {'status': 1, 'export_status': status.value, 'message': message}
  749. def _open_vdl(logdir, current_port):
  750. from visualdl.server import app
  751. app.run(logdir=logdir, host='0.0.0.0', port=current_port)
  752. def open_vdl(data, workspace, current_port, monitored_processes,
  753. running_boards):
  754. """打开vdl页面
  755. Args:
  756. data为dict,
  757. 'tid' 任务id
  758. """
  759. tid = data['tid']
  760. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  761. ip = get_ip()
  762. if tid in running_boards:
  763. url = ip + ":{}".format(running_boards[tid][0])
  764. return {'status': 1, 'url': url}
  765. task_path = workspace.tasks[tid].path
  766. logdir = osp.join(task_path, 'output', 'vdl_log')
  767. assert osp.exists(logdir), "该任务还未正常产生日志文件"
  768. port_available = is_available(ip, current_port)
  769. while not port_available:
  770. current_port += 1
  771. port_available = is_available(ip, current_port)
  772. assert current_port <= 8500, "找不到可用的端口"
  773. p = mp.Process(target=_open_vdl, args=(logdir, current_port))
  774. p.start()
  775. monitored_processes.put(p.pid)
  776. url = ip + ":{}".format(current_port)
  777. running_boards[tid] = [current_port, p.pid]
  778. current_port += 1
  779. total_time = 0
  780. while True:
  781. if not is_available(ip, current_port - 1):
  782. break
  783. print(current_port)
  784. time.sleep(0.5)
  785. total_time += 0.5
  786. assert total_time <= 8, "VisualDL服务启动超时,请重新尝试打开"
  787. return {'status': 1, 'url': url}