diff --git a/Cargo.lock b/Cargo.lock index 8b7c7c0..c14bb23 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3074,6 +3074,7 @@ dependencies = [ "anyhow", "arboard", "ashpd", + "chrono", "clap", "dialoguer", "directories", @@ -3086,8 +3087,10 @@ dependencies = [ "thiserror 2.0.18", "tokio", "tokio-util", + "toml", "tracing", "tracing-subscriber", + "ureq", "uuid", "x11rb", ] @@ -4448,6 +4451,34 @@ version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" +[[package]] +name = "ureq" +version = "3.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dea7109cdcd5864d4eeb1b58a1648dc9bf520360d7af16ec26d0a9354bafcfc0" +dependencies = [ + "base64", + "log", + "percent-encoding", + "rustls", + "rustls-pki-types", + "ureq-proto", + "utf8-zero", + "webpki-roots", +] + +[[package]] +name = "ureq-proto" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e994ba84b0bd1b1b0cf92878b7ef898a5c1760108fe7b6010327e274917a808c" +dependencies = [ + "base64", + "http", + "httparse", + "log", +] + [[package]] name = "url" version = "2.5.8" @@ -4461,6 +4492,12 @@ dependencies = [ "serde_derive", ] +[[package]] +name = "utf8-zero" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8c0a043c9540bae7c578c88f91dda8bd82e59ae27c21baca69c8b191aaf5a6e" + [[package]] name = "utf8_iter" version = "1.0.4" diff --git a/Cargo.toml b/Cargo.toml index b3a3e4e..d6a1f26 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,6 +30,9 @@ uuid = { version = "1", features = ["v4"] } iroh-tickets = "1.0.0-rc.0" dialoguer = { version = "0.12", default-features = false } arboard = { version = "3", default-features = false, features = ["wayland-data-control"] } +ureq = { version = "3", default-features = false, features = ["rustls"] } +toml = "1" +chrono = { version = "0.4", default-features = false, features = ["clock", "serde"] } [profile.release] lto = "thin" diff --git a/src/cli.rs b/src/cli.rs index b1d7188..467edea 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -46,9 +46,11 @@ pub struct Cli { pub low_latency: bool, /// Maximum number of concurrent viewers. Additional connections are - /// politely refused with a "host full" message. - #[arg(long, default_value_t = 2)] - pub max_viewers: u32, + /// politely refused with a "host full" message. Defaults to the + /// connection-aware recommendation from the bandwidth pre-flight if + /// available, otherwise 2. + #[arg(long)] + pub max_viewers: Option, // ── viewer options ──────────────────────────────────────────────── /// Local TCP port for the viewer to expose (default: random). @@ -63,6 +65,12 @@ pub struct Cli { /// Clean up orphaned PipeWire state from a crashed host run, then exit. #[arg(long)] pub repair: bool, + + /// Re-run the bandwidth pre-flight test, save the result, then exit. + /// Use this if your connection has changed (new ISP, moved house, etc.) + /// or if the previously saved test result is stale. + #[arg(long)] + pub reconfigure: bool, } #[derive(ValueEnum, Clone, Copy, Debug)] @@ -81,7 +89,7 @@ pub struct HostOpts { pub framerate: u32, pub no_hwencode: bool, pub low_latency: bool, - pub max_viewers: u32, + pub max_viewers: Option, pub interactive: bool, } diff --git a/src/common/bandwidth.rs b/src/common/bandwidth.rs new file mode 100644 index 0000000..5fcc080 --- /dev/null +++ b/src/common/bandwidth.rs @@ -0,0 +1,73 @@ +//! One-shot upstream bandwidth measurement against Cloudflare's open +//! speed-test endpoint. POST a fixed payload, time it, derive Mbps. +//! +//! Run via `tokio::task::spawn_blocking` from async contexts — ureq is a +//! blocking client and we don't want to wedge the tokio runtime during +//! the test. + +use anyhow::{Context, Result}; +use std::time::{Duration, Instant}; + +const ENDPOINT: &str = "https://speed.cloudflare.com/__up"; +const PAYLOAD_BYTES: usize = 5 * 1024 * 1024; // 5 MiB +const HTTP_TIMEOUT: Duration = Duration::from_secs(30); +/// Multiplier applied to the raw measurement. TCP slow-start, ramp-up, and +/// real-world contention all mean a one-shot upstream test slightly +/// overestimates sustainable throughput; clamp to 80% for headroom. +const SAFETY_FACTOR: f64 = 0.80; + +/// Result of a successful measurement. +#[derive(Debug, Clone)] +pub struct Measurement { + /// Raw measured throughput in megabits per second. + pub raw_mbps: f64, + /// `raw_mbps * SAFETY_FACTOR` — the value to use when sizing things. + pub safe_mbps: f64, + /// How long the upload took. + pub elapsed: Duration, +} + +/// Blocking upload-speed test. Call from a `spawn_blocking` task. +pub fn measure_upstream_blocking() -> Result { + let payload = vec![0u8; PAYLOAD_BYTES]; + let agent = ureq::Agent::config_builder() + .timeout_global(Some(HTTP_TIMEOUT)) + .build() + .new_agent(); + + let start = Instant::now(); + let response = agent + .post(ENDPOINT) + .content_type("application/octet-stream") + .send(&payload[..]) + .context("upload request to Cloudflare failed")?; + let elapsed = start.elapsed(); + + let status = response.status(); + if !status.is_success() { + anyhow::bail!("Cloudflare returned HTTP {status}"); + } + + let bits = (PAYLOAD_BYTES as f64) * 8.0; + let seconds = elapsed.as_secs_f64().max(0.001); + let raw_mbps = bits / seconds / 1_000_000.0; + let safe_mbps = raw_mbps * SAFETY_FACTOR; + + Ok(Measurement { + raw_mbps, + safe_mbps, + elapsed, + }) +} + +/// Convert a safe-upstream Mbps figure plus the host's per-viewer bitrate +/// (kbps for video, ignoring audio + protocol overhead which we account for +/// via SAFETY_FACTOR) into a recommended viewer count. Floors to at least 1. +pub fn recommended_max_viewers(safe_mbps: f64, bitrate_kbps: u32) -> u32 { + let per_viewer_mbps = (bitrate_kbps as f64) / 1000.0; + if per_viewer_mbps <= 0.0 { + return 1; + } + let n = (safe_mbps / per_viewer_mbps).floor(); + if n < 1.0 { 1 } else { n as u32 } +} diff --git a/src/common/config.rs b/src/common/config.rs new file mode 100644 index 0000000..6dfec06 --- /dev/null +++ b/src/common/config.rs @@ -0,0 +1,101 @@ +//! Persistent user-level config at `~/.config/pixelpass/config.toml`. +//! +//! Right now this only tracks the bandwidth pre-flight result. Future +//! preferences (default player, default bitrate, etc.) can hang off the +//! same file under their own `[section]`. + +use anyhow::{Context, Result}; +use chrono::{DateTime, Utc}; +use directories::ProjectDirs; +use serde::{Deserialize, Serialize}; +use std::fs; +use std::io::Write; +use std::path::PathBuf; + +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct Config { + #[serde(default)] + pub bandwidth: BandwidthEntry, +} + +/// Result of the first-run upstream measurement. +/// +/// `status = "unmeasured"` means we've never asked the user — show the +/// first-run dialog. `"measured"` means we have a number. `"skipped"` +/// means the user opted out (sticky — don't ask again). `"failed"` +/// means the last attempt errored and we should ask the user on next +/// interactive launch whether to retry. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct BandwidthEntry { + #[serde(default = "default_status")] + pub status: BandwidthStatus, + #[serde(default)] + pub upstream_mbps: Option, + #[serde(default)] + pub measured_at: Option>, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum BandwidthStatus { + Unmeasured, + Measured, + Skipped, + Failed, +} + +impl Default for BandwidthStatus { + fn default() -> Self { + Self::Unmeasured + } +} + +fn default_status() -> BandwidthStatus { + BandwidthStatus::Unmeasured +} + +/// Returns `~/.config/pixelpass/config.toml` (or the XDG equivalent on other +/// platforms). The parent directory is created lazily by [`save`]. +pub fn config_path() -> Result { + let dirs = ProjectDirs::from("", "", "pixelpass") + .context("could not locate a config directory for pixelpass")?; + Ok(dirs.config_dir().join("config.toml")) +} + +/// Returns the loaded config, or a `Default` instance if the file doesn't +/// exist yet. Bubble up parse errors so we don't silently overwrite a +/// hand-edited config the user is debugging. +pub fn load() -> Result { + let path = config_path()?; + match fs::read_to_string(&path) { + Ok(s) => toml::from_str::(&s) + .with_context(|| format!("failed to parse {}", path.display())), + Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(Config::default()), + Err(e) => Err(e).with_context(|| format!("failed to read {}", path.display())), + } +} + +/// Atomic write via tempfile-in-same-dir + rename. +pub fn save(cfg: &Config) -> Result<()> { + let path = config_path()?; + let parent = path + .parent() + .context("config path has no parent directory")?; + fs::create_dir_all(parent) + .with_context(|| format!("failed to create {}", parent.display()))?; + + let serialized = + toml::to_string_pretty(cfg).context("failed to serialize config to TOML")?; + + let tmp = parent.join(format!(".config.toml.tmp.{}", std::process::id())); + { + let mut f = fs::File::create(&tmp) + .with_context(|| format!("failed to create {}", tmp.display()))?; + f.write_all(serialized.as_bytes()) + .with_context(|| format!("failed to write {}", tmp.display()))?; + f.sync_all().ok(); + } + fs::rename(&tmp, &path) + .with_context(|| format!("failed to rename {} -> {}", tmp.display(), path.display()))?; + Ok(()) +} diff --git a/src/common/mod.rs b/src/common/mod.rs index fe24f8e..99d45e7 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -1,4 +1,6 @@ pub mod alpn; +pub mod bandwidth; +pub mod config; pub mod deps; pub mod display; pub mod process; diff --git a/src/host/mod.rs b/src/host/mod.rs index e7adf01..fe29b2a 100644 --- a/src/host/mod.rs +++ b/src/host/mod.rs @@ -10,7 +10,10 @@ use tokio::sync::{mpsc, oneshot}; use tokio_util::sync::CancellationToken; use crate::cli::HostOpts; -use crate::common::{alpn::ALPN, deps, display::DisplayServer, signal, tunnel}; +use crate::common::{ + alpn::ALPN, bandwidth, config, config::BandwidthStatus, deps, display::DisplayServer, signal, + tunnel, +}; use self::capture::CaptureHandle; @@ -36,7 +39,8 @@ pub async fn run(opts: HostOpts) -> Result<()> { ); } - if opts.max_viewers == 0 { + let resolution = resolve_max_viewers(&opts); + if resolution.value == 0 { bail!("--max-viewers must be at least 1"); } @@ -50,10 +54,10 @@ pub async fn run(opts: HostOpts) -> Result<()> { let addr = endpoint.addr(); let ticket = EndpointTicket::new(addr); let clipboard_ok = opts.interactive && copy_to_clipboard(&ticket.to_string()); - print_host_banner(&ticket, display, &opts, clipboard_ok); + print_host_banner(&ticket, display, &opts, &resolution, clipboard_ok); let (sup_tx, sup_rx) = mpsc::channel::(16); - let supervisor = tokio::spawn(supervise(opts.clone(), display, sup_rx)); + let supervisor = tokio::spawn(supervise(opts.clone(), display, resolution.value, sup_rx)); accept_loop(&endpoint, sup_tx.clone(), cancel.clone()).await; @@ -160,11 +164,12 @@ async fn handle_peer( /// Owns the single shared CaptureHandle and the active viewer count. Spawns /// capture lazily on the first AddViewer; tears it down when the count drops -/// back to zero. Enforces the --max-viewers cap by refusing AddViewer when +/// back to zero. Enforces the max-viewers cap by refusing AddViewer when /// the count is already at the cap. async fn supervise( opts: HostOpts, display: DisplayServer, + max_viewers: u32, mut rx: mpsc::Receiver, ) { let mut handle: Option = None; @@ -173,10 +178,9 @@ async fn supervise( while let Some(msg) = rx.recv().await { match msg { SupervisorMsg::AddViewer(reply) => { - if count >= opts.max_viewers { + if count >= max_viewers { let _ = reply.send(Err(format!( - "host is full ({} of {} viewers connected)", - count, opts.max_viewers + "host is full ({count} of {max_viewers} viewers connected)" ))); continue; } @@ -195,11 +199,11 @@ async fn supervise( let port = handle.as_ref().expect("handle was just set").local_port(); count += 1; let _ = reply.send(Ok(port)); - tracing::info!(active = count, cap = opts.max_viewers, "viewer joined"); + tracing::info!(active = count, cap = max_viewers, "viewer joined"); } SupervisorMsg::RemoveViewer => { count = count.saturating_sub(1); - tracing::info!(active = count, cap = opts.max_viewers, "viewer left"); + tracing::info!(active = count, cap = max_viewers, "viewer left"); if count == 0 && let Some(h) = handle.take() { @@ -220,6 +224,7 @@ fn print_host_banner( ticket: &EndpointTicket, display: DisplayServer, opts: &HostOpts, + resolution: &MaxViewersResolution, clipboard_ok: bool, ) { eprintln!(); @@ -228,7 +233,7 @@ fn print_host_banner( eprintln!("│ capture : {}", capture_summary(opts)); eprintln!("│ bitrate / fps : {} kbps @ {} fps", opts.bitrate, opts.framerate); eprintln!("│ hw encode : {}", if opts.no_hwencode { "off" } else { "auto (VAAPI if available)" }); - eprintln!("│ max viewers : {}", opts.max_viewers); + eprintln!("│ max viewers : {} ({})", resolution.value, resolution.source.label()); eprintln!("│"); if clipboard_ok { eprintln!("│ Your share code has been copied to your clipboard."); @@ -247,6 +252,60 @@ fn print_host_banner( eprintln!(); } +/// How we arrived at the final viewer cap. Surfaced in the banner so the +/// user can tell at a glance whether the number is what they specified, +/// what their measured upstream supports, or just the fallback default. +struct MaxViewersResolution { + value: u32, + source: MaxViewersSource, +} + +enum MaxViewersSource { + /// User passed --max-viewers explicitly. + UserFlag, + /// Derived from the saved bandwidth measurement. + BandwidthMeasurement { safe_mbps: f64 }, + /// No flag, no measurement — falling back. + DefaultFallback, +} + +impl MaxViewersSource { + fn label(&self) -> String { + match self { + MaxViewersSource::UserFlag => "user-specified".to_string(), + MaxViewersSource::BandwidthMeasurement { safe_mbps } => { + format!("auto: {safe_mbps:.1} Mbps measured upstream") + } + MaxViewersSource::DefaultFallback => { + "default — run `pixelpass --reconfigure` for a connection-aware value".to_string() + } + } + } +} + +fn resolve_max_viewers(opts: &HostOpts) -> MaxViewersResolution { + if let Some(n) = opts.max_viewers { + return MaxViewersResolution { + value: n, + source: MaxViewersSource::UserFlag, + }; + } + if let Ok(cfg) = config::load() + && cfg.bandwidth.status == BandwidthStatus::Measured + && let Some(upstream) = cfg.bandwidth.upstream_mbps + { + let n = bandwidth::recommended_max_viewers(upstream, opts.bitrate); + return MaxViewersResolution { + value: n, + source: MaxViewersSource::BandwidthMeasurement { safe_mbps: upstream }, + }; + } + MaxViewersResolution { + value: 2, + source: MaxViewersSource::DefaultFallback, + } +} + fn copy_to_clipboard(text: &str) -> bool { match arboard::Clipboard::new().and_then(|mut cb| cb.set_text(text.to_owned())) { Ok(()) => true, diff --git a/src/interactive.rs b/src/interactive.rs index 068d023..56a6149 100644 --- a/src/interactive.rs +++ b/src/interactive.rs @@ -4,6 +4,7 @@ use iroh_tickets::endpoint::EndpointTicket; use std::str::FromStr; use crate::cli::Cli; +use crate::common::{bandwidth, config}; use crate::{host, viewer}; pub async fn run(cli: Cli) -> Result<()> { @@ -20,7 +21,10 @@ pub async fn run(cli: Cli) -> Result<()> { .interact()?; match choice { - 0 => host::run(cli.into_host_opts(true)).await, + 0 => { + preflight_if_needed(&theme).await; + host::run(cli.into_host_opts(true)).await + } _ => { let ticket = prompt_ticket(&theme)?; viewer::run(ticket, cli.into_viewer_opts(true)).await @@ -28,6 +32,138 @@ pub async fn run(cli: Cli) -> Result<()> { } } +/// `pixelpass --reconfigure` entry point: unconditionally re-run the +/// bandwidth pre-flight test, save the result, and return. Used to +/// refresh a stale measurement (e.g. user moved house, changed ISP). +pub async fn run_reconfigure() -> Result<()> { + eprintln!(); + eprintln!("Re-running bandwidth pre-flight test…"); + let mut cfg = config::load().unwrap_or_default(); + run_bandwidth_test(&mut cfg).await; + Ok(()) +} + +/// First-run pre-flight gate. Called once, when the user picks "Host" in +/// the interactive menu. Behavior by saved status: +/// - Unmeasured (first ever launch): explain + offer Run / Skip +/// - Failed (previous attempt errored): offer Retry / give-up-and-skip +/// - Measured or Skipped: silent — never re-prompts +async fn preflight_if_needed(theme: &ColorfulTheme) { + let mut cfg = config::load().unwrap_or_default(); + match cfg.bandwidth.status { + config::BandwidthStatus::Measured | config::BandwidthStatus::Skipped => return, + config::BandwidthStatus::Unmeasured => { + eprintln!(); + eprintln!("First-time setup"); + eprintln!("────────────────"); + eprintln!("PixelPass can measure your upload speed to recommend a safe"); + eprintln!("default for how many viewers your connection can handle."); + eprintln!("The test takes about 5 seconds and uploads ~5 MB to"); + eprintln!("Cloudflare's open speed-test endpoint."); + eprintln!(); + eprintln!("If you skip, a conservative default (2 viewers) is used."); + eprintln!("You can run the test later with `pixelpass --reconfigure`."); + eprintln!(); + + let Ok(choice) = Select::with_theme(theme) + .with_prompt("What would you like to do?") + .items(&[ + "Run the bandwidth test (recommended)", + "Skip — use the conservative default", + ]) + .default(0) + .interact() + else { + return; + }; + + if choice == 1 { + cfg.bandwidth = config::BandwidthEntry { + status: config::BandwidthStatus::Skipped, + upstream_mbps: None, + measured_at: None, + }; + let _ = config::save(&cfg); + eprintln!("Pre-flight skipped."); + return; + } + run_bandwidth_test(&mut cfg).await; + } + config::BandwidthStatus::Failed => { + eprintln!(); + let Ok(choice) = Select::with_theme(theme) + .with_prompt("Last bandwidth test failed. Try again?") + .items(&[ + "Yes — retry now", + "No — use the conservative default", + ]) + .default(0) + .interact() + else { + return; + }; + + if choice == 1 { + cfg.bandwidth = config::BandwidthEntry { + status: config::BandwidthStatus::Skipped, + upstream_mbps: None, + measured_at: None, + }; + let _ = config::save(&cfg); + eprintln!("OK — using the conservative default."); + return; + } + run_bandwidth_test(&mut cfg).await; + } + } +} + +async fn run_bandwidth_test(cfg: &mut config::Config) { + eprintln!(); + eprintln!("Measuring upstream…"); + + let result = tokio::task::spawn_blocking(bandwidth::measure_upstream_blocking).await; + let measurement = match result { + Ok(Ok(m)) => m, + Ok(Err(e)) => { + eprintln!("Test failed: {e:#}"); + eprintln!("Marking as failed — you'll be asked again on next launch."); + cfg.bandwidth = config::BandwidthEntry { + status: config::BandwidthStatus::Failed, + upstream_mbps: None, + measured_at: None, + }; + let _ = config::save(cfg); + return; + } + Err(join_err) => { + eprintln!("Test task panicked: {join_err}"); + cfg.bandwidth = config::BandwidthEntry { + status: config::BandwidthStatus::Failed, + upstream_mbps: None, + measured_at: None, + }; + let _ = config::save(cfg); + return; + } + }; + + eprintln!( + "Measured {:.2} Mbps up (safe estimate {:.2} Mbps, took {:.1}s).", + measurement.raw_mbps, + measurement.safe_mbps, + measurement.elapsed.as_secs_f64() + ); + cfg.bandwidth = config::BandwidthEntry { + status: config::BandwidthStatus::Measured, + upstream_mbps: Some(measurement.safe_mbps), + measured_at: Some(chrono::Utc::now()), + }; + if let Err(e) = config::save(cfg) { + eprintln!("Warning: failed to save result: {e:#}"); + } +} + fn print_welcome() { eprintln!(); eprintln!("Welcome to PixelPass."); diff --git a/src/main.rs b/src/main.rs index c27b028..4e91275 100644 --- a/src/main.rs +++ b/src/main.rs @@ -20,6 +20,10 @@ async fn main() -> Result<()> { return repair::run().await; } + if cli.reconfigure { + return interactive::run_reconfigure().await; + } + match cli.ticket.as_deref() { Some(s) => { let ticket: EndpointTicket = s.parse().map_err(|e| {