spider_middleware/
proxy.rs1use async_trait::async_trait;
7use log::{info, warn};
8use rand::seq::SliceRandom;
9use serde::{Deserialize, Serialize};
10use std::fmt::Debug;
11use std::fs::File;
12use std::io::{BufRead, BufReader};
13use std::path::{Path, PathBuf};
14use std::sync::Arc;
15use std::sync::atomic::{AtomicUsize, Ordering};
16
17use crate::middleware::{Middleware, MiddlewareAction};
18use spider_util::error::SpiderError;
19use spider_util::request::Request;
20use spider_util::response::Response;
21
22#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
24pub enum ProxyRotationStrategy {
25 #[default]
27 Sequential,
28 Random,
30 StickyFailover,
32}
33
34#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
36#[serde(untagged)]
37pub enum ProxySource {
38 List(Vec<String>),
40 File(PathBuf),
42}
43
44impl Default for ProxySource {
45 fn default() -> Self {
46 ProxySource::List(Vec::new())
47 }
48}
49
50#[derive(Debug, Clone, Default, Serialize, Deserialize)]
52pub struct ProxyMiddlewareBuilder {
53 source: ProxySource,
54 strategy: ProxyRotationStrategy,
55 block_detection_texts: Vec<String>,
56}
57
58impl ProxyMiddlewareBuilder {
59 pub fn source(mut self, source: ProxySource) -> Self {
61 self.source = source;
62 self
63 }
64
65 pub fn strategy(mut self, strategy: ProxyRotationStrategy) -> Self {
67 self.strategy = strategy;
68 self
69 }
70
71 pub fn with_block_detection_texts<I, S>(mut self, texts: I) -> Self
74 where
75 I: IntoIterator<Item = S>,
76 S: Into<String>,
77 {
78 self.block_detection_texts = texts.into_iter().map(Into::into).collect();
79 self
80 }
81
82 pub fn build(self) -> Result<ProxyMiddleware, SpiderError> {
88 let proxies = Arc::new(ProxyMiddleware::load_proxies(&self.source)?);
89
90 let block_texts = if self.block_detection_texts.is_empty() {
91 None
92 } else {
93 Some(self.block_detection_texts)
94 };
95
96 let middleware = ProxyMiddleware {
97 strategy: self.strategy,
98 proxies,
99 current_index: AtomicUsize::new(0),
100 block_detection_texts: block_texts,
101 };
102
103 info!("Initializing ProxyMiddleware with config: {:?}", middleware);
104
105 Ok(middleware)
106 }
107}
108
109pub struct ProxyMiddleware {
111 strategy: ProxyRotationStrategy,
112 proxies: Arc<Vec<String>>,
113 current_index: AtomicUsize,
114 block_detection_texts: Option<Vec<String>>,
115}
116
117impl Debug for ProxyMiddleware {
118 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119 f.debug_struct("ProxyMiddleware")
120 .field("strategy", &self.strategy)
121 .field("proxies", &format!("Pool({})", self.proxies.len()))
122 .field("current_index", &self.current_index)
123 .field("block_detection_texts", &self.block_detection_texts)
124 .finish()
125 }
126}
127
128impl ProxyMiddleware {
129 pub fn builder() -> ProxyMiddlewareBuilder {
131 ProxyMiddlewareBuilder::default()
132 }
133
134 fn load_proxies(source: &ProxySource) -> Result<Vec<String>, SpiderError> {
135 match source {
136 ProxySource::List(list) => Ok(list.clone()),
137 ProxySource::File(path) => Self::load_from_file(path),
138 }
139 }
140
141 fn load_from_file(path: &Path) -> Result<Vec<String>, SpiderError> {
142 if !path.exists() {
143 return Err(SpiderError::IoError(
144 std::io::Error::new(
145 std::io::ErrorKind::NotFound,
146 format!("Proxy file not found: {}", path.display()),
147 )
148 .to_string(),
149 ));
150 }
151 let file = File::open(path)?;
152 let reader = BufReader::new(file);
153 let proxies: Vec<String> = reader
154 .lines()
155 .map_while(Result::ok)
156 .filter(|line| !line.trim().is_empty())
157 .collect();
158
159 if proxies.is_empty() {
160 warn!(
161 "Proxy file {:?} is empty or contains no valid proxy URLs.",
162 path
163 );
164 }
165 Ok(proxies)
166 }
167
168 fn get_proxy(&self) -> Option<String> {
169 if self.proxies.is_empty() {
170 return None;
171 }
172
173 match self.strategy {
174 ProxyRotationStrategy::Sequential => {
175 let current = self.current_index.fetch_add(1, Ordering::SeqCst);
176 let index = current % self.proxies.len();
177 self.proxies.get(index).cloned()
178 }
179 ProxyRotationStrategy::Random => {
180 let mut rng = rand::thread_rng();
181 self.proxies.choose(&mut rng).cloned()
182 }
183 ProxyRotationStrategy::StickyFailover => {
184 let current = self.current_index.load(Ordering::SeqCst);
185 let index = current % self.proxies.len();
186 self.proxies.get(index).cloned()
187 }
188 }
189 }
190
191 fn rotate_proxy(&self) {
192 if !self.proxies.is_empty() {
193 self.current_index.fetch_add(1, Ordering::SeqCst);
194 info!("Proxy rotation triggered due to failure.");
195 }
196 }
197}
198
199#[async_trait]
200impl<C: Send + Sync> Middleware<C> for ProxyMiddleware {
201 fn name(&self) -> &str {
202 "ProxyMiddleware"
203 }
204
205 async fn process_request(
206 &self,
207 _client: &C,
208 mut request: Request,
209 ) -> Result<MiddlewareAction<Request>, SpiderError> {
210 if let Some(proxy) = self.get_proxy() {
211 request.insert_meta("proxy".to_string(), proxy.into());
212 }
213 Ok(MiddlewareAction::Continue(request))
214 }
215
216 async fn process_response(
217 &self,
218 response: Response,
219 ) -> Result<MiddlewareAction<Response>, SpiderError> {
220 if self.strategy != ProxyRotationStrategy::StickyFailover {
221 return Ok(MiddlewareAction::Continue(response));
222 }
223
224 let mut rotate = false;
225 let status = response.status;
226
227 if status.is_client_error() || status.is_server_error() {
228 rotate = true;
229 }
230
231 if status.is_success()
232 && let Some(texts) = &self.block_detection_texts
233 {
234 let body_str = String::from_utf8_lossy(&response.body);
235 if texts.iter().any(|text| body_str.contains(text)) {
236 rotate = true;
237 info!(
238 "Block detection text found in response body from {}",
239 response.url
240 );
241 }
242 }
243
244 if rotate {
245 self.rotate_proxy();
246 }
247
248 Ok(MiddlewareAction::Continue(response))
249 }
250
251 async fn handle_error(
252 &self,
253 _request: &Request,
254 error: &SpiderError,
255 ) -> Result<MiddlewareAction<Request>, SpiderError> {
256 if self.strategy == ProxyRotationStrategy::StickyFailover {
257 self.rotate_proxy();
258 }
259
260 Err(error.clone())
261 }
262}