workspace.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from . import workspace_pb2 as w
  15. from .utils import get_logger
  16. from .dir import *
  17. import os
  18. import os.path as osp
  19. from threading import Thread
  20. import traceback
  21. import platform
  22. import configparser
  23. import time
  24. import shutil
  25. import copy
  26. from paddlex.utils import get_encoding
  27. class Workspace():
  28. def __init__(self, workspace, dirname, logger):
  29. self.workspace = workspace
  30. #self.machine_info = {}
  31. # app init
  32. self.init_app_resource(dirname)
  33. # 当前workspace版本
  34. self.current_version = "0.2.0"
  35. self.logger = logger
  36. # 设置PaddleX的预训练模型下载存储路径
  37. # 设置路径后不会重复下载相同模型
  38. self.load_workspace()
  39. self.stop_running = False
  40. self.sync_thread = self.sync_with_local(interval=2)
  41. #检查硬件环境
  42. #self.check_hardware_env()
  43. def init_app_resource(self, dirname):
  44. self.m_cfgfile = configparser.ConfigParser()
  45. app_conf_file_name = "PaddleX".lower() + ".cfg"
  46. paddlex_cfg_file = os.path.join(PADDLEX_HOME, app_conf_file_name)
  47. try:
  48. self.m_cfgfile.read(paddlex_cfg_file)
  49. except Exception as e:
  50. print("[ERROR] Fail to read {}".format(paddlex_cfg_file))
  51. if not self.m_cfgfile.has_option("USERCFG", "workspacedir"):
  52. self.m_cfgfile.add_section("USERCFG")
  53. self.m_cfgfile.set("USERCFG", "workspacedir", "")
  54. self.m_cfgfile["USERCFG"]["workspacedir"] = dirname
  55. def load_workspace(self):
  56. path = self.workspace.path
  57. newest_file = osp.join(self.workspace.path, 'workspace.newest.pb')
  58. bak_file = osp.join(self.workspace.path, 'workspace.bak.pb')
  59. flag_file = osp.join(self.workspace.path, '.pb.success')
  60. self.workspace.version = self.current_version
  61. try:
  62. if osp.exists(flag_file):
  63. with open(newest_file, 'rb') as f:
  64. self.workspace.ParseFromString(f.read())
  65. elif osp.exists(bak_file):
  66. with open(bak_file, 'rb') as f:
  67. self.workspace.ParseFromString(f.read())
  68. else:
  69. print("it is a new workspace")
  70. except Exception as e:
  71. print(traceback.format_exc())
  72. self.workspace.path = path
  73. if self.workspace.version < "0.2.0":
  74. self.update_workspace()
  75. self.recover_workspace()
  76. def update_workspace(self):
  77. if len(self.workspace.projects) == 0 and len(
  78. self.workspace.datasets) == 0:
  79. self.workspace.version == '0.2.0'
  80. return
  81. for key in self.workspace.datasets:
  82. ds = self.workspace.datasets[key]
  83. try:
  84. info_file = os.path.join(ds.path, 'info.pb')
  85. with open(info_file, 'wb') as f:
  86. f.write(ds.SerializeToString())
  87. except Exception as e:
  88. self.logger.info(traceback.format_exc())
  89. for key in self.workspace.projects:
  90. pj = self.workspace.projects[key]
  91. try:
  92. info_file = os.path.join(pj.path, 'info.pb')
  93. with open(info_file, 'wb') as f:
  94. f.write(pj.SerializeToString())
  95. except Exception as e:
  96. self.logger.info(traceback.format_exc())
  97. for key in self.workspace.tasks:
  98. task = self.workspace.tasks[key]
  99. try:
  100. info_file = os.path.join(task.path, 'info.pb')
  101. with open(info_file, 'wb') as f:
  102. f.write(task.SerializeToString())
  103. except Exception as e:
  104. self.logger.info(traceback.format_exc())
  105. self.workspace.version == '0.2.0'
  106. def recover_workspace(self):
  107. if len(self.workspace.projects) > 0 or len(
  108. self.workspace.datasets) > 0:
  109. return
  110. projects_dir = os.path.join(self.workspace.path, 'projects')
  111. datasets_dir = os.path.join(self.workspace.path, 'datasets')
  112. if not os.path.exists(projects_dir):
  113. os.makedirs(projects_dir)
  114. if not os.path.exists(datasets_dir):
  115. os.makedirs(datasets_dir)
  116. max_project_id = 0
  117. max_dataset_id = 0
  118. max_task_id = 0
  119. for pd in os.listdir(projects_dir):
  120. try:
  121. if pd[0] != 'P':
  122. continue
  123. if int(pd[1:]) > max_project_id:
  124. max_project_id = int(pd[1:])
  125. except:
  126. continue
  127. info_pb_file = os.path.join(projects_dir, pd, 'info.pb')
  128. if not os.path.exists(info_pb_file):
  129. continue
  130. try:
  131. pj = w.Project()
  132. with open(info_pb_file, 'rb') as f:
  133. pj.ParseFromString(f.read())
  134. self.workspace.projects[pd].CopyFrom(pj)
  135. except Exception as e:
  136. self.logger.info(traceback.format_exc())
  137. for td in os.listdir(os.path.join(projects_dir, pd)):
  138. try:
  139. if td[0] != 'T':
  140. continue
  141. if int(td[1:]) > max_task_id:
  142. max_task_id = int(td[1:])
  143. except:
  144. continue
  145. info_pb_file = os.path.join(projects_dir, pd, td, 'info.pb')
  146. if not os.path.exists(info_pb_file):
  147. continue
  148. try:
  149. task = w.Task()
  150. with open(info_pb_file, 'rb') as f:
  151. task.ParseFromString(f.read())
  152. self.workspace.tasks[td].CopyFrom(task)
  153. except Exception as e:
  154. self.logger.info(traceback.format_exc())
  155. for dd in os.listdir(datasets_dir):
  156. try:
  157. if dd[0] != 'D':
  158. continue
  159. if int(dd[1:]) > max_dataset_id:
  160. max_dataset_id = int(dd[1:])
  161. except:
  162. continue
  163. info_pb_file = os.path.join(datasets_dir, dd, 'info.pb')
  164. if not os.path.exists(info_pb_file):
  165. continue
  166. try:
  167. ds = w.Dataset()
  168. with open(info_pb_file, 'rb') as f:
  169. ds.ParseFromString(f.read())
  170. self.workspace.datasets[dd].CopyFrom(ds)
  171. except Exception as e:
  172. self.logger.info(traceback.format_exc())
  173. self.workspace.max_dataset_id = max_dataset_id
  174. self.workspace.max_project_id = max_project_id
  175. self.workspace.max_task_id = max_task_id
  176. # 每间隔interval秒,将workspace同步到本地文件
  177. def sync_with_local(self, interval=2):
  178. def sync_func(s, interval_seconds=2):
  179. newest_file = osp.join(self.workspace.path, 'workspace.newest.pb')
  180. stable_file = osp.join(self.workspace.path, 'workspace.stable.pb')
  181. bak_file = osp.join(self.workspace.path, 'workspace.bak.pb')
  182. flag_file = osp.join(self.workspace.path, '.pb.success')
  183. while True:
  184. current_time = time.time()
  185. time_array = time.localtime(current_time)
  186. current_time = time.strftime("%Y-%m-%d %H:%M:%S", time_array)
  187. self.workspace.current_time = current_time
  188. if osp.exists(flag_file):
  189. os.remove(flag_file)
  190. f = open(newest_file, mode='wb')
  191. f.write(s.workspace.SerializeToString())
  192. f.close()
  193. open(flag_file, 'w').close()
  194. if osp.exists(stable_file):
  195. shutil.copyfile(stable_file, bak_file)
  196. shutil.copyfile(newest_file, stable_file)
  197. if s.stop_running:
  198. break
  199. time.sleep(interval_seconds)
  200. t = Thread(target=sync_func, args=(self, interval))
  201. t.start()
  202. return t
  203. def check_hardware_env(self):
  204. # 判断是否有gpu,cpu值是否已经设置
  205. hasGpu = True
  206. try:
  207. '''data = {'path' : path}
  208. from .system import get_machine_info
  209. info = get_machine_info(data, self.machine_info)['info']
  210. if info is None:
  211. return
  212. if (info['gpu_num'] == 0 and self.sysstr == "Windows"):
  213. data['path'] = os.path.abspath(os.path.dirname(__file__))
  214. info = get_machine_info(data, self.machine_info)['info']'''
  215. from .system import get_system_info
  216. info = get_system_info()['info']
  217. hasGpu = (info['gpu_num'] > 0)
  218. self.machine_info = info
  219. #driver_ver = info['driver_version']
  220. # driver_ver_list = driver_ver.split(".")
  221. # major_ver, minor_ver = driver_ver_list[0:2]
  222. # if sysstr == "Windows":
  223. # if int(major_ver) < 411 or \
  224. # (int(major_ver) == 411 and int(minor_ver) < 31):
  225. # raise Exception("The GPU dirver version should be larger than 411.31")
  226. #
  227. # elif sysstr == "Linux":
  228. # if int(major_ver) < 410 or \
  229. # (int(major_ver) == 410 and int(minor_ver) < 48):
  230. # raise Exception("The GPU dirver version should be larger than 410.48")
  231. except Exception as e:
  232. hasGpu = False
  233. self.m_HasGpu = hasGpu
  234. self.save_app_cfg_file()
  235. def save_app_cfg_file(self):
  236. #更新程序配置信息
  237. app_conf_file_name = 'PaddleX'.lower() + ".cfg"
  238. with open(
  239. os.path.join(PADDLEX_HOME, app_conf_file_name),
  240. 'w+',
  241. encoding=get_encoding(
  242. os.path.join(PADDLEX_HOME, app_conf_file_name))) as file:
  243. self.m_cfgfile.write(file)
  244. def init_workspace(workspace, dirname, logger):
  245. wp = Workspace(workspace, dirname, logger)
  246. #if not machine_info:
  247. #machine_info.update(wp.machine_info)
  248. return {'status': 1}
  249. def set_attr(data, workspace):
  250. """对workspace中项目,数据,任务变量进行修改赋值
  251. Args:
  252. data为dict,key包括
  253. 'struct'结构类型,可以是'dataset', 'project'或'task';
  254. 'id'查询id, 其余的key:value则分别为待修改的变量名和相应的修改值。
  255. """
  256. struct = data['struct']
  257. id = data['id']
  258. assert struct in ['dataset', 'project', 'task'
  259. ], "struct只能为dataset, project或task"
  260. if struct == 'dataset':
  261. assert id in workspace.datasets, "数据集ID'{}'不存在".format(id)
  262. modify_struct = workspace.datasets[id]
  263. elif struct == 'project':
  264. assert id in workspace.projects, "项目ID'{}'不存在".format(id)
  265. modify_struct = workspace.projects[id]
  266. elif struct == 'task':
  267. assert id in workspace.tasks, "任务ID'{}'不存在".format(id)
  268. modify_struct = workspace.tasks[id]
  269. '''for k, v in data.items():
  270. if k in ['id', 'struct']:
  271. continue
  272. assert hasattr(modify_struct,
  273. k), "{}不存在成员变量'{}'".format(type(modify_struct), k)
  274. setattr(modify_struct, k, v)'''
  275. for k, v in data['attr_dict'].items():
  276. assert hasattr(modify_struct,
  277. k), "{}不存在成员变量'{}'".format(type(modify_struct), k)
  278. setattr(modify_struct, k, v)
  279. with open(os.path.join(modify_struct.path, 'info.pb'), 'wb') as f:
  280. f.write(modify_struct.SerializeToString())
  281. return {'status': 1}
  282. def get_attr(data, workspace):
  283. """取出workspace中项目,数据,任务变量值
  284. Args:
  285. data为dict,key包括
  286. 'struct'结构类型,可以是'dataset', 'project'或'task';
  287. 'id'查询id, 'attr_list'需要获取的属性值列表
  288. """
  289. struct = data['struct']
  290. id = data['id']
  291. assert struct in ['dataset', 'project', 'task'
  292. ], "struct只能为dataset, project或task"
  293. if struct == 'dataset':
  294. assert id in workspace.datasets, "数据集ID'{}'不存在".format(id)
  295. modify_struct = workspace.datasets[id]
  296. elif struct == 'project':
  297. assert id in workspace.projects, "项目ID'{}'不存在".format(id)
  298. modify_struct = workspace.projects[id]
  299. elif struct == 'task':
  300. assert id in workspace.tasks, "任务ID'{}'不存在".format(id)
  301. modify_struct = workspace.tasks[id]
  302. attr = {}
  303. for k in data['attr_list']:
  304. if k in ['id', 'struct']:
  305. continue
  306. assert hasattr(modify_struct,
  307. k), "{}不存在成员变量'{}'".format(type(modify_struct), k)
  308. v = getattr(modify_struct, k)
  309. attr[k] = v
  310. return {'status': 1, 'attr': attr}