diff --git a/src/interface/desktop/chat.html b/src/interface/desktop/chat.html
index cc081da7..94cde782 100644
--- a/src/interface/desktop/chat.html
+++ b/src/interface/desktop/chat.html
@@ -794,7 +794,6 @@
chatBody.dataset.conversationId = "";
chatBody.dataset.conversationTitle = "";
loadChat();
- flashStatusInChatInput("🗑 Cleared previous conversation history");
})
.catch(err => {
flashStatusInChatInput("⛔️ Failed to clear conversation history");
@@ -856,28 +855,6 @@
let conversationMenu = document.createElement('div');
conversationMenu.classList.add("conversation-menu");
- let deleteButton = document.createElement('button');
- deleteButton.innerHTML = "Delete";
- deleteButton.classList.add("delete-conversation-button");
- deleteButton.classList.add("three-dot-menu-button-item");
- deleteButton.addEventListener('click', function() {
- let deleteURL = `/api/chat/history?client=web&conversation_id=${incomingConversationId}`;
- fetch(`${hostURL}${deleteURL}` , { method: "DELETE", headers })
- .then(response => response.ok ? response.json() : Promise.reject(response))
- .then(data => {
- let chatBody = document.getElementById("chat-body");
- chatBody.innerHTML = "";
- chatBody.dataset.conversationId = "";
- chatBody.dataset.conversationTitle = "";
- loadChat();
- })
- .catch(err => {
- return;
- });
- });
- conversationMenu.appendChild(deleteButton);
- threeDotMenu.appendChild(conversationMenu);
-
let editTitleButton = document.createElement('button');
editTitleButton.innerHTML = "Rename";
editTitleButton.classList.add("edit-title-button");
@@ -903,12 +880,13 @@
conversationTitleInput.addEventListener('click', function(event) {
event.stopPropagation();
+ });
+ conversationTitleInput.addEventListener('keydown', function(event) {
if (event.key === "Enter") {
event.preventDefault();
conversationTitleInputButton.click();
}
});
-
conversationTitleInputBox.appendChild(conversationTitleInput);
let conversationTitleInputButton = document.createElement('button');
conversationTitleInputButton.innerHTML = "Save";
@@ -918,7 +896,7 @@
let newTitle = conversationTitleInput.value;
if (newTitle != null) {
let editURL = `/api/chat/title?client=web&conversation_id=${incomingConversationId}&title=${newTitle}`;
- fetch(`${hostURL}${editURL}` , { method: "PATCH" })
+ fetch(`${hostURL}${editURL}` , { method: "PATCH", headers })
.then(response => response.ok ? response.json() : Promise.reject(response))
.then(data => {
conversationButton.textContent = newTitle;
@@ -931,8 +909,35 @@
conversationTitleInputBox.appendChild(conversationTitleInputButton);
conversationMenu.appendChild(conversationTitleInputBox);
});
+
conversationMenu.appendChild(editTitleButton);
threeDotMenu.appendChild(conversationMenu);
+
+ let deleteButton = document.createElement('button');
+ deleteButton.innerHTML = "Delete";
+ deleteButton.classList.add("delete-conversation-button");
+ deleteButton.classList.add("three-dot-menu-button-item");
+ deleteButton.addEventListener('click', function() {
+ // Ask for confirmation before deleting chat session
+ let confirmation = confirm('Are you sure you want to delete this chat session?');
+ if (!confirmation) return;
+ let deleteURL = `/api/chat/history?client=web&conversation_id=${incomingConversationId}`;
+ fetch(`${hostURL}${deleteURL}` , { method: "DELETE", headers })
+ .then(response => response.ok ? response.json() : Promise.reject(response))
+ .then(data => {
+ let chatBody = document.getElementById("chat-body");
+ chatBody.innerHTML = "";
+ chatBody.dataset.conversationId = "";
+ chatBody.dataset.conversationTitle = "";
+ loadChat();
+ })
+ .catch(err => {
+ return;
+ });
+ });
+
+ conversationMenu.appendChild(deleteButton);
+ threeDotMenu.appendChild(conversationMenu);
});
threeDotMenu.appendChild(threeDotMenuButton);
conversationButton.appendChild(threeDotMenu);
diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py
index 9d74df05..e6927d27 100644
--- a/src/khoj/database/adapters/__init__.py
+++ b/src/khoj/database/adapters/__init__.py
@@ -436,19 +436,16 @@ class ConversationAdapters:
@staticmethod
async def aget_conversation_by_user(
- user: KhojUser, client_application: ClientApplication = None, conversation_id: int = None, slug: str = None
+ user: KhojUser, client_application: ClientApplication = None, conversation_id: int = None, title: str = None
):
if conversation_id:
- conversation = Conversation.objects.filter(user=user, client=client_application, id=conversation_id)
- elif slug:
- conversation = Conversation.objects.filter(user=user, client=client_application, slug=slug)
+ return await Conversation.objects.filter(user=user, client=client_application, id=conversation_id).afirst()
+ elif title:
+ return await Conversation.objects.filter(user=user, client=client_application, title=title).afirst()
else:
- conversation = Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at")
-
- if await conversation.aexists():
- return await conversation.afirst()
-
- return await Conversation.objects.acreate(user=user, client=client_application, slug=slug)
+ return await (
+ Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at").afirst()
+ ) or Conversation.objects.acreate(user=user, client=client_application)
@staticmethod
async def adelete_conversation_by_user(
diff --git a/src/khoj/database/migrations/0031_alter_googleuser_locale.py b/src/khoj/database/migrations/0031_alter_googleuser_locale.py
new file mode 100644
index 00000000..99c4573a
--- /dev/null
+++ b/src/khoj/database/migrations/0031_alter_googleuser_locale.py
@@ -0,0 +1,17 @@
+# Generated by Django 4.2.10 on 2024-03-15 10:04
+
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+ dependencies = [
+ ("database", "0030_conversation_slug_and_title"),
+ ]
+
+ operations = [
+ migrations.AlterField(
+ model_name="googleuser",
+ name="locale",
+ field=models.CharField(blank=True, default=None, max_length=200, null=True),
+ ),
+ ]
diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py
index b74aaa11..3f8f50b4 100644
--- a/src/khoj/database/models/__init__.py
+++ b/src/khoj/database/models/__init__.py
@@ -43,7 +43,7 @@ class GoogleUser(models.Model):
given_name = models.CharField(max_length=200, null=True, default=None, blank=True)
family_name = models.CharField(max_length=200, null=True, default=None, blank=True)
picture = models.CharField(max_length=200, null=True, default=None)
- locale = models.CharField(max_length=200)
+ locale = models.CharField(max_length=200, null=True, default=None, blank=True)
def __str__(self):
return self.name
diff --git a/src/khoj/interface/web/chat.html b/src/khoj/interface/web/chat.html
index c251bff2..35047c31 100644
--- a/src/khoj/interface/web/chat.html
+++ b/src/khoj/interface/web/chat.html
@@ -1008,6 +1008,8 @@ To get started, just start typing below. You can also type / to see a list of co
conversationTitleInput.addEventListener('click', function(event) {
event.stopPropagation();
+ });
+ conversationTitleInput.addEventListener('keydown', function(event) {
if (event.key === "Enter") {
event.preventDefault();
conversationTitleInputButton.click();
@@ -1044,6 +1046,9 @@ To get started, just start typing below. You can also type / to see a list of co
deleteButton.classList.add("delete-conversation-button");
deleteButton.classList.add("three-dot-menu-button-item");
deleteButton.addEventListener('click', function() {
+ // Ask for confirmation before deleting chat session
+ let confirmation = confirm('Are you sure you want to delete this chat session?');
+ if (!confirmation) return;
let deleteURL = `/api/chat/history?client=web&conversation_id=${incomingConversationId}`;
fetch(deleteURL , { method: "DELETE" })
.then(response => response.ok ? response.json() : Promise.reject(response))
@@ -1053,7 +1058,6 @@ To get started, just start typing below. You can also type / to see a list of co
chatBody.dataset.conversationId = "";
chatBody.dataset.conversationTitle = "";
loadChat();
- flashStatusInChatInput("🗑 Cleared previous conversation history");
})
.catch(err => {
flashStatusInChatInput("⛔️ Failed to clear conversation history");
diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py
index b384ad7a..15a4970e 100644
--- a/src/khoj/processor/conversation/utils.py
+++ b/src/khoj/processor/conversation/utils.py
@@ -199,19 +199,26 @@ def truncate_messages(
f"Fallback to default chat model tokenizer: {default_tokenizer}.\nConfigure tokenizer for unsupported model: {model_name} in Khoj settings to improve context stuffing."
)
- system_message = messages.pop()
- assert type(system_message.content) == str
- system_message_tokens = len(encoder.encode(system_message.content))
+ # Extract system message from messages
+ system_message = None
+ for idx, message in enumerate(messages):
+ if message.role == "system":
+ system_message = messages.pop(idx)
+ break
+
+ system_message_tokens = (
+ len(encoder.encode(system_message.content)) if system_message and type(system_message.content) == str else 0
+ )
tokens = sum([len(encoder.encode(message.content)) for message in messages if type(message.content) == str])
+
+ # Drop older messages until under max supported prompt size by model
while (tokens + system_message_tokens) > max_prompt_size and len(messages) > 1:
messages.pop()
- assert type(system_message.content) == str
tokens = sum([len(encoder.encode(message.content)) for message in messages if type(message.content) == str])
# Truncate current message if still over max supported prompt size by model
if (tokens + system_message_tokens) > max_prompt_size:
- assert type(system_message.content) == str
current_message = "\n".join(messages[0].content.split("\n")[:-1]) if type(messages[0].content) == str else ""
original_question = "\n".join(messages[0].content.split("\n")[-1:]) if type(messages[0].content) == str else ""
original_question = f"\n{original_question}"
@@ -223,7 +230,7 @@ def truncate_messages(
)
messages = [ChatMessage(content=truncated_message + original_question, role=messages[0].role)]
- return messages + [system_message]
+ return messages + [system_message] if system_message else messages
def reciprocal_conversation_to_chatml(message_pair):
diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py
index 7a99869c..7cd04c98 100644
--- a/src/khoj/routers/api_chat.py
+++ b/src/khoj/routers/api_chat.py
@@ -224,7 +224,7 @@ async def chat(
n: Optional[int] = 5,
d: Optional[float] = 0.18,
stream: Optional[bool] = False,
- slug: Optional[str] = None,
+ title: Optional[str] = None,
conversation_id: Optional[int] = None,
city: Optional[str] = None,
region: Optional[str] = None,
@@ -250,9 +250,15 @@ async def chat(
formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device())
return StreamingResponse(iter([formatted_help]), media_type="text/event-stream", status_code=200)
- meta_log = (
- await ConversationAdapters.aget_conversation_by_user(user, request.user.client_app, conversation_id, slug)
- ).conversation_log
+ conversation = await ConversationAdapters.aget_conversation_by_user(
+ user, request.user.client_app, conversation_id, title
+ )
+ if not conversation:
+ return Response(
+ content=f"No conversation found with requested id, title", media_type="text/plain", status_code=400
+ )
+ else:
+ meta_log = conversation.conversation_log
if conversation_commands == [ConversationCommand.Default]:
conversation_commands = await aget_relevant_information_sources(q, meta_log)
diff --git a/tests/test_conversation_utils.py b/tests/test_conversation_utils.py
index 52db0002..bc8c5315 100644
--- a/tests/test_conversation_utils.py
+++ b/tests/test_conversation_utils.py
@@ -19,49 +19,80 @@ class TestTruncateMessage:
encoder = tiktoken.encoding_for_model(model_name)
def test_truncate_message_all_small(self):
- chat_messages = ChatMessageFactory.build_batch(500)
+ # Arrange
+ chat_history = ChatMessageFactory.build_batch(500)
- prompt = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name)
- tokens = sum([len(self.encoder.encode(message.content)) for message in prompt])
+ # Act
+ truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
+ tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])
+ # Assert
# The original object has been modified. Verify certain properties
- assert len(chat_messages) < 500
- assert len(chat_messages) > 1
+ assert len(chat_history) < 500
+ assert len(chat_history) > 1
assert tokens <= self.max_prompt_size
def test_truncate_message_first_large(self):
- chat_messages = ChatMessageFactory.build_batch(25)
+ # Arrange
+ chat_history = ChatMessageFactory.build_batch(25)
big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=2000))
big_chat_message.content = big_chat_message.content + "\n" + "Question?"
copy_big_chat_message = big_chat_message.copy()
- chat_messages.insert(0, big_chat_message)
- tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
+ chat_history.insert(0, big_chat_message)
+ tokens = sum([len(self.encoder.encode(message.content)) for message in chat_history])
- prompt = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name)
- tokens = sum([len(self.encoder.encode(message.content)) for message in prompt])
+ # Act
+ truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
+ tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])
+ # Assert
# The original object has been modified. Verify certain properties
- assert len(chat_messages) == 1
- assert prompt[0] != copy_big_chat_message
+ assert len(chat_history) == 1
+ assert truncated_chat_history[0] != copy_big_chat_message
assert tokens <= self.max_prompt_size
def test_truncate_message_last_large(self):
- chat_messages = ChatMessageFactory.build_batch(25)
+ # Arrange
+ chat_history = ChatMessageFactory.build_batch(25)
+ chat_history[0].role = "system" # Mark the first message as system message
big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=1000))
big_chat_message.content = big_chat_message.content + "\n" + "Question?"
copy_big_chat_message = big_chat_message.copy()
- chat_messages.insert(0, big_chat_message)
- tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
+ chat_history.insert(0, big_chat_message)
+ initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_history])
- prompt = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name)
- tokens = sum([len(self.encoder.encode(message.content)) for message in prompt])
+ # Act
+ truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
+ final_tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])
+ # Assert
# The original object has been modified. Verify certain properties.
- assert len(prompt) == (
- len(chat_messages) + 1
+ assert len(truncated_chat_history) == (
+ len(chat_history) + 1
) # Because the system_prompt is popped off from the chat_messages lsit
- assert len(prompt) < 26
- assert len(prompt) > 1
- assert prompt[0] != copy_big_chat_message
- assert tokens <= self.max_prompt_size
+ assert len(truncated_chat_history) < 26
+ assert len(truncated_chat_history) > 1
+ assert truncated_chat_history[0] != copy_big_chat_message
+ assert initial_tokens > self.max_prompt_size
+ assert final_tokens <= self.max_prompt_size
+
+ def test_truncate_single_large_non_system_message(self):
+ # Arrange
+ big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=2000))
+ big_chat_message.content = big_chat_message.content + "\n" + "Question?"
+ big_chat_message.role = "user"
+ copy_big_chat_message = big_chat_message.copy()
+ chat_messages = [big_chat_message]
+ initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
+
+ # Act
+ truncated_chat_history = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name)
+ final_tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])
+
+ # Assert
+ # The original object has been modified. Verify certain properties
+ assert initial_tokens > self.max_prompt_size
+ assert final_tokens <= self.max_prompt_size
+ assert len(chat_messages) == 1
+ assert truncated_chat_history[0] != copy_big_chat_message