spider_middleware/
proxy.rs

1//! Proxy middleware.
2//!
3//! [`ProxyMiddleware`] attaches proxy metadata to outgoing requests so the
4//! downloader can route traffic through rotating or fixed proxy endpoints.
5
6use async_trait::async_trait;
7use log::{info, warn};
8use rand::seq::SliceRandom;
9use serde::{Deserialize, Serialize};
10use std::fmt::Debug;
11use std::fs::File;
12use std::io::{BufRead, BufReader};
13use std::path::{Path, PathBuf};
14use std::sync::Arc;
15use std::sync::atomic::{AtomicUsize, Ordering};
16
17use crate::middleware::{Middleware, MiddlewareAction};
18use spider_util::error::SpiderError;
19use spider_util::request::Request;
20use spider_util::response::Response;
21
22/// Defines the strategy for rotating proxies.
23#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
24pub enum ProxyRotationStrategy {
25    /// Sequentially cycles through the available proxies.
26    #[default]
27    Sequential,
28    /// Randomly selects a proxy from the available pool.
29    Random,
30    /// Uses one proxy until a failure is detected (based on status or body), then rotates.
31    StickyFailover,
32}
33
34/// Defines the source from which proxies are loaded.
35#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
36#[serde(untagged)]
37pub enum ProxySource {
38    /// A direct list of proxy URLs.
39    List(Vec<String>),
40    /// Path to a file containing proxy URLs, one per line.
41    File(PathBuf),
42}
43
44impl Default for ProxySource {
45    fn default() -> Self {
46        ProxySource::List(Vec::new())
47    }
48}
49
50/// Builder for creating a [`ProxyMiddleware`].
51#[derive(Debug, Clone, Default, Serialize, Deserialize)]
52pub struct ProxyMiddlewareBuilder {
53    source: ProxySource,
54    strategy: ProxyRotationStrategy,
55    block_detection_texts: Vec<String>,
56}
57
58impl ProxyMiddlewareBuilder {
59    /// Sets the primary source for proxies.
60    pub fn source(mut self, source: ProxySource) -> Self {
61        self.source = source;
62        self
63    }
64
65    /// Sets the strategy to use for rotating proxies.
66    pub fn strategy(mut self, strategy: ProxyRotationStrategy) -> Self {
67        self.strategy = strategy;
68        self
69    }
70
71    /// Sets the texts to detect in the response body to trigger a proxy rotation.
72    /// This is only used with the `StickyFailover` strategy.
73    pub fn with_block_detection_texts<I, S>(mut self, texts: I) -> Self
74    where
75        I: IntoIterator<Item = S>,
76        S: Into<String>,
77    {
78        self.block_detection_texts = texts.into_iter().map(Into::into).collect();
79        self
80    }
81
82    /// Builds the `ProxyMiddleware`.
83    ///
84    /// # Errors
85    ///
86    /// Returns an error if the configured source file cannot be read.
87    pub fn build(self) -> Result<ProxyMiddleware, SpiderError> {
88        let proxies = Arc::new(ProxyMiddleware::load_proxies(&self.source)?);
89
90        let block_texts = if self.block_detection_texts.is_empty() {
91            None
92        } else {
93            Some(self.block_detection_texts)
94        };
95
96        let middleware = ProxyMiddleware {
97            strategy: self.strategy,
98            proxies,
99            current_index: AtomicUsize::new(0),
100            block_detection_texts: block_texts,
101        };
102
103        info!("Initializing ProxyMiddleware with config: {:?}", middleware);
104
105        Ok(middleware)
106    }
107}
108
109/// Middleware that assigns proxies to outgoing requests and rotates them based on strategy.
110pub struct ProxyMiddleware {
111    strategy: ProxyRotationStrategy,
112    proxies: Arc<Vec<String>>,
113    current_index: AtomicUsize,
114    block_detection_texts: Option<Vec<String>>,
115}
116
117impl Debug for ProxyMiddleware {
118    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119        f.debug_struct("ProxyMiddleware")
120            .field("strategy", &self.strategy)
121            .field("proxies", &format!("Pool({})", self.proxies.len()))
122            .field("current_index", &self.current_index)
123            .field("block_detection_texts", &self.block_detection_texts)
124            .finish()
125    }
126}
127
128impl ProxyMiddleware {
129    /// Creates a new `ProxyMiddlewareBuilder` to start building the middleware.
130    pub fn builder() -> ProxyMiddlewareBuilder {
131        ProxyMiddlewareBuilder::default()
132    }
133
134    fn load_proxies(source: &ProxySource) -> Result<Vec<String>, SpiderError> {
135        match source {
136            ProxySource::List(list) => Ok(list.clone()),
137            ProxySource::File(path) => Self::load_from_file(path),
138        }
139    }
140
141    fn load_from_file(path: &Path) -> Result<Vec<String>, SpiderError> {
142        if !path.exists() {
143            return Err(SpiderError::IoError(
144                std::io::Error::new(
145                    std::io::ErrorKind::NotFound,
146                    format!("Proxy file not found: {}", path.display()),
147                )
148                .to_string(),
149            ));
150        }
151        let file = File::open(path)?;
152        let reader = BufReader::new(file);
153        let proxies: Vec<String> = reader
154            .lines()
155            .map_while(Result::ok)
156            .filter(|line| !line.trim().is_empty())
157            .collect();
158
159        if proxies.is_empty() {
160            warn!(
161                "Proxy file {:?} is empty or contains no valid proxy URLs.",
162                path
163            );
164        }
165        Ok(proxies)
166    }
167
168    fn get_proxy(&self) -> Option<String> {
169        if self.proxies.is_empty() {
170            return None;
171        }
172
173        match self.strategy {
174            ProxyRotationStrategy::Sequential => {
175                let current = self.current_index.fetch_add(1, Ordering::SeqCst);
176                let index = current % self.proxies.len();
177                self.proxies.get(index).cloned()
178            }
179            ProxyRotationStrategy::Random => {
180                let mut rng = rand::thread_rng();
181                self.proxies.choose(&mut rng).cloned()
182            }
183            ProxyRotationStrategy::StickyFailover => {
184                let current = self.current_index.load(Ordering::SeqCst);
185                let index = current % self.proxies.len();
186                self.proxies.get(index).cloned()
187            }
188        }
189    }
190
191    fn rotate_proxy(&self) {
192        if !self.proxies.is_empty() {
193            self.current_index.fetch_add(1, Ordering::SeqCst);
194            info!("Proxy rotation triggered due to failure.");
195        }
196    }
197}
198
199#[async_trait]
200impl<C: Send + Sync> Middleware<C> for ProxyMiddleware {
201    fn name(&self) -> &str {
202        "ProxyMiddleware"
203    }
204
205    async fn process_request(
206        &self,
207        _client: &C,
208        mut request: Request,
209    ) -> Result<MiddlewareAction<Request>, SpiderError> {
210        if let Some(proxy) = self.get_proxy() {
211            request.insert_meta("proxy".to_string(), proxy.into());
212        }
213        Ok(MiddlewareAction::Continue(request))
214    }
215
216    async fn process_response(
217        &self,
218        response: Response,
219    ) -> Result<MiddlewareAction<Response>, SpiderError> {
220        if self.strategy != ProxyRotationStrategy::StickyFailover {
221            return Ok(MiddlewareAction::Continue(response));
222        }
223
224        let mut rotate = false;
225        let status = response.status;
226
227        if status.is_client_error() || status.is_server_error() {
228            rotate = true;
229        }
230
231        if status.is_success()
232            && let Some(texts) = &self.block_detection_texts
233        {
234            let body_str = String::from_utf8_lossy(&response.body);
235            if texts.iter().any(|text| body_str.contains(text)) {
236                rotate = true;
237                info!(
238                    "Block detection text found in response body from {}",
239                    response.url
240                );
241            }
242        }
243
244        if rotate {
245            self.rotate_proxy();
246        }
247
248        Ok(MiddlewareAction::Continue(response))
249    }
250
251    async fn handle_error(
252        &self,
253        _request: &Request,
254        error: &SpiderError,
255    ) -> Result<MiddlewareAction<Request>, SpiderError> {
256        if self.strategy == ProxyRotationStrategy::StickyFailover {
257            self.rotate_proxy();
258        }
259
260        Err(error.clone())
261    }
262}