task.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797
  1. # copytrue (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from .. import workspace_pb2 as w
  15. import os
  16. import os.path as osp
  17. import shutil
  18. import time
  19. import pickle
  20. import json
  21. import multiprocessing as mp
  22. from ..utils import set_folder_status, TaskStatus, get_folder_status, is_available, get_ip
  23. from .train.params import ClsParams, DetParams, SegParams
  24. def create_task(data, workspace):
  25. """根据request创建task。
  26. Args:
  27. data为dict,key包括
  28. 'pid'所属项目id, 'train'训练参数。训练参数和数据增强参数以pickle的形式保存
  29. 在任务目录下的params.pkl文件中。 'parent_id'(可选)该裁剪训练任务的父任务,
  30. 'desc'(可选)任务描述。
  31. """
  32. create_time = time.time()
  33. time_array = time.localtime(create_time)
  34. create_time = time.strftime("%Y-%m-%d %H:%M:%S", time_array)
  35. id = workspace.max_task_id + 1
  36. workspace.max_task_id = id
  37. if id < 10000:
  38. id = 'T%04d' % id
  39. else:
  40. id = 'T{}'.format(id)
  41. pid = data['pid']
  42. assert pid in workspace.projects, "【任务创建】项目ID'{}'不存在.".format(pid)
  43. assert not id in workspace.tasks, "【任务创建】任务ID'{}'已经被占用.".format(id)
  44. did = workspace.projects[pid].did
  45. assert did in workspace.datasets, "【任务创建】数据集ID'{}'不存在".format(did)
  46. path = osp.join(workspace.projects[pid].path, id)
  47. if not osp.exists(path):
  48. os.makedirs(path)
  49. set_folder_status(path, TaskStatus.XINIT)
  50. data['task_type'] = workspace.projects[pid].type
  51. data['dataset_path'] = workspace.datasets[did].path
  52. data['pretrain_weights_download_save_dir'] = osp.join(workspace.path,
  53. 'pretrain')
  54. #获取参数
  55. if 'train' in data:
  56. params_json = json.loads(data['train'])
  57. if (data['task_type'] == 'classification'):
  58. params_init = ClsParams()
  59. if (data['task_type'] == 'detection' or
  60. data['task_type'] == 'instance_segmentation'):
  61. params_init = DetParams()
  62. if (data['task_type'] == 'segmentation' or
  63. data['task_type'] == 'remote_segmentation'):
  64. params_init = SegParams()
  65. params_init.load_from_dict(params_json)
  66. data['train'] = params_init
  67. parent_id = ''
  68. if 'parent_id' in data:
  69. data['tid'] = data['parent_id']
  70. parent_id = data['parent_id']
  71. assert data['parent_id'] in workspace.tasks, "【任务创建】裁剪任务创建失败".format(
  72. data['parent_id'])
  73. r = get_task_params(data, workspace)
  74. train_params = r['train']
  75. data['train'] = train_params
  76. desc = ""
  77. if 'desc' in data:
  78. desc = data['desc']
  79. with open(osp.join(path, 'params.pkl'), 'wb') as f:
  80. pickle.dump(data, f)
  81. task = w.Task(
  82. id=id,
  83. pid=pid,
  84. path=path,
  85. create_time=create_time,
  86. parent_id=parent_id,
  87. desc=desc)
  88. workspace.tasks[id].CopyFrom(task)
  89. with open(os.path.join(path, 'info.pb'), 'wb') as f:
  90. f.write(task.SerializeToString())
  91. return {'status': 1, 'tid': id}
  92. def delete_task(data, workspace):
  93. """删除task。
  94. Args:
  95. data为dict,key包括
  96. 'tid'任务id
  97. """
  98. task_id = data['tid']
  99. assert task_id in workspace.tasks, "任务ID'{}'不存在.".format(task_id)
  100. if osp.exists(workspace.tasks[task_id].path):
  101. shutil.rmtree(workspace.tasks[task_id].path)
  102. del workspace.tasks[task_id]
  103. return {'status': 1}
  104. def get_task_params(data, workspace):
  105. """根据request获取task的参数。
  106. Args:
  107. data为dict,key包括
  108. 'tid'任务id
  109. """
  110. tid = data['tid']
  111. assert tid in workspace.tasks, "【任务创建】任务ID'{}'不存在.".format(tid)
  112. path = workspace.tasks[tid].path
  113. with open(osp.join(path, 'params.pkl'), 'rb') as f:
  114. task_params = pickle.load(f)
  115. return {'status': 1, 'train': task_params['train']}
  116. def list_tasks(data, workspace):
  117. '''列出任务列表,可request的参数进行筛选
  118. Args:
  119. data为dict, 包括
  120. 'pid'(可选)所属项目id
  121. '''
  122. task_list = list()
  123. for key in workspace.tasks:
  124. task_id = workspace.tasks[key].id
  125. task_name = workspace.tasks[key].name
  126. task_desc = workspace.tasks[key].desc
  127. task_pid = workspace.tasks[key].pid
  128. task_path = workspace.tasks[key].path
  129. task_create_time = workspace.tasks[key].create_time
  130. from .operate import get_task_status
  131. path = workspace.tasks[task_id].path
  132. status, message = get_task_status(path)
  133. if data is not None:
  134. if "pid" in data:
  135. if data["pid"] != task_pid:
  136. continue
  137. attr = {
  138. "id": task_id,
  139. "name": task_name,
  140. "desc": task_desc,
  141. "pid": task_pid,
  142. "path": task_path,
  143. "create_time": task_create_time,
  144. "status": status.value
  145. }
  146. task_list.append(attr)
  147. return {'status': 1, 'tasks': task_list}
  148. def set_task_params(data, workspace):
  149. """根据request设置task的参数。只有在task是TaskStatus.XINIT状态时才有效
  150. Args:
  151. data为dict,key包括
  152. 'tid'任务id, 'train'训练参数. 训练
  153. 参数和数据增强参数以pickle的形式保存在任务目录下的params.pkl文件
  154. 中。
  155. """
  156. tid = data['tid']
  157. train = data['train']
  158. assert tid in workspace.tasks, "【任务创建】任务ID'{}'不存在.".format(tid)
  159. path = workspace.tasks[tid].path
  160. status = get_folder_status(path)
  161. assert status == TaskStatus.XINIT, "该任务不在初始化阶段,设置参数失败"
  162. with open(osp.join(path, 'params.pkl'), 'rb') as f:
  163. task_params = pickle.load(f)
  164. train_json = json.loads(train)
  165. task_params['train'].load_from_dict(train_json)
  166. with open(osp.join(path, 'params.pkl'), 'wb') as f:
  167. pickle.dump(task_params, f)
  168. return {'status': 1}
  169. def get_default_params(data, workspace, machine_info):
  170. from .train.params_v2 import get_params
  171. from ..dataset.dataset import get_dataset_details
  172. pid = data['pid']
  173. assert pid in workspace.projects, "项目ID{}不存在.".format(pid)
  174. project_type = workspace.projects[pid].type
  175. did = workspace.projects[pid].did
  176. result = get_dataset_details({'did': did}, workspace)
  177. if result['status'] == 1:
  178. details = result['details']
  179. else:
  180. raise Exception("Fail to get dataset details!")
  181. train_num = len(details['train_files'])
  182. class_num = len(details['labels'])
  183. if machine_info['gpu_num'] == 0:
  184. gpu_num = 0
  185. per_gpu_memory = 0
  186. gpu_list = None
  187. else:
  188. if gpu_list in data:
  189. gpu_list = data['gpu_list']
  190. gpu_num = len(gpu_list)
  191. per_gpu_memory = None
  192. for gpu_id in gpu_list:
  193. if per_gpu_memory is None:
  194. per_gpu_memory = machine_info['gpu_free_mem'][gpu_id]
  195. elif machine_info['gpu_free_mem'][gpu_id] < per_gpu_memory:
  196. per_gpu_memory = machine_info['gpu_free_mem'][gpu_id]
  197. else:
  198. gpu_num = 1
  199. per_gpu_memory = machine_info['gpu_free_mem'][0]
  200. gpu_list = [0]
  201. params = get_params(data, project_type, train_num, class_num, gpu_num,
  202. per_gpu_memory, gpu_list)
  203. return {"status": 1, "train": params}
  204. def get_task_params(data, workspace):
  205. """根据request获取task的参数。
  206. Args:
  207. data为dict,key包括
  208. 'tid'任务id
  209. """
  210. tid = data['tid']
  211. assert tid in workspace.tasks, "【任务创建】任务ID'{}'不存在.".format(tid)
  212. path = workspace.tasks[tid].path
  213. with open(osp.join(path, 'params.pkl'), 'rb') as f:
  214. task_params = pickle.load(f)
  215. return {'status': 1, 'train': task_params['train']}
  216. def get_task_status(data, workspace):
  217. """ 获取任务状态
  218. Args:
  219. data为dict, key包括
  220. 'tid'任务id, 'resume'(可选):获取是否可以恢复训练的状态
  221. """
  222. from .operate import get_task_status, get_task_max_saved_epochs
  223. tid = data['tid']
  224. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  225. path = workspace.tasks[tid].path
  226. status, message = get_task_status(path)
  227. if 'resume' in data:
  228. max_saved_epochs = get_task_max_saved_epochs(path)
  229. params = {'tid': tid}
  230. results = get_task_params(params, workspace)
  231. total_epochs = results['train'].num_epochs
  232. resumable = max_saved_epochs > 0 and max_saved_epochs < total_epochs
  233. return {
  234. 'status': 1,
  235. 'task_status': status.value,
  236. 'message': message,
  237. 'resumable': resumable,
  238. 'max_saved_epochs': max_saved_epochs
  239. }
  240. return {'status': 1, 'task_status': status.value, 'message': message}
  241. def get_train_metrics(data, workspace):
  242. """ 获取任务日志
  243. Args:
  244. data为dict, key包括
  245. 'tid'任务id
  246. Return:
  247. train_log(dict): 'eta':剩余时间,'train_metrics': 训练指标,'eval_metircs': 评估指标,
  248. 'download_status': 下载模型状态,'eval_done': 是否已保存模型,'train_error': 训练错误原因
  249. """
  250. tid = data['tid']
  251. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  252. from ..utils import TrainLogReader
  253. task_path = workspace.tasks[tid].path
  254. log_file = osp.join(task_path, 'out.log')
  255. train_log = TrainLogReader(log_file)
  256. train_log.update()
  257. train_log = train_log.__dict__
  258. return {'status': 1, 'train_log': train_log}
  259. def get_eval_metrics(data, workspace):
  260. """ 获取任务日志
  261. Args:
  262. data为dict, key包括
  263. 'tid'父任务id
  264. """
  265. tid = data['tid']
  266. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  267. best_model_path = osp.join(workspace.tasks[tid].path, "output",
  268. "best_model", "model.yml")
  269. import yaml
  270. f = open(best_model_path, "r", encoding="utf-8")
  271. eval_metrics = yaml.load(f)['_Attributes']['eval_metrics']
  272. f.close()
  273. return {'status': 1, 'eval_metric': eval_metrics}
  274. def get_eval_all_metrics(data, workspace):
  275. tid = data['tid']
  276. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  277. output_dir = osp.join(workspace.tasks[tid].path, "output")
  278. epoch_result_dict = dict()
  279. best_epoch = -1
  280. best_result = -1
  281. import yaml
  282. for file_dir in os.listdir(output_dir):
  283. if file_dir.startswith("epoch"):
  284. epoch_dir = osp.join(output_dir, file_dir)
  285. if osp.exists(osp.join(epoch_dir, ".success")):
  286. epoch_index = int(file_dir.split('_')[-1])
  287. yml_file_path = osp.join(epoch_dir, "model.yml")
  288. f = open(yml_file_path, 'r', encoding='utf-8')
  289. yml_file = yaml.load(f.read())
  290. result = yml_file["_Attributes"]["eval_metrics"]
  291. key = list(result.keys())[0]
  292. value = result[key]
  293. if value > best_result:
  294. best_result = value
  295. best_epoch = epoch_index
  296. elif value == best_result:
  297. if epoch_index < best_epoch:
  298. best_epoch = epoch_index
  299. epoch_result_dict[epoch_index] = value
  300. return {
  301. 'status': 1,
  302. 'key': key,
  303. 'epoch_result_dict': epoch_result_dict,
  304. 'best_epoch': best_epoch,
  305. 'best_result': best_result
  306. }
  307. def get_sensitivities_loss_img(data, workspace):
  308. """ 获取敏感度与模型裁剪率关系图
  309. Args:
  310. data为dict, key包括
  311. 'tid'任务id
  312. """
  313. tid = data['tid']
  314. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  315. task_path = workspace.tasks[tid].path
  316. pkl_path = osp.join(task_path, 'prune', 'sensitivities_xy.pkl')
  317. import pickle
  318. f = open(pkl_path, 'rb')
  319. sensitivities_xy = pickle.load(f)
  320. return {'status': 1, 'sensitivities_loss_img': sensitivities_xy}
  321. def start_train_task(data, workspace, monitored_processes):
  322. """启动训练任务。
  323. Args:
  324. data为dict,key包括
  325. 'tid'任务id, 'eval_metric_loss'(可选)裁剪任务所需的评估loss
  326. """
  327. from .operate import train_model
  328. tid = data['tid']
  329. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  330. path = workspace.tasks[tid].path
  331. if 'eval_metric_loss' in data and \
  332. data['eval_metric_loss'] is not None:
  333. # 裁剪任务
  334. parent_id = workspace.tasks[tid].parent_id
  335. assert parent_id != "", "任务{}不是裁剪训练任务".format(tid)
  336. parent_path = workspace.tasks[parent_id].path
  337. sensitivities_path = osp.join(parent_path, 'prune',
  338. 'sensitivities.data')
  339. eval_metric_loss = data['eval_metric_loss']
  340. parent_best_model_path = osp.join(parent_path, 'output', 'best_model')
  341. params_conf_file = osp.join(path, 'params.pkl')
  342. with open(params_conf_file, 'rb') as f:
  343. params = pickle.load(f)
  344. params['train'].sensitivities_path = sensitivities_path
  345. params['train'].eval_metric_loss = eval_metric_loss
  346. params['train'].pretrain_weights = parent_best_model_path
  347. with open(params_conf_file, 'wb') as f:
  348. pickle.dump(params, f)
  349. p = train_model(path)
  350. monitored_processes.put(p.pid)
  351. return {'status': 1}
  352. def resume_train_task(data, workspace, monitored_processes):
  353. """恢复训练任务
  354. Args:
  355. data为dict, key包括
  356. 'tid'任务id,'epoch'恢复训练的起始轮数
  357. """
  358. from .operate import train_model
  359. tid = data['tid']
  360. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  361. path = workspace.tasks[tid].path
  362. epoch_path = "epoch_" + str(data['epoch'])
  363. resume_checkpoint_path = osp.join(path, "output", epoch_path)
  364. params_conf_file = osp.join(path, 'params.pkl')
  365. with open(params_conf_file, 'rb') as f:
  366. params = pickle.load(f)
  367. params['train'].resume_checkpoint = resume_checkpoint_path
  368. with open(params_conf_file, 'wb') as f:
  369. pickle.dump(params, f)
  370. p = train_model(path)
  371. monitored_processes.put(p.pid)
  372. return {'status': 1}
  373. def stop_train_task(data, workspace):
  374. """停止训练任务
  375. Args:
  376. data为dict, key包括
  377. 'tid'任务id
  378. """
  379. from .operate import stop_train_model
  380. tid = data['tid']
  381. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  382. path = workspace.tasks[tid].path
  383. stop_train_model(path)
  384. return {'status': 1}
  385. def start_prune_analysis(data, workspace, monitored_processes):
  386. """开始模型裁剪分析
  387. Args:
  388. data为dict, key包括
  389. 'tid'任务id
  390. """
  391. tid = data['tid']
  392. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  393. task_path = workspace.tasks[tid].path
  394. from .operate import prune_analysis_model
  395. p = prune_analysis_model(task_path)
  396. monitored_processes.put(p.pid)
  397. return {'status': 1}
  398. def get_prune_metrics(data, workspace):
  399. """ 获取模型裁剪分析日志
  400. Args:
  401. data为dict, key包括
  402. 'tid'任务id
  403. Return:
  404. prune_log(dict): 'eta':剩余时间,'iters': 模型裁剪总轮数,'current': 当前轮数,
  405. 'progress': 模型裁剪进度
  406. """
  407. tid = data['tid']
  408. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  409. from ..utils import PruneLogReader
  410. task_path = workspace.tasks[tid].path
  411. log_file = osp.join(task_path, 'prune', 'out.log')
  412. # assert osp.exists(log_file), "模型裁剪分析任务还未开始,请稍等"
  413. if not osp.exists(log_file):
  414. return {'status': 1, 'prune_log': None}
  415. prune_log = PruneLogReader(log_file)
  416. prune_log.update()
  417. prune_log = prune_log.__dict__
  418. return {'status': 1, 'prune_log': prune_log}
  419. def get_prune_status(data, workspace):
  420. """ 获取模型裁剪状态
  421. Args:
  422. data为dict, key包括
  423. 'tid'任务id
  424. """
  425. from .operate import get_prune_status
  426. tid = data['tid']
  427. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  428. path = workspace.tasks[tid].path
  429. prune_path = osp.join(path, "prune")
  430. status, message = get_prune_status(prune_path)
  431. if status is not None:
  432. status = status.value
  433. return {'status': 1, 'prune_status': status, 'message': message}
  434. def stop_prune_analysis(data, workspace):
  435. """停止模型裁剪分析
  436. Args:
  437. data为dict, key包括
  438. 'tid'任务id
  439. """
  440. tid = data['tid']
  441. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  442. from .operate import stop_prune_analysis
  443. prune_path = osp.join(workspace.tasks[tid].path, 'prune')
  444. stop_prune_analysis(prune_path)
  445. return {'status': 1}
  446. def evaluate_model(data, workspace, monitored_processes):
  447. """ 模型评估
  448. Args:
  449. data为dict, key包括
  450. 'tid'任务id, topk, score_thresh, overlap_thresh这些评估所需参数
  451. Return:
  452. None
  453. """
  454. from .operate import evaluate_model
  455. tid = data['tid']
  456. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  457. pid = workspace.tasks[tid].pid
  458. assert pid in workspace.projects, "项目ID'{}'不存在".format(pid)
  459. path = workspace.tasks[tid].path
  460. type = workspace.projects[pid].type
  461. p = evaluate_model(path, type, data['epoch'], data['topk'],
  462. data['score_thresh'], data['overlap_thresh'])
  463. monitored_processes.put(p.pid)
  464. return {'status': 1}
  465. def get_evaluate_result(data, workspace):
  466. """ 获评估结果
  467. Args:
  468. data为dict, key包括
  469. 'tid'任务id
  470. Return:
  471. 包含评估指标的dict
  472. """
  473. from .operate import get_evaluate_status
  474. tid = data['tid']
  475. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  476. task_path = workspace.tasks[tid].path
  477. status, message = get_evaluate_status(task_path)
  478. if status == TaskStatus.XEVALUATED:
  479. result_file = osp.join(task_path, 'eval_res.pkl')
  480. if os.path.exists(result_file):
  481. result = pickle.load(open(result_file, "rb"))
  482. return {
  483. 'status': 1,
  484. 'evaluate_status': status,
  485. 'message': "{}评估完成".format(tid),
  486. 'path': result_file,
  487. 'result': result
  488. }
  489. else:
  490. return {
  491. 'status': -1,
  492. 'evaluate_status': status,
  493. 'message': "评估结果丢失,建议重新评估!",
  494. 'result': None
  495. }
  496. if status == TaskStatus.XEVALUATEFAIL:
  497. return {
  498. 'status': -1,
  499. 'evaluate_status': status,
  500. 'message': "评估失败,请重新评估!",
  501. 'result': None
  502. }
  503. return {
  504. 'status': 1,
  505. 'evaluate_status': status,
  506. 'message': "{}正在评估中,请稍后!".format(tid),
  507. 'result': None
  508. }
  509. def get_predict_status(data, workspace):
  510. from .operate import get_predict_status
  511. tid = data['tid']
  512. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  513. path = workspace.tasks[tid].path
  514. status, message, predict_num, total_num = get_predict_status(path)
  515. return {
  516. 'status': 1,
  517. 'predict_status': status.value,
  518. 'message': message,
  519. 'predict_num': predict_num,
  520. 'total_num': total_num
  521. }
  522. def predict_test_pics(data, workspace, monitored_processes):
  523. from .operate import predict_test_pics
  524. tid = data['tid']
  525. if 'img_list' in data:
  526. img_list = data['img_list']
  527. else:
  528. img_list = list()
  529. if 'image_data' in data:
  530. image_data = data['image_data']
  531. else:
  532. image_data = None
  533. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  534. path = workspace.tasks[tid].path
  535. save_dir = data['save_dir'] if 'save_dir' in data else None
  536. epoch = data['epoch'] if 'epoch' in data else None
  537. score_thresh = data['score_thresh'] if 'score_thresh' in data else 0.5
  538. p, save_dir = predict_test_pics(
  539. path,
  540. save_dir=save_dir,
  541. img_list=img_list,
  542. img_data=image_data,
  543. score_thresh=score_thresh,
  544. epoch=epoch)
  545. monitored_processes.put(p.pid)
  546. if 'image_data' in data:
  547. path = osp.join(save_dir, 'predict_result.png')
  548. else:
  549. path = None
  550. return {'status': 1, 'path': path}
  551. def stop_predict_task(data, workspace):
  552. from .operate import stop_predict_task
  553. tid = data['tid']
  554. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  555. path = workspace.tasks[tid].path
  556. status, message, predict_num, total_num = stop_predict_task(path)
  557. return {
  558. 'status': 1,
  559. 'predict_status': status.value,
  560. 'message': message,
  561. 'predict_num': predict_num,
  562. 'total_num': total_num
  563. }
  564. def get_quant_progress(data, workspace):
  565. tid = data['tid']
  566. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  567. from ..utils import QuantLogReader
  568. export_path = osp.join(workspace.tasks[tid].path, "./logs/export")
  569. log_file = osp.join(export_path, 'out.log')
  570. quant_log = QuantLogReader(log_file)
  571. quant_log.update()
  572. quant_log = quant_log.__dict__
  573. return {'status': 1, 'quant_log': quant_log}
  574. def get_quant_result(data, workspace):
  575. tid = data['tid']
  576. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  577. export_path = osp.join(workspace.tasks[tid].path, "./logs/export")
  578. result_json = osp.join(export_path, 'quant_result.json')
  579. result = {}
  580. import json
  581. if osp.exists(result_json):
  582. with open(result_json, 'r') as f:
  583. result = json.load(f)
  584. return {'status': 1, 'quant_result': result}
  585. def get_export_status(data, workspace):
  586. """ 获取导出状态
  587. Args:
  588. data为dict, key包括
  589. 'tid'任务id
  590. Return:
  591. 目前导出状态.
  592. """
  593. from .operate import get_export_status
  594. tid = data['tid']
  595. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  596. task_path = workspace.tasks[tid].path
  597. status, message = get_export_status(task_path)
  598. if status == TaskStatus.XEXPORTED:
  599. return {
  600. 'status': 1,
  601. 'export_status': status,
  602. 'message': "恭喜您,{}任务模型导出成功!".format(tid)
  603. }
  604. if status == TaskStatus.XEXPORTFAIL:
  605. return {
  606. 'status': -1,
  607. 'export_status': status,
  608. 'message': "{}任务模型导出失败,请重试!".format(tid)
  609. }
  610. return {
  611. 'status': 1,
  612. 'export_status': status,
  613. 'message': "{}任务模型导出中,请稍等!".format(tid)
  614. }
  615. def export_infer_model(data, workspace, monitored_processes):
  616. """导出部署模型
  617. Args:
  618. data为dict,key包括
  619. 'tid'任务id, 'save_dir'导出模型保存路径
  620. """
  621. from .operate import export_noquant_model, export_quant_model
  622. tid = data['tid']
  623. save_dir = data['save_dir']
  624. epoch = data['epoch'] if 'epoch' in data else None
  625. quant = data['quant'] if 'quant' in data else False
  626. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  627. path = workspace.tasks[tid].path
  628. if quant:
  629. p = export_quant_model(path, save_dir, epoch)
  630. else:
  631. p = export_noquant_model(path, save_dir, epoch)
  632. monitored_processes.put(p.pid)
  633. return {'status': 1, 'save_dir': save_dir}
  634. def export_lite_model(data, workspace):
  635. """ 导出lite模型
  636. Args:
  637. data为dict, key包括
  638. 'tid'任务id, 'save_dir'导出模型保存路径
  639. """
  640. from .operate import opt_lite_model
  641. model_path = data['model_path']
  642. save_dir = data['save_dir']
  643. tid = data['tid']
  644. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  645. opt_lite_model(model_path, save_dir)
  646. if not osp.exists(osp.join(save_dir, "model.nb")):
  647. if osp.exists(save_dir):
  648. shutil.rmtree(save_dir)
  649. return {'status': -1, 'message': "导出为lite模型失败"}
  650. return {'status': 1, 'message': "完成"}
  651. def stop_export_task(data, workspace):
  652. """ 停止导出任务
  653. Args:
  654. data为dict, key包括
  655. 'tid'任务id
  656. Return:
  657. 目前导出的状态.
  658. """
  659. from .operate import stop_export_task
  660. tid = data['tid']
  661. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  662. task_path = workspace.tasks[tid].path
  663. status, message = stop_export_task(task_path)
  664. return {'status': 1, 'export_status': status.value, 'message': message}
  665. def _open_vdl(logdir, current_port):
  666. from visualdl.server import app
  667. app.run(logdir=logdir, host='0.0.0.0', port=current_port)
  668. def open_vdl(data, workspace, current_port, monitored_processes,
  669. running_boards):
  670. """打开vdl页面
  671. Args:
  672. data为dict,
  673. 'tid' 任务id
  674. """
  675. tid = data['tid']
  676. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  677. ip = get_ip()
  678. if tid in running_boards:
  679. url = ip + ":{}".format(running_boards[tid][0])
  680. return {'status': 1, 'url': url}
  681. task_path = workspace.tasks[tid].path
  682. logdir = osp.join(task_path, 'output', 'vdl_log')
  683. assert osp.exists(logdir), "该任务还未正常产生日志文件"
  684. port_available = is_available(ip, current_port)
  685. while not port_available:
  686. current_port += 1
  687. port_available = is_available(ip, current_port)
  688. assert current_port <= 8500, "找不到可用的端口"
  689. p = mp.Process(target=_open_vdl, args=(logdir, current_port))
  690. p.start()
  691. monitored_processes.put(p.pid)
  692. url = ip + ":{}".format(current_port)
  693. running_boards[tid] = [current_port, p.pid]
  694. current_port += 1
  695. total_time = 0
  696. while True:
  697. if not is_available(ip, current_port - 1):
  698. break
  699. print(current_port)
  700. time.sleep(0.5)
  701. total_time += 0.5
  702. assert total_time <= 8, "VisualDL服务启动超时,请重新尝试打开"
  703. return {'status': 1, 'url': url}