config.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. import os
  2. import typing
  3. from collections.abc import MutableMapping
  4. from pathlib import Path
  5. class undefined:
  6. pass
  7. class EnvironError(Exception):
  8. pass
  9. class Environ(MutableMapping):
  10. def __init__(self, environ: typing.MutableMapping = os.environ):
  11. self._environ = environ
  12. self._has_been_read: typing.Set[typing.Any] = set()
  13. def __getitem__(self, key: typing.Any) -> typing.Any:
  14. self._has_been_read.add(key)
  15. return self._environ.__getitem__(key)
  16. def __setitem__(self, key: typing.Any, value: typing.Any) -> None:
  17. if key in self._has_been_read:
  18. raise EnvironError(
  19. f"Attempting to set environ['{key}'], but the value has already been "
  20. "read."
  21. )
  22. self._environ.__setitem__(key, value)
  23. def __delitem__(self, key: typing.Any) -> None:
  24. if key in self._has_been_read:
  25. raise EnvironError(
  26. f"Attempting to delete environ['{key}'], but the value has already "
  27. "been read."
  28. )
  29. self._environ.__delitem__(key)
  30. def __iter__(self) -> typing.Iterator:
  31. return iter(self._environ)
  32. def __len__(self) -> int:
  33. return len(self._environ)
  34. environ = Environ()
  35. T = typing.TypeVar("T")
  36. class Config:
  37. def __init__(
  38. self,
  39. env_file: typing.Optional[typing.Union[str, Path]] = None,
  40. environ: typing.Mapping[str, str] = environ,
  41. env_prefix: str = "",
  42. ) -> None:
  43. self.environ = environ
  44. self.env_prefix = env_prefix
  45. self.file_values: typing.Dict[str, str] = {}
  46. if env_file is not None and os.path.isfile(env_file):
  47. self.file_values = self._read_file(env_file)
  48. @typing.overload
  49. def __call__(self, key: str, *, default: None) -> typing.Optional[str]:
  50. ...
  51. @typing.overload
  52. def __call__(self, key: str, cast: typing.Type[T], default: T = ...) -> T:
  53. ...
  54. @typing.overload
  55. def __call__(
  56. self, key: str, cast: typing.Type[str] = ..., default: str = ...
  57. ) -> str:
  58. ...
  59. @typing.overload
  60. def __call__(
  61. self,
  62. key: str,
  63. cast: typing.Callable[[typing.Any], T] = ...,
  64. default: typing.Any = ...,
  65. ) -> T:
  66. ...
  67. @typing.overload
  68. def __call__(
  69. self, key: str, cast: typing.Type[str] = ..., default: T = ...
  70. ) -> typing.Union[T, str]:
  71. ...
  72. def __call__(
  73. self,
  74. key: str,
  75. cast: typing.Optional[typing.Callable] = None,
  76. default: typing.Any = undefined,
  77. ) -> typing.Any:
  78. return self.get(key, cast, default)
  79. def get(
  80. self,
  81. key: str,
  82. cast: typing.Optional[typing.Callable] = None,
  83. default: typing.Any = undefined,
  84. ) -> typing.Any:
  85. key = self.env_prefix + key
  86. if key in self.environ:
  87. value = self.environ[key]
  88. return self._perform_cast(key, value, cast)
  89. if key in self.file_values:
  90. value = self.file_values[key]
  91. return self._perform_cast(key, value, cast)
  92. if default is not undefined:
  93. return self._perform_cast(key, default, cast)
  94. raise KeyError(f"Config '{key}' is missing, and has no default.")
  95. def _read_file(self, file_name: typing.Union[str, Path]) -> typing.Dict[str, str]:
  96. file_values: typing.Dict[str, str] = {}
  97. with open(file_name) as input_file:
  98. for line in input_file.readlines():
  99. line = line.strip()
  100. if "=" in line and not line.startswith("#"):
  101. key, value = line.split("=", 1)
  102. key = key.strip()
  103. value = value.strip().strip("\"'")
  104. file_values[key] = value
  105. return file_values
  106. def _perform_cast(
  107. self, key: str, value: typing.Any, cast: typing.Optional[typing.Callable] = None
  108. ) -> typing.Any:
  109. if cast is None or value is None:
  110. return value
  111. elif cast is bool and isinstance(value, str):
  112. mapping = {"true": True, "1": True, "false": False, "0": False}
  113. value = value.lower()
  114. if value not in mapping:
  115. raise ValueError(
  116. f"Config '{key}' has value '{value}'. Not a valid bool."
  117. )
  118. return mapping[value]
  119. try:
  120. return cast(value)
  121. except (TypeError, ValueError):
  122. raise ValueError(
  123. f"Config '{key}' has value '{value}'. Not a valid {cast.__name__}."
  124. )