Simplify, modularize and add type hints to online search functions

- Simplify content arg to `extract_relevant_info' function. Validate,
  clean the content arg inside the `extract_relevant_info' function

- Extract `search_with_google' function outside the parent function
- Call the parent function a more appropriate `search_online' instead
  of `search_with_google'
- Simplify the `search_with_google' function using list comprehension.
  Drop empty search result fields from chat model context for response
  to reduce cost and response latency

- No need to show stacktrace when unable to read webpage, basic error
  is enough
- Add type hints to online search functions to catch issues with mypy
This commit is contained in:
Debanjum Singh Solanky
2024-03-10 02:09:11 +05:30
parent 88f096977b
commit d136a6be44
5 changed files with 38 additions and 43 deletions

View File

@@ -247,7 +247,7 @@ def llm_thread(g, messages: List[ChatMessage], model: Any):
def send_message_to_model_offline( def send_message_to_model_offline(
message, loaded_model=None, model="mistral-7b-instruct-v0.1.Q4_0.gguf", streaming=False, system_message="" message, loaded_model=None, model="mistral-7b-instruct-v0.1.Q4_0.gguf", streaming=False, system_message=""
): ) -> str:
try: try:
from gpt4all import GPT4All from gpt4all import GPT4All
except ModuleNotFoundError as e: except ModuleNotFoundError as e:

View File

@@ -43,7 +43,7 @@ class StreamingChatCallbackHandler(StreamingStdOutCallbackHandler):
before_sleep=before_sleep_log(logger, logging.DEBUG), before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True, reraise=True,
) )
def completion_with_backoff(**kwargs): def completion_with_backoff(**kwargs) -> str:
messages = kwargs.pop("messages") messages = kwargs.pop("messages")
if not "openai_api_key" in kwargs: if not "openai_api_key" in kwargs:
kwargs["openai_api_key"] = os.getenv("OPENAI_API_KEY") kwargs["openai_api_key"] = os.getenv("OPENAI_API_KEY")

View File

@@ -2,7 +2,7 @@ import asyncio
import json import json
import logging import logging
import os import os
from typing import Dict, Union from typing import Dict, Tuple, Union
import aiohttp import aiohttp
import requests import requests
@@ -16,12 +16,10 @@ from khoj.utils.rawconfig import LocationData
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
SERPER_DEV_API_KEY = os.getenv("SERPER_DEV_API_KEY") SERPER_DEV_API_KEY = os.getenv("SERPER_DEV_API_KEY")
OLOSTEP_API_KEY = os.getenv("OLOSTEP_API_KEY")
SERPER_DEV_URL = "https://google.serper.dev/search" SERPER_DEV_URL = "https://google.serper.dev/search"
OLOSTEP_API_KEY = os.getenv("OLOSTEP_API_KEY")
OLOSTEP_API_URL = "https://agent.olostep.com/olostep-p2p-incomingAPI" OLOSTEP_API_URL = "https://agent.olostep.com/olostep-p2p-incomingAPI"
OLOSTEP_QUERY_PARAMS = { OLOSTEP_QUERY_PARAMS = {
"timeout": 35, # seconds "timeout": 35, # seconds
"waitBeforeScraping": 1, # seconds "waitBeforeScraping": 1, # seconds
@@ -39,31 +37,7 @@ OLOSTEP_QUERY_PARAMS = {
MAX_WEBPAGES_TO_READ = 1 MAX_WEBPAGES_TO_READ = 1
async def search_with_google(query: str, conversation_history: dict, location: LocationData): async def search_online(query: str, conversation_history: dict, location: LocationData):
def _search_with_google(subquery: str):
payload = json.dumps(
{
"q": subquery,
}
)
headers = {"X-API-KEY": SERPER_DEV_API_KEY, "Content-Type": "application/json"}
response = requests.request("POST", SERPER_DEV_URL, headers=headers, data=payload)
if response.status_code != 200:
logger.error(response.text)
return {}
json_response = response.json()
sub_response_dict = {}
sub_response_dict["knowledgeGraph"] = json_response.get("knowledgeGraph", {})
sub_response_dict["organic"] = json_response.get("organic", [])
sub_response_dict["answerBox"] = json_response.get("answerBox", [])
sub_response_dict["peopleAlsoAsk"] = json_response.get("peopleAlsoAsk", [])
return sub_response_dict
if SERPER_DEV_API_KEY is None: if SERPER_DEV_API_KEY is None:
logger.warn("SERPER_DEV_API_KEY is not set") logger.warn("SERPER_DEV_API_KEY is not set")
return {} return {}
@@ -74,14 +48,14 @@ async def search_with_google(query: str, conversation_history: dict, location: L
for subquery in subqueries: for subquery in subqueries:
logger.info(f"Searching with Google for '{subquery}'") logger.info(f"Searching with Google for '{subquery}'")
response_dict[subquery] = _search_with_google(subquery) response_dict[subquery] = search_with_google(subquery)
# Gather distinct web pages from organic search results of each subquery without an instant answer # Gather distinct web pages from organic search results of each subquery without an instant answer
webpage_links = { webpage_links = {
result["link"] result["link"]
for subquery in response_dict for subquery in response_dict
for result in response_dict[subquery].get("organic")[:MAX_WEBPAGES_TO_READ] for result in response_dict[subquery].get("organic", [])[:MAX_WEBPAGES_TO_READ]
if is_none_or_empty(response_dict[subquery].get("answerBox")) if "answerBox" not in response_dict[subquery]
} }
# Read, extract relevant info from the retrieved web pages # Read, extract relevant info from the retrieved web pages
@@ -100,15 +74,34 @@ async def search_with_google(query: str, conversation_history: dict, location: L
return response_dict return response_dict
async def read_webpage_and_extract_content(subquery, url): def search_with_google(subquery: str):
payload = json.dumps({"q": subquery})
headers = {"X-API-KEY": SERPER_DEV_API_KEY, "Content-Type": "application/json"}
response = requests.request("POST", SERPER_DEV_URL, headers=headers, data=payload)
if response.status_code != 200:
logger.error(response.text)
return {}
json_response = response.json()
extraction_fields = ["organic", "answerBox", "peopleAlsoAsk", "knowledgeGraph"]
extracted_search_result = {
field: json_response[field] for field in extraction_fields if not is_none_or_empty(json_response.get(field))
}
return extracted_search_result
async def read_webpage_and_extract_content(subquery: str, url: str) -> Tuple[str, Union[None, str]]:
try: try:
with timer(f"Reading web page at '{url}' took", logger): with timer(f"Reading web page at '{url}' took", logger):
content = await read_webpage_with_olostep(url) if OLOSTEP_API_KEY else await read_webpage(url) content = await read_webpage_with_olostep(url) if OLOSTEP_API_KEY else await read_webpage(url)
with timer(f"Extracting relevant information from web page at '{url}' took", logger): with timer(f"Extracting relevant information from web page at '{url}' took", logger):
extracted_info = await extract_relevant_info(subquery, {subquery: [content.strip()]}) if content else None extracted_info = await extract_relevant_info(subquery, content)
return subquery, extracted_info return subquery, extracted_info
except Exception as e: except Exception as e:
logger.error(f"Failed to read web page at '{url}': {e}", exc_info=True) logger.error(f"Failed to read web page at '{url}' with {e}")
return subquery, None return subquery, None

View File

@@ -14,7 +14,7 @@ from khoj.database.adapters import ConversationAdapters, EntryAdapters, aget_use
from khoj.database.models import KhojUser from khoj.database.models import KhojUser
from khoj.processor.conversation.prompts import help_message, no_entries_found from khoj.processor.conversation.prompts import help_message, no_entries_found
from khoj.processor.conversation.utils import save_to_conversation_log from khoj.processor.conversation.utils import save_to_conversation_log
from khoj.processor.tools.online_search import search_with_google from khoj.processor.tools.online_search import search_online
from khoj.routers.api import extract_references_and_questions from khoj.routers.api import extract_references_and_questions
from khoj.routers.helpers import ( from khoj.routers.helpers import (
ApiUserRateLimiter, ApiUserRateLimiter,
@@ -284,7 +284,7 @@ async def chat(
if ConversationCommand.Online in conversation_commands: if ConversationCommand.Online in conversation_commands:
try: try:
online_results = await search_with_google(defiltered_query, meta_log, location) online_results = await search_online(defiltered_query, meta_log, location)
except ValueError as e: except ValueError as e:
return StreamingResponse( return StreamingResponse(
iter(["Please set your SERPER_DEV_API_KEY to get started with online searches 🌐"]), iter(["Please set your SERPER_DEV_API_KEY to get started with online searches 🌐"]),

View File

@@ -256,15 +256,17 @@ async def generate_online_subqueries(q: str, conversation_history: dict, locatio
return [q] return [q]
async def extract_relevant_info(q: str, corpus: dict) -> List[str]: async def extract_relevant_info(q: str, corpus: str) -> Union[str, None]:
""" """
Given a target corpus, extract the most relevant info given a query Extract relevant information for a given query from the target corpus
""" """
key = list(corpus.keys())[0] if is_none_or_empty(corpus) or is_none_or_empty(q):
return None
extract_relevant_information = prompts.extract_relevant_information.format( extract_relevant_information = prompts.extract_relevant_information.format(
query=q, query=q,
corpus=corpus[key], corpus=corpus.strip(),
) )
response = await send_message_to_model_wrapper( response = await send_message_to_model_wrapper(