Simplify AI Chat Response Streaming (#1167)

Reason
---
- Simplify code and logic to stream chat response by solely relying on
asyncio event loop.
- Reduce overhead of managing threads to increase efficiency and
throughput (where possible).

Details
---
- Use async/await with no threading when generating chat response via
OpenAI, Gemini, Anthropic AI model APIs
- Use threading for offline chat model as llama-cpp doesn't support
async streaming yet
This commit is contained in:
Debanjum
2025-04-21 14:28:02 +05:30
committed by GitHub
12 changed files with 357 additions and 361 deletions

View File

@@ -763,9 +763,9 @@ class AgentAdapters:
return False
@staticmethod
def get_conversation_agent_by_id(agent_id: int):
agent = Agent.objects.filter(id=agent_id).first()
if agent == AgentAdapters.get_default_agent():
async def aget_conversation_agent_by_id(agent_id: int):
agent = await Agent.objects.filter(id=agent_id).afirst()
if agent == await AgentAdapters.aget_default_agent():
# If the agent is set to the default agent, then return None and let the default application code be used
return None
return agent
@@ -1109,14 +1109,6 @@ class ConversationAdapters:
async def aget_all_chat_models():
return await sync_to_async(list)(ChatModel.objects.prefetch_related("ai_model_api").all())
@staticmethod
def get_vision_enabled_config():
chat_models = ConversationAdapters.get_all_chat_models()
for config in chat_models:
if config.vision_enabled:
return config
return None
@staticmethod
async def aget_vision_enabled_config():
chat_models = await ConversationAdapters.aget_all_chat_models()
@@ -1171,7 +1163,11 @@ class ConversationAdapters:
@staticmethod
async def aget_chat_model(user: KhojUser):
subscribed = await ais_user_subscribed(user)
config = await UserConversationConfig.objects.filter(user=user).prefetch_related("setting").afirst()
config = (
await UserConversationConfig.objects.filter(user=user)
.prefetch_related("setting", "setting__ai_model_api")
.afirst()
)
if subscribed:
# Subscibed users can use any available chat model
if config:
@@ -1387,7 +1383,7 @@ class ConversationAdapters:
@staticmethod
@require_valid_user
def save_conversation(
async def save_conversation(
user: KhojUser,
conversation_log: dict,
client_application: ClientApplication = None,
@@ -1396,19 +1392,21 @@ class ConversationAdapters:
):
slug = user_message.strip()[:200] if user_message else None
if conversation_id:
conversation = Conversation.objects.filter(user=user, client=client_application, id=conversation_id).first()
conversation = await Conversation.objects.filter(
user=user, client=client_application, id=conversation_id
).afirst()
else:
conversation = (
Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at").first()
await Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at").afirst()
)
if conversation:
conversation.conversation_log = conversation_log
conversation.slug = slug
conversation.updated_at = datetime.now(tz=timezone.utc)
conversation.save()
await conversation.asave()
else:
Conversation.objects.create(
await Conversation.objects.acreate(
user=user, conversation_log=conversation_log, client=client_application, slug=slug
)
@@ -1455,17 +1453,21 @@ class ConversationAdapters:
return random.sample(all_questions, max_results)
@staticmethod
def get_valid_chat_model(user: KhojUser, conversation: Conversation, is_subscribed: bool):
async def aget_valid_chat_model(user: KhojUser, conversation: Conversation, is_subscribed: bool):
agent: Agent = (
conversation.agent if is_subscribed and AgentAdapters.get_default_agent() != conversation.agent else None
conversation.agent
if is_subscribed and await AgentAdapters.aget_default_agent() != conversation.agent
else None
)
if agent and agent.chat_model:
chat_model = conversation.agent.chat_model
chat_model = await ChatModel.objects.select_related("ai_model_api").aget(
pk=conversation.agent.chat_model.pk
)
else:
chat_model = ConversationAdapters.get_chat_model(user)
chat_model = await ConversationAdapters.aget_chat_model(user)
if chat_model is None:
chat_model = ConversationAdapters.get_default_chat_model()
chat_model = await ConversationAdapters.aget_default_chat_model()
if chat_model.model_type == ChatModel.ModelType.OFFLINE:
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:

View File

@@ -1,6 +1,6 @@
import logging
from datetime import datetime, timedelta
from typing import Dict, List, Optional
from typing import AsyncGenerator, Dict, List, Optional
import pyjson5
from langchain.schema import ChatMessage
@@ -137,7 +137,7 @@ def anthropic_send_message_to_model(
)
def converse_anthropic(
async def converse_anthropic(
references,
user_query,
online_results: Optional[Dict[str, Dict]] = None,
@@ -161,7 +161,7 @@ def converse_anthropic(
generated_asset_results: Dict[str, Dict] = {},
deepthought: Optional[bool] = False,
tracer: dict = {},
):
) -> AsyncGenerator[str, None]:
"""
Converse with user using Anthropic's Claude
"""
@@ -191,11 +191,17 @@ def converse_anthropic(
# Get Conversation Primer appropriate to Conversation Type
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
completion_func(chat_response=prompts.no_notes_found.format())
return iter([prompts.no_notes_found.format()])
response = prompts.no_notes_found.format()
if completion_func:
await completion_func(chat_response=response)
yield response
return
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
completion_func(chat_response=prompts.no_online_results_found.format())
return iter([prompts.no_online_results_found.format()])
response = prompts.no_online_results_found.format()
if completion_func:
await completion_func(chat_response=response)
yield response
return
context_message = ""
if not is_none_or_empty(references):
@@ -228,17 +234,21 @@ def converse_anthropic(
logger.debug(f"Conversation Context for Claude: {messages_to_print(messages)}")
# Get Response from Claude
return anthropic_chat_completion_with_backoff(
full_response = ""
async for chunk in anthropic_chat_completion_with_backoff(
messages=messages,
compiled_references=references,
online_results=online_results,
model_name=model,
temperature=0.2,
api_key=api_key,
api_base_url=api_base_url,
system_prompt=system_prompt,
completion_func=completion_func,
max_prompt_size=max_prompt_size,
deepthought=deepthought,
tracer=tracer,
)
):
full_response += chunk
yield chunk
# Call completion_func once finish streaming and we have the full response
if completion_func:
await completion_func(chat_response=full_response)

View File

@@ -1,5 +1,5 @@
import logging
from threading import Thread
from time import perf_counter
from typing import Dict, List
import anthropic
@@ -13,13 +13,13 @@ from tenacity import (
)
from khoj.processor.conversation.utils import (
ThreadedGenerator,
commit_conversation_trace,
get_image_from_base64,
get_image_from_url,
)
from khoj.utils.helpers import (
get_ai_api_info,
get_anthropic_async_client,
get_anthropic_client,
get_chat_usage_metrics,
is_none_or_empty,
is_promptrace_enabled,
@@ -28,24 +28,12 @@ from khoj.utils.helpers import (
logger = logging.getLogger(__name__)
anthropic_clients: Dict[str, anthropic.Anthropic | anthropic.AnthropicVertex] = {}
anthropic_async_clients: Dict[str, anthropic.AsyncAnthropic | anthropic.AsyncAnthropicVertex] = {}
DEFAULT_MAX_TOKENS_ANTHROPIC = 8000
MAX_REASONING_TOKENS_ANTHROPIC = 12000
def get_anthropic_client(api_key, api_base_url=None) -> anthropic.Anthropic | anthropic.AnthropicVertex:
api_info = get_ai_api_info(api_key, api_base_url)
if api_info.api_key:
client = anthropic.Anthropic(api_key=api_info.api_key)
else:
client = anthropic.AnthropicVertex(
region=api_info.region,
project_id=api_info.project,
credentials=api_info.credentials,
)
return client
@retry(
wait=wait_random_exponential(min=1, max=10),
stop=stop_after_attempt(2),
@@ -126,60 +114,23 @@ def anthropic_completion_with_backoff(
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
def anthropic_chat_completion_with_backoff(
async def anthropic_chat_completion_with_backoff(
messages: list[ChatMessage],
compiled_references,
online_results,
model_name,
temperature,
api_key,
api_base_url,
system_prompt: str,
max_prompt_size=None,
completion_func=None,
deepthought=False,
model_kwargs=None,
tracer={},
):
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
t = Thread(
target=anthropic_llm_thread,
args=(
g,
messages,
system_prompt,
model_name,
temperature,
api_key,
api_base_url,
max_prompt_size,
deepthought,
model_kwargs,
tracer,
),
)
t.start()
return g
def anthropic_llm_thread(
g,
messages: list[ChatMessage],
system_prompt: str,
model_name: str,
temperature,
api_key,
api_base_url=None,
max_prompt_size=None,
deepthought=False,
model_kwargs=None,
tracer={},
):
try:
client = anthropic_clients.get(api_key)
client = anthropic_async_clients.get(api_key)
if not client:
client = get_anthropic_client(api_key, api_base_url)
anthropic_clients[api_key] = client
client = get_anthropic_async_client(api_key, api_base_url)
anthropic_async_clients[api_key] = client
model_kwargs = model_kwargs or dict()
max_tokens = DEFAULT_MAX_TOKENS_ANTHROPIC
@@ -193,7 +144,8 @@ def anthropic_llm_thread(
aggregated_response = ""
final_message = None
with client.messages.stream(
start_time = perf_counter()
async with client.messages.stream(
messages=formatted_messages,
model=model_name, # type: ignore
temperature=temperature,
@@ -202,10 +154,17 @@ def anthropic_llm_thread(
max_tokens=max_tokens,
**model_kwargs,
) as stream:
for text in stream.text_stream:
async for text in stream.text_stream:
# Log the time taken to start response
if aggregated_response == "":
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
# Handle streamed response chunk
aggregated_response += text
g.send(text)
final_message = stream.get_final_message()
yield text
final_message = await stream.get_final_message()
# Log the time taken to stream the entire response
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
# Calculate cost of chat
input_tokens = final_message.usage.input_tokens
@@ -222,9 +181,7 @@ def anthropic_llm_thread(
if is_promptrace_enabled():
commit_conversation_trace(messages, aggregated_response, tracer)
except Exception as e:
logger.error(f"Error in anthropic_llm_thread: {e}", exc_info=True)
finally:
g.close()
logger.error(f"Error in anthropic_chat_completion_with_backoff stream: {e}", exc_info=True)
def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt: str = None):

View File

@@ -1,6 +1,6 @@
import logging
from datetime import datetime, timedelta
from typing import Dict, List, Optional
from typing import AsyncGenerator, Dict, List, Optional
import pyjson5
from langchain.schema import ChatMessage
@@ -160,7 +160,7 @@ def gemini_send_message_to_model(
)
def converse_gemini(
async def converse_gemini(
references,
user_query,
online_results: Optional[Dict[str, Dict]] = None,
@@ -185,7 +185,7 @@ def converse_gemini(
program_execution_context: List[str] = None,
deepthought: Optional[bool] = False,
tracer={},
):
) -> AsyncGenerator[str, None]:
"""
Converse with user using Google's Gemini
"""
@@ -216,11 +216,17 @@ def converse_gemini(
# Get Conversation Primer appropriate to Conversation Type
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
completion_func(chat_response=prompts.no_notes_found.format())
return iter([prompts.no_notes_found.format()])
response = prompts.no_notes_found.format()
if completion_func:
await completion_func(chat_response=response)
yield response
return
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
completion_func(chat_response=prompts.no_online_results_found.format())
return iter([prompts.no_online_results_found.format()])
response = prompts.no_online_results_found.format()
if completion_func:
await completion_func(chat_response=response)
yield response
return
context_message = ""
if not is_none_or_empty(references):
@@ -253,16 +259,20 @@ def converse_gemini(
logger.debug(f"Conversation Context for Gemini: {messages_to_print(messages)}")
# Get Response from Google AI
return gemini_chat_completion_with_backoff(
full_response = ""
async for chunk in gemini_chat_completion_with_backoff(
messages=messages,
compiled_references=references,
online_results=online_results,
model_name=model,
temperature=temperature,
api_key=api_key,
api_base_url=api_base_url,
system_prompt=system_prompt,
completion_func=completion_func,
deepthought=deepthought,
tracer=tracer,
)
):
full_response += chunk
yield chunk
# Call completion_func once finish streaming and we have the full response
if completion_func:
await completion_func(chat_response=full_response)

View File

@@ -2,8 +2,8 @@ import logging
import os
import random
from copy import deepcopy
from threading import Thread
from typing import Dict
from time import perf_counter
from typing import AsyncGenerator, AsyncIterator, Dict
from google import genai
from google.genai import errors as gerrors
@@ -19,14 +19,13 @@ from tenacity import (
)
from khoj.processor.conversation.utils import (
ThreadedGenerator,
commit_conversation_trace,
get_image_from_base64,
get_image_from_url,
)
from khoj.utils.helpers import (
get_ai_api_info,
get_chat_usage_metrics,
get_gemini_client,
is_none_or_empty,
is_promptrace_enabled,
)
@@ -62,17 +61,6 @@ SAFETY_SETTINGS = [
]
def get_gemini_client(api_key, api_base_url=None) -> genai.Client:
api_info = get_ai_api_info(api_key, api_base_url)
return genai.Client(
location=api_info.region,
project=api_info.project,
credentials=api_info.credentials,
api_key=api_info.api_key,
vertexai=api_info.api_key is None,
)
@retry(
wait=wait_random_exponential(min=1, max=10),
stop=stop_after_attempt(2),
@@ -132,8 +120,8 @@ def gemini_completion_with_backoff(
)
# Aggregate cost of chat
input_tokens = response.usage_metadata.prompt_token_count if response else 0
output_tokens = response.usage_metadata.candidates_token_count if response else 0
input_tokens = response.usage_metadata.prompt_token_count or 0 if response else 0
output_tokens = response.usage_metadata.candidates_token_count or 0 if response else 0
thought_tokens = response.usage_metadata.thoughts_token_count or 0 if response else 0
tracer["usage"] = get_chat_usage_metrics(
model_name, input_tokens, output_tokens, thought_tokens=thought_tokens, usage=tracer.get("usage")
@@ -154,52 +142,17 @@ def gemini_completion_with_backoff(
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
def gemini_chat_completion_with_backoff(
async def gemini_chat_completion_with_backoff(
messages,
compiled_references,
online_results,
model_name,
temperature,
api_key,
api_base_url,
system_prompt,
completion_func=None,
model_kwargs=None,
deepthought=False,
tracer: dict = {},
):
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
t = Thread(
target=gemini_llm_thread,
args=(
g,
messages,
system_prompt,
model_name,
temperature,
api_key,
api_base_url,
model_kwargs,
deepthought,
tracer,
),
)
t.start()
return g
def gemini_llm_thread(
g,
messages,
system_prompt,
model_name,
temperature,
api_key,
api_base_url=None,
model_kwargs=None,
deepthought=False,
tracer: dict = {},
):
) -> AsyncGenerator[str, None]:
try:
client = gemini_clients.get(api_key)
if not client:
@@ -224,21 +177,32 @@ def gemini_llm_thread(
)
aggregated_response = ""
for chunk in client.models.generate_content_stream(
final_chunk = None
start_time = perf_counter()
chat_stream: AsyncIterator[gtypes.GenerateContentResponse] = await client.aio.models.generate_content_stream(
model=model_name, config=config, contents=formatted_messages
):
)
async for chunk in chat_stream:
# Log the time taken to start response
if final_chunk is None:
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
# Keep track of the last chunk for usage data
final_chunk = chunk
# Handle streamed response chunk
message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback)
message = message or chunk.text
aggregated_response += message
g.send(message)
yield message
if stopped:
raise ValueError(message)
# Log the time taken to stream the entire response
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
# Calculate cost of chat
input_tokens = chunk.usage_metadata.prompt_token_count
output_tokens = chunk.usage_metadata.candidates_token_count
thought_tokens = chunk.usage_metadata.thoughts_token_count or 0
input_tokens = final_chunk.usage_metadata.prompt_token_count or 0 if final_chunk else 0
output_tokens = final_chunk.usage_metadata.candidates_token_count or 0 if final_chunk else 0
thought_tokens = final_chunk.usage_metadata.thoughts_token_count or 0 if final_chunk else 0
tracer["usage"] = get_chat_usage_metrics(
model_name, input_tokens, output_tokens, thought_tokens=thought_tokens, usage=tracer.get("usage")
)
@@ -254,9 +218,7 @@ def gemini_llm_thread(
+ f"Last Message by {messages[-1].role}: {messages[-1].content}"
)
except Exception as e:
logger.error(f"Error in gemini_llm_thread: {e}", exc_info=True)
finally:
g.close()
logger.error(f"Error in gemini_chat_completion_with_backoff stream: {e}", exc_info=True)
def handle_gemini_response(

View File

@@ -1,9 +1,10 @@
import json
import asyncio
import logging
import os
from datetime import datetime, timedelta
from threading import Thread
from typing import Any, Dict, Iterator, List, Optional, Union
from time import perf_counter
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
import pyjson5
from langchain.schema import ChatMessage
@@ -13,7 +14,6 @@ from khoj.database.models import Agent, ChatModel, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.offline.utils import download_model
from khoj.processor.conversation.utils import (
ThreadedGenerator,
clean_json,
commit_conversation_trace,
generate_chatml_messages_with_context,
@@ -147,7 +147,7 @@ def filter_questions(questions: List[str]):
return list(filtered_questions)
def converse_offline(
async def converse_offline(
user_query,
references=[],
online_results={},
@@ -167,9 +167,9 @@ def converse_offline(
additional_context: List[str] = None,
generated_asset_results: Dict[str, Dict] = {},
tracer: dict = {},
) -> Union[ThreadedGenerator, Iterator[str]]:
) -> AsyncGenerator[str, None]:
"""
Converse with user using Llama
Converse with user using Llama (Async Version)
"""
# Initialize Variables
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
@@ -200,10 +200,17 @@ def converse_offline(
# Get Conversation Primer appropriate to Conversation Type
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
return iter([prompts.no_notes_found.format()])
response = prompts.no_notes_found.format()
if completion_func:
await completion_func(chat_response=response)
yield response
return
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
completion_func(chat_response=prompts.no_online_results_found.format())
return iter([prompts.no_online_results_found.format()])
response = prompts.no_online_results_found.format()
if completion_func:
await completion_func(chat_response=response)
yield response
return
context_message = ""
if not is_none_or_empty(references):
@@ -240,33 +247,77 @@ def converse_offline(
logger.debug(f"Conversation Context for {model_name}: {messages_to_print(messages)}")
g = ThreadedGenerator(references, online_results, completion_func=completion_func)
t = Thread(target=llm_thread, args=(g, messages, offline_chat_model, max_prompt_size, tracer))
t.start()
return g
def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int = None, tracer: dict = {}):
# Use asyncio.Queue and a thread to bridge sync iterator
queue: asyncio.Queue = asyncio.Queue()
stop_phrases = ["<s>", "INST]", "Notes:"]
aggregated_response = ""
aggregated_response_container = {"response": ""}
state.chat_lock.acquire()
try:
response_iterator = send_message_to_model_offline(
messages, loaded_model=model, stop=stop_phrases, max_prompt_size=max_prompt_size, streaming=True
)
for response in response_iterator:
response_delta = response["choices"][0]["delta"].get("content", "")
aggregated_response += response_delta
g.send(response_delta)
def _sync_llm_thread():
"""Synchronous function to run in a separate thread."""
aggregated_response = ""
start_time = perf_counter()
state.chat_lock.acquire()
try:
response_iterator = send_message_to_model_offline(
messages,
loaded_model=offline_chat_model,
stop=stop_phrases,
max_prompt_size=max_prompt_size,
streaming=True,
tracer=tracer,
)
for response in response_iterator:
response_delta = response["choices"][0]["delta"].get("content", "")
# Log the time taken to start response
if aggregated_response == "" and response_delta != "":
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
# Handle response chunk
aggregated_response += response_delta
# Put chunk into the asyncio queue (non-blocking)
try:
queue.put_nowait(response_delta)
except asyncio.QueueFull:
# Should not happen with default queue size unless consumer is very slow
logger.warning("Asyncio queue full during offline LLM streaming.")
# Potentially block here or handle differently if needed
asyncio.run(queue.put(response_delta))
# Save conversation trace
if is_promptrace_enabled():
commit_conversation_trace(messages, aggregated_response, tracer)
# Log the time taken to stream the entire response
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
finally:
state.chat_lock.release()
g.close()
# Save conversation trace
tracer["chat_model"] = model_name
if is_promptrace_enabled():
commit_conversation_trace(messages, aggregated_response, tracer)
except Exception as e:
logger.error(f"Error in offline LLM thread: {e}", exc_info=True)
finally:
state.chat_lock.release()
# Signal end of stream
queue.put_nowait(None)
aggregated_response_container["response"] = aggregated_response
# Start the synchronous thread
thread = Thread(target=_sync_llm_thread)
thread.start()
# Asynchronously consume from the queue
while True:
chunk = await queue.get()
if chunk is None: # End of stream signal
queue.task_done()
break
yield chunk
queue.task_done()
# Wait for the thread to finish (optional, ensures cleanup)
loop = asyncio.get_running_loop()
await loop.run_in_executor(None, thread.join)
# Call the completion function after streaming is done
if completion_func:
await completion_func(chat_response=aggregated_response_container["response"])
def send_message_to_model_offline(

View File

@@ -1,6 +1,6 @@
import logging
from datetime import datetime, timedelta
from typing import Dict, List, Optional
from typing import AsyncGenerator, Dict, List, Optional
import pyjson5
from langchain.schema import ChatMessage
@@ -162,7 +162,7 @@ def send_message_to_model(
)
def converse_openai(
async def converse_openai(
references,
user_query,
online_results: Optional[Dict[str, Dict]] = None,
@@ -187,7 +187,7 @@ def converse_openai(
program_execution_context: List[str] = None,
deepthought: Optional[bool] = False,
tracer: dict = {},
):
) -> AsyncGenerator[str, None]:
"""
Converse with user using OpenAI's ChatGPT
"""
@@ -217,11 +217,17 @@ def converse_openai(
# Get Conversation Primer appropriate to Conversation Type
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
completion_func(chat_response=prompts.no_notes_found.format())
return iter([prompts.no_notes_found.format()])
response = prompts.no_notes_found.format()
if completion_func:
await completion_func(chat_response=response)
yield response
return
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
completion_func(chat_response=prompts.no_online_results_found.format())
return iter([prompts.no_online_results_found.format()])
response = prompts.no_online_results_found.format()
if completion_func:
await completion_func(chat_response=response)
yield response
return
context_message = ""
if not is_none_or_empty(references):
@@ -255,19 +261,23 @@ def converse_openai(
logger.debug(f"Conversation Context for GPT: {messages_to_print(messages)}")
# Get Response from GPT
return chat_completion_with_backoff(
full_response = ""
async for chunk in chat_completion_with_backoff(
messages=messages,
compiled_references=references,
online_results=online_results,
model_name=model,
temperature=temperature,
openai_api_key=api_key,
api_base_url=api_base_url,
completion_func=completion_func,
deepthought=deepthought,
model_kwargs={"stop": ["Notes:\n["]},
tracer=tracer,
)
):
full_response += chunk
yield chunk
# Call completion_func once finish streaming and we have the full response
if completion_func:
await completion_func(chat_response=full_response)
def clean_response_schema(schema: BaseModel | dict) -> dict:

View File

@@ -1,7 +1,7 @@
import logging
import os
from threading import Thread
from typing import Dict, List
from time import perf_counter
from typing import AsyncGenerator, Dict, List
from urllib.parse import urlparse
import openai
@@ -16,13 +16,10 @@ from tenacity import (
wait_random_exponential,
)
from khoj.processor.conversation.utils import (
JsonSupport,
ThreadedGenerator,
commit_conversation_trace,
)
from khoj.processor.conversation.utils import JsonSupport, commit_conversation_trace
from khoj.utils.helpers import (
get_chat_usage_metrics,
get_openai_async_client,
get_openai_client,
is_promptrace_enabled,
)
@@ -30,6 +27,7 @@ from khoj.utils.helpers import (
logger = logging.getLogger(__name__)
openai_clients: Dict[str, openai.OpenAI] = {}
openai_async_clients: Dict[str, openai.AsyncOpenAI] = {}
@retry(
@@ -124,45 +122,22 @@ def completion_with_backoff(
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
def chat_completion_with_backoff(
async def chat_completion_with_backoff(
messages,
compiled_references,
online_results,
model_name,
temperature,
openai_api_key=None,
api_base_url=None,
completion_func=None,
deepthought=False,
model_kwargs=None,
tracer: dict = {},
):
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
t = Thread(
target=llm_thread,
args=(g, messages, model_name, temperature, openai_api_key, api_base_url, deepthought, model_kwargs, tracer),
)
t.start()
return g
def llm_thread(
g,
messages,
model_name: str,
temperature,
openai_api_key=None,
api_base_url=None,
deepthought=False,
model_kwargs: dict = {},
tracer: dict = {},
):
) -> AsyncGenerator[str, None]:
try:
client_key = f"{openai_api_key}--{api_base_url}"
client = openai_clients.get(client_key)
client = openai_async_clients.get(client_key)
if not client:
client = get_openai_client(openai_api_key, api_base_url)
openai_clients[client_key] = client
client = get_openai_async_client(openai_api_key, api_base_url)
openai_async_clients[client_key] = client
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
@@ -207,53 +182,58 @@ def llm_thread(
if os.getenv("KHOJ_LLM_SEED"):
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
chat: ChatCompletion | openai.Stream[ChatCompletionChunk] = client.chat.completions.create(
messages=formatted_messages,
model=model_name, # type: ignore
aggregated_response = ""
final_chunk = None
start_time = perf_counter()
chat_stream: openai.AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(
messages=formatted_messages, # type: ignore
model=model_name,
stream=stream,
temperature=temperature,
timeout=20,
**model_kwargs,
)
async for chunk in chat_stream:
# Log the time taken to start response
if final_chunk is None:
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
# Keep track of the last chunk for usage data
final_chunk = chunk
# Handle streamed response chunk
if len(chunk.choices) == 0:
continue
delta_chunk = chunk.choices[0].delta
text_chunk = ""
if isinstance(delta_chunk, str):
text_chunk = delta_chunk
elif delta_chunk and delta_chunk.content:
text_chunk = delta_chunk.content
if text_chunk:
aggregated_response += text_chunk
yield text_chunk
aggregated_response = ""
if not stream:
chunk = chat
aggregated_response = chunk.choices[0].message.content
g.send(aggregated_response)
else:
for chunk in chat:
if len(chunk.choices) == 0:
continue
delta_chunk = chunk.choices[0].delta
text_chunk = ""
if isinstance(delta_chunk, str):
text_chunk = delta_chunk
elif delta_chunk.content:
text_chunk = delta_chunk.content
if text_chunk:
aggregated_response += text_chunk
g.send(text_chunk)
# Log the time taken to stream the entire response
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
# Calculate cost of chat
input_tokens = chunk.usage.prompt_tokens if hasattr(chunk, "usage") and chunk.usage else 0
output_tokens = chunk.usage.completion_tokens if hasattr(chunk, "usage") and chunk.usage else 0
cost = (
chunk.usage.model_extra.get("estimated_cost", 0) if hasattr(chunk, "usage") and chunk.usage else 0
) # Estimated costs returned by DeepInfra API
tracer["usage"] = get_chat_usage_metrics(
model_name, input_tokens, output_tokens, usage=tracer.get("usage"), cost=cost
)
# Calculate cost of chat after stream finishes
input_tokens, output_tokens, cost = 0, 0, 0
if final_chunk and hasattr(final_chunk, "usage") and final_chunk.usage:
input_tokens = final_chunk.usage.prompt_tokens
output_tokens = final_chunk.usage.completion_tokens
# Estimated costs returned by DeepInfra API
if final_chunk.usage.model_extra and "estimated_cost" in final_chunk.usage.model_extra:
cost = final_chunk.usage.model_extra.get("estimated_cost", 0)
# Save conversation trace
tracer["chat_model"] = model_name
tracer["temperature"] = temperature
tracer["usage"] = get_chat_usage_metrics(
model_name, input_tokens, output_tokens, usage=tracer.get("usage"), cost=cost
)
if is_promptrace_enabled():
commit_conversation_trace(messages, aggregated_response, tracer)
except Exception as e:
logger.error(f"Error in llm_thread: {e}", exc_info=True)
finally:
g.close()
logger.error(f"Error in chat_completion_with_backoff stream: {e}", exc_info=True)
def get_openai_api_json_support(model_name: str, api_base_url: str = None) -> JsonSupport:

View File

@@ -77,42 +77,6 @@ model_to_prompt_size = {
model_to_tokenizer: Dict[str, str] = {}
class ThreadedGenerator:
def __init__(self, compiled_references, online_results, completion_func=None):
self.queue = queue.Queue()
self.compiled_references = compiled_references
self.online_results = online_results
self.completion_func = completion_func
self.response = ""
self.start_time = perf_counter()
def __iter__(self):
return self
def __next__(self):
item = self.queue.get()
if item is StopIteration:
time_to_response = perf_counter() - self.start_time
logger.info(f"Chat streaming took: {time_to_response:.3f} seconds")
if self.completion_func:
# The completion func effectively acts as a callback.
# It adds the aggregated response to the conversation history.
self.completion_func(chat_response=self.response)
raise StopIteration
return item
def send(self, data):
if self.response == "":
time_to_first_response = perf_counter() - self.start_time
logger.info(f"First response took: {time_to_first_response:.3f} seconds")
self.response += data
self.queue.put(data)
def close(self):
self.queue.put(StopIteration)
class InformationCollectionIteration:
def __init__(
self,
@@ -254,7 +218,7 @@ def message_to_log(
return conversation_log
def save_to_conversation_log(
async def save_to_conversation_log(
q: str,
chat_response: str,
user: KhojUser,
@@ -306,7 +270,7 @@ def save_to_conversation_log(
khoj_message_metadata=khoj_message_metadata,
conversation_log=meta_log.get("chat", []),
)
ConversationAdapters.save_conversation(
await ConversationAdapters.save_conversation(
user,
{"chat": updated_conversation},
client_application=client_application,

View File

@@ -67,7 +67,6 @@ from khoj.routers.research import (
from khoj.routers.storage import upload_user_image_to_bucket
from khoj.utils import state
from khoj.utils.helpers import (
AsyncIteratorWrapper,
ConversationCommand,
command_descriptions,
convert_image_to_webp,
@@ -999,7 +998,7 @@ async def chat(
return
llm_response = construct_automation_created_message(automation, crontime, query_to_run, subject)
await sync_to_async(save_to_conversation_log)(
await save_to_conversation_log(
q,
llm_response,
user,
@@ -1308,26 +1307,31 @@ async def chat(
yield result
continue_stream = True
iterator = AsyncIteratorWrapper(llm_response)
async for item in iterator:
async for item in llm_response:
# Should not happen with async generator, end is signaled by loop exit. Skip.
if item is None:
async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
yield result
# Send Usage Metadata once llm interactions are complete
async for event in send_event(ChatEvent.USAGE, tracer.get("usage")):
yield event
async for result in send_event(ChatEvent.END_RESPONSE, ""):
yield result
logger.debug("Finished streaming response")
return
continue
if not connection_alive or not continue_stream:
# Drain the generator if disconnected but keep processing internally
continue
try:
async for result in send_event(ChatEvent.MESSAGE, f"{item}"):
yield result
except Exception as e:
continue_stream = False
logger.info(f"User {user} disconnected. Emitting rest of responses to clear thread: {e}")
logger.info(f"User {user} disconnected or error during streaming. Stopping send: {e}")
# Signal end of LLM response after the loop finishes
if connection_alive:
async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
yield result
# Send Usage Metadata once llm interactions are complete
if tracer.get("usage"):
async for event in send_event(ChatEvent.USAGE, tracer.get("usage")):
yield event
async for result in send_event(ChatEvent.END_RESPONSE, ""):
yield result
logger.debug("Finished streaming response")
## Stream Text Response
if stream:

View File

@@ -1,4 +1,3 @@
import asyncio
import base64
import hashlib
import json
@@ -6,9 +5,7 @@ import logging
import math
import os
import re
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timedelta, timezone
from enum import Enum
from functools import partial
from random import random
from typing import (
@@ -17,7 +14,6 @@ from typing import (
AsyncGenerator,
Callable,
Dict,
Iterator,
List,
Optional,
Set,
@@ -97,7 +93,6 @@ from khoj.processor.conversation.openai.gpt import (
)
from khoj.processor.conversation.utils import (
ChatEvent,
ThreadedGenerator,
clean_json,
clean_mermaidjs,
construct_chat_history,
@@ -126,8 +121,6 @@ from khoj.utils.rawconfig import ChatRequestBody, FileAttachment, FileData, Loca
logger = logging.getLogger(__name__)
executor = ThreadPoolExecutor(max_workers=1)
NOTION_OAUTH_CLIENT_ID = os.getenv("NOTION_OAUTH_CLIENT_ID")
NOTION_OAUTH_CLIENT_SECRET = os.getenv("NOTION_OAUTH_CLIENT_SECRET")
@@ -262,11 +255,6 @@ def get_conversation_command(query: str) -> ConversationCommand:
return ConversationCommand.Default
async def agenerate_chat_response(*args):
loop = asyncio.get_event_loop()
return await loop.run_in_executor(executor, generate_chat_response, *args)
def gather_raw_query_files(
query_files: Dict[str, str],
):
@@ -1418,7 +1406,7 @@ def send_message_to_model_wrapper_sync(
raise HTTPException(status_code=500, detail="Invalid conversation config")
def generate_chat_response(
async def agenerate_chat_response(
q: str,
meta_log: dict,
conversation: Conversation,
@@ -1444,13 +1432,14 @@ def generate_chat_response(
generated_asset_results: Dict[str, Dict] = {},
is_subscribed: bool = False,
tracer: dict = {},
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
) -> Tuple[AsyncGenerator[str, None], Dict[str, str]]:
# Initialize Variables
chat_response = None
chat_response_generator = None
logger.debug(f"Conversation Types: {conversation_commands}")
metadata = {}
agent = AgentAdapters.get_conversation_agent_by_id(conversation.agent.id) if conversation.agent else None
agent = await AgentAdapters.aget_conversation_agent_by_id(conversation.agent.id) if conversation.agent else None
try:
partial_completion = partial(
save_to_conversation_log,
@@ -1481,17 +1470,17 @@ def generate_chat_response(
code_results = {}
deepthought = True
chat_model = ConversationAdapters.get_valid_chat_model(user, conversation, is_subscribed)
chat_model = await ConversationAdapters.aget_valid_chat_model(user, conversation, is_subscribed)
vision_available = chat_model.vision_enabled
if not vision_available and query_images:
vision_enabled_config = ConversationAdapters.get_vision_enabled_config()
vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config()
if vision_enabled_config:
chat_model = vision_enabled_config
vision_available = True
if chat_model.model_type == "offline":
loaded_model = state.offline_chat_processor_config.loaded_model
chat_response = converse_offline(
chat_response_generator = converse_offline(
user_query=query_to_run,
references=compiled_references,
online_results=online_results,
@@ -1515,7 +1504,7 @@ def generate_chat_response(
openai_chat_config = chat_model.ai_model_api
api_key = openai_chat_config.api_key
chat_model_name = chat_model.name
chat_response = converse_openai(
chat_response_generator = converse_openai(
compiled_references,
query_to_run,
query_images=query_images,
@@ -1544,7 +1533,7 @@ def generate_chat_response(
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
api_key = chat_model.ai_model_api.api_key
api_base_url = chat_model.ai_model_api.api_base_url
chat_response = converse_anthropic(
chat_response_generator = converse_anthropic(
compiled_references,
query_to_run,
query_images=query_images,
@@ -1572,7 +1561,7 @@ def generate_chat_response(
elif chat_model.model_type == ChatModel.ModelType.GOOGLE:
api_key = chat_model.ai_model_api.api_key
api_base_url = chat_model.ai_model_api.api_base_url
chat_response = converse_gemini(
chat_response_generator = converse_gemini(
compiled_references,
query_to_run,
online_results,
@@ -1604,7 +1593,8 @@ def generate_chat_response(
logger.error(e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
return chat_response, metadata
# Return the generator directly
return chat_response_generator, metadata
class DeleteMessageRequestBody(BaseModel):

View File

@@ -23,6 +23,7 @@ from time import perf_counter
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Tuple, Union
from urllib.parse import ParseResult, urlparse
import anthropic
import openai
import psutil
import pyjson5
@@ -30,6 +31,7 @@ import requests
import torch
from asgiref.sync import sync_to_async
from email_validator import EmailNotValidError, EmailUndeliverableError, validate_email
from google import genai
from google.auth.credentials import Credentials
from google.oauth2 import service_account
from magika import Magika
@@ -729,6 +731,60 @@ def get_openai_client(api_key: str, api_base_url: str) -> Union[openai.OpenAI, o
return client
def get_openai_async_client(api_key: str, api_base_url: str) -> Union[openai.AsyncOpenAI, openai.AsyncAzureOpenAI]:
"""Get OpenAI or AzureOpenAI client based on the API Base URL"""
parsed_url = urlparse(api_base_url)
if parsed_url.hostname and parsed_url.hostname.endswith(".openai.azure.com"):
client = openai.AsyncAzureOpenAI(
api_key=api_key,
azure_endpoint=api_base_url,
api_version="2024-10-21",
)
else:
client = openai.AsyncOpenAI(
api_key=api_key,
base_url=api_base_url,
)
return client
def get_anthropic_client(api_key, api_base_url=None) -> anthropic.Anthropic | anthropic.AnthropicVertex:
api_info = get_ai_api_info(api_key, api_base_url)
if api_info.api_key:
client = anthropic.Anthropic(api_key=api_info.api_key)
else:
client = anthropic.AnthropicVertex(
region=api_info.region,
project_id=api_info.project,
credentials=api_info.credentials,
)
return client
def get_anthropic_async_client(api_key, api_base_url=None) -> anthropic.AsyncAnthropic | anthropic.AsyncAnthropicVertex:
api_info = get_ai_api_info(api_key, api_base_url)
if api_info.api_key:
client = anthropic.AsyncAnthropic(api_key=api_info.api_key)
else:
client = anthropic.AsyncAnthropicVertex(
region=api_info.region,
project_id=api_info.project,
credentials=api_info.credentials,
)
return client
def get_gemini_client(api_key, api_base_url=None) -> genai.Client:
api_info = get_ai_api_info(api_key, api_base_url)
return genai.Client(
location=api_info.region,
project=api_info.project,
credentials=api_info.credentials,
api_key=api_info.api_key,
vertexai=api_info.api_key is None,
)
def normalize_email(email: str, check_deliverability=False) -> tuple[str, bool]:
"""Normalize, validate and check deliverability of email address"""
lower_email = email.lower()