Skip to main content

adguardian/
welcome.rs

1use base64::{engine::general_purpose::STANDARD, Engine as _};
2use colored::*;
3use crossterm::{
4  event::{self, Event, KeyCode, KeyModifiers},
5  terminal::{disable_raw_mode, enable_raw_mode},
6};
7use reqwest::{Client, Error};
8use std::{
9  cmp::Ordering,
10  env,
11  fmt::Display,
12  io::{self, IsTerminal, Write},
13  time::Duration,
14};
15
16use semver::Version;
17use serde::Deserialize;
18use serde_json::Value;
19
20/// Reusable function that just prints success messages to the console
21fn print_info(text: &str, is_secondary: bool) {
22  if is_secondary {
23    println!("{}", text.green().italic().dimmed());
24  } else {
25    println!("{}", text.green());
26  };
27}
28
29/// Prints the AdGuardian ASCII art to console
30fn print_ascii_art() {
31  let art = r"
32 █████╗ ██████╗  ██████╗ ██╗   ██╗ █████╗ ██████╗ ██████╗ ██╗ █████╗ ███╗   ██╗
33██╔══██╗██╔══██╗██╔════╝ ██║   ██║██╔══██╗██╔══██╗██╔══██╗██║██╔══██╗████╗  ██║
34███████║██║  ██║██║  ███╗██║   ██║███████║██████╔╝██║  ██║██║███████║██╔██╗ ██║
35██╔══██║██║  ██║██║   ██║██║   ██║██╔══██║██╔══██╗██║  ██║██║██╔══██║██║╚██╗██║
36██║  ██║██████╔╝╚██████╔╝╚██████╔╝██║  ██║██║  ██║██████╔╝██║██║  ██║██║ ╚████║
37╚═╝  ╚═╝╚═════╝  ╚═════╝  ╚═════╝ ╚═╝  ╚═╝╚═╝  ╚═╝╚═════╝ ╚═╝╚═╝  ╚═╝╚═╝  ╚═══╝
38";
39  print_info(art, false);
40  print_info("\nWelcome to AdGuardian Terminal Edition!", false);
41  print_info(
42    "Terminal-based, real-time traffic monitoring and statistics for your AdGuard Home instance",
43    true,
44  );
45  print_info(
46    "For documentation and support, please visit: https://github.com/lissy93/adguardian-term",
47    true,
48  );
49}
50
51/// Print error message, along with (optional) stack trace, then exit
52fn print_error(message: &str, sub_message: &str, error: Option<&Error>) -> ! {
53  eprintln!(
54    "{}{}{}",
55    message.red(),
56    match error {
57      Some(err) => format!("\n{}", err).red().dimmed(),
58      None => "".red().dimmed(),
59    },
60    format!("\n{}", sub_message).yellow(),
61  );
62
63  std::process::exit(1);
64}
65
66/// Given a key, get the value from the environmental variables, and print it to the console
67fn get_env(key: &str) -> Result<String, env::VarError> {
68  env::var(key).inspect(|v| {
69    println!(
70      "{}",
71      format!(
72        "{} is set to {}",
73        key.bold(),
74        if key.contains("PASSWORD") {
75          "******"
76        } else {
77          v
78        }
79      )
80      .green()
81    );
82  })
83}
84
85/// Given a possibly undefined version number, check if it's present and supported
86fn check_version(version: Option<&str>) {
87  let min_version = Version::parse("0.107.29").unwrap();
88
89  match version {
90    Some(version_str) => {
91      match Version::parse(version_str.strip_prefix('v').unwrap_or(version_str)) {
92        Ok(adguard_version) if adguard_version < min_version => print_error(
93          "AdGuard Home version is too old, and is now unsupported",
94          format!(
95            "You're running AdGuard {}. Please upgrade to v{} or later.",
96            version_str, min_version
97          )
98          .as_str(),
99          None,
100        ),
101        Ok(_) => {}
102        Err(_) => print_error(
103          "Unsupported AdGuard Home version",
104          "Couldn't parse the version number reported by your AdGuard Home instance.",
105          None,
106        ),
107      }
108    }
109    None => {
110      print_error(
111        "Unsupported AdGuard Home version",
112        format!(
113          concat!(
114            "Failed to get the version number of your AdGuard Home instance.\n",
115            "This usually means you're running an old, and unsupported version.\n",
116            "Please upgrade to v{} or later."
117          ),
118          min_version
119        )
120        .as_str(),
121        None,
122      );
123    }
124  }
125}
126
127/// Run an async operation, retrying on error up to `attempts` times, `delay` apart.
128/// Each failure is reported; the last error is returned once attempts are exhausted.
129pub async fn with_retries<T, E, F, Fut>(
130  attempts: u32,
131  delay: Duration,
132  label: &str,
133  mut operation: F,
134) -> Result<T, E>
135where
136  F: FnMut() -> Fut,
137  Fut: std::future::Future<Output = Result<T, E>>,
138  E: Display,
139{
140  let mut attempt = 1;
141  loop {
142    match operation().await {
143      Ok(value) => return Ok(value),
144      Err(e) if attempt < attempts => {
145        println!(
146          "{}",
147          format!(
148            "{} failed (attempt {}/{}): {}\nRetrying in {}s...",
149            label,
150            attempt,
151            attempts,
152            e,
153            delay.as_secs()
154          )
155          .yellow()
156        );
157        tokio::time::sleep(delay).await;
158        attempt += 1;
159      }
160      Err(e) => return Err(e),
161    }
162  }
163}
164
165/// With the users specified AdGuard details, verify the connection.
166/// Returns `Err` on a failed connection (so the caller can retry); exits on
167/// rejected auth or an unsupported version, which retrying wouldn't fix.
168async fn verify_connection(
169  client: &Client,
170  ip: &str,
171  port: &str,
172  protocol: &str,
173  username: &str,
174  password: &str,
175) -> Result<(), Box<dyn std::error::Error>> {
176  println!(
177    "{}",
178    "\nVerifying connection to your AdGuard instance...".blue()
179  );
180
181  let auth_string = format!("{}:{}", username, password);
182  let auth_header_value = format!("Basic {}", STANDARD.encode(&auth_string));
183  let mut headers = reqwest::header::HeaderMap::new();
184  headers.insert("Authorization", auth_header_value.parse()?);
185
186  let url = format!("{}://{}:{}/control/status", protocol, ip, port);
187
188  match client
189    .get(&url)
190    .headers(headers)
191    .timeout(Duration::from_secs(2))
192    .send()
193    .await
194  {
195    Ok(res) if res.status().is_success() => {
196      // Get version string (if present), and check if valid - exit if not
197      let body: Value = res.json().await?;
198      check_version(body["version"].as_str());
199      // All good! Print success message :)
200      let safe_version = body["version"].as_str().unwrap_or("mystery version");
201      println!(
202        "{}",
203        format!("AdGuard ({}) connection successful!\n", safe_version).green()
204      );
205      Ok(())
206    }
207    // Connection failed to authenticate. Print error and exit
208    Ok(_) => print_error(
209      &format!("Authentication with AdGuard at {}:{} failed", ip, port),
210      "Check the credentials you passed as environmental variables and try again.",
211      None,
212    ),
213    // Connection failed to establish - return so the caller can retry
214    Err(e) => Err(e.into()),
215  }
216}
217
218#[derive(Deserialize)]
219struct CratesIoResponse {
220  #[serde(rename = "crate")]
221  krate: Crate,
222}
223
224#[derive(Deserialize)]
225struct Crate {
226  max_version: String,
227}
228
229/// Gets the latest version of the crate from crates.io
230async fn get_latest_version(crate_name: &str) -> Result<String, Box<dyn std::error::Error>> {
231  let url = format!("https://crates.io/api/v1/crates/{}", crate_name);
232  let client = reqwest::Client::new();
233  let res = client
234    .get(&url)
235    .header(
236      reqwest::header::USER_AGENT,
237      "version_check (adguardian.as93.net)",
238    )
239    .timeout(Duration::from_secs(2))
240    .send()
241    .await?;
242
243  if res.status().is_success() {
244    let response: CratesIoResponse = res.json().await?;
245    Ok(response.krate.max_version)
246  } else {
247    let status = res.status();
248    let body = res.text().await?;
249    Err(format!("Request failed with status {}: body: {}", status, body).into())
250  }
251}
252
253/// Checks for updates to the crate, and prints a message if an update is available
254async fn check_for_updates() {
255  // Get crate name and version from Cargo.toml
256  let crate_name = env!("CARGO_PKG_NAME");
257  let crate_version = env!("CARGO_PKG_VERSION");
258  println!("{}", "\nChecking for updates...".blue());
259  // Parse the current version, and fetch and parse the latest version
260  let zero = Version::new(0, 0, 0);
261  let current_version = Version::parse(crate_version).unwrap_or_else(|_| zero.clone());
262  let latest_version = Version::parse(
263    &get_latest_version(crate_name)
264      .await
265      .unwrap_or_else(|_| "0.0.0".to_string()),
266  )
267  .unwrap_or_else(|_| zero.clone());
268
269  // Compare the current and latest versions, and print the appropriate message
270  if current_version == zero || latest_version == zero {
271    println!("{}", "Unable to check for updates".yellow());
272    return;
273  }
274  match current_version.cmp(&latest_version) {
275    Ordering::Less => println!(
276      "{}",
277      format!(
278        "A new version of AdGuardian is available.\nUpdate from {} to {} for the best experience",
279        current_version.to_string().bold(),
280        latest_version.to_string().bold()
281      )
282      .yellow()
283    ),
284    Ordering::Equal => println!(
285      "{}",
286      format!(
287        "AdGuardian is up-to-date, running version {}",
288        current_version.to_string().bold()
289      )
290      .green()
291    ),
292    Ordering::Greater => println!(
293      "{}",
294      format!(
295        "Running a pre-released edition of AdGuardian, version {}",
296        current_version.to_string().bold()
297      )
298      .green()
299    ),
300  }
301}
302
303/// The value to pre-fill for a field's interactive prompt, where a sensible one exists
304fn default_for(key: &str) -> Option<&'static str> {
305  match key {
306    "ADGUARD_IP" => Some("127.0.0.1"),
307    "ADGUARD_PORT" => Some("3000"),
308    _ => None,
309  }
310}
311
312/// Read a line from the terminal in raw mode, echoing nothing. Ctrl-C cancels.
313fn read_masked() -> io::Result<String> {
314  enable_raw_mode()?;
315  let result = masked_loop();
316  let _ = disable_raw_mode();
317  if result.is_ok() {
318    println!();
319  }
320  result
321}
322
323fn masked_loop() -> io::Result<String> {
324  let mut value = String::new();
325  loop {
326    if let Event::Key(key) = event::read()? {
327      let ctrl = key.modifiers.contains(KeyModifiers::CONTROL);
328      match key.code {
329        KeyCode::Enter => return Ok(value),
330        KeyCode::Char('c') if ctrl => return Err(io::ErrorKind::Interrupted.into()),
331        KeyCode::Char(c) if !ctrl => value.push(c),
332        KeyCode::Backspace => {
333          value.pop();
334        }
335        _ => {}
336      }
337    }
338  }
339}
340
341/// Print the prompt and read a value, masking secret fields on an interactive terminal
342fn read_field(prompt: &ColoredString, secret: bool) -> io::Result<String> {
343  print!("{}", prompt);
344  io::stdout().flush()?;
345  if secret && io::stdin().is_terminal() {
346    read_masked()
347  } else {
348    let mut value = String::new();
349    io::stdin().read_line(&mut value)?;
350    Ok(value)
351  }
352}
353
354/// Read a field off the async runtime threads
355async fn read_input(prompt: ColoredString, secret: bool) -> io::Result<String> {
356  tokio::task::spawn_blocking(move || read_field(&prompt, secret))
357    .await
358    .expect("input task panicked")
359}
360
361/// Print the cancellation notice and exit cleanly
362fn exit_interrupted() -> ! {
363  println!(
364    "{}",
365    "\n\nAdGuardian setup interrupted by user, exiting...".yellow()
366  );
367  std::process::exit(0);
368}
369
370/// Prompt for a single field, re-prompting until the input is valid.
371/// Masks passwords, applies the field's default on empty input, validates the
372/// port is numeric, and exits cleanly if the user interrupts with Ctrl-C.
373async fn prompt_for(key: &str) -> Result<String, Box<dyn std::error::Error>> {
374  let default = default_for(key);
375  let secret = key.contains("PASSWORD");
376  loop {
377    let hint = default.map(|d| format!(" [{}]", d)).unwrap_or_default();
378    let prompt = format!("› Enter a value for {}{}: ", key, hint)
379      .blue()
380      .bold();
381
382    let input = tokio::select! {
383      res = read_input(prompt, secret) => match res {
384        Ok(value) => value,
385        Err(e) if e.kind() == io::ErrorKind::Interrupted => exit_interrupted(),
386        Err(e) => return Err(e.into()),
387      },
388      _ = tokio::signal::ctrl_c() => exit_interrupted(),
389    };
390
391    let value = match input.trim() {
392      "" => default.unwrap_or_default(),
393      trimmed => trimmed,
394    };
395
396    if key == "ADGUARD_PORT" && value.parse::<u16>().is_err() {
397      println!("{}", "Port must be a number, and a valid port".yellow());
398      continue;
399    }
400    return Ok(value.to_string());
401  }
402}
403
404/// Initiate the welcome script
405/// This function will:
406/// - Print the AdGuardian ASCII art
407/// - Check if there's an update available
408/// - Check for the required environmental variables
409/// - Prompt the user to enter any missing variables
410/// - Verify the connection to the AdGuard instance
411/// - Verify authentication is successful
412/// - Verify the AdGuard Home version is supported
413/// - Then either print a success message, or show instructions to fix and exit
414pub async fn welcome() -> Result<(), Box<dyn std::error::Error>> {
415  print_ascii_art();
416
417  // Check for updates
418  check_for_updates().await;
419
420  println!("{}", "\nStarting initialization checks...".blue());
421
422  let client = Client::new();
423
424  // List of available flags, ant their associated env vars
425  let flags = [
426    ("--adguard-ip", "ADGUARD_IP"),
427    ("--adguard-port", "ADGUARD_PORT"),
428    ("--adguard-username", "ADGUARD_USERNAME"),
429    ("--adguard-password", "ADGUARD_PASSWORD"),
430  ];
431
432  let protocol: String = env::var("ADGUARD_PROTOCOL")
433    .unwrap_or_else(|_| "http".into())
434    .parse()?;
435  env::set_var("ADGUARD_PROTOCOL", protocol);
436
437  // Parse command line arguments
438  let mut args = std::env::args().peekable();
439  while let Some(arg) = args.next() {
440    for &(flag, var) in &flags {
441      if arg == flag {
442        if let Some(value) = args.peek().filter(|v| !v.starts_with("--")) {
443          env::set_var(var, value);
444          args.next();
445        }
446      }
447    }
448  }
449
450  // If any of the env variables or flags are not yet set, prompt the user to enter them
451  for &key in &[
452    "ADGUARD_IP",
453    "ADGUARD_PORT",
454    "ADGUARD_USERNAME",
455    "ADGUARD_PASSWORD",
456  ] {
457    if env::var(key).is_err() {
458      println!(
459        "{}",
460        format!("The {} environmental variable is not yet set", key.bold()).yellow()
461      );
462      env::set_var(key, prompt_for(key).await?);
463    }
464  }
465
466  // Grab the values of the (now set) environmental variables
467  let ip = get_env("ADGUARD_IP")?;
468  let port = get_env("ADGUARD_PORT")?;
469  let protocol = get_env("ADGUARD_PROTOCOL")?;
470  let username = get_env("ADGUARD_USERNAME")?;
471  let password = get_env("ADGUARD_PASSWORD")?;
472
473  // Verify we can connect, authenticate, and that the version is supported
474  let connected = with_retries(3, Duration::from_secs(5), "AdGuard connection", || {
475    verify_connection(&client, &ip, &port, &protocol, &username, &password)
476  })
477  .await;
478
479  if connected.is_err() {
480    print_error(
481      &format!(
482        "Could not reach AdGuard at {}:{} after 3 attempts",
483        ip, port
484      ),
485      "Please check that AdGuard Home is running and your settings are correct.",
486      None,
487    );
488  }
489
490  Ok(())
491}