Pick the chat model to generate text responses
@@ -1169,7 +1177,7 @@ export default function SettingsView() {Pick the search model to find your documents
@@ -1190,7 +1198,7 @@ export default function SettingsView() {Pick the paint model to generate image responses
@@ -1217,7 +1225,7 @@ export default function SettingsView() {Pick the voice model to generate speech responses diff --git a/src/khoj/configure.py b/src/khoj/configure.py index f4b1c9f4..5d3c7123 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -32,10 +32,9 @@ from khoj.database.adapters import ( ClientApplicationAdapters, ConversationAdapters, ProcessLockAdapters, - SubscriptionState, aget_or_create_user_by_phone_number, aget_user_by_phone_number, - aget_user_subscription_state, + ais_user_subscribed, delete_user_requests, get_all_users, get_or_create_search_models, @@ -119,15 +118,7 @@ class UserAuthenticationBackend(AuthenticationBackend): .afirst() ) if user: - if not state.billing_enabled: - return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user) - - subscription_state = await aget_user_subscription_state(user) - subscribed = ( - subscription_state == SubscriptionState.SUBSCRIBED.value - or subscription_state == SubscriptionState.TRIAL.value - or subscription_state == SubscriptionState.UNSUBSCRIBED.value - ) + subscribed = await ais_user_subscribed(user) if subscribed: return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user) return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user) @@ -144,15 +135,7 @@ class UserAuthenticationBackend(AuthenticationBackend): .afirst() ) if user_with_token: - if not state.billing_enabled: - return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user_with_token.user) - - subscription_state = await aget_user_subscription_state(user_with_token.user) - subscribed = ( - subscription_state == SubscriptionState.SUBSCRIBED.value - or subscription_state == SubscriptionState.TRIAL.value - or subscription_state == SubscriptionState.UNSUBSCRIBED.value - ) + subscribed = await ais_user_subscribed(user_with_token.user) if subscribed: return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user_with_token.user) return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user) @@ -189,20 +172,10 @@ class UserAuthenticationBackend(AuthenticationBackend): if user is None: return AuthCredentials(), UnauthenticatedUser() - if not state.billing_enabled: - return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user, client_application) + subscribed = await ais_user_subscribed(user) - subscription_state = await aget_user_subscription_state(user) - subscribed = ( - subscription_state == SubscriptionState.SUBSCRIBED.value - or subscription_state == SubscriptionState.TRIAL.value - or subscription_state == SubscriptionState.UNSUBSCRIBED.value - ) if subscribed: - return ( - AuthCredentials(["authenticated", "premium"]), - AuthenticatedKhojUser(user, client_application), - ) + return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user, client_application) return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user, client_application) # No auth required if server in anonymous mode diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index dba42094..ec09ecac 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -300,6 +300,38 @@ async def aget_user_subscription_state(user: KhojUser) -> str: return subscription_to_state(user_subscription) +async def ais_user_subscribed(user: KhojUser) -> bool: + """ + Get whether the user is subscribed + """ + if not state.billing_enabled or state.anonymous_mode: + return True + + subscription_state = await aget_user_subscription_state(user) + subscribed = ( + subscription_state == SubscriptionState.SUBSCRIBED.value + or subscription_state == SubscriptionState.TRIAL.value + or subscription_state == SubscriptionState.UNSUBSCRIBED.value + ) + return subscribed + + +def is_user_subscribed(user: KhojUser) -> bool: + """ + Get whether the user is subscribed + """ + if not state.billing_enabled or state.anonymous_mode: + return True + + subscription_state = get_user_subscription_state(user.email) + subscribed = ( + subscription_state == SubscriptionState.SUBSCRIBED.value + or subscription_state == SubscriptionState.TRIAL.value + or subscription_state == SubscriptionState.UNSUBSCRIBED.value + ) + return subscribed + + async def get_user_by_email(email: str) -> KhojUser: return await KhojUser.objects.filter(email=email).afirst() @@ -751,17 +783,23 @@ class ConversationAdapters: @staticmethod def get_conversation_config(user: KhojUser): + subscribed = is_user_subscribed(user) + if not subscribed: + return ConversationAdapters.get_default_conversation_config() config = UserConversationConfig.objects.filter(user=user).first() - if not config: - return None - return config.setting + if config: + return config.setting + return ConversationAdapters.get_advanced_conversation_config() @staticmethod async def aget_conversation_config(user: KhojUser): + subscribed = await ais_user_subscribed(user) + if not subscribed: + return await ConversationAdapters.aget_default_conversation_config() config = await UserConversationConfig.objects.filter(user=user).prefetch_related("setting").afirst() - if not config: - return None - return config.setting + if config: + return config.setting + return ConversationAdapters.aget_advanced_conversation_config() @staticmethod async def aget_voice_model_config(user: KhojUser) -> Optional[VoiceModelOption]: @@ -784,35 +822,38 @@ class ConversationAdapters: @staticmethod def get_default_conversation_config(): server_chat_settings = ServerChatSettings.objects.first() - if server_chat_settings is None or server_chat_settings.default_model is None: + if server_chat_settings is None or server_chat_settings.chat_default is None: return ChatModelOptions.objects.filter().first() - return server_chat_settings.default_model + return server_chat_settings.chat_default @staticmethod async def aget_default_conversation_config(): server_chat_settings: ServerChatSettings = ( await ServerChatSettings.objects.filter() - .prefetch_related("default_model", "default_model__openai_config") + .prefetch_related("chat_default", "chat_default__openai_config") .afirst() ) - if server_chat_settings is None or server_chat_settings.default_model is None: + if server_chat_settings is None or server_chat_settings.chat_default is None: return await ChatModelOptions.objects.filter().prefetch_related("openai_config").afirst() - return server_chat_settings.default_model + return server_chat_settings.chat_default @staticmethod - async def aget_summarizer_conversation_config(): + def get_advanced_conversation_config(): + server_chat_settings = ServerChatSettings.objects.first() + if server_chat_settings is None or server_chat_settings.chat_advanced is None: + return ConversationAdapters.get_default_conversation_config() + return server_chat_settings.chat_advanced + + @staticmethod + async def aget_advanced_conversation_config(): server_chat_settings: ServerChatSettings = ( await ServerChatSettings.objects.filter() - .prefetch_related( - "summarizer_model", "default_model", "default_model__openai_config", "summarizer_model__openai_config" - ) + .prefetch_related("chat_advanced", "chat_advanced__openai_config") .afirst() ) - if server_chat_settings is None or ( - server_chat_settings.summarizer_model is None and server_chat_settings.default_model is None - ): - return await ChatModelOptions.objects.filter().prefetch_related("openai_config").afirst() - return server_chat_settings.summarizer_model or server_chat_settings.default_model + if server_chat_settings is None or server_chat_settings.chat_advanced is None: + return await ConversationAdapters.aget_default_conversation_config() + return server_chat_settings.chat_advanced @staticmethod def create_conversation_from_public_conversation( diff --git a/src/khoj/database/admin.py b/src/khoj/database/admin.py index 9b7d1e04..f6222006 100644 --- a/src/khoj/database/admin.py +++ b/src/khoj/database/admin.py @@ -26,6 +26,7 @@ from khoj.database.models import ( SpeechToTextModelOptions, Subscription, TextToImageModelConfig, + UserConversationConfig, UserSearchModelConfig, UserVoiceModelConfig, VoiceModelOption, @@ -101,6 +102,7 @@ admin.site.register(GithubConfig) admin.site.register(NotionConfig) admin.site.register(UserVoiceModelConfig) admin.site.register(VoiceModelOption) +admin.site.register(UserConversationConfig) @admin.register(Agent) @@ -191,8 +193,8 @@ class SearchModelConfigAdmin(admin.ModelAdmin): @admin.register(ServerChatSettings) class ServerChatSettingsAdmin(admin.ModelAdmin): list_display = ( - "default_model", - "summarizer_model", + "chat_default", + "chat_advanced", ) diff --git a/src/khoj/database/migrations/0057_remove_serverchatsettings_default_model_and_more.py b/src/khoj/database/migrations/0057_remove_serverchatsettings_default_model_and_more.py new file mode 100644 index 00000000..52a8d854 --- /dev/null +++ b/src/khoj/database/migrations/0057_remove_serverchatsettings_default_model_and_more.py @@ -0,0 +1,51 @@ +# Generated by Django 5.0.7 on 2024-08-16 18:18 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0056_searchmodelconfig_cross_encoder_model_config"), + ] + + operations = [ + migrations.RenameField( + model_name="serverchatsettings", + old_name="default_model", + new_name="chat_default", + ), + migrations.RemoveField( + model_name="serverchatsettings", + name="summarizer_model", + ), + migrations.AddField( + model_name="chatmodeloptions", + name="subscribed_max_prompt_size", + field=models.IntegerField(blank=True, default=None, null=True), + ), + migrations.AddField( + model_name="serverchatsettings", + name="chat_advanced", + field=models.ForeignKey( + blank=True, + default=None, + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="chat_advanced", + to="database.chatmodeloptions", + ), + ), + migrations.AlterField( + model_name="serverchatsettings", + name="chat_default", + field=models.ForeignKey( + blank=True, + default=None, + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="chat_default", + to="database.chatmodeloptions", + ), + ), + ] diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index bf16b781..72c93157 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -89,6 +89,7 @@ class ChatModelOptions(BaseModel): ANTHROPIC = "anthropic" max_prompt_size = models.IntegerField(default=None, null=True, blank=True) + subscribed_max_prompt_size = models.IntegerField(default=None, null=True, blank=True) tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True) chat_model = models.CharField(max_length=200, default="NousResearch/Hermes-2-Pro-Mistral-7B-GGUF") model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE) @@ -205,11 +206,11 @@ class GithubRepoConfig(BaseModel): class ServerChatSettings(BaseModel): - default_model = models.ForeignKey( - ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="default_model" + chat_default = models.ForeignKey( + ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="chat_default" ) - summarizer_model = models.ForeignKey( - ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="summarizer_model" + chat_advanced = models.ForeignKey( + ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="chat_advanced" ) diff --git a/src/khoj/processor/tools/online_search.py b/src/khoj/processor/tools/online_search.py index 70e17630..8678e2bb 100644 --- a/src/khoj/processor/tools/online_search.py +++ b/src/khoj/processor/tools/online_search.py @@ -53,6 +53,7 @@ async def search_online( conversation_history: dict, location: LocationData, user: KhojUser, + subscribed: bool = False, send_status_func: Optional[Callable] = None, custom_filters: List[str] = [], ): @@ -91,12 +92,15 @@ async def search_online( # Read, extract relevant info from the retrieved web pages if webpages: webpage_links = [link for link, _, _ in webpages] - logger.info(f"🌐👀 Reading web pages at: {list(webpage_links)}") + logger.info(f"Reading web pages at: {list(webpage_links)}") if send_status_func: webpage_links_str = "\n- " + "\n- ".join(list(webpage_links)) async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"): yield {ChatEvent.STATUS: event} - tasks = [read_webpage_and_extract_content(subquery, link, content) for link, subquery, content in webpages] + tasks = [ + read_webpage_and_extract_content(subquery, link, content, subscribed=subscribed) + for link, subquery, content in webpages + ] results = await asyncio.gather(*tasks) # Collect extracted info from the retrieved web pages @@ -132,6 +136,7 @@ async def read_webpages( conversation_history: dict, location: LocationData, user: KhojUser, + subscribed: bool = False, send_status_func: Optional[Callable] = None, ): "Infer web pages to read from the query and extract relevant information from them" @@ -146,7 +151,7 @@ async def read_webpages( webpage_links_str = "\n- " + "\n- ".join(list(urls)) async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"): yield {ChatEvent.STATUS: event} - tasks = [read_webpage_and_extract_content(query, url) for url in urls] + tasks = [read_webpage_and_extract_content(query, url, subscribed=subscribed) for url in urls] results = await asyncio.gather(*tasks) response: Dict[str, Dict] = defaultdict(dict) @@ -157,14 +162,14 @@ async def read_webpages( async def read_webpage_and_extract_content( - subquery: str, url: str, content: str = None + subquery: str, url: str, content: str = None, subscribed: bool = False ) -> Tuple[str, Union[None, str], str]: try: if is_none_or_empty(content): with timer(f"Reading web page at '{url}' took", logger): content = await read_webpage_with_olostep(url) if OLOSTEP_API_KEY else await read_webpage_with_jina(url) with timer(f"Extracting relevant information from web page at '{url}' took", logger): - extracted_info = await extract_relevant_info(subquery, content) + extracted_info = await extract_relevant_info(subquery, content, subscribed=subscribed) return subquery, extracted_info, url except Exception as e: logger.error(f"Failed to read web page at '{url}' with {e}") diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index d515006c..f36183be 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -4,14 +4,14 @@ import logging import time from datetime import datetime from functools import partial -from typing import Any, Dict, List, Optional +from typing import Dict, Optional from urllib.parse import unquote from asgiref.sync import sync_to_async from fastapi import APIRouter, Depends, HTTPException, Request from fastapi.requests import Request from fastapi.responses import Response, StreamingResponse -from starlette.authentication import requires +from starlette.authentication import has_required_scope, requires from khoj.app.settings import ALLOWED_HOSTS from khoj.database.adapters import ( @@ -59,7 +59,7 @@ from khoj.utils.rawconfig import FileFilterRequest, FilesFilterRequest, Location # Initialize Router logger = logging.getLogger(__name__) conversation_command_rate_limiter = ConversationCommandRateLimiter( - trial_rate_limit=2, subscribed_rate_limit=100, slug="command" + trial_rate_limit=100, subscribed_rate_limit=100, slug="command" ) @@ -532,10 +532,10 @@ async def chat( country: Optional[str] = None, timezone: Optional[str] = None, rate_limiter_per_minute=Depends( - ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute") + ApiUserRateLimiter(requests=60, subscribed_requests=60, window=60, slug="chat_minute") ), rate_limiter_per_day=Depends( - ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day") + ApiUserRateLimiter(requests=600, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day") ), ): async def event_generator(q: str): @@ -544,6 +544,7 @@ async def chat( chat_metadata: dict = {} connection_alive = True user: KhojUser = request.user.object + subscribed: bool = has_required_scope(request, ["premium"]) event_delimiter = "␃🔚␗" q = unquote(q) @@ -632,7 +633,9 @@ async def chat( is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask] if conversation_commands == [ConversationCommand.Default] or is_automated_task: - conversation_commands = await aget_relevant_information_sources(q, meta_log, is_automated_task) + conversation_commands = await aget_relevant_information_sources( + q, meta_log, is_automated_task, subscribed=subscribed + ) conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands]) async for result in send_event( ChatEvent.STATUS, f"**Chose Data Sources to Search:** {conversation_commands_str}" @@ -687,7 +690,7 @@ async def chat( ): yield result - response = await extract_relevant_summary(q, contextual_data) + response = await extract_relevant_summary(q, contextual_data, subscribed=subscribed) response_log = str(response) async for result in send_llm_response(response_log): yield result @@ -792,7 +795,13 @@ async def chat( if ConversationCommand.Online in conversation_commands: try: async for result in search_online( - defiltered_query, meta_log, location, user, partial(send_event, ChatEvent.STATUS), custom_filters + defiltered_query, + meta_log, + location, + user, + subscribed, + partial(send_event, ChatEvent.STATUS), + custom_filters, ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] @@ -809,7 +818,7 @@ async def chat( if ConversationCommand.Webpage in conversation_commands: try: async for result in read_webpages( - defiltered_query, meta_log, location, user, partial(send_event, ChatEvent.STATUS) + defiltered_query, meta_log, location, user, subscribed, partial(send_event, ChatEvent.STATUS) ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] @@ -853,6 +862,7 @@ async def chat( location_data=location, references=compiled_references, online_results=online_results, + subscribed=subscribed, send_status_func=partial(send_event, ChatEvent.STATUS), ): if isinstance(result, dict) and ChatEvent.STATUS in result: diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 31b8d9b5..4e4f5a56 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -252,7 +252,7 @@ async def acreate_title_from_query(query: str) -> str: return response.strip() -async def aget_relevant_information_sources(query: str, conversation_history: dict, is_task: bool): +async def aget_relevant_information_sources(query: str, conversation_history: dict, is_task: bool, subscribed: bool): """ Given a query, determine which of the available tools the agent should use in order to answer appropriately. """ @@ -273,7 +273,9 @@ async def aget_relevant_information_sources(query: str, conversation_history: di ) with timer("Chat actor: Infer information sources to refer", logger): - response = await send_message_to_model_wrapper(relevant_tools_prompt, response_type="json_object") + response = await send_message_to_model_wrapper( + relevant_tools_prompt, response_type="json_object", subscribed=subscribed + ) try: response = response.strip() @@ -434,7 +436,7 @@ async def schedule_query(q: str, conversation_history: dict) -> Tuple[str, ...]: raise AssertionError(f"Invalid response for scheduling query: {raw_response}") -async def extract_relevant_info(q: str, corpus: str) -> Union[str, None]: +async def extract_relevant_info(q: str, corpus: str, subscribed: bool) -> Union[str, None]: """ Extract relevant information for a given query from the target corpus """ @@ -447,18 +449,19 @@ async def extract_relevant_info(q: str, corpus: str) -> Union[str, None]: corpus=corpus.strip(), ) - summarizer_model: ChatModelOptions = await ConversationAdapters.aget_summarizer_conversation_config() + chat_model: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config() with timer("Chat actor: Extract relevant information from data", logger): response = await send_message_to_model_wrapper( extract_relevant_information, prompts.system_prompt_extract_relevant_information, - chat_model_option=summarizer_model, + chat_model_option=chat_model, + subscribed=subscribed, ) return response.strip() -async def extract_relevant_summary(q: str, corpus: str) -> Union[str, None]: +async def extract_relevant_summary(q: str, corpus: str, subscribed: bool = False) -> Union[str, None]: """ Extract relevant information for a given query from the target corpus """ @@ -471,13 +474,14 @@ async def extract_relevant_summary(q: str, corpus: str) -> Union[str, None]: corpus=corpus.strip(), ) - summarizer_model: ChatModelOptions = await ConversationAdapters.aget_summarizer_conversation_config() + chat_model: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config() with timer("Chat actor: Extract relevant information from data", logger): response = await send_message_to_model_wrapper( extract_relevant_information, prompts.system_prompt_extract_relevant_summary, - chat_model_option=summarizer_model, + chat_model_option=chat_model, + subscribed=subscribed, ) return response.strip() @@ -489,6 +493,7 @@ async def generate_better_image_prompt( note_references: List[Dict[str, Any]], online_results: Optional[dict] = None, model_type: Optional[str] = None, + subscribed: bool = False, ) -> str: """ Generate a better image prompt from the given query @@ -533,10 +538,12 @@ async def generate_better_image_prompt( online_results=simplified_online_results, ) - summarizer_model: ChatModelOptions = await ConversationAdapters.aget_summarizer_conversation_config() + chat_model: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config() with timer("Chat actor: Generate contextual image prompt", logger): - response = await send_message_to_model_wrapper(image_prompt, chat_model_option=summarizer_model) + response = await send_message_to_model_wrapper( + image_prompt, chat_model_option=chat_model, subscribed=subscribed + ) response = response.strip() if response.startswith(('"', "'")) and response.endswith(('"', "'")): response = response[1:-1] @@ -549,13 +556,18 @@ async def send_message_to_model_wrapper( system_message: str = "", response_type: str = "text", chat_model_option: ChatModelOptions = None, + subscribed: bool = False, ): conversation_config: ChatModelOptions = ( chat_model_option or await ConversationAdapters.aget_default_conversation_config() ) chat_model = conversation_config.chat_model - max_tokens = conversation_config.max_prompt_size + max_tokens = ( + conversation_config.subscribed_max_prompt_size + if subscribed and conversation_config.subscribed_max_prompt_size + else conversation_config.max_prompt_size + ) tokenizer = conversation_config.tokenizer if conversation_config.model_type == "offline": @@ -786,6 +798,7 @@ async def text_to_image( location_data: LocationData, references: List[Dict[str, Any]], online_results: Dict[str, Any], + subscribed: bool = False, send_status_func: Optional[Callable] = None, ): status_code = 200 @@ -822,6 +835,7 @@ async def text_to_image( note_references=references, online_results=online_results, model_type=text_to_image_config.model_type, + subscribed=subscribed, ) if send_status_func: @@ -1359,7 +1373,9 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False) current_notion_config = get_user_notion_config(user) notion_token = current_notion_config.token if current_notion_config else "" - selected_chat_model_config = ConversationAdapters.get_conversation_config(user) + selected_chat_model_config = ( + ConversationAdapters.get_conversation_config(user) or ConversationAdapters.get_default_conversation_config() + ) chat_models = ConversationAdapters.get_conversation_processor_options().all() chat_model_options = list() for chat_model in chat_models: