Include additional user context in the image generation flow (#660)

* Make major improvements to the image generation flow

- Include user context from online references and personal notes for generating images
- Dynamically select the modality that the LLM should respond with
- Retun the inferred context in the query response for the dekstop, web chat views to read

* Add unit tests for retrieving response modes via LLM

* Move output mode unit tests to the actor suite, rather than director

* Only show the references button if there is at least one available

* Rename aget_relevant_modes to aget_relevant_output_modes

* Use a shared method for generating reference sections, simplify some of the prompting logic

* Make out of space errors in the desktop client more obvious
This commit is contained in:
sabaimran
2024-03-06 13:48:41 +05:30
committed by GitHub
parent 3cbc5b0d52
commit e323a6d69b
9 changed files with 336 additions and 102 deletions

View File

@@ -24,6 +24,7 @@ from khoj.processor.conversation.offline.chat_model import (
)
from khoj.processor.conversation.offline.utils import download_model
from khoj.processor.conversation.utils import message_to_log
from khoj.routers.helpers import aget_relevant_output_modes
MODEL_NAME = "mistral-7b-instruct-v0.1.Q4_0.gguf"
@@ -497,6 +498,34 @@ def test_filter_questions():
assert filtered_questions[0] == "Who is on the basketball team?"
# ----------------------------------------------------------------------------------------------------
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_use_default_response_mode(client_offline_chat):
# Arrange
user_query = "What's the latest in the Israel/Palestine conflict?"
# Act
mode = await aget_relevant_output_modes(user_query, {})
# Assert
assert mode.value == "default"
# ----------------------------------------------------------------------------------------------------
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_use_image_response_mode(client_offline_chat):
# Arrange
user_query = "Paint a picture of the scenery in Timbuktu in the winter"
# Act
mode = await aget_relevant_output_modes(user_query, {})
# Assert
assert mode.value == "image"
# Helpers
# ----------------------------------------------------------------------------------------------------
def populate_chat_history(message_list):

View File

@@ -7,6 +7,7 @@ from freezegun import freeze_time
from khoj.processor.conversation.openai.gpt import converse, extract_questions
from khoj.processor.conversation.utils import message_to_log
from khoj.routers.helpers import aget_relevant_output_modes
# Initialize variables for tests
api_key = os.getenv("OPENAI_API_KEY")
@@ -434,6 +435,34 @@ My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet.""
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_use_default_response_mode(chat_client):
# Arrange
user_query = "What's the latest in the Israel/Palestine conflict?"
# Act
mode = await aget_relevant_output_modes(user_query, {})
# Assert
assert mode.value == "default"
# ----------------------------------------------------------------------------------------------------
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_use_image_response_mode(chat_client):
# Arrange
user_query = "Paint a picture of the scenery in Timbuktu in the winter"
# Act
mode = await aget_relevant_output_modes(user_query, {})
# Assert
assert mode.value == "image"
# Helpers
# ----------------------------------------------------------------------------------------------------
def populate_chat_history(message_list):

View File

@@ -8,7 +8,10 @@ from freezegun import freeze_time
from khoj.database.models import KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.utils import message_to_log
from khoj.routers.helpers import aget_relevant_information_sources
from khoj.routers.helpers import (
aget_relevant_information_sources,
aget_relevant_output_modes,
)
from tests.helpers import ConversationFactory
# Initialize variables for tests