149 lines
4.9 KiB
Rust
149 lines
4.9 KiB
Rust
use std::net::SocketAddr;
|
|
use std::ops::ControlFlow;
|
|
use std::sync::Arc;
|
|
|
|
use axum::extract::ws::{Message, WebSocket};
|
|
use dashmap::DashMap;
|
|
use futures::lock::Mutex;
|
|
use futures::stream::{SplitSink, SplitStream};
|
|
use futures::{SinkExt, StreamExt};
|
|
use lazy_static::lazy_static;
|
|
use library::context::Context;
|
|
use library::model::response::{ResErr, ResResult};
|
|
|
|
lazy_static!{
|
|
static ref WS_POOL: DashMap<String, Arc<Mutex<SplitSink<WebSocket, Message>>>> = DashMap::<String, Arc<Mutex<SplitSink<WebSocket, Message>>>>::new();
|
|
}
|
|
|
|
/// Actual websocket statemachine (one will be spawned per connection)
|
|
pub async fn handle_socket(socket: WebSocket, _who: SocketAddr, context: Context) {
|
|
let account = context.get_account().unwrap();
|
|
tracing::info!("`{:?}` at {:?} connected, user is {:?}", account, _who, account.username);
|
|
let (sender, receiver) = socket.split();
|
|
|
|
// tokio::spawn(write(sender));
|
|
tokio::spawn(read(account.id.clone(), receiver));
|
|
|
|
// 测试消息发送
|
|
tokio::spawn(send());
|
|
|
|
WS_POOL.insert(account.id.clone(), Arc::new(Mutex::new(sender)));
|
|
}
|
|
|
|
/// 测试消息发送
|
|
async fn send() {
|
|
loop {
|
|
tokio::time::sleep(std::time::Duration::from_millis(2000)).await;
|
|
let _ = broadcast_message(Message::Text("哈哈".into())).await;
|
|
}
|
|
}
|
|
|
|
async fn read(account_id: String, mut receiver: SplitStream<WebSocket>) {
|
|
loop {
|
|
match receiver.next().await {
|
|
Some(Ok(msg)) => {
|
|
if let ControlFlow::Break(_) = process_message(msg) {
|
|
// 收到关闭通知,移除该连接,并跳出循环
|
|
WS_POOL.remove(&account_id);
|
|
break;
|
|
}
|
|
},
|
|
Some(Err(err)) => {
|
|
tracing::error!("读取消息失败 {:?}", err);
|
|
}
|
|
None => {},
|
|
}
|
|
}
|
|
}
|
|
|
|
fn process_message(msg: Message) -> ControlFlow<(), ()> {
|
|
match msg {
|
|
Message::Text(t) => {
|
|
tracing::info!("接收到消息 :{}", t.to_string());
|
|
}
|
|
Message::Close(c) => {
|
|
if let Some(cf) = c {
|
|
tracing::info!("收到关闭通知 code {}, reason `{}`", cf.code, cf.reason);
|
|
} else {
|
|
tracing::info!(">>> somehow sent close message without CloseFrame");
|
|
}
|
|
return ControlFlow::Break(());
|
|
}
|
|
|
|
Message::Pong(v) => {
|
|
tracing::info!(">>> sent pong with {v:?}");
|
|
}
|
|
_ => {}
|
|
}
|
|
ControlFlow::Continue(())
|
|
}
|
|
|
|
/// 发送广播消息
|
|
pub async fn broadcast_message(msg: Message) {
|
|
for ele in WS_POOL.iter() {
|
|
let mut sender = ele.value().lock().await;
|
|
if let Err(err) = sender.send(msg.clone()).await {
|
|
tracing::error!("Failed to send message: {:?}", err);
|
|
}
|
|
}
|
|
}
|
|
|
|
/// 发送定向消息
|
|
pub async fn send_message(msg: Message, accoudIds: Vec<String>) {
|
|
for ele in WS_POOL.iter() {
|
|
let mut sender = ele.value().lock().await;
|
|
if accoudIds.contains(&ele.key().to_string()) {
|
|
if let Err(err) = sender.send(msg.clone()).await {
|
|
tracing::error!("Failed to send message: {:?}", err);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// async fn write(mut sender: SplitSink<WebSocket, Message>) {
|
|
// loop {
|
|
// match sender.send(Message::text("haha")).await {
|
|
// Ok(_) => {
|
|
// // tracing::info!("发送成功");
|
|
// },
|
|
// Err(err) => {
|
|
// tracing::error!("发送失败 {:?}", err);
|
|
// },
|
|
// }
|
|
// tokio::time::sleep(std::time::Duration::from_millis(2000)).await;
|
|
// }
|
|
// }
|
|
|
|
|
|
// fn process_message(msg: Message, who: SocketAddr) -> ControlFlow<(), ()> {
|
|
// match msg {
|
|
// Message::Text(t) => {
|
|
// tracing::info!(">>> {who} sent str: {t:?}");
|
|
// }
|
|
// Message::Binary(d) => {
|
|
// tracing::info!(">>> {} sent {} bytes: {:?}", who, d.len(), d);
|
|
// }
|
|
// Message::Close(c) => {
|
|
// if let Some(cf) = c {
|
|
// tracing::info!(
|
|
// ">>> {} sent close with code {} and reason `{}`",
|
|
// who, cf.code, cf.reason
|
|
// );
|
|
// } else {
|
|
// tracing::info!(">>> {who} somehow sent close message without CloseFrame");
|
|
// }
|
|
// return ControlFlow::Break(());
|
|
// }
|
|
|
|
// Message::Pong(v) => {
|
|
// tracing::info!(">>> {who} sent pong with {v:?}");
|
|
// }
|
|
// // You should never need to manually handle Message::Ping, as axum's websocket library
|
|
// // will do so for you automagically by replying with Pong and copying the v according to
|
|
// // spec. But if you need the contents of the pings you can see them here.
|
|
// Message::Ping(v) => {
|
|
// // tracing::info!(">>> {who} sent ping with {v:?}");
|
|
// }
|
|
// }
|
|
// ControlFlow::Continue(())
|
|
// }
|