mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-05 13:21:18 +00:00
Merge branch 'fix/imports-and-references' of github.com:khoj-ai/khoj into fix/imports-and-references
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import math
|
||||
from typing import Optional, Type, TypeVar, List
|
||||
from datetime import date, datetime, timedelta
|
||||
import secrets
|
||||
@@ -437,12 +438,19 @@ class EntryAdapters:
|
||||
|
||||
@staticmethod
|
||||
def search_with_embeddings(
|
||||
user: KhojUser, embeddings: Tensor, max_results: int = 10, file_type_filter: str = None, raw_query: str = None
|
||||
user: KhojUser,
|
||||
embeddings: Tensor,
|
||||
max_results: int = 10,
|
||||
file_type_filter: str = None,
|
||||
raw_query: str = None,
|
||||
max_distance: float = math.inf,
|
||||
):
|
||||
relevant_entries = EntryAdapters.apply_filters(user, raw_query, file_type_filter)
|
||||
relevant_entries = relevant_entries.filter(user=user).annotate(
|
||||
distance=CosineDistance("embeddings", embeddings)
|
||||
)
|
||||
relevant_entries = relevant_entries.filter(distance__lte=max_distance)
|
||||
|
||||
if file_type_filter:
|
||||
relevant_entries = relevant_entries.filter(file_type=file_type_filter)
|
||||
relevant_entries = relevant_entries.order_by("distance")
|
||||
|
||||
@@ -188,7 +188,6 @@
|
||||
fetch(url, { headers })
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
console.log(data);
|
||||
document.getElementById("results").innerHTML = render_results(data, query, type);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -121,7 +121,7 @@
|
||||
div.finalize-buttons {
|
||||
display: grid;
|
||||
gap: 8px;
|
||||
padding: 24px 16px;
|
||||
padding: 32px 0px 0px;
|
||||
width: 320px;
|
||||
border-radius: 4px;
|
||||
overflow: hidden;
|
||||
@@ -274,7 +274,9 @@
|
||||
100% { transform: rotate(360deg); }
|
||||
}
|
||||
|
||||
|
||||
#status {
|
||||
padding-top: 32px;
|
||||
}
|
||||
div.finalize-actions {
|
||||
grid-auto-flow: column;
|
||||
grid-gap: 24px;
|
||||
@@ -347,6 +349,12 @@
|
||||
width: auto;
|
||||
}
|
||||
|
||||
#status {
|
||||
padding-top: 12px;
|
||||
}
|
||||
div.finalize-actions {
|
||||
padding: 12px 0 0;
|
||||
}
|
||||
div.finalize-buttons {
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
@@ -417,6 +417,9 @@ To get started, just start typing below. You can also type / to see a list of co
|
||||
display: block;
|
||||
}
|
||||
|
||||
div.references {
|
||||
padding-top: 8px;
|
||||
}
|
||||
div.reference {
|
||||
display: grid;
|
||||
grid-template-rows: auto;
|
||||
|
||||
@@ -104,6 +104,19 @@
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="general-settings section">
|
||||
<div id="status" style="display: none;"></div>
|
||||
</div>
|
||||
<div class="section finalize-actions general-settings">
|
||||
<div class="section-cards">
|
||||
<div class="finalize-buttons">
|
||||
<button id="configure" type="submit" title="Update index with the latest changes">💾 Save All</button>
|
||||
</div>
|
||||
<div class="finalize-buttons">
|
||||
<button id="reinitialize" type="submit" title="Regenerate index from scratch">🔄 Reinitialize</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="section">
|
||||
<h2 class="section-title">Features</h2>
|
||||
@@ -221,23 +234,7 @@
|
||||
</div>
|
||||
</div>
|
||||
{% endif %}
|
||||
<div class="section general-settings">
|
||||
<div id="results-count" title="Number of items to show in search and use for chat response">
|
||||
<label for="results-count-slider">Results Count: <span id="results-count-value">5</span></label>
|
||||
<input type="range" id="results-count-slider" name="results-count-slider" min="1" max="10" step="1" value="5">
|
||||
</div>
|
||||
<div id="status" style="display: none;"></div>
|
||||
</div>
|
||||
<div class="section finalize-actions general-settings">
|
||||
<div class="section-cards">
|
||||
<div class="finalize-buttons">
|
||||
<button id="configure" type="submit" title="Update index with the latest changes">⚙️ Configure</button>
|
||||
</div>
|
||||
<div class="finalize-buttons">
|
||||
<button id="reinitialize" type="submit" title="Regenerate index from scratch">🔄 Reinitialize</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="section"></div>
|
||||
</div>
|
||||
<script>
|
||||
|
||||
@@ -329,11 +326,11 @@
|
||||
event.preventDefault();
|
||||
updateIndex(
|
||||
force=false,
|
||||
successText="Configured successfully!",
|
||||
successText="Saved!",
|
||||
errorText="Unable to configure. Raise issue on Khoj <a href='https://github.com/khoj-ai/khoj/issues'>Github</a> or <a href='https://discord.gg/BDgyabRM6e'>Discord</a>.",
|
||||
button=configure,
|
||||
loadingText="Configuring...",
|
||||
emoji="⚙️");
|
||||
loadingText="Saving...",
|
||||
emoji="💾");
|
||||
});
|
||||
|
||||
var reinitialize = document.getElementById("reinitialize");
|
||||
@@ -341,7 +338,7 @@
|
||||
event.preventDefault();
|
||||
updateIndex(
|
||||
force=true,
|
||||
successText="Reinitialized successfully!",
|
||||
successText="Reinitialized!",
|
||||
errorText="Unable to reinitialize. Raise issue on Khoj <a href='https://github.com/khoj-ai/khoj/issues'>Github</a> or <a href='https://discord.gg/BDgyabRM6e'>Discord</a>.",
|
||||
button=reinitialize,
|
||||
loadingText="Reinitializing...",
|
||||
@@ -350,6 +347,7 @@
|
||||
|
||||
function updateIndex(force, successText, errorText, button, loadingText, emoji) {
|
||||
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
|
||||
const original_html = button.innerHTML;
|
||||
button.disabled = true;
|
||||
button.innerHTML = emoji + " " + loadingText;
|
||||
fetch('/api/update?&client=web&force=' + force, {
|
||||
@@ -361,15 +359,17 @@
|
||||
})
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
console.log('Success:', data);
|
||||
if (data.detail != null) {
|
||||
throw new Error(data.detail);
|
||||
}
|
||||
|
||||
document.getElementById("status").innerHTML = emoji + " " + successText;
|
||||
document.getElementById("status").style.display = "block";
|
||||
document.getElementById("status").style.display = "none";
|
||||
|
||||
button.disabled = false;
|
||||
button.innerHTML = '✅ Done!';
|
||||
button.innerHTML = `✅ ${successText}`;
|
||||
setTimeout(function() {
|
||||
button.innerHTML = original_html;
|
||||
}, 2000);
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error('Error:', error);
|
||||
@@ -377,6 +377,9 @@
|
||||
document.getElementById("status").style.display = "block";
|
||||
button.disabled = false;
|
||||
button.innerHTML = '⚠️ Unsuccessful';
|
||||
setTimeout(function() {
|
||||
button.innerHTML = original_html;
|
||||
}, 2000);
|
||||
});
|
||||
|
||||
content_sources = ["computer", "github", "notion"];
|
||||
@@ -400,26 +403,6 @@
|
||||
});
|
||||
}
|
||||
|
||||
// Setup the results count slider
|
||||
const resultsCountSlider = document.getElementById('results-count-slider');
|
||||
const resultsCountValue = document.getElementById('results-count-value');
|
||||
|
||||
// Set the initial value of the slider
|
||||
resultsCountValue.textContent = resultsCountSlider.value;
|
||||
|
||||
// Store the slider value in localStorage when it changes
|
||||
resultsCountSlider.addEventListener('input', () => {
|
||||
resultsCountValue.textContent = resultsCountSlider.value;
|
||||
localStorage.setItem('khojResultsCount', resultsCountSlider.value);
|
||||
});
|
||||
|
||||
// Get the slider value from localStorage on page load
|
||||
const storedResultsCount = localStorage.getItem('khojResultsCount');
|
||||
if (storedResultsCount) {
|
||||
resultsCountSlider.value = storedResultsCount;
|
||||
resultsCountValue.textContent = storedResultsCount;
|
||||
}
|
||||
|
||||
function generateAPIKey() {
|
||||
const apiKeyList = document.getElementById("api-key-list");
|
||||
fetch('/auth/token', {
|
||||
|
||||
@@ -46,6 +46,9 @@
|
||||
</div>
|
||||
</div>
|
||||
<style>
|
||||
td {
|
||||
padding: 10px 0;
|
||||
}
|
||||
div.repo {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
@@ -124,6 +127,11 @@
|
||||
return;
|
||||
}
|
||||
|
||||
const submitButton = document.getElementById("submit");
|
||||
submitButton.disabled = true;
|
||||
submitButton.innerHTML = "Saving...";
|
||||
|
||||
// Save Github config on server
|
||||
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
|
||||
fetch('/api/config/data/content-source/github', {
|
||||
method: 'POST',
|
||||
@@ -137,15 +145,40 @@
|
||||
})
|
||||
})
|
||||
.then(response => response.json())
|
||||
.then(data => { data["status"] === "ok" ? data : Promise.reject(data) })
|
||||
.catch(error => {
|
||||
document.getElementById("success").innerHTML = "⚠️ Failed to save Github settings.";
|
||||
document.getElementById("success").style.display = "block";
|
||||
submitButton.innerHTML = "⚠️ Failed to save settings";
|
||||
setTimeout(function() {
|
||||
submitButton.innerHTML = "Save";
|
||||
submitButton.disabled = false;
|
||||
}, 2000);
|
||||
return;
|
||||
});
|
||||
|
||||
// Index Github content on server
|
||||
fetch('/api/update?t=github')
|
||||
.then(response => response.json())
|
||||
.then(data => { data["status"] == "ok" ? data : Promise.reject(data) })
|
||||
.then(data => {
|
||||
if (data["status"] == "ok") {
|
||||
document.getElementById("success").innerHTML = "✅ Successfully updated. Go to your <a href='/config'>settings page</a> to complete setup.";
|
||||
document.getElementById("success").style.display = "block";
|
||||
} else {
|
||||
document.getElementById("success").innerHTML = "⚠️ Failed to update settings.";
|
||||
document.getElementById("success").style.display = "block";
|
||||
}
|
||||
document.getElementById("success").style.display = "none";
|
||||
submitButton.innerHTML = "✅ Successfully updated";
|
||||
setTimeout(function() {
|
||||
submitButton.innerHTML = "Save";
|
||||
submitButton.disabled = false;
|
||||
}, 2000);
|
||||
})
|
||||
.catch(error => {
|
||||
document.getElementById("success").innerHTML = "⚠️ Failed to save Github content.";
|
||||
document.getElementById("success").style.display = "block";
|
||||
submitButton.innerHTML = "⚠️ Failed to save content";
|
||||
setTimeout(function() {
|
||||
submitButton.innerHTML = "Save";
|
||||
submitButton.disabled = false;
|
||||
}, 2000);
|
||||
});
|
||||
|
||||
});
|
||||
</script>
|
||||
{% endblock %}
|
||||
|
||||
@@ -41,6 +41,11 @@
|
||||
return;
|
||||
}
|
||||
|
||||
const submitButton = document.getElementById("submit");
|
||||
submitButton.disabled = true;
|
||||
submitButton.innerHTML = "Saving...";
|
||||
|
||||
// Save Notion config on server
|
||||
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
|
||||
fetch('/api/config/data/content-source/notion', {
|
||||
method: 'POST',
|
||||
@@ -53,15 +58,39 @@
|
||||
})
|
||||
})
|
||||
.then(response => response.json())
|
||||
.then(data => { data["status"] === "ok" ? data : Promise.reject(data) })
|
||||
.catch(error => {
|
||||
document.getElementById("success").innerHTML = "⚠️ Failed to save Notion settings.";
|
||||
document.getElementById("success").style.display = "block";
|
||||
submitButton.innerHTML = "⚠️ Failed to save settings";
|
||||
setTimeout(function() {
|
||||
submitButton.innerHTML = "Save";
|
||||
submitButton.disabled = false;
|
||||
}, 2000);
|
||||
return;
|
||||
});
|
||||
|
||||
// Index Notion content on server
|
||||
fetch('/api/update?t=notion')
|
||||
.then(response => response.json())
|
||||
.then(data => { data["status"] == "ok" ? data : Promise.reject(data) })
|
||||
.then(data => {
|
||||
if (data["status"] == "ok") {
|
||||
document.getElementById("success").innerHTML = "✅ Successfully updated. Go to your <a href='/config'>settings page</a> to complete setup.";
|
||||
document.getElementById("success").style.display = "block";
|
||||
} else {
|
||||
document.getElementById("success").innerHTML = "⚠️ Failed to update settings.";
|
||||
document.getElementById("success").style.display = "block";
|
||||
}
|
||||
document.getElementById("success").style.display = "none";
|
||||
submitButton.innerHTML = "✅ Successfully updated";
|
||||
setTimeout(function() {
|
||||
submitButton.innerHTML = "Save";
|
||||
submitButton.disabled = false;
|
||||
}, 2000);
|
||||
})
|
||||
.catch(error => {
|
||||
document.getElementById("success").innerHTML = "⚠️ Failed to save Notion content.";
|
||||
document.getElementById("success").style.display = "block";
|
||||
submitButton.innerHTML = "⚠️ Failed to save content";
|
||||
setTimeout(function() {
|
||||
submitButton.innerHTML = "Save";
|
||||
submitButton.disabled = false;
|
||||
}, 2000);
|
||||
});
|
||||
});
|
||||
</script>
|
||||
{% endblock %}
|
||||
|
||||
@@ -189,7 +189,6 @@
|
||||
})
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
console.log(data);
|
||||
document.getElementById("results").innerHTML = render_results(data, query, type);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -111,15 +111,13 @@ def converse(
|
||||
return iter([prompts.no_notes_found.format()])
|
||||
elif conversation_command == ConversationCommand.General or is_none_or_empty(compiled_references):
|
||||
conversation_primer = prompts.general_conversation.format(query=user_query)
|
||||
personality = prompts.personality.format(current_date=current_date)
|
||||
else:
|
||||
conversation_primer = prompts.general_conversation.format(query=user_query)
|
||||
personality = prompts.personality_with_notes.format(current_date=current_date, references=compiled_references)
|
||||
conversation_primer = prompts.notes_conversation.format(query=user_query, references=compiled_references)
|
||||
|
||||
# Setup Prompt with Primer or Conversation History
|
||||
messages = generate_chatml_messages_with_context(
|
||||
conversation_primer,
|
||||
personality,
|
||||
prompts.personality.format(current_date=current_date),
|
||||
conversation_log,
|
||||
model,
|
||||
max_prompt_size,
|
||||
@@ -136,4 +134,5 @@ def converse(
|
||||
temperature=temperature,
|
||||
openai_api_key=api_key,
|
||||
completion_func=completion_func,
|
||||
model_kwargs={"stop": ["Notes:\n["]},
|
||||
)
|
||||
|
||||
@@ -69,15 +69,15 @@ def completion_with_backoff(**kwargs):
|
||||
reraise=True,
|
||||
)
|
||||
def chat_completion_with_backoff(
|
||||
messages, compiled_references, model_name, temperature, openai_api_key=None, completion_func=None
|
||||
messages, compiled_references, model_name, temperature, openai_api_key=None, completion_func=None, model_kwargs=None
|
||||
):
|
||||
g = ThreadedGenerator(compiled_references, completion_func=completion_func)
|
||||
t = Thread(target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key))
|
||||
t = Thread(target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key, model_kwargs))
|
||||
t.start()
|
||||
return g
|
||||
|
||||
|
||||
def llm_thread(g, messages, model_name, temperature, openai_api_key=None):
|
||||
def llm_thread(g, messages, model_name, temperature, openai_api_key=None, model_kwargs=None):
|
||||
callback_handler = StreamingChatCallbackHandler(g)
|
||||
chat = ChatOpenAI(
|
||||
streaming=True,
|
||||
@@ -86,6 +86,7 @@ def llm_thread(g, messages, model_name, temperature, openai_api_key=None):
|
||||
model_name=model_name, # type: ignore
|
||||
temperature=temperature,
|
||||
openai_api_key=openai_api_key or os.getenv("OPENAI_API_KEY"),
|
||||
model_kwargs=model_kwargs,
|
||||
request_timeout=20,
|
||||
max_retries=1,
|
||||
client=None,
|
||||
|
||||
@@ -13,7 +13,7 @@ You were created by Khoj Inc. with the following capabilities:
|
||||
- You *CAN REMEMBER ALL NOTES and PERSONAL INFORMATION FOREVER* that the user ever shares with you.
|
||||
- You cannot set reminders.
|
||||
- Say "I don't know" or "I don't understand" if you don't know what to say or if you don't know the answer to a question.
|
||||
- You ask friendly, inquisitive follow-up QUESTIONS to collect more detail about their experiences and better understand the user's intent. These questions end with a question mark and seek to better understand the user.
|
||||
- Ask crisp follow-up questions to get additional context, when the answer cannot be inferred from the provided notes or past conversations.
|
||||
- Sometimes the user will share personal information that needs to be remembered, like an account ID or a residential address. These can be acknowledged with a simple "Got it" or "Okay".
|
||||
|
||||
Note: More information about you, the company or other Khoj apps can be found at https://khoj.dev.
|
||||
@@ -21,25 +21,6 @@ Today is {current_date} in UTC.
|
||||
""".strip()
|
||||
)
|
||||
|
||||
personality_with_notes = PromptTemplate.from_template(
|
||||
"""
|
||||
You are Khoj, a smart, inquisitive and helpful personal assistant.
|
||||
Use your general knowledge and the past conversation with the user as context to inform your responses.
|
||||
You were created by Khoj Inc. with the following capabilities:
|
||||
|
||||
- You *CAN REMEMBER ALL NOTES and PERSONAL INFORMATION FOREVER* that the user ever shares with you.
|
||||
- You cannot set reminders.
|
||||
- Say "I don't know" or "I don't understand" if you don't know what to say or if you don't know the answer to a question.
|
||||
- You ask friendly, inquisitive follow-up QUESTIONS to collect more detail about their experiences and better understand the user's intent. These questions end with a question mark and seek to better understand the user.
|
||||
- Sometimes the user will share personal information that needs to be remembered, like an account ID or a residential address. These can be acknowledged with a simple "Got it" or "Okay".
|
||||
|
||||
Note: More information about you, the company or other Khoj apps can be found at https://khoj.dev.
|
||||
Today is {current_date} in UTC.
|
||||
|
||||
User's Notes:
|
||||
{references}
|
||||
""".strip()
|
||||
)
|
||||
## General Conversation
|
||||
## --
|
||||
general_conversation = PromptTemplate.from_template(
|
||||
@@ -108,14 +89,13 @@ conversation_llamav2 = PromptTemplate.from_template(
|
||||
## --
|
||||
notes_conversation = PromptTemplate.from_template(
|
||||
"""
|
||||
Using my personal notes and our past conversations as context, answer the following question.
|
||||
Ask crisp follow-up questions to get additional context, when the answer cannot be inferred from the provided notes or past conversations.
|
||||
These questions should end with a question mark.
|
||||
Use my personal notes and our past conversations to inform your response.
|
||||
Ask crisp follow-up questions to get additional context, when a helpful response cannot be provided from the provided notes or past conversations.
|
||||
|
||||
Notes:
|
||||
{references}
|
||||
|
||||
Question: {query}
|
||||
Query: {query}
|
||||
""".strip()
|
||||
)
|
||||
|
||||
|
||||
@@ -177,11 +177,15 @@ async def set_content_config_github_data(
|
||||
|
||||
user = request.user.object
|
||||
|
||||
await adapters.set_user_github_config(
|
||||
user=user,
|
||||
pat_token=updated_config.pat_token,
|
||||
repos=updated_config.repos,
|
||||
)
|
||||
try:
|
||||
await adapters.set_user_github_config(
|
||||
user=user,
|
||||
pat_token=updated_config.pat_token,
|
||||
repos=updated_config.repos,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(e, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Failed to set Github config")
|
||||
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
@@ -205,10 +209,14 @@ async def set_content_config_notion_data(
|
||||
|
||||
user = request.user.object
|
||||
|
||||
await adapters.set_notion_config(
|
||||
user=user,
|
||||
token=updated_config.token,
|
||||
)
|
||||
try:
|
||||
await adapters.set_notion_config(
|
||||
user=user,
|
||||
token=updated_config.token,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(e, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Failed to set Github config")
|
||||
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
@@ -348,7 +356,7 @@ async def search(
|
||||
n: Optional[int] = 5,
|
||||
t: Optional[SearchType] = SearchType.All,
|
||||
r: Optional[bool] = False,
|
||||
score_threshold: Optional[Union[float, None]] = None,
|
||||
max_distance: Optional[Union[float, None]] = None,
|
||||
dedupe: Optional[bool] = True,
|
||||
client: Optional[str] = None,
|
||||
user_agent: Optional[str] = Header(None),
|
||||
@@ -367,12 +375,12 @@ async def search(
|
||||
# initialize variables
|
||||
user_query = q.strip()
|
||||
results_count = n or 5
|
||||
score_threshold = score_threshold if score_threshold is not None else -math.inf
|
||||
max_distance = max_distance if max_distance is not None else math.inf
|
||||
search_futures: List[concurrent.futures.Future] = []
|
||||
|
||||
# return cached results, if available
|
||||
if user:
|
||||
query_cache_key = f"{user_query}-{n}-{t}-{r}-{score_threshold}-{dedupe}"
|
||||
query_cache_key = f"{user_query}-{n}-{t}-{r}-{max_distance}-{dedupe}"
|
||||
if query_cache_key in state.query_cache[user.uuid]:
|
||||
logger.debug(f"Return response from query cache")
|
||||
return state.query_cache[user.uuid][query_cache_key]
|
||||
@@ -410,7 +418,7 @@ async def search(
|
||||
t,
|
||||
question_embedding=encoded_asymmetric_query,
|
||||
rank_results=r or False,
|
||||
score_threshold=score_threshold,
|
||||
max_distance=max_distance,
|
||||
)
|
||||
]
|
||||
|
||||
@@ -423,7 +431,6 @@ async def search(
|
||||
results_count,
|
||||
state.search_models.image_search,
|
||||
state.content_index.image,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
]
|
||||
|
||||
@@ -446,11 +453,10 @@ async def search(
|
||||
# Collate results
|
||||
results += text_search.collate_results(hits, dedupe=dedupe)
|
||||
|
||||
if r:
|
||||
results = text_search.rerank_and_sort_results(results, query=defiltered_query)[:results_count]
|
||||
else:
|
||||
# Sort results across all content types and take top results
|
||||
results = sorted(results, key=lambda x: float(x.score))[:results_count]
|
||||
results = text_search.rerank_and_sort_results(results, query=defiltered_query, rank_results=r)[
|
||||
:results_count
|
||||
]
|
||||
|
||||
# Cache results
|
||||
if user:
|
||||
@@ -575,6 +581,7 @@ async def chat(
|
||||
request: Request,
|
||||
q: str,
|
||||
n: Optional[int] = 5,
|
||||
d: Optional[float] = 0.15,
|
||||
client: Optional[str] = None,
|
||||
stream: Optional[bool] = False,
|
||||
user_agent: Optional[str] = Header(None),
|
||||
@@ -591,7 +598,7 @@ async def chat(
|
||||
meta_log = (await ConversationAdapters.aget_conversation_by_user(user)).conversation_log
|
||||
|
||||
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
|
||||
request, meta_log, q, (n or 5), conversation_command
|
||||
request, meta_log, q, (n or 5), (d or math.inf), conversation_command
|
||||
)
|
||||
|
||||
if conversation_command == ConversationCommand.Default and is_none_or_empty(compiled_references):
|
||||
@@ -655,6 +662,7 @@ async def extract_references_and_questions(
|
||||
meta_log: dict,
|
||||
q: str,
|
||||
n: int,
|
||||
d: float,
|
||||
conversation_type: ConversationCommand = ConversationCommand.Default,
|
||||
):
|
||||
user = request.user.object if request.user.is_authenticated else None
|
||||
@@ -715,7 +723,7 @@ async def extract_references_and_questions(
|
||||
request=request,
|
||||
n=n_items,
|
||||
r=True,
|
||||
score_threshold=-5.0,
|
||||
max_distance=d,
|
||||
dedupe=False,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -146,7 +146,7 @@ def extract_metadata(image_name):
|
||||
|
||||
|
||||
async def query(
|
||||
raw_query, count, search_model: ImageSearchModel, content: ImageContent, score_threshold: float = -math.inf
|
||||
raw_query, count, search_model: ImageSearchModel, content: ImageContent, score_threshold: float = math.inf
|
||||
):
|
||||
# Set query to image content if query is of form file:/path/to/file.png
|
||||
if raw_query.startswith("file:") and pathlib.Path(raw_query[5:]).is_file():
|
||||
@@ -167,7 +167,8 @@ async def query(
|
||||
# Compute top_k ranked images based on cosine-similarity b/w query and all image embeddings.
|
||||
with timer("Search Time", logger):
|
||||
image_hits = {
|
||||
result["corpus_id"]: {"image_score": result["score"], "score": result["score"]}
|
||||
# Map scores to distance metric by multiplying by -1
|
||||
result["corpus_id"]: {"image_score": -1 * result["score"], "score": -1 * result["score"]}
|
||||
for result in util.semantic_search(query_embedding, content.image_embeddings, top_k=count)[0]
|
||||
}
|
||||
|
||||
@@ -204,7 +205,7 @@ async def query(
|
||||
]
|
||||
|
||||
# Filter results by score threshold
|
||||
hits = [hit for hit in hits if hit["image_score"] >= score_threshold]
|
||||
hits = [hit for hit in hits if hit["image_score"] <= score_threshold]
|
||||
|
||||
# Sort the images based on their combined metadata, image scores
|
||||
return sorted(hits, key=lambda hit: hit["score"], reverse=True)
|
||||
|
||||
@@ -105,7 +105,7 @@ async def query(
|
||||
type: SearchType = SearchType.All,
|
||||
question_embedding: Union[torch.Tensor, None] = None,
|
||||
rank_results: bool = False,
|
||||
score_threshold: float = -math.inf,
|
||||
max_distance: float = math.inf,
|
||||
) -> Tuple[List[dict], List[Entry]]:
|
||||
"Search for entries that answer the query"
|
||||
|
||||
@@ -127,6 +127,7 @@ async def query(
|
||||
max_results=top_k,
|
||||
file_type_filter=file_type,
|
||||
raw_query=raw_query,
|
||||
max_distance=max_distance,
|
||||
).all()
|
||||
hits = await sync_to_async(list)(hits) # type: ignore[call-arg]
|
||||
|
||||
@@ -177,12 +178,16 @@ def deduplicated_search_responses(hits: List[SearchResponse]):
|
||||
)
|
||||
|
||||
|
||||
def rerank_and_sort_results(hits, query):
|
||||
def rerank_and_sort_results(hits, query, rank_results):
|
||||
# If we have more than one result and reranking is enabled
|
||||
rank_results = rank_results and len(list(hits)) > 1
|
||||
|
||||
# Score all retrieved entries using the cross-encoder
|
||||
hits = cross_encoder_score(query, hits)
|
||||
if rank_results:
|
||||
hits = cross_encoder_score(query, hits)
|
||||
|
||||
# Sort results by cross-encoder score followed by bi-encoder score
|
||||
hits = sort_results(rank_results=True, hits=hits)
|
||||
hits = sort_results(rank_results=rank_results, hits=hits)
|
||||
|
||||
return hits
|
||||
|
||||
@@ -217,9 +222,9 @@ def cross_encoder_score(query: str, hits: List[SearchResponse]) -> List[SearchRe
|
||||
with timer("Cross-Encoder Predict Time", logger, state.device):
|
||||
cross_scores = state.cross_encoder_model.predict(query, hits)
|
||||
|
||||
# Store cross-encoder scores in results dictionary for ranking
|
||||
# Convert cross-encoder scores to distances and pass in hits for reranking
|
||||
for idx in range(len(cross_scores)):
|
||||
hits[idx]["cross_score"] = cross_scores[idx]
|
||||
hits[idx]["cross_score"] = -1 * cross_scores[idx]
|
||||
|
||||
return hits
|
||||
|
||||
@@ -227,7 +232,7 @@ def cross_encoder_score(query: str, hits: List[SearchResponse]) -> List[SearchRe
|
||||
def sort_results(rank_results: bool, hits: List[dict]) -> List[dict]:
|
||||
"""Order results by cross-encoder score followed by bi-encoder score"""
|
||||
with timer("Rank Time", logger, state.device):
|
||||
hits.sort(key=lambda x: x["score"], reverse=True) # sort by bi-encoder score
|
||||
hits.sort(key=lambda x: x["score"]) # sort by bi-encoder score
|
||||
if rank_results:
|
||||
hits.sort(key=lambda x: x["cross_score"], reverse=True) # sort by cross-encoder score
|
||||
hits.sort(key=lambda x: x["cross_score"]) # sort by cross-encoder score
|
||||
return hits
|
||||
|
||||
Reference in New Issue
Block a user