Merge branch 'main' into cli-ensure-code-tunnel-service-remains-headless-on-windows

This commit is contained in:
kernel-sanders
2023-08-02 10:57:13 -04:00
committed by GitHub
1081 changed files with 52167 additions and 16852 deletions

View File

@@ -4,7 +4,10 @@
*--------------------------------------------------------------------------------------------*/
use crate::{constants::APPLICATION_NAME, util::errors::CodeError};
use async_trait::async_trait;
use std::path::{Path, PathBuf};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpListener;
use uuid::Uuid;
// todo: we could probably abstract this into some crate, if one doesn't already exist
@@ -39,7 +42,7 @@ cfg_if::cfg_if! {
pipe.into_split()
}
} else {
use tokio::{time::sleep, io::{AsyncRead, AsyncWrite, ReadBuf}};
use tokio::{time::sleep, io::ReadBuf};
use tokio::net::windows::named_pipe::{ClientOptions, ServerOptions, NamedPipeClient, NamedPipeServer};
use std::{time::Duration, pin::Pin, task::{Context, Poll}, io};
use pin_project::pin_project;
@@ -181,3 +184,34 @@ pub fn get_socket_name() -> PathBuf {
}
}
}
pub type AcceptedRW = (
Box<dyn AsyncRead + Send + Unpin>,
Box<dyn AsyncWrite + Send + Unpin>,
);
#[async_trait]
pub trait AsyncRWAccepter {
async fn accept_rw(&mut self) -> Result<AcceptedRW, CodeError>;
}
#[async_trait]
impl AsyncRWAccepter for AsyncPipeListener {
async fn accept_rw(&mut self) -> Result<AcceptedRW, CodeError> {
let pipe = self.accept().await?;
let (read, write) = socket_stream_split(pipe);
Ok((Box::new(read), Box::new(write)))
}
}
#[async_trait]
impl AsyncRWAccepter for TcpListener {
async fn accept_rw(&mut self) -> Result<AcceptedRW, CodeError> {
let (stream, _) = self
.accept()
.await
.map_err(CodeError::AsyncPipeListenerFailed)?;
let (read, write) = tokio::io::split(stream);
Ok((Box::new(read), Box::new(write)))
}
}

View File

@@ -10,7 +10,8 @@ use crate::{
trace,
util::{
errors::{
wrap, AnyError, OAuthError, RefreshTokenNotAvailableError, StatusError, WrappedError,
wrap, AnyError, CodeError, OAuthError, RefreshTokenNotAvailableError, StatusError,
WrappedError,
},
input::prompt_options,
},
@@ -160,6 +161,7 @@ impl StoredCredential {
struct StorageWithLastRead {
storage: Box<dyn StorageImplementation>,
fallback_storage: Option<FileStorage>,
last_read: Cell<Result<Option<StoredCredential>, WrappedError>>,
}
@@ -172,9 +174,9 @@ pub struct Auth {
}
trait StorageImplementation: Send + Sync {
fn read(&mut self) -> Result<Option<StoredCredential>, WrappedError>;
fn store(&mut self, value: StoredCredential) -> Result<(), WrappedError>;
fn clear(&mut self) -> Result<(), WrappedError>;
fn read(&mut self) -> Result<Option<StoredCredential>, AnyError>;
fn store(&mut self, value: StoredCredential) -> Result<(), AnyError>;
fn clear(&mut self) -> Result<(), AnyError>;
}
// unseal decrypts and deserializes the value
@@ -217,16 +219,34 @@ struct ThreadKeyringStorage {
}
impl ThreadKeyringStorage {
fn thread_op<R, Fn>(&mut self, f: Fn) -> R
fn thread_op<R, Fn>(&mut self, f: Fn) -> Result<R, AnyError>
where
Fn: 'static + Send + FnOnce(&mut KeyringStorage) -> R,
Fn: 'static + Send + FnOnce(&mut KeyringStorage) -> Result<R, AnyError>,
R: 'static + Send,
{
let mut s = self.s.take().unwrap();
let handler = thread::spawn(move || (f(&mut s), s));
let (r, s) = handler.join().unwrap();
self.s = Some(s);
r
let mut s = match self.s.take() {
Some(s) => s,
None => return Err(CodeError::KeyringTimeout.into()),
};
// It seems like on Linux communication to the keyring can block indefinitely.
// Fall back after a 5 second timeout.
let (sender, receiver) = std::sync::mpsc::channel();
let tsender = sender.clone();
thread::spawn(move || sender.send(Some((f(&mut s), s))));
thread::spawn(move || {
thread::sleep(std::time::Duration::from_secs(5));
let _ = tsender.send(None);
});
match receiver.recv().unwrap() {
Some((r, s)) => {
self.s = Some(s);
r
}
None => Err(CodeError::KeyringTimeout.into()),
}
}
}
@@ -239,15 +259,15 @@ impl Default for ThreadKeyringStorage {
}
impl StorageImplementation for ThreadKeyringStorage {
fn read(&mut self) -> Result<Option<StoredCredential>, WrappedError> {
fn read(&mut self) -> Result<Option<StoredCredential>, AnyError> {
self.thread_op(|s| s.read())
}
fn store(&mut self, value: StoredCredential) -> Result<(), WrappedError> {
fn store(&mut self, value: StoredCredential) -> Result<(), AnyError> {
self.thread_op(move |s| s.store(value))
}
fn clear(&mut self) -> Result<(), WrappedError> {
fn clear(&mut self) -> Result<(), AnyError> {
self.thread_op(|s| s.clear())
}
}
@@ -273,7 +293,7 @@ macro_rules! get_next_entry {
}
impl StorageImplementation for KeyringStorage {
fn read(&mut self) -> Result<Option<StoredCredential>, WrappedError> {
fn read(&mut self) -> Result<Option<StoredCredential>, AnyError> {
let mut str = String::new();
for i in 0.. {
@@ -281,7 +301,7 @@ impl StorageImplementation for KeyringStorage {
let next_chunk = match entry.get_password() {
Ok(value) => value,
Err(keyring::Error::NoEntry) => return Ok(None), // missing entries?
Err(e) => return Err(wrap(e, "error reading keyring")),
Err(e) => return Err(wrap(e, "error reading keyring").into()),
};
if next_chunk.ends_with(CONTINUE_MARKER) {
@@ -295,7 +315,7 @@ impl StorageImplementation for KeyringStorage {
Ok(unseal(&str))
}
fn store(&mut self, value: StoredCredential) -> Result<(), WrappedError> {
fn store(&mut self, value: StoredCredential) -> Result<(), AnyError> {
let sealed = seal(&value);
let step_size = KEYCHAIN_ENTRY_LIMIT - CONTINUE_MARKER.len();
@@ -312,14 +332,14 @@ impl StorageImplementation for KeyringStorage {
};
if let Err(e) = stored {
return Err(wrap(e, "error updating keyring"));
return Err(wrap(e, "error updating keyring").into());
}
}
Ok(())
}
fn clear(&mut self) -> Result<(), WrappedError> {
fn clear(&mut self) -> Result<(), AnyError> {
self.read().ok(); // make sure component parts are available
for entry in self.entries.iter() {
entry
@@ -335,16 +355,16 @@ impl StorageImplementation for KeyringStorage {
struct FileStorage(PersistedState<Option<String>>);
impl StorageImplementation for FileStorage {
fn read(&mut self) -> Result<Option<StoredCredential>, WrappedError> {
fn read(&mut self) -> Result<Option<StoredCredential>, AnyError> {
Ok(self.0.load().and_then(|s| unseal(&s)))
}
fn store(&mut self, value: StoredCredential) -> Result<(), WrappedError> {
self.0.save(Some(seal(&value)))
fn store(&mut self, value: StoredCredential) -> Result<(), AnyError> {
self.0.save(Some(seal(&value))).map_err(|e| e.into())
}
fn clear(&mut self) -> Result<(), WrappedError> {
self.0.save(None)
fn clear(&mut self) -> Result<(), AnyError> {
self.0.save(None).map_err(|e| e.into())
}
}
@@ -373,20 +393,32 @@ impl Auth {
let mut keyring_storage = ThreadKeyringStorage::default();
let mut file_storage = FileStorage(PersistedState::new(self.file_storage_path.clone()));
let keyring_storage_result = match std::env::var("VSCODE_CLI_USE_FILE_KEYCHAIN") {
Ok(_) => Err(wrap("", "user prefers file storage")),
_ => keyring_storage.read(),
let native_storage_result = if std::env::var("VSCODE_CLI_USE_FILE_KEYCHAIN").is_ok()
|| self.file_storage_path.exists()
{
Err(wrap("", "user prefers file storage").into())
} else {
keyring_storage.read()
};
let mut storage = match keyring_storage_result {
let mut storage = match native_storage_result {
Ok(v) => StorageWithLastRead {
last_read: Cell::new(Ok(v)),
fallback_storage: Some(file_storage),
storage: Box::new(keyring_storage),
},
Err(_) => StorageWithLastRead {
last_read: Cell::new(file_storage.read()),
storage: Box::new(file_storage),
},
Err(e) => {
debug!(self.log, "Using file keychain storage due to: {}", e);
StorageWithLastRead {
last_read: Cell::new(
file_storage
.read()
.map_err(|e| wrap(e, "could not read from file storage")),
),
fallback_storage: None,
storage: Box::new(file_storage),
}
}
};
let out = op(&mut storage);
@@ -419,7 +451,7 @@ impl Auth {
}
/// Clears login info from the keyring.
pub fn clear_credentials(&self) -> Result<(), WrappedError> {
pub fn clear_credentials(&self) -> Result<(), AnyError> {
self.with_storage(|storage| {
storage.storage.clear()?;
storage.last_read.set(Ok(None));
@@ -505,7 +537,18 @@ impl Auth {
"Failed to update keyring with new credentials: {}",
e
);
if let Some(fb) = storage.fallback_storage.take() {
storage.storage = Box::new(fb);
match storage.storage.store(creds.clone()) {
Err(e) => {
warning!(self.log, "Also failed to update fallback storage: {}", e)
}
Ok(_) => debug!(self.log, "Updated fallback storage successfully"),
}
}
}
storage.last_read.set(Ok(Some(creds)));
})
}

View File

@@ -95,7 +95,9 @@ async fn main() -> Result<(), std::convert::Infallible> {
args::VersionSubcommand::Show => version::show(context!()).await,
},
Some(args::Commands::CommandShell) => tunnels::command_shell(context!()).await,
Some(args::Commands::CommandShell(cs_args)) => {
tunnels::command_shell(context!(), cs_args).await
}
Some(args::Commands::Tunnel(tunnel_args)) => match tunnel_args.subcommand {
Some(args::TunnelSubcommand::Prune) => tunnels::prune(context!()).await,
@@ -112,6 +114,9 @@ async fn main() -> Result<(), std::convert::Infallible> {
Some(args::TunnelSubcommand::Service(service_args)) => {
tunnels::service(context_no_logger(), service_args).await
}
Some(args::TunnelSubcommand::ForwardInternal(forward_args)) => {
tunnels::forward(context_no_logger(), forward_args).await
}
None => tunnels::serve(context_no_logger(), tunnel_args.serve_args).await,
},
},

View File

@@ -174,7 +174,20 @@ pub enum Commands {
/// Runs the control server on process stdin/stdout
#[clap(hide = true)]
CommandShell,
CommandShell(CommandShellArgs),
}
#[derive(Args, Debug, Clone)]
pub struct CommandShellArgs {
/// Listen on a socket instead of stdin/stdout.
#[clap(long)]
pub on_socket: bool,
/// Listen on a port instead of stdin/stdout.
#[clap(long)]
pub on_port: bool,
/// Require the given token string to be given in the handshake.
#[clap(long)]
pub require_token: Option<String>,
}
#[derive(Args, Debug, Clone)]
@@ -548,6 +561,7 @@ pub enum OutputFormat {
#[derive(Args, Clone, Debug, Default)]
pub struct ExistingTunnelArgs {
/// Name you'd like to assign preexisting tunnel to use to connect the tunnel
/// Old option, new code sohuld just use `--name`.
#[clap(long, hide = true)]
pub tunnel_name: Option<String>,
@@ -626,6 +640,10 @@ pub enum TunnelSubcommand {
/// (Preview) Manages the tunnel when installed as a system service,
#[clap(subcommand)]
Service(TunnelServiceSubCommands),
/// (Preview) Forwards local port using the dev tunnel
#[clap(hide = true)]
ForwardInternal(TunnelForwardArgs),
}
#[derive(Subcommand, Debug, Clone)]
@@ -649,6 +667,10 @@ pub struct TunnelServiceInstallArgs {
/// If set, the user accepts the server license terms and the server will be started without a user prompt.
#[clap(long)]
pub accept_server_license_terms: bool,
/// Sets the machine name for port forwarding service
#[clap(long)]
pub name: Option<String>,
}
#[derive(Args, Debug, Clone)]
@@ -657,6 +679,16 @@ pub struct TunnelRenameArgs {
pub name: String,
}
#[derive(Args, Debug, Clone)]
pub struct TunnelForwardArgs {
/// One or more ports to forward.
pub ports: Vec<u16>,
/// Login args -- used for convenience so the forwarding call is a single action.
#[clap(flatten)]
pub login: LoginArgs,
}
#[derive(Subcommand, Debug, Clone)]
pub enum TunnelUserSubCommands {
/// Log in to port forwarding service

View File

@@ -5,27 +5,37 @@
use async_trait::async_trait;
use base64::{engine::general_purpose as b64, Engine as _};
use futures::{stream::FuturesUnordered, StreamExt};
use serde::Serialize;
use sha2::{Digest, Sha256};
use std::{str::FromStr, time::Duration};
use sysinfo::Pid;
use tokio::{
io::{AsyncBufReadExt, BufReader},
sync::watch,
};
use super::{
args::{
AuthProvider, CliCore, ExistingTunnelArgs, TunnelRenameArgs, TunnelServeArgs,
TunnelServiceSubCommands, TunnelUserSubCommands,
AuthProvider, CliCore, CommandShellArgs, ExistingTunnelArgs, TunnelForwardArgs,
TunnelRenameArgs, TunnelServeArgs, TunnelServiceSubCommands, TunnelUserSubCommands,
},
CommandContext,
};
use crate::{
async_pipe::{get_socket_name, listen_socket_rw_stream, AsyncRWAccepter},
auth::Auth,
constants::{APPLICATION_NAME, TUNNEL_CLI_LOCK_NAME, TUNNEL_SERVICE_LOCK_NAME},
constants::{
APPLICATION_NAME, CONTROL_PORT, IS_A_TTY, TUNNEL_CLI_LOCK_NAME, TUNNEL_SERVICE_LOCK_NAME,
},
log,
state::LauncherPaths,
tunnels::{
code_server::CodeServerArgs,
create_service_manager, dev_tunnels, legal,
create_service_manager,
dev_tunnels::{self, DevTunnels},
forwarding, legal,
paths::get_all_servers,
protocol, serve_stream,
shutdown_signal::ShutdownRequest,
@@ -33,7 +43,7 @@ use crate::{
singleton_server::{
make_singleton_server, start_singleton_server, BroadcastLogSink, SingletonServerArgs,
},
Next, ServeStreamParams, ServiceContainer, ServiceManager,
AuthRequired, Next, ServeStreamParams, ServiceContainer, ServiceManager,
},
util::{
app_lock::AppMutex,
@@ -59,20 +69,31 @@ impl From<AuthProvider> for crate::auth::AuthProvider {
}
}
impl From<ExistingTunnelArgs> for Option<dev_tunnels::ExistingTunnel> {
fn from(d: ExistingTunnelArgs) -> Option<dev_tunnels::ExistingTunnel> {
if let (Some(tunnel_id), Some(tunnel_name), Some(cluster), Some(host_token)) =
(d.tunnel_id, d.tunnel_name, d.cluster, d.host_token)
{
fn fulfill_existing_tunnel_args(
d: ExistingTunnelArgs,
name_arg: &Option<String>,
) -> Option<dev_tunnels::ExistingTunnel> {
let tunnel_name = d.tunnel_name.or_else(|| name_arg.clone());
match (d.tunnel_id, d.cluster, d.host_token) {
(Some(tunnel_id), None, Some(host_token)) => {
let i = tunnel_id.find('.')?;
Some(dev_tunnels::ExistingTunnel {
tunnel_id,
tunnel_id: tunnel_id[..i].to_string(),
cluster: tunnel_id[i + 1..].to_string(),
tunnel_name,
host_token,
cluster,
})
} else {
None
}
(Some(tunnel_id), Some(cluster), Some(host_token)) => Some(dev_tunnels::ExistingTunnel {
tunnel_id,
tunnel_name,
host_token,
cluster,
}),
_ => None,
}
}
@@ -109,23 +130,71 @@ impl ServiceContainer for TunnelServiceContainer {
}
}
pub async fn command_shell(ctx: CommandContext) -> Result<i32, AnyError> {
pub async fn command_shell(ctx: CommandContext, args: CommandShellArgs) -> Result<i32, AnyError> {
let platform = PreReqChecker::new().verify().await?;
serve_stream(
tokio::io::stdin(),
tokio::io::stderr(),
ServeStreamParams {
log: ctx.log,
launcher_paths: ctx.paths,
platform,
requires_auth: true,
exit_barrier: ShutdownRequest::create_rx([ShutdownRequest::CtrlC]),
code_server_args: (&ctx.args).into(),
},
)
.await;
let mut params = ServeStreamParams {
log: ctx.log,
launcher_paths: ctx.paths,
platform,
requires_auth: args
.require_token
.map(AuthRequired::VSDAWithToken)
.unwrap_or(AuthRequired::VSDA),
exit_barrier: ShutdownRequest::create_rx([ShutdownRequest::CtrlC]),
code_server_args: (&ctx.args).into(),
};
Ok(0)
let mut listener: Box<dyn AsyncRWAccepter> = match (args.on_port, args.on_socket) {
(_, true) => {
let socket = get_socket_name();
let listener = listen_socket_rw_stream(&socket)
.await
.map_err(|e| wrap(e, "error listening on socket"))?;
params
.log
.result(format!("Listening on {}", socket.display()));
Box::new(listener)
}
(true, _) => {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
.await
.map_err(|e| wrap(e, "error listening on port"))?;
params
.log
.result(format!("Listening on {}", listener.local_addr().unwrap()));
Box::new(listener)
}
_ => {
serve_stream(tokio::io::stdin(), tokio::io::stderr(), params).await;
return Ok(0);
}
};
let mut servers = FuturesUnordered::new();
loop {
tokio::select! {
Some(_) = servers.next() => {},
socket = listener.accept_rw() => {
match socket {
Ok((read, write)) => servers.push(serve_stream(read, write, params.clone())),
Err(e) => {
error!(params.log, &format!("Error accepting connection: {}", e));
return Ok(1);
}
}
},
_ = params.exit_barrier.wait() => {
// wait for all servers to finish up:
while (servers.next().await).is_some() { }
return Ok(0);
}
}
}
}
pub async fn service(
@@ -135,10 +204,17 @@ pub async fn service(
let manager = create_service_manager(ctx.log.clone(), &ctx.paths);
match service_args {
TunnelServiceSubCommands::Install(args) => {
// ensure logged in, otherwise subsequent serving will fail
Auth::new(&ctx.paths, ctx.log.clone())
.get_credential()
.await?;
let auth = Auth::new(&ctx.paths, ctx.log.clone());
if let Some(name) = &args.name {
// ensure the name matches, and tunnel exists
dev_tunnels::DevTunnels::new_remote_tunnel(&ctx.log, auth, &ctx.paths)
.rename_tunnel(name)
.await?;
} else {
// still ensure they're logged in, otherwise subsequent serving will fail
auth.get_credential().await?;
}
// likewise for license consent
legal::require_consent(&ctx.paths, args.accept_server_license_terms)?;
@@ -203,23 +279,23 @@ pub async fn user(ctx: CommandContext, user_args: TunnelUserSubCommands) -> Resu
Ok(0)
}
/// Remove the tunnel used by this gateway, if any.
/// Remove the tunnel used by this tunnel, if any.
pub async fn rename(ctx: CommandContext, rename_args: TunnelRenameArgs) -> Result<i32, AnyError> {
let auth = Auth::new(&ctx.paths, ctx.log.clone());
let mut dt = dev_tunnels::DevTunnels::new(&ctx.log, auth, &ctx.paths);
let mut dt = dev_tunnels::DevTunnels::new_remote_tunnel(&ctx.log, auth, &ctx.paths);
dt.rename_tunnel(&rename_args.name).await?;
ctx.log.result(format!(
"Successfully renamed this gateway to {}",
"Successfully renamed this tunnel to {}",
&rename_args.name
));
Ok(0)
}
/// Remove the tunnel used by this gateway, if any.
/// Remove the tunnel used by this tunnel, if any.
pub async fn unregister(ctx: CommandContext) -> Result<i32, AnyError> {
let auth = Auth::new(&ctx.paths, ctx.log.clone());
let mut dt = dev_tunnels::DevTunnels::new(&ctx.log, auth, &ctx.paths);
let mut dt = dev_tunnels::DevTunnels::new_remote_tunnel(&ctx.log, auth, &ctx.paths);
dt.remove_tunnel().await?;
Ok(0)
}
@@ -327,6 +403,88 @@ pub async fn serve(ctx: CommandContext, gateway_args: TunnelServeArgs) -> Result
result
}
/// Internal command used by port forwarding. It reads requests for forwarded ports
/// on lines from stdin, as JSON. It uses singleton logic as well (though on
/// a different tunnel than the main one used for the control server) so that
/// all forward requests on a single machine go through a single hosted tunnel
/// process. Without singleton logic, requests could get routed to processes
/// that aren't forwarding a given port and then fail.
pub async fn forward(
ctx: CommandContext,
mut forward_args: TunnelForwardArgs,
) -> Result<i32, AnyError> {
// Spooky: check IS_A_TTY before starting the stdin reader, since IS_A_TTY will
// access stdin but a lock will later be held on stdin by the line-reader.
if *IS_A_TTY {
trace!(ctx.log, "port forwarding is an internal preview feature");
}
// #region stdin reading logic:
let (own_ports_tx, own_ports_rx) = watch::channel(vec![]);
let ports_process_log = ctx.log.clone();
tokio::spawn(async move {
let mut lines = BufReader::new(tokio::io::stdin()).lines();
while let Ok(Some(line)) = lines.next_line().await {
match serde_json::from_str(&line) {
Ok(p) => {
let _ = own_ports_tx.send(p);
}
Err(e) => warning!(ports_process_log, "error parsing ports: {}", e),
}
}
});
// #region singleton acquisition
let shutdown = ShutdownRequest::create_rx([ShutdownRequest::CtrlC]);
let server = loop {
if shutdown.is_open() {
return Ok(0);
}
match acquire_singleton(&ctx.paths.forwarding_lockfile()).await {
Ok(SingletonConnection::Client(stream)) => {
debug!(ctx.log, "starting as client to singleton");
let r = forwarding::client(forwarding::SingletonClientArgs {
log: ctx.log.clone(),
shutdown: shutdown.clone(),
stream,
port_requests: own_ports_rx.clone(),
})
.await;
if let Err(e) = r {
warning!(ctx.log, "error contacting forwarding singleton: {}", e);
}
}
Ok(SingletonConnection::Singleton(server)) => break server,
Err(e) => {
warning!(ctx.log, "error access singleton, retrying: {}", e);
tokio::time::sleep(Duration::from_secs(2)).await
}
}
};
// #region singleton handler
let auth = Auth::new(&ctx.paths, ctx.log.clone());
println!("preauth {:?}", forward_args.login);
if let (Some(p), Some(at)) = (
forward_args.login.provider.take(),
forward_args.login.access_token.take(),
) {
auth.login(Some(p.into()), Some(at)).await?;
}
println!("auth done");
let mut tunnels = DevTunnels::new_port_forwarding(&ctx.log, auth, &ctx.paths);
let tunnel = tunnels
.start_new_launcher_tunnel(None, true, &forward_args.ports)
.await?;
println!("made tunnel");
forwarding::server(ctx.log, tunnel, server, own_ports_rx, shutdown).await?;
Ok(0)
}
fn get_connection_token(tunnel: &ActiveTunnel) -> String {
let mut hash = Sha256::new();
hash.update(tunnel.id.as_bytes());
@@ -374,7 +532,7 @@ async fn serve_with_csa(
return Ok(0);
}
match acquire_singleton(paths.tunnel_lockfile()).await {
match acquire_singleton(&paths.tunnel_lockfile()).await {
Ok(SingletonConnection::Client(stream)) => {
debug!(log, "starting as client to singleton");
let should_exit = start_singleton_client(SingletonClientArgs {
@@ -403,13 +561,19 @@ async fn serve_with_csa(
let _lock = app_mutex_name.map(AppMutex::new);
let auth = Auth::new(&paths, log.clone());
let mut dt = dev_tunnels::DevTunnels::new(&log, auth, &paths);
let mut dt = dev_tunnels::DevTunnels::new_remote_tunnel(&log, auth, &paths);
loop {
let tunnel = if let Some(d) = gateway_args.tunnel.clone().into() {
dt.start_existing_tunnel(d).await
let tunnel = if let Some(t) =
fulfill_existing_tunnel_args(gateway_args.tunnel.clone(), &gateway_args.name)
{
dt.start_existing_tunnel(t).await
} else {
dt.start_new_launcher_tunnel(gateway_args.name.as_deref(), gateway_args.random_name)
.await
dt.start_new_launcher_tunnel(
gateway_args.name.as_deref(),
gateway_args.random_name,
&[CONTROL_PORT],
)
.await
}?;
csa.connection_token = Some(get_connection_token(&tunnel));

View File

@@ -122,7 +122,7 @@ pub struct MsgPackCodec<T> {
impl<T> MsgPackCodec<T> {
pub fn new() -> Self {
Self {
_marker: std::marker::PhantomData::default(),
_marker: std::marker::PhantomData,
}
}
}

View File

@@ -117,7 +117,7 @@ impl<S: Serialization, C: Send + Sync + 'static> RpcMethodBuilder<S, C> {
Ok(p) => p,
Err(err) => {
return id.map(|id| {
serial.serialize(&ErrorResponse {
serial.serialize(ErrorResponse {
id,
error: ResponseError {
code: 0,
@@ -131,7 +131,7 @@ impl<S: Serialization, C: Send + Sync + 'static> RpcMethodBuilder<S, C> {
match callback(param.params, &context) {
Ok(result) => id.map(|id| serial.serialize(&SuccessResponse { id, result })),
Err(err) => id.map(|id| {
serial.serialize(&ErrorResponse {
serial.serialize(ErrorResponse {
id,
error: ResponseError {
code: -1,
@@ -161,7 +161,7 @@ impl<S: Serialization, C: Send + Sync + 'static> RpcMethodBuilder<S, C> {
Ok(p) => p,
Err(err) => {
return future::ready(id.map(|id| {
serial.serialize(&ErrorResponse {
serial.serialize(ErrorResponse {
id,
error: ResponseError {
code: 0,
@@ -182,7 +182,7 @@ impl<S: Serialization, C: Send + Sync + 'static> RpcMethodBuilder<S, C> {
id.map(|id| serial.serialize(&SuccessResponse { id, result }))
}
Err(err) => id.map(|id| {
serial.serialize(&ErrorResponse {
serial.serialize(ErrorResponse {
id,
error: ResponseError {
code: -1,
@@ -222,7 +222,7 @@ impl<S: Serialization, C: Send + Sync + 'static> RpcMethodBuilder<S, C> {
return (
None,
future::ready(id.map(|id| {
serial.serialize(&ErrorResponse {
serial.serialize(ErrorResponse {
id,
error: ResponseError {
code: 0,
@@ -255,7 +255,7 @@ impl<S: Serialization, C: Send + Sync + 'static> RpcMethodBuilder<S, C> {
match callback(servers, param.params, context).await {
Ok(r) => id.map(|id| serial.serialize(&SuccessResponse { id, result: r })),
Err(err) => id.map(|id| {
serial.serialize(&ErrorResponse {
serial.serialize(ErrorResponse {
id,
error: ResponseError {
code: -1,
@@ -427,7 +427,7 @@ impl<S: Serialization, C: Send + Sync> RpcDispatcher<S, C> {
Some(Method::Async(callback)) => MaybeSync::Future(callback(id, body)),
Some(Method::Duplex(callback)) => MaybeSync::Stream(callback(id, body)),
None => MaybeSync::Sync(id.map(|id| {
self.serializer.serialize(&ErrorResponse {
self.serializer.serialize(ErrorResponse {
id,
error: ResponseError {
code: -1,

View File

@@ -53,12 +53,12 @@ struct LockFileMatter {
/// Tries to acquire the singleton homed at the given lock file, either starting
/// a new singleton if it doesn't exist, or connecting otherwise.
pub async fn acquire_singleton(lock_file: PathBuf) -> Result<SingletonConnection, CodeError> {
pub async fn acquire_singleton(lock_file: &Path) -> Result<SingletonConnection, CodeError> {
let file = OpenOptions::new()
.read(true)
.write(true)
.create(true)
.open(&lock_file)
.open(lock_file)
.map_err(CodeError::SingletonLockfileOpenFailed)?;
match FileLock::acquire(file) {
@@ -158,7 +158,7 @@ mod tests {
#[tokio::test]
async fn test_acquires_singleton() {
let dir = tempfile::tempdir().expect("expected to make temp dir");
let s = acquire_singleton(dir.path().join("lock"))
let s = acquire_singleton(&dir.path().join("lock"))
.await
.expect("expected to acquire");
@@ -172,7 +172,7 @@ mod tests {
async fn test_acquires_client() {
let dir = tempfile::tempdir().expect("expected to make temp dir");
let lockfile = dir.path().join("lock");
let s1 = acquire_singleton(lockfile.clone())
let s1 = acquire_singleton(&lockfile)
.await
.expect("expected to acquire1");
match s1 {
@@ -182,7 +182,7 @@ mod tests {
_ => panic!("expected to be singleton"),
};
let s2 = acquire_singleton(lockfile)
let s2 = acquire_singleton(&lockfile)
.await
.expect("expected to acquire2");
match s2 {

View File

@@ -187,6 +187,14 @@ impl LauncherPaths {
))
}
/// Lockfile for port forwarding
pub fn forwarding_lockfile(&self) -> PathBuf {
self.root.join(format!(
"forwarding-{}.lock",
VSCODE_CLI_QUALITY.unwrap_or("oss")
))
}
/// Suggested path for tunnel service logs, when using file logs
pub fn service_log_file(&self) -> PathBuf {
self.root.join("tunnel-service.log")

View File

@@ -11,6 +11,7 @@ pub mod protocol;
pub mod shutdown_signal;
pub mod singleton_client;
pub mod singleton_server;
pub mod forwarding;
mod wsl_detect;
mod challenge;
@@ -34,7 +35,7 @@ mod service_macos;
mod service_windows;
mod socket_signal;
pub use control_server::{serve, serve_stream, Next, ServeStreamParams};
pub use control_server::{serve, serve_stream, Next, ServeStreamParams, AuthRequired};
pub use nosleep::SleepInhibitor;
pub use service::{
create_service_manager, ServiceContainer, ServiceManager, SERVICE_LOG_FILE_NAME,

View File

@@ -575,7 +575,17 @@ impl<'a> ServerBuilder<'a> {
}
fn get_base_command(&self) -> Command {
#[cfg(not(windows))]
let mut cmd = Command::new(&self.server_paths.executable);
#[cfg(windows)]
let mut cmd = {
let mut cmd = Command::new("cmd");
cmd.arg("/Q");
cmd.arg("/C");
cmd.arg(&self.server_paths.executable);
cmd
};
cmd.stdin(std::process::Stdio::null())
.args(self.server_params.code_server_args.command_arguments());
cmd

View File

@@ -48,11 +48,11 @@ use super::dev_tunnels::ActiveTunnel;
use super::paths::prune_stopped_servers;
use super::port_forwarder::{PortForwarding, PortForwardingProcessor};
use super::protocol::{
AcquireCliParams, CallServerHttpParams, CallServerHttpResult, ChallengeIssueResponse,
ChallengeVerifyParams, ClientRequestMethod, EmptyObject, ForwardParams, ForwardResult,
FsStatRequest, FsStatResponse, GetEnvResponse, GetHostnameResponse, HttpBodyParams,
HttpHeadersParams, ServeParams, ServerLog, ServerMessageParams, SpawnParams, SpawnResult,
ToClientRequest, UnforwardParams, UpdateParams, UpdateResult, VersionResponse,
AcquireCliParams, CallServerHttpParams, CallServerHttpResult, ChallengeIssueParams,
ChallengeIssueResponse, ChallengeVerifyParams, ClientRequestMethod, EmptyObject, ForwardParams,
ForwardResult, FsStatRequest, FsStatResponse, GetEnvResponse, GetHostnameResponse,
HttpBodyParams, HttpHeadersParams, ServeParams, ServerLog, ServerMessageParams, SpawnParams,
SpawnResult, ToClientRequest, UnforwardParams, UpdateParams, UpdateResult, VersionResponse,
METHOD_CHALLENGE_VERIFY,
};
use super::server_bridge::ServerBridge;
@@ -94,8 +94,8 @@ struct HandlerContext {
/// Handler auth state.
enum AuthState {
/// Auth is required, we're waiting for the client to send its challenge.
WaitingForChallenge,
/// Auth is required, we're waiting for the client to send its challenge optionally bearing a token.
WaitingForChallenge(Option<String>),
/// A challenge has been issued. Waiting for a verification.
ChallengeIssued(String),
/// Auth is no longer required.
@@ -215,7 +215,7 @@ pub async fn serve(
code_server_args: own_code_server_args,
platform,
exit_barrier: own_exit,
requires_auth: false,
requires_auth: AuthRequired::None,
}).with_context(cx.clone()).await;
cx.span().add_event(
@@ -233,12 +233,20 @@ pub async fn serve(
}
}
#[derive(Clone)]
pub enum AuthRequired {
None,
VSDA,
VSDAWithToken(String),
}
#[derive(Clone)]
pub struct ServeStreamParams {
pub log: log::Logger,
pub launcher_paths: LauncherPaths,
pub code_server_args: CodeServerArgs,
pub platform: Platform,
pub requires_auth: bool,
pub requires_auth: AuthRequired,
pub exit_barrier: Barrier<ShutdownSignal>,
}
@@ -268,7 +276,7 @@ fn make_socket_rpc(
launcher_paths: LauncherPaths,
code_server_args: CodeServerArgs,
port_forwarding: Option<PortForwarding>,
requires_auth: bool,
requires_auth: AuthRequired,
platform: Platform,
) -> RpcDispatcher<MsgPackSerializer, HandlerContext> {
let http_requests = Arc::new(std::sync::Mutex::new(HashMap::new()));
@@ -276,8 +284,9 @@ fn make_socket_rpc(
let mut rpc = RpcBuilder::new(MsgPackSerializer {}).methods(HandlerContext {
did_update: Arc::new(AtomicBool::new(false)),
auth_state: Arc::new(std::sync::Mutex::new(match requires_auth {
true => AuthState::WaitingForChallenge,
false => AuthState::Authenticated,
AuthRequired::VSDAWithToken(t) => AuthState::WaitingForChallenge(Some(t)),
AuthRequired::VSDA => AuthState::WaitingForChallenge(None),
AuthRequired::None => AuthState::Authenticated,
})),
socket_tx,
log: log.clone(),
@@ -304,8 +313,8 @@ fn make_socket_rpc(
ensure_auth(&c.auth_state)?;
handle_get_env()
});
rpc.register_sync(METHOD_CHALLENGE_ISSUE, |_: EmptyObject, c| {
handle_challenge_issue(&c.auth_state)
rpc.register_sync(METHOD_CHALLENGE_ISSUE, |p: ChallengeIssueParams, c| {
handle_challenge_issue(p, &c.auth_state)
});
rpc.register_sync(METHOD_CHALLENGE_VERIFY, |p: ChallengeVerifyParams, c| {
handle_challenge_verify(p.response, &c.auth_state)
@@ -422,6 +431,7 @@ async fn process_socket(
let rx_counter = Arc::new(AtomicUsize::new(0));
let http_requests = Arc::new(std::sync::Mutex::new(HashMap::new()));
let already_authed = matches!(requires_auth, AuthRequired::None);
let rpc = make_socket_rpc(
log.clone(),
socket_tx.clone(),
@@ -439,7 +449,7 @@ async fn process_socket(
let socket_tx = socket_tx.clone();
let exit_barrier = exit_barrier.clone();
tokio::spawn(async move {
if !requires_auth {
if already_authed {
send_version(&socket_tx).await;
}
@@ -825,13 +835,22 @@ fn handle_get_env() -> Result<GetEnvResponse, AnyError> {
}
fn handle_challenge_issue(
params: ChallengeIssueParams,
auth_state: &Arc<std::sync::Mutex<AuthState>>,
) -> Result<ChallengeIssueResponse, AnyError> {
let challenge = create_challenge();
let mut auth_state = auth_state.lock().unwrap();
*auth_state = AuthState::ChallengeIssued(challenge.clone());
if let AuthState::WaitingForChallenge(Some(s)) = &*auth_state {
println!("looking for token {}, got {:?}", s, params.token);
match &params.token {
Some(t) if s != t => return Err(CodeError::AuthChallengeBadToken.into()),
None => return Err(CodeError::AuthChallengeBadToken.into()),
_ => {}
}
}
*auth_state = AuthState::ChallengeIssued(challenge.clone());
Ok(ChallengeIssueResponse { challenge })
}
@@ -843,7 +862,7 @@ fn handle_challenge_verify(
match &*auth_state {
AuthState::Authenticated => Ok(EmptyObject {}),
AuthState::WaitingForChallenge => Err(CodeError::AuthChallengeNotIssued.into()),
AuthState::WaitingForChallenge(_) => Err(CodeError::AuthChallengeNotIssued.into()),
AuthState::ChallengeIssued(c) => match verify_challenge(c, &response) {
false => Err(CodeError::AuthChallengeNotIssued.into()),
true => {

View File

@@ -3,12 +3,11 @@
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
use crate::auth;
use crate::constants::{
CONTROL_PORT, IS_INTERACTIVE_CLI, PROTOCOL_VERSION_TAG, TUNNEL_SERVICE_USER_AGENT,
};
use crate::constants::{IS_INTERACTIVE_CLI, PROTOCOL_VERSION_TAG, TUNNEL_SERVICE_USER_AGENT};
use crate::state::{LauncherPaths, PersistedState};
use crate::util::errors::{
wrap, AnyError, DevTunnelError, InvalidTunnelName, TunnelCreationFailed, WrappedError,
wrap, AnyError, CodeError, DevTunnelError, InvalidTunnelName, TunnelCreationFailed,
WrappedError,
};
use crate::util::input::prompt_placeholder;
use crate::{debug, info, log, spanf, trace, warning};
@@ -33,6 +32,8 @@ use tunnels::management::{
use super::wsl_detect::is_wsl_installed;
static TUNNEL_COUNT_LIMIT_NAME: &str = "TunnelsPerUserPerLocation";
#[derive(Clone, Serialize, Deserialize)]
pub struct PersistedTunnel {
pub name: String,
@@ -134,6 +135,7 @@ pub struct DevTunnels {
log: log::Logger,
launcher_tunnel: PersistedState<Option<PersistedTunnel>>,
client: TunnelManagementClient,
tag: &'static str,
}
/// Representation of a tunnel returned from the `start` methods.
@@ -162,30 +164,43 @@ impl ActiveTunnel {
}
/// Forwards a port over TCP.
pub async fn add_port_tcp(&mut self, port_number: u16) -> Result<(), AnyError> {
pub async fn add_port_tcp(&self, port_number: u16) -> Result<(), AnyError> {
self.manager.add_port_tcp(port_number).await?;
Ok(())
}
/// Removes a forwarded port TCP.
pub async fn remove_port(&mut self, port_number: u16) -> Result<(), AnyError> {
pub async fn remove_port(&self, port_number: u16) -> Result<(), AnyError> {
self.manager.remove_port(port_number).await?;
Ok(())
}
/// Gets the public URI on which a forwarded port can be access in browser.
pub async fn get_port_uri(&mut self, port: u16) -> Result<String, AnyError> {
let endpoint = self.manager.get_endpoint().await?;
let format = endpoint
.base
.port_uri_format
.expect("expected to have port format");
/// Gets the template string for forming forwarded port web URIs..
pub fn get_port_format(&self) -> Result<String, AnyError> {
if let Some(details) = &*self.manager.endpoint_rx.borrow() {
return details
.as_ref()
.map(|r| {
r.base
.port_uri_format
.clone()
.expect("expected to have port format")
})
.map_err(|e| e.clone().into());
}
Ok(format.replace(PORT_TOKEN, &port.to_string()))
Err(CodeError::NoTunnelEndpoint.into())
}
/// Gets the public URI on which a forwarded port can be access in browser.
pub fn get_port_uri(&self, port: u16) -> Result<String, AnyError> {
self.get_port_format()
.map(|f| f.replace(PORT_TOKEN, &port.to_string()))
}
}
const VSCODE_CLI_TUNNEL_TAG: &str = "vscode-server-launcher";
const VSCODE_CLI_FORWARDING_TAG: &str = "vscode-port-forward";
const MAX_TUNNEL_NAME_LENGTH: usize = 20;
fn get_host_token_from_tunnel(tunnel: &Tunnel) -> String {
@@ -229,7 +244,7 @@ lazy_static! {
#[derive(Clone, Debug)]
pub struct ExistingTunnel {
/// Name you'd like to assign preexisting tunnel to use to connect to the VS Code Server
pub tunnel_name: String,
pub tunnel_name: Option<String>,
/// Token to authenticate and use preexisting tunnel
pub host_token: String,
@@ -242,7 +257,29 @@ pub struct ExistingTunnel {
}
impl DevTunnels {
pub fn new(log: &log::Logger, auth: auth::Auth, paths: &LauncherPaths) -> DevTunnels {
/// Creates a new DevTunnels client used for port forwarding.
pub fn new_port_forwarding(
log: &log::Logger,
auth: auth::Auth,
paths: &LauncherPaths,
) -> DevTunnels {
let mut client = new_tunnel_management(&TUNNEL_SERVICE_USER_AGENT);
client.authorization_provider(auth);
DevTunnels {
log: log.clone(),
client: client.into(),
launcher_tunnel: PersistedState::new(paths.root().join("port_forwarding_tunnel.json")),
tag: VSCODE_CLI_FORWARDING_TAG,
}
}
/// Creates a new DevTunnels client used for the Remote Tunnels extension to access the VS Code Server.
pub fn new_remote_tunnel(
log: &log::Logger,
auth: auth::Auth,
paths: &LauncherPaths,
) -> DevTunnels {
let mut client = new_tunnel_management(&TUNNEL_SERVICE_USER_AGENT);
client.authorization_provider(auth);
@@ -250,6 +287,7 @@ impl DevTunnels {
log: log.clone(),
client: client.into(),
launcher_tunnel: PersistedState::new(paths.root().join("code_tunnel.json")),
tag: VSCODE_CLI_TUNNEL_TAG,
}
}
@@ -275,7 +313,9 @@ impl DevTunnels {
/// Renames the current tunnel to the new name.
pub async fn rename_tunnel(&mut self, name: &str) -> Result<(), AnyError> {
self.update_tunnel_name(None, name).await.map(|_| ())
self.update_tunnel_name(self.launcher_tunnel.load(), name)
.await
.map(|_| ())
}
/// Updates the name of the existing persisted tunnel to the new name.
@@ -286,28 +326,34 @@ impl DevTunnels {
name: &str,
) -> Result<(Tunnel, PersistedTunnel), AnyError> {
let name = name.to_ascii_lowercase();
self.check_is_name_free(&name).await?;
debug!(self.log, "Tunnel name changed, applying updates...");
let (mut full_tunnel, mut persisted, is_new) = match persisted {
Some(persisted) => {
debug!(
self.log,
"Found a persisted tunnel, seeing if the name matches..."
);
self.get_or_create_tunnel(persisted, Some(&name), NO_REQUEST_OPTIONS)
.await
}
None => self
.create_tunnel(&name, NO_REQUEST_OPTIONS)
.await
.map(|(pt, t)| (t, pt, true)),
None => {
debug!(self.log, "Creating a new tunnel with the requested name");
self.create_tunnel(&name, NO_REQUEST_OPTIONS)
.await
.map(|(pt, t)| (t, pt, true))
}
}?;
if is_new {
let desired_tags = self.get_tags(&name);
if is_new || vec_eq_as_set(&full_tunnel.tags, &desired_tags) {
return Ok((full_tunnel, persisted));
}
full_tunnel.tags = self.get_tags(&name);
debug!(self.log, "Tunnel name changed, applying updates...");
let new_tunnel = spanf!(
full_tunnel.tags = desired_tags;
let updated_tunnel = spanf!(
self.log,
self.log.span("dev-tunnel.tag.update"),
self.client.update_tunnel(&full_tunnel, NO_REQUEST_OPTIONS)
@@ -317,7 +363,7 @@ impl DevTunnels {
persisted.name = name;
self.launcher_tunnel.save(Some(persisted.clone()))?;
Ok((new_tunnel, persisted))
Ok((updated_tunnel, persisted))
}
/// Gets the persisted tunnel from the service, or creates a new one.
@@ -356,6 +402,7 @@ impl DevTunnels {
&mut self,
preferred_name: Option<&str>,
use_random_name: bool,
preserve_ports: &[u16],
) -> Result<ActiveTunnel, AnyError> {
let (mut tunnel, persisted) = match self.launcher_tunnel.load() {
Some(mut persisted) => {
@@ -385,7 +432,12 @@ impl DevTunnels {
};
tunnel = self
.sync_tunnel_tags(&persisted.name, tunnel, &HOST_TUNNEL_REQUEST_OPTIONS)
.sync_tunnel_tags(
&self.client,
&persisted.name,
tunnel,
&HOST_TUNNEL_REQUEST_OPTIONS,
)
.await?;
let locator = TunnelLocator::try_from(&tunnel).unwrap();
@@ -394,7 +446,7 @@ impl DevTunnels {
for port_to_delete in tunnel
.ports
.iter()
.filter(|p| p.port_number != CONTROL_PORT)
.filter(|p: &&TunnelPort| !preserve_ports.contains(&p.port_number))
{
let output_fut = self.client.delete_tunnel_port(
&locator,
@@ -443,14 +495,10 @@ impl DevTunnels {
) -> Result<(PersistedTunnel, Tunnel), AnyError> {
info!(self.log, "Creating tunnel with the name: {}", name);
let mut tried_recycle = false;
self.check_is_name_free(name).await?;
let new_tunnel = Tunnel {
tags: vec![
name.to_string(),
PROTOCOL_VERSION_TAG.to_string(),
VSCODE_CLI_TUNNEL_TAG.to_string(),
],
tags: self.get_tags(name),
..Default::default()
};
@@ -465,13 +513,14 @@ impl DevTunnels {
Err(HttpError::ResponseError(e))
if e.status_code == StatusCode::TOO_MANY_REQUESTS =>
{
if !tried_recycle && self.try_recycle_tunnel().await? {
tried_recycle = true;
continue;
}
if let Some(d) = e.get_details() {
let detail = d.detail.unwrap_or_else(|| "unknown".to_string());
if detail.contains(TUNNEL_COUNT_LIMIT_NAME)
&& self.try_recycle_tunnel().await?
{
continue;
}
return Err(AnyError::from(TunnelCreationFailed(
name.to_string(),
detail,
@@ -508,7 +557,7 @@ impl DevTunnels {
let mut tags = vec![
name.to_string(),
PROTOCOL_VERSION_TAG.to_string(),
VSCODE_CLI_TUNNEL_TAG.to_string(),
self.tag.to_string(),
];
if is_wsl_installed(&self.log) {
@@ -522,12 +571,13 @@ impl DevTunnels {
/// other version tags.
async fn sync_tunnel_tags(
&self,
client: &TunnelManagementClient,
name: &str,
tunnel: Tunnel,
options: &TunnelRequestOptions,
) -> Result<Tunnel, AnyError> {
let new_tags = self.get_tags(name);
if vec_eq_unsorted(&tunnel.tags, &new_tags) {
if vec_eq_as_set(&tunnel.tags, &new_tags) {
return Ok(tunnel);
}
@@ -548,7 +598,7 @@ impl DevTunnels {
let result = spanf!(
self.log,
self.log.span("dev-tunnel.protocol-tag-update"),
self.client.update_tunnel(&tunnel_update, options)
client.update_tunnel(&tunnel_update, options)
);
result.map_err(|e| wrap(e, "tunnel tag update failed").into())
@@ -599,7 +649,7 @@ impl DevTunnels {
self.log,
self.log.span("dev-tunnel.listall"),
self.client.list_all_tunnels(&TunnelRequestOptions {
tags: vec![VSCODE_CLI_TUNNEL_TAG.to_string()],
tags: vec![self.tag.to_string()],
require_all_tags: true,
..Default::default()
})
@@ -610,11 +660,11 @@ impl DevTunnels {
}
async fn check_is_name_free(&mut self, name: &str) -> Result<(), AnyError> {
let existing = spanf!(
let existing: Vec<Tunnel> = spanf!(
self.log,
self.log.span("dev-tunnel.rename.search"),
self.client.list_all_tunnels(&TunnelRequestOptions {
tags: vec![VSCODE_CLI_TUNNEL_TAG.to_string(), name.to_string()],
tags: vec![self.tag.to_string(), name.to_string()],
require_all_tags: true,
..Default::default()
})
@@ -629,6 +679,12 @@ impl DevTunnels {
Ok(())
}
fn get_placeholder_name() -> String {
let mut n = clean_hostname_for_tunnel(&gethostname::gethostname().to_string_lossy());
n.make_ascii_lowercase();
n
}
async fn get_name_for_tunnel(
&mut self,
preferred_name: Option<&str>,
@@ -660,10 +716,7 @@ impl DevTunnels {
use_random_name = true;
}
let mut placeholder_name =
clean_hostname_for_tunnel(&gethostname::gethostname().to_string_lossy());
placeholder_name.make_ascii_lowercase();
let mut placeholder_name = Self::get_placeholder_name();
if !is_name_free(&placeholder_name) {
for i in 2.. {
let fixed_name = format!("{}{}", placeholder_name, i);
@@ -705,7 +758,10 @@ impl DevTunnels {
tunnel: ExistingTunnel,
) -> Result<ActiveTunnel, AnyError> {
let tunnel_details = PersistedTunnel {
name: tunnel.tunnel_name,
name: match tunnel.tunnel_name {
Some(n) => n,
None => Self::get_placeholder_name(),
},
id: tunnel.tunnel_id,
cluster: tunnel.cluster,
};
@@ -715,10 +771,23 @@ impl DevTunnels {
tunnel.host_token.clone(),
));
let client = mgmt.into();
self.sync_tunnel_tags(
&client,
&tunnel_details.name,
Tunnel {
cluster_id: Some(tunnel_details.cluster.clone()),
tunnel_id: Some(tunnel_details.id.clone()),
..Default::default()
},
&HOST_TUNNEL_REQUEST_OPTIONS,
)
.await?;
self.start_tunnel(
tunnel_details.locator(),
&tunnel_details,
mgmt.into(),
client,
StaticAccessTokenProvider::new(tunnel.host_token),
)
.await
@@ -998,7 +1067,7 @@ fn clean_hostname_for_tunnel(hostname: &str) -> String {
}
}
fn vec_eq_unsorted(a: &[String], b: &[String]) -> bool {
fn vec_eq_as_set(a: &[String], b: &[String]) -> bool {
if a.len() != b.len() {
return false;
}

View File

@@ -0,0 +1,284 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
use std::{
collections::HashMap,
sync::{Arc, Mutex},
};
use tokio::{
pin,
sync::{mpsc, watch},
};
use crate::{
async_pipe::{socket_stream_split, AsyncPipe},
json_rpc::{new_json_rpc, start_json_rpc},
log,
singleton::SingletonServer,
util::{errors::CodeError, sync::Barrier},
};
use super::{
dev_tunnels::ActiveTunnel,
protocol::{
self,
forward_singleton::{PortList, SetPortsResponse},
},
shutdown_signal::ShutdownSignal,
};
type PortMap = HashMap<u16, u32>;
/// The PortForwardingHandle is given out to multiple consumers to allow
/// them to set_ports that they want to be forwarded.
struct PortForwardingSender {
/// Todo: when `SyncUnsafeCell` is no longer nightly, we can use it here with
/// the following comment:
///
/// SyncUnsafeCell is used and safe here because PortForwardingSender is used
/// exclusively in synchronous dispatch *and* we create a new sender in the
/// context for each connection, in `serve_singleton_rpc`.
///
/// If PortForwardingSender is ever used in a different context, this should
/// be refactored, e.g. to use locks or `&mut self` in set_ports`
///
/// see https://doc.rust-lang.org/stable/std/cell/struct.SyncUnsafeCell.html
current: Mutex<PortList>,
sender: Arc<Mutex<watch::Sender<PortMap>>>,
}
impl PortForwardingSender {
pub fn set_ports(&self, ports: PortList) {
let mut current = self.current.lock().unwrap();
self.sender.lock().unwrap().send_modify(|v| {
for p in current.iter() {
if !ports.contains(p) {
match v.get(p) {
Some(1) => {
v.remove(p);
}
Some(n) => {
v.insert(*p, n - 1);
}
None => unreachable!("removed port not in map"),
}
}
}
for p in ports.iter() {
if !current.contains(p) {
match v.get(p) {
Some(n) => v.insert(*p, n + 1),
None => v.insert(*p, 1),
};
}
}
current.splice(.., ports);
});
}
}
impl Clone for PortForwardingSender {
fn clone(&self) -> Self {
Self {
current: Mutex::new(vec![]),
sender: self.sender.clone(),
}
}
}
impl Drop for PortForwardingSender {
fn drop(&mut self) {
self.set_ports(vec![]);
}
}
struct PortForwardingReceiver {
receiver: watch::Receiver<PortMap>,
}
impl PortForwardingReceiver {
pub fn new() -> (PortForwardingSender, Self) {
let (sender, receiver) = watch::channel(HashMap::new());
let handle = PortForwardingSender {
current: Mutex::new(vec![]),
sender: Arc::new(Mutex::new(sender)),
};
let tracker = Self { receiver };
(handle, tracker)
}
/// Applies all changes from PortForwardingHandles to the tunnel.
pub async fn apply_to(&mut self, log: log::Logger, tunnel: Arc<ActiveTunnel>) {
let mut current = vec![];
while self.receiver.changed().await.is_ok() {
let next = self.receiver.borrow().keys().copied().collect::<Vec<_>>();
for p in current.iter() {
if !next.contains(p) {
match tunnel.remove_port(*p).await {
Ok(_) => info!(log, "stopped forwarding port {}", p),
Err(e) => error!(log, "failed to stop forwarding port {}: {}", p, e),
}
}
}
for p in next.iter() {
if !current.contains(p) {
match tunnel.add_port_tcp(*p).await {
Ok(_) => info!(log, "forwarding port {}", p),
Err(e) => error!(log, "failed to forward port {}: {}", p, e),
}
}
}
current = next;
}
}
}
pub struct SingletonClientArgs {
pub log: log::Logger,
pub stream: AsyncPipe,
pub shutdown: Barrier<ShutdownSignal>,
pub port_requests: watch::Receiver<PortList>,
}
#[derive(Clone)]
struct SingletonServerContext {
log: log::Logger,
handle: PortForwardingSender,
tunnel: Arc<ActiveTunnel>,
}
/// Serves a client singleton for port forwarding.
pub async fn client(args: SingletonClientArgs) -> Result<(), std::io::Error> {
let mut rpc = new_json_rpc();
let (msg_tx, msg_rx) = mpsc::unbounded_channel();
let SingletonClientArgs {
log,
shutdown,
stream,
mut port_requests,
} = args;
debug!(
log,
"An existing port forwarding process is running on this machine, connecting to it..."
);
let caller = rpc.get_caller(msg_tx);
let rpc = rpc.methods(()).build(log.clone());
let (read, write) = socket_stream_split(stream);
let serve = start_json_rpc(rpc, read, write, msg_rx, shutdown);
let forward = async move {
while port_requests.changed().await.is_ok() {
let ports = port_requests.borrow().clone();
let r = caller
.call::<_, _, protocol::forward_singleton::SetPortsResponse>(
protocol::forward_singleton::METHOD_SET_PORTS,
protocol::forward_singleton::SetPortsParams { ports },
)
.await
.unwrap();
match r {
Err(e) => error!(log, "failed to set ports: {:?}", e),
Ok(r) => print_forwarding_addr(&r),
};
}
};
tokio::select! {
r = serve => r.map(|_| ()),
_ = forward => Ok(()),
}
}
/// Serves a port-forwarding singleton.
pub async fn server(
log: log::Logger,
tunnel: ActiveTunnel,
server: SingletonServer,
mut port_requests: watch::Receiver<PortList>,
shutdown_rx: Barrier<ShutdownSignal>,
) -> Result<(), CodeError> {
let tunnel = Arc::new(tunnel);
let (forward_tx, mut forward_rx) = PortForwardingReceiver::new();
let forward_own_tunnel = tunnel.clone();
let forward_own_tx = forward_tx.clone();
let forward_own = async move {
while port_requests.changed().await.is_ok() {
forward_own_tx.set_ports(port_requests.borrow().clone());
print_forwarding_addr(&SetPortsResponse {
port_format: forward_own_tunnel.get_port_format().ok(),
});
}
};
tokio::select! {
_ = forward_own => Ok(()),
_ = forward_rx.apply_to(log.clone(), tunnel.clone()) => Ok(()),
r = serve_singleton_rpc(server, log, tunnel, forward_tx, shutdown_rx) => r,
}
}
async fn serve_singleton_rpc(
mut server: SingletonServer,
log: log::Logger,
tunnel: Arc<ActiveTunnel>,
forward_tx: PortForwardingSender,
shutdown_rx: Barrier<ShutdownSignal>,
) -> Result<(), CodeError> {
let mut own_shutdown = shutdown_rx.clone();
let shutdown_fut = own_shutdown.wait();
pin!(shutdown_fut);
loop {
let cnx = tokio::select! {
c = server.accept() => c?,
_ = &mut shutdown_fut => return Ok(()),
};
let (read, write) = socket_stream_split(cnx);
let shutdown_rx = shutdown_rx.clone();
let handle = forward_tx.clone();
let log = log.clone();
let tunnel = tunnel.clone();
tokio::spawn(async move {
// we make an rpc for the connection instead of re-using a dispatcher
// so that we can have the "handle" drop when the connection drops.
let rpc = new_json_rpc();
let mut rpc = rpc.methods(SingletonServerContext {
log: log.clone(),
handle,
tunnel,
});
rpc.register_sync(
protocol::forward_singleton::METHOD_SET_PORTS,
|p: protocol::forward_singleton::SetPortsParams, ctx| {
info!(ctx.log, "client setting ports to {:?}", p.ports);
ctx.handle.set_ports(p.ports);
Ok(SetPortsResponse {
port_format: ctx.tunnel.get_port_format().ok(),
})
},
);
let _ = start_json_rpc(rpc.build(log), read, write, (), shutdown_rx).await;
});
}
}
fn print_forwarding_addr(r: &SetPortsResponse) {
eprintln!("{}\n", serde_json::to_string(r).unwrap());
}

View File

@@ -91,10 +91,15 @@ impl InstalledServer {
pub fn server_paths(&self, p: &LauncherPaths) -> ServerPaths {
let server_dir = self.get_install_folder(p);
ServerPaths {
executable: server_dir
.join(SERVER_FOLDER_NAME)
.join("bin")
.join(self.quality.server_entrypoint()),
// allow using the OSS server in development via an override
executable: if let Some(p) = option_env!("VSCODE_CLI_OVERRIDE_SERVER_PATH") {
PathBuf::from(p)
} else {
server_dir
.join(SERVER_FOLDER_NAME)
.join("bin")
.join(self.quality.server_entrypoint())
},
logfile: server_dir.join("log.txt"),
pidfile: server_dir.join("pid.txt"),
server_dir,

View File

@@ -91,7 +91,7 @@ impl PortForwardingProcessor {
self.forwarded.insert(port);
}
tunnel.get_port_uri(port).await
tunnel.get_port_uri(port)
}
}

View File

@@ -199,6 +199,11 @@ pub struct SpawnResult {
pub const METHOD_CHALLENGE_ISSUE: &str = "challenge_issue";
pub const METHOD_CHALLENGE_VERIFY: &str = "challenge_verify";
#[derive(Serialize, Deserialize)]
pub struct ChallengeIssueParams {
pub token: Option<String>,
}
#[derive(Serialize, Deserialize)]
pub struct ChallengeIssueResponse {
pub challenge: String,
@@ -209,6 +214,24 @@ pub struct ChallengeVerifyParams {
pub response: String,
}
pub mod forward_singleton {
use serde::{Deserialize, Serialize};
pub const METHOD_SET_PORTS: &str = "set_ports";
pub type PortList = Vec<u16>;
#[derive(Serialize, Deserialize)]
pub struct SetPortsParams {
pub ports: PortList,
}
#[derive(Serialize, Deserialize)]
pub struct SetPortsResponse {
pub port_format: Option<String>,
}
}
pub mod singleton {
use crate::log;
use serde::{Deserialize, Serialize};

View File

@@ -78,6 +78,7 @@ impl CliServiceManager for WindowsService {
cmd.stderr(Stdio::null());
cmd.stdout(Stdio::null());
cmd.stdin(Stdio::null());
cmd.creation_flags(CREATE_NEW_PROCESS_GROUP | DETACHED_PROCESS);
cmd.spawn()
.map_err(|e| wrapdbg(e, "error starting service"))?;
@@ -121,8 +122,12 @@ impl CliServiceManager for WindowsService {
async fn unregister(&self) -> Result<(), AnyError> {
let key = WindowsService::open_key()?;
key.delete_value(TUNNEL_ACTIVITY_NAME)
.map_err(|e| AnyError::from(wrap(e, "error deleting registry key")))?;
match key.delete_value(TUNNEL_ACTIVITY_NAME) {
Ok(_) => {}
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {}
Err(e) => return Err(wrap(e, "error deleting registry key").into()),
}
info!(self.log, "Tunnel service uninstalled");
let r = do_single_rpc_call::<_, ()>(

View File

@@ -217,7 +217,7 @@ impl BroadcastLogSink {
}
}
fn get_brocaster(&self) -> broadcast::Sender<Vec<u8>> {
pub fn get_brocaster(&self) -> broadcast::Sender<Vec<u8>> {
self.tx.clone()
}

View File

@@ -509,8 +509,14 @@ pub enum CodeError {
ServerAuthRequired,
#[error("challenge not yet issued")]
AuthChallengeNotIssued,
#[error("challenge token is invalid")]
AuthChallengeBadToken,
#[error("unauthorized client refused")]
AuthMismatch,
#[error("keyring communication timed out after 5s")]
KeyringTimeout,
#[error("no host is connected to the tunnel relay")]
NoTunnelEndpoint,
}
makeAnyError!(