Have Khoj dynamically select conversation command(s) in chat (#641)

* Have Khoj dynamically select which conversation command(s) are to be used in the chat flow
- Intercept the commands if in default mode, and have Khoj dynamically guess which tools would be the most relevant for answering the user's query
* Remove conditional for default to enter online search mode
* Add multiple-tool examples in the prompt, make prompt for tools more specific to info collection
This commit is contained in:
sabaimran
2024-02-11 03:41:32 -08:00
committed by GitHub
parent 69344a6aa6
commit a3eb17b7d4
10 changed files with 373 additions and 63 deletions

View File

@@ -479,7 +479,7 @@ class ConversationAdapters:
conversation_id: int = None,
user_message: str = None,
):
slug = user_message.strip()[:200] if not is_none_or_empty(user_message) else None
slug = user_message.strip()[:200] if user_message else None
if conversation_id:
conversation = Conversation.objects.filter(user=user, client=client_application, id=conversation_id)
else:

View File

@@ -130,7 +130,7 @@ def converse_offline(
model: str = "mistral-7b-instruct-v0.1.Q4_0.gguf",
loaded_model: Union[Any, None] = None,
completion_func=None,
conversation_command=ConversationCommand.Default,
conversation_commands=[ConversationCommand.Default],
max_prompt_size=None,
tokenizer_name=None,
) -> Union[ThreadedGenerator, Iterator[str]]:
@@ -148,27 +148,24 @@ def converse_offline(
# Initialize Variables
compiled_references_message = "\n\n".join({f"{item}" for item in references})
conversation_primer = prompts.query_prompt.format(query=user_query)
# Get Conversation Primer appropriate to Conversation Type
if conversation_command == ConversationCommand.Notes and is_none_or_empty(compiled_references_message):
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(compiled_references_message):
return iter([prompts.no_notes_found.format()])
elif conversation_command == ConversationCommand.Online and is_none_or_empty(online_results):
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
completion_func(chat_response=prompts.no_online_results_found.format())
return iter([prompts.no_online_results_found.format()])
elif conversation_command == ConversationCommand.Online:
if ConversationCommand.Online in conversation_commands:
simplified_online_results = online_results.copy()
for result in online_results:
if online_results[result].get("extracted_content"):
simplified_online_results[result] = online_results[result]["extracted_content"]
conversation_primer = prompts.online_search_conversation.format(
query=user_query, online_results=str(simplified_online_results)
)
elif conversation_command == ConversationCommand.General or is_none_or_empty(compiled_references_message):
conversation_primer = user_query
else:
conversation_primer = prompts.notes_conversation_gpt4all.format(
query=user_query, references=compiled_references_message
)
conversation_primer = f"{prompts.online_search_conversation.format(online_results=str(simplified_online_results))}\n{conversation_primer}"
if ConversationCommand.Notes in conversation_commands:
conversation_primer = f"{prompts.notes_conversation_gpt4all.format(references=compiled_references_message)}\n{conversation_primer}"
# Setup Prompt with Primer or Conversation History
current_date = datetime.now().strftime("%Y-%m-%d")

View File

@@ -122,7 +122,7 @@ def converse(
api_key: Optional[str] = None,
temperature: float = 0.2,
completion_func=None,
conversation_command=ConversationCommand.Default,
conversation_commands=[ConversationCommand.Default],
max_prompt_size=None,
tokenizer_name=None,
):
@@ -133,26 +133,25 @@ def converse(
current_date = datetime.now().strftime("%Y-%m-%d")
compiled_references = "\n\n".join({f"# {item}" for item in references})
conversation_primer = prompts.query_prompt.format(query=user_query)
# Get Conversation Primer appropriate to Conversation Type
if conversation_command == ConversationCommand.Notes and is_none_or_empty(compiled_references):
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(compiled_references):
completion_func(chat_response=prompts.no_notes_found.format())
return iter([prompts.no_notes_found.format()])
elif conversation_command == ConversationCommand.Online and is_none_or_empty(online_results):
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
completion_func(chat_response=prompts.no_online_results_found.format())
return iter([prompts.no_online_results_found.format()])
elif conversation_command == ConversationCommand.Online:
if ConversationCommand.Online in conversation_commands:
simplified_online_results = online_results.copy()
for result in online_results:
if online_results[result].get("extracted_content"):
simplified_online_results[result] = online_results[result]["extracted_content"]
conversation_primer = prompts.online_search_conversation.format(
query=user_query, online_results=str(simplified_online_results)
)
elif conversation_command == ConversationCommand.General or is_none_or_empty(compiled_references):
conversation_primer = prompts.general_conversation.format(query=user_query)
else:
conversation_primer = prompts.notes_conversation.format(query=user_query, references=compiled_references)
conversation_primer = f"{prompts.online_search_conversation.format(online_results=str(simplified_online_results))}\n{conversation_primer}"
if ConversationCommand.Notes in conversation_commands:
conversation_primer = f"{prompts.notes_conversation.format(query=user_query, references=compiled_references)}\n{conversation_primer}"
# Setup Prompt with Primer or Conversation History
messages = generate_chatml_messages_with_context(

View File

@@ -112,7 +112,6 @@ notes_conversation_gpt4all = PromptTemplate.from_template(
"""
User's Notes:
{references}
Question: {query}
""".strip()
)
@@ -139,7 +138,13 @@ Use this up-to-date information from the internet to inform your response.
Ask crisp follow-up questions to get additional context, when a helpful response cannot be provided from the online data or past conversations.
Information from the internet: {online_results}
""".strip()
)
## Query prompt
## --
query_prompt = PromptTemplate.from_template(
"""
Query: {query}""".strip()
)
@@ -285,6 +290,60 @@ Collate the relevant information from the website to answer the target query.
""".strip()
)
pick_relevant_information_collection_tools = PromptTemplate.from_template(
"""
You are Khoj, a smart and helpful personal assistant. You have access to a variety of data sources to help you answer the user's question. You can use the data sources listed below to collect more relevant information. You can use any combination of these data sources to answer the user's question. Tell me which data sources you would like to use to answer the user's question.
{tools}
Here are some example responses:
Example 1:
Chat History:
User: I'm thinking of moving to a new city. I'm trying to decide between New York and San Francisco.
AI: Moving to a new city can be challenging. Both New York and San Francisco are great cities to live in. New York is known for its diverse culture and San Francisco is known for its tech scene.
Q: What is the population of each of those cities?
Khoj: ["online"]
Example 2:
Chat History:
User: I've been having a hard time at work. I'm thinking of quitting.
AI: I'm sorry to hear that. It's important to take care of your mental health. Have you considered talking to your manager about your concerns?
Q: What are the best ways to quit a job?
Khoj: ["general"]
Example 3:
Chat History:
User: I'm thinking of my next vacation idea. Ideally, I want to see something new and exciting.
AI: Excellent! Taking a vacation is a great way to relax and recharge.
Q: Where did Grandma grow up?
Khoj: ["notes"]
Example 4:
Chat History:
Q: I want to make chocolate cake. What was my recipe?
Khoj: ["notes"]
Example 5:
Chat History:
Q: What's the latest news with the first company I worked for?
Khoj: ["notes", "online"]
Now it's your turn to pick the tools you would like to use to answer the user's question. Provide your response as a list of strings.
Chat History:
{chat_history}
Q: {query}
A:
""".strip()
)
online_search_conversation_subqueries = PromptTemplate.from_template(
"""
You are Khoj, an extremely smart and helpful search assistant. You are tasked with constructing **up to three** search queries for Google to answer the user's question.

View File

@@ -274,7 +274,7 @@ async def extract_references_and_questions(
q: str,
n: int,
d: float,
conversation_type: ConversationCommand = ConversationCommand.Default,
conversation_commands: List[ConversationCommand] = [ConversationCommand.Default],
):
user = request.user.object if request.user.is_authenticated else None
@@ -282,7 +282,7 @@ async def extract_references_and_questions(
compiled_references: List[Any] = []
inferred_queries: List[str] = []
if conversation_type == ConversationCommand.General or conversation_type == ConversationCommand.Online:
if not ConversationCommand.Notes in conversation_commands:
return compiled_references, inferred_queries, q
if not await sync_to_async(EntryAdapters.user_has_entries)(user=user):

View File

@@ -21,6 +21,7 @@ from khoj.routers.helpers import (
CommonQueryParams,
ConversationCommandRateLimiter,
agenerate_chat_response,
aget_relevant_information_sources,
get_conversation_command,
is_ready_to_chat,
text_to_image,
@@ -207,7 +208,7 @@ async def set_conversation_title(
)
@api_chat.get("", response_class=Response)
@api_chat.get("/", response_class=Response)
@requires(["authenticated"])
async def chat(
request: Request,
@@ -229,25 +230,9 @@ async def chat(
q = unquote(q)
await is_ready_to_chat(user)
conversation_command = get_conversation_command(query=q, any_references=True)
conversation_commands = [get_conversation_command(query=q, any_references=True)]
await conversation_command_rate_limiter.update_and_check_if_valid(request, conversation_command)
q = q.replace(f"/{conversation_command.value}", "").strip()
meta_log = (
await ConversationAdapters.aget_conversation_by_user(user, request.user.client_app, conversation_id, slug)
).conversation_log
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
request, common, meta_log, q, (n or 5), (d or math.inf), conversation_command
)
online_results: Dict = dict()
if conversation_command == ConversationCommand.Default and is_none_or_empty(compiled_references):
conversation_command = ConversationCommand.General
elif conversation_command == ConversationCommand.Help:
if conversation_commands == [ConversationCommand.Help]:
conversation_config = await ConversationAdapters.aget_user_conversation_config(user)
if conversation_config == None:
conversation_config = await ConversationAdapters.aget_default_conversation_config()
@@ -255,7 +240,23 @@ async def chat(
formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device())
return StreamingResponse(iter([formatted_help]), media_type="text/event-stream", status_code=200)
elif conversation_command == ConversationCommand.Notes and not await EntryAdapters.auser_has_entries(user):
meta_log = (
await ConversationAdapters.aget_conversation_by_user(user, request.user.client_app, conversation_id, slug)
).conversation_log
if conversation_commands == [ConversationCommand.Default]:
conversation_commands = await aget_relevant_information_sources(q, meta_log)
for cmd in conversation_commands:
await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd)
q = q.replace(f"/{cmd.value}", "").strip()
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
request, common, meta_log, q, (n or 5), (d or math.inf), conversation_commands
)
online_results: Dict = dict()
if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user):
no_entries_found_format = no_entries_found.format()
if stream:
return StreamingResponse(iter([no_entries_found_format]), media_type="text/event-stream", status_code=200)
@@ -263,7 +264,10 @@ async def chat(
response_obj = {"response": no_entries_found_format}
return Response(content=json.dumps(response_obj), media_type="text/plain", status_code=200)
elif conversation_command == ConversationCommand.Online:
if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references):
conversation_commands.remove(ConversationCommand.Notes)
if ConversationCommand.Online in conversation_commands:
try:
online_results = await search_with_google(defiltered_query, meta_log)
except ValueError as e:
@@ -272,12 +276,12 @@ async def chat(
media_type="text/event-stream",
status_code=200,
)
elif conversation_command == ConversationCommand.Image:
elif conversation_commands == [ConversationCommand.Image]:
update_telemetry_state(
request=request,
telemetry_type="api",
api="chat",
metadata={"conversation_command": conversation_command.value},
metadata={"conversation_command": conversation_commands[0].value},
**common.__dict__,
)
image, status_code, improved_image_prompt = await text_to_image(q, meta_log)
@@ -308,13 +312,13 @@ async def chat(
compiled_references,
online_results,
inferred_queries,
conversation_command,
conversation_commands,
user,
request.user.client_app,
conversation_id,
)
chat_metadata.update({"conversation_command": conversation_command.value})
chat_metadata.update({"conversation_command": ",".join([cmd.value for cmd in conversation_commands])})
update_telemetry_state(
request=request,

View File

@@ -34,7 +34,11 @@ from khoj.processor.conversation.utils import (
)
from khoj.utils import state
from khoj.utils.config import GPT4AllProcessorModel
from khoj.utils.helpers import ConversationCommand, log_telemetry
from khoj.utils.helpers import (
ConversationCommand,
log_telemetry,
tool_descriptions_for_llm,
)
logger = logging.getLogger(__name__)
@@ -105,6 +109,15 @@ def update_telemetry_state(
]
def construct_chat_history(conversation_history: dict, n: int = 4) -> str:
chat_history = ""
for chat in conversation_history.get("chat", [])[-n:]:
if chat["by"] == "khoj" and chat["intent"].get("type") == "remember":
chat_history += f"User: {chat['intent']['query']}\n"
chat_history += f"Khoj: {chat['message']}\n"
return chat_history
def get_conversation_command(query: str, any_references: bool = False) -> ConversationCommand:
if query.startswith("/notes"):
return ConversationCommand.Notes
@@ -128,15 +141,50 @@ async def agenerate_chat_response(*args):
return await loop.run_in_executor(executor, generate_chat_response, *args)
async def aget_relevant_information_sources(query: str, conversation_history: dict):
"""
Given a query, determine which of the available tools the agent should use in order to answer appropriately.
"""
tool_options = dict()
for tool, description in tool_descriptions_for_llm.items():
tool_options[tool.value] = description
chat_history = construct_chat_history(conversation_history)
relevant_tools_prompt = prompts.pick_relevant_information_collection_tools.format(
query=query,
tools=str(tool_options),
chat_history=chat_history,
)
response = await send_message_to_model_wrapper(relevant_tools_prompt)
try:
response = response.strip()
response = json.loads(response)
response = [q.strip() for q in response if q.strip()]
if not isinstance(response, list) or not response or len(response) == 0:
logger.error(f"Invalid response for determining relevant tools: {response}")
return tool_options
final_response = []
for llm_suggested_tool in response:
if llm_suggested_tool in tool_options.keys():
# Check whether the tool exists as a valid ConversationCommand
final_response.append(ConversationCommand(llm_suggested_tool))
return final_response
except Exception as e:
logger.error(f"Invalid response for determining relevant tools: {response}")
return [ConversationCommand.Default]
async def generate_online_subqueries(q: str, conversation_history: dict) -> List[str]:
"""
Generate subqueries from the given query
"""
chat_history = ""
for chat in conversation_history.get("chat", [])[-4:]:
if chat["by"] == "khoj" and chat["intent"].get("type") == "remember":
chat_history += f"User: {chat['intent']['query']}\n"
chat_history += f"Khoj: {chat['message']}\n"
chat_history = construct_chat_history(conversation_history)
utc_date = datetime.utcnow().strftime("%Y-%m-%d")
online_queries_prompt = prompts.online_search_conversation_subqueries.format(
@@ -241,14 +289,14 @@ def generate_chat_response(
compiled_references: List[str] = [],
online_results: Dict[str, Any] = {},
inferred_queries: List[str] = [],
conversation_command: ConversationCommand = ConversationCommand.Default,
conversation_commands: List[ConversationCommand] = [ConversationCommand.Default],
user: KhojUser = None,
client_application: ClientApplication = None,
conversation_id: int = None,
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
# Initialize Variables
chat_response = None
logger.debug(f"Conversation Type: {conversation_command.name}")
logger.debug(f"Conversation Types: {conversation_commands}")
metadata = {}
@@ -278,7 +326,7 @@ def generate_chat_response(
loaded_model=loaded_model,
conversation_log=meta_log,
completion_func=partial_completion,
conversation_command=conversation_command,
conversation_commands=conversation_commands,
model=conversation_config.chat_model,
max_prompt_size=conversation_config.max_prompt_size,
tokenizer_name=conversation_config.tokenizer,
@@ -296,7 +344,7 @@ def generate_chat_response(
model=chat_model,
api_key=api_key,
completion_func=partial_completion,
conversation_command=conversation_command,
conversation_commands=conversation_commands,
max_prompt_size=conversation_config.max_prompt_size,
tokenizer_name=conversation_config.tokenizer,
)

View File

@@ -282,6 +282,13 @@ command_descriptions = {
ConversationCommand.Help: "Display a help message with all available commands and other metadata.",
}
tool_descriptions_for_llm = {
ConversationCommand.Default: "Use this if there might be a mix of general and personal knowledge in the question",
ConversationCommand.General: "Use this when you can answer the question without needing any additional online or personal information",
ConversationCommand.Notes: "Use this when you would like to use the user's personal knowledge base to answer the question",
ConversationCommand.Online: "Use this when you would like to look up information on the internet",
}
def generate_random_name():
# List of adjectives and nouns to choose from