diff --git a/src/khoj/processor/conversation/anthropic/anthropic_chat.py b/src/khoj/processor/conversation/anthropic/anthropic_chat.py index d5778885..72bf9250 100644 --- a/src/khoj/processor/conversation/anthropic/anthropic_chat.py +++ b/src/khoj/processor/conversation/anthropic/anthropic_chat.py @@ -6,7 +6,7 @@ from typing import Dict, Optional from langchain.schema import ChatMessage -from khoj.database.models import Agent +from khoj.database.models import Agent, KhojUser from khoj.processor.conversation import prompts from khoj.processor.conversation.anthropic.utils import ( anthropic_chat_completion_with_backoff, @@ -26,12 +26,14 @@ def extract_questions_anthropic( api_key=None, temperature=0, location_data: LocationData = None, + user: KhojUser = None, ): """ Infer search queries to retrieve relevant notes to answer user query """ # Extract Past User Message and Inferred Questions from Conversation Log location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown" + username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else "" # Extract Past User Message and Inferred Questions from Conversation Log chat_history = "".join( @@ -55,6 +57,7 @@ def extract_questions_anthropic( current_new_year_date=current_new_year.strftime("%Y-%m-%d"), yesterday_date=(today - timedelta(days=1)).strftime("%Y-%m-%d"), location=location, + username=username, ) prompt = prompts.extract_questions_anthropic_user_message.format( diff --git a/src/khoj/processor/conversation/offline/chat_model.py b/src/khoj/processor/conversation/offline/chat_model.py index 2da0c186..2fabddc7 100644 --- a/src/khoj/processor/conversation/offline/chat_model.py +++ b/src/khoj/processor/conversation/offline/chat_model.py @@ -7,7 +7,7 @@ from typing import Any, Iterator, List, Union from langchain.schema import ChatMessage from llama_cpp import Llama -from khoj.database.models import Agent +from khoj.database.models import Agent, KhojUser from khoj.processor.conversation import prompts from khoj.processor.conversation.offline.utils import download_model from khoj.processor.conversation.utils import ( @@ -30,6 +30,7 @@ def extract_questions_offline( use_history: bool = True, should_extract_questions: bool = True, location_data: LocationData = None, + user: KhojUser = None, max_prompt_size: int = None, ) -> List[str]: """ @@ -45,6 +46,7 @@ def extract_questions_offline( offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size) location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown" + username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else "" # Extract Past User Message and Inferred Questions from Conversation Log chat_history = "" @@ -68,6 +70,7 @@ def extract_questions_offline( last_year=last_year, this_year=today.year, location=location, + username=username, ) messages = generate_chatml_messages_with_context( diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index f1608fba..7649a6df 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -5,7 +5,7 @@ from typing import Dict, Optional from langchain.schema import ChatMessage -from khoj.database.models import Agent +from khoj.database.models import Agent, KhojUser from khoj.processor.conversation import prompts from khoj.processor.conversation.openai.utils import ( chat_completion_with_backoff, @@ -27,11 +27,13 @@ def extract_questions( temperature=0, max_tokens=100, location_data: LocationData = None, + user: KhojUser = None, ): """ Infer search queries to retrieve relevant notes to answer user query """ location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown" + username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else "" # Extract Past User Message and Inferred Questions from Conversation Log chat_history = "".join( @@ -59,6 +61,7 @@ def extract_questions( text=text, yesterday_date=(today - timedelta(days=1)).strftime("%Y-%m-%d"), location=location, + username=username, ) messages = [ChatMessage(content=prompt, role="user")] diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index b17e5c3d..841e6aa7 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -36,7 +36,7 @@ def completion_with_backoff( messages, model, temperature=0, openai_api_key=None, api_base_url=None, model_kwargs=None ) -> str: client_key = f"{openai_api_key}--{api_base_url}" - client: openai.OpenAI = openai_clients.get(client_key) + client: openai.OpenAI | None = openai_clients.get(client_key) if not client: client = openai.OpenAI( api_key=openai_api_key, diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index 289bafbc..6412e232 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -212,6 +212,7 @@ Construct search queries to retrieve relevant information to answer the user's q Current Date: {day_of_week}, {current_date} User's Location: {location} +{username} Examples: Q: How was my trip to Cambodia? @@ -258,6 +259,7 @@ Construct search queries to retrieve relevant information to answer the user's q What searches will you perform to answer the users question? Respond with search queries as list of strings in a JSON object. Current Date: {day_of_week}, {current_date} User's Location: {location} +{username} Q: How was my trip to Cambodia? Khoj: {{"queries": ["How was my trip to Cambodia?"]}} @@ -310,6 +312,7 @@ What searches will you perform to answer the users question? Respond with a JSON Current Date: {day_of_week}, {current_date} User's Location: {location} +{username} Here are some examples of how you can construct search queries to answer the user's question: @@ -525,6 +528,7 @@ Which webpages will you need to read to answer the user's question? Provide web page links as a list of strings in a JSON object. Current Date: {current_date} User's Location: {location} +{username} Here are some examples: History: @@ -571,6 +575,7 @@ What Google searches, if any, will you need to perform to answer the user's ques Provide search queries as a list of strings in a JSON object. Do not wrap the json in a codeblock. Current Date: {current_date} User's Location: {location} +{username} Here are some examples: History: diff --git a/src/khoj/processor/tools/online_search.py b/src/khoj/processor/tools/online_search.py index c087de70..29b0f850 100644 --- a/src/khoj/processor/tools/online_search.py +++ b/src/khoj/processor/tools/online_search.py @@ -10,6 +10,7 @@ import aiohttp from bs4 import BeautifulSoup from markdownify import markdownify +from khoj.database.models import KhojUser from khoj.routers.helpers import ( ChatEvent, extract_relevant_info, @@ -51,6 +52,7 @@ async def search_online( query: str, conversation_history: dict, location: LocationData, + user: KhojUser, send_status_func: Optional[Callable] = None, custom_filters: List[str] = [], ): @@ -61,7 +63,7 @@ async def search_online( return # Breakdown the query into subqueries to get the correct answer - subqueries = await generate_online_subqueries(query, conversation_history, location) + subqueries = await generate_online_subqueries(query, conversation_history, location, user) response_dict = {} if subqueries: @@ -126,14 +128,18 @@ async def search_with_google(query: str) -> Tuple[str, Dict[str, List[Dict]]]: async def read_webpages( - query: str, conversation_history: dict, location: LocationData, send_status_func: Optional[Callable] = None + query: str, + conversation_history: dict, + location: LocationData, + user: KhojUser, + send_status_func: Optional[Callable] = None, ): "Infer web pages to read from the query and extract relevant information from them" logger.info(f"Inferring web pages to read") if send_status_func: async for event in send_status_func(f"**🧐 Inferring web pages to read**"): yield {ChatEvent.STATUS: event} - urls = await infer_webpage_urls(query, conversation_history, location) + urls = await infer_webpage_urls(query, conversation_history, location, user) logger.info(f"Reading web pages at: {urls}") if send_status_func: diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 15d7cbc7..5f89cb72 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -343,6 +343,7 @@ async def extract_references_and_questions( conversation_log=meta_log, should_extract_questions=True, location_data=location_data, + user=user, max_prompt_size=conversation_config.max_prompt_size, ) elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI: @@ -357,6 +358,7 @@ async def extract_references_and_questions( api_base_url=base_url, conversation_log=meta_log, location_data=location_data, + user=user, max_tokens=conversation_config.max_prompt_size, ) elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC: @@ -368,6 +370,7 @@ async def extract_references_and_questions( api_key=api_key, conversation_log=meta_log, location_data=location_data, + user=user, ) # Collate search results as context for GPT diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 63529b8e..d9223192 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -800,7 +800,7 @@ async def chat( if ConversationCommand.Online in conversation_commands: try: async for result in search_online( - defiltered_query, meta_log, location, partial(send_event, ChatEvent.STATUS), custom_filters + defiltered_query, meta_log, location, user, partial(send_event, ChatEvent.STATUS), custom_filters ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] @@ -817,7 +817,7 @@ async def chat( if ConversationCommand.Webpage in conversation_commands: try: async for result in read_webpages( - defiltered_query, meta_log, location, partial(send_event, ChatEvent.STATUS) + defiltered_query, meta_log, location, user, partial(send_event, ChatEvent.STATUS) ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 846f5c8f..69ef74b1 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -315,11 +315,14 @@ async def aget_relevant_output_modes(query: str, conversation_history: dict, is_ return ConversationCommand.Text -async def infer_webpage_urls(q: str, conversation_history: dict, location_data: LocationData) -> List[str]: +async def infer_webpage_urls( + q: str, conversation_history: dict, location_data: LocationData, user: KhojUser +) -> List[str]: """ Infer webpage links from the given query """ location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown" + username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else "" chat_history = construct_chat_history(conversation_history) utc_date = datetime.utcnow().strftime("%Y-%m-%d") @@ -328,6 +331,7 @@ async def infer_webpage_urls(q: str, conversation_history: dict, location_data: query=q, chat_history=chat_history, location=location, + username=username, ) with timer("Chat actor: Infer webpage urls to read", logger): @@ -345,11 +349,14 @@ async def infer_webpage_urls(q: str, conversation_history: dict, location_data: raise ValueError(f"Invalid list of urls: {response}") -async def generate_online_subqueries(q: str, conversation_history: dict, location_data: LocationData) -> List[str]: +async def generate_online_subqueries( + q: str, conversation_history: dict, location_data: LocationData, user: KhojUser +) -> List[str]: """ Generate subqueries from the given query """ location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown" + username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else "" chat_history = construct_chat_history(conversation_history) utc_date = datetime.utcnow().strftime("%Y-%m-%d") @@ -358,6 +365,7 @@ async def generate_online_subqueries(q: str, conversation_history: dict, locatio query=q, chat_history=chat_history, location=location, + username=username, ) with timer("Chat actor: Generate online search subqueries", logger):