mirror of
https://github.com/microsoft/vscode.git
synced 2026-04-28 12:33:35 +01:00
Merge branch 'main' into cli-ensure-code-tunnel-service-remains-headless-on-windows
This commit is contained in:
@@ -4,7 +4,10 @@
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
use crate::{constants::APPLICATION_NAME, util::errors::CodeError};
|
||||
use async_trait::async_trait;
|
||||
use std::path::{Path, PathBuf};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio::net::TcpListener;
|
||||
use uuid::Uuid;
|
||||
|
||||
// todo: we could probably abstract this into some crate, if one doesn't already exist
|
||||
@@ -39,7 +42,7 @@ cfg_if::cfg_if! {
|
||||
pipe.into_split()
|
||||
}
|
||||
} else {
|
||||
use tokio::{time::sleep, io::{AsyncRead, AsyncWrite, ReadBuf}};
|
||||
use tokio::{time::sleep, io::ReadBuf};
|
||||
use tokio::net::windows::named_pipe::{ClientOptions, ServerOptions, NamedPipeClient, NamedPipeServer};
|
||||
use std::{time::Duration, pin::Pin, task::{Context, Poll}, io};
|
||||
use pin_project::pin_project;
|
||||
@@ -181,3 +184,34 @@ pub fn get_socket_name() -> PathBuf {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub type AcceptedRW = (
|
||||
Box<dyn AsyncRead + Send + Unpin>,
|
||||
Box<dyn AsyncWrite + Send + Unpin>,
|
||||
);
|
||||
|
||||
#[async_trait]
|
||||
pub trait AsyncRWAccepter {
|
||||
async fn accept_rw(&mut self) -> Result<AcceptedRW, CodeError>;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl AsyncRWAccepter for AsyncPipeListener {
|
||||
async fn accept_rw(&mut self) -> Result<AcceptedRW, CodeError> {
|
||||
let pipe = self.accept().await?;
|
||||
let (read, write) = socket_stream_split(pipe);
|
||||
Ok((Box::new(read), Box::new(write)))
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl AsyncRWAccepter for TcpListener {
|
||||
async fn accept_rw(&mut self) -> Result<AcceptedRW, CodeError> {
|
||||
let (stream, _) = self
|
||||
.accept()
|
||||
.await
|
||||
.map_err(CodeError::AsyncPipeListenerFailed)?;
|
||||
let (read, write) = tokio::io::split(stream);
|
||||
Ok((Box::new(read), Box::new(write)))
|
||||
}
|
||||
}
|
||||
|
||||
109
cli/src/auth.rs
109
cli/src/auth.rs
@@ -10,7 +10,8 @@ use crate::{
|
||||
trace,
|
||||
util::{
|
||||
errors::{
|
||||
wrap, AnyError, OAuthError, RefreshTokenNotAvailableError, StatusError, WrappedError,
|
||||
wrap, AnyError, CodeError, OAuthError, RefreshTokenNotAvailableError, StatusError,
|
||||
WrappedError,
|
||||
},
|
||||
input::prompt_options,
|
||||
},
|
||||
@@ -160,6 +161,7 @@ impl StoredCredential {
|
||||
|
||||
struct StorageWithLastRead {
|
||||
storage: Box<dyn StorageImplementation>,
|
||||
fallback_storage: Option<FileStorage>,
|
||||
last_read: Cell<Result<Option<StoredCredential>, WrappedError>>,
|
||||
}
|
||||
|
||||
@@ -172,9 +174,9 @@ pub struct Auth {
|
||||
}
|
||||
|
||||
trait StorageImplementation: Send + Sync {
|
||||
fn read(&mut self) -> Result<Option<StoredCredential>, WrappedError>;
|
||||
fn store(&mut self, value: StoredCredential) -> Result<(), WrappedError>;
|
||||
fn clear(&mut self) -> Result<(), WrappedError>;
|
||||
fn read(&mut self) -> Result<Option<StoredCredential>, AnyError>;
|
||||
fn store(&mut self, value: StoredCredential) -> Result<(), AnyError>;
|
||||
fn clear(&mut self) -> Result<(), AnyError>;
|
||||
}
|
||||
|
||||
// unseal decrypts and deserializes the value
|
||||
@@ -217,16 +219,34 @@ struct ThreadKeyringStorage {
|
||||
}
|
||||
|
||||
impl ThreadKeyringStorage {
|
||||
fn thread_op<R, Fn>(&mut self, f: Fn) -> R
|
||||
fn thread_op<R, Fn>(&mut self, f: Fn) -> Result<R, AnyError>
|
||||
where
|
||||
Fn: 'static + Send + FnOnce(&mut KeyringStorage) -> R,
|
||||
Fn: 'static + Send + FnOnce(&mut KeyringStorage) -> Result<R, AnyError>,
|
||||
R: 'static + Send,
|
||||
{
|
||||
let mut s = self.s.take().unwrap();
|
||||
let handler = thread::spawn(move || (f(&mut s), s));
|
||||
let (r, s) = handler.join().unwrap();
|
||||
self.s = Some(s);
|
||||
r
|
||||
let mut s = match self.s.take() {
|
||||
Some(s) => s,
|
||||
None => return Err(CodeError::KeyringTimeout.into()),
|
||||
};
|
||||
|
||||
// It seems like on Linux communication to the keyring can block indefinitely.
|
||||
// Fall back after a 5 second timeout.
|
||||
let (sender, receiver) = std::sync::mpsc::channel();
|
||||
let tsender = sender.clone();
|
||||
|
||||
thread::spawn(move || sender.send(Some((f(&mut s), s))));
|
||||
thread::spawn(move || {
|
||||
thread::sleep(std::time::Duration::from_secs(5));
|
||||
let _ = tsender.send(None);
|
||||
});
|
||||
|
||||
match receiver.recv().unwrap() {
|
||||
Some((r, s)) => {
|
||||
self.s = Some(s);
|
||||
r
|
||||
}
|
||||
None => Err(CodeError::KeyringTimeout.into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -239,15 +259,15 @@ impl Default for ThreadKeyringStorage {
|
||||
}
|
||||
|
||||
impl StorageImplementation for ThreadKeyringStorage {
|
||||
fn read(&mut self) -> Result<Option<StoredCredential>, WrappedError> {
|
||||
fn read(&mut self) -> Result<Option<StoredCredential>, AnyError> {
|
||||
self.thread_op(|s| s.read())
|
||||
}
|
||||
|
||||
fn store(&mut self, value: StoredCredential) -> Result<(), WrappedError> {
|
||||
fn store(&mut self, value: StoredCredential) -> Result<(), AnyError> {
|
||||
self.thread_op(move |s| s.store(value))
|
||||
}
|
||||
|
||||
fn clear(&mut self) -> Result<(), WrappedError> {
|
||||
fn clear(&mut self) -> Result<(), AnyError> {
|
||||
self.thread_op(|s| s.clear())
|
||||
}
|
||||
}
|
||||
@@ -273,7 +293,7 @@ macro_rules! get_next_entry {
|
||||
}
|
||||
|
||||
impl StorageImplementation for KeyringStorage {
|
||||
fn read(&mut self) -> Result<Option<StoredCredential>, WrappedError> {
|
||||
fn read(&mut self) -> Result<Option<StoredCredential>, AnyError> {
|
||||
let mut str = String::new();
|
||||
|
||||
for i in 0.. {
|
||||
@@ -281,7 +301,7 @@ impl StorageImplementation for KeyringStorage {
|
||||
let next_chunk = match entry.get_password() {
|
||||
Ok(value) => value,
|
||||
Err(keyring::Error::NoEntry) => return Ok(None), // missing entries?
|
||||
Err(e) => return Err(wrap(e, "error reading keyring")),
|
||||
Err(e) => return Err(wrap(e, "error reading keyring").into()),
|
||||
};
|
||||
|
||||
if next_chunk.ends_with(CONTINUE_MARKER) {
|
||||
@@ -295,7 +315,7 @@ impl StorageImplementation for KeyringStorage {
|
||||
Ok(unseal(&str))
|
||||
}
|
||||
|
||||
fn store(&mut self, value: StoredCredential) -> Result<(), WrappedError> {
|
||||
fn store(&mut self, value: StoredCredential) -> Result<(), AnyError> {
|
||||
let sealed = seal(&value);
|
||||
let step_size = KEYCHAIN_ENTRY_LIMIT - CONTINUE_MARKER.len();
|
||||
|
||||
@@ -312,14 +332,14 @@ impl StorageImplementation for KeyringStorage {
|
||||
};
|
||||
|
||||
if let Err(e) = stored {
|
||||
return Err(wrap(e, "error updating keyring"));
|
||||
return Err(wrap(e, "error updating keyring").into());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn clear(&mut self) -> Result<(), WrappedError> {
|
||||
fn clear(&mut self) -> Result<(), AnyError> {
|
||||
self.read().ok(); // make sure component parts are available
|
||||
for entry in self.entries.iter() {
|
||||
entry
|
||||
@@ -335,16 +355,16 @@ impl StorageImplementation for KeyringStorage {
|
||||
struct FileStorage(PersistedState<Option<String>>);
|
||||
|
||||
impl StorageImplementation for FileStorage {
|
||||
fn read(&mut self) -> Result<Option<StoredCredential>, WrappedError> {
|
||||
fn read(&mut self) -> Result<Option<StoredCredential>, AnyError> {
|
||||
Ok(self.0.load().and_then(|s| unseal(&s)))
|
||||
}
|
||||
|
||||
fn store(&mut self, value: StoredCredential) -> Result<(), WrappedError> {
|
||||
self.0.save(Some(seal(&value)))
|
||||
fn store(&mut self, value: StoredCredential) -> Result<(), AnyError> {
|
||||
self.0.save(Some(seal(&value))).map_err(|e| e.into())
|
||||
}
|
||||
|
||||
fn clear(&mut self) -> Result<(), WrappedError> {
|
||||
self.0.save(None)
|
||||
fn clear(&mut self) -> Result<(), AnyError> {
|
||||
self.0.save(None).map_err(|e| e.into())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -373,20 +393,32 @@ impl Auth {
|
||||
let mut keyring_storage = ThreadKeyringStorage::default();
|
||||
let mut file_storage = FileStorage(PersistedState::new(self.file_storage_path.clone()));
|
||||
|
||||
let keyring_storage_result = match std::env::var("VSCODE_CLI_USE_FILE_KEYCHAIN") {
|
||||
Ok(_) => Err(wrap("", "user prefers file storage")),
|
||||
_ => keyring_storage.read(),
|
||||
let native_storage_result = if std::env::var("VSCODE_CLI_USE_FILE_KEYCHAIN").is_ok()
|
||||
|| self.file_storage_path.exists()
|
||||
{
|
||||
Err(wrap("", "user prefers file storage").into())
|
||||
} else {
|
||||
keyring_storage.read()
|
||||
};
|
||||
|
||||
let mut storage = match keyring_storage_result {
|
||||
let mut storage = match native_storage_result {
|
||||
Ok(v) => StorageWithLastRead {
|
||||
last_read: Cell::new(Ok(v)),
|
||||
fallback_storage: Some(file_storage),
|
||||
storage: Box::new(keyring_storage),
|
||||
},
|
||||
Err(_) => StorageWithLastRead {
|
||||
last_read: Cell::new(file_storage.read()),
|
||||
storage: Box::new(file_storage),
|
||||
},
|
||||
Err(e) => {
|
||||
debug!(self.log, "Using file keychain storage due to: {}", e);
|
||||
StorageWithLastRead {
|
||||
last_read: Cell::new(
|
||||
file_storage
|
||||
.read()
|
||||
.map_err(|e| wrap(e, "could not read from file storage")),
|
||||
),
|
||||
fallback_storage: None,
|
||||
storage: Box::new(file_storage),
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let out = op(&mut storage);
|
||||
@@ -419,7 +451,7 @@ impl Auth {
|
||||
}
|
||||
|
||||
/// Clears login info from the keyring.
|
||||
pub fn clear_credentials(&self) -> Result<(), WrappedError> {
|
||||
pub fn clear_credentials(&self) -> Result<(), AnyError> {
|
||||
self.with_storage(|storage| {
|
||||
storage.storage.clear()?;
|
||||
storage.last_read.set(Ok(None));
|
||||
@@ -505,7 +537,18 @@ impl Auth {
|
||||
"Failed to update keyring with new credentials: {}",
|
||||
e
|
||||
);
|
||||
|
||||
if let Some(fb) = storage.fallback_storage.take() {
|
||||
storage.storage = Box::new(fb);
|
||||
match storage.storage.store(creds.clone()) {
|
||||
Err(e) => {
|
||||
warning!(self.log, "Also failed to update fallback storage: {}", e)
|
||||
}
|
||||
Ok(_) => debug!(self.log, "Updated fallback storage successfully"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
storage.last_read.set(Ok(Some(creds)));
|
||||
})
|
||||
}
|
||||
|
||||
@@ -95,7 +95,9 @@ async fn main() -> Result<(), std::convert::Infallible> {
|
||||
args::VersionSubcommand::Show => version::show(context!()).await,
|
||||
},
|
||||
|
||||
Some(args::Commands::CommandShell) => tunnels::command_shell(context!()).await,
|
||||
Some(args::Commands::CommandShell(cs_args)) => {
|
||||
tunnels::command_shell(context!(), cs_args).await
|
||||
}
|
||||
|
||||
Some(args::Commands::Tunnel(tunnel_args)) => match tunnel_args.subcommand {
|
||||
Some(args::TunnelSubcommand::Prune) => tunnels::prune(context!()).await,
|
||||
@@ -112,6 +114,9 @@ async fn main() -> Result<(), std::convert::Infallible> {
|
||||
Some(args::TunnelSubcommand::Service(service_args)) => {
|
||||
tunnels::service(context_no_logger(), service_args).await
|
||||
}
|
||||
Some(args::TunnelSubcommand::ForwardInternal(forward_args)) => {
|
||||
tunnels::forward(context_no_logger(), forward_args).await
|
||||
}
|
||||
None => tunnels::serve(context_no_logger(), tunnel_args.serve_args).await,
|
||||
},
|
||||
},
|
||||
|
||||
@@ -174,7 +174,20 @@ pub enum Commands {
|
||||
|
||||
/// Runs the control server on process stdin/stdout
|
||||
#[clap(hide = true)]
|
||||
CommandShell,
|
||||
CommandShell(CommandShellArgs),
|
||||
}
|
||||
|
||||
#[derive(Args, Debug, Clone)]
|
||||
pub struct CommandShellArgs {
|
||||
/// Listen on a socket instead of stdin/stdout.
|
||||
#[clap(long)]
|
||||
pub on_socket: bool,
|
||||
/// Listen on a port instead of stdin/stdout.
|
||||
#[clap(long)]
|
||||
pub on_port: bool,
|
||||
/// Require the given token string to be given in the handshake.
|
||||
#[clap(long)]
|
||||
pub require_token: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Args, Debug, Clone)]
|
||||
@@ -548,6 +561,7 @@ pub enum OutputFormat {
|
||||
#[derive(Args, Clone, Debug, Default)]
|
||||
pub struct ExistingTunnelArgs {
|
||||
/// Name you'd like to assign preexisting tunnel to use to connect the tunnel
|
||||
/// Old option, new code sohuld just use `--name`.
|
||||
#[clap(long, hide = true)]
|
||||
pub tunnel_name: Option<String>,
|
||||
|
||||
@@ -626,6 +640,10 @@ pub enum TunnelSubcommand {
|
||||
/// (Preview) Manages the tunnel when installed as a system service,
|
||||
#[clap(subcommand)]
|
||||
Service(TunnelServiceSubCommands),
|
||||
|
||||
/// (Preview) Forwards local port using the dev tunnel
|
||||
#[clap(hide = true)]
|
||||
ForwardInternal(TunnelForwardArgs),
|
||||
}
|
||||
|
||||
#[derive(Subcommand, Debug, Clone)]
|
||||
@@ -649,6 +667,10 @@ pub struct TunnelServiceInstallArgs {
|
||||
/// If set, the user accepts the server license terms and the server will be started without a user prompt.
|
||||
#[clap(long)]
|
||||
pub accept_server_license_terms: bool,
|
||||
|
||||
/// Sets the machine name for port forwarding service
|
||||
#[clap(long)]
|
||||
pub name: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Args, Debug, Clone)]
|
||||
@@ -657,6 +679,16 @@ pub struct TunnelRenameArgs {
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
#[derive(Args, Debug, Clone)]
|
||||
pub struct TunnelForwardArgs {
|
||||
/// One or more ports to forward.
|
||||
pub ports: Vec<u16>,
|
||||
|
||||
/// Login args -- used for convenience so the forwarding call is a single action.
|
||||
#[clap(flatten)]
|
||||
pub login: LoginArgs,
|
||||
}
|
||||
|
||||
#[derive(Subcommand, Debug, Clone)]
|
||||
pub enum TunnelUserSubCommands {
|
||||
/// Log in to port forwarding service
|
||||
|
||||
@@ -5,27 +5,37 @@
|
||||
|
||||
use async_trait::async_trait;
|
||||
use base64::{engine::general_purpose as b64, Engine as _};
|
||||
use futures::{stream::FuturesUnordered, StreamExt};
|
||||
use serde::Serialize;
|
||||
use sha2::{Digest, Sha256};
|
||||
use std::{str::FromStr, time::Duration};
|
||||
use sysinfo::Pid;
|
||||
use tokio::{
|
||||
io::{AsyncBufReadExt, BufReader},
|
||||
sync::watch,
|
||||
};
|
||||
|
||||
use super::{
|
||||
args::{
|
||||
AuthProvider, CliCore, ExistingTunnelArgs, TunnelRenameArgs, TunnelServeArgs,
|
||||
TunnelServiceSubCommands, TunnelUserSubCommands,
|
||||
AuthProvider, CliCore, CommandShellArgs, ExistingTunnelArgs, TunnelForwardArgs,
|
||||
TunnelRenameArgs, TunnelServeArgs, TunnelServiceSubCommands, TunnelUserSubCommands,
|
||||
},
|
||||
CommandContext,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
async_pipe::{get_socket_name, listen_socket_rw_stream, AsyncRWAccepter},
|
||||
auth::Auth,
|
||||
constants::{APPLICATION_NAME, TUNNEL_CLI_LOCK_NAME, TUNNEL_SERVICE_LOCK_NAME},
|
||||
constants::{
|
||||
APPLICATION_NAME, CONTROL_PORT, IS_A_TTY, TUNNEL_CLI_LOCK_NAME, TUNNEL_SERVICE_LOCK_NAME,
|
||||
},
|
||||
log,
|
||||
state::LauncherPaths,
|
||||
tunnels::{
|
||||
code_server::CodeServerArgs,
|
||||
create_service_manager, dev_tunnels, legal,
|
||||
create_service_manager,
|
||||
dev_tunnels::{self, DevTunnels},
|
||||
forwarding, legal,
|
||||
paths::get_all_servers,
|
||||
protocol, serve_stream,
|
||||
shutdown_signal::ShutdownRequest,
|
||||
@@ -33,7 +43,7 @@ use crate::{
|
||||
singleton_server::{
|
||||
make_singleton_server, start_singleton_server, BroadcastLogSink, SingletonServerArgs,
|
||||
},
|
||||
Next, ServeStreamParams, ServiceContainer, ServiceManager,
|
||||
AuthRequired, Next, ServeStreamParams, ServiceContainer, ServiceManager,
|
||||
},
|
||||
util::{
|
||||
app_lock::AppMutex,
|
||||
@@ -59,20 +69,31 @@ impl From<AuthProvider> for crate::auth::AuthProvider {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ExistingTunnelArgs> for Option<dev_tunnels::ExistingTunnel> {
|
||||
fn from(d: ExistingTunnelArgs) -> Option<dev_tunnels::ExistingTunnel> {
|
||||
if let (Some(tunnel_id), Some(tunnel_name), Some(cluster), Some(host_token)) =
|
||||
(d.tunnel_id, d.tunnel_name, d.cluster, d.host_token)
|
||||
{
|
||||
fn fulfill_existing_tunnel_args(
|
||||
d: ExistingTunnelArgs,
|
||||
name_arg: &Option<String>,
|
||||
) -> Option<dev_tunnels::ExistingTunnel> {
|
||||
let tunnel_name = d.tunnel_name.or_else(|| name_arg.clone());
|
||||
|
||||
match (d.tunnel_id, d.cluster, d.host_token) {
|
||||
(Some(tunnel_id), None, Some(host_token)) => {
|
||||
let i = tunnel_id.find('.')?;
|
||||
Some(dev_tunnels::ExistingTunnel {
|
||||
tunnel_id,
|
||||
tunnel_id: tunnel_id[..i].to_string(),
|
||||
cluster: tunnel_id[i + 1..].to_string(),
|
||||
tunnel_name,
|
||||
host_token,
|
||||
cluster,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
|
||||
(Some(tunnel_id), Some(cluster), Some(host_token)) => Some(dev_tunnels::ExistingTunnel {
|
||||
tunnel_id,
|
||||
tunnel_name,
|
||||
host_token,
|
||||
cluster,
|
||||
}),
|
||||
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -109,23 +130,71 @@ impl ServiceContainer for TunnelServiceContainer {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn command_shell(ctx: CommandContext) -> Result<i32, AnyError> {
|
||||
pub async fn command_shell(ctx: CommandContext, args: CommandShellArgs) -> Result<i32, AnyError> {
|
||||
let platform = PreReqChecker::new().verify().await?;
|
||||
serve_stream(
|
||||
tokio::io::stdin(),
|
||||
tokio::io::stderr(),
|
||||
ServeStreamParams {
|
||||
log: ctx.log,
|
||||
launcher_paths: ctx.paths,
|
||||
platform,
|
||||
requires_auth: true,
|
||||
exit_barrier: ShutdownRequest::create_rx([ShutdownRequest::CtrlC]),
|
||||
code_server_args: (&ctx.args).into(),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
let mut params = ServeStreamParams {
|
||||
log: ctx.log,
|
||||
launcher_paths: ctx.paths,
|
||||
platform,
|
||||
requires_auth: args
|
||||
.require_token
|
||||
.map(AuthRequired::VSDAWithToken)
|
||||
.unwrap_or(AuthRequired::VSDA),
|
||||
exit_barrier: ShutdownRequest::create_rx([ShutdownRequest::CtrlC]),
|
||||
code_server_args: (&ctx.args).into(),
|
||||
};
|
||||
|
||||
Ok(0)
|
||||
let mut listener: Box<dyn AsyncRWAccepter> = match (args.on_port, args.on_socket) {
|
||||
(_, true) => {
|
||||
let socket = get_socket_name();
|
||||
let listener = listen_socket_rw_stream(&socket)
|
||||
.await
|
||||
.map_err(|e| wrap(e, "error listening on socket"))?;
|
||||
|
||||
params
|
||||
.log
|
||||
.result(format!("Listening on {}", socket.display()));
|
||||
|
||||
Box::new(listener)
|
||||
}
|
||||
(true, _) => {
|
||||
let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
|
||||
.await
|
||||
.map_err(|e| wrap(e, "error listening on port"))?;
|
||||
|
||||
params
|
||||
.log
|
||||
.result(format!("Listening on {}", listener.local_addr().unwrap()));
|
||||
|
||||
Box::new(listener)
|
||||
}
|
||||
_ => {
|
||||
serve_stream(tokio::io::stdin(), tokio::io::stderr(), params).await;
|
||||
return Ok(0);
|
||||
}
|
||||
};
|
||||
|
||||
let mut servers = FuturesUnordered::new();
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
Some(_) = servers.next() => {},
|
||||
socket = listener.accept_rw() => {
|
||||
match socket {
|
||||
Ok((read, write)) => servers.push(serve_stream(read, write, params.clone())),
|
||||
Err(e) => {
|
||||
error!(params.log, &format!("Error accepting connection: {}", e));
|
||||
return Ok(1);
|
||||
}
|
||||
}
|
||||
},
|
||||
_ = params.exit_barrier.wait() => {
|
||||
// wait for all servers to finish up:
|
||||
while (servers.next().await).is_some() { }
|
||||
return Ok(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn service(
|
||||
@@ -135,10 +204,17 @@ pub async fn service(
|
||||
let manager = create_service_manager(ctx.log.clone(), &ctx.paths);
|
||||
match service_args {
|
||||
TunnelServiceSubCommands::Install(args) => {
|
||||
// ensure logged in, otherwise subsequent serving will fail
|
||||
Auth::new(&ctx.paths, ctx.log.clone())
|
||||
.get_credential()
|
||||
.await?;
|
||||
let auth = Auth::new(&ctx.paths, ctx.log.clone());
|
||||
|
||||
if let Some(name) = &args.name {
|
||||
// ensure the name matches, and tunnel exists
|
||||
dev_tunnels::DevTunnels::new_remote_tunnel(&ctx.log, auth, &ctx.paths)
|
||||
.rename_tunnel(name)
|
||||
.await?;
|
||||
} else {
|
||||
// still ensure they're logged in, otherwise subsequent serving will fail
|
||||
auth.get_credential().await?;
|
||||
}
|
||||
|
||||
// likewise for license consent
|
||||
legal::require_consent(&ctx.paths, args.accept_server_license_terms)?;
|
||||
@@ -203,23 +279,23 @@ pub async fn user(ctx: CommandContext, user_args: TunnelUserSubCommands) -> Resu
|
||||
Ok(0)
|
||||
}
|
||||
|
||||
/// Remove the tunnel used by this gateway, if any.
|
||||
/// Remove the tunnel used by this tunnel, if any.
|
||||
pub async fn rename(ctx: CommandContext, rename_args: TunnelRenameArgs) -> Result<i32, AnyError> {
|
||||
let auth = Auth::new(&ctx.paths, ctx.log.clone());
|
||||
let mut dt = dev_tunnels::DevTunnels::new(&ctx.log, auth, &ctx.paths);
|
||||
let mut dt = dev_tunnels::DevTunnels::new_remote_tunnel(&ctx.log, auth, &ctx.paths);
|
||||
dt.rename_tunnel(&rename_args.name).await?;
|
||||
ctx.log.result(format!(
|
||||
"Successfully renamed this gateway to {}",
|
||||
"Successfully renamed this tunnel to {}",
|
||||
&rename_args.name
|
||||
));
|
||||
|
||||
Ok(0)
|
||||
}
|
||||
|
||||
/// Remove the tunnel used by this gateway, if any.
|
||||
/// Remove the tunnel used by this tunnel, if any.
|
||||
pub async fn unregister(ctx: CommandContext) -> Result<i32, AnyError> {
|
||||
let auth = Auth::new(&ctx.paths, ctx.log.clone());
|
||||
let mut dt = dev_tunnels::DevTunnels::new(&ctx.log, auth, &ctx.paths);
|
||||
let mut dt = dev_tunnels::DevTunnels::new_remote_tunnel(&ctx.log, auth, &ctx.paths);
|
||||
dt.remove_tunnel().await?;
|
||||
Ok(0)
|
||||
}
|
||||
@@ -327,6 +403,88 @@ pub async fn serve(ctx: CommandContext, gateway_args: TunnelServeArgs) -> Result
|
||||
result
|
||||
}
|
||||
|
||||
/// Internal command used by port forwarding. It reads requests for forwarded ports
|
||||
/// on lines from stdin, as JSON. It uses singleton logic as well (though on
|
||||
/// a different tunnel than the main one used for the control server) so that
|
||||
/// all forward requests on a single machine go through a single hosted tunnel
|
||||
/// process. Without singleton logic, requests could get routed to processes
|
||||
/// that aren't forwarding a given port and then fail.
|
||||
pub async fn forward(
|
||||
ctx: CommandContext,
|
||||
mut forward_args: TunnelForwardArgs,
|
||||
) -> Result<i32, AnyError> {
|
||||
// Spooky: check IS_A_TTY before starting the stdin reader, since IS_A_TTY will
|
||||
// access stdin but a lock will later be held on stdin by the line-reader.
|
||||
if *IS_A_TTY {
|
||||
trace!(ctx.log, "port forwarding is an internal preview feature");
|
||||
}
|
||||
|
||||
// #region stdin reading logic:
|
||||
let (own_ports_tx, own_ports_rx) = watch::channel(vec![]);
|
||||
let ports_process_log = ctx.log.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut lines = BufReader::new(tokio::io::stdin()).lines();
|
||||
while let Ok(Some(line)) = lines.next_line().await {
|
||||
match serde_json::from_str(&line) {
|
||||
Ok(p) => {
|
||||
let _ = own_ports_tx.send(p);
|
||||
}
|
||||
Err(e) => warning!(ports_process_log, "error parsing ports: {}", e),
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// #region singleton acquisition
|
||||
let shutdown = ShutdownRequest::create_rx([ShutdownRequest::CtrlC]);
|
||||
let server = loop {
|
||||
if shutdown.is_open() {
|
||||
return Ok(0);
|
||||
}
|
||||
|
||||
match acquire_singleton(&ctx.paths.forwarding_lockfile()).await {
|
||||
Ok(SingletonConnection::Client(stream)) => {
|
||||
debug!(ctx.log, "starting as client to singleton");
|
||||
let r = forwarding::client(forwarding::SingletonClientArgs {
|
||||
log: ctx.log.clone(),
|
||||
shutdown: shutdown.clone(),
|
||||
stream,
|
||||
port_requests: own_ports_rx.clone(),
|
||||
})
|
||||
.await;
|
||||
if let Err(e) = r {
|
||||
warning!(ctx.log, "error contacting forwarding singleton: {}", e);
|
||||
}
|
||||
}
|
||||
Ok(SingletonConnection::Singleton(server)) => break server,
|
||||
Err(e) => {
|
||||
warning!(ctx.log, "error access singleton, retrying: {}", e);
|
||||
tokio::time::sleep(Duration::from_secs(2)).await
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// #region singleton handler
|
||||
let auth = Auth::new(&ctx.paths, ctx.log.clone());
|
||||
println!("preauth {:?}", forward_args.login);
|
||||
if let (Some(p), Some(at)) = (
|
||||
forward_args.login.provider.take(),
|
||||
forward_args.login.access_token.take(),
|
||||
) {
|
||||
auth.login(Some(p.into()), Some(at)).await?;
|
||||
}
|
||||
println!("auth done");
|
||||
|
||||
let mut tunnels = DevTunnels::new_port_forwarding(&ctx.log, auth, &ctx.paths);
|
||||
let tunnel = tunnels
|
||||
.start_new_launcher_tunnel(None, true, &forward_args.ports)
|
||||
.await?;
|
||||
println!("made tunnel");
|
||||
|
||||
forwarding::server(ctx.log, tunnel, server, own_ports_rx, shutdown).await?;
|
||||
|
||||
Ok(0)
|
||||
}
|
||||
|
||||
fn get_connection_token(tunnel: &ActiveTunnel) -> String {
|
||||
let mut hash = Sha256::new();
|
||||
hash.update(tunnel.id.as_bytes());
|
||||
@@ -374,7 +532,7 @@ async fn serve_with_csa(
|
||||
return Ok(0);
|
||||
}
|
||||
|
||||
match acquire_singleton(paths.tunnel_lockfile()).await {
|
||||
match acquire_singleton(&paths.tunnel_lockfile()).await {
|
||||
Ok(SingletonConnection::Client(stream)) => {
|
||||
debug!(log, "starting as client to singleton");
|
||||
let should_exit = start_singleton_client(SingletonClientArgs {
|
||||
@@ -403,13 +561,19 @@ async fn serve_with_csa(
|
||||
let _lock = app_mutex_name.map(AppMutex::new);
|
||||
|
||||
let auth = Auth::new(&paths, log.clone());
|
||||
let mut dt = dev_tunnels::DevTunnels::new(&log, auth, &paths);
|
||||
let mut dt = dev_tunnels::DevTunnels::new_remote_tunnel(&log, auth, &paths);
|
||||
loop {
|
||||
let tunnel = if let Some(d) = gateway_args.tunnel.clone().into() {
|
||||
dt.start_existing_tunnel(d).await
|
||||
let tunnel = if let Some(t) =
|
||||
fulfill_existing_tunnel_args(gateway_args.tunnel.clone(), &gateway_args.name)
|
||||
{
|
||||
dt.start_existing_tunnel(t).await
|
||||
} else {
|
||||
dt.start_new_launcher_tunnel(gateway_args.name.as_deref(), gateway_args.random_name)
|
||||
.await
|
||||
dt.start_new_launcher_tunnel(
|
||||
gateway_args.name.as_deref(),
|
||||
gateway_args.random_name,
|
||||
&[CONTROL_PORT],
|
||||
)
|
||||
.await
|
||||
}?;
|
||||
|
||||
csa.connection_token = Some(get_connection_token(&tunnel));
|
||||
|
||||
@@ -122,7 +122,7 @@ pub struct MsgPackCodec<T> {
|
||||
impl<T> MsgPackCodec<T> {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
_marker: std::marker::PhantomData::default(),
|
||||
_marker: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -117,7 +117,7 @@ impl<S: Serialization, C: Send + Sync + 'static> RpcMethodBuilder<S, C> {
|
||||
Ok(p) => p,
|
||||
Err(err) => {
|
||||
return id.map(|id| {
|
||||
serial.serialize(&ErrorResponse {
|
||||
serial.serialize(ErrorResponse {
|
||||
id,
|
||||
error: ResponseError {
|
||||
code: 0,
|
||||
@@ -131,7 +131,7 @@ impl<S: Serialization, C: Send + Sync + 'static> RpcMethodBuilder<S, C> {
|
||||
match callback(param.params, &context) {
|
||||
Ok(result) => id.map(|id| serial.serialize(&SuccessResponse { id, result })),
|
||||
Err(err) => id.map(|id| {
|
||||
serial.serialize(&ErrorResponse {
|
||||
serial.serialize(ErrorResponse {
|
||||
id,
|
||||
error: ResponseError {
|
||||
code: -1,
|
||||
@@ -161,7 +161,7 @@ impl<S: Serialization, C: Send + Sync + 'static> RpcMethodBuilder<S, C> {
|
||||
Ok(p) => p,
|
||||
Err(err) => {
|
||||
return future::ready(id.map(|id| {
|
||||
serial.serialize(&ErrorResponse {
|
||||
serial.serialize(ErrorResponse {
|
||||
id,
|
||||
error: ResponseError {
|
||||
code: 0,
|
||||
@@ -182,7 +182,7 @@ impl<S: Serialization, C: Send + Sync + 'static> RpcMethodBuilder<S, C> {
|
||||
id.map(|id| serial.serialize(&SuccessResponse { id, result }))
|
||||
}
|
||||
Err(err) => id.map(|id| {
|
||||
serial.serialize(&ErrorResponse {
|
||||
serial.serialize(ErrorResponse {
|
||||
id,
|
||||
error: ResponseError {
|
||||
code: -1,
|
||||
@@ -222,7 +222,7 @@ impl<S: Serialization, C: Send + Sync + 'static> RpcMethodBuilder<S, C> {
|
||||
return (
|
||||
None,
|
||||
future::ready(id.map(|id| {
|
||||
serial.serialize(&ErrorResponse {
|
||||
serial.serialize(ErrorResponse {
|
||||
id,
|
||||
error: ResponseError {
|
||||
code: 0,
|
||||
@@ -255,7 +255,7 @@ impl<S: Serialization, C: Send + Sync + 'static> RpcMethodBuilder<S, C> {
|
||||
match callback(servers, param.params, context).await {
|
||||
Ok(r) => id.map(|id| serial.serialize(&SuccessResponse { id, result: r })),
|
||||
Err(err) => id.map(|id| {
|
||||
serial.serialize(&ErrorResponse {
|
||||
serial.serialize(ErrorResponse {
|
||||
id,
|
||||
error: ResponseError {
|
||||
code: -1,
|
||||
@@ -427,7 +427,7 @@ impl<S: Serialization, C: Send + Sync> RpcDispatcher<S, C> {
|
||||
Some(Method::Async(callback)) => MaybeSync::Future(callback(id, body)),
|
||||
Some(Method::Duplex(callback)) => MaybeSync::Stream(callback(id, body)),
|
||||
None => MaybeSync::Sync(id.map(|id| {
|
||||
self.serializer.serialize(&ErrorResponse {
|
||||
self.serializer.serialize(ErrorResponse {
|
||||
id,
|
||||
error: ResponseError {
|
||||
code: -1,
|
||||
|
||||
@@ -53,12 +53,12 @@ struct LockFileMatter {
|
||||
|
||||
/// Tries to acquire the singleton homed at the given lock file, either starting
|
||||
/// a new singleton if it doesn't exist, or connecting otherwise.
|
||||
pub async fn acquire_singleton(lock_file: PathBuf) -> Result<SingletonConnection, CodeError> {
|
||||
pub async fn acquire_singleton(lock_file: &Path) -> Result<SingletonConnection, CodeError> {
|
||||
let file = OpenOptions::new()
|
||||
.read(true)
|
||||
.write(true)
|
||||
.create(true)
|
||||
.open(&lock_file)
|
||||
.open(lock_file)
|
||||
.map_err(CodeError::SingletonLockfileOpenFailed)?;
|
||||
|
||||
match FileLock::acquire(file) {
|
||||
@@ -158,7 +158,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_acquires_singleton() {
|
||||
let dir = tempfile::tempdir().expect("expected to make temp dir");
|
||||
let s = acquire_singleton(dir.path().join("lock"))
|
||||
let s = acquire_singleton(&dir.path().join("lock"))
|
||||
.await
|
||||
.expect("expected to acquire");
|
||||
|
||||
@@ -172,7 +172,7 @@ mod tests {
|
||||
async fn test_acquires_client() {
|
||||
let dir = tempfile::tempdir().expect("expected to make temp dir");
|
||||
let lockfile = dir.path().join("lock");
|
||||
let s1 = acquire_singleton(lockfile.clone())
|
||||
let s1 = acquire_singleton(&lockfile)
|
||||
.await
|
||||
.expect("expected to acquire1");
|
||||
match s1 {
|
||||
@@ -182,7 +182,7 @@ mod tests {
|
||||
_ => panic!("expected to be singleton"),
|
||||
};
|
||||
|
||||
let s2 = acquire_singleton(lockfile)
|
||||
let s2 = acquire_singleton(&lockfile)
|
||||
.await
|
||||
.expect("expected to acquire2");
|
||||
match s2 {
|
||||
|
||||
@@ -187,6 +187,14 @@ impl LauncherPaths {
|
||||
))
|
||||
}
|
||||
|
||||
/// Lockfile for port forwarding
|
||||
pub fn forwarding_lockfile(&self) -> PathBuf {
|
||||
self.root.join(format!(
|
||||
"forwarding-{}.lock",
|
||||
VSCODE_CLI_QUALITY.unwrap_or("oss")
|
||||
))
|
||||
}
|
||||
|
||||
/// Suggested path for tunnel service logs, when using file logs
|
||||
pub fn service_log_file(&self) -> PathBuf {
|
||||
self.root.join("tunnel-service.log")
|
||||
|
||||
@@ -11,6 +11,7 @@ pub mod protocol;
|
||||
pub mod shutdown_signal;
|
||||
pub mod singleton_client;
|
||||
pub mod singleton_server;
|
||||
pub mod forwarding;
|
||||
|
||||
mod wsl_detect;
|
||||
mod challenge;
|
||||
@@ -34,7 +35,7 @@ mod service_macos;
|
||||
mod service_windows;
|
||||
mod socket_signal;
|
||||
|
||||
pub use control_server::{serve, serve_stream, Next, ServeStreamParams};
|
||||
pub use control_server::{serve, serve_stream, Next, ServeStreamParams, AuthRequired};
|
||||
pub use nosleep::SleepInhibitor;
|
||||
pub use service::{
|
||||
create_service_manager, ServiceContainer, ServiceManager, SERVICE_LOG_FILE_NAME,
|
||||
|
||||
@@ -575,7 +575,17 @@ impl<'a> ServerBuilder<'a> {
|
||||
}
|
||||
|
||||
fn get_base_command(&self) -> Command {
|
||||
#[cfg(not(windows))]
|
||||
let mut cmd = Command::new(&self.server_paths.executable);
|
||||
#[cfg(windows)]
|
||||
let mut cmd = {
|
||||
let mut cmd = Command::new("cmd");
|
||||
cmd.arg("/Q");
|
||||
cmd.arg("/C");
|
||||
cmd.arg(&self.server_paths.executable);
|
||||
cmd
|
||||
};
|
||||
|
||||
cmd.stdin(std::process::Stdio::null())
|
||||
.args(self.server_params.code_server_args.command_arguments());
|
||||
cmd
|
||||
|
||||
@@ -48,11 +48,11 @@ use super::dev_tunnels::ActiveTunnel;
|
||||
use super::paths::prune_stopped_servers;
|
||||
use super::port_forwarder::{PortForwarding, PortForwardingProcessor};
|
||||
use super::protocol::{
|
||||
AcquireCliParams, CallServerHttpParams, CallServerHttpResult, ChallengeIssueResponse,
|
||||
ChallengeVerifyParams, ClientRequestMethod, EmptyObject, ForwardParams, ForwardResult,
|
||||
FsStatRequest, FsStatResponse, GetEnvResponse, GetHostnameResponse, HttpBodyParams,
|
||||
HttpHeadersParams, ServeParams, ServerLog, ServerMessageParams, SpawnParams, SpawnResult,
|
||||
ToClientRequest, UnforwardParams, UpdateParams, UpdateResult, VersionResponse,
|
||||
AcquireCliParams, CallServerHttpParams, CallServerHttpResult, ChallengeIssueParams,
|
||||
ChallengeIssueResponse, ChallengeVerifyParams, ClientRequestMethod, EmptyObject, ForwardParams,
|
||||
ForwardResult, FsStatRequest, FsStatResponse, GetEnvResponse, GetHostnameResponse,
|
||||
HttpBodyParams, HttpHeadersParams, ServeParams, ServerLog, ServerMessageParams, SpawnParams,
|
||||
SpawnResult, ToClientRequest, UnforwardParams, UpdateParams, UpdateResult, VersionResponse,
|
||||
METHOD_CHALLENGE_VERIFY,
|
||||
};
|
||||
use super::server_bridge::ServerBridge;
|
||||
@@ -94,8 +94,8 @@ struct HandlerContext {
|
||||
|
||||
/// Handler auth state.
|
||||
enum AuthState {
|
||||
/// Auth is required, we're waiting for the client to send its challenge.
|
||||
WaitingForChallenge,
|
||||
/// Auth is required, we're waiting for the client to send its challenge optionally bearing a token.
|
||||
WaitingForChallenge(Option<String>),
|
||||
/// A challenge has been issued. Waiting for a verification.
|
||||
ChallengeIssued(String),
|
||||
/// Auth is no longer required.
|
||||
@@ -215,7 +215,7 @@ pub async fn serve(
|
||||
code_server_args: own_code_server_args,
|
||||
platform,
|
||||
exit_barrier: own_exit,
|
||||
requires_auth: false,
|
||||
requires_auth: AuthRequired::None,
|
||||
}).with_context(cx.clone()).await;
|
||||
|
||||
cx.span().add_event(
|
||||
@@ -233,12 +233,20 @@ pub async fn serve(
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub enum AuthRequired {
|
||||
None,
|
||||
VSDA,
|
||||
VSDAWithToken(String),
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ServeStreamParams {
|
||||
pub log: log::Logger,
|
||||
pub launcher_paths: LauncherPaths,
|
||||
pub code_server_args: CodeServerArgs,
|
||||
pub platform: Platform,
|
||||
pub requires_auth: bool,
|
||||
pub requires_auth: AuthRequired,
|
||||
pub exit_barrier: Barrier<ShutdownSignal>,
|
||||
}
|
||||
|
||||
@@ -268,7 +276,7 @@ fn make_socket_rpc(
|
||||
launcher_paths: LauncherPaths,
|
||||
code_server_args: CodeServerArgs,
|
||||
port_forwarding: Option<PortForwarding>,
|
||||
requires_auth: bool,
|
||||
requires_auth: AuthRequired,
|
||||
platform: Platform,
|
||||
) -> RpcDispatcher<MsgPackSerializer, HandlerContext> {
|
||||
let http_requests = Arc::new(std::sync::Mutex::new(HashMap::new()));
|
||||
@@ -276,8 +284,9 @@ fn make_socket_rpc(
|
||||
let mut rpc = RpcBuilder::new(MsgPackSerializer {}).methods(HandlerContext {
|
||||
did_update: Arc::new(AtomicBool::new(false)),
|
||||
auth_state: Arc::new(std::sync::Mutex::new(match requires_auth {
|
||||
true => AuthState::WaitingForChallenge,
|
||||
false => AuthState::Authenticated,
|
||||
AuthRequired::VSDAWithToken(t) => AuthState::WaitingForChallenge(Some(t)),
|
||||
AuthRequired::VSDA => AuthState::WaitingForChallenge(None),
|
||||
AuthRequired::None => AuthState::Authenticated,
|
||||
})),
|
||||
socket_tx,
|
||||
log: log.clone(),
|
||||
@@ -304,8 +313,8 @@ fn make_socket_rpc(
|
||||
ensure_auth(&c.auth_state)?;
|
||||
handle_get_env()
|
||||
});
|
||||
rpc.register_sync(METHOD_CHALLENGE_ISSUE, |_: EmptyObject, c| {
|
||||
handle_challenge_issue(&c.auth_state)
|
||||
rpc.register_sync(METHOD_CHALLENGE_ISSUE, |p: ChallengeIssueParams, c| {
|
||||
handle_challenge_issue(p, &c.auth_state)
|
||||
});
|
||||
rpc.register_sync(METHOD_CHALLENGE_VERIFY, |p: ChallengeVerifyParams, c| {
|
||||
handle_challenge_verify(p.response, &c.auth_state)
|
||||
@@ -422,6 +431,7 @@ async fn process_socket(
|
||||
let rx_counter = Arc::new(AtomicUsize::new(0));
|
||||
let http_requests = Arc::new(std::sync::Mutex::new(HashMap::new()));
|
||||
|
||||
let already_authed = matches!(requires_auth, AuthRequired::None);
|
||||
let rpc = make_socket_rpc(
|
||||
log.clone(),
|
||||
socket_tx.clone(),
|
||||
@@ -439,7 +449,7 @@ async fn process_socket(
|
||||
let socket_tx = socket_tx.clone();
|
||||
let exit_barrier = exit_barrier.clone();
|
||||
tokio::spawn(async move {
|
||||
if !requires_auth {
|
||||
if already_authed {
|
||||
send_version(&socket_tx).await;
|
||||
}
|
||||
|
||||
@@ -825,13 +835,22 @@ fn handle_get_env() -> Result<GetEnvResponse, AnyError> {
|
||||
}
|
||||
|
||||
fn handle_challenge_issue(
|
||||
params: ChallengeIssueParams,
|
||||
auth_state: &Arc<std::sync::Mutex<AuthState>>,
|
||||
) -> Result<ChallengeIssueResponse, AnyError> {
|
||||
let challenge = create_challenge();
|
||||
|
||||
let mut auth_state = auth_state.lock().unwrap();
|
||||
*auth_state = AuthState::ChallengeIssued(challenge.clone());
|
||||
if let AuthState::WaitingForChallenge(Some(s)) = &*auth_state {
|
||||
println!("looking for token {}, got {:?}", s, params.token);
|
||||
match ¶ms.token {
|
||||
Some(t) if s != t => return Err(CodeError::AuthChallengeBadToken.into()),
|
||||
None => return Err(CodeError::AuthChallengeBadToken.into()),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
*auth_state = AuthState::ChallengeIssued(challenge.clone());
|
||||
Ok(ChallengeIssueResponse { challenge })
|
||||
}
|
||||
|
||||
@@ -843,7 +862,7 @@ fn handle_challenge_verify(
|
||||
|
||||
match &*auth_state {
|
||||
AuthState::Authenticated => Ok(EmptyObject {}),
|
||||
AuthState::WaitingForChallenge => Err(CodeError::AuthChallengeNotIssued.into()),
|
||||
AuthState::WaitingForChallenge(_) => Err(CodeError::AuthChallengeNotIssued.into()),
|
||||
AuthState::ChallengeIssued(c) => match verify_challenge(c, &response) {
|
||||
false => Err(CodeError::AuthChallengeNotIssued.into()),
|
||||
true => {
|
||||
|
||||
@@ -3,12 +3,11 @@
|
||||
* Licensed under the MIT License. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
use crate::auth;
|
||||
use crate::constants::{
|
||||
CONTROL_PORT, IS_INTERACTIVE_CLI, PROTOCOL_VERSION_TAG, TUNNEL_SERVICE_USER_AGENT,
|
||||
};
|
||||
use crate::constants::{IS_INTERACTIVE_CLI, PROTOCOL_VERSION_TAG, TUNNEL_SERVICE_USER_AGENT};
|
||||
use crate::state::{LauncherPaths, PersistedState};
|
||||
use crate::util::errors::{
|
||||
wrap, AnyError, DevTunnelError, InvalidTunnelName, TunnelCreationFailed, WrappedError,
|
||||
wrap, AnyError, CodeError, DevTunnelError, InvalidTunnelName, TunnelCreationFailed,
|
||||
WrappedError,
|
||||
};
|
||||
use crate::util::input::prompt_placeholder;
|
||||
use crate::{debug, info, log, spanf, trace, warning};
|
||||
@@ -33,6 +32,8 @@ use tunnels::management::{
|
||||
|
||||
use super::wsl_detect::is_wsl_installed;
|
||||
|
||||
static TUNNEL_COUNT_LIMIT_NAME: &str = "TunnelsPerUserPerLocation";
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
pub struct PersistedTunnel {
|
||||
pub name: String,
|
||||
@@ -134,6 +135,7 @@ pub struct DevTunnels {
|
||||
log: log::Logger,
|
||||
launcher_tunnel: PersistedState<Option<PersistedTunnel>>,
|
||||
client: TunnelManagementClient,
|
||||
tag: &'static str,
|
||||
}
|
||||
|
||||
/// Representation of a tunnel returned from the `start` methods.
|
||||
@@ -162,30 +164,43 @@ impl ActiveTunnel {
|
||||
}
|
||||
|
||||
/// Forwards a port over TCP.
|
||||
pub async fn add_port_tcp(&mut self, port_number: u16) -> Result<(), AnyError> {
|
||||
pub async fn add_port_tcp(&self, port_number: u16) -> Result<(), AnyError> {
|
||||
self.manager.add_port_tcp(port_number).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Removes a forwarded port TCP.
|
||||
pub async fn remove_port(&mut self, port_number: u16) -> Result<(), AnyError> {
|
||||
pub async fn remove_port(&self, port_number: u16) -> Result<(), AnyError> {
|
||||
self.manager.remove_port(port_number).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Gets the public URI on which a forwarded port can be access in browser.
|
||||
pub async fn get_port_uri(&mut self, port: u16) -> Result<String, AnyError> {
|
||||
let endpoint = self.manager.get_endpoint().await?;
|
||||
let format = endpoint
|
||||
.base
|
||||
.port_uri_format
|
||||
.expect("expected to have port format");
|
||||
/// Gets the template string for forming forwarded port web URIs..
|
||||
pub fn get_port_format(&self) -> Result<String, AnyError> {
|
||||
if let Some(details) = &*self.manager.endpoint_rx.borrow() {
|
||||
return details
|
||||
.as_ref()
|
||||
.map(|r| {
|
||||
r.base
|
||||
.port_uri_format
|
||||
.clone()
|
||||
.expect("expected to have port format")
|
||||
})
|
||||
.map_err(|e| e.clone().into());
|
||||
}
|
||||
|
||||
Ok(format.replace(PORT_TOKEN, &port.to_string()))
|
||||
Err(CodeError::NoTunnelEndpoint.into())
|
||||
}
|
||||
|
||||
/// Gets the public URI on which a forwarded port can be access in browser.
|
||||
pub fn get_port_uri(&self, port: u16) -> Result<String, AnyError> {
|
||||
self.get_port_format()
|
||||
.map(|f| f.replace(PORT_TOKEN, &port.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
const VSCODE_CLI_TUNNEL_TAG: &str = "vscode-server-launcher";
|
||||
const VSCODE_CLI_FORWARDING_TAG: &str = "vscode-port-forward";
|
||||
const MAX_TUNNEL_NAME_LENGTH: usize = 20;
|
||||
|
||||
fn get_host_token_from_tunnel(tunnel: &Tunnel) -> String {
|
||||
@@ -229,7 +244,7 @@ lazy_static! {
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ExistingTunnel {
|
||||
/// Name you'd like to assign preexisting tunnel to use to connect to the VS Code Server
|
||||
pub tunnel_name: String,
|
||||
pub tunnel_name: Option<String>,
|
||||
|
||||
/// Token to authenticate and use preexisting tunnel
|
||||
pub host_token: String,
|
||||
@@ -242,7 +257,29 @@ pub struct ExistingTunnel {
|
||||
}
|
||||
|
||||
impl DevTunnels {
|
||||
pub fn new(log: &log::Logger, auth: auth::Auth, paths: &LauncherPaths) -> DevTunnels {
|
||||
/// Creates a new DevTunnels client used for port forwarding.
|
||||
pub fn new_port_forwarding(
|
||||
log: &log::Logger,
|
||||
auth: auth::Auth,
|
||||
paths: &LauncherPaths,
|
||||
) -> DevTunnels {
|
||||
let mut client = new_tunnel_management(&TUNNEL_SERVICE_USER_AGENT);
|
||||
client.authorization_provider(auth);
|
||||
|
||||
DevTunnels {
|
||||
log: log.clone(),
|
||||
client: client.into(),
|
||||
launcher_tunnel: PersistedState::new(paths.root().join("port_forwarding_tunnel.json")),
|
||||
tag: VSCODE_CLI_FORWARDING_TAG,
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a new DevTunnels client used for the Remote Tunnels extension to access the VS Code Server.
|
||||
pub fn new_remote_tunnel(
|
||||
log: &log::Logger,
|
||||
auth: auth::Auth,
|
||||
paths: &LauncherPaths,
|
||||
) -> DevTunnels {
|
||||
let mut client = new_tunnel_management(&TUNNEL_SERVICE_USER_AGENT);
|
||||
client.authorization_provider(auth);
|
||||
|
||||
@@ -250,6 +287,7 @@ impl DevTunnels {
|
||||
log: log.clone(),
|
||||
client: client.into(),
|
||||
launcher_tunnel: PersistedState::new(paths.root().join("code_tunnel.json")),
|
||||
tag: VSCODE_CLI_TUNNEL_TAG,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -275,7 +313,9 @@ impl DevTunnels {
|
||||
|
||||
/// Renames the current tunnel to the new name.
|
||||
pub async fn rename_tunnel(&mut self, name: &str) -> Result<(), AnyError> {
|
||||
self.update_tunnel_name(None, name).await.map(|_| ())
|
||||
self.update_tunnel_name(self.launcher_tunnel.load(), name)
|
||||
.await
|
||||
.map(|_| ())
|
||||
}
|
||||
|
||||
/// Updates the name of the existing persisted tunnel to the new name.
|
||||
@@ -286,28 +326,34 @@ impl DevTunnels {
|
||||
name: &str,
|
||||
) -> Result<(Tunnel, PersistedTunnel), AnyError> {
|
||||
let name = name.to_ascii_lowercase();
|
||||
self.check_is_name_free(&name).await?;
|
||||
|
||||
debug!(self.log, "Tunnel name changed, applying updates...");
|
||||
|
||||
let (mut full_tunnel, mut persisted, is_new) = match persisted {
|
||||
Some(persisted) => {
|
||||
debug!(
|
||||
self.log,
|
||||
"Found a persisted tunnel, seeing if the name matches..."
|
||||
);
|
||||
self.get_or_create_tunnel(persisted, Some(&name), NO_REQUEST_OPTIONS)
|
||||
.await
|
||||
}
|
||||
None => self
|
||||
.create_tunnel(&name, NO_REQUEST_OPTIONS)
|
||||
.await
|
||||
.map(|(pt, t)| (t, pt, true)),
|
||||
None => {
|
||||
debug!(self.log, "Creating a new tunnel with the requested name");
|
||||
self.create_tunnel(&name, NO_REQUEST_OPTIONS)
|
||||
.await
|
||||
.map(|(pt, t)| (t, pt, true))
|
||||
}
|
||||
}?;
|
||||
|
||||
if is_new {
|
||||
let desired_tags = self.get_tags(&name);
|
||||
if is_new || vec_eq_as_set(&full_tunnel.tags, &desired_tags) {
|
||||
return Ok((full_tunnel, persisted));
|
||||
}
|
||||
|
||||
full_tunnel.tags = self.get_tags(&name);
|
||||
debug!(self.log, "Tunnel name changed, applying updates...");
|
||||
|
||||
let new_tunnel = spanf!(
|
||||
full_tunnel.tags = desired_tags;
|
||||
|
||||
let updated_tunnel = spanf!(
|
||||
self.log,
|
||||
self.log.span("dev-tunnel.tag.update"),
|
||||
self.client.update_tunnel(&full_tunnel, NO_REQUEST_OPTIONS)
|
||||
@@ -317,7 +363,7 @@ impl DevTunnels {
|
||||
persisted.name = name;
|
||||
self.launcher_tunnel.save(Some(persisted.clone()))?;
|
||||
|
||||
Ok((new_tunnel, persisted))
|
||||
Ok((updated_tunnel, persisted))
|
||||
}
|
||||
|
||||
/// Gets the persisted tunnel from the service, or creates a new one.
|
||||
@@ -356,6 +402,7 @@ impl DevTunnels {
|
||||
&mut self,
|
||||
preferred_name: Option<&str>,
|
||||
use_random_name: bool,
|
||||
preserve_ports: &[u16],
|
||||
) -> Result<ActiveTunnel, AnyError> {
|
||||
let (mut tunnel, persisted) = match self.launcher_tunnel.load() {
|
||||
Some(mut persisted) => {
|
||||
@@ -385,7 +432,12 @@ impl DevTunnels {
|
||||
};
|
||||
|
||||
tunnel = self
|
||||
.sync_tunnel_tags(&persisted.name, tunnel, &HOST_TUNNEL_REQUEST_OPTIONS)
|
||||
.sync_tunnel_tags(
|
||||
&self.client,
|
||||
&persisted.name,
|
||||
tunnel,
|
||||
&HOST_TUNNEL_REQUEST_OPTIONS,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let locator = TunnelLocator::try_from(&tunnel).unwrap();
|
||||
@@ -394,7 +446,7 @@ impl DevTunnels {
|
||||
for port_to_delete in tunnel
|
||||
.ports
|
||||
.iter()
|
||||
.filter(|p| p.port_number != CONTROL_PORT)
|
||||
.filter(|p: &&TunnelPort| !preserve_ports.contains(&p.port_number))
|
||||
{
|
||||
let output_fut = self.client.delete_tunnel_port(
|
||||
&locator,
|
||||
@@ -443,14 +495,10 @@ impl DevTunnels {
|
||||
) -> Result<(PersistedTunnel, Tunnel), AnyError> {
|
||||
info!(self.log, "Creating tunnel with the name: {}", name);
|
||||
|
||||
let mut tried_recycle = false;
|
||||
self.check_is_name_free(name).await?;
|
||||
|
||||
let new_tunnel = Tunnel {
|
||||
tags: vec![
|
||||
name.to_string(),
|
||||
PROTOCOL_VERSION_TAG.to_string(),
|
||||
VSCODE_CLI_TUNNEL_TAG.to_string(),
|
||||
],
|
||||
tags: self.get_tags(name),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
@@ -465,13 +513,14 @@ impl DevTunnels {
|
||||
Err(HttpError::ResponseError(e))
|
||||
if e.status_code == StatusCode::TOO_MANY_REQUESTS =>
|
||||
{
|
||||
if !tried_recycle && self.try_recycle_tunnel().await? {
|
||||
tried_recycle = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(d) = e.get_details() {
|
||||
let detail = d.detail.unwrap_or_else(|| "unknown".to_string());
|
||||
if detail.contains(TUNNEL_COUNT_LIMIT_NAME)
|
||||
&& self.try_recycle_tunnel().await?
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
return Err(AnyError::from(TunnelCreationFailed(
|
||||
name.to_string(),
|
||||
detail,
|
||||
@@ -508,7 +557,7 @@ impl DevTunnels {
|
||||
let mut tags = vec![
|
||||
name.to_string(),
|
||||
PROTOCOL_VERSION_TAG.to_string(),
|
||||
VSCODE_CLI_TUNNEL_TAG.to_string(),
|
||||
self.tag.to_string(),
|
||||
];
|
||||
|
||||
if is_wsl_installed(&self.log) {
|
||||
@@ -522,12 +571,13 @@ impl DevTunnels {
|
||||
/// other version tags.
|
||||
async fn sync_tunnel_tags(
|
||||
&self,
|
||||
client: &TunnelManagementClient,
|
||||
name: &str,
|
||||
tunnel: Tunnel,
|
||||
options: &TunnelRequestOptions,
|
||||
) -> Result<Tunnel, AnyError> {
|
||||
let new_tags = self.get_tags(name);
|
||||
if vec_eq_unsorted(&tunnel.tags, &new_tags) {
|
||||
if vec_eq_as_set(&tunnel.tags, &new_tags) {
|
||||
return Ok(tunnel);
|
||||
}
|
||||
|
||||
@@ -548,7 +598,7 @@ impl DevTunnels {
|
||||
let result = spanf!(
|
||||
self.log,
|
||||
self.log.span("dev-tunnel.protocol-tag-update"),
|
||||
self.client.update_tunnel(&tunnel_update, options)
|
||||
client.update_tunnel(&tunnel_update, options)
|
||||
);
|
||||
|
||||
result.map_err(|e| wrap(e, "tunnel tag update failed").into())
|
||||
@@ -599,7 +649,7 @@ impl DevTunnels {
|
||||
self.log,
|
||||
self.log.span("dev-tunnel.listall"),
|
||||
self.client.list_all_tunnels(&TunnelRequestOptions {
|
||||
tags: vec![VSCODE_CLI_TUNNEL_TAG.to_string()],
|
||||
tags: vec![self.tag.to_string()],
|
||||
require_all_tags: true,
|
||||
..Default::default()
|
||||
})
|
||||
@@ -610,11 +660,11 @@ impl DevTunnels {
|
||||
}
|
||||
|
||||
async fn check_is_name_free(&mut self, name: &str) -> Result<(), AnyError> {
|
||||
let existing = spanf!(
|
||||
let existing: Vec<Tunnel> = spanf!(
|
||||
self.log,
|
||||
self.log.span("dev-tunnel.rename.search"),
|
||||
self.client.list_all_tunnels(&TunnelRequestOptions {
|
||||
tags: vec![VSCODE_CLI_TUNNEL_TAG.to_string(), name.to_string()],
|
||||
tags: vec![self.tag.to_string(), name.to_string()],
|
||||
require_all_tags: true,
|
||||
..Default::default()
|
||||
})
|
||||
@@ -629,6 +679,12 @@ impl DevTunnels {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_placeholder_name() -> String {
|
||||
let mut n = clean_hostname_for_tunnel(&gethostname::gethostname().to_string_lossy());
|
||||
n.make_ascii_lowercase();
|
||||
n
|
||||
}
|
||||
|
||||
async fn get_name_for_tunnel(
|
||||
&mut self,
|
||||
preferred_name: Option<&str>,
|
||||
@@ -660,10 +716,7 @@ impl DevTunnels {
|
||||
use_random_name = true;
|
||||
}
|
||||
|
||||
let mut placeholder_name =
|
||||
clean_hostname_for_tunnel(&gethostname::gethostname().to_string_lossy());
|
||||
placeholder_name.make_ascii_lowercase();
|
||||
|
||||
let mut placeholder_name = Self::get_placeholder_name();
|
||||
if !is_name_free(&placeholder_name) {
|
||||
for i in 2.. {
|
||||
let fixed_name = format!("{}{}", placeholder_name, i);
|
||||
@@ -705,7 +758,10 @@ impl DevTunnels {
|
||||
tunnel: ExistingTunnel,
|
||||
) -> Result<ActiveTunnel, AnyError> {
|
||||
let tunnel_details = PersistedTunnel {
|
||||
name: tunnel.tunnel_name,
|
||||
name: match tunnel.tunnel_name {
|
||||
Some(n) => n,
|
||||
None => Self::get_placeholder_name(),
|
||||
},
|
||||
id: tunnel.tunnel_id,
|
||||
cluster: tunnel.cluster,
|
||||
};
|
||||
@@ -715,10 +771,23 @@ impl DevTunnels {
|
||||
tunnel.host_token.clone(),
|
||||
));
|
||||
|
||||
let client = mgmt.into();
|
||||
self.sync_tunnel_tags(
|
||||
&client,
|
||||
&tunnel_details.name,
|
||||
Tunnel {
|
||||
cluster_id: Some(tunnel_details.cluster.clone()),
|
||||
tunnel_id: Some(tunnel_details.id.clone()),
|
||||
..Default::default()
|
||||
},
|
||||
&HOST_TUNNEL_REQUEST_OPTIONS,
|
||||
)
|
||||
.await?;
|
||||
|
||||
self.start_tunnel(
|
||||
tunnel_details.locator(),
|
||||
&tunnel_details,
|
||||
mgmt.into(),
|
||||
client,
|
||||
StaticAccessTokenProvider::new(tunnel.host_token),
|
||||
)
|
||||
.await
|
||||
@@ -998,7 +1067,7 @@ fn clean_hostname_for_tunnel(hostname: &str) -> String {
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_eq_unsorted(a: &[String], b: &[String]) -> bool {
|
||||
fn vec_eq_as_set(a: &[String], b: &[String]) -> bool {
|
||||
if a.len() != b.len() {
|
||||
return false;
|
||||
}
|
||||
|
||||
284
cli/src/tunnels/forwarding.rs
Normal file
284
cli/src/tunnels/forwarding.rs
Normal file
@@ -0,0 +1,284 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT License. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
sync::{Arc, Mutex},
|
||||
};
|
||||
|
||||
use tokio::{
|
||||
pin,
|
||||
sync::{mpsc, watch},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
async_pipe::{socket_stream_split, AsyncPipe},
|
||||
json_rpc::{new_json_rpc, start_json_rpc},
|
||||
log,
|
||||
singleton::SingletonServer,
|
||||
util::{errors::CodeError, sync::Barrier},
|
||||
};
|
||||
|
||||
use super::{
|
||||
dev_tunnels::ActiveTunnel,
|
||||
protocol::{
|
||||
self,
|
||||
forward_singleton::{PortList, SetPortsResponse},
|
||||
},
|
||||
shutdown_signal::ShutdownSignal,
|
||||
};
|
||||
|
||||
type PortMap = HashMap<u16, u32>;
|
||||
|
||||
/// The PortForwardingHandle is given out to multiple consumers to allow
|
||||
/// them to set_ports that they want to be forwarded.
|
||||
struct PortForwardingSender {
|
||||
/// Todo: when `SyncUnsafeCell` is no longer nightly, we can use it here with
|
||||
/// the following comment:
|
||||
///
|
||||
/// SyncUnsafeCell is used and safe here because PortForwardingSender is used
|
||||
/// exclusively in synchronous dispatch *and* we create a new sender in the
|
||||
/// context for each connection, in `serve_singleton_rpc`.
|
||||
///
|
||||
/// If PortForwardingSender is ever used in a different context, this should
|
||||
/// be refactored, e.g. to use locks or `&mut self` in set_ports`
|
||||
///
|
||||
/// see https://doc.rust-lang.org/stable/std/cell/struct.SyncUnsafeCell.html
|
||||
current: Mutex<PortList>,
|
||||
sender: Arc<Mutex<watch::Sender<PortMap>>>,
|
||||
}
|
||||
|
||||
impl PortForwardingSender {
|
||||
pub fn set_ports(&self, ports: PortList) {
|
||||
let mut current = self.current.lock().unwrap();
|
||||
self.sender.lock().unwrap().send_modify(|v| {
|
||||
for p in current.iter() {
|
||||
if !ports.contains(p) {
|
||||
match v.get(p) {
|
||||
Some(1) => {
|
||||
v.remove(p);
|
||||
}
|
||||
Some(n) => {
|
||||
v.insert(*p, n - 1);
|
||||
}
|
||||
None => unreachable!("removed port not in map"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for p in ports.iter() {
|
||||
if !current.contains(p) {
|
||||
match v.get(p) {
|
||||
Some(n) => v.insert(*p, n + 1),
|
||||
None => v.insert(*p, 1),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
current.splice(.., ports);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for PortForwardingSender {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
current: Mutex::new(vec![]),
|
||||
sender: self.sender.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for PortForwardingSender {
|
||||
fn drop(&mut self) {
|
||||
self.set_ports(vec![]);
|
||||
}
|
||||
}
|
||||
|
||||
struct PortForwardingReceiver {
|
||||
receiver: watch::Receiver<PortMap>,
|
||||
}
|
||||
|
||||
impl PortForwardingReceiver {
|
||||
pub fn new() -> (PortForwardingSender, Self) {
|
||||
let (sender, receiver) = watch::channel(HashMap::new());
|
||||
let handle = PortForwardingSender {
|
||||
current: Mutex::new(vec![]),
|
||||
sender: Arc::new(Mutex::new(sender)),
|
||||
};
|
||||
|
||||
let tracker = Self { receiver };
|
||||
|
||||
(handle, tracker)
|
||||
}
|
||||
|
||||
/// Applies all changes from PortForwardingHandles to the tunnel.
|
||||
pub async fn apply_to(&mut self, log: log::Logger, tunnel: Arc<ActiveTunnel>) {
|
||||
let mut current = vec![];
|
||||
while self.receiver.changed().await.is_ok() {
|
||||
let next = self.receiver.borrow().keys().copied().collect::<Vec<_>>();
|
||||
|
||||
for p in current.iter() {
|
||||
if !next.contains(p) {
|
||||
match tunnel.remove_port(*p).await {
|
||||
Ok(_) => info!(log, "stopped forwarding port {}", p),
|
||||
Err(e) => error!(log, "failed to stop forwarding port {}: {}", p, e),
|
||||
}
|
||||
}
|
||||
}
|
||||
for p in next.iter() {
|
||||
if !current.contains(p) {
|
||||
match tunnel.add_port_tcp(*p).await {
|
||||
Ok(_) => info!(log, "forwarding port {}", p),
|
||||
Err(e) => error!(log, "failed to forward port {}: {}", p, e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
current = next;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SingletonClientArgs {
|
||||
pub log: log::Logger,
|
||||
pub stream: AsyncPipe,
|
||||
pub shutdown: Barrier<ShutdownSignal>,
|
||||
pub port_requests: watch::Receiver<PortList>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct SingletonServerContext {
|
||||
log: log::Logger,
|
||||
handle: PortForwardingSender,
|
||||
tunnel: Arc<ActiveTunnel>,
|
||||
}
|
||||
|
||||
/// Serves a client singleton for port forwarding.
|
||||
pub async fn client(args: SingletonClientArgs) -> Result<(), std::io::Error> {
|
||||
let mut rpc = new_json_rpc();
|
||||
let (msg_tx, msg_rx) = mpsc::unbounded_channel();
|
||||
let SingletonClientArgs {
|
||||
log,
|
||||
shutdown,
|
||||
stream,
|
||||
mut port_requests,
|
||||
} = args;
|
||||
|
||||
debug!(
|
||||
log,
|
||||
"An existing port forwarding process is running on this machine, connecting to it..."
|
||||
);
|
||||
|
||||
let caller = rpc.get_caller(msg_tx);
|
||||
let rpc = rpc.methods(()).build(log.clone());
|
||||
let (read, write) = socket_stream_split(stream);
|
||||
|
||||
let serve = start_json_rpc(rpc, read, write, msg_rx, shutdown);
|
||||
let forward = async move {
|
||||
while port_requests.changed().await.is_ok() {
|
||||
let ports = port_requests.borrow().clone();
|
||||
let r = caller
|
||||
.call::<_, _, protocol::forward_singleton::SetPortsResponse>(
|
||||
protocol::forward_singleton::METHOD_SET_PORTS,
|
||||
protocol::forward_singleton::SetPortsParams { ports },
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
match r {
|
||||
Err(e) => error!(log, "failed to set ports: {:?}", e),
|
||||
Ok(r) => print_forwarding_addr(&r),
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
tokio::select! {
|
||||
r = serve => r.map(|_| ()),
|
||||
_ = forward => Ok(()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Serves a port-forwarding singleton.
|
||||
pub async fn server(
|
||||
log: log::Logger,
|
||||
tunnel: ActiveTunnel,
|
||||
server: SingletonServer,
|
||||
mut port_requests: watch::Receiver<PortList>,
|
||||
shutdown_rx: Barrier<ShutdownSignal>,
|
||||
) -> Result<(), CodeError> {
|
||||
let tunnel = Arc::new(tunnel);
|
||||
let (forward_tx, mut forward_rx) = PortForwardingReceiver::new();
|
||||
|
||||
let forward_own_tunnel = tunnel.clone();
|
||||
let forward_own_tx = forward_tx.clone();
|
||||
let forward_own = async move {
|
||||
while port_requests.changed().await.is_ok() {
|
||||
forward_own_tx.set_ports(port_requests.borrow().clone());
|
||||
print_forwarding_addr(&SetPortsResponse {
|
||||
port_format: forward_own_tunnel.get_port_format().ok(),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
tokio::select! {
|
||||
_ = forward_own => Ok(()),
|
||||
_ = forward_rx.apply_to(log.clone(), tunnel.clone()) => Ok(()),
|
||||
r = serve_singleton_rpc(server, log, tunnel, forward_tx, shutdown_rx) => r,
|
||||
}
|
||||
}
|
||||
|
||||
async fn serve_singleton_rpc(
|
||||
mut server: SingletonServer,
|
||||
log: log::Logger,
|
||||
tunnel: Arc<ActiveTunnel>,
|
||||
forward_tx: PortForwardingSender,
|
||||
shutdown_rx: Barrier<ShutdownSignal>,
|
||||
) -> Result<(), CodeError> {
|
||||
let mut own_shutdown = shutdown_rx.clone();
|
||||
let shutdown_fut = own_shutdown.wait();
|
||||
pin!(shutdown_fut);
|
||||
|
||||
loop {
|
||||
let cnx = tokio::select! {
|
||||
c = server.accept() => c?,
|
||||
_ = &mut shutdown_fut => return Ok(()),
|
||||
};
|
||||
|
||||
let (read, write) = socket_stream_split(cnx);
|
||||
let shutdown_rx = shutdown_rx.clone();
|
||||
|
||||
let handle = forward_tx.clone();
|
||||
let log = log.clone();
|
||||
let tunnel = tunnel.clone();
|
||||
tokio::spawn(async move {
|
||||
// we make an rpc for the connection instead of re-using a dispatcher
|
||||
// so that we can have the "handle" drop when the connection drops.
|
||||
let rpc = new_json_rpc();
|
||||
let mut rpc = rpc.methods(SingletonServerContext {
|
||||
log: log.clone(),
|
||||
handle,
|
||||
tunnel,
|
||||
});
|
||||
|
||||
rpc.register_sync(
|
||||
protocol::forward_singleton::METHOD_SET_PORTS,
|
||||
|p: protocol::forward_singleton::SetPortsParams, ctx| {
|
||||
info!(ctx.log, "client setting ports to {:?}", p.ports);
|
||||
ctx.handle.set_ports(p.ports);
|
||||
Ok(SetPortsResponse {
|
||||
port_format: ctx.tunnel.get_port_format().ok(),
|
||||
})
|
||||
},
|
||||
);
|
||||
|
||||
let _ = start_json_rpc(rpc.build(log), read, write, (), shutdown_rx).await;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
fn print_forwarding_addr(r: &SetPortsResponse) {
|
||||
eprintln!("{}\n", serde_json::to_string(r).unwrap());
|
||||
}
|
||||
@@ -91,10 +91,15 @@ impl InstalledServer {
|
||||
pub fn server_paths(&self, p: &LauncherPaths) -> ServerPaths {
|
||||
let server_dir = self.get_install_folder(p);
|
||||
ServerPaths {
|
||||
executable: server_dir
|
||||
.join(SERVER_FOLDER_NAME)
|
||||
.join("bin")
|
||||
.join(self.quality.server_entrypoint()),
|
||||
// allow using the OSS server in development via an override
|
||||
executable: if let Some(p) = option_env!("VSCODE_CLI_OVERRIDE_SERVER_PATH") {
|
||||
PathBuf::from(p)
|
||||
} else {
|
||||
server_dir
|
||||
.join(SERVER_FOLDER_NAME)
|
||||
.join("bin")
|
||||
.join(self.quality.server_entrypoint())
|
||||
},
|
||||
logfile: server_dir.join("log.txt"),
|
||||
pidfile: server_dir.join("pid.txt"),
|
||||
server_dir,
|
||||
|
||||
@@ -91,7 +91,7 @@ impl PortForwardingProcessor {
|
||||
self.forwarded.insert(port);
|
||||
}
|
||||
|
||||
tunnel.get_port_uri(port).await
|
||||
tunnel.get_port_uri(port)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -199,6 +199,11 @@ pub struct SpawnResult {
|
||||
pub const METHOD_CHALLENGE_ISSUE: &str = "challenge_issue";
|
||||
pub const METHOD_CHALLENGE_VERIFY: &str = "challenge_verify";
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct ChallengeIssueParams {
|
||||
pub token: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct ChallengeIssueResponse {
|
||||
pub challenge: String,
|
||||
@@ -209,6 +214,24 @@ pub struct ChallengeVerifyParams {
|
||||
pub response: String,
|
||||
}
|
||||
|
||||
pub mod forward_singleton {
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub const METHOD_SET_PORTS: &str = "set_ports";
|
||||
|
||||
pub type PortList = Vec<u16>;
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct SetPortsParams {
|
||||
pub ports: PortList,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct SetPortsResponse {
|
||||
pub port_format: Option<String>,
|
||||
}
|
||||
}
|
||||
|
||||
pub mod singleton {
|
||||
use crate::log;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
@@ -78,6 +78,7 @@ impl CliServiceManager for WindowsService {
|
||||
cmd.stderr(Stdio::null());
|
||||
cmd.stdout(Stdio::null());
|
||||
cmd.stdin(Stdio::null());
|
||||
cmd.creation_flags(CREATE_NEW_PROCESS_GROUP | DETACHED_PROCESS);
|
||||
cmd.spawn()
|
||||
.map_err(|e| wrapdbg(e, "error starting service"))?;
|
||||
|
||||
@@ -121,8 +122,12 @@ impl CliServiceManager for WindowsService {
|
||||
|
||||
async fn unregister(&self) -> Result<(), AnyError> {
|
||||
let key = WindowsService::open_key()?;
|
||||
key.delete_value(TUNNEL_ACTIVITY_NAME)
|
||||
.map_err(|e| AnyError::from(wrap(e, "error deleting registry key")))?;
|
||||
match key.delete_value(TUNNEL_ACTIVITY_NAME) {
|
||||
Ok(_) => {}
|
||||
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {}
|
||||
Err(e) => return Err(wrap(e, "error deleting registry key").into()),
|
||||
}
|
||||
|
||||
info!(self.log, "Tunnel service uninstalled");
|
||||
|
||||
let r = do_single_rpc_call::<_, ()>(
|
||||
|
||||
@@ -217,7 +217,7 @@ impl BroadcastLogSink {
|
||||
}
|
||||
}
|
||||
|
||||
fn get_brocaster(&self) -> broadcast::Sender<Vec<u8>> {
|
||||
pub fn get_brocaster(&self) -> broadcast::Sender<Vec<u8>> {
|
||||
self.tx.clone()
|
||||
}
|
||||
|
||||
|
||||
@@ -509,8 +509,14 @@ pub enum CodeError {
|
||||
ServerAuthRequired,
|
||||
#[error("challenge not yet issued")]
|
||||
AuthChallengeNotIssued,
|
||||
#[error("challenge token is invalid")]
|
||||
AuthChallengeBadToken,
|
||||
#[error("unauthorized client refused")]
|
||||
AuthMismatch,
|
||||
#[error("keyring communication timed out after 5s")]
|
||||
KeyringTimeout,
|
||||
#[error("no host is connected to the tunnel relay")]
|
||||
NoTunnelEndpoint,
|
||||
}
|
||||
|
||||
makeAnyError!(
|
||||
|
||||
Reference in New Issue
Block a user