from __future__ import absolute_import
import logging
import os
from typing import Callable, Dict, List, Tuple, Union
from typing_inspect import get_origin
try:
from flask import current_app, request
from flask_restful import Resource
from werkzeug.exceptions import (BadRequest, Conflict, Forbidden,
HTTPException, NotFound, Unauthorized,
InternalServerError)
except ImportError: # pragma: no cover
raise ImportError('You must install flask to use the '
'doctor.flask module.')
from .constants import HTTP_METHODS_WITH_JSON_BODY
from .errors import (ForbiddenError, ImmutableError, InvalidValueError,
NotFoundError, TypeSystemError, UnauthorizedError)
from .parsers import map_param_names, parse_form_and_query_params
from .response import Response
from .routing import create_routes as doctor_create_routes
from .routing import Route
STATUS_CODE_MAP = {
'POST': 201,
'DELETE': 204,
}
ListOrNone = Union[List, None]
[docs]class SchematicHTTPException(HTTPException):
"""Schematic specific sub-class of werkzeug's BadRequest.
Note that this adds a flask-restful specific data attribute to the class,
as the error wouldn't render properly without it.
:param description: The error description.
:param errors: A dict containing all validation errors during the request.
The key is the param name and the value is the error message.
"""
def __init__(self, description: str=None, errors: dict=None):
super(SchematicHTTPException, self).__init__(description)
self.data = {'status': self.code, 'message': description}
self.errors = errors
def __str__(self):
return '%d: %s: %s' % (self.code, self.name, self.description)
[docs]class HTTP400Exception(SchematicHTTPException, BadRequest):
"""Represents a HTTP 400 error.
:param description: The error description.
:param errors: A dict containing all validation errors during the request.
The key is the param name and the value is the error message.
"""
pass
[docs]class HTTP401Exception(SchematicHTTPException, Unauthorized):
pass
[docs]class HTTP403Exception(SchematicHTTPException, Forbidden):
pass
[docs]class HTTP404Exception(SchematicHTTPException, NotFound):
pass
[docs]class HTTP409Exception(SchematicHTTPException, Conflict):
pass
[docs]class HTTP500Exception(SchematicHTTPException, InternalServerError):
pass
[docs]def should_raise_response_validation_errors() -> bool:
"""Returns if the library should raise response validation errors or not.
If the environment variable `RAISE_RESPONSE_VALIDATION_ERRORS` is set,
it will return True.
:returns: True if it should, False otherwise.
"""
return bool(os.environ.get('RAISE_RESPONSE_VALIDATION_ERRORS', False))
[docs]def handle_http(handler: Resource, args: Tuple, kwargs: Dict, logic: Callable):
"""Handle a Flask HTTP request
:param handler: flask_restful.Resource: An instance of a Flask Restful
resource class.
:param tuple args: Any positional arguments passed to the wrapper method.
:param dict kwargs: Any keyword arguments passed to the wrapper method.
:param callable logic: The callable to invoke to actually perform the
business logic for this request.
"""
try:
# We are checking mimetype here instead of content_type because
# mimetype is just the content-type, where as content_type can
# contain encoding, charset, and language information. e.g.
# `Content-Type: application/json; charset=UTF8`
if (request.mimetype == 'application/json' and
request.method in HTTP_METHODS_WITH_JSON_BODY):
# This is a proper typed JSON request. The parameters will be
# encoded into the request body as a JSON blob.
if not logic._doctor_req_obj_type:
request_params = map_param_names(
request.json, logic._doctor_signature.parameters)
else:
request_params = request.json
else:
# Try to parse things from normal HTTP parameters
request_params = parse_form_and_query_params(
request.values, logic._doctor_signature.parameters)
params = request_params
# Only filter out additional params if a req_obj_type was not specified.
if not logic._doctor_req_obj_type:
# Filter out any params not part of the logic signature.
all_params = logic._doctor_params.all
params = {k: v for k, v in params.items() if k in all_params}
params.update(**kwargs)
# Check for required params
missing = []
for required in logic._doctor_params.required:
if required not in params:
missing.append(required)
if missing:
verb = 'are'
if len(missing) == 1:
verb = 'is'
missing = missing[0]
error = '{} {} required.'.format(missing, verb)
raise InvalidValueError(error)
# Validate and coerce parameters to the appropriate types.
errors = {}
sig = logic._doctor_signature
# If a `req_obj_type` was defined for the route, pass all request
# params to that type for validation/coercion
if logic._doctor_req_obj_type:
annotation = logic._doctor_req_obj_type
try:
# NOTE: We calculate the value before applying native type in
# order to support UnionType types which dynamically modifies
# the native_type property based on the initialized value.
value = annotation(params)
params = annotation.native_type(value)
except TypeError:
logging.exception(
'Error casting and validating params with value `%s`.',
params)
raise
except TypeSystemError as e:
errors['__all__'] = e.detail
else:
for name, value in params.items():
annotation = sig.parameters[name].annotation
if annotation.nullable and value is None:
continue
try:
# NOTE: We calculate the value before applying native type
# in order to support UnionType types which dynamically
# modifies the native_type property based on the initialized
# value.
value = annotation(value)
params[name] = annotation.native_type(value)
except TypeSystemError as e:
errors[name] = e.detail
if errors:
raise TypeSystemError(errors, errors=errors)
if logic._doctor_req_obj_type:
# Pass any positional arguments followed by the coerced request
# parameters to the logic function.
response = logic(*args, params)
else:
# Only pass request parameters defined by the logic signature.
logic_params = {k: v for k, v in params.items()
if k in logic._doctor_params.logic}
response = logic(*args, **logic_params)
# response validation
if sig.return_annotation != sig.empty:
return_annotation = sig.return_annotation
_response = response
if isinstance(response, Response):
_response = response.content
# Check if our return annotation is a Response that supplied a
# type to validate against. If so, use that type for validation
# e.g. def logic() -> Response[MyType]
if ((get_origin(return_annotation) == Response) and
return_annotation.__args__ is not None):
return_annotation = return_annotation.__args__[0]
try:
return_annotation(_response)
except TypeSystemError as e:
response_str = str(_response)
logging.warning('Response to %s %s does not validate: %s.',
request.method, request.path,
response_str, exc_info=e)
if should_raise_response_validation_errors():
error = ('Response to {method} {path} `{response}` does not'
' validate: {error}'.format(
method=request.method, path=request.path,
response=response, error=e.detail))
raise TypeSystemError(error)
if isinstance(response, Response):
status_code = response.status_code
if status_code is None:
status_code = STATUS_CODE_MAP.get(request.method, 200)
return (response.content, status_code, response.headers)
return response, STATUS_CODE_MAP.get(request.method, 200)
except (InvalidValueError, TypeSystemError) as e:
errors = getattr(e, 'errors', None)
raise HTTP400Exception(e, errors=errors)
except UnauthorizedError as e:
raise HTTP401Exception(e)
except ForbiddenError as e:
raise HTTP403Exception(e)
except NotFoundError as e:
raise HTTP404Exception(e)
except ImmutableError as e:
raise HTTP409Exception(e)
except Exception as e:
# Always re-raise exceptions when DEBUG is enabled for development.
if current_app.config.get('DEBUG', False):
raise
allowed_exceptions = logic._doctor_allowed_exceptions
if allowed_exceptions and any(isinstance(e, cls)
for cls in allowed_exceptions):
raise
logging.exception(e)
raise HTTP500Exception('Uncaught error in logic function')
[docs]def create_routes(routes: Tuple[Route]) -> List[Tuple[str, Resource]]:
"""A thin wrapper around create_routes that passes in flask specific values.
:param routes: A tuple containing the route and another tuple with
all http methods allowed for the route.
:returns: A list of tuples containing the route and generated handler.
"""
return doctor_create_routes(
routes, handle_http, default_base_handler_class=Resource)