diff --git a/src/khoj/configure.py b/src/khoj/configure.py index a1f4a7db..002413b8 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -253,7 +253,7 @@ def configure_server( logger.info(message) if not init: - initialize_content(regenerate, search_type, user) + initialize_content(user, regenerate, search_type) except Exception as e: logger.error(f"Failed to load some search models: {e}", exc_info=True) @@ -263,17 +263,17 @@ def setup_default_agent(user: KhojUser): AgentAdapters.create_default_agent(user) -def initialize_content(regenerate: bool, search_type: Optional[SearchType] = None, user: KhojUser = None): +def initialize_content(user: KhojUser, regenerate: bool, search_type: Optional[SearchType] = None): # Initialize Content from Config if state.search_models: try: logger.info("📬 Updating content index...") all_files = collect_files(user=user) status = configure_content( + user, all_files, regenerate, search_type, - user=user, ) if not status: raise RuntimeError("Failed to update content index") @@ -338,9 +338,7 @@ def configure_middleware(app): def update_content_index(): for user in get_all_users(): all_files = collect_files(user=user) - success = configure_content(all_files, user=user) - all_files = collect_files(user=None) - success = configure_content(all_files, user=None) + success = configure_content(user, all_files) if not success: raise RuntimeError("Failed to update content index") logger.info("📪 Content index updated via Scheduler") diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 6676eefa..8538e217 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -8,13 +8,22 @@ import secrets import sys from datetime import date, datetime, timedelta, timezone from enum import Enum -from typing import Callable, Iterable, List, Optional, Type +from functools import wraps +from typing import ( + Any, + Callable, + Coroutine, + Iterable, + List, + Optional, + ParamSpec, + TypeVar, +) import cron_descriptor from apscheduler.job import Job from asgiref.sync import sync_to_async from django.contrib.sessions.backends.db import SessionStore -from django.db import models from django.db.models import Prefetch, Q from django.db.models.manager import BaseManager from django.db.utils import IntegrityError @@ -28,7 +37,6 @@ from khoj.database.models import ( ChatModelOptions, ClientApplication, Conversation, - DataStore, Entry, FileObject, GithubConfig, @@ -80,6 +88,45 @@ class SubscriptionState(Enum): INVALID = "invalid" +P = ParamSpec("P") +T = TypeVar("T") + + +def require_valid_user(func: Callable[P, T]) -> Callable[P, T]: + @wraps(func) + def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: + # Extract user from args/kwargs + user = next((arg for arg in args if isinstance(arg, KhojUser)), None) + if not user: + user = next((val for val in kwargs.values() if isinstance(val, KhojUser)), None) + + # Throw error if user is not found + if not user: + raise ValueError("Khoj user argument required but not provided.") + + return func(*args, **kwargs) + + return sync_wrapper + + +def arequire_valid_user(func: Callable[P, Coroutine[Any, Any, T]]) -> Callable[P, Coroutine[Any, Any, T]]: + @wraps(func) + async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: + # Extract user from args/kwargs + user = next((arg for arg in args if isinstance(arg, KhojUser)), None) + if not user: + user = next((v for v in kwargs.values() if isinstance(v, KhojUser)), None) + + # Throw error if user is not found + if not user: + raise ValueError("Khoj user argument required but not provided.") + + return await func(*args, **kwargs) + + return async_wrapper + + +@arequire_valid_user async def set_notion_config(token: str, user: KhojUser): notion_config = await NotionConfig.objects.filter(user=user).afirst() if not notion_config: @@ -90,6 +137,7 @@ async def set_notion_config(token: str, user: KhojUser): return notion_config +@require_valid_user def create_khoj_token(user: KhojUser, name=None): "Create Khoj API key for user" token = f"kk-{secrets.token_urlsafe(32)}" @@ -97,6 +145,7 @@ def create_khoj_token(user: KhojUser, name=None): return KhojApiUser.objects.create(token=token, user=user, name=name) +@arequire_valid_user async def acreate_khoj_token(user: KhojUser, name=None): "Create Khoj API key for user" token = f"kk-{secrets.token_urlsafe(32)}" @@ -104,11 +153,13 @@ async def acreate_khoj_token(user: KhojUser, name=None): return await KhojApiUser.objects.acreate(token=token, user=user, name=name) +@require_valid_user def get_khoj_tokens(user: KhojUser): "Get all Khoj API keys for user" return list(KhojApiUser.objects.filter(user=user)) +@arequire_valid_user async def delete_khoj_token(user: KhojUser, token: str): "Delete Khoj API Key for user" await KhojApiUser.objects.filter(token=token, user=user).adelete() @@ -132,6 +183,7 @@ async def aget_or_create_user_by_phone_number(phone_number: str) -> tuple[KhojUs return user, is_new +@arequire_valid_user async def aset_user_phone_number(user: KhojUser, phone_number: str) -> KhojUser: if is_none_or_empty(phone_number): return None @@ -155,6 +207,7 @@ async def aset_user_phone_number(user: KhojUser, phone_number: str) -> KhojUser: return user +@arequire_valid_user async def aremove_phone_number(user: KhojUser) -> KhojUser: user.phone_number = None user.verified_phone_number = False @@ -192,6 +245,7 @@ async def aget_or_create_user_by_email(email: str) -> tuple[KhojUser, bool]: return user, is_new +@arequire_valid_user async def astart_trial_subscription(user: KhojUser) -> Subscription: subscription = await Subscription.objects.filter(user=user).afirst() if not subscription: @@ -246,6 +300,7 @@ async def create_user_by_google_token(token: dict) -> KhojUser: return user +@require_valid_user def set_user_name(user: KhojUser, first_name: str, last_name: str) -> KhojUser: user.first_name = first_name user.last_name = last_name @@ -253,6 +308,7 @@ def set_user_name(user: KhojUser, first_name: str, last_name: str) -> KhojUser: return user +@require_valid_user def get_user_name(user: KhojUser): full_name = user.get_full_name() if not is_none_or_empty(full_name): @@ -264,6 +320,7 @@ def get_user_name(user: KhojUser): return None +@require_valid_user def get_user_photo(user: KhojUser): google_profile: GoogleUser = GoogleUser.objects.filter(user=user).first() if google_profile: @@ -327,6 +384,7 @@ def get_user_subscription_state(email: str) -> str: return subscription_to_state(user_subscription) +@arequire_valid_user async def aget_user_subscription_state(user: KhojUser) -> str: """Get subscription state of user Valid state transitions: trial -> subscribed <-> unsubscribed OR expired @@ -335,6 +393,7 @@ async def aget_user_subscription_state(user: KhojUser) -> str: return await sync_to_async(subscription_to_state)(user_subscription) +@arequire_valid_user async def ais_user_subscribed(user: KhojUser) -> bool: """ Get whether the user is subscribed @@ -351,6 +410,7 @@ async def ais_user_subscribed(user: KhojUser) -> bool: return subscribed +@require_valid_user def is_user_subscribed(user: KhojUser) -> bool: """ Get whether the user is subscribed @@ -416,11 +476,13 @@ def get_all_users() -> BaseManager[KhojUser]: return KhojUser.objects.all() +@require_valid_user def get_user_github_config(user: KhojUser): config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first() return config +@require_valid_user def get_user_notion_config(user: KhojUser): config = NotionConfig.objects.filter(user=user).first() return config @@ -430,6 +492,7 @@ def delete_user_requests(window: timedelta = timedelta(days=1)): return UserRequests.objects.filter(created_at__lte=datetime.now(tz=timezone.utc) - window).delete() +@arequire_valid_user async def aget_user_name(user: KhojUser): full_name = user.get_full_name() if not is_none_or_empty(full_name): @@ -441,18 +504,7 @@ async def aget_user_name(user: KhojUser): return None -async def set_text_content_config(user: KhojUser, object: Type[models.Model], updated_config): - deduped_files = list(set(updated_config.input_files)) if updated_config.input_files else None - deduped_filters = list(set(updated_config.input_filter)) if updated_config.input_filter else None - await object.objects.filter(user=user).adelete() - await object.objects.acreate( - input_files=deduped_files, - input_filter=deduped_filters, - index_heading_entries=updated_config.index_heading_entries, - user=user, - ) - - +@arequire_valid_user async def set_user_github_config(user: KhojUser, pat_token: str, repos: list): config = await GithubConfig.objects.filter(user=user).afirst() @@ -587,8 +639,11 @@ class AgentAdapters: ) @staticmethod + @arequire_valid_user async def adelete_agent_by_slug(agent_slug: str, user: KhojUser): agent = await AgentAdapters.aget_agent_by_slug(agent_slug, user) + if agent.creator != user: + return False async for entry in Entry.objects.filter(agent=agent).aiterator(): await entry.adelete() @@ -712,6 +767,7 @@ class AgentAdapters: return await Agent.objects.filter(name=AgentAdapters.DEFAULT_AGENT_NAME).afirst() @staticmethod + @arequire_valid_user async def aupdate_agent( user: KhojUser, name: str, @@ -787,19 +843,6 @@ class PublicConversationAdapters: return f"/share/chat/{public_conversation.slug}/" -class DataStoreAdapters: - @staticmethod - async def astore_data(data: dict, key: str, user: KhojUser, private: bool = True): - if await DataStore.objects.filter(key=key).aexists(): - return key - await DataStore.objects.acreate(value=data, key=key, owner=user, private=private) - return key - - @staticmethod - async def aretrieve_public_data(key: str): - return await DataStore.objects.filter(key=key, private=False).afirst() - - class ConversationAdapters: @staticmethod def make_public_conversation_copy(conversation: Conversation): @@ -812,6 +855,7 @@ class ConversationAdapters: ) @staticmethod + @require_valid_user def get_conversation_by_user( user: KhojUser, client_application: ClientApplication = None, conversation_id: str = None ) -> Optional[Conversation]: @@ -830,6 +874,7 @@ class ConversationAdapters: return conversation @staticmethod + @require_valid_user def get_conversation_sessions(user: KhojUser, client_application: ClientApplication = None): return ( Conversation.objects.filter(user=user, client=client_application) @@ -838,6 +883,7 @@ class ConversationAdapters: ) @staticmethod + @arequire_valid_user async def aset_conversation_title( user: KhojUser, client_application: ClientApplication, conversation_id: str, title: str ): @@ -855,6 +901,7 @@ class ConversationAdapters: return Conversation.objects.filter(id=conversation_id).first() @staticmethod + @arequire_valid_user async def acreate_conversation_session( user: KhojUser, client_application: ClientApplication = None, agent_slug: str = None, title: str = None ): @@ -871,6 +918,7 @@ class ConversationAdapters: ) @staticmethod + @require_valid_user def create_conversation_session( user: KhojUser, client_application: ClientApplication = None, agent_slug: str = None, title: str = None ): @@ -883,6 +931,7 @@ class ConversationAdapters: return Conversation.objects.create(user=user, client=client_application, agent=agent, title=title) @staticmethod + @arequire_valid_user async def aget_conversation_by_user( user: KhojUser, client_application: ClientApplication = None, @@ -907,6 +956,7 @@ class ConversationAdapters: ) @staticmethod + @arequire_valid_user async def adelete_conversation_by_user( user: KhojUser, client_application: ClientApplication = None, conversation_id: str = None ): @@ -915,6 +965,7 @@ class ConversationAdapters: return await Conversation.objects.filter(user=user, client=client_application).adelete() @staticmethod + @require_valid_user def has_any_conversation_config(user: KhojUser): return ChatModelOptions.objects.filter(user=user).exists() @@ -951,6 +1002,7 @@ class ConversationAdapters: return OpenAIProcessorConversationConfig.objects.filter().exists() @staticmethod + @arequire_valid_user async def aset_user_conversation_processor(user: KhojUser, conversation_processor_config_id: int): config = await ChatModelOptions.objects.filter(id=conversation_processor_config_id).afirst() if not config: @@ -959,6 +1011,7 @@ class ConversationAdapters: return new_config @staticmethod + @arequire_valid_user async def aset_user_voice_model(user: KhojUser, model_id: str): config = await VoiceModelOption.objects.filter(model_id=model_id).afirst() if not config: @@ -1143,6 +1196,7 @@ class ConversationAdapters: return enabled_scrapers @staticmethod + @require_valid_user def create_conversation_from_public_conversation( user: KhojUser, public_conversation: PublicConversation, client_app: ClientApplication ): @@ -1159,6 +1213,7 @@ class ConversationAdapters: ) @staticmethod + @require_valid_user def save_conversation( user: KhojUser, conversation_log: dict, @@ -1208,6 +1263,7 @@ class ConversationAdapters: return await SpeechToTextModelOptions.objects.filter().afirst() @staticmethod + @arequire_valid_user async def aget_conversation_starters(user: KhojUser, max_results=3): all_questions = [] if await ReflectiveQuestion.objects.filter(user=user).aexists(): @@ -1337,6 +1393,7 @@ class ConversationAdapters: return conversation.file_filters @staticmethod + @require_valid_user def delete_message_by_turn_id(user: KhojUser, conversation_id: str, turn_id: str): conversation = ConversationAdapters.get_conversation_by_user(user, conversation_id=conversation_id) if not conversation or not conversation.conversation_log or not conversation.conversation_log.get("chat"): @@ -1355,52 +1412,63 @@ class FileObjectAdapters: file_object.save() @staticmethod + @require_valid_user def create_file_object(user: KhojUser, file_name: str, raw_text: str): return FileObject.objects.create(user=user, file_name=file_name, raw_text=raw_text) @staticmethod + @require_valid_user def get_file_object_by_name(user: KhojUser, file_name: str): return FileObject.objects.filter(user=user, file_name=file_name).first() @staticmethod + @require_valid_user def get_all_file_objects(user: KhojUser): return FileObject.objects.filter(user=user).all() @staticmethod + @require_valid_user def delete_file_object_by_name(user: KhojUser, file_name: str): return FileObject.objects.filter(user=user, file_name=file_name).delete() @staticmethod + @require_valid_user def delete_all_file_objects(user: KhojUser): return FileObject.objects.filter(user=user).delete() @staticmethod - async def async_update_raw_text(file_object: FileObject, new_raw_text: str): + async def aupdate_raw_text(file_object: FileObject, new_raw_text: str): file_object.raw_text = new_raw_text await file_object.asave() @staticmethod - async def async_create_file_object(user: KhojUser, file_name: str, raw_text: str): + @arequire_valid_user + async def acreate_file_object(user: KhojUser, file_name: str, raw_text: str): return await FileObject.objects.acreate(user=user, file_name=file_name, raw_text=raw_text) @staticmethod - async def async_get_file_objects_by_name(user: KhojUser, file_name: str, agent: Agent = None): + @arequire_valid_user + async def aget_file_objects_by_name(user: KhojUser, file_name: str, agent: Agent = None): return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name=file_name, agent=agent)) @staticmethod - async def async_get_file_objects_by_names(user: KhojUser, file_names: List[str]): + @arequire_valid_user + async def aget_file_objects_by_names(user: KhojUser, file_names: List[str]): return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name__in=file_names)) @staticmethod - async def async_get_all_file_objects(user: KhojUser): + @arequire_valid_user + async def aget_all_file_objects(user: KhojUser): return await sync_to_async(list)(FileObject.objects.filter(user=user)) @staticmethod - async def async_delete_file_object_by_name(user: KhojUser, file_name: str): + @arequire_valid_user + async def adelete_file_object_by_name(user: KhojUser, file_name: str): return await FileObject.objects.filter(user=user, file_name=file_name).adelete() @staticmethod - async def async_delete_all_file_objects(user: KhojUser): + @arequire_valid_user + async def adelete_all_file_objects(user: KhojUser): return await FileObject.objects.filter(user=user).adelete() @@ -1410,15 +1478,18 @@ class EntryAdapters: date_filter = DateFilter() @staticmethod + @require_valid_user def does_entry_exist(user: KhojUser, hashed_value: str) -> bool: return Entry.objects.filter(user=user, hashed_value=hashed_value).exists() @staticmethod + @require_valid_user def delete_entry_by_file(user: KhojUser, file_path: str): deleted_count, _ = Entry.objects.filter(user=user, file_path=file_path).delete() return deleted_count @staticmethod + @require_valid_user def get_filtered_entries(user: KhojUser, file_type: str = None, file_source: str = None): queryset = Entry.objects.filter(user=user) @@ -1431,6 +1502,7 @@ class EntryAdapters: return queryset @staticmethod + @require_valid_user def delete_all_entries(user: KhojUser, file_type: str = None, file_source: str = None, batch_size=1000): deleted_count = 0 queryset = EntryAdapters.get_filtered_entries(user, file_type, file_source) @@ -1442,6 +1514,7 @@ class EntryAdapters: return deleted_count @staticmethod + @arequire_valid_user async def adelete_all_entries(user: KhojUser, file_type: str = None, file_source: str = None, batch_size=1000): deleted_count = 0 queryset = EntryAdapters.get_filtered_entries(user, file_type, file_source) @@ -1453,10 +1526,12 @@ class EntryAdapters: return deleted_count @staticmethod + @require_valid_user def get_existing_entry_hashes_by_file(user: KhojUser, file_path: str): return Entry.objects.filter(user=user, file_path=file_path).values_list("hashed_value", flat=True) @staticmethod + @require_valid_user def delete_entry_by_hash(user: KhojUser, hashed_values: List[str]): Entry.objects.filter(user=user, hashed_value__in=hashed_values).delete() @@ -1468,6 +1543,7 @@ class EntryAdapters: ) @staticmethod + @require_valid_user def user_has_entries(user: KhojUser): return Entry.objects.filter(user=user).exists() @@ -1476,6 +1552,7 @@ class EntryAdapters: return Entry.objects.filter(agent=agent).exists() @staticmethod + @arequire_valid_user async def auser_has_entries(user: KhojUser): return await Entry.objects.filter(user=user).aexists() @@ -1486,10 +1563,12 @@ class EntryAdapters: return await Entry.objects.filter(agent=agent).aexists() @staticmethod + @arequire_valid_user async def adelete_entry_by_file(user: KhojUser, file_path: str): return await Entry.objects.filter(user=user, file_path=file_path).adelete() @staticmethod + @arequire_valid_user async def adelete_entries_by_filenames(user: KhojUser, filenames: List[str], batch_size=1000): deleted_count = 0 for i in range(0, len(filenames), batch_size): @@ -1508,6 +1587,7 @@ class EntryAdapters: ) @staticmethod + @require_valid_user def get_all_filenames_by_source(user: KhojUser, file_source: str): return ( Entry.objects.filter(user=user, file_source=file_source) @@ -1516,6 +1596,7 @@ class EntryAdapters: ) @staticmethod + @require_valid_user def get_size_of_indexed_data_in_mb(user: KhojUser): entries = Entry.objects.filter(user=user).iterator() total_size = sum(sys.getsizeof(entry.compiled) for entry in entries) @@ -1536,6 +1617,9 @@ class EntryAdapters: if agent != None: owner_filter |= Q(agent=agent) + if owner_filter == Q(): + return Entry.objects.none() + if len(word_filters) == 0 and len(file_filters) == 0 and len(date_filters) == 0: return Entry.objects.filter(owner_filter) @@ -1610,10 +1694,12 @@ class EntryAdapters: return relevant_entries[:max_results] @staticmethod + @require_valid_user def get_unique_file_types(user: KhojUser): return Entry.objects.filter(user=user).values_list("file_type", flat=True).distinct() @staticmethod + @require_valid_user def get_unique_file_sources(user: KhojUser): return Entry.objects.filter(user=user).values_list("file_source", flat=True).distinct().all() diff --git a/src/khoj/processor/content/docx/docx_to_entries.py b/src/khoj/processor/content/docx/docx_to_entries.py index 19d9ba13..35c634f7 100644 --- a/src/khoj/processor/content/docx/docx_to_entries.py +++ b/src/khoj/processor/content/docx/docx_to_entries.py @@ -18,7 +18,7 @@ class DocxToEntries(TextToEntries): super().__init__() # Define Functions - def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]: + def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]: # Extract required fields from config deletion_file_names = set([file for file in files if files[file] == b""]) files_to_process = set(files) - deletion_file_names @@ -35,13 +35,13 @@ class DocxToEntries(TextToEntries): # Identify, mark and merge any new entries with previous entries with timer("Identify new or updated entries", logger): num_new_embeddings, num_deleted_embeddings = self.update_embeddings( + user, current_entries, DbEntry.EntryType.DOCX, DbEntry.EntrySource.COMPUTER, "compiled", logger, deletion_file_names, - user, regenerate=regenerate, file_to_text_map=file_to_text_map, ) diff --git a/src/khoj/processor/content/github/github_to_entries.py b/src/khoj/processor/content/github/github_to_entries.py index 1f3dea00..2381bea8 100644 --- a/src/khoj/processor/content/github/github_to_entries.py +++ b/src/khoj/processor/content/github/github_to_entries.py @@ -48,7 +48,7 @@ class GithubToEntries(TextToEntries): else: return - def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]: + def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]: if self.config.pat_token is None or self.config.pat_token == "": logger.error(f"Github PAT token is not set. Skipping github content") raise ValueError("Github PAT token is not set. Skipping github content") @@ -101,12 +101,12 @@ class GithubToEntries(TextToEntries): # Identify, mark and merge any new entries with previous entries with timer("Identify new or updated entries", logger): num_new_embeddings, num_deleted_embeddings = self.update_embeddings( + user, current_entries, DbEntry.EntryType.GITHUB, DbEntry.EntrySource.GITHUB, key="compiled", logger=logger, - user=user, ) return num_new_embeddings, num_deleted_embeddings diff --git a/src/khoj/processor/content/images/image_to_entries.py b/src/khoj/processor/content/images/image_to_entries.py index 87b9a009..134cca52 100644 --- a/src/khoj/processor/content/images/image_to_entries.py +++ b/src/khoj/processor/content/images/image_to_entries.py @@ -18,7 +18,7 @@ class ImageToEntries(TextToEntries): super().__init__() # Define Functions - def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]: + def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]: # Extract required fields from config deletion_file_names = set([file for file in files if files[file] == b""]) files_to_process = set(files) - deletion_file_names @@ -35,13 +35,13 @@ class ImageToEntries(TextToEntries): # Identify, mark and merge any new entries with previous entries with timer("Identify new or updated entries", logger): num_new_embeddings, num_deleted_embeddings = self.update_embeddings( + user, current_entries, DbEntry.EntryType.IMAGE, DbEntry.EntrySource.COMPUTER, "compiled", logger, deletion_file_names, - user, regenerate=regenerate, file_to_text_map=file_to_text_map, ) diff --git a/src/khoj/processor/content/markdown/markdown_to_entries.py b/src/khoj/processor/content/markdown/markdown_to_entries.py index fdb0c549..c4ee03ef 100644 --- a/src/khoj/processor/content/markdown/markdown_to_entries.py +++ b/src/khoj/processor/content/markdown/markdown_to_entries.py @@ -19,7 +19,7 @@ class MarkdownToEntries(TextToEntries): super().__init__() # Define Functions - def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]: + def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]: # Extract required fields from config deletion_file_names = set([file for file in files if files[file] == ""]) files_to_process = set(files) - deletion_file_names @@ -37,13 +37,13 @@ class MarkdownToEntries(TextToEntries): # Identify, mark and merge any new entries with previous entries with timer("Identify new or updated entries", logger): num_new_embeddings, num_deleted_embeddings = self.update_embeddings( + user, current_entries, DbEntry.EntryType.MARKDOWN, DbEntry.EntrySource.COMPUTER, "compiled", logger, deletion_file_names, - user, regenerate=regenerate, file_to_text_map=file_to_text_map, ) diff --git a/src/khoj/processor/content/notion/notion_to_entries.py b/src/khoj/processor/content/notion/notion_to_entries.py index fc6e296f..1e1ab4d3 100644 --- a/src/khoj/processor/content/notion/notion_to_entries.py +++ b/src/khoj/processor/content/notion/notion_to_entries.py @@ -79,7 +79,7 @@ class NotionToEntries(TextToEntries): self.body_params = {"page_size": 100} - def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]: + def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]: current_entries = [] # Get all pages @@ -248,12 +248,12 @@ class NotionToEntries(TextToEntries): # Identify, mark and merge any new entries with previous entries with timer("Identify new or updated entries", logger): num_new_embeddings, num_deleted_embeddings = self.update_embeddings( + user, current_entries, DbEntry.EntryType.NOTION, DbEntry.EntrySource.NOTION, key="compiled", logger=logger, - user=user, ) return num_new_embeddings, num_deleted_embeddings diff --git a/src/khoj/processor/content/org_mode/org_to_entries.py b/src/khoj/processor/content/org_mode/org_to_entries.py index 1272da11..cfc17cc0 100644 --- a/src/khoj/processor/content/org_mode/org_to_entries.py +++ b/src/khoj/processor/content/org_mode/org_to_entries.py @@ -20,7 +20,7 @@ class OrgToEntries(TextToEntries): super().__init__() # Define Functions - def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]: + def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]: deletion_file_names = set([file for file in files if files[file] == ""]) files_to_process = set(files) - deletion_file_names files = {file: files[file] for file in files_to_process} @@ -36,13 +36,13 @@ class OrgToEntries(TextToEntries): # Identify, mark and merge any new entries with previous entries with timer("Identify new or updated entries", logger): num_new_embeddings, num_deleted_embeddings = self.update_embeddings( + user, current_entries, DbEntry.EntryType.ORG, DbEntry.EntrySource.COMPUTER, "compiled", logger, deletion_file_names, - user, regenerate=regenerate, file_to_text_map=file_to_text_map, ) diff --git a/src/khoj/processor/content/pdf/pdf_to_entries.py b/src/khoj/processor/content/pdf/pdf_to_entries.py index f1ac5104..7d2bd384 100644 --- a/src/khoj/processor/content/pdf/pdf_to_entries.py +++ b/src/khoj/processor/content/pdf/pdf_to_entries.py @@ -19,7 +19,7 @@ class PdfToEntries(TextToEntries): super().__init__() # Define Functions - def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]: + def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]: # Extract required fields from config deletion_file_names = set([file for file in files if files[file] == b""]) files_to_process = set(files) - deletion_file_names @@ -36,13 +36,13 @@ class PdfToEntries(TextToEntries): # Identify, mark and merge any new entries with previous entries with timer("Identify new or updated entries", logger): num_new_embeddings, num_deleted_embeddings = self.update_embeddings( + user, current_entries, DbEntry.EntryType.PDF, DbEntry.EntrySource.COMPUTER, "compiled", logger, deletion_file_names, - user, regenerate=regenerate, file_to_text_map=file_to_text_map, ) diff --git a/src/khoj/processor/content/plaintext/plaintext_to_entries.py b/src/khoj/processor/content/plaintext/plaintext_to_entries.py index 483e752f..64470c08 100644 --- a/src/khoj/processor/content/plaintext/plaintext_to_entries.py +++ b/src/khoj/processor/content/plaintext/plaintext_to_entries.py @@ -20,7 +20,7 @@ class PlaintextToEntries(TextToEntries): super().__init__() # Define Functions - def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]: + def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]: deletion_file_names = set([file for file in files if files[file] == ""]) files_to_process = set(files) - deletion_file_names files = {file: files[file] for file in files_to_process} @@ -36,13 +36,13 @@ class PlaintextToEntries(TextToEntries): # Identify, mark and merge any new entries with previous entries with timer("Identify new or updated entries", logger): num_new_embeddings, num_deleted_embeddings = self.update_embeddings( + user, current_entries, DbEntry.EntryType.PLAINTEXT, DbEntry.EntrySource.COMPUTER, key="compiled", logger=logger, deletion_filenames=deletion_file_names, - user=user, regenerate=regenerate, file_to_text_map=file_to_text_map, ) diff --git a/src/khoj/processor/content/text_to_entries.py b/src/khoj/processor/content/text_to_entries.py index 181eb199..f013b28c 100644 --- a/src/khoj/processor/content/text_to_entries.py +++ b/src/khoj/processor/content/text_to_entries.py @@ -31,7 +31,7 @@ class TextToEntries(ABC): self.date_filter = DateFilter() @abstractmethod - def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]: + def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]: ... @staticmethod @@ -114,13 +114,13 @@ class TextToEntries(ABC): def update_embeddings( self, + user: KhojUser, current_entries: List[Entry], file_type: str, file_source: str, key="compiled", logger: logging.Logger = None, deletion_filenames: Set[str] = None, - user: KhojUser = None, regenerate: bool = False, file_to_text_map: dict[str, str] = None, ): diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index f66fbce8..fc7dfe27 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -212,7 +212,7 @@ def update( logger.warning(error_msg) raise HTTPException(status_code=500, detail=error_msg) try: - initialize_content(regenerate=force, search_type=t, user=user) + initialize_content(user=user, regenerate=force, search_type=t) except Exception as e: error_msg = f"🚨 Failed to update server via API: {e}" logger.error(error_msg, exc_info=True) diff --git a/src/khoj/routers/api_content.py b/src/khoj/routers/api_content.py index 40a1fb78..9ac0db47 100644 --- a/src/khoj/routers/api_content.py +++ b/src/khoj/routers/api_content.py @@ -239,7 +239,7 @@ async def set_content_notion( if updated_config.token: # Trigger an async job to configure_content. Let it run without blocking the response. - background_tasks.add_task(run_in_executor, configure_content, {}, False, SearchType.Notion, user) + background_tasks.add_task(run_in_executor, configure_content, user, {}, False, SearchType.Notion) update_telemetry_state( request=request, @@ -512,10 +512,10 @@ async def indexer( success = await loop.run_in_executor( None, configure_content, + user, indexer_input.model_dump(), regenerate, t, - user, ) if not success: raise RuntimeError(f"Failed to {method} {t} data sent by {client} client into content index") diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 2d0bbe29..3a2cb5cf 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -703,7 +703,7 @@ async def generate_summary_from_files( if await EntryAdapters.aagent_has_entries(agent): file_names = await EntryAdapters.aget_agent_entry_filepaths(agent) if len(file_names) > 0: - file_objects = await FileObjectAdapters.async_get_file_objects_by_name(None, file_names.pop(), agent) + file_objects = await FileObjectAdapters.aget_file_objects_by_name(None, file_names.pop(), agent) if (file_objects and len(file_objects) == 0 and not query_files) or (not file_objects and not query_files): response_log = "Sorry, I couldn't find anything to summarize." @@ -1975,10 +1975,10 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False) def configure_content( + user: KhojUser, files: Optional[dict[str, dict[str, str]]], regenerate: bool = False, t: Optional[state.SearchType] = state.SearchType.All, - user: KhojUser = None, ) -> bool: success = True if t == None: diff --git a/src/khoj/routers/notion.py b/src/khoj/routers/notion.py index 7d5ed25d..acfd1e2e 100644 --- a/src/khoj/routers/notion.py +++ b/src/khoj/routers/notion.py @@ -80,6 +80,6 @@ async def notion_auth_callback(request: Request, background_tasks: BackgroundTas notion_redirect = str(request.app.url_path_for("config_page")) # Trigger an async job to configure_content. Let it run without blocking the response. - background_tasks.add_task(run_in_executor, configure_content, {}, False, SearchType.Notion, user) + background_tasks.add_task(run_in_executor, configure_content, user, {}, False, SearchType.Notion) return RedirectResponse(notion_redirect) diff --git a/src/khoj/search_type/text_search.py b/src/khoj/search_type/text_search.py index eed72b51..6d7667e5 100644 --- a/src/khoj/search_type/text_search.py +++ b/src/khoj/search_type/text_search.py @@ -208,7 +208,7 @@ def setup( text_to_entries: Type[TextToEntries], files: dict[str, str], regenerate: bool, - user: KhojUser = None, + user: KhojUser, config=None, ) -> Tuple[int, int]: if config: diff --git a/src/khoj/utils/fs_syncer.py b/src/khoj/utils/fs_syncer.py index 475504f1..67e91bc9 100644 --- a/src/khoj/utils/fs_syncer.py +++ b/src/khoj/utils/fs_syncer.py @@ -8,6 +8,7 @@ from bs4 import BeautifulSoup from magika import Magika from khoj.database.models import ( + KhojUser, LocalMarkdownConfig, LocalOrgConfig, LocalPdfConfig, @@ -21,7 +22,7 @@ logger = logging.getLogger(__name__) magika = Magika() -def collect_files(search_type: Optional[SearchType] = SearchType.All, user=None) -> dict: +def collect_files(user: KhojUser, search_type: Optional[SearchType] = SearchType.All) -> dict: files: dict[str, dict] = {"docx": {}, "image": {}} if search_type == SearchType.All or search_type == SearchType.Org: diff --git a/tests/conftest.py b/tests/conftest.py index 54b4db86..b91af758 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -304,7 +304,7 @@ def chat_client_builder(search_config, user, index_content=True, require_auth=Fa # Index Markdown Content for Search all_files = fs_syncer.collect_files(user=user) - success = configure_content(all_files, user=user) + configure_content(user, all_files) # Initialize Processor from Config if os.getenv("OPENAI_API_KEY"): @@ -381,7 +381,7 @@ def client_offline_chat(search_config: SearchConfig, default_user2: KhojUser): ) all_files = fs_syncer.collect_files(user=default_user2) - configure_content(all_files, user=default_user2) + configure_content(default_user2, all_files) # Initialize Processor from Config ChatModelOptionsFactory( @@ -432,7 +432,7 @@ def pdf_configured_user1(default_user: KhojUser): ) # Index Markdown Content for Search all_files = fs_syncer.collect_files(user=default_user) - success = configure_content(all_files, user=default_user) + configure_content(default_user, all_files) @pytest.fixture(scope="function") diff --git a/tests/test_client.py b/tests/test_client.py index b8284e4b..f5ed320f 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -253,11 +253,11 @@ def test_regenerate_with_github_fails_without_pat(client): # ---------------------------------------------------------------------------------------------------- @pytest.mark.django_db -def test_get_configured_types_via_api(client, sample_org_data): +def test_get_configured_types_via_api(client, sample_org_data, default_user3: KhojUser): # Act - text_search.setup(OrgToEntries, sample_org_data, regenerate=False) + text_search.setup(OrgToEntries, sample_org_data, regenerate=False, user=default_user3) - enabled_types = EntryAdapters.get_unique_file_types(user=None).all().values_list("file_type", flat=True) + enabled_types = EntryAdapters.get_unique_file_types(user=default_user3).all().values_list("file_type", flat=True) # Assert assert list(enabled_types) == ["org"]