mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 13:18:18 +00:00
Improve Chat Session UX, Fix Login, Chat Message Truncation (#677)
### Improve - Improve delete, rename chat session UX in Desktop, Web app - Get conversation by title when requested via chat API ### Fix - Allow unset locale for Google authenticating user - Handle truncation when single long non-system chat message - Fix setting chat session title from Desktop app - Only create new chat on get if a specific chat id, slug isn't requested
This commit is contained in:
@@ -794,7 +794,6 @@
|
||||
chatBody.dataset.conversationId = "";
|
||||
chatBody.dataset.conversationTitle = "";
|
||||
loadChat();
|
||||
flashStatusInChatInput("🗑 Cleared previous conversation history");
|
||||
})
|
||||
.catch(err => {
|
||||
flashStatusInChatInput("⛔️ Failed to clear conversation history");
|
||||
@@ -856,28 +855,6 @@
|
||||
let conversationMenu = document.createElement('div');
|
||||
conversationMenu.classList.add("conversation-menu");
|
||||
|
||||
let deleteButton = document.createElement('button');
|
||||
deleteButton.innerHTML = "Delete";
|
||||
deleteButton.classList.add("delete-conversation-button");
|
||||
deleteButton.classList.add("three-dot-menu-button-item");
|
||||
deleteButton.addEventListener('click', function() {
|
||||
let deleteURL = `/api/chat/history?client=web&conversation_id=${incomingConversationId}`;
|
||||
fetch(`${hostURL}${deleteURL}` , { method: "DELETE", headers })
|
||||
.then(response => response.ok ? response.json() : Promise.reject(response))
|
||||
.then(data => {
|
||||
let chatBody = document.getElementById("chat-body");
|
||||
chatBody.innerHTML = "";
|
||||
chatBody.dataset.conversationId = "";
|
||||
chatBody.dataset.conversationTitle = "";
|
||||
loadChat();
|
||||
})
|
||||
.catch(err => {
|
||||
return;
|
||||
});
|
||||
});
|
||||
conversationMenu.appendChild(deleteButton);
|
||||
threeDotMenu.appendChild(conversationMenu);
|
||||
|
||||
let editTitleButton = document.createElement('button');
|
||||
editTitleButton.innerHTML = "Rename";
|
||||
editTitleButton.classList.add("edit-title-button");
|
||||
@@ -903,12 +880,13 @@
|
||||
|
||||
conversationTitleInput.addEventListener('click', function(event) {
|
||||
event.stopPropagation();
|
||||
});
|
||||
conversationTitleInput.addEventListener('keydown', function(event) {
|
||||
if (event.key === "Enter") {
|
||||
event.preventDefault();
|
||||
conversationTitleInputButton.click();
|
||||
}
|
||||
});
|
||||
|
||||
conversationTitleInputBox.appendChild(conversationTitleInput);
|
||||
let conversationTitleInputButton = document.createElement('button');
|
||||
conversationTitleInputButton.innerHTML = "Save";
|
||||
@@ -918,7 +896,7 @@
|
||||
let newTitle = conversationTitleInput.value;
|
||||
if (newTitle != null) {
|
||||
let editURL = `/api/chat/title?client=web&conversation_id=${incomingConversationId}&title=${newTitle}`;
|
||||
fetch(`${hostURL}${editURL}` , { method: "PATCH" })
|
||||
fetch(`${hostURL}${editURL}` , { method: "PATCH", headers })
|
||||
.then(response => response.ok ? response.json() : Promise.reject(response))
|
||||
.then(data => {
|
||||
conversationButton.textContent = newTitle;
|
||||
@@ -931,8 +909,35 @@
|
||||
conversationTitleInputBox.appendChild(conversationTitleInputButton);
|
||||
conversationMenu.appendChild(conversationTitleInputBox);
|
||||
});
|
||||
|
||||
conversationMenu.appendChild(editTitleButton);
|
||||
threeDotMenu.appendChild(conversationMenu);
|
||||
|
||||
let deleteButton = document.createElement('button');
|
||||
deleteButton.innerHTML = "Delete";
|
||||
deleteButton.classList.add("delete-conversation-button");
|
||||
deleteButton.classList.add("three-dot-menu-button-item");
|
||||
deleteButton.addEventListener('click', function() {
|
||||
// Ask for confirmation before deleting chat session
|
||||
let confirmation = confirm('Are you sure you want to delete this chat session?');
|
||||
if (!confirmation) return;
|
||||
let deleteURL = `/api/chat/history?client=web&conversation_id=${incomingConversationId}`;
|
||||
fetch(`${hostURL}${deleteURL}` , { method: "DELETE", headers })
|
||||
.then(response => response.ok ? response.json() : Promise.reject(response))
|
||||
.then(data => {
|
||||
let chatBody = document.getElementById("chat-body");
|
||||
chatBody.innerHTML = "";
|
||||
chatBody.dataset.conversationId = "";
|
||||
chatBody.dataset.conversationTitle = "";
|
||||
loadChat();
|
||||
})
|
||||
.catch(err => {
|
||||
return;
|
||||
});
|
||||
});
|
||||
|
||||
conversationMenu.appendChild(deleteButton);
|
||||
threeDotMenu.appendChild(conversationMenu);
|
||||
});
|
||||
threeDotMenu.appendChild(threeDotMenuButton);
|
||||
conversationButton.appendChild(threeDotMenu);
|
||||
|
||||
@@ -436,19 +436,16 @@ class ConversationAdapters:
|
||||
|
||||
@staticmethod
|
||||
async def aget_conversation_by_user(
|
||||
user: KhojUser, client_application: ClientApplication = None, conversation_id: int = None, slug: str = None
|
||||
user: KhojUser, client_application: ClientApplication = None, conversation_id: int = None, title: str = None
|
||||
):
|
||||
if conversation_id:
|
||||
conversation = Conversation.objects.filter(user=user, client=client_application, id=conversation_id)
|
||||
elif slug:
|
||||
conversation = Conversation.objects.filter(user=user, client=client_application, slug=slug)
|
||||
return await Conversation.objects.filter(user=user, client=client_application, id=conversation_id).afirst()
|
||||
elif title:
|
||||
return await Conversation.objects.filter(user=user, client=client_application, title=title).afirst()
|
||||
else:
|
||||
conversation = Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at")
|
||||
|
||||
if await conversation.aexists():
|
||||
return await conversation.afirst()
|
||||
|
||||
return await Conversation.objects.acreate(user=user, client=client_application, slug=slug)
|
||||
return await (
|
||||
Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at").afirst()
|
||||
) or Conversation.objects.acreate(user=user, client=client_application)
|
||||
|
||||
@staticmethod
|
||||
async def adelete_conversation_by_user(
|
||||
|
||||
17
src/khoj/database/migrations/0031_alter_googleuser_locale.py
Normal file
17
src/khoj/database/migrations/0031_alter_googleuser_locale.py
Normal file
@@ -0,0 +1,17 @@
|
||||
# Generated by Django 4.2.10 on 2024-03-15 10:04
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0030_conversation_slug_and_title"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name="googleuser",
|
||||
name="locale",
|
||||
field=models.CharField(blank=True, default=None, max_length=200, null=True),
|
||||
),
|
||||
]
|
||||
@@ -43,7 +43,7 @@ class GoogleUser(models.Model):
|
||||
given_name = models.CharField(max_length=200, null=True, default=None, blank=True)
|
||||
family_name = models.CharField(max_length=200, null=True, default=None, blank=True)
|
||||
picture = models.CharField(max_length=200, null=True, default=None)
|
||||
locale = models.CharField(max_length=200)
|
||||
locale = models.CharField(max_length=200, null=True, default=None, blank=True)
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
@@ -1008,6 +1008,8 @@ To get started, just start typing below. You can also type / to see a list of co
|
||||
|
||||
conversationTitleInput.addEventListener('click', function(event) {
|
||||
event.stopPropagation();
|
||||
});
|
||||
conversationTitleInput.addEventListener('keydown', function(event) {
|
||||
if (event.key === "Enter") {
|
||||
event.preventDefault();
|
||||
conversationTitleInputButton.click();
|
||||
@@ -1044,6 +1046,9 @@ To get started, just start typing below. You can also type / to see a list of co
|
||||
deleteButton.classList.add("delete-conversation-button");
|
||||
deleteButton.classList.add("three-dot-menu-button-item");
|
||||
deleteButton.addEventListener('click', function() {
|
||||
// Ask for confirmation before deleting chat session
|
||||
let confirmation = confirm('Are you sure you want to delete this chat session?');
|
||||
if (!confirmation) return;
|
||||
let deleteURL = `/api/chat/history?client=web&conversation_id=${incomingConversationId}`;
|
||||
fetch(deleteURL , { method: "DELETE" })
|
||||
.then(response => response.ok ? response.json() : Promise.reject(response))
|
||||
@@ -1053,7 +1058,6 @@ To get started, just start typing below. You can also type / to see a list of co
|
||||
chatBody.dataset.conversationId = "";
|
||||
chatBody.dataset.conversationTitle = "";
|
||||
loadChat();
|
||||
flashStatusInChatInput("🗑 Cleared previous conversation history");
|
||||
})
|
||||
.catch(err => {
|
||||
flashStatusInChatInput("⛔️ Failed to clear conversation history");
|
||||
|
||||
@@ -199,19 +199,26 @@ def truncate_messages(
|
||||
f"Fallback to default chat model tokenizer: {default_tokenizer}.\nConfigure tokenizer for unsupported model: {model_name} in Khoj settings to improve context stuffing."
|
||||
)
|
||||
|
||||
system_message = messages.pop()
|
||||
assert type(system_message.content) == str
|
||||
system_message_tokens = len(encoder.encode(system_message.content))
|
||||
# Extract system message from messages
|
||||
system_message = None
|
||||
for idx, message in enumerate(messages):
|
||||
if message.role == "system":
|
||||
system_message = messages.pop(idx)
|
||||
break
|
||||
|
||||
system_message_tokens = (
|
||||
len(encoder.encode(system_message.content)) if system_message and type(system_message.content) == str else 0
|
||||
)
|
||||
|
||||
tokens = sum([len(encoder.encode(message.content)) for message in messages if type(message.content) == str])
|
||||
|
||||
# Drop older messages until under max supported prompt size by model
|
||||
while (tokens + system_message_tokens) > max_prompt_size and len(messages) > 1:
|
||||
messages.pop()
|
||||
assert type(system_message.content) == str
|
||||
tokens = sum([len(encoder.encode(message.content)) for message in messages if type(message.content) == str])
|
||||
|
||||
# Truncate current message if still over max supported prompt size by model
|
||||
if (tokens + system_message_tokens) > max_prompt_size:
|
||||
assert type(system_message.content) == str
|
||||
current_message = "\n".join(messages[0].content.split("\n")[:-1]) if type(messages[0].content) == str else ""
|
||||
original_question = "\n".join(messages[0].content.split("\n")[-1:]) if type(messages[0].content) == str else ""
|
||||
original_question = f"\n{original_question}"
|
||||
@@ -223,7 +230,7 @@ def truncate_messages(
|
||||
)
|
||||
messages = [ChatMessage(content=truncated_message + original_question, role=messages[0].role)]
|
||||
|
||||
return messages + [system_message]
|
||||
return messages + [system_message] if system_message else messages
|
||||
|
||||
|
||||
def reciprocal_conversation_to_chatml(message_pair):
|
||||
|
||||
@@ -224,7 +224,7 @@ async def chat(
|
||||
n: Optional[int] = 5,
|
||||
d: Optional[float] = 0.18,
|
||||
stream: Optional[bool] = False,
|
||||
slug: Optional[str] = None,
|
||||
title: Optional[str] = None,
|
||||
conversation_id: Optional[int] = None,
|
||||
city: Optional[str] = None,
|
||||
region: Optional[str] = None,
|
||||
@@ -250,9 +250,15 @@ async def chat(
|
||||
formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device())
|
||||
return StreamingResponse(iter([formatted_help]), media_type="text/event-stream", status_code=200)
|
||||
|
||||
meta_log = (
|
||||
await ConversationAdapters.aget_conversation_by_user(user, request.user.client_app, conversation_id, slug)
|
||||
).conversation_log
|
||||
conversation = await ConversationAdapters.aget_conversation_by_user(
|
||||
user, request.user.client_app, conversation_id, title
|
||||
)
|
||||
if not conversation:
|
||||
return Response(
|
||||
content=f"No conversation found with requested id, title", media_type="text/plain", status_code=400
|
||||
)
|
||||
else:
|
||||
meta_log = conversation.conversation_log
|
||||
|
||||
if conversation_commands == [ConversationCommand.Default]:
|
||||
conversation_commands = await aget_relevant_information_sources(q, meta_log)
|
||||
|
||||
@@ -19,49 +19,80 @@ class TestTruncateMessage:
|
||||
encoder = tiktoken.encoding_for_model(model_name)
|
||||
|
||||
def test_truncate_message_all_small(self):
|
||||
chat_messages = ChatMessageFactory.build_batch(500)
|
||||
# Arrange
|
||||
chat_history = ChatMessageFactory.build_batch(500)
|
||||
|
||||
prompt = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name)
|
||||
tokens = sum([len(self.encoder.encode(message.content)) for message in prompt])
|
||||
# Act
|
||||
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
|
||||
tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])
|
||||
|
||||
# Assert
|
||||
# The original object has been modified. Verify certain properties
|
||||
assert len(chat_messages) < 500
|
||||
assert len(chat_messages) > 1
|
||||
assert len(chat_history) < 500
|
||||
assert len(chat_history) > 1
|
||||
assert tokens <= self.max_prompt_size
|
||||
|
||||
def test_truncate_message_first_large(self):
|
||||
chat_messages = ChatMessageFactory.build_batch(25)
|
||||
# Arrange
|
||||
chat_history = ChatMessageFactory.build_batch(25)
|
||||
big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=2000))
|
||||
big_chat_message.content = big_chat_message.content + "\n" + "Question?"
|
||||
copy_big_chat_message = big_chat_message.copy()
|
||||
chat_messages.insert(0, big_chat_message)
|
||||
tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
|
||||
chat_history.insert(0, big_chat_message)
|
||||
tokens = sum([len(self.encoder.encode(message.content)) for message in chat_history])
|
||||
|
||||
prompt = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name)
|
||||
tokens = sum([len(self.encoder.encode(message.content)) for message in prompt])
|
||||
# Act
|
||||
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
|
||||
tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])
|
||||
|
||||
# Assert
|
||||
# The original object has been modified. Verify certain properties
|
||||
assert len(chat_messages) == 1
|
||||
assert prompt[0] != copy_big_chat_message
|
||||
assert len(chat_history) == 1
|
||||
assert truncated_chat_history[0] != copy_big_chat_message
|
||||
assert tokens <= self.max_prompt_size
|
||||
|
||||
def test_truncate_message_last_large(self):
|
||||
chat_messages = ChatMessageFactory.build_batch(25)
|
||||
# Arrange
|
||||
chat_history = ChatMessageFactory.build_batch(25)
|
||||
chat_history[0].role = "system" # Mark the first message as system message
|
||||
big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=1000))
|
||||
big_chat_message.content = big_chat_message.content + "\n" + "Question?"
|
||||
copy_big_chat_message = big_chat_message.copy()
|
||||
|
||||
chat_messages.insert(0, big_chat_message)
|
||||
tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
|
||||
chat_history.insert(0, big_chat_message)
|
||||
initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_history])
|
||||
|
||||
prompt = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name)
|
||||
tokens = sum([len(self.encoder.encode(message.content)) for message in prompt])
|
||||
# Act
|
||||
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
|
||||
final_tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])
|
||||
|
||||
# Assert
|
||||
# The original object has been modified. Verify certain properties.
|
||||
assert len(prompt) == (
|
||||
len(chat_messages) + 1
|
||||
assert len(truncated_chat_history) == (
|
||||
len(chat_history) + 1
|
||||
) # Because the system_prompt is popped off from the chat_messages lsit
|
||||
assert len(prompt) < 26
|
||||
assert len(prompt) > 1
|
||||
assert prompt[0] != copy_big_chat_message
|
||||
assert tokens <= self.max_prompt_size
|
||||
assert len(truncated_chat_history) < 26
|
||||
assert len(truncated_chat_history) > 1
|
||||
assert truncated_chat_history[0] != copy_big_chat_message
|
||||
assert initial_tokens > self.max_prompt_size
|
||||
assert final_tokens <= self.max_prompt_size
|
||||
|
||||
def test_truncate_single_large_non_system_message(self):
|
||||
# Arrange
|
||||
big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=2000))
|
||||
big_chat_message.content = big_chat_message.content + "\n" + "Question?"
|
||||
big_chat_message.role = "user"
|
||||
copy_big_chat_message = big_chat_message.copy()
|
||||
chat_messages = [big_chat_message]
|
||||
initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
|
||||
|
||||
# Act
|
||||
truncated_chat_history = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name)
|
||||
final_tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])
|
||||
|
||||
# Assert
|
||||
# The original object has been modified. Verify certain properties
|
||||
assert initial_tokens > self.max_prompt_size
|
||||
assert final_tokens <= self.max_prompt_size
|
||||
assert len(chat_messages) == 1
|
||||
assert truncated_chat_history[0] != copy_big_chat_message
|
||||
|
||||
Reference in New Issue
Block a user