diff --git a/library/src/errors.rs b/library/src/errors.rs new file mode 100644 index 0000000..c1e343c --- /dev/null +++ b/library/src/errors.rs @@ -0,0 +1,36 @@ +use core::fmt; +use std::error::Error; + +#[derive(Debug)] +pub struct MessageError(Option); + +// 实现Error trait +impl Error for MessageError { + + fn cause(&self) -> Option<&dyn Error> { + self.source() + } + +} + +// 实现Display trait,这是Error trait的一部分要求 +impl fmt::Display for MessageError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.0 { + Some(message) => write!(f, "{}", message), + None => write!(f, "") + } + } +} + +impl From for MessageError { + fn from(value: String) -> Self { + MessageError(Some(value)) + } +} + +impl From<&str> for MessageError { + fn from(value: &str) -> Self { + MessageError(Some(value.to_string())) + } +} \ No newline at end of file diff --git a/library/src/lib.rs b/library/src/lib.rs index 8db8ea3..b0e081e 100644 --- a/library/src/lib.rs +++ b/library/src/lib.rs @@ -5,4 +5,5 @@ pub mod resp; pub mod middleware; pub mod token; pub mod cache; -pub mod social; \ No newline at end of file +pub mod social; +pub mod errors; \ No newline at end of file diff --git a/library/src/resp/response.rs b/library/src/resp/response.rs index bc038ae..d99f312 100644 --- a/library/src/resp/response.rs +++ b/library/src/resp/response.rs @@ -1,3 +1,5 @@ +use std::{error::Error as StdError, fmt::Display}; + use axum::{ response::{IntoResponse, Response}, Json, @@ -22,6 +24,7 @@ where } } +#[derive(Debug)] pub enum ResErr { Error(i32, String), ErrParams(Option), @@ -31,6 +34,7 @@ pub enum ResErr { ErrSystem(Option), ErrData(Option), ErrService(Option), + ErrSocial(Option), } use ResErr::*; @@ -95,10 +99,81 @@ impl IntoResponse for ResErr { None => Status::<()>::Err(code, String::from("服务异常")), } } + ErrSocial(msg) => { + let code = 80000; + match msg { + Some(v) => Status::<()>::Err(code, v), + None => Status::<()>::Err(code, String::from("社交服务异常")), + } + } }; Json(status.to_reply()).into_response() } } +// impl Display for ResErr { +// fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +// // write!(f, "{:?}", self.message) +// match self { +// Error(value, message) => write!(f, "{} - {:?}", value, message), +// ErrParams(message) +// | ErrAuth(message) +// | ErrPerm(message) +// | ErrNotFound(message) +// | ErrSystem(message) +// | ErrData(message) +// | ErrService(message) +// | ErrSocial(message) => write!(f, "{:?}", message), +// } +// } +// } + +// impl StdError for ResErr { +// fn cause(&self) -> Option<&dyn StdError> { +// match self { +// Error(_, _) => None, +// ErrParams(_) +// | ErrAuth(_) +// | ErrPerm(_) +// | ErrNotFound(_) +// | ErrSystem(_) +// | ErrData(_) +// | ErrService(_) +// | ErrSocial(_) => None, +// } +// } + +// fn source(&self) -> Option<&(dyn StdError + 'static)> { +// None +// } + +// fn description(&self) -> &str { +// match self { +// Error(value, message) => &format!("Error, {} - {}", value, message), +// ErrParams(message) => &format!("ErrParams, {}", message), +// ErrAuth(message) => &format!("ErrAuth, {}", message), +// ErrPerm(message) => &format!("ErrPerm, {}", message), +// ErrNotFound(message) => &format!("ErrNotFound, {}", message), +// ErrSystem(message) => &format!("ErrSystem, {}", message), +// ErrData(message) => &format!("ErrData, {}", message), +// ErrService(message) => &format!("ErrService, {}", message), +// ErrSocial(message) => &format!("ErrSocial, {}", message), +// } +// } + +// fn provide<'a>(&'a self, request: &mut std::error::Request<'a>) { +// match self { +// ResErr::ErrSystem(e) => request.provide_ref(e), +// ResErr::ErrData(e) => request.provide_ref(e), +// ResErr::ErrService(e) => request.provide_ref(e), +// ResErr::ErrSocial(e) => request.provide_ref(e), +// ResErr::ErrNotFound(e) => request.provide_ref(e), +// ResErr::ErrPerm(e) => request.provide_ref(e), +// ResErr::ErrAuth(e) => request.provide_ref(e), +// ResErr::ErrParams(e) => request.provide_ref(e), +// ResErr::Error(_, e) => request.provide_ref(e), +// } +// } + pub type ResResult = std::result::Result; diff --git a/library/src/social/google.rs b/library/src/social/google.rs index d9980b1..99d07a3 100644 --- a/library/src/social/google.rs +++ b/library/src/social/google.rs @@ -1,14 +1,99 @@ -use std::collections::HashMap; +use std::{collections::HashMap, fmt::Error}; -use jsonwebtoken::{decode, errors::{Error, ErrorKind}, Algorithm, DecodingKey, TokenData, Validation}; +use chrono::Utc; +use jsonwebtoken::{decode, errors::ErrorKind, Algorithm, DecodingKey, TokenData, Validation}; use reqwest::Client; use serde_json::Value; -type SocialResult = Result>; +use crate::{errors::MessageError, resp::response::ResErr}; + +use super::SocialResult; // 假设GOOGLE_PUBLIC_CERT_URL是Google提供的公钥URL const GOOGLE_PUBLIC_CERT_URL: &str = "https://www.googleapis.com/oauth2/v3/certs"; +#[derive(Debug, Default)] +pub struct GoogleJwtProfile { + // iss (issuer):签发人 + pub iss: String, + // sub (subject):主题 + pub sub: String, + pub azp: String, + // aud (audience):受众 + pub aud: String, + // iat (Issued At):签发时间 + pub iat: i64, + // exp (expiration time):过期时间 + pub exp: i64, + pub email: String, + pub email_verified: bool, + pub at_hash: String, + pub name: String, + pub picture: String, + pub given_name: String, + pub family_name: String, + pub locale: String, +} + +impl GoogleJwtProfile { + fn new() -> Self { + GoogleJwtProfile{ + ..Default::default() + } + } +} + +impl From for GoogleJwtProfile { + + fn from(value: Value) -> Self { + let mut google_jwt_profile = GoogleJwtProfile::new(); + if let Some(value) = value.get("iss") { + google_jwt_profile.iss = value.to_string(); + } + if let Some(value) = value.get("sub") { + google_jwt_profile.sub = value.to_string(); + } + if let Some(value) = value.get("azp") { + google_jwt_profile.azp = value.to_string(); + } + if let Some(value) = value.get("aud") { + google_jwt_profile.aud = value.to_string(); + } + if let Some(value) = value.get("iat") { + google_jwt_profile.iat = value.as_i64().unwrap_or_default(); + } + if let Some(value) = value.get("exp") { + google_jwt_profile.exp = value.as_i64().unwrap_or_default(); + } + if let Some(value) = value.get("email") { + google_jwt_profile.email = value.to_string(); + } + if let Some(value) = value.get("email_verified") { + google_jwt_profile.email_verified = value.as_bool().unwrap_or_default(); + } + if let Some(value) = value.get("at_hash") { + google_jwt_profile.at_hash = value.to_string(); + } + if let Some(value) = value.get("name") { + google_jwt_profile.name = value.to_string(); + } + if let Some(value) = value.get("picture") { + google_jwt_profile.picture = value.to_string(); + } + if let Some(value) = value.get("given_name") { + google_jwt_profile.given_name = value.to_string(); + } + if let Some(value) = value.get("family_name") { + google_jwt_profile.family_name = value.to_string(); + } + if let Some(value) = value.get("locale") { + google_jwt_profile.locale = value.to_string(); + } + google_jwt_profile + } + +} + // 异步获取并解析Google公钥 async fn fetch_and_parse_google_public_keys() -> SocialResult> { let response = Client::new().get(GOOGLE_PUBLIC_CERT_URL).send().await?; @@ -27,7 +112,7 @@ async fn fetch_and_parse_google_public_keys() -> SocialResult SocialResult { +async fn verify_id_token(id_token: &str) -> SocialResult { // 获取并解析公钥 let public_keys = fetch_and_parse_google_public_keys().await?; @@ -37,28 +122,36 @@ async fn verify_id_token(id_token: &str) -> SocialResult { // 检查是否找到了有效的kid if kid.is_none() { - return Err(Box::new(Error::from(ErrorKind::InvalidToken))); + return Err(Box::new(MessageError::from("校验Token失败,未找到有效的kid"))); } let kid = kid.unwrap(); // 根据kid找到正确的公钥 - let key = public_keys.get(&kid).ok_or_else(|| Box::new(Error::from(ErrorKind::InvalidToken)))?; + let key = public_keys.get(&kid).ok_or_else(|| Box::new(MessageError::from("校验Token失败,未找到正确的公钥")))?; // 验证Token let mut validation: Validation = Validation::new(Algorithm::RS256); - validation.set_issuer(&["https://accounts.google.com"]);// 设置预期的发行者 + validation.set_issuer(&["https://accounts.google.com", "accounts.google.com"]);// 设置预期的发行者 let decoded = decode::(id_token, &DecodingKey::from_rsa_pem(key.as_bytes())?, &validation)?; - let claims = decoded.claims; - // 进一步校验iss字段 - if let Some(issuer) = claims.get("iss") { - if issuer.as_str().unwrap() != "https://accounts.google.com" { - return Err(Box::new(Error::from(ErrorKind::InvalidToken))); - } - } else { - return Err(Box::new(Error::from(ErrorKind::InvalidToken))); - } - Ok(claims) + let claims: Value = decoded.claims; + let google_jwt_profile = GoogleJwtProfile::from(claims); + // 校验有效期 + if google_jwt_profile.exp < Utc::now().timestamp() { + return Err(Box::new(MessageError::from("校验Token失败,token有效期无效"))); + } + // 校验签发时间 + // if google_jwt_profile.iat > Utc::now().timestamp() { + // return Err(Box::new(Error::from(ErrorKind::InvalidToken))); + // } + // if google_jwt_profile.aud != config::GOOGLE_CLIENT_ID { + // } + // 校验iss字段 + if google_jwt_profile.iss != "accounts.google.com" && google_jwt_profile.iss != "https://accounts.google.com" { + return Err(Box::new(MessageError::from("校验Token失败,token签发人非法"))); + } + + Ok(google_jwt_profile) } diff --git a/library/src/social/mod.rs b/library/src/social/mod.rs index e729dfa..278413d 100644 --- a/library/src/social/mod.rs +++ b/library/src/social/mod.rs @@ -1,3 +1,4 @@ +type SocialResult = Result>; pub mod google; pub mod facebook;