+
{% endblock %}
diff --git a/src/khoj/interface/web/content_source_notion_input.html b/src/khoj/interface/web/content_source_notion_input.html
index 18eb5a7f..d5427ab3 100644
--- a/src/khoj/interface/web/content_source_notion_input.html
+++ b/src/khoj/interface/web/content_source_notion_input.html
@@ -41,6 +41,11 @@
return;
}
+ const submitButton = document.getElementById("submit");
+ submitButton.disabled = true;
+ submitButton.innerHTML = "Saving...";
+
+ // Save Notion config on server
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
fetch('/api/config/data/content-source/notion', {
method: 'POST',
@@ -53,15 +58,39 @@
})
})
.then(response => response.json())
+ .then(data => { data["status"] === "ok" ? data : Promise.reject(data) })
+ .catch(error => {
+ document.getElementById("success").innerHTML = "⚠️ Failed to save Notion settings.";
+ document.getElementById("success").style.display = "block";
+ submitButton.innerHTML = "⚠️ Failed to save settings";
+ setTimeout(function() {
+ submitButton.innerHTML = "Save";
+ submitButton.disabled = false;
+ }, 2000);
+ return;
+ });
+
+ // Index Notion content on server
+ fetch('/api/update?t=notion')
+ .then(response => response.json())
+ .then(data => { data["status"] == "ok" ? data : Promise.reject(data) })
.then(data => {
- if (data["status"] == "ok") {
- document.getElementById("success").innerHTML = "✅ Successfully updated. Go to your settings page to complete setup.";
- document.getElementById("success").style.display = "block";
- } else {
- document.getElementById("success").innerHTML = "⚠️ Failed to update settings.";
- document.getElementById("success").style.display = "block";
- }
+ document.getElementById("success").style.display = "none";
+ submitButton.innerHTML = "✅ Successfully updated";
+ setTimeout(function() {
+ submitButton.innerHTML = "Save";
+ submitButton.disabled = false;
+ }, 2000);
})
+ .catch(error => {
+ document.getElementById("success").innerHTML = "⚠️ Failed to save Notion content.";
+ document.getElementById("success").style.display = "block";
+ submitButton.innerHTML = "⚠️ Failed to save content";
+ setTimeout(function() {
+ submitButton.innerHTML = "Save";
+ submitButton.disabled = false;
+ }, 2000);
+ });
});
{% endblock %}
diff --git a/src/khoj/interface/web/search.html b/src/khoj/interface/web/search.html
index dcd98ede..5331ea92 100644
--- a/src/khoj/interface/web/search.html
+++ b/src/khoj/interface/web/search.html
@@ -189,7 +189,6 @@
})
.then(response => response.json())
.then(data => {
- console.log(data);
document.getElementById("results").innerHTML = render_results(data, query, type);
});
}
diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py
index 71088817..b86ebc6b 100644
--- a/src/khoj/processor/conversation/openai/gpt.py
+++ b/src/khoj/processor/conversation/openai/gpt.py
@@ -111,15 +111,13 @@ def converse(
return iter([prompts.no_notes_found.format()])
elif conversation_command == ConversationCommand.General or is_none_or_empty(compiled_references):
conversation_primer = prompts.general_conversation.format(query=user_query)
- personality = prompts.personality.format(current_date=current_date)
else:
- conversation_primer = prompts.general_conversation.format(query=user_query)
- personality = prompts.personality_with_notes.format(current_date=current_date, references=compiled_references)
+ conversation_primer = prompts.notes_conversation.format(query=user_query, references=compiled_references)
# Setup Prompt with Primer or Conversation History
messages = generate_chatml_messages_with_context(
conversation_primer,
- personality,
+ prompts.personality.format(current_date=current_date),
conversation_log,
model,
max_prompt_size,
@@ -136,4 +134,5 @@ def converse(
temperature=temperature,
openai_api_key=api_key,
completion_func=completion_func,
+ model_kwargs={"stop": ["Notes:\n["]},
)
diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py
index 130532e0..dce72e1f 100644
--- a/src/khoj/processor/conversation/openai/utils.py
+++ b/src/khoj/processor/conversation/openai/utils.py
@@ -69,15 +69,15 @@ def completion_with_backoff(**kwargs):
reraise=True,
)
def chat_completion_with_backoff(
- messages, compiled_references, model_name, temperature, openai_api_key=None, completion_func=None
+ messages, compiled_references, model_name, temperature, openai_api_key=None, completion_func=None, model_kwargs=None
):
g = ThreadedGenerator(compiled_references, completion_func=completion_func)
- t = Thread(target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key))
+ t = Thread(target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key, model_kwargs))
t.start()
return g
-def llm_thread(g, messages, model_name, temperature, openai_api_key=None):
+def llm_thread(g, messages, model_name, temperature, openai_api_key=None, model_kwargs=None):
callback_handler = StreamingChatCallbackHandler(g)
chat = ChatOpenAI(
streaming=True,
@@ -86,6 +86,7 @@ def llm_thread(g, messages, model_name, temperature, openai_api_key=None):
model_name=model_name, # type: ignore
temperature=temperature,
openai_api_key=openai_api_key or os.getenv("OPENAI_API_KEY"),
+ model_kwargs=model_kwargs,
request_timeout=20,
max_retries=1,
client=None,
diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py
index 78a42995..c11c38ba 100644
--- a/src/khoj/processor/conversation/prompts.py
+++ b/src/khoj/processor/conversation/prompts.py
@@ -13,7 +13,7 @@ You were created by Khoj Inc. with the following capabilities:
- You *CAN REMEMBER ALL NOTES and PERSONAL INFORMATION FOREVER* that the user ever shares with you.
- You cannot set reminders.
- Say "I don't know" or "I don't understand" if you don't know what to say or if you don't know the answer to a question.
-- You ask friendly, inquisitive follow-up QUESTIONS to collect more detail about their experiences and better understand the user's intent. These questions end with a question mark and seek to better understand the user.
+- Ask crisp follow-up questions to get additional context, when the answer cannot be inferred from the provided notes or past conversations.
- Sometimes the user will share personal information that needs to be remembered, like an account ID or a residential address. These can be acknowledged with a simple "Got it" or "Okay".
Note: More information about you, the company or other Khoj apps can be found at https://khoj.dev.
@@ -21,25 +21,6 @@ Today is {current_date} in UTC.
""".strip()
)
-personality_with_notes = PromptTemplate.from_template(
- """
-You are Khoj, a smart, inquisitive and helpful personal assistant.
-Use your general knowledge and the past conversation with the user as context to inform your responses.
-You were created by Khoj Inc. with the following capabilities:
-
-- You *CAN REMEMBER ALL NOTES and PERSONAL INFORMATION FOREVER* that the user ever shares with you.
-- You cannot set reminders.
-- Say "I don't know" or "I don't understand" if you don't know what to say or if you don't know the answer to a question.
-- You ask friendly, inquisitive follow-up QUESTIONS to collect more detail about their experiences and better understand the user's intent. These questions end with a question mark and seek to better understand the user.
-- Sometimes the user will share personal information that needs to be remembered, like an account ID or a residential address. These can be acknowledged with a simple "Got it" or "Okay".
-
-Note: More information about you, the company or other Khoj apps can be found at https://khoj.dev.
-Today is {current_date} in UTC.
-
-User's Notes:
-{references}
-""".strip()
-)
## General Conversation
## --
general_conversation = PromptTemplate.from_template(
@@ -108,14 +89,13 @@ conversation_llamav2 = PromptTemplate.from_template(
## --
notes_conversation = PromptTemplate.from_template(
"""
-Using my personal notes and our past conversations as context, answer the following question.
-Ask crisp follow-up questions to get additional context, when the answer cannot be inferred from the provided notes or past conversations.
-These questions should end with a question mark.
+Use my personal notes and our past conversations to inform your response.
+Ask crisp follow-up questions to get additional context, when a helpful response cannot be provided from the provided notes or past conversations.
Notes:
{references}
-Question: {query}
+Query: {query}
""".strip()
)
diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py
index ddfe9bc1..fbdfbd63 100644
--- a/src/khoj/routers/api.py
+++ b/src/khoj/routers/api.py
@@ -177,11 +177,15 @@ async def set_content_config_github_data(
user = request.user.object
- await adapters.set_user_github_config(
- user=user,
- pat_token=updated_config.pat_token,
- repos=updated_config.repos,
- )
+ try:
+ await adapters.set_user_github_config(
+ user=user,
+ pat_token=updated_config.pat_token,
+ repos=updated_config.repos,
+ )
+ except Exception as e:
+ logger.error(e, exc_info=True)
+ raise HTTPException(status_code=500, detail="Failed to set Github config")
update_telemetry_state(
request=request,
@@ -205,10 +209,14 @@ async def set_content_config_notion_data(
user = request.user.object
- await adapters.set_notion_config(
- user=user,
- token=updated_config.token,
- )
+ try:
+ await adapters.set_notion_config(
+ user=user,
+ token=updated_config.token,
+ )
+ except Exception as e:
+ logger.error(e, exc_info=True)
+ raise HTTPException(status_code=500, detail="Failed to set Github config")
update_telemetry_state(
request=request,
@@ -348,7 +356,7 @@ async def search(
n: Optional[int] = 5,
t: Optional[SearchType] = SearchType.All,
r: Optional[bool] = False,
- score_threshold: Optional[Union[float, None]] = None,
+ max_distance: Optional[Union[float, None]] = None,
dedupe: Optional[bool] = True,
client: Optional[str] = None,
user_agent: Optional[str] = Header(None),
@@ -367,12 +375,12 @@ async def search(
# initialize variables
user_query = q.strip()
results_count = n or 5
- score_threshold = score_threshold if score_threshold is not None else -math.inf
+ max_distance = max_distance if max_distance is not None else math.inf
search_futures: List[concurrent.futures.Future] = []
# return cached results, if available
if user:
- query_cache_key = f"{user_query}-{n}-{t}-{r}-{score_threshold}-{dedupe}"
+ query_cache_key = f"{user_query}-{n}-{t}-{r}-{max_distance}-{dedupe}"
if query_cache_key in state.query_cache[user.uuid]:
logger.debug(f"Return response from query cache")
return state.query_cache[user.uuid][query_cache_key]
@@ -410,7 +418,7 @@ async def search(
t,
question_embedding=encoded_asymmetric_query,
rank_results=r or False,
- score_threshold=score_threshold,
+ max_distance=max_distance,
)
]
@@ -423,7 +431,6 @@ async def search(
results_count,
state.search_models.image_search,
state.content_index.image,
- score_threshold=score_threshold,
)
]
@@ -446,11 +453,10 @@ async def search(
# Collate results
results += text_search.collate_results(hits, dedupe=dedupe)
- if r:
- results = text_search.rerank_and_sort_results(results, query=defiltered_query)[:results_count]
- else:
# Sort results across all content types and take top results
- results = sorted(results, key=lambda x: float(x.score))[:results_count]
+ results = text_search.rerank_and_sort_results(results, query=defiltered_query, rank_results=r)[
+ :results_count
+ ]
# Cache results
if user:
@@ -575,6 +581,7 @@ async def chat(
request: Request,
q: str,
n: Optional[int] = 5,
+ d: Optional[float] = 0.15,
client: Optional[str] = None,
stream: Optional[bool] = False,
user_agent: Optional[str] = Header(None),
@@ -591,7 +598,7 @@ async def chat(
meta_log = (await ConversationAdapters.aget_conversation_by_user(user)).conversation_log
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
- request, meta_log, q, (n or 5), conversation_command
+ request, meta_log, q, (n or 5), (d or math.inf), conversation_command
)
if conversation_command == ConversationCommand.Default and is_none_or_empty(compiled_references):
@@ -655,6 +662,7 @@ async def extract_references_and_questions(
meta_log: dict,
q: str,
n: int,
+ d: float,
conversation_type: ConversationCommand = ConversationCommand.Default,
):
user = request.user.object if request.user.is_authenticated else None
@@ -715,7 +723,7 @@ async def extract_references_and_questions(
request=request,
n=n_items,
r=True,
- score_threshold=-5.0,
+ max_distance=d,
dedupe=False,
)
)
diff --git a/src/khoj/search_type/image_search.py b/src/khoj/search_type/image_search.py
index d7f486af..214118fc 100644
--- a/src/khoj/search_type/image_search.py
+++ b/src/khoj/search_type/image_search.py
@@ -146,7 +146,7 @@ def extract_metadata(image_name):
async def query(
- raw_query, count, search_model: ImageSearchModel, content: ImageContent, score_threshold: float = -math.inf
+ raw_query, count, search_model: ImageSearchModel, content: ImageContent, score_threshold: float = math.inf
):
# Set query to image content if query is of form file:/path/to/file.png
if raw_query.startswith("file:") and pathlib.Path(raw_query[5:]).is_file():
@@ -167,7 +167,8 @@ async def query(
# Compute top_k ranked images based on cosine-similarity b/w query and all image embeddings.
with timer("Search Time", logger):
image_hits = {
- result["corpus_id"]: {"image_score": result["score"], "score": result["score"]}
+ # Map scores to distance metric by multiplying by -1
+ result["corpus_id"]: {"image_score": -1 * result["score"], "score": -1 * result["score"]}
for result in util.semantic_search(query_embedding, content.image_embeddings, top_k=count)[0]
}
@@ -204,7 +205,7 @@ async def query(
]
# Filter results by score threshold
- hits = [hit for hit in hits if hit["image_score"] >= score_threshold]
+ hits = [hit for hit in hits if hit["image_score"] <= score_threshold]
# Sort the images based on their combined metadata, image scores
return sorted(hits, key=lambda hit: hit["score"], reverse=True)
diff --git a/src/khoj/search_type/text_search.py b/src/khoj/search_type/text_search.py
index ba2fc9ec..041c385f 100644
--- a/src/khoj/search_type/text_search.py
+++ b/src/khoj/search_type/text_search.py
@@ -105,7 +105,7 @@ async def query(
type: SearchType = SearchType.All,
question_embedding: Union[torch.Tensor, None] = None,
rank_results: bool = False,
- score_threshold: float = -math.inf,
+ max_distance: float = math.inf,
) -> Tuple[List[dict], List[Entry]]:
"Search for entries that answer the query"
@@ -127,6 +127,7 @@ async def query(
max_results=top_k,
file_type_filter=file_type,
raw_query=raw_query,
+ max_distance=max_distance,
).all()
hits = await sync_to_async(list)(hits) # type: ignore[call-arg]
@@ -177,12 +178,16 @@ def deduplicated_search_responses(hits: List[SearchResponse]):
)
-def rerank_and_sort_results(hits, query):
+def rerank_and_sort_results(hits, query, rank_results):
+ # If we have more than one result and reranking is enabled
+ rank_results = rank_results and len(list(hits)) > 1
+
# Score all retrieved entries using the cross-encoder
- hits = cross_encoder_score(query, hits)
+ if rank_results:
+ hits = cross_encoder_score(query, hits)
# Sort results by cross-encoder score followed by bi-encoder score
- hits = sort_results(rank_results=True, hits=hits)
+ hits = sort_results(rank_results=rank_results, hits=hits)
return hits
@@ -217,9 +222,9 @@ def cross_encoder_score(query: str, hits: List[SearchResponse]) -> List[SearchRe
with timer("Cross-Encoder Predict Time", logger, state.device):
cross_scores = state.cross_encoder_model.predict(query, hits)
- # Store cross-encoder scores in results dictionary for ranking
+ # Convert cross-encoder scores to distances and pass in hits for reranking
for idx in range(len(cross_scores)):
- hits[idx]["cross_score"] = cross_scores[idx]
+ hits[idx]["cross_score"] = -1 * cross_scores[idx]
return hits
@@ -227,7 +232,7 @@ def cross_encoder_score(query: str, hits: List[SearchResponse]) -> List[SearchRe
def sort_results(rank_results: bool, hits: List[dict]) -> List[dict]:
"""Order results by cross-encoder score followed by bi-encoder score"""
with timer("Rank Time", logger, state.device):
- hits.sort(key=lambda x: x["score"], reverse=True) # sort by bi-encoder score
+ hits.sort(key=lambda x: x["score"]) # sort by bi-encoder score
if rank_results:
- hits.sort(key=lambda x: x["cross_score"], reverse=True) # sort by cross-encoder score
+ hits.sort(key=lambda x: x["cross_score"]) # sort by cross-encoder score
return hits
diff --git a/tests/test_openai_chat_director.py b/tests/test_openai_chat_director.py
index c7d2e0ec..a8c85787 100644
--- a/tests/test_openai_chat_director.py
+++ b/tests/test_openai_chat_director.py
@@ -308,6 +308,7 @@ def test_ask_for_clarification_if_not_enough_context_in_question(chat_client_no_
"which of namita's sons",
"the birth order",
"provide more context",
+ "provide me with more context",
]
assert response.status_code == 200
assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (