1use std::collections::HashMap;
2use std::time::{Duration, Instant};
3
4use futures::StreamExt;
5use reqwest::header::HeaderMap;
6use reqwest::{Client, Method, RequestBuilder};
7
8use crate::contract::HttpMethod;
9use crate::runtime::body::{object_to_form, object_to_json, object_to_multipart};
10use crate::runtime::bytes::decode_hex;
11use crate::runtime::context::VariableValue;
12use crate::runtime::duration::parse_duration;
13use crate::runtime::error::RuntimeError;
14use crate::runtime::interpolate::interpolate;
15use crate::runtime::response::HttpResponse;
16use crate::runtime::spec::HttpRequestSpec;
17
18pub async fn execute_http(
19 client: &Client,
20 base_url: &str,
21 spec: &HttpRequestSpec,
22 variables: &HashMap<String, VariableValue>,
23 max_response_bytes: usize,
24 retries: u32,
25) -> Result<HttpResponse, RuntimeError> {
26 let raw_path = spec.path.as_str();
27 let path = interpolate(raw_path, variables)?;
28 let url = join_url(base_url, raw_path, &path)?;
29 let method = to_reqwest_method(&spec.method);
30 let timeout = spec.timeout.as_deref().map(parse_duration).transpose()?;
31
32 tracing::debug!(%url, ?method, "http request");
33
34 let mut builder = client.request(method, url);
35 if let Some(duration) = timeout {
36 builder = builder.timeout(duration);
37 }
38
39 if let Some(agent) = &spec.user_agent {
40 builder = builder.header("user-agent", interpolate(agent, variables)?);
41 }
42
43 for (name, value) in &spec.headers {
44 builder = builder.header(
45 interpolate(name, variables)?,
46 interpolate(value, variables)?,
47 );
48 }
49
50 if !spec.cookies.is_empty() {
55 let mut parts = Vec::with_capacity(spec.cookies.len());
56 for (name, value) in &spec.cookies {
57 let name = interpolate(name, variables)?;
58 let value = interpolate(value, variables)?;
59 parts.push(format!("{name}={value}"));
60 }
61 builder = builder.header("cookie", parts.join("; "));
62 }
63
64 if !spec.queries.is_empty() {
65 let pairs: Vec<(String, String)> = spec
66 .queries
67 .iter()
68 .map(|(name, value)| {
69 Ok((
70 interpolate(name, variables)?,
71 interpolate(value, variables)?,
72 ))
73 })
74 .collect::<Result<_, RuntimeError>>()?;
75 builder = builder.query(&pairs);
76 }
77
78 builder = apply_body(builder, spec, variables)?;
79
80 let (response, elapsed) = send_with_retries(builder, retries).await?;
81 let status = response.status().as_u16();
82 let headers = flatten_headers(response.headers());
83 let body = read_body_capped(response, max_response_bytes).await?;
84
85 Ok(HttpResponse {
86 status,
87 headers,
88 body,
89 elapsed,
90 })
91}
92
93async fn read_body_capped(
99 response: reqwest::Response,
100 max_bytes: usize,
101) -> Result<String, RuntimeError> {
102 if max_bytes == 0 {
103 return Ok(String::new());
104 }
105 let mut buf: Vec<u8> = Vec::new();
106 let mut stream = response.bytes_stream();
107 while let Some(chunk) = stream.next().await {
108 let chunk = chunk?;
109 let remaining = max_bytes.saturating_sub(buf.len());
110 if remaining == 0 {
111 tracing::warn!(
112 limit = max_bytes,
113 "http response body truncated at max_response_bytes"
114 );
115 break;
116 }
117 let take = chunk.len().min(remaining);
118 buf.extend_from_slice(&chunk[..take]);
119 if take < chunk.len() {
120 tracing::warn!(
121 limit = max_bytes,
122 "http response body truncated at max_response_bytes"
123 );
124 break;
125 }
126 }
127 Ok(String::from_utf8_lossy(&buf).into_owned())
128}
129
130async fn send_with_retries(
139 builder: RequestBuilder,
140 retries: u32,
141) -> Result<(reqwest::Response, Duration), reqwest::Error> {
142 let mut attempt = 0u32;
143 loop {
144 let cloned = if attempt < retries {
147 builder.try_clone()
148 } else {
149 None
150 };
151 let Some(request) = cloned else {
152 let started = Instant::now();
153 return builder.send().await.map(|resp| (resp, started.elapsed()));
154 };
155
156 let started = Instant::now();
157 match request.send().await {
158 Ok(response) => return Ok((response, started.elapsed())),
159 Err(err) if is_transient(&err) => {
160 attempt += 1;
161 tokio::time::sleep(retry_backoff(attempt)).await;
162 }
163 Err(err) => return Err(err),
164 }
165 }
166}
167
168fn retry_backoff(n: u32) -> Duration {
170 match n {
171 1 => Duration::from_millis(300),
172 2 => Duration::from_millis(800),
173 _ => Duration::from_millis(1500),
174 }
175}
176
177fn is_transient(err: &reqwest::Error) -> bool {
181 if err.is_timeout() {
182 return true;
183 }
184 err.is_connect() && !error_mentions(err, "certificate")
185}
186
187fn error_mentions(err: &reqwest::Error, needle: &str) -> bool {
189 let mut source: Option<&dyn std::error::Error> = Some(err);
190 while let Some(cause) = source {
191 if cause.to_string().to_ascii_lowercase().contains(needle) {
192 return true;
193 }
194 source = cause.source();
195 }
196 false
197}
198
199pub fn build_client(
200 default_timeout: Option<std::time::Duration>,
201 follow_redirect: bool,
202 verify_ssl: bool,
203 proxy: Option<&str>,
204) -> Result<Client, RuntimeError> {
205 let mut builder = Client::builder();
206 if let Some(timeout) = default_timeout {
207 builder = builder.timeout(timeout);
208 }
209 builder = builder.redirect(if follow_redirect {
210 reqwest::redirect::Policy::limited(10)
211 } else {
212 reqwest::redirect::Policy::none()
213 });
214 if !verify_ssl {
215 builder = builder.danger_accept_invalid_certs(true);
216 }
217 if let Some(proxy_url) = proxy {
218 builder = builder.proxy(reqwest::Proxy::all(proxy_url)?);
219 }
220 Ok(builder.build()?)
221}
222
223fn apply_body(
225 builder: RequestBuilder,
226 spec: &HttpRequestSpec,
227 variables: &HashMap<String, VariableValue>,
228) -> Result<RequestBuilder, RuntimeError> {
229 if let Some(body) = &spec.json_body {
230 let json = object_to_json(body, variables)?;
231 return Ok(builder
232 .header("content-type", "application/json")
233 .body(json));
234 }
235 if let Some(body) = &spec.data_body {
236 let form = object_to_form(body, variables)?;
237 return Ok(builder.form(&form));
238 }
239 if let Some(raw) = &spec.raw_body {
240 return Ok(builder.body(interpolate(raw, variables)?));
241 }
242 if let Some(body) = &spec.multipart_body {
243 let form = object_to_multipart(body, variables)?;
244 return Ok(builder.multipart(form));
245 }
246 if let Some(hex) = &spec.body_bytes {
247 let bytes = decode_hex(&interpolate(hex, variables)?)?;
248 return Ok(builder.body(bytes));
249 }
250 Ok(builder)
251}
252
253fn to_reqwest_method(method: &HttpMethod) -> Method {
254 match method {
255 HttpMethod::Get => Method::GET,
256 HttpMethod::Post => Method::POST,
257 HttpMethod::Put => Method::PUT,
258 HttpMethod::Patch => Method::PATCH,
259 HttpMethod::Delete => Method::DELETE,
260 HttpMethod::Head => Method::HEAD,
261 HttpMethod::Options => Method::OPTIONS,
262 }
263}
264
265fn join_url(base: &str, raw_path: &str, resolved_path: &str) -> Result<String, RuntimeError> {
274 let raw_is_absolute = is_absolute_http_url(raw_path);
275 let resolved_is_absolute = is_absolute_http_url(resolved_path);
276 if raw_is_absolute {
277 return Ok(resolved_path.to_string());
278 }
279 if resolved_is_absolute {
280 return Err(RuntimeError::Other(format!(
283 "interpolated path switched to absolute URL; refusing as SSRF guard: {resolved_path}"
284 )));
285 }
286 let base = base.trim_end_matches('/');
287 let path = if resolved_path.starts_with('/') {
288 resolved_path.to_string()
289 } else {
290 format!("/{resolved_path}")
291 };
292 Ok(format!("{base}{path}"))
293}
294
295fn is_absolute_http_url(value: &str) -> bool {
296 let lower = value.trim_start();
297 lower.starts_with("http://") || lower.starts_with("https://")
298}
299
300fn flatten_headers(headers: &HeaderMap) -> HashMap<String, String> {
310 let mut map: HashMap<String, String> = HashMap::new();
311 for (name, value) in headers.iter() {
312 let Ok(text) = value.to_str() else { continue };
313 match map.entry(name.as_str().to_string()) {
314 std::collections::hash_map::Entry::Occupied(mut e) => {
315 let combined = format!("{}, {}", e.get(), text);
316 e.insert(combined);
317 }
318 std::collections::hash_map::Entry::Vacant(e) => {
319 e.insert(text.to_string());
320 }
321 }
322 }
323 map
324}
325
326#[cfg(test)]
327mod tests {
328 use super::*;
329 use reqwest::header::{HeaderMap, HeaderValue};
330
331 #[test]
332 fn join_url_appends_relative_path() {
333 let url = join_url("http://t.example", "/api", "/api").unwrap();
334 assert_eq!(url, "http://t.example/api");
335 }
336
337 #[test]
338 fn join_url_handles_path_without_leading_slash() {
339 let url = join_url("http://t.example", "api", "api").unwrap();
340 assert_eq!(url, "http://t.example/api");
341 }
342
343 #[test]
344 fn join_url_strips_trailing_base_slash() {
345 let url = join_url("http://t.example/", "/api", "/api").unwrap();
346 assert_eq!(url, "http://t.example/api");
347 }
348
349 #[test]
350 fn join_url_allows_explicitly_absolute_path() {
351 let url = join_url(
355 "http://t.example",
356 "https://other.example/x",
357 "https://other.example/x",
358 )
359 .unwrap();
360 assert_eq!(url, "https://other.example/x");
361 }
362
363 #[test]
364 fn join_url_rejects_interpolated_scheme_switch() {
365 let err = join_url(
368 "http://t.example",
369 "/api/{{ next }}",
370 "http://169.254.169.254/latest/meta-data",
371 )
372 .unwrap_err();
373 let msg = err.to_string();
374 assert!(msg.contains("SSRF"), "expected SSRF guard, got: {msg}");
375 }
376
377 #[test]
378 fn join_url_rejects_interpolated_localhost_redirect() {
379 let err = join_url(
380 "http://t.example",
381 "{{ extracted }}",
382 "http://localhost:6379/info",
383 )
384 .unwrap_err();
385 assert!(err.to_string().contains("SSRF"));
386 }
387
388 #[test]
389 fn flatten_headers_combines_duplicates() {
390 let mut headers = HeaderMap::new();
391 headers.append("set-cookie", HeaderValue::from_static("a=1"));
392 headers.append("set-cookie", HeaderValue::from_static("b=2"));
393 let flat = flatten_headers(&headers);
394 let cookie = flat.get("set-cookie").expect("set-cookie");
395 assert!(cookie.contains("a=1"));
397 assert!(cookie.contains("b=2"));
398 }
399
400 #[test]
401 fn flatten_headers_keeps_single_value_intact() {
402 let mut headers = HeaderMap::new();
403 headers.insert("x-foo", HeaderValue::from_static("bar"));
404 let flat = flatten_headers(&headers);
405 assert_eq!(flat.get("x-foo").map(String::as_str), Some("bar"));
406 }
407
408 fn spawn_stall_then_ok_server() -> u16 {
412 use std::io::{Read, Write};
413 use std::net::TcpListener;
414 use std::sync::Arc;
415 use std::sync::atomic::{AtomicU32, Ordering};
416
417 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
418 let port = listener.local_addr().unwrap().port();
419 let seen = Arc::new(AtomicU32::new(0));
420 std::thread::spawn(move || {
421 for stream in listener.incoming().flatten() {
422 let seen = seen.clone();
423 std::thread::spawn(move || {
424 let mut stream = stream;
425 if seen.fetch_add(1, Ordering::SeqCst) == 0 {
426 std::thread::sleep(Duration::from_secs(2));
427 return;
428 }
429 let mut buf = [0u8; 512];
430 let _ = stream.read(&mut buf);
431 let _ = stream.write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n");
432 });
433 }
434 });
435 port
436 }
437
438 fn probe_spec() -> HttpRequestSpec {
439 HttpRequestSpec {
440 method: HttpMethod::Get,
441 path: "/".into(),
442 timeout: Some("700ms".into()),
443 ..Default::default()
444 }
445 }
446
447 #[tokio::test]
448 async fn retries_recover_a_transient_failure() {
449 let base = format!("http://127.0.0.1:{}", spawn_stall_then_ok_server());
450 let client = build_client(None, false, true, None).unwrap();
451 let result = execute_http(&client, &base, &probe_spec(), &HashMap::new(), 1 << 20, 2).await;
452 assert!(result.is_ok(), "retry should recover; got {result:?}");
453 assert_eq!(result.unwrap().status, 200);
454 }
455
456 #[tokio::test]
457 async fn zero_retries_surfaces_the_transient_failure() {
458 let base = format!("http://127.0.0.1:{}", spawn_stall_then_ok_server());
461 let client = build_client(None, false, true, None).unwrap();
462 let result = execute_http(&client, &base, &probe_spec(), &HashMap::new(), 1 << 20, 0).await;
463 assert!(result.is_err(), "with no retries the failure must surface");
464 }
465}