diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 182ce701..28946557 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -1,6 +1,7 @@ import json import logging import math +import os import random import re import secrets @@ -10,7 +11,6 @@ from enum import Enum from typing import Callable, Iterable, List, Optional, Type import cron_descriptor -import django from apscheduler.job import Job from asgiref.sync import sync_to_async from django.contrib.sessions.backends.db import SessionStore @@ -52,6 +52,7 @@ from khoj.database.models import ( UserTextToImageModelConfig, UserVoiceModelConfig, VoiceModelOption, + WebScraper, ) from khoj.processor.conversation import prompts from khoj.search_filter.date_filter import DateFilter @@ -59,7 +60,12 @@ from khoj.search_filter.file_filter import FileFilter from khoj.search_filter.word_filter import WordFilter from khoj.utils import state from khoj.utils.config import OfflineChatProcessorModel -from khoj.utils.helpers import generate_random_name, is_none_or_empty, timer +from khoj.utils.helpers import ( + generate_random_name, + in_debug_mode, + is_none_or_empty, + timer, +) logger = logging.getLogger(__name__) @@ -1031,6 +1037,70 @@ class ConversationAdapters: return server_chat_settings.chat_advanced return await ConversationAdapters.aget_default_conversation_config(user) + @staticmethod + async def aget_server_webscraper(): + server_chat_settings = await ServerChatSettings.objects.filter().prefetch_related("web_scraper").afirst() + if server_chat_settings is not None and server_chat_settings.web_scraper is not None: + return server_chat_settings.web_scraper + return None + + @staticmethod + async def aget_enabled_webscrapers() -> list[WebScraper]: + enabled_scrapers: list[WebScraper] = [] + server_webscraper = await ConversationAdapters.aget_server_webscraper() + if server_webscraper: + # Only use the webscraper set in the server chat settings + enabled_scrapers = [server_webscraper] + if not enabled_scrapers: + # Use the enabled web scrapers, ordered by priority, until get web page content + enabled_scrapers = [scraper async for scraper in WebScraper.objects.all().order_by("priority").aiterator()] + if not enabled_scrapers: + # Use scrapers enabled via environment variables + if os.getenv("FIRECRAWL_API_KEY"): + api_url = os.getenv("FIRECRAWL_API_URL", "https://api.firecrawl.dev") + enabled_scrapers.append( + WebScraper( + type=WebScraper.WebScraperType.FIRECRAWL, + name=WebScraper.WebScraperType.FIRECRAWL.capitalize(), + api_key=os.getenv("FIRECRAWL_API_KEY"), + api_url=api_url, + ) + ) + if os.getenv("OLOSTEP_API_KEY"): + api_url = os.getenv("OLOSTEP_API_URL", "https://agent.olostep.com/olostep-p2p-incomingAPI") + enabled_scrapers.append( + WebScraper( + type=WebScraper.WebScraperType.OLOSTEP, + name=WebScraper.WebScraperType.OLOSTEP.capitalize(), + api_key=os.getenv("OLOSTEP_API_KEY"), + api_url=api_url, + ) + ) + # Jina is the default fallback scrapers to use as it does not require an API key + api_url = os.getenv("JINA_READER_API_URL", "https://r.jina.ai/") + enabled_scrapers.append( + WebScraper( + type=WebScraper.WebScraperType.JINA, + name=WebScraper.WebScraperType.JINA.capitalize(), + api_key=os.getenv("JINA_API_KEY"), + api_url=api_url, + ) + ) + + # Only enable the direct web page scraper by default in self-hosted single user setups. + # Useful for reading webpages on your intranet. + if state.anonymous_mode or in_debug_mode(): + enabled_scrapers.append( + WebScraper( + type=WebScraper.WebScraperType.DIRECT, + name=WebScraper.WebScraperType.DIRECT.capitalize(), + api_key=None, + api_url=None, + ) + ) + + return enabled_scrapers + @staticmethod def create_conversation_from_public_conversation( user: KhojUser, public_conversation: PublicConversation, client_app: ClientApplication diff --git a/src/khoj/database/admin.py b/src/khoj/database/admin.py index 3e192952..5aa9204b 100644 --- a/src/khoj/database/admin.py +++ b/src/khoj/database/admin.py @@ -31,6 +31,7 @@ from khoj.database.models import ( UserSearchModelConfig, UserVoiceModelConfig, VoiceModelOption, + WebScraper, ) from khoj.utils.helpers import ImageIntentType @@ -198,9 +199,24 @@ class ServerChatSettingsAdmin(admin.ModelAdmin): list_display = ( "chat_default", "chat_advanced", + "web_scraper", ) +@admin.register(WebScraper) +class WebScraperAdmin(admin.ModelAdmin): + list_display = ( + "priority", + "name", + "type", + "api_key", + "api_url", + "created_at", + ) + search_fields = ("name", "api_key", "api_url", "type") + ordering = ("priority",) + + @admin.register(Conversation) class ConversationAdmin(admin.ModelAdmin): list_display = ( diff --git a/src/khoj/database/migrations/0069_webscraper_serverchatsettings_web_scraper.py b/src/khoj/database/migrations/0069_webscraper_serverchatsettings_web_scraper.py new file mode 100644 index 00000000..3ea8ebe3 --- /dev/null +++ b/src/khoj/database/migrations/0069_webscraper_serverchatsettings_web_scraper.py @@ -0,0 +1,89 @@ +# Generated by Django 5.0.8 on 2024-10-18 00:41 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0068_alter_agent_output_modes"), + ] + + operations = [ + migrations.CreateModel( + name="WebScraper", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("updated_at", models.DateTimeField(auto_now=True)), + ( + "name", + models.CharField( + blank=True, + default=None, + help_text="Friendly name. If not set, it will be set to the type of the scraper.", + max_length=200, + null=True, + unique=True, + ), + ), + ( + "type", + models.CharField( + choices=[ + ("Firecrawl", "Firecrawl"), + ("Olostep", "Olostep"), + ("Jina", "Jina"), + ("Direct", "Direct"), + ], + default="Jina", + max_length=20, + ), + ), + ( + "api_key", + models.CharField( + blank=True, + default=None, + help_text="API key of the web scraper. Only set if scraper service requires an API key. Default is set from env var.", + max_length=200, + null=True, + ), + ), + ( + "api_url", + models.URLField( + blank=True, + default=None, + help_text="API URL of the web scraper. Only set if scraper service on non-default URL.", + null=True, + ), + ), + ( + "priority", + models.IntegerField( + blank=True, + default=None, + help_text="Priority of the web scraper. Lower numbers run first.", + null=True, + unique=True, + ), + ), + ], + options={ + "abstract": False, + }, + ), + migrations.AddField( + model_name="serverchatsettings", + name="web_scraper", + field=models.ForeignKey( + blank=True, + default=None, + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="web_scraper", + to="database.webscraper", + ), + ), + ] diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index ec4b61d1..2b2fde2d 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -1,3 +1,4 @@ +import os import re import uuid from random import choice @@ -11,8 +12,6 @@ from django.dispatch import receiver from pgvector.django import VectorField from phonenumber_field.modelfields import PhoneNumberField -from khoj.utils.helpers import ConversationCommand - class BaseModel(models.Model): created_at = models.DateTimeField(auto_now_add=True) @@ -244,6 +243,79 @@ class GithubRepoConfig(BaseModel): github_config = models.ForeignKey(GithubConfig, on_delete=models.CASCADE, related_name="githubrepoconfig") +class WebScraper(BaseModel): + class WebScraperType(models.TextChoices): + FIRECRAWL = "Firecrawl" + OLOSTEP = "Olostep" + JINA = "Jina" + DIRECT = "Direct" + + name = models.CharField( + max_length=200, + default=None, + null=True, + blank=True, + unique=True, + help_text="Friendly name. If not set, it will be set to the type of the scraper.", + ) + type = models.CharField(max_length=20, choices=WebScraperType.choices, default=WebScraperType.JINA) + api_key = models.CharField( + max_length=200, + default=None, + null=True, + blank=True, + help_text="API key of the web scraper. Only set if scraper service requires an API key. Default is set from env var.", + ) + api_url = models.URLField( + max_length=200, + default=None, + null=True, + blank=True, + help_text="API URL of the web scraper. Only set if scraper service on non-default URL.", + ) + priority = models.IntegerField( + default=None, + null=True, + blank=True, + unique=True, + help_text="Priority of the web scraper. Lower numbers run first.", + ) + + def clean(self): + error = {} + if self.name is None: + self.name = self.type.capitalize() + if self.api_url is None: + if self.type == self.WebScraperType.FIRECRAWL: + self.api_url = os.getenv("FIRECRAWL_API_URL", "https://api.firecrawl.dev") + elif self.type == self.WebScraperType.OLOSTEP: + self.api_url = os.getenv("OLOSTEP_API_URL", "https://agent.olostep.com/olostep-p2p-incomingAPI") + elif self.type == self.WebScraperType.JINA: + self.api_url = os.getenv("JINA_READER_API_URL", "https://r.jina.ai/") + if self.api_key is None: + if self.type == self.WebScraperType.FIRECRAWL: + self.api_key = os.getenv("FIRECRAWL_API_KEY") + if not self.api_key and self.api_url == "https://api.firecrawl.dev": + error["api_key"] = "Set API key to use default Firecrawl. Get API key from https://firecrawl.dev." + elif self.type == self.WebScraperType.OLOSTEP: + self.api_key = os.getenv("OLOSTEP_API_KEY") + if self.api_key is None: + error["api_key"] = "Set API key to use Olostep. Get API key from https://olostep.com/." + elif self.type == self.WebScraperType.JINA: + self.api_key = os.getenv("JINA_API_KEY") + if error: + raise ValidationError(error) + + def save(self, *args, **kwargs): + self.clean() + + if self.priority is None: + max_priority = WebScraper.objects.aggregate(models.Max("priority"))["priority__max"] + self.priority = max_priority + 1 if max_priority else 1 + + super().save(*args, **kwargs) + + class ServerChatSettings(BaseModel): chat_default = models.ForeignKey( ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="chat_default" @@ -251,6 +323,9 @@ class ServerChatSettings(BaseModel): chat_advanced = models.ForeignKey( ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="chat_advanced" ) + web_scraper = models.ForeignKey( + WebScraper, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="web_scraper" + ) class LocalOrgConfig(BaseModel): diff --git a/src/khoj/processor/embeddings.py b/src/khoj/processor/embeddings.py index 71af5b7d..a19d85fa 100644 --- a/src/khoj/processor/embeddings.py +++ b/src/khoj/processor/embeddings.py @@ -114,6 +114,7 @@ class CrossEncoderModel: payload = {"inputs": {"query": query, "passages": [hit.additional[key] for hit in hits]}} headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"} response = requests.post(target_url, json=payload, headers=headers) + response.raise_for_status() return response.json()["scores"] cross_inp = [[query, hit.additional[key]] for hit in hits] diff --git a/src/khoj/processor/tools/online_search.py b/src/khoj/processor/tools/online_search.py index f5cb3c12..70972eac 100644 --- a/src/khoj/processor/tools/online_search.py +++ b/src/khoj/processor/tools/online_search.py @@ -10,14 +10,22 @@ import aiohttp from bs4 import BeautifulSoup from markdownify import markdownify -from khoj.database.models import Agent, KhojUser +from khoj.database.adapters import ConversationAdapters +from khoj.database.models import Agent, KhojUser, WebScraper +from khoj.processor.conversation import prompts from khoj.routers.helpers import ( ChatEvent, extract_relevant_info, generate_online_subqueries, infer_webpage_urls, ) -from khoj.utils.helpers import is_internet_connected, is_none_or_empty, timer +from khoj.utils.helpers import ( + is_env_var_true, + is_internal_url, + is_internet_connected, + is_none_or_empty, + timer, +) from khoj.utils.rawconfig import LocationData logger = logging.getLogger(__name__) @@ -25,12 +33,11 @@ logger = logging.getLogger(__name__) SERPER_DEV_API_KEY = os.getenv("SERPER_DEV_API_KEY") SERPER_DEV_URL = "https://google.serper.dev/search" -JINA_READER_API_URL = "https://r.jina.ai/" JINA_SEARCH_API_URL = "https://s.jina.ai/" JINA_API_KEY = os.getenv("JINA_API_KEY") -OLOSTEP_API_KEY = os.getenv("OLOSTEP_API_KEY") -OLOSTEP_API_URL = "https://agent.olostep.com/olostep-p2p-incomingAPI" +FIRECRAWL_USE_LLM_EXTRACT = is_env_var_true("FIRECRAWL_USE_LLM_EXTRACT") + OLOSTEP_QUERY_PARAMS = { "timeout": 35, # seconds "waitBeforeScraping": 1, # seconds @@ -83,33 +90,36 @@ async def search_online( search_results = await asyncio.gather(*search_tasks) response_dict = {subquery: search_result for subquery, search_result in search_results} - # Gather distinct web page data from organic results of each subquery without an instant answer. + # Gather distinct web pages from organic results for subqueries without an instant answer. # Content of web pages is directly available when Jina is used for search. - webpages = { - (organic.get("link"), subquery, organic.get("content")) - for subquery in response_dict - for organic in response_dict[subquery].get("organic", [])[:MAX_WEBPAGES_TO_READ] - if "answerBox" not in response_dict[subquery] - } + webpages: Dict[str, Dict] = {} + for subquery in response_dict: + if "answerBox" in response_dict[subquery]: + continue + for organic in response_dict[subquery].get("organic", [])[:MAX_WEBPAGES_TO_READ]: + link = organic.get("link") + if link in webpages: + webpages[link]["queries"].add(subquery) + else: + webpages[link] = {"queries": {subquery}, "content": organic.get("content")} # Read, extract relevant info from the retrieved web pages if webpages: - webpage_links = set([link for link, _, _ in webpages]) - logger.info(f"Reading web pages at: {list(webpage_links)}") + logger.info(f"Reading web pages at: {webpages.keys()}") if send_status_func: - webpage_links_str = "\n- " + "\n- ".join(list(webpage_links)) + webpage_links_str = "\n- " + "\n- ".join(webpages.keys()) async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"): yield {ChatEvent.STATUS: event} tasks = [ - read_webpage_and_extract_content(subquery, link, content, user=user, agent=agent) - for link, subquery, content in webpages + read_webpage_and_extract_content(data["queries"], link, data["content"], user=user, agent=agent) + for link, data in webpages.items() ] results = await asyncio.gather(*tasks) # Collect extracted info from the retrieved web pages - for subquery, webpage_extract, url in results: + for subqueries, url, webpage_extract in results: if webpage_extract is not None: - response_dict[subquery]["webpages"] = {"link": url, "snippet": webpage_extract} + response_dict[subqueries.pop()]["webpages"] = {"link": url, "snippet": webpage_extract} yield response_dict @@ -156,29 +166,66 @@ async def read_webpages( webpage_links_str = "\n- " + "\n- ".join(list(urls)) async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"): yield {ChatEvent.STATUS: event} - tasks = [read_webpage_and_extract_content(query, url, user=user, agent=agent) for url in urls] + tasks = [read_webpage_and_extract_content({query}, url, user=user, agent=agent) for url in urls] results = await asyncio.gather(*tasks) response: Dict[str, Dict] = defaultdict(dict) response[query]["webpages"] = [ - {"query": q, "link": url, "snippet": web_extract} for q, web_extract, url in results if web_extract is not None + {"query": qs.pop(), "link": url, "snippet": extract} for qs, url, extract in results if extract is not None ] yield response +async def read_webpage( + url, scraper_type=None, api_key=None, api_url=None, subqueries=None, agent=None +) -> Tuple[str | None, str | None]: + if scraper_type == WebScraper.WebScraperType.FIRECRAWL and FIRECRAWL_USE_LLM_EXTRACT: + return None, await query_webpage_with_firecrawl(url, subqueries, api_key, api_url, agent) + elif scraper_type == WebScraper.WebScraperType.FIRECRAWL: + return await read_webpage_with_firecrawl(url, api_key, api_url), None + elif scraper_type == WebScraper.WebScraperType.OLOSTEP: + return await read_webpage_with_olostep(url, api_key, api_url), None + elif scraper_type == WebScraper.WebScraperType.JINA: + return await read_webpage_with_jina(url, api_key, api_url), None + else: + return await read_webpage_at_url(url), None + + async def read_webpage_and_extract_content( - subquery: str, url: str, content: str = None, user: KhojUser = None, agent: Agent = None -) -> Tuple[str, Union[None, str], str]: - try: - if is_none_or_empty(content): - with timer(f"Reading web page at '{url}' took", logger): - content = await read_webpage_with_olostep(url) if OLOSTEP_API_KEY else await read_webpage_with_jina(url) - with timer(f"Extracting relevant information from web page at '{url}' took", logger): - extracted_info = await extract_relevant_info(subquery, content, user=user, agent=agent) - return subquery, extracted_info, url - except Exception as e: - logger.error(f"Failed to read web page at '{url}' with {e}") - return subquery, None, url + subqueries: set[str], url: str, content: str = None, user: KhojUser = None, agent: Agent = None +) -> Tuple[set[str], str, Union[None, str]]: + # Select the web scrapers to use for reading the web page + web_scrapers = await ConversationAdapters.aget_enabled_webscrapers() + # Only use the direct web scraper for internal URLs + if is_internal_url(url): + web_scrapers = [scraper for scraper in web_scrapers if scraper.type == WebScraper.WebScraperType.DIRECT] + + # Fallback through enabled web scrapers until we successfully read the web page + extracted_info = None + for scraper in web_scrapers: + try: + # Read the web page + if is_none_or_empty(content): + with timer(f"Reading web page with {scraper.type} at '{url}' took", logger, log_level=logging.INFO): + content, extracted_info = await read_webpage( + url, scraper.type, scraper.api_key, scraper.api_url, subqueries, agent + ) + + # Extract relevant information from the web page + if is_none_or_empty(extracted_info): + with timer(f"Extracting relevant information from web page at '{url}' took", logger): + extracted_info = await extract_relevant_info(subqueries, content, user=user, agent=agent) + + # If we successfully extracted information, break the loop + if not is_none_or_empty(extracted_info): + break + except Exception as e: + logger.warning(f"Failed to read web page with {scraper.type} at '{url}' with {e}") + # If this is the last web scraper in the list, log an error + if scraper.name == web_scrapers[-1].name: + logger.error(f"All web scrapers failed for '{url}'") + + return subqueries, url, extracted_info async def read_webpage_at_url(web_url: str) -> str: @@ -195,23 +242,23 @@ async def read_webpage_at_url(web_url: str) -> str: return markdownify(body) -async def read_webpage_with_olostep(web_url: str) -> str: - headers = {"Authorization": f"Bearer {OLOSTEP_API_KEY}"} +async def read_webpage_with_olostep(web_url: str, api_key: str, api_url: str) -> str: + headers = {"Authorization": f"Bearer {api_key}"} web_scraping_params: Dict[str, Union[str, int, bool]] = OLOSTEP_QUERY_PARAMS.copy() # type: ignore web_scraping_params["url"] = web_url async with aiohttp.ClientSession() as session: - async with session.get(OLOSTEP_API_URL, params=web_scraping_params, headers=headers) as response: + async with session.get(api_url, params=web_scraping_params, headers=headers) as response: response.raise_for_status() response_json = await response.json() return response_json["markdown_content"] -async def read_webpage_with_jina(web_url: str) -> str: - jina_reader_api_url = f"{JINA_READER_API_URL}/{web_url}" +async def read_webpage_with_jina(web_url: str, api_key: str, api_url: str) -> str: + jina_reader_api_url = f"{api_url}/{web_url}" headers = {"Accept": "application/json", "X-Timeout": "30"} - if JINA_API_KEY: - headers["Authorization"] = f"Bearer {JINA_API_KEY}" + if api_key: + headers["Authorization"] = f"Bearer {api_key}" async with aiohttp.ClientSession() as session: async with session.get(jina_reader_api_url, headers=headers) as response: @@ -220,6 +267,54 @@ async def read_webpage_with_jina(web_url: str) -> str: return response_json["data"]["content"] +async def read_webpage_with_firecrawl(web_url: str, api_key: str, api_url: str) -> str: + firecrawl_api_url = f"{api_url}/v1/scrape" + headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} + params = {"url": web_url, "formats": ["markdown"], "excludeTags": ["script", ".ad"]} + + async with aiohttp.ClientSession() as session: + async with session.post(firecrawl_api_url, json=params, headers=headers) as response: + response.raise_for_status() + response_json = await response.json() + return response_json["data"]["markdown"] + + +async def query_webpage_with_firecrawl( + web_url: str, queries: set[str], api_key: str, api_url: str, agent: Agent = None +) -> str: + firecrawl_api_url = f"{api_url}/v1/scrape" + headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} + schema = { + "type": "object", + "properties": { + "relevant_extract": {"type": "string"}, + }, + "required": [ + "relevant_extract", + ], + } + + personality_context = ( + prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else "" + ) + system_prompt = f""" +{prompts.system_prompt_extract_relevant_information} + +{personality_context} +User Query: {", ".join(queries)} + +Collate only relevant information from the website to answer the target query and in the provided JSON schema. +""".strip() + + params = {"url": web_url, "formats": ["extract"], "extract": {"systemPrompt": system_prompt, "schema": schema}} + + async with aiohttp.ClientSession() as session: + async with session.post(firecrawl_api_url, json=params, headers=headers) as response: + response.raise_for_status() + response_json = await response.json() + return response_json["data"]["extract"]["relevant_extract"] + + async def search_with_jina(query: str, location: LocationData) -> Tuple[str, Dict[str, List[Dict]]]: encoded_query = urllib.parse.quote(query) jina_search_api_url = f"{JINA_SEARCH_API_URL}/{encoded_query}" diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 228d081c..0d367029 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -3,7 +3,6 @@ import base64 import json import logging import time -import warnings from datetime import datetime from functools import partial from typing import Dict, Optional @@ -574,7 +573,6 @@ async def chat( chat_metadata: dict = {} connection_alive = True user: KhojUser = request.user.object - subscribed: bool = has_required_scope(request, ["premium"]) event_delimiter = "␃🔚␗" q = unquote(q) nonlocal conversation_id @@ -641,7 +639,7 @@ async def chat( request=request, telemetry_type="api", api="chat", - client=request.user.client_app, + client=common.client, user_agent=request.headers.get("user-agent"), host=request.headers.get("host"), metadata=chat_metadata, @@ -840,25 +838,33 @@ async def chat( # Gather Context ## Extract Document References compiled_references, inferred_queries, defiltered_query = [], [], None - async for result in extract_references_and_questions( - request, - meta_log, - q, - (n or 7), - d, - conversation_id, - conversation_commands, - location, - partial(send_event, ChatEvent.STATUS), - uploaded_image_url=uploaded_image_url, - agent=agent, - ): - if isinstance(result, dict) and ChatEvent.STATUS in result: - yield result[ChatEvent.STATUS] - else: - compiled_references.extend(result[0]) - inferred_queries.extend(result[1]) - defiltered_query = result[2] + try: + async for result in extract_references_and_questions( + request, + meta_log, + q, + (n or 7), + d, + conversation_id, + conversation_commands, + location, + partial(send_event, ChatEvent.STATUS), + uploaded_image_url=uploaded_image_url, + agent=agent, + ): + if isinstance(result, dict) and ChatEvent.STATUS in result: + yield result[ChatEvent.STATUS] + else: + compiled_references.extend(result[0]) + inferred_queries.extend(result[1]) + defiltered_query = result[2] + except Exception as e: + error_message = f"Error searching knowledge base: {e}. Attempting to respond without document references." + logger.warning(error_message) + async for result in send_event( + ChatEvent.STATUS, "Document search failed. I'll try respond without document references" + ): + yield result if not is_none_or_empty(compiled_references): headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references])) @@ -894,12 +900,13 @@ async def chat( yield result[ChatEvent.STATUS] else: online_results = result - except ValueError as e: + except Exception as e: error_message = f"Error searching online: {e}. Attempting to respond without online results" logger.warning(error_message) - async for result in send_llm_response(error_message): + async for result in send_event( + ChatEvent.STATUS, "Online search failed. I'll try respond without online references" + ): yield result - return ## Gather Webpage References if ConversationCommand.Webpage in conversation_commands: @@ -928,11 +935,15 @@ async def chat( webpages.append(webpage["link"]) async for result in send_event(ChatEvent.STATUS, f"**Read web pages**: {webpages}"): yield result - except ValueError as e: + except Exception as e: logger.warning( - f"Error directly reading webpages: {e}. Attempting to respond without online results", + f"Error reading webpages: {e}. Attempting to respond without webpage results", exc_info=True, ) + async for result in send_event( + ChatEvent.STATUS, "Webpage read failed. I'll try respond without webpage references" + ): + yield result ## Send Gathered References async for result in send_event( diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 245fdf09..c3d997e9 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -353,13 +353,13 @@ async def aget_relevant_information_sources( final_response = [ConversationCommand.Default] else: final_response = [ConversationCommand.General] - return final_response - except Exception as e: + except Exception: logger.error(f"Invalid response for determining relevant tools: {response}") if len(agent_tools) == 0: final_response = [ConversationCommand.Default] else: final_response = agent_tools + return final_response async def aget_relevant_output_modes( @@ -551,12 +551,14 @@ async def schedule_query( raise AssertionError(f"Invalid response for scheduling query: {raw_response}") -async def extract_relevant_info(q: str, corpus: str, user: KhojUser = None, agent: Agent = None) -> Union[str, None]: +async def extract_relevant_info( + qs: set[str], corpus: str, user: KhojUser = None, agent: Agent = None +) -> Union[str, None]: """ Extract relevant information for a given query from the target corpus """ - if is_none_or_empty(corpus) or is_none_or_empty(q): + if is_none_or_empty(corpus) or is_none_or_empty(qs): return None personality_context = ( @@ -564,17 +566,16 @@ async def extract_relevant_info(q: str, corpus: str, user: KhojUser = None, agen ) extract_relevant_information = prompts.extract_relevant_information.format( - query=q, + query=", ".join(qs), corpus=corpus.strip(), personality_context=personality_context, ) - with timer("Chat actor: Extract relevant information from data", logger): - response = await send_message_to_model_wrapper( - extract_relevant_information, - prompts.system_prompt_extract_relevant_information, - user=user, - ) + response = await send_message_to_model_wrapper( + extract_relevant_information, + prompts.system_prompt_extract_relevant_information, + user=user, + ) return response.strip() diff --git a/src/khoj/search_type/text_search.py b/src/khoj/search_type/text_search.py index 52e23f29..b67132e4 100644 --- a/src/khoj/search_type/text_search.py +++ b/src/khoj/search_type/text_search.py @@ -3,6 +3,7 @@ import math from pathlib import Path from typing import List, Optional, Tuple, Type, Union +import requests import torch from asgiref.sync import sync_to_async from sentence_transformers import util @@ -231,8 +232,12 @@ def setup( def cross_encoder_score(query: str, hits: List[SearchResponse], search_model_name: str) -> List[SearchResponse]: """Score all retrieved entries using the cross-encoder""" - with timer("Cross-Encoder Predict Time", logger, state.device): - cross_scores = state.cross_encoder_model[search_model_name].predict(query, hits) + try: + with timer("Cross-Encoder Predict Time", logger, state.device): + cross_scores = state.cross_encoder_model[search_model_name].predict(query, hits) + except requests.exceptions.HTTPError as e: + logger.error(f"Failed to rerank documents using the inference endpoint. Error: {e}.", exc_info=True) + cross_scores = [0.0] * len(hits) # Convert cross-encoder scores to distances and pass in hits for reranking for idx in range(len(cross_scores)): diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py index e0908e51..7006d7d4 100644 --- a/src/khoj/utils/helpers.py +++ b/src/khoj/utils/helpers.py @@ -2,10 +2,12 @@ from __future__ import annotations # to avoid quoting type hints import datetime import io +import ipaddress import logging import os import platform import random +import urllib.parse import uuid from collections import OrderedDict from enum import Enum @@ -164,9 +166,9 @@ def get_class_by_name(name: str) -> object: class timer: """Context manager to log time taken for a block of code to run""" - def __init__(self, message: str, logger: logging.Logger, device: torch.device = None): + def __init__(self, message: str, logger: logging.Logger, device: torch.device = None, log_level=logging.DEBUG): self.message = message - self.logger = logger + self.logger = logger.debug if log_level == logging.DEBUG else logger.info self.device = device def __enter__(self): @@ -176,9 +178,9 @@ class timer: def __exit__(self, *_): elapsed = perf_counter() - self.start if self.device is None: - self.logger.debug(f"{self.message}: {elapsed:.3f} seconds") + self.logger(f"{self.message}: {elapsed:.3f} seconds") else: - self.logger.debug(f"{self.message}: {elapsed:.3f} seconds on device: {self.device}") + self.logger(f"{self.message}: {elapsed:.3f} seconds on device: {self.device}") class LRU(OrderedDict): @@ -436,6 +438,46 @@ def is_internet_connected(): return False +def is_internal_url(url: str) -> bool: + """ + Check if a URL is likely to be internal/non-public. + + Args: + url (str): The URL to check. + + Returns: + bool: True if the URL is likely internal, False otherwise. + """ + try: + parsed_url = urllib.parse.urlparse(url) + hostname = parsed_url.hostname + + # Check for localhost + if hostname in ["localhost", "127.0.0.1", "::1"]: + return True + + # Check for IP addresses in private ranges + try: + ip = ipaddress.ip_address(hostname) + return ip.is_private + except ValueError: + pass # Not an IP address, continue with other checks + + # Check for common internal TLDs + internal_tlds = [".local", ".internal", ".private", ".corp", ".home", ".lan"] + if any(hostname.endswith(tld) for tld in internal_tlds): + return True + + # Check for URLs without a TLD + if "." not in hostname: + return True + + return False + except Exception: + # If we can't parse the URL or something else goes wrong, assume it's not internal + return False + + def convert_image_to_webp(image_bytes): """Convert image bytes to webp format for faster loading""" image_io = io.BytesIO(image_bytes)