task.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807
  1. # copytrue (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from .. import workspace_pb2 as w
  15. import os
  16. import os.path as osp
  17. import shutil
  18. import time
  19. import pickle
  20. import json
  21. import multiprocessing as mp
  22. from ..utils import set_folder_status, TaskStatus, get_folder_status, is_available, get_ip
  23. from .train.params import ClsParams, DetParams, SegParams
  24. def create_task(data, workspace):
  25. """根据request创建task。
  26. Args:
  27. data为dict,key包括
  28. 'pid'所属项目id, 'train'训练参数。训练参数和数据增强参数以pickle的形式保存
  29. 在任务目录下的params.pkl文件中。 'parent_id'(可选)该裁剪训练任务的父任务,
  30. 'desc'(可选)任务描述。
  31. """
  32. create_time = time.time()
  33. time_array = time.localtime(create_time)
  34. create_time = time.strftime("%Y-%m-%d %H:%M:%S", time_array)
  35. id = workspace.max_task_id + 1
  36. workspace.max_task_id = id
  37. if id < 10000:
  38. id = 'T%04d' % id
  39. else:
  40. id = 'T{}'.format(id)
  41. pid = data['pid']
  42. assert pid in workspace.projects, "【任务创建】项目ID'{}'不存在.".format(pid)
  43. assert not id in workspace.tasks, "【任务创建】任务ID'{}'已经被占用.".format(id)
  44. did = workspace.projects[pid].did
  45. assert did in workspace.datasets, "【任务创建】数据集ID'{}'不存在".format(did)
  46. path = osp.join(workspace.projects[pid].path, id)
  47. if not osp.exists(path):
  48. os.makedirs(path)
  49. set_folder_status(path, TaskStatus.XINIT)
  50. data['task_type'] = workspace.projects[pid].type
  51. data['dataset_path'] = workspace.datasets[did].path
  52. data['pretrain_weights_download_save_dir'] = osp.join(workspace.path,
  53. 'pretrain')
  54. #获取参数
  55. if 'train' in data:
  56. params_json = json.loads(data['train'])
  57. if (data['task_type'] == 'classification'):
  58. params_init = ClsParams()
  59. if (data['task_type'] == 'detection' or
  60. data['task_type'] == 'instance_segmentation'):
  61. params_init = DetParams()
  62. if (data['task_type'] == 'segmentation' or
  63. data['task_type'] == 'remote_segmentation'):
  64. params_init = SegParams()
  65. params_init.load_from_dict(params_json)
  66. data['train'] = params_init
  67. parent_id = ''
  68. if 'parent_id' in data:
  69. data['tid'] = data['parent_id']
  70. parent_id = data['parent_id']
  71. assert data['parent_id'] in workspace.tasks, "【任务创建】裁剪任务创建失败".format(
  72. data['parent_id'])
  73. r = get_task_params(data, workspace)
  74. train_params = r['train']
  75. data['train'] = train_params
  76. desc = ""
  77. if 'desc' in data:
  78. desc = data['desc']
  79. with open(osp.join(path, 'params.pkl'), 'wb') as f:
  80. pickle.dump(data, f)
  81. task = w.Task(
  82. id=id,
  83. pid=pid,
  84. path=path,
  85. create_time=create_time,
  86. parent_id=parent_id,
  87. desc=desc)
  88. workspace.tasks[id].CopyFrom(task)
  89. with open(os.path.join(path, 'info.pb'), 'wb') as f:
  90. f.write(task.SerializeToString())
  91. return {'status': 1, 'tid': id}
  92. def delete_task(data, workspace):
  93. """删除task。
  94. Args:
  95. data为dict,key包括
  96. 'tid'任务id
  97. """
  98. task_id = data['tid']
  99. assert task_id in workspace.tasks, "任务ID'{}'不存在.".format(task_id)
  100. if osp.exists(workspace.tasks[task_id].path):
  101. shutil.rmtree(workspace.tasks[task_id].path)
  102. del workspace.tasks[task_id]
  103. return {'status': 1}
  104. def get_task_params(data, workspace):
  105. """根据request获取task的参数。
  106. Args:
  107. data为dict,key包括
  108. 'tid'任务id
  109. """
  110. tid = data['tid']
  111. assert tid in workspace.tasks, "【任务创建】任务ID'{}'不存在.".format(tid)
  112. path = workspace.tasks[tid].path
  113. with open(osp.join(path, 'params.pkl'), 'rb') as f:
  114. task_params = pickle.load(f)
  115. return {'status': 1, 'train': task_params['train']}
  116. def list_tasks(data, workspace):
  117. '''列出任务列表,可request的参数进行筛选
  118. Args:
  119. data为dict, 包括
  120. 'pid'(可选)所属项目id
  121. '''
  122. task_list = list()
  123. for key in workspace.tasks:
  124. task_id = workspace.tasks[key].id
  125. task_name = workspace.tasks[key].name
  126. task_desc = workspace.tasks[key].desc
  127. task_pid = workspace.tasks[key].pid
  128. task_path = workspace.tasks[key].path
  129. task_create_time = workspace.tasks[key].create_time
  130. task_type = workspace.projects[task_pid].type
  131. from .operate import get_task_status
  132. path = workspace.tasks[task_id].path
  133. status, message = get_task_status(path)
  134. if data is not None:
  135. if "pid" in data:
  136. if data["pid"] != task_pid:
  137. continue
  138. attr = {
  139. "id": task_id,
  140. "name": task_name,
  141. "desc": task_desc,
  142. "pid": task_pid,
  143. "path": task_path,
  144. "create_time": task_create_time,
  145. "status": status.value,
  146. 'type': task_type
  147. }
  148. task_list.append(attr)
  149. return {'status': 1, 'tasks': task_list}
  150. def set_task_params(data, workspace):
  151. """根据request设置task的参数。只有在task是TaskStatus.XINIT状态时才有效
  152. Args:
  153. data为dict,key包括
  154. 'tid'任务id, 'train'训练参数. 训练
  155. 参数和数据增强参数以pickle的形式保存在任务目录下的params.pkl文件
  156. 中。
  157. """
  158. tid = data['tid']
  159. train = data['train']
  160. assert tid in workspace.tasks, "【任务创建】任务ID'{}'不存在.".format(tid)
  161. path = workspace.tasks[tid].path
  162. status = get_folder_status(path)
  163. assert status == TaskStatus.XINIT, "该任务不在初始化阶段,设置参数失败"
  164. with open(osp.join(path, 'params.pkl'), 'rb') as f:
  165. task_params = pickle.load(f)
  166. train_json = json.loads(train)
  167. task_params['train'].load_from_dict(train_json)
  168. with open(osp.join(path, 'params.pkl'), 'wb') as f:
  169. pickle.dump(task_params, f)
  170. return {'status': 1}
  171. def get_default_params(data, workspace, machine_info):
  172. from .train.params_v2 import get_params
  173. from ..dataset.dataset import get_dataset_details
  174. pid = data['pid']
  175. assert pid in workspace.projects, "项目ID{}不存在.".format(pid)
  176. project_type = workspace.projects[pid].type
  177. did = workspace.projects[pid].did
  178. result = get_dataset_details({'did': did}, workspace)
  179. if result['status'] == 1:
  180. details = result['details']
  181. else:
  182. raise Exception("Fail to get dataset details!")
  183. train_num = len(details['train_files'])
  184. class_num = len(details['labels'])
  185. if machine_info['gpu_num'] == 0:
  186. gpu_num = 0
  187. per_gpu_memory = 0
  188. gpu_list = None
  189. else:
  190. if gpu_list in data:
  191. gpu_list = data['gpu_list']
  192. gpu_num = len(gpu_list)
  193. per_gpu_memory = None
  194. for gpu_id in gpu_list:
  195. if per_gpu_memory is None:
  196. per_gpu_memory = machine_info['gpu_free_mem'][gpu_id]
  197. elif machine_info['gpu_free_mem'][gpu_id] < per_gpu_memory:
  198. per_gpu_memory = machine_info['gpu_free_mem'][gpu_id]
  199. else:
  200. gpu_num = 1
  201. per_gpu_memory = machine_info['gpu_free_mem'][0]
  202. gpu_list = [0]
  203. params = get_params(data, project_type, train_num, class_num, gpu_num,
  204. per_gpu_memory, gpu_list)
  205. return {"status": 1, "train": params}
  206. def get_task_params(data, workspace):
  207. """根据request获取task的参数。
  208. Args:
  209. data为dict,key包括
  210. 'tid'任务id
  211. """
  212. tid = data['tid']
  213. assert tid in workspace.tasks, "【任务创建】任务ID'{}'不存在.".format(tid)
  214. path = workspace.tasks[tid].path
  215. with open(osp.join(path, 'params.pkl'), 'rb') as f:
  216. task_params = pickle.load(f)
  217. return {'status': 1, 'train': task_params['train']}
  218. def get_task_status(data, workspace):
  219. """ 获取任务状态
  220. Args:
  221. data为dict, key包括
  222. 'tid'任务id, 'resume'(可选):获取是否可以恢复训练的状态
  223. """
  224. from .operate import get_task_status, get_task_max_saved_epochs
  225. tid = data['tid']
  226. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  227. path = workspace.tasks[tid].path
  228. status, message = get_task_status(path)
  229. task_pid = workspace.tasks[tid].pid
  230. task_type = workspace.projects[task_pid].type
  231. if 'resume' in data:
  232. max_saved_epochs = get_task_max_saved_epochs(path)
  233. params = {'tid': tid}
  234. results = get_task_params(params, workspace)
  235. total_epochs = results['train'].num_epochs
  236. resumable = max_saved_epochs > 0 and max_saved_epochs < total_epochs
  237. return {
  238. 'status': 1,
  239. 'task_status': status.value,
  240. 'message': message,
  241. 'resumable': resumable,
  242. 'max_saved_epochs': max_saved_epochs,
  243. 'type': task_type
  244. }
  245. return {
  246. 'status': 1,
  247. 'task_status': status.value,
  248. 'message': message,
  249. 'type': task_type
  250. }
  251. def get_train_metrics(data, workspace):
  252. """ 获取任务日志
  253. Args:
  254. data为dict, key包括
  255. 'tid'任务id
  256. Return:
  257. train_log(dict): 'eta':剩余时间,'train_metrics': 训练指标,'eval_metircs': 评估指标,
  258. 'download_status': 下载模型状态,'eval_done': 是否已保存模型,'train_error': 训练错误原因
  259. """
  260. tid = data['tid']
  261. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  262. from ..utils import TrainLogReader
  263. task_path = workspace.tasks[tid].path
  264. log_file = osp.join(task_path, 'out.log')
  265. train_log = TrainLogReader(log_file)
  266. train_log.update()
  267. train_log = train_log.__dict__
  268. return {'status': 1, 'train_log': train_log}
  269. def get_eval_metrics(data, workspace):
  270. """ 获取任务日志
  271. Args:
  272. data为dict, key包括
  273. 'tid'父任务id
  274. """
  275. tid = data['tid']
  276. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  277. best_model_path = osp.join(workspace.tasks[tid].path, "output",
  278. "best_model", "model.yml")
  279. import yaml
  280. f = open(best_model_path, "r", encoding="utf-8")
  281. eval_metrics = yaml.load(f)['_Attributes']['eval_metrics']
  282. f.close()
  283. return {'status': 1, 'eval_metric': eval_metrics}
  284. def get_eval_all_metrics(data, workspace):
  285. tid = data['tid']
  286. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  287. output_dir = osp.join(workspace.tasks[tid].path, "output")
  288. epoch_result_dict = dict()
  289. best_epoch = -1
  290. best_result = -1
  291. import yaml
  292. for file_dir in os.listdir(output_dir):
  293. if file_dir.startswith("epoch"):
  294. epoch_dir = osp.join(output_dir, file_dir)
  295. if osp.exists(osp.join(epoch_dir, ".success")):
  296. epoch_index = int(file_dir.split('_')[-1])
  297. yml_file_path = osp.join(epoch_dir, "model.yml")
  298. f = open(yml_file_path, 'r', encoding='utf-8')
  299. yml_file = yaml.load(f.read())
  300. result = yml_file["_Attributes"]["eval_metrics"]
  301. key = list(result.keys())[0]
  302. value = result[key]
  303. if value > best_result:
  304. best_result = value
  305. best_epoch = epoch_index
  306. elif value == best_result:
  307. if epoch_index < best_epoch:
  308. best_epoch = epoch_index
  309. epoch_result_dict[epoch_index] = value
  310. return {
  311. 'status': 1,
  312. 'key': key,
  313. 'epoch_result_dict': epoch_result_dict,
  314. 'best_epoch': best_epoch,
  315. 'best_result': best_result
  316. }
  317. def get_sensitivities_loss_img(data, workspace):
  318. """ 获取敏感度与模型裁剪率关系图
  319. Args:
  320. data为dict, key包括
  321. 'tid'任务id
  322. """
  323. tid = data['tid']
  324. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  325. task_path = workspace.tasks[tid].path
  326. pkl_path = osp.join(task_path, 'prune', 'sensitivities_xy.pkl')
  327. import pickle
  328. f = open(pkl_path, 'rb')
  329. sensitivities_xy = pickle.load(f)
  330. return {'status': 1, 'sensitivities_loss_img': sensitivities_xy}
  331. def start_train_task(data, workspace, monitored_processes):
  332. """启动训练任务。
  333. Args:
  334. data为dict,key包括
  335. 'tid'任务id, 'eval_metric_loss'(可选)裁剪任务所需的评估loss
  336. """
  337. from .operate import train_model
  338. tid = data['tid']
  339. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  340. path = workspace.tasks[tid].path
  341. if 'eval_metric_loss' in data and \
  342. data['eval_metric_loss'] is not None:
  343. # 裁剪任务
  344. parent_id = workspace.tasks[tid].parent_id
  345. assert parent_id != "", "任务{}不是裁剪训练任务".format(tid)
  346. parent_path = workspace.tasks[parent_id].path
  347. sensitivities_path = osp.join(parent_path, 'prune',
  348. 'sensitivities.data')
  349. eval_metric_loss = data['eval_metric_loss']
  350. parent_best_model_path = osp.join(parent_path, 'output', 'best_model')
  351. params_conf_file = osp.join(path, 'params.pkl')
  352. with open(params_conf_file, 'rb') as f:
  353. params = pickle.load(f)
  354. params['train'].sensitivities_path = sensitivities_path
  355. params['train'].eval_metric_loss = eval_metric_loss
  356. params['train'].pretrain_weights = parent_best_model_path
  357. with open(params_conf_file, 'wb') as f:
  358. pickle.dump(params, f)
  359. p = train_model(path)
  360. monitored_processes.put(p.pid)
  361. return {'status': 1}
  362. def resume_train_task(data, workspace, monitored_processes):
  363. """恢复训练任务
  364. Args:
  365. data为dict, key包括
  366. 'tid'任务id,'epoch'恢复训练的起始轮数
  367. """
  368. from .operate import train_model
  369. tid = data['tid']
  370. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  371. path = workspace.tasks[tid].path
  372. epoch_path = "epoch_" + str(data['epoch'])
  373. resume_checkpoint_path = osp.join(path, "output", epoch_path)
  374. params_conf_file = osp.join(path, 'params.pkl')
  375. with open(params_conf_file, 'rb') as f:
  376. params = pickle.load(f)
  377. params['train'].resume_checkpoint = resume_checkpoint_path
  378. with open(params_conf_file, 'wb') as f:
  379. pickle.dump(params, f)
  380. p = train_model(path)
  381. monitored_processes.put(p.pid)
  382. return {'status': 1}
  383. def stop_train_task(data, workspace):
  384. """停止训练任务
  385. Args:
  386. data为dict, key包括
  387. 'tid'任务id
  388. """
  389. from .operate import stop_train_model
  390. tid = data['tid']
  391. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  392. path = workspace.tasks[tid].path
  393. stop_train_model(path)
  394. return {'status': 1}
  395. def start_prune_analysis(data, workspace, monitored_processes):
  396. """开始模型裁剪分析
  397. Args:
  398. data为dict, key包括
  399. 'tid'任务id
  400. """
  401. tid = data['tid']
  402. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  403. task_path = workspace.tasks[tid].path
  404. from .operate import prune_analysis_model
  405. p = prune_analysis_model(task_path)
  406. monitored_processes.put(p.pid)
  407. return {'status': 1}
  408. def get_prune_metrics(data, workspace):
  409. """ 获取模型裁剪分析日志
  410. Args:
  411. data为dict, key包括
  412. 'tid'任务id
  413. Return:
  414. prune_log(dict): 'eta':剩余时间,'iters': 模型裁剪总轮数,'current': 当前轮数,
  415. 'progress': 模型裁剪进度
  416. """
  417. tid = data['tid']
  418. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  419. from ..utils import PruneLogReader
  420. task_path = workspace.tasks[tid].path
  421. log_file = osp.join(task_path, 'prune', 'out.log')
  422. # assert osp.exists(log_file), "模型裁剪分析任务还未开始,请稍等"
  423. if not osp.exists(log_file):
  424. return {'status': 1, 'prune_log': None}
  425. prune_log = PruneLogReader(log_file)
  426. prune_log.update()
  427. prune_log = prune_log.__dict__
  428. return {'status': 1, 'prune_log': prune_log}
  429. def get_prune_status(data, workspace):
  430. """ 获取模型裁剪状态
  431. Args:
  432. data为dict, key包括
  433. 'tid'任务id
  434. """
  435. from .operate import get_prune_status
  436. tid = data['tid']
  437. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  438. path = workspace.tasks[tid].path
  439. prune_path = osp.join(path, "prune")
  440. status, message = get_prune_status(prune_path)
  441. if status is not None:
  442. status = status.value
  443. return {'status': 1, 'prune_status': status, 'message': message}
  444. def stop_prune_analysis(data, workspace):
  445. """停止模型裁剪分析
  446. Args:
  447. data为dict, key包括
  448. 'tid'任务id
  449. """
  450. tid = data['tid']
  451. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  452. from .operate import stop_prune_analysis
  453. prune_path = osp.join(workspace.tasks[tid].path, 'prune')
  454. stop_prune_analysis(prune_path)
  455. return {'status': 1}
  456. def evaluate_model(data, workspace, monitored_processes):
  457. """ 模型评估
  458. Args:
  459. data为dict, key包括
  460. 'tid'任务id, topk, score_thresh, overlap_thresh这些评估所需参数
  461. Return:
  462. None
  463. """
  464. from .operate import evaluate_model
  465. tid = data['tid']
  466. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  467. pid = workspace.tasks[tid].pid
  468. assert pid in workspace.projects, "项目ID'{}'不存在".format(pid)
  469. path = workspace.tasks[tid].path
  470. type = workspace.projects[pid].type
  471. p = evaluate_model(path, type, data['epoch'], data['topk'],
  472. data['score_thresh'], data['overlap_thresh'])
  473. monitored_processes.put(p.pid)
  474. return {'status': 1}
  475. def get_evaluate_result(data, workspace):
  476. """ 获评估结果
  477. Args:
  478. data为dict, key包括
  479. 'tid'任务id
  480. Return:
  481. 包含评估指标的dict
  482. """
  483. from .operate import get_evaluate_status
  484. tid = data['tid']
  485. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  486. task_path = workspace.tasks[tid].path
  487. status, message = get_evaluate_status(task_path)
  488. if status == TaskStatus.XEVALUATED:
  489. result_file = osp.join(task_path, 'eval_res.pkl')
  490. if os.path.exists(result_file):
  491. result = pickle.load(open(result_file, "rb"))
  492. return {
  493. 'status': 1,
  494. 'evaluate_status': status,
  495. 'message': "{}评估完成".format(tid),
  496. 'path': result_file,
  497. 'result': result
  498. }
  499. else:
  500. return {
  501. 'status': -1,
  502. 'evaluate_status': status,
  503. 'message': "评估结果丢失,建议重新评估!",
  504. 'result': None
  505. }
  506. if status == TaskStatus.XEVALUATEFAIL:
  507. return {
  508. 'status': -1,
  509. 'evaluate_status': status,
  510. 'message': "评估失败,请重新评估!",
  511. 'result': None
  512. }
  513. return {
  514. 'status': 1,
  515. 'evaluate_status': status,
  516. 'message': "{}正在评估中,请稍后!".format(tid),
  517. 'result': None
  518. }
  519. def get_predict_status(data, workspace):
  520. from .operate import get_predict_status
  521. tid = data['tid']
  522. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  523. path = workspace.tasks[tid].path
  524. status, message, predict_num, total_num = get_predict_status(path)
  525. return {
  526. 'status': 1,
  527. 'predict_status': status.value,
  528. 'message': message,
  529. 'predict_num': predict_num,
  530. 'total_num': total_num
  531. }
  532. def predict_test_pics(data, workspace, monitored_processes):
  533. from .operate import predict_test_pics
  534. tid = data['tid']
  535. if 'img_list' in data:
  536. img_list = data['img_list']
  537. else:
  538. img_list = list()
  539. if 'image_data' in data:
  540. image_data = data['image_data']
  541. else:
  542. image_data = None
  543. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  544. path = workspace.tasks[tid].path
  545. save_dir = data['save_dir'] if 'save_dir' in data else None
  546. epoch = data['epoch'] if 'epoch' in data else None
  547. score_thresh = data['score_thresh'] if 'score_thresh' in data else 0.5
  548. p, save_dir = predict_test_pics(
  549. path,
  550. save_dir=save_dir,
  551. img_list=img_list,
  552. img_data=image_data,
  553. score_thresh=score_thresh,
  554. epoch=epoch)
  555. monitored_processes.put(p.pid)
  556. if 'image_data' in data:
  557. path = osp.join(save_dir, 'predict_result.png')
  558. else:
  559. path = None
  560. return {'status': 1, 'path': path}
  561. def stop_predict_task(data, workspace):
  562. from .operate import stop_predict_task
  563. tid = data['tid']
  564. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  565. path = workspace.tasks[tid].path
  566. status, message, predict_num, total_num = stop_predict_task(path)
  567. return {
  568. 'status': 1,
  569. 'predict_status': status.value,
  570. 'message': message,
  571. 'predict_num': predict_num,
  572. 'total_num': total_num
  573. }
  574. def get_quant_progress(data, workspace):
  575. tid = data['tid']
  576. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  577. from ..utils import QuantLogReader
  578. export_path = osp.join(workspace.tasks[tid].path, "./logs/export")
  579. log_file = osp.join(export_path, 'out.log')
  580. quant_log = QuantLogReader(log_file)
  581. quant_log.update()
  582. quant_log = quant_log.__dict__
  583. return {'status': 1, 'quant_log': quant_log}
  584. def get_quant_result(data, workspace):
  585. tid = data['tid']
  586. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  587. export_path = osp.join(workspace.tasks[tid].path, "./logs/export")
  588. result_json = osp.join(export_path, 'quant_result.json')
  589. result = {}
  590. import json
  591. if osp.exists(result_json):
  592. with open(result_json, 'r') as f:
  593. result = json.load(f)
  594. return {'status': 1, 'quant_result': result}
  595. def get_export_status(data, workspace):
  596. """ 获取导出状态
  597. Args:
  598. data为dict, key包括
  599. 'tid'任务id
  600. Return:
  601. 目前导出状态.
  602. """
  603. from .operate import get_export_status
  604. tid = data['tid']
  605. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  606. task_path = workspace.tasks[tid].path
  607. status, message = get_export_status(task_path)
  608. if status == TaskStatus.XEXPORTED:
  609. return {
  610. 'status': 1,
  611. 'export_status': status,
  612. 'message': "恭喜您,{}任务模型导出成功!".format(tid)
  613. }
  614. if status == TaskStatus.XEXPORTFAIL:
  615. return {
  616. 'status': -1,
  617. 'export_status': status,
  618. 'message': "{}任务模型导出失败,请重试!".format(tid)
  619. }
  620. return {
  621. 'status': 1,
  622. 'export_status': status,
  623. 'message': "{}任务模型导出中,请稍等!".format(tid)
  624. }
  625. def export_infer_model(data, workspace, monitored_processes):
  626. """导出部署模型
  627. Args:
  628. data为dict,key包括
  629. 'tid'任务id, 'save_dir'导出模型保存路径
  630. """
  631. from .operate import export_noquant_model, export_quant_model
  632. tid = data['tid']
  633. save_dir = data['save_dir']
  634. epoch = data['epoch'] if 'epoch' in data else None
  635. quant = data['quant'] if 'quant' in data else False
  636. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  637. path = workspace.tasks[tid].path
  638. if quant:
  639. p = export_quant_model(path, save_dir, epoch)
  640. else:
  641. p = export_noquant_model(path, save_dir, epoch)
  642. monitored_processes.put(p.pid)
  643. return {'status': 1, 'save_dir': save_dir}
  644. def export_lite_model(data, workspace):
  645. """ 导出lite模型
  646. Args:
  647. data为dict, key包括
  648. 'tid'任务id, 'save_dir'导出模型保存路径
  649. """
  650. from .operate import opt_lite_model
  651. model_path = data['model_path']
  652. save_dir = data['save_dir']
  653. tid = data['tid']
  654. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  655. opt_lite_model(model_path, save_dir)
  656. if not osp.exists(osp.join(save_dir, "model.nb")):
  657. if osp.exists(save_dir):
  658. shutil.rmtree(save_dir)
  659. return {'status': -1, 'message': "导出为lite模型失败"}
  660. return {'status': 1, 'message': "完成"}
  661. def stop_export_task(data, workspace):
  662. """ 停止导出任务
  663. Args:
  664. data为dict, key包括
  665. 'tid'任务id
  666. Return:
  667. 目前导出的状态.
  668. """
  669. from .operate import stop_export_task
  670. tid = data['tid']
  671. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  672. task_path = workspace.tasks[tid].path
  673. status, message = stop_export_task(task_path)
  674. return {'status': 1, 'export_status': status.value, 'message': message}
  675. def _open_vdl(logdir, current_port):
  676. from visualdl.server import app
  677. app.run(logdir=logdir, host='0.0.0.0', port=current_port)
  678. def open_vdl(data, workspace, current_port, monitored_processes,
  679. running_boards):
  680. """打开vdl页面
  681. Args:
  682. data为dict,
  683. 'tid' 任务id
  684. """
  685. tid = data['tid']
  686. assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
  687. ip = get_ip()
  688. if tid in running_boards:
  689. url = ip + ":{}".format(running_boards[tid][0])
  690. return {'status': 1, 'url': url}
  691. task_path = workspace.tasks[tid].path
  692. logdir = osp.join(task_path, 'output', 'vdl_log')
  693. assert osp.exists(logdir), "该任务还未正常产生日志文件"
  694. port_available = is_available(ip, current_port)
  695. while not port_available:
  696. current_port += 1
  697. port_available = is_available(ip, current_port)
  698. assert current_port <= 8500, "找不到可用的端口"
  699. p = mp.Process(target=_open_vdl, args=(logdir, current_port))
  700. p.start()
  701. monitored_processes.put(p.pid)
  702. url = ip + ":{}".format(current_port)
  703. running_boards[tid] = [current_port, p.pid]
  704. current_port += 1
  705. total_time = 0
  706. while True:
  707. if not is_available(ip, current_port - 1):
  708. break
  709. print(current_port)
  710. time.sleep(0.5)
  711. total_time += 0.5
  712. assert total_time <= 8, "VisualDL服务启动超时,请重新尝试打开"
  713. return {'status': 1, 'url': url}