operate.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922
  1. # copytrue (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os.path as osp
  15. import os
  16. import numpy as np
  17. from PIL import Image
  18. import sys
  19. import cv2
  20. import psutil
  21. import shutil
  22. import pickle
  23. import base64
  24. import multiprocessing as mp
  25. from ..utils import (pkill, set_folder_status, get_folder_status, TaskStatus,
  26. PredictStatus, PruneStatus)
  27. from .evaluate.draw_pred_result import visualize_classified_result, visualize_detected_result, visualize_segmented_result
  28. from .visualize import plot_det_label, plot_insseg_label, get_color_map_list
  29. def _call_paddle_prune(best_model_path, prune_analysis_path, params):
  30. mode = 'w'
  31. sys.stdout = open(
  32. osp.join(prune_analysis_path, 'out.log'), mode, encoding='utf-8')
  33. sys.stderr = open(
  34. osp.join(prune_analysis_path, 'err.log'), mode, encoding='utf-8')
  35. sensitivities_path = osp.join(prune_analysis_path, "sensitivities.data")
  36. task_type = params['task_type']
  37. dataset_path = params['dataset_path']
  38. os.environ['CUDA_VISIBLE_DEVICES'] = params['train'].cuda_visible_devices
  39. if task_type == "classification":
  40. from .prune.classification import prune
  41. elif task_type in ["detection", "instance_segmentation"]:
  42. from .prune.detection import prune
  43. elif task_type == "segmentation":
  44. from .prune.segmentation import prune
  45. batch_size = params['train'].batch_size
  46. prune(best_model_path, dataset_path, sensitivities_path, batch_size)
  47. import paddlex as pdx
  48. from paddlex.cv.models.slim.visualize import visualize
  49. model = pdx.load_model(best_model_path)
  50. visualize(model, sensitivities_path, prune_analysis_path)
  51. set_folder_status(prune_analysis_path, PruneStatus.XSPRUNEDONE)
  52. def _call_paddlex_train(task_path, params):
  53. '''
  54. Args:
  55. params为dict,字段包括'pretrain_weights_download_save_dir': 预训练模型保存路径,
  56. 'task_type': 任务类型,'dataset_path': 数据集路径,'train':训练参数
  57. '''
  58. mode = 'w'
  59. if params['train'].resume_checkpoint is not None:
  60. mode = 'a'
  61. sys.stdout = open(osp.join(task_path, 'out.log'), mode, encoding='utf-8')
  62. sys.stderr = open(osp.join(task_path, 'err.log'), mode, encoding='utf-8')
  63. sys.stdout.write("This log file path is {}\n".format(
  64. osp.join(task_path, 'out.log')))
  65. sys.stdout.write("注意:标志为WARNING/INFO类的仅为警告或提示类信息,非错误信息\n")
  66. sys.stderr.write("This log file path is {}\n".format(
  67. osp.join(task_path, 'err.log')))
  68. sys.stderr.write("注意:标志为WARNING/INFO类的仅为警告或提示类信息,非错误信息\n")
  69. os.environ['CUDA_VISIBLE_DEVICES'] = params['train'].cuda_visible_devices
  70. import paddlex as pdx
  71. pdx.gui_mode = True
  72. pdx.log_level = 3
  73. pdx.pretrain_dir = params['pretrain_weights_download_save_dir']
  74. task_type = params['task_type']
  75. dataset_path = params['dataset_path']
  76. if task_type == "classification":
  77. from .train.classification import train
  78. elif task_type in ["detection", "instance_segmentation"]:
  79. from .train.detection import train
  80. elif task_type == "segmentation":
  81. from .train.segmentation import train
  82. train(task_path, dataset_path, params['train'])
  83. set_folder_status(task_path, TaskStatus.XTRAINDONE)
  84. def _call_paddlex_evaluate_model(task_path,
  85. model_path,
  86. task_type,
  87. epoch,
  88. topk=5,
  89. score_thresh=0.3,
  90. overlap_thresh=0.5):
  91. evaluate_status_path = osp.join(task_path, './logs/evaluate')
  92. sys.stdout = open(
  93. osp.join(evaluate_status_path, 'out.log'), 'w', encoding='utf-8')
  94. sys.stderr = open(
  95. osp.join(evaluate_status_path, 'err.log'), 'w', encoding='utf-8')
  96. if task_type == "classification":
  97. from .evaluate.classification import Evaluator
  98. evaluator = Evaluator(model_path, topk=topk)
  99. elif task_type == "detection":
  100. from .evaluate.detection import DetEvaluator
  101. evaluator = DetEvaluator(
  102. model_path,
  103. score_threshold=score_thresh,
  104. overlap_thresh=overlap_thresh)
  105. elif task_type == "instance_segmentation":
  106. from .evaluate.detection import InsSegEvaluator
  107. evaluator = InsSegEvaluator(
  108. model_path,
  109. score_threshold=score_thresh,
  110. overlap_thresh=overlap_thresh)
  111. elif task_type == "segmentation":
  112. from .evaluate.segmentation import Evaluator
  113. evaluator = Evaluator(model_path)
  114. report = evaluator.generate_report()
  115. report['epoch'] = epoch
  116. pickle.dump(report, open(osp.join(task_path, "eval_res.pkl"), "wb"))
  117. set_folder_status(evaluate_status_path, TaskStatus.XEVALUATED)
  118. set_folder_status(task_path, TaskStatus.XEVALUATED)
  119. def _call_paddlex_predict(task_path,
  120. predict_status_path,
  121. params,
  122. img_list,
  123. img_data,
  124. save_dir,
  125. score_thresh,
  126. epoch=None):
  127. total_num = open(
  128. osp.join(predict_status_path, 'total_num'), 'w', encoding='utf-8')
  129. def write_file_num(total_file_num):
  130. total_num.write(str(total_file_num))
  131. total_num.close()
  132. sys.stdout = open(
  133. osp.join(predict_status_path, 'out.log'), 'w', encoding='utf-8')
  134. sys.stderr = open(
  135. osp.join(predict_status_path, 'err.log'), 'w', encoding='utf-8')
  136. import paddlex as pdx
  137. pdx.log_level = 3
  138. task_type = params['task_type']
  139. dataset_path = params['dataset_path']
  140. if epoch is None:
  141. model_path = osp.join(task_path, 'output', 'best_model')
  142. else:
  143. model_path = osp.join(task_path, 'output', 'epoch_{}'.format(epoch))
  144. model = pdx.load_model(model_path)
  145. file_list = dict()
  146. predicted_num = 0
  147. if task_type == "classification":
  148. if img_data is None:
  149. if len(img_list) == 0 and osp.exists(
  150. osp.join(dataset_path, "test_list.txt")):
  151. with open(osp.join(dataset_path, "test_list.txt")) as f:
  152. for line in f:
  153. items = line.strip().split()
  154. file_list[osp.join(dataset_path, items[0])] = items[1]
  155. else:
  156. for image in img_list:
  157. file_list[image] = None
  158. total_file_num = len(file_list)
  159. write_file_num(total_file_num)
  160. for image, label_id in file_list.items():
  161. pred_result = {}
  162. if label_id is not None:
  163. pred_result["gt_label"] = model.labels[int(label_id)]
  164. results = model.predict(img_file=image)
  165. pred_result["label"] = []
  166. pred_result["score"] = []
  167. pred_result["topk"] = len(results)
  168. for res in results:
  169. pred_result["label"].append(res['category'])
  170. pred_result["score"].append(res['score'])
  171. visualize_classified_result(save_dir, image, pred_result)
  172. predicted_num += 1
  173. else:
  174. img_data = base64.b64decode(img_data)
  175. img_array = np.frombuffer(img_data, np.uint8)
  176. img = cv2.imdecode(img_array, cv2.COLOR_RGB2BGR)
  177. results = model.predict(img)
  178. pred_result = {}
  179. pred_result["label"] = []
  180. pred_result["score"] = []
  181. pred_result["topk"] = len(results)
  182. for res in results:
  183. pred_result["label"].append(res['category'])
  184. pred_result["score"].append(res['score'])
  185. visualize_classified_result(save_dir, img, pred_result)
  186. elif task_type in ["detection", "instance_segmentation"]:
  187. if img_data is None:
  188. if task_type == "detection" and osp.exists(
  189. osp.join(dataset_path, "test_list.txt")):
  190. if len(img_list) == 0 and osp.exists(
  191. osp.join(dataset_path, "test_list.txt")):
  192. with open(osp.join(dataset_path, "test_list.txt")) as f:
  193. for line in f:
  194. items = line.strip().split()
  195. file_list[osp.join(dataset_path, items[0])] = \
  196. osp.join(dataset_path, items[1])
  197. else:
  198. for image in img_list:
  199. file_list[image] = None
  200. total_file_num = len(file_list)
  201. write_file_num(total_file_num)
  202. for image, anno in file_list.items():
  203. results = model.predict(img_file=image)
  204. image_pred = pdx.det.visualize(
  205. image, results, threshold=score_thresh, save_dir=None)
  206. save_name = osp.join(save_dir, osp.split(image)[-1])
  207. image_gt = None
  208. if anno is not None:
  209. image_gt = plot_det_label(image, anno, model.labels)
  210. visualize_detected_result(save_name, image_gt, image_pred)
  211. predicted_num += 1
  212. elif len(img_list) == 0 and osp.exists(
  213. osp.join(dataset_path, "test.json")):
  214. from pycocotools.coco import COCO
  215. anno_path = osp.join(dataset_path, "test.json")
  216. coco = COCO(anno_path)
  217. img_ids = coco.getImgIds()
  218. total_file_num = len(img_ids)
  219. write_file_num(total_file_num)
  220. for img_id in img_ids:
  221. img_anno = coco.loadImgs(img_id)[0]
  222. file_name = img_anno['file_name']
  223. name = (osp.split(file_name)[-1]).split(".")[0]
  224. anno = osp.join(dataset_path, "Annotations", name + ".npy")
  225. img_file = osp.join(dataset_path, "JPEGImages", file_name)
  226. results = model.predict(img_file=img_file)
  227. image_pred = pdx.det.visualize(
  228. img_file,
  229. results,
  230. threshold=score_thresh,
  231. save_dir=None)
  232. save_name = osp.join(save_dir, osp.split(img_file)[-1])
  233. if task_type == "detection":
  234. image_gt = plot_det_label(img_file, anno, model.labels)
  235. else:
  236. image_gt = plot_insseg_label(img_file, anno,
  237. model.labels)
  238. visualize_detected_result(save_name, image_gt, image_pred)
  239. predicted_num += 1
  240. else:
  241. total_file_num = len(img_list)
  242. write_file_num(total_file_num)
  243. for image in img_list:
  244. results = model.predict(img_file=image)
  245. image_pred = pdx.det.visualize(
  246. image, results, threshold=score_thresh, save_dir=None)
  247. save_name = osp.join(save_dir, osp.split(image)[-1])
  248. visualize_detected_result(save_name, None, image_pred)
  249. predicted_num += 1
  250. else:
  251. img_data = base64.b64decode(img_data)
  252. img_array = np.frombuffer(img_data, np.uint8)
  253. img = cv2.imdecode(img_array, cv2.COLOR_RGB2BGR)
  254. results = model.predict(img)
  255. image_pred = pdx.det.visualize(
  256. img, results, threshold=score_thresh, save_dir=None)
  257. image_gt = None
  258. save_name = osp.join(save_dir, 'predict_result.png')
  259. visualize_detected_result(save_name, image_gt, image_pred)
  260. elif task_type == "segmentation":
  261. if img_data is None:
  262. if len(img_list) == 0 and osp.exists(
  263. osp.join(dataset_path, "test_list.txt")):
  264. with open(osp.join(dataset_path, "test_list.txt")) as f:
  265. for line in f:
  266. items = line.strip().split()
  267. file_list[osp.join(dataset_path, items[0])] = \
  268. osp.join(dataset_path, items[1])
  269. else:
  270. for image in img_list:
  271. file_list[image] = None
  272. total_file_num = len(file_list)
  273. write_file_num(total_file_num)
  274. color_map = get_color_map_list(256)
  275. legend = {}
  276. for i in range(len(model.labels)):
  277. legend[model.labels[i]] = color_map[i]
  278. for image, anno in file_list.items():
  279. results = model.predict(img_file=image)
  280. image_pred = pdx.seg.visualize(image, results, save_dir=None)
  281. pse_pred = pdx.seg.visualize(
  282. image, results, weight=0, save_dir=None)
  283. image_ground = None
  284. pse_label = None
  285. if anno is not None:
  286. label = np.asarray(Image.open(anno)).astype('uint8')
  287. image_ground = pdx.seg.visualize(
  288. image, {'label_map': label}, save_dir=None)
  289. pse_label = pdx.seg.visualize(
  290. image, {'label_map': label}, weight=0, save_dir=None)
  291. save_name = osp.join(save_dir, osp.split(image)[-1])
  292. visualize_segmented_result(save_name, image_ground, pse_label,
  293. image_pred, pse_pred, legend)
  294. predicted_num += 1
  295. else:
  296. img_data = base64.b64decode(img_data)
  297. img_array = np.frombuffer(img_data, np.uint8)
  298. img = cv2.imdecode(img_array, cv2.COLOR_RGB2BGR)
  299. color_map = get_color_map_list(256)
  300. legend = {}
  301. for i in range(len(model.labels)):
  302. legend[model.labels[i]] = color_map[i]
  303. results = model.predict(img)
  304. image_pred = pdx.seg.visualize(image, results, save_dir=None)
  305. pse_pred = pdx.seg.visualize(
  306. image, results, weight=0, save_dir=None)
  307. image_ground = None
  308. pse_label = None
  309. save_name = osp.join(save_dir, 'predict_result.png')
  310. visualize_segmented_result(save_name, image_ground, pse_label,
  311. image_pred, pse_pred, legend)
  312. set_folder_status(predict_status_path, PredictStatus.XPREDONE)
  313. def _call_paddlex_export_infer(task_path, save_dir, export_status_path, epoch):
  314. # 导出模型不使用GPU
  315. sys.stdout = open(
  316. osp.join(export_status_path, 'out.log'), 'w', encoding='utf-8')
  317. sys.stderr = open(
  318. osp.join(export_status_path, 'err.log'), 'w', encoding='utf-8')
  319. import os
  320. os.environ['CUDA_VISIBLE_DEVICES'] = ''
  321. import paddlex as pdx
  322. if epoch is not None:
  323. model_dir = "epoch_{}".format(epoch)
  324. model_path = osp.join(task_path, 'output', model_dir)
  325. else:
  326. model_path = osp.join(task_path, 'output', 'best_model')
  327. model = pdx.load_model(model_path)
  328. model.export_inference_model(save_dir)
  329. set_folder_status(export_status_path, TaskStatus.XEXPORTED)
  330. set_folder_status(task_path, TaskStatus.XEXPORTED)
  331. def _call_paddlex_export_quant(task_path, params, save_dir, export_status_path,
  332. epoch):
  333. sys.stdout = open(
  334. osp.join(export_status_path, 'out.log'), 'w', encoding='utf-8')
  335. sys.stderr = open(
  336. osp.join(export_status_path, 'err.log'), 'w', encoding='utf-8')
  337. dataset_path = params['dataset_path']
  338. task_type = params['task_type']
  339. os.environ['CUDA_VISIBLE_DEVICES'] = params['train'].cuda_visible_devices
  340. import paddlex as pdx
  341. if epoch is not None:
  342. model_dir = "epoch_{}".format(epoch)
  343. model_path = osp.join(task_path, 'output', model_dir)
  344. else:
  345. model_path = osp.join(task_path, 'output', 'best_model')
  346. model = pdx.load_model(model_path)
  347. if task_type == "classification":
  348. train_file_list = osp.join(dataset_path, 'train_list.txt')
  349. val_file_list = osp.join(dataset_path, 'val_list.txt')
  350. label_list = osp.join(dataset_path, 'labels.txt')
  351. quant_dataset = pdx.datasets.ImageNet(
  352. data_dir=dataset_path,
  353. file_list=train_file_list,
  354. label_list=label_list,
  355. transforms=model.test_transforms)
  356. eval_dataset = pdx.datasets.ImageNet(
  357. data_dir=dataset_path,
  358. file_list=val_file_list,
  359. label_list=label_list,
  360. transforms=model.eval_transforms)
  361. elif task_type == "detection":
  362. train_file_list = osp.join(dataset_path, 'train_list.txt')
  363. val_file_list = osp.join(dataset_path, 'val_list.txt')
  364. label_list = osp.join(dataset_path, 'labels.txt')
  365. quant_dataset = pdx.datasets.VOCDetection(
  366. data_dir=dataset_path,
  367. file_list=train_file_list,
  368. label_list=label_list,
  369. transforms=model.test_transforms)
  370. eval_dataset = pdx.datasets.VOCDetection(
  371. data_dir=dataset_path,
  372. file_list=val_file_list,
  373. label_list=label_list,
  374. transforms=model.eval_transforms)
  375. elif task_type == "instance_segmentation":
  376. train_json = osp.join(dataset_path, 'train.json')
  377. val_json = osp.join(dataset_path, 'val.json')
  378. quant_dataset = pdx.datasets.CocoDetection(
  379. data_dir=osp.join(dataset_path, 'JPEGImages'),
  380. ann_file=train_json,
  381. transforms=model.test_transforms)
  382. eval_dataset = pdx.datasets.CocoDetection(
  383. data_dir=osp.join(dataset_path, 'JPEGImages'),
  384. ann_file=val_json,
  385. transforms=model.eval_transforms)
  386. elif task_type == "segmentation":
  387. train_file_list = osp.join(dataset_path, 'train_list.txt')
  388. val_file_list = osp.join(dataset_path, 'val_list.txt')
  389. label_list = osp.join(dataset_path, 'labels.txt')
  390. quant_dataset = pdx.datasets.SegDataset(
  391. data_dir=dataset_path,
  392. file_list=train_file_list,
  393. label_list=label_list,
  394. transforms=model.test_transforms)
  395. eval_dataset = pdx.datasets.SegDataset(
  396. data_dir=dataset_path,
  397. file_list=val_file_list,
  398. label_list=label_list,
  399. transforms=model.eval_transforms)
  400. metric_before = model.evaluate(eval_dataset)
  401. pdx.log_level = 3
  402. pdx.slim.export_quant_model(
  403. model, quant_dataset, batch_size=1, save_dir=save_dir, cache_dir=None)
  404. model_quant = pdx.load_model(save_dir)
  405. metric_after = model_quant.evaluate(eval_dataset)
  406. metrics = {}
  407. if task_type == "segmentation":
  408. metrics['before'] = {'miou': metric_before['miou']}
  409. metrics['after'] = {'miou': metric_after['miou']}
  410. else:
  411. metrics['before'] = metric_before
  412. metrics['after'] = metric_after
  413. import json
  414. with open(
  415. osp.join(export_status_path, 'quant_result.json'),
  416. 'w',
  417. encoding='utf-8') as f:
  418. json.dump(metrics, f)
  419. set_folder_status(export_status_path, TaskStatus.XEXPORTED)
  420. set_folder_status(task_path, TaskStatus.XEXPORTED)
  421. def _call_paddlelite_export_lite(model_path, save_dir=None, place="arm"):
  422. import paddlelite.lite as lite
  423. opt = lite.Opt()
  424. model_file = os.path.join(model_path, '__model__')
  425. params_file = os.path.join(model_path, '__params__')
  426. if save_dir is None:
  427. save_dir = osp.join(model_path, "lite_model")
  428. if not osp.exists(save_dir):
  429. os.makedirs(save_dir)
  430. path = osp.join(save_dir, "model")
  431. opt.run_optimize("", model_file, params_file, "naive_buffer", place, path)
  432. def safe_clean_folder(folder):
  433. if osp.exists(folder):
  434. try:
  435. shutil.rmtree(folder)
  436. os.makedirs(folder)
  437. except Exception as e:
  438. pass
  439. if osp.exists(folder):
  440. for root, dirs, files in os.walk(folder):
  441. for name in files:
  442. try:
  443. os.remove(os.path.join(root, name))
  444. except Exception as e:
  445. pass
  446. else:
  447. os.makedirs(folder)
  448. else:
  449. os.makedirs(folder)
  450. if not osp.exists(folder):
  451. os.makedirs(folder)
  452. def get_task_max_saved_epochs(task_path):
  453. saved_epoch_num = -1
  454. output_path = osp.join(task_path, "output")
  455. if osp.exists(output_path):
  456. for f in os.listdir(output_path):
  457. if f.startswith("epoch_"):
  458. if not osp.exists(osp.join(output_path, f, '.success')):
  459. continue
  460. curr_epoch_num = int(f[6:])
  461. if curr_epoch_num > saved_epoch_num:
  462. saved_epoch_num = curr_epoch_num
  463. return saved_epoch_num
  464. def get_task_status(task_path):
  465. status, message = get_folder_status(task_path, True)
  466. task_id = os.path.split(task_path)[-1]
  467. err_log = os.path.join(task_path, 'err.log')
  468. if status in [TaskStatus.XTRAINING, TaskStatus.XPRUNETRAIN]:
  469. pid = int(message)
  470. is_dead = False
  471. if not psutil.pid_exists(pid):
  472. is_dead = True
  473. else:
  474. p = psutil.Process(pid)
  475. if p.status() == 'zombie':
  476. is_dead = True
  477. if is_dead:
  478. status = TaskStatus.XTRAINFAIL
  479. message = "训练任务{}异常终止,请查阅错误日志具体确认原因{}。\n\n 如若通过日志无法确定原因,可尝试以下几种方法,\n" \
  480. "1. 尝试重新启动训练,看是否能正常训练; \n" \
  481. "2. 调低batch_size(需同时按比例调低学习率等参数)排除是否是显存或内存不足的原因导致;\n" \
  482. "3. 前往GitHub提ISSUE,描述清楚问题会有工程师及时回复: https://github.com/PaddlePaddle/PaddleX/issues ; \n" \
  483. "3. 加QQ群1045148026或邮件至paddlex@baidu.com在线咨询工程师".format(task_id, err_log)
  484. set_folder_status(task_path, status, message)
  485. return status, message
  486. def train_model(task_path):
  487. """训练模型
  488. Args:
  489. task_path(str): 模型训练的参数保存在task_path下的'params.pkl'文件中
  490. """
  491. params_conf_file = osp.join(task_path, 'params.pkl')
  492. assert osp.exists(
  493. params_conf_file), "任务无法启动,路径{}下不存在参数配置文件params.pkl".format(task_path)
  494. with open(params_conf_file, 'rb') as f:
  495. params = pickle.load(f)
  496. sensitivities_path = params['train'].sensitivities_path
  497. p = mp.Process(target=_call_paddlex_train, args=(task_path, params))
  498. p.start()
  499. if sensitivities_path is None:
  500. set_folder_status(task_path, TaskStatus.XTRAINING, p.pid)
  501. else:
  502. set_folder_status(task_path, TaskStatus.XPRUNETRAIN, p.pid)
  503. return p
  504. def stop_train_model(task_path):
  505. """停止正在训练的模型
  506. Args:
  507. task_path(str): 从task_path下的'XTRANING'文件中获取训练的进程id
  508. """
  509. status, message = get_task_status(task_path)
  510. if status in [TaskStatus.XTRAINING, TaskStatus.XPRUNETRAIN]:
  511. pid = int(message)
  512. pkill(pid)
  513. best_model_saved = True
  514. if not osp.exists(osp.join(task_path, 'output', 'best_model')):
  515. best_model_saved = False
  516. set_folder_status(task_path, TaskStatus.XTRAINEXIT, best_model_saved)
  517. else:
  518. raise Exception("模型训练任务没在运行中")
  519. def prune_analysis_model(task_path):
  520. """模型裁剪分析
  521. Args:
  522. task_path(str): 模型训练的参数保存在task_path
  523. dataset_path(str) 模型裁剪中评估数据集的路径
  524. """
  525. best_model_path = osp.join(task_path, 'output', 'best_model')
  526. assert osp.exists(best_model_path), "该任务暂未保存模型,无法进行模型裁剪分析"
  527. prune_analysis_path = osp.join(task_path, 'prune')
  528. if not osp.exists(prune_analysis_path):
  529. os.makedirs(prune_analysis_path)
  530. params_conf_file = osp.join(task_path, 'params.pkl')
  531. assert osp.exists(
  532. params_conf_file), "任务无法启动,路径{}下不存在参数配置文件params.pkl".format(task_path)
  533. with open(params_conf_file, 'rb') as f:
  534. params = pickle.load(f)
  535. assert params['train'].model.lower() not in [
  536. "ppyolo", "fasterrcnn", "maskrcnn", "fastscnn", "HRNet_W18"
  537. ], "暂不支持PPYOLO、FasterRCNN、MaskRCNN、HRNet_W18、FastSCNN模型裁剪"
  538. p = mp.Process(
  539. target=_call_paddle_prune,
  540. args=(best_model_path, prune_analysis_path, params))
  541. p.start()
  542. set_folder_status(prune_analysis_path, PruneStatus.XSPRUNEING, p.pid)
  543. set_folder_status(task_path, TaskStatus.XPRUNEING, p.pid)
  544. return p
  545. def get_prune_status(prune_path):
  546. status, message = get_folder_status(prune_path, True)
  547. if status in [PruneStatus.XSPRUNEING]:
  548. pid = int(message)
  549. is_dead = False
  550. if not psutil.pid_exists(pid):
  551. is_dead = True
  552. else:
  553. p = psutil.Process(pid)
  554. if p.status() == 'zombie':
  555. is_dead = True
  556. if is_dead:
  557. status = PruneStatus.XSPRUNEFAIL
  558. message = "模型裁剪异常终止,可能原因如下:\n1.暂不支持FasterRCNN、MaskRCNN模型的模型裁剪\n2.模型裁剪过程中进程被异常结束,建议重新启动模型裁剪任务"
  559. set_folder_status(prune_path, status, message)
  560. return status, message
  561. def stop_prune_analysis(prune_path):
  562. """停止正在裁剪分析的模型
  563. Args:
  564. prune_path(str): prune_path'XSSLMING'文件中获取训练的进程id
  565. """
  566. status, message = get_prune_status(prune_path)
  567. if status == PruneStatus.XSPRUNEING:
  568. pid = int(message)
  569. pkill(pid)
  570. set_folder_status(prune_path, PruneStatus.XSPRUNEEXIT)
  571. else:
  572. raise Exception("模型裁剪分析任务未在运行中")
  573. def evaluate_model(task_path,
  574. task_type,
  575. epoch=None,
  576. topk=5,
  577. score_thresh=0.3,
  578. overlap_thresh=0.5):
  579. """评估最优模型
  580. Args:
  581. task_path(str): 模型训练相关结果的保存路径
  582. """
  583. output_path = osp.join(task_path, 'output')
  584. if not osp.exists(osp.join(output_path, 'best_model')):
  585. raise Exception("未在训练路径{}下发现保存的best_model,无法进行评估".format(output_path))
  586. evaluate_status_path = osp.join(task_path, './logs/evaluate')
  587. safe_clean_folder(evaluate_status_path)
  588. if epoch is None:
  589. model_path = osp.join(output_path, 'best_model')
  590. else:
  591. epoch_dir = "{}_{}".format('epoch', epoch)
  592. model_path = osp.join(output_path, epoch_dir)
  593. p = mp.Process(
  594. target=_call_paddlex_evaluate_model,
  595. args=(task_path, model_path, task_type, epoch, topk, score_thresh,
  596. overlap_thresh))
  597. p.start()
  598. set_folder_status(evaluate_status_path, TaskStatus.XEVALUATING, p.pid)
  599. return p
  600. def get_evaluate_status(task_path):
  601. """获取导出状态
  602. Args:
  603. task_path(str): 训练任务文件夹
  604. """
  605. evaluate_status_path = osp.join(task_path, './logs/evaluate')
  606. if not osp.exists(evaluate_status_path):
  607. return None, "No evaluate fold in path {}".format(task_path)
  608. status, message = get_folder_status(evaluate_status_path, True)
  609. if status == TaskStatus.XEVALUATING:
  610. pid = int(message)
  611. is_dead = False
  612. if not psutil.pid_exists(pid):
  613. is_dead = True
  614. else:
  615. p = psutil.Process(pid)
  616. if p.status() == 'zombie':
  617. is_dead = True
  618. if is_dead:
  619. status = TaskStatus.XEVALUATEFAIL
  620. message = "评估过程出现异常,请尝试重新评估!"
  621. set_folder_status(evaluate_status_path, status, message)
  622. if status not in [
  623. TaskStatus.XEVALUATING, TaskStatus.XEVALUATED,
  624. TaskStatus.XEVALUATEFAIL
  625. ]:
  626. raise ValueError("Wrong status in evaluate task {}".format(status))
  627. return status, message
  628. def get_predict_status(task_path):
  629. """获取预测任务状态
  630. Args:
  631. task_path(str): 从predict_path下的'XPRESTART'文件中获取训练的进程id
  632. """
  633. from ..utils import list_files
  634. predict_status_path = osp.join(task_path, "./logs/predict")
  635. save_dir = osp.join(task_path, "visualized_test_results")
  636. if not osp.exists(save_dir):
  637. return None, "任务目录下没有visualized_test_results文件夹,{}".format(
  638. task_path), 0, 0
  639. status, message = get_folder_status(predict_status_path, True)
  640. if status == PredictStatus.XPRESTART:
  641. pid = int(message)
  642. is_dead = False
  643. if not psutil.pid_exists(pid):
  644. is_dead = True
  645. else:
  646. p = psutil.Process(pid)
  647. if p.status() == 'zombie':
  648. is_dead = True
  649. if is_dead:
  650. status = PredictStatus.XPREFAIL
  651. message = "图片预测过程出现异常,请尝试重新预测!"
  652. set_folder_status(predict_status_path, status, message)
  653. if status not in [
  654. PredictStatus.XPRESTART, PredictStatus.XPREDONE,
  655. PredictStatus.XPREFAIL
  656. ]:
  657. raise ValueError("预测任务状态异常,{}".format(status))
  658. predict_num = len(list_files(save_dir))
  659. if predict_num > 0:
  660. if predict_num == 1:
  661. total_num = 1
  662. else:
  663. total_num = int(
  664. open(
  665. osp.join(predict_status_path, "total_num"),
  666. encoding='utf-8').readline().strip())
  667. else:
  668. predict_num = 0
  669. total_num = 0
  670. return status, message, predict_num, total_num
  671. def predict_test_pics(task_path,
  672. img_list=[],
  673. img_data=None,
  674. save_dir=None,
  675. score_thresh=0.5,
  676. epoch=None):
  677. """模型预测
  678. Args:
  679. task_path(str): 模型训练的参数保存在task_path下的'params.pkl'文件中
  680. """
  681. params_conf_file = osp.join(task_path, 'params.pkl')
  682. assert osp.exists(
  683. params_conf_file), "任务无法启动,路径{}下不存在参数配置文件params.pkl".format(task_path)
  684. with open(params_conf_file, 'rb') as f:
  685. params = pickle.load(f)
  686. predict_status_path = osp.join(task_path, "./logs/predict")
  687. safe_clean_folder(predict_status_path)
  688. save_dir = osp.join(task_path, 'visualized_test_results')
  689. safe_clean_folder(save_dir)
  690. p = mp.Process(
  691. target=_call_paddlex_predict,
  692. args=(task_path, predict_status_path, params, img_list, img_data,
  693. save_dir, score_thresh, epoch))
  694. p.start()
  695. set_folder_status(predict_status_path, PredictStatus.XPRESTART, p.pid)
  696. return p, save_dir
  697. def stop_predict_task(task_path):
  698. """停止预测任务
  699. Args:
  700. task_path(str): 从predict_path下的'XPRESTART'文件中获取训练的进程id
  701. """
  702. from ..utils import list_files
  703. predict_status_path = osp.join(task_path, "./logs/predict")
  704. save_dir = osp.join(task_path, "visualized_test_results")
  705. if not osp.exists(save_dir):
  706. return None, "任务目录下没有visualized_test_results文件夹,{}".format(
  707. task_path), 0, 0
  708. status, message = get_folder_status(predict_status_path, True)
  709. if status == PredictStatus.XPRESTART:
  710. pid = int(message)
  711. is_dead = False
  712. if not psutil.pid_exists(pid):
  713. is_dead = True
  714. else:
  715. p = psutil.Process(pid)
  716. if p.status() == 'zombie':
  717. is_dead = True
  718. if is_dead:
  719. status = PredictStatus.XPREFAIL
  720. message = "图片预测过程出现异常,请尝试重新预测!"
  721. set_folder_status(predict_status_path, status, message)
  722. else:
  723. pkill(pid)
  724. status = PredictStatus.XPREFAIL
  725. message = "图片预测进程已停止!"
  726. set_folder_status(predict_status_path, status, message)
  727. if status not in [
  728. PredictStatus.XPRESTART, PredictStatus.XPREDONE,
  729. PredictStatus.XPREFAIL
  730. ]:
  731. raise ValueError("预测任务状态异常,{}".format(status))
  732. predict_num = len(list_files(save_dir))
  733. if predict_num > 0:
  734. total_num = int(
  735. open(
  736. osp.join(predict_status_path, "total_num"), encoding='utf-8')
  737. .readline().strip())
  738. else:
  739. predict_num = 0
  740. total_num = 0
  741. return status, message, predict_num, total_num
  742. def get_export_status(task_path):
  743. """获取导出状态
  744. Args:
  745. task_path(str): 从task_path下的'export/XEXPORTING'文件中获取训练的进程id
  746. Return:
  747. 导出的状态和其他消息.
  748. """
  749. export_status_path = osp.join(task_path, './logs/export')
  750. if not osp.exists(export_status_path):
  751. return None, "{}任务目录下没有export文件夹".format(task_path)
  752. status, message = get_folder_status(export_status_path, True)
  753. if status == TaskStatus.XEXPORTING:
  754. pid = int(message)
  755. is_dead = False
  756. if not psutil.pid_exists(pid):
  757. is_dead = True
  758. else:
  759. p = psutil.Process(pid)
  760. if p.status() == 'zombie':
  761. is_dead = True
  762. if is_dead:
  763. status = TaskStatus.XEXPORTFAIL
  764. message = "导出过程出现异常,请尝试重新评估!"
  765. set_folder_status(export_status_path, status, message)
  766. if status not in [
  767. TaskStatus.XEXPORTING, TaskStatus.XEXPORTED, TaskStatus.XEXPORTFAIL
  768. ]:
  769. # raise ValueError("获取到的导出状态异常,{}。".format(status))
  770. return None, "获取到的导出状态异常,{}。".format(status)
  771. return status, message
  772. def export_quant_model(task_path, save_dir, epoch=None):
  773. """导出量化模型
  774. Args:
  775. task_path(str): 模型训练的路径
  776. save_dir(str): 导出后的模型保存路径
  777. """
  778. output_path = osp.join(task_path, 'output')
  779. if not osp.exists(osp.join(output_path, 'best_model')):
  780. raise Exception("未在训练路径{}下发现保存的best_model,导出失败".format(output_path))
  781. export_status_path = osp.join(task_path, './logs/export')
  782. safe_clean_folder(export_status_path)
  783. params_conf_file = osp.join(task_path, 'params.pkl')
  784. assert osp.exists(
  785. params_conf_file), "任务无法启动,路径{}下不存在参数配置文件params.pkl".format(task_path)
  786. with open(params_conf_file, 'rb') as f:
  787. params = pickle.load(f)
  788. p = mp.Process(
  789. target=_call_paddlex_export_quant,
  790. args=(task_path, params, save_dir, export_status_path, epoch))
  791. p.start()
  792. set_folder_status(export_status_path, TaskStatus.XEXPORTING, p.pid)
  793. set_folder_status(task_path, TaskStatus.XEXPORTING, p.pid)
  794. return p
  795. def export_noquant_model(task_path, save_dir, epoch=None):
  796. """导出inference模型
  797. Args:
  798. task_path(str): 模型训练的路径
  799. save_dir(str): 导出后的模型保存路径
  800. """
  801. output_path = osp.join(task_path, 'output')
  802. if not osp.exists(osp.join(output_path, 'best_model')):
  803. raise Exception("未在训练路径{}下发现保存的best_model,导出失败".format(output_path))
  804. export_status_path = osp.join(task_path, './logs/export')
  805. safe_clean_folder(export_status_path)
  806. p = mp.Process(
  807. target=_call_paddlex_export_infer,
  808. args=(task_path, save_dir, export_status_path, epoch))
  809. p.start()
  810. set_folder_status(export_status_path, TaskStatus.XEXPORTING, p.pid)
  811. set_folder_status(task_path, TaskStatus.XEXPORTING, p.pid)
  812. return p
  813. def opt_lite_model(model_path, save_dir=None, place='arm'):
  814. p = mp.Process(
  815. target=_call_paddlelite_export_lite,
  816. args=(model_path, save_dir, place))
  817. p.start()
  818. p.join()
  819. def stop_export_task(task_path):
  820. """停止导出
  821. Args:
  822. task_path(str): 从task_path下的'export/XEXPORTING'文件中获取训练的进程id
  823. Return:
  824. the export status and message.
  825. """
  826. export_status_path = osp.join(task_path, './logs/export')
  827. if not osp.exists(export_status_path):
  828. return None, "{}任务目录下没有export文件夹".format(task_path)
  829. status, message = get_folder_status(export_status_path, True)
  830. if status == TaskStatus.XEXPORTING:
  831. pid = int(message)
  832. is_dead = False
  833. if not psutil.pid_exists(pid):
  834. is_dead = True
  835. else:
  836. p = psutil.Process(pid)
  837. if p.status() == 'zombie':
  838. is_dead = True
  839. if is_dead:
  840. status = TaskStatus.XEXPORTFAIL
  841. message = "导出过程出现异常,请尝试重新评估!"
  842. set_folder_status(export_status_path, status, message)
  843. else:
  844. pkill(pid)
  845. status = TaskStatus.XEXPORTFAIL
  846. message = "已停止导出进程!"
  847. set_folder_status(export_status_path, status, message)
  848. if status not in [
  849. TaskStatus.XEXPORTING, TaskStatus.XEXPORTED, TaskStatus.XEXPORTFAIL
  850. ]:
  851. raise ValueError("获取到的导出状态异常,{}。".format(status))
  852. return status, message