添加path_validator和body_validator

This commit is contained in:
李运家 2024-10-01 15:55:05 +08:00
parent bd3e18dd8d
commit 4751d54c96
16 changed files with 246 additions and 49 deletions

36
Cargo.lock generated
View File

@ -111,9 +111,9 @@ checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0"
[[package]] [[package]]
name = "axum" name = "axum"
version = "0.7.5" version = "0.7.7"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3a6c9af12842a67734c9a2e355436e5d03b22383ed60cf13cd0c18fbfe3dcbcf" checksum = "504e3947307ac8326a5437504c517c4b56716c9d98fac0028c2acc7ca47d70ae"
dependencies = [ dependencies = [
"async-trait", "async-trait",
"axum-core", "axum-core",
@ -137,7 +137,7 @@ dependencies = [
"serde_urlencoded", "serde_urlencoded",
"sync_wrapper 1.0.1", "sync_wrapper 1.0.1",
"tokio", "tokio",
"tower", "tower 0.5.1",
"tower-layer", "tower-layer",
"tower-service", "tower-service",
"tracing", "tracing",
@ -145,9 +145,9 @@ dependencies = [
[[package]] [[package]]
name = "axum-core" name = "axum-core"
version = "0.4.3" version = "0.4.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a15c63fd72d41492dc4f497196f5da1fb04fb7529e631d73630d1b491e47a2e3" checksum = "09f2bd6146b97ae3359fa0cc6d6b376d9539582c7b4220f041a33ec24c226199"
dependencies = [ dependencies = [
"async-trait", "async-trait",
"bytes", "bytes",
@ -158,7 +158,7 @@ dependencies = [
"mime", "mime",
"pin-project-lite", "pin-project-lite",
"rustversion", "rustversion",
"sync_wrapper 0.1.2", "sync_wrapper 1.0.1",
"tower-layer", "tower-layer",
"tower-service", "tower-service",
"tracing", "tracing",
@ -180,7 +180,7 @@ dependencies = [
"mime", "mime",
"pin-project-lite", "pin-project-lite",
"serde", "serde",
"tower", "tower 0.4.13",
"tower-layer", "tower-layer",
"tower-service", "tower-service",
"tracing", "tracing",
@ -988,7 +988,7 @@ dependencies = [
"pin-project-lite", "pin-project-lite",
"socket2", "socket2",
"tokio", "tokio",
"tower", "tower 0.4.13",
"tower-service", "tower-service",
"tracing", "tracing",
] ]
@ -1149,7 +1149,7 @@ dependencies = [
"tokio", "tokio",
"tokio-cron-scheduler", "tokio-cron-scheduler",
"toml", "toml",
"tower", "tower 0.4.13",
"tower-http", "tower-http",
"tracing", "tracing",
"tracing-appender", "tracing-appender",
@ -2058,7 +2058,7 @@ dependencies = [
"sqlx", "sqlx",
"tokio", "tokio",
"tokio-cron-scheduler", "tokio-cron-scheduler",
"tower", "tower 0.4.13",
"tower-http", "tower-http",
"tracing", "tracing",
"validator", "validator",
@ -2720,6 +2720,22 @@ dependencies = [
"tracing", "tracing",
] ]
[[package]]
name = "tower"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2873938d487c3cfb9aed7546dc9f2711d867c9f90c46b889989a2cb84eba6b4f"
dependencies = [
"futures-core",
"futures-util",
"pin-project-lite",
"sync_wrapper 0.1.2",
"tokio",
"tower-layer",
"tower-service",
"tracing",
]
[[package]] [[package]]
name = "tower-http" name = "tower-http"
version = "0.5.2" version = "0.5.2"

5
README.MD Normal file
View File

@ -0,0 +1,5 @@
### todo
- [*] 使用Extractor方式提取数据包括Body、Query和Path
- [ ] multipart/form-data的实现
- [ ] 参考example中的jwt实现方式移除context对extension的依赖那么language-tag该怎么处理
- [ ] 参考rocket移除参数的元组类型

View File

@ -1,7 +1,7 @@
use serde::Deserialize; use serde::Deserialize;
use validator::Validate; use validator::Validate;
#[derive(Deserialize, Validate)] #[derive(Debug, Deserialize, Validate)]
pub struct PageParams { pub struct PageParams {
#[validate(required(message = "ValidatePageablePageRequired"), range(min = 1, message = "ValidatePageablePageRequired"))] #[validate(required(message = "ValidatePageablePageRequired"), range(min = 1, message = "ValidatePageablePageRequired"))]
pub page: Option<i64>, pub page: Option<i64>,

View File

@ -9,3 +9,4 @@ pub mod cache;
pub mod context; pub mod context;
pub mod task; pub mod task;
pub mod utils; pub mod utils;
pub mod validator;

View File

@ -1,7 +1,7 @@
use std::sync::Arc; use std::sync::Arc;
use axum::{extract::Request, middleware::Next, response::{IntoResponse, Response}}; use axum::{extract::Request, middleware::Next, response::{IntoResponse, Response}};
use http::{header, StatusCode}; use http::header;
use i18n::{message, message_ids::MessageId}; use i18n::{message, message_ids::MessageId};
use jsonwebtoken::{decode, DecodingKey, Validation}; use jsonwebtoken::{decode, DecodingKey, Validation};

View File

@ -1,3 +1 @@
pub mod response; pub mod response;
pub mod validator;
pub mod query_validator;

View File

@ -79,6 +79,7 @@ impl WechatSocial {
/// ///
/// https://developers.weixin.qq.com/minigame/dev/api-backend/open-api/access-token/auth.getAccessToken.html /// https://developers.weixin.qq.com/minigame/dev/api-backend/open-api/access-token/auth.getAccessToken.html
/// https://developers.weixin.qq.com/miniprogram/dev/OpenApiDoc/mp-access-token/getAccessToken.html /// https://developers.weixin.qq.com/miniprogram/dev/OpenApiDoc/mp-access-token/getAccessToken.html
#[allow(dead_code)]
async fn fetch_and_parse_wechat_access_token(&self) -> SocialResult<WeChatAccessToken> { async fn fetch_and_parse_wechat_access_token(&self) -> SocialResult<WeChatAccessToken> {
let mut wechat_access_token = self.access_token.lock().await; let mut wechat_access_token = self.access_token.lock().await;
if wechat_access_token.access_token.is_empty() if wechat_access_token.access_token.is_empty()

View File

@ -0,0 +1,80 @@
use axum::{async_trait, extract::{rejection::{FormRejection, JsonRejection}, FromRequest, Request}};
use http::header::CONTENT_TYPE;
use i18n::{message, message_ids::MessageId};
use validator::Validate;
use crate::{context::Context, model::response::ResErr};
pub struct JsonBody<T>(pub T);
#[async_trait]
impl<S, T> FromRequest<S> for JsonBody<T>
where
axum::Json<T>: FromRequest<S, Rejection = JsonRejection>,
axum::Form<T>: FromRequest<S, Rejection = FormRejection>,
S: Send + Sync,
T: Validate
{
type Rejection = ResErr;
async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
let (parts, body) = req.into_parts();
// We can use other extractors to provide better rejection messages.
// For example, here we are using `axum::extract::MatchedPath` to
// provide a better error message.
//
// Have to run that first since `Json` extraction consumes the request.
// let path = parts
// .extract::<MatchedPath>()
// .await
// .map(|path| path.as_str().to_owned())
// .ok();
let context_parts = parts.clone();
let context: &Context = context_parts.extensions.get().unwrap();
let content_type_header = context_parts.headers.get(CONTENT_TYPE);
let content_type = content_type_header.and_then(|value| value.to_str().ok());
let lang_tag = context.get_lang_tag();
let req = Request::from_parts(parts, body);
if let Some(content_type) = content_type {
if content_type.starts_with("application/json") {
match axum::Json::<T>::from_request(req, state).await {
Ok(value) => {
let data = value.0;
match super::validator::validate_params(&data, lang_tag) {
Ok(_) => return Ok(Self(data)),
Err(err) => return Err(err),
}
},
// convert the error from `axum::Json` into whatever we want
Err(rejection) => {
tracing::error!("无效的json数据: {:?}", rejection);
return Err(ResErr::params(message!(lang_tag, MessageId::InvalidParams)))
}
}
} else if content_type.starts_with("application/x-www-form-urlencoded") {
match axum::Form::<T>::from_request(req, state).await {
Ok(value) => {
let data = value.0;
match super::validator::validate_params(&data, lang_tag) {
Ok(_) => return Ok(Self(data)),
Err(err) => return Err(err),
}
},
// convert the error from `axum::Json` into whatever we want
Err(rejection) => {
tracing::error!("无效的json数据: {:?}", rejection);
return Err(ResErr::params(message!(lang_tag, MessageId::InvalidParams)))
}
}
}
return Err(ResErr::params(message!(lang_tag, MessageId::InvalidParams)))
}
return Err(ResErr::params(message!(lang_tag, MessageId::InvalidParams)))
}
}

View File

@ -0,0 +1,4 @@
pub mod validator;
pub mod query_validator;
pub mod path_validator;
pub mod body_validator;

View File

@ -0,0 +1,94 @@
use std::fmt::Display;
use axum::{async_trait, extract::{path::ErrorKind, rejection::PathRejection, FromRequestParts}};
use http::request::Parts;
use i18n::{message, message_ids::MessageId};
use serde::de::DeserializeOwned;
use crate::{context::Context, model::response::ResErr};
pub struct PathVar<T>(pub T);
#[async_trait]
impl<S, T> FromRequestParts<S> for PathVar<T>
where
T: DeserializeOwned + Send + Display,
S: Send + Sync,
{
type Rejection = ResErr;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let context_parts = parts.clone();
let context: &Context = context_parts.extensions.get().unwrap();
let lang_tag = context.get_lang_tag();
match axum::extract::Path::<T>::from_request_parts(parts, state).await {
Ok(value) => {
tracing::info!("path params: {}", value.0);
Ok(Self(value.0))
},
Err(rejection) => {
let err_rep = match rejection {
PathRejection::FailedToDeserializePathParams(inner) => {
let kind = inner.into_kind();
let body = match &kind {
ErrorKind::WrongNumberOfParameters { got, expected } => {
tracing::error!("无效的路径参数: WrongNumberOfParameters, {} - {}", got, expected);
ResErr::params(message!(lang_tag, MessageId::InvalidParams))
},
ErrorKind::ParseErrorAtKey { key, .. } => {
tracing::error!("无效的路径参数: ParseErrorAtKey, key: {}", key);
ResErr::params(message!(lang_tag, MessageId::InvalidParams))
},
ErrorKind::ParseErrorAtIndex { index, .. } => {
tracing::error!("无效的路径参数: ParseErrorAtIndex, index; {}", index);
ResErr::params(message!(lang_tag, MessageId::InvalidParams))
},
ErrorKind::ParseError { .. } => {
tracing::error!("无效的路径参数: ParseError");
ResErr::params(message!(lang_tag, MessageId::InvalidParams))
},
ErrorKind::InvalidUtf8InPathParam { key } => {
tracing::error!("无效的路径参数: InvalidUtf8InPathParam, key: {}", key);
ResErr::params(message!(lang_tag, MessageId::InvalidParams))
},
ErrorKind::UnsupportedType { .. } => {
// this error is caused by the programmer using an unsupported type
// (such as nested maps) so respond with `500` instead
tracing::error!("无效的路径参数: UnsupportedType");
ResErr::params(message!(lang_tag, MessageId::InvalidParams))
}
ErrorKind::Message(msg) => {
tracing::error!("无效的路径参数: Message, msg: {}", msg);
ResErr::params(message!(lang_tag, MessageId::InvalidParams))
},
_ => {
tracing::error!("无效的路径参数");
ResErr::params(message!(lang_tag, MessageId::InvalidParams))
},
};
body
}
PathRejection::MissingPathParams(error) => {
tracing::error!("无效的路径参数: MissingPathParams, err: {}", error);
ResErr::params(message!(lang_tag, MessageId::InvalidParams))
},
_ => {
tracing::error!("无效的路径参数: Others");
ResErr::params(message!(lang_tag, MessageId::InvalidParams))
},
};
Err(err_rep)
}
}
}
}

View File

@ -3,14 +3,12 @@ use http::Request;
use i18n::{message, message_ids::MessageId}; use i18n::{message, message_ids::MessageId};
use validator::Validate; use validator::Validate;
use crate::context::Context; use crate::{context::Context, model::response::ResErr};
use super::response::ResErr; pub struct QueryParams<T>(pub T);
pub struct QueryValidator<T>(pub T);
#[async_trait] #[async_trait]
impl<S, T> FromRequest<S> for QueryValidator<T> impl<S, T> FromRequest<S> for QueryParams<T>
where where
S: Send + Sync, S: Send + Sync,
T: Validate, T: Validate,
@ -23,11 +21,10 @@ where
let query = Query::<T>::from_request(Request::from_parts(parts.clone(), body), state).await; let query = Query::<T>::from_request(Request::from_parts(parts.clone(), body), state).await;
let context: &Context = parts.extensions.get().unwrap(); let context: &Context = parts.extensions.get().unwrap();
tracing::info!("{:?}", context);
if let Ok(Query(data)) = query { if let Ok(Query(data)) = query {
match super::validator::validate_params(&data, context.get_lang_tag()) { match super::validator::validate_params(&data, context.get_lang_tag()) {
Ok(_) => Ok(QueryValidator(data)), Ok(_) => Ok(Self(data)),
Err(err) => Err(err), Err(err) => Err(err),
} }
} else { } else {

View File

@ -3,7 +3,7 @@ use std::str::FromStr;
use i18n::{message, message_ids::MessageId}; use i18n::{message, message_ids::MessageId};
use validator::Validate; use validator::Validate;
use super::response::{ResData, ResErr, ResResult}; use crate::model::response::{ResData, ResErr, ResResult};
/// 验证请求参数 /// 验证请求参数
pub fn validate_params(params: &impl Validate, local: &str) -> ResResult<ResData<()>> { pub fn validate_params(params: &impl Validate, local: &str) -> ResResult<ResData<()>> {

View File

@ -1,6 +1,6 @@
use axum::{Extension, Json}; use axum::Extension;
use domain::{dto::account::{AuthenticateGooleAccountReq, AuthenticateWithPassword, RefreshToken}, vo::account::{LoginAccount, RefreshTokenResult}}; use domain::{dto::account::{AuthenticateGooleAccountReq, AuthenticateWithPassword, RefreshToken}, vo::account::{LoginAccount, RefreshTokenResult}};
use library::{context::Context, model::{response::ResResult, validator}}; use library::{context::Context, model::response::ResResult, validator::body_validator::JsonBody};
use crate::service; use crate::service;
@ -9,9 +9,8 @@ use crate::service;
/// google账号登录 /// google账号登录
pub async fn authenticate_google( pub async fn authenticate_google(
Extension(context): Extension<Context>, Extension(context): Extension<Context>,
Json(req): Json<AuthenticateGooleAccountReq> JsonBody(req): JsonBody<AuthenticateGooleAccountReq>
) -> ResResult<LoginAccount> { ) -> ResResult<LoginAccount> {
validator::validate_params(&req, context.get_lang_tag())?;
service::account_service::authenticate_google(context, req).await service::account_service::authenticate_google(context, req).await
} }
@ -20,9 +19,8 @@ pub async fn authenticate_google(
/// 账号密码登录 /// 账号密码登录
pub async fn authenticate_with_password( pub async fn authenticate_with_password(
Extension(context): Extension<Context>, Extension(context): Extension<Context>,
Json(req): Json<AuthenticateWithPassword> JsonBody(req): JsonBody<AuthenticateWithPassword>
) -> ResResult<LoginAccount> { ) -> ResResult<LoginAccount> {
validator::validate_params(&req, context.get_lang_tag())?;
service::sys_account_service::authenticate_with_password(context, req).await service::sys_account_service::authenticate_with_password(context, req).await
} }
@ -31,14 +29,14 @@ pub async fn authenticate_with_password(
/// 刷新token /// 刷新token
pub async fn refresh_token( pub async fn refresh_token(
Extension(context): Extension<Context>, Extension(context): Extension<Context>,
Json(refresh_token): Json<RefreshToken> JsonBody(refresh_token): JsonBody<RefreshToken>
) -> ResResult<RefreshTokenResult> { ) -> ResResult<RefreshTokenResult> {
tracing::debug!("刷新token, {:?}", context); tracing::debug!("刷新token, {:?}", context);
validator::validate_params(&refresh_token, "")?;
service::account_service::refresh_token(context, refresh_token.token).await service::account_service::refresh_token(context, refresh_token.token).await
} }
/// 添加管理员账号 /// 添加管理员账号
#[allow(dead_code)]
pub async fn add_account() -> ResResult<()> { pub async fn add_account() -> ResResult<()> {
service::sys_account_service::add_account().await service::sys_account_service::add_account().await
} }

View File

@ -1,12 +1,11 @@
use axum::extract::Query; use axum::Extension;
use axum::{Extension, Json};
use domain::dto::feedback::FeedbackAdd; use domain::dto::feedback::FeedbackAdd;
use domain::dto::pageable::PageParams; use domain::dto::pageable::PageParams;
use domain::vo::feedback::FeedbackPageable; use domain::vo::feedback::FeedbackPageable;
use library::context::Context; use library::context::Context;
use library::model::query_validator::QueryValidator;
use library::model::response::ResResult; use library::model::response::ResResult;
use library::model::validator; use library::validator::body_validator::JsonBody;
use library::validator::query_validator::QueryParams;
use crate::service; use crate::service;
@ -15,9 +14,8 @@ use crate::service;
/// 添加反馈信息 /// 添加反馈信息
pub async fn add_feedback( pub async fn add_feedback(
Extension(context): Extension<Context>, Extension(context): Extension<Context>,
Json(req): Json<FeedbackAdd>, JsonBody(req): JsonBody<FeedbackAdd>,
) -> ResResult<()> { ) -> ResResult<()> {
validator::validate_params(&req, context.get_lang_tag())?;
service::feedback_service::add_feedback(context, req).await service::feedback_service::add_feedback(context, req).await
} }
@ -26,9 +24,8 @@ pub async fn add_feedback(
/// 获取反馈信息列表 /// 获取反馈信息列表
pub async fn get_feedback_list_by_page( pub async fn get_feedback_list_by_page(
Extension(context): Extension<Context>, Extension(context): Extension<Context>,
QueryValidator(page_params): QueryValidator<PageParams>, QueryParams(page_params): QueryParams<PageParams>,
) -> ResResult<FeedbackPageable> { ) -> ResResult<FeedbackPageable> {
validator::validate_params(&page_params, context.get_lang_tag())?;
service::feedback_service::get_feedback_list_by_page( service::feedback_service::get_feedback_list_by_page(
context, context,
page_params.page.unwrap(), page_params.page.unwrap(),

View File

@ -20,6 +20,6 @@ pub fn init() -> Router {
.route( .route(
"/feedback", "/feedback",
post(feedback_controller::add_feedback) post(feedback_controller::add_feedback)
.get(feedback_controller::get_feedback_list_by_page), .get(feedback_controller::get_feedback_list_by_page)
) )
} }

View File

@ -1,4 +1,4 @@
use axum::{body::Body, extract::Request, routing::get, Router}; use axum::{body::Body, extract::Request, http, routing::get, Router};
use library::{config, task}; use library::{config, task};
use tasks::get_tasks; use tasks::get_tasks;
use tower::ServiceBuilder; use tower::ServiceBuilder;
@ -36,21 +36,27 @@ fn init() -> Router {
tracing::error_span!("request_id", id = req_id) tracing::error_span!("request_id", id = req_id)
}); });
// 配置路由
// layer之间存在顺序依赖勿改。layer执行顺序和配置顺序一致
// fallback路由放到最后如果无匹配的路由则不会执行layer直接返回404
Router::new() Router::new()
.route("/", get(|| async { "hello" })) .route("/", get(|| async { "hello" }))
.nest(&config!().server.prefix_url, auth) .nest(&config!().server.prefix_url, auth)
.layer( .layer(
ServiceBuilder::new() ServiceBuilder::new()
.layer(axum::middleware::from_fn( .layer(axum::middleware::from_fn(
library::middleware::req_ctx::authenticate_ctx, library::middleware::req_id::handle,
)) ))
.layer(trace_layer)
.layer(axum::middleware::from_fn(
library::middleware::cors::handle)
)
.layer(axum::middleware::from_fn( .layer(axum::middleware::from_fn(
library::middleware::req_log::handle, library::middleware::req_log::handle,
)) ))
.layer(axum::middleware::from_fn(library::middleware::cors::handle))
.layer(trace_layer)
.layer(axum::middleware::from_fn( .layer(axum::middleware::from_fn(
library::middleware::req_id::handle, library::middleware::req_ctx::authenticate_ctx,
)) ))
) )
.fallback(|| async { (http::status::StatusCode::NOT_FOUND, "Not Found") })
} }