Merge branch 'fix/imports-and-references' of github.com:khoj-ai/khoj into fix/imports-and-references

This commit is contained in:
sabaimran
2023-11-11 12:59:31 -08:00
15 changed files with 183 additions and 126 deletions

View File

@@ -1,3 +1,4 @@
import math
from typing import Optional, Type, TypeVar, List
from datetime import date, datetime, timedelta
import secrets
@@ -437,12 +438,19 @@ class EntryAdapters:
@staticmethod
def search_with_embeddings(
user: KhojUser, embeddings: Tensor, max_results: int = 10, file_type_filter: str = None, raw_query: str = None
user: KhojUser,
embeddings: Tensor,
max_results: int = 10,
file_type_filter: str = None,
raw_query: str = None,
max_distance: float = math.inf,
):
relevant_entries = EntryAdapters.apply_filters(user, raw_query, file_type_filter)
relevant_entries = relevant_entries.filter(user=user).annotate(
distance=CosineDistance("embeddings", embeddings)
)
relevant_entries = relevant_entries.filter(distance__lte=max_distance)
if file_type_filter:
relevant_entries = relevant_entries.filter(file_type=file_type_filter)
relevant_entries = relevant_entries.order_by("distance")

View File

@@ -188,7 +188,6 @@
fetch(url, { headers })
.then(response => response.json())
.then(data => {
console.log(data);
document.getElementById("results").innerHTML = render_results(data, query, type);
});
}

View File

@@ -121,7 +121,7 @@
div.finalize-buttons {
display: grid;
gap: 8px;
padding: 24px 16px;
padding: 32px 0px 0px;
width: 320px;
border-radius: 4px;
overflow: hidden;
@@ -274,7 +274,9 @@
100% { transform: rotate(360deg); }
}
#status {
padding-top: 32px;
}
div.finalize-actions {
grid-auto-flow: column;
grid-gap: 24px;
@@ -347,6 +349,12 @@
width: auto;
}
#status {
padding-top: 12px;
}
div.finalize-actions {
padding: 12px 0 0;
}
div.finalize-buttons {
padding: 0;
}

View File

@@ -417,6 +417,9 @@ To get started, just start typing below. You can also type / to see a list of co
display: block;
}
div.references {
padding-top: 8px;
}
div.reference {
display: grid;
grid-template-rows: auto;

View File

@@ -104,6 +104,19 @@
</div>
</div>
</div>
<div class="general-settings section">
<div id="status" style="display: none;"></div>
</div>
<div class="section finalize-actions general-settings">
<div class="section-cards">
<div class="finalize-buttons">
<button id="configure" type="submit" title="Update index with the latest changes">💾 Save All</button>
</div>
<div class="finalize-buttons">
<button id="reinitialize" type="submit" title="Regenerate index from scratch">🔄 Reinitialize</button>
</div>
</div>
</div>
</div>
<div class="section">
<h2 class="section-title">Features</h2>
@@ -221,23 +234,7 @@
</div>
</div>
{% endif %}
<div class="section general-settings">
<div id="results-count" title="Number of items to show in search and use for chat response">
<label for="results-count-slider">Results Count: <span id="results-count-value">5</span></label>
<input type="range" id="results-count-slider" name="results-count-slider" min="1" max="10" step="1" value="5">
</div>
<div id="status" style="display: none;"></div>
</div>
<div class="section finalize-actions general-settings">
<div class="section-cards">
<div class="finalize-buttons">
<button id="configure" type="submit" title="Update index with the latest changes">⚙️ Configure</button>
</div>
<div class="finalize-buttons">
<button id="reinitialize" type="submit" title="Regenerate index from scratch">🔄 Reinitialize</button>
</div>
</div>
</div>
<div class="section"></div>
</div>
<script>
@@ -329,11 +326,11 @@
event.preventDefault();
updateIndex(
force=false,
successText="Configured successfully!",
successText="Saved!",
errorText="Unable to configure. Raise issue on Khoj <a href='https://github.com/khoj-ai/khoj/issues'>Github</a> or <a href='https://discord.gg/BDgyabRM6e'>Discord</a>.",
button=configure,
loadingText="Configuring...",
emoji="⚙️");
loadingText="Saving...",
emoji="💾");
});
var reinitialize = document.getElementById("reinitialize");
@@ -341,7 +338,7 @@
event.preventDefault();
updateIndex(
force=true,
successText="Reinitialized successfully!",
successText="Reinitialized!",
errorText="Unable to reinitialize. Raise issue on Khoj <a href='https://github.com/khoj-ai/khoj/issues'>Github</a> or <a href='https://discord.gg/BDgyabRM6e'>Discord</a>.",
button=reinitialize,
loadingText="Reinitializing...",
@@ -350,6 +347,7 @@
function updateIndex(force, successText, errorText, button, loadingText, emoji) {
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
const original_html = button.innerHTML;
button.disabled = true;
button.innerHTML = emoji + " " + loadingText;
fetch('/api/update?&client=web&force=' + force, {
@@ -361,15 +359,17 @@
})
.then(response => response.json())
.then(data => {
console.log('Success:', data);
if (data.detail != null) {
throw new Error(data.detail);
}
document.getElementById("status").innerHTML = emoji + " " + successText;
document.getElementById("status").style.display = "block";
document.getElementById("status").style.display = "none";
button.disabled = false;
button.innerHTML = '✅ Done!';
button.innerHTML = `${successText}`;
setTimeout(function() {
button.innerHTML = original_html;
}, 2000);
})
.catch((error) => {
console.error('Error:', error);
@@ -377,6 +377,9 @@
document.getElementById("status").style.display = "block";
button.disabled = false;
button.innerHTML = '⚠️ Unsuccessful';
setTimeout(function() {
button.innerHTML = original_html;
}, 2000);
});
content_sources = ["computer", "github", "notion"];
@@ -400,26 +403,6 @@
});
}
// Setup the results count slider
const resultsCountSlider = document.getElementById('results-count-slider');
const resultsCountValue = document.getElementById('results-count-value');
// Set the initial value of the slider
resultsCountValue.textContent = resultsCountSlider.value;
// Store the slider value in localStorage when it changes
resultsCountSlider.addEventListener('input', () => {
resultsCountValue.textContent = resultsCountSlider.value;
localStorage.setItem('khojResultsCount', resultsCountSlider.value);
});
// Get the slider value from localStorage on page load
const storedResultsCount = localStorage.getItem('khojResultsCount');
if (storedResultsCount) {
resultsCountSlider.value = storedResultsCount;
resultsCountValue.textContent = storedResultsCount;
}
function generateAPIKey() {
const apiKeyList = document.getElementById("api-key-list");
fetch('/auth/token', {

View File

@@ -46,6 +46,9 @@
</div>
</div>
<style>
td {
padding: 10px 0;
}
div.repo {
width: 100%;
height: 100%;
@@ -124,6 +127,11 @@
return;
}
const submitButton = document.getElementById("submit");
submitButton.disabled = true;
submitButton.innerHTML = "Saving...";
// Save Github config on server
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
fetch('/api/config/data/content-source/github', {
method: 'POST',
@@ -137,15 +145,40 @@
})
})
.then(response => response.json())
.then(data => { data["status"] === "ok" ? data : Promise.reject(data) })
.catch(error => {
document.getElementById("success").innerHTML = "⚠️ Failed to save Github settings.";
document.getElementById("success").style.display = "block";
submitButton.innerHTML = "⚠️ Failed to save settings";
setTimeout(function() {
submitButton.innerHTML = "Save";
submitButton.disabled = false;
}, 2000);
return;
});
// Index Github content on server
fetch('/api/update?t=github')
.then(response => response.json())
.then(data => { data["status"] == "ok" ? data : Promise.reject(data) })
.then(data => {
if (data["status"] == "ok") {
document.getElementById("success").innerHTML = "✅ Successfully updated. Go to your <a href='/config'>settings page</a> to complete setup.";
document.getElementById("success").style.display = "block";
} else {
document.getElementById("success").innerHTML = "⚠️ Failed to update settings.";
document.getElementById("success").style.display = "block";
}
document.getElementById("success").style.display = "none";
submitButton.innerHTML = "✅ Successfully updated";
setTimeout(function() {
submitButton.innerHTML = "Save";
submitButton.disabled = false;
}, 2000);
})
.catch(error => {
document.getElementById("success").innerHTML = "⚠️ Failed to save Github content.";
document.getElementById("success").style.display = "block";
submitButton.innerHTML = "⚠️ Failed to save content";
setTimeout(function() {
submitButton.innerHTML = "Save";
submitButton.disabled = false;
}, 2000);
});
});
</script>
{% endblock %}

View File

@@ -41,6 +41,11 @@
return;
}
const submitButton = document.getElementById("submit");
submitButton.disabled = true;
submitButton.innerHTML = "Saving...";
// Save Notion config on server
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
fetch('/api/config/data/content-source/notion', {
method: 'POST',
@@ -53,15 +58,39 @@
})
})
.then(response => response.json())
.then(data => { data["status"] === "ok" ? data : Promise.reject(data) })
.catch(error => {
document.getElementById("success").innerHTML = "⚠️ Failed to save Notion settings.";
document.getElementById("success").style.display = "block";
submitButton.innerHTML = "⚠️ Failed to save settings";
setTimeout(function() {
submitButton.innerHTML = "Save";
submitButton.disabled = false;
}, 2000);
return;
});
// Index Notion content on server
fetch('/api/update?t=notion')
.then(response => response.json())
.then(data => { data["status"] == "ok" ? data : Promise.reject(data) })
.then(data => {
if (data["status"] == "ok") {
document.getElementById("success").innerHTML = "✅ Successfully updated. Go to your <a href='/config'>settings page</a> to complete setup.";
document.getElementById("success").style.display = "block";
} else {
document.getElementById("success").innerHTML = "⚠️ Failed to update settings.";
document.getElementById("success").style.display = "block";
}
document.getElementById("success").style.display = "none";
submitButton.innerHTML = "✅ Successfully updated";
setTimeout(function() {
submitButton.innerHTML = "Save";
submitButton.disabled = false;
}, 2000);
})
.catch(error => {
document.getElementById("success").innerHTML = "⚠️ Failed to save Notion content.";
document.getElementById("success").style.display = "block";
submitButton.innerHTML = "⚠️ Failed to save content";
setTimeout(function() {
submitButton.innerHTML = "Save";
submitButton.disabled = false;
}, 2000);
});
});
</script>
{% endblock %}

View File

@@ -189,7 +189,6 @@
})
.then(response => response.json())
.then(data => {
console.log(data);
document.getElementById("results").innerHTML = render_results(data, query, type);
});
}

View File

@@ -111,15 +111,13 @@ def converse(
return iter([prompts.no_notes_found.format()])
elif conversation_command == ConversationCommand.General or is_none_or_empty(compiled_references):
conversation_primer = prompts.general_conversation.format(query=user_query)
personality = prompts.personality.format(current_date=current_date)
else:
conversation_primer = prompts.general_conversation.format(query=user_query)
personality = prompts.personality_with_notes.format(current_date=current_date, references=compiled_references)
conversation_primer = prompts.notes_conversation.format(query=user_query, references=compiled_references)
# Setup Prompt with Primer or Conversation History
messages = generate_chatml_messages_with_context(
conversation_primer,
personality,
prompts.personality.format(current_date=current_date),
conversation_log,
model,
max_prompt_size,
@@ -136,4 +134,5 @@ def converse(
temperature=temperature,
openai_api_key=api_key,
completion_func=completion_func,
model_kwargs={"stop": ["Notes:\n["]},
)

View File

@@ -69,15 +69,15 @@ def completion_with_backoff(**kwargs):
reraise=True,
)
def chat_completion_with_backoff(
messages, compiled_references, model_name, temperature, openai_api_key=None, completion_func=None
messages, compiled_references, model_name, temperature, openai_api_key=None, completion_func=None, model_kwargs=None
):
g = ThreadedGenerator(compiled_references, completion_func=completion_func)
t = Thread(target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key))
t = Thread(target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key, model_kwargs))
t.start()
return g
def llm_thread(g, messages, model_name, temperature, openai_api_key=None):
def llm_thread(g, messages, model_name, temperature, openai_api_key=None, model_kwargs=None):
callback_handler = StreamingChatCallbackHandler(g)
chat = ChatOpenAI(
streaming=True,
@@ -86,6 +86,7 @@ def llm_thread(g, messages, model_name, temperature, openai_api_key=None):
model_name=model_name, # type: ignore
temperature=temperature,
openai_api_key=openai_api_key or os.getenv("OPENAI_API_KEY"),
model_kwargs=model_kwargs,
request_timeout=20,
max_retries=1,
client=None,

View File

@@ -13,7 +13,7 @@ You were created by Khoj Inc. with the following capabilities:
- You *CAN REMEMBER ALL NOTES and PERSONAL INFORMATION FOREVER* that the user ever shares with you.
- You cannot set reminders.
- Say "I don't know" or "I don't understand" if you don't know what to say or if you don't know the answer to a question.
- You ask friendly, inquisitive follow-up QUESTIONS to collect more detail about their experiences and better understand the user's intent. These questions end with a question mark and seek to better understand the user.
- Ask crisp follow-up questions to get additional context, when the answer cannot be inferred from the provided notes or past conversations.
- Sometimes the user will share personal information that needs to be remembered, like an account ID or a residential address. These can be acknowledged with a simple "Got it" or "Okay".
Note: More information about you, the company or other Khoj apps can be found at https://khoj.dev.
@@ -21,25 +21,6 @@ Today is {current_date} in UTC.
""".strip()
)
personality_with_notes = PromptTemplate.from_template(
"""
You are Khoj, a smart, inquisitive and helpful personal assistant.
Use your general knowledge and the past conversation with the user as context to inform your responses.
You were created by Khoj Inc. with the following capabilities:
- You *CAN REMEMBER ALL NOTES and PERSONAL INFORMATION FOREVER* that the user ever shares with you.
- You cannot set reminders.
- Say "I don't know" or "I don't understand" if you don't know what to say or if you don't know the answer to a question.
- You ask friendly, inquisitive follow-up QUESTIONS to collect more detail about their experiences and better understand the user's intent. These questions end with a question mark and seek to better understand the user.
- Sometimes the user will share personal information that needs to be remembered, like an account ID or a residential address. These can be acknowledged with a simple "Got it" or "Okay".
Note: More information about you, the company or other Khoj apps can be found at https://khoj.dev.
Today is {current_date} in UTC.
User's Notes:
{references}
""".strip()
)
## General Conversation
## --
general_conversation = PromptTemplate.from_template(
@@ -108,14 +89,13 @@ conversation_llamav2 = PromptTemplate.from_template(
## --
notes_conversation = PromptTemplate.from_template(
"""
Using my personal notes and our past conversations as context, answer the following question.
Ask crisp follow-up questions to get additional context, when the answer cannot be inferred from the provided notes or past conversations.
These questions should end with a question mark.
Use my personal notes and our past conversations to inform your response.
Ask crisp follow-up questions to get additional context, when a helpful response cannot be provided from the provided notes or past conversations.
Notes:
{references}
Question: {query}
Query: {query}
""".strip()
)

View File

@@ -177,11 +177,15 @@ async def set_content_config_github_data(
user = request.user.object
await adapters.set_user_github_config(
user=user,
pat_token=updated_config.pat_token,
repos=updated_config.repos,
)
try:
await adapters.set_user_github_config(
user=user,
pat_token=updated_config.pat_token,
repos=updated_config.repos,
)
except Exception as e:
logger.error(e, exc_info=True)
raise HTTPException(status_code=500, detail="Failed to set Github config")
update_telemetry_state(
request=request,
@@ -205,10 +209,14 @@ async def set_content_config_notion_data(
user = request.user.object
await adapters.set_notion_config(
user=user,
token=updated_config.token,
)
try:
await adapters.set_notion_config(
user=user,
token=updated_config.token,
)
except Exception as e:
logger.error(e, exc_info=True)
raise HTTPException(status_code=500, detail="Failed to set Github config")
update_telemetry_state(
request=request,
@@ -348,7 +356,7 @@ async def search(
n: Optional[int] = 5,
t: Optional[SearchType] = SearchType.All,
r: Optional[bool] = False,
score_threshold: Optional[Union[float, None]] = None,
max_distance: Optional[Union[float, None]] = None,
dedupe: Optional[bool] = True,
client: Optional[str] = None,
user_agent: Optional[str] = Header(None),
@@ -367,12 +375,12 @@ async def search(
# initialize variables
user_query = q.strip()
results_count = n or 5
score_threshold = score_threshold if score_threshold is not None else -math.inf
max_distance = max_distance if max_distance is not None else math.inf
search_futures: List[concurrent.futures.Future] = []
# return cached results, if available
if user:
query_cache_key = f"{user_query}-{n}-{t}-{r}-{score_threshold}-{dedupe}"
query_cache_key = f"{user_query}-{n}-{t}-{r}-{max_distance}-{dedupe}"
if query_cache_key in state.query_cache[user.uuid]:
logger.debug(f"Return response from query cache")
return state.query_cache[user.uuid][query_cache_key]
@@ -410,7 +418,7 @@ async def search(
t,
question_embedding=encoded_asymmetric_query,
rank_results=r or False,
score_threshold=score_threshold,
max_distance=max_distance,
)
]
@@ -423,7 +431,6 @@ async def search(
results_count,
state.search_models.image_search,
state.content_index.image,
score_threshold=score_threshold,
)
]
@@ -446,11 +453,10 @@ async def search(
# Collate results
results += text_search.collate_results(hits, dedupe=dedupe)
if r:
results = text_search.rerank_and_sort_results(results, query=defiltered_query)[:results_count]
else:
# Sort results across all content types and take top results
results = sorted(results, key=lambda x: float(x.score))[:results_count]
results = text_search.rerank_and_sort_results(results, query=defiltered_query, rank_results=r)[
:results_count
]
# Cache results
if user:
@@ -575,6 +581,7 @@ async def chat(
request: Request,
q: str,
n: Optional[int] = 5,
d: Optional[float] = 0.15,
client: Optional[str] = None,
stream: Optional[bool] = False,
user_agent: Optional[str] = Header(None),
@@ -591,7 +598,7 @@ async def chat(
meta_log = (await ConversationAdapters.aget_conversation_by_user(user)).conversation_log
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
request, meta_log, q, (n or 5), conversation_command
request, meta_log, q, (n or 5), (d or math.inf), conversation_command
)
if conversation_command == ConversationCommand.Default and is_none_or_empty(compiled_references):
@@ -655,6 +662,7 @@ async def extract_references_and_questions(
meta_log: dict,
q: str,
n: int,
d: float,
conversation_type: ConversationCommand = ConversationCommand.Default,
):
user = request.user.object if request.user.is_authenticated else None
@@ -715,7 +723,7 @@ async def extract_references_and_questions(
request=request,
n=n_items,
r=True,
score_threshold=-5.0,
max_distance=d,
dedupe=False,
)
)

View File

@@ -146,7 +146,7 @@ def extract_metadata(image_name):
async def query(
raw_query, count, search_model: ImageSearchModel, content: ImageContent, score_threshold: float = -math.inf
raw_query, count, search_model: ImageSearchModel, content: ImageContent, score_threshold: float = math.inf
):
# Set query to image content if query is of form file:/path/to/file.png
if raw_query.startswith("file:") and pathlib.Path(raw_query[5:]).is_file():
@@ -167,7 +167,8 @@ async def query(
# Compute top_k ranked images based on cosine-similarity b/w query and all image embeddings.
with timer("Search Time", logger):
image_hits = {
result["corpus_id"]: {"image_score": result["score"], "score": result["score"]}
# Map scores to distance metric by multiplying by -1
result["corpus_id"]: {"image_score": -1 * result["score"], "score": -1 * result["score"]}
for result in util.semantic_search(query_embedding, content.image_embeddings, top_k=count)[0]
}
@@ -204,7 +205,7 @@ async def query(
]
# Filter results by score threshold
hits = [hit for hit in hits if hit["image_score"] >= score_threshold]
hits = [hit for hit in hits if hit["image_score"] <= score_threshold]
# Sort the images based on their combined metadata, image scores
return sorted(hits, key=lambda hit: hit["score"], reverse=True)

View File

@@ -105,7 +105,7 @@ async def query(
type: SearchType = SearchType.All,
question_embedding: Union[torch.Tensor, None] = None,
rank_results: bool = False,
score_threshold: float = -math.inf,
max_distance: float = math.inf,
) -> Tuple[List[dict], List[Entry]]:
"Search for entries that answer the query"
@@ -127,6 +127,7 @@ async def query(
max_results=top_k,
file_type_filter=file_type,
raw_query=raw_query,
max_distance=max_distance,
).all()
hits = await sync_to_async(list)(hits) # type: ignore[call-arg]
@@ -177,12 +178,16 @@ def deduplicated_search_responses(hits: List[SearchResponse]):
)
def rerank_and_sort_results(hits, query):
def rerank_and_sort_results(hits, query, rank_results):
# If we have more than one result and reranking is enabled
rank_results = rank_results and len(list(hits)) > 1
# Score all retrieved entries using the cross-encoder
hits = cross_encoder_score(query, hits)
if rank_results:
hits = cross_encoder_score(query, hits)
# Sort results by cross-encoder score followed by bi-encoder score
hits = sort_results(rank_results=True, hits=hits)
hits = sort_results(rank_results=rank_results, hits=hits)
return hits
@@ -217,9 +222,9 @@ def cross_encoder_score(query: str, hits: List[SearchResponse]) -> List[SearchRe
with timer("Cross-Encoder Predict Time", logger, state.device):
cross_scores = state.cross_encoder_model.predict(query, hits)
# Store cross-encoder scores in results dictionary for ranking
# Convert cross-encoder scores to distances and pass in hits for reranking
for idx in range(len(cross_scores)):
hits[idx]["cross_score"] = cross_scores[idx]
hits[idx]["cross_score"] = -1 * cross_scores[idx]
return hits
@@ -227,7 +232,7 @@ def cross_encoder_score(query: str, hits: List[SearchResponse]) -> List[SearchRe
def sort_results(rank_results: bool, hits: List[dict]) -> List[dict]:
"""Order results by cross-encoder score followed by bi-encoder score"""
with timer("Rank Time", logger, state.device):
hits.sort(key=lambda x: x["score"], reverse=True) # sort by bi-encoder score
hits.sort(key=lambda x: x["score"]) # sort by bi-encoder score
if rank_results:
hits.sort(key=lambda x: x["cross_score"], reverse=True) # sort by cross-encoder score
hits.sort(key=lambda x: x["cross_score"]) # sort by cross-encoder score
return hits