spider_util/
selector.rs

1//! Cached CSS selector helpers.
2//!
3//! HTML-heavy crawls often reuse the same selectors across thousands of pages.
4//! This module keeps compiled selectors cached so repeated parsing work stays low.
5
6use crate::error::SpiderError;
7use ego_tree::NodeId;
8use ego_tree::iter::Children;
9use once_cell::sync::Lazy;
10use parking_lot::RwLock;
11use scraper::{ElementRef, Html, Selector};
12use std::cell::RefCell;
13use std::collections::{HashMap, HashSet};
14use std::sync::Arc;
15
16// Global selector cache to avoid repeated compilation
17static SELECTOR_CACHE: Lazy<RwLock<HashMap<String, Selector>>> =
18    Lazy::new(|| RwLock::new(HashMap::new()));
19static COMPILED_SELECTOR_CACHE: Lazy<RwLock<HashMap<String, CompiledSelector>>> =
20    Lazy::new(|| RwLock::new(HashMap::new()));
21
22thread_local! {
23    static DOCUMENT_CACHE: RefCell<HashMap<u64, (Arc<str>, Arc<Html>)>> = RefCell::new(HashMap::new());
24}
25
26#[derive(Debug, Clone, PartialEq, Eq)]
27pub(crate) enum ExtractionKind {
28    Element,
29    Text,
30    Attr(String),
31}
32
33#[derive(Debug, Clone)]
34pub(crate) struct CompiledSelector {
35    selector: Selector,
36    extraction: ExtractionKind,
37}
38
39impl CompiledSelector {
40    pub(crate) fn selector(&self) -> &Selector {
41        &self.selector
42    }
43
44    pub(crate) fn extraction(&self) -> &ExtractionKind {
45        &self.extraction
46    }
47}
48
49/// A node selected from an HTML document using the builtin CSS selector API.
50#[derive(Debug, Clone)]
51pub struct SelectorNode {
52    document_html: Arc<str>,
53    document_hash: u64,
54    path: Arc<[usize]>,
55    extraction: ExtractionKind,
56}
57
58/// A Scrapy-like selection result list.
59#[derive(Debug, Clone)]
60pub struct SelectorList {
61    document_html: Arc<str>,
62    document_hash: u64,
63    paths: Vec<Arc<[usize]>>,
64    extraction: ExtractionKind,
65}
66
67fn assert_selector_types_are_send_sync() {
68    fn assert_traits<T: Send + Sync>() {}
69
70    assert_traits::<SelectorNode>();
71    assert_traits::<SelectorList>();
72}
73
74const _: fn() = assert_selector_types_are_send_sync;
75
76impl SelectorNode {
77    pub(crate) fn new(
78        document_html: Arc<str>,
79        document_hash: u64,
80        path: Arc<[usize]>,
81        extraction: ExtractionKind,
82    ) -> Self {
83        Self {
84            document_html,
85            document_hash,
86            path,
87            extraction,
88        }
89    }
90
91    /// Applies a CSS selector relative to this node.
92    ///
93    /// # Errors
94    ///
95    /// Returns [`SpiderError::HtmlParseError`] when the selector is invalid or
96    /// when chaining from a text/attribute extraction.
97    pub fn css(&self, query: &str) -> Result<SelectorList, SpiderError> {
98        if self.extraction != ExtractionKind::Element {
99            return Err(SpiderError::HtmlParseError(
100                "css() can only be chained from element selections".to_string(),
101            ));
102        }
103
104        let compiled = get_cached_compiled_selector(query)?;
105        with_document(
106            self.document_hash,
107            &self.document_html,
108            |document| -> Result<SelectorList, SpiderError> {
109                let Some(scope) = self.element_ref(document) else {
110                    return Ok(SelectorList::empty(
111                        self.document_html.clone(),
112                        self.document_hash,
113                        compiled.extraction().clone(),
114                    ));
115                };
116
117                let paths = scope
118                    .select(compiled.selector())
119                    .map(|element| node_path(document, element.id()))
120                    .collect();
121
122                Ok(SelectorList::new(
123                    self.document_html.clone(),
124                    self.document_hash,
125                    paths,
126                    compiled.extraction().clone(),
127                ))
128            },
129        )
130    }
131
132    /// Returns the extracted value for this node, if present.
133    pub fn get(&self) -> Option<String> {
134        with_document(self.document_hash, &self.document_html, |document| {
135            self.element_ref(document)
136                .and_then(|element| extract_element_value(element, &self.extraction))
137        })
138    }
139
140    /// Returns this node's extracted value as a single-element vector or an empty one.
141    pub fn get_all(&self) -> Vec<String> {
142        self.get().into_iter().collect()
143    }
144
145    /// Returns the named attribute from the selected element.
146    pub fn attrib(&self, name: &str) -> Option<String> {
147        with_document(self.document_hash, &self.document_html, |document| {
148            self.element_ref(document)
149                .and_then(|element| element.attr(name).map(ToOwned::to_owned))
150        })
151    }
152
153    /// Returns the concatenated text content of the selected element.
154    pub fn text_content(&self) -> Option<String> {
155        with_document(self.document_hash, &self.document_html, |document| {
156            self.element_ref(document)
157                .map(|element| element.text().collect::<String>())
158        })
159    }
160
161    /// Returns `true` when this element has any descendant matching `query`.
162    ///
163    /// # Errors
164    ///
165    /// Returns [`SpiderError::HtmlParseError`] when the selector is invalid or
166    /// when called on a text/attribute extraction.
167    pub fn has_css(&self, query: &str) -> Result<bool, SpiderError> {
168        Ok(!self.css(query)?.is_empty())
169    }
170
171    /// Returns `true` when any ancestor of this element matches `query`.
172    ///
173    /// # Errors
174    ///
175    /// Returns [`SpiderError::HtmlParseError`] when the selector is invalid or
176    /// when called on a text/attribute extraction.
177    pub fn has_ancestor(&self, query: &str) -> Result<bool, SpiderError> {
178        let selector =
179            Selector::parse(query).map_err(|e| SpiderError::HtmlParseError(e.to_string()))?;
180        with_document(
181            self.document_hash,
182            &self.document_html,
183            |document| -> Result<bool, SpiderError> {
184                let Some(element) = self.element_ref(document) else {
185                    return Ok(false);
186                };
187
188                Ok(element
189                    .ancestors()
190                    .filter_map(ElementRef::wrap)
191                    .any(|ancestor| selector.matches(&ancestor)))
192            },
193        )
194    }
195
196    fn element_ref<'a>(&self, document: &'a Html) -> Option<ElementRef<'a>> {
197        element_ref_by_path(document, &self.path)
198    }
199}
200
201impl SelectorList {
202    pub(crate) fn new(
203        document_html: Arc<str>,
204        document_hash: u64,
205        paths: Vec<Arc<[usize]>>,
206        extraction: ExtractionKind,
207    ) -> Self {
208        Self {
209            document_html,
210            document_hash,
211            paths,
212            extraction,
213        }
214    }
215
216    pub(crate) fn from_document_query(
217        document_html: Arc<str>,
218        document_hash: u64,
219        query: &str,
220    ) -> Result<Self, SpiderError> {
221        let compiled = get_cached_compiled_selector(query)?;
222        with_document(
223            document_hash,
224            &document_html,
225            |document| -> Result<Self, SpiderError> {
226                let paths = document
227                    .select(compiled.selector())
228                    .map(|element| node_path(document, element.id()))
229                    .collect();
230
231                Ok(Self::new(
232                    document_html.clone(),
233                    document_hash,
234                    paths,
235                    compiled.extraction().clone(),
236                ))
237            },
238        )
239    }
240
241    pub(crate) fn empty(
242        document_html: Arc<str>,
243        document_hash: u64,
244        extraction: ExtractionKind,
245    ) -> Self {
246        Self::new(document_html, document_hash, Vec::new(), extraction)
247    }
248
249    /// Applies a CSS selector relative to every node in the list.
250    ///
251    /// # Errors
252    ///
253    /// Returns [`SpiderError::HtmlParseError`] when the selector is invalid or
254    /// when chaining from a text/attribute extraction.
255    pub fn css(&self, query: &str) -> Result<Self, SpiderError> {
256        if self.extraction != ExtractionKind::Element {
257            return Err(SpiderError::HtmlParseError(
258                "css() can only be chained from element selections".to_string(),
259            ));
260        }
261
262        let compiled = get_cached_compiled_selector(query)?;
263        let mut seen = HashSet::new();
264        with_document(
265            self.document_hash,
266            &self.document_html,
267            |document| -> Result<Self, SpiderError> {
268                let mut paths = Vec::new();
269
270                for path in &self.paths {
271                    let Some(scope) = element_ref_by_path(document, path) else {
272                        continue;
273                    };
274
275                    for element in scope.select(compiled.selector()) {
276                        let path = node_path(document, element.id());
277                        if seen.insert(path.clone()) {
278                            paths.push(path);
279                        }
280                    }
281                }
282
283                Ok(Self::new(
284                    self.document_html.clone(),
285                    self.document_hash,
286                    paths,
287                    compiled.extraction().clone(),
288                ))
289            },
290        )
291    }
292
293    /// Returns the first extracted value in the selection.
294    pub fn get(&self) -> Option<String> {
295        self.first().and_then(|node| node.get())
296    }
297
298    /// Returns all extracted values in the selection.
299    pub fn get_all(&self) -> Vec<String> {
300        with_document(self.document_hash, &self.document_html, |document| {
301            self.paths
302                .iter()
303                .filter_map(|path| {
304                    element_ref_by_path(document, path)
305                        .and_then(|element| extract_element_value(element, &self.extraction))
306                })
307                .collect()
308        })
309    }
310
311    /// Returns the named attribute from the first selected element.
312    pub fn attrib(&self, name: &str) -> Option<String> {
313        self.first().and_then(|node| node.attrib(name))
314    }
315
316    /// Returns the first selected node.
317    pub fn first(&self) -> Option<SelectorNode> {
318        self.paths.first().cloned().map(|path| {
319            SelectorNode::new(
320                self.document_html.clone(),
321                self.document_hash,
322                path,
323                self.extraction.clone(),
324            )
325        })
326    }
327
328    /// Returns the number of matched nodes.
329    pub fn len(&self) -> usize {
330        self.paths.len()
331    }
332
333    /// Returns `true` when the selection has no matched nodes.
334    pub fn is_empty(&self) -> bool {
335        self.paths.is_empty()
336    }
337}
338
339impl IntoIterator for SelectorList {
340    type Item = SelectorNode;
341    type IntoIter = std::vec::IntoIter<SelectorNode>;
342
343    fn into_iter(self) -> Self::IntoIter {
344        self.paths
345            .into_iter()
346            .map(|path| {
347                SelectorNode::new(
348                    self.document_html.clone(),
349                    self.document_hash,
350                    path,
351                    self.extraction.clone(),
352                )
353            })
354            .collect::<Vec<_>>()
355            .into_iter()
356    }
357}
358
359/// Returns a compiled selector from the cache, compiling it on first use.
360pub fn get_cached_selector(selector_str: &str) -> Option<Selector> {
361    {
362        let cache = SELECTOR_CACHE.read();
363        if let Some(cached) = cache.get(selector_str) {
364            return Some(cached.clone());
365        }
366    }
367
368    match Selector::parse(selector_str) {
369        Ok(selector) => {
370            {
371                let mut cache = SELECTOR_CACHE.write();
372                if let Some(cached) = cache.get(selector_str) {
373                    return Some(cached.clone());
374                }
375                cache.insert(selector_str.to_string(), selector.clone());
376            }
377            Some(selector)
378        }
379        Err(_) => None,
380    }
381}
382
383pub(crate) fn get_cached_compiled_selector(query: &str) -> Result<CompiledSelector, SpiderError> {
384    {
385        let cache = COMPILED_SELECTOR_CACHE.read();
386        if let Some(cached) = cache.get(query) {
387            return Ok(cached.clone());
388        }
389    }
390
391    let compiled = parse_compiled_selector(query)?;
392
393    {
394        let mut cache = COMPILED_SELECTOR_CACHE.write();
395        if let Some(cached) = cache.get(query) {
396            return Ok(cached.clone());
397        }
398        cache.insert(query.to_string(), compiled.clone());
399    }
400
401    Ok(compiled)
402}
403
404/// Pre-warms the selector cache with a small set of common selectors.
405pub fn prewarm_cache() {
406    let common_selectors = vec![
407        "a[href]",
408        "link[href]",
409        "script[src]",
410        "img[src]",
411        "audio[src]",
412        "video[src]",
413        "source[src]",
414        "form[action]",
415        "iframe[src]",
416        "frame[src]",
417        "embed[src]",
418        "object[data]",
419    ];
420
421    for selector_str in common_selectors {
422        get_cached_selector(selector_str);
423        let _ = get_cached_compiled_selector(selector_str);
424    }
425}
426
427fn parse_compiled_selector(query: &str) -> Result<CompiledSelector, SpiderError> {
428    let query = query.trim();
429    if query.is_empty() {
430        return Err(SpiderError::HtmlParseError(
431            "selector query cannot be empty".to_string(),
432        ));
433    }
434
435    let (selector_str, extraction) = parse_selector_parts(query)?;
436    let selector =
437        Selector::parse(selector_str).map_err(|e| SpiderError::HtmlParseError(e.to_string()))?;
438
439    Ok(CompiledSelector {
440        selector,
441        extraction,
442    })
443}
444
445fn parse_selector_parts(query: &str) -> Result<(&str, ExtractionKind), SpiderError> {
446    if let Some(selector) = query.strip_suffix("::text") {
447        let selector = selector.trim_end();
448        if selector.is_empty() {
449            return Err(SpiderError::HtmlParseError(
450                "selector cannot be empty before ::text".to_string(),
451            ));
452        }
453        return Ok((selector, ExtractionKind::Text));
454    }
455
456    if let Some(start) = query.rfind("::attr(")
457        && query.ends_with(')')
458    {
459        let selector = query[..start].trim_end();
460        let attr = query[start + "::attr(".len()..query.len() - 1].trim();
461        if selector.is_empty() {
462            return Err(SpiderError::HtmlParseError(
463                "selector cannot be empty before ::attr(...)".to_string(),
464            ));
465        }
466        if attr.is_empty() {
467            return Err(SpiderError::HtmlParseError(
468                "attribute name cannot be empty in ::attr(...)".to_string(),
469            ));
470        }
471
472        return Ok((selector, ExtractionKind::Attr(attr.to_string())));
473    }
474
475    Ok((query, ExtractionKind::Element))
476}
477
478fn with_document<T>(document_hash: u64, document_html: &Arc<str>, f: impl FnOnce(&Html) -> T) -> T {
479    DOCUMENT_CACHE.with(|cache| {
480        let mut cache = cache.borrow_mut();
481        let parsed = match cache.get(&document_hash) {
482            Some((cached_html, parsed)) if cached_html.as_ref() == document_html.as_ref() => {
483                parsed.clone()
484            }
485            _ => {
486                let parsed = Arc::new(Html::parse_document(document_html.as_ref()));
487                cache.insert(document_hash, (document_html.clone(), parsed.clone()));
488                parsed
489            }
490        };
491        drop(cache);
492        f(parsed.as_ref())
493    })
494}
495
496fn element_ref_by_id(document: &Html, node_id: NodeId) -> Option<ElementRef<'_>> {
497    document.tree.get(node_id).and_then(ElementRef::wrap)
498}
499
500fn element_ref_by_path<'a>(document: &'a Html, path: &[usize]) -> Option<ElementRef<'a>> {
501    let mut current = document.tree.root().id();
502
503    for child_index in path {
504        current = nth_child(document.tree.get(current)?.children(), *child_index)?.id();
505    }
506
507    element_ref_by_id(document, current)
508}
509
510fn node_path(document: &Html, node_id: NodeId) -> Arc<[usize]> {
511    let mut path = Vec::new();
512    let mut current = node_id;
513
514    while let Some(node) = document.tree.get(current) {
515        let Some(parent) = node.parent() else {
516            break;
517        };
518        let parent_id = parent.id();
519
520        let mut child_index = 0usize;
521        for child in parent.children() {
522            if child.id() == current {
523                break;
524            }
525            child_index += 1;
526        }
527
528        path.push(child_index);
529        current = parent_id;
530    }
531
532    path.reverse();
533    Arc::from(path)
534}
535
536fn nth_child<'a>(
537    mut children: Children<'a, scraper::node::Node>,
538    child_index: usize,
539) -> Option<ego_tree::NodeRef<'a, scraper::node::Node>> {
540    children.nth(child_index)
541}
542
543fn extract_element_value(element: ElementRef<'_>, extraction: &ExtractionKind) -> Option<String> {
544    match extraction {
545        ExtractionKind::Element => Some(element.html()),
546        ExtractionKind::Text => Some(element.text().collect::<String>()),
547        ExtractionKind::Attr(attr) => element.attr(attr).map(ToOwned::to_owned),
548    }
549}