Rename DbModels Embeddings, EmbeddingsAdapter to Entry, EntryAdapter

Improves readability as name has closer match to underlying
constructs

- Entry is any atomic item indexed by Khoj. This can be an org-mode
  entry, a markdown section, a PDF or Notion page etc.

- Embeddings are semantic vectors generated by the search ML model
  that encodes for meaning contained in an entries text.

- An "Entry" contains "Embeddings" vectors but also other metadata
  about the entry like filename etc.
This commit is contained in:
Debanjum Singh Solanky
2023-10-31 18:50:54 -07:00
parent 54a387326c
commit bcbee05a9e
15 changed files with 115 additions and 87 deletions

View File

@@ -27,7 +27,7 @@ from database.models import (
KhojApiUser,
NotionConfig,
GithubConfig,
Embeddings,
Entry,
GithubRepoConfig,
Conversation,
ConversationProcessorConfig,
@@ -286,54 +286,54 @@ class ConversationAdapters:
return await OpenAIProcessorConversationConfig.objects.filter(user=user).afirst()
class EmbeddingsAdapters:
class EntryAdapters:
word_filer = WordFilter()
file_filter = FileFilter()
date_filter = DateFilter()
@staticmethod
def does_embedding_exist(user: KhojUser, hashed_value: str) -> bool:
return Embeddings.objects.filter(user=user, hashed_value=hashed_value).exists()
def does_entry_exist(user: KhojUser, hashed_value: str) -> bool:
return Entry.objects.filter(user=user, hashed_value=hashed_value).exists()
@staticmethod
def delete_embedding_by_file(user: KhojUser, file_path: str):
deleted_count, _ = Embeddings.objects.filter(user=user, file_path=file_path).delete()
def delete_entry_by_file(user: KhojUser, file_path: str):
deleted_count, _ = Entry.objects.filter(user=user, file_path=file_path).delete()
return deleted_count
@staticmethod
def delete_all_embeddings(user: KhojUser, file_type: str):
deleted_count, _ = Embeddings.objects.filter(user=user, file_type=file_type).delete()
def delete_all_entries(user: KhojUser, file_type: str):
deleted_count, _ = Entry.objects.filter(user=user, file_type=file_type).delete()
return deleted_count
@staticmethod
def get_existing_entry_hashes_by_file(user: KhojUser, file_path: str):
return Embeddings.objects.filter(user=user, file_path=file_path).values_list("hashed_value", flat=True)
return Entry.objects.filter(user=user, file_path=file_path).values_list("hashed_value", flat=True)
@staticmethod
def delete_embedding_by_hash(user: KhojUser, hashed_values: List[str]):
Embeddings.objects.filter(user=user, hashed_value__in=hashed_values).delete()
def delete_entry_by_hash(user: KhojUser, hashed_values: List[str]):
Entry.objects.filter(user=user, hashed_value__in=hashed_values).delete()
@staticmethod
def get_embeddings_by_date_filter(embeddings: BaseManager[Embeddings], start_date: date, end_date: date):
return embeddings.filter(
embeddingsdates__date__gte=start_date,
embeddingsdates__date__lte=end_date,
def get_entries_by_date_filter(entry: BaseManager[Entry], start_date: date, end_date: date):
return entry.filter(
entrydates__date__gte=start_date,
entrydates__date__lte=end_date,
)
@staticmethod
async def user_has_embeddings(user: KhojUser):
return await Embeddings.objects.filter(user=user).aexists()
async def user_has_entries(user: KhojUser):
return await Entry.objects.filter(user=user).aexists()
@staticmethod
def apply_filters(user: KhojUser, query: str, file_type_filter: str = None):
q_filter_terms = Q()
explicit_word_terms = EmbeddingsAdapters.word_filer.get_filter_terms(query)
file_filters = EmbeddingsAdapters.file_filter.get_filter_terms(query)
date_filters = EmbeddingsAdapters.date_filter.get_query_date_range(query)
explicit_word_terms = EntryAdapters.word_filer.get_filter_terms(query)
file_filters = EntryAdapters.file_filter.get_filter_terms(query)
date_filters = EntryAdapters.date_filter.get_query_date_range(query)
if len(explicit_word_terms) == 0 and len(file_filters) == 0 and len(date_filters) == 0:
return Embeddings.objects.filter(user=user)
return Entry.objects.filter(user=user)
for term in explicit_word_terms:
if term.startswith("+"):
@@ -354,32 +354,32 @@ class EmbeddingsAdapters:
if min_date is not None:
# Convert the min_date timestamp to yyyy-mm-dd format
formatted_min_date = date.fromtimestamp(min_date).strftime("%Y-%m-%d")
q_filter_terms &= Q(embeddings_dates__date__gte=formatted_min_date)
q_filter_terms &= Q(entry_dates__date__gte=formatted_min_date)
if max_date is not None:
# Convert the max_date timestamp to yyyy-mm-dd format
formatted_max_date = date.fromtimestamp(max_date).strftime("%Y-%m-%d")
q_filter_terms &= Q(embeddings_dates__date__lte=formatted_max_date)
q_filter_terms &= Q(entry_dates__date__lte=formatted_max_date)
relevant_embeddings = Embeddings.objects.filter(user=user).filter(
relevant_entries = Entry.objects.filter(user=user).filter(
q_filter_terms,
)
if file_type_filter:
relevant_embeddings = relevant_embeddings.filter(file_type=file_type_filter)
return relevant_embeddings
relevant_entries = relevant_entries.filter(file_type=file_type_filter)
return relevant_entries
@staticmethod
def search_with_embeddings(
user: KhojUser, embeddings: Tensor, max_results: int = 10, file_type_filter: str = None, raw_query: str = None
):
relevant_embeddings = EmbeddingsAdapters.apply_filters(user, raw_query, file_type_filter)
relevant_embeddings = relevant_embeddings.filter(user=user).annotate(
relevant_entries = EntryAdapters.apply_filters(user, raw_query, file_type_filter)
relevant_entries = relevant_entries.filter(user=user).annotate(
distance=CosineDistance("embeddings", embeddings)
)
if file_type_filter:
relevant_embeddings = relevant_embeddings.filter(file_type=file_type_filter)
relevant_embeddings = relevant_embeddings.order_by("distance")
return relevant_embeddings[:max_results]
relevant_entries = relevant_entries.filter(file_type=file_type_filter)
relevant_entries = relevant_entries.order_by("distance")
return relevant_entries[:max_results]
@staticmethod
def get_unique_file_types(user: KhojUser):
return Embeddings.objects.filter(user=user).values_list("file_type", flat=True).distinct()
return Entry.objects.filter(user=user).values_list("file_type", flat=True).distinct()