model.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import time
  15. import os
  16. import shutil
  17. import pickle
  18. from os import path as osp
  19. from .utils import set_folder_status, TaskStatus, copy_pretrained_model, PretrainedModelStatus
  20. from . import workspace_pb2 as w
  21. def list_pretrained_models(workspace):
  22. """列出预训练模型列表
  23. """
  24. pretrained_model_list = list()
  25. for id in workspace.pretrained_models:
  26. pretrained_model = workspace.pretrained_models[id]
  27. model_id = pretrained_model.id
  28. model_name = pretrained_model.name
  29. model_model = pretrained_model.model
  30. model_type = pretrained_model.type
  31. model_pid = pretrained_model.pid
  32. model_tid = pretrained_model.tid
  33. model_create_time = pretrained_model.create_time
  34. model_path = pretrained_model.path
  35. attr = {
  36. 'id': model_id,
  37. 'name': model_name,
  38. 'model': model_model,
  39. 'type': model_type,
  40. 'pid': model_pid,
  41. 'tid': model_tid,
  42. 'create_time': model_create_time,
  43. 'path': model_path
  44. }
  45. pretrained_model_list.append(attr)
  46. return {'status': 1, "pretrained_models": pretrained_model_list}
  47. def create_pretrained_model(data, workspace, monitored_processes):
  48. """根据request创建预训练模型。
  49. Args:
  50. data为dict,key包括
  51. 'pid'所属项目id, 'tid'所属任务id,'name'预训练模型名称,
  52. 'source_path' 原模型路径, 'eval_results'(可选) 评估结果数据
  53. """
  54. time_array = time.localtime(time.time())
  55. create_time = time.strftime("%Y-%m-%d %H:%M:%S", time_array)
  56. id = workspace.max_pretrained_model_id + 1
  57. workspace.max_pretrained_model_id = id
  58. if id < 10000:
  59. id = 'PM%04d' % id
  60. else:
  61. id = 'PM{}'.format(id)
  62. pid = data['pid']
  63. tid = data['tid']
  64. name = data['name']
  65. source_path = data['source_path']
  66. assert pid in workspace.projects, "【预训练模型创建】项目ID'{}'不存在.".format(pid)
  67. assert tid in workspace.tasks, "【预训练模型创建】任务ID'{}'不存在.".format(tid)
  68. assert not id in workspace.pretrained_models, "【预训练模型创建】预训练模型'{}'已经被占用.".format(
  69. id)
  70. assert osp.exists(source_path), "原模型路径不存在: {}".format(source_path)
  71. path = osp.join(workspace.path, 'pretrain', id)
  72. if not osp.exists(path):
  73. os.makedirs(path)
  74. set_folder_status(path, PretrainedModelStatus.XPINIT)
  75. params = {'tid': tid}
  76. from .project.task import get_task_params
  77. ret = get_task_params(params, workspace)
  78. train_params = ret['train']
  79. model_structure = train_params.model
  80. if hasattr(train_params, "backbone"):
  81. model_structure = "{}-{}".format(model_structure,
  82. train_params.backbone)
  83. if hasattr(train_params, "with_fpn"):
  84. if train_params.with_fpn:
  85. model_structure = "{}-{}".format(model_structure, "WITH_FPN")
  86. pm = w.PretrainedModel(
  87. id=id,
  88. name=name,
  89. model=model_structure,
  90. type=workspace.projects[pid].type,
  91. pid=pid,
  92. tid=tid,
  93. create_time=create_time,
  94. path=path)
  95. workspace.pretrained_models[id].CopyFrom(pm)
  96. # 保存评估结果
  97. if 'eval_results' in data:
  98. with open(osp.join(source_path, "eval_res.pkl"), "wb") as f:
  99. pickle.dump(data['eval_results'], f)
  100. # 拷贝训练参数文件
  101. task_path = workspace.tasks[tid].path
  102. task_params_path = osp.join(task_path, 'params.pkl')
  103. if osp.exists(task_params_path):
  104. shutil.copy(task_params_path, path)
  105. # 拷贝数据集信息文件
  106. did = workspace.projects[pid].did
  107. dataset_path = workspace.datasets[did].path
  108. dataset_info_path = osp.join(dataset_path, "statis.pkl")
  109. if osp.exists(dataset_info_path):
  110. # 写入部分数据集信息
  111. with open(dataset_info_path, "rb") as f:
  112. dataset_info_dict = pickle.load(f)
  113. dataset_info_dict['name'] = workspace.datasets[did].name
  114. dataset_info_dict['desc'] = workspace.datasets[did].desc
  115. with open(dataset_info_path, "wb") as f:
  116. pickle.dump(dataset_info_dict, f)
  117. shutil.copy(dataset_info_path, path)
  118. # copy from source_path to path
  119. proc = copy_pretrained_model(source_path, path)
  120. monitored_processes.put(proc.pid)
  121. return {'status': 1, 'pmid': id}
  122. def delete_pretrained_model(data, workspace):
  123. """删除pretrained_model。
  124. Args:
  125. data为dict,
  126. key包括'pmid'预训练模型id
  127. """
  128. pmid = data['pmid']
  129. assert pmid in workspace.pretrained_models, "预训练模型ID'{}'不存在.".format(pmid)
  130. if osp.exists(workspace.pretrained_models[pmid].path):
  131. shutil.rmtree(workspace.pretrained_models[pmid].path)
  132. del workspace.pretrained_models[pmid]
  133. return {'status': 1}
  134. def create_exported_model(data, workspace):
  135. """根据request创建已发布模型。
  136. Args:
  137. data为dict,key包括
  138. 'pid'所属项目id, 'tid'所属任务id,'name'已发布模型名称,
  139. 'path' 模型路径, 'exported_type' 已发布模型类型,
  140. """
  141. time_array = time.localtime(time.time())
  142. create_time = time.strftime("%Y-%m-%d %H:%M:%S", time_array)
  143. emid = workspace.max_exported_model_id + 1
  144. workspace.max_exported_model_id = emid
  145. if emid < 10000:
  146. emid = 'EM%04d' % emid
  147. else:
  148. emid = 'EM{}'.format(emid)
  149. pid = data['pid']
  150. tid = data['tid']
  151. name = data['name']
  152. path = data['path']
  153. exported_type = data['exported_type']
  154. assert pid in workspace.projects, "【已发布模型创建】项目ID'{}'不存在.".format(pid)
  155. assert tid in workspace.tasks, "【已发布模型创建】任务ID'{}'不存在.".format(tid)
  156. assert emid not in workspace.exported_models, "【已发布模型创建】已发布模型'{}'已经被占用.".format(
  157. emid)
  158. #assert osp.exists(path), "已发布模型路径不存在: {}".format(path)
  159. if not osp.exists(path):
  160. os.makedirs(path)
  161. task_path = workspace.tasks[tid].path
  162. # 拷贝评估结果
  163. eval_res_path = osp.join(task_path, 'eval_res.pkl')
  164. if osp.exists(eval_res_path):
  165. shutil.copy(eval_res_path, path)
  166. # 拷贝训练参数文件
  167. task_params_path = osp.join(task_path, 'params.pkl')
  168. if osp.exists(task_params_path):
  169. shutil.copy(task_params_path, path)
  170. # 拷贝数据集信息文件
  171. did = workspace.projects[pid].did
  172. dataset_path = workspace.datasets[did].path
  173. dataset_info_path = osp.join(dataset_path, "statis.pkl")
  174. if osp.exists(dataset_info_path):
  175. # 写入部分数据集信息
  176. with open(dataset_info_path, "rb") as f:
  177. dataset_info_dict = pickle.load(f)
  178. dataset_info_dict['name'] = workspace.datasets[did].name
  179. dataset_info_dict['desc'] = workspace.datasets[did].desc
  180. with open(dataset_info_path, "wb") as f:
  181. pickle.dump(dataset_info_dict, f)
  182. shutil.copy(dataset_info_path, path)
  183. from .project.task import get_task_params
  184. params = {'tid': tid}
  185. ret = get_task_params(params, workspace)
  186. train_params = ret['train']
  187. model_structure = train_params.model
  188. if hasattr(train_params, "backbone"):
  189. model_structure = "{}-{}".format(model_structure,
  190. train_params.backbone)
  191. if hasattr(train_params, "with_fpn"):
  192. if train_params.with_fpn:
  193. model_structure = "{}-{}".format(model_structure, "WITH_FPN")
  194. em = w.ExportedModel(
  195. id=emid,
  196. name=name,
  197. model=model_structure,
  198. type=workspace.projects[pid].type,
  199. pid=pid,
  200. tid=tid,
  201. create_time=create_time,
  202. path=path,
  203. exported_type=exported_type)
  204. workspace.exported_models[emid].CopyFrom(em)
  205. return {'status': 1, 'emid': emid}
  206. def list_exported_models(workspace):
  207. """列出预训练模型列表,可根据request中的参数进行筛选
  208. Args:
  209. """
  210. exported_model_list = list()
  211. for id in workspace.exported_models:
  212. exported_model = workspace.exported_models[id]
  213. model_id = exported_model.id
  214. model_name = exported_model.name
  215. model_model = exported_model.model
  216. model_type = exported_model.type
  217. model_pid = exported_model.pid
  218. model_tid = exported_model.tid
  219. model_create_time = exported_model.create_time
  220. model_path = exported_model.path
  221. model_exported_type = exported_model.exported_type
  222. attr = {
  223. 'id': model_id,
  224. 'name': model_name,
  225. 'model': model_model,
  226. 'type': model_type,
  227. 'pid': model_pid,
  228. 'tid': model_tid,
  229. 'create_time': model_create_time,
  230. 'path': model_path,
  231. 'exported_type': model_exported_type
  232. }
  233. if model_tid in workspace.tasks:
  234. from .project.task import get_export_status
  235. params = {'tid': model_tid}
  236. results = get_export_status(params, workspace)
  237. if results['export_status'] == TaskStatus.XEXPORTED:
  238. exported_model_list.append(attr)
  239. else:
  240. exported_model_list.append(attr)
  241. return {'status': 1, "exported_models": exported_model_list}
  242. def delete_exported_model(data, workspace):
  243. """删除exported_model。
  244. Args:
  245. data为dict,
  246. key包括'emid'已发布模型id
  247. """
  248. emid = data['emid']
  249. assert emid in workspace.exported_models, "已发布模型模型ID'{}'不存在.".format(emid)
  250. if osp.exists(workspace.exported_models[emid].path):
  251. shutil.rmtree(workspace.exported_models[emid].path)
  252. del workspace.exported_models[emid]
  253. return {'status': 1}
  254. def get_model_details(data, workspace):
  255. """获取模型详情。
  256. Args:
  257. data为dict,
  258. key包括'mid'模型id
  259. """
  260. mid = data['mid']
  261. if mid in workspace.pretrained_models:
  262. model_path = workspace.pretrained_models[mid].path
  263. elif mid in workspace.exported_models:
  264. model_path = workspace.exported_models[mid].path
  265. else:
  266. raise "模型{}不存在".format(mid)
  267. dataset_file = osp.join(model_path, 'statis.pkl')
  268. dataset_info = pickle.load(open(dataset_file, 'rb'))
  269. dataset_attr = {
  270. 'name': dataset_info['name'],
  271. 'desc': dataset_info['desc'],
  272. 'labels': dataset_info['labels'],
  273. 'train_num': len(dataset_info['train_files']),
  274. 'val_num': len(dataset_info['val_files']),
  275. 'test_num': len(dataset_info['test_files'])
  276. }
  277. task_params_file = osp.join(model_path, 'params.pkl')
  278. task_params = pickle.load(open(task_params_file, 'rb'))
  279. eval_result_file = osp.join(model_path, 'eval_res.pkl')
  280. eval_result = pickle.load(open(eval_result_file, 'rb'))
  281. #model_file = {'task_attr': task_params_file, 'eval_result': eval_result_file}
  282. return {
  283. 'status': 1,
  284. 'dataset_attr': dataset_attr,
  285. 'task_params': task_params,
  286. 'eval_result': eval_result
  287. }