app.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from flask import Flask, request, render_template, send_from_directory, jsonify, session, send_file
  15. from werkzeug.utils import secure_filename
  16. from flask_cors import CORS
  17. import argparse
  18. from os import path as osp
  19. import os
  20. import time
  21. import json
  22. import sys
  23. import multiprocessing as mp
  24. from . import workspace_pb2 as w
  25. from .utils import CustomEncoder, ShareData, is_pic, get_logger, TaskStatus, get_ip
  26. from paddlex_restful.restful.dataset.utils import get_encoding
  27. import numpy as np
  28. app = Flask(__name__)
  29. CORS(app, supports_credentials=True)
  30. SESSION_TYPE = 'filesystem'
  31. app.config.from_object(__name__)
  32. SD = ShareData()
  33. def init(dirname, logger):
  34. #初始化工作空间
  35. from .workspace import init_workspace
  36. from .system import get_system_info
  37. SD.workspace = w.Workspace(path=dirname)
  38. init_workspace(SD.workspace, dirname, logger)
  39. SD.workspace_dir = dirname
  40. get_system_info(SD.machine_info)
  41. @app.errorhandler(Exception)
  42. def handle_exception(e):
  43. ret = {"status": -1, 'message': repr(e)}
  44. return ret
  45. @app.route('/workspace', methods=['GET', 'PUT'])
  46. def workspace():
  47. """
  48. methods=='GET':获取工作目录中项目、数据集、任务的属性
  49. Args:
  50. struct(str):结构类型,可以是'dataset', 'project'或'task',
  51. id(str):结构类型对应的id
  52. attr_list(list):需要获取的属性的列表
  53. Return:
  54. attr(dict):key为属性,value为属性的值,
  55. status
  56. methods=='PUT':修改工作目录中项目、数据集、任务的属性
  57. Args:
  58. struct(str):结构类型,可以是'dataset', 'project'或'task',
  59. id(str):结构类型对应的id
  60. attr_dict(dict):key:需要修改的属性,value:需要修改属性的值
  61. Return:
  62. status
  63. """
  64. data = request.get_json()
  65. if data is None:
  66. data = request.args
  67. if request.method == 'GET':
  68. if data:
  69. from .workspace import get_attr
  70. ret = get_attr(data, SD.workspace)
  71. return ret
  72. return {'status': 1, 'dirname': SD.workspace_dir}
  73. if request.method == 'PUT':
  74. from .workspace import set_attr
  75. ret = set_attr(data, SD.workspace)
  76. return ret
  77. @app.route('/dataset', methods=['GET', 'POST', 'PUT', 'DELETE'])
  78. def dataset():
  79. """
  80. methods=='GET':获取所有数据集或者单个数据集的信息
  81. Args:
  82. did(str, optional):数据集id(可选),如果存在就返回数据集id对应数据集的信息
  83. Ruturn:
  84. status
  85. if 'did' in Args:
  86. id(str):数据集id,
  87. dataset_status(int):数据集状态(DatasetStatus)枚举变量的值
  88. message(str):数据集状态信息
  89. attr(dict):数据集属性
  90. else:
  91. datasets(list):所有数据集属性的列表
  92. methods=='POST':创建一个新的数据集
  93. Args:
  94. name(str):数据集名字
  95. desc(str):数据集描述
  96. dataset_type(str):数据集类型,可以是['classification', 'detection', 'segmentation','instance_segmentation','remote_segmentation']
  97. Return:
  98. did(str):数据集id
  99. status
  100. methods=='PUT':异步,向数据集导入数据,支持分类、检测、语义分割、实例分割、摇杆分割数据集类型
  101. Args:
  102. did(str):数据集id
  103. path(str):数据集路径
  104. Return:
  105. status
  106. methods=='DELETE':删除已有的某个数据集
  107. Args:
  108. did(str):数据集id
  109. Return:
  110. status
  111. """
  112. data = request.get_json()
  113. if data is None:
  114. data = request.args
  115. if request.method == 'GET':
  116. if 'did' in data:
  117. from .dataset.dataset import get_dataset_status
  118. ret = get_dataset_status(data, SD.workspace)
  119. return ret
  120. from .dataset.dataset import list_datasets
  121. ret = list_datasets(SD.workspace)
  122. return ret
  123. if request.method == 'POST':
  124. from .dataset.dataset import create_dataset
  125. ret = create_dataset(data, SD.workspace)
  126. return ret
  127. if request.method == 'PUT':
  128. from .dataset.dataset import import_dataset
  129. ret = import_dataset(data, SD.workspace, SD.monitored_processes,
  130. SD.load_demo_proc_dict)
  131. return ret
  132. if request.method == 'DELETE':
  133. from .dataset.dataset import delete_dataset
  134. ret = delete_dataset(data, SD.workspace)
  135. return ret
  136. @app.route('/dataset/details', methods=['GET'])
  137. def dataset_details():
  138. """
  139. methods=='GET':获取某个数据集的详细信息
  140. Args:
  141. did(str):数据集id
  142. Return:
  143. details(dict):数据集详细信息,
  144. status
  145. """
  146. data = request.get_json()
  147. if data is None:
  148. data = request.args
  149. if request.method == 'GET':
  150. from .dataset.dataset import get_dataset_details
  151. ret = get_dataset_details(data, SD.workspace)
  152. return ret
  153. @app.route('/dataset/split', methods=['PUT'])
  154. def dataset_split():
  155. """
  156. Args:
  157. did(str):数据集id
  158. val_split(float): 验证集比例
  159. test_split(float): 测试集比例
  160. Return:
  161. status
  162. """
  163. data = request.get_json()
  164. if request.method == 'PUT':
  165. from .dataset.dataset import split_dataset
  166. ret = split_dataset(data, SD.workspace)
  167. return ret
  168. @app.route('/dataset/image', methods=['GET'])
  169. def dataset_img_base64():
  170. """
  171. Args:
  172. GET: 获取图片base64数据,参数:'path' 图片绝对路径
  173. """
  174. data = request.get_json()
  175. if request.method == 'GET':
  176. from .dataset.dataset import img_base64
  177. ret = img_base64(data)
  178. return ret
  179. @app.route('/dataset/file', methods=['GET'])
  180. def get_image_file():
  181. """
  182. Args:
  183. GET: 获取文件数据,参数:'path' 文件绝对路径
  184. """
  185. data = request.get_json()
  186. if request.method == 'GET':
  187. ret = data['path']
  188. assert os.path.abspath(ret).startswith(
  189. os.path.abspath(SD.workspace_dir)
  190. ) and ".." not in ret, "Illegal path {}.".format(ret)
  191. return send_file(ret)
  192. @app.route('/dataset/npy', methods=['GET'])
  193. def get_npyfile():
  194. """
  195. Args:
  196. GET: 获取文件数据,参数:'path' npy文件绝对路径
  197. """
  198. data = request.get_json()
  199. if request.method == 'GET':
  200. npy = np.load(data['path'], allow_pickle=True).tolist()
  201. npy['gt_bbox'] = npy['gt_bbox'].tolist()
  202. return npy
  203. @app.route('/file', methods=['GET'])
  204. def get_file():
  205. """
  206. Args:
  207. path'(str):文件在服务端的路径
  208. Return:
  209. #数据为图片
  210. img_data(str): base64图片数据
  211. status
  212. #数据为xml文件
  213. ret:数据流
  214. #数据为log文件
  215. ret:json数据
  216. """
  217. data = request.get_json()
  218. if data is None:
  219. data = request.args
  220. if request.method == 'GET':
  221. path = data['path']
  222. if not os.path.exists(path):
  223. return {'status': -1}
  224. if is_pic(path):
  225. from .dataset.dataset import img_base64
  226. ret = img_base64(data, SD.workspace)
  227. return ret
  228. file_type = path[(path.rfind('.') + 1):]
  229. if file_type in ['xml', 'npy', 'log']:
  230. return send_file(path)
  231. else:
  232. pass
  233. @app.route('/project', methods=['GET', 'POST', 'DELETE'])
  234. def project():
  235. """
  236. methods=='GET':获取指定项目id的信息
  237. Args:
  238. 'id'(str, optional):项目id,可选,如果存在就返回项目id对应项目的信息
  239. Return:
  240. status,
  241. if 'id' in Args:
  242. attr(dict):项目属性
  243. else:
  244. projects(list):所有项目属性
  245. methods=='POST':创建一个项目
  246. Args:
  247. name(str): 项目名
  248. desc(str):项目描述
  249. project_type(str):项目类型
  250. Return:
  251. pid(str):项目id
  252. status
  253. methods=='DELETE':删除一个项目,以及项目相关的task
  254. Args:
  255. pid(str):项目id
  256. Return:
  257. status
  258. """
  259. data = request.get_json()
  260. if data is None:
  261. data = request.args
  262. if request.method == 'GET':
  263. from .project.project import list_projects
  264. from .project.project import get_project
  265. if 'id' in data:
  266. ret = get_project(data, SD.workspace)
  267. return ret
  268. ret = list_projects(SD.workspace)
  269. return ret
  270. if request.method == 'POST':
  271. from .project.project import create_project
  272. ret = create_project(data, SD.workspace)
  273. return ret
  274. if request.method == 'DELETE':
  275. from .project.project import delete_project
  276. ret = delete_project(data, SD.workspace)
  277. return ret
  278. @app.route('/project/task', methods=['GET', 'POST', 'DELETE'])
  279. def task():
  280. """
  281. methods=='GET':#获取某个任务的信息或者所有任务的信息
  282. Args:
  283. tid(str, optional):任务id,可选,若存在即返回id对应任务的信息
  284. resume(str, optional):获取是否可以恢复训练的状态,可选,需在存在tid的情况下才生效
  285. pid(str, optional):项目id,可选,若存在即返回该项目id下所有任务信息
  286. Return:
  287. status
  288. if 'tid' in Args:
  289. task_status(int):任务状态(TaskStatus)枚举变量的值
  290. message(str):任务状态信息
  291. type:任务类型包括{'classification', 'detection', 'segmentation', 'instance_segmentation'}
  292. resumable(bool):仅Args中存在resume时返回,任务训练是否可以恢复
  293. max_saved_epochs(int):仅Args中存在resume时返回,当前训练模型保存的最大epoch
  294. else:
  295. tasks(list):所有任务属性
  296. methods=='POST':#创建任务(训练或者裁剪)
  297. Args:
  298. pid(str):项目id
  299. train(dict):训练参数
  300. desc(str, optional):任务描述,可选
  301. parent_id(str, optional):可选,若存在即表示新建的任务为裁剪任务,parent_id的值为裁剪任务对应的训练任务id
  302. Return:
  303. tid(str):任务id
  304. status
  305. methods=='DELETE':#删除任务
  306. Args:
  307. tid(str):任务id
  308. Return:
  309. status
  310. """
  311. data = request.get_json()
  312. if data is None:
  313. data = request.args
  314. if request.method == 'GET':
  315. if data:
  316. if 'pid' not in data:
  317. from .project.task import get_task_status
  318. ret = get_task_status(data, SD.workspace)
  319. return ret
  320. from .project.task import list_tasks
  321. ret = list_tasks(data, SD.workspace)
  322. return ret
  323. if request.method == 'POST':
  324. from .project.task import create_task
  325. ret = create_task(data, SD.workspace)
  326. return ret
  327. if request.method == 'DELETE':
  328. from .project.task import delete_task
  329. ret = delete_task(data, SD.workspace)
  330. return ret
  331. @app.route('/project/task/params', methods=['GET', 'POST'])
  332. def task_params():
  333. """
  334. methods=='GET':#获取任务id对应的参数,或者获取项目默认参数
  335. Args:
  336. tid(str, optional):获取任务对应的参数
  337. pid(str,optional):获取项目对应的默认参数
  338. model_type(str,optional):pid存在下有效,对应项目下获取指定模型的默认参数
  339. gpu_list(list,optional):pid存在下有效,默认值为[0],使用指定的gpu并获取相应的默认参数
  340. Return:
  341. train(dict):训练或者裁剪的参数
  342. status
  343. methods=='POST':#设置任务参数,将前端用户设置训练参数dict保存在后端的pkl文件中
  344. Args:
  345. tid(str):任务id
  346. train(dict):训练参数
  347. Return:
  348. status
  349. """
  350. data = request.get_json()
  351. if data is None:
  352. data = request.args
  353. if request.method == 'GET':
  354. if 'tid' in data:
  355. from .project.task import get_task_params
  356. ret = get_task_params(data, SD.workspace)
  357. ret['train'] = CustomEncoder().encode(ret['train'])
  358. ret['train'] = json.loads(ret['train'])
  359. return ret
  360. if 'pid' in data:
  361. from .project.task import get_default_params
  362. ret = get_default_params(data, SD.workspace, SD.machine_info)
  363. return ret
  364. if request.method == 'POST':
  365. from .project.task import set_task_params
  366. ret = set_task_params(data, SD.workspace)
  367. return ret
  368. @app.route('/project/task/metrics', methods=['GET'])
  369. def task_metrics():
  370. """
  371. methods=='GET':#获取日志数据
  372. Args:
  373. tid(str):任务id
  374. type(str):可以获取日志的类型,[train,eval,sensitivities,prune],包括训练,评估,敏感度与模型裁剪率关系图,裁剪的日志
  375. Return:
  376. status
  377. if type == 'train':
  378. train_log(dict): 训练日志
  379. elif type == 'eval':
  380. eval_metrics(dict): 评估结果
  381. elif type == 'sensitivities':
  382. sensitivities_loss_img(dict): 敏感度与模型裁剪率关系图
  383. elif type == 'prune':
  384. prune_log(dict):裁剪日志
  385. """
  386. data = request.get_json()
  387. if data is None:
  388. data = request.args
  389. if request.method == 'GET':
  390. if data['type'] == 'train':
  391. from .project.task import get_train_metrics
  392. ret = get_train_metrics(data, SD.workspace)
  393. return ret
  394. if data['type'] == 'eval':
  395. from .project.task import get_eval_metrics
  396. ret = get_eval_metrics(data, SD.workspace)
  397. return ret
  398. if data['type'] == 'eval_all':
  399. from .project.task import get_eval_all_metrics
  400. ret = get_eval_all_metrics(data, SD.workspace)
  401. return ret
  402. if data['type'] == 'sensitivities':
  403. from .project.task import get_sensitivities_loss_img
  404. ret = get_sensitivities_loss_img(data, SD.workspace)
  405. return ret
  406. if data['type'] == 'prune':
  407. from .project.task import get_prune_metrics
  408. ret = get_prune_metrics(data, SD.workspace)
  409. return ret
  410. @app.route('/project/task/train', methods=['POST', 'PUT'])
  411. def task_train():
  412. """
  413. methods=='POST':#异步,启动训练或者裁剪任务
  414. Args:
  415. tid(str):任务id
  416. eval_metric_loss(int,optional):可选,裁剪任务时可用,裁剪任务所需的评估loss
  417. Return:
  418. status
  419. methods=='PUT':#改变任务训练的状态,即终止训练或者恢复训练
  420. Args:
  421. tid(str):任务id
  422. act(str):[stop,resume]暂停或者恢复
  423. epoch(int):(resume下可以设置)恢复训练的起始轮数
  424. Return:
  425. status
  426. """
  427. data = request.get_json()
  428. if request.method == 'POST':
  429. from .project.task import start_train_task
  430. ret = start_train_task(data, SD.workspace, SD.monitored_processes)
  431. return ret
  432. if request.method == 'PUT':
  433. if data['act'] == 'resume':
  434. from .project.task import resume_train_task
  435. ret = resume_train_task(data, SD.workspace, SD.monitored_processes)
  436. return ret
  437. if data['act'] == 'stop':
  438. from .project.task import stop_train_task
  439. ret = stop_train_task(data, SD.workspace)
  440. return ret
  441. @app.route('/project/task/train/file', methods=['GET'])
  442. def log_file():
  443. data = request.get_json()
  444. if request.method == 'GET':
  445. path = data['path']
  446. if not os.path.exists(path):
  447. return {'status': -1}
  448. logs = open(path, encoding='utf-8').readlines()
  449. if len(logs) < 50:
  450. return {'status': 1, 'log': logs}
  451. else:
  452. logs = logs[-50:]
  453. return {'status': 1, 'log': logs}
  454. @app.route('/project/task/prune', methods=['GET', 'POST', 'PUT'])
  455. def task_prune():
  456. """
  457. methods=='GET':#获取裁剪任务的状态
  458. Args:
  459. tid(str):任务id
  460. Return:
  461. prune_status(int): 裁剪任务状态(PruneStatus)枚举变量的值
  462. status
  463. methods=='POST':#异步,创建一个裁剪分析,对于启动裁剪任务前需要先启动裁剪分析
  464. Args:
  465. tid(str):任务id
  466. Return:
  467. status
  468. methods=='PUT':#改变裁剪分析任务的状态
  469. Args:
  470. tid(str):任务id
  471. act(str):[stop],目前仅支持停止一个裁剪分析任务
  472. Return
  473. status
  474. """
  475. data = request.get_json()
  476. if data is None:
  477. data = request.args
  478. if request.method == 'GET':
  479. from .project.task import get_prune_status
  480. ret = get_prune_status(data, SD.workspace)
  481. return ret
  482. if request.method == 'POST':
  483. from .project.task import start_prune_analysis
  484. ret = start_prune_analysis(data, SD.workspace, SD.monitored_processes)
  485. return ret
  486. if request.method == 'PUT':
  487. if data['act'] == 'stop':
  488. from .project.task import stop_prune_analysis
  489. ret = stop_prune_analysis(data, SD.workspace)
  490. return ret
  491. @app.route('/project/task/evaluate', methods=['GET', 'POST'])
  492. def task_evaluate():
  493. '''
  494. methods=='GET':#获取模型评估的结果
  495. Args:
  496. tid(str):任务id
  497. Return:
  498. evaluate_status(int): 任务状态(TaskStatus)枚举变量的值
  499. message(str):描述评估任务的信息
  500. result(dict):如果评估成功,返回评估结果的dict,否则为None
  501. status
  502. methods=='POST':#异步,创建一个评估任务
  503. Args:
  504. tid(str):任务id
  505. epoch(int,optional):需要评估的epoch,如果为None则会评估训练时指标最好的epoch
  506. topk(int,optional):分类任务topk指标,如果为None默认输入为5
  507. score_thresh(float):检测任务类别的score threshhold值,如果为None默认输入为0.5
  508. overlap_thresh(float):实例分割任务IOU threshhold值,如果为None默认输入为0.3
  509. Return:
  510. status
  511. '''
  512. data = request.get_json()
  513. if data is None:
  514. data = request.args
  515. if request.method == 'GET':
  516. from .project.task import get_evaluate_result
  517. ret = get_evaluate_result(data, SD.workspace)
  518. if ret['evaluate_status'] == TaskStatus.XEVALUATED and ret[
  519. 'result'] is not None:
  520. if 'Confusion_Matrix' in ret['result']:
  521. ret['result']['Confusion_Matrix'] = ret['result'][
  522. 'Confusion_Matrix'].tolist()
  523. ret['result'] = CustomEncoder().encode(ret['result'])
  524. ret['result'] = json.loads(ret['result'])
  525. ret['evaluate_status'] = ret['evaluate_status'].value
  526. return ret
  527. if request.method == 'POST':
  528. from .project.task import evaluate_model
  529. ret = evaluate_model(data, SD.workspace, SD.monitored_processes)
  530. return ret
  531. @app.route('/project/task/evaluate/file', methods=['GET'])
  532. def task_evaluate_file():
  533. data = request.get_json()
  534. if request.method == 'GET':
  535. if 'path' in data:
  536. ret = data['path']
  537. assert os.path.abspath(ret).startswith(
  538. os.path.abspath(SD.workspace_dir)
  539. ) and ".." not in ret, "Illegal path {}.".format(ret)
  540. return send_file(ret)
  541. else:
  542. from .project.task import get_evaluate_result
  543. from .project.task import import_evaluate_excel
  544. ret = get_evaluate_result(data, SD.workspace)
  545. if ret['evaluate_status'] == TaskStatus.XEVALUATED and ret[
  546. 'result'] is not None:
  547. result = ret['result']
  548. excel_ret = dict()
  549. excel_ret = import_evaluate_excel(data, result, SD.workspace)
  550. return excel_ret
  551. else:
  552. excel_ret = dict()
  553. excel_ret['path'] = None
  554. excel_ret['status'] = -1
  555. excel_ret['message'] = "评估尚未完成或评估失败"
  556. return excel_ret
  557. @app.route('/project/task/predict', methods=['GET', 'POST', 'PUT'])
  558. def task_predict():
  559. '''
  560. methods=='GET':#获取预测状态
  561. Args:
  562. tid(str):任务id
  563. Return:
  564. predict_status(int): 预测任务状态(PredictStatus)枚举变量的值
  565. message(str): 预测信息
  566. status
  567. methods=='POST':#创建预测任务,目前仅支持单张图片的预测
  568. Args:
  569. tid(str):任务id
  570. image_data(str):base64编码的image数据
  571. score_thresh(float,optional):可选,检测任务时有效,检测类别的score threashold值默认是0.5
  572. epoch(int,float,optional):可选,选择需要做预测的ephoch,默认为评估指标最好的那一个epoch
  573. Return:
  574. path(str):服务器上保存预测结果图片的路径
  575. status
  576. '''
  577. data = request.get_json()
  578. if data is None:
  579. data = request.args
  580. if request.method == 'GET':
  581. from .project.task import get_predict_status
  582. ret = get_predict_status(data, SD.workspace)
  583. return ret
  584. if request.method == 'POST':
  585. from .project.task import predict_test_pics
  586. ret = predict_test_pics(data, SD.workspace, SD.monitored_processes)
  587. if 'img_list' in data:
  588. del ret['path']
  589. return ret
  590. return ret
  591. if request.method == 'PUT':
  592. from .project.task import stop_predict_task
  593. ret = stop_predict_task(data, SD.workspace)
  594. return ret
  595. @app.route('/project/task/export', methods=['GET', 'POST', 'PUT'])
  596. def task_export():
  597. '''
  598. methods=='GET':#获取导出模型的状态
  599. Args:
  600. tid(str):任务id
  601. quant(str,optional)可选,[log,result],导出量模型导出状态,若值为log则返回量化的日志;若值为result则返回量化的结果
  602. Return:
  603. status
  604. if quant == 'log':
  605. quant_log(dict):量化日志
  606. if quant == 'result'
  607. quant_result(dict):量化结果
  608. if quant not in Args:
  609. export_status(int):模型导出状态(PredictStatus)枚举变量的值
  610. message(str):模型导出提示信息
  611. methods=='POST':#导出inference模型或者导出lite模型
  612. Args:
  613. tid(str):任务id
  614. type(str):保存模型的类别[infer,lite],支持inference模型导出和lite的模型导出
  615. save_dir(str):保存模型的路径
  616. epoch(str,optional)可选,指定导出的epoch数默认为评估效果最好的epoch
  617. quant(bool,optional)可选,type为infer有效,是否导出量化后的模型,默认为False
  618. model_path(str,optional)可选,type为lite时有效,inference模型的地址
  619. Return:
  620. status
  621. if type == 'infer':
  622. save_dir:模型保存路径
  623. if type == 'lite':
  624. message:模型保存信息
  625. methods=='PUT':#停止导出模型
  626. Args:
  627. tid(str):任务id
  628. Return:
  629. export_status(int):模型导出状态(PredictStatus)枚举变量的值
  630. message(str):停止模型导出提示信息
  631. status
  632. '''
  633. data = request.get_json()
  634. if data is None:
  635. data = request.args
  636. if request.method == 'GET':
  637. if 'quant' in data:
  638. if data['quant'] == 'log':
  639. from .project.task import get_quant_progress
  640. ret = get_quant_progress(data, SD.workspace)
  641. return ret
  642. if data['quant'] == 'result':
  643. from .project.task import get_quant_result
  644. ret = get_quant_result(data, SD.workspace)
  645. return ret
  646. from .project.task import get_export_status
  647. ret = get_export_status(data, SD.workspace)
  648. ret['export_status'] = ret['export_status'].value
  649. return ret
  650. if request.method == 'POST':
  651. if data['type'] == 'infer':
  652. from .project.task import export_infer_model
  653. ret = export_infer_model(data, SD.workspace,
  654. SD.monitored_processes)
  655. return ret
  656. if data['type'] == 'lite':
  657. from .project.task import export_lite_model
  658. ret = export_lite_model(data, SD.workspace)
  659. return ret
  660. if request.method == 'PUT':
  661. from .project.task import stop_export_task
  662. stop_export_task(data, SD.workspace)
  663. return ret
  664. @app.route('/project/task/vdl', methods=['GET'])
  665. def task_vdl():
  666. '''
  667. methods=='GET':#打开某个任务的可视化分析工具(VisualDL)
  668. Args:
  669. tid(str):任务id
  670. Return:
  671. url(str):vdl地址
  672. status
  673. '''
  674. data = request.get_json()
  675. if data is None:
  676. data = request.args
  677. if request.method == 'GET':
  678. from .project.task import open_vdl
  679. ret = open_vdl(data, SD.workspace, SD.current_port,
  680. SD.monitored_processes, SD.running_boards)
  681. return ret
  682. @app.route('/system', methods=['GET', 'DELETE'])
  683. def system():
  684. '''
  685. methods=='GET':#获取系统GPU、CPU信息
  686. Args:
  687. type(str):[machine_info,gpu_memory_size]选择需要获取的系统信息
  688. Return:
  689. status
  690. if type=='machine_info'
  691. info(dict):服务端信息
  692. if type=='gpu_memory_size'
  693. gpu_mem_infos(list):GPU内存信息
  694. '''
  695. data = request.get_json()
  696. if data is None:
  697. data = request.args
  698. if request.method == 'GET':
  699. if data['type'] == 'machine_info':
  700. '''if 'path' not in data:
  701. data['path'] = None
  702. from .system import get_machine_info
  703. ret = get_machine_info(data, SD.machine_info)'''
  704. from .system import get_system_info
  705. ret = get_system_info(SD.machine_info)
  706. return ret
  707. if data['type'] == 'gpu_memory_size':
  708. #from .system import get_gpu_memory_size
  709. from .system import get_gpu_memory_info
  710. ret = get_gpu_memory_info(SD.machine_info)
  711. return ret
  712. if request.method == 'DELETE':
  713. from .system import exit_system
  714. ret = exit_system(SD.monitored_processes)
  715. return ret
  716. @app.route('/demo', methods=['GET', 'POST', 'PUT'])
  717. def demo():
  718. '''
  719. methods=='GET':#获取demo下载进度
  720. Args:
  721. prj_type(int):项目类型ProjectType枚举变量的int值
  722. Return:
  723. status
  724. attr(dict):demo下载信息
  725. methods=='POST':#下载或创建demo工程
  726. Args:
  727. type(str):{download,load}下载或者创建样例
  728. prj_type(int):项目类型ProjectType枚举变量的int值
  729. Return:
  730. status
  731. if type=='load':
  732. did:数据集id
  733. pid:项目id
  734. methods=='PUT':#停止下载或创建demo工程
  735. Args:
  736. prj_type(int):项目类型ProjectType枚举变量的int值
  737. Return:
  738. status
  739. '''
  740. data = request.get_json()
  741. if data is None:
  742. data = request.args
  743. if request.method == 'GET':
  744. from .demo import get_download_demo_progress
  745. ret = get_download_demo_progress(data, SD.workspace)
  746. return ret
  747. if request.method == 'POST':
  748. if data['type'] == 'download':
  749. from .demo import download_demo_dataset
  750. ret = download_demo_dataset(data, SD.workspace,
  751. SD.load_demo_proc_dict)
  752. return ret
  753. if data['type'] == 'load':
  754. from .demo import load_demo_project
  755. ret = load_demo_project(data, SD.workspace, SD.monitored_processes,
  756. SD.load_demo_proj_data_dict,
  757. SD.load_demo_proc_dict)
  758. return ret
  759. if request.method == 'PUT':
  760. from .demo import stop_import_demo
  761. ret = stop_import_demo(data, SD.workspace, SD.load_demo_proc_dict,
  762. SD.load_demo_proj_data_dict)
  763. return ret
  764. @app.route('/model', methods=['GET', 'POST', 'DELETE'])
  765. def model():
  766. '''
  767. methods=='GET':#获取一个或者所有模型的信息
  768. Args:
  769. mid(str,optional)可选,若存在则返回某个模型的信息
  770. type(str,optional)可选,[pretrained,exported].若存在则返回对应类型下所有的模型信息
  771. Return:
  772. status
  773. if mid in Args:
  774. dataset_attr(dict):数据集属性
  775. task_params(dict):模型训练参数
  776. eval_result(dict):模型评估结果
  777. if type in Args and type == 'pretrained':
  778. pretrained_models(list):所有预训练模型信息
  779. if type in Args and type == 'exported':
  780. exported_models(list):所有inference模型的信息
  781. methods=='POST':#创建一个模型
  782. Args:
  783. pid(str):项目id
  784. tid(str):任务id
  785. name(str):模型名字
  786. type(str):创建模型的类型,[pretrained,exported],pretrained代表创建预训练模型、exported代表创建inference或者lite模型
  787. source_path(str):仅type为pretrained时有效,训练好的模型的路径
  788. path(str):仅type为exported时有效,inference或者lite模型的路径
  789. exported_type(int):0为inference模型,1为lite模型
  790. eval_results(dict,optional):可选,仅type为pretrained时有效,模型评估的指标
  791. Return:
  792. status
  793. if type == 'pretrained':
  794. pmid(str):预训练模型id
  795. if type == 'exported':
  796. emid(str):inference模型id
  797. methods=='DELETE':删除一个模型
  798. Args:
  799. type(str):删除模型的类型,[pretrained,exported],pretrained代表创建预训练模型、exported代表创建inference或者lite模型
  800. if type='pretrained':
  801. pmid:预训练模型id
  802. if type='exported':
  803. emid:inference或者lite模型id
  804. Return:
  805. status
  806. '''
  807. data = request.get_json()
  808. if data is None:
  809. data = request.args
  810. if request.method == 'GET':
  811. if 'type' in data:
  812. if data['type'] == 'pretrained':
  813. from .model import list_pretrained_models
  814. ret = list_pretrained_models(SD.workspace)
  815. return ret
  816. if data['type'] == 'exported':
  817. from .model import list_exported_models
  818. ret = list_exported_models(SD.workspace)
  819. return ret
  820. from .model import get_model_details
  821. ret = get_model_details(data, SD.workspace)
  822. ret['eval_result']['Confusion_Matrix'] = ret['eval_result'][
  823. 'Confusion_Matrix'].tolist()
  824. ret['eval_result'] = CustomEncoder().encode(ret['eval_result'])
  825. ret['task_params'] = CustomEncoder().encode(ret['task_params'])
  826. return ret
  827. if request.method == 'POST':
  828. if data['type'] == 'pretrained':
  829. if 'eval_results' in data:
  830. data['eval_results']['Confusion_Matrix'] = np.array(data[
  831. 'eval_results']['Confusion_Matrix'])
  832. from .model import create_pretrained_model
  833. ret = create_pretrained_model(data, SD.workspace,
  834. SD.monitored_processes)
  835. return ret
  836. if data['type'] == 'exported':
  837. from .model import create_exported_model
  838. ret = create_exported_model(data, SD.workspace)
  839. return ret
  840. if request.method == 'DELETE':
  841. if data['type'] == 'pretrained':
  842. from .model import delete_pretrained_model
  843. ret = delete_pretrained_model(data, SD.workspace)
  844. return ret
  845. if data['type'] == 'exported':
  846. from .model import delete_exported_model
  847. ret = delete_exported_model(data, SD.workspace)
  848. return ret
  849. @app.route('/model/file', methods=['GET'])
  850. def model_file():
  851. data = request.get_json()
  852. if request.method == 'GET':
  853. ret = data['path']
  854. assert os.path.abspath(ret).startswith(
  855. os.path.abspath(SD.workspace_dir)
  856. ) and ".." not in ret, "Illegal path {}.".format(ret)
  857. return send_file(ret)
  858. @app.route('/', methods=['GET'])
  859. def gui():
  860. if request.method == 'GET':
  861. file_path = osp.join(
  862. osp.dirname(__file__), 'templates', 'paddlex_restful_demo.html')
  863. ip = get_ip()
  864. url = 'var str_srv_url = "http://' + ip + ':' + str(SD.port) + '";'
  865. f = open(file_path, 'r+', encoding=get_encoding(file_path))
  866. lines = f.readlines()
  867. for i, line in enumerate(lines):
  868. if '0.0.0.0:8080' in line:
  869. lines[i] = url
  870. break
  871. f.close()
  872. f = open(file_path, 'w+', encoding=get_encoding(file_path))
  873. f.writelines(lines)
  874. f.close()
  875. return render_template('/paddlex_restful_demo.html')
  876. def run(port, workspace_dir):
  877. if workspace_dir is None:
  878. user_home = os.path.expanduser('~')
  879. dirname = osp.join(user_home, "paddlex_workspace")
  880. else:
  881. dirname = workspace_dir
  882. if not osp.exists(dirname):
  883. os.makedirs(dirname)
  884. else:
  885. if not osp.isdir(dirname):
  886. os.remove(dirname)
  887. os.makedirs(dirname)
  888. logger = get_logger(osp.join(dirname, "mcessages.log"))
  889. init(dirname, logger)
  890. SD.port = port
  891. ip = get_ip()
  892. url = ip + ':' + str(port)
  893. try:
  894. logger.info("RESTful服务启动成功后,您可以在浏览器打开 {} 使用WEB版本GUI".format(url))
  895. app.run(host='0.0.0.0', port=port, threaded=True)
  896. except:
  897. print("服务启动不成功,请确保端口号:{}未被防火墙限制".format(port))