spider_pipeline/
dedup.rs

1//! Pipeline that drops duplicate items by selected fields.
2
3use crate::pipeline::Pipeline;
4use async_trait::async_trait;
5use dashmap::DashSet;
6use log::{debug, info};
7use seahash::SeaHasher;
8use serde_json::Value;
9use spider_util::{error::PipelineError, item::ScrapedItem};
10use std::collections::HashSet;
11use std::hash::Hasher;
12use std::marker::PhantomData;
13
14/// Pipeline that filters duplicate items based on a configurable field set.
15pub struct DeduplicationPipeline<I: ScrapedItem> {
16    unique_fields: Vec<String>,
17    seen_hashes: DashSet<u64>,
18    _phantom: PhantomData<I>,
19}
20
21impl<I: ScrapedItem> DeduplicationPipeline<I> {
22    /// Creates a new `DeduplicationPipeline` with a specified set of unique fields.
23    pub fn new<F, S>(unique_fields: F) -> Self
24    where
25        F: IntoIterator<Item = S>,
26        S: AsRef<str>,
27    {
28        let unique_fields: Vec<String> = unique_fields
29            .into_iter()
30            .map(|field| field.as_ref().to_string())
31            .collect();
32        info!(
33            "Initializing DeduplicationPipeline with unique fields: {:?}",
34            unique_fields
35        );
36        DeduplicationPipeline {
37            unique_fields,
38            seen_hashes: DashSet::new(),
39            _phantom: PhantomData,
40        }
41    }
42
43    /// Generates a hash for an item based on its unique fields.
44    fn generate_hash(&self, item: &I) -> Result<u64, PipelineError> {
45        let item_value = item.to_json_value();
46        let mut hasher = SeaHasher::new();
47
48        if let Some(map) = item_value.as_object() {
49            for field_name in &self.unique_fields {
50                if let Some(value) = map.get(field_name) {
51                    hasher.write(field_name.as_bytes());
52
53                    if let Some(str_val) = value.as_str() {
54                        hasher.write(str_val.as_bytes());
55                    } else {
56                        hasher.write(value.to_string().as_bytes());
57                    };
58                } else {
59                    hasher.write(field_name.as_bytes());
60                    hasher.write("".as_bytes());
61                }
62            }
63        } else {
64            return Err(PipelineError::ItemError(
65                "Item for deduplication must be a JSON object.".to_string(),
66            ));
67        }
68        Ok(hasher.finish())
69    }
70}
71
72#[async_trait]
73impl<I: ScrapedItem> Pipeline<I> for DeduplicationPipeline<I> {
74    fn name(&self) -> &str {
75        "DeduplicationPipeline"
76    }
77
78    async fn process_item(&self, item: I) -> Result<Option<I>, PipelineError> {
79        debug!("DeduplicationPipeline processing item.");
80
81        let item_hash = self.generate_hash(&item)?;
82
83        if self.seen_hashes.insert(item_hash) {
84            debug!("Unique item, passing through: {:?}", item);
85            Ok(Some(item))
86        } else {
87            debug!("Duplicate item detected, dropping: {:?}", item);
88            Ok(None)
89        }
90    }
91
92    async fn get_state(&self) -> Result<Option<Value>, PipelineError> {
93        let hashes: HashSet<u64> = self.seen_hashes.iter().map(|r| *r).collect();
94        let state = serde_json::to_value(hashes).map_err(|e| {
95            PipelineError::Other(format!("Failed to serialize deduplication state: {}", e))
96        })?;
97        Ok(Some(state))
98    }
99
100    async fn restore_state(&self, state: Value) -> Result<(), PipelineError> {
101        let hashes: HashSet<u64> = serde_json::from_value(state).map_err(|e| {
102            PipelineError::Other(format!("Failed to deserialize deduplication state: {}", e))
103        })?;
104
105        self.seen_hashes.clear();
106        for hash in hashes {
107            self.seen_hashes.insert(hash);
108        }
109
110        info!(
111            "Restored {} seen items in DeduplicationPipeline.",
112            self.seen_hashes.len()
113        );
114
115        Ok(())
116    }
117}