diff --git a/gunicorn-config.py b/gunicorn-config.py index 1760ae38..bfed49e7 100644 --- a/gunicorn-config.py +++ b/gunicorn-config.py @@ -1,7 +1,7 @@ import multiprocessing bind = "0.0.0.0:42110" -workers = 4 +workers = 8 worker_class = "uvicorn.workers.UvicornWorker" timeout = 120 keep_alive = 60 diff --git a/src/interface/desktop/chat.html b/src/interface/desktop/chat.html index cc081da7..94cde782 100644 --- a/src/interface/desktop/chat.html +++ b/src/interface/desktop/chat.html @@ -794,7 +794,6 @@ chatBody.dataset.conversationId = ""; chatBody.dataset.conversationTitle = ""; loadChat(); - flashStatusInChatInput("🗑 Cleared previous conversation history"); }) .catch(err => { flashStatusInChatInput("⛔️ Failed to clear conversation history"); @@ -856,28 +855,6 @@ let conversationMenu = document.createElement('div'); conversationMenu.classList.add("conversation-menu"); - let deleteButton = document.createElement('button'); - deleteButton.innerHTML = "Delete"; - deleteButton.classList.add("delete-conversation-button"); - deleteButton.classList.add("three-dot-menu-button-item"); - deleteButton.addEventListener('click', function() { - let deleteURL = `/api/chat/history?client=web&conversation_id=${incomingConversationId}`; - fetch(`${hostURL}${deleteURL}` , { method: "DELETE", headers }) - .then(response => response.ok ? response.json() : Promise.reject(response)) - .then(data => { - let chatBody = document.getElementById("chat-body"); - chatBody.innerHTML = ""; - chatBody.dataset.conversationId = ""; - chatBody.dataset.conversationTitle = ""; - loadChat(); - }) - .catch(err => { - return; - }); - }); - conversationMenu.appendChild(deleteButton); - threeDotMenu.appendChild(conversationMenu); - let editTitleButton = document.createElement('button'); editTitleButton.innerHTML = "Rename"; editTitleButton.classList.add("edit-title-button"); @@ -903,12 +880,13 @@ conversationTitleInput.addEventListener('click', function(event) { event.stopPropagation(); + }); + conversationTitleInput.addEventListener('keydown', function(event) { if (event.key === "Enter") { event.preventDefault(); conversationTitleInputButton.click(); } }); - conversationTitleInputBox.appendChild(conversationTitleInput); let conversationTitleInputButton = document.createElement('button'); conversationTitleInputButton.innerHTML = "Save"; @@ -918,7 +896,7 @@ let newTitle = conversationTitleInput.value; if (newTitle != null) { let editURL = `/api/chat/title?client=web&conversation_id=${incomingConversationId}&title=${newTitle}`; - fetch(`${hostURL}${editURL}` , { method: "PATCH" }) + fetch(`${hostURL}${editURL}` , { method: "PATCH", headers }) .then(response => response.ok ? response.json() : Promise.reject(response)) .then(data => { conversationButton.textContent = newTitle; @@ -931,8 +909,35 @@ conversationTitleInputBox.appendChild(conversationTitleInputButton); conversationMenu.appendChild(conversationTitleInputBox); }); + conversationMenu.appendChild(editTitleButton); threeDotMenu.appendChild(conversationMenu); + + let deleteButton = document.createElement('button'); + deleteButton.innerHTML = "Delete"; + deleteButton.classList.add("delete-conversation-button"); + deleteButton.classList.add("three-dot-menu-button-item"); + deleteButton.addEventListener('click', function() { + // Ask for confirmation before deleting chat session + let confirmation = confirm('Are you sure you want to delete this chat session?'); + if (!confirmation) return; + let deleteURL = `/api/chat/history?client=web&conversation_id=${incomingConversationId}`; + fetch(`${hostURL}${deleteURL}` , { method: "DELETE", headers }) + .then(response => response.ok ? response.json() : Promise.reject(response)) + .then(data => { + let chatBody = document.getElementById("chat-body"); + chatBody.innerHTML = ""; + chatBody.dataset.conversationId = ""; + chatBody.dataset.conversationTitle = ""; + loadChat(); + }) + .catch(err => { + return; + }); + }); + + conversationMenu.appendChild(deleteButton); + threeDotMenu.appendChild(conversationMenu); }); threeDotMenu.appendChild(threeDotMenuButton); conversationButton.appendChild(threeDotMenu); diff --git a/src/khoj/configure.py b/src/khoj/configure.py index ab0b2649..0adbe889 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -363,4 +363,4 @@ def upload_telemetry(): @schedule.repeat(schedule.every(31).minutes) def delete_old_user_requests(): num_deleted = delete_user_requests() - logger.info(f"🗑️ Deleted {num_deleted[0]} day-old user requests") + logger.debug(f"🗑️ Deleted {num_deleted[0]} day-old user requests") diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 87314726..1dce9c8e 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -470,7 +470,7 @@ class ConversationAdapters: @staticmethod def get_conversation_by_user( user: KhojUser, client_application: ClientApplication = None, conversation_id: int = None - ): + ) -> Optional[Conversation]: if conversation_id: conversation = ( Conversation.objects.filter(user=user, client=client_application, id=conversation_id) @@ -518,19 +518,21 @@ class ConversationAdapters: @staticmethod async def aget_conversation_by_user( - user: KhojUser, client_application: ClientApplication = None, conversation_id: int = None, slug: str = None - ): + user: KhojUser, client_application: ClientApplication = None, conversation_id: int = None, title: str = None + ) -> Optional[Conversation]: if conversation_id: - conversation = Conversation.objects.filter(user=user, client=client_application, id=conversation_id) - elif slug: - conversation = Conversation.objects.filter(user=user, client=client_application, slug=slug) + return await Conversation.objects.filter(user=user, client=client_application, id=conversation_id).afirst() + elif title: + return await Conversation.objects.filter(user=user, client=client_application, title=title).afirst() else: conversation = Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at") if await conversation.aexists(): return await conversation.prefetch_related("agent").afirst() - return await Conversation.objects.acreate(user=user, client=client_application, slug=slug) + return await ( + Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at").afirst() + ) or await Conversation.objects.acreate(user=user, client=client_application) @staticmethod async def adelete_conversation_by_user( diff --git a/src/khoj/database/migrations/0031_alter_googleuser_locale.py b/src/khoj/database/migrations/0031_alter_googleuser_locale.py new file mode 100644 index 00000000..99c4573a --- /dev/null +++ b/src/khoj/database/migrations/0031_alter_googleuser_locale.py @@ -0,0 +1,17 @@ +# Generated by Django 4.2.10 on 2024-03-15 10:04 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0030_conversation_slug_and_title"), + ] + + operations = [ + migrations.AlterField( + model_name="googleuser", + name="locale", + field=models.CharField(blank=True, default=None, max_length=200, null=True), + ), + ] diff --git a/src/khoj/database/migrations/0032_merge_20240322_0427.py b/src/khoj/database/migrations/0032_merge_20240322_0427.py new file mode 100644 index 00000000..aee557c0 --- /dev/null +++ b/src/khoj/database/migrations/0032_merge_20240322_0427.py @@ -0,0 +1,14 @@ +# Generated by Django 4.2.10 on 2024-03-22 04:27 + +from typing import List + +from django.db import migrations + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0031_agent_conversation_agent"), + ("database", "0031_alter_googleuser_locale"), + ] + + operations: List[str] = [] diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 846b3318..3d9cdfc6 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -47,7 +47,7 @@ class GoogleUser(models.Model): given_name = models.CharField(max_length=200, null=True, default=None, blank=True) family_name = models.CharField(max_length=200, null=True, default=None, blank=True) picture = models.CharField(max_length=200, null=True, default=None) - locale = models.CharField(max_length=200) + locale = models.CharField(max_length=200, null=True, default=None, blank=True) def __str__(self): return self.name diff --git a/src/khoj/interface/web/chat.html b/src/khoj/interface/web/chat.html index 50e7c0f9..6a00f6cb 100644 --- a/src/khoj/interface/web/chat.html +++ b/src/khoj/interface/web/chat.html @@ -1334,6 +1334,8 @@ To get started, just start typing below. You can also type / to see a list of co conversationTitleInput.addEventListener('click', function(event) { event.stopPropagation(); + }); + conversationTitleInput.addEventListener('keydown', function(event) { if (event.key === "Enter") { event.preventDefault(); conversationTitleInputButton.click(); @@ -1370,6 +1372,9 @@ To get started, just start typing below. You can also type / to see a list of co deleteButton.classList.add("delete-conversation-button"); deleteButton.classList.add("three-dot-menu-button-item"); deleteButton.addEventListener('click', function() { + // Ask for confirmation before deleting chat session + let confirmation = confirm('Are you sure you want to delete this chat session?'); + if (!confirmation) return; let deleteURL = `/api/chat/history?client=web&conversation_id=${incomingConversationId}`; fetch(deleteURL , { method: "DELETE" }) .then(response => response.ok ? response.json() : Promise.reject(response)) @@ -1379,7 +1384,6 @@ To get started, just start typing below. You can also type / to see a list of co chatBody.dataset.conversationId = ""; chatBody.dataset.conversationTitle = ""; loadChat(); - flashStatusInChatInput("🗑 Cleared previous conversation history"); }) .catch(err => { flashStatusInChatInput("⛔️ Failed to clear conversation history"); diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index 05dea460..52e7fd59 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -47,7 +47,8 @@ def extract_questions( last_new_year = current_new_year.replace(year=today.year - 1) prompt = prompts.extract_questions.format( - current_date=today.strftime("%A, %Y-%m-%d"), + current_date=today.strftime("%Y-%m-%d"), + day_of_week=today.strftime("%A"), last_new_year=last_new_year.strftime("%Y"), last_new_year_date=last_new_year.strftime("%Y-%m-%d"), current_new_year_date=current_new_year.strftime("%Y-%m-%d"), diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index 0013cbfa..be6b2751 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -253,8 +253,8 @@ You are Khoj, an extremely smart and helpful search assistant with the ability t - Break messages into multiple search queries when required to retrieve the relevant information. - Add date filters to your search queries from questions and answers when required to retrieve the relevant information. -What searches will you need to perform to answer the users question? Respond with search queries as list of strings in a JSON object. -Current Date: {current_date} +What searches will you perform to answer the users question? Respond with search queries as list of strings in a JSON object. +Current Date: {day_of_week}, {current_date} User's Location: {location} Q: How was my trip to Cambodia? @@ -418,7 +418,7 @@ You are Khoj, an advanced google search assistant. You are tasked with construct - You will receive the conversation history as context. - Add as much context from the previous questions and answers as required into your search queries. - Break messages into multiple search queries when required to retrieve the relevant information. -- Use site: and after: google search operators when appropriate +- Use site: google search operators when appropriate - You have access to the the whole internet to retrieve information. - Official, up-to-date information about you, Khoj, is available at site:khoj.dev @@ -433,7 +433,7 @@ User: I like to use Hacker News to get my tech news. AI: Hacker News is an online forum for sharing and discussing the latest tech news. It is a great place to learn about new technologies and startups. Q: Summarize posts about vector databases on Hacker News since Feb 2024 -Khoj: {{"queries": ["site:news.ycombinator.com after:2024/02/01 vector database"]}} +Khoj: {{"queries": ["site:news.ycombinator.com vector database since 1 February 2024"]}} History: User: I'm currently living in New York but I'm thinking about moving to San Francisco. diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index b384ad7a..15a4970e 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -199,19 +199,26 @@ def truncate_messages( f"Fallback to default chat model tokenizer: {default_tokenizer}.\nConfigure tokenizer for unsupported model: {model_name} in Khoj settings to improve context stuffing." ) - system_message = messages.pop() - assert type(system_message.content) == str - system_message_tokens = len(encoder.encode(system_message.content)) + # Extract system message from messages + system_message = None + for idx, message in enumerate(messages): + if message.role == "system": + system_message = messages.pop(idx) + break + + system_message_tokens = ( + len(encoder.encode(system_message.content)) if system_message and type(system_message.content) == str else 0 + ) tokens = sum([len(encoder.encode(message.content)) for message in messages if type(message.content) == str]) + + # Drop older messages until under max supported prompt size by model while (tokens + system_message_tokens) > max_prompt_size and len(messages) > 1: messages.pop() - assert type(system_message.content) == str tokens = sum([len(encoder.encode(message.content)) for message in messages if type(message.content) == str]) # Truncate current message if still over max supported prompt size by model if (tokens + system_message_tokens) > max_prompt_size: - assert type(system_message.content) == str current_message = "\n".join(messages[0].content.split("\n")[:-1]) if type(messages[0].content) == str else "" original_question = "\n".join(messages[0].content.split("\n")[-1:]) if type(messages[0].content) == str else "" original_question = f"\n{original_question}" @@ -223,7 +230,7 @@ def truncate_messages( ) messages = [ChatMessage(content=truncated_message + original_question, role=messages[0].role)] - return messages + [system_message] + return messages + [system_message] if system_message else messages def reciprocal_conversation_to_chatml(message_pair): diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index d91dd596..9498d181 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -369,6 +369,7 @@ async def extract_references_and_questions( # Collate search results as context for GPT with timer("Searching knowledge base took", logger): result_list = [] + logger.info(f"🔍 Searching knowledge base with queries: {inferred_queries}") for query in inferred_queries: n_items = min(n, 3) if using_offline_chat else n result_list.extend( diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 150733b7..884b62a4 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -455,7 +455,7 @@ async def chat( n: Optional[int] = 5, d: Optional[float] = 0.18, stream: Optional[bool] = False, - slug: Optional[str] = None, + title: Optional[str] = None, conversation_id: Optional[int] = None, city: Optional[str] = None, region: Optional[str] = None, @@ -482,10 +482,14 @@ async def chat( return StreamingResponse(iter([formatted_help]), media_type="text/event-stream", status_code=200) conversation = await ConversationAdapters.aget_conversation_by_user( - user, request.user.client_app, conversation_id, slug + user, request.user.client_app, conversation_id, title ) - - meta_log = conversation.conversation_log + if not conversation: + return Response( + content=f"No conversation found with requested id, title", media_type="text/plain", status_code=400 + ) + else: + meta_log = conversation.conversation_log if conversation_commands == [ConversationCommand.Default]: conversation_commands = await aget_relevant_information_sources(q, meta_log) @@ -557,7 +561,7 @@ async def chat( intent_type=intent_type, inferred_queries=[improved_image_prompt], client_application=request.user.client_app, - conversation_id=conversation_id, + conversation_id=conversation.id, compiled_references=compiled_references, online_results=online_results, ) @@ -575,7 +579,7 @@ async def chat( conversation_commands, user, request.user.client_app, - conversation_id, + conversation.id, location, user_name, ) diff --git a/tests/test_conversation_utils.py b/tests/test_conversation_utils.py index 52db0002..bc8c5315 100644 --- a/tests/test_conversation_utils.py +++ b/tests/test_conversation_utils.py @@ -19,49 +19,80 @@ class TestTruncateMessage: encoder = tiktoken.encoding_for_model(model_name) def test_truncate_message_all_small(self): - chat_messages = ChatMessageFactory.build_batch(500) + # Arrange + chat_history = ChatMessageFactory.build_batch(500) - prompt = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name) - tokens = sum([len(self.encoder.encode(message.content)) for message in prompt]) + # Act + truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name) + tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history]) + # Assert # The original object has been modified. Verify certain properties - assert len(chat_messages) < 500 - assert len(chat_messages) > 1 + assert len(chat_history) < 500 + assert len(chat_history) > 1 assert tokens <= self.max_prompt_size def test_truncate_message_first_large(self): - chat_messages = ChatMessageFactory.build_batch(25) + # Arrange + chat_history = ChatMessageFactory.build_batch(25) big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=2000)) big_chat_message.content = big_chat_message.content + "\n" + "Question?" copy_big_chat_message = big_chat_message.copy() - chat_messages.insert(0, big_chat_message) - tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages]) + chat_history.insert(0, big_chat_message) + tokens = sum([len(self.encoder.encode(message.content)) for message in chat_history]) - prompt = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name) - tokens = sum([len(self.encoder.encode(message.content)) for message in prompt]) + # Act + truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name) + tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history]) + # Assert # The original object has been modified. Verify certain properties - assert len(chat_messages) == 1 - assert prompt[0] != copy_big_chat_message + assert len(chat_history) == 1 + assert truncated_chat_history[0] != copy_big_chat_message assert tokens <= self.max_prompt_size def test_truncate_message_last_large(self): - chat_messages = ChatMessageFactory.build_batch(25) + # Arrange + chat_history = ChatMessageFactory.build_batch(25) + chat_history[0].role = "system" # Mark the first message as system message big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=1000)) big_chat_message.content = big_chat_message.content + "\n" + "Question?" copy_big_chat_message = big_chat_message.copy() - chat_messages.insert(0, big_chat_message) - tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages]) + chat_history.insert(0, big_chat_message) + initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_history]) - prompt = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name) - tokens = sum([len(self.encoder.encode(message.content)) for message in prompt]) + # Act + truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name) + final_tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history]) + # Assert # The original object has been modified. Verify certain properties. - assert len(prompt) == ( - len(chat_messages) + 1 + assert len(truncated_chat_history) == ( + len(chat_history) + 1 ) # Because the system_prompt is popped off from the chat_messages lsit - assert len(prompt) < 26 - assert len(prompt) > 1 - assert prompt[0] != copy_big_chat_message - assert tokens <= self.max_prompt_size + assert len(truncated_chat_history) < 26 + assert len(truncated_chat_history) > 1 + assert truncated_chat_history[0] != copy_big_chat_message + assert initial_tokens > self.max_prompt_size + assert final_tokens <= self.max_prompt_size + + def test_truncate_single_large_non_system_message(self): + # Arrange + big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=2000)) + big_chat_message.content = big_chat_message.content + "\n" + "Question?" + big_chat_message.role = "user" + copy_big_chat_message = big_chat_message.copy() + chat_messages = [big_chat_message] + initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages]) + + # Act + truncated_chat_history = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name) + final_tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history]) + + # Assert + # The original object has been modified. Verify certain properties + assert initial_tokens > self.max_prompt_size + assert final_tokens <= self.max_prompt_size + assert len(chat_messages) == 1 + assert truncated_chat_history[0] != copy_big_chat_message