92 lines
4.1 KiB
Rust
92 lines
4.1 KiB
Rust
use std::sync::Arc;
|
||
|
||
use axum::{extract::Request, middleware::Next, response::{IntoResponse, Response}};
|
||
use http::header;
|
||
use i18n::{message, message_ids::MessageId};
|
||
use jsonwebtoken::{decode, DecodingKey, Validation};
|
||
|
||
use crate::{cache::account_cache::LOGIN_CACHE, config, context::Context, model::response::ResErr, token::Claims, utils::request_util};
|
||
|
||
const WHITE_LIST: &[(&str, &str)] = &[
|
||
("POST", "/account/sys"),
|
||
("POST", "/account/google"),
|
||
("GET", "/wechat/access_token"),
|
||
("POST", "/wechat/code_2_session"),
|
||
("POST", "/wechat/check_session"),
|
||
];
|
||
|
||
/// 认证中间件,包括网络请求白名单、token验证、登录缓存
|
||
pub async fn authenticate_ctx(mut req: Request, next: Next) -> Response {
|
||
// 获取请求的url和method,然后判断是否在白名单中,如果在白名单中,则直接返回next(req),否则继续执行下面的代码
|
||
let method = req.method().clone().to_string();
|
||
let mut uri = req.uri().path_and_query().unwrap().to_string();
|
||
uri = uri.replace(&config!().server.prefix_url, "");
|
||
tracing::debug!("请求路径: {}", uri);
|
||
if WHITE_LIST.into_iter().find(|item| {
|
||
return item.0 == method && item.1 == uri;
|
||
}).is_some() {
|
||
// 解析语言
|
||
let language = request_util::get_lang_tag(req.headers());
|
||
req.extensions_mut().insert(Context { lang_tag: Arc::new(language.clone()), account: None, token: None });
|
||
return next.run(req).await;
|
||
}
|
||
|
||
// 解析token
|
||
let auth_header = req.headers().get(header::AUTHORIZATION);
|
||
let token = match auth_header {
|
||
Some(header_value) => {
|
||
let parts: Vec<&str> = header_value.to_str().unwrap_or("").split_whitespace().collect();
|
||
if parts.len() != 2 || parts[0] != "Bearer" {
|
||
tracing::error!("无效的 authorization 请求头参数");
|
||
// 解析语言
|
||
let language = request_util::get_lang_tag(req.headers());
|
||
return ResErr::params(message!(&language, MessageId::BadRequest)).into_response();
|
||
}
|
||
parts[1]
|
||
},
|
||
None => {
|
||
tracing::error!("缺少 authorization 请求头参数");
|
||
// 解析语言
|
||
let language = request_util::get_lang_tag(req.headers());
|
||
return ResErr::auth(message!(&language, MessageId::BadRequest)).into_response()
|
||
},
|
||
};
|
||
|
||
let validation = Validation::default();
|
||
match decode::<Claims>(token, &DecodingKey::from_secret(config!().jwt.token_secret.as_bytes()), &validation) {
|
||
Ok(decoded) => {
|
||
// 从缓存中获取当前用户信息
|
||
let account = LOGIN_CACHE.get(&decoded.claims.sub).await;
|
||
if account.is_none() {
|
||
tracing::error!("无效的 token");
|
||
// 解析语言
|
||
let language = request_util::get_lang_tag(req.headers());
|
||
return ResErr::auth(message!(&language, MessageId::BadRequest)).into_response();
|
||
}
|
||
let account = account.unwrap();
|
||
// 判断token是否有效(注释掉,如果服务因为升级等原因手动重启了,缓存的数据也不再存在)
|
||
// let account = account.unwrap();
|
||
// if account.token != token {
|
||
// return (StatusCode::UNAUTHORIZED, "Invalid token".to_string()).into_response();
|
||
// }
|
||
let mut language = account.account.clone().lang_tag.clone();
|
||
if language.is_empty() {
|
||
language = request_util::get_lang_tag(req.headers());
|
||
}
|
||
// 将Claims附加到请求扩展中,以便后续处理使用
|
||
req.extensions_mut().insert(
|
||
Context {
|
||
account: Some(account.account.clone()),
|
||
token: Some(account.token.clone()),
|
||
lang_tag: Arc::new(language)
|
||
});
|
||
next.run(req).await
|
||
},
|
||
Err(_) => {
|
||
tracing::error!("无效的 token");
|
||
// 解析语言
|
||
let language = request_util::get_lang_tag(req.headers());
|
||
return ResErr::auth(message!(&language, MessageId::BadRequest)).into_response();
|
||
}
|
||
}
|
||
} |