workspace.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325
  1. from . import workspace_pb2 as w
  2. from .utils import get_logger
  3. from .dir import *
  4. import os
  5. import os.path as osp
  6. from threading import Thread
  7. import traceback
  8. import platform
  9. import configparser
  10. import time
  11. import shutil
  12. import copy
  13. class Workspace():
  14. def __init__(self, workspace, dirname, logger):
  15. self.workspace = workspace
  16. #self.machine_info = {}
  17. # app init
  18. self.init_app_resource(dirname)
  19. # 当前workspace版本
  20. self.current_version = "0.2.0"
  21. self.logger = logger
  22. # 设置PaddleX的预训练模型下载存储路径
  23. # 设置路径后不会重复下载相同模型
  24. self.load_workspace()
  25. self.stop_running = False
  26. self.sync_thread = self.sync_with_local(interval=2)
  27. #检查硬件环境
  28. #self.check_hardware_env()
  29. def init_app_resource(self, dirname):
  30. self.m_cfgfile = configparser.ConfigParser()
  31. app_conf_file_name = "PaddleX".lower() + ".cfg"
  32. paddlex_cfg_file = os.path.join(PADDLEX_HOME, app_conf_file_name)
  33. try:
  34. self.m_cfgfile.read(paddlex_cfg_file)
  35. except Exception as e:
  36. print("[ERROR] Fail to read {}".format(paddlex_cfg_file))
  37. if not self.m_cfgfile.has_option("USERCFG", "workspacedir"):
  38. self.m_cfgfile.add_section("USERCFG")
  39. self.m_cfgfile.set("USERCFG", "workspacedir", "")
  40. self.m_cfgfile["USERCFG"]["workspacedir"] = dirname
  41. def load_workspace(self):
  42. path = self.workspace.path
  43. newest_file = osp.join(self.workspace.path, 'workspace.newest.pb')
  44. bak_file = osp.join(self.workspace.path, 'workspace.bak.pb')
  45. flag_file = osp.join(self.workspace.path, '.pb.success')
  46. self.workspace.version = self.current_version
  47. try:
  48. if osp.exists(flag_file):
  49. with open(newest_file, 'rb') as f:
  50. self.workspace.ParseFromString(f.read())
  51. elif osp.exists(bak_file):
  52. with open(bak_file, 'rb') as f:
  53. self.workspace.ParseFromString(f.read())
  54. else:
  55. print("it is a new workspace")
  56. except Exception as e:
  57. print(traceback.format_exc())
  58. self.workspace.path = path
  59. if self.workspace.version < "0.2.0":
  60. self.update_workspace()
  61. self.recover_workspace()
  62. def update_workspace(self):
  63. if len(self.workspace.projects) == 0 and len(
  64. self.workspace.datasets) == 0:
  65. self.workspace.version == '0.2.0'
  66. return
  67. for key in self.workspace.datasets:
  68. ds = self.workspace.datasets[key]
  69. try:
  70. info_file = os.path.join(ds.path, 'info.pb')
  71. with open(info_file, 'wb') as f:
  72. f.write(ds.SerializeToString())
  73. except Exception as e:
  74. self.logger.info(traceback.format_exc())
  75. for key in self.workspace.projects:
  76. pj = self.workspace.projects[key]
  77. try:
  78. info_file = os.path.join(pj.path, 'info.pb')
  79. with open(info_file, 'wb') as f:
  80. f.write(pj.SerializeToString())
  81. except Exception as e:
  82. self.logger.info(traceback.format_exc())
  83. for key in self.workspace.tasks:
  84. task = self.workspace.tasks[key]
  85. try:
  86. info_file = os.path.join(task.path, 'info.pb')
  87. with open(info_file, 'wb') as f:
  88. f.write(task.SerializeToString())
  89. except Exception as e:
  90. self.logger.info(traceback.format_exc())
  91. self.workspace.version == '0.2.0'
  92. def recover_workspace(self):
  93. if len(self.workspace.projects) > 0 or len(
  94. self.workspace.datasets) > 0:
  95. return
  96. projects_dir = os.path.join(self.workspace.path, 'projects')
  97. datasets_dir = os.path.join(self.workspace.path, 'datasets')
  98. if not os.path.exists(projects_dir):
  99. os.makedirs(projects_dir)
  100. if not os.path.exists(datasets_dir):
  101. os.makedirs(datasets_dir)
  102. max_project_id = 0
  103. max_dataset_id = 0
  104. max_task_id = 0
  105. for pd in os.listdir(projects_dir):
  106. try:
  107. if pd[0] != 'P':
  108. continue
  109. if int(pd[1:]) > max_project_id:
  110. max_project_id = int(pd[1:])
  111. except:
  112. continue
  113. info_pb_file = os.path.join(projects_dir, pd, 'info.pb')
  114. if not os.path.exists(info_pb_file):
  115. continue
  116. try:
  117. pj = w.Project()
  118. with open(info_pb_file, 'rb') as f:
  119. pj.ParseFromString(f.read())
  120. self.workspace.projects[pd].CopyFrom(pj)
  121. except Exception as e:
  122. self.logger.info(traceback.format_exc())
  123. for td in os.listdir(os.path.join(projects_dir, pd)):
  124. try:
  125. if td[0] != 'T':
  126. continue
  127. if int(td[1:]) > max_task_id:
  128. max_task_id = int(td[1:])
  129. except:
  130. continue
  131. info_pb_file = os.path.join(projects_dir, pd, td, 'info.pb')
  132. if not os.path.exists(info_pb_file):
  133. continue
  134. try:
  135. task = w.Task()
  136. with open(info_pb_file, 'rb') as f:
  137. task.ParseFromString(f.read())
  138. self.workspace.tasks[td].CopyFrom(task)
  139. except Exception as e:
  140. self.logger.info(traceback.format_exc())
  141. for dd in os.listdir(datasets_dir):
  142. try:
  143. if dd[0] != 'D':
  144. continue
  145. if int(dd[1:]) > max_dataset_id:
  146. max_dataset_id = int(dd[1:])
  147. except:
  148. continue
  149. info_pb_file = os.path.join(datasets_dir, dd, 'info.pb')
  150. if not os.path.exists(info_pb_file):
  151. continue
  152. try:
  153. ds = w.Dataset()
  154. with open(info_pb_file, 'rb') as f:
  155. ds.ParseFromString(f.read())
  156. self.workspace.datasets[dd].CopyFrom(ds)
  157. except Exception as e:
  158. self.logger.info(traceback.format_exc())
  159. self.workspace.max_dataset_id = max_dataset_id
  160. self.workspace.max_project_id = max_project_id
  161. self.workspace.max_task_id = max_task_id
  162. # 每间隔interval秒,将workspace同步到本地文件
  163. def sync_with_local(self, interval=2):
  164. def sync_func(s, interval_seconds=2):
  165. newest_file = osp.join(self.workspace.path, 'workspace.newest.pb')
  166. stable_file = osp.join(self.workspace.path, 'workspace.stable.pb')
  167. bak_file = osp.join(self.workspace.path, 'workspace.bak.pb')
  168. flag_file = osp.join(self.workspace.path, '.pb.success')
  169. while True:
  170. current_time = time.time()
  171. time_array = time.localtime(current_time)
  172. current_time = time.strftime("%Y-%m-%d %H:%M:%S", time_array)
  173. self.workspace.current_time = current_time
  174. if osp.exists(flag_file):
  175. os.remove(flag_file)
  176. f = open(newest_file, mode='wb')
  177. f.write(s.workspace.SerializeToString())
  178. f.close()
  179. open(flag_file, 'w').close()
  180. if osp.exists(stable_file):
  181. shutil.copyfile(stable_file, bak_file)
  182. shutil.copyfile(newest_file, stable_file)
  183. if s.stop_running:
  184. break
  185. time.sleep(interval_seconds)
  186. t = Thread(target=sync_func, args=(self, interval))
  187. t.start()
  188. return t
  189. def check_hardware_env(self):
  190. # 判断是否有gpu,cpu值是否已经设置
  191. hasGpu = True
  192. try:
  193. '''data = {'path' : path}
  194. from .system import get_machine_info
  195. info = get_machine_info(data, self.machine_info)['info']
  196. if info is None:
  197. return
  198. if (info['gpu_num'] == 0 and self.sysstr == "Windows"):
  199. data['path'] = os.path.abspath(os.path.dirname(__file__))
  200. info = get_machine_info(data, self.machine_info)['info']'''
  201. from .system import get_system_info
  202. info = get_system_info()['info']
  203. hasGpu = (info['gpu_num'] > 0)
  204. self.machine_info = info
  205. #driver_ver = info['driver_version']
  206. # driver_ver_list = driver_ver.split(".")
  207. # major_ver, minor_ver = driver_ver_list[0:2]
  208. # if sysstr == "Windows":
  209. # if int(major_ver) < 411 or \
  210. # (int(major_ver) == 411 and int(minor_ver) < 31):
  211. # raise Exception("The GPU dirver version should be larger than 411.31")
  212. #
  213. # elif sysstr == "Linux":
  214. # if int(major_ver) < 410 or \
  215. # (int(major_ver) == 410 and int(minor_ver) < 48):
  216. # raise Exception("The GPU dirver version should be larger than 410.48")
  217. except Exception as e:
  218. hasGpu = False
  219. self.m_HasGpu = hasGpu
  220. self.save_app_cfg_file()
  221. def save_app_cfg_file(self):
  222. #更新程序配置信息
  223. app_conf_file_name = 'PaddleX'.lower() + ".cfg"
  224. with open(os.path.join(PADDLEX_HOME, app_conf_file_name),
  225. 'w+') as file:
  226. self.m_cfgfile.write(file)
  227. def init_workspace(workspace, dirname, logger):
  228. wp = Workspace(workspace, dirname, logger)
  229. #if not machine_info:
  230. #machine_info.update(wp.machine_info)
  231. return {'status': 1}
  232. def set_attr(data, workspace):
  233. """对workspace中项目,数据,任务变量进行修改赋值
  234. Args:
  235. data为dict,key包括
  236. 'struct'结构类型,可以是'dataset', 'project'或'task';
  237. 'id'查询id, 其余的key:value则分别为待修改的变量名和相应的修改值。
  238. """
  239. struct = data['struct']
  240. id = data['id']
  241. assert struct in ['dataset', 'project', 'task'
  242. ], "struct只能为dataset, project或task"
  243. if struct == 'dataset':
  244. assert id in workspace.datasets, "数据集ID'{}'不存在".format(id)
  245. modify_struct = workspace.datasets[id]
  246. elif struct == 'project':
  247. assert id in workspace.projects, "项目ID'{}'不存在".format(id)
  248. modify_struct = workspace.projects[id]
  249. elif struct == 'task':
  250. assert id in workspace.tasks, "任务ID'{}'不存在".format(id)
  251. modify_struct = workspace.tasks[id]
  252. '''for k, v in data.items():
  253. if k in ['id', 'struct']:
  254. continue
  255. assert hasattr(modify_struct,
  256. k), "{}不存在成员变量'{}'".format(type(modify_struct), k)
  257. setattr(modify_struct, k, v)'''
  258. for k, v in data['attr_dict'].items():
  259. assert hasattr(modify_struct,
  260. k), "{}不存在成员变量'{}'".format(type(modify_struct), k)
  261. setattr(modify_struct, k, v)
  262. with open(os.path.join(modify_struct.path, 'info.pb'), 'wb') as f:
  263. f.write(modify_struct.SerializeToString())
  264. return {'status': 1}
  265. def get_attr(data, workspace):
  266. """取出workspace中项目,数据,任务变量值
  267. Args:
  268. data为dict,key包括
  269. 'struct'结构类型,可以是'dataset', 'project'或'task';
  270. 'id'查询id, 'attr_list'需要获取的属性值列表
  271. """
  272. struct = data['struct']
  273. id = data['id']
  274. assert struct in ['dataset', 'project', 'task'
  275. ], "struct只能为dataset, project或task"
  276. if struct == 'dataset':
  277. assert id in workspace.datasets, "数据集ID'{}'不存在".format(id)
  278. modify_struct = workspace.datasets[id]
  279. elif struct == 'project':
  280. assert id in workspace.projects, "项目ID'{}'不存在".format(id)
  281. modify_struct = workspace.projects[id]
  282. elif struct == 'task':
  283. assert id in workspace.tasks, "任务ID'{}'不存在".format(id)
  284. modify_struct = workspace.tasks[id]
  285. attr = {}
  286. for k in data['attr_list']:
  287. if k in ['id', 'struct']:
  288. continue
  289. assert hasattr(modify_struct,
  290. k), "{}不存在成员变量'{}'".format(type(modify_struct), k)
  291. v = getattr(modify_struct, k)
  292. attr[k] = v
  293. return {'status': 1, 'attr': attr}