check.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import os
  18. import sys
  19. import paddle
  20. from paddle import is_compiled_with_cuda
  21. from paddlex.ppcls.arch import get_architectures
  22. from paddlex.ppcls.arch import similar_architectures
  23. from paddlex.ppcls.arch import get_blacklist_model_in_static_mode
  24. from paddlex.ppcls.utils import logger
  25. def check_version():
  26. """
  27. Log error and exit when the installed version of paddlepaddle is
  28. not satisfied.
  29. """
  30. err = "PaddlePaddle version 1.8.0 or higher is required, " \
  31. "or a suitable develop version is satisfied as well. \n" \
  32. "Please make sure the version is good with your code."
  33. try:
  34. pass
  35. # paddle.utils.require_version('0.0.0')
  36. except Exception:
  37. logger.error(err)
  38. sys.exit(1)
  39. def check_gpu():
  40. """
  41. Log error and exit when using paddlepaddle cpu version.
  42. """
  43. err = "You are using paddlepaddle cpu version! Please try to " \
  44. "install paddlepaddle-gpu to run model on GPU."
  45. try:
  46. assert is_compiled_with_cuda()
  47. except AssertionError:
  48. logger.error(err)
  49. sys.exit(1)
  50. def check_architecture(architecture):
  51. """
  52. check architecture and recommend similar architectures
  53. """
  54. assert isinstance(architecture, dict), \
  55. ("the type of architecture({}) should be dict". format(architecture))
  56. assert "name" in architecture, \
  57. ("name must be in the architecture keys, just contains: {}". format(
  58. architecture.keys()))
  59. similar_names = similar_architectures(architecture["name"],
  60. get_architectures())
  61. model_list = ', '.join(similar_names)
  62. err = "Architecture [{}] is not exist! Maybe you want: [{}]" \
  63. "".format(architecture["name"], model_list)
  64. try:
  65. assert architecture["name"] in similar_names
  66. except AssertionError:
  67. logger.error(err)
  68. sys.exit(1)
  69. def check_model_with_running_mode(architecture):
  70. """
  71. check whether the model is consistent with the operating mode
  72. """
  73. # some model are not supported in the static mode
  74. blacklist = get_blacklist_model_in_static_mode()
  75. if not paddle.in_dynamic_mode() and architecture["name"] in blacklist:
  76. logger.error("Model: {} is not supported in the staic mode.".format(
  77. architecture["name"]))
  78. sys.exit(1)
  79. return
  80. def check_mix(architecture, use_mix=False):
  81. """
  82. check mix parameter
  83. """
  84. err = "Cannot use mix processing in GoogLeNet, " \
  85. "please set use_mix = False."
  86. try:
  87. if architecture["name"] == "GoogLeNet":
  88. assert use_mix is not True
  89. except AssertionError:
  90. logger.error(err)
  91. sys.exit(1)
  92. def check_classes_num(classes_num):
  93. """
  94. check classes_num
  95. """
  96. err = "classes_num({}) should be a positive integer" \
  97. "and larger than 1".format(classes_num)
  98. try:
  99. assert isinstance(classes_num, int)
  100. assert classes_num > 1
  101. except AssertionError:
  102. logger.error(err)
  103. sys.exit(1)
  104. def check_data_dir(path):
  105. """
  106. check cata_dir
  107. """
  108. err = "Data path is not exist, please given a right path" \
  109. "".format(path)
  110. try:
  111. assert os.isdir(path)
  112. except AssertionError:
  113. logger.error(err)
  114. sys.exit(1)
  115. def check_function_params(config, key):
  116. """
  117. check specify config
  118. """
  119. k_config = config.get(key)
  120. assert k_config is not None, \
  121. ('{} is required in config'.format(key))
  122. assert k_config.get('function'), \
  123. ('function is required {} config'.format(key))
  124. params = k_config.get('params')
  125. assert params is not None, \
  126. ('params is required in {} config'.format(key))
  127. assert isinstance(params, dict), \
  128. ('the params in {} config should be a dict'.format(key))