1use crate::pipeline::Pipeline;
4use async_trait::async_trait;
5use dashmap::DashSet;
6use log::{debug, info};
7use seahash::SeaHasher;
8use serde_json::Value;
9use spider_util::{error::PipelineError, item::ScrapedItem};
10use std::collections::HashSet;
11use std::hash::Hasher;
12use std::marker::PhantomData;
13
14pub struct DeduplicationPipeline<I: ScrapedItem> {
16 unique_fields: Vec<String>,
17 seen_hashes: DashSet<u64>,
18 _phantom: PhantomData<I>,
19}
20
21impl<I: ScrapedItem> DeduplicationPipeline<I> {
22 pub fn new<F, S>(unique_fields: F) -> Self
24 where
25 F: IntoIterator<Item = S>,
26 S: AsRef<str>,
27 {
28 let unique_fields: Vec<String> = unique_fields
29 .into_iter()
30 .map(|field| field.as_ref().to_string())
31 .collect();
32 info!(
33 "Initializing DeduplicationPipeline with unique fields: {:?}",
34 unique_fields
35 );
36 DeduplicationPipeline {
37 unique_fields,
38 seen_hashes: DashSet::new(),
39 _phantom: PhantomData,
40 }
41 }
42
43 fn generate_hash(&self, item: &I) -> Result<u64, PipelineError> {
45 let item_value = item.to_json_value();
46 let mut hasher = SeaHasher::new();
47
48 if let Some(map) = item_value.as_object() {
49 for field_name in &self.unique_fields {
50 if let Some(value) = map.get(field_name) {
51 hasher.write(field_name.as_bytes());
52
53 if let Some(str_val) = value.as_str() {
54 hasher.write(str_val.as_bytes());
55 } else {
56 hasher.write(value.to_string().as_bytes());
57 };
58 } else {
59 hasher.write(field_name.as_bytes());
60 hasher.write("".as_bytes());
61 }
62 }
63 } else {
64 return Err(PipelineError::ItemError(
65 "Item for deduplication must be a JSON object.".to_string(),
66 ));
67 }
68 Ok(hasher.finish())
69 }
70}
71
72#[async_trait]
73impl<I: ScrapedItem> Pipeline<I> for DeduplicationPipeline<I> {
74 fn name(&self) -> &str {
75 "DeduplicationPipeline"
76 }
77
78 async fn process_item(&self, item: I) -> Result<Option<I>, PipelineError> {
79 debug!("DeduplicationPipeline processing item.");
80
81 let item_hash = self.generate_hash(&item)?;
82
83 if self.seen_hashes.insert(item_hash) {
84 debug!("Unique item, passing through: {:?}", item);
85 Ok(Some(item))
86 } else {
87 debug!("Duplicate item detected, dropping: {:?}", item);
88 Ok(None)
89 }
90 }
91
92 async fn get_state(&self) -> Result<Option<Value>, PipelineError> {
93 let hashes: HashSet<u64> = self.seen_hashes.iter().map(|r| *r).collect();
94 let state = serde_json::to_value(hashes).map_err(|e| {
95 PipelineError::Other(format!("Failed to serialize deduplication state: {}", e))
96 })?;
97 Ok(Some(state))
98 }
99
100 async fn restore_state(&self, state: Value) -> Result<(), PipelineError> {
101 let hashes: HashSet<u64> = serde_json::from_value(state).map_err(|e| {
102 PipelineError::Other(format!("Failed to deserialize deduplication state: {}", e))
103 })?;
104
105 self.seen_hashes.clear();
106 for hash in hashes {
107 self.seen_hashes.insert(hash);
108 }
109
110 info!(
111 "Restored {} seen items in DeduplicationPipeline.",
112 self.seen_hashes.len()
113 );
114
115 Ok(())
116 }
117}