diff --git a/src/interface/desktop/chat.html b/src/interface/desktop/chat.html
index 51c3f0b6..9bdc7cef 100644
--- a/src/interface/desktop/chat.html
+++ b/src/interface/desktop/chat.html
@@ -197,13 +197,18 @@
}
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null, inferredQueries=null) {
- if (intentType === "text-to-image") {
- let imageMarkdown = ``;
- const inferredQuery = inferredQueries?.[0];
- if (inferredQuery) {
- imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
+ if ((context == null || context.length == 0) && (onlineContext == null || (onlineContext && Object.keys(onlineContext).length == 0))) {
+ if (intentType === "text-to-image") {
+ let imageMarkdown = ``;
+ const inferredQuery = inferredQueries?.[0];
+ if (inferredQuery) {
+ imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
+ }
+ renderMessage(imageMarkdown, by, dt);
+ return;
}
- renderMessage(imageMarkdown, by, dt);
+
+ renderMessage(message, by, dt);
return;
}
@@ -261,6 +266,16 @@
references.appendChild(referenceSection);
+ if (intentType === "text-to-image") {
+ let imageMarkdown = ``;
+ const inferredQuery = inferredQueries?.[0];
+ if (inferredQuery) {
+ imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
+ }
+ renderMessage(imageMarkdown, by, dt, references);
+ return;
+ }
+
renderMessage(message, by, dt, references);
}
@@ -324,6 +339,46 @@
return element
}
+ function createReferenceSection(references) {
+ let referenceSection = document.createElement('div');
+ referenceSection.classList.add("reference-section");
+ referenceSection.classList.add("collapsed");
+
+ let numReferences = 0;
+
+ if (Array.isArray(references)) {
+ numReferences = references.length;
+
+ references.forEach((reference, index) => {
+ let polishedReference = generateReference(reference, index);
+ referenceSection.appendChild(polishedReference);
+ });
+ } else {
+ numReferences += processOnlineReferences(referenceSection, references);
+ }
+
+ let referenceExpandButton = document.createElement('button');
+ referenceExpandButton.classList.add("reference-expand-button");
+ referenceExpandButton.innerHTML = numReferences == 1 ? "1 reference" : `${numReferences} references`;
+
+ referenceExpandButton.addEventListener('click', function() {
+ if (referenceSection.classList.contains("collapsed")) {
+ referenceSection.classList.remove("collapsed");
+ referenceSection.classList.add("expanded");
+ } else {
+ referenceSection.classList.add("collapsed");
+ referenceSection.classList.remove("expanded");
+ }
+ });
+
+ let referencesDiv = document.createElement('div');
+ referencesDiv.classList.add("references");
+ referencesDiv.appendChild(referenceExpandButton);
+ referencesDiv.appendChild(referenceSection);
+
+ return referencesDiv;
+ }
+
async function chat() {
// Extract required fields for search from form
let query = document.getElementById("chat-input").value.trim();
@@ -382,6 +437,7 @@
// Call Khoj chat API
let response = await fetch(chatApi, { headers });
let rawResponse = "";
+ let references = null;
const contentType = response.headers.get("content-type");
if (contentType === "application/json") {
@@ -396,6 +452,10 @@
rawResponse += `\n\n**Inferred Query**:\n\n${inferredQueries}`;
}
}
+ if (responseAsJson.context) {
+ const rawReferenceAsJson = responseAsJson.context;
+ references = createReferenceSection(rawReferenceAsJson);
+ }
if (responseAsJson.detail) {
// If response has detail field, response is an error message.
rawResponse += responseAsJson.detail;
@@ -407,6 +467,10 @@
newResponseText.innerHTML = "";
newResponseText.appendChild(formatHTMLMessage(rawResponse));
+ if (references != null) {
+ newResponseText.appendChild(references);
+ }
+
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
document.getElementById("chat-input").removeAttribute("disabled");
}
@@ -441,45 +505,7 @@
const rawReference = chunk.split("### compiled references:")[1];
const rawReferenceAsJson = JSON.parse(rawReference);
- references = document.createElement('div');
- references.classList.add("references");
-
- let referenceExpandButton = document.createElement('button');
- referenceExpandButton.classList.add("reference-expand-button");
-
- let referenceSection = document.createElement('div');
- referenceSection.classList.add("reference-section");
- referenceSection.classList.add("collapsed");
-
- let numReferences = 0;
-
- // If rawReferenceAsJson is a list, then count the length
- if (Array.isArray(rawReferenceAsJson)) {
- numReferences = rawReferenceAsJson.length;
-
- rawReferenceAsJson.forEach((reference, index) => {
- let polishedReference = generateReference(reference, index);
- referenceSection.appendChild(polishedReference);
- });
- } else {
- numReferences += processOnlineReferences(referenceSection, rawReferenceAsJson);
- }
-
- references.appendChild(referenceExpandButton);
-
- referenceExpandButton.addEventListener('click', function() {
- if (referenceSection.classList.contains("collapsed")) {
- referenceSection.classList.remove("collapsed");
- referenceSection.classList.add("expanded");
- } else {
- referenceSection.classList.add("collapsed");
- referenceSection.classList.remove("expanded");
- }
- });
-
- let expandButtonText = numReferences == 1 ? "1 reference" : `${numReferences} references`;
- referenceExpandButton.innerHTML = expandButtonText;
- references.appendChild(referenceSection);
+ references = createReferenceSection(rawReferenceAsJson);
readStream();
} else {
// Display response from Khoj
diff --git a/src/khoj/interface/web/chat.html b/src/khoj/interface/web/chat.html
index e1b19183..ca6064e2 100644
--- a/src/khoj/interface/web/chat.html
+++ b/src/khoj/interface/web/chat.html
@@ -209,17 +209,17 @@ To get started, just start typing below. You can also type / to see a list of co
}
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null, inferredQueries=null) {
- if (intentType === "text-to-image") {
- let imageMarkdown = ``;
- const inferredQuery = inferredQueries?.[0];
- if (inferredQuery) {
- imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
+ if ((context == null || context.length == 0) && (onlineContext == null || (onlineContext && Object.keys(onlineContext).length == 0))) {
+ if (intentType === "text-to-image") {
+ let imageMarkdown = ``;
+ const inferredQuery = inferredQueries?.[0];
+ if (inferredQuery) {
+ imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
+ }
+ renderMessage(imageMarkdown, by, dt);
+ return;
}
- renderMessage(imageMarkdown, by, dt);
- return;
- }
- if (context == null && onlineContext == null) {
renderMessage(message, by, dt);
return;
}
@@ -273,6 +273,16 @@ To get started, just start typing below. You can also type / to see a list of co
references.appendChild(referenceSection);
+ if (intentType === "text-to-image") {
+ let imageMarkdown = ``;
+ const inferredQuery = inferredQueries?.[0];
+ if (inferredQuery) {
+ imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
+ }
+ renderMessage(imageMarkdown, by, dt, references);
+ return;
+ }
+
renderMessage(message, by, dt, references);
}
@@ -336,6 +346,46 @@ To get started, just start typing below. You can also type / to see a list of co
return element
}
+ function createReferenceSection(references) {
+ let referenceSection = document.createElement('div');
+ referenceSection.classList.add("reference-section");
+ referenceSection.classList.add("collapsed");
+
+ let numReferences = 0;
+
+ if (Array.isArray(references)) {
+ numReferences = references.length;
+
+ references.forEach((reference, index) => {
+ let polishedReference = generateReference(reference, index);
+ referenceSection.appendChild(polishedReference);
+ });
+ } else {
+ numReferences += processOnlineReferences(referenceSection, references);
+ }
+
+ let referenceExpandButton = document.createElement('button');
+ referenceExpandButton.classList.add("reference-expand-button");
+ referenceExpandButton.innerHTML = numReferences == 1 ? "1 reference" : `${numReferences} references`;
+
+ referenceExpandButton.addEventListener('click', function() {
+ if (referenceSection.classList.contains("collapsed")) {
+ referenceSection.classList.remove("collapsed");
+ referenceSection.classList.add("expanded");
+ } else {
+ referenceSection.classList.add("collapsed");
+ referenceSection.classList.remove("expanded");
+ }
+ });
+
+ let referencesDiv = document.createElement('div');
+ referencesDiv.classList.add("references");
+ referencesDiv.appendChild(referenceExpandButton);
+ referencesDiv.appendChild(referenceSection);
+
+ return referencesDiv;
+ }
+
async function chat() {
// Extract required fields for search from form
let query = document.getElementById("chat-input").value.trim();
@@ -390,6 +440,7 @@ To get started, just start typing below. You can also type / to see a list of co
// Call specified Khoj API
let response = await fetch(url);
let rawResponse = "";
+ let references = null;
const contentType = response.headers.get("content-type");
if (contentType === "application/json") {
@@ -404,6 +455,10 @@ To get started, just start typing below. You can also type / to see a list of co
rawResponse += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
}
}
+ if (responseAsJson.context && responseAsJson.context.length > 0) {
+ const rawReferenceAsJson = responseAsJson.context;
+ references = createReferenceSection(rawReferenceAsJson);
+ }
if (responseAsJson.detail) {
// If response has detail field, response is an error message.
rawResponse += responseAsJson.detail;
@@ -415,6 +470,10 @@ To get started, just start typing below. You can also type / to see a list of co
newResponseText.innerHTML = "";
newResponseText.appendChild(formatHTMLMessage(rawResponse));
+ if (references != null) {
+ newResponseText.appendChild(references);
+ }
+
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
document.getElementById("chat-input").removeAttribute("disabled");
}
@@ -449,45 +508,7 @@ To get started, just start typing below. You can also type / to see a list of co
const rawReference = chunk.split("### compiled references:")[1];
const rawReferenceAsJson = JSON.parse(rawReference);
- references = document.createElement('div');
- references.classList.add("references");
-
- let referenceExpandButton = document.createElement('button');
- referenceExpandButton.classList.add("reference-expand-button");
-
- let referenceSection = document.createElement('div');
- referenceSection.classList.add("reference-section");
- referenceSection.classList.add("collapsed");
-
- let numReferences = 0;
-
- // If rawReferenceAsJson is a list, then count the length
- if (Array.isArray(rawReferenceAsJson)) {
- numReferences = rawReferenceAsJson.length;
-
- rawReferenceAsJson.forEach((reference, index) => {
- let polishedReference = generateReference(reference, index);
- referenceSection.appendChild(polishedReference);
- });
- } else {
- numReferences += processOnlineReferences(referenceSection, rawReferenceAsJson);
- }
-
- references.appendChild(referenceExpandButton);
-
- referenceExpandButton.addEventListener('click', function() {
- if (referenceSection.classList.contains("collapsed")) {
- referenceSection.classList.remove("collapsed");
- referenceSection.classList.add("expanded");
- } else {
- referenceSection.classList.add("collapsed");
- referenceSection.classList.remove("expanded");
- }
- });
-
- let expandButtonText = numReferences == 1 ? "1 reference" : `${numReferences} references`;
- referenceExpandButton.innerHTML = expandButtonText;
- references.appendChild(referenceSection);
+ references = createReferenceSection(rawReferenceAsJson);
readStream();
} else {
// Display response from Khoj
diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py
index 9ad85afb..d124e8f0 100644
--- a/src/khoj/processor/conversation/prompts.py
+++ b/src/khoj/processor/conversation/prompts.py
@@ -120,16 +120,23 @@ User's Notes:
image_generation_improve_prompt = PromptTemplate.from_template(
"""
-You are a talented creator. Generate a detailed prompt to generate an image based on the following description. Update the query below to improve the image generation. Add additional context to the query to improve the image generation. Make sure to retain any important information from the query. Use the conversation log to inform your response.
+You are a talented creator. Generate a detailed prompt to generate an image based on the following description. Update the query below to improve the image generation. Add additional context to the query to improve the image generation. Make sure to retain any important information originally from the query. You are provided with the following information to help you generate the prompt:
Today's Date: {current_date}
User's Location: {location}
+User's Notes:
+{references}
+
+Online References:
+{online_results}
+
Conversation Log:
{chat_history}
Query: {query}
+Remember, now you are generating a prompt to improve the image generation. Add additional context to the query to improve the image generation. Make sure to retain any important information originally from the query. Use the additional context from the user's notes, online references and conversation log to improve the image generation.
Improved Query:"""
)
@@ -294,6 +301,40 @@ Collate the relevant information from the website to answer the target query.
""".strip()
)
+pick_relevant_output_mode = PromptTemplate.from_template(
+ """
+You are Khoj, an excellent analyst for selecting the correct way to respond to a user's query. You have access to a limited set of modes for your response. You can only use one of these modes.
+
+{modes}
+
+Here are some example responses:
+
+Example:
+Chat History:
+User: I just visited Jerusalem for the first time. Pull up my notes from the trip.
+AI: You mention visiting Masjid Al-Aqsa and the Western Wall. You also mention trying the local cuisine and visiting the Dead Sea.
+
+Q: Draw a picture of my trip to Jerusalem.
+Khoj: image
+
+Example:
+Chat History:
+User: I'm having trouble deciding which laptop to get. I want something with at least 16 GB of RAM and a 1 TB SSD.
+AI: I can help with that. I see online that there is a new model of the Dell XPS 15 that meets your requirements.
+
+Q: What are the specs of the new Dell XPS 15?
+Khoj: default
+
+Now it's your turn to pick the mode you would like to use to answer the user's question. Provide your response as a string.
+
+Chat History:
+{chat_history}
+
+Q: {query}
+Khoj:
+""".strip()
+)
+
pick_relevant_information_collection_tools = PromptTemplate.from_template(
"""
You are Khoj, a smart and helpful personal assistant. You have access to a variety of data sources to help you answer the user's question. You can use the data sources listed below to collect more relevant information. You can use any combination of these data sources to answer the user's question. Tell me which data sources you would like to use to answer the user's question.
diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py
index 809645a5..093e8cc9 100644
--- a/src/khoj/routers/api_chat.py
+++ b/src/khoj/routers/api_chat.py
@@ -22,6 +22,7 @@ from khoj.routers.helpers import (
ConversationCommandRateLimiter,
agenerate_chat_response,
aget_relevant_information_sources,
+ aget_relevant_output_modes,
get_conversation_command,
is_ready_to_chat,
text_to_image,
@@ -250,6 +251,9 @@ async def chat(
if conversation_commands == [ConversationCommand.Default]:
conversation_commands = await aget_relevant_information_sources(q, meta_log)
+ mode = await aget_relevant_output_modes(q, meta_log)
+ if mode not in conversation_commands:
+ conversation_commands.append(mode)
for cmd in conversation_commands:
await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd)
@@ -287,7 +291,8 @@ async def chat(
media_type="text/event-stream",
status_code=200,
)
- elif conversation_commands == [ConversationCommand.Image]:
+
+ if ConversationCommand.Image in conversation_commands:
update_telemetry_state(
request=request,
telemetry_type="api",
@@ -295,7 +300,9 @@ async def chat(
metadata={"conversation_command": conversation_commands[0].value},
**common.__dict__,
)
- image, status_code, improved_image_prompt = await text_to_image(q, meta_log, location_data=location)
+ image, status_code, improved_image_prompt = await text_to_image(
+ q, meta_log, location_data=location, references=compiled_references, online_results=online_results
+ )
if image is None:
content_obj = {"image": image, "intentType": "text-to-image", "detail": improved_image_prompt}
return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code)
@@ -308,8 +315,10 @@ async def chat(
inferred_queries=[improved_image_prompt],
client_application=request.user.client_app,
conversation_id=conversation_id,
+ compiled_references=compiled_references,
+ online_results=online_results,
)
- content_obj = {"image": image, "intentType": "text-to-image", "inferredQueries": [improved_image_prompt]} # type: ignore
+ content_obj = {"image": image, "intentType": "text-to-image", "inferredQueries": [improved_image_prompt], "context": compiled_references, "online_results": online_results} # type: ignore
return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code)
# Get the (streamed) chat response from the LLM of choice.
diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py
index 5dfd73cc..f24ff688 100644
--- a/src/khoj/routers/helpers.py
+++ b/src/khoj/routers/helpers.py
@@ -38,6 +38,7 @@ from khoj.utils.helpers import (
ConversationCommand,
is_none_or_empty,
log_telemetry,
+ mode_descriptions_for_llm,
tool_descriptions_for_llm,
)
from khoj.utils.rawconfig import LocationData
@@ -117,6 +118,9 @@ def construct_chat_history(conversation_history: dict, n: int = 4) -> str:
if chat["by"] == "khoj" and chat["intent"].get("type") == "remember":
chat_history += f"User: {chat['intent']['query']}\n"
chat_history += f"Khoj: {chat['message']}\n"
+ elif chat["by"] == "khoj" and chat["intent"].get("type") == "text-to-image":
+ chat_history += f"User: {chat['intent']['query']}\n"
+ chat_history += f"Khoj: [generated image redacted for space]\n"
return chat_history
@@ -185,6 +189,42 @@ async def aget_relevant_information_sources(query: str, conversation_history: di
return [ConversationCommand.Default]
+async def aget_relevant_output_modes(query: str, conversation_history: dict):
+ """
+ Given a query, determine which of the available tools the agent should use in order to answer appropriately.
+ """
+
+ mode_options = dict()
+
+ for mode, description in mode_descriptions_for_llm.items():
+ mode_options[mode.value] = description
+
+ chat_history = construct_chat_history(conversation_history)
+
+ relevant_mode_prompt = prompts.pick_relevant_output_mode.format(
+ query=query,
+ modes=str(mode_options),
+ chat_history=chat_history,
+ )
+
+ response = await send_message_to_model_wrapper(relevant_mode_prompt)
+
+ try:
+ response = response.strip()
+
+ if is_none_or_empty(response):
+ return ConversationCommand.Default
+
+ if response in mode_options.keys():
+ # Check whether the tool exists as a valid ConversationCommand
+ return ConversationCommand(response)
+
+ return ConversationCommand.Default
+ except Exception as e:
+ logger.error(f"Invalid response for determining relevant mode: {response}")
+ return ConversationCommand.Default
+
+
async def generate_online_subqueries(q: str, conversation_history: dict, location_data: LocationData) -> List[str]:
"""
Generate subqueries from the given query
@@ -234,7 +274,13 @@ async def extract_relevant_info(q: str, corpus: dict) -> List[str]:
return response.strip()
-async def generate_better_image_prompt(q: str, conversation_history: str, location_data: LocationData) -> str:
+async def generate_better_image_prompt(
+ q: str,
+ conversation_history: str,
+ location_data: LocationData,
+ note_references: List[str],
+ online_results: Optional[dict] = None,
+) -> str:
"""
Generate a better image prompt from the given query
"""
@@ -242,11 +288,26 @@ async def generate_better_image_prompt(q: str, conversation_history: str, locati
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
today_date = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d")
+ location_prompt = prompts.user_location.format(location=location)
+
+ user_references = "\n\n".join([f"# {item}" for item in note_references])
+
+ simplified_online_results = {}
+
+ if online_results:
+ for result in online_results:
+ if online_results[result].get("answerBox"):
+ simplified_online_results[result] = online_results[result]["answerBox"]
+ elif online_results[result].get("extracted_content"):
+ simplified_online_results[result] = online_results[result]["extracted_content"]
+
image_prompt = prompts.image_generation_improve_prompt.format(
query=q,
chat_history=conversation_history,
- location=location,
+ location=location_prompt,
current_date=today_date,
+ references=user_references,
+ online_results=simplified_online_results,
)
response = await send_message_to_model_wrapper(image_prompt)
@@ -377,7 +438,11 @@ def generate_chat_response(
async def text_to_image(
- message: str, conversation_log: dict, location_data: LocationData
+ message: str,
+ conversation_log: dict,
+ location_data: LocationData,
+ references: List[str],
+ online_results: Dict[str, Any],
) -> Tuple[Optional[str], int, Optional[str]]:
status_code = 200
image = None
@@ -396,7 +461,13 @@ async def text_to_image(
if chat["by"] == "khoj" and chat["intent"].get("type") == "remember":
chat_history += f"Q: {chat['intent']['query']}\n"
chat_history += f"A: {chat['message']}\n"
- improved_image_prompt = await generate_better_image_prompt(message, chat_history, location_data=location_data)
+ improved_image_prompt = await generate_better_image_prompt(
+ message,
+ chat_history,
+ location_data=location_data,
+ note_references=references,
+ online_results=online_results,
+ )
try:
response = state.openai_client.images.generate(
prompt=improved_image_prompt, model=text2image_model, response_format="b64_json"
diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py
index ac8c482d..d2b64296 100644
--- a/src/khoj/utils/helpers.py
+++ b/src/khoj/utils/helpers.py
@@ -289,6 +289,11 @@ tool_descriptions_for_llm = {
ConversationCommand.Online: "Use this when you would like to look up information on the internet",
}
+mode_descriptions_for_llm = {
+ ConversationCommand.Image: "Use this if you think the user is requesting an image or visual response to their query.",
+ ConversationCommand.Default: "Use this if the other response modes don't seem to fit the query.",
+}
+
def generate_random_name():
# List of adjectives and nouns to choose from
diff --git a/tests/test_gpt4all_chat_actors.py b/tests/test_gpt4all_chat_actors.py
index 3b1d7249..5cce4fc5 100644
--- a/tests/test_gpt4all_chat_actors.py
+++ b/tests/test_gpt4all_chat_actors.py
@@ -24,6 +24,7 @@ from khoj.processor.conversation.offline.chat_model import (
)
from khoj.processor.conversation.offline.utils import download_model
from khoj.processor.conversation.utils import message_to_log
+from khoj.routers.helpers import aget_relevant_output_modes
MODEL_NAME = "mistral-7b-instruct-v0.1.Q4_0.gguf"
@@ -497,6 +498,34 @@ def test_filter_questions():
assert filtered_questions[0] == "Who is on the basketball team?"
+# ----------------------------------------------------------------------------------------------------
+@pytest.mark.anyio
+@pytest.mark.django_db(transaction=True)
+async def test_use_default_response_mode(client_offline_chat):
+ # Arrange
+ user_query = "What's the latest in the Israel/Palestine conflict?"
+
+ # Act
+ mode = await aget_relevant_output_modes(user_query, {})
+
+ # Assert
+ assert mode.value == "default"
+
+
+# ----------------------------------------------------------------------------------------------------
+@pytest.mark.anyio
+@pytest.mark.django_db(transaction=True)
+async def test_use_image_response_mode(client_offline_chat):
+ # Arrange
+ user_query = "Paint a picture of the scenery in Timbuktu in the winter"
+
+ # Act
+ mode = await aget_relevant_output_modes(user_query, {})
+
+ # Assert
+ assert mode.value == "image"
+
+
# Helpers
# ----------------------------------------------------------------------------------------------------
def populate_chat_history(message_list):
diff --git a/tests/test_openai_chat_actors.py b/tests/test_openai_chat_actors.py
index 201e3fef..183ffb24 100644
--- a/tests/test_openai_chat_actors.py
+++ b/tests/test_openai_chat_actors.py
@@ -7,6 +7,7 @@ from freezegun import freeze_time
from khoj.processor.conversation.openai.gpt import converse, extract_questions
from khoj.processor.conversation.utils import message_to_log
+from khoj.routers.helpers import aget_relevant_output_modes
# Initialize variables for tests
api_key = os.getenv("OPENAI_API_KEY")
@@ -434,6 +435,34 @@ My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet.""
)
+# ----------------------------------------------------------------------------------------------------
+@pytest.mark.anyio
+@pytest.mark.django_db(transaction=True)
+async def test_use_default_response_mode(chat_client):
+ # Arrange
+ user_query = "What's the latest in the Israel/Palestine conflict?"
+
+ # Act
+ mode = await aget_relevant_output_modes(user_query, {})
+
+ # Assert
+ assert mode.value == "default"
+
+
+# ----------------------------------------------------------------------------------------------------
+@pytest.mark.anyio
+@pytest.mark.django_db(transaction=True)
+async def test_use_image_response_mode(chat_client):
+ # Arrange
+ user_query = "Paint a picture of the scenery in Timbuktu in the winter"
+
+ # Act
+ mode = await aget_relevant_output_modes(user_query, {})
+
+ # Assert
+ assert mode.value == "image"
+
+
# Helpers
# ----------------------------------------------------------------------------------------------------
def populate_chat_history(message_list):
diff --git a/tests/test_openai_chat_director.py b/tests/test_openai_chat_director.py
index 33547b48..105ec033 100644
--- a/tests/test_openai_chat_director.py
+++ b/tests/test_openai_chat_director.py
@@ -8,7 +8,10 @@ from freezegun import freeze_time
from khoj.database.models import KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.utils import message_to_log
-from khoj.routers.helpers import aget_relevant_information_sources
+from khoj.routers.helpers import (
+ aget_relevant_information_sources,
+ aget_relevant_output_modes,
+)
from tests.helpers import ConversationFactory
# Initialize variables for tests