Flask Part 3: API Decorators and Helpers

This is the third of three posts about building a JSON API with Flask.
Make sure you start with part 1 and part 2.

In the first post, we used a custom base SQLAlchemy class to serialize and deserialize database models to and from JSON. The second post created a RESTful API listing and updating users. Now we’re ready for extra API features such as caching responses, rate limiting, and preventing brute forcing. Also, we add a helper method to our BaseModel class from the first post for fetching rows or creating them when they don’t already exist.

Decorators

Use these decorators to add extra functionality to your API view functions.

For example, to rate limit and cache your users api resource:

@app.route("/api/users")
@rate_limited
@cached
def users():
    return json.dumps([user.to_dict() for user in User.query.all()])

Order of decorators matters. Since decorators are executed downwards, make sure @cached is always below @rate_limited and @protected.

To customize decorator options, pass arguments to the decorators:

@app.route("/api/users")
@rate_limited(limit=50, minutes=60)  # only 50 requests from user/ip to this endpoint allowed per hour
@protected(limit=2, minutes=720)  # only 2 404 requests from same ip to this endpoint allowed per 12 hours
@cached(minutes=5)  # response cached for 5 minutes
def users():
    return json.dumps([user.to_dict() for user in User.query.all()])

Cache an API response with memcached

import hashlib
import memcache
import traceback
from flask import request
from functools import wraps
from wakatime_website import app
from werkzeug.contrib.cache import MemcachedCache

mc = memcache.Client()
cache = MemcachedCache(mc)

def cached(fn=None, unique_per_user=True, minutes=30):
    """Caches a Flask route/view in memcached.

    The request url, args, and current user are used to build the cache key.
    Only GET requests are cached.
    By default, cached requests expire after 30 minutes.
    """

    if not isinstance(minutes, int):
        raise Exception('Minutes must be an integer number.')

    def wrapper(func):
        @wraps(func)
        def inner(*args, **kwargs):
            if request.method != 'GET':
                return func(*args, **kwargs)

            prefix = 'flask-request'
            path = request.full_path
            user_id = app.current_user.id if app.current_user.is_authenticated else None
            key = '{user}-{method}-{path}'.format(
                user=user_id,
                method=request.method,
                path=path,
            )
            hashed = hashlib.md5(key.encode('utf8')).hexdigest()
            hashed = '{prefix}-{hashed}'.format(prefix=prefix, hashed=hashed)

            try:
                resp = cache.get(hashed)
                if resp:
                    return resp
            except:
                app.logger.error(traceback.format_exc())
                resp = None

            resp = func(*args, **kwargs)
            try:
                cache.set(hashed, resp, timeout=minutes * 60)
            except:
                app.logger.error(traceback.format_exc())
            return resp

        return inner
    return wrapper(fn) if fn else wrapper

Rate limit API requests by IP or Current User

import redis
import traceback
from flask import abort, request
from functools import wraps
from wakatime_website import app

r = redis.Redis(decode_responses=True)

def rate_limited(fn=None, limit=20, methods=[], ip=True, user=True, minutes=1):
    """Limits requests to this endpoint to `limit` per `minutes`."""

    if not isinstance(limit, int):
        raise Exception('Limit must be an integer number.')
    if limit < 1:
        raise Exception('Limit must be greater than zero.')

    def wrapper(func):
        @wraps(func)
        def inner(*args, **kwargs):
            if not methods or request.method in methods:

                if ip:
                    increment_counter(type='ip', for_methods=methods,
                                      minutes=minutes)
                    count = get_count(type='ip', for_methods=methods)
                    if count > limit:
                        abort(429)

                    if user and app.current_user.is_authenticated:
                        increment_counter(type='user', for_methods=methods,
                                          minutes=minutes)
                        count = get_count(type='user', for_methods=methods)
                        if count > limit:
                            abort(429)

            return func(*args, **kwargs)

        return inner
    return wrapper(fn) if fn else wrapper

def get_counter_key(type=None, for_only_this_route=True, for_methods=None): if not isinstance(for_methods, list):
        for_methods = []
    if type == 'ip':
        key = request.remote_addr
    elif type == 'user':
        key = app.current_user.id if app.current_user.is_authenticated else None
    else:
        raise Exception('Unknown rate limit type: {0}'.format(type))
    route = ''
    if for_only_this_route:
        route = '{endpoint}'.format(
            endpoint=request.endpoint,
        )
    return '{type}-{methods}-{key}{route}'.format(
        type=type,
        key=key,
        methods=','.join(for_methods),
        route=route,
    )

def increment_counter(type=None, for_only_this_route=True, for_methods=None,
                      minutes=1):
    if type not in ['ip', 'user']:
        raise Exception('Type must be ip or user.')

    key = get_counter_key(type=type, for_only_this_route=for_only_this_route,
                          for_methods=for_methods)
    try:
        r.incr(key)
        r.expire(key, time=60 * minutes)
    except:
        app.logger.error(traceback.format_exc())
        pass

def get_count(type=None, for_only_this_route=True, for_methods=None):
    key = get_counter_key(type=type, for_only_this_route=for_only_this_route,
                          for_methods=for_methods)
    try:
        return int(r.get(key) or 0)
    except:
        app.logger.error(traceback.format_exc())
        return 0

Prevent brute forcing secrets or tokens

import redis
import traceback
from flask import request
from functools import wraps
from wakatime_website import app
from werkzeug.exceptions import NotFound

r = redis.Redis(decode_responses=True)

def protected(fn=None, limit=10, minutes=60):
    """Bans IP after requesting a protected resource too many times.

    Prevents IP from making more than `limit` requests per `minutes` to
    the decorated route. Prevents enumerating secrets or tokens from urls or
    query arguments by blocking requests after too many 404 not found errors.
    """

    if not isinstance(limit, int):
        raise Exception('Limit must be an integer number.')
    if not isinstance(minutes, int):
        raise Exception('Minutes must be an integer number.')

    def wrapper(func):
        @wraps(func)
        def inner(*args, **kwargs):
            key = 'bruteforce-{}-{}'.format(request.endpoint, request.remote_addr)
            try:
                count = int(r.get(key) or 0)
                if count > limit:
                    r.incr(key)
                    seconds = 60 * minutes
                    r.expire(key, time=seconds)
                    app.logger.info('Request blocked by protected decorator.')
                    return '404', 404
            except:
                app.logger.error(traceback.format_exc())

            try:
                result = func(*args, **kwargs)
            except NotFound:
                try:
                    r.incr(key)
                    seconds = 60 * minutes
                    r.expire(key, time=seconds)
                except:
                    pass
                raise

            if isinstance(result, tuple) and len(result) > 1 and result[1] == 404:
                try:
                    r.incr(key)
                    seconds = 60 * minutes
                    r.expire(key, time=seconds)
                except:
                    pass

            return result

        return inner
    return wrapper(fn) if fn else wrapper

Get or Create SQLAlchemy Helper

A common pattern is to fetch a User from the database, creating one if necessary, then update some attributes on that user and save back to the database. With SQLAlchemy, this looks something like:

user = User.query.filter_by(username="zzzeek").first()
if not user:
    user = User(username="zzzeek")
    db.session.add(user)
user.first_name = "Mike"
db.session.commit()

This works, but means repeating boilerplate code. It’s also prone to errors from race conditions, when a user is created with the same username after the first query but before the commit.

A better way is adding a get_or_create convenience method to the BaseModel SQLAlchemy class from the previous post:

from sqlalchemy.exc import IntegrityError, OperationalError

class BaseModel(db.Model):
    __abstract__ = True

    ...

    @classmethod
    def _get_or_create(
        cls,
        _session=None,
        _filters=None,
        _defaults={},
        _retry_count=0,
        _max_retries=3,
        **kwargs
    ):
        if not _session:
            _session = db.session
        query = _session.query(cls)
        if _filters is not None:
            query = query.filter(*_filters)
        if len(kwargs) > 0:
            query = query.filter_by(**kwargs)

        instance = query.first()
        if instance is not None:
            return instance, False

        _session.begin_nested()
        try:
            kwargs.update(_defaults)
            instance = cls(**kwargs)
            _session.add(instance)
            _session.commit()
            return instance, True

        except IntegrityError:
            _session.rollback()
            instance = query.first()
            if instance is None:
                raise
            return instance, False

        except OperationalError:
            _session.rollback()
            instance = query.first()
            if instance is None:
                if _retry_count < _max_retries:
                    return cls._get_or_create(
                        _filters=_filters,
                        _defaults=_defaults,
                        _retry_count=_retry_count + 1,
                        _max_retries=_max_retries,
                        **kwargs
                    )
                raise
            return instance, False

    @classmethod
    def get_or_create(cls, **kwargs):
        return cls._get_or_create(**kwargs)[0]

Now using this helper, our above code becomes:

updates = {"first_name": "Mike"}
user = User.get_or_create(username="zzzeek", _defaults=updates)
user.from_dict(**updates)
db.session.commit()

Authentication, Permissions, and Access Control

Authentication is handled by Flask-Login. If your API needs access control, use OAuth with Flask-Login's custom request loader. Then decorate your API views with the OAuth scopes required by the currently authenticated user:

@app.route("/api/users")
def oauth(required_scopes=['user:read']):
def users():
    return json.dumps([user.to_dict() for user in User.query.all()])

This provides resource-level access control not column/table level, but it has worked well for WakaTime’s public api.

Schema Migrations with Alembic

Alembic is a schema migration tool written by the author of SQLAlchemy. It automatically detects when you change your SQLAlchemy models, generating corresponding ALTER, CREATE TABLE, etc. statements in a versioned migration file. The versioned files use hashes instead of auto-incrementing integers, so your team won’t ever run into conflicts where two devs create different schema migrations with the same version number. I highly recommend using Alembic to manage your database schema changes with SQLAlchemy.

Conclusion

Hopefully these patterns and base methods will make creating APIs with Flask a breeze!

By the way, WakaTime is built with Flask along with these patterns ;)

Tags in this article: