spider_middleware/
autothrottle.rs

1//! Adaptive throttling middleware.
2//!
3//! [`AutoThrottleMiddleware`] adjusts request pacing from observed latency and
4//! response status so a crawl can speed up or slow down automatically.
5
6use async_trait::async_trait;
7use log::debug;
8use moka::future::Cache;
9use rand::distributions::{Distribution, Uniform};
10use spider_util::constants::{MIDDLEWARE_CACHE_CAPACITY, MIDDLEWARE_CACHE_TTL_SECS};
11use spider_util::error::SpiderError;
12use spider_util::request::Request;
13use spider_util::response::Response;
14use std::time::{Duration, SystemTime, UNIX_EPOCH};
15use tokio::sync::Mutex;
16use tokio::time::{Instant, sleep};
17
18use crate::middleware::{Middleware, MiddlewareAction};
19
20const STARTED_AT_META_KEY: &str = "__autothrottle_started_at_ms";
21
22/// Scope used to isolate throttle state.
23#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
24pub enum Scope {
25    /// A single shared throttle state for all requests.
26    Global,
27    /// Independent throttle state per origin (`scheme://host:port`).
28    Domain,
29}
30
31#[derive(Debug, Clone)]
32struct ThrottleState {
33    delay: Duration,
34    next_allowed_at: Instant,
35}
36
37/// Middleware that adapts pacing dynamically based on response feedback.
38pub struct AutoThrottleMiddleware {
39    scope: Scope,
40    states: Cache<String, std::sync::Arc<Mutex<ThrottleState>>>,
41    min_delay: Duration,
42    max_delay: Duration,
43    target_concurrency: f64,
44    smoothing_factor: f64,
45    error_penalty: f64,
46    forbidden_penalty: f64,
47    too_many_penalty: f64,
48    jitter: bool,
49}
50
51impl Default for AutoThrottleMiddleware {
52    fn default() -> Self {
53        Self::builder().build()
54    }
55}
56
57impl AutoThrottleMiddleware {
58    /// Creates a new builder for [`AutoThrottleMiddleware`].
59    pub fn builder() -> AutoThrottleMiddlewareBuilder {
60        AutoThrottleMiddlewareBuilder::default()
61    }
62
63    fn scope_key(&self, request: &Request) -> String {
64        match self.scope {
65            Scope::Global => "global".to_string(),
66            Scope::Domain => spider_util::util::normalize_origin(request),
67        }
68    }
69
70    fn apply_jitter(&self, delay: Duration) -> Duration {
71        if !self.jitter || delay.is_zero() {
72            return delay;
73        }
74
75        let jitter_window = delay.mul_f64(0.25).min(Duration::from_millis(500));
76        let low = delay.saturating_sub(jitter_window);
77        let high = delay + jitter_window;
78
79        let mut rng = rand::thread_rng();
80        let uniform = Uniform::new_inclusive(low, high);
81        uniform.sample(&mut rng)
82    }
83}
84
85#[async_trait]
86impl<C: Send + Sync> Middleware<C> for AutoThrottleMiddleware {
87    fn name(&self) -> &str {
88        "AutoThrottleMiddleware"
89    }
90
91    async fn process_request(
92        &self,
93        _client: &C,
94        mut request: Request,
95    ) -> Result<MiddlewareAction<Request>, SpiderError> {
96        let key = self.scope_key(&request);
97        let state = self
98            .states
99            .get_with(key, async {
100                std::sync::Arc::new(Mutex::new(ThrottleState {
101                    delay: self.min_delay,
102                    next_allowed_at: Instant::now(),
103                }))
104            })
105            .await;
106
107        let sleep_duration = {
108            let mut state_guard = state.lock().await;
109            let now = Instant::now();
110            let delay = state_guard.delay;
111
112            if now < state_guard.next_allowed_at {
113                let wait = state_guard.next_allowed_at - now;
114                state_guard.next_allowed_at += delay;
115                wait
116            } else {
117                state_guard.next_allowed_at = now + delay;
118                Duration::ZERO
119            }
120        };
121
122        let sleep_duration = self.apply_jitter(sleep_duration);
123        if !sleep_duration.is_zero() {
124            sleep(sleep_duration).await;
125        }
126
127        if let Ok(since_epoch) = SystemTime::now().duration_since(UNIX_EPOCH) {
128            request.insert_meta(
129                STARTED_AT_META_KEY.to_string(),
130                serde_json::Value::from(since_epoch.as_millis().min(u128::from(u64::MAX)) as u64),
131            );
132        }
133
134        Ok(MiddlewareAction::Continue(request))
135    }
136
137    async fn process_response(
138        &self,
139        response: Response,
140    ) -> Result<MiddlewareAction<Response>, SpiderError> {
141        if response.cached {
142            return Ok(MiddlewareAction::Continue(response));
143        }
144
145        let key = self.scope_key(&response.request_from_response());
146
147        let Some(state) = self.states.get(&key).await else {
148            return Ok(MiddlewareAction::Continue(response));
149        };
150
151        let observed_latency = response
152            .meta
153            .as_ref()
154            .and_then(|meta| meta.get(STARTED_AT_META_KEY).map(|v| v.value().clone()))
155            .and_then(|v| v.as_u64())
156            .and_then(|started_at_ms| {
157                SystemTime::now()
158                    .duration_since(UNIX_EPOCH)
159                    .ok()
160                    .map(|now| now.as_millis().saturating_sub(u128::from(started_at_ms)))
161            })
162            .map(|delta_ms| {
163                let bounded = delta_ms.min(u128::from(u64::MAX)) as u64;
164                Duration::from_millis(bounded)
165            });
166
167        let status = response.status.as_u16();
168        let mut guard = state.lock().await;
169        let old_delay = guard.delay;
170
171        if let Some(latency) = observed_latency {
172            let target_delay = latency
173                .div_f64(self.target_concurrency.max(0.1))
174                .clamp(self.min_delay, self.max_delay);
175            let smoothed = old_delay.mul_f64(1.0 - self.smoothing_factor)
176                + target_delay.mul_f64(self.smoothing_factor);
177            guard.delay = smoothed.clamp(self.min_delay, self.max_delay);
178        }
179
180        match status {
181            429 => guard.delay = guard.delay.mul_f64(self.too_many_penalty),
182            403 => guard.delay = guard.delay.mul_f64(self.forbidden_penalty),
183            500..=599 => guard.delay = guard.delay.mul_f64(self.error_penalty),
184            _ => {}
185        }
186        guard.delay = guard.delay.clamp(self.min_delay, self.max_delay);
187
188        if old_delay != guard.delay {
189            debug!(
190                "AutoThrottle adjusted delay for '{}': {:?} -> {:?} (status={})",
191                key, old_delay, guard.delay, status
192            );
193        }
194
195        Ok(MiddlewareAction::Continue(response))
196    }
197}
198
199/// Builder for [`AutoThrottleMiddleware`].
200pub struct AutoThrottleMiddlewareBuilder {
201    scope: Scope,
202    min_delay: Duration,
203    max_delay: Duration,
204    target_concurrency: f64,
205    smoothing_factor: f64,
206    error_penalty: f64,
207    forbidden_penalty: f64,
208    too_many_penalty: f64,
209    cache_ttl: Duration,
210    cache_capacity: u64,
211    jitter: bool,
212}
213
214impl Default for AutoThrottleMiddlewareBuilder {
215    fn default() -> Self {
216        Self {
217            scope: Scope::Domain,
218            min_delay: Duration::from_millis(50),
219            max_delay: Duration::from_secs(60),
220            target_concurrency: 1.0,
221            smoothing_factor: 0.3,
222            error_penalty: 1.5,
223            forbidden_penalty: 1.2,
224            too_many_penalty: 2.0,
225            cache_ttl: Duration::from_secs(MIDDLEWARE_CACHE_TTL_SECS),
226            cache_capacity: MIDDLEWARE_CACHE_CAPACITY,
227            jitter: true,
228        }
229    }
230}
231
232impl AutoThrottleMiddlewareBuilder {
233    /// Sets throttling scope.
234    pub fn scope(mut self, scope: Scope) -> Self {
235        self.scope = scope;
236        self
237    }
238
239    /// Sets minimum delay between requests.
240    pub fn min_delay(mut self, min_delay: Duration) -> Self {
241        self.min_delay = min_delay;
242        self
243    }
244
245    /// Sets maximum delay between requests.
246    pub fn max_delay(mut self, max_delay: Duration) -> Self {
247        self.max_delay = max_delay;
248        self
249    }
250
251    /// Sets target concurrency used in `latency / target_concurrency`.
252    pub fn target_concurrency(mut self, target_concurrency: f64) -> Self {
253        self.target_concurrency = target_concurrency;
254        self
255    }
256
257    /// Sets smoothing factor (0.0..=1.0) for delay updates.
258    pub fn smoothing_factor(mut self, smoothing_factor: f64) -> Self {
259        self.smoothing_factor = smoothing_factor.clamp(0.0, 1.0);
260        self
261    }
262
263    /// Sets multiplier for 5xx responses.
264    pub fn error_penalty(mut self, error_penalty: f64) -> Self {
265        self.error_penalty = error_penalty.max(1.0);
266        self
267    }
268
269    /// Sets multiplier for 403 responses.
270    pub fn forbidden_penalty(mut self, forbidden_penalty: f64) -> Self {
271        self.forbidden_penalty = forbidden_penalty.max(1.0);
272        self
273    }
274
275    /// Sets multiplier for 429 responses.
276    pub fn too_many_penalty(mut self, too_many_penalty: f64) -> Self {
277        self.too_many_penalty = too_many_penalty.max(1.0);
278        self
279    }
280
281    /// Enables/disables sleep jitter.
282    pub fn jitter(mut self, jitter: bool) -> Self {
283        self.jitter = jitter;
284        self
285    }
286
287    /// Sets middleware state cache TTL.
288    pub fn cache_ttl(mut self, cache_ttl: Duration) -> Self {
289        self.cache_ttl = cache_ttl;
290        self
291    }
292
293    /// Sets middleware state cache capacity.
294    pub fn cache_capacity(mut self, cache_capacity: u64) -> Self {
295        self.cache_capacity = cache_capacity;
296        self
297    }
298
299    /// Builds [`AutoThrottleMiddleware`].
300    pub fn build(self) -> AutoThrottleMiddleware {
301        let min_delay = self.min_delay.min(self.max_delay);
302        let max_delay = self.max_delay.max(self.min_delay);
303
304        AutoThrottleMiddleware {
305            scope: self.scope,
306            states: Cache::builder()
307                .time_to_idle(self.cache_ttl)
308                .max_capacity(self.cache_capacity)
309                .build(),
310            min_delay,
311            max_delay,
312            target_concurrency: self.target_concurrency.max(0.1),
313            smoothing_factor: self.smoothing_factor.clamp(0.0, 1.0),
314            error_penalty: self.error_penalty.max(1.0),
315            forbidden_penalty: self.forbidden_penalty.max(1.0),
316            too_many_penalty: self.too_many_penalty.max(1.0),
317            jitter: self.jitter,
318        }
319    }
320}