增加网络请求中query参数的校验处理

This commit is contained in:
李运家 2024-09-30 09:46:55 +08:00
parent f34062e162
commit 94cb27d1e7
10 changed files with 127 additions and 18 deletions

View File

@ -6,10 +6,11 @@ AccountNoPermission,account has no permission,账户无权限
IncorrectUsernameOrPassword,incorrect username or password,用户名或密码错误 IncorrectUsernameOrPassword,incorrect username or password,用户名或密码错误
InvalidToken,invalid token,无效令牌 InvalidToken,invalid token,无效令牌
ValidateFeedbackContentRequired,feedback content is required,反馈内容不能为空 ValidateFeedbackContentRequired,feedback content is required,反馈内容不能为空
ValidateAccountNameRequired,username is required,"用户名称不能为空" ValidateAccountNameRequired,username is required,用户名称不能为空
ValidateAccountPasswordRequired,password is required,密码不能为空 ValidateAccountPasswordRequired,password is required,密码不能为空
ValidateAccountIdTokenRequired,ID Token is required,用户ID Token不能为空 ValidateAccountIdTokenRequired,ID Token is required,用户ID Token不能为空
ValidateAccountLangTagRequired,lang tag is required,用户语言标识不能为空 ValidateAccountLangTagRequired,lang tag is required,用户语言标识不能为空
ValidatePageablePageRequired,invalid page number,页码无效 ValidatePageablePageRequired,invalid page number,页码无效
ValidatePageablePageSizeRequired,invalid quantity per page,每页数量无效 ValidatePageablePageSizeRequired,invalid quantity per page,每页数量无效
BadRequest,bad request,无效请求 BadRequest,bad request,无效请求
InvalidParams,invalid params,无效参数
1 id en-US zh-CN
6 IncorrectUsernameOrPassword incorrect username or password 用户名或密码错误
7 InvalidToken invalid token 无效令牌
8 ValidateFeedbackContentRequired feedback content is required 反馈内容不能为空
9 ValidateAccountNameRequired username is required 用户名称不能为空
10 ValidateAccountPasswordRequired password is required 密码不能为空
11 ValidateAccountIdTokenRequired ID Token is required 用户ID Token不能为空
12 ValidateAccountLangTagRequired lang tag is required 用户语言标识不能为空
13 ValidatePageablePageRequired invalid page number 页码无效
14 ValidatePageablePageSizeRequired invalid quantity per page 每页数量无效
15 BadRequest bad request 无效请求
16 InvalidParams invalid params 无效参数

View File

@ -10,6 +10,7 @@ pub enum MessageId {
AccountNoPermission, AccountNoPermission,
IncorrectUsernameOrPassword, IncorrectUsernameOrPassword,
InvalidToken, InvalidToken,
InvalidParams,
ValidateFeedbackContentRequired, ValidateFeedbackContentRequired,
ValidateAccountNameRequired, ValidateAccountNameRequired,

View File

@ -8,3 +8,4 @@ pub mod social;
pub mod cache; pub mod cache;
pub mod context; pub mod context;
pub mod task; pub mod task;
pub mod utils;

View File

@ -3,7 +3,7 @@ use http::{header, StatusCode};
use i18n::{message, message_ids::MessageId}; use i18n::{message, message_ids::MessageId};
use jsonwebtoken::{decode, DecodingKey, Validation}; use jsonwebtoken::{decode, DecodingKey, Validation};
use crate::{cache::account_cache::LOGIN_CACHE, config, context::{Context, WhiteContext}, model::response::ResErr, token::Claims}; use crate::{cache::account_cache::LOGIN_CACHE, config, context::{Context, WhiteContext}, model::response::ResErr, token::Claims, utils::request_util};
const WHITE_LIST: &[(&str, &str)] = &[ const WHITE_LIST: &[(&str, &str)] = &[
("POST", "/account/sys"), ("POST", "/account/sys"),
@ -13,19 +13,7 @@ const WHITE_LIST: &[(&str, &str)] = &[
/// 认证中间件包括网络请求白名单、token验证、登录缓存 /// 认证中间件包括网络请求白名单、token验证、登录缓存
pub async fn authenticate_ctx(mut req: Request, next: Next) -> Response { pub async fn authenticate_ctx(mut req: Request, next: Next) -> Response {
// 解析语言 // 解析语言
let mut language = String::from("zh-CN"); let language = request_util::get_lang_tag(req.headers());
let language_header = req.headers().get(header::ACCEPT_LANGUAGE);
language = match language_header {
Some(value) => {
let value_str: Vec<&str> = value.to_str().unwrap_or("zh-CN").split(',').collect();
if value_str.is_empty() {
language
} else {
String::from(value_str[0])
}
},
None => language,
};
req.extensions_mut().insert(WhiteContext { lang_tag: language.clone() }); req.extensions_mut().insert(WhiteContext { lang_tag: language.clone() });
// 获取请求的url和method然后判断是否在白名单中如果在白名单中则直接返回next(req),否则继续执行下面的代码 // 获取请求的url和method然后判断是否在白名单中如果在白名单中则直接返回next(req),否则继续执行下面的代码
let method = req.method().clone().to_string(); let method = req.method().clone().to_string();

View File

@ -1,2 +1,3 @@
pub mod response; pub mod response;
pub mod validator; pub mod validator;
pub mod query_validator;

View File

@ -0,0 +1,41 @@
use axum::{async_trait, body::Body, extract::{FromRequest, FromRequestParts, Query}};
use http::Request;
use i18n::{message, message_ids::MessageId};
use validator::Validate;
use crate::utils::request_util;
use super::response::ResErr;
pub struct QueryValidator<T>(pub T);
#[async_trait]
impl<S, T> FromRequest<S> for QueryValidator<T>
where
S: Send + Sync,
T: Validate,
Query<T>: FromRequestParts<S>,
{
type Rejection = ResErr;
async fn from_request(req: http::Request<Body>, state: &S) -> Result<Self, Self::Rejection> {
let (parts, body) = req.into_parts();
let query = Query::<T>::from_request(Request::from_parts(parts.clone(), body), state).await;
let header = &parts.headers;
if let Ok(Query(data)) = query {
match data.validate() {
Ok(_) => Ok(QueryValidator(data)),
Err(_) => {
let lang_tag = request_util::get_lang_tag(header);
let err = Err(ResErr::params(message!(&lang_tag, MessageId::InvalidParams)));
err
},
}
} else {
let lang_tag = request_util::get_lang_tag(header);
let err = Err(ResErr::params(message!(&lang_tag, MessageId::InvalidParams)));
err
}
}
}

1
library/src/utils/mod.rs Normal file
View File

@ -0,0 +1 @@
pub mod request_util;

View File

@ -0,0 +1,20 @@
use axum::extract::Request;
use http::{header, HeaderMap, HeaderValue};
/// 获取请求的语言
#[inline]
pub fn get_lang_tag(headers: &HeaderMap<HeaderValue>) -> String {
let language = String::from("zh-CN");
let language_header = headers.get(header::ACCEPT_LANGUAGE);
match language_header {
Some(value) => {
let value_str: Vec<&str> = value.to_str().unwrap_or("zh-CN").split(',').collect();
if value_str.is_empty() {
language
} else {
String::from(value_str[0])
}
},
None => language,
}
}

View File

@ -4,6 +4,7 @@ use domain::dto::feedback::FeedbackAdd;
use domain::dto::pageable::PageParams; use domain::dto::pageable::PageParams;
use domain::vo::feedback::FeedbackPageable; use domain::vo::feedback::FeedbackPageable;
use library::context::Context; use library::context::Context;
use library::model::query_validator::QueryValidator;
use library::model::response::ResResult; use library::model::response::ResResult;
use library::model::validator; use library::model::validator;
@ -25,7 +26,7 @@ pub async fn add_feedback(
/// 获取反馈信息列表 /// 获取反馈信息列表
pub async fn get_feedback_list_by_page( pub async fn get_feedback_list_by_page(
Extension(context): Extension<Context>, Extension(context): Extension<Context>,
Query(page_params): Query<PageParams>, QueryValidator(page_params): QueryValidator<PageParams>,
) -> ResResult<FeedbackPageable> { ) -> ResResult<FeedbackPageable> {
validator::validate_params(&page_params, context.get_lang_tag())?; validator::validate_params(&page_params, context.get_lang_tag())?;
service::feedback_service::get_feedback_list_by_page( service::feedback_service::get_feedback_list_by_page(

View File

@ -0,0 +1,54 @@
// https://github.com/tokio-rs/axum/discussions/2081
// use axum::{
// async_trait,
// extract::{FromRequest, FromRequestParts, Query},
// http, response,
// };
// use serde::{Deserialize, Serialize};
// use validator::Validate;
// pub struct Validated<T>(pub T);
// #[async_trait]
// impl<S, B, T> FromRequest<S, B> for Validated<T>
// where
// S: Send + Sync,
// B: Send + 'static,
// T: Validate,
// Query<T>: FromRequestParts<S>,
// {
// type Rejection = (http::StatusCode, String);
// async fn from_request(req: http::Request<B>, state: &S) -> Result<Self, Self::Rejection> {
// let query = Query::<T>::from_request(req, state).await;
// if let Ok(Query(data)) = query {
// match data.validate() {
// Ok(_) => Ok(Validated(data)),
// Err(err) => Err((http::StatusCode::BAD_REQUEST, err.to_string())),
// }
// } else {
// Err((
// http::StatusCode::INTERNAL_SERVER_ERROR,
// "internal server error".to_string(),
// ))
// }
// }
// }
// // my handler
// #[derive(Deserialize, Serialize, Validate)]
// pub struct Pagination {
// #[validate(range(min = 1))]
// page: usize,
// #[validate(range(max = 100))]
// per_page: usize,
// }
// #[axum::debug_handler]
// pub async fn get_query_string(
// Validated(pagination): Validated<Pagination>,
// ) -> response::Json<Pagination> {
// response::Json(pagination)
// }