1use async_trait::async_trait;
8use bytes::Bytes;
9use log::{debug, info, trace, warn};
10use reqwest::StatusCode;
11use reqwest::header::{CACHE_CONTROL, EXPIRES, HeaderMap, HeaderName, HeaderValue};
12use std::path::PathBuf;
13use std::time::{Duration, SystemTime, UNIX_EPOCH};
14use time::OffsetDateTime;
15use time::format_description::well_known::Rfc2822;
16use tokio::fs;
17
18use crate::middleware::{Middleware, MiddlewareAction};
19use serde::{Deserialize, Deserializer, Serialize, Serializer};
20use spider_util::error::SpiderError;
21use spider_util::request::Request;
22use spider_util::response::Response;
23use url::Url;
24
25fn serialize_headermap<S>(headers: &HeaderMap, serializer: S) -> Result<S::Ok, S::Error>
26where
27 S: Serializer,
28{
29 let mut map = std::collections::HashMap::<String, String>::new();
30 for (name, value) in headers.iter() {
31 map.insert(
32 name.to_string(),
33 value.to_str().unwrap_or_default().to_string(),
34 );
35 }
36 map.serialize(serializer)
37}
38
39fn deserialize_headermap<'de, D>(deserializer: D) -> Result<HeaderMap, D::Error>
40where
41 D: Deserializer<'de>,
42{
43 let map = std::collections::HashMap::<String, String>::deserialize(deserializer)?;
44 let mut headers = HeaderMap::new();
45 for (name, value) in map {
46 if let (Ok(header_name), Ok(header_value)) =
47 (name.parse::<HeaderName>(), value.parse::<HeaderValue>())
48 {
49 headers.insert(header_name, header_value);
50 } else {
51 warn!("Failed to parse header: {} = {}", name, value);
52 }
53 }
54 Ok(headers)
55}
56
57fn serialize_statuscode<S>(status: &StatusCode, serializer: S) -> Result<S::Ok, S::Error>
58where
59 S: Serializer,
60{
61 status.as_u16().serialize(serializer)
62}
63
64fn deserialize_statuscode<'de, D>(deserializer: D) -> Result<StatusCode, D::Error>
65where
66 D: Deserializer<'de>,
67{
68 let status_u16 = u16::deserialize(deserializer)?;
69 StatusCode::from_u16(status_u16).map_err(serde::de::Error::custom)
70}
71
72fn serialize_url<S>(url: &Url, serializer: S) -> Result<S::Ok, S::Error>
73where
74 S: Serializer,
75{
76 url.to_string().serialize(serializer)
77}
78
79fn deserialize_url<'de, D>(deserializer: D) -> Result<Url, D::Error>
80where
81 D: Deserializer<'de>,
82{
83 let s = String::deserialize(deserializer)?;
84 Url::parse(&s).map_err(serde::de::Error::custom)
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
89struct CachedResponse {
90 #[serde(serialize_with = "serialize_url", deserialize_with = "deserialize_url")]
91 url: Url,
92 #[serde(
93 serialize_with = "serialize_statuscode",
94 deserialize_with = "deserialize_statuscode"
95 )]
96 status: StatusCode,
97 #[serde(
98 serialize_with = "serialize_headermap",
99 deserialize_with = "deserialize_headermap"
100 )]
101 headers: HeaderMap,
102 body: Vec<u8>,
103 #[serde(serialize_with = "serialize_url", deserialize_with = "deserialize_url")]
104 request_url: Url,
105 #[serde(default)]
106 cached_at_unix_secs: u64,
107 #[serde(default)]
108 expires_at_unix_secs: Option<u64>,
109}
110
111impl From<Response> for CachedResponse {
112 fn from(response: Response) -> Self {
113 CachedResponse {
114 url: response.url,
115 status: response.status,
116 headers: response.headers,
117 body: response.body.to_vec(),
118 request_url: response.request_url,
119 cached_at_unix_secs: now_unix_secs(),
120 expires_at_unix_secs: None,
121 }
122 }
123}
124
125impl From<CachedResponse> for Response {
126 fn from(cached_response: CachedResponse) -> Self {
127 Response {
128 url: cached_response.url,
129 status: cached_response.status,
130 headers: cached_response.headers,
131 body: Bytes::from(cached_response.body),
132 request_url: cached_response.request_url,
133 request_priority: 0,
134 meta: Default::default(),
135 cached: true,
136 }
137 }
138}
139
140#[derive(Debug, Clone, Copy, PartialEq, Eq)]
141enum CachePolicy {
142 DoNotStore,
143 Store { expires_at_unix_secs: Option<u64> },
144}
145
146impl CachedResponse {
147 fn is_fresh_at(&self, now_unix_secs: u64) -> bool {
148 match self.expires_at_unix_secs {
149 Some(expires_at_unix_secs) => now_unix_secs < expires_at_unix_secs,
150 None => true,
151 }
152 }
153}
154
155#[derive(Default)]
157pub struct HttpCacheMiddlewareBuilder {
158 cache_dir: Option<PathBuf>,
159}
160
161impl HttpCacheMiddlewareBuilder {
162 pub fn cache_dir(mut self, path: impl Into<PathBuf>) -> Self {
164 self.cache_dir = Some(path.into());
165 self
166 }
167
168 pub fn build(self) -> Result<HttpCacheMiddleware, SpiderError> {
174 let cache_dir = if let Some(path) = self.cache_dir {
175 path
176 } else {
177 dirs::cache_dir()
178 .ok_or_else(|| {
179 SpiderError::ConfigurationError(
180 "Could not determine cache directory".to_string(),
181 )
182 })?
183 .join("spider-lib")
184 .join("http_cache")
185 };
186
187 std::fs::create_dir_all(&cache_dir)?;
188
189 let middleware = HttpCacheMiddleware { cache_dir };
190 info!(
191 "Initializing HttpCacheMiddleware with config: {:?}",
192 middleware
193 );
194
195 Ok(middleware)
196 }
197}
198
199#[derive(Debug)]
200pub struct HttpCacheMiddleware {
202 cache_dir: PathBuf,
203}
204
205impl HttpCacheMiddleware {
206 pub fn builder() -> HttpCacheMiddlewareBuilder {
208 HttpCacheMiddlewareBuilder::default()
209 }
210
211 fn get_cache_file_path(&self, fingerprint: &str) -> PathBuf {
212 self.cache_dir.join(format!("{}.bin", fingerprint))
213 }
214}
215
216fn now_unix_secs() -> u64 {
217 SystemTime::now()
218 .duration_since(UNIX_EPOCH)
219 .unwrap_or_else(|_| Duration::from_secs(0))
220 .as_secs()
221}
222
223fn parse_cache_policy(headers: &HeaderMap, cached_at_unix_secs: u64) -> CachePolicy {
224 if let Some(policy) = parse_cache_control(headers, cached_at_unix_secs) {
225 return policy;
226 }
227
228 CachePolicy::Store {
229 expires_at_unix_secs: parse_expires(headers),
230 }
231}
232
233fn parse_cache_control(headers: &HeaderMap, cached_at_unix_secs: u64) -> Option<CachePolicy> {
234 let cache_control = headers.get(CACHE_CONTROL)?.to_str().ok()?;
235 let mut max_age_secs = None;
236
237 for directive in cache_control.split(',') {
238 let directive = directive.trim();
239 if directive.eq_ignore_ascii_case("no-store") {
240 return Some(CachePolicy::DoNotStore);
241 }
242
243 let Some((name, value)) = directive.split_once('=') else {
244 continue;
245 };
246
247 if !name.trim().eq_ignore_ascii_case("max-age") {
248 continue;
249 }
250
251 let value = value.trim().trim_matches('"');
252 if let Ok(parsed) = value.parse::<u64>() {
253 max_age_secs = Some(parsed);
254 }
255 }
256
257 max_age_secs.map(|max_age_secs| CachePolicy::Store {
258 expires_at_unix_secs: cached_at_unix_secs.checked_add(max_age_secs),
259 })
260}
261
262fn parse_expires(headers: &HeaderMap) -> Option<u64> {
263 let expires = headers.get(EXPIRES)?.to_str().ok()?;
264 let parsed = OffsetDateTime::parse(expires, &Rfc2822).ok()?;
265 u64::try_from(parsed.unix_timestamp()).ok()
266}
267
268#[async_trait]
269impl<C: Send + Sync> Middleware<C> for HttpCacheMiddleware {
270 fn name(&self) -> &str {
271 "HttpCacheMiddleware"
272 }
273
274 async fn process_request(
275 &self,
276 _client: &C,
277 request: Request,
278 ) -> Result<MiddlewareAction<Request>, SpiderError> {
279 let fingerprint = request.fingerprint();
280 let cache_file_path = self.get_cache_file_path(&fingerprint);
281
282 trace!(
283 "Checking cache for request: {} (fingerprint: {})",
284 request.url, fingerprint
285 );
286 if fs::metadata(&cache_file_path).await.is_ok() {
287 debug!("Cache hit for request: {}", request.url);
288 match fs::read(&cache_file_path).await {
289 Ok(cached_bytes) => match bincode::deserialize::<CachedResponse>(&cached_bytes) {
290 Ok(cached_resp) => {
291 let now_unix_secs = now_unix_secs();
292 if !cached_resp.is_fresh_at(now_unix_secs) {
293 debug!(
294 "Cached response expired for {} at {:?}, refreshing from network",
295 request.url, cached_resp.expires_at_unix_secs
296 );
297 return Ok(MiddlewareAction::Continue(request));
298 }
299
300 trace!(
301 "Successfully deserialized cached response for {}",
302 request.url
303 );
304 let mut response: Response = cached_resp.into();
305 response.meta = request.clone_meta();
306 debug!("Returning cached response for {}", response.url);
307 return Ok(MiddlewareAction::ReturnResponse(response));
308 }
309 Err(e) => {
310 warn!(
311 "Failed to deserialize cached response from {}: {}. Deleting invalid cache file.",
312 cache_file_path.display(),
313 e
314 );
315 fs::remove_file(&cache_file_path).await.ok();
316 }
317 },
318 Err(e) => {
319 warn!(
320 "Failed to read cache file {}: {}. Deleting invalid cache file.",
321 cache_file_path.display(),
322 e
323 );
324 fs::remove_file(&cache_file_path).await.ok();
325 }
326 }
327 } else {
328 trace!(
329 "Cache miss for request: {} (no cache file found)",
330 request.url
331 );
332 }
333
334 trace!("Continuing request to downloader: {}", request.url);
335 Ok(MiddlewareAction::Continue(request))
336 }
337
338 async fn process_response(
339 &self,
340 response: Response,
341 ) -> Result<MiddlewareAction<Response>, SpiderError> {
342 trace!(
343 "Processing response for caching: {} with status: {}",
344 response.url, response.status
345 );
346
347 if response.status.is_success() {
349 let original_request_fingerprint = response.request_from_response().fingerprint();
350 let cache_file_path = self.get_cache_file_path(&original_request_fingerprint);
351 let cached_at_unix_secs = now_unix_secs();
352 let cache_policy = parse_cache_policy(&response.headers, cached_at_unix_secs);
353
354 if matches!(cache_policy, CachePolicy::DoNotStore) {
355 debug!(
356 "Skipping cache storage for {} due to Cache-Control: no-store",
357 response.url
358 );
359 return Ok(MiddlewareAction::Continue(response));
360 }
361
362 trace!(
363 "Serializing response for caching to: {}",
364 cache_file_path.display()
365 );
366 let mut cached_response: CachedResponse = response.clone().into();
367 cached_response.cached_at_unix_secs = cached_at_unix_secs;
368 cached_response.expires_at_unix_secs = match cache_policy {
369 CachePolicy::Store {
370 expires_at_unix_secs,
371 } => expires_at_unix_secs,
372 CachePolicy::DoNotStore => None,
373 };
374 match bincode::serialize(&cached_response) {
375 Ok(serialized_bytes) => {
376 let bytes_count = serialized_bytes.len();
377 trace!(
378 "Writing {} bytes to cache file: {}",
379 bytes_count,
380 cache_file_path.display()
381 );
382 fs::write(&cache_file_path, serialized_bytes)
383 .await
384 .map_err(|e| SpiderError::IoError(e.to_string()))?;
385 debug!(
386 "Cached response for {} ({} bytes)",
387 response.url, bytes_count
388 );
389 }
390 Err(e) => {
391 warn!(
392 "Failed to serialize response for caching {}: {}",
393 response.url, e
394 );
395 }
396 }
397 } else {
398 trace!(
399 "Response status {} is not successful, skipping cache for: {}",
400 response.status, response.url
401 );
402 }
403
404 trace!("Continuing response: {}", response.url);
405 Ok(MiddlewareAction::Continue(response))
406 }
407}