Skip to main content

ruso_runtime/runtime/
port_cache.rs

1//! TCP port reachability cache (30s TTL), shared across executor runs in one process.
2
3use std::collections::HashMap;
4use std::net::IpAddr;
5use std::sync::{Arc, OnceLock};
6use std::time::{Duration, Instant};
7
8use tokio::net::TcpStream;
9use tokio::sync::Mutex;
10use tokio::time::timeout;
11
12use reqwest::Url;
13
14use crate::runtime::context::VariableValue;
15use crate::runtime::interpolate::interpolate;
16use crate::runtime::spec::{ProbeKind, ProgramSpec};
17
18const CACHE_TTL: Duration = Duration::from_secs(30);
19const CONNECT_TIMEOUT: Duration = Duration::from_millis(1500);
20/// Cap cache size to bound memory growth during long-running bulk scans.
21/// At ~50 bytes per entry plus map overhead this is well under 1 MB.
22const MAX_CACHE_ENTRIES: usize = 4096;
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25enum Reachability {
26    Open,
27    Closed,
28}
29
30#[derive(Debug)]
31struct Entry {
32    state: Reachability,
33    checked_at: Instant,
34}
35
36/// Result of probing one `host:port` (from cache or live TCP connect).
37#[derive(Debug, Clone, PartialEq, Eq)]
38pub struct PortCheck {
39    pub host: String,
40    pub port: u16,
41    pub open: bool,
42}
43
44#[derive(Debug)]
45pub struct PortCache {
46    entries: Mutex<HashMap<(String, u16), Entry>>,
47}
48
49impl PortCache {
50    pub fn new() -> Arc<Self> {
51        Arc::new(Self {
52            entries: Mutex::new(HashMap::new()),
53        })
54    }
55
56    /// Process-wide cache so back-to-back checks in `ruso scan` reuse state.
57    pub fn global() -> Arc<Self> {
58        static CACHE: OnceLock<Arc<PortCache>> = OnceLock::new();
59        CACHE.get_or_init(PortCache::new).clone()
60    }
61
62    /// Endpoints to probe before running: socket probes in the spec plus `--target` host:port for HTTP checks.
63    pub fn endpoints_for_run(spec: &ProgramSpec, base_url: &str) -> Vec<(String, u16)> {
64        // Socket probes may use the documented `host "{{scan_host}}"` form,
65        // which the executor resolves from `--target` at send-time. The
66        // pre-run port check must resolve it identically, or it would probe
67        // the literal placeholder, find it "closed", and skip the whole run.
68        let scan_vars = scan_target_variables(base_url);
69        let resolve =
70            |host: &str| interpolate(host, &scan_vars).unwrap_or_else(|_| host.to_string());
71
72        let mut out = Vec::new();
73        for kind in spec.probes.values() {
74            match kind {
75                // Only connection-oriented TCP probes get a TCP-connect liveness
76                // pre-check. UDP and wire-DNS probes are sent as connectionless
77                // datagrams (`exchange_udp`), so a TCP connect to the same port
78                // proves nothing about the UDP service — and would wrongly skip
79                // the entire run for UDP-only hosts (NTP, SNMP, syslog, UDP-only
80                // resolvers, …). Those probes rely on their own read timeout.
81                ProbeKind::Tcp(socket) => {
82                    if let Some(port) = socket.port {
83                        out.push((normalize_host(&resolve(&socket.host)), port));
84                    }
85                }
86                ProbeKind::Udp(_) | ProbeKind::Dns(_) | ProbeKind::Http(_) => {}
87            }
88        }
89
90        let has_http = spec
91            .probes
92            .values()
93            .any(|k| matches!(k, ProbeKind::Http(_)));
94        if has_http && let Some((host, port)) = scan_target_host_port(base_url) {
95            out.push((host, port));
96        }
97
98        dedupe_endpoints(out)
99    }
100
101    /// Probe every endpoint for this run; returns checks and the first closed `(host, port)` if any.
102    pub async fn check_for_run(
103        &self,
104        spec: &ProgramSpec,
105        base_url: &str,
106    ) -> (Vec<PortCheck>, Option<(String, u16)>) {
107        let endpoints = Self::endpoints_for_run(spec, base_url);
108        let mut checks = Vec::with_capacity(endpoints.len());
109        let mut first_closed = None;
110
111        for (host, port) in endpoints {
112            let open = self.is_open(&host, port).await;
113            checks.push(PortCheck {
114                host: host.clone(),
115                port,
116                open,
117            });
118            if !open && first_closed.is_none() {
119                first_closed = Some((host, port));
120            }
121        }
122
123        (checks, first_closed)
124    }
125
126    pub async fn is_open(&self, host: &str, port: u16) -> bool {
127        self.get_state(host, port).await == Reachability::Open
128    }
129
130    async fn get_state(&self, host: &str, port: u16) -> Reachability {
131        let normalized = normalize_host(host);
132        let key = (normalized, port);
133        {
134            let guard = self.entries.lock().await;
135            if let Some(entry) = guard.get(&key)
136                && entry.checked_at.elapsed() < CACHE_TTL
137            {
138                return entry.state;
139            }
140        }
141
142        let state = probe_tcp(&key.0, port).await;
143        let mut guard = self.entries.lock().await;
144        // Evict the oldest entry if we'd otherwise exceed the cap. Linear
145        // scan is fine at the 4K-entry scale — bulk scans grow this slowly.
146        if guard.len() >= MAX_CACHE_ENTRIES
147            && !guard.contains_key(&key)
148            && let Some(oldest_key) = guard
149                .iter()
150                .min_by_key(|(_, e)| e.checked_at)
151                .map(|(k, _)| k.clone())
152        {
153            guard.remove(&oldest_key);
154        }
155        guard.insert(
156            key,
157            Entry {
158                state,
159                checked_at: Instant::now(),
160            },
161        );
162        state
163    }
164}
165
166/// Host and port from CLI `--target` / executor `base_url` (for `{{scan_host}}` / `{{scan_port}}`).
167pub fn scan_target_host_port(base_url: &str) -> Option<(String, u16)> {
168    let url = Url::parse(base_url).ok()?;
169    let host = url.host_str()?;
170    let port = url
171        .port()
172        .unwrap_or_else(|| if url.scheme() == "https" { 443 } else { 80 });
173    Some((normalize_host(host), port))
174}
175
176/// Build the `{{scan_host}}` / `{{scan_port}}` / `{{scan_url}}` variables from
177/// the CLI `--target` (executor `base_url`). Mirrors the executor's
178/// `inject_scan_target_variables` so the pre-run port check resolves socket
179/// hosts the same way the send path does.
180fn scan_target_variables(base_url: &str) -> HashMap<String, VariableValue> {
181    let mut vars = HashMap::new();
182    if let Some((host, port)) = scan_target_host_port(base_url) {
183        vars.insert("scan_host".to_string(), VariableValue::String(host));
184        vars.insert(
185            "scan_port".to_string(),
186            VariableValue::String(port.to_string()),
187        );
188    }
189    if !base_url.is_empty() {
190        vars.insert(
191            "scan_url".to_string(),
192            VariableValue::String(base_url.to_string()),
193        );
194    }
195    vars
196}
197
198/// Canonicalize a host string so equivalent IPv6 representations
199/// (`::1`, `0:0:0:0:0:0:0:1`) and arbitrary case in hostnames share one
200/// cache entry. Falls back to the original string on unknown formats.
201fn normalize_host(host: &str) -> String {
202    // Strip any surrounding brackets so `[::1]` and `::1` normalize the same.
203    let trimmed = host.trim_start_matches('[').trim_end_matches(']');
204    if let Ok(ip) = trimmed.parse::<IpAddr>() {
205        return ip.to_string();
206    }
207    host.to_ascii_lowercase()
208}
209
210fn dedupe_endpoints(mut endpoints: Vec<(String, u16)>) -> Vec<(String, u16)> {
211    endpoints.sort_unstable();
212    endpoints.dedup();
213    endpoints
214}
215
216/// Format a `host:port` socket address, bracketing literal IPv6 addresses
217/// so they parse correctly via `tokio::net::TcpStream::connect`.
218pub fn format_socket_addr(host: &str, port: u16) -> String {
219    let trimmed = host.trim_start_matches('[').trim_end_matches(']');
220    if trimmed.parse::<std::net::Ipv6Addr>().is_ok() {
221        format!("[{trimmed}]:{port}")
222    } else {
223        format!("{host}:{port}")
224    }
225}
226
227async fn probe_tcp(host: &str, port: u16) -> Reachability {
228    let address = format_socket_addr(host, port);
229    match timeout(CONNECT_TIMEOUT, TcpStream::connect(address.as_str())).await {
230        Ok(Ok(_)) => Reachability::Open,
231        _ => Reachability::Closed,
232    }
233}
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238
239    #[test]
240    fn normalize_ipv6_collapses_long_form() {
241        assert_eq!(normalize_host("::1"), normalize_host("0:0:0:0:0:0:0:1"));
242        assert_eq!(normalize_host("[::1]"), normalize_host("::1"));
243    }
244
245    #[test]
246    fn normalize_hostname_lowercases() {
247        assert_eq!(normalize_host("Example.COM"), "example.com");
248    }
249
250    #[test]
251    fn format_socket_addr_brackets_ipv6() {
252        assert_eq!(format_socket_addr("::1", 443), "[::1]:443");
253    }
254
255    #[test]
256    fn format_socket_addr_passes_through_ipv4() {
257        assert_eq!(format_socket_addr("127.0.0.1", 80), "127.0.0.1:80");
258    }
259
260    #[test]
261    fn format_socket_addr_passes_through_hostname() {
262        assert_eq!(format_socket_addr("example.com", 80), "example.com:80");
263    }
264
265    #[test]
266    fn scan_target_handles_ipv6_url() {
267        let (host, port) = scan_target_host_port("http://[::1]:8080/api").unwrap();
268        assert_eq!(host, "::1");
269        assert_eq!(port, 8080);
270    }
271
272    fn tcp_spec(host: &str) -> ProgramSpec {
273        use crate::runtime::spec::{CheckMetadata, SocketProbeSpec};
274        let mut probes = std::collections::HashMap::new();
275        probes.insert(
276            "svc".to_string(),
277            ProbeKind::Tcp(SocketProbeSpec {
278                host: host.to_string(),
279                port: Some(6379),
280                ..SocketProbeSpec::default()
281            }),
282        );
283        ProgramSpec {
284            probes,
285            metadata: CheckMetadata::default(),
286        }
287    }
288
289    #[test]
290    fn endpoints_interpolate_scan_host_for_socket_probes() {
291        // The pre-run port check must resolve `{{scan_host}}` from --target,
292        // just like the send path — otherwise socket probes are wrongly skipped.
293        let eps = PortCache::endpoints_for_run(&tcp_spec("{{scan_host}}"), "http://127.0.0.1:6379");
294        assert_eq!(eps, vec![("127.0.0.1".to_string(), 6379)]);
295    }
296
297    #[test]
298    fn endpoints_keep_static_host_unchanged() {
299        // A hardcoded host has no placeholder and must pass through verbatim,
300        // with or without a --target (e.g. banner-grab probes).
301        let eps = PortCache::endpoints_for_run(&tcp_spec("scanme.example.com"), "");
302        assert_eq!(eps, vec![("scanme.example.com".to_string(), 6379)]);
303    }
304
305    fn single_probe_spec(kind: ProbeKind) -> ProgramSpec {
306        use crate::runtime::spec::CheckMetadata;
307        let mut probes = std::collections::HashMap::new();
308        probes.insert("svc".to_string(), kind);
309        ProgramSpec {
310            probes,
311            metadata: CheckMetadata::default(),
312        }
313    }
314
315    #[test]
316    fn udp_and_dns_probes_get_no_tcp_precheck() {
317        use crate::runtime::spec::SocketProbeSpec;
318        // UDP and wire-DNS are connectionless; a TCP-connect pre-check to their
319        // port proves nothing and would wrongly skip the run on UDP-only hosts.
320        let udp = single_probe_spec(ProbeKind::Udp(SocketProbeSpec {
321            host: "{{scan_host}}".into(),
322            port: Some(123),
323            ..SocketProbeSpec::default()
324        }));
325        assert!(PortCache::endpoints_for_run(&udp, "http://127.0.0.1").is_empty());
326
327        let dns = single_probe_spec(ProbeKind::Dns(SocketProbeSpec {
328            host: "{{scan_host}}".into(),
329            port: Some(53),
330            ..SocketProbeSpec::default()
331        }));
332        assert!(PortCache::endpoints_for_run(&dns, "http://127.0.0.1").is_empty());
333    }
334}