$ ls ~yifei/notes/

FastAPI 实现 OAuth2 登录认证

Posted on:

Last modified:

一个传统的 Web 网站只需要实现"登录表单+设置 Cookie"的登录方式。在 FastAPI 中,一个后端 可能要面对多种前端,比如 Web 端和 Android/iOS 端,可能还要实现用微信、微博登录等第三方 登录方式。这时候可以考虑一些更成熟的登录协议,比如 OAuth2.

常见的登录协议有:OAuth2、OAuth1、Open ID Connect、Open ID 等。根据 FastAPI 官方文档总结, 实际上只有 OAuth2 是比较常用的。

如果只是第一方应用,OAuth2 的 password flow 也很简单:客户端向服务器发送用户名和密码请求 登录,然后服务器返回一个 token,客户端再次请求的时候在 Authorization Header 中附上这个 token。这和传统的 Cookie 认证登录几乎是一样的,只不过浏览器在请求时会自动附上 Cookie,而 Authorization Header 需要手动添加。另外,在手机应用中就一样了,不管 Cookie 还是 Authorization 都需要自己添加 Header。

在 FastAPI 中使用 OAuth2 相对于自己实现认证登录还有一点好处——会生成自动的 OpenAPI 文档, Swagger 中会自带一个登录表单并在当前页面的请求中都附上 token,方便调试。

至于 access_token 的格式,一般使用 jwt(json web tokens),也是一种通用的格式。需要如下库:

pip install python-jose[cryptography]
pip install passlib[bcrypt]

FastAPI 官方文档中的例子并不很完备,没有验证 access token 是否过期等,我改了一下,下面是 一个完整带注释的例子。

import time
from datetime import timedelta

from fastapi import Depends, FastAPI, HTTPException, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from passlib.context import CryptContext
from jose import JWTError, jwt

# openssl rand -hex 32 生成 secret_key
SECRET_KEY = "09d25e094faa6ca2556c818166b7a9563b93f7099f6f0f4caa6cf63b88e8d3e7"
# 加密算法
ALGORITHM = "HS256"
# 默认过期时间
ACCESS_TOKEN_EXPIRE_MINUTES = 30
# 加盐
SALT = "10dd0cd14462a4dda5a6a3ec4b71f2e0"

app = FastAPI()

# 这里是 tokenUrl,而不是 token_url,是为了和 OAuth2 规范统一
# tokenUrl 是为了指定 OpenAPI 前端登录时的接口,在自己的程序中并无用处
# OAuthPasswordBearer 实现的功能很简单,只是把 Authorization Header 的 Bearer 取出来罢了
oauth2_bearer = OAuth2PasswordBearer(tokenUrl="/token")
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")

class User(BaseModel):
    username: str
    password: str
    email: Optional[str] = None
    full_name: Optional[str] = None
    disabled: Optional[bool] = None

class Token(BaseModel):
    access_token: str
    token_type: str

def unauthorized(detail: str) -> HTTPException:
    """返回一个 401 错误"""
    return HTTPException(status_code=401, detail=detail, headers={"WWW-Authenticate": "Bearer"})

def _read_user(username) -> User:
    """从数据库中读取用户信息"""
    # read user from db
    return User.get_by_username(username)

def auth(username: str, password: str) -> User:
    user = _read_user(username)
    if not user:
        return None
    # 密码要加盐,这是为了防范彩虹表攻击
    if not pwd_context.verify(password + SALT, user.password):
        return None
    return user

# 如果只是验证一下 token,可能得到的信息还不够,可以再定义一个依赖,然后在其中读取用户。
async def dep_user(token: str = Depends(oauth2_bearer)):
    error = unauthorized("Could not validate access token")
    try:
        payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
        username = payload.get("sub")
        if not username:
            raise error
        # 验证 token 是否过期
        expires = payload.get("exp")
        if expires < time.time():
            raise unauthorized("Token has expired")
    except JWTError:
        raise error
    user = _read_user(username)
    if user is None:
        raise error
    return user

def _create_token(data: dict, expires: timedelta = timedelta(minutes=60)) -> str:
    data = {**data, "exp": time.time() + expires.seconds}
    token = jwt.encode(data, SECRET_KEY, algorithm=ALGORITHM)
    return token

@app.post("/token", response_model=Token)
async def login(form: OAuth2PasswordRequestForm = Depends()):
    user = auth(form.username, form.password)
    if not user:
        raise unauthorized("Incorrect username or password")
    token = _create_token({"sub": user.username}, timedelta(minutes=ACCESS_TOKEN_EXPIRES_MINUTES))
    return {"access_token": token, "token_type": "bearer"}

# exclude 去掉 password
@app.get("/users/me", response_model=User, response_model_exclude=["password"])
async def my_info(user: User = Depends(dep_user)):
    return user

参考

  1. https://fastapi.tiangolo.com/tutorial/security/oauth2-jwt/
WeChat Qr Code

© 2016-2022 Yifei Kong. Powered by ynotes

All contents are under the CC-BY-NC-SA license, if not otherwise specified.

Opinions expressed here are solely my own and do not express the views or opinions of my employer.

友情链接: MySQL 教程站