diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index da40f41f..cb8cef89 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -378,6 +378,13 @@ async def aset_user_search_model(user: KhojUser, search_model_config_id: int): return new_config +async def aget_user_search_model(user: KhojUser): + config = await UserSearchModelConfig.objects.filter(user=user).prefetch_related("setting").afirst() + if not config: + return None + return config.setting + + class ClientApplicationAdapters: @staticmethod async def aget_client_application_by_id(client_id: str, client_secret: str): @@ -639,6 +646,12 @@ class EntryAdapters: deleted_count, _ = Entry.objects.filter(user=user, file_source=file_source).delete() return deleted_count + @staticmethod + async def adelete_all_entries(user: KhojUser, file_source: str = None): + if file_source is None: + return await Entry.objects.filter(user=user).adelete() + return await Entry.objects.filter(user=user, file_source=file_source).adelete() + @staticmethod def get_existing_entry_hashes_by_file(user: KhojUser, file_path: str): return Entry.objects.filter(user=user, file_path=file_path).values_list("hashed_value", flat=True) @@ -674,10 +687,6 @@ class EntryAdapters: .values_list("file_path", flat=True) ) - @staticmethod - async def adelete_all_entries(user: KhojUser): - return await Entry.objects.filter(user=user).adelete() - @staticmethod def get_size_of_indexed_data_in_mb(user: KhojUser): entries = Entry.objects.filter(user=user).iterator() diff --git a/src/khoj/interface/web/config.html b/src/khoj/interface/web/config.html index 90d3c522..1f1d1f1e 100644 --- a/src/khoj/interface/web/config.html +++ b/src/khoj/interface/web/config.html @@ -376,6 +376,11 @@ }; function updateSearchModel() { + let confirmation = window.confirm("All your existing data will be deleted, and you will have to regenerate it. Are you sure you want to continue?"); + if (!confirmation) { + return; + } + const searchModel = document.getElementById("search-models").value; const saveSearchModelButton = document.getElementById("save-search-model"); saveSearchModelButton.disabled = true; @@ -398,7 +403,7 @@ } let notificationBanner = document.getElementById("notification-banner"); - notificationBanner.innerHTML = "When updating the language model, be sure to delete all your saved content and re-initialize."; + notificationBanner.innerHTML = "Khoj can now better understand the language of your content! Manually sync your data from one of the Khoj clients to update your knowledge base."; notificationBanner.style.display = "block"; setTimeout(function() { notificationBanner.style.display = "none"; diff --git a/src/khoj/processor/content/notion/notion_to_entries.py b/src/khoj/processor/content/notion/notion_to_entries.py index 2ed62fd3..541a1732 100644 --- a/src/khoj/processor/content/notion/notion_to_entries.py +++ b/src/khoj/processor/content/notion/notion_to_entries.py @@ -93,7 +93,7 @@ class NotionToEntries(TextToEntries): json=self.body_params, ).json() responses.append(result) - if result["has_more"] == False: + if result.get("has_more", False) == False: break else: self.body_params.update({"start_cursor": result["next_cursor"]}) diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index f0234503..5ef3ff03 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -441,12 +441,14 @@ You are using the **{model}** model on the **{device}**. # -- user_location = PromptTemplate.from_template( """ +Mention the user's location only if it's relevant to the conversation. User's Location: {location} """.strip() ) user_name = PromptTemplate.from_template( """ +Mention the user's name only if it's relevant to the conversation. User's Name: {name} """.strip() ) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index fbfa17b0..6524da9d 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -11,7 +11,7 @@ from transformers import AutoTokenizer from khoj.database.adapters import ConversationAdapters from khoj.database.models import ClientApplication, KhojUser -from khoj.utils.helpers import merge_dicts +from khoj.utils.helpers import is_none_or_empty, merge_dicts logger = logging.getLogger(__name__) model_to_prompt_size = { @@ -159,10 +159,13 @@ def generate_chatml_messages_with_context( rest_backnforths += reciprocal_conversation_to_chatml([user_msg, assistant_msg])[::-1] # Format user and system messages to chatml format - system_chatml_message = [ChatMessage(content=system_message, role="system")] - user_chatml_message = [ChatMessage(content=user_message, role="user")] - - messages = user_chatml_message + rest_backnforths + system_chatml_message + messages = [] + if not is_none_or_empty(user_message): + messages.append(ChatMessage(content=user_message, role="user")) + if len(rest_backnforths) > 0: + messages += rest_backnforths + if not is_none_or_empty(system_message): + messages.append(ChatMessage(content=system_message, role="system")) # Truncate oldest messages from conversation history until under max supported prompt size by model messages = truncate_messages(messages, max_prompt_size, model_name, tokenizer_name) diff --git a/src/khoj/routers/api_config.py b/src/khoj/routers/api_config.py index c4498f5c..d23f131b 100644 --- a/src/khoj/routers/api_config.py +++ b/src/khoj/routers/api_config.py @@ -262,8 +262,12 @@ async def update_search_model( ): user = request.user.object + prev_config = await adapters.aget_user_search_model(user) new_config = await adapters.aset_user_search_model(user, int(id)) + if int(id) != prev_config.id: + await EntryAdapters.adelete_all_entries(user) + if new_config is None: return {"status": "error", "message": "Model not found"} else: diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 0d4561b7..c92a9793 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -208,11 +208,11 @@ async def generate_online_subqueries(q: str, conversation_history: dict, locatio response = json.loads(response) response = [q.strip() for q in response if q.strip()] if not isinstance(response, list) or not response or len(response) == 0: - logger.error(f"Invalid response for constructing subqueries: {response}") + logger.error(f"Invalid response for constructing subqueries: {response}. Returning original query: {q}") return [q] return response except Exception as e: - logger.error(f"Invalid response for constructing subqueries: {response}") + logger.error(f"Invalid response for constructing subqueries: {response}. Returning original query: {q}") return [q]