workspace.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from . import workspace_pb2 as w
  15. from .utils import get_logger
  16. from .dir import *
  17. import os
  18. import os.path as osp
  19. from threading import Thread
  20. import traceback
  21. import platform
  22. import configparser
  23. import time
  24. import shutil
  25. import copy
  26. class Workspace():
  27. def __init__(self, workspace, dirname, logger):
  28. self.workspace = workspace
  29. #self.machine_info = {}
  30. # app init
  31. self.init_app_resource(dirname)
  32. # 当前workspace版本
  33. self.current_version = "0.2.0"
  34. self.logger = logger
  35. # 设置PaddleX的预训练模型下载存储路径
  36. # 设置路径后不会重复下载相同模型
  37. self.load_workspace()
  38. self.stop_running = False
  39. self.sync_thread = self.sync_with_local(interval=2)
  40. #检查硬件环境
  41. #self.check_hardware_env()
  42. def init_app_resource(self, dirname):
  43. self.m_cfgfile = configparser.ConfigParser()
  44. app_conf_file_name = "PaddleX".lower() + ".cfg"
  45. paddlex_cfg_file = os.path.join(PADDLEX_HOME, app_conf_file_name)
  46. try:
  47. self.m_cfgfile.read(paddlex_cfg_file)
  48. except Exception as e:
  49. print("[ERROR] Fail to read {}".format(paddlex_cfg_file))
  50. if not self.m_cfgfile.has_option("USERCFG", "workspacedir"):
  51. self.m_cfgfile.add_section("USERCFG")
  52. self.m_cfgfile.set("USERCFG", "workspacedir", "")
  53. self.m_cfgfile["USERCFG"]["workspacedir"] = dirname
  54. def load_workspace(self):
  55. path = self.workspace.path
  56. newest_file = osp.join(self.workspace.path, 'workspace.newest.pb')
  57. bak_file = osp.join(self.workspace.path, 'workspace.bak.pb')
  58. flag_file = osp.join(self.workspace.path, '.pb.success')
  59. self.workspace.version = self.current_version
  60. try:
  61. if osp.exists(flag_file):
  62. with open(newest_file, 'rb') as f:
  63. self.workspace.ParseFromString(f.read())
  64. elif osp.exists(bak_file):
  65. with open(bak_file, 'rb') as f:
  66. self.workspace.ParseFromString(f.read())
  67. else:
  68. print("it is a new workspace")
  69. except Exception as e:
  70. print(traceback.format_exc())
  71. self.workspace.path = path
  72. if self.workspace.version < "0.2.0":
  73. self.update_workspace()
  74. self.recover_workspace()
  75. def update_workspace(self):
  76. if len(self.workspace.projects) == 0 and len(
  77. self.workspace.datasets) == 0:
  78. self.workspace.version == '0.2.0'
  79. return
  80. for key in self.workspace.datasets:
  81. ds = self.workspace.datasets[key]
  82. try:
  83. info_file = os.path.join(ds.path, 'info.pb')
  84. with open(info_file, 'wb') as f:
  85. f.write(ds.SerializeToString())
  86. except Exception as e:
  87. self.logger.info(traceback.format_exc())
  88. for key in self.workspace.projects:
  89. pj = self.workspace.projects[key]
  90. try:
  91. info_file = os.path.join(pj.path, 'info.pb')
  92. with open(info_file, 'wb') as f:
  93. f.write(pj.SerializeToString())
  94. except Exception as e:
  95. self.logger.info(traceback.format_exc())
  96. for key in self.workspace.tasks:
  97. task = self.workspace.tasks[key]
  98. try:
  99. info_file = os.path.join(task.path, 'info.pb')
  100. with open(info_file, 'wb') as f:
  101. f.write(task.SerializeToString())
  102. except Exception as e:
  103. self.logger.info(traceback.format_exc())
  104. self.workspace.version == '0.2.0'
  105. def recover_workspace(self):
  106. if len(self.workspace.projects) > 0 or len(
  107. self.workspace.datasets) > 0:
  108. return
  109. projects_dir = os.path.join(self.workspace.path, 'projects')
  110. datasets_dir = os.path.join(self.workspace.path, 'datasets')
  111. if not os.path.exists(projects_dir):
  112. os.makedirs(projects_dir)
  113. if not os.path.exists(datasets_dir):
  114. os.makedirs(datasets_dir)
  115. max_project_id = 0
  116. max_dataset_id = 0
  117. max_task_id = 0
  118. for pd in os.listdir(projects_dir):
  119. try:
  120. if pd[0] != 'P':
  121. continue
  122. if int(pd[1:]) > max_project_id:
  123. max_project_id = int(pd[1:])
  124. except:
  125. continue
  126. info_pb_file = os.path.join(projects_dir, pd, 'info.pb')
  127. if not os.path.exists(info_pb_file):
  128. continue
  129. try:
  130. pj = w.Project()
  131. with open(info_pb_file, 'rb') as f:
  132. pj.ParseFromString(f.read())
  133. self.workspace.projects[pd].CopyFrom(pj)
  134. except Exception as e:
  135. self.logger.info(traceback.format_exc())
  136. for td in os.listdir(os.path.join(projects_dir, pd)):
  137. try:
  138. if td[0] != 'T':
  139. continue
  140. if int(td[1:]) > max_task_id:
  141. max_task_id = int(td[1:])
  142. except:
  143. continue
  144. info_pb_file = os.path.join(projects_dir, pd, td, 'info.pb')
  145. if not os.path.exists(info_pb_file):
  146. continue
  147. try:
  148. task = w.Task()
  149. with open(info_pb_file, 'rb') as f:
  150. task.ParseFromString(f.read())
  151. self.workspace.tasks[td].CopyFrom(task)
  152. except Exception as e:
  153. self.logger.info(traceback.format_exc())
  154. for dd in os.listdir(datasets_dir):
  155. try:
  156. if dd[0] != 'D':
  157. continue
  158. if int(dd[1:]) > max_dataset_id:
  159. max_dataset_id = int(dd[1:])
  160. except:
  161. continue
  162. info_pb_file = os.path.join(datasets_dir, dd, 'info.pb')
  163. if not os.path.exists(info_pb_file):
  164. continue
  165. try:
  166. ds = w.Dataset()
  167. with open(info_pb_file, 'rb') as f:
  168. ds.ParseFromString(f.read())
  169. self.workspace.datasets[dd].CopyFrom(ds)
  170. except Exception as e:
  171. self.logger.info(traceback.format_exc())
  172. self.workspace.max_dataset_id = max_dataset_id
  173. self.workspace.max_project_id = max_project_id
  174. self.workspace.max_task_id = max_task_id
  175. # 每间隔interval秒,将workspace同步到本地文件
  176. def sync_with_local(self, interval=2):
  177. def sync_func(s, interval_seconds=2):
  178. newest_file = osp.join(self.workspace.path, 'workspace.newest.pb')
  179. stable_file = osp.join(self.workspace.path, 'workspace.stable.pb')
  180. bak_file = osp.join(self.workspace.path, 'workspace.bak.pb')
  181. flag_file = osp.join(self.workspace.path, '.pb.success')
  182. while True:
  183. current_time = time.time()
  184. time_array = time.localtime(current_time)
  185. current_time = time.strftime("%Y-%m-%d %H:%M:%S", time_array)
  186. self.workspace.current_time = current_time
  187. if osp.exists(flag_file):
  188. os.remove(flag_file)
  189. f = open(newest_file, mode='wb')
  190. f.write(s.workspace.SerializeToString())
  191. f.close()
  192. open(flag_file, 'w').close()
  193. if osp.exists(stable_file):
  194. shutil.copyfile(stable_file, bak_file)
  195. shutil.copyfile(newest_file, stable_file)
  196. if s.stop_running:
  197. break
  198. time.sleep(interval_seconds)
  199. t = Thread(target=sync_func, args=(self, interval))
  200. t.start()
  201. return t
  202. def check_hardware_env(self):
  203. # 判断是否有gpu,cpu值是否已经设置
  204. hasGpu = True
  205. try:
  206. '''data = {'path' : path}
  207. from .system import get_machine_info
  208. info = get_machine_info(data, self.machine_info)['info']
  209. if info is None:
  210. return
  211. if (info['gpu_num'] == 0 and self.sysstr == "Windows"):
  212. data['path'] = os.path.abspath(os.path.dirname(__file__))
  213. info = get_machine_info(data, self.machine_info)['info']'''
  214. from .system import get_system_info
  215. info = get_system_info()['info']
  216. hasGpu = (info['gpu_num'] > 0)
  217. self.machine_info = info
  218. #driver_ver = info['driver_version']
  219. # driver_ver_list = driver_ver.split(".")
  220. # major_ver, minor_ver = driver_ver_list[0:2]
  221. # if sysstr == "Windows":
  222. # if int(major_ver) < 411 or \
  223. # (int(major_ver) == 411 and int(minor_ver) < 31):
  224. # raise Exception("The GPU dirver version should be larger than 411.31")
  225. #
  226. # elif sysstr == "Linux":
  227. # if int(major_ver) < 410 or \
  228. # (int(major_ver) == 410 and int(minor_ver) < 48):
  229. # raise Exception("The GPU dirver version should be larger than 410.48")
  230. except Exception as e:
  231. hasGpu = False
  232. self.m_HasGpu = hasGpu
  233. self.save_app_cfg_file()
  234. def save_app_cfg_file(self):
  235. #更新程序配置信息
  236. app_conf_file_name = 'PaddleX'.lower() + ".cfg"
  237. with open(os.path.join(PADDLEX_HOME, app_conf_file_name),
  238. 'w+') as file:
  239. self.m_cfgfile.write(file)
  240. def init_workspace(workspace, dirname, logger):
  241. wp = Workspace(workspace, dirname, logger)
  242. #if not machine_info:
  243. #machine_info.update(wp.machine_info)
  244. return {'status': 1}
  245. def set_attr(data, workspace):
  246. """对workspace中项目,数据,任务变量进行修改赋值
  247. Args:
  248. data为dict,key包括
  249. 'struct'结构类型,可以是'dataset', 'project'或'task';
  250. 'id'查询id, 其余的key:value则分别为待修改的变量名和相应的修改值。
  251. """
  252. struct = data['struct']
  253. id = data['id']
  254. assert struct in ['dataset', 'project', 'task'
  255. ], "struct只能为dataset, project或task"
  256. if struct == 'dataset':
  257. assert id in workspace.datasets, "数据集ID'{}'不存在".format(id)
  258. modify_struct = workspace.datasets[id]
  259. elif struct == 'project':
  260. assert id in workspace.projects, "项目ID'{}'不存在".format(id)
  261. modify_struct = workspace.projects[id]
  262. elif struct == 'task':
  263. assert id in workspace.tasks, "任务ID'{}'不存在".format(id)
  264. modify_struct = workspace.tasks[id]
  265. '''for k, v in data.items():
  266. if k in ['id', 'struct']:
  267. continue
  268. assert hasattr(modify_struct,
  269. k), "{}不存在成员变量'{}'".format(type(modify_struct), k)
  270. setattr(modify_struct, k, v)'''
  271. for k, v in data['attr_dict'].items():
  272. assert hasattr(modify_struct,
  273. k), "{}不存在成员变量'{}'".format(type(modify_struct), k)
  274. setattr(modify_struct, k, v)
  275. with open(os.path.join(modify_struct.path, 'info.pb'), 'wb') as f:
  276. f.write(modify_struct.SerializeToString())
  277. return {'status': 1}
  278. def get_attr(data, workspace):
  279. """取出workspace中项目,数据,任务变量值
  280. Args:
  281. data为dict,key包括
  282. 'struct'结构类型,可以是'dataset', 'project'或'task';
  283. 'id'查询id, 'attr_list'需要获取的属性值列表
  284. """
  285. struct = data['struct']
  286. id = data['id']
  287. assert struct in ['dataset', 'project', 'task'
  288. ], "struct只能为dataset, project或task"
  289. if struct == 'dataset':
  290. assert id in workspace.datasets, "数据集ID'{}'不存在".format(id)
  291. modify_struct = workspace.datasets[id]
  292. elif struct == 'project':
  293. assert id in workspace.projects, "项目ID'{}'不存在".format(id)
  294. modify_struct = workspace.projects[id]
  295. elif struct == 'task':
  296. assert id in workspace.tasks, "任务ID'{}'不存在".format(id)
  297. modify_struct = workspace.tasks[id]
  298. attr = {}
  299. for k in data['attr_list']:
  300. if k in ['id', 'struct']:
  301. continue
  302. assert hasattr(modify_struct,
  303. k), "{}不存在成员变量'{}'".format(type(modify_struct), k)
  304. v = getattr(modify_struct, k)
  305. attr[k] = v
  306. return {'status': 1, 'attr': attr}