spider_middleware/
cookies.rs1use async_trait::async_trait;
7use cookie::Cookie;
8use cookie_store::CookieStore;
9use std::fs::File;
10use std::io::{BufRead, BufReader};
11use std::path::Path;
12use std::sync::Arc;
13use time::OffsetDateTime;
14use tokio::sync::RwLock;
15use url::Url;
16
17use crate::middleware::{Middleware, MiddlewareAction};
18use spider_util::error::SpiderError;
19use spider_util::request::Request;
20use spider_util::response::Response;
21
22pub struct CookieMiddleware {
24 pub store: Arc<RwLock<CookieStore>>,
25}
26
27impl CookieMiddleware {
28 pub fn new() -> Self {
30 Self::with_store(CookieStore::default())
31 }
32
33 pub fn with_store(store: CookieStore) -> Self {
35 Self {
36 store: Arc::new(RwLock::new(store)),
37 }
38 }
39
40 pub async fn from_json<P: AsRef<Path>>(path: P) -> Result<Self, SpiderError> {
43 let file = File::open(path)?;
44 let reader = BufReader::new(file);
45
46 let store =
47 CookieStore::load_json(reader).map_err(|e| SpiderError::GeneralError(e.to_string()))?;
48
49 Ok(Self::with_store(store))
50 }
51
52 pub async fn from_netscape_file<P: AsRef<Path>>(path: P) -> Result<Self, SpiderError> {
55 let file = File::open(path)?;
56 let reader = BufReader::new(file);
57 let mut store = CookieStore::default();
58
59 for line in reader.lines() {
60 let line = line?;
61 if line.starts_with('#') || line.trim().is_empty() {
62 continue;
63 }
64
65 let parts: Vec<&str> = line.split('\t').collect();
66 if parts.len() != 7 {
67 return Err(SpiderError::GeneralError(format!(
68 "Malformed Netscape cookie line: Expected 7 parts, got {}",
69 parts.len()
70 )));
71 }
72
73 let domain = parts[0];
74 let secure = parts[3].eq_ignore_ascii_case("TRUE");
75 let path = parts[2];
76 let expires_timestamp = parts[4].parse::<i64>().map_err(|e| {
77 SpiderError::GeneralError(format!("Invalid timestamp format: {}", e))
78 })?;
79
80 let expires = if expires_timestamp == 0 {
81 None
82 } else {
83 Some(
84 OffsetDateTime::from_unix_timestamp(expires_timestamp).map_err(|e| {
85 SpiderError::GeneralError(format!("Invalid timestamp value: {}", e))
86 })?,
87 )
88 };
89
90 let name = parts[5].to_string();
91 let value = parts[6].to_string();
92
93 let mut cookie_builder = Cookie::build(name, value).path(path).secure(secure);
94 if let Some(expires) = expires {
95 cookie_builder = cookie_builder.expires(expires);
96 }
97
98 let mut domain_for_url = domain;
99 if domain_for_url.starts_with('.') {
100 domain_for_url = &domain_for_url[1..];
101 }
102
103 let url_str = format!(
104 "{}://{}",
105 if secure { "https" } else { "http" },
106 domain_for_url
107 );
108
109 let url = Url::parse(&url_str)?;
110 let cookie = cookie_builder
111 .domain(domain.to_string())
112 .finish()
113 .into_owned();
114
115 store.store_response_cookies(std::iter::once(cookie), &url);
116 }
117
118 Ok(Self::with_store(store))
119 }
120
121 pub async fn from_rfc6265<P: AsRef<Path>>(path: P) -> Result<Self, SpiderError> {
124 let file = File::open(path)?;
125 let reader = BufReader::new(file);
126 let mut store = CookieStore::default();
127
128 for line in reader.lines() {
129 let line = line?;
130 if line.trim().is_empty() {
131 continue;
132 }
133
134 let cookie = Cookie::parse(line)
135 .map_err(|e| SpiderError::GeneralError(format!("Failed to parse cookie: {}", e)))?;
136
137 let domain = cookie.domain().ok_or_else(|| {
138 SpiderError::GeneralError(
139 "Cookie in file must have an explicit Domain attribute".to_string(),
140 )
141 })?;
142
143 let secure = cookie.secure().unwrap_or(false);
144 let url_str = format!("{}://{}", if secure { "https" } else { "http" }, domain);
145 let url = Url::parse(&url_str)?;
146 store.store_response_cookies(std::iter::once(cookie), &url);
147 }
148
149 Ok(Self::with_store(store))
150 }
151}
152
153#[async_trait]
154impl<C: Send + Sync> Middleware<C> for CookieMiddleware {
155 fn name(&self) -> &str {
156 "CookieMiddleware"
157 }
158
159 async fn process_request(
160 &self,
161 _client: &C,
162 mut request: Request,
163 ) -> Result<MiddlewareAction<Request>, SpiderError> {
164 let store = self.store.read().await;
165
166 let cookie_header = store
167 .get_request_values(&request.url)
168 .map(|(name, value)| format!("{}={}", name, value))
169 .collect::<Vec<_>>()
170 .join("; ");
171
172 if !cookie_header.is_empty() {
173 request.headers.insert(
174 http::header::COOKIE,
175 http::HeaderValue::from_str(&cookie_header)?,
176 );
177 }
178
179 Ok(MiddlewareAction::Continue(request))
180 }
181
182 async fn process_response(
183 &self,
184 response: Response,
185 ) -> Result<MiddlewareAction<Response>, SpiderError> {
186 let cookies_to_store = response
187 .headers
188 .get_all(http::header::SET_COOKIE)
189 .iter()
190 .filter_map(|val| val.to_str().ok())
191 .filter_map(|s| Cookie::parse(s).ok());
192
193 self.store
194 .write()
195 .await
196 .store_response_cookies(cookies_to_store.map(|c| c.into_owned()), &response.url);
197
198 Ok(MiddlewareAction::Continue(response))
199 }
200}
201
202impl Default for CookieMiddleware {
203 fn default() -> Self {
204 Self::new()
205 }
206}