diff --git a/src/database/adapters/__init__.py b/src/database/adapters/__init__.py index fa37aa99..69a3c1f4 100644 --- a/src/database/adapters/__init__.py +++ b/src/database/adapters/__init__.py @@ -287,13 +287,21 @@ class EntryAdapters: return deleted_count @staticmethod - def delete_all_entries(user: KhojUser, file_type: str = None): + def delete_all_entries_by_type(user: KhojUser, file_type: str = None): if file_type is None: deleted_count, _ = Entry.objects.filter(user=user).delete() else: deleted_count, _ = Entry.objects.filter(user=user, file_type=file_type).delete() return deleted_count + @staticmethod + def delete_all_entries_by_source(user: KhojUser, file_source: str = None): + if file_source is None: + deleted_count, _ = Entry.objects.filter(user=user).delete() + else: + deleted_count, _ = Entry.objects.filter(user=user, file_source=file_source).delete() + return deleted_count + @staticmethod def get_existing_entry_hashes_by_file(user: KhojUser, file_path: str): return Entry.objects.filter(user=user, file_path=file_path).values_list("hashed_value", flat=True) @@ -318,8 +326,12 @@ class EntryAdapters: return await Entry.objects.filter(user=user, file_path=file_path).adelete() @staticmethod - def aget_all_filenames(user: KhojUser): - return Entry.objects.filter(user=user).distinct("file_path").values_list("file_path", flat=True) + def aget_all_filenames_by_source(user: KhojUser, file_source: str): + return ( + Entry.objects.filter(user=user, file_source=file_source) + .distinct("file_path") + .values_list("file_path", flat=True) + ) @staticmethod async def adelete_all_entries(user: KhojUser): @@ -384,3 +396,7 @@ class EntryAdapters: @staticmethod def get_unique_file_types(user: KhojUser): return Entry.objects.filter(user=user).values_list("file_type", flat=True).distinct() + + @staticmethod + def get_unique_file_source(user: KhojUser): + return Entry.objects.filter(user=user).values_list("file_source", flat=True).distinct() diff --git a/src/database/migrations/0012_entry_file_source.py b/src/database/migrations/0012_entry_file_source.py new file mode 100644 index 00000000..187136ae --- /dev/null +++ b/src/database/migrations/0012_entry_file_source.py @@ -0,0 +1,21 @@ +# Generated by Django 4.2.5 on 2023-11-07 07:24 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0011_merge_20231102_0138"), + ] + + operations = [ + migrations.AddField( + model_name="entry", + name="file_source", + field=models.CharField( + choices=[("computer", "Computer"), ("notion", "Notion"), ("github", "Github")], + default="computer", + max_length=30, + ), + ), + ] diff --git a/src/database/models/__init__.py b/src/database/models/__init__.py index 5dd9622b..b1be9ded 100644 --- a/src/database/models/__init__.py +++ b/src/database/models/__init__.py @@ -131,11 +131,17 @@ class Entry(BaseModel): GITHUB = "github" CONVERSATION = "conversation" + class EntrySource(models.TextChoices): + COMPUTER = "computer" + NOTION = "notion" + GITHUB = "github" + user = models.ForeignKey(KhojUser, on_delete=models.CASCADE, default=None, null=True, blank=True) embeddings = VectorField(dimensions=384) raw = models.TextField() compiled = models.TextField() heading = models.CharField(max_length=1000, default=None, null=True, blank=True) + file_source = models.CharField(max_length=30, choices=EntrySource.choices, default=EntrySource.COMPUTER) file_type = models.CharField(max_length=30, choices=EntryType.choices, default=EntryType.PLAINTEXT) file_path = models.CharField(max_length=400, default=None, null=True, blank=True) file_name = models.CharField(max_length=400, default=None, null=True, blank=True) diff --git a/src/khoj/processor/github/github_to_entries.py b/src/khoj/processor/github/github_to_entries.py index 14e9b696..56279453 100644 --- a/src/khoj/processor/github/github_to_entries.py +++ b/src/khoj/processor/github/github_to_entries.py @@ -104,7 +104,12 @@ class GithubToEntries(TextToEntries): # Identify, mark and merge any new entries with previous entries with timer("Identify new or updated entries", logger): num_new_embeddings, num_deleted_embeddings = self.update_embeddings( - current_entries, DbEntry.EntryType.GITHUB, key="compiled", logger=logger, user=user + current_entries, + DbEntry.EntryType.GITHUB, + DbEntry.EntrySource.GITHUB, + key="compiled", + logger=logger, + user=user, ) return num_new_embeddings, num_deleted_embeddings diff --git a/src/khoj/processor/markdown/markdown_to_entries.py b/src/khoj/processor/markdown/markdown_to_entries.py index e0b76368..0dd71740 100644 --- a/src/khoj/processor/markdown/markdown_to_entries.py +++ b/src/khoj/processor/markdown/markdown_to_entries.py @@ -47,6 +47,7 @@ class MarkdownToEntries(TextToEntries): num_new_embeddings, num_deleted_embeddings = self.update_embeddings( current_entries, DbEntry.EntryType.MARKDOWN, + DbEntry.EntrySource.COMPUTER, "compiled", logger, deletion_file_names, diff --git a/src/khoj/processor/notion/notion_to_entries.py b/src/khoj/processor/notion/notion_to_entries.py index a4b15d4e..7a88e2a1 100644 --- a/src/khoj/processor/notion/notion_to_entries.py +++ b/src/khoj/processor/notion/notion_to_entries.py @@ -250,7 +250,12 @@ class NotionToEntries(TextToEntries): # Identify, mark and merge any new entries with previous entries with timer("Identify new or updated entries", logger): num_new_embeddings, num_deleted_embeddings = self.update_embeddings( - current_entries, DbEntry.EntryType.NOTION, key="compiled", logger=logger, user=user + current_entries, + DbEntry.EntryType.NOTION, + DbEntry.EntrySource.NOTION, + key="compiled", + logger=logger, + user=user, ) return num_new_embeddings, num_deleted_embeddings diff --git a/src/khoj/processor/org_mode/org_to_entries.py b/src/khoj/processor/org_mode/org_to_entries.py index bf6df6dc..04ce97e4 100644 --- a/src/khoj/processor/org_mode/org_to_entries.py +++ b/src/khoj/processor/org_mode/org_to_entries.py @@ -48,6 +48,7 @@ class OrgToEntries(TextToEntries): num_new_embeddings, num_deleted_embeddings = self.update_embeddings( current_entries, DbEntry.EntryType.ORG, + DbEntry.EntrySource.COMPUTER, "compiled", logger, deletion_file_names, diff --git a/src/khoj/processor/pdf/pdf_to_entries.py b/src/khoj/processor/pdf/pdf_to_entries.py index 81c2250f..3a47096a 100644 --- a/src/khoj/processor/pdf/pdf_to_entries.py +++ b/src/khoj/processor/pdf/pdf_to_entries.py @@ -46,6 +46,7 @@ class PdfToEntries(TextToEntries): num_new_embeddings, num_deleted_embeddings = self.update_embeddings( current_entries, DbEntry.EntryType.PDF, + DbEntry.EntrySource.COMPUTER, "compiled", logger, deletion_file_names, diff --git a/src/khoj/processor/plaintext/plaintext_to_entries.py b/src/khoj/processor/plaintext/plaintext_to_entries.py index fd5e1de2..d42dae30 100644 --- a/src/khoj/processor/plaintext/plaintext_to_entries.py +++ b/src/khoj/processor/plaintext/plaintext_to_entries.py @@ -56,6 +56,7 @@ class PlaintextToEntries(TextToEntries): num_new_embeddings, num_deleted_embeddings = self.update_embeddings( current_entries, DbEntry.EntryType.PLAINTEXT, + DbEntry.EntrySource.COMPUTER, key="compiled", logger=logger, deletion_filenames=deletion_file_names, diff --git a/src/khoj/processor/text_to_entries.py b/src/khoj/processor/text_to_entries.py index 4661fd9b..3d79e02e 100644 --- a/src/khoj/processor/text_to_entries.py +++ b/src/khoj/processor/text_to_entries.py @@ -78,6 +78,7 @@ class TextToEntries(ABC): self, current_entries: List[Entry], file_type: str, + file_source: str, key="compiled", logger: logging.Logger = None, deletion_filenames: Set[str] = None, @@ -95,7 +96,7 @@ class TextToEntries(ABC): if regenerate: with timer("Cleared existing dataset for regeneration in", logger): logger.debug(f"Deleting all entries for file type {file_type}") - num_deleted_entries = EntryAdapters.delete_all_entries(user, file_type) + num_deleted_entries = EntryAdapters.delete_all_entries_by_type(user, file_type) hashes_to_process = set() with timer("Identified entries to add to database in", logger): @@ -132,6 +133,7 @@ class TextToEntries(ABC): compiled=entry.compiled, heading=entry.heading[:1000], # Truncate to max chars of field allowed file_path=entry.file, + file_source=file_source, file_type=file_type, hashed_value=entry_hash, corpus_id=entry.corpus_id, diff --git a/src/khoj/search_type/text_search.py b/src/khoj/search_type/text_search.py index 14f5b770..ba2fc9ec 100644 --- a/src/khoj/search_type/text_search.py +++ b/src/khoj/search_type/text_search.py @@ -204,11 +204,12 @@ def setup( files=files, full_corpus=full_corpus, user=user, regenerate=regenerate ) - file_names = [file_name for file_name in files] + if files: + file_names = [file_name for file_name in files] - logger.info( - f"Deleted {num_deleted_embeddings} entries. Created {num_new_embeddings} new entries for user {user} from files {file_names}" - ) + logger.info( + f"Deleted {num_deleted_embeddings} entries. Created {num_new_embeddings} new entries for user {user} from files {file_names}" + ) def cross_encoder_score(query: str, hits: List[SearchResponse]) -> List[SearchResponse]: diff --git a/tests/test_text_search.py b/tests/test_text_search.py index 7d8c30fb..3d729ab5 100644 --- a/tests/test_text_search.py +++ b/tests/test_text_search.py @@ -58,7 +58,7 @@ def test_get_org_files_with_org_suffixed_dir_doesnt_raise_error(tmp_path, defaul # ---------------------------------------------------------------------------------------------------- @pytest.mark.django_db -def test_text_search_setup_with_empty_file_raises_error( +def test_text_search_setup_with_empty_file_creates_no_entries( org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog ): # Arrange