Enable free tier users to have unlimited chats with the default chat model (#886)

- Allow free tier users to have unlimited chats with default chat model. It'll only be rate-limited and at the same rate as subscribed users
- In the server chat settings, replace the concept of default/summarizer models with default/advanced chat models. Use the advanced models as a default for subscribed users.
- For each `ChatModelOption' configuration, allow the admin to specify a separate value of `max_tokens' for subscribed users. This allows server admins to configure different max token limits for unsubscribed and subscribed users
- Show error message in web app when hit rate limit or other server errors
This commit is contained in:
sabaimran
2024-08-16 12:14:44 -07:00
committed by GitHub
parent 8dad9362e7
commit c0316a6b5d
11 changed files with 210 additions and 92 deletions

View File

@@ -222,7 +222,20 @@ export default function Chat() {
try { try {
await readChatStream(response); await readChatStream(response);
} catch (err) { } catch (err) {
console.log(err); console.error(err);
// Retrieve latest message being processed
const currentMessage = messages.find((message) => !message.completed);
if (!currentMessage) return;
// Render error message as current message
const errorMessage = (err as Error).message;
currentMessage.rawResponse = `Encountered Error: ${errorMessage}. Please try again later.`;
// Complete message streaming teardown properly
currentMessage.completed = true;
setMessages([...messages]);
setQueryToProcess("");
setProcessQuerySignal(false);
} }
} }

View File

@@ -386,8 +386,6 @@ export default function ChatMessage(props: ChatMessageProps) {
preElement.prepend(copyButton); preElement.prepend(copyButton);
}); });
console.log("render katex within the chat message");
renderMathInElement(messageRef.current, { renderMathInElement(messageRef.current, {
delimiters: [ delimiters: [
{ left: "$$", right: "$$", display: true }, { left: "$$", right: "$$", display: true },

View File

@@ -672,7 +672,15 @@ export default function SettingsView() {
}; };
const updateModel = (name: string) => async (id: string) => { const updateModel = (name: string) => async (id: string) => {
if (!userConfig?.is_active && name !== "search") return; if (!userConfig?.is_active && name !== "search") {
toast({
title: `Model Update`,
description: `You need to be subscribed to update ${name} models`,
variant: "destructive",
});
return;
}
try { try {
const response = await fetch(`/api/model/${name}?id=` + id, { const response = await fetch(`/api/model/${name}?id=` + id, {
method: "POST", method: "POST",
@@ -1144,7 +1152,7 @@ export default function SettingsView() {
<ChatCircleText className="h-7 w-7 mr-2" /> <ChatCircleText className="h-7 w-7 mr-2" />
Chat Chat
</CardHeader> </CardHeader>
<CardContent className="overflow-hidden pb-12 grid gap-8"> <CardContent className="overflow-hidden pb-12 grid gap-8 h-fit">
<p className="text-gray-400"> <p className="text-gray-400">
Pick the chat model to generate text responses Pick the chat model to generate text responses
</p> </p>
@@ -1169,7 +1177,7 @@ export default function SettingsView() {
<FileMagnifyingGlass className="h-7 w-7 mr-2" /> <FileMagnifyingGlass className="h-7 w-7 mr-2" />
Search Search
</CardHeader> </CardHeader>
<CardContent className="overflow-hidden pb-12 grid gap-8"> <CardContent className="overflow-hidden pb-12 grid gap-8 h-fit">
<p className="text-gray-400"> <p className="text-gray-400">
Pick the search model to find your documents Pick the search model to find your documents
</p> </p>
@@ -1190,7 +1198,7 @@ export default function SettingsView() {
<Palette className="h-7 w-7 mr-2" /> <Palette className="h-7 w-7 mr-2" />
Paint Paint
</CardHeader> </CardHeader>
<CardContent className="overflow-hidden pb-12 grid gap-8"> <CardContent className="overflow-hidden pb-12 grid gap-8 h-fit">
<p className="text-gray-400"> <p className="text-gray-400">
Pick the paint model to generate image responses Pick the paint model to generate image responses
</p> </p>
@@ -1217,7 +1225,7 @@ export default function SettingsView() {
<Waveform className="h-7 w-7 mr-2" /> <Waveform className="h-7 w-7 mr-2" />
Voice Voice
</CardHeader> </CardHeader>
<CardContent className="overflow-hidden pb-12 grid gap-8"> <CardContent className="overflow-hidden pb-12 grid gap-8 h-fit">
<p className="text-gray-400"> <p className="text-gray-400">
Pick the voice model to generate speech Pick the voice model to generate speech
responses responses

View File

@@ -32,10 +32,9 @@ from khoj.database.adapters import (
ClientApplicationAdapters, ClientApplicationAdapters,
ConversationAdapters, ConversationAdapters,
ProcessLockAdapters, ProcessLockAdapters,
SubscriptionState,
aget_or_create_user_by_phone_number, aget_or_create_user_by_phone_number,
aget_user_by_phone_number, aget_user_by_phone_number,
aget_user_subscription_state, ais_user_subscribed,
delete_user_requests, delete_user_requests,
get_all_users, get_all_users,
get_or_create_search_models, get_or_create_search_models,
@@ -119,15 +118,7 @@ class UserAuthenticationBackend(AuthenticationBackend):
.afirst() .afirst()
) )
if user: if user:
if not state.billing_enabled: subscribed = await ais_user_subscribed(user)
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user)
subscription_state = await aget_user_subscription_state(user)
subscribed = (
subscription_state == SubscriptionState.SUBSCRIBED.value
or subscription_state == SubscriptionState.TRIAL.value
or subscription_state == SubscriptionState.UNSUBSCRIBED.value
)
if subscribed: if subscribed:
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user) return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user)
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user) return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
@@ -144,15 +135,7 @@ class UserAuthenticationBackend(AuthenticationBackend):
.afirst() .afirst()
) )
if user_with_token: if user_with_token:
if not state.billing_enabled: subscribed = await ais_user_subscribed(user_with_token.user)
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user_with_token.user)
subscription_state = await aget_user_subscription_state(user_with_token.user)
subscribed = (
subscription_state == SubscriptionState.SUBSCRIBED.value
or subscription_state == SubscriptionState.TRIAL.value
or subscription_state == SubscriptionState.UNSUBSCRIBED.value
)
if subscribed: if subscribed:
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user_with_token.user) return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user_with_token.user)
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user) return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user)
@@ -189,20 +172,10 @@ class UserAuthenticationBackend(AuthenticationBackend):
if user is None: if user is None:
return AuthCredentials(), UnauthenticatedUser() return AuthCredentials(), UnauthenticatedUser()
if not state.billing_enabled: subscribed = await ais_user_subscribed(user)
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user, client_application)
subscription_state = await aget_user_subscription_state(user)
subscribed = (
subscription_state == SubscriptionState.SUBSCRIBED.value
or subscription_state == SubscriptionState.TRIAL.value
or subscription_state == SubscriptionState.UNSUBSCRIBED.value
)
if subscribed: if subscribed:
return ( return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user, client_application)
AuthCredentials(["authenticated", "premium"]),
AuthenticatedKhojUser(user, client_application),
)
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user, client_application) return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user, client_application)
# No auth required if server in anonymous mode # No auth required if server in anonymous mode

View File

@@ -300,6 +300,38 @@ async def aget_user_subscription_state(user: KhojUser) -> str:
return subscription_to_state(user_subscription) return subscription_to_state(user_subscription)
async def ais_user_subscribed(user: KhojUser) -> bool:
"""
Get whether the user is subscribed
"""
if not state.billing_enabled or state.anonymous_mode:
return True
subscription_state = await aget_user_subscription_state(user)
subscribed = (
subscription_state == SubscriptionState.SUBSCRIBED.value
or subscription_state == SubscriptionState.TRIAL.value
or subscription_state == SubscriptionState.UNSUBSCRIBED.value
)
return subscribed
def is_user_subscribed(user: KhojUser) -> bool:
"""
Get whether the user is subscribed
"""
if not state.billing_enabled or state.anonymous_mode:
return True
subscription_state = get_user_subscription_state(user.email)
subscribed = (
subscription_state == SubscriptionState.SUBSCRIBED.value
or subscription_state == SubscriptionState.TRIAL.value
or subscription_state == SubscriptionState.UNSUBSCRIBED.value
)
return subscribed
async def get_user_by_email(email: str) -> KhojUser: async def get_user_by_email(email: str) -> KhojUser:
return await KhojUser.objects.filter(email=email).afirst() return await KhojUser.objects.filter(email=email).afirst()
@@ -751,17 +783,23 @@ class ConversationAdapters:
@staticmethod @staticmethod
def get_conversation_config(user: KhojUser): def get_conversation_config(user: KhojUser):
subscribed = is_user_subscribed(user)
if not subscribed:
return ConversationAdapters.get_default_conversation_config()
config = UserConversationConfig.objects.filter(user=user).first() config = UserConversationConfig.objects.filter(user=user).first()
if not config: if config:
return None return config.setting
return config.setting return ConversationAdapters.get_advanced_conversation_config()
@staticmethod @staticmethod
async def aget_conversation_config(user: KhojUser): async def aget_conversation_config(user: KhojUser):
subscribed = await ais_user_subscribed(user)
if not subscribed:
return await ConversationAdapters.aget_default_conversation_config()
config = await UserConversationConfig.objects.filter(user=user).prefetch_related("setting").afirst() config = await UserConversationConfig.objects.filter(user=user).prefetch_related("setting").afirst()
if not config: if config:
return None return config.setting
return config.setting return ConversationAdapters.aget_advanced_conversation_config()
@staticmethod @staticmethod
async def aget_voice_model_config(user: KhojUser) -> Optional[VoiceModelOption]: async def aget_voice_model_config(user: KhojUser) -> Optional[VoiceModelOption]:
@@ -784,35 +822,38 @@ class ConversationAdapters:
@staticmethod @staticmethod
def get_default_conversation_config(): def get_default_conversation_config():
server_chat_settings = ServerChatSettings.objects.first() server_chat_settings = ServerChatSettings.objects.first()
if server_chat_settings is None or server_chat_settings.default_model is None: if server_chat_settings is None or server_chat_settings.chat_default is None:
return ChatModelOptions.objects.filter().first() return ChatModelOptions.objects.filter().first()
return server_chat_settings.default_model return server_chat_settings.chat_default
@staticmethod @staticmethod
async def aget_default_conversation_config(): async def aget_default_conversation_config():
server_chat_settings: ServerChatSettings = ( server_chat_settings: ServerChatSettings = (
await ServerChatSettings.objects.filter() await ServerChatSettings.objects.filter()
.prefetch_related("default_model", "default_model__openai_config") .prefetch_related("chat_default", "chat_default__openai_config")
.afirst() .afirst()
) )
if server_chat_settings is None or server_chat_settings.default_model is None: if server_chat_settings is None or server_chat_settings.chat_default is None:
return await ChatModelOptions.objects.filter().prefetch_related("openai_config").afirst() return await ChatModelOptions.objects.filter().prefetch_related("openai_config").afirst()
return server_chat_settings.default_model return server_chat_settings.chat_default
@staticmethod @staticmethod
async def aget_summarizer_conversation_config(): def get_advanced_conversation_config():
server_chat_settings = ServerChatSettings.objects.first()
if server_chat_settings is None or server_chat_settings.chat_advanced is None:
return ConversationAdapters.get_default_conversation_config()
return server_chat_settings.chat_advanced
@staticmethod
async def aget_advanced_conversation_config():
server_chat_settings: ServerChatSettings = ( server_chat_settings: ServerChatSettings = (
await ServerChatSettings.objects.filter() await ServerChatSettings.objects.filter()
.prefetch_related( .prefetch_related("chat_advanced", "chat_advanced__openai_config")
"summarizer_model", "default_model", "default_model__openai_config", "summarizer_model__openai_config"
)
.afirst() .afirst()
) )
if server_chat_settings is None or ( if server_chat_settings is None or server_chat_settings.chat_advanced is None:
server_chat_settings.summarizer_model is None and server_chat_settings.default_model is None return await ConversationAdapters.aget_default_conversation_config()
): return server_chat_settings.chat_advanced
return await ChatModelOptions.objects.filter().prefetch_related("openai_config").afirst()
return server_chat_settings.summarizer_model or server_chat_settings.default_model
@staticmethod @staticmethod
def create_conversation_from_public_conversation( def create_conversation_from_public_conversation(

View File

@@ -26,6 +26,7 @@ from khoj.database.models import (
SpeechToTextModelOptions, SpeechToTextModelOptions,
Subscription, Subscription,
TextToImageModelConfig, TextToImageModelConfig,
UserConversationConfig,
UserSearchModelConfig, UserSearchModelConfig,
UserVoiceModelConfig, UserVoiceModelConfig,
VoiceModelOption, VoiceModelOption,
@@ -101,6 +102,7 @@ admin.site.register(GithubConfig)
admin.site.register(NotionConfig) admin.site.register(NotionConfig)
admin.site.register(UserVoiceModelConfig) admin.site.register(UserVoiceModelConfig)
admin.site.register(VoiceModelOption) admin.site.register(VoiceModelOption)
admin.site.register(UserConversationConfig)
@admin.register(Agent) @admin.register(Agent)
@@ -191,8 +193,8 @@ class SearchModelConfigAdmin(admin.ModelAdmin):
@admin.register(ServerChatSettings) @admin.register(ServerChatSettings)
class ServerChatSettingsAdmin(admin.ModelAdmin): class ServerChatSettingsAdmin(admin.ModelAdmin):
list_display = ( list_display = (
"default_model", "chat_default",
"summarizer_model", "chat_advanced",
) )

View File

@@ -0,0 +1,51 @@
# Generated by Django 5.0.7 on 2024-08-16 18:18
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0056_searchmodelconfig_cross_encoder_model_config"),
]
operations = [
migrations.RenameField(
model_name="serverchatsettings",
old_name="default_model",
new_name="chat_default",
),
migrations.RemoveField(
model_name="serverchatsettings",
name="summarizer_model",
),
migrations.AddField(
model_name="chatmodeloptions",
name="subscribed_max_prompt_size",
field=models.IntegerField(blank=True, default=None, null=True),
),
migrations.AddField(
model_name="serverchatsettings",
name="chat_advanced",
field=models.ForeignKey(
blank=True,
default=None,
null=True,
on_delete=django.db.models.deletion.CASCADE,
related_name="chat_advanced",
to="database.chatmodeloptions",
),
),
migrations.AlterField(
model_name="serverchatsettings",
name="chat_default",
field=models.ForeignKey(
blank=True,
default=None,
null=True,
on_delete=django.db.models.deletion.CASCADE,
related_name="chat_default",
to="database.chatmodeloptions",
),
),
]

View File

@@ -89,6 +89,7 @@ class ChatModelOptions(BaseModel):
ANTHROPIC = "anthropic" ANTHROPIC = "anthropic"
max_prompt_size = models.IntegerField(default=None, null=True, blank=True) max_prompt_size = models.IntegerField(default=None, null=True, blank=True)
subscribed_max_prompt_size = models.IntegerField(default=None, null=True, blank=True)
tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True) tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True)
chat_model = models.CharField(max_length=200, default="NousResearch/Hermes-2-Pro-Mistral-7B-GGUF") chat_model = models.CharField(max_length=200, default="NousResearch/Hermes-2-Pro-Mistral-7B-GGUF")
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE) model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)
@@ -205,11 +206,11 @@ class GithubRepoConfig(BaseModel):
class ServerChatSettings(BaseModel): class ServerChatSettings(BaseModel):
default_model = models.ForeignKey( chat_default = models.ForeignKey(
ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="default_model" ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="chat_default"
) )
summarizer_model = models.ForeignKey( chat_advanced = models.ForeignKey(
ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="summarizer_model" ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="chat_advanced"
) )

View File

@@ -53,6 +53,7 @@ async def search_online(
conversation_history: dict, conversation_history: dict,
location: LocationData, location: LocationData,
user: KhojUser, user: KhojUser,
subscribed: bool = False,
send_status_func: Optional[Callable] = None, send_status_func: Optional[Callable] = None,
custom_filters: List[str] = [], custom_filters: List[str] = [],
): ):
@@ -91,12 +92,15 @@ async def search_online(
# Read, extract relevant info from the retrieved web pages # Read, extract relevant info from the retrieved web pages
if webpages: if webpages:
webpage_links = [link for link, _, _ in webpages] webpage_links = [link for link, _, _ in webpages]
logger.info(f"🌐👀 Reading web pages at: {list(webpage_links)}") logger.info(f"Reading web pages at: {list(webpage_links)}")
if send_status_func: if send_status_func:
webpage_links_str = "\n- " + "\n- ".join(list(webpage_links)) webpage_links_str = "\n- " + "\n- ".join(list(webpage_links))
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"): async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
yield {ChatEvent.STATUS: event} yield {ChatEvent.STATUS: event}
tasks = [read_webpage_and_extract_content(subquery, link, content) for link, subquery, content in webpages] tasks = [
read_webpage_and_extract_content(subquery, link, content, subscribed=subscribed)
for link, subquery, content in webpages
]
results = await asyncio.gather(*tasks) results = await asyncio.gather(*tasks)
# Collect extracted info from the retrieved web pages # Collect extracted info from the retrieved web pages
@@ -132,6 +136,7 @@ async def read_webpages(
conversation_history: dict, conversation_history: dict,
location: LocationData, location: LocationData,
user: KhojUser, user: KhojUser,
subscribed: bool = False,
send_status_func: Optional[Callable] = None, send_status_func: Optional[Callable] = None,
): ):
"Infer web pages to read from the query and extract relevant information from them" "Infer web pages to read from the query and extract relevant information from them"
@@ -146,7 +151,7 @@ async def read_webpages(
webpage_links_str = "\n- " + "\n- ".join(list(urls)) webpage_links_str = "\n- " + "\n- ".join(list(urls))
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"): async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
yield {ChatEvent.STATUS: event} yield {ChatEvent.STATUS: event}
tasks = [read_webpage_and_extract_content(query, url) for url in urls] tasks = [read_webpage_and_extract_content(query, url, subscribed=subscribed) for url in urls]
results = await asyncio.gather(*tasks) results = await asyncio.gather(*tasks)
response: Dict[str, Dict] = defaultdict(dict) response: Dict[str, Dict] = defaultdict(dict)
@@ -157,14 +162,14 @@ async def read_webpages(
async def read_webpage_and_extract_content( async def read_webpage_and_extract_content(
subquery: str, url: str, content: str = None subquery: str, url: str, content: str = None, subscribed: bool = False
) -> Tuple[str, Union[None, str], str]: ) -> Tuple[str, Union[None, str], str]:
try: try:
if is_none_or_empty(content): if is_none_or_empty(content):
with timer(f"Reading web page at '{url}' took", logger): with timer(f"Reading web page at '{url}' took", logger):
content = await read_webpage_with_olostep(url) if OLOSTEP_API_KEY else await read_webpage_with_jina(url) content = await read_webpage_with_olostep(url) if OLOSTEP_API_KEY else await read_webpage_with_jina(url)
with timer(f"Extracting relevant information from web page at '{url}' took", logger): with timer(f"Extracting relevant information from web page at '{url}' took", logger):
extracted_info = await extract_relevant_info(subquery, content) extracted_info = await extract_relevant_info(subquery, content, subscribed=subscribed)
return subquery, extracted_info, url return subquery, extracted_info, url
except Exception as e: except Exception as e:
logger.error(f"Failed to read web page at '{url}' with {e}") logger.error(f"Failed to read web page at '{url}' with {e}")

View File

@@ -4,14 +4,14 @@ import logging
import time import time
from datetime import datetime from datetime import datetime
from functools import partial from functools import partial
from typing import Any, Dict, List, Optional from typing import Dict, Optional
from urllib.parse import unquote from urllib.parse import unquote
from asgiref.sync import sync_to_async from asgiref.sync import sync_to_async
from fastapi import APIRouter, Depends, HTTPException, Request from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.requests import Request from fastapi.requests import Request
from fastapi.responses import Response, StreamingResponse from fastapi.responses import Response, StreamingResponse
from starlette.authentication import requires from starlette.authentication import has_required_scope, requires
from khoj.app.settings import ALLOWED_HOSTS from khoj.app.settings import ALLOWED_HOSTS
from khoj.database.adapters import ( from khoj.database.adapters import (
@@ -59,7 +59,7 @@ from khoj.utils.rawconfig import FileFilterRequest, FilesFilterRequest, Location
# Initialize Router # Initialize Router
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
conversation_command_rate_limiter = ConversationCommandRateLimiter( conversation_command_rate_limiter = ConversationCommandRateLimiter(
trial_rate_limit=2, subscribed_rate_limit=100, slug="command" trial_rate_limit=100, subscribed_rate_limit=100, slug="command"
) )
@@ -532,10 +532,10 @@ async def chat(
country: Optional[str] = None, country: Optional[str] = None,
timezone: Optional[str] = None, timezone: Optional[str] = None,
rate_limiter_per_minute=Depends( rate_limiter_per_minute=Depends(
ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute") ApiUserRateLimiter(requests=60, subscribed_requests=60, window=60, slug="chat_minute")
), ),
rate_limiter_per_day=Depends( rate_limiter_per_day=Depends(
ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day") ApiUserRateLimiter(requests=600, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
), ),
): ):
async def event_generator(q: str): async def event_generator(q: str):
@@ -544,6 +544,7 @@ async def chat(
chat_metadata: dict = {} chat_metadata: dict = {}
connection_alive = True connection_alive = True
user: KhojUser = request.user.object user: KhojUser = request.user.object
subscribed: bool = has_required_scope(request, ["premium"])
event_delimiter = "␃🔚␗" event_delimiter = "␃🔚␗"
q = unquote(q) q = unquote(q)
@@ -632,7 +633,9 @@ async def chat(
is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask] is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask]
if conversation_commands == [ConversationCommand.Default] or is_automated_task: if conversation_commands == [ConversationCommand.Default] or is_automated_task:
conversation_commands = await aget_relevant_information_sources(q, meta_log, is_automated_task) conversation_commands = await aget_relevant_information_sources(
q, meta_log, is_automated_task, subscribed=subscribed
)
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands]) conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
async for result in send_event( async for result in send_event(
ChatEvent.STATUS, f"**Chose Data Sources to Search:** {conversation_commands_str}" ChatEvent.STATUS, f"**Chose Data Sources to Search:** {conversation_commands_str}"
@@ -687,7 +690,7 @@ async def chat(
): ):
yield result yield result
response = await extract_relevant_summary(q, contextual_data) response = await extract_relevant_summary(q, contextual_data, subscribed=subscribed)
response_log = str(response) response_log = str(response)
async for result in send_llm_response(response_log): async for result in send_llm_response(response_log):
yield result yield result
@@ -792,7 +795,13 @@ async def chat(
if ConversationCommand.Online in conversation_commands: if ConversationCommand.Online in conversation_commands:
try: try:
async for result in search_online( async for result in search_online(
defiltered_query, meta_log, location, user, partial(send_event, ChatEvent.STATUS), custom_filters defiltered_query,
meta_log,
location,
user,
subscribed,
partial(send_event, ChatEvent.STATUS),
custom_filters,
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS] yield result[ChatEvent.STATUS]
@@ -809,7 +818,7 @@ async def chat(
if ConversationCommand.Webpage in conversation_commands: if ConversationCommand.Webpage in conversation_commands:
try: try:
async for result in read_webpages( async for result in read_webpages(
defiltered_query, meta_log, location, user, partial(send_event, ChatEvent.STATUS) defiltered_query, meta_log, location, user, subscribed, partial(send_event, ChatEvent.STATUS)
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS] yield result[ChatEvent.STATUS]
@@ -853,6 +862,7 @@ async def chat(
location_data=location, location_data=location,
references=compiled_references, references=compiled_references,
online_results=online_results, online_results=online_results,
subscribed=subscribed,
send_status_func=partial(send_event, ChatEvent.STATUS), send_status_func=partial(send_event, ChatEvent.STATUS),
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:

View File

@@ -252,7 +252,7 @@ async def acreate_title_from_query(query: str) -> str:
return response.strip() return response.strip()
async def aget_relevant_information_sources(query: str, conversation_history: dict, is_task: bool): async def aget_relevant_information_sources(query: str, conversation_history: dict, is_task: bool, subscribed: bool):
""" """
Given a query, determine which of the available tools the agent should use in order to answer appropriately. Given a query, determine which of the available tools the agent should use in order to answer appropriately.
""" """
@@ -273,7 +273,9 @@ async def aget_relevant_information_sources(query: str, conversation_history: di
) )
with timer("Chat actor: Infer information sources to refer", logger): with timer("Chat actor: Infer information sources to refer", logger):
response = await send_message_to_model_wrapper(relevant_tools_prompt, response_type="json_object") response = await send_message_to_model_wrapper(
relevant_tools_prompt, response_type="json_object", subscribed=subscribed
)
try: try:
response = response.strip() response = response.strip()
@@ -434,7 +436,7 @@ async def schedule_query(q: str, conversation_history: dict) -> Tuple[str, ...]:
raise AssertionError(f"Invalid response for scheduling query: {raw_response}") raise AssertionError(f"Invalid response for scheduling query: {raw_response}")
async def extract_relevant_info(q: str, corpus: str) -> Union[str, None]: async def extract_relevant_info(q: str, corpus: str, subscribed: bool) -> Union[str, None]:
""" """
Extract relevant information for a given query from the target corpus Extract relevant information for a given query from the target corpus
""" """
@@ -447,18 +449,19 @@ async def extract_relevant_info(q: str, corpus: str) -> Union[str, None]:
corpus=corpus.strip(), corpus=corpus.strip(),
) )
summarizer_model: ChatModelOptions = await ConversationAdapters.aget_summarizer_conversation_config() chat_model: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config()
with timer("Chat actor: Extract relevant information from data", logger): with timer("Chat actor: Extract relevant information from data", logger):
response = await send_message_to_model_wrapper( response = await send_message_to_model_wrapper(
extract_relevant_information, extract_relevant_information,
prompts.system_prompt_extract_relevant_information, prompts.system_prompt_extract_relevant_information,
chat_model_option=summarizer_model, chat_model_option=chat_model,
subscribed=subscribed,
) )
return response.strip() return response.strip()
async def extract_relevant_summary(q: str, corpus: str) -> Union[str, None]: async def extract_relevant_summary(q: str, corpus: str, subscribed: bool = False) -> Union[str, None]:
""" """
Extract relevant information for a given query from the target corpus Extract relevant information for a given query from the target corpus
""" """
@@ -471,13 +474,14 @@ async def extract_relevant_summary(q: str, corpus: str) -> Union[str, None]:
corpus=corpus.strip(), corpus=corpus.strip(),
) )
summarizer_model: ChatModelOptions = await ConversationAdapters.aget_summarizer_conversation_config() chat_model: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config()
with timer("Chat actor: Extract relevant information from data", logger): with timer("Chat actor: Extract relevant information from data", logger):
response = await send_message_to_model_wrapper( response = await send_message_to_model_wrapper(
extract_relevant_information, extract_relevant_information,
prompts.system_prompt_extract_relevant_summary, prompts.system_prompt_extract_relevant_summary,
chat_model_option=summarizer_model, chat_model_option=chat_model,
subscribed=subscribed,
) )
return response.strip() return response.strip()
@@ -489,6 +493,7 @@ async def generate_better_image_prompt(
note_references: List[Dict[str, Any]], note_references: List[Dict[str, Any]],
online_results: Optional[dict] = None, online_results: Optional[dict] = None,
model_type: Optional[str] = None, model_type: Optional[str] = None,
subscribed: bool = False,
) -> str: ) -> str:
""" """
Generate a better image prompt from the given query Generate a better image prompt from the given query
@@ -533,10 +538,12 @@ async def generate_better_image_prompt(
online_results=simplified_online_results, online_results=simplified_online_results,
) )
summarizer_model: ChatModelOptions = await ConversationAdapters.aget_summarizer_conversation_config() chat_model: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config()
with timer("Chat actor: Generate contextual image prompt", logger): with timer("Chat actor: Generate contextual image prompt", logger):
response = await send_message_to_model_wrapper(image_prompt, chat_model_option=summarizer_model) response = await send_message_to_model_wrapper(
image_prompt, chat_model_option=chat_model, subscribed=subscribed
)
response = response.strip() response = response.strip()
if response.startswith(('"', "'")) and response.endswith(('"', "'")): if response.startswith(('"', "'")) and response.endswith(('"', "'")):
response = response[1:-1] response = response[1:-1]
@@ -549,13 +556,18 @@ async def send_message_to_model_wrapper(
system_message: str = "", system_message: str = "",
response_type: str = "text", response_type: str = "text",
chat_model_option: ChatModelOptions = None, chat_model_option: ChatModelOptions = None,
subscribed: bool = False,
): ):
conversation_config: ChatModelOptions = ( conversation_config: ChatModelOptions = (
chat_model_option or await ConversationAdapters.aget_default_conversation_config() chat_model_option or await ConversationAdapters.aget_default_conversation_config()
) )
chat_model = conversation_config.chat_model chat_model = conversation_config.chat_model
max_tokens = conversation_config.max_prompt_size max_tokens = (
conversation_config.subscribed_max_prompt_size
if subscribed and conversation_config.subscribed_max_prompt_size
else conversation_config.max_prompt_size
)
tokenizer = conversation_config.tokenizer tokenizer = conversation_config.tokenizer
if conversation_config.model_type == "offline": if conversation_config.model_type == "offline":
@@ -786,6 +798,7 @@ async def text_to_image(
location_data: LocationData, location_data: LocationData,
references: List[Dict[str, Any]], references: List[Dict[str, Any]],
online_results: Dict[str, Any], online_results: Dict[str, Any],
subscribed: bool = False,
send_status_func: Optional[Callable] = None, send_status_func: Optional[Callable] = None,
): ):
status_code = 200 status_code = 200
@@ -822,6 +835,7 @@ async def text_to_image(
note_references=references, note_references=references,
online_results=online_results, online_results=online_results,
model_type=text_to_image_config.model_type, model_type=text_to_image_config.model_type,
subscribed=subscribed,
) )
if send_status_func: if send_status_func:
@@ -1359,7 +1373,9 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False)
current_notion_config = get_user_notion_config(user) current_notion_config = get_user_notion_config(user)
notion_token = current_notion_config.token if current_notion_config else "" notion_token = current_notion_config.token if current_notion_config else ""
selected_chat_model_config = ConversationAdapters.get_conversation_config(user) selected_chat_model_config = (
ConversationAdapters.get_conversation_config(user) or ConversationAdapters.get_default_conversation_config()
)
chat_models = ConversationAdapters.get_conversation_processor_options().all() chat_models = ConversationAdapters.get_conversation_processor_options().all()
chat_model_options = list() chat_model_options = list()
for chat_model in chat_models: for chat_model in chat_models: