spider_middleware/
http_cache.rs

1//! On-disk HTTP response cache middleware.
2//!
3//! This middleware stores successful responses by request fingerprint and can
4//! short-circuit later requests by returning cached responses directly.
5//! Cache freshness is evaluated per response from HTTP caching headers.
6
7use async_trait::async_trait;
8use bytes::Bytes;
9use log::{debug, info, trace, warn};
10use reqwest::StatusCode;
11use reqwest::header::{CACHE_CONTROL, EXPIRES, HeaderMap, HeaderName, HeaderValue};
12use std::path::PathBuf;
13use std::time::{Duration, SystemTime, UNIX_EPOCH};
14use time::OffsetDateTime;
15use time::format_description::well_known::Rfc2822;
16use tokio::fs;
17
18use crate::middleware::{Middleware, MiddlewareAction};
19use serde::{Deserialize, Deserializer, Serialize, Serializer};
20use spider_util::error::SpiderError;
21use spider_util::request::Request;
22use spider_util::response::Response;
23use url::Url;
24
25fn serialize_headermap<S>(headers: &HeaderMap, serializer: S) -> Result<S::Ok, S::Error>
26where
27    S: Serializer,
28{
29    let mut map = std::collections::HashMap::<String, String>::new();
30    for (name, value) in headers.iter() {
31        map.insert(
32            name.to_string(),
33            value.to_str().unwrap_or_default().to_string(),
34        );
35    }
36    map.serialize(serializer)
37}
38
39fn deserialize_headermap<'de, D>(deserializer: D) -> Result<HeaderMap, D::Error>
40where
41    D: Deserializer<'de>,
42{
43    let map = std::collections::HashMap::<String, String>::deserialize(deserializer)?;
44    let mut headers = HeaderMap::new();
45    for (name, value) in map {
46        if let (Ok(header_name), Ok(header_value)) =
47            (name.parse::<HeaderName>(), value.parse::<HeaderValue>())
48        {
49            headers.insert(header_name, header_value);
50        } else {
51            warn!("Failed to parse header: {} = {}", name, value);
52        }
53    }
54    Ok(headers)
55}
56
57fn serialize_statuscode<S>(status: &StatusCode, serializer: S) -> Result<S::Ok, S::Error>
58where
59    S: Serializer,
60{
61    status.as_u16().serialize(serializer)
62}
63
64fn deserialize_statuscode<'de, D>(deserializer: D) -> Result<StatusCode, D::Error>
65where
66    D: Deserializer<'de>,
67{
68    let status_u16 = u16::deserialize(deserializer)?;
69    StatusCode::from_u16(status_u16).map_err(serde::de::Error::custom)
70}
71
72fn serialize_url<S>(url: &Url, serializer: S) -> Result<S::Ok, S::Error>
73where
74    S: Serializer,
75{
76    url.to_string().serialize(serializer)
77}
78
79fn deserialize_url<'de, D>(deserializer: D) -> Result<Url, D::Error>
80where
81    D: Deserializer<'de>,
82{
83    let s = String::deserialize(deserializer)?;
84    Url::parse(&s).map_err(serde::de::Error::custom)
85}
86
87/// Serialized response data used for cache storage.
88#[derive(Debug, Clone, Serialize, Deserialize)]
89struct CachedResponse {
90    #[serde(serialize_with = "serialize_url", deserialize_with = "deserialize_url")]
91    url: Url,
92    #[serde(
93        serialize_with = "serialize_statuscode",
94        deserialize_with = "deserialize_statuscode"
95    )]
96    status: StatusCode,
97    #[serde(
98        serialize_with = "serialize_headermap",
99        deserialize_with = "deserialize_headermap"
100    )]
101    headers: HeaderMap,
102    body: Vec<u8>,
103    #[serde(serialize_with = "serialize_url", deserialize_with = "deserialize_url")]
104    request_url: Url,
105    #[serde(default)]
106    cached_at_unix_secs: u64,
107    #[serde(default)]
108    expires_at_unix_secs: Option<u64>,
109}
110
111impl From<Response> for CachedResponse {
112    fn from(response: Response) -> Self {
113        CachedResponse {
114            url: response.url,
115            status: response.status,
116            headers: response.headers,
117            body: response.body.to_vec(),
118            request_url: response.request_url,
119            cached_at_unix_secs: now_unix_secs(),
120            expires_at_unix_secs: None,
121        }
122    }
123}
124
125impl From<CachedResponse> for Response {
126    fn from(cached_response: CachedResponse) -> Self {
127        Response {
128            url: cached_response.url,
129            status: cached_response.status,
130            headers: cached_response.headers,
131            body: Bytes::from(cached_response.body),
132            request_url: cached_response.request_url,
133            request_priority: 0,
134            meta: Default::default(),
135            cached: true,
136        }
137    }
138}
139
140#[derive(Debug, Clone, Copy, PartialEq, Eq)]
141enum CachePolicy {
142    DoNotStore,
143    Store { expires_at_unix_secs: Option<u64> },
144}
145
146impl CachedResponse {
147    fn is_fresh_at(&self, now_unix_secs: u64) -> bool {
148        match self.expires_at_unix_secs {
149            Some(expires_at_unix_secs) => now_unix_secs < expires_at_unix_secs,
150            None => true,
151        }
152    }
153}
154
155/// Builder for [`HttpCacheMiddleware`].
156#[derive(Default)]
157pub struct HttpCacheMiddlewareBuilder {
158    cache_dir: Option<PathBuf>,
159}
160
161impl HttpCacheMiddlewareBuilder {
162    /// Sets the directory where cache files will be stored.
163    pub fn cache_dir(mut self, path: impl Into<PathBuf>) -> Self {
164        self.cache_dir = Some(path.into());
165        self
166    }
167
168    /// Builds the `HttpCacheMiddleware`.
169    ///
170    /// # Errors
171    ///
172    /// Returns an error if the cache directory cannot be resolved or created.
173    pub fn build(self) -> Result<HttpCacheMiddleware, SpiderError> {
174        let cache_dir = if let Some(path) = self.cache_dir {
175            path
176        } else {
177            dirs::cache_dir()
178                .ok_or_else(|| {
179                    SpiderError::ConfigurationError(
180                        "Could not determine cache directory".to_string(),
181                    )
182                })?
183                .join("spider-lib")
184                .join("http_cache")
185        };
186
187        std::fs::create_dir_all(&cache_dir)?;
188
189        let middleware = HttpCacheMiddleware { cache_dir };
190        info!(
191            "Initializing HttpCacheMiddleware with config: {:?}",
192            middleware
193        );
194
195        Ok(middleware)
196    }
197}
198
199#[derive(Debug)]
200/// Middleware that caches successful HTTP responses on disk.
201pub struct HttpCacheMiddleware {
202    cache_dir: PathBuf,
203}
204
205impl HttpCacheMiddleware {
206    /// Creates a new `HttpCacheMiddlewareBuilder` to start building an `HttpCacheMiddleware`.
207    pub fn builder() -> HttpCacheMiddlewareBuilder {
208        HttpCacheMiddlewareBuilder::default()
209    }
210
211    fn get_cache_file_path(&self, fingerprint: &str) -> PathBuf {
212        self.cache_dir.join(format!("{}.bin", fingerprint))
213    }
214}
215
216fn now_unix_secs() -> u64 {
217    SystemTime::now()
218        .duration_since(UNIX_EPOCH)
219        .unwrap_or_else(|_| Duration::from_secs(0))
220        .as_secs()
221}
222
223fn parse_cache_policy(headers: &HeaderMap, cached_at_unix_secs: u64) -> CachePolicy {
224    if let Some(policy) = parse_cache_control(headers, cached_at_unix_secs) {
225        return policy;
226    }
227
228    CachePolicy::Store {
229        expires_at_unix_secs: parse_expires(headers),
230    }
231}
232
233fn parse_cache_control(headers: &HeaderMap, cached_at_unix_secs: u64) -> Option<CachePolicy> {
234    let cache_control = headers.get(CACHE_CONTROL)?.to_str().ok()?;
235    let mut max_age_secs = None;
236
237    for directive in cache_control.split(',') {
238        let directive = directive.trim();
239        if directive.eq_ignore_ascii_case("no-store") {
240            return Some(CachePolicy::DoNotStore);
241        }
242
243        let Some((name, value)) = directive.split_once('=') else {
244            continue;
245        };
246
247        if !name.trim().eq_ignore_ascii_case("max-age") {
248            continue;
249        }
250
251        let value = value.trim().trim_matches('"');
252        if let Ok(parsed) = value.parse::<u64>() {
253            max_age_secs = Some(parsed);
254        }
255    }
256
257    max_age_secs.map(|max_age_secs| CachePolicy::Store {
258        expires_at_unix_secs: cached_at_unix_secs.checked_add(max_age_secs),
259    })
260}
261
262fn parse_expires(headers: &HeaderMap) -> Option<u64> {
263    let expires = headers.get(EXPIRES)?.to_str().ok()?;
264    let parsed = OffsetDateTime::parse(expires, &Rfc2822).ok()?;
265    u64::try_from(parsed.unix_timestamp()).ok()
266}
267
268#[async_trait]
269impl<C: Send + Sync> Middleware<C> for HttpCacheMiddleware {
270    fn name(&self) -> &str {
271        "HttpCacheMiddleware"
272    }
273
274    async fn process_request(
275        &self,
276        _client: &C,
277        request: Request,
278    ) -> Result<MiddlewareAction<Request>, SpiderError> {
279        let fingerprint = request.fingerprint();
280        let cache_file_path = self.get_cache_file_path(&fingerprint);
281
282        trace!(
283            "Checking cache for request: {} (fingerprint: {})",
284            request.url, fingerprint
285        );
286        if fs::metadata(&cache_file_path).await.is_ok() {
287            debug!("Cache hit for request: {}", request.url);
288            match fs::read(&cache_file_path).await {
289                Ok(cached_bytes) => match bincode::deserialize::<CachedResponse>(&cached_bytes) {
290                    Ok(cached_resp) => {
291                        let now_unix_secs = now_unix_secs();
292                        if !cached_resp.is_fresh_at(now_unix_secs) {
293                            debug!(
294                                "Cached response expired for {} at {:?}, refreshing from network",
295                                request.url, cached_resp.expires_at_unix_secs
296                            );
297                            return Ok(MiddlewareAction::Continue(request));
298                        }
299
300                        trace!(
301                            "Successfully deserialized cached response for {}",
302                            request.url
303                        );
304                        let mut response: Response = cached_resp.into();
305                        response.meta = request.clone_meta();
306                        debug!("Returning cached response for {}", response.url);
307                        return Ok(MiddlewareAction::ReturnResponse(response));
308                    }
309                    Err(e) => {
310                        warn!(
311                            "Failed to deserialize cached response from {}: {}. Deleting invalid cache file.",
312                            cache_file_path.display(),
313                            e
314                        );
315                        fs::remove_file(&cache_file_path).await.ok();
316                    }
317                },
318                Err(e) => {
319                    warn!(
320                        "Failed to read cache file {}: {}. Deleting invalid cache file.",
321                        cache_file_path.display(),
322                        e
323                    );
324                    fs::remove_file(&cache_file_path).await.ok();
325                }
326            }
327        } else {
328            trace!(
329                "Cache miss for request: {} (no cache file found)",
330                request.url
331            );
332        }
333
334        trace!("Continuing request to downloader: {}", request.url);
335        Ok(MiddlewareAction::Continue(request))
336    }
337
338    async fn process_response(
339        &self,
340        response: Response,
341    ) -> Result<MiddlewareAction<Response>, SpiderError> {
342        trace!(
343            "Processing response for caching: {} with status: {}",
344            response.url, response.status
345        );
346
347        // Only cache successful responses (e.g., 200 OK)
348        if response.status.is_success() {
349            let original_request_fingerprint = response.request_from_response().fingerprint();
350            let cache_file_path = self.get_cache_file_path(&original_request_fingerprint);
351            let cached_at_unix_secs = now_unix_secs();
352            let cache_policy = parse_cache_policy(&response.headers, cached_at_unix_secs);
353
354            if matches!(cache_policy, CachePolicy::DoNotStore) {
355                debug!(
356                    "Skipping cache storage for {} due to Cache-Control: no-store",
357                    response.url
358                );
359                return Ok(MiddlewareAction::Continue(response));
360            }
361
362            trace!(
363                "Serializing response for caching to: {}",
364                cache_file_path.display()
365            );
366            let mut cached_response: CachedResponse = response.clone().into();
367            cached_response.cached_at_unix_secs = cached_at_unix_secs;
368            cached_response.expires_at_unix_secs = match cache_policy {
369                CachePolicy::Store {
370                    expires_at_unix_secs,
371                } => expires_at_unix_secs,
372                CachePolicy::DoNotStore => None,
373            };
374            match bincode::serialize(&cached_response) {
375                Ok(serialized_bytes) => {
376                    let bytes_count = serialized_bytes.len();
377                    trace!(
378                        "Writing {} bytes to cache file: {}",
379                        bytes_count,
380                        cache_file_path.display()
381                    );
382                    fs::write(&cache_file_path, serialized_bytes)
383                        .await
384                        .map_err(|e| SpiderError::IoError(e.to_string()))?;
385                    debug!(
386                        "Cached response for {} ({} bytes)",
387                        response.url, bytes_count
388                    );
389                }
390                Err(e) => {
391                    warn!(
392                        "Failed to serialize response for caching {}: {}",
393                        response.url, e
394                    );
395                }
396            }
397        } else {
398            trace!(
399                "Response status {} is not successful, skipping cache for: {}",
400                response.status, response.url
401            );
402        }
403
404        trace!("Continuing response: {}", response.url);
405        Ok(MiddlewareAction::Continue(response))
406    }
407}