mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-09 21:29:11 +00:00
Support customization of the OpenAI base url in admin settings (#725)
- Allow self-hosted users to customize their open ai base url. This allows you to easily use a proxy service and extend support for other models. - This also includes a migration that associates any existing openai chat model configuration with an openai processor configuration - Make changing model a paid/subscriber feature - Removes usage of langchain's OpenAI wrapper for better control over parsing input/output
This commit is contained in:
@@ -623,10 +623,6 @@ class ConversationAdapters:
|
|||||||
def get_openai_conversation_config():
|
def get_openai_conversation_config():
|
||||||
return OpenAIProcessorConversationConfig.objects.filter().first()
|
return OpenAIProcessorConversationConfig.objects.filter().first()
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def aget_openai_conversation_config():
|
|
||||||
return await OpenAIProcessorConversationConfig.objects.filter().afirst()
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def has_valid_openai_conversation_config():
|
def has_valid_openai_conversation_config():
|
||||||
return OpenAIProcessorConversationConfig.objects.filter().exists()
|
return OpenAIProcessorConversationConfig.objects.filter().exists()
|
||||||
@@ -659,7 +655,7 @@ class ConversationAdapters:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def aget_default_conversation_config():
|
async def aget_default_conversation_config():
|
||||||
return await ChatModelOptions.objects.filter().afirst()
|
return await ChatModelOptions.objects.filter().prefetch_related("openai_config").afirst()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def save_conversation(
|
def save_conversation(
|
||||||
@@ -697,29 +693,15 @@ class ConversationAdapters:
|
|||||||
user_conversation_config.setting = new_config
|
user_conversation_config.setting = new_config
|
||||||
user_conversation_config.save()
|
user_conversation_config.save()
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def get_default_offline_llm():
|
|
||||||
return await ChatModelOptions.objects.filter(model_type="offline").afirst()
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def aget_user_conversation_config(user: KhojUser):
|
async def aget_user_conversation_config(user: KhojUser):
|
||||||
config = await UserConversationConfig.objects.filter(user=user).prefetch_related("setting").afirst()
|
config = (
|
||||||
|
await UserConversationConfig.objects.filter(user=user).prefetch_related("setting__openai_config").afirst()
|
||||||
|
)
|
||||||
if not config:
|
if not config:
|
||||||
return None
|
return None
|
||||||
return config.setting
|
return config.setting
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def has_openai_chat():
|
|
||||||
return await OpenAIProcessorConversationConfig.objects.filter().aexists()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def aget_default_openai_llm():
|
|
||||||
return await ChatModelOptions.objects.filter(model_type="openai").afirst()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def get_openai_chat_config():
|
|
||||||
return await OpenAIProcessorConversationConfig.objects.filter().afirst()
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def get_speech_to_text_config():
|
async def get_speech_to_text_config():
|
||||||
return await SpeechToTextModelOptions.objects.filter().afirst()
|
return await SpeechToTextModelOptions.objects.filter().afirst()
|
||||||
@@ -744,7 +726,8 @@ class ConversationAdapters:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_valid_conversation_config(user: KhojUser, conversation: Conversation):
|
def get_valid_conversation_config(user: KhojUser, conversation: Conversation):
|
||||||
if conversation.agent and conversation.agent.chat_model:
|
agent: Agent = conversation.agent if AgentAdapters.get_default_agent() != conversation.agent else None
|
||||||
|
if agent and agent.chat_model:
|
||||||
conversation_config = conversation.agent.chat_model
|
conversation_config = conversation.agent.chat_model
|
||||||
else:
|
else:
|
||||||
conversation_config = ConversationAdapters.get_conversation_config(user)
|
conversation_config = ConversationAdapters.get_conversation_config(user)
|
||||||
@@ -760,8 +743,7 @@ class ConversationAdapters:
|
|||||||
|
|
||||||
return conversation_config
|
return conversation_config
|
||||||
|
|
||||||
openai_chat_config = ConversationAdapters.get_openai_conversation_config()
|
if conversation_config.model_type == "openai" and conversation_config.openai_config:
|
||||||
if openai_chat_config and conversation_config.model_type == "openai":
|
|
||||||
return conversation_config
|
return conversation_config
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -0,0 +1,51 @@
|
|||||||
|
# Generated by Django 4.2.10 on 2024-04-24 05:46
|
||||||
|
|
||||||
|
import django.db.models.deletion
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
def attach_openai_config(apps, schema_editor):
|
||||||
|
OpenAIProcessorConversationConfig = apps.get_model("database", "OpenAIProcessorConversationConfig")
|
||||||
|
openai_processor_conversation_config = OpenAIProcessorConversationConfig.objects.first()
|
||||||
|
if openai_processor_conversation_config:
|
||||||
|
ChatModelOptions = apps.get_model("database", "ChatModelOptions")
|
||||||
|
for chat_model_option in ChatModelOptions.objects.all():
|
||||||
|
if chat_model_option.model_type == "openai":
|
||||||
|
chat_model_option.openai_config = openai_processor_conversation_config
|
||||||
|
chat_model_option.save()
|
||||||
|
|
||||||
|
|
||||||
|
def separate_openai_config(apps, schema_editor):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
dependencies = [
|
||||||
|
("database", "0036_delete_offlinechatprocessorconversationconfig"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="chatmodeloptions",
|
||||||
|
name="openai_config",
|
||||||
|
field=models.ForeignKey(
|
||||||
|
blank=True,
|
||||||
|
default=None,
|
||||||
|
null=True,
|
||||||
|
on_delete=django.db.models.deletion.CASCADE,
|
||||||
|
to="database.openaiprocessorconversationconfig",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="openaiprocessorconversationconfig",
|
||||||
|
name="api_base_url",
|
||||||
|
field=models.URLField(blank=True, default=None, null=True),
|
||||||
|
),
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="openaiprocessorconversationconfig",
|
||||||
|
name="name",
|
||||||
|
field=models.CharField(default="default", max_length=200),
|
||||||
|
preserve_default=False,
|
||||||
|
),
|
||||||
|
migrations.RunPython(attach_openai_config, reverse_code=separate_openai_config),
|
||||||
|
]
|
||||||
14
src/khoj/database/migrations/0038_merge_20240425_0857.py
Normal file
14
src/khoj/database/migrations/0038_merge_20240425_0857.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
# Generated by Django 4.2.10 on 2024-04-25 08:57
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from django.db import migrations
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
dependencies = [
|
||||||
|
("database", "0037_chatmodeloptions_openai_config_and_more"),
|
||||||
|
("database", "0037_searchmodelconfig_bi_encoder_docs_encode_config_and_more"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations: List[str] = []
|
||||||
@@ -73,6 +73,12 @@ class Subscription(BaseModel):
|
|||||||
renewal_date = models.DateTimeField(null=True, default=None, blank=True)
|
renewal_date = models.DateTimeField(null=True, default=None, blank=True)
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIProcessorConversationConfig(BaseModel):
|
||||||
|
name = models.CharField(max_length=200)
|
||||||
|
api_key = models.CharField(max_length=200)
|
||||||
|
api_base_url = models.URLField(max_length=200, default=None, blank=True, null=True)
|
||||||
|
|
||||||
|
|
||||||
class ChatModelOptions(BaseModel):
|
class ChatModelOptions(BaseModel):
|
||||||
class ModelType(models.TextChoices):
|
class ModelType(models.TextChoices):
|
||||||
OPENAI = "openai"
|
OPENAI = "openai"
|
||||||
@@ -82,6 +88,9 @@ class ChatModelOptions(BaseModel):
|
|||||||
tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True)
|
tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||||
chat_model = models.CharField(max_length=200, default="NousResearch/Hermes-2-Pro-Mistral-7B-GGUF")
|
chat_model = models.CharField(max_length=200, default="NousResearch/Hermes-2-Pro-Mistral-7B-GGUF")
|
||||||
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)
|
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)
|
||||||
|
openai_config = models.ForeignKey(
|
||||||
|
OpenAIProcessorConversationConfig, on_delete=models.CASCADE, default=None, null=True, blank=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Agent(BaseModel):
|
class Agent(BaseModel):
|
||||||
@@ -211,10 +220,6 @@ class TextToImageModelConfig(BaseModel):
|
|||||||
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI)
|
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI)
|
||||||
|
|
||||||
|
|
||||||
class OpenAIProcessorConversationConfig(BaseModel):
|
|
||||||
api_key = models.CharField(max_length=200)
|
|
||||||
|
|
||||||
|
|
||||||
class SpeechToTextModelOptions(BaseModel):
|
class SpeechToTextModelOptions(BaseModel):
|
||||||
class ModelType(models.TextChoices):
|
class ModelType(models.TextChoices):
|
||||||
OPENAI = "openai"
|
OPENAI = "openai"
|
||||||
|
|||||||
@@ -191,9 +191,15 @@
|
|||||||
</select>
|
</select>
|
||||||
</div>
|
</div>
|
||||||
<div class="card-action-row">
|
<div class="card-action-row">
|
||||||
|
{% if (not billing_enabled) or (subscription_state != 'unsubscribed' and subscription_state != 'expired') %}
|
||||||
<button id="save-model" class="card-button happy" onclick="updateChatModel()">
|
<button id="save-model" class="card-button happy" onclick="updateChatModel()">
|
||||||
Save
|
Save
|
||||||
</button>
|
</button>
|
||||||
|
{% else %}
|
||||||
|
<button id="save-model" class="card-button" disabled>
|
||||||
|
Subscribe to use different models
|
||||||
|
</button>
|
||||||
|
{% endif %}
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@@ -121,14 +121,16 @@ def migrate_server_pg(args):
|
|||||||
if openai.get("chat-model") is None:
|
if openai.get("chat-model") is None:
|
||||||
openai["chat-model"] = "gpt-3.5-turbo"
|
openai["chat-model"] = "gpt-3.5-turbo"
|
||||||
|
|
||||||
OpenAIProcessorConversationConfig.objects.create(
|
openai_config = OpenAIProcessorConversationConfig.objects.create(
|
||||||
api_key=openai.get("api-key"),
|
api_key=openai.get("api-key"), name="default"
|
||||||
)
|
)
|
||||||
|
|
||||||
ChatModelOptions.objects.create(
|
ChatModelOptions.objects.create(
|
||||||
chat_model=openai.get("chat-model"),
|
chat_model=openai.get("chat-model"),
|
||||||
tokenizer=processor_conversation.get("tokenizer"),
|
tokenizer=processor_conversation.get("tokenizer"),
|
||||||
max_prompt_size=processor_conversation.get("max-prompt-size"),
|
max_prompt_size=processor_conversation.get("max-prompt-size"),
|
||||||
model_type=ChatModelOptions.ModelType.OPENAI,
|
model_type=ChatModelOptions.ModelType.OPENAI,
|
||||||
|
openai_config=openai_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
save_config_to_file(raw_config, args.config_file)
|
save_config_to_file(raw_config, args.config_file)
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ def extract_questions(
|
|||||||
model: Optional[str] = "gpt-4-turbo-preview",
|
model: Optional[str] = "gpt-4-turbo-preview",
|
||||||
conversation_log={},
|
conversation_log={},
|
||||||
api_key=None,
|
api_key=None,
|
||||||
|
api_base_url=None,
|
||||||
temperature=0,
|
temperature=0,
|
||||||
max_tokens=100,
|
max_tokens=100,
|
||||||
location_data: LocationData = None,
|
location_data: LocationData = None,
|
||||||
@@ -64,12 +65,12 @@ def extract_questions(
|
|||||||
# Get Response from GPT
|
# Get Response from GPT
|
||||||
response = completion_with_backoff(
|
response = completion_with_backoff(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
completion_kwargs={"temperature": temperature, "max_tokens": max_tokens},
|
model=model,
|
||||||
model_kwargs={
|
temperature=temperature,
|
||||||
"model_name": model,
|
max_tokens=max_tokens,
|
||||||
"openai_api_key": api_key,
|
api_base_url=api_base_url,
|
||||||
"model_kwargs": {"response_format": {"type": "json_object"}},
|
model_kwargs={"response_format": {"type": "json_object"}},
|
||||||
},
|
openai_api_key=api_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extract, Clean Message from GPT's Response
|
# Extract, Clean Message from GPT's Response
|
||||||
@@ -89,7 +90,7 @@ def extract_questions(
|
|||||||
return questions
|
return questions
|
||||||
|
|
||||||
|
|
||||||
def send_message_to_model(messages, api_key, model, response_type="text"):
|
def send_message_to_model(messages, api_key, model, response_type="text", api_base_url=None):
|
||||||
"""
|
"""
|
||||||
Send message to model
|
Send message to model
|
||||||
"""
|
"""
|
||||||
@@ -97,11 +98,10 @@ def send_message_to_model(messages, api_key, model, response_type="text"):
|
|||||||
# Get Response from GPT
|
# Get Response from GPT
|
||||||
return completion_with_backoff(
|
return completion_with_backoff(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
model_kwargs={
|
model=model,
|
||||||
"model_name": model,
|
openai_api_key=api_key,
|
||||||
"openai_api_key": api_key,
|
api_base_url=api_base_url,
|
||||||
"model_kwargs": {"response_format": {"type": response_type}},
|
model_kwargs={"response_format": {"type": response_type}},
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -112,6 +112,7 @@ def converse(
|
|||||||
conversation_log={},
|
conversation_log={},
|
||||||
model: str = "gpt-3.5-turbo",
|
model: str = "gpt-3.5-turbo",
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
|
api_base_url: Optional[str] = None,
|
||||||
temperature: float = 0.2,
|
temperature: float = 0.2,
|
||||||
completion_func=None,
|
completion_func=None,
|
||||||
conversation_commands=[ConversationCommand.Default],
|
conversation_commands=[ConversationCommand.Default],
|
||||||
@@ -181,6 +182,7 @@ def converse(
|
|||||||
model_name=model,
|
model_name=model,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
openai_api_key=api_key,
|
openai_api_key=api_key,
|
||||||
|
api_base_url=api_base_url,
|
||||||
completion_func=completion_func,
|
completion_func=completion_func,
|
||||||
model_kwargs={"stop": ["Notes:\n["]},
|
model_kwargs={"stop": ["Notes:\n["]},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,12 +1,9 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from typing import Any
|
from typing import Dict
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
from langchain.callbacks.base import BaseCallbackManager
|
|
||||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
|
||||||
from langchain_openai import ChatOpenAI
|
|
||||||
from tenacity import (
|
from tenacity import (
|
||||||
before_sleep_log,
|
before_sleep_log,
|
||||||
retry,
|
retry,
|
||||||
@@ -20,14 +17,7 @@ from khoj.processor.conversation.utils import ThreadedGenerator
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
openai_clients: Dict[str, openai.OpenAI] = {}
|
||||||
class StreamingChatCallbackHandler(StreamingStdOutCallbackHandler):
|
|
||||||
def __init__(self, gen: ThreadedGenerator):
|
|
||||||
super().__init__()
|
|
||||||
self.gen = gen
|
|
||||||
|
|
||||||
def on_llm_new_token(self, token: str, **kwargs) -> Any:
|
|
||||||
self.gen.send(token)
|
|
||||||
|
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
@@ -43,13 +33,37 @@ class StreamingChatCallbackHandler(StreamingStdOutCallbackHandler):
|
|||||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||||
reraise=True,
|
reraise=True,
|
||||||
)
|
)
|
||||||
def completion_with_backoff(messages, model_kwargs={}, completion_kwargs={}) -> str:
|
def completion_with_backoff(
|
||||||
if not "openai_api_key" in model_kwargs:
|
messages, model, temperature=0, openai_api_key=None, api_base_url=None, model_kwargs=None, max_tokens=None
|
||||||
model_kwargs["openai_api_key"] = os.getenv("OPENAI_API_KEY")
|
) -> str:
|
||||||
llm = ChatOpenAI(**model_kwargs, request_timeout=20, max_retries=1)
|
client_key = f"{openai_api_key}--{api_base_url}"
|
||||||
|
client: openai.OpenAI = openai_clients.get(client_key)
|
||||||
|
if not client:
|
||||||
|
client = openai.OpenAI(
|
||||||
|
api_key=openai_api_key or os.getenv("OPENAI_API_KEY"),
|
||||||
|
base_url=api_base_url,
|
||||||
|
)
|
||||||
|
openai_clients[client_key] = client
|
||||||
|
|
||||||
|
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
|
||||||
|
|
||||||
|
chat = client.chat.completions.create(
|
||||||
|
stream=True,
|
||||||
|
messages=formatted_messages, # type: ignore
|
||||||
|
model=model, # type: ignore
|
||||||
|
temperature=temperature,
|
||||||
|
timeout=20,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
**(model_kwargs or dict()),
|
||||||
|
)
|
||||||
aggregated_response = ""
|
aggregated_response = ""
|
||||||
for chunk in llm.stream(messages, **completion_kwargs):
|
for chunk in chat:
|
||||||
aggregated_response += chunk.content
|
delta_chunk = chunk.choices[0].delta # type: ignore
|
||||||
|
if isinstance(delta_chunk, str):
|
||||||
|
aggregated_response += delta_chunk
|
||||||
|
elif delta_chunk.content:
|
||||||
|
aggregated_response += delta_chunk.content
|
||||||
|
|
||||||
return aggregated_response
|
return aggregated_response
|
||||||
|
|
||||||
|
|
||||||
@@ -73,30 +87,45 @@ def chat_completion_with_backoff(
|
|||||||
model_name,
|
model_name,
|
||||||
temperature,
|
temperature,
|
||||||
openai_api_key=None,
|
openai_api_key=None,
|
||||||
|
api_base_url=None,
|
||||||
completion_func=None,
|
completion_func=None,
|
||||||
model_kwargs=None,
|
model_kwargs=None,
|
||||||
):
|
):
|
||||||
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
|
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
|
||||||
t = Thread(target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key, model_kwargs))
|
t = Thread(
|
||||||
|
target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key, api_base_url, model_kwargs)
|
||||||
|
)
|
||||||
t.start()
|
t.start()
|
||||||
return g
|
return g
|
||||||
|
|
||||||
|
|
||||||
def llm_thread(g, messages, model_name, temperature, openai_api_key=None, model_kwargs=None):
|
def llm_thread(g, messages, model_name, temperature, openai_api_key=None, api_base_url=None, model_kwargs=None):
|
||||||
callback_handler = StreamingChatCallbackHandler(g)
|
client_key = f"{openai_api_key}--{api_base_url}"
|
||||||
chat = ChatOpenAI(
|
if client_key not in openai_clients:
|
||||||
streaming=True,
|
client: openai.OpenAI = openai.OpenAI(
|
||||||
verbose=True,
|
api_key=openai_api_key or os.getenv("OPENAI_API_KEY"),
|
||||||
callback_manager=BaseCallbackManager([callback_handler]),
|
base_url=api_base_url,
|
||||||
model_name=model_name, # type: ignore
|
)
|
||||||
|
openai_clients[client_key] = client
|
||||||
|
else:
|
||||||
|
client: openai.OpenAI = openai_clients[client_key]
|
||||||
|
|
||||||
|
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
|
||||||
|
|
||||||
|
chat = client.chat.completions.create(
|
||||||
|
stream=True,
|
||||||
|
messages=formatted_messages,
|
||||||
|
model=model_name, # type: ignore
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
openai_api_key=openai_api_key or os.getenv("OPENAI_API_KEY"),
|
timeout=20,
|
||||||
model_kwargs=model_kwargs,
|
**(model_kwargs or dict()),
|
||||||
request_timeout=20,
|
|
||||||
max_retries=1,
|
|
||||||
client=None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
chat(messages=messages)
|
for chunk in chat:
|
||||||
|
delta_chunk = chunk.choices[0].delta
|
||||||
|
if isinstance(delta_chunk, str):
|
||||||
|
g.send(delta_chunk)
|
||||||
|
elif delta_chunk.content:
|
||||||
|
g.send(delta_chunk.content)
|
||||||
|
|
||||||
g.close()
|
g.close()
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from transformers import AutoTokenizer
|
|||||||
from khoj.database.adapters import ConversationAdapters
|
from khoj.database.adapters import ConversationAdapters
|
||||||
from khoj.database.models import ClientApplication, KhojUser
|
from khoj.database.models import ClientApplication, KhojUser
|
||||||
from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens
|
from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens
|
||||||
|
from khoj.utils import state
|
||||||
from khoj.utils.helpers import is_none_or_empty, merge_dicts
|
from khoj.utils.helpers import is_none_or_empty, merge_dicts
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -186,19 +187,31 @@ def truncate_messages(
|
|||||||
max_prompt_size,
|
max_prompt_size,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
loaded_model: Optional[Llama] = None,
|
loaded_model: Optional[Llama] = None,
|
||||||
tokenizer_name="hf-internal-testing/llama-tokenizer",
|
tokenizer_name=None,
|
||||||
) -> list[ChatMessage]:
|
) -> list[ChatMessage]:
|
||||||
"""Truncate messages to fit within max prompt size supported by model"""
|
"""Truncate messages to fit within max prompt size supported by model"""
|
||||||
|
|
||||||
|
default_tokenizer = "hf-internal-testing/llama-tokenizer"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if loaded_model:
|
if loaded_model:
|
||||||
encoder = loaded_model.tokenizer()
|
encoder = loaded_model.tokenizer()
|
||||||
elif model_name.startswith("gpt-"):
|
elif model_name.startswith("gpt-"):
|
||||||
encoder = tiktoken.encoding_for_model(model_name)
|
encoder = tiktoken.encoding_for_model(model_name)
|
||||||
|
elif tokenizer_name:
|
||||||
|
if tokenizer_name in state.pretrained_tokenizers:
|
||||||
|
encoder = state.pretrained_tokenizers[tokenizer_name]
|
||||||
|
else:
|
||||||
|
encoder = AutoTokenizer.from_pretrained(tokenizer_name)
|
||||||
|
state.pretrained_tokenizers[tokenizer_name] = encoder
|
||||||
else:
|
else:
|
||||||
encoder = download_model(model_name).tokenizer()
|
encoder = download_model(model_name).tokenizer()
|
||||||
except:
|
except:
|
||||||
encoder = AutoTokenizer.from_pretrained(tokenizer_name)
|
if default_tokenizer in state.pretrained_tokenizers:
|
||||||
|
encoder = state.pretrained_tokenizers[default_tokenizer]
|
||||||
|
else:
|
||||||
|
encoder = AutoTokenizer.from_pretrained(default_tokenizer)
|
||||||
|
state.pretrained_tokenizers[default_tokenizer] = encoder
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Fallback to default chat model tokenizer: {tokenizer_name}.\nConfigure tokenizer for unsupported model: {model_name} in Khoj settings to improve context stuffing."
|
f"Fallback to default chat model tokenizer: {tokenizer_name}.\nConfigure tokenizer for unsupported model: {model_name} in Khoj settings to improve context stuffing."
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -267,7 +267,6 @@ async def transcribe(
|
|||||||
|
|
||||||
async def extract_references_and_questions(
|
async def extract_references_and_questions(
|
||||||
request: Request,
|
request: Request,
|
||||||
common: CommonQueryParams,
|
|
||||||
meta_log: dict,
|
meta_log: dict,
|
||||||
q: str,
|
q: str,
|
||||||
n: int,
|
n: int,
|
||||||
@@ -303,14 +302,12 @@ async def extract_references_and_questions(
|
|||||||
# Infer search queries from user message
|
# Infer search queries from user message
|
||||||
with timer("Extracting search queries took", logger):
|
with timer("Extracting search queries took", logger):
|
||||||
# If we've reached here, either the user has enabled offline chat or the openai model is enabled.
|
# If we've reached here, either the user has enabled offline chat or the openai model is enabled.
|
||||||
conversation_config = await ConversationAdapters.aget_conversation_config(user)
|
|
||||||
if conversation_config is None:
|
|
||||||
conversation_config = await ConversationAdapters.aget_default_conversation_config()
|
conversation_config = await ConversationAdapters.aget_default_conversation_config()
|
||||||
|
|
||||||
if conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE:
|
if conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE:
|
||||||
using_offline_chat = True
|
using_offline_chat = True
|
||||||
default_offline_llm = await ConversationAdapters.get_default_offline_llm()
|
chat_model = conversation_config.chat_model
|
||||||
chat_model = default_offline_llm.chat_model
|
max_tokens = conversation_config.max_prompt_size
|
||||||
max_tokens = default_offline_llm.max_prompt_size
|
|
||||||
if state.offline_chat_processor_config is None:
|
if state.offline_chat_processor_config is None:
|
||||||
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
|
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
|
||||||
|
|
||||||
@@ -324,11 +321,10 @@ async def extract_references_and_questions(
|
|||||||
location_data=location_data,
|
location_data=location_data,
|
||||||
max_prompt_size=conversation_config.max_prompt_size,
|
max_prompt_size=conversation_config.max_prompt_size,
|
||||||
)
|
)
|
||||||
elif conversation_config and conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
||||||
openai_chat_config = await ConversationAdapters.get_openai_chat_config()
|
openai_chat_config = conversation_config.openai_config
|
||||||
default_openai_llm = await ConversationAdapters.aget_default_openai_llm()
|
|
||||||
api_key = openai_chat_config.api_key
|
api_key = openai_chat_config.api_key
|
||||||
chat_model = default_openai_llm.chat_model
|
chat_model = conversation_config.chat_model
|
||||||
inferred_queries = extract_questions(
|
inferred_queries = extract_questions(
|
||||||
defiltered_query,
|
defiltered_query,
|
||||||
model=chat_model,
|
model=chat_model,
|
||||||
|
|||||||
@@ -380,7 +380,7 @@ async def websocket_endpoint(
|
|||||||
q = q.replace(f"/{cmd.value}", "").strip()
|
q = q.replace(f"/{cmd.value}", "").strip()
|
||||||
|
|
||||||
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
|
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
|
||||||
websocket, None, meta_log, q, 7, 0.18, conversation_commands, location, send_status_update
|
websocket, meta_log, q, 7, 0.18, conversation_commands, location, send_status_update
|
||||||
)
|
)
|
||||||
|
|
||||||
if compiled_references:
|
if compiled_references:
|
||||||
@@ -575,7 +575,7 @@ async def chat(
|
|||||||
user_name = await aget_user_name(user)
|
user_name = await aget_user_name(user)
|
||||||
|
|
||||||
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
|
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
|
||||||
request, common, meta_log, q, (n or 5), (d or math.inf), conversation_commands, location
|
request, meta_log, q, (n or 5), (d or math.inf), conversation_commands, location
|
||||||
)
|
)
|
||||||
online_results: Dict[str, Dict] = {}
|
online_results: Dict[str, Dict] = {}
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from asgiref.sync import sync_to_async
|
|||||||
from fastapi import APIRouter, HTTPException, Request
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
from fastapi.requests import Request
|
from fastapi.requests import Request
|
||||||
from fastapi.responses import Response
|
from fastapi.responses import Response
|
||||||
from starlette.authentication import requires
|
from starlette.authentication import has_required_scope, requires
|
||||||
|
|
||||||
from khoj.database import adapters
|
from khoj.database import adapters
|
||||||
from khoj.database.adapters import ConversationAdapters, EntryAdapters
|
from khoj.database.adapters import ConversationAdapters, EntryAdapters
|
||||||
@@ -20,6 +20,7 @@ from khoj.database.models import (
|
|||||||
LocalPdfConfig,
|
LocalPdfConfig,
|
||||||
LocalPlaintextConfig,
|
LocalPlaintextConfig,
|
||||||
NotionConfig,
|
NotionConfig,
|
||||||
|
Subscription,
|
||||||
)
|
)
|
||||||
from khoj.routers.helpers import CommonQueryParams, update_telemetry_state
|
from khoj.routers.helpers import CommonQueryParams, update_telemetry_state
|
||||||
from khoj.utils import constants, state
|
from khoj.utils import constants, state
|
||||||
@@ -236,6 +237,10 @@ async def update_chat_model(
|
|||||||
client: Optional[str] = None,
|
client: Optional[str] = None,
|
||||||
):
|
):
|
||||||
user = request.user.object
|
user = request.user.object
|
||||||
|
subscribed = has_required_scope(request, ["premium"])
|
||||||
|
|
||||||
|
if not subscribed:
|
||||||
|
raise HTTPException(status_code=403, detail="User is not subscribed to premium")
|
||||||
|
|
||||||
new_config = await ConversationAdapters.aset_user_conversation_processor(user, int(id))
|
new_config = await ConversationAdapters.aset_user_conversation_processor(user, int(id))
|
||||||
|
|
||||||
|
|||||||
@@ -70,13 +70,14 @@ def validate_conversation_config():
|
|||||||
if default_config is None:
|
if default_config is None:
|
||||||
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
|
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
|
||||||
|
|
||||||
if default_config.model_type == "openai" and not ConversationAdapters.has_valid_openai_conversation_config():
|
if default_config.model_type == "openai" and not default_config.openai_config:
|
||||||
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
|
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
|
||||||
|
|
||||||
|
|
||||||
async def is_ready_to_chat(user: KhojUser):
|
async def is_ready_to_chat(user: KhojUser):
|
||||||
has_openai_config = await ConversationAdapters.has_openai_chat()
|
user_conversation_config = (await ConversationAdapters.aget_user_conversation_config(user)) or (
|
||||||
user_conversation_config = await ConversationAdapters.aget_user_conversation_config(user)
|
await ConversationAdapters.aget_default_conversation_config()
|
||||||
|
)
|
||||||
|
|
||||||
if user_conversation_config and user_conversation_config.model_type == "offline":
|
if user_conversation_config and user_conversation_config.model_type == "offline":
|
||||||
chat_model = user_conversation_config.chat_model
|
chat_model = user_conversation_config.chat_model
|
||||||
@@ -86,7 +87,13 @@ async def is_ready_to_chat(user: KhojUser):
|
|||||||
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
|
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if not has_openai_config:
|
if (
|
||||||
|
user_conversation_config
|
||||||
|
and user_conversation_config.model_type == "openai"
|
||||||
|
and user_conversation_config.openai_config
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
|
||||||
raise HTTPException(status_code=500, detail="Set your OpenAI API key or enable Local LLM via Khoj settings.")
|
raise HTTPException(status_code=500, detail="Set your OpenAI API key or enable Local LLM via Khoj settings.")
|
||||||
|
|
||||||
|
|
||||||
@@ -407,8 +414,9 @@ async def send_message_to_model_wrapper(
|
|||||||
)
|
)
|
||||||
|
|
||||||
elif conversation_config.model_type == "openai":
|
elif conversation_config.model_type == "openai":
|
||||||
openai_chat_config = await ConversationAdapters.aget_openai_conversation_config()
|
openai_chat_config = conversation_config.openai_config
|
||||||
api_key = openai_chat_config.api_key
|
api_key = openai_chat_config.api_key
|
||||||
|
api_base_url = openai_chat_config.api_base_url
|
||||||
truncated_messages = generate_chatml_messages_with_context(
|
truncated_messages = generate_chatml_messages_with_context(
|
||||||
user_message=message,
|
user_message=message,
|
||||||
system_message=system_message,
|
system_message=system_message,
|
||||||
@@ -418,7 +426,11 @@ async def send_message_to_model_wrapper(
|
|||||||
)
|
)
|
||||||
|
|
||||||
openai_response = send_message_to_model(
|
openai_response = send_message_to_model(
|
||||||
messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type
|
messages=truncated_messages,
|
||||||
|
api_key=api_key,
|
||||||
|
model=chat_model,
|
||||||
|
response_type=response_type,
|
||||||
|
api_base_url=api_base_url,
|
||||||
)
|
)
|
||||||
|
|
||||||
return openai_response
|
return openai_response
|
||||||
@@ -480,7 +492,7 @@ def generate_chat_response(
|
|||||||
)
|
)
|
||||||
|
|
||||||
elif conversation_config.model_type == "openai":
|
elif conversation_config.model_type == "openai":
|
||||||
openai_chat_config = ConversationAdapters.get_openai_conversation_config()
|
openai_chat_config = conversation_config.openai_config
|
||||||
api_key = openai_chat_config.api_key
|
api_key = openai_chat_config.api_key
|
||||||
chat_model = conversation_config.chat_model
|
chat_model = conversation_config.chat_model
|
||||||
chat_response = converse(
|
chat_response = converse(
|
||||||
@@ -490,6 +502,7 @@ def generate_chat_response(
|
|||||||
conversation_log=meta_log,
|
conversation_log=meta_log,
|
||||||
model=chat_model,
|
model=chat_model,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
api_base_url=openai_chat_config.api_base_url,
|
||||||
completion_func=partial_completion,
|
completion_func=partial_completion,
|
||||||
conversation_commands=conversation_commands,
|
conversation_commands=conversation_commands,
|
||||||
max_prompt_size=conversation_config.max_prompt_size,
|
max_prompt_size=conversation_config.max_prompt_size,
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import os
|
|||||||
import threading
|
import threading
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
from whisper import Whisper
|
from whisper import Whisper
|
||||||
@@ -34,6 +34,7 @@ khoj_version: str = None
|
|||||||
device = get_device()
|
device = get_device()
|
||||||
chat_on_gpu: bool = True
|
chat_on_gpu: bool = True
|
||||||
anonymous_mode: bool = False
|
anonymous_mode: bool = False
|
||||||
|
pretrained_tokenizers: Dict[str, Any] = dict()
|
||||||
billing_enabled: bool = (
|
billing_enabled: bool = (
|
||||||
os.getenv("STRIPE_API_KEY") is not None
|
os.getenv("STRIPE_API_KEY") is not None
|
||||||
and os.getenv("STRIPE_SIGNING_SECRET") is not None
|
and os.getenv("STRIPE_SIGNING_SECRET") is not None
|
||||||
|
|||||||
Reference in New Issue
Block a user