diff --git a/Cargo.lock b/Cargo.lock index 8bc0e76..828bda5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -111,9 +111,9 @@ checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0" [[package]] name = "axum" -version = "0.7.5" +version = "0.7.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a6c9af12842a67734c9a2e355436e5d03b22383ed60cf13cd0c18fbfe3dcbcf" +checksum = "504e3947307ac8326a5437504c517c4b56716c9d98fac0028c2acc7ca47d70ae" dependencies = [ "async-trait", "axum-core", @@ -137,7 +137,7 @@ dependencies = [ "serde_urlencoded", "sync_wrapper 1.0.1", "tokio", - "tower", + "tower 0.5.1", "tower-layer", "tower-service", "tracing", @@ -145,9 +145,9 @@ dependencies = [ [[package]] name = "axum-core" -version = "0.4.3" +version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a15c63fd72d41492dc4f497196f5da1fb04fb7529e631d73630d1b491e47a2e3" +checksum = "09f2bd6146b97ae3359fa0cc6d6b376d9539582c7b4220f041a33ec24c226199" dependencies = [ "async-trait", "bytes", @@ -158,7 +158,7 @@ dependencies = [ "mime", "pin-project-lite", "rustversion", - "sync_wrapper 0.1.2", + "sync_wrapper 1.0.1", "tower-layer", "tower-service", "tracing", @@ -180,7 +180,7 @@ dependencies = [ "mime", "pin-project-lite", "serde", - "tower", + "tower 0.4.13", "tower-layer", "tower-service", "tracing", @@ -988,7 +988,7 @@ dependencies = [ "pin-project-lite", "socket2", "tokio", - "tower", + "tower 0.4.13", "tower-service", "tracing", ] @@ -1149,7 +1149,7 @@ dependencies = [ "tokio", "tokio-cron-scheduler", "toml", - "tower", + "tower 0.4.13", "tower-http", "tracing", "tracing-appender", @@ -2058,7 +2058,7 @@ dependencies = [ "sqlx", "tokio", "tokio-cron-scheduler", - "tower", + "tower 0.4.13", "tower-http", "tracing", "validator", @@ -2720,6 +2720,22 @@ dependencies = [ "tracing", ] +[[package]] +name = "tower" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2873938d487c3cfb9aed7546dc9f2711d867c9f90c46b889989a2cb84eba6b4f" +dependencies = [ + "futures-core", + "futures-util", + "pin-project-lite", + "sync_wrapper 0.1.2", + "tokio", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "tower-http" version = "0.5.2" diff --git a/README.MD b/README.MD new file mode 100644 index 0000000..2ea6b33 --- /dev/null +++ b/README.MD @@ -0,0 +1,5 @@ +### todo +- [*] 使用Extractor方式提取数据,包括Body、Query和Path +- [ ] multipart/form-data的实现 +- [ ] 参考example中的jwt实现方式,移除context对extension的依赖?那么language-tag该怎么处理? +- [ ] 参考rocket,移除参数的元组类型 diff --git a/domain/src/dto/pageable.rs b/domain/src/dto/pageable.rs index 0a686cf..e9ee3cd 100644 --- a/domain/src/dto/pageable.rs +++ b/domain/src/dto/pageable.rs @@ -1,7 +1,7 @@ use serde::Deserialize; use validator::Validate; -#[derive(Deserialize, Validate)] +#[derive(Debug, Deserialize, Validate)] pub struct PageParams { #[validate(required(message = "ValidatePageablePageRequired"), range(min = 1, message = "ValidatePageablePageRequired"))] pub page: Option, diff --git a/library/src/lib.rs b/library/src/lib.rs index 7a2a7f2..0c49982 100644 --- a/library/src/lib.rs +++ b/library/src/lib.rs @@ -8,4 +8,5 @@ pub mod social; pub mod cache; pub mod context; pub mod task; -pub mod utils; \ No newline at end of file +pub mod utils; +pub mod validator; \ No newline at end of file diff --git a/library/src/middleware/req_ctx.rs b/library/src/middleware/req_ctx.rs index 8804fb8..7c6d1d4 100644 --- a/library/src/middleware/req_ctx.rs +++ b/library/src/middleware/req_ctx.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use axum::{extract::Request, middleware::Next, response::{IntoResponse, Response}}; -use http::{header, StatusCode}; +use http::header; use i18n::{message, message_ids::MessageId}; use jsonwebtoken::{decode, DecodingKey, Validation}; diff --git a/library/src/model/mod.rs b/library/src/model/mod.rs index 4eda66d..3ae0378 100644 --- a/library/src/model/mod.rs +++ b/library/src/model/mod.rs @@ -1,3 +1 @@ -pub mod response; -pub mod validator; -pub mod query_validator; \ No newline at end of file +pub mod response; \ No newline at end of file diff --git a/library/src/social/wechat.rs b/library/src/social/wechat.rs index 2ae1b9f..a7922bc 100644 --- a/library/src/social/wechat.rs +++ b/library/src/social/wechat.rs @@ -79,6 +79,7 @@ impl WechatSocial { /// /// https://developers.weixin.qq.com/minigame/dev/api-backend/open-api/access-token/auth.getAccessToken.html /// https://developers.weixin.qq.com/miniprogram/dev/OpenApiDoc/mp-access-token/getAccessToken.html + #[allow(dead_code)] async fn fetch_and_parse_wechat_access_token(&self) -> SocialResult { let mut wechat_access_token = self.access_token.lock().await; if wechat_access_token.access_token.is_empty() diff --git a/library/src/validator/body_validator.rs b/library/src/validator/body_validator.rs new file mode 100644 index 0000000..8f96859 --- /dev/null +++ b/library/src/validator/body_validator.rs @@ -0,0 +1,80 @@ +use axum::{async_trait, extract::{rejection::{FormRejection, JsonRejection}, FromRequest, Request}}; +use http::header::CONTENT_TYPE; +use i18n::{message, message_ids::MessageId}; +use validator::Validate; + +use crate::{context::Context, model::response::ResErr}; + + +pub struct JsonBody(pub T); + +#[async_trait] +impl FromRequest for JsonBody +where + axum::Json: FromRequest, + axum::Form: FromRequest, + S: Send + Sync, + T: Validate +{ + type Rejection = ResErr; + + async fn from_request(req: Request, state: &S) -> Result { + let (parts, body) = req.into_parts(); + + // We can use other extractors to provide better rejection messages. + // For example, here we are using `axum::extract::MatchedPath` to + // provide a better error message. + // + // Have to run that first since `Json` extraction consumes the request. + // let path = parts + // .extract::() + // .await + // .map(|path| path.as_str().to_owned()) + // .ok(); + + let context_parts = parts.clone(); + let context: &Context = context_parts.extensions.get().unwrap(); + let content_type_header = context_parts.headers.get(CONTENT_TYPE); + let content_type = content_type_header.and_then(|value| value.to_str().ok()); + + let lang_tag = context.get_lang_tag(); + let req = Request::from_parts(parts, body); + + if let Some(content_type) = content_type { + if content_type.starts_with("application/json") { + match axum::Json::::from_request(req, state).await { + Ok(value) => { + let data = value.0; + match super::validator::validate_params(&data, lang_tag) { + Ok(_) => return Ok(Self(data)), + Err(err) => return Err(err), + } + }, + // convert the error from `axum::Json` into whatever we want + Err(rejection) => { + tracing::error!("无效的json数据: {:?}", rejection); + return Err(ResErr::params(message!(lang_tag, MessageId::InvalidParams))) + } + } + } else if content_type.starts_with("application/x-www-form-urlencoded") { + match axum::Form::::from_request(req, state).await { + Ok(value) => { + let data = value.0; + match super::validator::validate_params(&data, lang_tag) { + Ok(_) => return Ok(Self(data)), + Err(err) => return Err(err), + } + }, + // convert the error from `axum::Json` into whatever we want + Err(rejection) => { + tracing::error!("无效的json数据: {:?}", rejection); + return Err(ResErr::params(message!(lang_tag, MessageId::InvalidParams))) + } + } + } + return Err(ResErr::params(message!(lang_tag, MessageId::InvalidParams))) + } + + return Err(ResErr::params(message!(lang_tag, MessageId::InvalidParams))) + } +} diff --git a/library/src/validator/mod.rs b/library/src/validator/mod.rs new file mode 100644 index 0000000..099d689 --- /dev/null +++ b/library/src/validator/mod.rs @@ -0,0 +1,4 @@ +pub mod validator; +pub mod query_validator; +pub mod path_validator; +pub mod body_validator; \ No newline at end of file diff --git a/library/src/validator/path_validator.rs b/library/src/validator/path_validator.rs new file mode 100644 index 0000000..6948043 --- /dev/null +++ b/library/src/validator/path_validator.rs @@ -0,0 +1,94 @@ +use std::fmt::Display; + +use axum::{async_trait, extract::{path::ErrorKind, rejection::PathRejection, FromRequestParts}}; +use http::request::Parts; +use i18n::{message, message_ids::MessageId}; +use serde::de::DeserializeOwned; + +use crate::{context::Context, model::response::ResErr}; + + +pub struct PathVar(pub T); + +#[async_trait] +impl FromRequestParts for PathVar +where + T: DeserializeOwned + Send + Display, + S: Send + Sync, +{ + type Rejection = ResErr; + + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + let context_parts = parts.clone(); + let context: &Context = context_parts.extensions.get().unwrap(); + let lang_tag = context.get_lang_tag(); + + match axum::extract::Path::::from_request_parts(parts, state).await { + Ok(value) => { + tracing::info!("path params: {}", value.0); + Ok(Self(value.0)) + }, + Err(rejection) => { + let err_rep = match rejection { + PathRejection::FailedToDeserializePathParams(inner) => { + let kind = inner.into_kind(); + let body = match &kind { + ErrorKind::WrongNumberOfParameters { got, expected } => { + tracing::error!("无效的路径参数: WrongNumberOfParameters, {} - {}", got, expected); + ResErr::params(message!(lang_tag, MessageId::InvalidParams)) + }, + + ErrorKind::ParseErrorAtKey { key, .. } => { + tracing::error!("无效的路径参数: ParseErrorAtKey, key: {}", key); + ResErr::params(message!(lang_tag, MessageId::InvalidParams)) + }, + + ErrorKind::ParseErrorAtIndex { index, .. } => { + tracing::error!("无效的路径参数: ParseErrorAtIndex, index; {}", index); + ResErr::params(message!(lang_tag, MessageId::InvalidParams)) + }, + + ErrorKind::ParseError { .. } => { + tracing::error!("无效的路径参数: ParseError"); + ResErr::params(message!(lang_tag, MessageId::InvalidParams)) + }, + + ErrorKind::InvalidUtf8InPathParam { key } => { + tracing::error!("无效的路径参数: InvalidUtf8InPathParam, key: {}", key); + ResErr::params(message!(lang_tag, MessageId::InvalidParams)) + }, + + ErrorKind::UnsupportedType { .. } => { + // this error is caused by the programmer using an unsupported type + // (such as nested maps) so respond with `500` instead + tracing::error!("无效的路径参数: UnsupportedType"); + ResErr::params(message!(lang_tag, MessageId::InvalidParams)) + } + + ErrorKind::Message(msg) => { + tracing::error!("无效的路径参数: Message, msg: {}", msg); + ResErr::params(message!(lang_tag, MessageId::InvalidParams)) + }, + + _ => { + tracing::error!("无效的路径参数"); + ResErr::params(message!(lang_tag, MessageId::InvalidParams)) + }, + }; + + body + } + PathRejection::MissingPathParams(error) => { + tracing::error!("无效的路径参数: MissingPathParams, err: {}", error); + ResErr::params(message!(lang_tag, MessageId::InvalidParams)) + }, + _ => { + tracing::error!("无效的路径参数: Others"); + ResErr::params(message!(lang_tag, MessageId::InvalidParams)) + }, + }; + Err(err_rep) + } + } + } +} \ No newline at end of file diff --git a/library/src/model/query_validator.rs b/library/src/validator/query_validator.rs similarity index 79% rename from library/src/model/query_validator.rs rename to library/src/validator/query_validator.rs index 02e4d05..2475a4d 100644 --- a/library/src/model/query_validator.rs +++ b/library/src/validator/query_validator.rs @@ -3,14 +3,12 @@ use http::Request; use i18n::{message, message_ids::MessageId}; use validator::Validate; -use crate::context::Context; +use crate::{context::Context, model::response::ResErr}; -use super::response::ResErr; - -pub struct QueryValidator(pub T); +pub struct QueryParams(pub T); #[async_trait] -impl FromRequest for QueryValidator +impl FromRequest for QueryParams where S: Send + Sync, T: Validate, @@ -23,11 +21,10 @@ where let query = Query::::from_request(Request::from_parts(parts.clone(), body), state).await; let context: &Context = parts.extensions.get().unwrap(); - tracing::info!("{:?}", context); if let Ok(Query(data)) = query { match super::validator::validate_params(&data, context.get_lang_tag()) { - Ok(_) => Ok(QueryValidator(data)), + Ok(_) => Ok(Self(data)), Err(err) => Err(err), } } else { diff --git a/library/src/model/validator.rs b/library/src/validator/validator.rs similarity index 97% rename from library/src/model/validator.rs rename to library/src/validator/validator.rs index 42e2536..b720393 100644 --- a/library/src/model/validator.rs +++ b/library/src/validator/validator.rs @@ -3,7 +3,7 @@ use std::str::FromStr; use i18n::{message, message_ids::MessageId}; use validator::Validate; -use super::response::{ResData, ResErr, ResResult}; +use crate::model::response::{ResData, ResErr, ResResult}; /// 验证请求参数 pub fn validate_params(params: &impl Validate, local: &str) -> ResResult> { diff --git a/server/src/controller/account_controller.rs b/server/src/controller/account_controller.rs index 4f7171c..c4b620c 100644 --- a/server/src/controller/account_controller.rs +++ b/server/src/controller/account_controller.rs @@ -1,6 +1,6 @@ -use axum::{Extension, Json}; +use axum::Extension; use domain::{dto::account::{AuthenticateGooleAccountReq, AuthenticateWithPassword, RefreshToken}, vo::account::{LoginAccount, RefreshTokenResult}}; -use library::{context::Context, model::{response::ResResult, validator}}; +use library::{context::Context, model::response::ResResult, validator::body_validator::JsonBody}; use crate::service; @@ -9,9 +9,8 @@ use crate::service; /// google账号登录 pub async fn authenticate_google( Extension(context): Extension, - Json(req): Json + JsonBody(req): JsonBody ) -> ResResult { - validator::validate_params(&req, context.get_lang_tag())?; service::account_service::authenticate_google(context, req).await } @@ -20,9 +19,8 @@ pub async fn authenticate_google( /// 账号密码登录 pub async fn authenticate_with_password( Extension(context): Extension, - Json(req): Json + JsonBody(req): JsonBody ) -> ResResult { - validator::validate_params(&req, context.get_lang_tag())?; service::sys_account_service::authenticate_with_password(context, req).await } @@ -31,14 +29,14 @@ pub async fn authenticate_with_password( /// 刷新token pub async fn refresh_token( Extension(context): Extension, - Json(refresh_token): Json + JsonBody(refresh_token): JsonBody ) -> ResResult { tracing::debug!("刷新token, {:?}", context); - validator::validate_params(&refresh_token, "")?; service::account_service::refresh_token(context, refresh_token.token).await } /// 添加管理员账号 +#[allow(dead_code)] pub async fn add_account() -> ResResult<()> { service::sys_account_service::add_account().await } \ No newline at end of file diff --git a/server/src/controller/feedback_controller.rs b/server/src/controller/feedback_controller.rs index 04f60a8..fe33b22 100644 --- a/server/src/controller/feedback_controller.rs +++ b/server/src/controller/feedback_controller.rs @@ -1,38 +1,35 @@ -use axum::extract::Query; -use axum::{Extension, Json}; +use axum::Extension; use domain::dto::feedback::FeedbackAdd; use domain::dto::pageable::PageParams; use domain::vo::feedback::FeedbackPageable; use library::context::Context; -use library::model::query_validator::QueryValidator; use library::model::response::ResResult; -use library::model::validator; +use library::validator::body_validator::JsonBody; +use library::validator::query_validator::QueryParams; use crate::service; /// post: /feedback -/// +/// /// 添加反馈信息 pub async fn add_feedback( Extension(context): Extension, - Json(req): Json, + JsonBody(req): JsonBody, ) -> ResResult<()> { - validator::validate_params(&req, context.get_lang_tag())?; service::feedback_service::add_feedback(context, req).await } /// get: /feedback -/// +/// /// 获取反馈信息列表 pub async fn get_feedback_list_by_page( Extension(context): Extension, - QueryValidator(page_params): QueryValidator, + QueryParams(page_params): QueryParams, ) -> ResResult { - validator::validate_params(&page_params, context.get_lang_tag())?; service::feedback_service::get_feedback_list_by_page( context, page_params.page.unwrap(), page_params.page_size.unwrap(), ) .await -} +} \ No newline at end of file diff --git a/server/src/controller/mod.rs b/server/src/controller/mod.rs index 4ea5550..6fc7a11 100644 --- a/server/src/controller/mod.rs +++ b/server/src/controller/mod.rs @@ -20,6 +20,6 @@ pub fn init() -> Router { .route( "/feedback", post(feedback_controller::add_feedback) - .get(feedback_controller::get_feedback_list_by_page), + .get(feedback_controller::get_feedback_list_by_page) ) } \ No newline at end of file diff --git a/server/src/lib.rs b/server/src/lib.rs index 5f4b997..5d58218 100644 --- a/server/src/lib.rs +++ b/server/src/lib.rs @@ -1,4 +1,4 @@ -use axum::{body::Body, extract::Request, routing::get, Router}; +use axum::{body::Body, extract::Request, http, routing::get, Router}; use library::{config, task}; use tasks::get_tasks; use tower::ServiceBuilder; @@ -36,21 +36,27 @@ fn init() -> Router { tracing::error_span!("request_id", id = req_id) }); + // 配置路由 + // layer之间存在顺序依赖,勿改。layer执行顺序和配置顺序一致 + // fallback路由放到最后,如果无匹配的路由,则不会执行layer,直接返回404 Router::new() .route("/", get(|| async { "hello" })) .nest(&config!().server.prefix_url, auth) .layer( ServiceBuilder::new() .layer(axum::middleware::from_fn( - library::middleware::req_ctx::authenticate_ctx, + library::middleware::req_id::handle, )) + .layer(trace_layer) + .layer(axum::middleware::from_fn( + library::middleware::cors::handle) + ) .layer(axum::middleware::from_fn( library::middleware::req_log::handle, )) - .layer(axum::middleware::from_fn(library::middleware::cors::handle)) - .layer(trace_layer) .layer(axum::middleware::from_fn( - library::middleware::req_id::handle, + library::middleware::req_ctx::authenticate_ctx, )) ) + .fallback(|| async { (http::status::StatusCode::NOT_FOUND, "Not Found") }) } \ No newline at end of file