From 3d5e8e027679ee8cacc311024079e964a835e227 Mon Sep 17 00:00:00 2001 From: rustdesk Date: Wed, 22 Mar 2023 16:50:59 +0800 Subject: [PATCH] vdi: new message loop --- vdi/host/Cargo.lock | 1 - vdi/host/Cargo.toml | 1 - vdi/host/src/connection.rs | 10 ++++ vdi/host/src/console.rs | 92 +++--------------------------- vdi/host/src/server.rs | 111 +++++++++++++++++++++++-------------- 5 files changed, 88 insertions(+), 127 deletions(-) diff --git a/vdi/host/Cargo.lock b/vdi/host/Cargo.lock index 161664625..0b2e8ca2b 100644 --- a/vdi/host/Cargo.lock +++ b/vdi/host/Cargo.lock @@ -1528,7 +1528,6 @@ version = "0.1.0" dependencies = [ "async-trait", "clap", - "derivative", "hbb_common", "image", "qemu-display", diff --git a/vdi/host/Cargo.toml b/vdi/host/Cargo.toml index a1fd1e68e..42b6fc83a 100644 --- a/vdi/host/Cargo.toml +++ b/vdi/host/Cargo.toml @@ -9,6 +9,5 @@ qemu-display = { git = "https://github.com/rustdesk/qemu-display" } hbb_common = { path = "../../libs/hbb_common" } clap = { version = "4.1", features = ["derive"] } zbus = { version = "3.0" } -derivative = "2.2" image = "0.24" async-trait = "0.1" diff --git a/vdi/host/src/connection.rs b/vdi/host/src/connection.rs index f416ce4e0..9f856fa2e 100644 --- a/vdi/host/src/connection.rs +++ b/vdi/host/src/connection.rs @@ -1 +1,11 @@ use hbb_common::{message_proto::*, tokio, ResultType}; +pub use tokio::sync::{mpsc, Mutex}; +pub struct Connection { + pub tx: mpsc::UnboundedSender, +} + +impl Connection { + pub async fn on_message(&mut self, message: Message) -> ResultType { + Ok(true) + } +} diff --git a/vdi/host/src/console.rs b/vdi/host/src/console.rs index 29dcbbfc9..a342f1a9a 100644 --- a/vdi/host/src/console.rs +++ b/vdi/host/src/console.rs @@ -1,7 +1,7 @@ -use hbb_common::{log, tokio, ResultType}; +use hbb_common::{tokio, ResultType}; use image::GenericImage; use qemu_display::{Console, ConsoleListenerHandler, MouseButton}; -use std::{collections::HashSet, sync::Arc, time}; +use std::{collections::HashSet, sync::Arc}; pub use tokio::sync::{mpsc, Mutex}; #[derive(Debug)] @@ -54,91 +54,17 @@ impl ConsoleListenerHandler for ConsoleListener { } fn disconnected(&mut self) { - dbg!(); + self.tx.send(Event::Disconnected).ok(); } } -#[derive(derivative::Derivative)] -#[derivative(Debug)] -pub struct Client { - #[derivative(Debug = "ignore")] - console: Arc>, - last_update: Option, - has_update: bool, - req_update: bool, - last_buttons: HashSet, - dimensions: (u16, u16), - image: Arc>, -} - -impl Client { - pub fn new(console: Arc>, image: Arc>) -> Self { - Self { - console, - image, - last_update: None, - has_update: false, - req_update: false, - last_buttons: HashSet::new(), - dimensions: (0, 0), - } - } - - pub fn update_pending(&self) -> bool { - self.has_update && self.req_update - } - - pub async fn key_event(&self, qnum: u32, down: bool) -> ResultType<()> { - let console = self.console.lock().await; - if down { - console.keyboard.press(qnum).await?; - } else { - console.keyboard.release(qnum).await?; - } - Ok(()) - } - - pub async fn desktop_resize(&mut self) -> ResultType<()> { - let image = self.image.lock().await; - let (width, height) = (image.width() as _, image.height() as _); - if (width, height) == self.dimensions { - return Ok(()); - } - self.dimensions = (width, height); - Ok(()) - } - - pub async fn send_framebuffer_update(&mut self) -> ResultType<()> { - self.desktop_resize().await?; - if self.has_update && self.req_update { - if let Some(last_update) = self.last_update { - if last_update.elapsed().as_millis() < 10 { - log::info!("TODO: <10ms, could delay update..") - } - } - // self.server.send_framebuffer_update(&self.vnc_server)?; - self.last_update = Some(time::Instant::now()); - self.has_update = false; - self.req_update = false; - } - Ok(()) - } - - pub async fn handle_event(&mut self, event: Option) -> ResultType { - match event { - Some(Event::ConsoleUpdate(_)) => { - self.has_update = true; - } - Some(Event::Disconnected) => { - return Ok(false); - } - None => { - self.send_framebuffer_update().await?; - } - } - - Ok(true) +pub async fn key_event(console: &mut Console, qnum: u32, down: bool) -> ResultType<()> { + if down { + console.keyboard.press(qnum).await?; + } else { + console.keyboard.release(qnum).await?; } + Ok(()) } fn image_from_vec(format: u32, width: u32, height: u32, stride: u32, data: Vec) -> BgraImage { diff --git a/vdi/host/src/server.rs b/vdi/host/src/server.rs index 5fd28d2d7..b43bd364f 100644 --- a/vdi/host/src/server.rs +++ b/vdi/host/src/server.rs @@ -1,13 +1,18 @@ use clap::Parser; -use hbb_common::{anyhow::Context, log, tokio, ResultType}; -use qemu_display::{Console, VMProxy}; -use std::{ - borrow::Borrow, - net::{TcpListener, TcpStream}, - sync::Arc, - thread, +use hbb_common::{ + allow_err, + anyhow::{bail, Context}, + log, + message_proto::*, + protobuf::Message as _, + tokio, + tokio::net::TcpListener, + ResultType, Stream, }; +use qemu_display::{Console, VMProxy}; +use std::{borrow::Borrow, sync::Arc}; +use crate::connection::*; use crate::console::*; #[derive(Parser, Debug)] @@ -37,8 +42,10 @@ struct Cli { #[derive(Debug)] struct Server { vm_name: String, - rx: mpsc::UnboundedReceiver, - tx: mpsc::UnboundedSender, + rx_console: mpsc::UnboundedReceiver, + tx_console: mpsc::UnboundedSender, + rx_conn: mpsc::UnboundedReceiver, + tx_conn: mpsc::UnboundedSender, image: Arc>, console: Arc>, } @@ -48,12 +55,15 @@ impl Server { let width = console.width().await?; let height = console.height().await?; let image = BgraImage::new(width as _, height as _); - let (tx, rx) = mpsc::unbounded_channel(); + let (tx_console, rx_console) = mpsc::unbounded_channel(); + let (tx_conn, rx_conn) = mpsc::unbounded_channel(); Ok(Self { vm_name, - rx, + rx_console, + tx_console, + rx_conn, + tx_conn, image: Arc::new(Mutex::new(image)), - tx, console: Arc::new(Mutex::new(console)), }) } @@ -69,7 +79,7 @@ impl Server { .await .register_listener(ConsoleListener { image: self.image.clone(), - tx: self.tx.clone(), + tx: self.tx_console.clone(), }) .await?; Ok(()) @@ -80,33 +90,47 @@ impl Server { (image.width() as u16, image.height() as u16) } - async fn handle_connection(&mut self, stream: TcpStream) -> ResultType<()> { - let (width, height) = self.dimensions().await; - - let tx = self.tx.clone(); - let _client_thread = thread::spawn(move || loop {}); - - let mut client = Client::new(self.console.clone(), self.image.clone()); + async fn handle_connection(&mut self, stream: Stream) -> ResultType<()> { + let mut stream = stream; self.run_console().await?; + let mut conn = Connection { + tx: self.tx_conn.clone(), + }; + loop { - let ev = if client.update_pending() { - match self.rx.try_recv() { - Ok(e) => Some(e), - Err(mpsc::error::TryRecvError::Empty) => None, - Err(e) => { - return Err(e.into()); + tokio::select! { + Some(evt) = self.rx_console.recv() => { + match evt { + _ => {} + } + } + Some(msg) = self.rx_conn.recv() => { + allow_err!(stream.send(&msg).await); + } + res = stream.next() => { + if let Some(res) = res { + match res { + Err(err) => { + bail!(err); + } + Ok(bytes) => { + if let Ok(msg_in) = Message::parse_from_bytes(&bytes) { + match conn.on_message(msg_in).await { + Ok(false) => { + break; + } + Err(err) => { + log::error!("{err}"); + } + _ => {} + } + } + } + } + } else { + bail!("Reset by the peer"); } } - } else { - Some( - self.rx - .recv() - .await - .context("Channel closed unexpectedly")?, - ) - }; - if !client.handle_event(ev).await? { - break; } } @@ -119,7 +143,9 @@ impl Server { pub async fn run() -> ResultType<()> { let args = Cli::parse(); - let listener = TcpListener::bind::(args.address.into()).unwrap(); + let listener = TcpListener::bind::(args.address.into()) + .await + .unwrap(); let dbus = if let Some(addr) = args.dbus_address { zbus::ConnectionBuilder::address(addr.borrow())? .build() @@ -134,12 +160,13 @@ pub async fn run() -> ResultType<()> { .await .context("Failed to get the console")?; let mut server = Server::new(format!("qemu-rustdesk ({})", vm_name), console).await?; - for stream in listener.incoming() { - let stream = stream?; + loop { + let (stream, addr) = listener.accept().await?; + stream.set_nodelay(true).ok(); + let laddr = stream.local_addr()?; + let stream = Stream::from(stream, laddr); if let Err(err) = server.handle_connection(stream).await { - log::error!("Connection closed: {err}"); + log::error!("Connection from {addr} closed: {err}"); } } - - Ok(()) }