mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 13:18:18 +00:00
Simplify AI Chat Response Streaming (#1167)
Reason --- - Simplify code and logic to stream chat response by solely relying on asyncio event loop. - Reduce overhead of managing threads to increase efficiency and throughput (where possible). Details --- - Use async/await with no threading when generating chat response via OpenAI, Gemini, Anthropic AI model APIs - Use threading for offline chat model as llama-cpp doesn't support async streaming yet
This commit is contained in:
@@ -763,9 +763,9 @@ class AgentAdapters:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def get_conversation_agent_by_id(agent_id: int):
|
||||
agent = Agent.objects.filter(id=agent_id).first()
|
||||
if agent == AgentAdapters.get_default_agent():
|
||||
async def aget_conversation_agent_by_id(agent_id: int):
|
||||
agent = await Agent.objects.filter(id=agent_id).afirst()
|
||||
if agent == await AgentAdapters.aget_default_agent():
|
||||
# If the agent is set to the default agent, then return None and let the default application code be used
|
||||
return None
|
||||
return agent
|
||||
@@ -1109,14 +1109,6 @@ class ConversationAdapters:
|
||||
async def aget_all_chat_models():
|
||||
return await sync_to_async(list)(ChatModel.objects.prefetch_related("ai_model_api").all())
|
||||
|
||||
@staticmethod
|
||||
def get_vision_enabled_config():
|
||||
chat_models = ConversationAdapters.get_all_chat_models()
|
||||
for config in chat_models:
|
||||
if config.vision_enabled:
|
||||
return config
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def aget_vision_enabled_config():
|
||||
chat_models = await ConversationAdapters.aget_all_chat_models()
|
||||
@@ -1171,7 +1163,11 @@ class ConversationAdapters:
|
||||
@staticmethod
|
||||
async def aget_chat_model(user: KhojUser):
|
||||
subscribed = await ais_user_subscribed(user)
|
||||
config = await UserConversationConfig.objects.filter(user=user).prefetch_related("setting").afirst()
|
||||
config = (
|
||||
await UserConversationConfig.objects.filter(user=user)
|
||||
.prefetch_related("setting", "setting__ai_model_api")
|
||||
.afirst()
|
||||
)
|
||||
if subscribed:
|
||||
# Subscibed users can use any available chat model
|
||||
if config:
|
||||
@@ -1387,7 +1383,7 @@ class ConversationAdapters:
|
||||
|
||||
@staticmethod
|
||||
@require_valid_user
|
||||
def save_conversation(
|
||||
async def save_conversation(
|
||||
user: KhojUser,
|
||||
conversation_log: dict,
|
||||
client_application: ClientApplication = None,
|
||||
@@ -1396,19 +1392,21 @@ class ConversationAdapters:
|
||||
):
|
||||
slug = user_message.strip()[:200] if user_message else None
|
||||
if conversation_id:
|
||||
conversation = Conversation.objects.filter(user=user, client=client_application, id=conversation_id).first()
|
||||
conversation = await Conversation.objects.filter(
|
||||
user=user, client=client_application, id=conversation_id
|
||||
).afirst()
|
||||
else:
|
||||
conversation = (
|
||||
Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at").first()
|
||||
await Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at").afirst()
|
||||
)
|
||||
|
||||
if conversation:
|
||||
conversation.conversation_log = conversation_log
|
||||
conversation.slug = slug
|
||||
conversation.updated_at = datetime.now(tz=timezone.utc)
|
||||
conversation.save()
|
||||
await conversation.asave()
|
||||
else:
|
||||
Conversation.objects.create(
|
||||
await Conversation.objects.acreate(
|
||||
user=user, conversation_log=conversation_log, client=client_application, slug=slug
|
||||
)
|
||||
|
||||
@@ -1455,17 +1453,21 @@ class ConversationAdapters:
|
||||
return random.sample(all_questions, max_results)
|
||||
|
||||
@staticmethod
|
||||
def get_valid_chat_model(user: KhojUser, conversation: Conversation, is_subscribed: bool):
|
||||
async def aget_valid_chat_model(user: KhojUser, conversation: Conversation, is_subscribed: bool):
|
||||
agent: Agent = (
|
||||
conversation.agent if is_subscribed and AgentAdapters.get_default_agent() != conversation.agent else None
|
||||
conversation.agent
|
||||
if is_subscribed and await AgentAdapters.aget_default_agent() != conversation.agent
|
||||
else None
|
||||
)
|
||||
if agent and agent.chat_model:
|
||||
chat_model = conversation.agent.chat_model
|
||||
chat_model = await ChatModel.objects.select_related("ai_model_api").aget(
|
||||
pk=conversation.agent.chat_model.pk
|
||||
)
|
||||
else:
|
||||
chat_model = ConversationAdapters.get_chat_model(user)
|
||||
chat_model = await ConversationAdapters.aget_chat_model(user)
|
||||
|
||||
if chat_model is None:
|
||||
chat_model = ConversationAdapters.get_default_chat_model()
|
||||
chat_model = await ConversationAdapters.aget_default_chat_model()
|
||||
|
||||
if chat_model.model_type == ChatModel.ModelType.OFFLINE:
|
||||
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional
|
||||
from typing import AsyncGenerator, Dict, List, Optional
|
||||
|
||||
import pyjson5
|
||||
from langchain.schema import ChatMessage
|
||||
@@ -137,7 +137,7 @@ def anthropic_send_message_to_model(
|
||||
)
|
||||
|
||||
|
||||
def converse_anthropic(
|
||||
async def converse_anthropic(
|
||||
references,
|
||||
user_query,
|
||||
online_results: Optional[Dict[str, Dict]] = None,
|
||||
@@ -161,7 +161,7 @@ def converse_anthropic(
|
||||
generated_asset_results: Dict[str, Dict] = {},
|
||||
deepthought: Optional[bool] = False,
|
||||
tracer: dict = {},
|
||||
):
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Converse with user using Anthropic's Claude
|
||||
"""
|
||||
@@ -191,11 +191,17 @@ def converse_anthropic(
|
||||
|
||||
# Get Conversation Primer appropriate to Conversation Type
|
||||
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
|
||||
completion_func(chat_response=prompts.no_notes_found.format())
|
||||
return iter([prompts.no_notes_found.format()])
|
||||
response = prompts.no_notes_found.format()
|
||||
if completion_func:
|
||||
await completion_func(chat_response=response)
|
||||
yield response
|
||||
return
|
||||
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
|
||||
completion_func(chat_response=prompts.no_online_results_found.format())
|
||||
return iter([prompts.no_online_results_found.format()])
|
||||
response = prompts.no_online_results_found.format()
|
||||
if completion_func:
|
||||
await completion_func(chat_response=response)
|
||||
yield response
|
||||
return
|
||||
|
||||
context_message = ""
|
||||
if not is_none_or_empty(references):
|
||||
@@ -228,17 +234,21 @@ def converse_anthropic(
|
||||
logger.debug(f"Conversation Context for Claude: {messages_to_print(messages)}")
|
||||
|
||||
# Get Response from Claude
|
||||
return anthropic_chat_completion_with_backoff(
|
||||
full_response = ""
|
||||
async for chunk in anthropic_chat_completion_with_backoff(
|
||||
messages=messages,
|
||||
compiled_references=references,
|
||||
online_results=online_results,
|
||||
model_name=model,
|
||||
temperature=0.2,
|
||||
api_key=api_key,
|
||||
api_base_url=api_base_url,
|
||||
system_prompt=system_prompt,
|
||||
completion_func=completion_func,
|
||||
max_prompt_size=max_prompt_size,
|
||||
deepthought=deepthought,
|
||||
tracer=tracer,
|
||||
)
|
||||
):
|
||||
full_response += chunk
|
||||
yield chunk
|
||||
|
||||
# Call completion_func once finish streaming and we have the full response
|
||||
if completion_func:
|
||||
await completion_func(chat_response=full_response)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from threading import Thread
|
||||
from time import perf_counter
|
||||
from typing import Dict, List
|
||||
|
||||
import anthropic
|
||||
@@ -13,13 +13,13 @@ from tenacity import (
|
||||
)
|
||||
|
||||
from khoj.processor.conversation.utils import (
|
||||
ThreadedGenerator,
|
||||
commit_conversation_trace,
|
||||
get_image_from_base64,
|
||||
get_image_from_url,
|
||||
)
|
||||
from khoj.utils.helpers import (
|
||||
get_ai_api_info,
|
||||
get_anthropic_async_client,
|
||||
get_anthropic_client,
|
||||
get_chat_usage_metrics,
|
||||
is_none_or_empty,
|
||||
is_promptrace_enabled,
|
||||
@@ -28,24 +28,12 @@ from khoj.utils.helpers import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
anthropic_clients: Dict[str, anthropic.Anthropic | anthropic.AnthropicVertex] = {}
|
||||
anthropic_async_clients: Dict[str, anthropic.AsyncAnthropic | anthropic.AsyncAnthropicVertex] = {}
|
||||
|
||||
DEFAULT_MAX_TOKENS_ANTHROPIC = 8000
|
||||
MAX_REASONING_TOKENS_ANTHROPIC = 12000
|
||||
|
||||
|
||||
def get_anthropic_client(api_key, api_base_url=None) -> anthropic.Anthropic | anthropic.AnthropicVertex:
|
||||
api_info = get_ai_api_info(api_key, api_base_url)
|
||||
if api_info.api_key:
|
||||
client = anthropic.Anthropic(api_key=api_info.api_key)
|
||||
else:
|
||||
client = anthropic.AnthropicVertex(
|
||||
region=api_info.region,
|
||||
project_id=api_info.project,
|
||||
credentials=api_info.credentials,
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
@retry(
|
||||
wait=wait_random_exponential(min=1, max=10),
|
||||
stop=stop_after_attempt(2),
|
||||
@@ -126,60 +114,23 @@ def anthropic_completion_with_backoff(
|
||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
reraise=True,
|
||||
)
|
||||
def anthropic_chat_completion_with_backoff(
|
||||
async def anthropic_chat_completion_with_backoff(
|
||||
messages: list[ChatMessage],
|
||||
compiled_references,
|
||||
online_results,
|
||||
model_name,
|
||||
temperature,
|
||||
api_key,
|
||||
api_base_url,
|
||||
system_prompt: str,
|
||||
max_prompt_size=None,
|
||||
completion_func=None,
|
||||
deepthought=False,
|
||||
model_kwargs=None,
|
||||
tracer={},
|
||||
):
|
||||
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
|
||||
t = Thread(
|
||||
target=anthropic_llm_thread,
|
||||
args=(
|
||||
g,
|
||||
messages,
|
||||
system_prompt,
|
||||
model_name,
|
||||
temperature,
|
||||
api_key,
|
||||
api_base_url,
|
||||
max_prompt_size,
|
||||
deepthought,
|
||||
model_kwargs,
|
||||
tracer,
|
||||
),
|
||||
)
|
||||
t.start()
|
||||
return g
|
||||
|
||||
|
||||
def anthropic_llm_thread(
|
||||
g,
|
||||
messages: list[ChatMessage],
|
||||
system_prompt: str,
|
||||
model_name: str,
|
||||
temperature,
|
||||
api_key,
|
||||
api_base_url=None,
|
||||
max_prompt_size=None,
|
||||
deepthought=False,
|
||||
model_kwargs=None,
|
||||
tracer={},
|
||||
):
|
||||
try:
|
||||
client = anthropic_clients.get(api_key)
|
||||
client = anthropic_async_clients.get(api_key)
|
||||
if not client:
|
||||
client = get_anthropic_client(api_key, api_base_url)
|
||||
anthropic_clients[api_key] = client
|
||||
client = get_anthropic_async_client(api_key, api_base_url)
|
||||
anthropic_async_clients[api_key] = client
|
||||
|
||||
model_kwargs = model_kwargs or dict()
|
||||
max_tokens = DEFAULT_MAX_TOKENS_ANTHROPIC
|
||||
@@ -193,7 +144,8 @@ def anthropic_llm_thread(
|
||||
|
||||
aggregated_response = ""
|
||||
final_message = None
|
||||
with client.messages.stream(
|
||||
start_time = perf_counter()
|
||||
async with client.messages.stream(
|
||||
messages=formatted_messages,
|
||||
model=model_name, # type: ignore
|
||||
temperature=temperature,
|
||||
@@ -202,10 +154,17 @@ def anthropic_llm_thread(
|
||||
max_tokens=max_tokens,
|
||||
**model_kwargs,
|
||||
) as stream:
|
||||
for text in stream.text_stream:
|
||||
async for text in stream.text_stream:
|
||||
# Log the time taken to start response
|
||||
if aggregated_response == "":
|
||||
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
|
||||
# Handle streamed response chunk
|
||||
aggregated_response += text
|
||||
g.send(text)
|
||||
final_message = stream.get_final_message()
|
||||
yield text
|
||||
final_message = await stream.get_final_message()
|
||||
|
||||
# Log the time taken to stream the entire response
|
||||
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
|
||||
|
||||
# Calculate cost of chat
|
||||
input_tokens = final_message.usage.input_tokens
|
||||
@@ -222,9 +181,7 @@ def anthropic_llm_thread(
|
||||
if is_promptrace_enabled():
|
||||
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in anthropic_llm_thread: {e}", exc_info=True)
|
||||
finally:
|
||||
g.close()
|
||||
logger.error(f"Error in anthropic_chat_completion_with_backoff stream: {e}", exc_info=True)
|
||||
|
||||
|
||||
def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt: str = None):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional
|
||||
from typing import AsyncGenerator, Dict, List, Optional
|
||||
|
||||
import pyjson5
|
||||
from langchain.schema import ChatMessage
|
||||
@@ -160,7 +160,7 @@ def gemini_send_message_to_model(
|
||||
)
|
||||
|
||||
|
||||
def converse_gemini(
|
||||
async def converse_gemini(
|
||||
references,
|
||||
user_query,
|
||||
online_results: Optional[Dict[str, Dict]] = None,
|
||||
@@ -185,7 +185,7 @@ def converse_gemini(
|
||||
program_execution_context: List[str] = None,
|
||||
deepthought: Optional[bool] = False,
|
||||
tracer={},
|
||||
):
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Converse with user using Google's Gemini
|
||||
"""
|
||||
@@ -216,11 +216,17 @@ def converse_gemini(
|
||||
|
||||
# Get Conversation Primer appropriate to Conversation Type
|
||||
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
|
||||
completion_func(chat_response=prompts.no_notes_found.format())
|
||||
return iter([prompts.no_notes_found.format()])
|
||||
response = prompts.no_notes_found.format()
|
||||
if completion_func:
|
||||
await completion_func(chat_response=response)
|
||||
yield response
|
||||
return
|
||||
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
|
||||
completion_func(chat_response=prompts.no_online_results_found.format())
|
||||
return iter([prompts.no_online_results_found.format()])
|
||||
response = prompts.no_online_results_found.format()
|
||||
if completion_func:
|
||||
await completion_func(chat_response=response)
|
||||
yield response
|
||||
return
|
||||
|
||||
context_message = ""
|
||||
if not is_none_or_empty(references):
|
||||
@@ -253,16 +259,20 @@ def converse_gemini(
|
||||
logger.debug(f"Conversation Context for Gemini: {messages_to_print(messages)}")
|
||||
|
||||
# Get Response from Google AI
|
||||
return gemini_chat_completion_with_backoff(
|
||||
full_response = ""
|
||||
async for chunk in gemini_chat_completion_with_backoff(
|
||||
messages=messages,
|
||||
compiled_references=references,
|
||||
online_results=online_results,
|
||||
model_name=model,
|
||||
temperature=temperature,
|
||||
api_key=api_key,
|
||||
api_base_url=api_base_url,
|
||||
system_prompt=system_prompt,
|
||||
completion_func=completion_func,
|
||||
deepthought=deepthought,
|
||||
tracer=tracer,
|
||||
)
|
||||
):
|
||||
full_response += chunk
|
||||
yield chunk
|
||||
|
||||
# Call completion_func once finish streaming and we have the full response
|
||||
if completion_func:
|
||||
await completion_func(chat_response=full_response)
|
||||
|
||||
@@ -2,8 +2,8 @@ import logging
|
||||
import os
|
||||
import random
|
||||
from copy import deepcopy
|
||||
from threading import Thread
|
||||
from typing import Dict
|
||||
from time import perf_counter
|
||||
from typing import AsyncGenerator, AsyncIterator, Dict
|
||||
|
||||
from google import genai
|
||||
from google.genai import errors as gerrors
|
||||
@@ -19,14 +19,13 @@ from tenacity import (
|
||||
)
|
||||
|
||||
from khoj.processor.conversation.utils import (
|
||||
ThreadedGenerator,
|
||||
commit_conversation_trace,
|
||||
get_image_from_base64,
|
||||
get_image_from_url,
|
||||
)
|
||||
from khoj.utils.helpers import (
|
||||
get_ai_api_info,
|
||||
get_chat_usage_metrics,
|
||||
get_gemini_client,
|
||||
is_none_or_empty,
|
||||
is_promptrace_enabled,
|
||||
)
|
||||
@@ -62,17 +61,6 @@ SAFETY_SETTINGS = [
|
||||
]
|
||||
|
||||
|
||||
def get_gemini_client(api_key, api_base_url=None) -> genai.Client:
|
||||
api_info = get_ai_api_info(api_key, api_base_url)
|
||||
return genai.Client(
|
||||
location=api_info.region,
|
||||
project=api_info.project,
|
||||
credentials=api_info.credentials,
|
||||
api_key=api_info.api_key,
|
||||
vertexai=api_info.api_key is None,
|
||||
)
|
||||
|
||||
|
||||
@retry(
|
||||
wait=wait_random_exponential(min=1, max=10),
|
||||
stop=stop_after_attempt(2),
|
||||
@@ -132,8 +120,8 @@ def gemini_completion_with_backoff(
|
||||
)
|
||||
|
||||
# Aggregate cost of chat
|
||||
input_tokens = response.usage_metadata.prompt_token_count if response else 0
|
||||
output_tokens = response.usage_metadata.candidates_token_count if response else 0
|
||||
input_tokens = response.usage_metadata.prompt_token_count or 0 if response else 0
|
||||
output_tokens = response.usage_metadata.candidates_token_count or 0 if response else 0
|
||||
thought_tokens = response.usage_metadata.thoughts_token_count or 0 if response else 0
|
||||
tracer["usage"] = get_chat_usage_metrics(
|
||||
model_name, input_tokens, output_tokens, thought_tokens=thought_tokens, usage=tracer.get("usage")
|
||||
@@ -154,52 +142,17 @@ def gemini_completion_with_backoff(
|
||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
reraise=True,
|
||||
)
|
||||
def gemini_chat_completion_with_backoff(
|
||||
async def gemini_chat_completion_with_backoff(
|
||||
messages,
|
||||
compiled_references,
|
||||
online_results,
|
||||
model_name,
|
||||
temperature,
|
||||
api_key,
|
||||
api_base_url,
|
||||
system_prompt,
|
||||
completion_func=None,
|
||||
model_kwargs=None,
|
||||
deepthought=False,
|
||||
tracer: dict = {},
|
||||
):
|
||||
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
|
||||
t = Thread(
|
||||
target=gemini_llm_thread,
|
||||
args=(
|
||||
g,
|
||||
messages,
|
||||
system_prompt,
|
||||
model_name,
|
||||
temperature,
|
||||
api_key,
|
||||
api_base_url,
|
||||
model_kwargs,
|
||||
deepthought,
|
||||
tracer,
|
||||
),
|
||||
)
|
||||
t.start()
|
||||
return g
|
||||
|
||||
|
||||
def gemini_llm_thread(
|
||||
g,
|
||||
messages,
|
||||
system_prompt,
|
||||
model_name,
|
||||
temperature,
|
||||
api_key,
|
||||
api_base_url=None,
|
||||
model_kwargs=None,
|
||||
deepthought=False,
|
||||
tracer: dict = {},
|
||||
):
|
||||
) -> AsyncGenerator[str, None]:
|
||||
try:
|
||||
client = gemini_clients.get(api_key)
|
||||
if not client:
|
||||
@@ -224,21 +177,32 @@ def gemini_llm_thread(
|
||||
)
|
||||
|
||||
aggregated_response = ""
|
||||
|
||||
for chunk in client.models.generate_content_stream(
|
||||
final_chunk = None
|
||||
start_time = perf_counter()
|
||||
chat_stream: AsyncIterator[gtypes.GenerateContentResponse] = await client.aio.models.generate_content_stream(
|
||||
model=model_name, config=config, contents=formatted_messages
|
||||
):
|
||||
)
|
||||
async for chunk in chat_stream:
|
||||
# Log the time taken to start response
|
||||
if final_chunk is None:
|
||||
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
|
||||
# Keep track of the last chunk for usage data
|
||||
final_chunk = chunk
|
||||
# Handle streamed response chunk
|
||||
message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback)
|
||||
message = message or chunk.text
|
||||
aggregated_response += message
|
||||
g.send(message)
|
||||
yield message
|
||||
if stopped:
|
||||
raise ValueError(message)
|
||||
|
||||
# Log the time taken to stream the entire response
|
||||
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
|
||||
|
||||
# Calculate cost of chat
|
||||
input_tokens = chunk.usage_metadata.prompt_token_count
|
||||
output_tokens = chunk.usage_metadata.candidates_token_count
|
||||
thought_tokens = chunk.usage_metadata.thoughts_token_count or 0
|
||||
input_tokens = final_chunk.usage_metadata.prompt_token_count or 0 if final_chunk else 0
|
||||
output_tokens = final_chunk.usage_metadata.candidates_token_count or 0 if final_chunk else 0
|
||||
thought_tokens = final_chunk.usage_metadata.thoughts_token_count or 0 if final_chunk else 0
|
||||
tracer["usage"] = get_chat_usage_metrics(
|
||||
model_name, input_tokens, output_tokens, thought_tokens=thought_tokens, usage=tracer.get("usage")
|
||||
)
|
||||
@@ -254,9 +218,7 @@ def gemini_llm_thread(
|
||||
+ f"Last Message by {messages[-1].role}: {messages[-1].content}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in gemini_llm_thread: {e}", exc_info=True)
|
||||
finally:
|
||||
g.close()
|
||||
logger.error(f"Error in gemini_chat_completion_with_backoff stream: {e}", exc_info=True)
|
||||
|
||||
|
||||
def handle_gemini_response(
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import json
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
from threading import Thread
|
||||
from typing import Any, Dict, Iterator, List, Optional, Union
|
||||
from time import perf_counter
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
||||
|
||||
import pyjson5
|
||||
from langchain.schema import ChatMessage
|
||||
@@ -13,7 +14,6 @@ from khoj.database.models import Agent, ChatModel, KhojUser
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.offline.utils import download_model
|
||||
from khoj.processor.conversation.utils import (
|
||||
ThreadedGenerator,
|
||||
clean_json,
|
||||
commit_conversation_trace,
|
||||
generate_chatml_messages_with_context,
|
||||
@@ -147,7 +147,7 @@ def filter_questions(questions: List[str]):
|
||||
return list(filtered_questions)
|
||||
|
||||
|
||||
def converse_offline(
|
||||
async def converse_offline(
|
||||
user_query,
|
||||
references=[],
|
||||
online_results={},
|
||||
@@ -167,9 +167,9 @@ def converse_offline(
|
||||
additional_context: List[str] = None,
|
||||
generated_asset_results: Dict[str, Dict] = {},
|
||||
tracer: dict = {},
|
||||
) -> Union[ThreadedGenerator, Iterator[str]]:
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Converse with user using Llama
|
||||
Converse with user using Llama (Async Version)
|
||||
"""
|
||||
# Initialize Variables
|
||||
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
||||
@@ -200,10 +200,17 @@ def converse_offline(
|
||||
|
||||
# Get Conversation Primer appropriate to Conversation Type
|
||||
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
|
||||
return iter([prompts.no_notes_found.format()])
|
||||
response = prompts.no_notes_found.format()
|
||||
if completion_func:
|
||||
await completion_func(chat_response=response)
|
||||
yield response
|
||||
return
|
||||
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
|
||||
completion_func(chat_response=prompts.no_online_results_found.format())
|
||||
return iter([prompts.no_online_results_found.format()])
|
||||
response = prompts.no_online_results_found.format()
|
||||
if completion_func:
|
||||
await completion_func(chat_response=response)
|
||||
yield response
|
||||
return
|
||||
|
||||
context_message = ""
|
||||
if not is_none_or_empty(references):
|
||||
@@ -240,33 +247,77 @@ def converse_offline(
|
||||
|
||||
logger.debug(f"Conversation Context for {model_name}: {messages_to_print(messages)}")
|
||||
|
||||
g = ThreadedGenerator(references, online_results, completion_func=completion_func)
|
||||
t = Thread(target=llm_thread, args=(g, messages, offline_chat_model, max_prompt_size, tracer))
|
||||
t.start()
|
||||
return g
|
||||
|
||||
|
||||
def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int = None, tracer: dict = {}):
|
||||
# Use asyncio.Queue and a thread to bridge sync iterator
|
||||
queue: asyncio.Queue = asyncio.Queue()
|
||||
stop_phrases = ["<s>", "INST]", "Notes:"]
|
||||
aggregated_response = ""
|
||||
aggregated_response_container = {"response": ""}
|
||||
|
||||
state.chat_lock.acquire()
|
||||
try:
|
||||
response_iterator = send_message_to_model_offline(
|
||||
messages, loaded_model=model, stop=stop_phrases, max_prompt_size=max_prompt_size, streaming=True
|
||||
)
|
||||
for response in response_iterator:
|
||||
response_delta = response["choices"][0]["delta"].get("content", "")
|
||||
aggregated_response += response_delta
|
||||
g.send(response_delta)
|
||||
def _sync_llm_thread():
|
||||
"""Synchronous function to run in a separate thread."""
|
||||
aggregated_response = ""
|
||||
start_time = perf_counter()
|
||||
state.chat_lock.acquire()
|
||||
try:
|
||||
response_iterator = send_message_to_model_offline(
|
||||
messages,
|
||||
loaded_model=offline_chat_model,
|
||||
stop=stop_phrases,
|
||||
max_prompt_size=max_prompt_size,
|
||||
streaming=True,
|
||||
tracer=tracer,
|
||||
)
|
||||
for response in response_iterator:
|
||||
response_delta = response["choices"][0]["delta"].get("content", "")
|
||||
# Log the time taken to start response
|
||||
if aggregated_response == "" and response_delta != "":
|
||||
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
|
||||
# Handle response chunk
|
||||
aggregated_response += response_delta
|
||||
# Put chunk into the asyncio queue (non-blocking)
|
||||
try:
|
||||
queue.put_nowait(response_delta)
|
||||
except asyncio.QueueFull:
|
||||
# Should not happen with default queue size unless consumer is very slow
|
||||
logger.warning("Asyncio queue full during offline LLM streaming.")
|
||||
# Potentially block here or handle differently if needed
|
||||
asyncio.run(queue.put(response_delta))
|
||||
|
||||
# Save conversation trace
|
||||
if is_promptrace_enabled():
|
||||
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||
# Log the time taken to stream the entire response
|
||||
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
|
||||
|
||||
finally:
|
||||
state.chat_lock.release()
|
||||
g.close()
|
||||
# Save conversation trace
|
||||
tracer["chat_model"] = model_name
|
||||
if is_promptrace_enabled():
|
||||
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in offline LLM thread: {e}", exc_info=True)
|
||||
finally:
|
||||
state.chat_lock.release()
|
||||
# Signal end of stream
|
||||
queue.put_nowait(None)
|
||||
aggregated_response_container["response"] = aggregated_response
|
||||
|
||||
# Start the synchronous thread
|
||||
thread = Thread(target=_sync_llm_thread)
|
||||
thread.start()
|
||||
|
||||
# Asynchronously consume from the queue
|
||||
while True:
|
||||
chunk = await queue.get()
|
||||
if chunk is None: # End of stream signal
|
||||
queue.task_done()
|
||||
break
|
||||
yield chunk
|
||||
queue.task_done()
|
||||
|
||||
# Wait for the thread to finish (optional, ensures cleanup)
|
||||
loop = asyncio.get_running_loop()
|
||||
await loop.run_in_executor(None, thread.join)
|
||||
|
||||
# Call the completion function after streaming is done
|
||||
if completion_func:
|
||||
await completion_func(chat_response=aggregated_response_container["response"])
|
||||
|
||||
|
||||
def send_message_to_model_offline(
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional
|
||||
from typing import AsyncGenerator, Dict, List, Optional
|
||||
|
||||
import pyjson5
|
||||
from langchain.schema import ChatMessage
|
||||
@@ -162,7 +162,7 @@ def send_message_to_model(
|
||||
)
|
||||
|
||||
|
||||
def converse_openai(
|
||||
async def converse_openai(
|
||||
references,
|
||||
user_query,
|
||||
online_results: Optional[Dict[str, Dict]] = None,
|
||||
@@ -187,7 +187,7 @@ def converse_openai(
|
||||
program_execution_context: List[str] = None,
|
||||
deepthought: Optional[bool] = False,
|
||||
tracer: dict = {},
|
||||
):
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Converse with user using OpenAI's ChatGPT
|
||||
"""
|
||||
@@ -217,11 +217,17 @@ def converse_openai(
|
||||
|
||||
# Get Conversation Primer appropriate to Conversation Type
|
||||
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
|
||||
completion_func(chat_response=prompts.no_notes_found.format())
|
||||
return iter([prompts.no_notes_found.format()])
|
||||
response = prompts.no_notes_found.format()
|
||||
if completion_func:
|
||||
await completion_func(chat_response=response)
|
||||
yield response
|
||||
return
|
||||
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
|
||||
completion_func(chat_response=prompts.no_online_results_found.format())
|
||||
return iter([prompts.no_online_results_found.format()])
|
||||
response = prompts.no_online_results_found.format()
|
||||
if completion_func:
|
||||
await completion_func(chat_response=response)
|
||||
yield response
|
||||
return
|
||||
|
||||
context_message = ""
|
||||
if not is_none_or_empty(references):
|
||||
@@ -255,19 +261,23 @@ def converse_openai(
|
||||
logger.debug(f"Conversation Context for GPT: {messages_to_print(messages)}")
|
||||
|
||||
# Get Response from GPT
|
||||
return chat_completion_with_backoff(
|
||||
full_response = ""
|
||||
async for chunk in chat_completion_with_backoff(
|
||||
messages=messages,
|
||||
compiled_references=references,
|
||||
online_results=online_results,
|
||||
model_name=model,
|
||||
temperature=temperature,
|
||||
openai_api_key=api_key,
|
||||
api_base_url=api_base_url,
|
||||
completion_func=completion_func,
|
||||
deepthought=deepthought,
|
||||
model_kwargs={"stop": ["Notes:\n["]},
|
||||
tracer=tracer,
|
||||
)
|
||||
):
|
||||
full_response += chunk
|
||||
yield chunk
|
||||
|
||||
# Call completion_func once finish streaming and we have the full response
|
||||
if completion_func:
|
||||
await completion_func(chat_response=full_response)
|
||||
|
||||
|
||||
def clean_response_schema(schema: BaseModel | dict) -> dict:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import os
|
||||
from threading import Thread
|
||||
from typing import Dict, List
|
||||
from time import perf_counter
|
||||
from typing import AsyncGenerator, Dict, List
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import openai
|
||||
@@ -16,13 +16,10 @@ from tenacity import (
|
||||
wait_random_exponential,
|
||||
)
|
||||
|
||||
from khoj.processor.conversation.utils import (
|
||||
JsonSupport,
|
||||
ThreadedGenerator,
|
||||
commit_conversation_trace,
|
||||
)
|
||||
from khoj.processor.conversation.utils import JsonSupport, commit_conversation_trace
|
||||
from khoj.utils.helpers import (
|
||||
get_chat_usage_metrics,
|
||||
get_openai_async_client,
|
||||
get_openai_client,
|
||||
is_promptrace_enabled,
|
||||
)
|
||||
@@ -30,6 +27,7 @@ from khoj.utils.helpers import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
openai_clients: Dict[str, openai.OpenAI] = {}
|
||||
openai_async_clients: Dict[str, openai.AsyncOpenAI] = {}
|
||||
|
||||
|
||||
@retry(
|
||||
@@ -124,45 +122,22 @@ def completion_with_backoff(
|
||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
reraise=True,
|
||||
)
|
||||
def chat_completion_with_backoff(
|
||||
async def chat_completion_with_backoff(
|
||||
messages,
|
||||
compiled_references,
|
||||
online_results,
|
||||
model_name,
|
||||
temperature,
|
||||
openai_api_key=None,
|
||||
api_base_url=None,
|
||||
completion_func=None,
|
||||
deepthought=False,
|
||||
model_kwargs=None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
|
||||
t = Thread(
|
||||
target=llm_thread,
|
||||
args=(g, messages, model_name, temperature, openai_api_key, api_base_url, deepthought, model_kwargs, tracer),
|
||||
)
|
||||
t.start()
|
||||
return g
|
||||
|
||||
|
||||
def llm_thread(
|
||||
g,
|
||||
messages,
|
||||
model_name: str,
|
||||
temperature,
|
||||
openai_api_key=None,
|
||||
api_base_url=None,
|
||||
deepthought=False,
|
||||
model_kwargs: dict = {},
|
||||
tracer: dict = {},
|
||||
):
|
||||
) -> AsyncGenerator[str, None]:
|
||||
try:
|
||||
client_key = f"{openai_api_key}--{api_base_url}"
|
||||
client = openai_clients.get(client_key)
|
||||
client = openai_async_clients.get(client_key)
|
||||
if not client:
|
||||
client = get_openai_client(openai_api_key, api_base_url)
|
||||
openai_clients[client_key] = client
|
||||
client = get_openai_async_client(openai_api_key, api_base_url)
|
||||
openai_async_clients[client_key] = client
|
||||
|
||||
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
|
||||
|
||||
@@ -207,53 +182,58 @@ def llm_thread(
|
||||
if os.getenv("KHOJ_LLM_SEED"):
|
||||
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
|
||||
|
||||
chat: ChatCompletion | openai.Stream[ChatCompletionChunk] = client.chat.completions.create(
|
||||
messages=formatted_messages,
|
||||
model=model_name, # type: ignore
|
||||
aggregated_response = ""
|
||||
final_chunk = None
|
||||
start_time = perf_counter()
|
||||
chat_stream: openai.AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(
|
||||
messages=formatted_messages, # type: ignore
|
||||
model=model_name,
|
||||
stream=stream,
|
||||
temperature=temperature,
|
||||
timeout=20,
|
||||
**model_kwargs,
|
||||
)
|
||||
async for chunk in chat_stream:
|
||||
# Log the time taken to start response
|
||||
if final_chunk is None:
|
||||
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
|
||||
# Keep track of the last chunk for usage data
|
||||
final_chunk = chunk
|
||||
# Handle streamed response chunk
|
||||
if len(chunk.choices) == 0:
|
||||
continue
|
||||
delta_chunk = chunk.choices[0].delta
|
||||
text_chunk = ""
|
||||
if isinstance(delta_chunk, str):
|
||||
text_chunk = delta_chunk
|
||||
elif delta_chunk and delta_chunk.content:
|
||||
text_chunk = delta_chunk.content
|
||||
if text_chunk:
|
||||
aggregated_response += text_chunk
|
||||
yield text_chunk
|
||||
|
||||
aggregated_response = ""
|
||||
if not stream:
|
||||
chunk = chat
|
||||
aggregated_response = chunk.choices[0].message.content
|
||||
g.send(aggregated_response)
|
||||
else:
|
||||
for chunk in chat:
|
||||
if len(chunk.choices) == 0:
|
||||
continue
|
||||
delta_chunk = chunk.choices[0].delta
|
||||
text_chunk = ""
|
||||
if isinstance(delta_chunk, str):
|
||||
text_chunk = delta_chunk
|
||||
elif delta_chunk.content:
|
||||
text_chunk = delta_chunk.content
|
||||
if text_chunk:
|
||||
aggregated_response += text_chunk
|
||||
g.send(text_chunk)
|
||||
# Log the time taken to stream the entire response
|
||||
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
|
||||
|
||||
# Calculate cost of chat
|
||||
input_tokens = chunk.usage.prompt_tokens if hasattr(chunk, "usage") and chunk.usage else 0
|
||||
output_tokens = chunk.usage.completion_tokens if hasattr(chunk, "usage") and chunk.usage else 0
|
||||
cost = (
|
||||
chunk.usage.model_extra.get("estimated_cost", 0) if hasattr(chunk, "usage") and chunk.usage else 0
|
||||
) # Estimated costs returned by DeepInfra API
|
||||
tracer["usage"] = get_chat_usage_metrics(
|
||||
model_name, input_tokens, output_tokens, usage=tracer.get("usage"), cost=cost
|
||||
)
|
||||
# Calculate cost of chat after stream finishes
|
||||
input_tokens, output_tokens, cost = 0, 0, 0
|
||||
if final_chunk and hasattr(final_chunk, "usage") and final_chunk.usage:
|
||||
input_tokens = final_chunk.usage.prompt_tokens
|
||||
output_tokens = final_chunk.usage.completion_tokens
|
||||
# Estimated costs returned by DeepInfra API
|
||||
if final_chunk.usage.model_extra and "estimated_cost" in final_chunk.usage.model_extra:
|
||||
cost = final_chunk.usage.model_extra.get("estimated_cost", 0)
|
||||
|
||||
# Save conversation trace
|
||||
tracer["chat_model"] = model_name
|
||||
tracer["temperature"] = temperature
|
||||
tracer["usage"] = get_chat_usage_metrics(
|
||||
model_name, input_tokens, output_tokens, usage=tracer.get("usage"), cost=cost
|
||||
)
|
||||
if is_promptrace_enabled():
|
||||
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in llm_thread: {e}", exc_info=True)
|
||||
finally:
|
||||
g.close()
|
||||
logger.error(f"Error in chat_completion_with_backoff stream: {e}", exc_info=True)
|
||||
|
||||
|
||||
def get_openai_api_json_support(model_name: str, api_base_url: str = None) -> JsonSupport:
|
||||
|
||||
@@ -77,42 +77,6 @@ model_to_prompt_size = {
|
||||
model_to_tokenizer: Dict[str, str] = {}
|
||||
|
||||
|
||||
class ThreadedGenerator:
|
||||
def __init__(self, compiled_references, online_results, completion_func=None):
|
||||
self.queue = queue.Queue()
|
||||
self.compiled_references = compiled_references
|
||||
self.online_results = online_results
|
||||
self.completion_func = completion_func
|
||||
self.response = ""
|
||||
self.start_time = perf_counter()
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
item = self.queue.get()
|
||||
if item is StopIteration:
|
||||
time_to_response = perf_counter() - self.start_time
|
||||
logger.info(f"Chat streaming took: {time_to_response:.3f} seconds")
|
||||
if self.completion_func:
|
||||
# The completion func effectively acts as a callback.
|
||||
# It adds the aggregated response to the conversation history.
|
||||
self.completion_func(chat_response=self.response)
|
||||
raise StopIteration
|
||||
return item
|
||||
|
||||
def send(self, data):
|
||||
if self.response == "":
|
||||
time_to_first_response = perf_counter() - self.start_time
|
||||
logger.info(f"First response took: {time_to_first_response:.3f} seconds")
|
||||
|
||||
self.response += data
|
||||
self.queue.put(data)
|
||||
|
||||
def close(self):
|
||||
self.queue.put(StopIteration)
|
||||
|
||||
|
||||
class InformationCollectionIteration:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -254,7 +218,7 @@ def message_to_log(
|
||||
return conversation_log
|
||||
|
||||
|
||||
def save_to_conversation_log(
|
||||
async def save_to_conversation_log(
|
||||
q: str,
|
||||
chat_response: str,
|
||||
user: KhojUser,
|
||||
@@ -306,7 +270,7 @@ def save_to_conversation_log(
|
||||
khoj_message_metadata=khoj_message_metadata,
|
||||
conversation_log=meta_log.get("chat", []),
|
||||
)
|
||||
ConversationAdapters.save_conversation(
|
||||
await ConversationAdapters.save_conversation(
|
||||
user,
|
||||
{"chat": updated_conversation},
|
||||
client_application=client_application,
|
||||
|
||||
@@ -67,7 +67,6 @@ from khoj.routers.research import (
|
||||
from khoj.routers.storage import upload_user_image_to_bucket
|
||||
from khoj.utils import state
|
||||
from khoj.utils.helpers import (
|
||||
AsyncIteratorWrapper,
|
||||
ConversationCommand,
|
||||
command_descriptions,
|
||||
convert_image_to_webp,
|
||||
@@ -999,7 +998,7 @@ async def chat(
|
||||
return
|
||||
|
||||
llm_response = construct_automation_created_message(automation, crontime, query_to_run, subject)
|
||||
await sync_to_async(save_to_conversation_log)(
|
||||
await save_to_conversation_log(
|
||||
q,
|
||||
llm_response,
|
||||
user,
|
||||
@@ -1308,26 +1307,31 @@ async def chat(
|
||||
yield result
|
||||
|
||||
continue_stream = True
|
||||
iterator = AsyncIteratorWrapper(llm_response)
|
||||
async for item in iterator:
|
||||
async for item in llm_response:
|
||||
# Should not happen with async generator, end is signaled by loop exit. Skip.
|
||||
if item is None:
|
||||
async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
|
||||
yield result
|
||||
# Send Usage Metadata once llm interactions are complete
|
||||
async for event in send_event(ChatEvent.USAGE, tracer.get("usage")):
|
||||
yield event
|
||||
async for result in send_event(ChatEvent.END_RESPONSE, ""):
|
||||
yield result
|
||||
logger.debug("Finished streaming response")
|
||||
return
|
||||
continue
|
||||
if not connection_alive or not continue_stream:
|
||||
# Drain the generator if disconnected but keep processing internally
|
||||
continue
|
||||
try:
|
||||
async for result in send_event(ChatEvent.MESSAGE, f"{item}"):
|
||||
yield result
|
||||
except Exception as e:
|
||||
continue_stream = False
|
||||
logger.info(f"User {user} disconnected. Emitting rest of responses to clear thread: {e}")
|
||||
logger.info(f"User {user} disconnected or error during streaming. Stopping send: {e}")
|
||||
|
||||
# Signal end of LLM response after the loop finishes
|
||||
if connection_alive:
|
||||
async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
|
||||
yield result
|
||||
# Send Usage Metadata once llm interactions are complete
|
||||
if tracer.get("usage"):
|
||||
async for event in send_event(ChatEvent.USAGE, tracer.get("usage")):
|
||||
yield event
|
||||
async for result in send_event(ChatEvent.END_RESPONSE, ""):
|
||||
yield result
|
||||
logger.debug("Finished streaming response")
|
||||
|
||||
## Stream Text Response
|
||||
if stream:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
@@ -6,9 +5,7 @@ import logging
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from random import random
|
||||
from typing import (
|
||||
@@ -17,7 +14,6 @@ from typing import (
|
||||
AsyncGenerator,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
@@ -97,7 +93,6 @@ from khoj.processor.conversation.openai.gpt import (
|
||||
)
|
||||
from khoj.processor.conversation.utils import (
|
||||
ChatEvent,
|
||||
ThreadedGenerator,
|
||||
clean_json,
|
||||
clean_mermaidjs,
|
||||
construct_chat_history,
|
||||
@@ -126,8 +121,6 @@ from khoj.utils.rawconfig import ChatRequestBody, FileAttachment, FileData, Loca
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
executor = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
|
||||
NOTION_OAUTH_CLIENT_ID = os.getenv("NOTION_OAUTH_CLIENT_ID")
|
||||
NOTION_OAUTH_CLIENT_SECRET = os.getenv("NOTION_OAUTH_CLIENT_SECRET")
|
||||
@@ -262,11 +255,6 @@ def get_conversation_command(query: str) -> ConversationCommand:
|
||||
return ConversationCommand.Default
|
||||
|
||||
|
||||
async def agenerate_chat_response(*args):
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(executor, generate_chat_response, *args)
|
||||
|
||||
|
||||
def gather_raw_query_files(
|
||||
query_files: Dict[str, str],
|
||||
):
|
||||
@@ -1418,7 +1406,7 @@ def send_message_to_model_wrapper_sync(
|
||||
raise HTTPException(status_code=500, detail="Invalid conversation config")
|
||||
|
||||
|
||||
def generate_chat_response(
|
||||
async def agenerate_chat_response(
|
||||
q: str,
|
||||
meta_log: dict,
|
||||
conversation: Conversation,
|
||||
@@ -1444,13 +1432,14 @@ def generate_chat_response(
|
||||
generated_asset_results: Dict[str, Dict] = {},
|
||||
is_subscribed: bool = False,
|
||||
tracer: dict = {},
|
||||
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
|
||||
) -> Tuple[AsyncGenerator[str, None], Dict[str, str]]:
|
||||
# Initialize Variables
|
||||
chat_response = None
|
||||
chat_response_generator = None
|
||||
logger.debug(f"Conversation Types: {conversation_commands}")
|
||||
|
||||
metadata = {}
|
||||
agent = AgentAdapters.get_conversation_agent_by_id(conversation.agent.id) if conversation.agent else None
|
||||
agent = await AgentAdapters.aget_conversation_agent_by_id(conversation.agent.id) if conversation.agent else None
|
||||
|
||||
try:
|
||||
partial_completion = partial(
|
||||
save_to_conversation_log,
|
||||
@@ -1481,17 +1470,17 @@ def generate_chat_response(
|
||||
code_results = {}
|
||||
deepthought = True
|
||||
|
||||
chat_model = ConversationAdapters.get_valid_chat_model(user, conversation, is_subscribed)
|
||||
chat_model = await ConversationAdapters.aget_valid_chat_model(user, conversation, is_subscribed)
|
||||
vision_available = chat_model.vision_enabled
|
||||
if not vision_available and query_images:
|
||||
vision_enabled_config = ConversationAdapters.get_vision_enabled_config()
|
||||
vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config()
|
||||
if vision_enabled_config:
|
||||
chat_model = vision_enabled_config
|
||||
vision_available = True
|
||||
|
||||
if chat_model.model_type == "offline":
|
||||
loaded_model = state.offline_chat_processor_config.loaded_model
|
||||
chat_response = converse_offline(
|
||||
chat_response_generator = converse_offline(
|
||||
user_query=query_to_run,
|
||||
references=compiled_references,
|
||||
online_results=online_results,
|
||||
@@ -1515,7 +1504,7 @@ def generate_chat_response(
|
||||
openai_chat_config = chat_model.ai_model_api
|
||||
api_key = openai_chat_config.api_key
|
||||
chat_model_name = chat_model.name
|
||||
chat_response = converse_openai(
|
||||
chat_response_generator = converse_openai(
|
||||
compiled_references,
|
||||
query_to_run,
|
||||
query_images=query_images,
|
||||
@@ -1544,7 +1533,7 @@ def generate_chat_response(
|
||||
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
|
||||
api_key = chat_model.ai_model_api.api_key
|
||||
api_base_url = chat_model.ai_model_api.api_base_url
|
||||
chat_response = converse_anthropic(
|
||||
chat_response_generator = converse_anthropic(
|
||||
compiled_references,
|
||||
query_to_run,
|
||||
query_images=query_images,
|
||||
@@ -1572,7 +1561,7 @@ def generate_chat_response(
|
||||
elif chat_model.model_type == ChatModel.ModelType.GOOGLE:
|
||||
api_key = chat_model.ai_model_api.api_key
|
||||
api_base_url = chat_model.ai_model_api.api_base_url
|
||||
chat_response = converse_gemini(
|
||||
chat_response_generator = converse_gemini(
|
||||
compiled_references,
|
||||
query_to_run,
|
||||
online_results,
|
||||
@@ -1604,7 +1593,8 @@ def generate_chat_response(
|
||||
logger.error(e, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
return chat_response, metadata
|
||||
# Return the generator directly
|
||||
return chat_response_generator, metadata
|
||||
|
||||
|
||||
class DeleteMessageRequestBody(BaseModel):
|
||||
|
||||
@@ -23,6 +23,7 @@ from time import perf_counter
|
||||
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Tuple, Union
|
||||
from urllib.parse import ParseResult, urlparse
|
||||
|
||||
import anthropic
|
||||
import openai
|
||||
import psutil
|
||||
import pyjson5
|
||||
@@ -30,6 +31,7 @@ import requests
|
||||
import torch
|
||||
from asgiref.sync import sync_to_async
|
||||
from email_validator import EmailNotValidError, EmailUndeliverableError, validate_email
|
||||
from google import genai
|
||||
from google.auth.credentials import Credentials
|
||||
from google.oauth2 import service_account
|
||||
from magika import Magika
|
||||
@@ -729,6 +731,60 @@ def get_openai_client(api_key: str, api_base_url: str) -> Union[openai.OpenAI, o
|
||||
return client
|
||||
|
||||
|
||||
def get_openai_async_client(api_key: str, api_base_url: str) -> Union[openai.AsyncOpenAI, openai.AsyncAzureOpenAI]:
|
||||
"""Get OpenAI or AzureOpenAI client based on the API Base URL"""
|
||||
parsed_url = urlparse(api_base_url)
|
||||
if parsed_url.hostname and parsed_url.hostname.endswith(".openai.azure.com"):
|
||||
client = openai.AsyncAzureOpenAI(
|
||||
api_key=api_key,
|
||||
azure_endpoint=api_base_url,
|
||||
api_version="2024-10-21",
|
||||
)
|
||||
else:
|
||||
client = openai.AsyncOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=api_base_url,
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
def get_anthropic_client(api_key, api_base_url=None) -> anthropic.Anthropic | anthropic.AnthropicVertex:
|
||||
api_info = get_ai_api_info(api_key, api_base_url)
|
||||
if api_info.api_key:
|
||||
client = anthropic.Anthropic(api_key=api_info.api_key)
|
||||
else:
|
||||
client = anthropic.AnthropicVertex(
|
||||
region=api_info.region,
|
||||
project_id=api_info.project,
|
||||
credentials=api_info.credentials,
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
def get_anthropic_async_client(api_key, api_base_url=None) -> anthropic.AsyncAnthropic | anthropic.AsyncAnthropicVertex:
|
||||
api_info = get_ai_api_info(api_key, api_base_url)
|
||||
if api_info.api_key:
|
||||
client = anthropic.AsyncAnthropic(api_key=api_info.api_key)
|
||||
else:
|
||||
client = anthropic.AsyncAnthropicVertex(
|
||||
region=api_info.region,
|
||||
project_id=api_info.project,
|
||||
credentials=api_info.credentials,
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
def get_gemini_client(api_key, api_base_url=None) -> genai.Client:
|
||||
api_info = get_ai_api_info(api_key, api_base_url)
|
||||
return genai.Client(
|
||||
location=api_info.region,
|
||||
project=api_info.project,
|
||||
credentials=api_info.credentials,
|
||||
api_key=api_info.api_key,
|
||||
vertexai=api_info.api_key is None,
|
||||
)
|
||||
|
||||
|
||||
def normalize_email(email: str, check_deliverability=False) -> tuple[str, bool]:
|
||||
"""Normalize, validate and check deliverability of email address"""
|
||||
lower_email = email.lower()
|
||||
|
||||
Reference in New Issue
Block a user