chuanyue-service/library/src/social/google.rs

192 lines
6.2 KiB
Rust

use std::{collections::HashMap, sync::Arc};
use chrono::Utc;
use futures_util::lock::Mutex;
use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation};
use lazy_static::lazy_static;
use reqwest::Client;
use serde::Deserialize;
use serde_json::Value;
use crate::resp::response::ResErr;
use super::SocialResult;
lazy_static! {
pub static ref GOOGLE_SOCIAL: GoogleSocial = GoogleSocial::default();
}
#[derive(Default)]
pub struct GoogleSocial {
google_public_keys: Arc<Mutex<GooglePublicKeys>>,
}
// 假设GOOGLE_PUBLIC_CERT_URL是Google提供的公钥URL
const GOOGLE_PUBLIC_CERT_URL: &str = "https://www.googleapis.com/oauth2/v3/certs";
#[derive(Debug, Default)]
pub struct GoogleJwtProfile {
// iss (issuer):签发人
pub iss: String,
// sub (subject):主题
pub sub: String,
pub azp: String,
// aud (audience):受众
pub aud: String,
// iat (Issued At):签发时间
pub iat: i64,
// exp (expiration time):过期时间
pub exp: i64,
pub email: String,
pub email_verified: bool,
pub at_hash: String,
pub name: String,
pub picture: String,
pub given_name: String,
pub family_name: String,
pub locale: String,
}
impl GoogleJwtProfile {
fn new() -> Self {
GoogleJwtProfile {
..Default::default()
}
}
}
impl From<Value> for GoogleJwtProfile {
fn from(value: Value) -> Self {
let mut google_jwt_profile = GoogleJwtProfile::new();
if let Some(value) = value.get("iss") {
google_jwt_profile.iss = value.as_str().unwrap().to_string();
}
if let Some(value) = value.get("sub") {
google_jwt_profile.sub = value.as_str().unwrap().to_string();
}
if let Some(value) = value.get("azp") {
google_jwt_profile.azp = value.as_str().unwrap().to_string();
}
if let Some(value) = value.get("aud") {
google_jwt_profile.aud = value.as_str().unwrap().to_string();
}
if let Some(value) = value.get("iat") {
google_jwt_profile.iat = value.as_i64().unwrap_or_default();
}
if let Some(value) = value.get("exp") {
google_jwt_profile.exp = value.as_i64().unwrap_or_default();
}
if let Some(value) = value.get("email") {
google_jwt_profile.email = value.as_str().unwrap().to_string();
}
if let Some(value) = value.get("email_verified") {
google_jwt_profile.email_verified = value.as_bool().unwrap_or_default();
}
if let Some(value) = value.get("at_hash") {
google_jwt_profile.at_hash = value.as_str().unwrap().to_string();
}
if let Some(value) = value.get("name") {
google_jwt_profile.name = value.as_str().unwrap().to_string();
}
if let Some(value) = value.get("picture") {
google_jwt_profile.picture = value.as_str().unwrap().to_string();
}
if let Some(value) = value.get("given_name") {
google_jwt_profile.given_name = value.as_str().unwrap().to_string();
}
if let Some(value) = value.get("family_name") {
google_jwt_profile.family_name = value.as_str().unwrap().to_string();
}
if let Some(value) = value.get("locale") {
google_jwt_profile.locale = value.as_str().unwrap().to_string();
}
google_jwt_profile
}
}
#[derive(Deserialize, Debug, Clone, Default)]
struct GooglePublicKey {
e: String,
n: String,
kty: String,
use_: String,
alg: String,
kid: String,
}
#[derive(Deserialize, Debug, Clone, Default)]
struct GooglePublicKeys {
keys: Vec<GooglePublicKey>,
refresh_at: i64,
}
impl GoogleSocial {
// 异步获取并解析Google公钥
async fn fetch_and_parse_google_public_keys(&self) -> SocialResult<HashMap<String, GooglePublicKey>>
{
let mut public_keys = self.google_public_keys.lock().await;
if public_keys.keys.is_empty() || public_keys.refresh_at < Utc::now().timestamp() {
let response = Client::new().get(GOOGLE_PUBLIC_CERT_URL).send().await?;
let mut google_keys: GooglePublicKeys = response.json().await?;
tracing::info!("Google公钥获取成功, {:?}", google_keys);
google_keys.refresh_at = Utc::now().timestamp() + 3600;
*public_keys = google_keys;
}
let mut key_map = HashMap::new();
// 解析公钥
for key in public_keys.keys.iter() {
if key.kty == "RSA" {
key_map.insert(key.kid.to_owned(), key.to_owned());
}
}
tracing::info!("Google公钥解析成功, {:?}", key_map);
Ok(key_map)
}
// 验证ID Token并考虑kid匹配
pub async fn verify_id_token(&self, id_token: &str) -> SocialResult<GoogleJwtProfile> {
// 获取并解析公钥
let public_keys = self.fetch_and_parse_google_public_keys().await?;
// 解码Token头部以获取kid
let token_header = jsonwebtoken::decode_header(id_token).unwrap();
let kid = token_header.kid;
// 检查是否找到了有效的kid
if kid.is_none() {
return Err(Box::new(ResErr::social("校验Token失败,未找到有效的kid")));
}
let kid = kid.unwrap();
// 根据kid找到正确的公钥
let key = public_keys
.get(&kid)
.ok_or_else(|| Box::new(ResErr::social("校验Token失败,未找到正确的公钥")))?;
tracing::info!("public key : {:?}", key);
// 验证Token
let mut validation: Validation = Validation::new(Algorithm::RS256);
validation.set_issuer(&["https://accounts.google.com", "accounts.google.com"]); // 设置预期的发行者
validation.validate_aud = false;
let decoded = decode::<Value>(
id_token,
&DecodingKey::from_rsa_components(key.n.as_str(), key.e.as_str()).unwrap(),
&validation,
)?;
let claims: Value = decoded.claims;
let google_jwt_profile = GoogleJwtProfile::from(claims);
// 校验有效期
if google_jwt_profile.exp < Utc::now().timestamp() {
return Err(Box::new(ResErr::social("校验Token失败,token有效期无效")));
}
Ok(google_jwt_profile)
}
}