Refactor Openai chat response to stream async, no separate thread

- Refactor chat API to use async/await for Openai streaming
- Fix and clean Openai chat response async streaming
This commit is contained in:
Debanjum
2025-04-19 21:36:45 +05:30
parent c93c0d982e
commit 0751f2ea30
6 changed files with 132 additions and 141 deletions

View File

@@ -763,9 +763,9 @@ class AgentAdapters:
return False
@staticmethod
def get_conversation_agent_by_id(agent_id: int):
agent = Agent.objects.filter(id=agent_id).first()
if agent == AgentAdapters.get_default_agent():
async def aget_conversation_agent_by_id(agent_id: int):
agent = await Agent.objects.filter(id=agent_id).afirst()
if agent == await AgentAdapters.aget_default_agent():
# If the agent is set to the default agent, then return None and let the default application code be used
return None
return agent
@@ -1109,14 +1109,6 @@ class ConversationAdapters:
async def aget_all_chat_models():
return await sync_to_async(list)(ChatModel.objects.prefetch_related("ai_model_api").all())
@staticmethod
def get_vision_enabled_config():
chat_models = ConversationAdapters.get_all_chat_models()
for config in chat_models:
if config.vision_enabled:
return config
return None
@staticmethod
async def aget_vision_enabled_config():
chat_models = await ConversationAdapters.aget_all_chat_models()
@@ -1171,7 +1163,11 @@ class ConversationAdapters:
@staticmethod
async def aget_chat_model(user: KhojUser):
subscribed = await ais_user_subscribed(user)
config = await UserConversationConfig.objects.filter(user=user).prefetch_related("setting").afirst()
config = (
await UserConversationConfig.objects.filter(user=user)
.prefetch_related("setting", "setting__ai_model_api")
.afirst()
)
if subscribed:
# Subscibed users can use any available chat model
if config:
@@ -1387,7 +1383,7 @@ class ConversationAdapters:
@staticmethod
@require_valid_user
def save_conversation(
async def save_conversation(
user: KhojUser,
conversation_log: dict,
client_application: ClientApplication = None,
@@ -1396,19 +1392,21 @@ class ConversationAdapters:
):
slug = user_message.strip()[:200] if user_message else None
if conversation_id:
conversation = Conversation.objects.filter(user=user, client=client_application, id=conversation_id).first()
conversation = await Conversation.objects.filter(
user=user, client=client_application, id=conversation_id
).afirst()
else:
conversation = (
Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at").first()
await Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at").afirst()
)
if conversation:
conversation.conversation_log = conversation_log
conversation.slug = slug
conversation.updated_at = datetime.now(tz=timezone.utc)
conversation.save()
await conversation.asave()
else:
Conversation.objects.create(
await Conversation.objects.acreate(
user=user, conversation_log=conversation_log, client=client_application, slug=slug
)
@@ -1455,17 +1453,21 @@ class ConversationAdapters:
return random.sample(all_questions, max_results)
@staticmethod
def get_valid_chat_model(user: KhojUser, conversation: Conversation, is_subscribed: bool):
async def aget_valid_chat_model(user: KhojUser, conversation: Conversation, is_subscribed: bool):
agent: Agent = (
conversation.agent if is_subscribed and AgentAdapters.get_default_agent() != conversation.agent else None
conversation.agent
if is_subscribed and await AgentAdapters.aget_default_agent() != conversation.agent
else None
)
if agent and agent.chat_model:
chat_model = conversation.agent.chat_model
chat_model = await ChatModel.objects.select_related("ai_model_api").aget(
pk=conversation.agent.chat_model.pk
)
else:
chat_model = ConversationAdapters.get_chat_model(user)
chat_model = await ConversationAdapters.aget_chat_model(user)
if chat_model is None:
chat_model = ConversationAdapters.get_default_chat_model()
chat_model = await ConversationAdapters.aget_default_chat_model()
if chat_model.model_type == ChatModel.ModelType.OFFLINE:
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:

View File

@@ -1,6 +1,6 @@
import logging
from datetime import datetime, timedelta
from typing import Dict, List, Optional
from typing import AsyncGenerator, Dict, List, Optional
import pyjson5
from langchain.schema import ChatMessage
@@ -162,7 +162,7 @@ def send_message_to_model(
)
def converse_openai(
async def converse_openai(
references,
user_query,
online_results: Optional[Dict[str, Dict]] = None,
@@ -187,7 +187,7 @@ def converse_openai(
program_execution_context: List[str] = None,
deepthought: Optional[bool] = False,
tracer: dict = {},
):
) -> AsyncGenerator[str, None]:
"""
Converse with user using OpenAI's ChatGPT
"""
@@ -217,11 +217,17 @@ def converse_openai(
# Get Conversation Primer appropriate to Conversation Type
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
completion_func(chat_response=prompts.no_notes_found.format())
return iter([prompts.no_notes_found.format()])
response = prompts.no_notes_found.format()
if completion_func:
await completion_func(chat_response=response)
yield response
return
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
completion_func(chat_response=prompts.no_online_results_found.format())
return iter([prompts.no_online_results_found.format()])
response = prompts.no_online_results_found.format()
if completion_func:
await completion_func(chat_response=response)
yield response
return
context_message = ""
if not is_none_or_empty(references):
@@ -255,19 +261,23 @@ def converse_openai(
logger.debug(f"Conversation Context for GPT: {messages_to_print(messages)}")
# Get Response from GPT
return chat_completion_with_backoff(
full_response = ""
async for chunk in chat_completion_with_backoff(
messages=messages,
compiled_references=references,
online_results=online_results,
model_name=model,
temperature=temperature,
openai_api_key=api_key,
api_base_url=api_base_url,
completion_func=completion_func,
deepthought=deepthought,
model_kwargs={"stop": ["Notes:\n["]},
tracer=tracer,
)
):
full_response += chunk
yield chunk
# Call completion_func once finish streaming and we have the full response
if completion_func:
await completion_func(chat_response=full_response)
def clean_response_schema(schema: BaseModel | dict) -> dict:

View File

@@ -1,7 +1,7 @@
import logging
import os
from threading import Thread
from typing import Dict, List
from time import perf_counter
from typing import AsyncGenerator, Dict, List
from urllib.parse import urlparse
import openai
@@ -16,13 +16,10 @@ from tenacity import (
wait_random_exponential,
)
from khoj.processor.conversation.utils import (
JsonSupport,
ThreadedGenerator,
commit_conversation_trace,
)
from khoj.processor.conversation.utils import JsonSupport, commit_conversation_trace
from khoj.utils.helpers import (
get_chat_usage_metrics,
get_openai_async_client,
get_openai_client,
is_promptrace_enabled,
)
@@ -30,6 +27,7 @@ from khoj.utils.helpers import (
logger = logging.getLogger(__name__)
openai_clients: Dict[str, openai.OpenAI] = {}
openai_async_clients: Dict[str, openai.AsyncOpenAI] = {}
@retry(
@@ -124,45 +122,22 @@ def completion_with_backoff(
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
def chat_completion_with_backoff(
async def chat_completion_with_backoff(
messages,
compiled_references,
online_results,
model_name,
temperature,
openai_api_key=None,
api_base_url=None,
completion_func=None,
deepthought=False,
model_kwargs=None,
tracer: dict = {},
):
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
t = Thread(
target=llm_thread,
args=(g, messages, model_name, temperature, openai_api_key, api_base_url, deepthought, model_kwargs, tracer),
)
t.start()
return g
def llm_thread(
g,
messages,
model_name: str,
temperature,
openai_api_key=None,
api_base_url=None,
deepthought=False,
model_kwargs: dict = {},
tracer: dict = {},
):
) -> AsyncGenerator[str, None]:
try:
client_key = f"{openai_api_key}--{api_base_url}"
client = openai_clients.get(client_key)
client = openai_async_clients.get(client_key)
if not client:
client = get_openai_client(openai_api_key, api_base_url)
openai_clients[client_key] = client
client = get_openai_async_client(openai_api_key, api_base_url)
openai_async_clients[client_key] = client
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
@@ -207,53 +182,58 @@ def llm_thread(
if os.getenv("KHOJ_LLM_SEED"):
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
chat: ChatCompletion | openai.Stream[ChatCompletionChunk] = client.chat.completions.create(
messages=formatted_messages,
model=model_name, # type: ignore
aggregated_response = ""
final_chunk = None
start_time = perf_counter()
chat_stream: openai.AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(
messages=formatted_messages, # type: ignore
model=model_name,
stream=stream,
temperature=temperature,
timeout=20,
**model_kwargs,
)
async for chunk in chat_stream:
# Log the time taken to start response
if final_chunk is None:
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
# Keep track of the last chunk for usage data
final_chunk = chunk
# Handle streamed response chunk
if len(chunk.choices) == 0:
continue
delta_chunk = chunk.choices[0].delta
text_chunk = ""
if isinstance(delta_chunk, str):
text_chunk = delta_chunk
elif delta_chunk and delta_chunk.content:
text_chunk = delta_chunk.content
if text_chunk:
aggregated_response += text_chunk
yield text_chunk
aggregated_response = ""
if not stream:
chunk = chat
aggregated_response = chunk.choices[0].message.content
g.send(aggregated_response)
else:
for chunk in chat:
if len(chunk.choices) == 0:
continue
delta_chunk = chunk.choices[0].delta
text_chunk = ""
if isinstance(delta_chunk, str):
text_chunk = delta_chunk
elif delta_chunk.content:
text_chunk = delta_chunk.content
if text_chunk:
aggregated_response += text_chunk
g.send(text_chunk)
# Log the time taken to stream the entire response
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
# Calculate cost of chat
input_tokens = chunk.usage.prompt_tokens if hasattr(chunk, "usage") and chunk.usage else 0
output_tokens = chunk.usage.completion_tokens if hasattr(chunk, "usage") and chunk.usage else 0
cost = (
chunk.usage.model_extra.get("estimated_cost", 0) if hasattr(chunk, "usage") and chunk.usage else 0
) # Estimated costs returned by DeepInfra API
tracer["usage"] = get_chat_usage_metrics(
model_name, input_tokens, output_tokens, usage=tracer.get("usage"), cost=cost
)
# Calculate cost of chat after stream finishes
input_tokens, output_tokens, cost = 0, 0, 0
if final_chunk and hasattr(final_chunk, "usage") and final_chunk.usage:
input_tokens = final_chunk.usage.prompt_tokens
output_tokens = final_chunk.usage.completion_tokens
# Estimated costs returned by DeepInfra API
if final_chunk.usage.model_extra and "estimated_cost" in final_chunk.usage.model_extra:
cost = final_chunk.usage.model_extra.get("estimated_cost", 0)
# Save conversation trace
tracer["chat_model"] = model_name
tracer["temperature"] = temperature
tracer["usage"] = get_chat_usage_metrics(
model_name, input_tokens, output_tokens, usage=tracer.get("usage"), cost=cost
)
if is_promptrace_enabled():
commit_conversation_trace(messages, aggregated_response, tracer)
except Exception as e:
logger.error(f"Error in llm_thread: {e}", exc_info=True)
finally:
g.close()
logger.error(f"Error in chat_completion_with_backoff stream: {e}", exc_info=True)
def get_openai_api_json_support(model_name: str, api_base_url: str = None) -> JsonSupport:

View File

@@ -254,7 +254,7 @@ def message_to_log(
return conversation_log
def save_to_conversation_log(
async def save_to_conversation_log(
q: str,
chat_response: str,
user: KhojUser,
@@ -306,7 +306,7 @@ def save_to_conversation_log(
khoj_message_metadata=khoj_message_metadata,
conversation_log=meta_log.get("chat", []),
)
ConversationAdapters.save_conversation(
await ConversationAdapters.save_conversation(
user,
{"chat": updated_conversation},
client_application=client_application,

View File

@@ -67,7 +67,6 @@ from khoj.routers.research import (
from khoj.routers.storage import upload_user_image_to_bucket
from khoj.utils import state
from khoj.utils.helpers import (
AsyncIteratorWrapper,
ConversationCommand,
command_descriptions,
convert_image_to_webp,
@@ -999,7 +998,7 @@ async def chat(
return
llm_response = construct_automation_created_message(automation, crontime, query_to_run, subject)
await sync_to_async(save_to_conversation_log)(
await save_to_conversation_log(
q,
llm_response,
user,
@@ -1308,26 +1307,31 @@ async def chat(
yield result
continue_stream = True
iterator = AsyncIteratorWrapper(llm_response)
async for item in iterator:
async for item in llm_response:
# Should not happen with async generator, end is signaled by loop exit. Skip.
if item is None:
async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
yield result
# Send Usage Metadata once llm interactions are complete
async for event in send_event(ChatEvent.USAGE, tracer.get("usage")):
yield event
async for result in send_event(ChatEvent.END_RESPONSE, ""):
yield result
logger.debug("Finished streaming response")
return
continue
if not connection_alive or not continue_stream:
# Drain the generator if disconnected but keep processing internally
continue
try:
async for result in send_event(ChatEvent.MESSAGE, f"{item}"):
yield result
except Exception as e:
continue_stream = False
logger.info(f"User {user} disconnected. Emitting rest of responses to clear thread: {e}")
logger.info(f"User {user} disconnected or error during streaming. Stopping send: {e}")
# Signal end of LLM response after the loop finishes
if connection_alive:
async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
yield result
# Send Usage Metadata once llm interactions are complete
if tracer.get("usage"):
async for event in send_event(ChatEvent.USAGE, tracer.get("usage")):
yield event
async for result in send_event(ChatEvent.END_RESPONSE, ""):
yield result
logger.debug("Finished streaming response")
## Stream Text Response
if stream:

View File

@@ -1,4 +1,3 @@
import asyncio
import base64
import hashlib
import json
@@ -6,9 +5,7 @@ import logging
import math
import os
import re
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timedelta, timezone
from enum import Enum
from functools import partial
from random import random
from typing import (
@@ -17,7 +14,6 @@ from typing import (
AsyncGenerator,
Callable,
Dict,
Iterator,
List,
Optional,
Set,
@@ -126,8 +122,6 @@ from khoj.utils.rawconfig import ChatRequestBody, FileAttachment, FileData, Loca
logger = logging.getLogger(__name__)
executor = ThreadPoolExecutor(max_workers=1)
NOTION_OAUTH_CLIENT_ID = os.getenv("NOTION_OAUTH_CLIENT_ID")
NOTION_OAUTH_CLIENT_SECRET = os.getenv("NOTION_OAUTH_CLIENT_SECRET")
@@ -262,11 +256,6 @@ def get_conversation_command(query: str) -> ConversationCommand:
return ConversationCommand.Default
async def agenerate_chat_response(*args):
loop = asyncio.get_event_loop()
return await loop.run_in_executor(executor, generate_chat_response, *args)
def gather_raw_query_files(
query_files: Dict[str, str],
):
@@ -1418,7 +1407,7 @@ def send_message_to_model_wrapper_sync(
raise HTTPException(status_code=500, detail="Invalid conversation config")
def generate_chat_response(
async def agenerate_chat_response(
q: str,
meta_log: dict,
conversation: Conversation,
@@ -1444,13 +1433,14 @@ def generate_chat_response(
generated_asset_results: Dict[str, Dict] = {},
is_subscribed: bool = False,
tracer: dict = {},
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
) -> Tuple[AsyncGenerator[str, None], Dict[str, str]]:
# Initialize Variables
chat_response = None
chat_response_generator = None
logger.debug(f"Conversation Types: {conversation_commands}")
metadata = {}
agent = AgentAdapters.get_conversation_agent_by_id(conversation.agent.id) if conversation.agent else None
agent = await AgentAdapters.aget_conversation_agent_by_id(conversation.agent.id) if conversation.agent else None
try:
partial_completion = partial(
save_to_conversation_log,
@@ -1481,23 +1471,25 @@ def generate_chat_response(
code_results = {}
deepthought = True
chat_model = ConversationAdapters.get_valid_chat_model(user, conversation, is_subscribed)
chat_model = await ConversationAdapters.aget_valid_chat_model(user, conversation, is_subscribed)
vision_available = chat_model.vision_enabled
if not vision_available and query_images:
vision_enabled_config = ConversationAdapters.get_vision_enabled_config()
vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config()
if vision_enabled_config:
chat_model = vision_enabled_config
vision_available = True
if chat_model.model_type == "offline":
# Assuming converse_offline remains sync or is refactored separately
loaded_model = state.offline_chat_processor_config.loaded_model
chat_response = converse_offline(
# If converse_offline returns an iterator, wrap it if needed, or refactor it to async generator
chat_response_generator = converse_offline( # Needs adaptation if it becomes async
user_query=query_to_run,
references=compiled_references,
online_results=online_results,
loaded_model=loaded_model,
conversation_log=meta_log,
completion_func=partial_completion,
completion_func=partial_completion, # Pass the async wrapper
conversation_commands=conversation_commands,
model_name=chat_model.name,
max_prompt_size=chat_model.max_prompt_size,
@@ -1515,7 +1507,7 @@ def generate_chat_response(
openai_chat_config = chat_model.ai_model_api
api_key = openai_chat_config.api_key
chat_model_name = chat_model.name
chat_response = converse_openai(
chat_response_generator = converse_openai(
compiled_references,
query_to_run,
query_images=query_images,
@@ -1542,9 +1534,10 @@ def generate_chat_response(
)
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
# Assuming converse_anthropic remains sync or is refactored separately
api_key = chat_model.ai_model_api.api_key
api_base_url = chat_model.ai_model_api.api_base_url
chat_response = converse_anthropic(
chat_response_generator = converse_anthropic( # Needs adaptation if it becomes async
compiled_references,
query_to_run,
query_images=query_images,
@@ -1570,9 +1563,10 @@ def generate_chat_response(
tracer=tracer,
)
elif chat_model.model_type == ChatModel.ModelType.GOOGLE:
# Assuming converse_gemini remains sync or is refactored separately
api_key = chat_model.ai_model_api.api_key
api_base_url = chat_model.ai_model_api.api_base_url
chat_response = converse_gemini(
chat_response_generator = converse_gemini( # Needs adaptation if it becomes async
compiled_references,
query_to_run,
online_results,
@@ -1604,7 +1598,8 @@ def generate_chat_response(
logger.error(e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
return chat_response, metadata
# Return the generator directly
return chat_response_generator, metadata
class DeleteMessageRequestBody(BaseModel):