bazarr/libs/flask_restx/swagger.py
2022-11-07 13:08:27 -05:00

745 lines
26 KiB
Python

# -*- coding: utf-8 -*-
import itertools
import re
from inspect import isclass, getdoc
from collections import OrderedDict
from collections.abc import Hashable
from flask import current_app
from . import fields
from .model import Model, ModelBase, OrderedModel
from .reqparse import RequestParser
from .utils import merge, not_none, not_none_sorted
from ._http import HTTPStatus
from urllib.parse import quote
#: Maps Flask/Werkzeug rooting types to Swagger ones
PATH_TYPES = {
"int": "integer",
"float": "number",
"string": "string",
"default": "string",
}
#: Maps Python primitives types to Swagger ones
PY_TYPES = {
int: "integer",
float: "number",
str: "string",
bool: "boolean",
None: "void",
}
RE_URL = re.compile(r"<(?:[^:<>]+:)?([^<>]+)>")
DEFAULT_RESPONSE_DESCRIPTION = "Success"
DEFAULT_RESPONSE = {"description": DEFAULT_RESPONSE_DESCRIPTION}
RE_RAISES = re.compile(
r"^:raises\s+(?P<name>[\w\d_]+)\s*:\s*(?P<description>.*)$", re.MULTILINE
)
RE_PARSE_RULE = re.compile(
r"""
(?P<static>[^<]*) # static rule data
<
(?:
(?P<converter>[a-zA-Z_][a-zA-Z0-9_]*) # converter name
(?:\((?P<args>.*?)\))? # converter arguments
\: # variable delimiter
)?
(?P<variable>[a-zA-Z_][a-zA-Z0-9_]*) # variable name
>
""",
re.VERBOSE,
)
def ref(model):
"""Return a reference to model in definitions"""
name = model.name if isinstance(model, ModelBase) else model
return {"$ref": "#/definitions/{0}".format(quote(name, safe=""))}
def _v(value):
"""Dereference values (callable)"""
return value() if callable(value) else value
def extract_path(path):
"""
Transform a Flask/Werkzeug URL pattern in a Swagger one.
"""
return RE_URL.sub(r"{\1}", path)
def parse_rule(rule):
"""
Parse a rule and return it as generator. Each iteration yields tuples in the form
``(converter, arguments, variable)``. If the converter is `None` it's a static url part, otherwise it's a dynamic
one.
Note: This originally lived in werkzeug.routing.parse_rule until it was removed in werkzeug 2.2.0.
"""
pos = 0
end = len(rule)
do_match = RE_PARSE_RULE.match
used_names = set()
while pos < end:
m = do_match(rule, pos)
if m is None:
break
data = m.groupdict()
if data["static"]:
yield None, None, data["static"]
variable = data["variable"]
converter = data["converter"] or "default"
if variable in used_names:
raise ValueError(f"variable name {variable!r} used twice.")
used_names.add(variable)
yield converter, data["args"] or None, variable
pos = m.end()
if pos < end:
remaining = rule[pos:]
if ">" in remaining or "<" in remaining:
raise ValueError(f"malformed url rule: {rule!r}")
yield None, None, remaining
def extract_path_params(path):
"""
Extract Flask-style parameters from an URL pattern as Swagger ones.
"""
params = OrderedDict()
for converter, arguments, variable in parse_rule(path):
if not converter:
continue
param = {"name": variable, "in": "path", "required": True}
if converter in PATH_TYPES:
param["type"] = PATH_TYPES[converter]
elif converter in current_app.url_map.converters:
param["type"] = "string"
else:
raise ValueError("Unsupported type converter: %s" % converter)
params[variable] = param
return params
def _param_to_header(param):
param.pop("in", None)
param.pop("name", None)
return _clean_header(param)
def _clean_header(header):
if isinstance(header, str):
header = {"description": header}
typedef = header.get("type", "string")
if isinstance(typedef, Hashable) and typedef in PY_TYPES:
header["type"] = PY_TYPES[typedef]
elif (
isinstance(typedef, (list, tuple))
and len(typedef) == 1
and typedef[0] in PY_TYPES
):
header["type"] = "array"
header["items"] = {"type": PY_TYPES[typedef[0]]}
elif hasattr(typedef, "__schema__"):
header.update(typedef.__schema__)
else:
header["type"] = typedef
return not_none(header)
def parse_docstring(obj):
raw = getdoc(obj)
summary = raw.strip(" \n").split("\n")[0].split(".")[0] if raw else None
raises = {}
details = raw.replace(summary, "").lstrip(". \n").strip(" \n") if raw else None
for match in RE_RAISES.finditer(raw or ""):
raises[match.group("name")] = match.group("description")
if details:
details = details.replace(match.group(0), "")
parsed = {
"raw": raw,
"summary": summary or None,
"details": details or None,
"returns": None,
"params": [],
"raises": raises,
}
return parsed
def is_hidden(resource, route_doc=None):
"""
Determine whether a Resource has been hidden from Swagger documentation
i.e. by using Api.doc(False) decorator
"""
if route_doc is False:
return True
else:
return hasattr(resource, "__apidoc__") and resource.__apidoc__ is False
def build_request_body_parameters_schema(body_params):
"""
:param body_params: List of JSON schema of body parameters.
:type body_params: list of dict, generated from the json body parameters of a request parser
:return dict: The Swagger schema representation of the request body
:Example:
{
'name': 'payload',
'required': True,
'in': 'body',
'schema': {
'type': 'object',
'properties': [
'parameter1': {
'type': 'integer'
},
'parameter2': {
'type': 'string'
}
]
}
}
"""
properties = {}
for param in body_params:
properties[param["name"]] = {"type": param.get("type", "string")}
return {
"name": "payload",
"required": True,
"in": "body",
"schema": {"type": "object", "properties": properties},
}
class Swagger(object):
"""
A Swagger documentation wrapper for an API instance.
"""
def __init__(self, api):
self.api = api
self._registered_models = {}
def as_dict(self):
"""
Output the specification as a serializable ``dict``.
:returns: the full Swagger specification in a serializable format
:rtype: dict
"""
basepath = self.api.base_path
if len(basepath) > 1 and basepath.endswith("/"):
basepath = basepath[:-1]
infos = {
"title": _v(self.api.title),
"version": _v(self.api.version),
}
if self.api.description:
infos["description"] = _v(self.api.description)
if self.api.terms_url:
infos["termsOfService"] = _v(self.api.terms_url)
if self.api.contact and (self.api.contact_email or self.api.contact_url):
infos["contact"] = {
"name": _v(self.api.contact),
"email": _v(self.api.contact_email),
"url": _v(self.api.contact_url),
}
if self.api.license:
infos["license"] = {"name": _v(self.api.license)}
if self.api.license_url:
infos["license"]["url"] = _v(self.api.license_url)
paths = {}
tags = self.extract_tags(self.api)
# register errors
responses = self.register_errors()
for ns in self.api.namespaces:
for resource, urls, route_doc, kwargs in ns.resources:
for url in self.api.ns_urls(ns, urls):
path = extract_path(url)
serialized = self.serialize_resource(
ns, resource, url, route_doc=route_doc, **kwargs
)
paths[path] = serialized
# register all models if required
if current_app.config["RESTX_INCLUDE_ALL_MODELS"]:
for m in self.api.models:
self.register_model(m)
# merge in the top-level authorizations
for ns in self.api.namespaces:
if ns.authorizations:
if self.api.authorizations is None:
self.api.authorizations = {}
self.api.authorizations = merge(
self.api.authorizations, ns.authorizations
)
specs = {
"swagger": "2.0",
"basePath": basepath,
"paths": not_none_sorted(paths),
"info": infos,
"produces": list(self.api.representations.keys()),
"consumes": ["application/json"],
"securityDefinitions": self.api.authorizations or None,
"security": self.security_requirements(self.api.security) or None,
"tags": tags,
"definitions": self.serialize_definitions() or None,
"responses": responses or None,
"host": self.get_host(),
}
return not_none(specs)
def get_host(self):
hostname = current_app.config.get("SERVER_NAME", None) or None
if hostname and self.api.blueprint and self.api.blueprint.subdomain:
hostname = ".".join((self.api.blueprint.subdomain, hostname))
return hostname
def extract_tags(self, api):
tags = []
by_name = {}
for tag in api.tags:
if isinstance(tag, str):
tag = {"name": tag}
elif isinstance(tag, (list, tuple)):
tag = {"name": tag[0], "description": tag[1]}
elif isinstance(tag, dict) and "name" in tag:
pass
else:
raise ValueError("Unsupported tag format for {0}".format(tag))
tags.append(tag)
by_name[tag["name"]] = tag
for ns in api.namespaces:
# hide namespaces without any Resources
if not ns.resources:
continue
# hide namespaces with all Resources hidden from Swagger documentation
if all(is_hidden(r.resource, route_doc=r.route_doc) for r in ns.resources):
continue
if ns.name not in by_name:
tags.append(
{"name": ns.name, "description": ns.description}
if ns.description
else {"name": ns.name}
)
elif ns.description:
by_name[ns.name]["description"] = ns.description
return tags
def extract_resource_doc(self, resource, url, route_doc=None):
route_doc = {} if route_doc is None else route_doc
if route_doc is False:
return False
doc = merge(getattr(resource, "__apidoc__", {}), route_doc)
if doc is False:
return False
# ensure unique names for multiple routes to the same resource
# provides different Swagger operationId's
doc["name"] = (
"{}_{}".format(resource.__name__, url) if route_doc else resource.__name__
)
params = merge(self.expected_params(doc), doc.get("params", OrderedDict()))
params = merge(params, extract_path_params(url))
# Track parameters for late deduplication
up_params = {(n, p.get("in", "query")): p for n, p in params.items()}
need_to_go_down = set()
methods = [m.lower() for m in resource.methods or []]
for method in methods:
method_doc = doc.get(method, OrderedDict())
method_impl = getattr(resource, method)
if hasattr(method_impl, "im_func"):
method_impl = method_impl.im_func
elif hasattr(method_impl, "__func__"):
method_impl = method_impl.__func__
method_doc = merge(
method_doc, getattr(method_impl, "__apidoc__", OrderedDict())
)
if method_doc is not False:
method_doc["docstring"] = parse_docstring(method_impl)
method_params = self.expected_params(method_doc)
method_params = merge(method_params, method_doc.get("params", {}))
inherited_params = OrderedDict(
(k, v) for k, v in params.items() if k in method_params
)
method_doc["params"] = merge(inherited_params, method_params)
for name, param in method_doc["params"].items():
key = (name, param.get("in", "query"))
if key in up_params:
need_to_go_down.add(key)
doc[method] = method_doc
# Deduplicate parameters
# For each couple (name, in), if a method overrides it,
# we need to move the paramter down to each method
if need_to_go_down:
for method in methods:
method_doc = doc.get(method)
if not method_doc:
continue
params = {
(n, p.get("in", "query")): p
for n, p in (method_doc["params"] or {}).items()
}
for key in need_to_go_down:
if key not in params:
method_doc["params"][key[0]] = up_params[key]
doc["params"] = OrderedDict(
(k[0], p) for k, p in up_params.items() if k not in need_to_go_down
)
return doc
def expected_params(self, doc):
params = OrderedDict()
if "expect" not in doc:
return params
for expect in doc.get("expect", []):
if isinstance(expect, RequestParser):
parser_params = OrderedDict(
(p["name"], p) for p in expect.__schema__ if p["in"] != "body"
)
params.update(parser_params)
body_params = [p for p in expect.__schema__ if p["in"] == "body"]
if body_params:
params["payload"] = build_request_body_parameters_schema(
body_params
)
elif isinstance(expect, ModelBase):
params["payload"] = not_none(
{
"name": "payload",
"required": True,
"in": "body",
"schema": self.serialize_schema(expect),
}
)
elif isinstance(expect, (list, tuple)):
if len(expect) == 2:
# this is (payload, description) shortcut
model, description = expect
params["payload"] = not_none(
{
"name": "payload",
"required": True,
"in": "body",
"schema": self.serialize_schema(model),
"description": description,
}
)
else:
params["payload"] = not_none(
{
"name": "payload",
"required": True,
"in": "body",
"schema": self.serialize_schema(expect),
}
)
return params
def register_errors(self):
responses = {}
for exception, handler in self.api.error_handlers.items():
doc = parse_docstring(handler)
response = {"description": doc["summary"]}
apidoc = getattr(handler, "__apidoc__", {})
self.process_headers(response, apidoc)
if "responses" in apidoc:
_, model, _ = list(apidoc["responses"].values())[0]
response["schema"] = self.serialize_schema(model)
responses[exception.__name__] = not_none(response)
return responses
def serialize_resource(self, ns, resource, url, route_doc=None, **kwargs):
doc = self.extract_resource_doc(resource, url, route_doc=route_doc)
if doc is False:
return
path = {"parameters": self.parameters_for(doc) or None}
for method in [m.lower() for m in resource.methods or []]:
methods = [m.lower() for m in kwargs.get("methods", [])]
if doc[method] is False or methods and method not in methods:
continue
path[method] = self.serialize_operation(doc, method)
path[method]["tags"] = [ns.name]
return not_none(path)
def serialize_operation(self, doc, method):
operation = {
"responses": self.responses_for(doc, method) or None,
"summary": doc[method]["docstring"]["summary"],
"description": self.description_for(doc, method) or None,
"operationId": self.operation_id_for(doc, method),
"parameters": self.parameters_for(doc[method]) or None,
"security": self.security_for(doc, method),
}
# Handle 'produces' mimetypes documentation
if "produces" in doc[method]:
operation["produces"] = doc[method]["produces"]
# Handle deprecated annotation
if doc.get("deprecated") or doc[method].get("deprecated"):
operation["deprecated"] = True
# Handle form exceptions:
doc_params = list(doc.get("params", {}).values())
all_params = doc_params + (operation["parameters"] or [])
if all_params and any(p["in"] == "formData" for p in all_params):
if any(p["type"] == "file" for p in all_params):
operation["consumes"] = ["multipart/form-data"]
else:
operation["consumes"] = [
"application/x-www-form-urlencoded",
"multipart/form-data",
]
operation.update(self.vendor_fields(doc, method))
return not_none(operation)
def vendor_fields(self, doc, method):
"""
Extract custom 3rd party Vendor fields prefixed with ``x-``
See: https://swagger.io/specification/#specification-extensions
"""
return dict(
(k if k.startswith("x-") else "x-{0}".format(k), v)
for k, v in doc[method].get("vendor", {}).items()
)
def description_for(self, doc, method):
"""Extract the description metadata and fallback on the whole docstring"""
parts = []
if "description" in doc:
parts.append(doc["description"] or "")
if method in doc and "description" in doc[method]:
parts.append(doc[method]["description"])
if doc[method]["docstring"]["details"]:
parts.append(doc[method]["docstring"]["details"])
return "\n".join(parts).strip()
def operation_id_for(self, doc, method):
"""Extract the operation id"""
return (
doc[method]["id"]
if "id" in doc[method]
else self.api.default_id(doc["name"], method)
)
def parameters_for(self, doc):
params = []
for name, param in doc["params"].items():
param["name"] = name
if "type" not in param and "schema" not in param:
param["type"] = "string"
if "in" not in param:
param["in"] = "query"
if "type" in param and "schema" not in param:
ptype = param.get("type", None)
if isinstance(ptype, (list, tuple)):
typ = ptype[0]
param["type"] = "array"
param["items"] = {"type": PY_TYPES.get(typ, typ)}
elif isinstance(ptype, (type, type(None))) and ptype in PY_TYPES:
param["type"] = PY_TYPES[ptype]
params.append(param)
# Handle fields mask
mask = doc.get("__mask__")
if mask and current_app.config["RESTX_MASK_SWAGGER"]:
param = {
"name": current_app.config["RESTX_MASK_HEADER"],
"in": "header",
"type": "string",
"format": "mask",
"description": "An optional fields mask",
}
if isinstance(mask, str):
param["default"] = mask
params.append(param)
return params
def responses_for(self, doc, method):
# TODO: simplify/refactor responses/model handling
responses = {}
for d in doc, doc[method]:
if "responses" in d:
for code, response in d["responses"].items():
code = str(code)
if isinstance(response, str):
description = response
model = None
kwargs = {}
elif len(response) == 3:
description, model, kwargs = response
elif len(response) == 2:
description, model = response
kwargs = {}
else:
raise ValueError("Unsupported response specification")
description = description or DEFAULT_RESPONSE_DESCRIPTION
if code in responses:
responses[code].update(description=description)
else:
responses[code] = {"description": description}
if model:
schema = self.serialize_schema(model)
envelope = kwargs.get("envelope")
if envelope:
schema = {"properties": {envelope: schema}}
responses[code]["schema"] = schema
self.process_headers(
responses[code], doc, method, kwargs.get("headers")
)
if "model" in d:
code = str(d.get("default_code", HTTPStatus.OK))
if code not in responses:
responses[code] = self.process_headers(
DEFAULT_RESPONSE.copy(), doc, method
)
responses[code]["schema"] = self.serialize_schema(d["model"])
if "docstring" in d:
for name, description in d["docstring"]["raises"].items():
for exception, handler in self.api.error_handlers.items():
error_responses = getattr(handler, "__apidoc__", {}).get(
"responses", {}
)
code = (
str(list(error_responses.keys())[0])
if error_responses
else None
)
if code and exception.__name__ == name:
responses[code] = {"$ref": "#/responses/{0}".format(name)}
break
if not responses:
responses[str(HTTPStatus.OK.value)] = self.process_headers(
DEFAULT_RESPONSE.copy(), doc, method
)
return responses
def process_headers(self, response, doc, method=None, headers=None):
method_doc = doc.get(method, {})
if "headers" in doc or "headers" in method_doc or headers:
response["headers"] = dict(
(k, _clean_header(v))
for k, v in itertools.chain(
doc.get("headers", {}).items(),
method_doc.get("headers", {}).items(),
(headers or {}).items(),
)
)
return response
def serialize_definitions(self):
return dict(
(name, model.__schema__) for name, model in self._registered_models.items()
)
def serialize_schema(self, model):
if isinstance(model, (list, tuple)):
model = model[0]
return {
"type": "array",
"items": self.serialize_schema(model),
}
elif isinstance(model, ModelBase):
self.register_model(model)
return ref(model)
elif isinstance(model, str):
self.register_model(model)
return ref(model)
elif isclass(model) and issubclass(model, fields.Raw):
return self.serialize_schema(model())
elif isinstance(model, fields.Raw):
return model.__schema__
elif isinstance(model, (type, type(None))) and model in PY_TYPES:
return {"type": PY_TYPES[model]}
raise ValueError("Model {0} not registered".format(model))
def register_model(self, model):
name = model.name if isinstance(model, ModelBase) else model
if name not in self.api.models:
raise ValueError("Model {0} not registered".format(name))
specs = self.api.models[name]
if name in self._registered_models:
return ref(model)
self._registered_models[name] = specs
if isinstance(specs, ModelBase):
for parent in specs.__parents__:
self.register_model(parent)
if isinstance(specs, (Model, OrderedModel)):
for field in specs.values():
self.register_field(field)
return ref(model)
def register_field(self, field):
if isinstance(field, fields.Polymorph):
for model in field.mapping.values():
self.register_model(model)
elif isinstance(field, fields.Nested):
self.register_model(field.nested)
elif isinstance(field, (fields.List, fields.Wildcard)):
self.register_field(field.container)
def security_for(self, doc, method):
security = None
if "security" in doc:
auth = doc["security"]
security = self.security_requirements(auth)
if "security" in doc[method]:
auth = doc[method]["security"]
security = self.security_requirements(auth)
return security
def security_requirements(self, value):
if isinstance(value, (list, tuple)):
return [self.security_requirement(v) for v in value]
elif value:
requirement = self.security_requirement(value)
return [requirement] if requirement else None
else:
return []
def security_requirement(self, value):
if isinstance(value, (str)):
return {value: []}
elif isinstance(value, dict):
return dict(
(k, v if isinstance(v, (list, tuple)) else [v])
for k, v in value.items()
)
else:
return None