utils.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810
  1. import inspect
  2. from contextlib import contextmanager
  3. from copy import deepcopy
  4. from typing import (
  5. Any,
  6. Callable,
  7. Coroutine,
  8. Dict,
  9. ForwardRef,
  10. List,
  11. Mapping,
  12. Optional,
  13. Sequence,
  14. Tuple,
  15. Type,
  16. Union,
  17. cast,
  18. )
  19. import anyio
  20. from fastapi import params
  21. from fastapi._compat import (
  22. PYDANTIC_V2,
  23. ErrorWrapper,
  24. ModelField,
  25. Required,
  26. Undefined,
  27. _regenerate_error_with_loc,
  28. copy_field_info,
  29. create_body_model,
  30. evaluate_forwardref,
  31. field_annotation_is_scalar,
  32. get_annotation_from_field_info,
  33. get_missing_field_error,
  34. is_bytes_field,
  35. is_bytes_sequence_field,
  36. is_scalar_field,
  37. is_scalar_sequence_field,
  38. is_sequence_field,
  39. is_uploadfile_or_nonable_uploadfile_annotation,
  40. is_uploadfile_sequence_annotation,
  41. lenient_issubclass,
  42. sequence_types,
  43. serialize_sequence_value,
  44. value_is_sequence,
  45. )
  46. from fastapi.background import BackgroundTasks
  47. from fastapi.concurrency import (
  48. AsyncExitStack,
  49. asynccontextmanager,
  50. contextmanager_in_threadpool,
  51. )
  52. from fastapi.dependencies.models import Dependant, SecurityRequirement
  53. from fastapi.logger import logger
  54. from fastapi.security.base import SecurityBase
  55. from fastapi.security.oauth2 import OAuth2, SecurityScopes
  56. from fastapi.security.open_id_connect_url import OpenIdConnect
  57. from fastapi.utils import create_response_field, get_path_param_names
  58. from pydantic.fields import FieldInfo
  59. from starlette.background import BackgroundTasks as StarletteBackgroundTasks
  60. from starlette.concurrency import run_in_threadpool
  61. from starlette.datastructures import FormData, Headers, QueryParams, UploadFile
  62. from starlette.requests import HTTPConnection, Request
  63. from starlette.responses import Response
  64. from starlette.websockets import WebSocket
  65. from typing_extensions import Annotated, get_args, get_origin
  66. multipart_not_installed_error = (
  67. 'Form data requires "python-multipart" to be installed. \n'
  68. 'You can install "python-multipart" with: \n\n'
  69. "pip install python-multipart\n"
  70. )
  71. multipart_incorrect_install_error = (
  72. 'Form data requires "python-multipart" to be installed. '
  73. 'It seems you installed "multipart" instead. \n'
  74. 'You can remove "multipart" with: \n\n'
  75. "pip uninstall multipart\n\n"
  76. 'And then install "python-multipart" with: \n\n'
  77. "pip install python-multipart\n"
  78. )
  79. def check_file_field(field: ModelField) -> None:
  80. field_info = field.field_info
  81. if isinstance(field_info, params.Form):
  82. try:
  83. # __version__ is available in both multiparts, and can be mocked
  84. from multipart import __version__ # type: ignore
  85. assert __version__
  86. try:
  87. # parse_options_header is only available in the right multipart
  88. from multipart.multipart import parse_options_header # type: ignore
  89. assert parse_options_header
  90. except ImportError:
  91. logger.error(multipart_incorrect_install_error)
  92. raise RuntimeError(multipart_incorrect_install_error) from None
  93. except ImportError:
  94. logger.error(multipart_not_installed_error)
  95. raise RuntimeError(multipart_not_installed_error) from None
  96. def get_param_sub_dependant(
  97. *,
  98. param_name: str,
  99. depends: params.Depends,
  100. path: str,
  101. security_scopes: Optional[List[str]] = None,
  102. ) -> Dependant:
  103. assert depends.dependency
  104. return get_sub_dependant(
  105. depends=depends,
  106. dependency=depends.dependency,
  107. path=path,
  108. name=param_name,
  109. security_scopes=security_scopes,
  110. )
  111. def get_parameterless_sub_dependant(*, depends: params.Depends, path: str) -> Dependant:
  112. assert callable(
  113. depends.dependency
  114. ), "A parameter-less dependency must have a callable dependency"
  115. return get_sub_dependant(depends=depends, dependency=depends.dependency, path=path)
  116. def get_sub_dependant(
  117. *,
  118. depends: params.Depends,
  119. dependency: Callable[..., Any],
  120. path: str,
  121. name: Optional[str] = None,
  122. security_scopes: Optional[List[str]] = None,
  123. ) -> Dependant:
  124. security_requirement = None
  125. security_scopes = security_scopes or []
  126. if isinstance(depends, params.Security):
  127. dependency_scopes = depends.scopes
  128. security_scopes.extend(dependency_scopes)
  129. if isinstance(dependency, SecurityBase):
  130. use_scopes: List[str] = []
  131. if isinstance(dependency, (OAuth2, OpenIdConnect)):
  132. use_scopes = security_scopes
  133. security_requirement = SecurityRequirement(
  134. security_scheme=dependency, scopes=use_scopes
  135. )
  136. sub_dependant = get_dependant(
  137. path=path,
  138. call=dependency,
  139. name=name,
  140. security_scopes=security_scopes,
  141. use_cache=depends.use_cache,
  142. )
  143. if security_requirement:
  144. sub_dependant.security_requirements.append(security_requirement)
  145. return sub_dependant
  146. CacheKey = Tuple[Optional[Callable[..., Any]], Tuple[str, ...]]
  147. def get_flat_dependant(
  148. dependant: Dependant,
  149. *,
  150. skip_repeats: bool = False,
  151. visited: Optional[List[CacheKey]] = None,
  152. ) -> Dependant:
  153. if visited is None:
  154. visited = []
  155. visited.append(dependant.cache_key)
  156. flat_dependant = Dependant(
  157. path_params=dependant.path_params.copy(),
  158. query_params=dependant.query_params.copy(),
  159. header_params=dependant.header_params.copy(),
  160. cookie_params=dependant.cookie_params.copy(),
  161. body_params=dependant.body_params.copy(),
  162. security_schemes=dependant.security_requirements.copy(),
  163. use_cache=dependant.use_cache,
  164. path=dependant.path,
  165. )
  166. for sub_dependant in dependant.dependencies:
  167. if skip_repeats and sub_dependant.cache_key in visited:
  168. continue
  169. flat_sub = get_flat_dependant(
  170. sub_dependant, skip_repeats=skip_repeats, visited=visited
  171. )
  172. flat_dependant.path_params.extend(flat_sub.path_params)
  173. flat_dependant.query_params.extend(flat_sub.query_params)
  174. flat_dependant.header_params.extend(flat_sub.header_params)
  175. flat_dependant.cookie_params.extend(flat_sub.cookie_params)
  176. flat_dependant.body_params.extend(flat_sub.body_params)
  177. flat_dependant.security_requirements.extend(flat_sub.security_requirements)
  178. return flat_dependant
  179. def get_flat_params(dependant: Dependant) -> List[ModelField]:
  180. flat_dependant = get_flat_dependant(dependant, skip_repeats=True)
  181. return (
  182. flat_dependant.path_params
  183. + flat_dependant.query_params
  184. + flat_dependant.header_params
  185. + flat_dependant.cookie_params
  186. )
  187. def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
  188. signature = inspect.signature(call)
  189. globalns = getattr(call, "__globals__", {})
  190. typed_params = [
  191. inspect.Parameter(
  192. name=param.name,
  193. kind=param.kind,
  194. default=param.default,
  195. annotation=get_typed_annotation(param.annotation, globalns),
  196. )
  197. for param in signature.parameters.values()
  198. ]
  199. typed_signature = inspect.Signature(typed_params)
  200. return typed_signature
  201. def get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any:
  202. if isinstance(annotation, str):
  203. annotation = ForwardRef(annotation)
  204. annotation = evaluate_forwardref(annotation, globalns, globalns)
  205. return annotation
  206. def get_typed_return_annotation(call: Callable[..., Any]) -> Any:
  207. signature = inspect.signature(call)
  208. annotation = signature.return_annotation
  209. if annotation is inspect.Signature.empty:
  210. return None
  211. globalns = getattr(call, "__globals__", {})
  212. return get_typed_annotation(annotation, globalns)
  213. def get_dependant(
  214. *,
  215. path: str,
  216. call: Callable[..., Any],
  217. name: Optional[str] = None,
  218. security_scopes: Optional[List[str]] = None,
  219. use_cache: bool = True,
  220. ) -> Dependant:
  221. path_param_names = get_path_param_names(path)
  222. endpoint_signature = get_typed_signature(call)
  223. signature_params = endpoint_signature.parameters
  224. dependant = Dependant(
  225. call=call,
  226. name=name,
  227. path=path,
  228. security_scopes=security_scopes,
  229. use_cache=use_cache,
  230. )
  231. for param_name, param in signature_params.items():
  232. is_path_param = param_name in path_param_names
  233. type_annotation, depends, param_field = analyze_param(
  234. param_name=param_name,
  235. annotation=param.annotation,
  236. value=param.default,
  237. is_path_param=is_path_param,
  238. )
  239. if depends is not None:
  240. sub_dependant = get_param_sub_dependant(
  241. param_name=param_name,
  242. depends=depends,
  243. path=path,
  244. security_scopes=security_scopes,
  245. )
  246. dependant.dependencies.append(sub_dependant)
  247. continue
  248. if add_non_field_param_to_dependency(
  249. param_name=param_name,
  250. type_annotation=type_annotation,
  251. dependant=dependant,
  252. ):
  253. assert (
  254. param_field is None
  255. ), f"Cannot specify multiple FastAPI annotations for {param_name!r}"
  256. continue
  257. assert param_field is not None
  258. if is_body_param(param_field=param_field, is_path_param=is_path_param):
  259. dependant.body_params.append(param_field)
  260. else:
  261. add_param_to_fields(field=param_field, dependant=dependant)
  262. return dependant
  263. def add_non_field_param_to_dependency(
  264. *, param_name: str, type_annotation: Any, dependant: Dependant
  265. ) -> Optional[bool]:
  266. if lenient_issubclass(type_annotation, Request):
  267. dependant.request_param_name = param_name
  268. return True
  269. elif lenient_issubclass(type_annotation, WebSocket):
  270. dependant.websocket_param_name = param_name
  271. return True
  272. elif lenient_issubclass(type_annotation, HTTPConnection):
  273. dependant.http_connection_param_name = param_name
  274. return True
  275. elif lenient_issubclass(type_annotation, Response):
  276. dependant.response_param_name = param_name
  277. return True
  278. elif lenient_issubclass(type_annotation, StarletteBackgroundTasks):
  279. dependant.background_tasks_param_name = param_name
  280. return True
  281. elif lenient_issubclass(type_annotation, SecurityScopes):
  282. dependant.security_scopes_param_name = param_name
  283. return True
  284. return None
  285. def analyze_param(
  286. *,
  287. param_name: str,
  288. annotation: Any,
  289. value: Any,
  290. is_path_param: bool,
  291. ) -> Tuple[Any, Optional[params.Depends], Optional[ModelField]]:
  292. field_info = None
  293. depends = None
  294. type_annotation: Any = Any
  295. if (
  296. annotation is not inspect.Signature.empty
  297. and get_origin(annotation) is Annotated
  298. ):
  299. annotated_args = get_args(annotation)
  300. type_annotation = annotated_args[0]
  301. fastapi_annotations = [
  302. arg
  303. for arg in annotated_args[1:]
  304. if isinstance(arg, (FieldInfo, params.Depends))
  305. ]
  306. assert (
  307. len(fastapi_annotations) <= 1
  308. ), f"Cannot specify multiple `Annotated` FastAPI arguments for {param_name!r}"
  309. fastapi_annotation = next(iter(fastapi_annotations), None)
  310. if isinstance(fastapi_annotation, FieldInfo):
  311. # Copy `field_info` because we mutate `field_info.default` below.
  312. field_info = copy_field_info(
  313. field_info=fastapi_annotation, annotation=annotation
  314. )
  315. assert field_info.default is Undefined or field_info.default is Required, (
  316. f"`{field_info.__class__.__name__}` default value cannot be set in"
  317. f" `Annotated` for {param_name!r}. Set the default value with `=` instead."
  318. )
  319. if value is not inspect.Signature.empty:
  320. assert not is_path_param, "Path parameters cannot have default values"
  321. field_info.default = value
  322. else:
  323. field_info.default = Required
  324. elif isinstance(fastapi_annotation, params.Depends):
  325. depends = fastapi_annotation
  326. elif annotation is not inspect.Signature.empty:
  327. type_annotation = annotation
  328. if isinstance(value, params.Depends):
  329. assert depends is None, (
  330. "Cannot specify `Depends` in `Annotated` and default value"
  331. f" together for {param_name!r}"
  332. )
  333. assert field_info is None, (
  334. "Cannot specify a FastAPI annotation in `Annotated` and `Depends` as a"
  335. f" default value together for {param_name!r}"
  336. )
  337. depends = value
  338. elif isinstance(value, FieldInfo):
  339. assert field_info is None, (
  340. "Cannot specify FastAPI annotations in `Annotated` and default value"
  341. f" together for {param_name!r}"
  342. )
  343. field_info = value
  344. if PYDANTIC_V2:
  345. field_info.annotation = type_annotation
  346. if depends is not None and depends.dependency is None:
  347. depends.dependency = type_annotation
  348. if lenient_issubclass(
  349. type_annotation,
  350. (
  351. Request,
  352. WebSocket,
  353. HTTPConnection,
  354. Response,
  355. StarletteBackgroundTasks,
  356. SecurityScopes,
  357. ),
  358. ):
  359. assert depends is None, f"Cannot specify `Depends` for type {type_annotation!r}"
  360. assert (
  361. field_info is None
  362. ), f"Cannot specify FastAPI annotation for type {type_annotation!r}"
  363. elif field_info is None and depends is None:
  364. default_value = value if value is not inspect.Signature.empty else Required
  365. if is_path_param:
  366. # We might check here that `default_value is Required`, but the fact is that the same
  367. # parameter might sometimes be a path parameter and sometimes not. See
  368. # `tests/test_infer_param_optionality.py` for an example.
  369. field_info = params.Path(annotation=type_annotation)
  370. elif is_uploadfile_or_nonable_uploadfile_annotation(
  371. type_annotation
  372. ) or is_uploadfile_sequence_annotation(type_annotation):
  373. field_info = params.File(annotation=type_annotation, default=default_value)
  374. elif not field_annotation_is_scalar(annotation=type_annotation):
  375. field_info = params.Body(annotation=type_annotation, default=default_value)
  376. else:
  377. field_info = params.Query(annotation=type_annotation, default=default_value)
  378. field = None
  379. if field_info is not None:
  380. if is_path_param:
  381. assert isinstance(field_info, params.Path), (
  382. f"Cannot use `{field_info.__class__.__name__}` for path param"
  383. f" {param_name!r}"
  384. )
  385. elif (
  386. isinstance(field_info, params.Param)
  387. and getattr(field_info, "in_", None) is None
  388. ):
  389. field_info.in_ = params.ParamTypes.query
  390. use_annotation = get_annotation_from_field_info(
  391. type_annotation,
  392. field_info,
  393. param_name,
  394. )
  395. if not field_info.alias and getattr(field_info, "convert_underscores", None):
  396. alias = param_name.replace("_", "-")
  397. else:
  398. alias = field_info.alias or param_name
  399. field_info.alias = alias
  400. field = create_response_field(
  401. name=param_name,
  402. type_=use_annotation,
  403. default=field_info.default,
  404. alias=alias,
  405. required=field_info.default in (Required, Undefined),
  406. field_info=field_info,
  407. )
  408. return type_annotation, depends, field
  409. def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool:
  410. if is_path_param:
  411. assert is_scalar_field(
  412. field=param_field
  413. ), "Path params must be of one of the supported types"
  414. return False
  415. elif is_scalar_field(field=param_field):
  416. return False
  417. elif isinstance(
  418. param_field.field_info, (params.Query, params.Header)
  419. ) and is_scalar_sequence_field(param_field):
  420. return False
  421. else:
  422. assert isinstance(
  423. param_field.field_info, params.Body
  424. ), f"Param: {param_field.name} can only be a request body, using Body()"
  425. return True
  426. def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None:
  427. field_info = cast(params.Param, field.field_info)
  428. if field_info.in_ == params.ParamTypes.path:
  429. dependant.path_params.append(field)
  430. elif field_info.in_ == params.ParamTypes.query:
  431. dependant.query_params.append(field)
  432. elif field_info.in_ == params.ParamTypes.header:
  433. dependant.header_params.append(field)
  434. else:
  435. assert (
  436. field_info.in_ == params.ParamTypes.cookie
  437. ), f"non-body parameters must be in path, query, header or cookie: {field.name}"
  438. dependant.cookie_params.append(field)
  439. def is_coroutine_callable(call: Callable[..., Any]) -> bool:
  440. if inspect.isroutine(call):
  441. return inspect.iscoroutinefunction(call)
  442. if inspect.isclass(call):
  443. return False
  444. dunder_call = getattr(call, "__call__", None) # noqa: B004
  445. return inspect.iscoroutinefunction(dunder_call)
  446. def is_async_gen_callable(call: Callable[..., Any]) -> bool:
  447. if inspect.isasyncgenfunction(call):
  448. return True
  449. dunder_call = getattr(call, "__call__", None) # noqa: B004
  450. return inspect.isasyncgenfunction(dunder_call)
  451. def is_gen_callable(call: Callable[..., Any]) -> bool:
  452. if inspect.isgeneratorfunction(call):
  453. return True
  454. dunder_call = getattr(call, "__call__", None) # noqa: B004
  455. return inspect.isgeneratorfunction(dunder_call)
  456. async def solve_generator(
  457. *, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any]
  458. ) -> Any:
  459. if is_gen_callable(call):
  460. cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values))
  461. elif is_async_gen_callable(call):
  462. cm = asynccontextmanager(call)(**sub_values)
  463. return await stack.enter_async_context(cm)
  464. async def solve_dependencies(
  465. *,
  466. request: Union[Request, WebSocket],
  467. dependant: Dependant,
  468. body: Optional[Union[Dict[str, Any], FormData]] = None,
  469. background_tasks: Optional[StarletteBackgroundTasks] = None,
  470. response: Optional[Response] = None,
  471. dependency_overrides_provider: Optional[Any] = None,
  472. dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str]], Any]] = None,
  473. ) -> Tuple[
  474. Dict[str, Any],
  475. List[Any],
  476. Optional[StarletteBackgroundTasks],
  477. Response,
  478. Dict[Tuple[Callable[..., Any], Tuple[str]], Any],
  479. ]:
  480. values: Dict[str, Any] = {}
  481. errors: List[Any] = []
  482. if response is None:
  483. response = Response()
  484. del response.headers["content-length"]
  485. response.status_code = None # type: ignore
  486. dependency_cache = dependency_cache or {}
  487. sub_dependant: Dependant
  488. for sub_dependant in dependant.dependencies:
  489. sub_dependant.call = cast(Callable[..., Any], sub_dependant.call)
  490. sub_dependant.cache_key = cast(
  491. Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key
  492. )
  493. call = sub_dependant.call
  494. use_sub_dependant = sub_dependant
  495. if (
  496. dependency_overrides_provider
  497. and dependency_overrides_provider.dependency_overrides
  498. ):
  499. original_call = sub_dependant.call
  500. call = getattr(
  501. dependency_overrides_provider, "dependency_overrides", {}
  502. ).get(original_call, original_call)
  503. use_path: str = sub_dependant.path # type: ignore
  504. use_sub_dependant = get_dependant(
  505. path=use_path,
  506. call=call,
  507. name=sub_dependant.name,
  508. security_scopes=sub_dependant.security_scopes,
  509. )
  510. solved_result = await solve_dependencies(
  511. request=request,
  512. dependant=use_sub_dependant,
  513. body=body,
  514. background_tasks=background_tasks,
  515. response=response,
  516. dependency_overrides_provider=dependency_overrides_provider,
  517. dependency_cache=dependency_cache,
  518. )
  519. (
  520. sub_values,
  521. sub_errors,
  522. background_tasks,
  523. _, # the subdependency returns the same response we have
  524. sub_dependency_cache,
  525. ) = solved_result
  526. dependency_cache.update(sub_dependency_cache)
  527. if sub_errors:
  528. errors.extend(sub_errors)
  529. continue
  530. if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache:
  531. solved = dependency_cache[sub_dependant.cache_key]
  532. elif is_gen_callable(call) or is_async_gen_callable(call):
  533. stack = request.scope.get("fastapi_astack")
  534. assert isinstance(stack, AsyncExitStack)
  535. solved = await solve_generator(
  536. call=call, stack=stack, sub_values=sub_values
  537. )
  538. elif is_coroutine_callable(call):
  539. solved = await call(**sub_values)
  540. else:
  541. solved = await run_in_threadpool(call, **sub_values)
  542. if sub_dependant.name is not None:
  543. values[sub_dependant.name] = solved
  544. if sub_dependant.cache_key not in dependency_cache:
  545. dependency_cache[sub_dependant.cache_key] = solved
  546. path_values, path_errors = request_params_to_args(
  547. dependant.path_params, request.path_params
  548. )
  549. query_values, query_errors = request_params_to_args(
  550. dependant.query_params, request.query_params
  551. )
  552. header_values, header_errors = request_params_to_args(
  553. dependant.header_params, request.headers
  554. )
  555. cookie_values, cookie_errors = request_params_to_args(
  556. dependant.cookie_params, request.cookies
  557. )
  558. values.update(path_values)
  559. values.update(query_values)
  560. values.update(header_values)
  561. values.update(cookie_values)
  562. errors += path_errors + query_errors + header_errors + cookie_errors
  563. if dependant.body_params:
  564. (
  565. body_values,
  566. body_errors,
  567. ) = await request_body_to_args( # body_params checked above
  568. required_params=dependant.body_params, received_body=body
  569. )
  570. values.update(body_values)
  571. errors.extend(body_errors)
  572. if dependant.http_connection_param_name:
  573. values[dependant.http_connection_param_name] = request
  574. if dependant.request_param_name and isinstance(request, Request):
  575. values[dependant.request_param_name] = request
  576. elif dependant.websocket_param_name and isinstance(request, WebSocket):
  577. values[dependant.websocket_param_name] = request
  578. if dependant.background_tasks_param_name:
  579. if background_tasks is None:
  580. background_tasks = BackgroundTasks()
  581. values[dependant.background_tasks_param_name] = background_tasks
  582. if dependant.response_param_name:
  583. values[dependant.response_param_name] = response
  584. if dependant.security_scopes_param_name:
  585. values[dependant.security_scopes_param_name] = SecurityScopes(
  586. scopes=dependant.security_scopes
  587. )
  588. return values, errors, background_tasks, response, dependency_cache
  589. def request_params_to_args(
  590. required_params: Sequence[ModelField],
  591. received_params: Union[Mapping[str, Any], QueryParams, Headers],
  592. ) -> Tuple[Dict[str, Any], List[Any]]:
  593. values = {}
  594. errors = []
  595. for field in required_params:
  596. if is_scalar_sequence_field(field) and isinstance(
  597. received_params, (QueryParams, Headers)
  598. ):
  599. value = received_params.getlist(field.alias) or field.default
  600. else:
  601. value = received_params.get(field.alias)
  602. field_info = field.field_info
  603. assert isinstance(
  604. field_info, params.Param
  605. ), "Params must be subclasses of Param"
  606. loc = (field_info.in_.value, field.alias)
  607. if value is None:
  608. if field.required:
  609. errors.append(get_missing_field_error(loc=loc))
  610. else:
  611. values[field.name] = deepcopy(field.default)
  612. continue
  613. v_, errors_ = field.validate(value, values, loc=loc)
  614. if isinstance(errors_, ErrorWrapper):
  615. errors.append(errors_)
  616. elif isinstance(errors_, list):
  617. new_errors = _regenerate_error_with_loc(errors=errors_, loc_prefix=())
  618. errors.extend(new_errors)
  619. else:
  620. values[field.name] = v_
  621. return values, errors
  622. async def request_body_to_args(
  623. required_params: List[ModelField],
  624. received_body: Optional[Union[Dict[str, Any], FormData]],
  625. ) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
  626. values = {}
  627. errors: List[Dict[str, Any]] = []
  628. if required_params:
  629. field = required_params[0]
  630. field_info = field.field_info
  631. embed = getattr(field_info, "embed", None)
  632. field_alias_omitted = len(required_params) == 1 and not embed
  633. if field_alias_omitted:
  634. received_body = {field.alias: received_body}
  635. for field in required_params:
  636. loc: Tuple[str, ...]
  637. if field_alias_omitted:
  638. loc = ("body",)
  639. else:
  640. loc = ("body", field.alias)
  641. value: Optional[Any] = None
  642. if received_body is not None:
  643. if (is_sequence_field(field)) and isinstance(received_body, FormData):
  644. value = received_body.getlist(field.alias)
  645. else:
  646. try:
  647. value = received_body.get(field.alias)
  648. except AttributeError:
  649. errors.append(get_missing_field_error(loc))
  650. continue
  651. if (
  652. value is None
  653. or (isinstance(field_info, params.Form) and value == "")
  654. or (
  655. isinstance(field_info, params.Form)
  656. and is_sequence_field(field)
  657. and len(value) == 0
  658. )
  659. ):
  660. if field.required:
  661. errors.append(get_missing_field_error(loc))
  662. else:
  663. values[field.name] = deepcopy(field.default)
  664. continue
  665. if (
  666. isinstance(field_info, params.File)
  667. and is_bytes_field(field)
  668. and isinstance(value, UploadFile)
  669. ):
  670. value = await value.read()
  671. elif (
  672. is_bytes_sequence_field(field)
  673. and isinstance(field_info, params.File)
  674. and value_is_sequence(value)
  675. ):
  676. # For types
  677. assert isinstance(value, sequence_types) # type: ignore[arg-type]
  678. results: List[Union[bytes, str]] = []
  679. async def process_fn(
  680. fn: Callable[[], Coroutine[Any, Any, Any]]
  681. ) -> None:
  682. result = await fn()
  683. results.append(result) # noqa: B023
  684. async with anyio.create_task_group() as tg:
  685. for sub_value in value:
  686. tg.start_soon(process_fn, sub_value.read)
  687. value = serialize_sequence_value(field=field, value=results)
  688. v_, errors_ = field.validate(value, values, loc=loc)
  689. if isinstance(errors_, list):
  690. errors.extend(errors_)
  691. elif errors_:
  692. errors.append(errors_)
  693. else:
  694. values[field.name] = v_
  695. return values, errors
  696. def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]:
  697. flat_dependant = get_flat_dependant(dependant)
  698. if not flat_dependant.body_params:
  699. return None
  700. first_param = flat_dependant.body_params[0]
  701. field_info = first_param.field_info
  702. embed = getattr(field_info, "embed", None)
  703. body_param_names_set = {param.name for param in flat_dependant.body_params}
  704. if len(body_param_names_set) == 1 and not embed:
  705. check_file_field(first_param)
  706. return first_param
  707. # If one field requires to embed, all have to be embedded
  708. # in case a sub-dependency is evaluated with a single unique body field
  709. # That is combined (embedded) with other body fields
  710. for param in flat_dependant.body_params:
  711. setattr(param.field_info, "embed", True) # noqa: B010
  712. model_name = "Body_" + name
  713. BodyModel = create_body_model(
  714. fields=flat_dependant.body_params, model_name=model_name
  715. )
  716. required = any(True for f in flat_dependant.body_params if f.required)
  717. BodyFieldInfo_kwargs: Dict[str, Any] = {
  718. "annotation": BodyModel,
  719. "alias": "body",
  720. }
  721. if not required:
  722. BodyFieldInfo_kwargs["default"] = None
  723. if any(isinstance(f.field_info, params.File) for f in flat_dependant.body_params):
  724. BodyFieldInfo: Type[params.Body] = params.File
  725. elif any(isinstance(f.field_info, params.Form) for f in flat_dependant.body_params):
  726. BodyFieldInfo = params.Form
  727. else:
  728. BodyFieldInfo = params.Body
  729. body_param_media_types = [
  730. f.field_info.media_type
  731. for f in flat_dependant.body_params
  732. if isinstance(f.field_info, params.Body)
  733. ]
  734. if len(set(body_param_media_types)) == 1:
  735. BodyFieldInfo_kwargs["media_type"] = body_param_media_types[0]
  736. final_field = create_response_field(
  737. name="body",
  738. type_=BodyModel,
  739. required=required,
  740. alias="body",
  741. field_info=BodyFieldInfo(**BodyFieldInfo_kwargs),
  742. )
  743. check_file_field(final_field)
  744. return final_field