utils.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530
  1. import http.client
  2. import inspect
  3. import warnings
  4. from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Type, Union, cast
  5. from fastapi import routing
  6. from fastapi._compat import (
  7. GenerateJsonSchema,
  8. JsonSchemaValue,
  9. ModelField,
  10. Undefined,
  11. get_compat_model_name_map,
  12. get_definitions,
  13. get_schema_from_model_field,
  14. lenient_issubclass,
  15. )
  16. from fastapi.datastructures import DefaultPlaceholder
  17. from fastapi.dependencies.models import Dependant
  18. from fastapi.dependencies.utils import get_flat_dependant, get_flat_params
  19. from fastapi.encoders import jsonable_encoder
  20. from fastapi.openapi.constants import METHODS_WITH_BODY, REF_PREFIX, REF_TEMPLATE
  21. from fastapi.openapi.models import OpenAPI
  22. from fastapi.params import Body, Param
  23. from fastapi.responses import Response
  24. from fastapi.types import ModelNameMap
  25. from fastapi.utils import (
  26. deep_dict_update,
  27. generate_operation_id_for_path,
  28. is_body_allowed_for_status_code,
  29. )
  30. from starlette.responses import JSONResponse
  31. from starlette.routing import BaseRoute
  32. from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
  33. from typing_extensions import Literal
  34. validation_error_definition = {
  35. "title": "ValidationError",
  36. "type": "object",
  37. "properties": {
  38. "loc": {
  39. "title": "Location",
  40. "type": "array",
  41. "items": {"anyOf": [{"type": "string"}, {"type": "integer"}]},
  42. },
  43. "msg": {"title": "Message", "type": "string"},
  44. "type": {"title": "Error Type", "type": "string"},
  45. },
  46. "required": ["loc", "msg", "type"],
  47. }
  48. validation_error_response_definition = {
  49. "title": "HTTPValidationError",
  50. "type": "object",
  51. "properties": {
  52. "detail": {
  53. "title": "Detail",
  54. "type": "array",
  55. "items": {"$ref": REF_PREFIX + "ValidationError"},
  56. }
  57. },
  58. }
  59. status_code_ranges: Dict[str, str] = {
  60. "1XX": "Information",
  61. "2XX": "Success",
  62. "3XX": "Redirection",
  63. "4XX": "Client Error",
  64. "5XX": "Server Error",
  65. "DEFAULT": "Default Response",
  66. }
  67. def get_openapi_security_definitions(
  68. flat_dependant: Dependant,
  69. ) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
  70. security_definitions = {}
  71. operation_security = []
  72. for security_requirement in flat_dependant.security_requirements:
  73. security_definition = jsonable_encoder(
  74. security_requirement.security_scheme.model,
  75. by_alias=True,
  76. exclude_none=True,
  77. )
  78. security_name = security_requirement.security_scheme.scheme_name
  79. security_definitions[security_name] = security_definition
  80. operation_security.append({security_name: security_requirement.scopes})
  81. return security_definitions, operation_security
  82. def get_openapi_operation_parameters(
  83. *,
  84. all_route_params: Sequence[ModelField],
  85. schema_generator: GenerateJsonSchema,
  86. model_name_map: ModelNameMap,
  87. field_mapping: Dict[
  88. Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
  89. ],
  90. separate_input_output_schemas: bool = True,
  91. ) -> List[Dict[str, Any]]:
  92. parameters = []
  93. for param in all_route_params:
  94. field_info = param.field_info
  95. field_info = cast(Param, field_info)
  96. if not field_info.include_in_schema:
  97. continue
  98. param_schema = get_schema_from_model_field(
  99. field=param,
  100. schema_generator=schema_generator,
  101. model_name_map=model_name_map,
  102. field_mapping=field_mapping,
  103. separate_input_output_schemas=separate_input_output_schemas,
  104. )
  105. parameter = {
  106. "name": param.alias,
  107. "in": field_info.in_.value,
  108. "required": param.required,
  109. "schema": param_schema,
  110. }
  111. if field_info.description:
  112. parameter["description"] = field_info.description
  113. if field_info.openapi_examples:
  114. parameter["examples"] = jsonable_encoder(field_info.openapi_examples)
  115. elif field_info.example != Undefined:
  116. parameter["example"] = jsonable_encoder(field_info.example)
  117. if field_info.deprecated:
  118. parameter["deprecated"] = field_info.deprecated
  119. parameters.append(parameter)
  120. return parameters
  121. def get_openapi_operation_request_body(
  122. *,
  123. body_field: Optional[ModelField],
  124. schema_generator: GenerateJsonSchema,
  125. model_name_map: ModelNameMap,
  126. field_mapping: Dict[
  127. Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
  128. ],
  129. separate_input_output_schemas: bool = True,
  130. ) -> Optional[Dict[str, Any]]:
  131. if not body_field:
  132. return None
  133. assert isinstance(body_field, ModelField)
  134. body_schema = get_schema_from_model_field(
  135. field=body_field,
  136. schema_generator=schema_generator,
  137. model_name_map=model_name_map,
  138. field_mapping=field_mapping,
  139. separate_input_output_schemas=separate_input_output_schemas,
  140. )
  141. field_info = cast(Body, body_field.field_info)
  142. request_media_type = field_info.media_type
  143. required = body_field.required
  144. request_body_oai: Dict[str, Any] = {}
  145. if required:
  146. request_body_oai["required"] = required
  147. request_media_content: Dict[str, Any] = {"schema": body_schema}
  148. if field_info.openapi_examples:
  149. request_media_content["examples"] = jsonable_encoder(
  150. field_info.openapi_examples
  151. )
  152. elif field_info.example != Undefined:
  153. request_media_content["example"] = jsonable_encoder(field_info.example)
  154. request_body_oai["content"] = {request_media_type: request_media_content}
  155. return request_body_oai
  156. def generate_operation_id(
  157. *, route: routing.APIRoute, method: str
  158. ) -> str: # pragma: nocover
  159. warnings.warn(
  160. "fastapi.openapi.utils.generate_operation_id() was deprecated, "
  161. "it is not used internally, and will be removed soon",
  162. DeprecationWarning,
  163. stacklevel=2,
  164. )
  165. if route.operation_id:
  166. return route.operation_id
  167. path: str = route.path_format
  168. return generate_operation_id_for_path(name=route.name, path=path, method=method)
  169. def generate_operation_summary(*, route: routing.APIRoute, method: str) -> str:
  170. if route.summary:
  171. return route.summary
  172. return route.name.replace("_", " ").title()
  173. def get_openapi_operation_metadata(
  174. *, route: routing.APIRoute, method: str, operation_ids: Set[str]
  175. ) -> Dict[str, Any]:
  176. operation: Dict[str, Any] = {}
  177. if route.tags:
  178. operation["tags"] = route.tags
  179. operation["summary"] = generate_operation_summary(route=route, method=method)
  180. if route.description:
  181. operation["description"] = route.description
  182. operation_id = route.operation_id or route.unique_id
  183. if operation_id in operation_ids:
  184. message = (
  185. f"Duplicate Operation ID {operation_id} for function "
  186. + f"{route.endpoint.__name__}"
  187. )
  188. file_name = getattr(route.endpoint, "__globals__", {}).get("__file__")
  189. if file_name:
  190. message += f" at {file_name}"
  191. warnings.warn(message, stacklevel=1)
  192. operation_ids.add(operation_id)
  193. operation["operationId"] = operation_id
  194. if route.deprecated:
  195. operation["deprecated"] = route.deprecated
  196. return operation
  197. def get_openapi_path(
  198. *,
  199. route: routing.APIRoute,
  200. operation_ids: Set[str],
  201. schema_generator: GenerateJsonSchema,
  202. model_name_map: ModelNameMap,
  203. field_mapping: Dict[
  204. Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
  205. ],
  206. separate_input_output_schemas: bool = True,
  207. ) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
  208. path = {}
  209. security_schemes: Dict[str, Any] = {}
  210. definitions: Dict[str, Any] = {}
  211. assert route.methods is not None, "Methods must be a list"
  212. if isinstance(route.response_class, DefaultPlaceholder):
  213. current_response_class: Type[Response] = route.response_class.value
  214. else:
  215. current_response_class = route.response_class
  216. assert current_response_class, "A response class is needed to generate OpenAPI"
  217. route_response_media_type: Optional[str] = current_response_class.media_type
  218. if route.include_in_schema:
  219. for method in route.methods:
  220. operation = get_openapi_operation_metadata(
  221. route=route, method=method, operation_ids=operation_ids
  222. )
  223. parameters: List[Dict[str, Any]] = []
  224. flat_dependant = get_flat_dependant(route.dependant, skip_repeats=True)
  225. security_definitions, operation_security = get_openapi_security_definitions(
  226. flat_dependant=flat_dependant
  227. )
  228. if operation_security:
  229. operation.setdefault("security", []).extend(operation_security)
  230. if security_definitions:
  231. security_schemes.update(security_definitions)
  232. all_route_params = get_flat_params(route.dependant)
  233. operation_parameters = get_openapi_operation_parameters(
  234. all_route_params=all_route_params,
  235. schema_generator=schema_generator,
  236. model_name_map=model_name_map,
  237. field_mapping=field_mapping,
  238. separate_input_output_schemas=separate_input_output_schemas,
  239. )
  240. parameters.extend(operation_parameters)
  241. if parameters:
  242. all_parameters = {
  243. (param["in"], param["name"]): param for param in parameters
  244. }
  245. required_parameters = {
  246. (param["in"], param["name"]): param
  247. for param in parameters
  248. if param.get("required")
  249. }
  250. # Make sure required definitions of the same parameter take precedence
  251. # over non-required definitions
  252. all_parameters.update(required_parameters)
  253. operation["parameters"] = list(all_parameters.values())
  254. if method in METHODS_WITH_BODY:
  255. request_body_oai = get_openapi_operation_request_body(
  256. body_field=route.body_field,
  257. schema_generator=schema_generator,
  258. model_name_map=model_name_map,
  259. field_mapping=field_mapping,
  260. separate_input_output_schemas=separate_input_output_schemas,
  261. )
  262. if request_body_oai:
  263. operation["requestBody"] = request_body_oai
  264. if route.callbacks:
  265. callbacks = {}
  266. for callback in route.callbacks:
  267. if isinstance(callback, routing.APIRoute):
  268. (
  269. cb_path,
  270. cb_security_schemes,
  271. cb_definitions,
  272. ) = get_openapi_path(
  273. route=callback,
  274. operation_ids=operation_ids,
  275. schema_generator=schema_generator,
  276. model_name_map=model_name_map,
  277. field_mapping=field_mapping,
  278. separate_input_output_schemas=separate_input_output_schemas,
  279. )
  280. callbacks[callback.name] = {callback.path: cb_path}
  281. operation["callbacks"] = callbacks
  282. if route.status_code is not None:
  283. status_code = str(route.status_code)
  284. else:
  285. # It would probably make more sense for all response classes to have an
  286. # explicit default status_code, and to extract it from them, instead of
  287. # doing this inspection tricks, that would probably be in the future
  288. # TODO: probably make status_code a default class attribute for all
  289. # responses in Starlette
  290. response_signature = inspect.signature(current_response_class.__init__)
  291. status_code_param = response_signature.parameters.get("status_code")
  292. if status_code_param is not None:
  293. if isinstance(status_code_param.default, int):
  294. status_code = str(status_code_param.default)
  295. operation.setdefault("responses", {}).setdefault(status_code, {})[
  296. "description"
  297. ] = route.response_description
  298. if route_response_media_type and is_body_allowed_for_status_code(
  299. route.status_code
  300. ):
  301. response_schema = {"type": "string"}
  302. if lenient_issubclass(current_response_class, JSONResponse):
  303. if route.response_field:
  304. response_schema = get_schema_from_model_field(
  305. field=route.response_field,
  306. schema_generator=schema_generator,
  307. model_name_map=model_name_map,
  308. field_mapping=field_mapping,
  309. separate_input_output_schemas=separate_input_output_schemas,
  310. )
  311. else:
  312. response_schema = {}
  313. operation.setdefault("responses", {}).setdefault(
  314. status_code, {}
  315. ).setdefault("content", {}).setdefault(route_response_media_type, {})[
  316. "schema"
  317. ] = response_schema
  318. if route.responses:
  319. operation_responses = operation.setdefault("responses", {})
  320. for (
  321. additional_status_code,
  322. additional_response,
  323. ) in route.responses.items():
  324. process_response = additional_response.copy()
  325. process_response.pop("model", None)
  326. status_code_key = str(additional_status_code).upper()
  327. if status_code_key == "DEFAULT":
  328. status_code_key = "default"
  329. openapi_response = operation_responses.setdefault(
  330. status_code_key, {}
  331. )
  332. assert isinstance(
  333. process_response, dict
  334. ), "An additional response must be a dict"
  335. field = route.response_fields.get(additional_status_code)
  336. additional_field_schema: Optional[Dict[str, Any]] = None
  337. if field:
  338. additional_field_schema = get_schema_from_model_field(
  339. field=field,
  340. schema_generator=schema_generator,
  341. model_name_map=model_name_map,
  342. field_mapping=field_mapping,
  343. separate_input_output_schemas=separate_input_output_schemas,
  344. )
  345. media_type = route_response_media_type or "application/json"
  346. additional_schema = (
  347. process_response.setdefault("content", {})
  348. .setdefault(media_type, {})
  349. .setdefault("schema", {})
  350. )
  351. deep_dict_update(additional_schema, additional_field_schema)
  352. status_text: Optional[str] = status_code_ranges.get(
  353. str(additional_status_code).upper()
  354. ) or http.client.responses.get(int(additional_status_code))
  355. description = (
  356. process_response.get("description")
  357. or openapi_response.get("description")
  358. or status_text
  359. or "Additional Response"
  360. )
  361. deep_dict_update(openapi_response, process_response)
  362. openapi_response["description"] = description
  363. http422 = str(HTTP_422_UNPROCESSABLE_ENTITY)
  364. if (all_route_params or route.body_field) and not any(
  365. status in operation["responses"]
  366. for status in [http422, "4XX", "default"]
  367. ):
  368. operation["responses"][http422] = {
  369. "description": "Validation Error",
  370. "content": {
  371. "application/json": {
  372. "schema": {"$ref": REF_PREFIX + "HTTPValidationError"}
  373. }
  374. },
  375. }
  376. if "ValidationError" not in definitions:
  377. definitions.update(
  378. {
  379. "ValidationError": validation_error_definition,
  380. "HTTPValidationError": validation_error_response_definition,
  381. }
  382. )
  383. if route.openapi_extra:
  384. deep_dict_update(operation, route.openapi_extra)
  385. path[method.lower()] = operation
  386. return path, security_schemes, definitions
  387. def get_fields_from_routes(
  388. routes: Sequence[BaseRoute],
  389. ) -> List[ModelField]:
  390. body_fields_from_routes: List[ModelField] = []
  391. responses_from_routes: List[ModelField] = []
  392. request_fields_from_routes: List[ModelField] = []
  393. callback_flat_models: List[ModelField] = []
  394. for route in routes:
  395. if getattr(route, "include_in_schema", None) and isinstance(
  396. route, routing.APIRoute
  397. ):
  398. if route.body_field:
  399. assert isinstance(
  400. route.body_field, ModelField
  401. ), "A request body must be a Pydantic Field"
  402. body_fields_from_routes.append(route.body_field)
  403. if route.response_field:
  404. responses_from_routes.append(route.response_field)
  405. if route.response_fields:
  406. responses_from_routes.extend(route.response_fields.values())
  407. if route.callbacks:
  408. callback_flat_models.extend(get_fields_from_routes(route.callbacks))
  409. params = get_flat_params(route.dependant)
  410. request_fields_from_routes.extend(params)
  411. flat_models = callback_flat_models + list(
  412. body_fields_from_routes + responses_from_routes + request_fields_from_routes
  413. )
  414. return flat_models
  415. def get_openapi(
  416. *,
  417. title: str,
  418. version: str,
  419. openapi_version: str = "3.1.0",
  420. summary: Optional[str] = None,
  421. description: Optional[str] = None,
  422. routes: Sequence[BaseRoute],
  423. webhooks: Optional[Sequence[BaseRoute]] = None,
  424. tags: Optional[List[Dict[str, Any]]] = None,
  425. servers: Optional[List[Dict[str, Union[str, Any]]]] = None,
  426. terms_of_service: Optional[str] = None,
  427. contact: Optional[Dict[str, Union[str, Any]]] = None,
  428. license_info: Optional[Dict[str, Union[str, Any]]] = None,
  429. separate_input_output_schemas: bool = True,
  430. ) -> Dict[str, Any]:
  431. info: Dict[str, Any] = {"title": title, "version": version}
  432. if summary:
  433. info["summary"] = summary
  434. if description:
  435. info["description"] = description
  436. if terms_of_service:
  437. info["termsOfService"] = terms_of_service
  438. if contact:
  439. info["contact"] = contact
  440. if license_info:
  441. info["license"] = license_info
  442. output: Dict[str, Any] = {"openapi": openapi_version, "info": info}
  443. if servers:
  444. output["servers"] = servers
  445. components: Dict[str, Dict[str, Any]] = {}
  446. paths: Dict[str, Dict[str, Any]] = {}
  447. webhook_paths: Dict[str, Dict[str, Any]] = {}
  448. operation_ids: Set[str] = set()
  449. all_fields = get_fields_from_routes(list(routes or []) + list(webhooks or []))
  450. model_name_map = get_compat_model_name_map(all_fields)
  451. schema_generator = GenerateJsonSchema(ref_template=REF_TEMPLATE)
  452. field_mapping, definitions = get_definitions(
  453. fields=all_fields,
  454. schema_generator=schema_generator,
  455. model_name_map=model_name_map,
  456. separate_input_output_schemas=separate_input_output_schemas,
  457. )
  458. for route in routes or []:
  459. if isinstance(route, routing.APIRoute):
  460. result = get_openapi_path(
  461. route=route,
  462. operation_ids=operation_ids,
  463. schema_generator=schema_generator,
  464. model_name_map=model_name_map,
  465. field_mapping=field_mapping,
  466. separate_input_output_schemas=separate_input_output_schemas,
  467. )
  468. if result:
  469. path, security_schemes, path_definitions = result
  470. if path:
  471. paths.setdefault(route.path_format, {}).update(path)
  472. if security_schemes:
  473. components.setdefault("securitySchemes", {}).update(
  474. security_schemes
  475. )
  476. if path_definitions:
  477. definitions.update(path_definitions)
  478. for webhook in webhooks or []:
  479. if isinstance(webhook, routing.APIRoute):
  480. result = get_openapi_path(
  481. route=webhook,
  482. operation_ids=operation_ids,
  483. schema_generator=schema_generator,
  484. model_name_map=model_name_map,
  485. field_mapping=field_mapping,
  486. separate_input_output_schemas=separate_input_output_schemas,
  487. )
  488. if result:
  489. path, security_schemes, path_definitions = result
  490. if path:
  491. webhook_paths.setdefault(webhook.path_format, {}).update(path)
  492. if security_schemes:
  493. components.setdefault("securitySchemes", {}).update(
  494. security_schemes
  495. )
  496. if path_definitions:
  497. definitions.update(path_definitions)
  498. if definitions:
  499. components["schemas"] = {k: definitions[k] for k in sorted(definitions)}
  500. if components:
  501. output["components"] = components
  502. output["paths"] = paths
  503. if webhook_paths:
  504. output["webhooks"] = webhook_paths
  505. if tags:
  506. output["tags"] = tags
  507. return jsonable_encoder(OpenAPI(**output), by_alias=True, exclude_none=True) # type: ignore