Refactor Openai chat response to stream async, no separate thread

- Refactor chat API to use async/await for Openai streaming
- Fix and clean Openai chat response async streaming
This commit is contained in:
Debanjum
2025-04-19 21:36:45 +05:30
parent c93c0d982e
commit 0751f2ea30
6 changed files with 132 additions and 141 deletions

View File

@@ -763,9 +763,9 @@ class AgentAdapters:
return False return False
@staticmethod @staticmethod
def get_conversation_agent_by_id(agent_id: int): async def aget_conversation_agent_by_id(agent_id: int):
agent = Agent.objects.filter(id=agent_id).first() agent = await Agent.objects.filter(id=agent_id).afirst()
if agent == AgentAdapters.get_default_agent(): if agent == await AgentAdapters.aget_default_agent():
# If the agent is set to the default agent, then return None and let the default application code be used # If the agent is set to the default agent, then return None and let the default application code be used
return None return None
return agent return agent
@@ -1109,14 +1109,6 @@ class ConversationAdapters:
async def aget_all_chat_models(): async def aget_all_chat_models():
return await sync_to_async(list)(ChatModel.objects.prefetch_related("ai_model_api").all()) return await sync_to_async(list)(ChatModel.objects.prefetch_related("ai_model_api").all())
@staticmethod
def get_vision_enabled_config():
chat_models = ConversationAdapters.get_all_chat_models()
for config in chat_models:
if config.vision_enabled:
return config
return None
@staticmethod @staticmethod
async def aget_vision_enabled_config(): async def aget_vision_enabled_config():
chat_models = await ConversationAdapters.aget_all_chat_models() chat_models = await ConversationAdapters.aget_all_chat_models()
@@ -1171,7 +1163,11 @@ class ConversationAdapters:
@staticmethod @staticmethod
async def aget_chat_model(user: KhojUser): async def aget_chat_model(user: KhojUser):
subscribed = await ais_user_subscribed(user) subscribed = await ais_user_subscribed(user)
config = await UserConversationConfig.objects.filter(user=user).prefetch_related("setting").afirst() config = (
await UserConversationConfig.objects.filter(user=user)
.prefetch_related("setting", "setting__ai_model_api")
.afirst()
)
if subscribed: if subscribed:
# Subscibed users can use any available chat model # Subscibed users can use any available chat model
if config: if config:
@@ -1387,7 +1383,7 @@ class ConversationAdapters:
@staticmethod @staticmethod
@require_valid_user @require_valid_user
def save_conversation( async def save_conversation(
user: KhojUser, user: KhojUser,
conversation_log: dict, conversation_log: dict,
client_application: ClientApplication = None, client_application: ClientApplication = None,
@@ -1396,19 +1392,21 @@ class ConversationAdapters:
): ):
slug = user_message.strip()[:200] if user_message else None slug = user_message.strip()[:200] if user_message else None
if conversation_id: if conversation_id:
conversation = Conversation.objects.filter(user=user, client=client_application, id=conversation_id).first() conversation = await Conversation.objects.filter(
user=user, client=client_application, id=conversation_id
).afirst()
else: else:
conversation = ( conversation = (
Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at").first() await Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at").afirst()
) )
if conversation: if conversation:
conversation.conversation_log = conversation_log conversation.conversation_log = conversation_log
conversation.slug = slug conversation.slug = slug
conversation.updated_at = datetime.now(tz=timezone.utc) conversation.updated_at = datetime.now(tz=timezone.utc)
conversation.save() await conversation.asave()
else: else:
Conversation.objects.create( await Conversation.objects.acreate(
user=user, conversation_log=conversation_log, client=client_application, slug=slug user=user, conversation_log=conversation_log, client=client_application, slug=slug
) )
@@ -1455,17 +1453,21 @@ class ConversationAdapters:
return random.sample(all_questions, max_results) return random.sample(all_questions, max_results)
@staticmethod @staticmethod
def get_valid_chat_model(user: KhojUser, conversation: Conversation, is_subscribed: bool): async def aget_valid_chat_model(user: KhojUser, conversation: Conversation, is_subscribed: bool):
agent: Agent = ( agent: Agent = (
conversation.agent if is_subscribed and AgentAdapters.get_default_agent() != conversation.agent else None conversation.agent
if is_subscribed and await AgentAdapters.aget_default_agent() != conversation.agent
else None
) )
if agent and agent.chat_model: if agent and agent.chat_model:
chat_model = conversation.agent.chat_model chat_model = await ChatModel.objects.select_related("ai_model_api").aget(
pk=conversation.agent.chat_model.pk
)
else: else:
chat_model = ConversationAdapters.get_chat_model(user) chat_model = await ConversationAdapters.aget_chat_model(user)
if chat_model is None: if chat_model is None:
chat_model = ConversationAdapters.get_default_chat_model() chat_model = await ConversationAdapters.aget_default_chat_model()
if chat_model.model_type == ChatModel.ModelType.OFFLINE: if chat_model.model_type == ChatModel.ModelType.OFFLINE:
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None: if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:

View File

@@ -1,6 +1,6 @@
import logging import logging
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Dict, List, Optional from typing import AsyncGenerator, Dict, List, Optional
import pyjson5 import pyjson5
from langchain.schema import ChatMessage from langchain.schema import ChatMessage
@@ -162,7 +162,7 @@ def send_message_to_model(
) )
def converse_openai( async def converse_openai(
references, references,
user_query, user_query,
online_results: Optional[Dict[str, Dict]] = None, online_results: Optional[Dict[str, Dict]] = None,
@@ -187,7 +187,7 @@ def converse_openai(
program_execution_context: List[str] = None, program_execution_context: List[str] = None,
deepthought: Optional[bool] = False, deepthought: Optional[bool] = False,
tracer: dict = {}, tracer: dict = {},
): ) -> AsyncGenerator[str, None]:
""" """
Converse with user using OpenAI's ChatGPT Converse with user using OpenAI's ChatGPT
""" """
@@ -217,11 +217,17 @@ def converse_openai(
# Get Conversation Primer appropriate to Conversation Type # Get Conversation Primer appropriate to Conversation Type
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references): if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
completion_func(chat_response=prompts.no_notes_found.format()) response = prompts.no_notes_found.format()
return iter([prompts.no_notes_found.format()]) if completion_func:
await completion_func(chat_response=response)
yield response
return
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results): elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
completion_func(chat_response=prompts.no_online_results_found.format()) response = prompts.no_online_results_found.format()
return iter([prompts.no_online_results_found.format()]) if completion_func:
await completion_func(chat_response=response)
yield response
return
context_message = "" context_message = ""
if not is_none_or_empty(references): if not is_none_or_empty(references):
@@ -255,19 +261,23 @@ def converse_openai(
logger.debug(f"Conversation Context for GPT: {messages_to_print(messages)}") logger.debug(f"Conversation Context for GPT: {messages_to_print(messages)}")
# Get Response from GPT # Get Response from GPT
return chat_completion_with_backoff( full_response = ""
async for chunk in chat_completion_with_backoff(
messages=messages, messages=messages,
compiled_references=references,
online_results=online_results,
model_name=model, model_name=model,
temperature=temperature, temperature=temperature,
openai_api_key=api_key, openai_api_key=api_key,
api_base_url=api_base_url, api_base_url=api_base_url,
completion_func=completion_func,
deepthought=deepthought, deepthought=deepthought,
model_kwargs={"stop": ["Notes:\n["]}, model_kwargs={"stop": ["Notes:\n["]},
tracer=tracer, tracer=tracer,
) ):
full_response += chunk
yield chunk
# Call completion_func once finish streaming and we have the full response
if completion_func:
await completion_func(chat_response=full_response)
def clean_response_schema(schema: BaseModel | dict) -> dict: def clean_response_schema(schema: BaseModel | dict) -> dict:

View File

@@ -1,7 +1,7 @@
import logging import logging
import os import os
from threading import Thread from time import perf_counter
from typing import Dict, List from typing import AsyncGenerator, Dict, List
from urllib.parse import urlparse from urllib.parse import urlparse
import openai import openai
@@ -16,13 +16,10 @@ from tenacity import (
wait_random_exponential, wait_random_exponential,
) )
from khoj.processor.conversation.utils import ( from khoj.processor.conversation.utils import JsonSupport, commit_conversation_trace
JsonSupport,
ThreadedGenerator,
commit_conversation_trace,
)
from khoj.utils.helpers import ( from khoj.utils.helpers import (
get_chat_usage_metrics, get_chat_usage_metrics,
get_openai_async_client,
get_openai_client, get_openai_client,
is_promptrace_enabled, is_promptrace_enabled,
) )
@@ -30,6 +27,7 @@ from khoj.utils.helpers import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
openai_clients: Dict[str, openai.OpenAI] = {} openai_clients: Dict[str, openai.OpenAI] = {}
openai_async_clients: Dict[str, openai.AsyncOpenAI] = {}
@retry( @retry(
@@ -124,45 +122,22 @@ def completion_with_backoff(
before_sleep=before_sleep_log(logger, logging.DEBUG), before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True, reraise=True,
) )
def chat_completion_with_backoff( async def chat_completion_with_backoff(
messages, messages,
compiled_references,
online_results,
model_name, model_name,
temperature, temperature,
openai_api_key=None, openai_api_key=None,
api_base_url=None, api_base_url=None,
completion_func=None,
deepthought=False,
model_kwargs=None,
tracer: dict = {},
):
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
t = Thread(
target=llm_thread,
args=(g, messages, model_name, temperature, openai_api_key, api_base_url, deepthought, model_kwargs, tracer),
)
t.start()
return g
def llm_thread(
g,
messages,
model_name: str,
temperature,
openai_api_key=None,
api_base_url=None,
deepthought=False, deepthought=False,
model_kwargs: dict = {}, model_kwargs: dict = {},
tracer: dict = {}, tracer: dict = {},
): ) -> AsyncGenerator[str, None]:
try: try:
client_key = f"{openai_api_key}--{api_base_url}" client_key = f"{openai_api_key}--{api_base_url}"
client = openai_clients.get(client_key) client = openai_async_clients.get(client_key)
if not client: if not client:
client = get_openai_client(openai_api_key, api_base_url) client = get_openai_async_client(openai_api_key, api_base_url)
openai_clients[client_key] = client openai_async_clients[client_key] = client
formatted_messages = [{"role": message.role, "content": message.content} for message in messages] formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
@@ -207,53 +182,58 @@ def llm_thread(
if os.getenv("KHOJ_LLM_SEED"): if os.getenv("KHOJ_LLM_SEED"):
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED")) model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
chat: ChatCompletion | openai.Stream[ChatCompletionChunk] = client.chat.completions.create( aggregated_response = ""
messages=formatted_messages, final_chunk = None
model=model_name, # type: ignore start_time = perf_counter()
chat_stream: openai.AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(
messages=formatted_messages, # type: ignore
model=model_name,
stream=stream, stream=stream,
temperature=temperature, temperature=temperature,
timeout=20, timeout=20,
**model_kwargs, **model_kwargs,
) )
async for chunk in chat_stream:
aggregated_response = "" # Log the time taken to start response
if not stream: if final_chunk is None:
chunk = chat logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
aggregated_response = chunk.choices[0].message.content # Keep track of the last chunk for usage data
g.send(aggregated_response) final_chunk = chunk
else: # Handle streamed response chunk
for chunk in chat:
if len(chunk.choices) == 0: if len(chunk.choices) == 0:
continue continue
delta_chunk = chunk.choices[0].delta delta_chunk = chunk.choices[0].delta
text_chunk = "" text_chunk = ""
if isinstance(delta_chunk, str): if isinstance(delta_chunk, str):
text_chunk = delta_chunk text_chunk = delta_chunk
elif delta_chunk.content: elif delta_chunk and delta_chunk.content:
text_chunk = delta_chunk.content text_chunk = delta_chunk.content
if text_chunk: if text_chunk:
aggregated_response += text_chunk aggregated_response += text_chunk
g.send(text_chunk) yield text_chunk
# Calculate cost of chat # Log the time taken to stream the entire response
input_tokens = chunk.usage.prompt_tokens if hasattr(chunk, "usage") and chunk.usage else 0 logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
output_tokens = chunk.usage.completion_tokens if hasattr(chunk, "usage") and chunk.usage else 0
cost = ( # Calculate cost of chat after stream finishes
chunk.usage.model_extra.get("estimated_cost", 0) if hasattr(chunk, "usage") and chunk.usage else 0 input_tokens, output_tokens, cost = 0, 0, 0
) # Estimated costs returned by DeepInfra API if final_chunk and hasattr(final_chunk, "usage") and final_chunk.usage:
tracer["usage"] = get_chat_usage_metrics( input_tokens = final_chunk.usage.prompt_tokens
model_name, input_tokens, output_tokens, usage=tracer.get("usage"), cost=cost output_tokens = final_chunk.usage.completion_tokens
) # Estimated costs returned by DeepInfra API
if final_chunk.usage.model_extra and "estimated_cost" in final_chunk.usage.model_extra:
cost = final_chunk.usage.model_extra.get("estimated_cost", 0)
# Save conversation trace # Save conversation trace
tracer["chat_model"] = model_name tracer["chat_model"] = model_name
tracer["temperature"] = temperature tracer["temperature"] = temperature
tracer["usage"] = get_chat_usage_metrics(
model_name, input_tokens, output_tokens, usage=tracer.get("usage"), cost=cost
)
if is_promptrace_enabled(): if is_promptrace_enabled():
commit_conversation_trace(messages, aggregated_response, tracer) commit_conversation_trace(messages, aggregated_response, tracer)
except Exception as e: except Exception as e:
logger.error(f"Error in llm_thread: {e}", exc_info=True) logger.error(f"Error in chat_completion_with_backoff stream: {e}", exc_info=True)
finally:
g.close()
def get_openai_api_json_support(model_name: str, api_base_url: str = None) -> JsonSupport: def get_openai_api_json_support(model_name: str, api_base_url: str = None) -> JsonSupport:

View File

@@ -254,7 +254,7 @@ def message_to_log(
return conversation_log return conversation_log
def save_to_conversation_log( async def save_to_conversation_log(
q: str, q: str,
chat_response: str, chat_response: str,
user: KhojUser, user: KhojUser,
@@ -306,7 +306,7 @@ def save_to_conversation_log(
khoj_message_metadata=khoj_message_metadata, khoj_message_metadata=khoj_message_metadata,
conversation_log=meta_log.get("chat", []), conversation_log=meta_log.get("chat", []),
) )
ConversationAdapters.save_conversation( await ConversationAdapters.save_conversation(
user, user,
{"chat": updated_conversation}, {"chat": updated_conversation},
client_application=client_application, client_application=client_application,

View File

@@ -67,7 +67,6 @@ from khoj.routers.research import (
from khoj.routers.storage import upload_user_image_to_bucket from khoj.routers.storage import upload_user_image_to_bucket
from khoj.utils import state from khoj.utils import state
from khoj.utils.helpers import ( from khoj.utils.helpers import (
AsyncIteratorWrapper,
ConversationCommand, ConversationCommand,
command_descriptions, command_descriptions,
convert_image_to_webp, convert_image_to_webp,
@@ -999,7 +998,7 @@ async def chat(
return return
llm_response = construct_automation_created_message(automation, crontime, query_to_run, subject) llm_response = construct_automation_created_message(automation, crontime, query_to_run, subject)
await sync_to_async(save_to_conversation_log)( await save_to_conversation_log(
q, q,
llm_response, llm_response,
user, user,
@@ -1308,26 +1307,31 @@ async def chat(
yield result yield result
continue_stream = True continue_stream = True
iterator = AsyncIteratorWrapper(llm_response) async for item in llm_response:
async for item in iterator: # Should not happen with async generator, end is signaled by loop exit. Skip.
if item is None: if item is None:
async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""): continue
yield result
# Send Usage Metadata once llm interactions are complete
async for event in send_event(ChatEvent.USAGE, tracer.get("usage")):
yield event
async for result in send_event(ChatEvent.END_RESPONSE, ""):
yield result
logger.debug("Finished streaming response")
return
if not connection_alive or not continue_stream: if not connection_alive or not continue_stream:
# Drain the generator if disconnected but keep processing internally
continue continue
try: try:
async for result in send_event(ChatEvent.MESSAGE, f"{item}"): async for result in send_event(ChatEvent.MESSAGE, f"{item}"):
yield result yield result
except Exception as e: except Exception as e:
continue_stream = False continue_stream = False
logger.info(f"User {user} disconnected. Emitting rest of responses to clear thread: {e}") logger.info(f"User {user} disconnected or error during streaming. Stopping send: {e}")
# Signal end of LLM response after the loop finishes
if connection_alive:
async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
yield result
# Send Usage Metadata once llm interactions are complete
if tracer.get("usage"):
async for event in send_event(ChatEvent.USAGE, tracer.get("usage")):
yield event
async for result in send_event(ChatEvent.END_RESPONSE, ""):
yield result
logger.debug("Finished streaming response")
## Stream Text Response ## Stream Text Response
if stream: if stream:

View File

@@ -1,4 +1,3 @@
import asyncio
import base64 import base64
import hashlib import hashlib
import json import json
@@ -6,9 +5,7 @@ import logging
import math import math
import os import os
import re import re
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from enum import Enum
from functools import partial from functools import partial
from random import random from random import random
from typing import ( from typing import (
@@ -17,7 +14,6 @@ from typing import (
AsyncGenerator, AsyncGenerator,
Callable, Callable,
Dict, Dict,
Iterator,
List, List,
Optional, Optional,
Set, Set,
@@ -126,8 +122,6 @@ from khoj.utils.rawconfig import ChatRequestBody, FileAttachment, FileData, Loca
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
executor = ThreadPoolExecutor(max_workers=1)
NOTION_OAUTH_CLIENT_ID = os.getenv("NOTION_OAUTH_CLIENT_ID") NOTION_OAUTH_CLIENT_ID = os.getenv("NOTION_OAUTH_CLIENT_ID")
NOTION_OAUTH_CLIENT_SECRET = os.getenv("NOTION_OAUTH_CLIENT_SECRET") NOTION_OAUTH_CLIENT_SECRET = os.getenv("NOTION_OAUTH_CLIENT_SECRET")
@@ -262,11 +256,6 @@ def get_conversation_command(query: str) -> ConversationCommand:
return ConversationCommand.Default return ConversationCommand.Default
async def agenerate_chat_response(*args):
loop = asyncio.get_event_loop()
return await loop.run_in_executor(executor, generate_chat_response, *args)
def gather_raw_query_files( def gather_raw_query_files(
query_files: Dict[str, str], query_files: Dict[str, str],
): ):
@@ -1418,7 +1407,7 @@ def send_message_to_model_wrapper_sync(
raise HTTPException(status_code=500, detail="Invalid conversation config") raise HTTPException(status_code=500, detail="Invalid conversation config")
def generate_chat_response( async def agenerate_chat_response(
q: str, q: str,
meta_log: dict, meta_log: dict,
conversation: Conversation, conversation: Conversation,
@@ -1444,13 +1433,14 @@ def generate_chat_response(
generated_asset_results: Dict[str, Dict] = {}, generated_asset_results: Dict[str, Dict] = {},
is_subscribed: bool = False, is_subscribed: bool = False,
tracer: dict = {}, tracer: dict = {},
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]: ) -> Tuple[AsyncGenerator[str, None], Dict[str, str]]:
# Initialize Variables # Initialize Variables
chat_response = None chat_response_generator = None
logger.debug(f"Conversation Types: {conversation_commands}") logger.debug(f"Conversation Types: {conversation_commands}")
metadata = {} metadata = {}
agent = AgentAdapters.get_conversation_agent_by_id(conversation.agent.id) if conversation.agent else None agent = await AgentAdapters.aget_conversation_agent_by_id(conversation.agent.id) if conversation.agent else None
try: try:
partial_completion = partial( partial_completion = partial(
save_to_conversation_log, save_to_conversation_log,
@@ -1481,23 +1471,25 @@ def generate_chat_response(
code_results = {} code_results = {}
deepthought = True deepthought = True
chat_model = ConversationAdapters.get_valid_chat_model(user, conversation, is_subscribed) chat_model = await ConversationAdapters.aget_valid_chat_model(user, conversation, is_subscribed)
vision_available = chat_model.vision_enabled vision_available = chat_model.vision_enabled
if not vision_available and query_images: if not vision_available and query_images:
vision_enabled_config = ConversationAdapters.get_vision_enabled_config() vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config()
if vision_enabled_config: if vision_enabled_config:
chat_model = vision_enabled_config chat_model = vision_enabled_config
vision_available = True vision_available = True
if chat_model.model_type == "offline": if chat_model.model_type == "offline":
# Assuming converse_offline remains sync or is refactored separately
loaded_model = state.offline_chat_processor_config.loaded_model loaded_model = state.offline_chat_processor_config.loaded_model
chat_response = converse_offline( # If converse_offline returns an iterator, wrap it if needed, or refactor it to async generator
chat_response_generator = converse_offline( # Needs adaptation if it becomes async
user_query=query_to_run, user_query=query_to_run,
references=compiled_references, references=compiled_references,
online_results=online_results, online_results=online_results,
loaded_model=loaded_model, loaded_model=loaded_model,
conversation_log=meta_log, conversation_log=meta_log,
completion_func=partial_completion, completion_func=partial_completion, # Pass the async wrapper
conversation_commands=conversation_commands, conversation_commands=conversation_commands,
model_name=chat_model.name, model_name=chat_model.name,
max_prompt_size=chat_model.max_prompt_size, max_prompt_size=chat_model.max_prompt_size,
@@ -1515,7 +1507,7 @@ def generate_chat_response(
openai_chat_config = chat_model.ai_model_api openai_chat_config = chat_model.ai_model_api
api_key = openai_chat_config.api_key api_key = openai_chat_config.api_key
chat_model_name = chat_model.name chat_model_name = chat_model.name
chat_response = converse_openai( chat_response_generator = converse_openai(
compiled_references, compiled_references,
query_to_run, query_to_run,
query_images=query_images, query_images=query_images,
@@ -1542,9 +1534,10 @@ def generate_chat_response(
) )
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC: elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
# Assuming converse_anthropic remains sync or is refactored separately
api_key = chat_model.ai_model_api.api_key api_key = chat_model.ai_model_api.api_key
api_base_url = chat_model.ai_model_api.api_base_url api_base_url = chat_model.ai_model_api.api_base_url
chat_response = converse_anthropic( chat_response_generator = converse_anthropic( # Needs adaptation if it becomes async
compiled_references, compiled_references,
query_to_run, query_to_run,
query_images=query_images, query_images=query_images,
@@ -1570,9 +1563,10 @@ def generate_chat_response(
tracer=tracer, tracer=tracer,
) )
elif chat_model.model_type == ChatModel.ModelType.GOOGLE: elif chat_model.model_type == ChatModel.ModelType.GOOGLE:
# Assuming converse_gemini remains sync or is refactored separately
api_key = chat_model.ai_model_api.api_key api_key = chat_model.ai_model_api.api_key
api_base_url = chat_model.ai_model_api.api_base_url api_base_url = chat_model.ai_model_api.api_base_url
chat_response = converse_gemini( chat_response_generator = converse_gemini( # Needs adaptation if it becomes async
compiled_references, compiled_references,
query_to_run, query_to_run,
online_results, online_results,
@@ -1604,7 +1598,8 @@ def generate_chat_response(
logger.error(e, exc_info=True) logger.error(e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
return chat_response, metadata # Return the generator directly
return chat_response_generator, metadata
class DeleteMessageRequestBody(BaseModel): class DeleteMessageRequestBody(BaseModel):