Pass user name to document and online search actors prompts

This should improve the quality of personal information extraction
from document and online sources. The user name is only used when it
is set
This commit is contained in:
Debanjum Singh Solanky
2024-07-26 22:56:24 +05:30
parent eb5af38f33
commit a47a54f207
9 changed files with 42 additions and 11 deletions

View File

@@ -6,7 +6,7 @@ from typing import Dict, Optional
from langchain.schema import ChatMessage from langchain.schema import ChatMessage
from khoj.database.models import Agent from khoj.database.models import Agent, KhojUser
from khoj.processor.conversation import prompts from khoj.processor.conversation import prompts
from khoj.processor.conversation.anthropic.utils import ( from khoj.processor.conversation.anthropic.utils import (
anthropic_chat_completion_with_backoff, anthropic_chat_completion_with_backoff,
@@ -26,12 +26,14 @@ def extract_questions_anthropic(
api_key=None, api_key=None,
temperature=0, temperature=0,
location_data: LocationData = None, location_data: LocationData = None,
user: KhojUser = None,
): ):
""" """
Infer search queries to retrieve relevant notes to answer user query Infer search queries to retrieve relevant notes to answer user query
""" """
# Extract Past User Message and Inferred Questions from Conversation Log # Extract Past User Message and Inferred Questions from Conversation Log
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown" location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
# Extract Past User Message and Inferred Questions from Conversation Log # Extract Past User Message and Inferred Questions from Conversation Log
chat_history = "".join( chat_history = "".join(
@@ -55,6 +57,7 @@ def extract_questions_anthropic(
current_new_year_date=current_new_year.strftime("%Y-%m-%d"), current_new_year_date=current_new_year.strftime("%Y-%m-%d"),
yesterday_date=(today - timedelta(days=1)).strftime("%Y-%m-%d"), yesterday_date=(today - timedelta(days=1)).strftime("%Y-%m-%d"),
location=location, location=location,
username=username,
) )
prompt = prompts.extract_questions_anthropic_user_message.format( prompt = prompts.extract_questions_anthropic_user_message.format(

View File

@@ -7,7 +7,7 @@ from typing import Any, Iterator, List, Union
from langchain.schema import ChatMessage from langchain.schema import ChatMessage
from llama_cpp import Llama from llama_cpp import Llama
from khoj.database.models import Agent from khoj.database.models import Agent, KhojUser
from khoj.processor.conversation import prompts from khoj.processor.conversation import prompts
from khoj.processor.conversation.offline.utils import download_model from khoj.processor.conversation.offline.utils import download_model
from khoj.processor.conversation.utils import ( from khoj.processor.conversation.utils import (
@@ -30,6 +30,7 @@ def extract_questions_offline(
use_history: bool = True, use_history: bool = True,
should_extract_questions: bool = True, should_extract_questions: bool = True,
location_data: LocationData = None, location_data: LocationData = None,
user: KhojUser = None,
max_prompt_size: int = None, max_prompt_size: int = None,
) -> List[str]: ) -> List[str]:
""" """
@@ -45,6 +46,7 @@ def extract_questions_offline(
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size) offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown" location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
# Extract Past User Message and Inferred Questions from Conversation Log # Extract Past User Message and Inferred Questions from Conversation Log
chat_history = "" chat_history = ""
@@ -68,6 +70,7 @@ def extract_questions_offline(
last_year=last_year, last_year=last_year,
this_year=today.year, this_year=today.year,
location=location, location=location,
username=username,
) )
messages = generate_chatml_messages_with_context( messages = generate_chatml_messages_with_context(

View File

@@ -5,7 +5,7 @@ from typing import Dict, Optional
from langchain.schema import ChatMessage from langchain.schema import ChatMessage
from khoj.database.models import Agent from khoj.database.models import Agent, KhojUser
from khoj.processor.conversation import prompts from khoj.processor.conversation import prompts
from khoj.processor.conversation.openai.utils import ( from khoj.processor.conversation.openai.utils import (
chat_completion_with_backoff, chat_completion_with_backoff,
@@ -27,11 +27,13 @@ def extract_questions(
temperature=0, temperature=0,
max_tokens=100, max_tokens=100,
location_data: LocationData = None, location_data: LocationData = None,
user: KhojUser = None,
): ):
""" """
Infer search queries to retrieve relevant notes to answer user query Infer search queries to retrieve relevant notes to answer user query
""" """
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown" location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
# Extract Past User Message and Inferred Questions from Conversation Log # Extract Past User Message and Inferred Questions from Conversation Log
chat_history = "".join( chat_history = "".join(
@@ -59,6 +61,7 @@ def extract_questions(
text=text, text=text,
yesterday_date=(today - timedelta(days=1)).strftime("%Y-%m-%d"), yesterday_date=(today - timedelta(days=1)).strftime("%Y-%m-%d"),
location=location, location=location,
username=username,
) )
messages = [ChatMessage(content=prompt, role="user")] messages = [ChatMessage(content=prompt, role="user")]

View File

@@ -36,7 +36,7 @@ def completion_with_backoff(
messages, model, temperature=0, openai_api_key=None, api_base_url=None, model_kwargs=None messages, model, temperature=0, openai_api_key=None, api_base_url=None, model_kwargs=None
) -> str: ) -> str:
client_key = f"{openai_api_key}--{api_base_url}" client_key = f"{openai_api_key}--{api_base_url}"
client: openai.OpenAI = openai_clients.get(client_key) client: openai.OpenAI | None = openai_clients.get(client_key)
if not client: if not client:
client = openai.OpenAI( client = openai.OpenAI(
api_key=openai_api_key, api_key=openai_api_key,

View File

@@ -212,6 +212,7 @@ Construct search queries to retrieve relevant information to answer the user's q
Current Date: {day_of_week}, {current_date} Current Date: {day_of_week}, {current_date}
User's Location: {location} User's Location: {location}
{username}
Examples: Examples:
Q: How was my trip to Cambodia? Q: How was my trip to Cambodia?
@@ -258,6 +259,7 @@ Construct search queries to retrieve relevant information to answer the user's q
What searches will you perform to answer the users question? Respond with search queries as list of strings in a JSON object. What searches will you perform to answer the users question? Respond with search queries as list of strings in a JSON object.
Current Date: {day_of_week}, {current_date} Current Date: {day_of_week}, {current_date}
User's Location: {location} User's Location: {location}
{username}
Q: How was my trip to Cambodia? Q: How was my trip to Cambodia?
Khoj: {{"queries": ["How was my trip to Cambodia?"]}} Khoj: {{"queries": ["How was my trip to Cambodia?"]}}
@@ -310,6 +312,7 @@ What searches will you perform to answer the users question? Respond with a JSON
Current Date: {day_of_week}, {current_date} Current Date: {day_of_week}, {current_date}
User's Location: {location} User's Location: {location}
{username}
Here are some examples of how you can construct search queries to answer the user's question: Here are some examples of how you can construct search queries to answer the user's question:
@@ -525,6 +528,7 @@ Which webpages will you need to read to answer the user's question?
Provide web page links as a list of strings in a JSON object. Provide web page links as a list of strings in a JSON object.
Current Date: {current_date} Current Date: {current_date}
User's Location: {location} User's Location: {location}
{username}
Here are some examples: Here are some examples:
History: History:
@@ -571,6 +575,7 @@ What Google searches, if any, will you need to perform to answer the user's ques
Provide search queries as a list of strings in a JSON object. Do not wrap the json in a codeblock. Provide search queries as a list of strings in a JSON object. Do not wrap the json in a codeblock.
Current Date: {current_date} Current Date: {current_date}
User's Location: {location} User's Location: {location}
{username}
Here are some examples: Here are some examples:
History: History:

View File

@@ -10,6 +10,7 @@ import aiohttp
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from markdownify import markdownify from markdownify import markdownify
from khoj.database.models import KhojUser
from khoj.routers.helpers import ( from khoj.routers.helpers import (
ChatEvent, ChatEvent,
extract_relevant_info, extract_relevant_info,
@@ -51,6 +52,7 @@ async def search_online(
query: str, query: str,
conversation_history: dict, conversation_history: dict,
location: LocationData, location: LocationData,
user: KhojUser,
send_status_func: Optional[Callable] = None, send_status_func: Optional[Callable] = None,
custom_filters: List[str] = [], custom_filters: List[str] = [],
): ):
@@ -61,7 +63,7 @@ async def search_online(
return return
# Breakdown the query into subqueries to get the correct answer # Breakdown the query into subqueries to get the correct answer
subqueries = await generate_online_subqueries(query, conversation_history, location) subqueries = await generate_online_subqueries(query, conversation_history, location, user)
response_dict = {} response_dict = {}
if subqueries: if subqueries:
@@ -126,14 +128,18 @@ async def search_with_google(query: str) -> Tuple[str, Dict[str, List[Dict]]]:
async def read_webpages( async def read_webpages(
query: str, conversation_history: dict, location: LocationData, send_status_func: Optional[Callable] = None query: str,
conversation_history: dict,
location: LocationData,
user: KhojUser,
send_status_func: Optional[Callable] = None,
): ):
"Infer web pages to read from the query and extract relevant information from them" "Infer web pages to read from the query and extract relevant information from them"
logger.info(f"Inferring web pages to read") logger.info(f"Inferring web pages to read")
if send_status_func: if send_status_func:
async for event in send_status_func(f"**🧐 Inferring web pages to read**"): async for event in send_status_func(f"**🧐 Inferring web pages to read**"):
yield {ChatEvent.STATUS: event} yield {ChatEvent.STATUS: event}
urls = await infer_webpage_urls(query, conversation_history, location) urls = await infer_webpage_urls(query, conversation_history, location, user)
logger.info(f"Reading web pages at: {urls}") logger.info(f"Reading web pages at: {urls}")
if send_status_func: if send_status_func:

View File

@@ -343,6 +343,7 @@ async def extract_references_and_questions(
conversation_log=meta_log, conversation_log=meta_log,
should_extract_questions=True, should_extract_questions=True,
location_data=location_data, location_data=location_data,
user=user,
max_prompt_size=conversation_config.max_prompt_size, max_prompt_size=conversation_config.max_prompt_size,
) )
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI: elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
@@ -357,6 +358,7 @@ async def extract_references_and_questions(
api_base_url=base_url, api_base_url=base_url,
conversation_log=meta_log, conversation_log=meta_log,
location_data=location_data, location_data=location_data,
user=user,
max_tokens=conversation_config.max_prompt_size, max_tokens=conversation_config.max_prompt_size,
) )
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC: elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC:
@@ -368,6 +370,7 @@ async def extract_references_and_questions(
api_key=api_key, api_key=api_key,
conversation_log=meta_log, conversation_log=meta_log,
location_data=location_data, location_data=location_data,
user=user,
) )
# Collate search results as context for GPT # Collate search results as context for GPT

View File

@@ -800,7 +800,7 @@ async def chat(
if ConversationCommand.Online in conversation_commands: if ConversationCommand.Online in conversation_commands:
try: try:
async for result in search_online( async for result in search_online(
defiltered_query, meta_log, location, partial(send_event, ChatEvent.STATUS), custom_filters defiltered_query, meta_log, location, user, partial(send_event, ChatEvent.STATUS), custom_filters
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS] yield result[ChatEvent.STATUS]
@@ -817,7 +817,7 @@ async def chat(
if ConversationCommand.Webpage in conversation_commands: if ConversationCommand.Webpage in conversation_commands:
try: try:
async for result in read_webpages( async for result in read_webpages(
defiltered_query, meta_log, location, partial(send_event, ChatEvent.STATUS) defiltered_query, meta_log, location, user, partial(send_event, ChatEvent.STATUS)
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS] yield result[ChatEvent.STATUS]

View File

@@ -315,11 +315,14 @@ async def aget_relevant_output_modes(query: str, conversation_history: dict, is_
return ConversationCommand.Text return ConversationCommand.Text
async def infer_webpage_urls(q: str, conversation_history: dict, location_data: LocationData) -> List[str]: async def infer_webpage_urls(
q: str, conversation_history: dict, location_data: LocationData, user: KhojUser
) -> List[str]:
""" """
Infer webpage links from the given query Infer webpage links from the given query
""" """
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown" location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else ""
chat_history = construct_chat_history(conversation_history) chat_history = construct_chat_history(conversation_history)
utc_date = datetime.utcnow().strftime("%Y-%m-%d") utc_date = datetime.utcnow().strftime("%Y-%m-%d")
@@ -328,6 +331,7 @@ async def infer_webpage_urls(q: str, conversation_history: dict, location_data:
query=q, query=q,
chat_history=chat_history, chat_history=chat_history,
location=location, location=location,
username=username,
) )
with timer("Chat actor: Infer webpage urls to read", logger): with timer("Chat actor: Infer webpage urls to read", logger):
@@ -345,11 +349,14 @@ async def infer_webpage_urls(q: str, conversation_history: dict, location_data:
raise ValueError(f"Invalid list of urls: {response}") raise ValueError(f"Invalid list of urls: {response}")
async def generate_online_subqueries(q: str, conversation_history: dict, location_data: LocationData) -> List[str]: async def generate_online_subqueries(
q: str, conversation_history: dict, location_data: LocationData, user: KhojUser
) -> List[str]:
""" """
Generate subqueries from the given query Generate subqueries from the given query
""" """
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown" location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else ""
chat_history = construct_chat_history(conversation_history) chat_history = construct_chat_history(conversation_history)
utc_date = datetime.utcnow().strftime("%Y-%m-%d") utc_date = datetime.utcnow().strftime("%Y-%m-%d")
@@ -358,6 +365,7 @@ async def generate_online_subqueries(q: str, conversation_history: dict, locatio
query=q, query=q,
chat_history=chat_history, chat_history=chat_history,
location=location, location=location,
username=username,
) )
with timer("Chat actor: Generate online search subqueries", logger): with timer("Chat actor: Generate online search subqueries", logger):