spider_middleware/
user_agent.rs

1//! User-agent middleware.
2//!
3//! [`UserAgentMiddleware`] selects and rotates outgoing user-agent values using
4//! built-in lists, custom lists, or per-domain configuration.
5
6use async_trait::async_trait;
7use dashmap::DashMap;
8use log::{debug, info, warn};
9use moka::sync::Cache;
10use reqwest::header::{HeaderValue, USER_AGENT};
11use serde::{Deserialize, Deserializer, Serialize, Serializer};
12use std::fmt::Debug;
13use std::fs::File;
14use std::io::{BufRead, BufReader};
15use std::path::{Path, PathBuf};
16use std::sync::Arc;
17use std::sync::atomic::{AtomicUsize, Ordering};
18use std::time::Duration;
19use ua_generator::ua::*;
20
21use rand::seq::SliceRandom;
22
23use crate::middleware::{Middleware, MiddlewareAction};
24use spider_util::error::SpiderError;
25use spider_util::request::Request;
26
27/// Defines the strategy for rotating User-Agents.
28#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
29pub enum UserAgentRotationStrategy {
30    /// Randomly selects a User-Agent from the available pool.
31    #[default]
32    Random,
33    /// Sequentially cycles through the available User-Agents.
34    Sequential,
35    /// Selects a User-Agent on first encounter with a domain and uses it for all subsequent requests to that domain.
36    Sticky,
37    /// Selects a User-Agent on first encounter with a domain and uses it for a configured duration (session).
38    StickySession,
39}
40
41/// Predefined lists of User-Agents for common scenarios.
42#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
43pub enum BuiltinUserAgentList {
44    /// Generic Chrome User-Agents.
45    Chrome,
46    /// Chrome User-Agents on Linux.
47    ChromeLinux,
48    /// Chrome User-Agents on Mac.
49    ChromeMac,
50    /// Chrome Mobile User-Agents.
51    ChromeMobile,
52    /// Chrome Tablet User-Agents.
53    ChromeTablet,
54    /// Chrome User-Agents on Windows.
55    ChromeWindows,
56    /// Generic Firefox User-Agents.
57    Firefox,
58    /// Firefox User-Agents on Linux.
59    FirefoxLinux,
60    /// Firefox User-Agents on Mac.
61    FirefoxMac,
62    /// Firefox Mobile User-Agents.
63    FirefoxMobile,
64    /// Firefox Tablet User-Agents.
65    FirefoxTablet,
66    /// Firefox User-Agents on Windows.
67    FirefoxWindows,
68    /// Generic Safari User-Agents.
69    Safari,
70    /// Safari User-Agents on Mac.
71    SafariMac,
72    /// Safari Mobile User-Agents.
73    SafariMobile,
74    /// Safari Tablet User-Agents.
75    SafariTablet,
76    /// Safari User-Agents on Windows.
77    SafariWindows,
78    /// A random selection from all available User-Agents.
79    Random,
80}
81
82/// Defines the source from which User-Agents are loaded.
83#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
84#[serde(untagged)]
85pub enum UserAgentSource {
86    /// A direct list of User-Agent strings.
87    List(Vec<String>),
88    /// Path to a file containing User-Agent strings, one per line.
89    File(PathBuf),
90    /// Use a predefined, built-in list of User-Agents.
91    Builtin(BuiltinUserAgentList),
92    /// No User-Agent source specified, will fallback to a default if available.
93    None,
94}
95
96impl Default for UserAgentSource {
97    fn default() -> Self {
98        UserAgentSource::Builtin(BuiltinUserAgentList::Random)
99    }
100}
101
102/// Serializes `Arc<String>` as a string.
103fn serialize_arc_string<S>(x: &Arc<String>, s: S) -> Result<S::Ok, S::Error>
104where
105    S: Serializer,
106{
107    s.serialize_str(x.as_str())
108}
109
110/// Deserializes a string into `Arc<String>`.
111fn deserialize_arc_string<'de, D>(deserializer: D) -> Result<Arc<String>, D::Error>
112where
113    D: Deserializer<'de>,
114{
115    let s = String::deserialize(deserializer)?;
116    Ok(Arc::new(s))
117}
118
119/// Represents a User-Agent profile, including the User-Agent string and other associated headers.
120#[derive(Debug, Clone, Serialize, Deserialize)]
121pub(crate) struct UserAgentProfile {
122    /// The User-Agent string.
123    #[serde(
124        serialize_with = "serialize_arc_string",
125        deserialize_with = "deserialize_arc_string"
126    )]
127    pub user_agent: Arc<String>,
128    /// Additional headers that should be sent with this User-Agent to mimic a real browser.
129    #[serde(default)]
130    pub headers: DashMap<String, String>,
131}
132
133impl From<String> for UserAgentProfile {
134    fn from(user_agent: String) -> Self {
135        UserAgentProfile {
136            user_agent: Arc::new(user_agent),
137            headers: DashMap::new(),
138        }
139    }
140}
141
142impl From<&str> for UserAgentProfile {
143    fn from(user_agent: &str) -> Self {
144        UserAgentProfile {
145            user_agent: Arc::new(user_agent.to_string()),
146            headers: DashMap::new(),
147        }
148    }
149}
150
151/// Builder for creating a [`UserAgentMiddleware`].
152#[derive(Debug, Clone, Default, Serialize, Deserialize)]
153pub struct UserAgentMiddlewareBuilder {
154    source: UserAgentSource,
155    strategy: UserAgentRotationStrategy,
156    fallback_user_agent: Option<String>,
157    per_domain_source: DashMap<String, UserAgentSource>,
158    per_domain_strategy: DashMap<String, UserAgentRotationStrategy>,
159    session_duration: Option<Duration>,
160}
161
162impl UserAgentMiddlewareBuilder {
163    /// Sets the primary source for User-Agents.
164    pub fn source(mut self, source: UserAgentSource) -> Self {
165        self.source = source;
166        self
167    }
168
169    /// Sets the default strategy to use for rotating User-Agents.
170    pub fn strategy(mut self, strategy: UserAgentRotationStrategy) -> Self {
171        self.strategy = strategy;
172        self
173    }
174
175    /// Sets the duration for a "sticky session" in the `StickySession` strategy.
176    pub fn session_duration(mut self, duration: Duration) -> Self {
177        self.session_duration = Some(duration);
178        self
179    }
180
181    /// Sets a fallback User-Agent to use if no other User-Agents are available.
182    pub fn fallback_user_agent(mut self, fallback_user_agent: impl Into<String>) -> Self {
183        self.fallback_user_agent = Some(fallback_user_agent.into());
184        self
185    }
186
187    /// Adds a domain-specific User-Agent source.
188    pub fn per_domain_source(self, domain: impl Into<String>, source: UserAgentSource) -> Self {
189        self.per_domain_source.insert(domain.into(), source);
190        self
191    }
192
193    /// Adds a domain-specific User-Agent rotation strategy, overriding the default.
194    pub fn per_domain_strategy(
195        self,
196        domain: impl Into<String>,
197        strategy: UserAgentRotationStrategy,
198    ) -> Self {
199        self.per_domain_strategy.insert(domain.into(), strategy);
200        self
201    }
202
203    /// Builds the `UserAgentMiddleware`.
204    ///
205    /// # Errors
206    ///
207    /// Returns an error if a configured source file cannot be read.
208    pub fn build(self) -> Result<UserAgentMiddleware, SpiderError> {
209        let default_pool = Arc::new(UserAgentMiddleware::load_user_agents(&self.source)?);
210
211        let domain_cache = Cache::builder()
212            .time_to_live(Duration::from_secs(30 * 60)) // 30 minutes
213            .build();
214
215        for entry in self.per_domain_source.iter() {
216            let domain = entry.key().clone();
217            let source = entry.value().clone();
218            let pool = Arc::new(UserAgentMiddleware::load_user_agents(&source)?);
219            domain_cache.insert(domain, pool);
220        }
221
222        let session_cache = Cache::builder()
223            .time_to_live(self.session_duration.unwrap_or(Duration::from_secs(5 * 60)))
224            .build();
225
226        let middleware = UserAgentMiddleware {
227            strategy: self.strategy,
228            fallback_user_agent: self.fallback_user_agent,
229            domain_cache,
230            default_pool,
231            sticky_cache: DashMap::new(),
232            session_cache,
233            per_domain_strategy: self.per_domain_strategy,
234            current_index: AtomicUsize::new(0),
235        };
236
237        info!(
238            "Initializing UserAgentMiddleware with config: {:?}",
239            middleware
240        );
241
242        Ok(middleware)
243    }
244}
245
246/// Middleware that sets and rotates `User-Agent` headers for outgoing requests.
247pub struct UserAgentMiddleware {
248    strategy: UserAgentRotationStrategy,
249    fallback_user_agent: Option<String>,
250    domain_cache: Cache<String, Arc<Vec<UserAgentProfile>>>,
251    default_pool: Arc<Vec<UserAgentProfile>>,
252    sticky_cache: DashMap<String, UserAgentProfile>,
253    session_cache: Cache<String, UserAgentProfile>,
254    per_domain_strategy: DashMap<String, UserAgentRotationStrategy>,
255    current_index: AtomicUsize,
256}
257
258impl Debug for UserAgentMiddleware {
259    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
260        f.debug_struct("UserAgentMiddleware")
261            .field("strategy", &self.strategy)
262            .field("fallback_user_agent", &self.fallback_user_agent)
263            .field(
264                "domain_cache",
265                &format!("Cache({})", self.domain_cache.weighted_size()),
266            )
267            .field(
268                "default_pool",
269                &format!("Pool({})", self.default_pool.len()),
270            )
271            .field(
272                "sticky_cache",
273                &format!("DashMap({})", self.sticky_cache.len()),
274            )
275            .field(
276                "session_cache",
277                &format!("Cache({})", self.session_cache.weighted_size()),
278            )
279            .field(
280                "per_domain_strategy",
281                &format!("DashMap({})", self.per_domain_strategy.len()),
282            )
283            .field("current_index", &self.current_index)
284            .finish()
285    }
286}
287
288impl UserAgentMiddleware {
289    /// Creates a new `UserAgentMiddlewareBuilder` to start building a `UserAgentMiddleware`.
290    pub fn builder() -> UserAgentMiddlewareBuilder {
291        UserAgentMiddlewareBuilder::default()
292    }
293
294    fn load_user_agents(source: &UserAgentSource) -> Result<Vec<UserAgentProfile>, SpiderError> {
295        match source {
296            UserAgentSource::List(list) => Ok(list
297                .iter()
298                .map(|ua| UserAgentProfile::from(ua.clone()))
299                .collect()),
300            UserAgentSource::File(path) => Self::load_from_file(path),
301            UserAgentSource::Builtin(builtin_list) => {
302                Ok(Self::load_builtin_user_agents(builtin_list))
303            }
304            UserAgentSource::None => Ok(Vec::new()),
305        }
306    }
307
308    fn load_from_file(path: &Path) -> Result<Vec<UserAgentProfile>, SpiderError> {
309        if !path.exists() {
310            return Err(SpiderError::IoError(
311                std::io::Error::new(
312                    std::io::ErrorKind::NotFound,
313                    format!("User-agent file not found: {}", path.display()),
314                )
315                .to_string(),
316            ));
317        }
318        let file = File::open(path)?;
319        let reader = BufReader::new(file);
320        let user_agents: Vec<UserAgentProfile> = reader
321            .lines()
322            .map_while(Result::ok)
323            .filter(|line| !line.trim().is_empty())
324            .map(UserAgentProfile::from)
325            .collect();
326
327        if user_agents.is_empty() {
328            warn!(
329                "User-Agent file {:?} is empty or contains no valid User-Agents.",
330                path
331            );
332        }
333        Ok(user_agents)
334    }
335
336    fn load_builtin_user_agents(list_type: &BuiltinUserAgentList) -> Vec<UserAgentProfile> {
337        let ua = match list_type {
338            BuiltinUserAgentList::Chrome => STATIC_CHROME_AGENTS,
339            BuiltinUserAgentList::ChromeLinux => STATIC_CHROME_LINUX_AGENTS,
340            BuiltinUserAgentList::ChromeMac => STATIC_CHROME_MAC_AGENTS,
341            BuiltinUserAgentList::ChromeMobile => STATIC_CHROME_MOBILE_AGENTS,
342            BuiltinUserAgentList::ChromeTablet => STATIC_CHROME_TABLET_AGENTS,
343            BuiltinUserAgentList::ChromeWindows => STATIC_CHROME_WINDOWS_AGENTS,
344            BuiltinUserAgentList::Firefox => STATIC_FIREFOX_AGENTS,
345            BuiltinUserAgentList::FirefoxLinux => STATIC_FIREFOX_LINUX_AGENTS,
346            BuiltinUserAgentList::FirefoxMac => STATIC_FIREFOX_MAC_AGENTS,
347            BuiltinUserAgentList::FirefoxMobile => STATIC_FIREFOX_MOBILE_AGENTS,
348            BuiltinUserAgentList::FirefoxTablet => STATIC_FIREFOX_TABLET_AGENTS,
349            BuiltinUserAgentList::FirefoxWindows => STATIC_FIREFOX_WINDOWS_AGENTS,
350            BuiltinUserAgentList::Safari => STATIC_SAFARI_AGENTS,
351            BuiltinUserAgentList::SafariMac => STATIC_SAFARI_MAC_AGENTS,
352            BuiltinUserAgentList::SafariMobile => STATIC_SAFARI_MOBILE_AGENTS,
353            BuiltinUserAgentList::SafariTablet => STATIC_SAFARI_TABLET_AGENTS,
354            BuiltinUserAgentList::SafariWindows => STATIC_FIREFOX_WINDOWS_AGENTS,
355            BuiltinUserAgentList::Random => all_static_agents(),
356        };
357
358        ua.iter().map(|&v| UserAgentProfile::from(v)).collect()
359    }
360
361    fn get_user_agent(&self, domain: Option<&str>) -> Option<UserAgentProfile> {
362        let mut rng = rand::thread_rng();
363
364        let domain_str = domain.unwrap_or_default().to_string();
365
366        let strategy = self
367            .per_domain_strategy
368            .get(&domain_str)
369            .map(|s| s.value().clone())
370            .unwrap_or_else(|| self.strategy.clone());
371
372        let pool = || {
373            domain
374                .and_then(|d| self.domain_cache.get(d))
375                .unwrap_or_else(|| self.default_pool.clone())
376        };
377
378        let get_fallback = || {
379            debug!("User-Agent pool is empty or no UA selected.");
380            self.fallback_user_agent
381                .as_ref()
382                .map(|ua| UserAgentProfile::from(ua.clone()))
383        };
384
385        match strategy {
386            UserAgentRotationStrategy::Random => {
387                let p = pool();
388                if p.is_empty() {
389                    return get_fallback();
390                }
391                p.choose(&mut rng).cloned()
392            }
393            UserAgentRotationStrategy::Sequential => {
394                let p = pool();
395                if p.is_empty() {
396                    return get_fallback();
397                }
398                let current = self.current_index.fetch_add(1, Ordering::SeqCst);
399                let index = current % p.len();
400                p.get(index).cloned()
401            }
402            UserAgentRotationStrategy::Sticky => {
403                if let Some(profile) = self.sticky_cache.get(&domain_str) {
404                    return Some(profile.clone());
405                }
406
407                let p = pool();
408                if p.is_empty() {
409                    return get_fallback();
410                }
411
412                if let Some(profile) = p.choose(&mut rng).cloned() {
413                    self.sticky_cache.insert(domain_str, profile.clone());
414                    Some(profile)
415                } else {
416                    get_fallback()
417                }
418            }
419            UserAgentRotationStrategy::StickySession => {
420                if let Some(profile) = self.session_cache.get(&domain_str) {
421                    return Some(profile);
422                }
423
424                let p = pool();
425                if p.is_empty() {
426                    return get_fallback();
427                }
428
429                if let Some(profile) = p.choose(&mut rng).cloned() {
430                    self.session_cache.insert(domain_str, profile.clone());
431                    Some(profile)
432                } else {
433                    get_fallback()
434                }
435            }
436        }
437    }
438}
439
440#[async_trait]
441impl<C: Send + Sync> Middleware<C> for UserAgentMiddleware {
442    fn name(&self) -> &str {
443        "UserAgentMiddleware"
444    }
445
446    async fn process_request(
447        &self,
448        _client: &C,
449        mut request: Request,
450    ) -> Result<MiddlewareAction<Request>, SpiderError> {
451        let domain = request.url.domain();
452        if let Some(profile) = self.get_user_agent(domain) {
453            debug!("Applying User-Agent: {}", profile.user_agent);
454            request.headers.insert(
455                USER_AGENT,
456                HeaderValue::from_str(&profile.user_agent).map_err(|e| {
457                    SpiderError::HeaderValueError(format!(
458                        "Invalid User-Agent string '{}': {}",
459                        profile.user_agent, e
460                    ))
461                })?,
462            );
463            for header in profile.headers.iter() {
464                request.headers.insert(
465                    reqwest::header::HeaderName::from_bytes(header.key().as_bytes()).map_err(
466                        |e| SpiderError::HeaderValueError(format!("Invalid header name: {}", e)),
467                    )?,
468                    HeaderValue::from_str(header.value().as_str()).map_err(|e| {
469                        SpiderError::HeaderValueError(format!(
470                            "Invalid header value for {}: {}",
471                            header.key(),
472                            e
473                        ))
474                    })?,
475                );
476            }
477        } else {
478            debug!("No User-Agent applied.");
479        }
480        Ok(MiddlewareAction::Continue(request))
481    }
482}