mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
Include additional user context in the image generation flow (#660)
* Make major improvements to the image generation flow - Include user context from online references and personal notes for generating images - Dynamically select the modality that the LLM should respond with - Retun the inferred context in the query response for the dekstop, web chat views to read * Add unit tests for retrieving response modes via LLM * Move output mode unit tests to the actor suite, rather than director * Only show the references button if there is at least one available * Rename aget_relevant_modes to aget_relevant_output_modes * Use a shared method for generating reference sections, simplify some of the prompting logic * Make out of space errors in the desktop client more obvious
This commit is contained in:
@@ -24,6 +24,7 @@ from khoj.processor.conversation.offline.chat_model import (
|
||||
)
|
||||
from khoj.processor.conversation.offline.utils import download_model
|
||||
from khoj.processor.conversation.utils import message_to_log
|
||||
from khoj.routers.helpers import aget_relevant_output_modes
|
||||
|
||||
MODEL_NAME = "mistral-7b-instruct-v0.1.Q4_0.gguf"
|
||||
|
||||
@@ -497,6 +498,34 @@ def test_filter_questions():
|
||||
assert filtered_questions[0] == "Who is on the basketball team?"
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
async def test_use_default_response_mode(client_offline_chat):
|
||||
# Arrange
|
||||
user_query = "What's the latest in the Israel/Palestine conflict?"
|
||||
|
||||
# Act
|
||||
mode = await aget_relevant_output_modes(user_query, {})
|
||||
|
||||
# Assert
|
||||
assert mode.value == "default"
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
async def test_use_image_response_mode(client_offline_chat):
|
||||
# Arrange
|
||||
user_query = "Paint a picture of the scenery in Timbuktu in the winter"
|
||||
|
||||
# Act
|
||||
mode = await aget_relevant_output_modes(user_query, {})
|
||||
|
||||
# Assert
|
||||
assert mode.value == "image"
|
||||
|
||||
|
||||
# Helpers
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def populate_chat_history(message_list):
|
||||
|
||||
@@ -7,6 +7,7 @@ from freezegun import freeze_time
|
||||
|
||||
from khoj.processor.conversation.openai.gpt import converse, extract_questions
|
||||
from khoj.processor.conversation.utils import message_to_log
|
||||
from khoj.routers.helpers import aget_relevant_output_modes
|
||||
|
||||
# Initialize variables for tests
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
@@ -434,6 +435,34 @@ My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet.""
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
async def test_use_default_response_mode(chat_client):
|
||||
# Arrange
|
||||
user_query = "What's the latest in the Israel/Palestine conflict?"
|
||||
|
||||
# Act
|
||||
mode = await aget_relevant_output_modes(user_query, {})
|
||||
|
||||
# Assert
|
||||
assert mode.value == "default"
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
async def test_use_image_response_mode(chat_client):
|
||||
# Arrange
|
||||
user_query = "Paint a picture of the scenery in Timbuktu in the winter"
|
||||
|
||||
# Act
|
||||
mode = await aget_relevant_output_modes(user_query, {})
|
||||
|
||||
# Assert
|
||||
assert mode.value == "image"
|
||||
|
||||
|
||||
# Helpers
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def populate_chat_history(message_list):
|
||||
|
||||
@@ -8,7 +8,10 @@ from freezegun import freeze_time
|
||||
from khoj.database.models import KhojUser
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.utils import message_to_log
|
||||
from khoj.routers.helpers import aget_relevant_information_sources
|
||||
from khoj.routers.helpers import (
|
||||
aget_relevant_information_sources,
|
||||
aget_relevant_output_modes,
|
||||
)
|
||||
from tests.helpers import ConversationFactory
|
||||
|
||||
# Initialize variables for tests
|
||||
|
||||
Reference in New Issue
Block a user