mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-09 21:29:11 +00:00
Enable free tier users to have unlimited chats with the default chat model (#886)
- Allow free tier users to have unlimited chats with default chat model. It'll only be rate-limited and at the same rate as subscribed users - In the server chat settings, replace the concept of default/summarizer models with default/advanced chat models. Use the advanced models as a default for subscribed users. - For each `ChatModelOption' configuration, allow the admin to specify a separate value of `max_tokens' for subscribed users. This allows server admins to configure different max token limits for unsubscribed and subscribed users - Show error message in web app when hit rate limit or other server errors
This commit is contained in:
@@ -222,7 +222,20 @@ export default function Chat() {
|
|||||||
try {
|
try {
|
||||||
await readChatStream(response);
|
await readChatStream(response);
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
console.log(err);
|
console.error(err);
|
||||||
|
// Retrieve latest message being processed
|
||||||
|
const currentMessage = messages.find((message) => !message.completed);
|
||||||
|
if (!currentMessage) return;
|
||||||
|
|
||||||
|
// Render error message as current message
|
||||||
|
const errorMessage = (err as Error).message;
|
||||||
|
currentMessage.rawResponse = `Encountered Error: ${errorMessage}. Please try again later.`;
|
||||||
|
|
||||||
|
// Complete message streaming teardown properly
|
||||||
|
currentMessage.completed = true;
|
||||||
|
setMessages([...messages]);
|
||||||
|
setQueryToProcess("");
|
||||||
|
setProcessQuerySignal(false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -386,8 +386,6 @@ export default function ChatMessage(props: ChatMessageProps) {
|
|||||||
preElement.prepend(copyButton);
|
preElement.prepend(copyButton);
|
||||||
});
|
});
|
||||||
|
|
||||||
console.log("render katex within the chat message");
|
|
||||||
|
|
||||||
renderMathInElement(messageRef.current, {
|
renderMathInElement(messageRef.current, {
|
||||||
delimiters: [
|
delimiters: [
|
||||||
{ left: "$$", right: "$$", display: true },
|
{ left: "$$", right: "$$", display: true },
|
||||||
|
|||||||
@@ -672,7 +672,15 @@ export default function SettingsView() {
|
|||||||
};
|
};
|
||||||
|
|
||||||
const updateModel = (name: string) => async (id: string) => {
|
const updateModel = (name: string) => async (id: string) => {
|
||||||
if (!userConfig?.is_active && name !== "search") return;
|
if (!userConfig?.is_active && name !== "search") {
|
||||||
|
toast({
|
||||||
|
title: `Model Update`,
|
||||||
|
description: `You need to be subscribed to update ${name} models`,
|
||||||
|
variant: "destructive",
|
||||||
|
});
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const response = await fetch(`/api/model/${name}?id=` + id, {
|
const response = await fetch(`/api/model/${name}?id=` + id, {
|
||||||
method: "POST",
|
method: "POST",
|
||||||
@@ -1144,7 +1152,7 @@ export default function SettingsView() {
|
|||||||
<ChatCircleText className="h-7 w-7 mr-2" />
|
<ChatCircleText className="h-7 w-7 mr-2" />
|
||||||
Chat
|
Chat
|
||||||
</CardHeader>
|
</CardHeader>
|
||||||
<CardContent className="overflow-hidden pb-12 grid gap-8">
|
<CardContent className="overflow-hidden pb-12 grid gap-8 h-fit">
|
||||||
<p className="text-gray-400">
|
<p className="text-gray-400">
|
||||||
Pick the chat model to generate text responses
|
Pick the chat model to generate text responses
|
||||||
</p>
|
</p>
|
||||||
@@ -1169,7 +1177,7 @@ export default function SettingsView() {
|
|||||||
<FileMagnifyingGlass className="h-7 w-7 mr-2" />
|
<FileMagnifyingGlass className="h-7 w-7 mr-2" />
|
||||||
Search
|
Search
|
||||||
</CardHeader>
|
</CardHeader>
|
||||||
<CardContent className="overflow-hidden pb-12 grid gap-8">
|
<CardContent className="overflow-hidden pb-12 grid gap-8 h-fit">
|
||||||
<p className="text-gray-400">
|
<p className="text-gray-400">
|
||||||
Pick the search model to find your documents
|
Pick the search model to find your documents
|
||||||
</p>
|
</p>
|
||||||
@@ -1190,7 +1198,7 @@ export default function SettingsView() {
|
|||||||
<Palette className="h-7 w-7 mr-2" />
|
<Palette className="h-7 w-7 mr-2" />
|
||||||
Paint
|
Paint
|
||||||
</CardHeader>
|
</CardHeader>
|
||||||
<CardContent className="overflow-hidden pb-12 grid gap-8">
|
<CardContent className="overflow-hidden pb-12 grid gap-8 h-fit">
|
||||||
<p className="text-gray-400">
|
<p className="text-gray-400">
|
||||||
Pick the paint model to generate image responses
|
Pick the paint model to generate image responses
|
||||||
</p>
|
</p>
|
||||||
@@ -1217,7 +1225,7 @@ export default function SettingsView() {
|
|||||||
<Waveform className="h-7 w-7 mr-2" />
|
<Waveform className="h-7 w-7 mr-2" />
|
||||||
Voice
|
Voice
|
||||||
</CardHeader>
|
</CardHeader>
|
||||||
<CardContent className="overflow-hidden pb-12 grid gap-8">
|
<CardContent className="overflow-hidden pb-12 grid gap-8 h-fit">
|
||||||
<p className="text-gray-400">
|
<p className="text-gray-400">
|
||||||
Pick the voice model to generate speech
|
Pick the voice model to generate speech
|
||||||
responses
|
responses
|
||||||
|
|||||||
@@ -32,10 +32,9 @@ from khoj.database.adapters import (
|
|||||||
ClientApplicationAdapters,
|
ClientApplicationAdapters,
|
||||||
ConversationAdapters,
|
ConversationAdapters,
|
||||||
ProcessLockAdapters,
|
ProcessLockAdapters,
|
||||||
SubscriptionState,
|
|
||||||
aget_or_create_user_by_phone_number,
|
aget_or_create_user_by_phone_number,
|
||||||
aget_user_by_phone_number,
|
aget_user_by_phone_number,
|
||||||
aget_user_subscription_state,
|
ais_user_subscribed,
|
||||||
delete_user_requests,
|
delete_user_requests,
|
||||||
get_all_users,
|
get_all_users,
|
||||||
get_or_create_search_models,
|
get_or_create_search_models,
|
||||||
@@ -119,15 +118,7 @@ class UserAuthenticationBackend(AuthenticationBackend):
|
|||||||
.afirst()
|
.afirst()
|
||||||
)
|
)
|
||||||
if user:
|
if user:
|
||||||
if not state.billing_enabled:
|
subscribed = await ais_user_subscribed(user)
|
||||||
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user)
|
|
||||||
|
|
||||||
subscription_state = await aget_user_subscription_state(user)
|
|
||||||
subscribed = (
|
|
||||||
subscription_state == SubscriptionState.SUBSCRIBED.value
|
|
||||||
or subscription_state == SubscriptionState.TRIAL.value
|
|
||||||
or subscription_state == SubscriptionState.UNSUBSCRIBED.value
|
|
||||||
)
|
|
||||||
if subscribed:
|
if subscribed:
|
||||||
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user)
|
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user)
|
||||||
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
|
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
|
||||||
@@ -144,15 +135,7 @@ class UserAuthenticationBackend(AuthenticationBackend):
|
|||||||
.afirst()
|
.afirst()
|
||||||
)
|
)
|
||||||
if user_with_token:
|
if user_with_token:
|
||||||
if not state.billing_enabled:
|
subscribed = await ais_user_subscribed(user_with_token.user)
|
||||||
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user_with_token.user)
|
|
||||||
|
|
||||||
subscription_state = await aget_user_subscription_state(user_with_token.user)
|
|
||||||
subscribed = (
|
|
||||||
subscription_state == SubscriptionState.SUBSCRIBED.value
|
|
||||||
or subscription_state == SubscriptionState.TRIAL.value
|
|
||||||
or subscription_state == SubscriptionState.UNSUBSCRIBED.value
|
|
||||||
)
|
|
||||||
if subscribed:
|
if subscribed:
|
||||||
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user_with_token.user)
|
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user_with_token.user)
|
||||||
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user)
|
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user)
|
||||||
@@ -189,20 +172,10 @@ class UserAuthenticationBackend(AuthenticationBackend):
|
|||||||
if user is None:
|
if user is None:
|
||||||
return AuthCredentials(), UnauthenticatedUser()
|
return AuthCredentials(), UnauthenticatedUser()
|
||||||
|
|
||||||
if not state.billing_enabled:
|
subscribed = await ais_user_subscribed(user)
|
||||||
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user, client_application)
|
|
||||||
|
|
||||||
subscription_state = await aget_user_subscription_state(user)
|
|
||||||
subscribed = (
|
|
||||||
subscription_state == SubscriptionState.SUBSCRIBED.value
|
|
||||||
or subscription_state == SubscriptionState.TRIAL.value
|
|
||||||
or subscription_state == SubscriptionState.UNSUBSCRIBED.value
|
|
||||||
)
|
|
||||||
if subscribed:
|
if subscribed:
|
||||||
return (
|
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user, client_application)
|
||||||
AuthCredentials(["authenticated", "premium"]),
|
|
||||||
AuthenticatedKhojUser(user, client_application),
|
|
||||||
)
|
|
||||||
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user, client_application)
|
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user, client_application)
|
||||||
|
|
||||||
# No auth required if server in anonymous mode
|
# No auth required if server in anonymous mode
|
||||||
|
|||||||
@@ -300,6 +300,38 @@ async def aget_user_subscription_state(user: KhojUser) -> str:
|
|||||||
return subscription_to_state(user_subscription)
|
return subscription_to_state(user_subscription)
|
||||||
|
|
||||||
|
|
||||||
|
async def ais_user_subscribed(user: KhojUser) -> bool:
|
||||||
|
"""
|
||||||
|
Get whether the user is subscribed
|
||||||
|
"""
|
||||||
|
if not state.billing_enabled or state.anonymous_mode:
|
||||||
|
return True
|
||||||
|
|
||||||
|
subscription_state = await aget_user_subscription_state(user)
|
||||||
|
subscribed = (
|
||||||
|
subscription_state == SubscriptionState.SUBSCRIBED.value
|
||||||
|
or subscription_state == SubscriptionState.TRIAL.value
|
||||||
|
or subscription_state == SubscriptionState.UNSUBSCRIBED.value
|
||||||
|
)
|
||||||
|
return subscribed
|
||||||
|
|
||||||
|
|
||||||
|
def is_user_subscribed(user: KhojUser) -> bool:
|
||||||
|
"""
|
||||||
|
Get whether the user is subscribed
|
||||||
|
"""
|
||||||
|
if not state.billing_enabled or state.anonymous_mode:
|
||||||
|
return True
|
||||||
|
|
||||||
|
subscription_state = get_user_subscription_state(user.email)
|
||||||
|
subscribed = (
|
||||||
|
subscription_state == SubscriptionState.SUBSCRIBED.value
|
||||||
|
or subscription_state == SubscriptionState.TRIAL.value
|
||||||
|
or subscription_state == SubscriptionState.UNSUBSCRIBED.value
|
||||||
|
)
|
||||||
|
return subscribed
|
||||||
|
|
||||||
|
|
||||||
async def get_user_by_email(email: str) -> KhojUser:
|
async def get_user_by_email(email: str) -> KhojUser:
|
||||||
return await KhojUser.objects.filter(email=email).afirst()
|
return await KhojUser.objects.filter(email=email).afirst()
|
||||||
|
|
||||||
@@ -751,17 +783,23 @@ class ConversationAdapters:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_conversation_config(user: KhojUser):
|
def get_conversation_config(user: KhojUser):
|
||||||
|
subscribed = is_user_subscribed(user)
|
||||||
|
if not subscribed:
|
||||||
|
return ConversationAdapters.get_default_conversation_config()
|
||||||
config = UserConversationConfig.objects.filter(user=user).first()
|
config = UserConversationConfig.objects.filter(user=user).first()
|
||||||
if not config:
|
if config:
|
||||||
return None
|
return config.setting
|
||||||
return config.setting
|
return ConversationAdapters.get_advanced_conversation_config()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def aget_conversation_config(user: KhojUser):
|
async def aget_conversation_config(user: KhojUser):
|
||||||
|
subscribed = await ais_user_subscribed(user)
|
||||||
|
if not subscribed:
|
||||||
|
return await ConversationAdapters.aget_default_conversation_config()
|
||||||
config = await UserConversationConfig.objects.filter(user=user).prefetch_related("setting").afirst()
|
config = await UserConversationConfig.objects.filter(user=user).prefetch_related("setting").afirst()
|
||||||
if not config:
|
if config:
|
||||||
return None
|
return config.setting
|
||||||
return config.setting
|
return ConversationAdapters.aget_advanced_conversation_config()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def aget_voice_model_config(user: KhojUser) -> Optional[VoiceModelOption]:
|
async def aget_voice_model_config(user: KhojUser) -> Optional[VoiceModelOption]:
|
||||||
@@ -784,35 +822,38 @@ class ConversationAdapters:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def get_default_conversation_config():
|
def get_default_conversation_config():
|
||||||
server_chat_settings = ServerChatSettings.objects.first()
|
server_chat_settings = ServerChatSettings.objects.first()
|
||||||
if server_chat_settings is None or server_chat_settings.default_model is None:
|
if server_chat_settings is None or server_chat_settings.chat_default is None:
|
||||||
return ChatModelOptions.objects.filter().first()
|
return ChatModelOptions.objects.filter().first()
|
||||||
return server_chat_settings.default_model
|
return server_chat_settings.chat_default
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def aget_default_conversation_config():
|
async def aget_default_conversation_config():
|
||||||
server_chat_settings: ServerChatSettings = (
|
server_chat_settings: ServerChatSettings = (
|
||||||
await ServerChatSettings.objects.filter()
|
await ServerChatSettings.objects.filter()
|
||||||
.prefetch_related("default_model", "default_model__openai_config")
|
.prefetch_related("chat_default", "chat_default__openai_config")
|
||||||
.afirst()
|
.afirst()
|
||||||
)
|
)
|
||||||
if server_chat_settings is None or server_chat_settings.default_model is None:
|
if server_chat_settings is None or server_chat_settings.chat_default is None:
|
||||||
return await ChatModelOptions.objects.filter().prefetch_related("openai_config").afirst()
|
return await ChatModelOptions.objects.filter().prefetch_related("openai_config").afirst()
|
||||||
return server_chat_settings.default_model
|
return server_chat_settings.chat_default
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def aget_summarizer_conversation_config():
|
def get_advanced_conversation_config():
|
||||||
|
server_chat_settings = ServerChatSettings.objects.first()
|
||||||
|
if server_chat_settings is None or server_chat_settings.chat_advanced is None:
|
||||||
|
return ConversationAdapters.get_default_conversation_config()
|
||||||
|
return server_chat_settings.chat_advanced
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def aget_advanced_conversation_config():
|
||||||
server_chat_settings: ServerChatSettings = (
|
server_chat_settings: ServerChatSettings = (
|
||||||
await ServerChatSettings.objects.filter()
|
await ServerChatSettings.objects.filter()
|
||||||
.prefetch_related(
|
.prefetch_related("chat_advanced", "chat_advanced__openai_config")
|
||||||
"summarizer_model", "default_model", "default_model__openai_config", "summarizer_model__openai_config"
|
|
||||||
)
|
|
||||||
.afirst()
|
.afirst()
|
||||||
)
|
)
|
||||||
if server_chat_settings is None or (
|
if server_chat_settings is None or server_chat_settings.chat_advanced is None:
|
||||||
server_chat_settings.summarizer_model is None and server_chat_settings.default_model is None
|
return await ConversationAdapters.aget_default_conversation_config()
|
||||||
):
|
return server_chat_settings.chat_advanced
|
||||||
return await ChatModelOptions.objects.filter().prefetch_related("openai_config").afirst()
|
|
||||||
return server_chat_settings.summarizer_model or server_chat_settings.default_model
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_conversation_from_public_conversation(
|
def create_conversation_from_public_conversation(
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ from khoj.database.models import (
|
|||||||
SpeechToTextModelOptions,
|
SpeechToTextModelOptions,
|
||||||
Subscription,
|
Subscription,
|
||||||
TextToImageModelConfig,
|
TextToImageModelConfig,
|
||||||
|
UserConversationConfig,
|
||||||
UserSearchModelConfig,
|
UserSearchModelConfig,
|
||||||
UserVoiceModelConfig,
|
UserVoiceModelConfig,
|
||||||
VoiceModelOption,
|
VoiceModelOption,
|
||||||
@@ -101,6 +102,7 @@ admin.site.register(GithubConfig)
|
|||||||
admin.site.register(NotionConfig)
|
admin.site.register(NotionConfig)
|
||||||
admin.site.register(UserVoiceModelConfig)
|
admin.site.register(UserVoiceModelConfig)
|
||||||
admin.site.register(VoiceModelOption)
|
admin.site.register(VoiceModelOption)
|
||||||
|
admin.site.register(UserConversationConfig)
|
||||||
|
|
||||||
|
|
||||||
@admin.register(Agent)
|
@admin.register(Agent)
|
||||||
@@ -191,8 +193,8 @@ class SearchModelConfigAdmin(admin.ModelAdmin):
|
|||||||
@admin.register(ServerChatSettings)
|
@admin.register(ServerChatSettings)
|
||||||
class ServerChatSettingsAdmin(admin.ModelAdmin):
|
class ServerChatSettingsAdmin(admin.ModelAdmin):
|
||||||
list_display = (
|
list_display = (
|
||||||
"default_model",
|
"chat_default",
|
||||||
"summarizer_model",
|
"chat_advanced",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,51 @@
|
|||||||
|
# Generated by Django 5.0.7 on 2024-08-16 18:18
|
||||||
|
|
||||||
|
import django.db.models.deletion
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
dependencies = [
|
||||||
|
("database", "0056_searchmodelconfig_cross_encoder_model_config"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.RenameField(
|
||||||
|
model_name="serverchatsettings",
|
||||||
|
old_name="default_model",
|
||||||
|
new_name="chat_default",
|
||||||
|
),
|
||||||
|
migrations.RemoveField(
|
||||||
|
model_name="serverchatsettings",
|
||||||
|
name="summarizer_model",
|
||||||
|
),
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="chatmodeloptions",
|
||||||
|
name="subscribed_max_prompt_size",
|
||||||
|
field=models.IntegerField(blank=True, default=None, null=True),
|
||||||
|
),
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="serverchatsettings",
|
||||||
|
name="chat_advanced",
|
||||||
|
field=models.ForeignKey(
|
||||||
|
blank=True,
|
||||||
|
default=None,
|
||||||
|
null=True,
|
||||||
|
on_delete=django.db.models.deletion.CASCADE,
|
||||||
|
related_name="chat_advanced",
|
||||||
|
to="database.chatmodeloptions",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
migrations.AlterField(
|
||||||
|
model_name="serverchatsettings",
|
||||||
|
name="chat_default",
|
||||||
|
field=models.ForeignKey(
|
||||||
|
blank=True,
|
||||||
|
default=None,
|
||||||
|
null=True,
|
||||||
|
on_delete=django.db.models.deletion.CASCADE,
|
||||||
|
related_name="chat_default",
|
||||||
|
to="database.chatmodeloptions",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
||||||
@@ -89,6 +89,7 @@ class ChatModelOptions(BaseModel):
|
|||||||
ANTHROPIC = "anthropic"
|
ANTHROPIC = "anthropic"
|
||||||
|
|
||||||
max_prompt_size = models.IntegerField(default=None, null=True, blank=True)
|
max_prompt_size = models.IntegerField(default=None, null=True, blank=True)
|
||||||
|
subscribed_max_prompt_size = models.IntegerField(default=None, null=True, blank=True)
|
||||||
tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True)
|
tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||||
chat_model = models.CharField(max_length=200, default="NousResearch/Hermes-2-Pro-Mistral-7B-GGUF")
|
chat_model = models.CharField(max_length=200, default="NousResearch/Hermes-2-Pro-Mistral-7B-GGUF")
|
||||||
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)
|
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)
|
||||||
@@ -205,11 +206,11 @@ class GithubRepoConfig(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class ServerChatSettings(BaseModel):
|
class ServerChatSettings(BaseModel):
|
||||||
default_model = models.ForeignKey(
|
chat_default = models.ForeignKey(
|
||||||
ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="default_model"
|
ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="chat_default"
|
||||||
)
|
)
|
||||||
summarizer_model = models.ForeignKey(
|
chat_advanced = models.ForeignKey(
|
||||||
ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="summarizer_model"
|
ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="chat_advanced"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -53,6 +53,7 @@ async def search_online(
|
|||||||
conversation_history: dict,
|
conversation_history: dict,
|
||||||
location: LocationData,
|
location: LocationData,
|
||||||
user: KhojUser,
|
user: KhojUser,
|
||||||
|
subscribed: bool = False,
|
||||||
send_status_func: Optional[Callable] = None,
|
send_status_func: Optional[Callable] = None,
|
||||||
custom_filters: List[str] = [],
|
custom_filters: List[str] = [],
|
||||||
):
|
):
|
||||||
@@ -91,12 +92,15 @@ async def search_online(
|
|||||||
# Read, extract relevant info from the retrieved web pages
|
# Read, extract relevant info from the retrieved web pages
|
||||||
if webpages:
|
if webpages:
|
||||||
webpage_links = [link for link, _, _ in webpages]
|
webpage_links = [link for link, _, _ in webpages]
|
||||||
logger.info(f"🌐👀 Reading web pages at: {list(webpage_links)}")
|
logger.info(f"Reading web pages at: {list(webpage_links)}")
|
||||||
if send_status_func:
|
if send_status_func:
|
||||||
webpage_links_str = "\n- " + "\n- ".join(list(webpage_links))
|
webpage_links_str = "\n- " + "\n- ".join(list(webpage_links))
|
||||||
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
|
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
|
||||||
yield {ChatEvent.STATUS: event}
|
yield {ChatEvent.STATUS: event}
|
||||||
tasks = [read_webpage_and_extract_content(subquery, link, content) for link, subquery, content in webpages]
|
tasks = [
|
||||||
|
read_webpage_and_extract_content(subquery, link, content, subscribed=subscribed)
|
||||||
|
for link, subquery, content in webpages
|
||||||
|
]
|
||||||
results = await asyncio.gather(*tasks)
|
results = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
# Collect extracted info from the retrieved web pages
|
# Collect extracted info from the retrieved web pages
|
||||||
@@ -132,6 +136,7 @@ async def read_webpages(
|
|||||||
conversation_history: dict,
|
conversation_history: dict,
|
||||||
location: LocationData,
|
location: LocationData,
|
||||||
user: KhojUser,
|
user: KhojUser,
|
||||||
|
subscribed: bool = False,
|
||||||
send_status_func: Optional[Callable] = None,
|
send_status_func: Optional[Callable] = None,
|
||||||
):
|
):
|
||||||
"Infer web pages to read from the query and extract relevant information from them"
|
"Infer web pages to read from the query and extract relevant information from them"
|
||||||
@@ -146,7 +151,7 @@ async def read_webpages(
|
|||||||
webpage_links_str = "\n- " + "\n- ".join(list(urls))
|
webpage_links_str = "\n- " + "\n- ".join(list(urls))
|
||||||
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
|
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
|
||||||
yield {ChatEvent.STATUS: event}
|
yield {ChatEvent.STATUS: event}
|
||||||
tasks = [read_webpage_and_extract_content(query, url) for url in urls]
|
tasks = [read_webpage_and_extract_content(query, url, subscribed=subscribed) for url in urls]
|
||||||
results = await asyncio.gather(*tasks)
|
results = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
response: Dict[str, Dict] = defaultdict(dict)
|
response: Dict[str, Dict] = defaultdict(dict)
|
||||||
@@ -157,14 +162,14 @@ async def read_webpages(
|
|||||||
|
|
||||||
|
|
||||||
async def read_webpage_and_extract_content(
|
async def read_webpage_and_extract_content(
|
||||||
subquery: str, url: str, content: str = None
|
subquery: str, url: str, content: str = None, subscribed: bool = False
|
||||||
) -> Tuple[str, Union[None, str], str]:
|
) -> Tuple[str, Union[None, str], str]:
|
||||||
try:
|
try:
|
||||||
if is_none_or_empty(content):
|
if is_none_or_empty(content):
|
||||||
with timer(f"Reading web page at '{url}' took", logger):
|
with timer(f"Reading web page at '{url}' took", logger):
|
||||||
content = await read_webpage_with_olostep(url) if OLOSTEP_API_KEY else await read_webpage_with_jina(url)
|
content = await read_webpage_with_olostep(url) if OLOSTEP_API_KEY else await read_webpage_with_jina(url)
|
||||||
with timer(f"Extracting relevant information from web page at '{url}' took", logger):
|
with timer(f"Extracting relevant information from web page at '{url}' took", logger):
|
||||||
extracted_info = await extract_relevant_info(subquery, content)
|
extracted_info = await extract_relevant_info(subquery, content, subscribed=subscribed)
|
||||||
return subquery, extracted_info, url
|
return subquery, extracted_info, url
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to read web page at '{url}' with {e}")
|
logger.error(f"Failed to read web page at '{url}' with {e}")
|
||||||
|
|||||||
@@ -4,14 +4,14 @@ import logging
|
|||||||
import time
|
import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Dict, Optional
|
||||||
from urllib.parse import unquote
|
from urllib.parse import unquote
|
||||||
|
|
||||||
from asgiref.sync import sync_to_async
|
from asgiref.sync import sync_to_async
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||||
from fastapi.requests import Request
|
from fastapi.requests import Request
|
||||||
from fastapi.responses import Response, StreamingResponse
|
from fastapi.responses import Response, StreamingResponse
|
||||||
from starlette.authentication import requires
|
from starlette.authentication import has_required_scope, requires
|
||||||
|
|
||||||
from khoj.app.settings import ALLOWED_HOSTS
|
from khoj.app.settings import ALLOWED_HOSTS
|
||||||
from khoj.database.adapters import (
|
from khoj.database.adapters import (
|
||||||
@@ -59,7 +59,7 @@ from khoj.utils.rawconfig import FileFilterRequest, FilesFilterRequest, Location
|
|||||||
# Initialize Router
|
# Initialize Router
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
conversation_command_rate_limiter = ConversationCommandRateLimiter(
|
conversation_command_rate_limiter = ConversationCommandRateLimiter(
|
||||||
trial_rate_limit=2, subscribed_rate_limit=100, slug="command"
|
trial_rate_limit=100, subscribed_rate_limit=100, slug="command"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -532,10 +532,10 @@ async def chat(
|
|||||||
country: Optional[str] = None,
|
country: Optional[str] = None,
|
||||||
timezone: Optional[str] = None,
|
timezone: Optional[str] = None,
|
||||||
rate_limiter_per_minute=Depends(
|
rate_limiter_per_minute=Depends(
|
||||||
ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute")
|
ApiUserRateLimiter(requests=60, subscribed_requests=60, window=60, slug="chat_minute")
|
||||||
),
|
),
|
||||||
rate_limiter_per_day=Depends(
|
rate_limiter_per_day=Depends(
|
||||||
ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
|
ApiUserRateLimiter(requests=600, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
async def event_generator(q: str):
|
async def event_generator(q: str):
|
||||||
@@ -544,6 +544,7 @@ async def chat(
|
|||||||
chat_metadata: dict = {}
|
chat_metadata: dict = {}
|
||||||
connection_alive = True
|
connection_alive = True
|
||||||
user: KhojUser = request.user.object
|
user: KhojUser = request.user.object
|
||||||
|
subscribed: bool = has_required_scope(request, ["premium"])
|
||||||
event_delimiter = "␃🔚␗"
|
event_delimiter = "␃🔚␗"
|
||||||
q = unquote(q)
|
q = unquote(q)
|
||||||
|
|
||||||
@@ -632,7 +633,9 @@ async def chat(
|
|||||||
is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask]
|
is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask]
|
||||||
|
|
||||||
if conversation_commands == [ConversationCommand.Default] or is_automated_task:
|
if conversation_commands == [ConversationCommand.Default] or is_automated_task:
|
||||||
conversation_commands = await aget_relevant_information_sources(q, meta_log, is_automated_task)
|
conversation_commands = await aget_relevant_information_sources(
|
||||||
|
q, meta_log, is_automated_task, subscribed=subscribed
|
||||||
|
)
|
||||||
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
|
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
|
||||||
async for result in send_event(
|
async for result in send_event(
|
||||||
ChatEvent.STATUS, f"**Chose Data Sources to Search:** {conversation_commands_str}"
|
ChatEvent.STATUS, f"**Chose Data Sources to Search:** {conversation_commands_str}"
|
||||||
@@ -687,7 +690,7 @@ async def chat(
|
|||||||
):
|
):
|
||||||
yield result
|
yield result
|
||||||
|
|
||||||
response = await extract_relevant_summary(q, contextual_data)
|
response = await extract_relevant_summary(q, contextual_data, subscribed=subscribed)
|
||||||
response_log = str(response)
|
response_log = str(response)
|
||||||
async for result in send_llm_response(response_log):
|
async for result in send_llm_response(response_log):
|
||||||
yield result
|
yield result
|
||||||
@@ -792,7 +795,13 @@ async def chat(
|
|||||||
if ConversationCommand.Online in conversation_commands:
|
if ConversationCommand.Online in conversation_commands:
|
||||||
try:
|
try:
|
||||||
async for result in search_online(
|
async for result in search_online(
|
||||||
defiltered_query, meta_log, location, user, partial(send_event, ChatEvent.STATUS), custom_filters
|
defiltered_query,
|
||||||
|
meta_log,
|
||||||
|
location,
|
||||||
|
user,
|
||||||
|
subscribed,
|
||||||
|
partial(send_event, ChatEvent.STATUS),
|
||||||
|
custom_filters,
|
||||||
):
|
):
|
||||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
yield result[ChatEvent.STATUS]
|
yield result[ChatEvent.STATUS]
|
||||||
@@ -809,7 +818,7 @@ async def chat(
|
|||||||
if ConversationCommand.Webpage in conversation_commands:
|
if ConversationCommand.Webpage in conversation_commands:
|
||||||
try:
|
try:
|
||||||
async for result in read_webpages(
|
async for result in read_webpages(
|
||||||
defiltered_query, meta_log, location, user, partial(send_event, ChatEvent.STATUS)
|
defiltered_query, meta_log, location, user, subscribed, partial(send_event, ChatEvent.STATUS)
|
||||||
):
|
):
|
||||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
yield result[ChatEvent.STATUS]
|
yield result[ChatEvent.STATUS]
|
||||||
@@ -853,6 +862,7 @@ async def chat(
|
|||||||
location_data=location,
|
location_data=location,
|
||||||
references=compiled_references,
|
references=compiled_references,
|
||||||
online_results=online_results,
|
online_results=online_results,
|
||||||
|
subscribed=subscribed,
|
||||||
send_status_func=partial(send_event, ChatEvent.STATUS),
|
send_status_func=partial(send_event, ChatEvent.STATUS),
|
||||||
):
|
):
|
||||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
|
|||||||
@@ -252,7 +252,7 @@ async def acreate_title_from_query(query: str) -> str:
|
|||||||
return response.strip()
|
return response.strip()
|
||||||
|
|
||||||
|
|
||||||
async def aget_relevant_information_sources(query: str, conversation_history: dict, is_task: bool):
|
async def aget_relevant_information_sources(query: str, conversation_history: dict, is_task: bool, subscribed: bool):
|
||||||
"""
|
"""
|
||||||
Given a query, determine which of the available tools the agent should use in order to answer appropriately.
|
Given a query, determine which of the available tools the agent should use in order to answer appropriately.
|
||||||
"""
|
"""
|
||||||
@@ -273,7 +273,9 @@ async def aget_relevant_information_sources(query: str, conversation_history: di
|
|||||||
)
|
)
|
||||||
|
|
||||||
with timer("Chat actor: Infer information sources to refer", logger):
|
with timer("Chat actor: Infer information sources to refer", logger):
|
||||||
response = await send_message_to_model_wrapper(relevant_tools_prompt, response_type="json_object")
|
response = await send_message_to_model_wrapper(
|
||||||
|
relevant_tools_prompt, response_type="json_object", subscribed=subscribed
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = response.strip()
|
response = response.strip()
|
||||||
@@ -434,7 +436,7 @@ async def schedule_query(q: str, conversation_history: dict) -> Tuple[str, ...]:
|
|||||||
raise AssertionError(f"Invalid response for scheduling query: {raw_response}")
|
raise AssertionError(f"Invalid response for scheduling query: {raw_response}")
|
||||||
|
|
||||||
|
|
||||||
async def extract_relevant_info(q: str, corpus: str) -> Union[str, None]:
|
async def extract_relevant_info(q: str, corpus: str, subscribed: bool) -> Union[str, None]:
|
||||||
"""
|
"""
|
||||||
Extract relevant information for a given query from the target corpus
|
Extract relevant information for a given query from the target corpus
|
||||||
"""
|
"""
|
||||||
@@ -447,18 +449,19 @@ async def extract_relevant_info(q: str, corpus: str) -> Union[str, None]:
|
|||||||
corpus=corpus.strip(),
|
corpus=corpus.strip(),
|
||||||
)
|
)
|
||||||
|
|
||||||
summarizer_model: ChatModelOptions = await ConversationAdapters.aget_summarizer_conversation_config()
|
chat_model: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config()
|
||||||
|
|
||||||
with timer("Chat actor: Extract relevant information from data", logger):
|
with timer("Chat actor: Extract relevant information from data", logger):
|
||||||
response = await send_message_to_model_wrapper(
|
response = await send_message_to_model_wrapper(
|
||||||
extract_relevant_information,
|
extract_relevant_information,
|
||||||
prompts.system_prompt_extract_relevant_information,
|
prompts.system_prompt_extract_relevant_information,
|
||||||
chat_model_option=summarizer_model,
|
chat_model_option=chat_model,
|
||||||
|
subscribed=subscribed,
|
||||||
)
|
)
|
||||||
return response.strip()
|
return response.strip()
|
||||||
|
|
||||||
|
|
||||||
async def extract_relevant_summary(q: str, corpus: str) -> Union[str, None]:
|
async def extract_relevant_summary(q: str, corpus: str, subscribed: bool = False) -> Union[str, None]:
|
||||||
"""
|
"""
|
||||||
Extract relevant information for a given query from the target corpus
|
Extract relevant information for a given query from the target corpus
|
||||||
"""
|
"""
|
||||||
@@ -471,13 +474,14 @@ async def extract_relevant_summary(q: str, corpus: str) -> Union[str, None]:
|
|||||||
corpus=corpus.strip(),
|
corpus=corpus.strip(),
|
||||||
)
|
)
|
||||||
|
|
||||||
summarizer_model: ChatModelOptions = await ConversationAdapters.aget_summarizer_conversation_config()
|
chat_model: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config()
|
||||||
|
|
||||||
with timer("Chat actor: Extract relevant information from data", logger):
|
with timer("Chat actor: Extract relevant information from data", logger):
|
||||||
response = await send_message_to_model_wrapper(
|
response = await send_message_to_model_wrapper(
|
||||||
extract_relevant_information,
|
extract_relevant_information,
|
||||||
prompts.system_prompt_extract_relevant_summary,
|
prompts.system_prompt_extract_relevant_summary,
|
||||||
chat_model_option=summarizer_model,
|
chat_model_option=chat_model,
|
||||||
|
subscribed=subscribed,
|
||||||
)
|
)
|
||||||
return response.strip()
|
return response.strip()
|
||||||
|
|
||||||
@@ -489,6 +493,7 @@ async def generate_better_image_prompt(
|
|||||||
note_references: List[Dict[str, Any]],
|
note_references: List[Dict[str, Any]],
|
||||||
online_results: Optional[dict] = None,
|
online_results: Optional[dict] = None,
|
||||||
model_type: Optional[str] = None,
|
model_type: Optional[str] = None,
|
||||||
|
subscribed: bool = False,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Generate a better image prompt from the given query
|
Generate a better image prompt from the given query
|
||||||
@@ -533,10 +538,12 @@ async def generate_better_image_prompt(
|
|||||||
online_results=simplified_online_results,
|
online_results=simplified_online_results,
|
||||||
)
|
)
|
||||||
|
|
||||||
summarizer_model: ChatModelOptions = await ConversationAdapters.aget_summarizer_conversation_config()
|
chat_model: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config()
|
||||||
|
|
||||||
with timer("Chat actor: Generate contextual image prompt", logger):
|
with timer("Chat actor: Generate contextual image prompt", logger):
|
||||||
response = await send_message_to_model_wrapper(image_prompt, chat_model_option=summarizer_model)
|
response = await send_message_to_model_wrapper(
|
||||||
|
image_prompt, chat_model_option=chat_model, subscribed=subscribed
|
||||||
|
)
|
||||||
response = response.strip()
|
response = response.strip()
|
||||||
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
|
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
|
||||||
response = response[1:-1]
|
response = response[1:-1]
|
||||||
@@ -549,13 +556,18 @@ async def send_message_to_model_wrapper(
|
|||||||
system_message: str = "",
|
system_message: str = "",
|
||||||
response_type: str = "text",
|
response_type: str = "text",
|
||||||
chat_model_option: ChatModelOptions = None,
|
chat_model_option: ChatModelOptions = None,
|
||||||
|
subscribed: bool = False,
|
||||||
):
|
):
|
||||||
conversation_config: ChatModelOptions = (
|
conversation_config: ChatModelOptions = (
|
||||||
chat_model_option or await ConversationAdapters.aget_default_conversation_config()
|
chat_model_option or await ConversationAdapters.aget_default_conversation_config()
|
||||||
)
|
)
|
||||||
|
|
||||||
chat_model = conversation_config.chat_model
|
chat_model = conversation_config.chat_model
|
||||||
max_tokens = conversation_config.max_prompt_size
|
max_tokens = (
|
||||||
|
conversation_config.subscribed_max_prompt_size
|
||||||
|
if subscribed and conversation_config.subscribed_max_prompt_size
|
||||||
|
else conversation_config.max_prompt_size
|
||||||
|
)
|
||||||
tokenizer = conversation_config.tokenizer
|
tokenizer = conversation_config.tokenizer
|
||||||
|
|
||||||
if conversation_config.model_type == "offline":
|
if conversation_config.model_type == "offline":
|
||||||
@@ -786,6 +798,7 @@ async def text_to_image(
|
|||||||
location_data: LocationData,
|
location_data: LocationData,
|
||||||
references: List[Dict[str, Any]],
|
references: List[Dict[str, Any]],
|
||||||
online_results: Dict[str, Any],
|
online_results: Dict[str, Any],
|
||||||
|
subscribed: bool = False,
|
||||||
send_status_func: Optional[Callable] = None,
|
send_status_func: Optional[Callable] = None,
|
||||||
):
|
):
|
||||||
status_code = 200
|
status_code = 200
|
||||||
@@ -822,6 +835,7 @@ async def text_to_image(
|
|||||||
note_references=references,
|
note_references=references,
|
||||||
online_results=online_results,
|
online_results=online_results,
|
||||||
model_type=text_to_image_config.model_type,
|
model_type=text_to_image_config.model_type,
|
||||||
|
subscribed=subscribed,
|
||||||
)
|
)
|
||||||
|
|
||||||
if send_status_func:
|
if send_status_func:
|
||||||
@@ -1359,7 +1373,9 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False)
|
|||||||
current_notion_config = get_user_notion_config(user)
|
current_notion_config = get_user_notion_config(user)
|
||||||
notion_token = current_notion_config.token if current_notion_config else ""
|
notion_token = current_notion_config.token if current_notion_config else ""
|
||||||
|
|
||||||
selected_chat_model_config = ConversationAdapters.get_conversation_config(user)
|
selected_chat_model_config = (
|
||||||
|
ConversationAdapters.get_conversation_config(user) or ConversationAdapters.get_default_conversation_config()
|
||||||
|
)
|
||||||
chat_models = ConversationAdapters.get_conversation_processor_options().all()
|
chat_models = ConversationAdapters.get_conversation_processor_options().all()
|
||||||
chat_model_options = list()
|
chat_model_options = list()
|
||||||
for chat_model in chat_models:
|
for chat_model in chat_models:
|
||||||
|
|||||||
Reference in New Issue
Block a user