From a47a54f2079ac754df5bf5bd3a321870071268b2 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Fri, 26 Jul 2024 22:56:24 +0530 Subject: [PATCH] Pass user name to document and online search actors prompts This should improve the quality of personal information extraction from document and online sources. The user name is only used when it is set --- .../conversation/anthropic/anthropic_chat.py | 5 ++++- .../processor/conversation/offline/chat_model.py | 5 ++++- src/khoj/processor/conversation/openai/gpt.py | 5 ++++- src/khoj/processor/conversation/openai/utils.py | 2 +- src/khoj/processor/conversation/prompts.py | 5 +++++ src/khoj/processor/tools/online_search.py | 12 +++++++++--- src/khoj/routers/api.py | 3 +++ src/khoj/routers/api_chat.py | 4 ++-- src/khoj/routers/helpers.py | 12 ++++++++++-- 9 files changed, 42 insertions(+), 11 deletions(-) diff --git a/src/khoj/processor/conversation/anthropic/anthropic_chat.py b/src/khoj/processor/conversation/anthropic/anthropic_chat.py index d5778885..72bf9250 100644 --- a/src/khoj/processor/conversation/anthropic/anthropic_chat.py +++ b/src/khoj/processor/conversation/anthropic/anthropic_chat.py @@ -6,7 +6,7 @@ from typing import Dict, Optional from langchain.schema import ChatMessage -from khoj.database.models import Agent +from khoj.database.models import Agent, KhojUser from khoj.processor.conversation import prompts from khoj.processor.conversation.anthropic.utils import ( anthropic_chat_completion_with_backoff, @@ -26,12 +26,14 @@ def extract_questions_anthropic( api_key=None, temperature=0, location_data: LocationData = None, + user: KhojUser = None, ): """ Infer search queries to retrieve relevant notes to answer user query """ # Extract Past User Message and Inferred Questions from Conversation Log location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown" + username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else "" # Extract Past User Message and Inferred Questions from Conversation Log chat_history = "".join( @@ -55,6 +57,7 @@ def extract_questions_anthropic( current_new_year_date=current_new_year.strftime("%Y-%m-%d"), yesterday_date=(today - timedelta(days=1)).strftime("%Y-%m-%d"), location=location, + username=username, ) prompt = prompts.extract_questions_anthropic_user_message.format( diff --git a/src/khoj/processor/conversation/offline/chat_model.py b/src/khoj/processor/conversation/offline/chat_model.py index 2da0c186..2fabddc7 100644 --- a/src/khoj/processor/conversation/offline/chat_model.py +++ b/src/khoj/processor/conversation/offline/chat_model.py @@ -7,7 +7,7 @@ from typing import Any, Iterator, List, Union from langchain.schema import ChatMessage from llama_cpp import Llama -from khoj.database.models import Agent +from khoj.database.models import Agent, KhojUser from khoj.processor.conversation import prompts from khoj.processor.conversation.offline.utils import download_model from khoj.processor.conversation.utils import ( @@ -30,6 +30,7 @@ def extract_questions_offline( use_history: bool = True, should_extract_questions: bool = True, location_data: LocationData = None, + user: KhojUser = None, max_prompt_size: int = None, ) -> List[str]: """ @@ -45,6 +46,7 @@ def extract_questions_offline( offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size) location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown" + username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else "" # Extract Past User Message and Inferred Questions from Conversation Log chat_history = "" @@ -68,6 +70,7 @@ def extract_questions_offline( last_year=last_year, this_year=today.year, location=location, + username=username, ) messages = generate_chatml_messages_with_context( diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index f1608fba..7649a6df 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -5,7 +5,7 @@ from typing import Dict, Optional from langchain.schema import ChatMessage -from khoj.database.models import Agent +from khoj.database.models import Agent, KhojUser from khoj.processor.conversation import prompts from khoj.processor.conversation.openai.utils import ( chat_completion_with_backoff, @@ -27,11 +27,13 @@ def extract_questions( temperature=0, max_tokens=100, location_data: LocationData = None, + user: KhojUser = None, ): """ Infer search queries to retrieve relevant notes to answer user query """ location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown" + username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else "" # Extract Past User Message and Inferred Questions from Conversation Log chat_history = "".join( @@ -59,6 +61,7 @@ def extract_questions( text=text, yesterday_date=(today - timedelta(days=1)).strftime("%Y-%m-%d"), location=location, + username=username, ) messages = [ChatMessage(content=prompt, role="user")] diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index b17e5c3d..841e6aa7 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -36,7 +36,7 @@ def completion_with_backoff( messages, model, temperature=0, openai_api_key=None, api_base_url=None, model_kwargs=None ) -> str: client_key = f"{openai_api_key}--{api_base_url}" - client: openai.OpenAI = openai_clients.get(client_key) + client: openai.OpenAI | None = openai_clients.get(client_key) if not client: client = openai.OpenAI( api_key=openai_api_key, diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index 289bafbc..6412e232 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -212,6 +212,7 @@ Construct search queries to retrieve relevant information to answer the user's q Current Date: {day_of_week}, {current_date} User's Location: {location} +{username} Examples: Q: How was my trip to Cambodia? @@ -258,6 +259,7 @@ Construct search queries to retrieve relevant information to answer the user's q What searches will you perform to answer the users question? Respond with search queries as list of strings in a JSON object. Current Date: {day_of_week}, {current_date} User's Location: {location} +{username} Q: How was my trip to Cambodia? Khoj: {{"queries": ["How was my trip to Cambodia?"]}} @@ -310,6 +312,7 @@ What searches will you perform to answer the users question? Respond with a JSON Current Date: {day_of_week}, {current_date} User's Location: {location} +{username} Here are some examples of how you can construct search queries to answer the user's question: @@ -525,6 +528,7 @@ Which webpages will you need to read to answer the user's question? Provide web page links as a list of strings in a JSON object. Current Date: {current_date} User's Location: {location} +{username} Here are some examples: History: @@ -571,6 +575,7 @@ What Google searches, if any, will you need to perform to answer the user's ques Provide search queries as a list of strings in a JSON object. Do not wrap the json in a codeblock. Current Date: {current_date} User's Location: {location} +{username} Here are some examples: History: diff --git a/src/khoj/processor/tools/online_search.py b/src/khoj/processor/tools/online_search.py index c087de70..29b0f850 100644 --- a/src/khoj/processor/tools/online_search.py +++ b/src/khoj/processor/tools/online_search.py @@ -10,6 +10,7 @@ import aiohttp from bs4 import BeautifulSoup from markdownify import markdownify +from khoj.database.models import KhojUser from khoj.routers.helpers import ( ChatEvent, extract_relevant_info, @@ -51,6 +52,7 @@ async def search_online( query: str, conversation_history: dict, location: LocationData, + user: KhojUser, send_status_func: Optional[Callable] = None, custom_filters: List[str] = [], ): @@ -61,7 +63,7 @@ async def search_online( return # Breakdown the query into subqueries to get the correct answer - subqueries = await generate_online_subqueries(query, conversation_history, location) + subqueries = await generate_online_subqueries(query, conversation_history, location, user) response_dict = {} if subqueries: @@ -126,14 +128,18 @@ async def search_with_google(query: str) -> Tuple[str, Dict[str, List[Dict]]]: async def read_webpages( - query: str, conversation_history: dict, location: LocationData, send_status_func: Optional[Callable] = None + query: str, + conversation_history: dict, + location: LocationData, + user: KhojUser, + send_status_func: Optional[Callable] = None, ): "Infer web pages to read from the query and extract relevant information from them" logger.info(f"Inferring web pages to read") if send_status_func: async for event in send_status_func(f"**🧐 Inferring web pages to read**"): yield {ChatEvent.STATUS: event} - urls = await infer_webpage_urls(query, conversation_history, location) + urls = await infer_webpage_urls(query, conversation_history, location, user) logger.info(f"Reading web pages at: {urls}") if send_status_func: diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 15d7cbc7..5f89cb72 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -343,6 +343,7 @@ async def extract_references_and_questions( conversation_log=meta_log, should_extract_questions=True, location_data=location_data, + user=user, max_prompt_size=conversation_config.max_prompt_size, ) elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI: @@ -357,6 +358,7 @@ async def extract_references_and_questions( api_base_url=base_url, conversation_log=meta_log, location_data=location_data, + user=user, max_tokens=conversation_config.max_prompt_size, ) elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC: @@ -368,6 +370,7 @@ async def extract_references_and_questions( api_key=api_key, conversation_log=meta_log, location_data=location_data, + user=user, ) # Collate search results as context for GPT diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 63529b8e..d9223192 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -800,7 +800,7 @@ async def chat( if ConversationCommand.Online in conversation_commands: try: async for result in search_online( - defiltered_query, meta_log, location, partial(send_event, ChatEvent.STATUS), custom_filters + defiltered_query, meta_log, location, user, partial(send_event, ChatEvent.STATUS), custom_filters ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] @@ -817,7 +817,7 @@ async def chat( if ConversationCommand.Webpage in conversation_commands: try: async for result in read_webpages( - defiltered_query, meta_log, location, partial(send_event, ChatEvent.STATUS) + defiltered_query, meta_log, location, user, partial(send_event, ChatEvent.STATUS) ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 846f5c8f..69ef74b1 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -315,11 +315,14 @@ async def aget_relevant_output_modes(query: str, conversation_history: dict, is_ return ConversationCommand.Text -async def infer_webpage_urls(q: str, conversation_history: dict, location_data: LocationData) -> List[str]: +async def infer_webpage_urls( + q: str, conversation_history: dict, location_data: LocationData, user: KhojUser +) -> List[str]: """ Infer webpage links from the given query """ location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown" + username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else "" chat_history = construct_chat_history(conversation_history) utc_date = datetime.utcnow().strftime("%Y-%m-%d") @@ -328,6 +331,7 @@ async def infer_webpage_urls(q: str, conversation_history: dict, location_data: query=q, chat_history=chat_history, location=location, + username=username, ) with timer("Chat actor: Infer webpage urls to read", logger): @@ -345,11 +349,14 @@ async def infer_webpage_urls(q: str, conversation_history: dict, location_data: raise ValueError(f"Invalid list of urls: {response}") -async def generate_online_subqueries(q: str, conversation_history: dict, location_data: LocationData) -> List[str]: +async def generate_online_subqueries( + q: str, conversation_history: dict, location_data: LocationData, user: KhojUser +) -> List[str]: """ Generate subqueries from the given query """ location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown" + username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else "" chat_history = construct_chat_history(conversation_history) utc_date = datetime.utcnow().strftime("%Y-%m-%d") @@ -358,6 +365,7 @@ async def generate_online_subqueries(q: str, conversation_history: dict, locatio query=q, chat_history=chat_history, location=location, + username=username, ) with timer("Chat actor: Generate online search subqueries", logger):