diff --git a/src/khoj/processor/tools/online_search.py b/src/khoj/processor/tools/online_search.py index 2c8297d8..01623efa 100644 --- a/src/khoj/processor/tools/online_search.py +++ b/src/khoj/processor/tools/online_search.py @@ -158,9 +158,7 @@ async def search_online( async for event in send_status_func(f"**Browsing**: {webpage_links_str}"): yield {ChatEvent.STATUS: event} tasks = [ - read_webpage_and_extract_content( - data["queries"], link, data.get("content"), user=user, agent=agent, tracer=tracer - ) + extract_from_webpage(link, data["queries"], data.get("content"), user=user, agent=agent, tracer=tracer) for link, data in webpages.items() ] results = await asyncio.gather(*tasks) @@ -476,7 +474,7 @@ async def read_webpages_content( webpage_links_str = "\n- " + "\n- ".join(list(urls)) async for event in send_status_func(f"**Browsing**: {webpage_links_str}"): yield {ChatEvent.STATUS: event} - tasks = [read_webpage_and_extract_content({query}, url, user=user, agent=agent, tracer=tracer) for url in urls] + tasks = [extract_from_webpage(url, {query}, user=user, agent=agent, tracer=tracer) for url in urls] results = await asyncio.gather(*tasks) response: Dict[str, Dict] = defaultdict(dict) @@ -486,49 +484,37 @@ async def read_webpages_content( yield response -async def read_webpage(url, scraper_type=None, api_key=None, api_url=None) -> Tuple[str | None, str | None]: +async def scrape_webpage(url, scraper_type=None, api_key=None, api_url=None) -> str | None: if scraper_type == WebScraper.WebScraperType.FIRECRAWL: - return await read_webpage_with_firecrawl(url, api_key, api_url), None + return await read_webpage_with_firecrawl(url, api_key, api_url) elif scraper_type == WebScraper.WebScraperType.OLOSTEP: - return await read_webpage_with_olostep(url, api_key, api_url), None + return await read_webpage_with_olostep(url, api_key, api_url) elif scraper_type == WebScraper.WebScraperType.EXA: - return await read_webpage_with_exa(url, api_key, api_url), None + return await read_webpage_with_exa(url, api_key, api_url) else: - return await read_webpage_at_url(url), None + return await read_webpage_at_url(url) -async def read_webpage_and_extract_content( - subqueries: set[str], - url: str, - content: str = None, - user: KhojUser = None, - agent: Agent = None, - tracer: dict = {}, -) -> Tuple[set[str], str, Union[None, str]]: +async def scrape_webpage_with_fallback(url: str) -> Optional[str]: + """ + Scrape a webpage using enabled web scrapers with fallback logic. + Tries all enabled scrapers in order until one succeeds. + Returns the content if successful, None otherwise. + """ # Select the web scrapers to use for reading the web page web_scrapers = await ConversationAdapters.aget_enabled_webscrapers() # Only use the direct web scraper for internal URLs if is_internal_url(url): web_scrapers = [scraper for scraper in web_scrapers if scraper.type == WebScraper.WebScraperType.DIRECT] - # Fallback through enabled web scrapers until we successfully read the web page - extracted_info = None + # Read the web page + # fallback through enabled web scrapers until success + content = None for scraper in web_scrapers: try: - # Read the web page - if is_none_or_empty(content): - with timer(f"Reading web page with {scraper.type} at '{url}' took", logger, log_level=logging.INFO): - content, extracted_info = await read_webpage(url, scraper.type, scraper.api_key, scraper.api_url) - - # Extract relevant information from the web page - if is_none_or_empty(extracted_info): - with timer(f"Extracting relevant information from web page at '{url}' took", logger): - extracted_info = await extract_relevant_info( - subqueries, content, user=user, agent=agent, tracer=tracer - ) - - # If we successfully extracted information, break the loop - if not is_none_or_empty(extracted_info): + with timer(f"Reading web page with {scraper.type} at '{url}' took", logger, log_level=logging.INFO): + content = await scrape_webpage(url, scraper.type, scraper.api_key, scraper.api_url) + if not is_none_or_empty(content): break except Exception as e: logger.warning(f"Failed to read web page with {scraper.type} at '{url}' with {e}") @@ -536,6 +522,28 @@ async def read_webpage_and_extract_content( if scraper.name == web_scrapers[-1].name: logger.error(f"All web scrapers failed for '{url}'") + return content + + +async def extract_from_webpage( + url: str, + subqueries: set[str] = None, + content: str = None, + user: KhojUser = None, + agent: Agent = None, + tracer: dict = {}, +) -> Tuple[set[str], str, Union[None, str]]: + # Read the web page + content = None + if is_none_or_empty(content): + content = await scrape_webpage_with_fallback(url) + + # Extract relevant information from the web page + extracted_info = None + if not is_none_or_empty(content): + with timer(f"Extracting relevant information from web page at '{url}' took", logger): + extracted_info = await extract_relevant_info(subqueries, content, user=user, agent=agent, tracer=tracer) + return subqueries, url, extracted_info