mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
Merge pull request #896 from khoj-ai/features/add-support-for-custom-confidence
Add support for custom search model-specific thresholds
This commit is contained in:
@@ -0,0 +1,17 @@
|
|||||||
|
# Generated by Django 5.0.7 on 2024-08-24 18:19
|
||||||
|
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
dependencies = [
|
||||||
|
("database", "0058_alter_chatmodeloptions_chat_model"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="searchmodelconfig",
|
||||||
|
name="bi_encoder_confidence_threshold",
|
||||||
|
field=models.FloatField(default=0.18),
|
||||||
|
),
|
||||||
|
]
|
||||||
@@ -270,6 +270,8 @@ class SearchModelConfig(BaseModel):
|
|||||||
cross_encoder_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True)
|
cross_encoder_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||||
# Inference server API Key to use for embeddings inference. Cross-encoder model should be hosted on this server
|
# Inference server API Key to use for embeddings inference. Cross-encoder model should be hosted on this server
|
||||||
cross_encoder_inference_endpoint_api_key = models.CharField(max_length=200, default=None, null=True, blank=True)
|
cross_encoder_inference_endpoint_api_key = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||||
|
# The confidence threshold of the bi_encoder model to consider the embeddings as relevant
|
||||||
|
bi_encoder_confidence_threshold = models.FloatField(default=0.18)
|
||||||
|
|
||||||
|
|
||||||
class TextToImageModelConfig(BaseModel):
|
class TextToImageModelConfig(BaseModel):
|
||||||
|
|||||||
@@ -83,7 +83,7 @@ async def search(
|
|||||||
n=n,
|
n=n,
|
||||||
t=t,
|
t=t,
|
||||||
r=r,
|
r=r,
|
||||||
max_distance=max_distance,
|
max_distance=max_distance or math.inf,
|
||||||
dedupe=dedupe,
|
dedupe=dedupe,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -117,7 +117,6 @@ async def execute_search(
|
|||||||
# initialize variables
|
# initialize variables
|
||||||
user_query = q.strip()
|
user_query = q.strip()
|
||||||
results_count = n or 5
|
results_count = n or 5
|
||||||
max_distance = max_distance or math.inf
|
|
||||||
search_futures: List[concurrent.futures.Future] = []
|
search_futures: List[concurrent.futures.Future] = []
|
||||||
|
|
||||||
# return cached results, if available
|
# return cached results, if available
|
||||||
|
|||||||
@@ -524,7 +524,7 @@ async def chat(
|
|||||||
common: CommonQueryParams,
|
common: CommonQueryParams,
|
||||||
q: str,
|
q: str,
|
||||||
n: int = 7,
|
n: int = 7,
|
||||||
d: float = 0.18,
|
d: float = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
title: Optional[str] = None,
|
title: Optional[str] = None,
|
||||||
conversation_id: Optional[int] = None,
|
conversation_id: Optional[int] = None,
|
||||||
@@ -764,7 +764,7 @@ async def chat(
|
|||||||
meta_log,
|
meta_log,
|
||||||
q,
|
q,
|
||||||
(n or 7),
|
(n or 7),
|
||||||
(d or 0.18),
|
d,
|
||||||
conversation_id,
|
conversation_id,
|
||||||
conversation_commands,
|
conversation_commands,
|
||||||
location,
|
location,
|
||||||
|
|||||||
@@ -100,18 +100,23 @@ async def query(
|
|||||||
raw_query: str,
|
raw_query: str,
|
||||||
type: SearchType = SearchType.All,
|
type: SearchType = SearchType.All,
|
||||||
question_embedding: Union[torch.Tensor, None] = None,
|
question_embedding: Union[torch.Tensor, None] = None,
|
||||||
max_distance: float = math.inf,
|
max_distance: float = None,
|
||||||
) -> Tuple[List[dict], List[Entry]]:
|
) -> Tuple[List[dict], List[Entry]]:
|
||||||
"Search for entries that answer the query"
|
"Search for entries that answer the query"
|
||||||
|
|
||||||
file_type = search_type_to_embeddings_type[type.value]
|
file_type = search_type_to_embeddings_type[type.value]
|
||||||
|
|
||||||
query = raw_query
|
query = raw_query
|
||||||
|
search_model = await sync_to_async(get_user_search_model_or_default)(user)
|
||||||
|
if not max_distance:
|
||||||
|
if search_model.bi_encoder_confidence_threshold:
|
||||||
|
max_distance = search_model.bi_encoder_confidence_threshold
|
||||||
|
else:
|
||||||
|
max_distance = math.inf
|
||||||
|
|
||||||
# Encode the query using the bi-encoder
|
# Encode the query using the bi-encoder
|
||||||
if question_embedding is None:
|
if question_embedding is None:
|
||||||
with timer("Query Encode Time", logger, state.device):
|
with timer("Query Encode Time", logger, state.device):
|
||||||
search_model = await sync_to_async(get_user_search_model_or_default)(user)
|
|
||||||
question_embedding = state.embeddings_model[search_model.name].embed_query(query)
|
question_embedding = state.embeddings_model[search_model.name].embed_query(query)
|
||||||
|
|
||||||
# Find relevant entries for the query
|
# Find relevant entries for the query
|
||||||
|
|||||||
Reference in New Issue
Block a user