diff --git a/src/database/adapters/__init__.py b/src/database/adapters/__init__.py index 28999369..4b9b54ef 100644 --- a/src/database/adapters/__init__.py +++ b/src/database/adapters/__init__.py @@ -1,3 +1,4 @@ +import math from typing import Optional, Type, TypeVar, List from datetime import date, datetime, timedelta import secrets @@ -437,12 +438,19 @@ class EntryAdapters: @staticmethod def search_with_embeddings( - user: KhojUser, embeddings: Tensor, max_results: int = 10, file_type_filter: str = None, raw_query: str = None + user: KhojUser, + embeddings: Tensor, + max_results: int = 10, + file_type_filter: str = None, + raw_query: str = None, + max_distance: float = math.inf, ): relevant_entries = EntryAdapters.apply_filters(user, raw_query, file_type_filter) relevant_entries = relevant_entries.filter(user=user).annotate( distance=CosineDistance("embeddings", embeddings) ) + relevant_entries = relevant_entries.filter(distance__lte=max_distance) + if file_type_filter: relevant_entries = relevant_entries.filter(file_type=file_type_filter) relevant_entries = relevant_entries.order_by("distance") diff --git a/src/interface/desktop/search.html b/src/interface/desktop/search.html index 315e6972..aa8aa662 100644 --- a/src/interface/desktop/search.html +++ b/src/interface/desktop/search.html @@ -188,7 +188,6 @@ fetch(url, { headers }) .then(response => response.json()) .then(data => { - console.log(data); document.getElementById("results").innerHTML = render_results(data, query, type); }); } diff --git a/src/khoj/interface/web/base_config.html b/src/khoj/interface/web/base_config.html index 05119fad..d9546249 100644 --- a/src/khoj/interface/web/base_config.html +++ b/src/khoj/interface/web/base_config.html @@ -121,7 +121,7 @@ div.finalize-buttons { display: grid; gap: 8px; - padding: 24px 16px; + padding: 32px 0px 0px; width: 320px; border-radius: 4px; overflow: hidden; @@ -274,7 +274,9 @@ 100% { transform: rotate(360deg); } } - + #status { + padding-top: 32px; + } div.finalize-actions { grid-auto-flow: column; grid-gap: 24px; @@ -347,6 +349,12 @@ width: auto; } + #status { + padding-top: 12px; + } + div.finalize-actions { + padding: 12px 0 0; + } div.finalize-buttons { padding: 0; } diff --git a/src/khoj/interface/web/chat.html b/src/khoj/interface/web/chat.html index f15489ec..e1c9979d 100644 --- a/src/khoj/interface/web/chat.html +++ b/src/khoj/interface/web/chat.html @@ -417,6 +417,9 @@ To get started, just start typing below. You can also type / to see a list of co display: block; } + div.references { + padding-top: 8px; + } div.reference { display: grid; grid-template-rows: auto; diff --git a/src/khoj/interface/web/config.html b/src/khoj/interface/web/config.html index 34a4f642..497dd31a 100644 --- a/src/khoj/interface/web/config.html +++ b/src/khoj/interface/web/config.html @@ -104,6 +104,19 @@ +
+ +
+
+
+
+ +
+
+ +
+
+

Features

@@ -221,23 +234,7 @@
{% endif %} -
-
- - -
- -
-
-
-
- -
-
- -
-
-
+
{% endblock %} diff --git a/src/khoj/interface/web/content_source_notion_input.html b/src/khoj/interface/web/content_source_notion_input.html index 18eb5a7f..d5427ab3 100644 --- a/src/khoj/interface/web/content_source_notion_input.html +++ b/src/khoj/interface/web/content_source_notion_input.html @@ -41,6 +41,11 @@ return; } + const submitButton = document.getElementById("submit"); + submitButton.disabled = true; + submitButton.innerHTML = "Saving..."; + + // Save Notion config on server const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1]; fetch('/api/config/data/content-source/notion', { method: 'POST', @@ -53,15 +58,39 @@ }) }) .then(response => response.json()) + .then(data => { data["status"] === "ok" ? data : Promise.reject(data) }) + .catch(error => { + document.getElementById("success").innerHTML = "⚠️ Failed to save Notion settings."; + document.getElementById("success").style.display = "block"; + submitButton.innerHTML = "⚠️ Failed to save settings"; + setTimeout(function() { + submitButton.innerHTML = "Save"; + submitButton.disabled = false; + }, 2000); + return; + }); + + // Index Notion content on server + fetch('/api/update?t=notion') + .then(response => response.json()) + .then(data => { data["status"] == "ok" ? data : Promise.reject(data) }) .then(data => { - if (data["status"] == "ok") { - document.getElementById("success").innerHTML = "✅ Successfully updated. Go to your settings page to complete setup."; - document.getElementById("success").style.display = "block"; - } else { - document.getElementById("success").innerHTML = "⚠️ Failed to update settings."; - document.getElementById("success").style.display = "block"; - } + document.getElementById("success").style.display = "none"; + submitButton.innerHTML = "✅ Successfully updated"; + setTimeout(function() { + submitButton.innerHTML = "Save"; + submitButton.disabled = false; + }, 2000); }) + .catch(error => { + document.getElementById("success").innerHTML = "⚠️ Failed to save Notion content."; + document.getElementById("success").style.display = "block"; + submitButton.innerHTML = "⚠️ Failed to save content"; + setTimeout(function() { + submitButton.innerHTML = "Save"; + submitButton.disabled = false; + }, 2000); + }); }); {% endblock %} diff --git a/src/khoj/interface/web/search.html b/src/khoj/interface/web/search.html index dcd98ede..5331ea92 100644 --- a/src/khoj/interface/web/search.html +++ b/src/khoj/interface/web/search.html @@ -189,7 +189,6 @@ }) .then(response => response.json()) .then(data => { - console.log(data); document.getElementById("results").innerHTML = render_results(data, query, type); }); } diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index 71088817..b86ebc6b 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -111,15 +111,13 @@ def converse( return iter([prompts.no_notes_found.format()]) elif conversation_command == ConversationCommand.General or is_none_or_empty(compiled_references): conversation_primer = prompts.general_conversation.format(query=user_query) - personality = prompts.personality.format(current_date=current_date) else: - conversation_primer = prompts.general_conversation.format(query=user_query) - personality = prompts.personality_with_notes.format(current_date=current_date, references=compiled_references) + conversation_primer = prompts.notes_conversation.format(query=user_query, references=compiled_references) # Setup Prompt with Primer or Conversation History messages = generate_chatml_messages_with_context( conversation_primer, - personality, + prompts.personality.format(current_date=current_date), conversation_log, model, max_prompt_size, @@ -136,4 +134,5 @@ def converse( temperature=temperature, openai_api_key=api_key, completion_func=completion_func, + model_kwargs={"stop": ["Notes:\n["]}, ) diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index 130532e0..dce72e1f 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -69,15 +69,15 @@ def completion_with_backoff(**kwargs): reraise=True, ) def chat_completion_with_backoff( - messages, compiled_references, model_name, temperature, openai_api_key=None, completion_func=None + messages, compiled_references, model_name, temperature, openai_api_key=None, completion_func=None, model_kwargs=None ): g = ThreadedGenerator(compiled_references, completion_func=completion_func) - t = Thread(target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key)) + t = Thread(target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key, model_kwargs)) t.start() return g -def llm_thread(g, messages, model_name, temperature, openai_api_key=None): +def llm_thread(g, messages, model_name, temperature, openai_api_key=None, model_kwargs=None): callback_handler = StreamingChatCallbackHandler(g) chat = ChatOpenAI( streaming=True, @@ -86,6 +86,7 @@ def llm_thread(g, messages, model_name, temperature, openai_api_key=None): model_name=model_name, # type: ignore temperature=temperature, openai_api_key=openai_api_key or os.getenv("OPENAI_API_KEY"), + model_kwargs=model_kwargs, request_timeout=20, max_retries=1, client=None, diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index 78a42995..c11c38ba 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -13,7 +13,7 @@ You were created by Khoj Inc. with the following capabilities: - You *CAN REMEMBER ALL NOTES and PERSONAL INFORMATION FOREVER* that the user ever shares with you. - You cannot set reminders. - Say "I don't know" or "I don't understand" if you don't know what to say or if you don't know the answer to a question. -- You ask friendly, inquisitive follow-up QUESTIONS to collect more detail about their experiences and better understand the user's intent. These questions end with a question mark and seek to better understand the user. +- Ask crisp follow-up questions to get additional context, when the answer cannot be inferred from the provided notes or past conversations. - Sometimes the user will share personal information that needs to be remembered, like an account ID or a residential address. These can be acknowledged with a simple "Got it" or "Okay". Note: More information about you, the company or other Khoj apps can be found at https://khoj.dev. @@ -21,25 +21,6 @@ Today is {current_date} in UTC. """.strip() ) -personality_with_notes = PromptTemplate.from_template( - """ -You are Khoj, a smart, inquisitive and helpful personal assistant. -Use your general knowledge and the past conversation with the user as context to inform your responses. -You were created by Khoj Inc. with the following capabilities: - -- You *CAN REMEMBER ALL NOTES and PERSONAL INFORMATION FOREVER* that the user ever shares with you. -- You cannot set reminders. -- Say "I don't know" or "I don't understand" if you don't know what to say or if you don't know the answer to a question. -- You ask friendly, inquisitive follow-up QUESTIONS to collect more detail about their experiences and better understand the user's intent. These questions end with a question mark and seek to better understand the user. -- Sometimes the user will share personal information that needs to be remembered, like an account ID or a residential address. These can be acknowledged with a simple "Got it" or "Okay". - -Note: More information about you, the company or other Khoj apps can be found at https://khoj.dev. -Today is {current_date} in UTC. - -User's Notes: -{references} -""".strip() -) ## General Conversation ## -- general_conversation = PromptTemplate.from_template( @@ -108,14 +89,13 @@ conversation_llamav2 = PromptTemplate.from_template( ## -- notes_conversation = PromptTemplate.from_template( """ -Using my personal notes and our past conversations as context, answer the following question. -Ask crisp follow-up questions to get additional context, when the answer cannot be inferred from the provided notes or past conversations. -These questions should end with a question mark. +Use my personal notes and our past conversations to inform your response. +Ask crisp follow-up questions to get additional context, when a helpful response cannot be provided from the provided notes or past conversations. Notes: {references} -Question: {query} +Query: {query} """.strip() ) diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index ddfe9bc1..fbdfbd63 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -177,11 +177,15 @@ async def set_content_config_github_data( user = request.user.object - await adapters.set_user_github_config( - user=user, - pat_token=updated_config.pat_token, - repos=updated_config.repos, - ) + try: + await adapters.set_user_github_config( + user=user, + pat_token=updated_config.pat_token, + repos=updated_config.repos, + ) + except Exception as e: + logger.error(e, exc_info=True) + raise HTTPException(status_code=500, detail="Failed to set Github config") update_telemetry_state( request=request, @@ -205,10 +209,14 @@ async def set_content_config_notion_data( user = request.user.object - await adapters.set_notion_config( - user=user, - token=updated_config.token, - ) + try: + await adapters.set_notion_config( + user=user, + token=updated_config.token, + ) + except Exception as e: + logger.error(e, exc_info=True) + raise HTTPException(status_code=500, detail="Failed to set Github config") update_telemetry_state( request=request, @@ -348,7 +356,7 @@ async def search( n: Optional[int] = 5, t: Optional[SearchType] = SearchType.All, r: Optional[bool] = False, - score_threshold: Optional[Union[float, None]] = None, + max_distance: Optional[Union[float, None]] = None, dedupe: Optional[bool] = True, client: Optional[str] = None, user_agent: Optional[str] = Header(None), @@ -367,12 +375,12 @@ async def search( # initialize variables user_query = q.strip() results_count = n or 5 - score_threshold = score_threshold if score_threshold is not None else -math.inf + max_distance = max_distance if max_distance is not None else math.inf search_futures: List[concurrent.futures.Future] = [] # return cached results, if available if user: - query_cache_key = f"{user_query}-{n}-{t}-{r}-{score_threshold}-{dedupe}" + query_cache_key = f"{user_query}-{n}-{t}-{r}-{max_distance}-{dedupe}" if query_cache_key in state.query_cache[user.uuid]: logger.debug(f"Return response from query cache") return state.query_cache[user.uuid][query_cache_key] @@ -410,7 +418,7 @@ async def search( t, question_embedding=encoded_asymmetric_query, rank_results=r or False, - score_threshold=score_threshold, + max_distance=max_distance, ) ] @@ -423,7 +431,6 @@ async def search( results_count, state.search_models.image_search, state.content_index.image, - score_threshold=score_threshold, ) ] @@ -446,11 +453,10 @@ async def search( # Collate results results += text_search.collate_results(hits, dedupe=dedupe) - if r: - results = text_search.rerank_and_sort_results(results, query=defiltered_query)[:results_count] - else: # Sort results across all content types and take top results - results = sorted(results, key=lambda x: float(x.score))[:results_count] + results = text_search.rerank_and_sort_results(results, query=defiltered_query, rank_results=r)[ + :results_count + ] # Cache results if user: @@ -575,6 +581,7 @@ async def chat( request: Request, q: str, n: Optional[int] = 5, + d: Optional[float] = 0.15, client: Optional[str] = None, stream: Optional[bool] = False, user_agent: Optional[str] = Header(None), @@ -591,7 +598,7 @@ async def chat( meta_log = (await ConversationAdapters.aget_conversation_by_user(user)).conversation_log compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions( - request, meta_log, q, (n or 5), conversation_command + request, meta_log, q, (n or 5), (d or math.inf), conversation_command ) if conversation_command == ConversationCommand.Default and is_none_or_empty(compiled_references): @@ -655,6 +662,7 @@ async def extract_references_and_questions( meta_log: dict, q: str, n: int, + d: float, conversation_type: ConversationCommand = ConversationCommand.Default, ): user = request.user.object if request.user.is_authenticated else None @@ -715,7 +723,7 @@ async def extract_references_and_questions( request=request, n=n_items, r=True, - score_threshold=-5.0, + max_distance=d, dedupe=False, ) ) diff --git a/src/khoj/search_type/image_search.py b/src/khoj/search_type/image_search.py index d7f486af..214118fc 100644 --- a/src/khoj/search_type/image_search.py +++ b/src/khoj/search_type/image_search.py @@ -146,7 +146,7 @@ def extract_metadata(image_name): async def query( - raw_query, count, search_model: ImageSearchModel, content: ImageContent, score_threshold: float = -math.inf + raw_query, count, search_model: ImageSearchModel, content: ImageContent, score_threshold: float = math.inf ): # Set query to image content if query is of form file:/path/to/file.png if raw_query.startswith("file:") and pathlib.Path(raw_query[5:]).is_file(): @@ -167,7 +167,8 @@ async def query( # Compute top_k ranked images based on cosine-similarity b/w query and all image embeddings. with timer("Search Time", logger): image_hits = { - result["corpus_id"]: {"image_score": result["score"], "score": result["score"]} + # Map scores to distance metric by multiplying by -1 + result["corpus_id"]: {"image_score": -1 * result["score"], "score": -1 * result["score"]} for result in util.semantic_search(query_embedding, content.image_embeddings, top_k=count)[0] } @@ -204,7 +205,7 @@ async def query( ] # Filter results by score threshold - hits = [hit for hit in hits if hit["image_score"] >= score_threshold] + hits = [hit for hit in hits if hit["image_score"] <= score_threshold] # Sort the images based on their combined metadata, image scores return sorted(hits, key=lambda hit: hit["score"], reverse=True) diff --git a/src/khoj/search_type/text_search.py b/src/khoj/search_type/text_search.py index ba2fc9ec..041c385f 100644 --- a/src/khoj/search_type/text_search.py +++ b/src/khoj/search_type/text_search.py @@ -105,7 +105,7 @@ async def query( type: SearchType = SearchType.All, question_embedding: Union[torch.Tensor, None] = None, rank_results: bool = False, - score_threshold: float = -math.inf, + max_distance: float = math.inf, ) -> Tuple[List[dict], List[Entry]]: "Search for entries that answer the query" @@ -127,6 +127,7 @@ async def query( max_results=top_k, file_type_filter=file_type, raw_query=raw_query, + max_distance=max_distance, ).all() hits = await sync_to_async(list)(hits) # type: ignore[call-arg] @@ -177,12 +178,16 @@ def deduplicated_search_responses(hits: List[SearchResponse]): ) -def rerank_and_sort_results(hits, query): +def rerank_and_sort_results(hits, query, rank_results): + # If we have more than one result and reranking is enabled + rank_results = rank_results and len(list(hits)) > 1 + # Score all retrieved entries using the cross-encoder - hits = cross_encoder_score(query, hits) + if rank_results: + hits = cross_encoder_score(query, hits) # Sort results by cross-encoder score followed by bi-encoder score - hits = sort_results(rank_results=True, hits=hits) + hits = sort_results(rank_results=rank_results, hits=hits) return hits @@ -217,9 +222,9 @@ def cross_encoder_score(query: str, hits: List[SearchResponse]) -> List[SearchRe with timer("Cross-Encoder Predict Time", logger, state.device): cross_scores = state.cross_encoder_model.predict(query, hits) - # Store cross-encoder scores in results dictionary for ranking + # Convert cross-encoder scores to distances and pass in hits for reranking for idx in range(len(cross_scores)): - hits[idx]["cross_score"] = cross_scores[idx] + hits[idx]["cross_score"] = -1 * cross_scores[idx] return hits @@ -227,7 +232,7 @@ def cross_encoder_score(query: str, hits: List[SearchResponse]) -> List[SearchRe def sort_results(rank_results: bool, hits: List[dict]) -> List[dict]: """Order results by cross-encoder score followed by bi-encoder score""" with timer("Rank Time", logger, state.device): - hits.sort(key=lambda x: x["score"], reverse=True) # sort by bi-encoder score + hits.sort(key=lambda x: x["score"]) # sort by bi-encoder score if rank_results: - hits.sort(key=lambda x: x["cross_score"], reverse=True) # sort by cross-encoder score + hits.sort(key=lambda x: x["cross_score"]) # sort by cross-encoder score return hits diff --git a/tests/test_openai_chat_director.py b/tests/test_openai_chat_director.py index c7d2e0ec..a8c85787 100644 --- a/tests/test_openai_chat_director.py +++ b/tests/test_openai_chat_director.py @@ -308,6 +308,7 @@ def test_ask_for_clarification_if_not_enough_context_in_question(chat_client_no_ "which of namita's sons", "the birth order", "provide more context", + "provide me with more context", ] assert response.status_code == 200 assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (