1use crate::error::SpiderError;
7use ego_tree::NodeId;
8use ego_tree::iter::Children;
9use once_cell::sync::Lazy;
10use parking_lot::RwLock;
11use scraper::{ElementRef, Html, Selector};
12use std::cell::RefCell;
13use std::collections::{HashMap, HashSet};
14use std::sync::Arc;
15
16static SELECTOR_CACHE: Lazy<RwLock<HashMap<String, Selector>>> =
18 Lazy::new(|| RwLock::new(HashMap::new()));
19static COMPILED_SELECTOR_CACHE: Lazy<RwLock<HashMap<String, CompiledSelector>>> =
20 Lazy::new(|| RwLock::new(HashMap::new()));
21
22thread_local! {
23 static DOCUMENT_CACHE: RefCell<HashMap<u64, (Arc<str>, Arc<Html>)>> = RefCell::new(HashMap::new());
24}
25
26#[derive(Debug, Clone, PartialEq, Eq)]
27pub(crate) enum ExtractionKind {
28 Element,
29 Text,
30 Attr(String),
31}
32
33#[derive(Debug, Clone)]
34pub(crate) struct CompiledSelector {
35 selector: Selector,
36 extraction: ExtractionKind,
37}
38
39impl CompiledSelector {
40 pub(crate) fn selector(&self) -> &Selector {
41 &self.selector
42 }
43
44 pub(crate) fn extraction(&self) -> &ExtractionKind {
45 &self.extraction
46 }
47}
48
49#[derive(Debug, Clone)]
51pub struct SelectorNode {
52 document_html: Arc<str>,
53 document_hash: u64,
54 path: Arc<[usize]>,
55 extraction: ExtractionKind,
56}
57
58#[derive(Debug, Clone)]
60pub struct SelectorList {
61 document_html: Arc<str>,
62 document_hash: u64,
63 paths: Vec<Arc<[usize]>>,
64 extraction: ExtractionKind,
65}
66
67fn assert_selector_types_are_send_sync() {
68 fn assert_traits<T: Send + Sync>() {}
69
70 assert_traits::<SelectorNode>();
71 assert_traits::<SelectorList>();
72}
73
74const _: fn() = assert_selector_types_are_send_sync;
75
76impl SelectorNode {
77 pub(crate) fn new(
78 document_html: Arc<str>,
79 document_hash: u64,
80 path: Arc<[usize]>,
81 extraction: ExtractionKind,
82 ) -> Self {
83 Self {
84 document_html,
85 document_hash,
86 path,
87 extraction,
88 }
89 }
90
91 pub fn css(&self, query: &str) -> Result<SelectorList, SpiderError> {
98 if self.extraction != ExtractionKind::Element {
99 return Err(SpiderError::HtmlParseError(
100 "css() can only be chained from element selections".to_string(),
101 ));
102 }
103
104 let compiled = get_cached_compiled_selector(query)?;
105 with_document(
106 self.document_hash,
107 &self.document_html,
108 |document| -> Result<SelectorList, SpiderError> {
109 let Some(scope) = self.element_ref(document) else {
110 return Ok(SelectorList::empty(
111 self.document_html.clone(),
112 self.document_hash,
113 compiled.extraction().clone(),
114 ));
115 };
116
117 let paths = scope
118 .select(compiled.selector())
119 .map(|element| node_path(document, element.id()))
120 .collect();
121
122 Ok(SelectorList::new(
123 self.document_html.clone(),
124 self.document_hash,
125 paths,
126 compiled.extraction().clone(),
127 ))
128 },
129 )
130 }
131
132 pub fn get(&self) -> Option<String> {
134 with_document(self.document_hash, &self.document_html, |document| {
135 self.element_ref(document)
136 .and_then(|element| extract_element_value(element, &self.extraction))
137 })
138 }
139
140 pub fn get_all(&self) -> Vec<String> {
142 self.get().into_iter().collect()
143 }
144
145 pub fn attrib(&self, name: &str) -> Option<String> {
147 with_document(self.document_hash, &self.document_html, |document| {
148 self.element_ref(document)
149 .and_then(|element| element.attr(name).map(ToOwned::to_owned))
150 })
151 }
152
153 pub fn text_content(&self) -> Option<String> {
155 with_document(self.document_hash, &self.document_html, |document| {
156 self.element_ref(document)
157 .map(|element| element.text().collect::<String>())
158 })
159 }
160
161 pub fn has_css(&self, query: &str) -> Result<bool, SpiderError> {
168 Ok(!self.css(query)?.is_empty())
169 }
170
171 pub fn has_ancestor(&self, query: &str) -> Result<bool, SpiderError> {
178 let selector =
179 Selector::parse(query).map_err(|e| SpiderError::HtmlParseError(e.to_string()))?;
180 with_document(
181 self.document_hash,
182 &self.document_html,
183 |document| -> Result<bool, SpiderError> {
184 let Some(element) = self.element_ref(document) else {
185 return Ok(false);
186 };
187
188 Ok(element
189 .ancestors()
190 .filter_map(ElementRef::wrap)
191 .any(|ancestor| selector.matches(&ancestor)))
192 },
193 )
194 }
195
196 fn element_ref<'a>(&self, document: &'a Html) -> Option<ElementRef<'a>> {
197 element_ref_by_path(document, &self.path)
198 }
199}
200
201impl SelectorList {
202 pub(crate) fn new(
203 document_html: Arc<str>,
204 document_hash: u64,
205 paths: Vec<Arc<[usize]>>,
206 extraction: ExtractionKind,
207 ) -> Self {
208 Self {
209 document_html,
210 document_hash,
211 paths,
212 extraction,
213 }
214 }
215
216 pub(crate) fn from_document_query(
217 document_html: Arc<str>,
218 document_hash: u64,
219 query: &str,
220 ) -> Result<Self, SpiderError> {
221 let compiled = get_cached_compiled_selector(query)?;
222 with_document(
223 document_hash,
224 &document_html,
225 |document| -> Result<Self, SpiderError> {
226 let paths = document
227 .select(compiled.selector())
228 .map(|element| node_path(document, element.id()))
229 .collect();
230
231 Ok(Self::new(
232 document_html.clone(),
233 document_hash,
234 paths,
235 compiled.extraction().clone(),
236 ))
237 },
238 )
239 }
240
241 pub(crate) fn empty(
242 document_html: Arc<str>,
243 document_hash: u64,
244 extraction: ExtractionKind,
245 ) -> Self {
246 Self::new(document_html, document_hash, Vec::new(), extraction)
247 }
248
249 pub fn css(&self, query: &str) -> Result<Self, SpiderError> {
256 if self.extraction != ExtractionKind::Element {
257 return Err(SpiderError::HtmlParseError(
258 "css() can only be chained from element selections".to_string(),
259 ));
260 }
261
262 let compiled = get_cached_compiled_selector(query)?;
263 let mut seen = HashSet::new();
264 with_document(
265 self.document_hash,
266 &self.document_html,
267 |document| -> Result<Self, SpiderError> {
268 let mut paths = Vec::new();
269
270 for path in &self.paths {
271 let Some(scope) = element_ref_by_path(document, path) else {
272 continue;
273 };
274
275 for element in scope.select(compiled.selector()) {
276 let path = node_path(document, element.id());
277 if seen.insert(path.clone()) {
278 paths.push(path);
279 }
280 }
281 }
282
283 Ok(Self::new(
284 self.document_html.clone(),
285 self.document_hash,
286 paths,
287 compiled.extraction().clone(),
288 ))
289 },
290 )
291 }
292
293 pub fn get(&self) -> Option<String> {
295 self.first().and_then(|node| node.get())
296 }
297
298 pub fn get_all(&self) -> Vec<String> {
300 with_document(self.document_hash, &self.document_html, |document| {
301 self.paths
302 .iter()
303 .filter_map(|path| {
304 element_ref_by_path(document, path)
305 .and_then(|element| extract_element_value(element, &self.extraction))
306 })
307 .collect()
308 })
309 }
310
311 pub fn attrib(&self, name: &str) -> Option<String> {
313 self.first().and_then(|node| node.attrib(name))
314 }
315
316 pub fn first(&self) -> Option<SelectorNode> {
318 self.paths.first().cloned().map(|path| {
319 SelectorNode::new(
320 self.document_html.clone(),
321 self.document_hash,
322 path,
323 self.extraction.clone(),
324 )
325 })
326 }
327
328 pub fn len(&self) -> usize {
330 self.paths.len()
331 }
332
333 pub fn is_empty(&self) -> bool {
335 self.paths.is_empty()
336 }
337}
338
339impl IntoIterator for SelectorList {
340 type Item = SelectorNode;
341 type IntoIter = std::vec::IntoIter<SelectorNode>;
342
343 fn into_iter(self) -> Self::IntoIter {
344 self.paths
345 .into_iter()
346 .map(|path| {
347 SelectorNode::new(
348 self.document_html.clone(),
349 self.document_hash,
350 path,
351 self.extraction.clone(),
352 )
353 })
354 .collect::<Vec<_>>()
355 .into_iter()
356 }
357}
358
359pub fn get_cached_selector(selector_str: &str) -> Option<Selector> {
361 {
362 let cache = SELECTOR_CACHE.read();
363 if let Some(cached) = cache.get(selector_str) {
364 return Some(cached.clone());
365 }
366 }
367
368 match Selector::parse(selector_str) {
369 Ok(selector) => {
370 {
371 let mut cache = SELECTOR_CACHE.write();
372 if let Some(cached) = cache.get(selector_str) {
373 return Some(cached.clone());
374 }
375 cache.insert(selector_str.to_string(), selector.clone());
376 }
377 Some(selector)
378 }
379 Err(_) => None,
380 }
381}
382
383pub(crate) fn get_cached_compiled_selector(query: &str) -> Result<CompiledSelector, SpiderError> {
384 {
385 let cache = COMPILED_SELECTOR_CACHE.read();
386 if let Some(cached) = cache.get(query) {
387 return Ok(cached.clone());
388 }
389 }
390
391 let compiled = parse_compiled_selector(query)?;
392
393 {
394 let mut cache = COMPILED_SELECTOR_CACHE.write();
395 if let Some(cached) = cache.get(query) {
396 return Ok(cached.clone());
397 }
398 cache.insert(query.to_string(), compiled.clone());
399 }
400
401 Ok(compiled)
402}
403
404pub fn prewarm_cache() {
406 let common_selectors = vec![
407 "a[href]",
408 "link[href]",
409 "script[src]",
410 "img[src]",
411 "audio[src]",
412 "video[src]",
413 "source[src]",
414 "form[action]",
415 "iframe[src]",
416 "frame[src]",
417 "embed[src]",
418 "object[data]",
419 ];
420
421 for selector_str in common_selectors {
422 get_cached_selector(selector_str);
423 let _ = get_cached_compiled_selector(selector_str);
424 }
425}
426
427fn parse_compiled_selector(query: &str) -> Result<CompiledSelector, SpiderError> {
428 let query = query.trim();
429 if query.is_empty() {
430 return Err(SpiderError::HtmlParseError(
431 "selector query cannot be empty".to_string(),
432 ));
433 }
434
435 let (selector_str, extraction) = parse_selector_parts(query)?;
436 let selector =
437 Selector::parse(selector_str).map_err(|e| SpiderError::HtmlParseError(e.to_string()))?;
438
439 Ok(CompiledSelector {
440 selector,
441 extraction,
442 })
443}
444
445fn parse_selector_parts(query: &str) -> Result<(&str, ExtractionKind), SpiderError> {
446 if let Some(selector) = query.strip_suffix("::text") {
447 let selector = selector.trim_end();
448 if selector.is_empty() {
449 return Err(SpiderError::HtmlParseError(
450 "selector cannot be empty before ::text".to_string(),
451 ));
452 }
453 return Ok((selector, ExtractionKind::Text));
454 }
455
456 if let Some(start) = query.rfind("::attr(")
457 && query.ends_with(')')
458 {
459 let selector = query[..start].trim_end();
460 let attr = query[start + "::attr(".len()..query.len() - 1].trim();
461 if selector.is_empty() {
462 return Err(SpiderError::HtmlParseError(
463 "selector cannot be empty before ::attr(...)".to_string(),
464 ));
465 }
466 if attr.is_empty() {
467 return Err(SpiderError::HtmlParseError(
468 "attribute name cannot be empty in ::attr(...)".to_string(),
469 ));
470 }
471
472 return Ok((selector, ExtractionKind::Attr(attr.to_string())));
473 }
474
475 Ok((query, ExtractionKind::Element))
476}
477
478fn with_document<T>(document_hash: u64, document_html: &Arc<str>, f: impl FnOnce(&Html) -> T) -> T {
479 DOCUMENT_CACHE.with(|cache| {
480 let mut cache = cache.borrow_mut();
481 let parsed = match cache.get(&document_hash) {
482 Some((cached_html, parsed)) if cached_html.as_ref() == document_html.as_ref() => {
483 parsed.clone()
484 }
485 _ => {
486 let parsed = Arc::new(Html::parse_document(document_html.as_ref()));
487 cache.insert(document_hash, (document_html.clone(), parsed.clone()));
488 parsed
489 }
490 };
491 drop(cache);
492 f(parsed.as_ref())
493 })
494}
495
496fn element_ref_by_id(document: &Html, node_id: NodeId) -> Option<ElementRef<'_>> {
497 document.tree.get(node_id).and_then(ElementRef::wrap)
498}
499
500fn element_ref_by_path<'a>(document: &'a Html, path: &[usize]) -> Option<ElementRef<'a>> {
501 let mut current = document.tree.root().id();
502
503 for child_index in path {
504 current = nth_child(document.tree.get(current)?.children(), *child_index)?.id();
505 }
506
507 element_ref_by_id(document, current)
508}
509
510fn node_path(document: &Html, node_id: NodeId) -> Arc<[usize]> {
511 let mut path = Vec::new();
512 let mut current = node_id;
513
514 while let Some(node) = document.tree.get(current) {
515 let Some(parent) = node.parent() else {
516 break;
517 };
518 let parent_id = parent.id();
519
520 let mut child_index = 0usize;
521 for child in parent.children() {
522 if child.id() == current {
523 break;
524 }
525 child_index += 1;
526 }
527
528 path.push(child_index);
529 current = parent_id;
530 }
531
532 path.reverse();
533 Arc::from(path)
534}
535
536fn nth_child<'a>(
537 mut children: Children<'a, scraper::node::Node>,
538 child_index: usize,
539) -> Option<ego_tree::NodeRef<'a, scraper::node::Node>> {
540 children.nth(child_index)
541}
542
543fn extract_element_value(element: ElementRef<'_>, extraction: &ExtractionKind) -> Option<String> {
544 match extraction {
545 ExtractionKind::Element => Some(element.html()),
546 ExtractionKind::Text => Some(element.text().collect::<String>()),
547 ExtractionKind::Attr(attr) => element.attr(attr).map(ToOwned::to_owned),
548 }
549}