1use std::collections::HashMap;
4use std::net::IpAddr;
5use std::sync::{Arc, OnceLock};
6use std::time::{Duration, Instant};
7
8use tokio::net::TcpStream;
9use tokio::sync::Mutex;
10use tokio::time::timeout;
11
12use reqwest::Url;
13
14use crate::runtime::context::VariableValue;
15use crate::runtime::interpolate::interpolate;
16use crate::runtime::spec::{ProbeKind, ProgramSpec};
17
18const CACHE_TTL: Duration = Duration::from_secs(30);
19const CONNECT_TIMEOUT: Duration = Duration::from_millis(1500);
20const MAX_CACHE_ENTRIES: usize = 4096;
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25enum Reachability {
26 Open,
27 Closed,
28}
29
30#[derive(Debug)]
31struct Entry {
32 state: Reachability,
33 checked_at: Instant,
34}
35
36#[derive(Debug, Clone, PartialEq, Eq)]
38pub struct PortCheck {
39 pub host: String,
40 pub port: u16,
41 pub open: bool,
42}
43
44#[derive(Debug)]
45pub struct PortCache {
46 entries: Mutex<HashMap<(String, u16), Entry>>,
47}
48
49impl PortCache {
50 pub fn new() -> Arc<Self> {
51 Arc::new(Self {
52 entries: Mutex::new(HashMap::new()),
53 })
54 }
55
56 pub fn global() -> Arc<Self> {
58 static CACHE: OnceLock<Arc<PortCache>> = OnceLock::new();
59 CACHE.get_or_init(PortCache::new).clone()
60 }
61
62 pub fn endpoints_for_run(spec: &ProgramSpec, base_url: &str) -> Vec<(String, u16)> {
64 let scan_vars = scan_target_variables(base_url);
69 let resolve =
70 |host: &str| interpolate(host, &scan_vars).unwrap_or_else(|_| host.to_string());
71
72 let mut out = Vec::new();
73 for kind in spec.probes.values() {
74 match kind {
75 ProbeKind::Tcp(socket) => {
82 if let Some(port) = socket.port {
83 out.push((normalize_host(&resolve(&socket.host)), port));
84 }
85 }
86 ProbeKind::Udp(_) | ProbeKind::Dns(_) | ProbeKind::Http(_) => {}
87 }
88 }
89
90 let has_http = spec
91 .probes
92 .values()
93 .any(|k| matches!(k, ProbeKind::Http(_)));
94 if has_http && let Some((host, port)) = scan_target_host_port(base_url) {
95 out.push((host, port));
96 }
97
98 dedupe_endpoints(out)
99 }
100
101 pub async fn check_for_run(
103 &self,
104 spec: &ProgramSpec,
105 base_url: &str,
106 ) -> (Vec<PortCheck>, Option<(String, u16)>) {
107 let endpoints = Self::endpoints_for_run(spec, base_url);
108 let mut checks = Vec::with_capacity(endpoints.len());
109 let mut first_closed = None;
110
111 for (host, port) in endpoints {
112 let open = self.is_open(&host, port).await;
113 checks.push(PortCheck {
114 host: host.clone(),
115 port,
116 open,
117 });
118 if !open && first_closed.is_none() {
119 first_closed = Some((host, port));
120 }
121 }
122
123 (checks, first_closed)
124 }
125
126 pub async fn is_open(&self, host: &str, port: u16) -> bool {
127 self.get_state(host, port).await == Reachability::Open
128 }
129
130 async fn get_state(&self, host: &str, port: u16) -> Reachability {
131 let normalized = normalize_host(host);
132 let key = (normalized, port);
133 {
134 let guard = self.entries.lock().await;
135 if let Some(entry) = guard.get(&key)
136 && entry.checked_at.elapsed() < CACHE_TTL
137 {
138 return entry.state;
139 }
140 }
141
142 let state = probe_tcp(&key.0, port).await;
143 let mut guard = self.entries.lock().await;
144 if guard.len() >= MAX_CACHE_ENTRIES
147 && !guard.contains_key(&key)
148 && let Some(oldest_key) = guard
149 .iter()
150 .min_by_key(|(_, e)| e.checked_at)
151 .map(|(k, _)| k.clone())
152 {
153 guard.remove(&oldest_key);
154 }
155 guard.insert(
156 key,
157 Entry {
158 state,
159 checked_at: Instant::now(),
160 },
161 );
162 state
163 }
164}
165
166pub fn scan_target_host_port(base_url: &str) -> Option<(String, u16)> {
168 let url = Url::parse(base_url).ok()?;
169 let host = url.host_str()?;
170 let port = url
171 .port()
172 .unwrap_or_else(|| if url.scheme() == "https" { 443 } else { 80 });
173 Some((normalize_host(host), port))
174}
175
176fn scan_target_variables(base_url: &str) -> HashMap<String, VariableValue> {
181 let mut vars = HashMap::new();
182 if let Some((host, port)) = scan_target_host_port(base_url) {
183 vars.insert("scan_host".to_string(), VariableValue::String(host));
184 vars.insert(
185 "scan_port".to_string(),
186 VariableValue::String(port.to_string()),
187 );
188 }
189 if !base_url.is_empty() {
190 vars.insert(
191 "scan_url".to_string(),
192 VariableValue::String(base_url.to_string()),
193 );
194 }
195 vars
196}
197
198fn normalize_host(host: &str) -> String {
202 let trimmed = host.trim_start_matches('[').trim_end_matches(']');
204 if let Ok(ip) = trimmed.parse::<IpAddr>() {
205 return ip.to_string();
206 }
207 host.to_ascii_lowercase()
208}
209
210fn dedupe_endpoints(mut endpoints: Vec<(String, u16)>) -> Vec<(String, u16)> {
211 endpoints.sort_unstable();
212 endpoints.dedup();
213 endpoints
214}
215
216pub fn format_socket_addr(host: &str, port: u16) -> String {
219 let trimmed = host.trim_start_matches('[').trim_end_matches(']');
220 if trimmed.parse::<std::net::Ipv6Addr>().is_ok() {
221 format!("[{trimmed}]:{port}")
222 } else {
223 format!("{host}:{port}")
224 }
225}
226
227async fn probe_tcp(host: &str, port: u16) -> Reachability {
228 let address = format_socket_addr(host, port);
229 match timeout(CONNECT_TIMEOUT, TcpStream::connect(address.as_str())).await {
230 Ok(Ok(_)) => Reachability::Open,
231 _ => Reachability::Closed,
232 }
233}
234
235#[cfg(test)]
236mod tests {
237 use super::*;
238
239 #[test]
240 fn normalize_ipv6_collapses_long_form() {
241 assert_eq!(normalize_host("::1"), normalize_host("0:0:0:0:0:0:0:1"));
242 assert_eq!(normalize_host("[::1]"), normalize_host("::1"));
243 }
244
245 #[test]
246 fn normalize_hostname_lowercases() {
247 assert_eq!(normalize_host("Example.COM"), "example.com");
248 }
249
250 #[test]
251 fn format_socket_addr_brackets_ipv6() {
252 assert_eq!(format_socket_addr("::1", 443), "[::1]:443");
253 }
254
255 #[test]
256 fn format_socket_addr_passes_through_ipv4() {
257 assert_eq!(format_socket_addr("127.0.0.1", 80), "127.0.0.1:80");
258 }
259
260 #[test]
261 fn format_socket_addr_passes_through_hostname() {
262 assert_eq!(format_socket_addr("example.com", 80), "example.com:80");
263 }
264
265 #[test]
266 fn scan_target_handles_ipv6_url() {
267 let (host, port) = scan_target_host_port("http://[::1]:8080/api").unwrap();
268 assert_eq!(host, "::1");
269 assert_eq!(port, 8080);
270 }
271
272 fn tcp_spec(host: &str) -> ProgramSpec {
273 use crate::runtime::spec::{CheckMetadata, SocketProbeSpec};
274 let mut probes = std::collections::HashMap::new();
275 probes.insert(
276 "svc".to_string(),
277 ProbeKind::Tcp(SocketProbeSpec {
278 host: host.to_string(),
279 port: Some(6379),
280 ..SocketProbeSpec::default()
281 }),
282 );
283 ProgramSpec {
284 probes,
285 metadata: CheckMetadata::default(),
286 }
287 }
288
289 #[test]
290 fn endpoints_interpolate_scan_host_for_socket_probes() {
291 let eps = PortCache::endpoints_for_run(&tcp_spec("{{scan_host}}"), "http://127.0.0.1:6379");
294 assert_eq!(eps, vec![("127.0.0.1".to_string(), 6379)]);
295 }
296
297 #[test]
298 fn endpoints_keep_static_host_unchanged() {
299 let eps = PortCache::endpoints_for_run(&tcp_spec("scanme.example.com"), "");
302 assert_eq!(eps, vec![("scanme.example.com".to_string(), 6379)]);
303 }
304
305 fn single_probe_spec(kind: ProbeKind) -> ProgramSpec {
306 use crate::runtime::spec::CheckMetadata;
307 let mut probes = std::collections::HashMap::new();
308 probes.insert("svc".to_string(), kind);
309 ProgramSpec {
310 probes,
311 metadata: CheckMetadata::default(),
312 }
313 }
314
315 #[test]
316 fn udp_and_dns_probes_get_no_tcp_precheck() {
317 use crate::runtime::spec::SocketProbeSpec;
318 let udp = single_probe_spec(ProbeKind::Udp(SocketProbeSpec {
321 host: "{{scan_host}}".into(),
322 port: Some(123),
323 ..SocketProbeSpec::default()
324 }));
325 assert!(PortCache::endpoints_for_run(&udp, "http://127.0.0.1").is_empty());
326
327 let dns = single_probe_spec(ProbeKind::Dns(SocketProbeSpec {
328 host: "{{scan_host}}".into(),
329 port: Some(53),
330 ..SocketProbeSpec::default()
331 }));
332 assert!(PortCache::endpoints_for_run(&dns, "http://127.0.0.1").is_empty());
333 }
334}