From d002c68fc2da41498baf5b268a772b4810cfd671 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E8=BF=90=E5=AE=B6?= Date: Mon, 30 Sep 2024 18:01:20 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9context=EF=BC=8C=E7=A7=BB?= =?UTF-8?q?=E9=99=A4WhiteContext?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- library/src/context.rs | 32 ++++++++++----------- library/src/middleware/req_ctx.rs | 24 ++++++++++++---- library/src/model/query_validator.rs | 12 ++++---- library/src/utils/request_util.rs | 1 - server/src/controller/account_controller.rs | 6 ++-- server/src/service/account_service.rs | 9 +++--- server/src/service/feedback_service.rs | 4 +-- server/src/service/sys_account_service.rs | 4 +-- 8 files changed, 50 insertions(+), 42 deletions(-) diff --git a/library/src/context.rs b/library/src/context.rs index 16c32d7..f62400c 100644 --- a/library/src/context.rs +++ b/library/src/context.rs @@ -4,30 +4,28 @@ use domain::entities::account::Account; #[derive(Debug, Clone)] pub struct Context { - pub account: Arc, - pub token: Arc, + pub account: Option>, + pub token: Option>, + pub lang_tag: Arc, } impl Context { - pub fn get_account(&self) -> &Account { - &self.account + pub fn get_account(&self) -> Option> { + if let Some(account) = &self.account { + return Some(account.clone()); + } else { + None + } } - pub fn get_token(&self) -> &String { - &self.token + pub fn get_token(&self) -> Option> { + if let Some(token) = &self.token { + return Some(token.clone()); + } else { + None + } } - pub fn get_lang_tag(&self) -> &String { - &self.account.lang_tag - } -} - -#[derive(Debug, Clone)] -pub struct WhiteContext { - pub lang_tag: String, -} - -impl WhiteContext { pub fn get_lang_tag(&self) -> &String { &self.lang_tag } diff --git a/library/src/middleware/req_ctx.rs b/library/src/middleware/req_ctx.rs index 3ff8df8..5f8e024 100644 --- a/library/src/middleware/req_ctx.rs +++ b/library/src/middleware/req_ctx.rs @@ -1,9 +1,11 @@ +use std::sync::Arc; + use axum::{extract::Request, middleware::Next, response::{IntoResponse, Response}}; use http::{header, StatusCode}; use i18n::{message, message_ids::MessageId}; use jsonwebtoken::{decode, DecodingKey, Validation}; -use crate::{cache::account_cache::LOGIN_CACHE, config, context::{Context, WhiteContext}, model::response::ResErr, token::Claims, utils::request_util}; +use crate::{cache::account_cache::LOGIN_CACHE, config, context::Context, model::response::ResErr, token::Claims, utils::request_util}; const WHITE_LIST: &[(&str, &str)] = &[ ("POST", "/account/sys"), @@ -12,9 +14,6 @@ const WHITE_LIST: &[(&str, &str)] = &[ /// 认证中间件,包括网络请求白名单、token验证、登录缓存 pub async fn authenticate_ctx(mut req: Request, next: Next) -> Response { - // 解析语言 - let language = request_util::get_lang_tag(req.headers()); - req.extensions_mut().insert(WhiteContext { lang_tag: language.clone() }); // 获取请求的url和method,然后判断是否在白名单中,如果在白名单中,则直接返回next(req),否则继续执行下面的代码 let method = req.method().clone().to_string(); let mut uri = req.uri().path_and_query().unwrap().to_string(); @@ -23,6 +22,9 @@ pub async fn authenticate_ctx(mut req: Request, next: Next) -> Response { if WHITE_LIST.into_iter().find(|item| { return item.0 == method && item.1 == uri; }).is_some() { + // 解析语言 + let language = request_util::get_lang_tag(req.headers()); + req.extensions_mut().insert(Context { lang_tag: Arc::new(language.clone()), account: None, token: None }); return next.run(req).await; } @@ -33,12 +35,16 @@ pub async fn authenticate_ctx(mut req: Request, next: Next) -> Response { let parts: Vec<&str> = header_value.to_str().unwrap_or("").split_whitespace().collect(); if parts.len() != 2 || parts[0] != "Bearer" { tracing::error!("无效的 authorization 请求头参数"); + // 解析语言 + let language = request_util::get_lang_tag(req.headers()); return ResErr::params(message!(&language, MessageId::BadRequest)).into_response(); } parts[1] }, None => { tracing::error!("缺少 authorization 请求头参数"); + // 解析语言 + let language = request_util::get_lang_tag(req.headers()); return ResErr::auth(message!(&language, MessageId::BadRequest)).into_response() }, }; @@ -50,6 +56,8 @@ pub async fn authenticate_ctx(mut req: Request, next: Next) -> Response { let account = LOGIN_CACHE.get(&decoded.claims.sub).await; if account.is_none() { tracing::error!("无效的 token"); + // 解析语言 + let language = request_util::get_lang_tag(req.headers()); return ResErr::auth(message!(&language, MessageId::BadRequest)).into_response(); } let account = account.unwrap(); @@ -58,16 +66,20 @@ pub async fn authenticate_ctx(mut req: Request, next: Next) -> Response { // if account.token != token { // return (StatusCode::UNAUTHORIZED, "Invalid token".to_string()).into_response(); // } + let language = account.account.clone().lang_tag.clone(); // 将Claims附加到请求扩展中,以便后续处理使用 req.extensions_mut().insert( Context { - account: account.account.clone(), - token: account.token.clone() + account: Some(account.account.clone()), + token: Some(account.token.clone()), + lang_tag: Arc::new(language) }); next.run(req).await }, Err(_) => { tracing::error!("无效的 token"); + // 解析语言 + let language = request_util::get_lang_tag(req.headers()); return ResErr::auth(message!(&language, MessageId::BadRequest)).into_response(); } } diff --git a/library/src/model/query_validator.rs b/library/src/model/query_validator.rs index f33f57f..ad5fc54 100644 --- a/library/src/model/query_validator.rs +++ b/library/src/model/query_validator.rs @@ -19,21 +19,19 @@ where type Rejection = ResErr; async fn from_request(req: http::Request, state: &S) -> Result { + // let context = req.extensions().get().unwrap(); + let (parts, body) = req.into_parts(); let header = &parts.headers; let query = Query::::from_request(Request::from_parts(parts.clone(), body), state).await; + let lang_tag = request_util::get_lang_tag(header); if let Ok(Query(data)) = query { - match data.validate() { + match super::validator::validate_params(&data, &lang_tag) { Ok(_) => Ok(QueryValidator(data)), - Err(_) => { - let lang_tag = request_util::get_lang_tag(header); - let err = Err(ResErr::params(message!(&lang_tag, MessageId::InvalidParams))); - err - }, + Err(err) => Err(err), } } else { - let lang_tag = request_util::get_lang_tag(header); let err = Err(ResErr::params(message!(&lang_tag, MessageId::InvalidParams))); err } diff --git a/library/src/utils/request_util.rs b/library/src/utils/request_util.rs index 7ddc573..e56b537 100644 --- a/library/src/utils/request_util.rs +++ b/library/src/utils/request_util.rs @@ -1,4 +1,3 @@ -use axum::extract::Request; use http::{header, HeaderMap, HeaderValue}; /// 获取请求的语言 diff --git a/server/src/controller/account_controller.rs b/server/src/controller/account_controller.rs index fb83279..4f7171c 100644 --- a/server/src/controller/account_controller.rs +++ b/server/src/controller/account_controller.rs @@ -1,6 +1,6 @@ use axum::{Extension, Json}; use domain::{dto::account::{AuthenticateGooleAccountReq, AuthenticateWithPassword, RefreshToken}, vo::account::{LoginAccount, RefreshTokenResult}}; -use library::{context::{Context, WhiteContext}, model::{response::ResResult, validator}}; +use library::{context::Context, model::{response::ResResult, validator}}; use crate::service; @@ -8,7 +8,7 @@ use crate::service; /// /// google账号登录 pub async fn authenticate_google( - Extension(context): Extension, + Extension(context): Extension, Json(req): Json ) -> ResResult { validator::validate_params(&req, context.get_lang_tag())?; @@ -19,7 +19,7 @@ pub async fn authenticate_google( /// /// 账号密码登录 pub async fn authenticate_with_password( - Extension(context): Extension, + Extension(context): Extension, Json(req): Json ) -> ResResult { validator::validate_params(&req, context.get_lang_tag())?; diff --git a/server/src/service/account_service.rs b/server/src/service/account_service.rs index 6bdfbda..0349354 100644 --- a/server/src/service/account_service.rs +++ b/server/src/service/account_service.rs @@ -7,7 +7,7 @@ use domain::vo::account::{LoginAccount, RefreshTokenResult}; use i18n::message; use i18n::message_ids::MessageId; use library::cache::account_cache::{CacheAccount, LOGIN_CACHE}; -use library::context::{Context, WhiteContext}; +use library::context::Context; use library::model::response::ResErr::ErrPerm; use library::model::response::{ResErr, ResResult}; use library::social::google::GOOGLE_SOCIAL; @@ -15,7 +15,7 @@ use library::token::{generate_refresh_token, generate_token}; use library::{db, token}; pub async fn authenticate_google( - context: WhiteContext, + context: Context, req: AuthenticateGooleAccountReq, ) -> ResResult { let verify_result = GOOGLE_SOCIAL @@ -100,11 +100,12 @@ pub async fn refresh_token( context: Context, refresh_token: String, ) -> ResResult { - let account = context.account.clone(); + let lang = context.get_lang_tag(); + let account = context.get_account().unwrap(); if token::verify_refresh_token(&refresh_token).is_err() { return Err(ResErr::params(message!( - context.get_lang_tag(), + lang, MessageId::InvalidToken ))); } diff --git a/server/src/service/feedback_service.rs b/server/src/service/feedback_service.rs index a9fbb83..eef7fe2 100644 --- a/server/src/service/feedback_service.rs +++ b/server/src/service/feedback_service.rs @@ -11,7 +11,7 @@ pub async fn get_feedback_list_by_page( page: i64, page_size: i64, ) -> ResResult { - if !context.account.role.is_admin() { + if !context.account.unwrap().role.is_admin() { tracing::error!("非管理员用户,无法获取反馈信息列表"); return Ok(FeedbackPageable::empty(page, page_size)); } @@ -41,7 +41,7 @@ async fn get_feedback_count() -> i64 { /// 添加反馈信息 pub async fn add_feedback(context: Context, req: FeedbackAdd) -> ResResult<()> { - let account = context.account; + let account = context.account.unwrap(); let mut transaction = db!().begin().await?; match Feedback::add_feedback( &mut Feedback { diff --git a/server/src/service/sys_account_service.rs b/server/src/service/sys_account_service.rs index f2fa130..5bcd310 100644 --- a/server/src/service/sys_account_service.rs +++ b/server/src/service/sys_account_service.rs @@ -11,12 +11,12 @@ use i18n::{ message_ids::MessageId, }; use library::{ - cache::account_cache::{CacheAccount, LOGIN_CACHE}, context::WhiteContext, db, model::response::{ResErr, ResResult}, token::{generate_refresh_token, generate_token} + cache::account_cache::{CacheAccount, LOGIN_CACHE}, context::Context, db, model::response::{ResErr, ResResult}, token::{generate_refresh_token, generate_token} }; /// 登录, 使用账号和密码 pub async fn authenticate_with_password( - context: WhiteContext, + context: Context, req: AuthenticateWithPassword, ) -> ResResult { let account =