mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-08 05:39:13 +00:00
Refactor Openai chat response to stream async, no separate thread
- Refactor chat API to use async/await for Openai streaming - Fix and clean Openai chat response async streaming
This commit is contained in:
@@ -763,9 +763,9 @@ class AgentAdapters:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_conversation_agent_by_id(agent_id: int):
|
async def aget_conversation_agent_by_id(agent_id: int):
|
||||||
agent = Agent.objects.filter(id=agent_id).first()
|
agent = await Agent.objects.filter(id=agent_id).afirst()
|
||||||
if agent == AgentAdapters.get_default_agent():
|
if agent == await AgentAdapters.aget_default_agent():
|
||||||
# If the agent is set to the default agent, then return None and let the default application code be used
|
# If the agent is set to the default agent, then return None and let the default application code be used
|
||||||
return None
|
return None
|
||||||
return agent
|
return agent
|
||||||
@@ -1109,14 +1109,6 @@ class ConversationAdapters:
|
|||||||
async def aget_all_chat_models():
|
async def aget_all_chat_models():
|
||||||
return await sync_to_async(list)(ChatModel.objects.prefetch_related("ai_model_api").all())
|
return await sync_to_async(list)(ChatModel.objects.prefetch_related("ai_model_api").all())
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_vision_enabled_config():
|
|
||||||
chat_models = ConversationAdapters.get_all_chat_models()
|
|
||||||
for config in chat_models:
|
|
||||||
if config.vision_enabled:
|
|
||||||
return config
|
|
||||||
return None
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def aget_vision_enabled_config():
|
async def aget_vision_enabled_config():
|
||||||
chat_models = await ConversationAdapters.aget_all_chat_models()
|
chat_models = await ConversationAdapters.aget_all_chat_models()
|
||||||
@@ -1171,7 +1163,11 @@ class ConversationAdapters:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
async def aget_chat_model(user: KhojUser):
|
async def aget_chat_model(user: KhojUser):
|
||||||
subscribed = await ais_user_subscribed(user)
|
subscribed = await ais_user_subscribed(user)
|
||||||
config = await UserConversationConfig.objects.filter(user=user).prefetch_related("setting").afirst()
|
config = (
|
||||||
|
await UserConversationConfig.objects.filter(user=user)
|
||||||
|
.prefetch_related("setting", "setting__ai_model_api")
|
||||||
|
.afirst()
|
||||||
|
)
|
||||||
if subscribed:
|
if subscribed:
|
||||||
# Subscibed users can use any available chat model
|
# Subscibed users can use any available chat model
|
||||||
if config:
|
if config:
|
||||||
@@ -1387,7 +1383,7 @@ class ConversationAdapters:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@require_valid_user
|
@require_valid_user
|
||||||
def save_conversation(
|
async def save_conversation(
|
||||||
user: KhojUser,
|
user: KhojUser,
|
||||||
conversation_log: dict,
|
conversation_log: dict,
|
||||||
client_application: ClientApplication = None,
|
client_application: ClientApplication = None,
|
||||||
@@ -1396,19 +1392,21 @@ class ConversationAdapters:
|
|||||||
):
|
):
|
||||||
slug = user_message.strip()[:200] if user_message else None
|
slug = user_message.strip()[:200] if user_message else None
|
||||||
if conversation_id:
|
if conversation_id:
|
||||||
conversation = Conversation.objects.filter(user=user, client=client_application, id=conversation_id).first()
|
conversation = await Conversation.objects.filter(
|
||||||
|
user=user, client=client_application, id=conversation_id
|
||||||
|
).afirst()
|
||||||
else:
|
else:
|
||||||
conversation = (
|
conversation = (
|
||||||
Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at").first()
|
await Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at").afirst()
|
||||||
)
|
)
|
||||||
|
|
||||||
if conversation:
|
if conversation:
|
||||||
conversation.conversation_log = conversation_log
|
conversation.conversation_log = conversation_log
|
||||||
conversation.slug = slug
|
conversation.slug = slug
|
||||||
conversation.updated_at = datetime.now(tz=timezone.utc)
|
conversation.updated_at = datetime.now(tz=timezone.utc)
|
||||||
conversation.save()
|
await conversation.asave()
|
||||||
else:
|
else:
|
||||||
Conversation.objects.create(
|
await Conversation.objects.acreate(
|
||||||
user=user, conversation_log=conversation_log, client=client_application, slug=slug
|
user=user, conversation_log=conversation_log, client=client_application, slug=slug
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1455,17 +1453,21 @@ class ConversationAdapters:
|
|||||||
return random.sample(all_questions, max_results)
|
return random.sample(all_questions, max_results)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_valid_chat_model(user: KhojUser, conversation: Conversation, is_subscribed: bool):
|
async def aget_valid_chat_model(user: KhojUser, conversation: Conversation, is_subscribed: bool):
|
||||||
agent: Agent = (
|
agent: Agent = (
|
||||||
conversation.agent if is_subscribed and AgentAdapters.get_default_agent() != conversation.agent else None
|
conversation.agent
|
||||||
|
if is_subscribed and await AgentAdapters.aget_default_agent() != conversation.agent
|
||||||
|
else None
|
||||||
)
|
)
|
||||||
if agent and agent.chat_model:
|
if agent and agent.chat_model:
|
||||||
chat_model = conversation.agent.chat_model
|
chat_model = await ChatModel.objects.select_related("ai_model_api").aget(
|
||||||
|
pk=conversation.agent.chat_model.pk
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
chat_model = ConversationAdapters.get_chat_model(user)
|
chat_model = await ConversationAdapters.aget_chat_model(user)
|
||||||
|
|
||||||
if chat_model is None:
|
if chat_model is None:
|
||||||
chat_model = ConversationAdapters.get_default_chat_model()
|
chat_model = await ConversationAdapters.aget_default_chat_model()
|
||||||
|
|
||||||
if chat_model.model_type == ChatModel.ModelType.OFFLINE:
|
if chat_model.model_type == ChatModel.ModelType.OFFLINE:
|
||||||
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
|
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Dict, List, Optional
|
from typing import AsyncGenerator, Dict, List, Optional
|
||||||
|
|
||||||
import pyjson5
|
import pyjson5
|
||||||
from langchain.schema import ChatMessage
|
from langchain.schema import ChatMessage
|
||||||
@@ -162,7 +162,7 @@ def send_message_to_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def converse_openai(
|
async def converse_openai(
|
||||||
references,
|
references,
|
||||||
user_query,
|
user_query,
|
||||||
online_results: Optional[Dict[str, Dict]] = None,
|
online_results: Optional[Dict[str, Dict]] = None,
|
||||||
@@ -187,7 +187,7 @@ def converse_openai(
|
|||||||
program_execution_context: List[str] = None,
|
program_execution_context: List[str] = None,
|
||||||
deepthought: Optional[bool] = False,
|
deepthought: Optional[bool] = False,
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
):
|
) -> AsyncGenerator[str, None]:
|
||||||
"""
|
"""
|
||||||
Converse with user using OpenAI's ChatGPT
|
Converse with user using OpenAI's ChatGPT
|
||||||
"""
|
"""
|
||||||
@@ -217,11 +217,17 @@ def converse_openai(
|
|||||||
|
|
||||||
# Get Conversation Primer appropriate to Conversation Type
|
# Get Conversation Primer appropriate to Conversation Type
|
||||||
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
|
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
|
||||||
completion_func(chat_response=prompts.no_notes_found.format())
|
response = prompts.no_notes_found.format()
|
||||||
return iter([prompts.no_notes_found.format()])
|
if completion_func:
|
||||||
|
await completion_func(chat_response=response)
|
||||||
|
yield response
|
||||||
|
return
|
||||||
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
|
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
|
||||||
completion_func(chat_response=prompts.no_online_results_found.format())
|
response = prompts.no_online_results_found.format()
|
||||||
return iter([prompts.no_online_results_found.format()])
|
if completion_func:
|
||||||
|
await completion_func(chat_response=response)
|
||||||
|
yield response
|
||||||
|
return
|
||||||
|
|
||||||
context_message = ""
|
context_message = ""
|
||||||
if not is_none_or_empty(references):
|
if not is_none_or_empty(references):
|
||||||
@@ -255,19 +261,23 @@ def converse_openai(
|
|||||||
logger.debug(f"Conversation Context for GPT: {messages_to_print(messages)}")
|
logger.debug(f"Conversation Context for GPT: {messages_to_print(messages)}")
|
||||||
|
|
||||||
# Get Response from GPT
|
# Get Response from GPT
|
||||||
return chat_completion_with_backoff(
|
full_response = ""
|
||||||
|
async for chunk in chat_completion_with_backoff(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
compiled_references=references,
|
|
||||||
online_results=online_results,
|
|
||||||
model_name=model,
|
model_name=model,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
openai_api_key=api_key,
|
openai_api_key=api_key,
|
||||||
api_base_url=api_base_url,
|
api_base_url=api_base_url,
|
||||||
completion_func=completion_func,
|
|
||||||
deepthought=deepthought,
|
deepthought=deepthought,
|
||||||
model_kwargs={"stop": ["Notes:\n["]},
|
model_kwargs={"stop": ["Notes:\n["]},
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
):
|
||||||
|
full_response += chunk
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
# Call completion_func once finish streaming and we have the full response
|
||||||
|
if completion_func:
|
||||||
|
await completion_func(chat_response=full_response)
|
||||||
|
|
||||||
|
|
||||||
def clean_response_schema(schema: BaseModel | dict) -> dict:
|
def clean_response_schema(schema: BaseModel | dict) -> dict:
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from threading import Thread
|
from time import perf_counter
|
||||||
from typing import Dict, List
|
from typing import AsyncGenerator, Dict, List
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
@@ -16,13 +16,10 @@ from tenacity import (
|
|||||||
wait_random_exponential,
|
wait_random_exponential,
|
||||||
)
|
)
|
||||||
|
|
||||||
from khoj.processor.conversation.utils import (
|
from khoj.processor.conversation.utils import JsonSupport, commit_conversation_trace
|
||||||
JsonSupport,
|
|
||||||
ThreadedGenerator,
|
|
||||||
commit_conversation_trace,
|
|
||||||
)
|
|
||||||
from khoj.utils.helpers import (
|
from khoj.utils.helpers import (
|
||||||
get_chat_usage_metrics,
|
get_chat_usage_metrics,
|
||||||
|
get_openai_async_client,
|
||||||
get_openai_client,
|
get_openai_client,
|
||||||
is_promptrace_enabled,
|
is_promptrace_enabled,
|
||||||
)
|
)
|
||||||
@@ -30,6 +27,7 @@ from khoj.utils.helpers import (
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
openai_clients: Dict[str, openai.OpenAI] = {}
|
openai_clients: Dict[str, openai.OpenAI] = {}
|
||||||
|
openai_async_clients: Dict[str, openai.AsyncOpenAI] = {}
|
||||||
|
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
@@ -124,45 +122,22 @@ def completion_with_backoff(
|
|||||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||||
reraise=True,
|
reraise=True,
|
||||||
)
|
)
|
||||||
def chat_completion_with_backoff(
|
async def chat_completion_with_backoff(
|
||||||
messages,
|
messages,
|
||||||
compiled_references,
|
|
||||||
online_results,
|
|
||||||
model_name,
|
model_name,
|
||||||
temperature,
|
temperature,
|
||||||
openai_api_key=None,
|
openai_api_key=None,
|
||||||
api_base_url=None,
|
api_base_url=None,
|
||||||
completion_func=None,
|
|
||||||
deepthought=False,
|
|
||||||
model_kwargs=None,
|
|
||||||
tracer: dict = {},
|
|
||||||
):
|
|
||||||
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
|
|
||||||
t = Thread(
|
|
||||||
target=llm_thread,
|
|
||||||
args=(g, messages, model_name, temperature, openai_api_key, api_base_url, deepthought, model_kwargs, tracer),
|
|
||||||
)
|
|
||||||
t.start()
|
|
||||||
return g
|
|
||||||
|
|
||||||
|
|
||||||
def llm_thread(
|
|
||||||
g,
|
|
||||||
messages,
|
|
||||||
model_name: str,
|
|
||||||
temperature,
|
|
||||||
openai_api_key=None,
|
|
||||||
api_base_url=None,
|
|
||||||
deepthought=False,
|
deepthought=False,
|
||||||
model_kwargs: dict = {},
|
model_kwargs: dict = {},
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
):
|
) -> AsyncGenerator[str, None]:
|
||||||
try:
|
try:
|
||||||
client_key = f"{openai_api_key}--{api_base_url}"
|
client_key = f"{openai_api_key}--{api_base_url}"
|
||||||
client = openai_clients.get(client_key)
|
client = openai_async_clients.get(client_key)
|
||||||
if not client:
|
if not client:
|
||||||
client = get_openai_client(openai_api_key, api_base_url)
|
client = get_openai_async_client(openai_api_key, api_base_url)
|
||||||
openai_clients[client_key] = client
|
openai_async_clients[client_key] = client
|
||||||
|
|
||||||
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
|
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
|
||||||
|
|
||||||
@@ -207,53 +182,58 @@ def llm_thread(
|
|||||||
if os.getenv("KHOJ_LLM_SEED"):
|
if os.getenv("KHOJ_LLM_SEED"):
|
||||||
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
|
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
|
||||||
|
|
||||||
chat: ChatCompletion | openai.Stream[ChatCompletionChunk] = client.chat.completions.create(
|
aggregated_response = ""
|
||||||
messages=formatted_messages,
|
final_chunk = None
|
||||||
model=model_name, # type: ignore
|
start_time = perf_counter()
|
||||||
|
chat_stream: openai.AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(
|
||||||
|
messages=formatted_messages, # type: ignore
|
||||||
|
model=model_name,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
timeout=20,
|
timeout=20,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
async for chunk in chat_stream:
|
||||||
aggregated_response = ""
|
# Log the time taken to start response
|
||||||
if not stream:
|
if final_chunk is None:
|
||||||
chunk = chat
|
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
|
||||||
aggregated_response = chunk.choices[0].message.content
|
# Keep track of the last chunk for usage data
|
||||||
g.send(aggregated_response)
|
final_chunk = chunk
|
||||||
else:
|
# Handle streamed response chunk
|
||||||
for chunk in chat:
|
|
||||||
if len(chunk.choices) == 0:
|
if len(chunk.choices) == 0:
|
||||||
continue
|
continue
|
||||||
delta_chunk = chunk.choices[0].delta
|
delta_chunk = chunk.choices[0].delta
|
||||||
text_chunk = ""
|
text_chunk = ""
|
||||||
if isinstance(delta_chunk, str):
|
if isinstance(delta_chunk, str):
|
||||||
text_chunk = delta_chunk
|
text_chunk = delta_chunk
|
||||||
elif delta_chunk.content:
|
elif delta_chunk and delta_chunk.content:
|
||||||
text_chunk = delta_chunk.content
|
text_chunk = delta_chunk.content
|
||||||
if text_chunk:
|
if text_chunk:
|
||||||
aggregated_response += text_chunk
|
aggregated_response += text_chunk
|
||||||
g.send(text_chunk)
|
yield text_chunk
|
||||||
|
|
||||||
# Calculate cost of chat
|
# Log the time taken to stream the entire response
|
||||||
input_tokens = chunk.usage.prompt_tokens if hasattr(chunk, "usage") and chunk.usage else 0
|
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
|
||||||
output_tokens = chunk.usage.completion_tokens if hasattr(chunk, "usage") and chunk.usage else 0
|
|
||||||
cost = (
|
# Calculate cost of chat after stream finishes
|
||||||
chunk.usage.model_extra.get("estimated_cost", 0) if hasattr(chunk, "usage") and chunk.usage else 0
|
input_tokens, output_tokens, cost = 0, 0, 0
|
||||||
) # Estimated costs returned by DeepInfra API
|
if final_chunk and hasattr(final_chunk, "usage") and final_chunk.usage:
|
||||||
tracer["usage"] = get_chat_usage_metrics(
|
input_tokens = final_chunk.usage.prompt_tokens
|
||||||
model_name, input_tokens, output_tokens, usage=tracer.get("usage"), cost=cost
|
output_tokens = final_chunk.usage.completion_tokens
|
||||||
)
|
# Estimated costs returned by DeepInfra API
|
||||||
|
if final_chunk.usage.model_extra and "estimated_cost" in final_chunk.usage.model_extra:
|
||||||
|
cost = final_chunk.usage.model_extra.get("estimated_cost", 0)
|
||||||
|
|
||||||
# Save conversation trace
|
# Save conversation trace
|
||||||
tracer["chat_model"] = model_name
|
tracer["chat_model"] = model_name
|
||||||
tracer["temperature"] = temperature
|
tracer["temperature"] = temperature
|
||||||
|
tracer["usage"] = get_chat_usage_metrics(
|
||||||
|
model_name, input_tokens, output_tokens, usage=tracer.get("usage"), cost=cost
|
||||||
|
)
|
||||||
if is_promptrace_enabled():
|
if is_promptrace_enabled():
|
||||||
commit_conversation_trace(messages, aggregated_response, tracer)
|
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in llm_thread: {e}", exc_info=True)
|
logger.error(f"Error in chat_completion_with_backoff stream: {e}", exc_info=True)
|
||||||
finally:
|
|
||||||
g.close()
|
|
||||||
|
|
||||||
|
|
||||||
def get_openai_api_json_support(model_name: str, api_base_url: str = None) -> JsonSupport:
|
def get_openai_api_json_support(model_name: str, api_base_url: str = None) -> JsonSupport:
|
||||||
|
|||||||
@@ -254,7 +254,7 @@ def message_to_log(
|
|||||||
return conversation_log
|
return conversation_log
|
||||||
|
|
||||||
|
|
||||||
def save_to_conversation_log(
|
async def save_to_conversation_log(
|
||||||
q: str,
|
q: str,
|
||||||
chat_response: str,
|
chat_response: str,
|
||||||
user: KhojUser,
|
user: KhojUser,
|
||||||
@@ -306,7 +306,7 @@ def save_to_conversation_log(
|
|||||||
khoj_message_metadata=khoj_message_metadata,
|
khoj_message_metadata=khoj_message_metadata,
|
||||||
conversation_log=meta_log.get("chat", []),
|
conversation_log=meta_log.get("chat", []),
|
||||||
)
|
)
|
||||||
ConversationAdapters.save_conversation(
|
await ConversationAdapters.save_conversation(
|
||||||
user,
|
user,
|
||||||
{"chat": updated_conversation},
|
{"chat": updated_conversation},
|
||||||
client_application=client_application,
|
client_application=client_application,
|
||||||
|
|||||||
@@ -67,7 +67,6 @@ from khoj.routers.research import (
|
|||||||
from khoj.routers.storage import upload_user_image_to_bucket
|
from khoj.routers.storage import upload_user_image_to_bucket
|
||||||
from khoj.utils import state
|
from khoj.utils import state
|
||||||
from khoj.utils.helpers import (
|
from khoj.utils.helpers import (
|
||||||
AsyncIteratorWrapper,
|
|
||||||
ConversationCommand,
|
ConversationCommand,
|
||||||
command_descriptions,
|
command_descriptions,
|
||||||
convert_image_to_webp,
|
convert_image_to_webp,
|
||||||
@@ -999,7 +998,7 @@ async def chat(
|
|||||||
return
|
return
|
||||||
|
|
||||||
llm_response = construct_automation_created_message(automation, crontime, query_to_run, subject)
|
llm_response = construct_automation_created_message(automation, crontime, query_to_run, subject)
|
||||||
await sync_to_async(save_to_conversation_log)(
|
await save_to_conversation_log(
|
||||||
q,
|
q,
|
||||||
llm_response,
|
llm_response,
|
||||||
user,
|
user,
|
||||||
@@ -1308,26 +1307,31 @@ async def chat(
|
|||||||
yield result
|
yield result
|
||||||
|
|
||||||
continue_stream = True
|
continue_stream = True
|
||||||
iterator = AsyncIteratorWrapper(llm_response)
|
async for item in llm_response:
|
||||||
async for item in iterator:
|
# Should not happen with async generator, end is signaled by loop exit. Skip.
|
||||||
if item is None:
|
if item is None:
|
||||||
async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
|
continue
|
||||||
yield result
|
|
||||||
# Send Usage Metadata once llm interactions are complete
|
|
||||||
async for event in send_event(ChatEvent.USAGE, tracer.get("usage")):
|
|
||||||
yield event
|
|
||||||
async for result in send_event(ChatEvent.END_RESPONSE, ""):
|
|
||||||
yield result
|
|
||||||
logger.debug("Finished streaming response")
|
|
||||||
return
|
|
||||||
if not connection_alive or not continue_stream:
|
if not connection_alive or not continue_stream:
|
||||||
|
# Drain the generator if disconnected but keep processing internally
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
async for result in send_event(ChatEvent.MESSAGE, f"{item}"):
|
async for result in send_event(ChatEvent.MESSAGE, f"{item}"):
|
||||||
yield result
|
yield result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
continue_stream = False
|
continue_stream = False
|
||||||
logger.info(f"User {user} disconnected. Emitting rest of responses to clear thread: {e}")
|
logger.info(f"User {user} disconnected or error during streaming. Stopping send: {e}")
|
||||||
|
|
||||||
|
# Signal end of LLM response after the loop finishes
|
||||||
|
if connection_alive:
|
||||||
|
async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
|
||||||
|
yield result
|
||||||
|
# Send Usage Metadata once llm interactions are complete
|
||||||
|
if tracer.get("usage"):
|
||||||
|
async for event in send_event(ChatEvent.USAGE, tracer.get("usage")):
|
||||||
|
yield event
|
||||||
|
async for result in send_event(ChatEvent.END_RESPONSE, ""):
|
||||||
|
yield result
|
||||||
|
logger.debug("Finished streaming response")
|
||||||
|
|
||||||
## Stream Text Response
|
## Stream Text Response
|
||||||
if stream:
|
if stream:
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import asyncio
|
|
||||||
import base64
|
import base64
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
@@ -6,9 +5,7 @@ import logging
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from enum import Enum
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from random import random
|
from random import random
|
||||||
from typing import (
|
from typing import (
|
||||||
@@ -17,7 +14,6 @@ from typing import (
|
|||||||
AsyncGenerator,
|
AsyncGenerator,
|
||||||
Callable,
|
Callable,
|
||||||
Dict,
|
Dict,
|
||||||
Iterator,
|
|
||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
Set,
|
Set,
|
||||||
@@ -126,8 +122,6 @@ from khoj.utils.rawconfig import ChatRequestBody, FileAttachment, FileData, Loca
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
executor = ThreadPoolExecutor(max_workers=1)
|
|
||||||
|
|
||||||
|
|
||||||
NOTION_OAUTH_CLIENT_ID = os.getenv("NOTION_OAUTH_CLIENT_ID")
|
NOTION_OAUTH_CLIENT_ID = os.getenv("NOTION_OAUTH_CLIENT_ID")
|
||||||
NOTION_OAUTH_CLIENT_SECRET = os.getenv("NOTION_OAUTH_CLIENT_SECRET")
|
NOTION_OAUTH_CLIENT_SECRET = os.getenv("NOTION_OAUTH_CLIENT_SECRET")
|
||||||
@@ -262,11 +256,6 @@ def get_conversation_command(query: str) -> ConversationCommand:
|
|||||||
return ConversationCommand.Default
|
return ConversationCommand.Default
|
||||||
|
|
||||||
|
|
||||||
async def agenerate_chat_response(*args):
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
return await loop.run_in_executor(executor, generate_chat_response, *args)
|
|
||||||
|
|
||||||
|
|
||||||
def gather_raw_query_files(
|
def gather_raw_query_files(
|
||||||
query_files: Dict[str, str],
|
query_files: Dict[str, str],
|
||||||
):
|
):
|
||||||
@@ -1418,7 +1407,7 @@ def send_message_to_model_wrapper_sync(
|
|||||||
raise HTTPException(status_code=500, detail="Invalid conversation config")
|
raise HTTPException(status_code=500, detail="Invalid conversation config")
|
||||||
|
|
||||||
|
|
||||||
def generate_chat_response(
|
async def agenerate_chat_response(
|
||||||
q: str,
|
q: str,
|
||||||
meta_log: dict,
|
meta_log: dict,
|
||||||
conversation: Conversation,
|
conversation: Conversation,
|
||||||
@@ -1444,13 +1433,14 @@ def generate_chat_response(
|
|||||||
generated_asset_results: Dict[str, Dict] = {},
|
generated_asset_results: Dict[str, Dict] = {},
|
||||||
is_subscribed: bool = False,
|
is_subscribed: bool = False,
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
|
) -> Tuple[AsyncGenerator[str, None], Dict[str, str]]:
|
||||||
# Initialize Variables
|
# Initialize Variables
|
||||||
chat_response = None
|
chat_response_generator = None
|
||||||
logger.debug(f"Conversation Types: {conversation_commands}")
|
logger.debug(f"Conversation Types: {conversation_commands}")
|
||||||
|
|
||||||
metadata = {}
|
metadata = {}
|
||||||
agent = AgentAdapters.get_conversation_agent_by_id(conversation.agent.id) if conversation.agent else None
|
agent = await AgentAdapters.aget_conversation_agent_by_id(conversation.agent.id) if conversation.agent else None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
partial_completion = partial(
|
partial_completion = partial(
|
||||||
save_to_conversation_log,
|
save_to_conversation_log,
|
||||||
@@ -1481,23 +1471,25 @@ def generate_chat_response(
|
|||||||
code_results = {}
|
code_results = {}
|
||||||
deepthought = True
|
deepthought = True
|
||||||
|
|
||||||
chat_model = ConversationAdapters.get_valid_chat_model(user, conversation, is_subscribed)
|
chat_model = await ConversationAdapters.aget_valid_chat_model(user, conversation, is_subscribed)
|
||||||
vision_available = chat_model.vision_enabled
|
vision_available = chat_model.vision_enabled
|
||||||
if not vision_available and query_images:
|
if not vision_available and query_images:
|
||||||
vision_enabled_config = ConversationAdapters.get_vision_enabled_config()
|
vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config()
|
||||||
if vision_enabled_config:
|
if vision_enabled_config:
|
||||||
chat_model = vision_enabled_config
|
chat_model = vision_enabled_config
|
||||||
vision_available = True
|
vision_available = True
|
||||||
|
|
||||||
if chat_model.model_type == "offline":
|
if chat_model.model_type == "offline":
|
||||||
|
# Assuming converse_offline remains sync or is refactored separately
|
||||||
loaded_model = state.offline_chat_processor_config.loaded_model
|
loaded_model = state.offline_chat_processor_config.loaded_model
|
||||||
chat_response = converse_offline(
|
# If converse_offline returns an iterator, wrap it if needed, or refactor it to async generator
|
||||||
|
chat_response_generator = converse_offline( # Needs adaptation if it becomes async
|
||||||
user_query=query_to_run,
|
user_query=query_to_run,
|
||||||
references=compiled_references,
|
references=compiled_references,
|
||||||
online_results=online_results,
|
online_results=online_results,
|
||||||
loaded_model=loaded_model,
|
loaded_model=loaded_model,
|
||||||
conversation_log=meta_log,
|
conversation_log=meta_log,
|
||||||
completion_func=partial_completion,
|
completion_func=partial_completion, # Pass the async wrapper
|
||||||
conversation_commands=conversation_commands,
|
conversation_commands=conversation_commands,
|
||||||
model_name=chat_model.name,
|
model_name=chat_model.name,
|
||||||
max_prompt_size=chat_model.max_prompt_size,
|
max_prompt_size=chat_model.max_prompt_size,
|
||||||
@@ -1515,7 +1507,7 @@ def generate_chat_response(
|
|||||||
openai_chat_config = chat_model.ai_model_api
|
openai_chat_config = chat_model.ai_model_api
|
||||||
api_key = openai_chat_config.api_key
|
api_key = openai_chat_config.api_key
|
||||||
chat_model_name = chat_model.name
|
chat_model_name = chat_model.name
|
||||||
chat_response = converse_openai(
|
chat_response_generator = converse_openai(
|
||||||
compiled_references,
|
compiled_references,
|
||||||
query_to_run,
|
query_to_run,
|
||||||
query_images=query_images,
|
query_images=query_images,
|
||||||
@@ -1542,9 +1534,10 @@ def generate_chat_response(
|
|||||||
)
|
)
|
||||||
|
|
||||||
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
|
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
|
||||||
|
# Assuming converse_anthropic remains sync or is refactored separately
|
||||||
api_key = chat_model.ai_model_api.api_key
|
api_key = chat_model.ai_model_api.api_key
|
||||||
api_base_url = chat_model.ai_model_api.api_base_url
|
api_base_url = chat_model.ai_model_api.api_base_url
|
||||||
chat_response = converse_anthropic(
|
chat_response_generator = converse_anthropic( # Needs adaptation if it becomes async
|
||||||
compiled_references,
|
compiled_references,
|
||||||
query_to_run,
|
query_to_run,
|
||||||
query_images=query_images,
|
query_images=query_images,
|
||||||
@@ -1570,9 +1563,10 @@ def generate_chat_response(
|
|||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
elif chat_model.model_type == ChatModel.ModelType.GOOGLE:
|
elif chat_model.model_type == ChatModel.ModelType.GOOGLE:
|
||||||
|
# Assuming converse_gemini remains sync or is refactored separately
|
||||||
api_key = chat_model.ai_model_api.api_key
|
api_key = chat_model.ai_model_api.api_key
|
||||||
api_base_url = chat_model.ai_model_api.api_base_url
|
api_base_url = chat_model.ai_model_api.api_base_url
|
||||||
chat_response = converse_gemini(
|
chat_response_generator = converse_gemini( # Needs adaptation if it becomes async
|
||||||
compiled_references,
|
compiled_references,
|
||||||
query_to_run,
|
query_to_run,
|
||||||
online_results,
|
online_results,
|
||||||
@@ -1604,7 +1598,8 @@ def generate_chat_response(
|
|||||||
logger.error(e, exc_info=True)
|
logger.error(e, exc_info=True)
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
return chat_response, metadata
|
# Return the generator directly
|
||||||
|
return chat_response_generator, metadata
|
||||||
|
|
||||||
|
|
||||||
class DeleteMessageRequestBody(BaseModel):
|
class DeleteMessageRequestBody(BaseModel):
|
||||||
|
|||||||
Reference in New Issue
Block a user