mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-09 13:25:11 +00:00
Simplify AI Chat Response Streaming (#1167)
Reason --- - Simplify code and logic to stream chat response by solely relying on asyncio event loop. - Reduce overhead of managing threads to increase efficiency and throughput (where possible). Details --- - Use async/await with no threading when generating chat response via OpenAI, Gemini, Anthropic AI model APIs - Use threading for offline chat model as llama-cpp doesn't support async streaming yet
This commit is contained in:
@@ -763,9 +763,9 @@ class AgentAdapters:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_conversation_agent_by_id(agent_id: int):
|
async def aget_conversation_agent_by_id(agent_id: int):
|
||||||
agent = Agent.objects.filter(id=agent_id).first()
|
agent = await Agent.objects.filter(id=agent_id).afirst()
|
||||||
if agent == AgentAdapters.get_default_agent():
|
if agent == await AgentAdapters.aget_default_agent():
|
||||||
# If the agent is set to the default agent, then return None and let the default application code be used
|
# If the agent is set to the default agent, then return None and let the default application code be used
|
||||||
return None
|
return None
|
||||||
return agent
|
return agent
|
||||||
@@ -1109,14 +1109,6 @@ class ConversationAdapters:
|
|||||||
async def aget_all_chat_models():
|
async def aget_all_chat_models():
|
||||||
return await sync_to_async(list)(ChatModel.objects.prefetch_related("ai_model_api").all())
|
return await sync_to_async(list)(ChatModel.objects.prefetch_related("ai_model_api").all())
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_vision_enabled_config():
|
|
||||||
chat_models = ConversationAdapters.get_all_chat_models()
|
|
||||||
for config in chat_models:
|
|
||||||
if config.vision_enabled:
|
|
||||||
return config
|
|
||||||
return None
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def aget_vision_enabled_config():
|
async def aget_vision_enabled_config():
|
||||||
chat_models = await ConversationAdapters.aget_all_chat_models()
|
chat_models = await ConversationAdapters.aget_all_chat_models()
|
||||||
@@ -1171,7 +1163,11 @@ class ConversationAdapters:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
async def aget_chat_model(user: KhojUser):
|
async def aget_chat_model(user: KhojUser):
|
||||||
subscribed = await ais_user_subscribed(user)
|
subscribed = await ais_user_subscribed(user)
|
||||||
config = await UserConversationConfig.objects.filter(user=user).prefetch_related("setting").afirst()
|
config = (
|
||||||
|
await UserConversationConfig.objects.filter(user=user)
|
||||||
|
.prefetch_related("setting", "setting__ai_model_api")
|
||||||
|
.afirst()
|
||||||
|
)
|
||||||
if subscribed:
|
if subscribed:
|
||||||
# Subscibed users can use any available chat model
|
# Subscibed users can use any available chat model
|
||||||
if config:
|
if config:
|
||||||
@@ -1387,7 +1383,7 @@ class ConversationAdapters:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@require_valid_user
|
@require_valid_user
|
||||||
def save_conversation(
|
async def save_conversation(
|
||||||
user: KhojUser,
|
user: KhojUser,
|
||||||
conversation_log: dict,
|
conversation_log: dict,
|
||||||
client_application: ClientApplication = None,
|
client_application: ClientApplication = None,
|
||||||
@@ -1396,19 +1392,21 @@ class ConversationAdapters:
|
|||||||
):
|
):
|
||||||
slug = user_message.strip()[:200] if user_message else None
|
slug = user_message.strip()[:200] if user_message else None
|
||||||
if conversation_id:
|
if conversation_id:
|
||||||
conversation = Conversation.objects.filter(user=user, client=client_application, id=conversation_id).first()
|
conversation = await Conversation.objects.filter(
|
||||||
|
user=user, client=client_application, id=conversation_id
|
||||||
|
).afirst()
|
||||||
else:
|
else:
|
||||||
conversation = (
|
conversation = (
|
||||||
Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at").first()
|
await Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at").afirst()
|
||||||
)
|
)
|
||||||
|
|
||||||
if conversation:
|
if conversation:
|
||||||
conversation.conversation_log = conversation_log
|
conversation.conversation_log = conversation_log
|
||||||
conversation.slug = slug
|
conversation.slug = slug
|
||||||
conversation.updated_at = datetime.now(tz=timezone.utc)
|
conversation.updated_at = datetime.now(tz=timezone.utc)
|
||||||
conversation.save()
|
await conversation.asave()
|
||||||
else:
|
else:
|
||||||
Conversation.objects.create(
|
await Conversation.objects.acreate(
|
||||||
user=user, conversation_log=conversation_log, client=client_application, slug=slug
|
user=user, conversation_log=conversation_log, client=client_application, slug=slug
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1455,17 +1453,21 @@ class ConversationAdapters:
|
|||||||
return random.sample(all_questions, max_results)
|
return random.sample(all_questions, max_results)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_valid_chat_model(user: KhojUser, conversation: Conversation, is_subscribed: bool):
|
async def aget_valid_chat_model(user: KhojUser, conversation: Conversation, is_subscribed: bool):
|
||||||
agent: Agent = (
|
agent: Agent = (
|
||||||
conversation.agent if is_subscribed and AgentAdapters.get_default_agent() != conversation.agent else None
|
conversation.agent
|
||||||
|
if is_subscribed and await AgentAdapters.aget_default_agent() != conversation.agent
|
||||||
|
else None
|
||||||
)
|
)
|
||||||
if agent and agent.chat_model:
|
if agent and agent.chat_model:
|
||||||
chat_model = conversation.agent.chat_model
|
chat_model = await ChatModel.objects.select_related("ai_model_api").aget(
|
||||||
|
pk=conversation.agent.chat_model.pk
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
chat_model = ConversationAdapters.get_chat_model(user)
|
chat_model = await ConversationAdapters.aget_chat_model(user)
|
||||||
|
|
||||||
if chat_model is None:
|
if chat_model is None:
|
||||||
chat_model = ConversationAdapters.get_default_chat_model()
|
chat_model = await ConversationAdapters.aget_default_chat_model()
|
||||||
|
|
||||||
if chat_model.model_type == ChatModel.ModelType.OFFLINE:
|
if chat_model.model_type == ChatModel.ModelType.OFFLINE:
|
||||||
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
|
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Dict, List, Optional
|
from typing import AsyncGenerator, Dict, List, Optional
|
||||||
|
|
||||||
import pyjson5
|
import pyjson5
|
||||||
from langchain.schema import ChatMessage
|
from langchain.schema import ChatMessage
|
||||||
@@ -137,7 +137,7 @@ def anthropic_send_message_to_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def converse_anthropic(
|
async def converse_anthropic(
|
||||||
references,
|
references,
|
||||||
user_query,
|
user_query,
|
||||||
online_results: Optional[Dict[str, Dict]] = None,
|
online_results: Optional[Dict[str, Dict]] = None,
|
||||||
@@ -161,7 +161,7 @@ def converse_anthropic(
|
|||||||
generated_asset_results: Dict[str, Dict] = {},
|
generated_asset_results: Dict[str, Dict] = {},
|
||||||
deepthought: Optional[bool] = False,
|
deepthought: Optional[bool] = False,
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
):
|
) -> AsyncGenerator[str, None]:
|
||||||
"""
|
"""
|
||||||
Converse with user using Anthropic's Claude
|
Converse with user using Anthropic's Claude
|
||||||
"""
|
"""
|
||||||
@@ -191,11 +191,17 @@ def converse_anthropic(
|
|||||||
|
|
||||||
# Get Conversation Primer appropriate to Conversation Type
|
# Get Conversation Primer appropriate to Conversation Type
|
||||||
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
|
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
|
||||||
completion_func(chat_response=prompts.no_notes_found.format())
|
response = prompts.no_notes_found.format()
|
||||||
return iter([prompts.no_notes_found.format()])
|
if completion_func:
|
||||||
|
await completion_func(chat_response=response)
|
||||||
|
yield response
|
||||||
|
return
|
||||||
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
|
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
|
||||||
completion_func(chat_response=prompts.no_online_results_found.format())
|
response = prompts.no_online_results_found.format()
|
||||||
return iter([prompts.no_online_results_found.format()])
|
if completion_func:
|
||||||
|
await completion_func(chat_response=response)
|
||||||
|
yield response
|
||||||
|
return
|
||||||
|
|
||||||
context_message = ""
|
context_message = ""
|
||||||
if not is_none_or_empty(references):
|
if not is_none_or_empty(references):
|
||||||
@@ -228,17 +234,21 @@ def converse_anthropic(
|
|||||||
logger.debug(f"Conversation Context for Claude: {messages_to_print(messages)}")
|
logger.debug(f"Conversation Context for Claude: {messages_to_print(messages)}")
|
||||||
|
|
||||||
# Get Response from Claude
|
# Get Response from Claude
|
||||||
return anthropic_chat_completion_with_backoff(
|
full_response = ""
|
||||||
|
async for chunk in anthropic_chat_completion_with_backoff(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
compiled_references=references,
|
|
||||||
online_results=online_results,
|
|
||||||
model_name=model,
|
model_name=model,
|
||||||
temperature=0.2,
|
temperature=0.2,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_base_url=api_base_url,
|
api_base_url=api_base_url,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
completion_func=completion_func,
|
|
||||||
max_prompt_size=max_prompt_size,
|
max_prompt_size=max_prompt_size,
|
||||||
deepthought=deepthought,
|
deepthought=deepthought,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
):
|
||||||
|
full_response += chunk
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
# Call completion_func once finish streaming and we have the full response
|
||||||
|
if completion_func:
|
||||||
|
await completion_func(chat_response=full_response)
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
from threading import Thread
|
from time import perf_counter
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
import anthropic
|
import anthropic
|
||||||
@@ -13,13 +13,13 @@ from tenacity import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from khoj.processor.conversation.utils import (
|
from khoj.processor.conversation.utils import (
|
||||||
ThreadedGenerator,
|
|
||||||
commit_conversation_trace,
|
commit_conversation_trace,
|
||||||
get_image_from_base64,
|
get_image_from_base64,
|
||||||
get_image_from_url,
|
get_image_from_url,
|
||||||
)
|
)
|
||||||
from khoj.utils.helpers import (
|
from khoj.utils.helpers import (
|
||||||
get_ai_api_info,
|
get_anthropic_async_client,
|
||||||
|
get_anthropic_client,
|
||||||
get_chat_usage_metrics,
|
get_chat_usage_metrics,
|
||||||
is_none_or_empty,
|
is_none_or_empty,
|
||||||
is_promptrace_enabled,
|
is_promptrace_enabled,
|
||||||
@@ -28,24 +28,12 @@ from khoj.utils.helpers import (
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
anthropic_clients: Dict[str, anthropic.Anthropic | anthropic.AnthropicVertex] = {}
|
anthropic_clients: Dict[str, anthropic.Anthropic | anthropic.AnthropicVertex] = {}
|
||||||
|
anthropic_async_clients: Dict[str, anthropic.AsyncAnthropic | anthropic.AsyncAnthropicVertex] = {}
|
||||||
|
|
||||||
DEFAULT_MAX_TOKENS_ANTHROPIC = 8000
|
DEFAULT_MAX_TOKENS_ANTHROPIC = 8000
|
||||||
MAX_REASONING_TOKENS_ANTHROPIC = 12000
|
MAX_REASONING_TOKENS_ANTHROPIC = 12000
|
||||||
|
|
||||||
|
|
||||||
def get_anthropic_client(api_key, api_base_url=None) -> anthropic.Anthropic | anthropic.AnthropicVertex:
|
|
||||||
api_info = get_ai_api_info(api_key, api_base_url)
|
|
||||||
if api_info.api_key:
|
|
||||||
client = anthropic.Anthropic(api_key=api_info.api_key)
|
|
||||||
else:
|
|
||||||
client = anthropic.AnthropicVertex(
|
|
||||||
region=api_info.region,
|
|
||||||
project_id=api_info.project,
|
|
||||||
credentials=api_info.credentials,
|
|
||||||
)
|
|
||||||
return client
|
|
||||||
|
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
wait=wait_random_exponential(min=1, max=10),
|
wait=wait_random_exponential(min=1, max=10),
|
||||||
stop=stop_after_attempt(2),
|
stop=stop_after_attempt(2),
|
||||||
@@ -126,60 +114,23 @@ def anthropic_completion_with_backoff(
|
|||||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||||
reraise=True,
|
reraise=True,
|
||||||
)
|
)
|
||||||
def anthropic_chat_completion_with_backoff(
|
async def anthropic_chat_completion_with_backoff(
|
||||||
messages: list[ChatMessage],
|
messages: list[ChatMessage],
|
||||||
compiled_references,
|
|
||||||
online_results,
|
|
||||||
model_name,
|
model_name,
|
||||||
temperature,
|
temperature,
|
||||||
api_key,
|
api_key,
|
||||||
api_base_url,
|
api_base_url,
|
||||||
system_prompt: str,
|
system_prompt: str,
|
||||||
max_prompt_size=None,
|
max_prompt_size=None,
|
||||||
completion_func=None,
|
|
||||||
deepthought=False,
|
|
||||||
model_kwargs=None,
|
|
||||||
tracer={},
|
|
||||||
):
|
|
||||||
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
|
|
||||||
t = Thread(
|
|
||||||
target=anthropic_llm_thread,
|
|
||||||
args=(
|
|
||||||
g,
|
|
||||||
messages,
|
|
||||||
system_prompt,
|
|
||||||
model_name,
|
|
||||||
temperature,
|
|
||||||
api_key,
|
|
||||||
api_base_url,
|
|
||||||
max_prompt_size,
|
|
||||||
deepthought,
|
|
||||||
model_kwargs,
|
|
||||||
tracer,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
t.start()
|
|
||||||
return g
|
|
||||||
|
|
||||||
|
|
||||||
def anthropic_llm_thread(
|
|
||||||
g,
|
|
||||||
messages: list[ChatMessage],
|
|
||||||
system_prompt: str,
|
|
||||||
model_name: str,
|
|
||||||
temperature,
|
|
||||||
api_key,
|
|
||||||
api_base_url=None,
|
|
||||||
max_prompt_size=None,
|
|
||||||
deepthought=False,
|
deepthought=False,
|
||||||
model_kwargs=None,
|
model_kwargs=None,
|
||||||
tracer={},
|
tracer={},
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
client = anthropic_clients.get(api_key)
|
client = anthropic_async_clients.get(api_key)
|
||||||
if not client:
|
if not client:
|
||||||
client = get_anthropic_client(api_key, api_base_url)
|
client = get_anthropic_async_client(api_key, api_base_url)
|
||||||
anthropic_clients[api_key] = client
|
anthropic_async_clients[api_key] = client
|
||||||
|
|
||||||
model_kwargs = model_kwargs or dict()
|
model_kwargs = model_kwargs or dict()
|
||||||
max_tokens = DEFAULT_MAX_TOKENS_ANTHROPIC
|
max_tokens = DEFAULT_MAX_TOKENS_ANTHROPIC
|
||||||
@@ -193,7 +144,8 @@ def anthropic_llm_thread(
|
|||||||
|
|
||||||
aggregated_response = ""
|
aggregated_response = ""
|
||||||
final_message = None
|
final_message = None
|
||||||
with client.messages.stream(
|
start_time = perf_counter()
|
||||||
|
async with client.messages.stream(
|
||||||
messages=formatted_messages,
|
messages=formatted_messages,
|
||||||
model=model_name, # type: ignore
|
model=model_name, # type: ignore
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
@@ -202,10 +154,17 @@ def anthropic_llm_thread(
|
|||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
) as stream:
|
) as stream:
|
||||||
for text in stream.text_stream:
|
async for text in stream.text_stream:
|
||||||
|
# Log the time taken to start response
|
||||||
|
if aggregated_response == "":
|
||||||
|
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
|
||||||
|
# Handle streamed response chunk
|
||||||
aggregated_response += text
|
aggregated_response += text
|
||||||
g.send(text)
|
yield text
|
||||||
final_message = stream.get_final_message()
|
final_message = await stream.get_final_message()
|
||||||
|
|
||||||
|
# Log the time taken to stream the entire response
|
||||||
|
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
|
||||||
|
|
||||||
# Calculate cost of chat
|
# Calculate cost of chat
|
||||||
input_tokens = final_message.usage.input_tokens
|
input_tokens = final_message.usage.input_tokens
|
||||||
@@ -222,9 +181,7 @@ def anthropic_llm_thread(
|
|||||||
if is_promptrace_enabled():
|
if is_promptrace_enabled():
|
||||||
commit_conversation_trace(messages, aggregated_response, tracer)
|
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in anthropic_llm_thread: {e}", exc_info=True)
|
logger.error(f"Error in anthropic_chat_completion_with_backoff stream: {e}", exc_info=True)
|
||||||
finally:
|
|
||||||
g.close()
|
|
||||||
|
|
||||||
|
|
||||||
def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt: str = None):
|
def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt: str = None):
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Dict, List, Optional
|
from typing import AsyncGenerator, Dict, List, Optional
|
||||||
|
|
||||||
import pyjson5
|
import pyjson5
|
||||||
from langchain.schema import ChatMessage
|
from langchain.schema import ChatMessage
|
||||||
@@ -160,7 +160,7 @@ def gemini_send_message_to_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def converse_gemini(
|
async def converse_gemini(
|
||||||
references,
|
references,
|
||||||
user_query,
|
user_query,
|
||||||
online_results: Optional[Dict[str, Dict]] = None,
|
online_results: Optional[Dict[str, Dict]] = None,
|
||||||
@@ -185,7 +185,7 @@ def converse_gemini(
|
|||||||
program_execution_context: List[str] = None,
|
program_execution_context: List[str] = None,
|
||||||
deepthought: Optional[bool] = False,
|
deepthought: Optional[bool] = False,
|
||||||
tracer={},
|
tracer={},
|
||||||
):
|
) -> AsyncGenerator[str, None]:
|
||||||
"""
|
"""
|
||||||
Converse with user using Google's Gemini
|
Converse with user using Google's Gemini
|
||||||
"""
|
"""
|
||||||
@@ -216,11 +216,17 @@ def converse_gemini(
|
|||||||
|
|
||||||
# Get Conversation Primer appropriate to Conversation Type
|
# Get Conversation Primer appropriate to Conversation Type
|
||||||
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
|
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
|
||||||
completion_func(chat_response=prompts.no_notes_found.format())
|
response = prompts.no_notes_found.format()
|
||||||
return iter([prompts.no_notes_found.format()])
|
if completion_func:
|
||||||
|
await completion_func(chat_response=response)
|
||||||
|
yield response
|
||||||
|
return
|
||||||
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
|
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
|
||||||
completion_func(chat_response=prompts.no_online_results_found.format())
|
response = prompts.no_online_results_found.format()
|
||||||
return iter([prompts.no_online_results_found.format()])
|
if completion_func:
|
||||||
|
await completion_func(chat_response=response)
|
||||||
|
yield response
|
||||||
|
return
|
||||||
|
|
||||||
context_message = ""
|
context_message = ""
|
||||||
if not is_none_or_empty(references):
|
if not is_none_or_empty(references):
|
||||||
@@ -253,16 +259,20 @@ def converse_gemini(
|
|||||||
logger.debug(f"Conversation Context for Gemini: {messages_to_print(messages)}")
|
logger.debug(f"Conversation Context for Gemini: {messages_to_print(messages)}")
|
||||||
|
|
||||||
# Get Response from Google AI
|
# Get Response from Google AI
|
||||||
return gemini_chat_completion_with_backoff(
|
full_response = ""
|
||||||
|
async for chunk in gemini_chat_completion_with_backoff(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
compiled_references=references,
|
|
||||||
online_results=online_results,
|
|
||||||
model_name=model,
|
model_name=model,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_base_url=api_base_url,
|
api_base_url=api_base_url,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
completion_func=completion_func,
|
|
||||||
deepthought=deepthought,
|
deepthought=deepthought,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
):
|
||||||
|
full_response += chunk
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
# Call completion_func once finish streaming and we have the full response
|
||||||
|
if completion_func:
|
||||||
|
await completion_func(chat_response=full_response)
|
||||||
|
|||||||
@@ -2,8 +2,8 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from threading import Thread
|
from time import perf_counter
|
||||||
from typing import Dict
|
from typing import AsyncGenerator, AsyncIterator, Dict
|
||||||
|
|
||||||
from google import genai
|
from google import genai
|
||||||
from google.genai import errors as gerrors
|
from google.genai import errors as gerrors
|
||||||
@@ -19,14 +19,13 @@ from tenacity import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from khoj.processor.conversation.utils import (
|
from khoj.processor.conversation.utils import (
|
||||||
ThreadedGenerator,
|
|
||||||
commit_conversation_trace,
|
commit_conversation_trace,
|
||||||
get_image_from_base64,
|
get_image_from_base64,
|
||||||
get_image_from_url,
|
get_image_from_url,
|
||||||
)
|
)
|
||||||
from khoj.utils.helpers import (
|
from khoj.utils.helpers import (
|
||||||
get_ai_api_info,
|
|
||||||
get_chat_usage_metrics,
|
get_chat_usage_metrics,
|
||||||
|
get_gemini_client,
|
||||||
is_none_or_empty,
|
is_none_or_empty,
|
||||||
is_promptrace_enabled,
|
is_promptrace_enabled,
|
||||||
)
|
)
|
||||||
@@ -62,17 +61,6 @@ SAFETY_SETTINGS = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def get_gemini_client(api_key, api_base_url=None) -> genai.Client:
|
|
||||||
api_info = get_ai_api_info(api_key, api_base_url)
|
|
||||||
return genai.Client(
|
|
||||||
location=api_info.region,
|
|
||||||
project=api_info.project,
|
|
||||||
credentials=api_info.credentials,
|
|
||||||
api_key=api_info.api_key,
|
|
||||||
vertexai=api_info.api_key is None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
wait=wait_random_exponential(min=1, max=10),
|
wait=wait_random_exponential(min=1, max=10),
|
||||||
stop=stop_after_attempt(2),
|
stop=stop_after_attempt(2),
|
||||||
@@ -132,8 +120,8 @@ def gemini_completion_with_backoff(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Aggregate cost of chat
|
# Aggregate cost of chat
|
||||||
input_tokens = response.usage_metadata.prompt_token_count if response else 0
|
input_tokens = response.usage_metadata.prompt_token_count or 0 if response else 0
|
||||||
output_tokens = response.usage_metadata.candidates_token_count if response else 0
|
output_tokens = response.usage_metadata.candidates_token_count or 0 if response else 0
|
||||||
thought_tokens = response.usage_metadata.thoughts_token_count or 0 if response else 0
|
thought_tokens = response.usage_metadata.thoughts_token_count or 0 if response else 0
|
||||||
tracer["usage"] = get_chat_usage_metrics(
|
tracer["usage"] = get_chat_usage_metrics(
|
||||||
model_name, input_tokens, output_tokens, thought_tokens=thought_tokens, usage=tracer.get("usage")
|
model_name, input_tokens, output_tokens, thought_tokens=thought_tokens, usage=tracer.get("usage")
|
||||||
@@ -154,52 +142,17 @@ def gemini_completion_with_backoff(
|
|||||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||||
reraise=True,
|
reraise=True,
|
||||||
)
|
)
|
||||||
def gemini_chat_completion_with_backoff(
|
async def gemini_chat_completion_with_backoff(
|
||||||
messages,
|
messages,
|
||||||
compiled_references,
|
|
||||||
online_results,
|
|
||||||
model_name,
|
model_name,
|
||||||
temperature,
|
temperature,
|
||||||
api_key,
|
api_key,
|
||||||
api_base_url,
|
api_base_url,
|
||||||
system_prompt,
|
system_prompt,
|
||||||
completion_func=None,
|
|
||||||
model_kwargs=None,
|
model_kwargs=None,
|
||||||
deepthought=False,
|
deepthought=False,
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
):
|
) -> AsyncGenerator[str, None]:
|
||||||
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
|
|
||||||
t = Thread(
|
|
||||||
target=gemini_llm_thread,
|
|
||||||
args=(
|
|
||||||
g,
|
|
||||||
messages,
|
|
||||||
system_prompt,
|
|
||||||
model_name,
|
|
||||||
temperature,
|
|
||||||
api_key,
|
|
||||||
api_base_url,
|
|
||||||
model_kwargs,
|
|
||||||
deepthought,
|
|
||||||
tracer,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
t.start()
|
|
||||||
return g
|
|
||||||
|
|
||||||
|
|
||||||
def gemini_llm_thread(
|
|
||||||
g,
|
|
||||||
messages,
|
|
||||||
system_prompt,
|
|
||||||
model_name,
|
|
||||||
temperature,
|
|
||||||
api_key,
|
|
||||||
api_base_url=None,
|
|
||||||
model_kwargs=None,
|
|
||||||
deepthought=False,
|
|
||||||
tracer: dict = {},
|
|
||||||
):
|
|
||||||
try:
|
try:
|
||||||
client = gemini_clients.get(api_key)
|
client = gemini_clients.get(api_key)
|
||||||
if not client:
|
if not client:
|
||||||
@@ -224,21 +177,32 @@ def gemini_llm_thread(
|
|||||||
)
|
)
|
||||||
|
|
||||||
aggregated_response = ""
|
aggregated_response = ""
|
||||||
|
final_chunk = None
|
||||||
for chunk in client.models.generate_content_stream(
|
start_time = perf_counter()
|
||||||
|
chat_stream: AsyncIterator[gtypes.GenerateContentResponse] = await client.aio.models.generate_content_stream(
|
||||||
model=model_name, config=config, contents=formatted_messages
|
model=model_name, config=config, contents=formatted_messages
|
||||||
):
|
)
|
||||||
|
async for chunk in chat_stream:
|
||||||
|
# Log the time taken to start response
|
||||||
|
if final_chunk is None:
|
||||||
|
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
|
||||||
|
# Keep track of the last chunk for usage data
|
||||||
|
final_chunk = chunk
|
||||||
|
# Handle streamed response chunk
|
||||||
message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback)
|
message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback)
|
||||||
message = message or chunk.text
|
message = message or chunk.text
|
||||||
aggregated_response += message
|
aggregated_response += message
|
||||||
g.send(message)
|
yield message
|
||||||
if stopped:
|
if stopped:
|
||||||
raise ValueError(message)
|
raise ValueError(message)
|
||||||
|
|
||||||
|
# Log the time taken to stream the entire response
|
||||||
|
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
|
||||||
|
|
||||||
# Calculate cost of chat
|
# Calculate cost of chat
|
||||||
input_tokens = chunk.usage_metadata.prompt_token_count
|
input_tokens = final_chunk.usage_metadata.prompt_token_count or 0 if final_chunk else 0
|
||||||
output_tokens = chunk.usage_metadata.candidates_token_count
|
output_tokens = final_chunk.usage_metadata.candidates_token_count or 0 if final_chunk else 0
|
||||||
thought_tokens = chunk.usage_metadata.thoughts_token_count or 0
|
thought_tokens = final_chunk.usage_metadata.thoughts_token_count or 0 if final_chunk else 0
|
||||||
tracer["usage"] = get_chat_usage_metrics(
|
tracer["usage"] = get_chat_usage_metrics(
|
||||||
model_name, input_tokens, output_tokens, thought_tokens=thought_tokens, usage=tracer.get("usage")
|
model_name, input_tokens, output_tokens, thought_tokens=thought_tokens, usage=tracer.get("usage")
|
||||||
)
|
)
|
||||||
@@ -254,9 +218,7 @@ def gemini_llm_thread(
|
|||||||
+ f"Last Message by {messages[-1].role}: {messages[-1].content}"
|
+ f"Last Message by {messages[-1].role}: {messages[-1].content}"
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in gemini_llm_thread: {e}", exc_info=True)
|
logger.error(f"Error in gemini_chat_completion_with_backoff stream: {e}", exc_info=True)
|
||||||
finally:
|
|
||||||
g.close()
|
|
||||||
|
|
||||||
|
|
||||||
def handle_gemini_response(
|
def handle_gemini_response(
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
import json
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from typing import Any, Dict, Iterator, List, Optional, Union
|
from time import perf_counter
|
||||||
|
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
||||||
|
|
||||||
import pyjson5
|
import pyjson5
|
||||||
from langchain.schema import ChatMessage
|
from langchain.schema import ChatMessage
|
||||||
@@ -13,7 +14,6 @@ from khoj.database.models import Agent, ChatModel, KhojUser
|
|||||||
from khoj.processor.conversation import prompts
|
from khoj.processor.conversation import prompts
|
||||||
from khoj.processor.conversation.offline.utils import download_model
|
from khoj.processor.conversation.offline.utils import download_model
|
||||||
from khoj.processor.conversation.utils import (
|
from khoj.processor.conversation.utils import (
|
||||||
ThreadedGenerator,
|
|
||||||
clean_json,
|
clean_json,
|
||||||
commit_conversation_trace,
|
commit_conversation_trace,
|
||||||
generate_chatml_messages_with_context,
|
generate_chatml_messages_with_context,
|
||||||
@@ -147,7 +147,7 @@ def filter_questions(questions: List[str]):
|
|||||||
return list(filtered_questions)
|
return list(filtered_questions)
|
||||||
|
|
||||||
|
|
||||||
def converse_offline(
|
async def converse_offline(
|
||||||
user_query,
|
user_query,
|
||||||
references=[],
|
references=[],
|
||||||
online_results={},
|
online_results={},
|
||||||
@@ -167,9 +167,9 @@ def converse_offline(
|
|||||||
additional_context: List[str] = None,
|
additional_context: List[str] = None,
|
||||||
generated_asset_results: Dict[str, Dict] = {},
|
generated_asset_results: Dict[str, Dict] = {},
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
) -> Union[ThreadedGenerator, Iterator[str]]:
|
) -> AsyncGenerator[str, None]:
|
||||||
"""
|
"""
|
||||||
Converse with user using Llama
|
Converse with user using Llama (Async Version)
|
||||||
"""
|
"""
|
||||||
# Initialize Variables
|
# Initialize Variables
|
||||||
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
||||||
@@ -200,10 +200,17 @@ def converse_offline(
|
|||||||
|
|
||||||
# Get Conversation Primer appropriate to Conversation Type
|
# Get Conversation Primer appropriate to Conversation Type
|
||||||
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
|
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
|
||||||
return iter([prompts.no_notes_found.format()])
|
response = prompts.no_notes_found.format()
|
||||||
|
if completion_func:
|
||||||
|
await completion_func(chat_response=response)
|
||||||
|
yield response
|
||||||
|
return
|
||||||
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
|
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
|
||||||
completion_func(chat_response=prompts.no_online_results_found.format())
|
response = prompts.no_online_results_found.format()
|
||||||
return iter([prompts.no_online_results_found.format()])
|
if completion_func:
|
||||||
|
await completion_func(chat_response=response)
|
||||||
|
yield response
|
||||||
|
return
|
||||||
|
|
||||||
context_message = ""
|
context_message = ""
|
||||||
if not is_none_or_empty(references):
|
if not is_none_or_empty(references):
|
||||||
@@ -240,33 +247,77 @@ def converse_offline(
|
|||||||
|
|
||||||
logger.debug(f"Conversation Context for {model_name}: {messages_to_print(messages)}")
|
logger.debug(f"Conversation Context for {model_name}: {messages_to_print(messages)}")
|
||||||
|
|
||||||
g = ThreadedGenerator(references, online_results, completion_func=completion_func)
|
# Use asyncio.Queue and a thread to bridge sync iterator
|
||||||
t = Thread(target=llm_thread, args=(g, messages, offline_chat_model, max_prompt_size, tracer))
|
queue: asyncio.Queue = asyncio.Queue()
|
||||||
t.start()
|
|
||||||
return g
|
|
||||||
|
|
||||||
|
|
||||||
def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int = None, tracer: dict = {}):
|
|
||||||
stop_phrases = ["<s>", "INST]", "Notes:"]
|
stop_phrases = ["<s>", "INST]", "Notes:"]
|
||||||
aggregated_response = ""
|
aggregated_response_container = {"response": ""}
|
||||||
|
|
||||||
|
def _sync_llm_thread():
|
||||||
|
"""Synchronous function to run in a separate thread."""
|
||||||
|
aggregated_response = ""
|
||||||
|
start_time = perf_counter()
|
||||||
state.chat_lock.acquire()
|
state.chat_lock.acquire()
|
||||||
try:
|
try:
|
||||||
response_iterator = send_message_to_model_offline(
|
response_iterator = send_message_to_model_offline(
|
||||||
messages, loaded_model=model, stop=stop_phrases, max_prompt_size=max_prompt_size, streaming=True
|
messages,
|
||||||
|
loaded_model=offline_chat_model,
|
||||||
|
stop=stop_phrases,
|
||||||
|
max_prompt_size=max_prompt_size,
|
||||||
|
streaming=True,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
for response in response_iterator:
|
for response in response_iterator:
|
||||||
response_delta = response["choices"][0]["delta"].get("content", "")
|
response_delta = response["choices"][0]["delta"].get("content", "")
|
||||||
|
# Log the time taken to start response
|
||||||
|
if aggregated_response == "" and response_delta != "":
|
||||||
|
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
|
||||||
|
# Handle response chunk
|
||||||
aggregated_response += response_delta
|
aggregated_response += response_delta
|
||||||
g.send(response_delta)
|
# Put chunk into the asyncio queue (non-blocking)
|
||||||
|
try:
|
||||||
|
queue.put_nowait(response_delta)
|
||||||
|
except asyncio.QueueFull:
|
||||||
|
# Should not happen with default queue size unless consumer is very slow
|
||||||
|
logger.warning("Asyncio queue full during offline LLM streaming.")
|
||||||
|
# Potentially block here or handle differently if needed
|
||||||
|
asyncio.run(queue.put(response_delta))
|
||||||
|
|
||||||
|
# Log the time taken to stream the entire response
|
||||||
|
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
|
||||||
|
|
||||||
# Save conversation trace
|
# Save conversation trace
|
||||||
|
tracer["chat_model"] = model_name
|
||||||
if is_promptrace_enabled():
|
if is_promptrace_enabled():
|
||||||
commit_conversation_trace(messages, aggregated_response, tracer)
|
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in offline LLM thread: {e}", exc_info=True)
|
||||||
finally:
|
finally:
|
||||||
state.chat_lock.release()
|
state.chat_lock.release()
|
||||||
g.close()
|
# Signal end of stream
|
||||||
|
queue.put_nowait(None)
|
||||||
|
aggregated_response_container["response"] = aggregated_response
|
||||||
|
|
||||||
|
# Start the synchronous thread
|
||||||
|
thread = Thread(target=_sync_llm_thread)
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
# Asynchronously consume from the queue
|
||||||
|
while True:
|
||||||
|
chunk = await queue.get()
|
||||||
|
if chunk is None: # End of stream signal
|
||||||
|
queue.task_done()
|
||||||
|
break
|
||||||
|
yield chunk
|
||||||
|
queue.task_done()
|
||||||
|
|
||||||
|
# Wait for the thread to finish (optional, ensures cleanup)
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
await loop.run_in_executor(None, thread.join)
|
||||||
|
|
||||||
|
# Call the completion function after streaming is done
|
||||||
|
if completion_func:
|
||||||
|
await completion_func(chat_response=aggregated_response_container["response"])
|
||||||
|
|
||||||
|
|
||||||
def send_message_to_model_offline(
|
def send_message_to_model_offline(
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Dict, List, Optional
|
from typing import AsyncGenerator, Dict, List, Optional
|
||||||
|
|
||||||
import pyjson5
|
import pyjson5
|
||||||
from langchain.schema import ChatMessage
|
from langchain.schema import ChatMessage
|
||||||
@@ -162,7 +162,7 @@ def send_message_to_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def converse_openai(
|
async def converse_openai(
|
||||||
references,
|
references,
|
||||||
user_query,
|
user_query,
|
||||||
online_results: Optional[Dict[str, Dict]] = None,
|
online_results: Optional[Dict[str, Dict]] = None,
|
||||||
@@ -187,7 +187,7 @@ def converse_openai(
|
|||||||
program_execution_context: List[str] = None,
|
program_execution_context: List[str] = None,
|
||||||
deepthought: Optional[bool] = False,
|
deepthought: Optional[bool] = False,
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
):
|
) -> AsyncGenerator[str, None]:
|
||||||
"""
|
"""
|
||||||
Converse with user using OpenAI's ChatGPT
|
Converse with user using OpenAI's ChatGPT
|
||||||
"""
|
"""
|
||||||
@@ -217,11 +217,17 @@ def converse_openai(
|
|||||||
|
|
||||||
# Get Conversation Primer appropriate to Conversation Type
|
# Get Conversation Primer appropriate to Conversation Type
|
||||||
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
|
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
|
||||||
completion_func(chat_response=prompts.no_notes_found.format())
|
response = prompts.no_notes_found.format()
|
||||||
return iter([prompts.no_notes_found.format()])
|
if completion_func:
|
||||||
|
await completion_func(chat_response=response)
|
||||||
|
yield response
|
||||||
|
return
|
||||||
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
|
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
|
||||||
completion_func(chat_response=prompts.no_online_results_found.format())
|
response = prompts.no_online_results_found.format()
|
||||||
return iter([prompts.no_online_results_found.format()])
|
if completion_func:
|
||||||
|
await completion_func(chat_response=response)
|
||||||
|
yield response
|
||||||
|
return
|
||||||
|
|
||||||
context_message = ""
|
context_message = ""
|
||||||
if not is_none_or_empty(references):
|
if not is_none_or_empty(references):
|
||||||
@@ -255,19 +261,23 @@ def converse_openai(
|
|||||||
logger.debug(f"Conversation Context for GPT: {messages_to_print(messages)}")
|
logger.debug(f"Conversation Context for GPT: {messages_to_print(messages)}")
|
||||||
|
|
||||||
# Get Response from GPT
|
# Get Response from GPT
|
||||||
return chat_completion_with_backoff(
|
full_response = ""
|
||||||
|
async for chunk in chat_completion_with_backoff(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
compiled_references=references,
|
|
||||||
online_results=online_results,
|
|
||||||
model_name=model,
|
model_name=model,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
openai_api_key=api_key,
|
openai_api_key=api_key,
|
||||||
api_base_url=api_base_url,
|
api_base_url=api_base_url,
|
||||||
completion_func=completion_func,
|
|
||||||
deepthought=deepthought,
|
deepthought=deepthought,
|
||||||
model_kwargs={"stop": ["Notes:\n["]},
|
model_kwargs={"stop": ["Notes:\n["]},
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
):
|
||||||
|
full_response += chunk
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
# Call completion_func once finish streaming and we have the full response
|
||||||
|
if completion_func:
|
||||||
|
await completion_func(chat_response=full_response)
|
||||||
|
|
||||||
|
|
||||||
def clean_response_schema(schema: BaseModel | dict) -> dict:
|
def clean_response_schema(schema: BaseModel | dict) -> dict:
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from threading import Thread
|
from time import perf_counter
|
||||||
from typing import Dict, List
|
from typing import AsyncGenerator, Dict, List
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
@@ -16,13 +16,10 @@ from tenacity import (
|
|||||||
wait_random_exponential,
|
wait_random_exponential,
|
||||||
)
|
)
|
||||||
|
|
||||||
from khoj.processor.conversation.utils import (
|
from khoj.processor.conversation.utils import JsonSupport, commit_conversation_trace
|
||||||
JsonSupport,
|
|
||||||
ThreadedGenerator,
|
|
||||||
commit_conversation_trace,
|
|
||||||
)
|
|
||||||
from khoj.utils.helpers import (
|
from khoj.utils.helpers import (
|
||||||
get_chat_usage_metrics,
|
get_chat_usage_metrics,
|
||||||
|
get_openai_async_client,
|
||||||
get_openai_client,
|
get_openai_client,
|
||||||
is_promptrace_enabled,
|
is_promptrace_enabled,
|
||||||
)
|
)
|
||||||
@@ -30,6 +27,7 @@ from khoj.utils.helpers import (
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
openai_clients: Dict[str, openai.OpenAI] = {}
|
openai_clients: Dict[str, openai.OpenAI] = {}
|
||||||
|
openai_async_clients: Dict[str, openai.AsyncOpenAI] = {}
|
||||||
|
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
@@ -124,45 +122,22 @@ def completion_with_backoff(
|
|||||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||||
reraise=True,
|
reraise=True,
|
||||||
)
|
)
|
||||||
def chat_completion_with_backoff(
|
async def chat_completion_with_backoff(
|
||||||
messages,
|
messages,
|
||||||
compiled_references,
|
|
||||||
online_results,
|
|
||||||
model_name,
|
model_name,
|
||||||
temperature,
|
temperature,
|
||||||
openai_api_key=None,
|
openai_api_key=None,
|
||||||
api_base_url=None,
|
api_base_url=None,
|
||||||
completion_func=None,
|
|
||||||
deepthought=False,
|
|
||||||
model_kwargs=None,
|
|
||||||
tracer: dict = {},
|
|
||||||
):
|
|
||||||
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
|
|
||||||
t = Thread(
|
|
||||||
target=llm_thread,
|
|
||||||
args=(g, messages, model_name, temperature, openai_api_key, api_base_url, deepthought, model_kwargs, tracer),
|
|
||||||
)
|
|
||||||
t.start()
|
|
||||||
return g
|
|
||||||
|
|
||||||
|
|
||||||
def llm_thread(
|
|
||||||
g,
|
|
||||||
messages,
|
|
||||||
model_name: str,
|
|
||||||
temperature,
|
|
||||||
openai_api_key=None,
|
|
||||||
api_base_url=None,
|
|
||||||
deepthought=False,
|
deepthought=False,
|
||||||
model_kwargs: dict = {},
|
model_kwargs: dict = {},
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
):
|
) -> AsyncGenerator[str, None]:
|
||||||
try:
|
try:
|
||||||
client_key = f"{openai_api_key}--{api_base_url}"
|
client_key = f"{openai_api_key}--{api_base_url}"
|
||||||
client = openai_clients.get(client_key)
|
client = openai_async_clients.get(client_key)
|
||||||
if not client:
|
if not client:
|
||||||
client = get_openai_client(openai_api_key, api_base_url)
|
client = get_openai_async_client(openai_api_key, api_base_url)
|
||||||
openai_clients[client_key] = client
|
openai_async_clients[client_key] = client
|
||||||
|
|
||||||
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
|
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
|
||||||
|
|
||||||
@@ -207,53 +182,58 @@ def llm_thread(
|
|||||||
if os.getenv("KHOJ_LLM_SEED"):
|
if os.getenv("KHOJ_LLM_SEED"):
|
||||||
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
|
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
|
||||||
|
|
||||||
chat: ChatCompletion | openai.Stream[ChatCompletionChunk] = client.chat.completions.create(
|
aggregated_response = ""
|
||||||
messages=formatted_messages,
|
final_chunk = None
|
||||||
model=model_name, # type: ignore
|
start_time = perf_counter()
|
||||||
|
chat_stream: openai.AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(
|
||||||
|
messages=formatted_messages, # type: ignore
|
||||||
|
model=model_name,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
timeout=20,
|
timeout=20,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
async for chunk in chat_stream:
|
||||||
aggregated_response = ""
|
# Log the time taken to start response
|
||||||
if not stream:
|
if final_chunk is None:
|
||||||
chunk = chat
|
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
|
||||||
aggregated_response = chunk.choices[0].message.content
|
# Keep track of the last chunk for usage data
|
||||||
g.send(aggregated_response)
|
final_chunk = chunk
|
||||||
else:
|
# Handle streamed response chunk
|
||||||
for chunk in chat:
|
|
||||||
if len(chunk.choices) == 0:
|
if len(chunk.choices) == 0:
|
||||||
continue
|
continue
|
||||||
delta_chunk = chunk.choices[0].delta
|
delta_chunk = chunk.choices[0].delta
|
||||||
text_chunk = ""
|
text_chunk = ""
|
||||||
if isinstance(delta_chunk, str):
|
if isinstance(delta_chunk, str):
|
||||||
text_chunk = delta_chunk
|
text_chunk = delta_chunk
|
||||||
elif delta_chunk.content:
|
elif delta_chunk and delta_chunk.content:
|
||||||
text_chunk = delta_chunk.content
|
text_chunk = delta_chunk.content
|
||||||
if text_chunk:
|
if text_chunk:
|
||||||
aggregated_response += text_chunk
|
aggregated_response += text_chunk
|
||||||
g.send(text_chunk)
|
yield text_chunk
|
||||||
|
|
||||||
# Calculate cost of chat
|
# Log the time taken to stream the entire response
|
||||||
input_tokens = chunk.usage.prompt_tokens if hasattr(chunk, "usage") and chunk.usage else 0
|
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
|
||||||
output_tokens = chunk.usage.completion_tokens if hasattr(chunk, "usage") and chunk.usage else 0
|
|
||||||
cost = (
|
# Calculate cost of chat after stream finishes
|
||||||
chunk.usage.model_extra.get("estimated_cost", 0) if hasattr(chunk, "usage") and chunk.usage else 0
|
input_tokens, output_tokens, cost = 0, 0, 0
|
||||||
) # Estimated costs returned by DeepInfra API
|
if final_chunk and hasattr(final_chunk, "usage") and final_chunk.usage:
|
||||||
tracer["usage"] = get_chat_usage_metrics(
|
input_tokens = final_chunk.usage.prompt_tokens
|
||||||
model_name, input_tokens, output_tokens, usage=tracer.get("usage"), cost=cost
|
output_tokens = final_chunk.usage.completion_tokens
|
||||||
)
|
# Estimated costs returned by DeepInfra API
|
||||||
|
if final_chunk.usage.model_extra and "estimated_cost" in final_chunk.usage.model_extra:
|
||||||
|
cost = final_chunk.usage.model_extra.get("estimated_cost", 0)
|
||||||
|
|
||||||
# Save conversation trace
|
# Save conversation trace
|
||||||
tracer["chat_model"] = model_name
|
tracer["chat_model"] = model_name
|
||||||
tracer["temperature"] = temperature
|
tracer["temperature"] = temperature
|
||||||
|
tracer["usage"] = get_chat_usage_metrics(
|
||||||
|
model_name, input_tokens, output_tokens, usage=tracer.get("usage"), cost=cost
|
||||||
|
)
|
||||||
if is_promptrace_enabled():
|
if is_promptrace_enabled():
|
||||||
commit_conversation_trace(messages, aggregated_response, tracer)
|
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in llm_thread: {e}", exc_info=True)
|
logger.error(f"Error in chat_completion_with_backoff stream: {e}", exc_info=True)
|
||||||
finally:
|
|
||||||
g.close()
|
|
||||||
|
|
||||||
|
|
||||||
def get_openai_api_json_support(model_name: str, api_base_url: str = None) -> JsonSupport:
|
def get_openai_api_json_support(model_name: str, api_base_url: str = None) -> JsonSupport:
|
||||||
|
|||||||
@@ -77,42 +77,6 @@ model_to_prompt_size = {
|
|||||||
model_to_tokenizer: Dict[str, str] = {}
|
model_to_tokenizer: Dict[str, str] = {}
|
||||||
|
|
||||||
|
|
||||||
class ThreadedGenerator:
|
|
||||||
def __init__(self, compiled_references, online_results, completion_func=None):
|
|
||||||
self.queue = queue.Queue()
|
|
||||||
self.compiled_references = compiled_references
|
|
||||||
self.online_results = online_results
|
|
||||||
self.completion_func = completion_func
|
|
||||||
self.response = ""
|
|
||||||
self.start_time = perf_counter()
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __next__(self):
|
|
||||||
item = self.queue.get()
|
|
||||||
if item is StopIteration:
|
|
||||||
time_to_response = perf_counter() - self.start_time
|
|
||||||
logger.info(f"Chat streaming took: {time_to_response:.3f} seconds")
|
|
||||||
if self.completion_func:
|
|
||||||
# The completion func effectively acts as a callback.
|
|
||||||
# It adds the aggregated response to the conversation history.
|
|
||||||
self.completion_func(chat_response=self.response)
|
|
||||||
raise StopIteration
|
|
||||||
return item
|
|
||||||
|
|
||||||
def send(self, data):
|
|
||||||
if self.response == "":
|
|
||||||
time_to_first_response = perf_counter() - self.start_time
|
|
||||||
logger.info(f"First response took: {time_to_first_response:.3f} seconds")
|
|
||||||
|
|
||||||
self.response += data
|
|
||||||
self.queue.put(data)
|
|
||||||
|
|
||||||
def close(self):
|
|
||||||
self.queue.put(StopIteration)
|
|
||||||
|
|
||||||
|
|
||||||
class InformationCollectionIteration:
|
class InformationCollectionIteration:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -254,7 +218,7 @@ def message_to_log(
|
|||||||
return conversation_log
|
return conversation_log
|
||||||
|
|
||||||
|
|
||||||
def save_to_conversation_log(
|
async def save_to_conversation_log(
|
||||||
q: str,
|
q: str,
|
||||||
chat_response: str,
|
chat_response: str,
|
||||||
user: KhojUser,
|
user: KhojUser,
|
||||||
@@ -306,7 +270,7 @@ def save_to_conversation_log(
|
|||||||
khoj_message_metadata=khoj_message_metadata,
|
khoj_message_metadata=khoj_message_metadata,
|
||||||
conversation_log=meta_log.get("chat", []),
|
conversation_log=meta_log.get("chat", []),
|
||||||
)
|
)
|
||||||
ConversationAdapters.save_conversation(
|
await ConversationAdapters.save_conversation(
|
||||||
user,
|
user,
|
||||||
{"chat": updated_conversation},
|
{"chat": updated_conversation},
|
||||||
client_application=client_application,
|
client_application=client_application,
|
||||||
|
|||||||
@@ -67,7 +67,6 @@ from khoj.routers.research import (
|
|||||||
from khoj.routers.storage import upload_user_image_to_bucket
|
from khoj.routers.storage import upload_user_image_to_bucket
|
||||||
from khoj.utils import state
|
from khoj.utils import state
|
||||||
from khoj.utils.helpers import (
|
from khoj.utils.helpers import (
|
||||||
AsyncIteratorWrapper,
|
|
||||||
ConversationCommand,
|
ConversationCommand,
|
||||||
command_descriptions,
|
command_descriptions,
|
||||||
convert_image_to_webp,
|
convert_image_to_webp,
|
||||||
@@ -999,7 +998,7 @@ async def chat(
|
|||||||
return
|
return
|
||||||
|
|
||||||
llm_response = construct_automation_created_message(automation, crontime, query_to_run, subject)
|
llm_response = construct_automation_created_message(automation, crontime, query_to_run, subject)
|
||||||
await sync_to_async(save_to_conversation_log)(
|
await save_to_conversation_log(
|
||||||
q,
|
q,
|
||||||
llm_response,
|
llm_response,
|
||||||
user,
|
user,
|
||||||
@@ -1308,26 +1307,31 @@ async def chat(
|
|||||||
yield result
|
yield result
|
||||||
|
|
||||||
continue_stream = True
|
continue_stream = True
|
||||||
iterator = AsyncIteratorWrapper(llm_response)
|
async for item in llm_response:
|
||||||
async for item in iterator:
|
# Should not happen with async generator, end is signaled by loop exit. Skip.
|
||||||
if item is None:
|
if item is None:
|
||||||
async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
|
continue
|
||||||
yield result
|
|
||||||
# Send Usage Metadata once llm interactions are complete
|
|
||||||
async for event in send_event(ChatEvent.USAGE, tracer.get("usage")):
|
|
||||||
yield event
|
|
||||||
async for result in send_event(ChatEvent.END_RESPONSE, ""):
|
|
||||||
yield result
|
|
||||||
logger.debug("Finished streaming response")
|
|
||||||
return
|
|
||||||
if not connection_alive or not continue_stream:
|
if not connection_alive or not continue_stream:
|
||||||
|
# Drain the generator if disconnected but keep processing internally
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
async for result in send_event(ChatEvent.MESSAGE, f"{item}"):
|
async for result in send_event(ChatEvent.MESSAGE, f"{item}"):
|
||||||
yield result
|
yield result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
continue_stream = False
|
continue_stream = False
|
||||||
logger.info(f"User {user} disconnected. Emitting rest of responses to clear thread: {e}")
|
logger.info(f"User {user} disconnected or error during streaming. Stopping send: {e}")
|
||||||
|
|
||||||
|
# Signal end of LLM response after the loop finishes
|
||||||
|
if connection_alive:
|
||||||
|
async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
|
||||||
|
yield result
|
||||||
|
# Send Usage Metadata once llm interactions are complete
|
||||||
|
if tracer.get("usage"):
|
||||||
|
async for event in send_event(ChatEvent.USAGE, tracer.get("usage")):
|
||||||
|
yield event
|
||||||
|
async for result in send_event(ChatEvent.END_RESPONSE, ""):
|
||||||
|
yield result
|
||||||
|
logger.debug("Finished streaming response")
|
||||||
|
|
||||||
## Stream Text Response
|
## Stream Text Response
|
||||||
if stream:
|
if stream:
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import asyncio
|
|
||||||
import base64
|
import base64
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
@@ -6,9 +5,7 @@ import logging
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from enum import Enum
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from random import random
|
from random import random
|
||||||
from typing import (
|
from typing import (
|
||||||
@@ -17,7 +14,6 @@ from typing import (
|
|||||||
AsyncGenerator,
|
AsyncGenerator,
|
||||||
Callable,
|
Callable,
|
||||||
Dict,
|
Dict,
|
||||||
Iterator,
|
|
||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
Set,
|
Set,
|
||||||
@@ -97,7 +93,6 @@ from khoj.processor.conversation.openai.gpt import (
|
|||||||
)
|
)
|
||||||
from khoj.processor.conversation.utils import (
|
from khoj.processor.conversation.utils import (
|
||||||
ChatEvent,
|
ChatEvent,
|
||||||
ThreadedGenerator,
|
|
||||||
clean_json,
|
clean_json,
|
||||||
clean_mermaidjs,
|
clean_mermaidjs,
|
||||||
construct_chat_history,
|
construct_chat_history,
|
||||||
@@ -126,8 +121,6 @@ from khoj.utils.rawconfig import ChatRequestBody, FileAttachment, FileData, Loca
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
executor = ThreadPoolExecutor(max_workers=1)
|
|
||||||
|
|
||||||
|
|
||||||
NOTION_OAUTH_CLIENT_ID = os.getenv("NOTION_OAUTH_CLIENT_ID")
|
NOTION_OAUTH_CLIENT_ID = os.getenv("NOTION_OAUTH_CLIENT_ID")
|
||||||
NOTION_OAUTH_CLIENT_SECRET = os.getenv("NOTION_OAUTH_CLIENT_SECRET")
|
NOTION_OAUTH_CLIENT_SECRET = os.getenv("NOTION_OAUTH_CLIENT_SECRET")
|
||||||
@@ -262,11 +255,6 @@ def get_conversation_command(query: str) -> ConversationCommand:
|
|||||||
return ConversationCommand.Default
|
return ConversationCommand.Default
|
||||||
|
|
||||||
|
|
||||||
async def agenerate_chat_response(*args):
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
return await loop.run_in_executor(executor, generate_chat_response, *args)
|
|
||||||
|
|
||||||
|
|
||||||
def gather_raw_query_files(
|
def gather_raw_query_files(
|
||||||
query_files: Dict[str, str],
|
query_files: Dict[str, str],
|
||||||
):
|
):
|
||||||
@@ -1418,7 +1406,7 @@ def send_message_to_model_wrapper_sync(
|
|||||||
raise HTTPException(status_code=500, detail="Invalid conversation config")
|
raise HTTPException(status_code=500, detail="Invalid conversation config")
|
||||||
|
|
||||||
|
|
||||||
def generate_chat_response(
|
async def agenerate_chat_response(
|
||||||
q: str,
|
q: str,
|
||||||
meta_log: dict,
|
meta_log: dict,
|
||||||
conversation: Conversation,
|
conversation: Conversation,
|
||||||
@@ -1444,13 +1432,14 @@ def generate_chat_response(
|
|||||||
generated_asset_results: Dict[str, Dict] = {},
|
generated_asset_results: Dict[str, Dict] = {},
|
||||||
is_subscribed: bool = False,
|
is_subscribed: bool = False,
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
|
) -> Tuple[AsyncGenerator[str, None], Dict[str, str]]:
|
||||||
# Initialize Variables
|
# Initialize Variables
|
||||||
chat_response = None
|
chat_response_generator = None
|
||||||
logger.debug(f"Conversation Types: {conversation_commands}")
|
logger.debug(f"Conversation Types: {conversation_commands}")
|
||||||
|
|
||||||
metadata = {}
|
metadata = {}
|
||||||
agent = AgentAdapters.get_conversation_agent_by_id(conversation.agent.id) if conversation.agent else None
|
agent = await AgentAdapters.aget_conversation_agent_by_id(conversation.agent.id) if conversation.agent else None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
partial_completion = partial(
|
partial_completion = partial(
|
||||||
save_to_conversation_log,
|
save_to_conversation_log,
|
||||||
@@ -1481,17 +1470,17 @@ def generate_chat_response(
|
|||||||
code_results = {}
|
code_results = {}
|
||||||
deepthought = True
|
deepthought = True
|
||||||
|
|
||||||
chat_model = ConversationAdapters.get_valid_chat_model(user, conversation, is_subscribed)
|
chat_model = await ConversationAdapters.aget_valid_chat_model(user, conversation, is_subscribed)
|
||||||
vision_available = chat_model.vision_enabled
|
vision_available = chat_model.vision_enabled
|
||||||
if not vision_available and query_images:
|
if not vision_available and query_images:
|
||||||
vision_enabled_config = ConversationAdapters.get_vision_enabled_config()
|
vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config()
|
||||||
if vision_enabled_config:
|
if vision_enabled_config:
|
||||||
chat_model = vision_enabled_config
|
chat_model = vision_enabled_config
|
||||||
vision_available = True
|
vision_available = True
|
||||||
|
|
||||||
if chat_model.model_type == "offline":
|
if chat_model.model_type == "offline":
|
||||||
loaded_model = state.offline_chat_processor_config.loaded_model
|
loaded_model = state.offline_chat_processor_config.loaded_model
|
||||||
chat_response = converse_offline(
|
chat_response_generator = converse_offline(
|
||||||
user_query=query_to_run,
|
user_query=query_to_run,
|
||||||
references=compiled_references,
|
references=compiled_references,
|
||||||
online_results=online_results,
|
online_results=online_results,
|
||||||
@@ -1515,7 +1504,7 @@ def generate_chat_response(
|
|||||||
openai_chat_config = chat_model.ai_model_api
|
openai_chat_config = chat_model.ai_model_api
|
||||||
api_key = openai_chat_config.api_key
|
api_key = openai_chat_config.api_key
|
||||||
chat_model_name = chat_model.name
|
chat_model_name = chat_model.name
|
||||||
chat_response = converse_openai(
|
chat_response_generator = converse_openai(
|
||||||
compiled_references,
|
compiled_references,
|
||||||
query_to_run,
|
query_to_run,
|
||||||
query_images=query_images,
|
query_images=query_images,
|
||||||
@@ -1544,7 +1533,7 @@ def generate_chat_response(
|
|||||||
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
|
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
|
||||||
api_key = chat_model.ai_model_api.api_key
|
api_key = chat_model.ai_model_api.api_key
|
||||||
api_base_url = chat_model.ai_model_api.api_base_url
|
api_base_url = chat_model.ai_model_api.api_base_url
|
||||||
chat_response = converse_anthropic(
|
chat_response_generator = converse_anthropic(
|
||||||
compiled_references,
|
compiled_references,
|
||||||
query_to_run,
|
query_to_run,
|
||||||
query_images=query_images,
|
query_images=query_images,
|
||||||
@@ -1572,7 +1561,7 @@ def generate_chat_response(
|
|||||||
elif chat_model.model_type == ChatModel.ModelType.GOOGLE:
|
elif chat_model.model_type == ChatModel.ModelType.GOOGLE:
|
||||||
api_key = chat_model.ai_model_api.api_key
|
api_key = chat_model.ai_model_api.api_key
|
||||||
api_base_url = chat_model.ai_model_api.api_base_url
|
api_base_url = chat_model.ai_model_api.api_base_url
|
||||||
chat_response = converse_gemini(
|
chat_response_generator = converse_gemini(
|
||||||
compiled_references,
|
compiled_references,
|
||||||
query_to_run,
|
query_to_run,
|
||||||
online_results,
|
online_results,
|
||||||
@@ -1604,7 +1593,8 @@ def generate_chat_response(
|
|||||||
logger.error(e, exc_info=True)
|
logger.error(e, exc_info=True)
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
return chat_response, metadata
|
# Return the generator directly
|
||||||
|
return chat_response_generator, metadata
|
||||||
|
|
||||||
|
|
||||||
class DeleteMessageRequestBody(BaseModel):
|
class DeleteMessageRequestBody(BaseModel):
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ from time import perf_counter
|
|||||||
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Tuple, Union
|
||||||
from urllib.parse import ParseResult, urlparse
|
from urllib.parse import ParseResult, urlparse
|
||||||
|
|
||||||
|
import anthropic
|
||||||
import openai
|
import openai
|
||||||
import psutil
|
import psutil
|
||||||
import pyjson5
|
import pyjson5
|
||||||
@@ -30,6 +31,7 @@ import requests
|
|||||||
import torch
|
import torch
|
||||||
from asgiref.sync import sync_to_async
|
from asgiref.sync import sync_to_async
|
||||||
from email_validator import EmailNotValidError, EmailUndeliverableError, validate_email
|
from email_validator import EmailNotValidError, EmailUndeliverableError, validate_email
|
||||||
|
from google import genai
|
||||||
from google.auth.credentials import Credentials
|
from google.auth.credentials import Credentials
|
||||||
from google.oauth2 import service_account
|
from google.oauth2 import service_account
|
||||||
from magika import Magika
|
from magika import Magika
|
||||||
@@ -729,6 +731,60 @@ def get_openai_client(api_key: str, api_base_url: str) -> Union[openai.OpenAI, o
|
|||||||
return client
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
def get_openai_async_client(api_key: str, api_base_url: str) -> Union[openai.AsyncOpenAI, openai.AsyncAzureOpenAI]:
|
||||||
|
"""Get OpenAI or AzureOpenAI client based on the API Base URL"""
|
||||||
|
parsed_url = urlparse(api_base_url)
|
||||||
|
if parsed_url.hostname and parsed_url.hostname.endswith(".openai.azure.com"):
|
||||||
|
client = openai.AsyncAzureOpenAI(
|
||||||
|
api_key=api_key,
|
||||||
|
azure_endpoint=api_base_url,
|
||||||
|
api_version="2024-10-21",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
client = openai.AsyncOpenAI(
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=api_base_url,
|
||||||
|
)
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
def get_anthropic_client(api_key, api_base_url=None) -> anthropic.Anthropic | anthropic.AnthropicVertex:
|
||||||
|
api_info = get_ai_api_info(api_key, api_base_url)
|
||||||
|
if api_info.api_key:
|
||||||
|
client = anthropic.Anthropic(api_key=api_info.api_key)
|
||||||
|
else:
|
||||||
|
client = anthropic.AnthropicVertex(
|
||||||
|
region=api_info.region,
|
||||||
|
project_id=api_info.project,
|
||||||
|
credentials=api_info.credentials,
|
||||||
|
)
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
def get_anthropic_async_client(api_key, api_base_url=None) -> anthropic.AsyncAnthropic | anthropic.AsyncAnthropicVertex:
|
||||||
|
api_info = get_ai_api_info(api_key, api_base_url)
|
||||||
|
if api_info.api_key:
|
||||||
|
client = anthropic.AsyncAnthropic(api_key=api_info.api_key)
|
||||||
|
else:
|
||||||
|
client = anthropic.AsyncAnthropicVertex(
|
||||||
|
region=api_info.region,
|
||||||
|
project_id=api_info.project,
|
||||||
|
credentials=api_info.credentials,
|
||||||
|
)
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
def get_gemini_client(api_key, api_base_url=None) -> genai.Client:
|
||||||
|
api_info = get_ai_api_info(api_key, api_base_url)
|
||||||
|
return genai.Client(
|
||||||
|
location=api_info.region,
|
||||||
|
project=api_info.project,
|
||||||
|
credentials=api_info.credentials,
|
||||||
|
api_key=api_info.api_key,
|
||||||
|
vertexai=api_info.api_key is None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def normalize_email(email: str, check_deliverability=False) -> tuple[str, bool]:
|
def normalize_email(email: str, check_deliverability=False) -> tuple[str, bool]:
|
||||||
"""Normalize, validate and check deliverability of email address"""
|
"""Normalize, validate and check deliverability of email address"""
|
||||||
lower_email = email.lower()
|
lower_email = email.lower()
|
||||||
|
|||||||
Reference in New Issue
Block a user