1use async_trait::async_trait;
7use log::debug;
8use moka::future::Cache;
9use rand::distributions::{Distribution, Uniform};
10use spider_util::constants::{MIDDLEWARE_CACHE_CAPACITY, MIDDLEWARE_CACHE_TTL_SECS};
11use spider_util::error::SpiderError;
12use spider_util::request::Request;
13use spider_util::response::Response;
14use std::time::{Duration, SystemTime, UNIX_EPOCH};
15use tokio::sync::Mutex;
16use tokio::time::{Instant, sleep};
17
18use crate::middleware::{Middleware, MiddlewareAction};
19
20const STARTED_AT_META_KEY: &str = "__autothrottle_started_at_ms";
21
22#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
24pub enum Scope {
25 Global,
27 Domain,
29}
30
31#[derive(Debug, Clone)]
32struct ThrottleState {
33 delay: Duration,
34 next_allowed_at: Instant,
35}
36
37pub struct AutoThrottleMiddleware {
39 scope: Scope,
40 states: Cache<String, std::sync::Arc<Mutex<ThrottleState>>>,
41 min_delay: Duration,
42 max_delay: Duration,
43 target_concurrency: f64,
44 smoothing_factor: f64,
45 error_penalty: f64,
46 forbidden_penalty: f64,
47 too_many_penalty: f64,
48 jitter: bool,
49}
50
51impl Default for AutoThrottleMiddleware {
52 fn default() -> Self {
53 Self::builder().build()
54 }
55}
56
57impl AutoThrottleMiddleware {
58 pub fn builder() -> AutoThrottleMiddlewareBuilder {
60 AutoThrottleMiddlewareBuilder::default()
61 }
62
63 fn scope_key(&self, request: &Request) -> String {
64 match self.scope {
65 Scope::Global => "global".to_string(),
66 Scope::Domain => spider_util::util::normalize_origin(request),
67 }
68 }
69
70 fn apply_jitter(&self, delay: Duration) -> Duration {
71 if !self.jitter || delay.is_zero() {
72 return delay;
73 }
74
75 let jitter_window = delay.mul_f64(0.25).min(Duration::from_millis(500));
76 let low = delay.saturating_sub(jitter_window);
77 let high = delay + jitter_window;
78
79 let mut rng = rand::thread_rng();
80 let uniform = Uniform::new_inclusive(low, high);
81 uniform.sample(&mut rng)
82 }
83}
84
85#[async_trait]
86impl<C: Send + Sync> Middleware<C> for AutoThrottleMiddleware {
87 fn name(&self) -> &str {
88 "AutoThrottleMiddleware"
89 }
90
91 async fn process_request(
92 &self,
93 _client: &C,
94 mut request: Request,
95 ) -> Result<MiddlewareAction<Request>, SpiderError> {
96 let key = self.scope_key(&request);
97 let state = self
98 .states
99 .get_with(key, async {
100 std::sync::Arc::new(Mutex::new(ThrottleState {
101 delay: self.min_delay,
102 next_allowed_at: Instant::now(),
103 }))
104 })
105 .await;
106
107 let sleep_duration = {
108 let mut state_guard = state.lock().await;
109 let now = Instant::now();
110 let delay = state_guard.delay;
111
112 if now < state_guard.next_allowed_at {
113 let wait = state_guard.next_allowed_at - now;
114 state_guard.next_allowed_at += delay;
115 wait
116 } else {
117 state_guard.next_allowed_at = now + delay;
118 Duration::ZERO
119 }
120 };
121
122 let sleep_duration = self.apply_jitter(sleep_duration);
123 if !sleep_duration.is_zero() {
124 sleep(sleep_duration).await;
125 }
126
127 if let Ok(since_epoch) = SystemTime::now().duration_since(UNIX_EPOCH) {
128 request.insert_meta(
129 STARTED_AT_META_KEY.to_string(),
130 serde_json::Value::from(since_epoch.as_millis().min(u128::from(u64::MAX)) as u64),
131 );
132 }
133
134 Ok(MiddlewareAction::Continue(request))
135 }
136
137 async fn process_response(
138 &self,
139 response: Response,
140 ) -> Result<MiddlewareAction<Response>, SpiderError> {
141 if response.cached {
142 return Ok(MiddlewareAction::Continue(response));
143 }
144
145 let key = self.scope_key(&response.request_from_response());
146
147 let Some(state) = self.states.get(&key).await else {
148 return Ok(MiddlewareAction::Continue(response));
149 };
150
151 let observed_latency = response
152 .meta
153 .as_ref()
154 .and_then(|meta| meta.get(STARTED_AT_META_KEY).map(|v| v.value().clone()))
155 .and_then(|v| v.as_u64())
156 .and_then(|started_at_ms| {
157 SystemTime::now()
158 .duration_since(UNIX_EPOCH)
159 .ok()
160 .map(|now| now.as_millis().saturating_sub(u128::from(started_at_ms)))
161 })
162 .map(|delta_ms| {
163 let bounded = delta_ms.min(u128::from(u64::MAX)) as u64;
164 Duration::from_millis(bounded)
165 });
166
167 let status = response.status.as_u16();
168 let mut guard = state.lock().await;
169 let old_delay = guard.delay;
170
171 if let Some(latency) = observed_latency {
172 let target_delay = latency
173 .div_f64(self.target_concurrency.max(0.1))
174 .clamp(self.min_delay, self.max_delay);
175 let smoothed = old_delay.mul_f64(1.0 - self.smoothing_factor)
176 + target_delay.mul_f64(self.smoothing_factor);
177 guard.delay = smoothed.clamp(self.min_delay, self.max_delay);
178 }
179
180 match status {
181 429 => guard.delay = guard.delay.mul_f64(self.too_many_penalty),
182 403 => guard.delay = guard.delay.mul_f64(self.forbidden_penalty),
183 500..=599 => guard.delay = guard.delay.mul_f64(self.error_penalty),
184 _ => {}
185 }
186 guard.delay = guard.delay.clamp(self.min_delay, self.max_delay);
187
188 if old_delay != guard.delay {
189 debug!(
190 "AutoThrottle adjusted delay for '{}': {:?} -> {:?} (status={})",
191 key, old_delay, guard.delay, status
192 );
193 }
194
195 Ok(MiddlewareAction::Continue(response))
196 }
197}
198
199pub struct AutoThrottleMiddlewareBuilder {
201 scope: Scope,
202 min_delay: Duration,
203 max_delay: Duration,
204 target_concurrency: f64,
205 smoothing_factor: f64,
206 error_penalty: f64,
207 forbidden_penalty: f64,
208 too_many_penalty: f64,
209 cache_ttl: Duration,
210 cache_capacity: u64,
211 jitter: bool,
212}
213
214impl Default for AutoThrottleMiddlewareBuilder {
215 fn default() -> Self {
216 Self {
217 scope: Scope::Domain,
218 min_delay: Duration::from_millis(50),
219 max_delay: Duration::from_secs(60),
220 target_concurrency: 1.0,
221 smoothing_factor: 0.3,
222 error_penalty: 1.5,
223 forbidden_penalty: 1.2,
224 too_many_penalty: 2.0,
225 cache_ttl: Duration::from_secs(MIDDLEWARE_CACHE_TTL_SECS),
226 cache_capacity: MIDDLEWARE_CACHE_CAPACITY,
227 jitter: true,
228 }
229 }
230}
231
232impl AutoThrottleMiddlewareBuilder {
233 pub fn scope(mut self, scope: Scope) -> Self {
235 self.scope = scope;
236 self
237 }
238
239 pub fn min_delay(mut self, min_delay: Duration) -> Self {
241 self.min_delay = min_delay;
242 self
243 }
244
245 pub fn max_delay(mut self, max_delay: Duration) -> Self {
247 self.max_delay = max_delay;
248 self
249 }
250
251 pub fn target_concurrency(mut self, target_concurrency: f64) -> Self {
253 self.target_concurrency = target_concurrency;
254 self
255 }
256
257 pub fn smoothing_factor(mut self, smoothing_factor: f64) -> Self {
259 self.smoothing_factor = smoothing_factor.clamp(0.0, 1.0);
260 self
261 }
262
263 pub fn error_penalty(mut self, error_penalty: f64) -> Self {
265 self.error_penalty = error_penalty.max(1.0);
266 self
267 }
268
269 pub fn forbidden_penalty(mut self, forbidden_penalty: f64) -> Self {
271 self.forbidden_penalty = forbidden_penalty.max(1.0);
272 self
273 }
274
275 pub fn too_many_penalty(mut self, too_many_penalty: f64) -> Self {
277 self.too_many_penalty = too_many_penalty.max(1.0);
278 self
279 }
280
281 pub fn jitter(mut self, jitter: bool) -> Self {
283 self.jitter = jitter;
284 self
285 }
286
287 pub fn cache_ttl(mut self, cache_ttl: Duration) -> Self {
289 self.cache_ttl = cache_ttl;
290 self
291 }
292
293 pub fn cache_capacity(mut self, cache_capacity: u64) -> Self {
295 self.cache_capacity = cache_capacity;
296 self
297 }
298
299 pub fn build(self) -> AutoThrottleMiddleware {
301 let min_delay = self.min_delay.min(self.max_delay);
302 let max_delay = self.max_delay.max(self.min_delay);
303
304 AutoThrottleMiddleware {
305 scope: self.scope,
306 states: Cache::builder()
307 .time_to_idle(self.cache_ttl)
308 .max_capacity(self.cache_capacity)
309 .build(),
310 min_delay,
311 max_delay,
312 target_concurrency: self.target_concurrency.max(0.1),
313 smoothing_factor: self.smoothing_factor.clamp(0.0, 1.0),
314 error_penalty: self.error_penalty.max(1.0),
315 forbidden_penalty: self.forbidden_penalty.max(1.0),
316 too_many_penalty: self.too_many_penalty.max(1.0),
317 jitter: self.jitter,
318 }
319 }
320}