operate.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931
  1. # copytrue (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os.path as osp
  15. import os
  16. import numpy as np
  17. from PIL import Image
  18. import sys
  19. import cv2
  20. import psutil
  21. import shutil
  22. import pickle
  23. import base64
  24. import multiprocessing as mp
  25. from ..utils import (pkill, set_folder_status, get_folder_status, TaskStatus,
  26. PredictStatus, PruneStatus)
  27. from .evaluate.draw_pred_result import visualize_classified_result, visualize_detected_result, visualize_segmented_result
  28. from .visualize import plot_det_label, plot_insseg_label, get_color_map_list
  29. from paddlex_restful.restful.dataset.utils import get_encoding
  30. def _call_paddle_prune(best_model_path, prune_analysis_path, params):
  31. mode = 'w'
  32. sys.stdout = open(
  33. osp.join(prune_analysis_path, 'out.log'), mode, encoding='utf-8')
  34. sys.stderr = open(
  35. osp.join(prune_analysis_path, 'err.log'), mode, encoding='utf-8')
  36. task_type = params['task_type']
  37. dataset_path = params['dataset_path']
  38. os.environ['CUDA_VISIBLE_DEVICES'] = params['train'].cuda_visible_devices
  39. if task_type == "classification":
  40. from .prune.classification import prune
  41. elif task_type in ["detection", "instance_segmentation"]:
  42. from .prune.detection import prune
  43. elif task_type == "segmentation":
  44. from .prune.segmentation import prune
  45. batch_size = params['train'].batch_size
  46. prune(best_model_path, dataset_path, prune_analysis_path, batch_size)
  47. set_folder_status(prune_analysis_path, PruneStatus.XSPRUNEDONE)
  48. def _call_paddlex_train(task_path, params):
  49. '''
  50. Args:
  51. params为dict,字段包括'pretrain_weights_download_save_dir': 预训练模型保存路径,
  52. 'task_type': 任务类型,'dataset_path': 数据集路径,'train':训练参数
  53. '''
  54. mode = 'w'
  55. if params['train'].resume_checkpoint is not None:
  56. params['train'].pretrain_weights = None
  57. mode = 'a'
  58. sys.stdout = open(osp.join(task_path, 'out.log'), mode, encoding='utf-8')
  59. sys.stderr = open(osp.join(task_path, 'err.log'), mode, encoding='utf-8')
  60. sys.stdout.write("This log file path is {}\n".format(
  61. osp.join(task_path, 'out.log')))
  62. sys.stdout.write("注意:标志为WARNING/INFO类的仅为警告或提示类信息,非错误信息\n")
  63. sys.stderr.write("This log file path is {}\n".format(
  64. osp.join(task_path, 'err.log')))
  65. sys.stderr.write("注意:标志为WARNING/INFO类的仅为警告或提示类信息,非错误信息\n")
  66. os.environ['CUDA_VISIBLE_DEVICES'] = params['train'].cuda_visible_devices
  67. import paddlex as pdx
  68. pdx.gui_mode = True
  69. pdx.log_level = 3
  70. pdx.pretrain_dir = params['pretrain_weights_download_save_dir']
  71. task_type = params['task_type']
  72. dataset_path = params['dataset_path']
  73. if task_type == "classification":
  74. from .train.classification import train
  75. elif task_type in ["detection", "instance_segmentation"]:
  76. from .train.detection import train
  77. elif task_type == "segmentation":
  78. from .train.segmentation import train
  79. train(task_path, dataset_path, params['train'])
  80. set_folder_status(task_path, TaskStatus.XTRAINDONE)
  81. def _call_paddlex_evaluate_model(task_path,
  82. model_path,
  83. task_type,
  84. epoch,
  85. topk=5,
  86. score_thresh=0.3,
  87. overlap_thresh=0.5):
  88. evaluate_status_path = osp.join(task_path, './logs/evaluate')
  89. sys.stdout = open(
  90. osp.join(evaluate_status_path, 'out.log'), 'w', encoding='utf-8')
  91. sys.stderr = open(
  92. osp.join(evaluate_status_path, 'err.log'), 'w', encoding='utf-8')
  93. if task_type == "classification":
  94. from .evaluate.classification import Evaluator
  95. evaluator = Evaluator(model_path, topk=topk)
  96. elif task_type == "detection":
  97. from .evaluate.detection import DetEvaluator
  98. evaluator = DetEvaluator(
  99. model_path,
  100. score_threshold=score_thresh,
  101. overlap_thresh=overlap_thresh)
  102. elif task_type == "instance_segmentation":
  103. from .evaluate.detection import InsSegEvaluator
  104. evaluator = InsSegEvaluator(
  105. model_path,
  106. score_threshold=score_thresh,
  107. overlap_thresh=overlap_thresh)
  108. elif task_type == "segmentation":
  109. from .evaluate.segmentation import Evaluator
  110. evaluator = Evaluator(model_path)
  111. report = evaluator.generate_report()
  112. report['epoch'] = epoch
  113. pickle.dump(report, open(osp.join(task_path, "eval_res.pkl"), "wb"))
  114. set_folder_status(evaluate_status_path, TaskStatus.XEVALUATED)
  115. set_folder_status(task_path, TaskStatus.XEVALUATED)
  116. def _call_paddlex_predict(task_path,
  117. predict_status_path,
  118. params,
  119. img_list,
  120. img_data,
  121. save_dir,
  122. score_thresh,
  123. epoch=None):
  124. total_num = open(
  125. osp.join(predict_status_path, 'total_num'), 'w', encoding='utf-8')
  126. def write_file_num(total_file_num):
  127. total_num.write(str(total_file_num))
  128. total_num.close()
  129. sys.stdout = open(
  130. osp.join(predict_status_path, 'out.log'), 'w', encoding='utf-8')
  131. sys.stderr = open(
  132. osp.join(predict_status_path, 'err.log'), 'w', encoding='utf-8')
  133. import paddlex as pdx
  134. pdx.log_level = 3
  135. task_type = params['task_type']
  136. dataset_path = params['dataset_path']
  137. if epoch is None:
  138. model_path = osp.join(task_path, 'output', 'best_model')
  139. else:
  140. model_path = osp.join(task_path, 'output', 'epoch_{}'.format(epoch))
  141. model = pdx.load_model(model_path)
  142. file_list = dict()
  143. predicted_num = 0
  144. if task_type == "classification":
  145. if img_data is None:
  146. if len(img_list) == 0 and osp.exists(
  147. osp.join(dataset_path, "test_list.txt")):
  148. with open(osp.join(dataset_path, "test_list.txt")) as f:
  149. for line in f:
  150. items = line.strip().split()
  151. file_list[osp.join(dataset_path, items[0])] = items[1]
  152. else:
  153. for image in img_list:
  154. file_list[image] = None
  155. total_file_num = len(file_list)
  156. write_file_num(total_file_num)
  157. for image, label_id in file_list.items():
  158. pred_result = {}
  159. if label_id is not None:
  160. pred_result["gt_label"] = model.labels[int(label_id)]
  161. results = model.predict(img_file=image)
  162. pred_result["label"] = []
  163. pred_result["score"] = []
  164. pred_result["topk"] = len(results)
  165. for res in results:
  166. pred_result["label"].append(res['category'])
  167. pred_result["score"].append(res['score'])
  168. visualize_classified_result(save_dir, image, pred_result)
  169. predicted_num += 1
  170. else:
  171. img_data = base64.b64decode(img_data)
  172. img_array = np.frombuffer(img_data, np.uint8)
  173. img = cv2.imdecode(img_array, cv2.COLOR_RGB2BGR)
  174. results = model.predict(img)
  175. pred_result = {}
  176. pred_result["label"] = []
  177. pred_result["score"] = []
  178. pred_result["topk"] = len(results)
  179. for res in results:
  180. pred_result["label"].append(res['category'])
  181. pred_result["score"].append(res['score'])
  182. visualize_classified_result(save_dir, img, pred_result)
  183. elif task_type in ["detection", "instance_segmentation"]:
  184. if img_data is None:
  185. if task_type == "detection" and osp.exists(
  186. osp.join(dataset_path, "test_list.txt")):
  187. if len(img_list) == 0 and osp.exists(
  188. osp.join(dataset_path, "test_list.txt")):
  189. with open(
  190. osp.join(dataset_path, "test_list.txt"),
  191. encoding=get_encoding(
  192. osp.join(dataset_path, "test_list.txt"))) as f:
  193. for line in f:
  194. items = line.strip().split()
  195. file_list[osp.join(dataset_path, items[0])] = \
  196. osp.join(dataset_path, items[1])
  197. else:
  198. for image in img_list:
  199. file_list[image] = None
  200. total_file_num = len(file_list)
  201. write_file_num(total_file_num)
  202. for image, anno in file_list.items():
  203. results = model.predict(img_file=image)
  204. image_pred = pdx.det.visualize(
  205. image, results, threshold=score_thresh, save_dir=None)
  206. save_name = osp.join(save_dir, osp.split(image)[-1])
  207. image_gt = None
  208. if anno is not None:
  209. image_gt = plot_det_label(image, anno, model.labels)
  210. visualize_detected_result(save_name, image_gt, image_pred)
  211. predicted_num += 1
  212. elif len(img_list) == 0 and osp.exists(
  213. osp.join(dataset_path, "test.json")):
  214. from pycocotools.coco import COCO
  215. anno_path = osp.join(dataset_path, "test.json")
  216. coco = COCO(anno_path)
  217. img_ids = coco.getImgIds()
  218. total_file_num = len(img_ids)
  219. write_file_num(total_file_num)
  220. for img_id in img_ids:
  221. img_anno = coco.loadImgs(img_id)[0]
  222. file_name = img_anno['file_name']
  223. name = (osp.split(file_name)[-1]).split(".")[0]
  224. anno = osp.join(dataset_path, "Annotations", name + ".npy")
  225. img_file = osp.join(dataset_path, "JPEGImages", file_name)
  226. results = model.predict(img_file=img_file)
  227. image_pred = pdx.det.visualize(
  228. img_file,
  229. results,
  230. threshold=score_thresh,
  231. save_dir=None)
  232. save_name = osp.join(save_dir, osp.split(img_file)[-1])
  233. if task_type == "detection":
  234. image_gt = plot_det_label(img_file, anno, model.labels)
  235. else:
  236. image_gt = plot_insseg_label(img_file, anno,
  237. model.labels)
  238. visualize_detected_result(save_name, image_gt, image_pred)
  239. predicted_num += 1
  240. else:
  241. total_file_num = len(img_list)
  242. write_file_num(total_file_num)
  243. for image in img_list:
  244. results = model.predict(img_file=image)
  245. image_pred = pdx.det.visualize(
  246. image, results, threshold=score_thresh, save_dir=None)
  247. save_name = osp.join(save_dir, osp.split(image)[-1])
  248. visualize_detected_result(save_name, None, image_pred)
  249. predicted_num += 1
  250. else:
  251. img_data = base64.b64decode(img_data)
  252. img_array = np.frombuffer(img_data, np.uint8)
  253. img = cv2.imdecode(img_array, cv2.COLOR_RGB2BGR)
  254. results = model.predict(img)
  255. image_pred = pdx.det.visualize(
  256. img, results, threshold=score_thresh, save_dir=None)
  257. image_gt = None
  258. save_name = osp.join(save_dir, 'predict_result.png')
  259. visualize_detected_result(save_name, image_gt, image_pred)
  260. elif task_type == "segmentation":
  261. if img_data is None:
  262. if len(img_list) == 0 and osp.exists(
  263. osp.join(dataset_path, "test_list.txt")):
  264. with open(
  265. osp.join(dataset_path, "test_list.txt"),
  266. encoding=get_encoding(
  267. osp.join(dataset_path, "test_list.txt"))) as f:
  268. for line in f:
  269. items = line.strip().split()
  270. file_list[osp.join(dataset_path, items[0])] = \
  271. osp.join(dataset_path, items[1])
  272. else:
  273. for image in img_list:
  274. file_list[image] = None
  275. total_file_num = len(file_list)
  276. write_file_num(total_file_num)
  277. color_map = get_color_map_list(256)
  278. legend = {}
  279. for i in range(len(model.labels)):
  280. legend[model.labels[i]] = color_map[i]
  281. for image, anno in file_list.items():
  282. results = model.predict(img_file=image)
  283. image_pred = pdx.seg.visualize(image, results, save_dir=None)
  284. pse_pred = pdx.seg.visualize(
  285. image, results, weight=0, save_dir=None)
  286. image_ground = None
  287. pse_label = None
  288. if anno is not None:
  289. label = np.asarray(Image.open(anno)).astype('uint8')
  290. image_ground = pdx.seg.visualize(
  291. image, {'label_map': label}, save_dir=None)
  292. pse_label = pdx.seg.visualize(
  293. image, {'label_map': label}, weight=0, save_dir=None)
  294. save_name = osp.join(save_dir, osp.split(image)[-1])
  295. visualize_segmented_result(save_name, image_ground, pse_label,
  296. image_pred, pse_pred, legend)
  297. predicted_num += 1
  298. else:
  299. img_data = base64.b64decode(img_data)
  300. img_array = np.frombuffer(img_data, np.uint8)
  301. img = cv2.imdecode(img_array, cv2.COLOR_RGB2BGR)
  302. color_map = get_color_map_list(256)
  303. legend = {}
  304. for i in range(len(model.labels)):
  305. legend[model.labels[i]] = color_map[i]
  306. results = model.predict(img)
  307. image_pred = pdx.seg.visualize(img, results, save_dir=None)
  308. pse_pred = pdx.seg.visualize(img, results, weight=0, save_dir=None)
  309. image_ground = None
  310. pse_label = None
  311. save_name = osp.join(save_dir, 'predict_result.png')
  312. visualize_segmented_result(save_name, image_ground, pse_label,
  313. image_pred, pse_pred, legend)
  314. set_folder_status(predict_status_path, PredictStatus.XPREDONE)
  315. def _call_paddlex_export_infer(task_path, save_dir, export_status_path, epoch):
  316. # 导出模型不使用GPU
  317. sys.stdout = open(
  318. osp.join(export_status_path, 'out.log'), 'w', encoding='utf-8')
  319. sys.stderr = open(
  320. osp.join(export_status_path, 'err.log'), 'w', encoding='utf-8')
  321. import os
  322. os.environ['CUDA_VISIBLE_DEVICES'] = ''
  323. os.environ['PADDLEX_EXPORT_STAGE'] = 'True'
  324. os.environ['PADDLESEG_EXPORT_STAGE'] = 'True'
  325. import paddlex as pdx
  326. model_dir = "epoch_{}".format(epoch) if epoch is not None else "best_model"
  327. model_path = osp.join(task_path, 'output', model_dir)
  328. model = pdx.load_model(model_path)
  329. model._export_inference_model(save_dir)
  330. '''
  331. model_dir = "epoch_{}".format(epoch)
  332. model_path = osp.join(task_path, 'output', model_dir)
  333. if os.path.exists(save_dir):
  334. shutil.rmtree(save_dir)
  335. shutil.copytree(model_path, save_dir)
  336. '''
  337. set_folder_status(export_status_path, TaskStatus.XEXPORTED)
  338. set_folder_status(task_path, TaskStatus.XEXPORTED)
  339. def _call_paddlex_export_quant(task_path, params, save_dir, export_status_path,
  340. epoch):
  341. sys.stdout = open(
  342. osp.join(export_status_path, 'out.log'), 'w', encoding='utf-8')
  343. sys.stderr = open(
  344. osp.join(export_status_path, 'err.log'), 'w', encoding='utf-8')
  345. dataset_path = params['dataset_path']
  346. task_type = params['task_type']
  347. os.environ['CUDA_VISIBLE_DEVICES'] = params['train'].cuda_visible_devices
  348. import paddlex as pdx
  349. if epoch is not None:
  350. model_dir = "epoch_{}".format(epoch)
  351. model_path = osp.join(task_path, 'output', model_dir)
  352. else:
  353. model_path = osp.join(task_path, 'output', 'best_model')
  354. model = pdx.load_model(model_path)
  355. if task_type == "classification":
  356. train_file_list = osp.join(dataset_path, 'train_list.txt')
  357. val_file_list = osp.join(dataset_path, 'val_list.txt')
  358. label_list = osp.join(dataset_path, 'labels.txt')
  359. quant_dataset = pdx.datasets.ImageNet(
  360. data_dir=dataset_path,
  361. file_list=train_file_list,
  362. label_list=label_list,
  363. transforms=model.test_transforms)
  364. eval_dataset = pdx.datasets.ImageNet(
  365. data_dir=dataset_path,
  366. file_list=val_file_list,
  367. label_list=label_list,
  368. transforms=model.eval_transforms)
  369. elif task_type == "detection":
  370. train_file_list = osp.join(dataset_path, 'train_list.txt')
  371. val_file_list = osp.join(dataset_path, 'val_list.txt')
  372. label_list = osp.join(dataset_path, 'labels.txt')
  373. quant_dataset = pdx.datasets.VOCDetection(
  374. data_dir=dataset_path,
  375. file_list=train_file_list,
  376. label_list=label_list,
  377. transforms=model.test_transforms)
  378. eval_dataset = pdx.datasets.VOCDetection(
  379. data_dir=dataset_path,
  380. file_list=val_file_list,
  381. label_list=label_list,
  382. transforms=model.eval_transforms)
  383. elif task_type == "instance_segmentation":
  384. train_json = osp.join(dataset_path, 'train.json')
  385. val_json = osp.join(dataset_path, 'val.json')
  386. quant_dataset = pdx.datasets.CocoDetection(
  387. data_dir=osp.join(dataset_path, 'JPEGImages'),
  388. ann_file=train_json,
  389. transforms=model.test_transforms)
  390. eval_dataset = pdx.datasets.CocoDetection(
  391. data_dir=osp.join(dataset_path, 'JPEGImages'),
  392. ann_file=val_json,
  393. transforms=model.eval_transforms)
  394. elif task_type == "segmentation":
  395. train_file_list = osp.join(dataset_path, 'train_list.txt')
  396. val_file_list = osp.join(dataset_path, 'val_list.txt')
  397. label_list = osp.join(dataset_path, 'labels.txt')
  398. quant_dataset = pdx.datasets.SegDataset(
  399. data_dir=dataset_path,
  400. file_list=train_file_list,
  401. label_list=label_list,
  402. transforms=model.test_transforms)
  403. eval_dataset = pdx.datasets.SegDataset(
  404. data_dir=dataset_path,
  405. file_list=val_file_list,
  406. label_list=label_list,
  407. transforms=model.eval_transforms)
  408. metric_before = model.evaluate(eval_dataset)
  409. pdx.log_level = 3
  410. pdx.slim.export_quant_model(
  411. model, quant_dataset, batch_size=1, save_dir=save_dir, cache_dir=None)
  412. model_quant = pdx.load_model(save_dir)
  413. metric_after = model_quant.evaluate(eval_dataset)
  414. metrics = {}
  415. if task_type == "segmentation":
  416. metrics['before'] = {'miou': metric_before['miou']}
  417. metrics['after'] = {'miou': metric_after['miou']}
  418. else:
  419. metrics['before'] = metric_before
  420. metrics['after'] = metric_after
  421. import json
  422. with open(
  423. osp.join(export_status_path, 'quant_result.json'),
  424. 'w',
  425. encoding='utf-8') as f:
  426. json.dump(metrics, f)
  427. set_folder_status(export_status_path, TaskStatus.XEXPORTED)
  428. set_folder_status(task_path, TaskStatus.XEXPORTED)
  429. def _call_paddlelite_export_lite(model_path, save_dir=None, place="arm"):
  430. import paddlelite.lite as lite
  431. opt = lite.Opt()
  432. model_file = os.path.join(model_path, '__model__')
  433. params_file = os.path.join(model_path, '__params__')
  434. if save_dir is None:
  435. save_dir = osp.join(model_path, "lite_model")
  436. if not osp.exists(save_dir):
  437. os.makedirs(save_dir)
  438. path = osp.join(save_dir, "model")
  439. opt.run_optimize("", model_file, params_file, "naive_buffer", place, path)
  440. def safe_clean_folder(folder):
  441. if osp.exists(folder):
  442. try:
  443. shutil.rmtree(folder)
  444. os.makedirs(folder)
  445. except Exception as e:
  446. pass
  447. if osp.exists(folder):
  448. for root, dirs, files in os.walk(folder):
  449. for name in files:
  450. try:
  451. os.remove(os.path.join(root, name))
  452. except Exception as e:
  453. pass
  454. else:
  455. os.makedirs(folder)
  456. else:
  457. os.makedirs(folder)
  458. if not osp.exists(folder):
  459. os.makedirs(folder)
  460. def get_task_max_saved_epochs(task_path):
  461. saved_epoch_num = -1
  462. output_path = osp.join(task_path, "output")
  463. if osp.exists(output_path):
  464. for f in os.listdir(output_path):
  465. if f.startswith("epoch_"):
  466. if not osp.exists(osp.join(output_path, f, '.success')):
  467. continue
  468. curr_epoch_num = int(f[6:])
  469. if curr_epoch_num > saved_epoch_num:
  470. saved_epoch_num = curr_epoch_num
  471. return saved_epoch_num
  472. def get_task_status(task_path):
  473. status, message = get_folder_status(task_path, True)
  474. task_id = os.path.split(task_path)[-1]
  475. err_log = os.path.join(task_path, 'err.log')
  476. if status in [TaskStatus.XTRAINING, TaskStatus.XPRUNETRAIN]:
  477. pid = int(message)
  478. is_dead = False
  479. if not psutil.pid_exists(pid):
  480. is_dead = True
  481. else:
  482. p = psutil.Process(pid)
  483. if p.status() == 'zombie':
  484. is_dead = True
  485. if is_dead:
  486. status = TaskStatus.XTRAINFAIL
  487. message = "训练任务{}异常终止,请查阅错误日志具体确认原因{}。\n\n 如若通过日志无法确定原因,可尝试以下几种方法,\n" \
  488. "1. 尝试重新启动训练,看是否能正常训练; \n" \
  489. "2. 调低batch_size(需同时按比例调低学习率等参数)排除是否是显存或内存不足的原因导致;\n" \
  490. "3. 前往GitHub提ISSUE,描述清楚问题会有工程师及时回复: https://github.com/PaddlePaddle/PaddleX/issues ; \n" \
  491. "3. 加QQ群1045148026或邮件至paddlex@baidu.com在线咨询工程师".format(task_id, err_log)
  492. set_folder_status(task_path, status, message)
  493. return status, message
  494. def train_model(task_path):
  495. """训练模型
  496. Args:
  497. task_path(str): 模型训练的参数保存在task_path下的'params.pkl'文件中
  498. """
  499. params_conf_file = osp.join(task_path, 'params.pkl')
  500. assert osp.exists(
  501. params_conf_file), "任务无法启动,路径{}下不存在参数配置文件params.pkl".format(task_path)
  502. with open(params_conf_file, 'rb') as f:
  503. params = pickle.load(f)
  504. sensitivities_path = params['train'].sensitivities_path
  505. p = mp.Process(target=_call_paddlex_train, args=(task_path, params))
  506. p.start()
  507. if sensitivities_path is None:
  508. set_folder_status(task_path, TaskStatus.XTRAINING, p.pid)
  509. else:
  510. set_folder_status(task_path, TaskStatus.XPRUNETRAIN, p.pid)
  511. return p
  512. def stop_train_model(task_path):
  513. """停止正在训练的模型
  514. Args:
  515. task_path(str): 从task_path下的'XTRANING'文件中获取训练的进程id
  516. """
  517. status, message = get_task_status(task_path)
  518. if status in [TaskStatus.XTRAINING, TaskStatus.XPRUNETRAIN]:
  519. pid = int(message)
  520. pkill(pid)
  521. best_model_saved = True
  522. if not osp.exists(osp.join(task_path, 'output', 'best_model')):
  523. best_model_saved = False
  524. set_folder_status(task_path, TaskStatus.XTRAINEXIT, best_model_saved)
  525. else:
  526. raise Exception("模型训练任务没在运行中")
  527. def prune_analysis_model(task_path):
  528. """模型裁剪分析
  529. Args:
  530. task_path(str): 模型训练的参数保存在task_path
  531. dataset_path(str) 模型裁剪中评估数据集的路径
  532. """
  533. best_model_path = osp.join(task_path, 'output', 'best_model')
  534. assert osp.exists(best_model_path), "该任务暂未保存模型,无法进行模型裁剪分析"
  535. prune_analysis_path = osp.join(task_path, 'prune')
  536. if not osp.exists(prune_analysis_path):
  537. os.makedirs(prune_analysis_path)
  538. params_conf_file = osp.join(task_path, 'params.pkl')
  539. assert osp.exists(
  540. params_conf_file), "任务无法启动,路径{}下不存在参数配置文件params.pkl".format(task_path)
  541. with open(params_conf_file, 'rb') as f:
  542. params = pickle.load(f)
  543. assert params['train'].model.lower() not in [
  544. "fasterrcnn", "maskrcnn"
  545. ], "暂不支持FasterRCNN、MaskRCNN模型裁剪"
  546. p = mp.Process(
  547. target=_call_paddle_prune,
  548. args=(best_model_path, prune_analysis_path, params))
  549. p.start()
  550. set_folder_status(prune_analysis_path, PruneStatus.XSPRUNEING, p.pid)
  551. set_folder_status(task_path, TaskStatus.XPRUNEING, p.pid)
  552. return p
  553. def get_prune_status(prune_path):
  554. status, message = get_folder_status(prune_path, True)
  555. if status in [PruneStatus.XSPRUNEING]:
  556. pid = int(message)
  557. is_dead = False
  558. if not psutil.pid_exists(pid):
  559. is_dead = True
  560. else:
  561. p = psutil.Process(pid)
  562. if p.status() == 'zombie':
  563. is_dead = True
  564. if is_dead:
  565. status = PruneStatus.XSPRUNEFAIL
  566. message = "模型裁剪异常终止,可能原因如下:\n1.暂不支持FasterRCNN、MaskRCNN模型的模型裁剪\n2.模型裁剪过程中进程被异常结束,建议重新启动模型裁剪任务"
  567. set_folder_status(prune_path, status, message)
  568. return status, message
  569. def stop_prune_analysis(prune_path):
  570. """停止正在裁剪分析的模型
  571. Args:
  572. prune_path(str): prune_path'XSSLMING'文件中获取训练的进程id
  573. """
  574. status, message = get_prune_status(prune_path)
  575. if status == PruneStatus.XSPRUNEING:
  576. pid = int(message)
  577. pkill(pid)
  578. set_folder_status(prune_path, PruneStatus.XSPRUNEEXIT)
  579. else:
  580. raise Exception("模型裁剪分析任务未在运行中")
  581. def evaluate_model(task_path,
  582. task_type,
  583. epoch=None,
  584. topk=5,
  585. score_thresh=0.3,
  586. overlap_thresh=0.5):
  587. """评估最优模型
  588. Args:
  589. task_path(str): 模型训练相关结果的保存路径
  590. """
  591. output_path = osp.join(task_path, 'output')
  592. if not osp.exists(osp.join(output_path, 'best_model')):
  593. raise Exception("未在训练路径{}下发现保存的best_model,无法进行评估".format(output_path))
  594. evaluate_status_path = osp.join(task_path, './logs/evaluate')
  595. safe_clean_folder(evaluate_status_path)
  596. if epoch is None:
  597. model_path = osp.join(output_path, 'best_model')
  598. else:
  599. epoch_dir = "{}_{}".format('epoch', epoch)
  600. model_path = osp.join(output_path, epoch_dir)
  601. p = mp.Process(
  602. target=_call_paddlex_evaluate_model,
  603. args=(task_path, model_path, task_type, epoch, topk, score_thresh,
  604. overlap_thresh))
  605. p.start()
  606. set_folder_status(evaluate_status_path, TaskStatus.XEVALUATING, p.pid)
  607. return p
  608. def get_evaluate_status(task_path):
  609. """获取导出状态
  610. Args:
  611. task_path(str): 训练任务文件夹
  612. """
  613. evaluate_status_path = osp.join(task_path, './logs/evaluate')
  614. if not osp.exists(evaluate_status_path):
  615. return None, "No evaluate fold in path {}".format(task_path)
  616. status, message = get_folder_status(evaluate_status_path, True)
  617. if status == TaskStatus.XEVALUATING:
  618. pid = int(message)
  619. is_dead = False
  620. if not psutil.pid_exists(pid):
  621. is_dead = True
  622. else:
  623. p = psutil.Process(pid)
  624. if p.status() == 'zombie':
  625. is_dead = True
  626. if is_dead:
  627. status = TaskStatus.XEVALUATEFAIL
  628. message = "评估过程出现异常,请尝试重新评估!"
  629. set_folder_status(evaluate_status_path, status, message)
  630. if status not in [
  631. TaskStatus.XEVALUATING, TaskStatus.XEVALUATED,
  632. TaskStatus.XEVALUATEFAIL
  633. ]:
  634. raise ValueError("Wrong status in evaluate task {}".format(status))
  635. return status, message
  636. def get_predict_status(task_path):
  637. """获取预测任务状态
  638. Args:
  639. task_path(str): 从predict_path下的'XPRESTART'文件中获取训练的进程id
  640. """
  641. from ..utils import list_files
  642. predict_status_path = osp.join(task_path, "./logs/predict")
  643. save_dir = osp.join(task_path, "visualized_test_results")
  644. if not osp.exists(save_dir):
  645. return None, "任务目录下没有visualized_test_results文件夹,{}".format(
  646. task_path), 0, 0
  647. status, message = get_folder_status(predict_status_path, True)
  648. if status == PredictStatus.XPRESTART:
  649. pid = int(message)
  650. is_dead = False
  651. if not psutil.pid_exists(pid):
  652. is_dead = True
  653. else:
  654. p = psutil.Process(pid)
  655. if p.status() == 'zombie':
  656. is_dead = True
  657. if is_dead:
  658. status = PredictStatus.XPREFAIL
  659. message = "图片预测过程出现异常,请尝试重新预测!"
  660. set_folder_status(predict_status_path, status, message)
  661. if status not in [
  662. PredictStatus.XPRESTART, PredictStatus.XPREDONE,
  663. PredictStatus.XPREFAIL
  664. ]:
  665. raise ValueError("预测任务状态异常,{}".format(status))
  666. predict_num = len(list_files(save_dir))
  667. if predict_num > 0:
  668. if predict_num == 1:
  669. total_num = 1
  670. else:
  671. total_num = int(
  672. open(
  673. osp.join(predict_status_path, "total_num"),
  674. encoding='utf-8').readline().strip())
  675. else:
  676. predict_num = 0
  677. total_num = 0
  678. return status, message, predict_num, total_num
  679. def predict_test_pics(task_path,
  680. img_list=[],
  681. img_data=None,
  682. save_dir=None,
  683. score_thresh=0.5,
  684. epoch=None):
  685. """模型预测
  686. Args:
  687. task_path(str): 模型训练的参数保存在task_path下的'params.pkl'文件中
  688. """
  689. params_conf_file = osp.join(task_path, 'params.pkl')
  690. assert osp.exists(
  691. params_conf_file), "任务无法启动,路径{}下不存在参数配置文件params.pkl".format(task_path)
  692. with open(params_conf_file, 'rb') as f:
  693. params = pickle.load(f)
  694. predict_status_path = osp.join(task_path, "./logs/predict")
  695. safe_clean_folder(predict_status_path)
  696. save_dir = osp.join(task_path, 'visualized_test_results')
  697. safe_clean_folder(save_dir)
  698. p = mp.Process(
  699. target=_call_paddlex_predict,
  700. args=(task_path, predict_status_path, params, img_list, img_data,
  701. save_dir, score_thresh, epoch))
  702. p.start()
  703. set_folder_status(predict_status_path, PredictStatus.XPRESTART, p.pid)
  704. return p, save_dir
  705. def stop_predict_task(task_path):
  706. """停止预测任务
  707. Args:
  708. task_path(str): 从predict_path下的'XPRESTART'文件中获取训练的进程id
  709. """
  710. from ..utils import list_files
  711. predict_status_path = osp.join(task_path, "./logs/predict")
  712. save_dir = osp.join(task_path, "visualized_test_results")
  713. if not osp.exists(save_dir):
  714. return None, "任务目录下没有visualized_test_results文件夹,{}".format(
  715. task_path), 0, 0
  716. status, message = get_folder_status(predict_status_path, True)
  717. if status == PredictStatus.XPRESTART:
  718. pid = int(message)
  719. is_dead = False
  720. if not psutil.pid_exists(pid):
  721. is_dead = True
  722. else:
  723. p = psutil.Process(pid)
  724. if p.status() == 'zombie':
  725. is_dead = True
  726. if is_dead:
  727. status = PredictStatus.XPREFAIL
  728. message = "图片预测过程出现异常,请尝试重新预测!"
  729. set_folder_status(predict_status_path, status, message)
  730. else:
  731. pkill(pid)
  732. status = PredictStatus.XPREFAIL
  733. message = "图片预测进程已停止!"
  734. set_folder_status(predict_status_path, status, message)
  735. if status not in [
  736. PredictStatus.XPRESTART, PredictStatus.XPREDONE,
  737. PredictStatus.XPREFAIL
  738. ]:
  739. raise ValueError("预测任务状态异常,{}".format(status))
  740. predict_num = len(list_files(save_dir))
  741. if predict_num > 0:
  742. total_num = int(
  743. open(
  744. osp.join(predict_status_path, "total_num"), encoding='utf-8')
  745. .readline().strip())
  746. else:
  747. predict_num = 0
  748. total_num = 0
  749. return status, message, predict_num, total_num
  750. def get_export_status(task_path):
  751. """获取导出状态
  752. Args:
  753. task_path(str): 从task_path下的'export/XEXPORTING'文件中获取训练的进程id
  754. Return:
  755. 导出的状态和其他消息.
  756. """
  757. export_status_path = osp.join(task_path, './logs/export')
  758. if not osp.exists(export_status_path):
  759. return None, "{}任务目录下没有export文件夹".format(task_path)
  760. status, message = get_folder_status(export_status_path, True)
  761. if status == TaskStatus.XEXPORTING:
  762. pid = int(message)
  763. is_dead = False
  764. if not psutil.pid_exists(pid):
  765. is_dead = True
  766. else:
  767. p = psutil.Process(pid)
  768. if p.status() == 'zombie':
  769. is_dead = True
  770. if is_dead:
  771. status = TaskStatus.XEXPORTFAIL
  772. message = "导出过程出现异常,请尝试重新评估!"
  773. set_folder_status(export_status_path, status, message)
  774. if status not in [
  775. TaskStatus.XEXPORTING, TaskStatus.XEXPORTED, TaskStatus.XEXPORTFAIL
  776. ]:
  777. # raise ValueError("获取到的导出状态异常,{}。".format(status))
  778. return None, "获取到的导出状态异常,{}。".format(status)
  779. return status, message
  780. def export_quant_model(task_path, save_dir, epoch=None):
  781. """导出量化模型
  782. Args:
  783. task_path(str): 模型训练的路径
  784. save_dir(str): 导出后的模型保存路径
  785. """
  786. output_path = osp.join(task_path, 'output')
  787. if not osp.exists(osp.join(output_path, 'best_model')):
  788. raise Exception("未在训练路径{}下发现保存的best_model,导出失败".format(output_path))
  789. export_status_path = osp.join(task_path, './logs/export')
  790. safe_clean_folder(export_status_path)
  791. params_conf_file = osp.join(task_path, 'params.pkl')
  792. assert osp.exists(
  793. params_conf_file), "任务无法启动,路径{}下不存在参数配置文件params.pkl".format(task_path)
  794. with open(params_conf_file, 'rb') as f:
  795. params = pickle.load(f)
  796. p = mp.Process(
  797. target=_call_paddlex_export_quant,
  798. args=(task_path, params, save_dir, export_status_path, epoch))
  799. p.start()
  800. set_folder_status(export_status_path, TaskStatus.XEXPORTING, p.pid)
  801. set_folder_status(task_path, TaskStatus.XEXPORTING, p.pid)
  802. return p
  803. def export_noquant_model(task_path, save_dir, epoch=None):
  804. """导出inference模型
  805. Args:
  806. task_path(str): 模型训练的路径
  807. save_dir(str): 导出后的模型保存路径
  808. """
  809. output_path = osp.join(task_path, 'output')
  810. if not osp.exists(osp.join(output_path, 'best_model')):
  811. raise Exception("未在训练路径{}下发现保存的best_model,导出失败".format(output_path))
  812. export_status_path = osp.join(task_path, './logs/export')
  813. safe_clean_folder(export_status_path)
  814. p = mp.Process(
  815. target=_call_paddlex_export_infer,
  816. args=(task_path, save_dir, export_status_path, epoch))
  817. p.start()
  818. set_folder_status(export_status_path, TaskStatus.XEXPORTING, p.pid)
  819. set_folder_status(task_path, TaskStatus.XEXPORTING, p.pid)
  820. return p
  821. def opt_lite_model(model_path, save_dir=None, place='arm'):
  822. p = mp.Process(
  823. target=_call_paddlelite_export_lite,
  824. args=(model_path, save_dir, place))
  825. p.start()
  826. p.join()
  827. def stop_export_task(task_path):
  828. """停止导出
  829. Args:
  830. task_path(str): 从task_path下的'export/XEXPORTING'文件中获取训练的进程id
  831. Return:
  832. the export status and message.
  833. """
  834. export_status_path = osp.join(task_path, './logs/export')
  835. if not osp.exists(export_status_path):
  836. return None, "{}任务目录下没有export文件夹".format(task_path)
  837. status, message = get_folder_status(export_status_path, True)
  838. if status == TaskStatus.XEXPORTING:
  839. pid = int(message)
  840. is_dead = False
  841. if not psutil.pid_exists(pid):
  842. is_dead = True
  843. else:
  844. p = psutil.Process(pid)
  845. if p.status() == 'zombie':
  846. is_dead = True
  847. if is_dead:
  848. status = TaskStatus.XEXPORTFAIL
  849. message = "导出过程出现异常,请尝试重新评估!"
  850. set_folder_status(export_status_path, status, message)
  851. else:
  852. pkill(pid)
  853. status = TaskStatus.XEXPORTFAIL
  854. message = "已停止导出进程!"
  855. set_folder_status(export_status_path, status, message)
  856. if status not in [
  857. TaskStatus.XEXPORTING, TaskStatus.XEXPORTED, TaskStatus.XEXPORTFAIL
  858. ]:
  859. raise ValueError("获取到的导出状态异常,{}。".format(status))
  860. return status, message