1use async_trait::async_trait;
7use dashmap::DashMap;
8use log::{debug, info, warn};
9use moka::sync::Cache;
10use reqwest::header::{HeaderValue, USER_AGENT};
11use serde::{Deserialize, Deserializer, Serialize, Serializer};
12use std::fmt::Debug;
13use std::fs::File;
14use std::io::{BufRead, BufReader};
15use std::path::{Path, PathBuf};
16use std::sync::Arc;
17use std::sync::atomic::{AtomicUsize, Ordering};
18use std::time::Duration;
19use ua_generator::ua::*;
20
21use rand::seq::SliceRandom;
22
23use crate::middleware::{Middleware, MiddlewareAction};
24use spider_util::error::SpiderError;
25use spider_util::request::Request;
26
27#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
29pub enum UserAgentRotationStrategy {
30 #[default]
32 Random,
33 Sequential,
35 Sticky,
37 StickySession,
39}
40
41#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
43pub enum BuiltinUserAgentList {
44 Chrome,
46 ChromeLinux,
48 ChromeMac,
50 ChromeMobile,
52 ChromeTablet,
54 ChromeWindows,
56 Firefox,
58 FirefoxLinux,
60 FirefoxMac,
62 FirefoxMobile,
64 FirefoxTablet,
66 FirefoxWindows,
68 Safari,
70 SafariMac,
72 SafariMobile,
74 SafariTablet,
76 SafariWindows,
78 Random,
80}
81
82#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
84#[serde(untagged)]
85pub enum UserAgentSource {
86 List(Vec<String>),
88 File(PathBuf),
90 Builtin(BuiltinUserAgentList),
92 None,
94}
95
96impl Default for UserAgentSource {
97 fn default() -> Self {
98 UserAgentSource::Builtin(BuiltinUserAgentList::Random)
99 }
100}
101
102fn serialize_arc_string<S>(x: &Arc<String>, s: S) -> Result<S::Ok, S::Error>
104where
105 S: Serializer,
106{
107 s.serialize_str(x.as_str())
108}
109
110fn deserialize_arc_string<'de, D>(deserializer: D) -> Result<Arc<String>, D::Error>
112where
113 D: Deserializer<'de>,
114{
115 let s = String::deserialize(deserializer)?;
116 Ok(Arc::new(s))
117}
118
119#[derive(Debug, Clone, Serialize, Deserialize)]
121pub(crate) struct UserAgentProfile {
122 #[serde(
124 serialize_with = "serialize_arc_string",
125 deserialize_with = "deserialize_arc_string"
126 )]
127 pub user_agent: Arc<String>,
128 #[serde(default)]
130 pub headers: DashMap<String, String>,
131}
132
133impl From<String> for UserAgentProfile {
134 fn from(user_agent: String) -> Self {
135 UserAgentProfile {
136 user_agent: Arc::new(user_agent),
137 headers: DashMap::new(),
138 }
139 }
140}
141
142impl From<&str> for UserAgentProfile {
143 fn from(user_agent: &str) -> Self {
144 UserAgentProfile {
145 user_agent: Arc::new(user_agent.to_string()),
146 headers: DashMap::new(),
147 }
148 }
149}
150
151#[derive(Debug, Clone, Default, Serialize, Deserialize)]
153pub struct UserAgentMiddlewareBuilder {
154 source: UserAgentSource,
155 strategy: UserAgentRotationStrategy,
156 fallback_user_agent: Option<String>,
157 per_domain_source: DashMap<String, UserAgentSource>,
158 per_domain_strategy: DashMap<String, UserAgentRotationStrategy>,
159 session_duration: Option<Duration>,
160}
161
162impl UserAgentMiddlewareBuilder {
163 pub fn source(mut self, source: UserAgentSource) -> Self {
165 self.source = source;
166 self
167 }
168
169 pub fn strategy(mut self, strategy: UserAgentRotationStrategy) -> Self {
171 self.strategy = strategy;
172 self
173 }
174
175 pub fn session_duration(mut self, duration: Duration) -> Self {
177 self.session_duration = Some(duration);
178 self
179 }
180
181 pub fn fallback_user_agent(mut self, fallback_user_agent: impl Into<String>) -> Self {
183 self.fallback_user_agent = Some(fallback_user_agent.into());
184 self
185 }
186
187 pub fn per_domain_source(self, domain: impl Into<String>, source: UserAgentSource) -> Self {
189 self.per_domain_source.insert(domain.into(), source);
190 self
191 }
192
193 pub fn per_domain_strategy(
195 self,
196 domain: impl Into<String>,
197 strategy: UserAgentRotationStrategy,
198 ) -> Self {
199 self.per_domain_strategy.insert(domain.into(), strategy);
200 self
201 }
202
203 pub fn build(self) -> Result<UserAgentMiddleware, SpiderError> {
209 let default_pool = Arc::new(UserAgentMiddleware::load_user_agents(&self.source)?);
210
211 let domain_cache = Cache::builder()
212 .time_to_live(Duration::from_secs(30 * 60)) .build();
214
215 for entry in self.per_domain_source.iter() {
216 let domain = entry.key().clone();
217 let source = entry.value().clone();
218 let pool = Arc::new(UserAgentMiddleware::load_user_agents(&source)?);
219 domain_cache.insert(domain, pool);
220 }
221
222 let session_cache = Cache::builder()
223 .time_to_live(self.session_duration.unwrap_or(Duration::from_secs(5 * 60)))
224 .build();
225
226 let middleware = UserAgentMiddleware {
227 strategy: self.strategy,
228 fallback_user_agent: self.fallback_user_agent,
229 domain_cache,
230 default_pool,
231 sticky_cache: DashMap::new(),
232 session_cache,
233 per_domain_strategy: self.per_domain_strategy,
234 current_index: AtomicUsize::new(0),
235 };
236
237 info!(
238 "Initializing UserAgentMiddleware with config: {:?}",
239 middleware
240 );
241
242 Ok(middleware)
243 }
244}
245
246pub struct UserAgentMiddleware {
248 strategy: UserAgentRotationStrategy,
249 fallback_user_agent: Option<String>,
250 domain_cache: Cache<String, Arc<Vec<UserAgentProfile>>>,
251 default_pool: Arc<Vec<UserAgentProfile>>,
252 sticky_cache: DashMap<String, UserAgentProfile>,
253 session_cache: Cache<String, UserAgentProfile>,
254 per_domain_strategy: DashMap<String, UserAgentRotationStrategy>,
255 current_index: AtomicUsize,
256}
257
258impl Debug for UserAgentMiddleware {
259 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
260 f.debug_struct("UserAgentMiddleware")
261 .field("strategy", &self.strategy)
262 .field("fallback_user_agent", &self.fallback_user_agent)
263 .field(
264 "domain_cache",
265 &format!("Cache({})", self.domain_cache.weighted_size()),
266 )
267 .field(
268 "default_pool",
269 &format!("Pool({})", self.default_pool.len()),
270 )
271 .field(
272 "sticky_cache",
273 &format!("DashMap({})", self.sticky_cache.len()),
274 )
275 .field(
276 "session_cache",
277 &format!("Cache({})", self.session_cache.weighted_size()),
278 )
279 .field(
280 "per_domain_strategy",
281 &format!("DashMap({})", self.per_domain_strategy.len()),
282 )
283 .field("current_index", &self.current_index)
284 .finish()
285 }
286}
287
288impl UserAgentMiddleware {
289 pub fn builder() -> UserAgentMiddlewareBuilder {
291 UserAgentMiddlewareBuilder::default()
292 }
293
294 fn load_user_agents(source: &UserAgentSource) -> Result<Vec<UserAgentProfile>, SpiderError> {
295 match source {
296 UserAgentSource::List(list) => Ok(list
297 .iter()
298 .map(|ua| UserAgentProfile::from(ua.clone()))
299 .collect()),
300 UserAgentSource::File(path) => Self::load_from_file(path),
301 UserAgentSource::Builtin(builtin_list) => {
302 Ok(Self::load_builtin_user_agents(builtin_list))
303 }
304 UserAgentSource::None => Ok(Vec::new()),
305 }
306 }
307
308 fn load_from_file(path: &Path) -> Result<Vec<UserAgentProfile>, SpiderError> {
309 if !path.exists() {
310 return Err(SpiderError::IoError(
311 std::io::Error::new(
312 std::io::ErrorKind::NotFound,
313 format!("User-agent file not found: {}", path.display()),
314 )
315 .to_string(),
316 ));
317 }
318 let file = File::open(path)?;
319 let reader = BufReader::new(file);
320 let user_agents: Vec<UserAgentProfile> = reader
321 .lines()
322 .map_while(Result::ok)
323 .filter(|line| !line.trim().is_empty())
324 .map(UserAgentProfile::from)
325 .collect();
326
327 if user_agents.is_empty() {
328 warn!(
329 "User-Agent file {:?} is empty or contains no valid User-Agents.",
330 path
331 );
332 }
333 Ok(user_agents)
334 }
335
336 fn load_builtin_user_agents(list_type: &BuiltinUserAgentList) -> Vec<UserAgentProfile> {
337 let ua = match list_type {
338 BuiltinUserAgentList::Chrome => STATIC_CHROME_AGENTS,
339 BuiltinUserAgentList::ChromeLinux => STATIC_CHROME_LINUX_AGENTS,
340 BuiltinUserAgentList::ChromeMac => STATIC_CHROME_MAC_AGENTS,
341 BuiltinUserAgentList::ChromeMobile => STATIC_CHROME_MOBILE_AGENTS,
342 BuiltinUserAgentList::ChromeTablet => STATIC_CHROME_TABLET_AGENTS,
343 BuiltinUserAgentList::ChromeWindows => STATIC_CHROME_WINDOWS_AGENTS,
344 BuiltinUserAgentList::Firefox => STATIC_FIREFOX_AGENTS,
345 BuiltinUserAgentList::FirefoxLinux => STATIC_FIREFOX_LINUX_AGENTS,
346 BuiltinUserAgentList::FirefoxMac => STATIC_FIREFOX_MAC_AGENTS,
347 BuiltinUserAgentList::FirefoxMobile => STATIC_FIREFOX_MOBILE_AGENTS,
348 BuiltinUserAgentList::FirefoxTablet => STATIC_FIREFOX_TABLET_AGENTS,
349 BuiltinUserAgentList::FirefoxWindows => STATIC_FIREFOX_WINDOWS_AGENTS,
350 BuiltinUserAgentList::Safari => STATIC_SAFARI_AGENTS,
351 BuiltinUserAgentList::SafariMac => STATIC_SAFARI_MAC_AGENTS,
352 BuiltinUserAgentList::SafariMobile => STATIC_SAFARI_MOBILE_AGENTS,
353 BuiltinUserAgentList::SafariTablet => STATIC_SAFARI_TABLET_AGENTS,
354 BuiltinUserAgentList::SafariWindows => STATIC_FIREFOX_WINDOWS_AGENTS,
355 BuiltinUserAgentList::Random => all_static_agents(),
356 };
357
358 ua.iter().map(|&v| UserAgentProfile::from(v)).collect()
359 }
360
361 fn get_user_agent(&self, domain: Option<&str>) -> Option<UserAgentProfile> {
362 let mut rng = rand::thread_rng();
363
364 let domain_str = domain.unwrap_or_default().to_string();
365
366 let strategy = self
367 .per_domain_strategy
368 .get(&domain_str)
369 .map(|s| s.value().clone())
370 .unwrap_or_else(|| self.strategy.clone());
371
372 let pool = || {
373 domain
374 .and_then(|d| self.domain_cache.get(d))
375 .unwrap_or_else(|| self.default_pool.clone())
376 };
377
378 let get_fallback = || {
379 debug!("User-Agent pool is empty or no UA selected.");
380 self.fallback_user_agent
381 .as_ref()
382 .map(|ua| UserAgentProfile::from(ua.clone()))
383 };
384
385 match strategy {
386 UserAgentRotationStrategy::Random => {
387 let p = pool();
388 if p.is_empty() {
389 return get_fallback();
390 }
391 p.choose(&mut rng).cloned()
392 }
393 UserAgentRotationStrategy::Sequential => {
394 let p = pool();
395 if p.is_empty() {
396 return get_fallback();
397 }
398 let current = self.current_index.fetch_add(1, Ordering::SeqCst);
399 let index = current % p.len();
400 p.get(index).cloned()
401 }
402 UserAgentRotationStrategy::Sticky => {
403 if let Some(profile) = self.sticky_cache.get(&domain_str) {
404 return Some(profile.clone());
405 }
406
407 let p = pool();
408 if p.is_empty() {
409 return get_fallback();
410 }
411
412 if let Some(profile) = p.choose(&mut rng).cloned() {
413 self.sticky_cache.insert(domain_str, profile.clone());
414 Some(profile)
415 } else {
416 get_fallback()
417 }
418 }
419 UserAgentRotationStrategy::StickySession => {
420 if let Some(profile) = self.session_cache.get(&domain_str) {
421 return Some(profile);
422 }
423
424 let p = pool();
425 if p.is_empty() {
426 return get_fallback();
427 }
428
429 if let Some(profile) = p.choose(&mut rng).cloned() {
430 self.session_cache.insert(domain_str, profile.clone());
431 Some(profile)
432 } else {
433 get_fallback()
434 }
435 }
436 }
437 }
438}
439
440#[async_trait]
441impl<C: Send + Sync> Middleware<C> for UserAgentMiddleware {
442 fn name(&self) -> &str {
443 "UserAgentMiddleware"
444 }
445
446 async fn process_request(
447 &self,
448 _client: &C,
449 mut request: Request,
450 ) -> Result<MiddlewareAction<Request>, SpiderError> {
451 let domain = request.url.domain();
452 if let Some(profile) = self.get_user_agent(domain) {
453 debug!("Applying User-Agent: {}", profile.user_agent);
454 request.headers.insert(
455 USER_AGENT,
456 HeaderValue::from_str(&profile.user_agent).map_err(|e| {
457 SpiderError::HeaderValueError(format!(
458 "Invalid User-Agent string '{}': {}",
459 profile.user_agent, e
460 ))
461 })?,
462 );
463 for header in profile.headers.iter() {
464 request.headers.insert(
465 reqwest::header::HeaderName::from_bytes(header.key().as_bytes()).map_err(
466 |e| SpiderError::HeaderValueError(format!("Invalid header name: {}", e)),
467 )?,
468 HeaderValue::from_str(header.value().as_str()).map_err(|e| {
469 SpiderError::HeaderValueError(format!(
470 "Invalid header value for {}: {}",
471 header.key(),
472 e
473 ))
474 })?,
475 );
476 }
477 } else {
478 debug!("No User-Agent applied.");
479 }
480 Ok(MiddlewareAction::Continue(request))
481 }
482}