use std::sync::Arc; use axum::{extract::Request, middleware::Next, response::{IntoResponse, Response}}; use http::header; use i18n::{message, message_ids::MessageId}; use jsonwebtoken::{decode, DecodingKey, Validation}; use crate::{cache::inner_cache::LOGIN_ACCOUNT_CACHE, config, context::Context, model::response::ResErr, token::Claims, utils::request_util}; const WHITE_LIST: &[(&str, &str)] = &[ ("GET", "/"), ("POST", "/account/sys"), ("POST", "/account/google"), ("GET", "/wechat/access_token"), ("POST", "/wechat/code_2_session"), ("POST", "/wechat/check_session"), ]; /// 认证中间件,包括网络请求白名单、token验证、登录缓存 pub async fn authenticate_ctx(mut req: Request, next: Next) -> Response { // 获取请求的url和method,然后判断是否在白名单中,如果在白名单中,则直接返回next(req),否则继续执行下面的代码 let method = req.method().clone().to_string(); let mut uri = req.uri().path_and_query().unwrap().to_string(); uri = uri.replace(&config!().server.prefix_url, ""); tracing::debug!("请求路径: {}", uri); if WHITE_LIST.into_iter().find(|item| { return item.0 == method && item.1 == uri; }).is_some() { // 解析语言 let language = request_util::get_lang_tag(req.headers()); req.extensions_mut().insert(Context { lang_tag: Arc::new(language.clone()), account: None, token: None }); return next.run(req).await; } // 解析token let auth_header = req.headers().get(header::AUTHORIZATION); let token = match auth_header { Some(header_value) => { let parts: Vec<&str> = header_value.to_str().unwrap_or("").split_whitespace().collect(); if parts.len() != 2 || parts[0] != "Bearer" { tracing::error!("无效的 authorization 请求头参数"); // 解析语言 let language = request_util::get_lang_tag(req.headers()); return ResErr::params(message!(&language, MessageId::BadRequest)).into_response(); } parts[1] }, None => { tracing::error!("缺少 authorization 请求头参数"); // 解析语言 let language = request_util::get_lang_tag(req.headers()); return ResErr::auth(message!(&language, MessageId::BadRequest)).into_response() }, }; let validation = Validation::default(); match decode::(token, &DecodingKey::from_secret(config!().jwt.token_secret.as_bytes()), &validation) { Ok(decoded) => { // 从缓存中获取当前用户信息 let account = LOGIN_ACCOUNT_CACHE.get(&decoded.claims.sub).await; if account.is_none() { tracing::error!("无效的 token, 无缓存的登陆用户信息"); // 解析语言 let language = request_util::get_lang_tag(req.headers()); return ResErr::auth(message!(&language, MessageId::BadRequest)).into_response(); } let account = account.unwrap(); // 判断token是否有效 if account.token != Arc::new(String::from(token)) { tracing::error!("无效的 token, 缓存的登陆用户信息和token不一致"); return (hyper::StatusCode::UNAUTHORIZED, "Invalid token".to_string()).into_response(); } let mut language = account.account.clone().lang_tag.clone(); if language.is_empty() { language = request_util::get_lang_tag(req.headers()); } // 将Claims附加到请求扩展中,以便后续处理使用 req.extensions_mut().insert( Context { account: Some(account.account.clone()), token: Some(account.token.clone()), lang_tag: Arc::new(language) }); next.run(req).await }, Err(_) => { tracing::error!("无效的 token, 解析失败"); // 解析语言 let language = request_util::get_lang_tag(req.headers()); return ResErr::auth(message!(&language, MessageId::BadRequest)).into_response(); } } }