diff --git a/build/azure-pipelines/win32/cli-build-win32.yml b/build/azure-pipelines/win32/cli-build-win32.yml index c4c54a35ab4..7c476519aa6 100644 --- a/build/azure-pipelines/win32/cli-build-win32.yml +++ b/build/azure-pipelines/win32/cli-build-win32.yml @@ -60,6 +60,7 @@ steps: VSCODE_CLI_ENV: OPENSSL_LIB_DIR: $(Build.ArtifactStagingDirectory)/openssl/x64-windows-static-md/lib OPENSSL_INCLUDE_DIR: $(Build.ArtifactStagingDirectory)/openssl/x64-windows-static-md/include + RUSTFLAGS: '-C target-feature=+crt-static' - ${{ if eq(parameters.VSCODE_BUILD_WIN32_ARM64, true) }}: - template: ../cli/cli-compile-and-publish.yml @@ -69,6 +70,7 @@ steps: VSCODE_CLI_ENV: OPENSSL_LIB_DIR: $(Build.ArtifactStagingDirectory)/openssl/arm64-windows-static-md/lib OPENSSL_INCLUDE_DIR: $(Build.ArtifactStagingDirectory)/openssl/arm64-windows-static-md/include + RUSTFLAGS: '-C target-feature=+crt-static' - ${{ if eq(parameters.VSCODE_BUILD_WIN32_32BIT, true) }}: - template: ../cli/cli-compile-and-publish.yml @@ -78,3 +80,4 @@ steps: VSCODE_CLI_ENV: OPENSSL_LIB_DIR: $(Build.ArtifactStagingDirectory)/openssl/x86-windows-static-md/lib OPENSSL_INCLUDE_DIR: $(Build.ArtifactStagingDirectory)/openssl/x86-windows-static-md/include + RUSTFLAGS: '-C target-feature=+crt-static' diff --git a/cli/.cargo/config.toml b/cli/.cargo/config.toml deleted file mode 100644 index 35c67ad3d28..00000000000 --- a/cli/.cargo/config.toml +++ /dev/null @@ -1,2 +0,0 @@ -[target.'cfg(all(windows, target_env = "msvc"))'] -rustflags = ["-C", "target-feature=+crt-static"] diff --git a/cli/Cargo.lock b/cli/Cargo.lock index 8bee1ff21ec..73cdff12358 100644 --- a/cli/Cargo.lock +++ b/cli/Cargo.lock @@ -230,6 +230,7 @@ dependencies = [ "async-trait", "atty", "base64", + "cfg-if", "chrono", "clap", "clap_lex", @@ -249,6 +250,7 @@ dependencies = [ "open", "opentelemetry", "opentelemetry-application-insights", + "pin-project", "rand 0.8.5", "regex", "reqwest", @@ -261,6 +263,7 @@ dependencies = [ "sysinfo", "tar", "tempfile", + "thiserror", "tokio", "tokio-util", "tunnels", @@ -1533,6 +1536,26 @@ version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "478c572c3d73181ff3c2539045f6eb99e5491218eae919370993b890cdbdd98e" +[[package]] +name = "pin-project" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad29a609b6bcd67fee905812e544992d216af9d755757c05ed2d0e15a74c6ecc" +dependencies = [ + "pin-project-internal", +] + +[[package]] +name = "pin-project-internal" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "069bdb1e05adc7a8990dce9cc75370895fbe4e3d58b9b73bf1aee56359344a55" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "pin-project-lite" version = "0.2.9" @@ -2207,18 +2230,18 @@ checksum = "949517c0cf1bf4ee812e2e07e08ab448e3ae0d23472aee8a06c985f0c8815b16" [[package]] name = "thiserror" -version = "1.0.37" +version = "1.0.39" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "10deb33631e3c9018b9baf9dcbbc4f737320d2b576bac10f6aefa048fa407e3e" +checksum = "a5ab016db510546d856297882807df8da66a16fb8c4101cb8b30054b0d5b2d9c" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.37" +version = "1.0.39" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "982d17546b47146b28f7c22e3d08465f6b8903d0ea13c1660d9d84a6e7adcdbb" +checksum = "5420d42e90af0c38c3290abcca25b9b3bdf379fc9f55c528f53a269d9c9a267e" dependencies = [ "proc-macro2", "quote", diff --git a/cli/Cargo.toml b/cli/Cargo.toml index 6d140536900..536c08ff829 100644 --- a/cli/Cargo.toml +++ b/cli/Cargo.toml @@ -50,6 +50,9 @@ const_format = "0.2" sha2 = "0.10" base64 = "0.13" shell-escape = "0.1.5" +thiserror = "1.0" +cfg-if = "1.0.0" +pin-project = "1.0" [build-dependencies] serde = { version = "1.0" } diff --git a/cli/src/async_pipe.rs b/cli/src/async_pipe.rs new file mode 100644 index 00000000000..dcbe0d16017 --- /dev/null +++ b/cli/src/async_pipe.rs @@ -0,0 +1,183 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +use crate::{constants::APPLICATION_NAME, util::errors::CodeError}; +use std::path::{Path, PathBuf}; +use uuid::Uuid; + +// todo: we could probably abstract this into some crate, if one doesn't already exist + +cfg_if::cfg_if! { + if #[cfg(unix)] { + pub type AsyncPipe = tokio::net::UnixStream; + pub type AsyncPipeWriteHalf = tokio::net::unix::OwnedWriteHalf; + pub type AsyncPipeReadHalf = tokio::net::unix::OwnedReadHalf; + + pub async fn get_socket_rw_stream(path: &Path) -> Result { + tokio::net::UnixStream::connect(path) + .await + .map_err(CodeError::AsyncPipeFailed) + } + + pub async fn listen_socket_rw_stream(path: &Path) -> Result { + tokio::net::UnixListener::bind(path) + .map(AsyncPipeListener) + .map_err(CodeError::AsyncPipeListenerFailed) + } + + pub struct AsyncPipeListener(tokio::net::UnixListener); + + impl AsyncPipeListener { + pub async fn accept(&mut self) -> Result { + self.0.accept().await.map_err(CodeError::AsyncPipeListenerFailed).map(|(s, _)| s) + } + } + + pub fn socket_stream_split(pipe: AsyncPipe) -> (AsyncPipeReadHalf, AsyncPipeWriteHalf) { + pipe.into_split() + } + } else { + use tokio::{time::sleep, io::{AsyncRead, AsyncWrite, ReadBuf}}; + use tokio::net::windows::named_pipe::{ClientOptions, ServerOptions, NamedPipeClient, NamedPipeServer}; + use std::{time::Duration, pin::Pin, task::{Context, Poll}, io}; + use pin_project::pin_project; + + #[pin_project(project = AsyncPipeProj)] + pub enum AsyncPipe { + PipeClient(#[pin] NamedPipeClient), + PipeServer(#[pin] NamedPipeServer), + } + + impl AsyncRead for AsyncPipe { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + match self.project() { + AsyncPipeProj::PipeClient(c) => c.poll_read(cx, buf), + AsyncPipeProj::PipeServer(c) => c.poll_read(cx, buf), + } + } + } + + impl AsyncWrite for AsyncPipe { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + match self.project() { + AsyncPipeProj::PipeClient(c) => c.poll_write(cx, buf), + AsyncPipeProj::PipeServer(c) => c.poll_write(cx, buf), + } + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> { + match self.project() { + AsyncPipeProj::PipeClient(c) => c.poll_write_vectored(cx, bufs), + AsyncPipeProj::PipeServer(c) => c.poll_write_vectored(cx, bufs), + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.project() { + AsyncPipeProj::PipeClient(c) => c.poll_flush(cx), + AsyncPipeProj::PipeServer(c) => c.poll_flush(cx), + } + } + + fn is_write_vectored(&self) -> bool { + match self { + AsyncPipe::PipeClient(c) => c.is_write_vectored(), + AsyncPipe::PipeServer(c) => c.is_write_vectored(), + } + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.project() { + AsyncPipeProj::PipeClient(c) => c.poll_shutdown(cx), + AsyncPipeProj::PipeServer(c) => c.poll_shutdown(cx), + } + } + } + + pub type AsyncPipeWriteHalf = tokio::io::WriteHalf; + pub type AsyncPipeReadHalf = tokio::io::ReadHalf; + + pub async fn get_socket_rw_stream(path: &Path) -> Result { + // Tokio says we can need to try in a loop. Do so. + // https://docs.rs/tokio/latest/tokio/net/windows/named_pipe/struct.NamedPipeClient.html + let client = loop { + match ClientOptions::new().open(path) { + Ok(client) => break client, + // ERROR_PIPE_BUSY https://docs.microsoft.com/en-us/windows/win32/debug/system-error-codes--0-499- + Err(e) if e.raw_os_error() == Some(231) => sleep(Duration::from_millis(100)).await, + Err(e) => return Err(CodeError::AsyncPipeFailed(e)), + } + }; + + Ok(AsyncPipe::PipeClient(client)) + } + + pub struct AsyncPipeListener { + path: PathBuf, + server: NamedPipeServer + } + + impl AsyncPipeListener { + pub async fn accept(&mut self) -> Result { + // see https://docs.rs/tokio/latest/tokio/net/windows/named_pipe/struct.NamedPipeServer.html + // this is a bit weird in that the server becomes the client once + // they get a connection, and we create a new client. + + self.server + .connect() + .await + .map_err(CodeError::AsyncPipeListenerFailed)?; + + // Construct the next server to be connected before sending the one + // we already have of onto a task. This ensures that the server + // isn't closed (after it's done in the task) before a new one is + // available. Otherwise the client might error with + // `io::ErrorKind::NotFound`. + let next_server = ServerOptions::new() + .create(&self.path) + .map_err(CodeError::AsyncPipeListenerFailed)?; + + + Ok(AsyncPipe::PipeServer(std::mem::replace(&mut self.server, next_server))) + } + } + + pub async fn listen_socket_rw_stream(path: &Path) -> Result { + let server = ServerOptions::new() + .first_pipe_instance(true) + .create(path) + .map_err(CodeError::AsyncPipeListenerFailed)?; + + Ok(AsyncPipeListener { path: path.to_owned(), server }) + } + + pub fn socket_stream_split(pipe: AsyncPipe) -> (AsyncPipeReadHalf, AsyncPipeWriteHalf) { + tokio::io::split(pipe) + } + } +} + +/// Gets a random name for a pipe/socket on the paltform +pub fn get_socket_name() -> PathBuf { + cfg_if::cfg_if! { + if #[cfg(unix)] { + std::env::temp_dir().join(format!("{}-{}", APPLICATION_NAME, Uuid::new_v4())) + } else { + PathBuf::from(format!(r"\\.\pipe\{}-{}", APPLICATION_NAME, Uuid::new_v4())) + } + } +} diff --git a/cli/src/commands.rs b/cli/src/commands.rs index 32b1ac3592b..082031af201 100644 --- a/cli/src/commands.rs +++ b/cli/src/commands.rs @@ -6,8 +6,8 @@ mod context; pub mod args; +pub mod internal_wsl; pub mod tunnels; pub mod update; pub mod version; -pub mod internal_wsl; pub use context::CommandContext; diff --git a/cli/src/commands/internal_wsl.rs b/cli/src/commands/internal_wsl.rs index 9912b59428b..483ee52c6aa 100644 --- a/cli/src/commands/internal_wsl.rs +++ b/cli/src/commands/internal_wsl.rs @@ -4,14 +4,14 @@ *--------------------------------------------------------------------------------------------*/ use crate::{ - tunnels::{serve_wsl, shutdown_signal::ShutdownSignal}, + tunnels::{serve_wsl, shutdown_signal::ShutdownRequest}, util::{errors::AnyError, prereqs::PreReqChecker}, }; use super::CommandContext; pub async fn serve(ctx: CommandContext) -> Result { - let signal = ShutdownSignal::create_rx(&[ShutdownSignal::CtrlC]); + let signal = ShutdownRequest::create_rx([ShutdownRequest::CtrlC]); let platform = spanf!( ctx.log, ctx.log.span("prereq"), diff --git a/cli/src/commands/tunnels.rs b/cli/src/commands/tunnels.rs index 01e5aa280d9..9414ca59cac 100644 --- a/cli/src/commands/tunnels.rs +++ b/cli/src/commands/tunnels.rs @@ -5,9 +5,8 @@ use async_trait::async_trait; use sha2::{Digest, Sha256}; -use std::str::FromStr; +use std::{str::FromStr, time::Duration}; use sysinfo::Pid; -use tokio::sync::mpsc; use super::{ args::{ @@ -17,21 +16,31 @@ use super::{ CommandContext, }; -use crate::tunnels::shutdown_signal::ShutdownSignal; -use crate::tunnels::{dev_tunnels::ActiveTunnel, SleepInhibitor}; use crate::{ auth::Auth, log::{self, Logger}, state::LauncherPaths, tunnels::{ - code_server::CodeServerArgs, create_service_manager, dev_tunnels, legal, - paths::get_all_servers, ServiceContainer, ServiceManager, + code_server::CodeServerArgs, + create_service_manager, dev_tunnels, legal, + paths::get_all_servers, + shutdown_signal::ShutdownRequest, + singleton_server::{start_singleton_server, SingletonServerArgs, BroadcastLogSink}, + ServiceContainer, ServiceManager, }, util::{ errors::{wrap, AnyError}, prereqs::PreReqChecker, }, }; +use crate::{ + singleton::{acquire_singleton, SingletonConnection}, + tunnels::{ + dev_tunnels::ActiveTunnel, + singleton_client::{start_singleton_client, SingletonClientArgs}, + SleepInhibitor, + }, +}; impl From for crate::auth::AuthProvider { fn from(auth_provider: AuthProvider) -> Self { @@ -75,7 +84,6 @@ impl ServiceContainer for TunnelServiceContainer { &mut self, log: log::Logger, launcher_paths: LauncherPaths, - shutdown_rx: mpsc::UnboundedReceiver, ) -> Result<(), AnyError> { let csa = (&self.args).into(); serve_with_csa( @@ -86,7 +94,6 @@ impl ServiceContainer for TunnelServiceContainer { ..Default::default() }, csa, - Some(shutdown_rx), ) .await?; Ok(()) @@ -227,7 +234,7 @@ pub async fn serve(ctx: CommandContext, gateway_args: TunnelServeArgs) -> Result legal::require_consent(&paths, gateway_args.accept_server_license_terms)?; let csa = (&args).into(); - let result = serve_with_csa(paths, log, gateway_args, csa, None).await; + let result = serve_with_csa(paths, log, gateway_args, csa).await; drop(no_sleep); result @@ -242,15 +249,52 @@ fn get_connection_token(tunnel: &ActiveTunnel) -> String { async fn serve_with_csa( paths: LauncherPaths, - log: Logger, + mut log: Logger, gateway_args: TunnelServeArgs, mut csa: CodeServerArgs, - shutdown_rx: Option>, ) -> Result { + let shutdown = match gateway_args + .parent_process_id + .and_then(|p| Pid::from_str(&p).ok()) + { + Some(pid) => ShutdownRequest::create_rx([ + ShutdownRequest::CtrlC, + ShutdownRequest::ParentProcessKilled(pid), + ]), + None => ShutdownRequest::create_rx([ShutdownRequest::CtrlC]), + }; + // Intentionally read before starting the server. If the server updated and // respawn is requested, the old binary will get renamed, and then // current_exe will point to the wrong path. let current_exe = std::env::current_exe().unwrap(); + let server = loop { + if shutdown.is_open() { + return Ok(0); + } + + match acquire_singleton(paths.root().join("tunnel.lock")).await { + Ok(SingletonConnection::Client(stream)) => { + debug!(log, "starting as client to singleton"); + start_singleton_client(SingletonClientArgs { + log: log.clone(), + shutdown: shutdown.clone(), + stream, + }) + .await + } + Ok(SingletonConnection::Singleton(server)) => break server, + Err(e) => { + warning!(log, "error access singleton, retrying: {}", e); + tokio::time::sleep(Duration::from_secs(2)).await + } + } + }; + + debug!(log, "starting as new singleton"); + + let log_broadcast = BroadcastLogSink::new(); + log = log.tee(log_broadcast.clone()); let platform = spanf!(log, log.span("prereq"), PreReqChecker::new().verify())?; let auth = Auth::new(&paths, log.clone()); @@ -264,21 +308,17 @@ async fn serve_with_csa( csa.connection_token = Some(get_connection_token(&tunnel)); - let shutdown_tx = if let Some(tx) = shutdown_rx { - tx - } else if let Some(pid) = gateway_args - .parent_process_id - .and_then(|p| Pid::from_str(&p).ok()) - { - ShutdownSignal::create_rx(&[ - ShutdownSignal::CtrlC, - ShutdownSignal::ParentProcessKilled(pid), - ]) - } else { - ShutdownSignal::create_rx(&[ShutdownSignal::CtrlC]) - }; - - let mut r = crate::tunnels::serve(&log, tunnel, &paths, &csa, platform, shutdown_tx).await?; + let mut r = start_singleton_server(SingletonServerArgs { + log: log.clone(), + tunnel, + paths, + code_server_args: csa, + platform, + log_broadcast, + shutdown, + server, + }) + .await?; r.tunnel.close().await.ok(); if r.respawn { diff --git a/cli/src/json_rpc.rs b/cli/src/json_rpc.rs index 9cc5ad1ade1..68ba6fc64fb 100644 --- a/cli/src/json_rpc.rs +++ b/cli/src/json_rpc.rs @@ -5,12 +5,16 @@ use tokio::{ io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader}, + pin, sync::mpsc, }; use crate::{ rpc::{self, MaybeSync, Serialization}, - util::errors::InvalidRpcDataError, + util::{ + errors::InvalidRpcDataError, + sync::{Barrier, Receivable}, + }, }; use std::io; @@ -39,34 +43,38 @@ pub fn new_json_rpc() -> rpc::RpcBuilder { } #[allow(dead_code)] -pub async fn start_json_rpc( +pub async fn start_json_rpc( dispatcher: rpc::RpcDispatcher, read: impl AsyncRead + Unpin, mut write: impl AsyncWrite + Unpin, - mut msg_rx: mpsc::UnboundedReceiver>, - mut shutdown_rx: mpsc::UnboundedReceiver, + mut msg_rx: impl Receivable>, + mut shutdown_rx: Barrier, ) -> io::Result> { let (write_tx, mut write_rx) = mpsc::unbounded_channel::>(); let mut read = BufReader::new(read); let mut read_buf = String::new(); + let shutdown_fut = shutdown_rx.wait(); + pin!(shutdown_fut); loop { tokio::select! { - r = shutdown_rx.recv() => return Ok(r), + r = &mut shutdown_fut => return Ok(r.ok()), Some(w) = write_rx.recv() => { write.write_all(&w).await?; }, - Some(w) = msg_rx.recv() => { + Some(w) = msg_rx.recv_msg() => { write.write_all(&w).await?; }, n = read.read_line(&mut read_buf) => { let r = match n { Ok(0) => return Ok(None), - Ok(n) => dispatcher.dispatch(read_buf[..n].as_bytes()), + Ok(n) => dispatcher.dispatch(read_buf[..n].as_bytes()), Err(e) => return Err(e) }; + read_buf.truncate(0); + match r { MaybeSync::Sync(Some(v)) => { write_tx.send(v).ok(); diff --git a/cli/src/lib.rs b/cli/src/lib.rs index fd0917843ba..0fe65c9d8c8 100644 --- a/cli/src/lib.rs +++ b/cli/src/lib.rs @@ -18,6 +18,8 @@ pub mod tunnels; pub mod update_service; pub mod util; -mod rpc; +mod async_pipe; mod json_rpc; mod msgpack_rpc; +mod rpc; +mod singleton; diff --git a/cli/src/log.rs b/cli/src/log.rs index a008a1ba06d..15a9e9d88cf 100644 --- a/cli/src/log.rs +++ b/cli/src/log.rs @@ -8,6 +8,7 @@ use opentelemetry::{ sdk::trace::{Tracer, TracerProvider}, trace::{SpanBuilder, Tracer as TraitTracer, TracerProvider as TracerProviderTrait}, }; +use serde::{Deserialize, Serialize}; use std::fmt; use std::{env, path::Path, sync::Arc}; use std::{ @@ -25,7 +26,7 @@ pub fn next_counter() -> u32 { } // Log level -#[derive(clap::ArgEnum, PartialEq, Eq, PartialOrd, Clone, Copy, Debug)] +#[derive(clap::ArgEnum, PartialEq, Eq, PartialOrd, Clone, Copy, Debug, Serialize, Deserialize)] pub enum Level { Trace = 0, Debug, diff --git a/cli/src/msgpack_rpc.rs b/cli/src/msgpack_rpc.rs index b00b4c11ed8..de46e738da8 100644 --- a/cli/src/msgpack_rpc.rs +++ b/cli/src/msgpack_rpc.rs @@ -5,12 +5,16 @@ use tokio::{ io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader}, + pin, sync::mpsc, }; use crate::{ rpc::{self, MaybeSync, Serialization}, - util::errors::{AnyError, InvalidRpcDataError}, + util::{ + errors::{AnyError, InvalidRpcDataError}, + sync::{Barrier, Receivable}, + }, }; use std::io; @@ -35,17 +39,20 @@ pub fn new_msgpack_rpc() -> rpc::RpcBuilder { } #[allow(clippy::read_zero_byte_vec)] // false positive -pub async fn start_msgpack_rpc( +pub async fn start_msgpack_rpc( dispatcher: rpc::RpcDispatcher, read: impl AsyncRead + Unpin, mut write: impl AsyncWrite + Unpin, - mut msg_rx: mpsc::UnboundedReceiver>, - mut shutdown_rx: mpsc::UnboundedReceiver, + mut msg_rx: impl Receivable>, + mut shutdown_rx: Barrier, ) -> io::Result> { let (write_tx, mut write_rx) = mpsc::unbounded_channel::>(); let mut read = BufReader::new(read); let mut decode_buf = vec![]; + let shutdown_fut = shutdown_rx.wait(); + pin!(shutdown_fut); + loop { tokio::select! { u = read.read_u32() => { @@ -66,16 +73,16 @@ pub async fn start_msgpack_rpc( }); } }, - r = shutdown_rx.recv() => return Ok(r), + r = &mut shutdown_fut => return Ok(r.ok()), }; }, Some(m) = write_rx.recv() => { write.write_all(&m).await?; }, - Some(m) = msg_rx.recv() => { + Some(m) = msg_rx.recv_msg() => { write.write_all(&m).await?; }, - r = shutdown_rx.recv() => return Ok(r), + r = &mut shutdown_fut => return Ok(r.ok()), } write.flush().await?; diff --git a/cli/src/rpc.rs b/cli/src/rpc.rs index 2249e5b7cf1..0c48600cc32 100644 --- a/cli/src/rpc.rs +++ b/cli/src/rpc.rs @@ -204,19 +204,28 @@ pub struct RpcCaller { } impl RpcCaller { + pub fn serialize_notify(serializer: &S, method: M, params: A) -> Vec + where + S: Serialization, + M: AsRef + serde::Serialize, + A: Serialize, + { + serializer.serialize(&FullRequest { + id: None, + method, + params, + }) + } + /// Enqueues an outbound call. Returns whether the message was enqueued. pub fn notify(&self, method: M, params: A) -> bool where - M: Into, + M: AsRef + serde::Serialize, A: Serialize, { - let body = self.serializer.serialize(&FullRequest { - id: None, - method: method.into(), - params, - }); - - self.sender.send(body).is_ok() + self.sender + .send(Self::serialize_notify(&self.serializer, method, params)) + .is_ok() } /// Enqueues an outbound call, returning its result. @@ -227,7 +236,7 @@ impl RpcCaller { params: A, ) -> oneshot::Receiver> where - M: Into, + M: AsRef + serde::Serialize, A: Serialize, R: DeserializeOwned + Send + 'static, { @@ -235,7 +244,7 @@ impl RpcCaller { let id = next_message_id(); let body = self.serializer.serialize(&FullRequest { id: Some(id), - method: method.into(), + method, params, }); @@ -349,9 +358,9 @@ struct PartialIncoming { } #[derive(Serialize)] -pub struct FullRequest

{ +pub struct FullRequest, P> { pub id: Option, - pub method: String, + pub method: M, pub params: P, } diff --git a/cli/src/singleton.rs b/cli/src/singleton.rs new file mode 100644 index 00000000000..0e2c8cd844b --- /dev/null +++ b/cli/src/singleton.rs @@ -0,0 +1,177 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +use serde::{Deserialize, Serialize}; +use std::{ + fs::{File, OpenOptions}, + io::{Seek, SeekFrom, Write}, + path::{Path, PathBuf}, + time::Duration, +}; +use sysinfo::{Pid, PidExt}; + +use crate::{ + async_pipe::{ + get_socket_name, get_socket_rw_stream, listen_socket_rw_stream, AsyncPipe, + AsyncPipeListener, + }, + util::{ + errors::CodeError, + file_lock::{FileLock, Lock, PREFIX_LOCKED_BYTES}, + machine::wait_until_process_exits, + }, +}; + +pub struct SingletonServer { + server: AsyncPipeListener, + _lock: FileLock, +} + +impl SingletonServer { + pub async fn accept(&mut self) -> Result { + self.server.accept().await + } +} + +pub enum SingletonConnection { + /// This instance got the singleton lock. It started listening on a socket + /// and has the read/write pair. If this gets dropped, the lock is released. + Singleton(SingletonServer), + /// Another instance is a singleton, and this client connected to it. + Client(AsyncPipe), +} + +/// Contents of the lock file; the listening socket ID and process ID +/// doing the listening. +#[derive(Deserialize, Serialize)] +struct LockFileMatter { + socket_path: String, + pid: u32, +} + +pub async fn acquire_singleton(lock_file: PathBuf) -> Result { + let file = OpenOptions::new() + .read(true) + .write(true) + .create(true) + .open(&lock_file) + .map_err(CodeError::SingletonLockfileOpenFailed)?; + + match FileLock::acquire(file) { + Ok(Lock::AlreadyLocked(mut file)) => connect_as_client(&mut file).await, + Ok(Lock::Acquired(lock)) => start_singleton_server(lock).await, + Err(e) => Err(e), + } +} + +async fn start_singleton_server(mut lock: FileLock) -> Result { + let socket_path = get_socket_name(); + + let mut vec = Vec::with_capacity(128); + let _ = vec.write(&[0; PREFIX_LOCKED_BYTES]); + let _ = rmp_serde::encode::write( + &mut vec, + &LockFileMatter { + socket_path: socket_path.to_string_lossy().to_string(), + pid: std::process::id(), + }, + ); + + lock.file_mut() + .write_all(&vec) + .map_err(CodeError::SingletonLockfileOpenFailed)?; + + let server = listen_socket_rw_stream(&socket_path).await?; + Ok(SingletonConnection::Singleton(SingletonServer { + server, + _lock: lock, + })) +} + +const MAX_CLIENT_ATTEMPTS: i32 = 10; + +async fn connect_as_client(mut file: &mut File) -> Result { + // retry, since someone else could get a lock and we could read it before + // the JSON info was finished writing out + let mut attempt = 0; + loop { + let _ = file.seek(SeekFrom::Start(PREFIX_LOCKED_BYTES as u64)); + let r = match rmp_serde::from_read::<_, LockFileMatter>(&mut file) { + Ok(prev) => { + let socket_path = PathBuf::from(prev.socket_path); + + tokio::select! { + p = retry_get_socket_rw_stream(&socket_path, 5, Duration::from_millis(500)) => p, + _ = wait_until_process_exits(Pid::from_u32(prev.pid), 500) => Err(CodeError::SingletonLockedProcessExited(prev.pid)), + } + } + Err(e) => Err(CodeError::SingletonLockfileReadFailed(e)), + }; + + if r.is_ok() || attempt == MAX_CLIENT_ATTEMPTS { + return r.map(SingletonConnection::Client); + } + + attempt += 1; + tokio::time::sleep(Duration::from_millis(500)).await; + } +} + +async fn retry_get_socket_rw_stream( + path: &Path, + max_tries: usize, + interval: Duration, +) -> Result { + for i in 0.. { + match get_socket_rw_stream(path).await { + Ok(s) => return Ok(s), + Err(e) if i == max_tries => return Err(e), + Err(_) => tokio::time::sleep(interval).await, + } + } + + unreachable!() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_acquires_singleton() { + let dir = tempfile::tempdir().expect("expected to make temp dir"); + let s = acquire_singleton(dir.path().join("lock")) + .await + .expect("expected to acquire"); + + match s { + SingletonConnection::Singleton(_) => {} + _ => panic!("expected to be singleton"), + } + } + + #[tokio::test] + async fn test_acquires_client() { + let dir = tempfile::tempdir().expect("expected to make temp dir"); + let lockfile = dir.path().join("lock"); + let s1 = acquire_singleton(lockfile.clone()) + .await + .expect("expected to acquire1"); + match s1 { + SingletonConnection::Singleton(mut l) => tokio::spawn(async move { + l.accept().await.expect("expected to accept"); + }), + _ => panic!("expected to be singleton"), + }; + + let s2 = acquire_singleton(lockfile) + .await + .expect("expected to acquire2"); + match s2 { + SingletonConnection::Client(_) => {} + _ => panic!("expected to be client"), + } + } +} diff --git a/cli/src/tunnels.rs b/cli/src/tunnels.rs index 73c5ffeb907..2ab7435245a 100644 --- a/cli/src/tunnels.rs +++ b/cli/src/tunnels.rs @@ -8,6 +8,8 @@ pub mod dev_tunnels; pub mod legal; pub mod paths; pub mod shutdown_signal; +pub mod singleton_client; +pub mod singleton_server; mod control_server; mod nosleep; @@ -19,8 +21,6 @@ mod nosleep_macos; mod nosleep_windows; mod port_forwarder; mod protocol; -#[cfg_attr(unix, path = "tunnels/server_bridge_unix.rs")] -#[cfg_attr(windows, path = "tunnels/server_bridge_windows.rs")] mod server_bridge; mod server_multiplexer; mod service; diff --git a/cli/src/tunnels/code_server.rs b/cli/src/tunnels/code_server.rs index 357f04750ba..30e2f2780a8 100644 --- a/cli/src/tunnels/code_server.rs +++ b/cli/src/tunnels/code_server.rs @@ -3,6 +3,7 @@ * Licensed under the MIT License. See License.txt in the project root for license information. *--------------------------------------------------------------------------------------------*/ use super::paths::{InstalledServer, LastUsedServers, ServerPaths}; +use crate::async_pipe::get_socket_name; use crate::constants::{APPLICATION_NAME, QUALITYLESS_PRODUCT_NAME, QUALITYLESS_SERVER_NAME}; use crate::options::{Quality, TelemetryLevel}; use crate::state::LauncherPaths; @@ -32,7 +33,6 @@ use tokio::io::{AsyncBufReadExt, BufReader}; use tokio::process::{Child, Command}; use tokio::sync::oneshot::Receiver; use tokio::time::{interval, timeout}; -use uuid::Uuid; lazy_static! { static ref LISTENING_PORT_RE: Regex = @@ -539,12 +539,7 @@ impl<'a, Http: SimpleHttp + Send + Sync + Clone + 'static> ServerBuilder<'a, Htt } pub async fn listen_on_default_socket(&self) -> Result { - let requested_file = if cfg!(target_os = "windows") { - PathBuf::from(format!(r"\\.\pipe\vscode-server-{}", Uuid::new_v4())) - } else { - std::env::temp_dir().join(format!("vscode-server-{}", Uuid::new_v4())) - }; - + let requested_file = get_socket_name(); self.listen_on_socket(&requested_file).await } diff --git a/cli/src/tunnels/control_server.rs b/cli/src/tunnels/control_server.rs index aec9309ace8..945fba1a366 100644 --- a/cli/src/tunnels/control_server.rs +++ b/cli/src/tunnels/control_server.rs @@ -2,6 +2,7 @@ * Copyright (c) Microsoft Corporation. All rights reserved. * Licensed under the MIT License. See License.txt in the project root for license information. *--------------------------------------------------------------------------------------------*/ +use crate::async_pipe::get_socket_rw_stream; use crate::constants::{CONTROL_PORT, EDITOR_WEB_URL, QUALITYLESS_SERVER_NAME}; use crate::log; use crate::rpc::{MaybeSync, RpcBuilder, RpcDispatcher, Serialization}; @@ -30,7 +31,6 @@ use std::sync::atomic::{AtomicBool, AtomicU32, AtomicUsize, Ordering}; use std::sync::Arc; use std::time::Instant; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader}; -use tokio::pin; use tokio::sync::{mpsc, Mutex}; use super::code_server::{ @@ -45,7 +45,7 @@ use super::protocol::{ ServerMessageParams, ToClientRequest, UnforwardParams, UpdateParams, UpdateResult, VersionParams, }; -use super::server_bridge::{get_socket_rw_stream, ServerBridge}; +use super::server_bridge::ServerBridge; use super::server_multiplexer::ServerMultiplexer; use super::shutdown_signal::ShutdownSignal; use super::socket_signal::{ @@ -155,7 +155,7 @@ pub async fn serve( launcher_paths: &LauncherPaths, code_server_args: &CodeServerArgs, platform: Platform, - shutdown_rx: mpsc::UnboundedReceiver, + mut shutdown_rx: Barrier, ) -> Result { let mut port = tunnel.add_port_direct(CONTROL_PORT).await?; print_listening(log, &tunnel.name); @@ -164,12 +164,10 @@ pub async fn serve( let (tx, mut rx) = mpsc::channel::(4); let (exit_barrier, signal_exit) = new_barrier(); - pin!(shutdown_rx); - loop { tokio::select! { - Some(r) = shutdown_rx.recv() => { - info!(log, "Shutting down: {}", r ); + Ok(r) = shutdown_rx.wait() => { + info!(log, "Shutting down: {}", r); drop(signal_exit); return Ok(ServerTermination { respawn: false, diff --git a/cli/src/tunnels/legal.rs b/cli/src/tunnels/legal.rs index 1e3d7b1bac3..84b72bf8e69 100644 --- a/cli/src/tunnels/legal.rs +++ b/cli/src/tunnels/legal.rs @@ -41,7 +41,10 @@ pub fn require_consent( if accept_server_license_terms { load.consented = Some(true); } else if !*IS_INTERACTIVE_CLI { - return Err(MissingLegalConsent("Run this command again with --accept-server-license-terms to indicate your agreement.".to_string()) + return Err(MissingLegalConsent( + "Run this command again with --accept-server-license-terms to indicate your agreement." + .to_string(), + ) .into()); } else { match prompt_yn(prompt) { diff --git a/cli/src/tunnels/protocol.rs b/cli/src/tunnels/protocol.rs index f2d6dd65dae..846ee904166 100644 --- a/cli/src/tunnels/protocol.rs +++ b/cli/src/tunnels/protocol.rs @@ -4,7 +4,10 @@ *--------------------------------------------------------------------------------------------*/ use std::collections::HashMap; -use crate::{constants::{VSCODE_CLI_VERSION, PROTOCOL_VERSION}, options::Quality}; +use crate::{ + constants::{PROTOCOL_VERSION, VSCODE_CLI_VERSION}, + options::Quality, +}; use serde::{Deserialize, Serialize}; #[derive(Serialize, Debug)] @@ -154,3 +157,22 @@ impl Default for VersionParams { } } } + +pub mod singleton { + use crate::log; + use serde::{Deserialize, Serialize}; + + #[derive(Serialize)] + pub struct LogMessage<'a> { + pub level: log::Level, + pub prefix: &'a str, + pub message: &'a str, + } + + #[derive(Deserialize)] + pub struct LogMessageOwned { + pub level: log::Level, + pub prefix: String, + pub message: String, + } +} diff --git a/cli/src/tunnels/server_bridge_unix.rs b/cli/src/tunnels/server_bridge.rs similarity index 75% rename from cli/src/tunnels/server_bridge_unix.rs rename to cli/src/tunnels/server_bridge.rs index c7be34cf5d0..50dde8e7303 100644 --- a/cli/src/tunnels/server_bridge_unix.rs +++ b/cli/src/tunnels/server_bridge.rs @@ -2,36 +2,19 @@ * Copyright (c) Microsoft Corporation. All rights reserved. * Licensed under the MIT License. See License.txt in the project root for license information. *--------------------------------------------------------------------------------------------*/ -use std::path::Path; - -use tokio::{ - io::{AsyncReadExt, AsyncWriteExt}, - net::{unix::OwnedWriteHalf, UnixStream}, -}; - -use crate::util::errors::{wrap, AnyError}; - use super::socket_signal::{ClientMessageDecoder, ServerMessageSink}; +use crate::{ + async_pipe::{get_socket_rw_stream, socket_stream_split, AsyncPipeWriteHalf}, + util::errors::AnyError, +}; +use std::path::Path; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; pub struct ServerBridge { - write: OwnedWriteHalf, + write: AsyncPipeWriteHalf, decoder: ClientMessageDecoder, } -pub async fn get_socket_rw_stream(path: &Path) -> Result { - let s = UnixStream::connect(path).await.map_err(|e| { - wrap( - e, - format!( - "error connecting to vscode server socket in {}", - path.display() - ), - ) - })?; - - Ok(s) -} - const BUFFER_SIZE: usize = 65536; impl ServerBridge { @@ -41,7 +24,7 @@ impl ServerBridge { decoder: ClientMessageDecoder, ) -> Result { let stream = get_socket_rw_stream(path).await?; - let (mut read, write) = stream.into_split(); + let (mut read, write) = socket_stream_split(stream); tokio::spawn(async move { let mut read_buf = vec![0; BUFFER_SIZE]; diff --git a/cli/src/tunnels/server_bridge_windows.rs b/cli/src/tunnels/server_bridge_windows.rs deleted file mode 100644 index ca604468518..00000000000 --- a/cli/src/tunnels/server_bridge_windows.rs +++ /dev/null @@ -1,132 +0,0 @@ -/*--------------------------------------------------------------------------------------------- - * Copyright (c) Microsoft Corporation. All rights reserved. - * Licensed under the MIT License. See License.txt in the project root for license information. - *--------------------------------------------------------------------------------------------*/ - -use std::{path::Path, time::Duration}; - -use tokio::{ - io::{self, Interest}, - net::windows::named_pipe::{ClientOptions, NamedPipeClient}, - sync::mpsc, - time::sleep, -}; - -use crate::util::errors::{wrap, AnyError}; - -use super::socket_signal::{ClientMessageDecoder, ServerMessageSink}; - -pub struct ServerBridge { - write_tx: mpsc::Sender>, - decoder: ClientMessageDecoder, -} - -const BUFFER_SIZE: usize = 65536; - -pub async fn get_socket_rw_stream(path: &Path) -> Result { - // Tokio says we can need to try in a loop. Do so. - // https://docs.rs/tokio/latest/tokio/net/windows/named_pipe/struct.NamedPipeClient.html - let client = loop { - match ClientOptions::new().open(path) { - Ok(client) => break client, - // ERROR_PIPE_BUSY https://docs.microsoft.com/en-us/windows/win32/debug/system-error-codes--0-499- - Err(e) if e.raw_os_error() == Some(231) => sleep(Duration::from_millis(100)).await, - Err(e) => { - return Err(AnyError::WrappedError(wrap( - e, - format!( - "error connecting to vscode server socket in {}", - path.display() - ), - ))) - } - } - }; - - Ok(client) -} - -impl ServerBridge { - pub async fn new( - path: &Path, - mut target: ServerMessageSink, - decoder: ClientMessageDecoder, - ) -> Result { - let client = get_socket_rw_stream(path).await?; - let (write_tx, mut write_rx) = mpsc::channel(4); - tokio::spawn(async move { - let mut read_buf = vec![0; BUFFER_SIZE]; - let mut pending_recv: Option> = None; - - // See https://docs.rs/tokio/1.17.0/tokio/net/windows/named_pipe/struct.NamedPipeClient.html#method.ready - // With additional complications. If there's nothing queued to write, we wait for the - // pipe to be readable, or for something to come in. If there is something to - // write, wait until the pipe is either readable or writable. - loop { - let ready_result = if pending_recv.is_none() { - tokio::select! { - msg = write_rx.recv() => match msg { - Some(msg) => { - pending_recv = Some(msg); - client.ready(Interest::READABLE | Interest::WRITABLE).await - }, - None => return - }, - r = client.ready(Interest::READABLE) => r, - } - } else { - client.ready(Interest::READABLE | Interest::WRITABLE).await - }; - - let ready = match ready_result { - Ok(r) => r, - Err(_) => return, - }; - - if ready.is_readable() { - match client.try_read(&mut read_buf) { - Ok(0) => return, // EOF - Ok(s) => { - let send = target.server_message(&read_buf[..s]).await; - if send.is_err() { - return; - } - } - Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => { - continue; - } - Err(_) => return, - } - } - - if let Some(msg) = &pending_recv { - if ready.is_writable() { - match client.try_write(msg) { - Ok(n) if n == msg.len() => pending_recv = None, - Ok(n) => pending_recv = Some(msg[n..].to_vec()), - Err(e) if e.kind() == io::ErrorKind::WouldBlock => { - continue; - } - Err(_) => return, - } - } - } - } - }); - - Ok(ServerBridge { write_tx, decoder }) - } - - pub async fn write(&mut self, b: Vec) -> std::io::Result<()> { - let dec = self.decoder.decode(&b)?; - if !dec.is_empty() { - self.write_tx.send(dec.to_vec()).await.ok(); - } - Ok(()) - } - - pub async fn close(self) -> std::io::Result<()> { - drop(self.write_tx); - Ok(()) - } -} diff --git a/cli/src/tunnels/service.rs b/cli/src/tunnels/service.rs index a0a2129bef9..31bf6890996 100644 --- a/cli/src/tunnels/service.rs +++ b/cli/src/tunnels/service.rs @@ -6,15 +6,12 @@ use std::path::{Path, PathBuf}; use async_trait::async_trait; -use tokio::sync::mpsc; use crate::log; use crate::state::LauncherPaths; use crate::util::errors::{wrap, AnyError}; use crate::util::io::{tailf, TailEvent}; -use super::shutdown_signal::ShutdownSignal; - pub const SERVICE_LOG_FILE_NAME: &str = "tunnel-service.log"; #[async_trait] @@ -23,7 +20,6 @@ pub trait ServiceContainer: Send { &mut self, log: log::Logger, launcher_paths: LauncherPaths, - shutdown_rx: mpsc::UnboundedReceiver, ) -> Result<(), AnyError>; } diff --git a/cli/src/tunnels/service_linux.rs b/cli/src/tunnels/service_linux.rs index 022d4cee409..725b72a8d6d 100644 --- a/cli/src/tunnels/service_linux.rs +++ b/cli/src/tunnels/service_linux.rs @@ -10,7 +10,6 @@ use std::{ process::Command, }; -use super::shutdown_signal::ShutdownSignal; use async_trait::async_trait; use zbus::{dbus_proxy, zvariant, Connection}; @@ -119,8 +118,7 @@ impl ServiceManager for SystemdService { launcher_paths: crate::state::LauncherPaths, mut handle: impl 'static + super::ServiceContainer, ) -> Result<(), crate::util::errors::AnyError> { - let rx = ShutdownSignal::create_rx(&[ShutdownSignal::CtrlC]); - handle.run_service(self.log, launcher_paths, rx).await + handle.run_service(self.log, launcher_paths).await } async fn show_logs(&self) -> Result<(), AnyError> { diff --git a/cli/src/tunnels/service_macos.rs b/cli/src/tunnels/service_macos.rs index 6c83af0f039..7344f34b0ac 100644 --- a/cli/src/tunnels/service_macos.rs +++ b/cli/src/tunnels/service_macos.rs @@ -9,7 +9,6 @@ use std::{ path::{Path, PathBuf}, }; -use super::shutdown_signal::ShutdownSignal; use async_trait::async_trait; use crate::{ @@ -73,8 +72,7 @@ impl ServiceManager for LaunchdService { launcher_paths: crate::state::LauncherPaths, mut handle: impl 'static + super::ServiceContainer, ) -> Result<(), crate::util::errors::AnyError> { - let rx = ShutdownSignal::create_rx(&[ShutdownSignal::CtrlC]); - handle.run_service(self.log, launcher_paths, rx).await + handle.run_service(self.log, launcher_paths).await } async fn unregister(&self) -> Result<(), crate::util::errors::AnyError> { diff --git a/cli/src/tunnels/service_windows.rs b/cli/src/tunnels/service_windows.rs index 7d839ccf330..d230d2e454f 100644 --- a/cli/src/tunnels/service_windows.rs +++ b/cli/src/tunnels/service_windows.rs @@ -17,7 +17,6 @@ use crate::{ constants::TUNNEL_ACTIVITY_NAME, log, state::LauncherPaths, - tunnels::shutdown_signal::ShutdownSignal, util::errors::{wrap, wrapdbg, AnyError}, }; @@ -90,8 +89,7 @@ impl CliServiceManager for WindowsService { launcher_paths: LauncherPaths, mut handle: impl 'static + ServiceContainer, ) -> Result<(), AnyError> { - let rx = ShutdownSignal::create_rx(&[ShutdownSignal::CtrlC]); - handle.run_service(self.log, launcher_paths, rx).await + handle.run_service(self.log, launcher_paths).await } async fn unregister(&self) -> Result<(), AnyError> { diff --git a/cli/src/tunnels/shutdown_signal.rs b/cli/src/tunnels/shutdown_signal.rs index 9e185770058..f42e0acc19b 100644 --- a/cli/src/tunnels/shutdown_signal.rs +++ b/cli/src/tunnels/shutdown_signal.rs @@ -3,16 +3,21 @@ * Licensed under the MIT License. See License.txt in the project root for license information. *--------------------------------------------------------------------------------------------*/ -use std::{fmt, time::Duration}; +use std::fmt; +use sysinfo::Pid; -use sysinfo::{Pid, SystemExt}; -use tokio::{sync::mpsc, time::sleep}; +use crate::util::{ + machine::wait_until_process_exits, + sync::{new_barrier, Barrier}, +}; /// Describes the signal to manully stop the server +#[derive(Copy, Clone)] pub enum ShutdownSignal { CtrlC, ParentProcessKilled(Pid), ServiceStopped, + RpcShutdownRequested, } impl fmt::Display for ShutdownSignal { @@ -23,41 +28,57 @@ impl fmt::Display for ShutdownSignal { write!(f, "Parent process {} no longer exists", p) } ShutdownSignal::ServiceStopped => write!(f, "Service stopped"), + ShutdownSignal::RpcShutdownRequested => write!(f, "RPC client requested shutdown"), } } } -impl ShutdownSignal { +pub enum ShutdownRequest { + CtrlC, + ParentProcessKilled(Pid), + RpcShutdownRequested(Barrier<()>), + Derived(Barrier), +} + +impl ShutdownRequest { /// Creates a receiver channel sent to once any of the signals are received. /// Note: does not handle ServiceStopped - pub fn create_rx(signals: &[ShutdownSignal]) -> mpsc::UnboundedReceiver { - let (tx, rx) = mpsc::unbounded_channel(); - for signal in signals { - let tx = tx.clone(); + pub fn create_rx( + signals: impl IntoIterator, + ) -> Barrier { + let (barrier, opener) = new_barrier(); + for signal in signals.into_iter() { + let opener = opener.clone(); match signal { - ShutdownSignal::CtrlC => { + ShutdownRequest::CtrlC => { let ctrl_c = tokio::signal::ctrl_c(); tokio::spawn(async move { ctrl_c.await.ok(); - tx.send(ShutdownSignal::CtrlC).ok(); + opener.open(ShutdownSignal::CtrlC) }); } - ShutdownSignal::ParentProcessKilled(pid) => { - let pid = *pid; - let tx = tx.clone(); + ShutdownRequest::ParentProcessKilled(pid) => { tokio::spawn(async move { - let mut s = sysinfo::System::new(); - while s.refresh_process(pid) { - sleep(Duration::from_millis(2000)).await; - } - tx.send(ShutdownSignal::ParentProcessKilled(pid)).ok(); + wait_until_process_exits(pid, 2000).await; + opener.open(ShutdownSignal::ParentProcessKilled(pid)) }); } - ShutdownSignal::ServiceStopped => { - unreachable!("Cannot use ServiceStopped in ShutdownSignal::create_rx"); + ShutdownRequest::RpcShutdownRequested(mut rx) => { + tokio::spawn(async move { + let _ = rx.wait().await; + opener.open(ShutdownSignal::RpcShutdownRequested) + }); + } + ShutdownRequest::Derived(mut rx) => { + tokio::spawn(async move { + if let Ok(s) = rx.wait().await { + opener.open(s); + } + }); } } } - rx + + barrier } } diff --git a/cli/src/tunnels/singleton_client.rs b/cli/src/tunnels/singleton_client.rs new file mode 100644 index 00000000000..5ae208dc4a8 --- /dev/null +++ b/cli/src/tunnels/singleton_client.rs @@ -0,0 +1,45 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +use crate::{ + async_pipe::{socket_stream_split, AsyncPipe}, + json_rpc::{new_json_rpc, start_json_rpc}, + log, + util::sync::Barrier, +}; + +use super::{protocol, shutdown_signal::ShutdownSignal}; + +pub struct SingletonClientArgs { + pub log: log::Logger, + pub stream: AsyncPipe, + pub shutdown: Barrier, +} + +struct SingletonServerContext { + log: log::Logger, +} + +pub async fn start_singleton_client(args: SingletonClientArgs) { + let rpc = new_json_rpc(); + + debug!( + args.log, + "An existing tunnel is running on this machine, connecting to it..." + ); + + let mut rpc = rpc.methods(SingletonServerContext { + log: args.log.clone(), + }); + + rpc.register_sync("log", |log: protocol::singleton::LogMessageOwned, c| { + c.log + .emit(log.level, &format!("{}: {}", log.prefix, log.message)); + Ok(()) + }); + + let (read, write) = socket_stream_split(args.stream); + let _ = start_json_rpc(rpc.build(args.log), read, write, (), args.shutdown.clone()).await; +} diff --git a/cli/src/tunnels/singleton_server.rs b/cli/src/tunnels/singleton_server.rs new file mode 100644 index 00000000000..73660b32ff4 --- /dev/null +++ b/cli/src/tunnels/singleton_server.rs @@ -0,0 +1,185 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +use std::sync::{Arc, Mutex}; + +use super::{ + code_server::CodeServerArgs, + control_server::ServerTermination, + dev_tunnels::ActiveTunnel, + protocol, + shutdown_signal::{ShutdownRequest, ShutdownSignal}, +}; +use crate::{ + async_pipe::socket_stream_split, + json_rpc::{new_json_rpc, start_json_rpc, JsonRpcSerializer}, + log, + rpc::{RpcCaller, RpcDispatcher}, + singleton::SingletonServer, + state::LauncherPaths, + update_service::Platform, + util::{ + errors::{AnyError, CodeError}, + ring_buffer::RingBuffer, + sync::{new_barrier, Barrier, BarrierOpener, ConcatReceivable}, + }, +}; +use tokio::{ + pin, + sync::{broadcast, mpsc}, +}; + +pub struct SingletonServerArgs { + pub log: log::Logger, + pub tunnel: ActiveTunnel, + pub paths: LauncherPaths, + pub code_server_args: CodeServerArgs, + pub platform: Platform, + pub server: SingletonServer, + pub shutdown: Barrier, + pub log_broadcast: BroadcastLogSink, +} + +#[derive(Clone)] +struct SingletonServerContext { + shutdown: BarrierOpener<()>, +} + +pub async fn start_singleton_server( + mut args: SingletonServerArgs, +) -> Result { + let (shutdown_rx, shutdown_tx) = new_barrier(); + let shutdown_rx = ShutdownRequest::create_rx([ + ShutdownRequest::RpcShutdownRequested(shutdown_rx), + ShutdownRequest::Derived(args.shutdown), + ]); + + let rpc = new_json_rpc(); + + let mut rpc = rpc.methods(SingletonServerContext { + shutdown: shutdown_tx, + }); + + rpc.register_sync("shutdown", |_: protocol::EmptyObject, ctx| { + ctx.shutdown.open(()); + Ok(()) + }); + + let (r1, r2) = tokio::join!( + serve_singleton_rpc( + args.log_broadcast, + &mut args.server, + rpc.build(args.log.clone()), + shutdown_rx.clone(), + ), + super::serve( + &args.log, + args.tunnel, + &args.paths, + &args.code_server_args, + args.platform, + shutdown_rx, + ), + ); + + r1?; + r2 +} + +async fn serve_singleton_rpc( + log_broadcast: BroadcastLogSink, + server: &mut SingletonServer, + dispatcher: RpcDispatcher, + shutdown_rx: Barrier, +) -> Result<(), CodeError> { + let mut own_shutdown = shutdown_rx.clone(); + let shutdown_fut = own_shutdown.wait(); + pin!(shutdown_fut); + + loop { + let cnx = tokio::select! { + c = server.accept() => c?, + _ = &mut shutdown_fut => return Ok(()), + }; + + let (read, write) = socket_stream_split(cnx); + let dispatcher = dispatcher.clone(); + let msg_rx = log_broadcast.replay_and_subscribe(); + let shutdown_rx = shutdown_rx.clone(); + tokio::spawn(async move { + let _ = start_json_rpc(dispatcher.clone(), read, write, msg_rx, shutdown_rx).await; + }); + } +} + +/// Log sink that can broadcast and replay log events. Used for transmitting +/// logs from the singleton to all clients. This should be created and injected +/// into other services, like the tunnel, before `start_singleton_server` +/// is called. +#[derive(Clone)] +pub struct BroadcastLogSink { + recent: Arc>>>, + tx: broadcast::Sender>, +} + +impl Default for BroadcastLogSink { + fn default() -> Self { + Self::new() + } +} + +impl BroadcastLogSink { + pub fn new() -> Self { + let (tx, _) = broadcast::channel(64); + Self { + tx, + recent: Arc::new(Mutex::new(RingBuffer::new(50))), + } + } + + fn replay_and_subscribe( + &self, + ) -> ConcatReceivable, mpsc::UnboundedReceiver>, broadcast::Receiver>> { + let (log_replay_tx, log_replay_rx) = mpsc::unbounded_channel(); + + for log in self.recent.lock().unwrap().iter() { + let _ = log_replay_tx.send(log.clone()); + } + + let _ = log_replay_tx.send(RpcCaller::serialize_notify( + &JsonRpcSerializer {}, + "log", + protocol::singleton::LogMessage { + level: log::Level::Info, + prefix: "", + message: "Connected to an existing tunnel process running on this machined.", + }, + )); + + ConcatReceivable::new(log_replay_rx, self.tx.subscribe()) + } +} + +impl log::LogSink for BroadcastLogSink { + fn write_log(&self, level: log::Level, prefix: &str, message: &str) { + let s = JsonRpcSerializer {}; + let serialized = RpcCaller::serialize_notify( + &s, + "log", + protocol::singleton::LogMessage { + level, + prefix, + message, + }, + ); + + let _ = self.tx.send(serialized.clone()); + self.recent.lock().unwrap().push(serialized); + } + + fn write_result(&self, message: &str) { + self.write_log(log::Level::Info, "", message); + } +} diff --git a/cli/src/tunnels/wsl_server.rs b/cli/src/tunnels/wsl_server.rs index 8859cc7f53e..b6250c8247f 100644 --- a/cli/src/tunnels/wsl_server.rs +++ b/cli/src/tunnels/wsl_server.rs @@ -16,6 +16,7 @@ use crate::{ wrap, AnyError, InvalidRpcDataError, MismatchedLaunchModeError, NoAttachedServerError, }, http::ReqwestSimpleHttp, + sync::Barrier, }, }; @@ -69,7 +70,7 @@ pub async fn serve_wsl( code_server_args: CodeServerArgs, platform: Platform, http: reqwest::Client, - shutdown_rx: mpsc::UnboundedReceiver, + shutdown_rx: Barrier, ) -> Result { let (caller_tx, caller_rx) = mpsc::unbounded_channel(); let mut rpc = new_msgpack_rpc(); diff --git a/cli/src/util.rs b/cli/src/util.rs index 2ed47f2f263..48eb634eb8a 100644 --- a/cli/src/util.rs +++ b/cli/src/util.rs @@ -12,8 +12,10 @@ pub mod input; pub mod io; pub mod machine; pub mod prereqs; +pub mod ring_buffer; pub mod sync; pub use is_integrated::*; +pub mod file_lock; #[cfg(target_os = "linux")] pub mod tar; diff --git a/cli/src/util/errors.rs b/cli/src/util/errors.rs index fa5d67db300..fea4ca25389 100644 --- a/cli/src/util/errors.rs +++ b/cli/src/util/errors.rs @@ -2,11 +2,11 @@ * Copyright (c) Microsoft Corporation. All rights reserved. * Licensed under the MIT License. See License.txt in the project root for license information. *--------------------------------------------------------------------------------------------*/ -use std::fmt::Display; - use crate::constants::{ APPLICATION_NAME, CONTROL_PORT, DOCUMENTATION_URL, QUALITYLESS_PRODUCT_NAME, }; +use std::fmt::Display; +use thiserror::Error; // Wraps another error with additional info. #[derive(Debug, Clone)] @@ -475,6 +475,22 @@ macro_rules! makeAnyError { }; } +/// Internal errors in the VS Code CLI. +/// Note: other error should be migrated to this type gradually +#[derive(Error, Debug)] +pub enum CodeError { + #[error("could not connect to socket/pipe")] + AsyncPipeFailed(std::io::Error), + #[error("could not listen on socket/pipe")] + AsyncPipeListenerFailed(std::io::Error), + #[error("could not create singleton lock file")] + SingletonLockfileOpenFailed(std::io::Error), + #[error("could not read singleton lock file")] + SingletonLockfileReadFailed(rmp_serde::decode::Error), + #[error("the process holding the singleton lock file exited")] + SingletonLockedProcessExited(u32), +} + makeAnyError!( MissingLegalConsent, MismatchConnectionToken, @@ -505,7 +521,8 @@ makeAnyError!( MissingHomeDirectory, CommandFailed, OAuthError, - InvalidRpcDataError + InvalidRpcDataError, + CodeError ); impl From for AnyError { diff --git a/cli/src/util/file_lock.rs b/cli/src/util/file_lock.rs new file mode 100644 index 00000000000..8ee60cba4f8 --- /dev/null +++ b/cli/src/util/file_lock.rs @@ -0,0 +1,125 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +use crate::util::errors::CodeError; +use std::{fs::File, io}; + +pub struct FileLock { + file: File, + #[cfg(windows)] + overlapped: winapi::um::minwinbase::OVERLAPPED, +} + +#[cfg(windows)] // overlapped is thread-safe, mark it so with this +unsafe impl Send for FileLock {} + +pub enum Lock { + Acquired(FileLock), + AlreadyLocked(File), +} + +/// Number of locked bytes in the file. On Windows, locking prevents reads, +/// but consumers of the lock may still want to read what the locking file +/// as written. Thus, only PREFIX_LOCKED_BYTES are locked, and any globally- +/// readable content should be written after the prefix. +#[cfg(windows)] +pub const PREFIX_LOCKED_BYTES: usize = 1; + +#[cfg(unix)] +pub const PREFIX_LOCKED_BYTES: usize = 0; + +impl FileLock { + #[cfg(windows)] + pub fn acquire(file: File) -> Result { + use std::os::windows::prelude::AsRawHandle; + use winapi::{ + shared::winerror::{ERROR_IO_PENDING, ERROR_LOCK_VIOLATION}, + um::{ + fileapi::LockFileEx, + minwinbase::{LOCKFILE_EXCLUSIVE_LOCK, LOCKFILE_FAIL_IMMEDIATELY}, + }, + }; + + let handle = file.as_raw_handle(); + let (overlapped, ok) = unsafe { + let mut overlapped = std::mem::zeroed(); + let ok = LockFileEx( + handle, + LOCKFILE_EXCLUSIVE_LOCK | LOCKFILE_FAIL_IMMEDIATELY, + 0, + PREFIX_LOCKED_BYTES as u32, + 0, + &mut overlapped, + ); + + (overlapped, ok) + }; + + if ok != 0 { + return Ok(Lock::Acquired(Self { file, overlapped })); + } + + let err = io::Error::last_os_error(); + let raw = err.raw_os_error(); + // docs report it should return ERROR_IO_PENDING, but in my testing it actually + // returns ERROR_LOCK_VIOLATION. Or maybe winapi is wrong? + if raw == Some(ERROR_IO_PENDING as i32) || raw == Some(ERROR_LOCK_VIOLATION as i32) { + return Ok(Lock::AlreadyLocked(file)); + } + + Err(CodeError::SingletonLockfileOpenFailed(err)) + } + + #[cfg(unix)] + pub fn acquire(file: File) -> Result { + use std::os::unix::io::AsRawFd; + + let fd = file.as_raw_fd(); + let res = unsafe { libc::flock(fd, libc::LOCK_EX | libc::LOCK_NB) }; + if res == 0 { + return Ok(Lock::Acquired(Self { file })); + } + + let err = io::Error::last_os_error(); + if err.kind() == io::ErrorKind::WouldBlock { + return Ok(Lock::AlreadyLocked(file)); + } + + Err(CodeError::SingletonLockfileOpenFailed(err)) + } + + pub fn file(&self) -> &File { + &self.file + } + + pub fn file_mut(&mut self) -> &mut File { + &mut self.file + } +} + +impl Drop for FileLock { + #[cfg(windows)] + fn drop(&mut self) { + use std::os::windows::prelude::AsRawHandle; + use winapi::um::fileapi::UnlockFileEx; + + unsafe { + UnlockFileEx( + self.file.as_raw_handle(), + 0, + u32::MAX, + u32::MAX, + &mut self.overlapped, + ) + }; + } + + #[cfg(unix)] + fn drop(&mut self) { + use std::os::unix::io::AsRawFd; + + unsafe { libc::flock(self.file.as_raw_fd(), libc::LOCK_UN) }; + } +} diff --git a/cli/src/util/io.rs b/cli/src/util/io.rs index a21a2ceb632..95b378c0c65 100644 --- a/cli/src/util/io.rs +++ b/cli/src/util/io.rs @@ -15,6 +15,8 @@ use tokio::{ time::sleep, }; +use super::ring_buffer::RingBuffer; + pub trait ReportCopyProgress { fn report_progress(&mut self, bytes_so_far: u64, total_bytes: u64); } @@ -132,8 +134,7 @@ pub fn tailf(file: File, n: usize) -> mpsc::UnboundedReceiver { // Read the initial "n" lines back from the request. initial_lines // is a small ring buffer. - let mut initial_lines = Vec::with_capacity(n); - let mut initial_lines_i = 0; + let mut initial_lines = RingBuffer::new(n); loop { let mut line = String::new(); let bytes_read = match reader.read_line(&mut line) { @@ -151,26 +152,11 @@ pub fn tailf(file: File, n: usize) -> mpsc::UnboundedReceiver { } pos += bytes_read as u64; - if initial_lines.len() < initial_lines.capacity() { - initial_lines.push(line) - } else { - initial_lines[initial_lines_i] = line; - } - - initial_lines_i = (initial_lines_i + 1) % n; + initial_lines.push(line); } - // remove tail lines... - if initial_lines_i < initial_lines.len() { - for line in initial_lines.drain((initial_lines_i)..) { - tx.send(TailEvent::Line(line)).ok(); - } - } - // then the remaining lines - if !initial_lines.is_empty() { - for line in initial_lines.drain(0..) { - tx.send(TailEvent::Line(line)).ok(); - } + for line in initial_lines.into_iter() { + tx.send(TailEvent::Line(line)).ok(); } // now spawn the poll process to keep reading new lines diff --git a/cli/src/util/machine.rs b/cli/src/util/machine.rs index c3e0e2bfb98..e97a043a637 100644 --- a/cli/src/util/machine.rs +++ b/cli/src/util/machine.rs @@ -3,7 +3,7 @@ * Licensed under the MIT License. See License.txt in the project root for license information. *--------------------------------------------------------------------------------------------*/ -use std::path::Path; +use std::{path::Path, time::Duration}; use sysinfo::{Pid, PidExt, ProcessExt, System, SystemExt}; pub fn process_at_path_exists(pid: u32, name: &Path) -> bool { @@ -29,6 +29,14 @@ pub fn process_exists(pid: u32) -> bool { sys.refresh_process(Pid::from_u32(pid)) } +pub async fn wait_until_process_exits(pid: Pid, poll_ms: u64) { + let mut s = System::new(); + let duration = Duration::from_millis(poll_ms); + while s.refresh_process(pid) { + tokio::time::sleep(duration).await; + } +} + pub fn find_running_process(name: &Path) -> Option { let mut sys = System::new(); sys.refresh_processes(); diff --git a/cli/src/util/ring_buffer.rs b/cli/src/util/ring_buffer.rs new file mode 100644 index 00000000000..3dfb8c587d9 --- /dev/null +++ b/cli/src/util/ring_buffer.rs @@ -0,0 +1,142 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +pub struct RingBuffer { + data: Vec, + i: usize, +} + +impl RingBuffer { + pub fn new(capacity: usize) -> Self { + Self { + data: Vec::with_capacity(capacity), + i: 0, + } + } + + pub fn capacity(&self) -> usize { + self.data.capacity() + } + + pub fn len(&self) -> usize { + self.data.len() + } + + pub fn is_full(&self) -> bool { + self.data.len() == self.data.capacity() + } + + pub fn is_empty(&self) -> bool { + self.data.len() == 0 + } + + pub fn push(&mut self, value: T) { + if self.data.len() == self.data.capacity() { + self.data[self.i] = value; + } else { + self.data.push(value); + } + + self.i = (self.i + 1) % self.data.capacity(); + } + + pub fn iter(&self) -> RingBufferIter<'_, T> { + RingBufferIter { + index: 0, + buffer: self, + } + } +} + +impl IntoIterator for RingBuffer { + type Item = T; + type IntoIter = OwnedRingBufferIter; + + fn into_iter(self) -> OwnedRingBufferIter + where + T: Default, + { + OwnedRingBufferIter { + index: 0, + buffer: self, + } + } +} + +pub struct OwnedRingBufferIter { + buffer: RingBuffer, + index: usize, +} + +impl Iterator for OwnedRingBufferIter { + type Item = T; + + fn next(&mut self) -> Option { + if self.index == self.buffer.len() { + return None; + } + + let ii = (self.index + self.buffer.i) % self.buffer.len(); + let item = std::mem::take(&mut self.buffer.data[ii]); + self.index += 1; + Some(item) + } +} + +pub struct RingBufferIter<'a, T> { + buffer: &'a RingBuffer, + index: usize, +} + +impl<'a, T> Iterator for RingBufferIter<'a, T> { + type Item = &'a T; + + fn next(&mut self) -> Option { + if self.index == self.buffer.len() { + return None; + } + + let ii = (self.index + self.buffer.i) % self.buffer.len(); + let item = &self.buffer.data[ii]; + self.index += 1; + Some(item) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_inserts() { + let mut rb = RingBuffer::new(3); + assert_eq!(rb.capacity(), 3); + assert!(!rb.is_full()); + assert_eq!(rb.len(), 0); + assert_eq!(rb.iter().copied().count(), 0); + + rb.push(1); + assert!(!rb.is_full()); + assert_eq!(rb.len(), 1); + assert_eq!(rb.iter().copied().collect::>(), vec![1]); + + rb.push(2); + assert!(!rb.is_full()); + assert_eq!(rb.len(), 2); + assert_eq!(rb.iter().copied().collect::>(), vec![1, 2]); + + rb.push(3); + assert!(rb.is_full()); + assert_eq!(rb.len(), 3); + assert_eq!(rb.iter().copied().collect::>(), vec![1, 2, 3]); + + rb.push(4); + assert!(rb.is_full()); + assert_eq!(rb.len(), 3); + assert_eq!(rb.iter().copied().collect::>(), vec![2, 3, 4]); + + assert_eq!(rb.into_iter().collect::>(), vec![2, 3, 4]); + } +} diff --git a/cli/src/util/sync.rs b/cli/src/util/sync.rs index 5f33419488a..d57cbdecf17 100644 --- a/cli/src/util/sync.rs +++ b/cli/src/util/sync.rs @@ -2,38 +2,53 @@ * Copyright (c) Microsoft Corporation. All rights reserved. * Licensed under the MIT License. See License.txt in the project root for license information. *--------------------------------------------------------------------------------------------*/ -use tokio::sync::watch::{ - self, - error::{RecvError, SendError}, +use async_trait::async_trait; +use std::{marker::PhantomData, sync::Arc}; +use tokio::sync::{ + broadcast, mpsc, + watch::{self, error::RecvError}, }; #[derive(Clone)] pub struct Barrier(watch::Receiver>) where - T: Copy; + T: Clone; impl Barrier where - T: Copy, + T: Clone, { /// Waits for the barrier to be closed, returning a value if one was sent. pub async fn wait(&mut self) -> Result { loop { self.0.changed().await?; - if let Some(v) = *(self.0.borrow()) { + if let Some(v) = self.0.borrow().clone() { return Ok(v); } } } + + /// Gets whether the barrier is currently open + pub fn is_open(&self) -> bool { + self.0.borrow().is_some() + } } -pub struct BarrierOpener(watch::Sender>); +#[derive(Clone)] +pub struct BarrierOpener(Arc>>); -impl BarrierOpener { - /// Closes the barrier. - pub fn open(self, value: T) -> Result<(), SendError>> { - self.0.send(Some(value)) +impl BarrierOpener { + /// Opens the barrier. + pub fn open(&self, value: T) { + self.0.send_if_modified(|v| { + if v.is_none() { + *v = Some(value); + true + } else { + false + } + }); } } @@ -44,7 +59,119 @@ where T: Copy, { let (closed_tx, closed_rx) = watch::channel(None); - (Barrier(closed_rx), BarrierOpener(closed_tx)) + (Barrier(closed_rx), BarrierOpener(Arc::new(closed_tx))) +} + +/// Type that can receive messages in an async way. +#[async_trait] +pub trait Receivable { + async fn recv_msg(&mut self) -> Option; +} + +// todo: ideally we would use an Arc in the broadcast::Receiver to avoid having +// to clone bytes everywhere, requires updating rpc consumers as well. +#[async_trait] +impl Receivable for broadcast::Receiver { + async fn recv_msg(&mut self) -> Option { + loop { + match self.recv().await { + Ok(v) => return Some(v), + Err(broadcast::error::RecvError::Lagged(_)) => continue, + Err(broadcast::error::RecvError::Closed) => return None, + } + } + } +} + +#[async_trait] +impl Receivable for mpsc::UnboundedReceiver { + async fn recv_msg(&mut self) -> Option { + self.recv().await + } +} + +#[async_trait] +impl Receivable for () { + async fn recv_msg(&mut self) -> Option { + futures::future::pending().await + } +} + +pub struct ConcatReceivable, B: Receivable> { + left: Option, + right: B, + _marker: PhantomData, +} + +impl, B: Receivable> ConcatReceivable { + pub fn new(left: A, right: B) -> Self { + Self { + left: Some(left), + right, + _marker: PhantomData, + } + } +} + +#[async_trait] +impl, B: Send + Receivable> Receivable + for ConcatReceivable +{ + async fn recv_msg(&mut self) -> Option { + if let Some(left) = &mut self.left { + match left.recv_msg().await { + Some(v) => return Some(v), + None => { + self.left = None; + } + } + } + + return self.right.recv_msg().await; + } +} + +pub struct MergedReceivable, B: Receivable> { + left: Option, + right: Option, + _marker: PhantomData, +} + +impl, B: Receivable> MergedReceivable { + pub fn new(left: A, right: B) -> Self { + Self { + left: Some(left), + right: Some(right), + _marker: PhantomData, + } + } +} + +#[async_trait] +impl, B: Send + Receivable> Receivable + for MergedReceivable +{ + async fn recv_msg(&mut self) -> Option { + loop { + match (&mut self.left, &mut self.right) { + (Some(left), Some(right)) => { + tokio::select! { + left = left.recv_msg() => match left { + Some(v) => return Some(v), + None => { self.left = None; continue; }, + }, + right = right.recv_msg() => match right { + Some(v) => return Some(v), + None => { self.right = None; continue; }, + }, + } + } + (Some(a), None) => break a.recv_msg().await, + (None, Some(b)) => break b.recv_msg().await, + (None, None) => break None, + } + } + } } #[cfg(test)] @@ -60,7 +187,7 @@ mod tests { tx.send(barrier.wait().await.unwrap()).unwrap(); }); - opener.open(42).unwrap(); + opener.open(42); assert!(rx.await.unwrap() == 42); } @@ -71,7 +198,7 @@ mod tests { let (tx1, rx1) = tokio::sync::oneshot::channel::(); let (tx2, rx2) = tokio::sync::oneshot::channel::(); - opener.open(42).unwrap(); + opener.open(42); let mut b1 = barrier.clone(); tokio::spawn(async move { tx1.send(b1.wait().await.unwrap()).unwrap();