cli: allow exec server to listen on socket (#188123)

* cli: allow exec server to listen on socket

For remote ssh

* fix lint
This commit is contained in:
Connor Peet
2023-07-18 09:19:44 -07:00
committed by GitHub
parent a40a4f728a
commit 3e0786633b
4 changed files with 63 additions and 19 deletions

View File

@@ -95,7 +95,9 @@ async fn main() -> Result<(), std::convert::Infallible> {
args::VersionSubcommand::Show => version::show(context!()).await,
},
Some(args::Commands::CommandShell) => tunnels::command_shell(context!()).await,
Some(args::Commands::CommandShell(cs_args)) => {
tunnels::command_shell(context!(), cs_args).await
}
Some(args::Commands::Tunnel(tunnel_args)) => match tunnel_args.subcommand {
Some(args::TunnelSubcommand::Prune) => tunnels::prune(context!()).await,

View File

@@ -174,7 +174,14 @@ pub enum Commands {
/// Runs the control server on process stdin/stdout
#[clap(hide = true)]
CommandShell,
CommandShell(CommandShellArgs),
}
#[derive(Args, Debug, Clone)]
pub struct CommandShellArgs {
/// Listen on a socket instead of stdin/stdout.
#[clap(long)]
pub on_socket: bool,
}
#[derive(Args, Debug, Clone)]

View File

@@ -5,6 +5,7 @@
use async_trait::async_trait;
use base64::{engine::general_purpose as b64, Engine as _};
use futures::{stream::FuturesUnordered, StreamExt};
use serde::Serialize;
use sha2::{Digest, Sha256};
use std::{str::FromStr, time::Duration};
@@ -12,13 +13,14 @@ use sysinfo::Pid;
use super::{
args::{
AuthProvider, CliCore, ExistingTunnelArgs, TunnelRenameArgs, TunnelServeArgs,
TunnelServiceSubCommands, TunnelUserSubCommands,
AuthProvider, CliCore, CommandShellArgs, ExistingTunnelArgs, TunnelRenameArgs,
TunnelServeArgs, TunnelServiceSubCommands, TunnelUserSubCommands,
},
CommandContext,
};
use crate::{
async_pipe::{get_socket_name, listen_socket_rw_stream, socket_stream_split},
auth::Auth,
constants::{APPLICATION_NAME, TUNNEL_CLI_LOCK_NAME, TUNNEL_SERVICE_LOCK_NAME},
log,
@@ -120,23 +122,55 @@ impl ServiceContainer for TunnelServiceContainer {
}
}
pub async fn command_shell(ctx: CommandContext) -> Result<i32, AnyError> {
pub async fn command_shell(ctx: CommandContext, args: CommandShellArgs) -> Result<i32, AnyError> {
let platform = PreReqChecker::new().verify().await?;
serve_stream(
tokio::io::stdin(),
tokio::io::stderr(),
ServeStreamParams {
let mut params = ServeStreamParams {
log: ctx.log,
launcher_paths: ctx.paths,
platform,
requires_auth: true,
exit_barrier: ShutdownRequest::create_rx([ShutdownRequest::CtrlC]),
code_server_args: (&ctx.args).into(),
},
)
.await;
};
Ok(0)
if !args.on_socket {
serve_stream(tokio::io::stdin(), tokio::io::stderr(), params).await;
return Ok(0);
}
let socket = get_socket_name();
let mut listener = listen_socket_rw_stream(&socket)
.await
.map_err(|e| wrap(e, "error listening on socket"))?;
params
.log
.result(format!("Listening on {}", socket.display()));
let mut servers = FuturesUnordered::new();
loop {
tokio::select! {
Some(_) = servers.next() => {},
socket = listener.accept() => {
match socket {
Ok(s) => {
let (read, write) = socket_stream_split(s);
servers.push(serve_stream(read, write, params.clone()));
},
Err(e) => {
error!(params.log, &format!("Error accepting connection: {}", e));
return Ok(1);
}
}
},
_ = params.exit_barrier.wait() => {
// wait for all servers to finish up:
while (servers.next().await).is_some() { }
return Ok(0);
}
}
}
}
pub async fn service(

View File

@@ -233,6 +233,7 @@ pub async fn serve(
}
}
#[derive(Clone)]
pub struct ServeStreamParams {
pub log: log::Logger,
pub launcher_paths: LauncherPaths,