1use base64::{engine::general_purpose::STANDARD, Engine as _};
2use colored::*;
3use crossterm::{
4 event::{self, Event, KeyCode, KeyModifiers},
5 terminal::{disable_raw_mode, enable_raw_mode},
6};
7use reqwest::{Client, Error};
8use std::{
9 cmp::Ordering,
10 env,
11 fmt::Display,
12 io::{self, IsTerminal, Write},
13 time::Duration,
14};
15
16use semver::Version;
17use serde::Deserialize;
18use serde_json::Value;
19
20fn print_info(text: &str, is_secondary: bool) {
22 if is_secondary {
23 println!("{}", text.green().italic().dimmed());
24 } else {
25 println!("{}", text.green());
26 };
27}
28
29fn print_ascii_art() {
31 let art = r"
32 █████╗ ██████╗ ██████╗ ██╗ ██╗ █████╗ ██████╗ ██████╗ ██╗ █████╗ ███╗ ██╗
33██╔══██╗██╔══██╗██╔════╝ ██║ ██║██╔══██╗██╔══██╗██╔══██╗██║██╔══██╗████╗ ██║
34███████║██║ ██║██║ ███╗██║ ██║███████║██████╔╝██║ ██║██║███████║██╔██╗ ██║
35██╔══██║██║ ██║██║ ██║██║ ██║██╔══██║██╔══██╗██║ ██║██║██╔══██║██║╚██╗██║
36██║ ██║██████╔╝╚██████╔╝╚██████╔╝██║ ██║██║ ██║██████╔╝██║██║ ██║██║ ╚████║
37╚═╝ ╚═╝╚═════╝ ╚═════╝ ╚═════╝ ╚═╝ ╚═╝╚═╝ ╚═╝╚═════╝ ╚═╝╚═╝ ╚═╝╚═╝ ╚═══╝
38";
39 print_info(art, false);
40 print_info("\nWelcome to AdGuardian Terminal Edition!", false);
41 print_info(
42 "Terminal-based, real-time traffic monitoring and statistics for your AdGuard Home instance",
43 true,
44 );
45 print_info(
46 "For documentation and support, please visit: https://github.com/lissy93/adguardian-term",
47 true,
48 );
49}
50
51fn print_error(message: &str, sub_message: &str, error: Option<&Error>) -> ! {
53 eprintln!(
54 "{}{}{}",
55 message.red(),
56 match error {
57 Some(err) => format!("\n{}", err).red().dimmed(),
58 None => "".red().dimmed(),
59 },
60 format!("\n{}", sub_message).yellow(),
61 );
62
63 std::process::exit(1);
64}
65
66fn get_env(key: &str) -> Result<String, env::VarError> {
68 env::var(key).inspect(|v| {
69 println!(
70 "{}",
71 format!(
72 "{} is set to {}",
73 key.bold(),
74 if key.contains("PASSWORD") {
75 "******"
76 } else {
77 v
78 }
79 )
80 .green()
81 );
82 })
83}
84
85fn check_version(version: Option<&str>) {
87 let min_version = Version::parse("0.107.29").unwrap();
88
89 match version {
90 Some(version_str) => {
91 match Version::parse(version_str.strip_prefix('v').unwrap_or(version_str)) {
92 Ok(adguard_version) if adguard_version < min_version => print_error(
93 "AdGuard Home version is too old, and is now unsupported",
94 format!(
95 "You're running AdGuard {}. Please upgrade to v{} or later.",
96 version_str, min_version
97 )
98 .as_str(),
99 None,
100 ),
101 Ok(_) => {}
102 Err(_) => print_error(
103 "Unsupported AdGuard Home version",
104 "Couldn't parse the version number reported by your AdGuard Home instance.",
105 None,
106 ),
107 }
108 }
109 None => {
110 print_error(
111 "Unsupported AdGuard Home version",
112 format!(
113 concat!(
114 "Failed to get the version number of your AdGuard Home instance.\n",
115 "This usually means you're running an old, and unsupported version.\n",
116 "Please upgrade to v{} or later."
117 ),
118 min_version
119 )
120 .as_str(),
121 None,
122 );
123 }
124 }
125}
126
127pub async fn with_retries<T, E, F, Fut>(
130 attempts: u32,
131 delay: Duration,
132 label: &str,
133 mut operation: F,
134) -> Result<T, E>
135where
136 F: FnMut() -> Fut,
137 Fut: std::future::Future<Output = Result<T, E>>,
138 E: Display,
139{
140 let mut attempt = 1;
141 loop {
142 match operation().await {
143 Ok(value) => return Ok(value),
144 Err(e) if attempt < attempts => {
145 println!(
146 "{}",
147 format!(
148 "{} failed (attempt {}/{}): {}\nRetrying in {}s...",
149 label,
150 attempt,
151 attempts,
152 e,
153 delay.as_secs()
154 )
155 .yellow()
156 );
157 tokio::time::sleep(delay).await;
158 attempt += 1;
159 }
160 Err(e) => return Err(e),
161 }
162 }
163}
164
165async fn verify_connection(
169 client: &Client,
170 ip: &str,
171 port: &str,
172 protocol: &str,
173 username: &str,
174 password: &str,
175) -> Result<(), Box<dyn std::error::Error>> {
176 println!(
177 "{}",
178 "\nVerifying connection to your AdGuard instance...".blue()
179 );
180
181 let auth_string = format!("{}:{}", username, password);
182 let auth_header_value = format!("Basic {}", STANDARD.encode(&auth_string));
183 let mut headers = reqwest::header::HeaderMap::new();
184 headers.insert("Authorization", auth_header_value.parse()?);
185
186 let url = format!("{}://{}:{}/control/status", protocol, ip, port);
187
188 match client
189 .get(&url)
190 .headers(headers)
191 .timeout(Duration::from_secs(2))
192 .send()
193 .await
194 {
195 Ok(res) if res.status().is_success() => {
196 let body: Value = res.json().await?;
198 check_version(body["version"].as_str());
199 let safe_version = body["version"].as_str().unwrap_or("mystery version");
201 println!(
202 "{}",
203 format!("AdGuard ({}) connection successful!\n", safe_version).green()
204 );
205 Ok(())
206 }
207 Ok(_) => print_error(
209 &format!("Authentication with AdGuard at {}:{} failed", ip, port),
210 "Check the credentials you passed as environmental variables and try again.",
211 None,
212 ),
213 Err(e) => Err(e.into()),
215 }
216}
217
218#[derive(Deserialize)]
219struct CratesIoResponse {
220 #[serde(rename = "crate")]
221 krate: Crate,
222}
223
224#[derive(Deserialize)]
225struct Crate {
226 max_version: String,
227}
228
229async fn get_latest_version(crate_name: &str) -> Result<String, Box<dyn std::error::Error>> {
231 let url = format!("https://crates.io/api/v1/crates/{}", crate_name);
232 let client = reqwest::Client::new();
233 let res = client
234 .get(&url)
235 .header(
236 reqwest::header::USER_AGENT,
237 "version_check (adguardian.as93.net)",
238 )
239 .timeout(Duration::from_secs(2))
240 .send()
241 .await?;
242
243 if res.status().is_success() {
244 let response: CratesIoResponse = res.json().await?;
245 Ok(response.krate.max_version)
246 } else {
247 let status = res.status();
248 let body = res.text().await?;
249 Err(format!("Request failed with status {}: body: {}", status, body).into())
250 }
251}
252
253async fn check_for_updates() {
255 let crate_name = env!("CARGO_PKG_NAME");
257 let crate_version = env!("CARGO_PKG_VERSION");
258 println!("{}", "\nChecking for updates...".blue());
259 let zero = Version::new(0, 0, 0);
261 let current_version = Version::parse(crate_version).unwrap_or_else(|_| zero.clone());
262 let latest_version = Version::parse(
263 &get_latest_version(crate_name)
264 .await
265 .unwrap_or_else(|_| "0.0.0".to_string()),
266 )
267 .unwrap_or_else(|_| zero.clone());
268
269 if current_version == zero || latest_version == zero {
271 println!("{}", "Unable to check for updates".yellow());
272 return;
273 }
274 match current_version.cmp(&latest_version) {
275 Ordering::Less => println!(
276 "{}",
277 format!(
278 "A new version of AdGuardian is available.\nUpdate from {} to {} for the best experience",
279 current_version.to_string().bold(),
280 latest_version.to_string().bold()
281 )
282 .yellow()
283 ),
284 Ordering::Equal => println!(
285 "{}",
286 format!(
287 "AdGuardian is up-to-date, running version {}",
288 current_version.to_string().bold()
289 )
290 .green()
291 ),
292 Ordering::Greater => println!(
293 "{}",
294 format!(
295 "Running a pre-released edition of AdGuardian, version {}",
296 current_version.to_string().bold()
297 )
298 .green()
299 ),
300 }
301}
302
303fn default_for(key: &str) -> Option<&'static str> {
305 match key {
306 "ADGUARD_IP" => Some("127.0.0.1"),
307 "ADGUARD_PORT" => Some("3000"),
308 _ => None,
309 }
310}
311
312fn read_masked() -> io::Result<String> {
314 enable_raw_mode()?;
315 let result = masked_loop();
316 let _ = disable_raw_mode();
317 if result.is_ok() {
318 println!();
319 }
320 result
321}
322
323fn masked_loop() -> io::Result<String> {
324 let mut value = String::new();
325 loop {
326 if let Event::Key(key) = event::read()? {
327 let ctrl = key.modifiers.contains(KeyModifiers::CONTROL);
328 match key.code {
329 KeyCode::Enter => return Ok(value),
330 KeyCode::Char('c') if ctrl => return Err(io::ErrorKind::Interrupted.into()),
331 KeyCode::Char(c) if !ctrl => value.push(c),
332 KeyCode::Backspace => {
333 value.pop();
334 }
335 _ => {}
336 }
337 }
338 }
339}
340
341fn read_field(prompt: &ColoredString, secret: bool) -> io::Result<String> {
343 print!("{}", prompt);
344 io::stdout().flush()?;
345 if secret && io::stdin().is_terminal() {
346 read_masked()
347 } else {
348 let mut value = String::new();
349 io::stdin().read_line(&mut value)?;
350 Ok(value)
351 }
352}
353
354async fn read_input(prompt: ColoredString, secret: bool) -> io::Result<String> {
356 tokio::task::spawn_blocking(move || read_field(&prompt, secret))
357 .await
358 .expect("input task panicked")
359}
360
361fn exit_interrupted() -> ! {
363 println!(
364 "{}",
365 "\n\nAdGuardian setup interrupted by user, exiting...".yellow()
366 );
367 std::process::exit(0);
368}
369
370async fn prompt_for(key: &str) -> Result<String, Box<dyn std::error::Error>> {
374 let default = default_for(key);
375 let secret = key.contains("PASSWORD");
376 loop {
377 let hint = default.map(|d| format!(" [{}]", d)).unwrap_or_default();
378 let prompt = format!("› Enter a value for {}{}: ", key, hint)
379 .blue()
380 .bold();
381
382 let input = tokio::select! {
383 res = read_input(prompt, secret) => match res {
384 Ok(value) => value,
385 Err(e) if e.kind() == io::ErrorKind::Interrupted => exit_interrupted(),
386 Err(e) => return Err(e.into()),
387 },
388 _ = tokio::signal::ctrl_c() => exit_interrupted(),
389 };
390
391 let value = match input.trim() {
392 "" => default.unwrap_or_default(),
393 trimmed => trimmed,
394 };
395
396 if key == "ADGUARD_PORT" && value.parse::<u16>().is_err() {
397 println!("{}", "Port must be a number, and a valid port".yellow());
398 continue;
399 }
400 return Ok(value.to_string());
401 }
402}
403
404pub async fn welcome() -> Result<(), Box<dyn std::error::Error>> {
415 print_ascii_art();
416
417 check_for_updates().await;
419
420 println!("{}", "\nStarting initialization checks...".blue());
421
422 let client = Client::new();
423
424 let flags = [
426 ("--adguard-ip", "ADGUARD_IP"),
427 ("--adguard-port", "ADGUARD_PORT"),
428 ("--adguard-username", "ADGUARD_USERNAME"),
429 ("--adguard-password", "ADGUARD_PASSWORD"),
430 ];
431
432 let protocol: String = env::var("ADGUARD_PROTOCOL")
433 .unwrap_or_else(|_| "http".into())
434 .parse()?;
435 env::set_var("ADGUARD_PROTOCOL", protocol);
436
437 let mut args = std::env::args().peekable();
439 while let Some(arg) = args.next() {
440 for &(flag, var) in &flags {
441 if arg == flag {
442 if let Some(value) = args.peek().filter(|v| !v.starts_with("--")) {
443 env::set_var(var, value);
444 args.next();
445 }
446 }
447 }
448 }
449
450 for &key in &[
452 "ADGUARD_IP",
453 "ADGUARD_PORT",
454 "ADGUARD_USERNAME",
455 "ADGUARD_PASSWORD",
456 ] {
457 if env::var(key).is_err() {
458 println!(
459 "{}",
460 format!("The {} environmental variable is not yet set", key.bold()).yellow()
461 );
462 env::set_var(key, prompt_for(key).await?);
463 }
464 }
465
466 let ip = get_env("ADGUARD_IP")?;
468 let port = get_env("ADGUARD_PORT")?;
469 let protocol = get_env("ADGUARD_PROTOCOL")?;
470 let username = get_env("ADGUARD_USERNAME")?;
471 let password = get_env("ADGUARD_PASSWORD")?;
472
473 let connected = with_retries(3, Duration::from_secs(5), "AdGuard connection", || {
475 verify_connection(&client, &ip, &port, &protocol, &username, &password)
476 })
477 .await;
478
479 if connected.is_err() {
480 print_error(
481 &format!(
482 "Could not reach AdGuard at {}:{} after 3 attempts",
483 ip, port
484 ),
485 "Please check that AdGuard Home is running and your settings are correct.",
486 None,
487 );
488 }
489
490 Ok(())
491}