use std::net::SocketAddr; use std::ops::ControlFlow; use std::sync::Arc; use axum::extract::ws::{Message, WebSocket}; use dashmap::DashMap; use futures::lock::Mutex; use futures::stream::{SplitSink, SplitStream}; use futures::{SinkExt, StreamExt}; use lazy_static::lazy_static; use library::context::Context; use library::model::response::{ResErr, ResResult}; lazy_static!{ static ref WS_POOL: DashMap>>> = DashMap::>>>::new(); } /// Actual websocket statemachine (one will be spawned per connection) pub async fn handle_socket(socket: WebSocket, _who: SocketAddr, context: Context) { let account = context.get_account().unwrap(); tracing::info!("`{:?}` at {:?} connected, user is {:?}", account, _who, account.username); let (sender, receiver) = socket.split(); // tokio::spawn(write(sender)); tokio::spawn(read(account.id.clone(), receiver)); // 测试消息发送 tokio::spawn(send()); WS_POOL.insert(account.id.clone(), Arc::new(Mutex::new(sender))); } /// 测试消息发送 async fn send() { loop { tokio::time::sleep(std::time::Duration::from_millis(2000)).await; let _ = broadcast_message(Message::Text("哈哈".into())).await; } } async fn read(account_id: String, mut receiver: SplitStream) { loop { match receiver.next().await { Some(Ok(msg)) => { if let ControlFlow::Break(_) = process_message(msg) { // 收到关闭通知,移除该连接,并跳出循环 WS_POOL.remove(&account_id); break; } }, Some(Err(err)) => { tracing::error!("读取消息失败 {:?}", err); } None => {}, } } } fn process_message(msg: Message) -> ControlFlow<(), ()> { match msg { Message::Text(t) => { tracing::info!("接收到消息 :{}", t.to_string()); } Message::Close(c) => { if let Some(cf) = c { tracing::info!("收到关闭通知 code {}, reason `{}`", cf.code, cf.reason); } else { tracing::info!(">>> somehow sent close message without CloseFrame"); } return ControlFlow::Break(()); } Message::Pong(v) => { tracing::info!(">>> sent pong with {v:?}"); } _ => {} } ControlFlow::Continue(()) } /// 发送广播消息 pub async fn broadcast_message(msg: Message) { for ele in WS_POOL.iter() { let mut sender = ele.value().lock().await; if let Err(err) = sender.send(msg.clone()).await { tracing::error!("Failed to send message: {:?}", err); } } } /// 发送定向消息 pub async fn send_message(msg: Message, accoudIds: Vec) { for ele in WS_POOL.iter() { let mut sender = ele.value().lock().await; if accoudIds.contains(&ele.key().to_string()) { if let Err(err) = sender.send(msg.clone()).await { tracing::error!("Failed to send message: {:?}", err); } } } } // async fn write(mut sender: SplitSink) { // loop { // match sender.send(Message::text("haha")).await { // Ok(_) => { // // tracing::info!("发送成功"); // }, // Err(err) => { // tracing::error!("发送失败 {:?}", err); // }, // } // tokio::time::sleep(std::time::Duration::from_millis(2000)).await; // } // } // fn process_message(msg: Message, who: SocketAddr) -> ControlFlow<(), ()> { // match msg { // Message::Text(t) => { // tracing::info!(">>> {who} sent str: {t:?}"); // } // Message::Binary(d) => { // tracing::info!(">>> {} sent {} bytes: {:?}", who, d.len(), d); // } // Message::Close(c) => { // if let Some(cf) = c { // tracing::info!( // ">>> {} sent close with code {} and reason `{}`", // who, cf.code, cf.reason // ); // } else { // tracing::info!(">>> {who} somehow sent close message without CloseFrame"); // } // return ControlFlow::Break(()); // } // Message::Pong(v) => { // tracing::info!(">>> {who} sent pong with {v:?}"); // } // // You should never need to manually handle Message::Ping, as axum's websocket library // // will do so for you automagically by replying with Pong and copying the v according to // // spec. But if you need the contents of the pings you can see them here. // Message::Ping(v) => { // // tracing::info!(">>> {who} sent ping with {v:?}"); // } // } // ControlFlow::Continue(()) // }