Skip to main content

ruso_runtime/runtime/
http.rs

1use std::collections::HashMap;
2use std::time::{Duration, Instant};
3
4use futures::StreamExt;
5use reqwest::header::HeaderMap;
6use reqwest::{Client, Method, RequestBuilder};
7
8use crate::contract::HttpMethod;
9use crate::runtime::body::{object_to_form, object_to_json, object_to_multipart};
10use crate::runtime::bytes::decode_hex;
11use crate::runtime::context::VariableValue;
12use crate::runtime::duration::parse_duration;
13use crate::runtime::error::RuntimeError;
14use crate::runtime::interpolate::interpolate;
15use crate::runtime::response::HttpResponse;
16use crate::runtime::spec::HttpRequestSpec;
17
18pub async fn execute_http(
19    client: &Client,
20    base_url: &str,
21    spec: &HttpRequestSpec,
22    variables: &HashMap<String, VariableValue>,
23    max_response_bytes: usize,
24    retries: u32,
25) -> Result<HttpResponse, RuntimeError> {
26    let raw_path = spec.path.as_str();
27    let path = interpolate(raw_path, variables)?;
28    let url = join_url(base_url, raw_path, &path)?;
29    let method = to_reqwest_method(&spec.method);
30    let timeout = spec.timeout.as_deref().map(parse_duration).transpose()?;
31
32    tracing::debug!(%url, ?method, "http request");
33
34    let mut builder = client.request(method, url);
35    if let Some(duration) = timeout {
36        builder = builder.timeout(duration);
37    }
38
39    if let Some(agent) = &spec.user_agent {
40        builder = builder.header("user-agent", interpolate(agent, variables)?);
41    }
42
43    for (name, value) in &spec.headers {
44        builder = builder.header(
45            interpolate(name, variables)?,
46            interpolate(value, variables)?,
47        );
48    }
49
50    // RFC 6265 §5.4 requires all cookies to ride in a single `Cookie:` header
51    // joined by `"; "`. reqwest's `.header()` appends, so emitting one call
52    // per cookie produced multiple `Cookie:` headers — accepted leniently by
53    // some servers, rejected outright by others. Build one joined string.
54    if !spec.cookies.is_empty() {
55        let mut parts = Vec::with_capacity(spec.cookies.len());
56        for (name, value) in &spec.cookies {
57            let name = interpolate(name, variables)?;
58            let value = interpolate(value, variables)?;
59            parts.push(format!("{name}={value}"));
60        }
61        builder = builder.header("cookie", parts.join("; "));
62    }
63
64    if !spec.queries.is_empty() {
65        let pairs: Vec<(String, String)> = spec
66            .queries
67            .iter()
68            .map(|(name, value)| {
69                Ok((
70                    interpolate(name, variables)?,
71                    interpolate(value, variables)?,
72                ))
73            })
74            .collect::<Result<_, RuntimeError>>()?;
75        builder = builder.query(&pairs);
76    }
77
78    builder = apply_body(builder, spec, variables)?;
79
80    let (response, elapsed) = send_with_retries(builder, retries).await?;
81    let status = response.status().as_u16();
82    let headers = flatten_headers(response.headers());
83    let body = read_body_capped(response, max_response_bytes).await?;
84
85    Ok(HttpResponse {
86        status,
87        headers,
88        body,
89        elapsed,
90    })
91}
92
93/// Stream the response body into a buffer, stopping once `max_bytes` is
94/// reached. Caps memory use against malicious targets returning multi-GB
95/// payloads. The returned `String` is a lossy UTF-8 decode of the truncated
96/// bytes — matchers that need byte-precise comparison should use socket
97/// probes, not HTTP body.
98async fn read_body_capped(
99    response: reqwest::Response,
100    max_bytes: usize,
101) -> Result<String, RuntimeError> {
102    if max_bytes == 0 {
103        return Ok(String::new());
104    }
105    let mut buf: Vec<u8> = Vec::new();
106    let mut stream = response.bytes_stream();
107    while let Some(chunk) = stream.next().await {
108        let chunk = chunk?;
109        let remaining = max_bytes.saturating_sub(buf.len());
110        if remaining == 0 {
111            tracing::warn!(
112                limit = max_bytes,
113                "http response body truncated at max_response_bytes"
114            );
115            break;
116        }
117        let take = chunk.len().min(remaining);
118        buf.extend_from_slice(&chunk[..take]);
119        if take < chunk.len() {
120            tracing::warn!(
121                limit = max_bytes,
122                "http response body truncated at max_response_bytes"
123            );
124            break;
125        }
126    }
127    Ok(String::from_utf8_lossy(&buf).into_owned())
128}
129
130/// Send `builder`, retrying transient transport failures up to `retries` times.
131///
132/// Returns the response and the elapsed time of the attempt that produced it
133/// (backoff waits are excluded, so `response_time` matchers stay meaningful).
134/// Only connection-level failures are retried — a reset peer or a connect/read
135/// timeout — never a received HTTP response (any status) or a permanent TLS
136/// certificate rejection. Each retry clones the request; a non-cloneable
137/// (streaming) body is sent once.
138async fn send_with_retries(
139    builder: RequestBuilder,
140    retries: u32,
141) -> Result<(reqwest::Response, Duration), reqwest::Error> {
142    let mut attempt = 0u32;
143    loop {
144        // Clone for every attempt except the last so `builder` survives to
145        // retry; on the last attempt (or a non-cloneable body) consume it.
146        let cloned = if attempt < retries {
147            builder.try_clone()
148        } else {
149            None
150        };
151        let Some(request) = cloned else {
152            let started = Instant::now();
153            return builder.send().await.map(|resp| (resp, started.elapsed()));
154        };
155
156        let started = Instant::now();
157        match request.send().await {
158            Ok(response) => return Ok((response, started.elapsed())),
159            Err(err) if is_transient(&err) => {
160                attempt += 1;
161                tokio::time::sleep(retry_backoff(attempt)).await;
162            }
163            Err(err) => return Err(err),
164        }
165    }
166}
167
168/// Backoff before retry attempt `n` (1-based): 300ms, 800ms, then 1.5s.
169fn retry_backoff(n: u32) -> Duration {
170    match n {
171        1 => Duration::from_millis(300),
172        2 => Duration::from_millis(800),
173        _ => Duration::from_millis(1500),
174    }
175}
176
177/// Whether a failed request is worth retrying: a timeout, or a connection-level
178/// failure that is *not* a TLS certificate rejection (those are permanent —
179/// retrying only repeats the same handshake error).
180fn is_transient(err: &reqwest::Error) -> bool {
181    if err.is_timeout() {
182        return true;
183    }
184    err.is_connect() && !error_mentions(err, "certificate")
185}
186
187/// Case-insensitive search for `needle` across an error's full source chain.
188fn error_mentions(err: &reqwest::Error, needle: &str) -> bool {
189    let mut source: Option<&dyn std::error::Error> = Some(err);
190    while let Some(cause) = source {
191        if cause.to_string().to_ascii_lowercase().contains(needle) {
192            return true;
193        }
194        source = cause.source();
195    }
196    false
197}
198
199pub fn build_client(
200    default_timeout: Option<std::time::Duration>,
201    follow_redirect: bool,
202    verify_ssl: bool,
203    proxy: Option<&str>,
204) -> Result<Client, RuntimeError> {
205    let mut builder = Client::builder();
206    if let Some(timeout) = default_timeout {
207        builder = builder.timeout(timeout);
208    }
209    builder = builder.redirect(if follow_redirect {
210        reqwest::redirect::Policy::limited(10)
211    } else {
212        reqwest::redirect::Policy::none()
213    });
214    if !verify_ssl {
215        builder = builder.danger_accept_invalid_certs(true);
216    }
217    if let Some(proxy_url) = proxy {
218        builder = builder.proxy(reqwest::Proxy::all(proxy_url)?);
219    }
220    Ok(builder.build()?)
221}
222
223/// Attach request body; first matching mode in the request block wins (see `http_spec`).
224fn apply_body(
225    builder: RequestBuilder,
226    spec: &HttpRequestSpec,
227    variables: &HashMap<String, VariableValue>,
228) -> Result<RequestBuilder, RuntimeError> {
229    if let Some(body) = &spec.json_body {
230        let json = object_to_json(body, variables)?;
231        return Ok(builder
232            .header("content-type", "application/json")
233            .body(json));
234    }
235    if let Some(body) = &spec.data_body {
236        let form = object_to_form(body, variables)?;
237        return Ok(builder.form(&form));
238    }
239    if let Some(raw) = &spec.raw_body {
240        return Ok(builder.body(interpolate(raw, variables)?));
241    }
242    if let Some(body) = &spec.multipart_body {
243        let form = object_to_multipart(body, variables)?;
244        return Ok(builder.multipart(form));
245    }
246    if let Some(hex) = &spec.body_bytes {
247        let bytes = decode_hex(&interpolate(hex, variables)?)?;
248        return Ok(builder.body(bytes));
249    }
250    Ok(builder)
251}
252
253fn to_reqwest_method(method: &HttpMethod) -> Method {
254    match method {
255        HttpMethod::Get => Method::GET,
256        HttpMethod::Post => Method::POST,
257        HttpMethod::Put => Method::PUT,
258        HttpMethod::Patch => Method::PATCH,
259        HttpMethod::Delete => Method::DELETE,
260        HttpMethod::Head => Method::HEAD,
261        HttpMethod::Options => Method::OPTIONS,
262    }
263}
264
265/// Resolve the request URL relative to the scan target (`base`).
266///
267/// `raw_path` is the path written in the script (pre-interpolation);
268/// `resolved_path` is the same string after `{{ var }}` substitution.
269/// We allow absolute URLs *if and only if* the script itself wrote one —
270/// allowing interpolation to switch the scheme/host opens an SSRF where a
271/// previously-extracted variable can redirect later probes to internal
272/// services (`http://169.254.169.254/...`, `http://localhost:6379/...`).
273fn join_url(base: &str, raw_path: &str, resolved_path: &str) -> Result<String, RuntimeError> {
274    let raw_is_absolute = is_absolute_http_url(raw_path);
275    let resolved_is_absolute = is_absolute_http_url(resolved_path);
276    if raw_is_absolute {
277        return Ok(resolved_path.to_string());
278    }
279    if resolved_is_absolute {
280        // Interpolation produced an absolute URL from a relative template —
281        // a hostname swap by variable. Reject.
282        return Err(RuntimeError::Other(format!(
283            "interpolated path switched to absolute URL; refusing as SSRF guard: {resolved_path}"
284        )));
285    }
286    let base = base.trim_end_matches('/');
287    let path = if resolved_path.starts_with('/') {
288        resolved_path.to_string()
289    } else {
290        format!("/{resolved_path}")
291    };
292    Ok(format!("{base}{path}"))
293}
294
295fn is_absolute_http_url(value: &str) -> bool {
296    let lower = value.trim_start();
297    lower.starts_with("http://") || lower.starts_with("https://")
298}
299
300/// Convert reqwest's `HeaderMap` into a `HashMap<String, String>` for
301/// downstream matchers. HTTP allows multi-valued headers (most notably
302/// `Set-Cookie`, also `WWW-Authenticate`, `Link`, etc.); reqwest's
303/// `HeaderMap` keeps them as separate entries, but matchers expect one
304/// string per header name. We collapse duplicates by joining with `", "`,
305/// which is the RFC 7230 §3.2.2 combining rule for the headers it applies
306/// to; `Set-Cookie` is the documented exception, but matchers operate by
307/// substring search ("`HttpOnly`", "`Secure`", domain names) which still
308/// works correctly against a joined string.
309fn flatten_headers(headers: &HeaderMap) -> HashMap<String, String> {
310    let mut map: HashMap<String, String> = HashMap::new();
311    for (name, value) in headers.iter() {
312        let Ok(text) = value.to_str() else { continue };
313        match map.entry(name.as_str().to_string()) {
314            std::collections::hash_map::Entry::Occupied(mut e) => {
315                let combined = format!("{}, {}", e.get(), text);
316                e.insert(combined);
317            }
318            std::collections::hash_map::Entry::Vacant(e) => {
319                e.insert(text.to_string());
320            }
321        }
322    }
323    map
324}
325
326#[cfg(test)]
327mod tests {
328    use super::*;
329    use reqwest::header::{HeaderMap, HeaderValue};
330
331    #[test]
332    fn join_url_appends_relative_path() {
333        let url = join_url("http://t.example", "/api", "/api").unwrap();
334        assert_eq!(url, "http://t.example/api");
335    }
336
337    #[test]
338    fn join_url_handles_path_without_leading_slash() {
339        let url = join_url("http://t.example", "api", "api").unwrap();
340        assert_eq!(url, "http://t.example/api");
341    }
342
343    #[test]
344    fn join_url_strips_trailing_base_slash() {
345        let url = join_url("http://t.example/", "/api", "/api").unwrap();
346        assert_eq!(url, "http://t.example/api");
347    }
348
349    #[test]
350    fn join_url_allows_explicitly_absolute_path() {
351        // Script writer wrote a full URL into the path — that's an explicit
352        // opt-in, so we honor it (e.g. probing a different origin during a
353        // multi-host check).
354        let url = join_url(
355            "http://t.example",
356            "https://other.example/x",
357            "https://other.example/x",
358        )
359        .unwrap();
360        assert_eq!(url, "https://other.example/x");
361    }
362
363    #[test]
364    fn join_url_rejects_interpolated_scheme_switch() {
365        // The template is a relative path, but a variable expanded into a
366        // full URL. That's an SSRF vector — refuse.
367        let err = join_url(
368            "http://t.example",
369            "/api/{{ next }}",
370            "http://169.254.169.254/latest/meta-data",
371        )
372        .unwrap_err();
373        let msg = err.to_string();
374        assert!(msg.contains("SSRF"), "expected SSRF guard, got: {msg}");
375    }
376
377    #[test]
378    fn join_url_rejects_interpolated_localhost_redirect() {
379        let err = join_url(
380            "http://t.example",
381            "{{ extracted }}",
382            "http://localhost:6379/info",
383        )
384        .unwrap_err();
385        assert!(err.to_string().contains("SSRF"));
386    }
387
388    #[test]
389    fn flatten_headers_combines_duplicates() {
390        let mut headers = HeaderMap::new();
391        headers.append("set-cookie", HeaderValue::from_static("a=1"));
392        headers.append("set-cookie", HeaderValue::from_static("b=2"));
393        let flat = flatten_headers(&headers);
394        let cookie = flat.get("set-cookie").expect("set-cookie");
395        // Joined form so substring matchers on cookie attributes still work.
396        assert!(cookie.contains("a=1"));
397        assert!(cookie.contains("b=2"));
398    }
399
400    #[test]
401    fn flatten_headers_keeps_single_value_intact() {
402        let mut headers = HeaderMap::new();
403        headers.insert("x-foo", HeaderValue::from_static("bar"));
404        let flat = flatten_headers(&headers);
405        assert_eq!(flat.get("x-foo").map(String::as_str), Some("bar"));
406    }
407
408    /// Localhost server that stalls the first connection past the probe timeout
409    /// (a transient failure), then answers every later connection with 200 OK.
410    /// Each connection is handled on its own thread so `accept` never blocks.
411    fn spawn_stall_then_ok_server() -> u16 {
412        use std::io::{Read, Write};
413        use std::net::TcpListener;
414        use std::sync::Arc;
415        use std::sync::atomic::{AtomicU32, Ordering};
416
417        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
418        let port = listener.local_addr().unwrap().port();
419        let seen = Arc::new(AtomicU32::new(0));
420        std::thread::spawn(move || {
421            for stream in listener.incoming().flatten() {
422                let seen = seen.clone();
423                std::thread::spawn(move || {
424                    let mut stream = stream;
425                    if seen.fetch_add(1, Ordering::SeqCst) == 0 {
426                        std::thread::sleep(Duration::from_secs(2));
427                        return;
428                    }
429                    let mut buf = [0u8; 512];
430                    let _ = stream.read(&mut buf);
431                    let _ = stream.write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n");
432                });
433            }
434        });
435        port
436    }
437
438    fn probe_spec() -> HttpRequestSpec {
439        HttpRequestSpec {
440            method: HttpMethod::Get,
441            path: "/".into(),
442            timeout: Some("700ms".into()),
443            ..Default::default()
444        }
445    }
446
447    #[tokio::test]
448    async fn retries_recover_a_transient_failure() {
449        let base = format!("http://127.0.0.1:{}", spawn_stall_then_ok_server());
450        let client = build_client(None, false, true, None).unwrap();
451        let result = execute_http(&client, &base, &probe_spec(), &HashMap::new(), 1 << 20, 2).await;
452        assert!(result.is_ok(), "retry should recover; got {result:?}");
453        assert_eq!(result.unwrap().status, 200);
454    }
455
456    #[tokio::test]
457    async fn zero_retries_surfaces_the_transient_failure() {
458        // 0 retries is what a probe driven by the script's own `retry` gets:
459        // the single stalled attempt must fail rather than silently retrying.
460        let base = format!("http://127.0.0.1:{}", spawn_stall_then_ok_server());
461        let client = build_client(None, false, true, None).unwrap();
462        let result = execute_http(&client, &base, &probe_spec(), &HashMap::new(), 1 << 20, 0).await;
463        assert!(result.is_err(), "with no retries the failure must surface");
464    }
465}