spider_middleware/
cookies.rs

1//! Cookie persistence middleware.
2//!
3//! [`CookieMiddleware`] stores cookies from responses and attaches matching
4//! cookies to later requests, which is useful for session-based crawls.
5
6use async_trait::async_trait;
7use cookie::Cookie;
8use cookie_store::CookieStore;
9use std::fs::File;
10use std::io::{BufRead, BufReader};
11use std::path::Path;
12use std::sync::Arc;
13use time::OffsetDateTime;
14use tokio::sync::RwLock;
15use url::Url;
16
17use crate::middleware::{Middleware, MiddlewareAction};
18use spider_util::error::SpiderError;
19use spider_util::request::Request;
20use spider_util::response::Response;
21
22/// Middleware that keeps a shared cookie store across requests.
23pub struct CookieMiddleware {
24    pub store: Arc<RwLock<CookieStore>>,
25}
26
27impl CookieMiddleware {
28    /// Creates a new `CookieMiddleware` with a shared `CookieStore`.
29    pub fn new() -> Self {
30        Self::with_store(CookieStore::default())
31    }
32
33    /// Creates a new `CookieMiddleware` with a pre-populated `CookieStore`.
34    pub fn with_store(store: CookieStore) -> Self {
35        Self {
36            store: Arc::new(RwLock::new(store)),
37        }
38    }
39
40    /// Load cookies from a JSON file, replacing all cookies in the store.
41    /// The JSON format is the one used by the `cookie_store` crate.
42    pub async fn from_json<P: AsRef<Path>>(path: P) -> Result<Self, SpiderError> {
43        let file = File::open(path)?;
44        let reader = BufReader::new(file);
45
46        let store =
47            CookieStore::load_json(reader).map_err(|e| SpiderError::GeneralError(e.to_string()))?;
48
49        Ok(Self::with_store(store))
50    }
51
52    /// Load cookies from a Netscape cookie file.
53    /// This will add to, not replace, the existing cookies in the store.
54    pub async fn from_netscape_file<P: AsRef<Path>>(path: P) -> Result<Self, SpiderError> {
55        let file = File::open(path)?;
56        let reader = BufReader::new(file);
57        let mut store = CookieStore::default();
58
59        for line in reader.lines() {
60            let line = line?;
61            if line.starts_with('#') || line.trim().is_empty() {
62                continue;
63            }
64
65            let parts: Vec<&str> = line.split('\t').collect();
66            if parts.len() != 7 {
67                return Err(SpiderError::GeneralError(format!(
68                    "Malformed Netscape cookie line: Expected 7 parts, got {}",
69                    parts.len()
70                )));
71            }
72
73            let domain = parts[0];
74            let secure = parts[3].eq_ignore_ascii_case("TRUE");
75            let path = parts[2];
76            let expires_timestamp = parts[4].parse::<i64>().map_err(|e| {
77                SpiderError::GeneralError(format!("Invalid timestamp format: {}", e))
78            })?;
79
80            let expires = if expires_timestamp == 0 {
81                None
82            } else {
83                Some(
84                    OffsetDateTime::from_unix_timestamp(expires_timestamp).map_err(|e| {
85                        SpiderError::GeneralError(format!("Invalid timestamp value: {}", e))
86                    })?,
87                )
88            };
89
90            let name = parts[5].to_string();
91            let value = parts[6].to_string();
92
93            let mut cookie_builder = Cookie::build(name, value).path(path).secure(secure);
94            if let Some(expires) = expires {
95                cookie_builder = cookie_builder.expires(expires);
96            }
97
98            let mut domain_for_url = domain;
99            if domain_for_url.starts_with('.') {
100                domain_for_url = &domain_for_url[1..];
101            }
102
103            let url_str = format!(
104                "{}://{}",
105                if secure { "https" } else { "http" },
106                domain_for_url
107            );
108
109            let url = Url::parse(&url_str)?;
110            let cookie = cookie_builder
111                .domain(domain.to_string())
112                .finish()
113                .into_owned();
114
115            store.store_response_cookies(std::iter::once(cookie), &url);
116        }
117
118        Ok(Self::with_store(store))
119    }
120
121    /// Load cookies from a file where each line is a `Set-Cookie` header value.
122    /// The `Domain` attribute must be explicitly set in each cookie line.
123    pub async fn from_rfc6265<P: AsRef<Path>>(path: P) -> Result<Self, SpiderError> {
124        let file = File::open(path)?;
125        let reader = BufReader::new(file);
126        let mut store = CookieStore::default();
127
128        for line in reader.lines() {
129            let line = line?;
130            if line.trim().is_empty() {
131                continue;
132            }
133
134            let cookie = Cookie::parse(line)
135                .map_err(|e| SpiderError::GeneralError(format!("Failed to parse cookie: {}", e)))?;
136
137            let domain = cookie.domain().ok_or_else(|| {
138                SpiderError::GeneralError(
139                    "Cookie in file must have an explicit Domain attribute".to_string(),
140                )
141            })?;
142
143            let secure = cookie.secure().unwrap_or(false);
144            let url_str = format!("{}://{}", if secure { "https" } else { "http" }, domain);
145            let url = Url::parse(&url_str)?;
146            store.store_response_cookies(std::iter::once(cookie), &url);
147        }
148
149        Ok(Self::with_store(store))
150    }
151}
152
153#[async_trait]
154impl<C: Send + Sync> Middleware<C> for CookieMiddleware {
155    fn name(&self) -> &str {
156        "CookieMiddleware"
157    }
158
159    async fn process_request(
160        &self,
161        _client: &C,
162        mut request: Request,
163    ) -> Result<MiddlewareAction<Request>, SpiderError> {
164        let store = self.store.read().await;
165
166        let cookie_header = store
167            .get_request_values(&request.url)
168            .map(|(name, value)| format!("{}={}", name, value))
169            .collect::<Vec<_>>()
170            .join("; ");
171
172        if !cookie_header.is_empty() {
173            request.headers.insert(
174                http::header::COOKIE,
175                http::HeaderValue::from_str(&cookie_header)?,
176            );
177        }
178
179        Ok(MiddlewareAction::Continue(request))
180    }
181
182    async fn process_response(
183        &self,
184        response: Response,
185    ) -> Result<MiddlewareAction<Response>, SpiderError> {
186        let cookies_to_store = response
187            .headers
188            .get_all(http::header::SET_COOKIE)
189            .iter()
190            .filter_map(|val| val.to_str().ok())
191            .filter_map(|s| Cookie::parse(s).ok());
192
193        self.store
194            .write()
195            .await
196            .store_response_cookies(cookies_to_store.map(|c| c.into_owned()), &response.url);
197
198        Ok(MiddlewareAction::Continue(response))
199    }
200}
201
202impl Default for CookieMiddleware {
203    fn default() -> Self {
204        Self::new()
205    }
206}