1use crate::pipeline::Pipeline;
4use async_trait::async_trait;
5use serde_json::{Map, Value};
6use spider_util::error::PipelineError;
7use spider_util::item::{
8 FieldValueType, ItemFieldSchema, ItemSchema, ScrapedItem, TypedItemSchema,
9};
10use std::collections::BTreeMap;
11use std::marker::PhantomData;
12use std::sync::Arc;
13
14type SchemaValidatorFn<I> = dyn Fn(&I, &ItemSchema, &Value) -> Result<(), String> + Send + Sync;
15type SchemaTransformFn<I> = dyn Fn(I) -> Result<I, String> + Send + Sync;
16
17#[derive(Debug, Clone, PartialEq, Eq)]
19pub struct SchemaViolation {
20 pub field: Option<String>,
21 pub message: String,
22}
23
24#[derive(Debug, Clone, Default)]
26pub struct SchemaExportConfig {
27 field_aliases: BTreeMap<String, String>,
28 schema_version_field: Option<String>,
29 inject_nulls_for_missing_optional: bool,
30}
31
32impl SchemaExportConfig {
33 pub fn new() -> Self {
34 Self::default()
35 }
36
37 pub fn with_field_alias(
38 mut self,
39 field_name: impl Into<String>,
40 export_name: impl Into<String>,
41 ) -> Self {
42 self.field_aliases
43 .insert(field_name.into(), export_name.into());
44 self
45 }
46
47 pub fn with_schema_version_field(mut self, field_name: impl Into<String>) -> Self {
48 self.schema_version_field = Some(field_name.into());
49 self
50 }
51
52 pub fn inject_nulls_for_missing_optional(mut self, enabled: bool) -> Self {
53 self.inject_nulls_for_missing_optional = enabled;
54 self
55 }
56
57 pub fn export_name_for<'a>(&'a self, field_name: &'a str) -> &'a str {
58 self.field_aliases
59 .get(field_name)
60 .map(String::as_str)
61 .unwrap_or(field_name)
62 }
63}
64
65pub struct SchemaValidationPipeline<I: ScrapedItem + TypedItemSchema> {
67 validators: Vec<Arc<SchemaValidatorFn<I>>>,
68 expected_schema_version: Option<u32>,
69 _phantom: PhantomData<I>,
70}
71
72impl<I: ScrapedItem + TypedItemSchema> SchemaValidationPipeline<I> {
73 pub fn new() -> Self {
74 Self {
75 validators: Vec::new(),
76 expected_schema_version: None,
77 _phantom: PhantomData,
78 }
79 }
80
81 pub fn expect_schema_version(mut self, version: u32) -> Self {
82 self.expected_schema_version = Some(version);
83 self
84 }
85
86 pub fn with_validator<F>(mut self, validator: F) -> Self
87 where
88 F: Fn(&I, &ItemSchema, &Value) -> Result<(), String> + Send + Sync + 'static,
89 {
90 self.validators.push(Arc::new(validator));
91 self
92 }
93}
94
95impl<I: ScrapedItem + TypedItemSchema> Default for SchemaValidationPipeline<I> {
96 fn default() -> Self {
97 Self::new()
98 }
99}
100
101#[async_trait]
102impl<I: ScrapedItem + TypedItemSchema> Pipeline<I> for SchemaValidationPipeline<I> {
103 fn name(&self) -> &str {
104 "SchemaValidationPipeline"
105 }
106
107 async fn process_item(&self, item: I) -> Result<Option<I>, PipelineError> {
108 let schema = I::schema();
109 if let Some(expected_version) = self.expected_schema_version
110 && schema.version != expected_version
111 {
112 return Ok(None);
113 }
114
115 let json = item.to_json_value();
116 if validate_value_against_schema(&schema, &json).is_err() {
117 return Ok(None);
118 }
119
120 for validator in &self.validators {
121 if validator(&item, &schema, &json).is_err() {
122 return Ok(None);
123 }
124 }
125
126 Ok(Some(item))
127 }
128}
129
130pub struct SchemaTransformPipeline<I: ScrapedItem + TypedItemSchema> {
132 transforms: Vec<Arc<SchemaTransformFn<I>>>,
133 _phantom: PhantomData<I>,
134}
135
136impl<I: ScrapedItem + TypedItemSchema> SchemaTransformPipeline<I> {
137 pub fn new() -> Self {
138 Self {
139 transforms: Vec::new(),
140 _phantom: PhantomData,
141 }
142 }
143
144 pub fn with_transform<F>(mut self, transform: F) -> Self
145 where
146 F: Fn(I) -> Result<I, String> + Send + Sync + 'static,
147 {
148 self.transforms.push(Arc::new(transform));
149 self
150 }
151}
152
153impl<I: ScrapedItem + TypedItemSchema> Default for SchemaTransformPipeline<I> {
154 fn default() -> Self {
155 Self::new()
156 }
157}
158
159#[async_trait]
160impl<I: ScrapedItem + TypedItemSchema> Pipeline<I> for SchemaTransformPipeline<I> {
161 fn name(&self) -> &str {
162 "SchemaTransformPipeline"
163 }
164
165 async fn process_item(&self, mut item: I) -> Result<Option<I>, PipelineError> {
166 for transform in &self.transforms {
167 item = transform(item).map_err(PipelineError::ItemError)?;
168 }
169
170 Ok(Some(item))
171 }
172}
173
174pub fn export_schema_for_item<I: ScrapedItem>(
175 item: &I,
176 config: Option<&SchemaExportConfig>,
177) -> Option<Vec<ItemFieldSchema>> {
178 let schema = item.item_schema()?;
179 Some(
180 schema
181 .fields
182 .iter()
183 .map(|field| ItemFieldSchema {
184 name: config
185 .map(|cfg| cfg.export_name_for(&field.name).to_string())
186 .unwrap_or_else(|| field.name.clone()),
187 rust_type: field.rust_type.clone(),
188 value_type: field.value_type.clone(),
189 nullable: field.nullable,
190 })
191 .collect(),
192 )
193}
194
195pub fn map_item_for_export<I: ScrapedItem>(item: &I, config: Option<&SchemaExportConfig>) -> Value {
196 let raw = item.to_json_value();
197 let Some(schema) = item.item_schema() else {
198 return raw;
199 };
200 let Some(source) = raw.as_object() else {
201 return raw;
202 };
203
204 let mut output = Map::new();
205 for field in &schema.fields {
206 let export_name = config
207 .map(|cfg| cfg.export_name_for(&field.name).to_string())
208 .unwrap_or_else(|| field.name.clone());
209 match source.get(&field.name) {
210 Some(value) => {
211 output.insert(export_name, value.clone());
212 }
213 None if field.nullable
214 && config
215 .map(|cfg| cfg.inject_nulls_for_missing_optional)
216 .unwrap_or(false) =>
217 {
218 output.insert(export_name, Value::Null);
219 }
220 None => {}
221 }
222 }
223
224 if let Some(version_field) = config.and_then(|cfg| cfg.schema_version_field.as_ref()) {
225 output.insert(
226 version_field.clone(),
227 Value::from(item.item_schema_version()),
228 );
229 }
230
231 Value::Object(output)
232}
233
234pub fn sqlite_type_for_field(field: &ItemFieldSchema) -> &'static str {
235 match field.value_type {
236 FieldValueType::Bool | FieldValueType::Integer => "INTEGER",
237 FieldValueType::Float => "REAL",
238 FieldValueType::String
239 | FieldValueType::Json
240 | FieldValueType::Sequence
241 | FieldValueType::Map
242 | FieldValueType::Unknown => "TEXT",
243 }
244}
245
246fn validate_value_against_schema(schema: &ItemSchema, json: &Value) -> Result<(), SchemaViolation> {
247 let map = json.as_object().ok_or_else(|| SchemaViolation {
248 field: None,
249 message: "Item must serialize to a JSON object for schema validation.".to_string(),
250 })?;
251
252 for field in &schema.fields {
253 let value = map.get(&field.name);
254 if value.is_none() && !field.nullable {
255 return Err(SchemaViolation {
256 field: Some(field.name.clone()),
257 message: format!("Missing non-nullable field '{}'.", field.name),
258 });
259 }
260
261 if let Some(value) = value {
262 if value.is_null() && !field.nullable {
263 return Err(SchemaViolation {
264 field: Some(field.name.clone()),
265 message: format!("Field '{}' cannot be null.", field.name),
266 });
267 }
268
269 if !matches_field_type(field, value) {
270 return Err(SchemaViolation {
271 field: Some(field.name.clone()),
272 message: format!(
273 "Field '{}' does not match declared schema type '{}'.",
274 field.name, field.rust_type
275 ),
276 });
277 }
278 }
279 }
280
281 Ok(())
282}
283
284fn matches_field_type(field: &ItemFieldSchema, value: &Value) -> bool {
285 if value.is_null() {
286 return field.nullable;
287 }
288
289 match field.value_type {
290 FieldValueType::Bool => value.is_boolean(),
291 FieldValueType::Integer => value.as_i64().is_some() || value.as_u64().is_some(),
292 FieldValueType::Float => value.is_number(),
293 FieldValueType::String => value.is_string(),
294 FieldValueType::Json => true,
295 FieldValueType::Sequence => value.is_array(),
296 FieldValueType::Map => value.is_object(),
297 FieldValueType::Unknown => true,
298 }
299}