From c93c0d982e23d40c092b7a7a1ec770212b9120bd Mon Sep 17 00:00:00 2001 From: Debanjum Date: Mon, 24 Mar 2025 15:14:19 +0530 Subject: [PATCH 1/6] Create async get anthropic, openai client funcs, move to reusable package This package is where the get openai client functions also reside. --- .../processor/conversation/anthropic/utils.py | 15 +---- .../processor/conversation/google/utils.py | 13 +---- src/khoj/utils/helpers.py | 56 +++++++++++++++++++ 3 files changed, 58 insertions(+), 26 deletions(-) diff --git a/src/khoj/processor/conversation/anthropic/utils.py b/src/khoj/processor/conversation/anthropic/utils.py index 6c2ffb8a..48c6515f 100644 --- a/src/khoj/processor/conversation/anthropic/utils.py +++ b/src/khoj/processor/conversation/anthropic/utils.py @@ -19,7 +19,7 @@ from khoj.processor.conversation.utils import ( get_image_from_url, ) from khoj.utils.helpers import ( - get_ai_api_info, + get_anthropic_client, get_chat_usage_metrics, is_none_or_empty, is_promptrace_enabled, @@ -33,19 +33,6 @@ DEFAULT_MAX_TOKENS_ANTHROPIC = 8000 MAX_REASONING_TOKENS_ANTHROPIC = 12000 -def get_anthropic_client(api_key, api_base_url=None) -> anthropic.Anthropic | anthropic.AnthropicVertex: - api_info = get_ai_api_info(api_key, api_base_url) - if api_info.api_key: - client = anthropic.Anthropic(api_key=api_info.api_key) - else: - client = anthropic.AnthropicVertex( - region=api_info.region, - project_id=api_info.project, - credentials=api_info.credentials, - ) - return client - - @retry( wait=wait_random_exponential(min=1, max=10), stop=stop_after_attempt(2), diff --git a/src/khoj/processor/conversation/google/utils.py b/src/khoj/processor/conversation/google/utils.py index 9a8b4132..b497edec 100644 --- a/src/khoj/processor/conversation/google/utils.py +++ b/src/khoj/processor/conversation/google/utils.py @@ -25,8 +25,8 @@ from khoj.processor.conversation.utils import ( get_image_from_url, ) from khoj.utils.helpers import ( - get_ai_api_info, get_chat_usage_metrics, + get_gemini_client, is_none_or_empty, is_promptrace_enabled, ) @@ -62,17 +62,6 @@ SAFETY_SETTINGS = [ ] -def get_gemini_client(api_key, api_base_url=None) -> genai.Client: - api_info = get_ai_api_info(api_key, api_base_url) - return genai.Client( - location=api_info.region, - project=api_info.project, - credentials=api_info.credentials, - api_key=api_info.api_key, - vertexai=api_info.api_key is None, - ) - - @retry( wait=wait_random_exponential(min=1, max=10), stop=stop_after_attempt(2), diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py index f0aa0cd6..4a756dcb 100644 --- a/src/khoj/utils/helpers.py +++ b/src/khoj/utils/helpers.py @@ -23,6 +23,7 @@ from time import perf_counter from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Tuple, Union from urllib.parse import ParseResult, urlparse +import anthropic import openai import psutil import pyjson5 @@ -30,6 +31,7 @@ import requests import torch from asgiref.sync import sync_to_async from email_validator import EmailNotValidError, EmailUndeliverableError, validate_email +from google import genai from google.auth.credentials import Credentials from google.oauth2 import service_account from magika import Magika @@ -729,6 +731,60 @@ def get_openai_client(api_key: str, api_base_url: str) -> Union[openai.OpenAI, o return client +def get_openai_async_client(api_key: str, api_base_url: str) -> Union[openai.AsyncOpenAI, openai.AsyncAzureOpenAI]: + """Get OpenAI or AzureOpenAI client based on the API Base URL""" + parsed_url = urlparse(api_base_url) + if parsed_url.hostname and parsed_url.hostname.endswith(".openai.azure.com"): + client = openai.AsyncAzureOpenAI( + api_key=api_key, + azure_endpoint=api_base_url, + api_version="2024-10-21", + ) + else: + client = openai.AsyncOpenAI( + api_key=api_key, + base_url=api_base_url, + ) + return client + + +def get_anthropic_client(api_key, api_base_url=None) -> anthropic.Anthropic | anthropic.AnthropicVertex: + api_info = get_ai_api_info(api_key, api_base_url) + if api_info.api_key: + client = anthropic.Anthropic(api_key=api_info.api_key) + else: + client = anthropic.AnthropicVertex( + region=api_info.region, + project_id=api_info.project, + credentials=api_info.credentials, + ) + return client + + +def get_anthropic_async_client(api_key, api_base_url=None) -> anthropic.AsyncAnthropic | anthropic.AsyncAnthropicVertex: + api_info = get_ai_api_info(api_key, api_base_url) + if api_info.api_key: + client = anthropic.AsyncAnthropic(api_key=api_info.api_key) + else: + client = anthropic.AsyncAnthropicVertex( + region=api_info.region, + project_id=api_info.project, + credentials=api_info.credentials, + ) + return client + + +def get_gemini_client(api_key, api_base_url=None) -> genai.Client: + api_info = get_ai_api_info(api_key, api_base_url) + return genai.Client( + location=api_info.region, + project=api_info.project, + credentials=api_info.credentials, + api_key=api_info.api_key, + vertexai=api_info.api_key is None, + ) + + def normalize_email(email: str, check_deliverability=False) -> tuple[str, bool]: """Normalize, validate and check deliverability of email address""" lower_email = email.lower() From 0751f2ea30d4943d88498d1fdfb7fe34696f5d9f Mon Sep 17 00:00:00 2001 From: Debanjum Date: Sat, 19 Apr 2025 21:36:45 +0530 Subject: [PATCH 2/6] Refactor Openai chat response to stream async, no separate thread - Refactor chat API to use async/await for Openai streaming - Fix and clean Openai chat response async streaming --- src/khoj/database/adapters/__init__.py | 46 +++---- src/khoj/processor/conversation/openai/gpt.py | 34 +++-- .../processor/conversation/openai/utils.py | 116 ++++++++---------- src/khoj/processor/conversation/utils.py | 4 +- src/khoj/routers/api_chat.py | 32 ++--- src/khoj/routers/helpers.py | 41 +++---- 6 files changed, 132 insertions(+), 141 deletions(-) diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 92846020..248a78e8 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -763,9 +763,9 @@ class AgentAdapters: return False @staticmethod - def get_conversation_agent_by_id(agent_id: int): - agent = Agent.objects.filter(id=agent_id).first() - if agent == AgentAdapters.get_default_agent(): + async def aget_conversation_agent_by_id(agent_id: int): + agent = await Agent.objects.filter(id=agent_id).afirst() + if agent == await AgentAdapters.aget_default_agent(): # If the agent is set to the default agent, then return None and let the default application code be used return None return agent @@ -1109,14 +1109,6 @@ class ConversationAdapters: async def aget_all_chat_models(): return await sync_to_async(list)(ChatModel.objects.prefetch_related("ai_model_api").all()) - @staticmethod - def get_vision_enabled_config(): - chat_models = ConversationAdapters.get_all_chat_models() - for config in chat_models: - if config.vision_enabled: - return config - return None - @staticmethod async def aget_vision_enabled_config(): chat_models = await ConversationAdapters.aget_all_chat_models() @@ -1171,7 +1163,11 @@ class ConversationAdapters: @staticmethod async def aget_chat_model(user: KhojUser): subscribed = await ais_user_subscribed(user) - config = await UserConversationConfig.objects.filter(user=user).prefetch_related("setting").afirst() + config = ( + await UserConversationConfig.objects.filter(user=user) + .prefetch_related("setting", "setting__ai_model_api") + .afirst() + ) if subscribed: # Subscibed users can use any available chat model if config: @@ -1387,7 +1383,7 @@ class ConversationAdapters: @staticmethod @require_valid_user - def save_conversation( + async def save_conversation( user: KhojUser, conversation_log: dict, client_application: ClientApplication = None, @@ -1396,19 +1392,21 @@ class ConversationAdapters: ): slug = user_message.strip()[:200] if user_message else None if conversation_id: - conversation = Conversation.objects.filter(user=user, client=client_application, id=conversation_id).first() + conversation = await Conversation.objects.filter( + user=user, client=client_application, id=conversation_id + ).afirst() else: conversation = ( - Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at").first() + await Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at").afirst() ) if conversation: conversation.conversation_log = conversation_log conversation.slug = slug conversation.updated_at = datetime.now(tz=timezone.utc) - conversation.save() + await conversation.asave() else: - Conversation.objects.create( + await Conversation.objects.acreate( user=user, conversation_log=conversation_log, client=client_application, slug=slug ) @@ -1455,17 +1453,21 @@ class ConversationAdapters: return random.sample(all_questions, max_results) @staticmethod - def get_valid_chat_model(user: KhojUser, conversation: Conversation, is_subscribed: bool): + async def aget_valid_chat_model(user: KhojUser, conversation: Conversation, is_subscribed: bool): agent: Agent = ( - conversation.agent if is_subscribed and AgentAdapters.get_default_agent() != conversation.agent else None + conversation.agent + if is_subscribed and await AgentAdapters.aget_default_agent() != conversation.agent + else None ) if agent and agent.chat_model: - chat_model = conversation.agent.chat_model + chat_model = await ChatModel.objects.select_related("ai_model_api").aget( + pk=conversation.agent.chat_model.pk + ) else: - chat_model = ConversationAdapters.get_chat_model(user) + chat_model = await ConversationAdapters.aget_chat_model(user) if chat_model is None: - chat_model = ConversationAdapters.get_default_chat_model() + chat_model = await ConversationAdapters.aget_default_chat_model() if chat_model.model_type == ChatModel.ModelType.OFFLINE: if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None: diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index 11e6a03d..b5fbdcf2 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -1,6 +1,6 @@ import logging from datetime import datetime, timedelta -from typing import Dict, List, Optional +from typing import AsyncGenerator, Dict, List, Optional import pyjson5 from langchain.schema import ChatMessage @@ -162,7 +162,7 @@ def send_message_to_model( ) -def converse_openai( +async def converse_openai( references, user_query, online_results: Optional[Dict[str, Dict]] = None, @@ -187,7 +187,7 @@ def converse_openai( program_execution_context: List[str] = None, deepthought: Optional[bool] = False, tracer: dict = {}, -): +) -> AsyncGenerator[str, None]: """ Converse with user using OpenAI's ChatGPT """ @@ -217,11 +217,17 @@ def converse_openai( # Get Conversation Primer appropriate to Conversation Type if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references): - completion_func(chat_response=prompts.no_notes_found.format()) - return iter([prompts.no_notes_found.format()]) + response = prompts.no_notes_found.format() + if completion_func: + await completion_func(chat_response=response) + yield response + return elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results): - completion_func(chat_response=prompts.no_online_results_found.format()) - return iter([prompts.no_online_results_found.format()]) + response = prompts.no_online_results_found.format() + if completion_func: + await completion_func(chat_response=response) + yield response + return context_message = "" if not is_none_or_empty(references): @@ -255,19 +261,23 @@ def converse_openai( logger.debug(f"Conversation Context for GPT: {messages_to_print(messages)}") # Get Response from GPT - return chat_completion_with_backoff( + full_response = "" + async for chunk in chat_completion_with_backoff( messages=messages, - compiled_references=references, - online_results=online_results, model_name=model, temperature=temperature, openai_api_key=api_key, api_base_url=api_base_url, - completion_func=completion_func, deepthought=deepthought, model_kwargs={"stop": ["Notes:\n["]}, tracer=tracer, - ) + ): + full_response += chunk + yield chunk + + # Call completion_func once finish streaming and we have the full response + if completion_func: + await completion_func(chat_response=full_response) def clean_response_schema(schema: BaseModel | dict) -> dict: diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index b73903ae..7fab44aa 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -1,7 +1,7 @@ import logging import os -from threading import Thread -from typing import Dict, List +from time import perf_counter +from typing import AsyncGenerator, Dict, List from urllib.parse import urlparse import openai @@ -16,13 +16,10 @@ from tenacity import ( wait_random_exponential, ) -from khoj.processor.conversation.utils import ( - JsonSupport, - ThreadedGenerator, - commit_conversation_trace, -) +from khoj.processor.conversation.utils import JsonSupport, commit_conversation_trace from khoj.utils.helpers import ( get_chat_usage_metrics, + get_openai_async_client, get_openai_client, is_promptrace_enabled, ) @@ -30,6 +27,7 @@ from khoj.utils.helpers import ( logger = logging.getLogger(__name__) openai_clients: Dict[str, openai.OpenAI] = {} +openai_async_clients: Dict[str, openai.AsyncOpenAI] = {} @retry( @@ -124,45 +122,22 @@ def completion_with_backoff( before_sleep=before_sleep_log(logger, logging.DEBUG), reraise=True, ) -def chat_completion_with_backoff( +async def chat_completion_with_backoff( messages, - compiled_references, - online_results, model_name, temperature, openai_api_key=None, api_base_url=None, - completion_func=None, - deepthought=False, - model_kwargs=None, - tracer: dict = {}, -): - g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func) - t = Thread( - target=llm_thread, - args=(g, messages, model_name, temperature, openai_api_key, api_base_url, deepthought, model_kwargs, tracer), - ) - t.start() - return g - - -def llm_thread( - g, - messages, - model_name: str, - temperature, - openai_api_key=None, - api_base_url=None, deepthought=False, model_kwargs: dict = {}, tracer: dict = {}, -): +) -> AsyncGenerator[str, None]: try: client_key = f"{openai_api_key}--{api_base_url}" - client = openai_clients.get(client_key) + client = openai_async_clients.get(client_key) if not client: - client = get_openai_client(openai_api_key, api_base_url) - openai_clients[client_key] = client + client = get_openai_async_client(openai_api_key, api_base_url) + openai_async_clients[client_key] = client formatted_messages = [{"role": message.role, "content": message.content} for message in messages] @@ -207,53 +182,58 @@ def llm_thread( if os.getenv("KHOJ_LLM_SEED"): model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED")) - chat: ChatCompletion | openai.Stream[ChatCompletionChunk] = client.chat.completions.create( - messages=formatted_messages, - model=model_name, # type: ignore + aggregated_response = "" + final_chunk = None + start_time = perf_counter() + chat_stream: openai.AsyncStream[ChatCompletionChunk] = await client.chat.completions.create( + messages=formatted_messages, # type: ignore + model=model_name, stream=stream, temperature=temperature, timeout=20, **model_kwargs, ) + async for chunk in chat_stream: + # Log the time taken to start response + if final_chunk is None: + logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds") + # Keep track of the last chunk for usage data + final_chunk = chunk + # Handle streamed response chunk + if len(chunk.choices) == 0: + continue + delta_chunk = chunk.choices[0].delta + text_chunk = "" + if isinstance(delta_chunk, str): + text_chunk = delta_chunk + elif delta_chunk and delta_chunk.content: + text_chunk = delta_chunk.content + if text_chunk: + aggregated_response += text_chunk + yield text_chunk - aggregated_response = "" - if not stream: - chunk = chat - aggregated_response = chunk.choices[0].message.content - g.send(aggregated_response) - else: - for chunk in chat: - if len(chunk.choices) == 0: - continue - delta_chunk = chunk.choices[0].delta - text_chunk = "" - if isinstance(delta_chunk, str): - text_chunk = delta_chunk - elif delta_chunk.content: - text_chunk = delta_chunk.content - if text_chunk: - aggregated_response += text_chunk - g.send(text_chunk) + # Log the time taken to stream the entire response + logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds") - # Calculate cost of chat - input_tokens = chunk.usage.prompt_tokens if hasattr(chunk, "usage") and chunk.usage else 0 - output_tokens = chunk.usage.completion_tokens if hasattr(chunk, "usage") and chunk.usage else 0 - cost = ( - chunk.usage.model_extra.get("estimated_cost", 0) if hasattr(chunk, "usage") and chunk.usage else 0 - ) # Estimated costs returned by DeepInfra API - tracer["usage"] = get_chat_usage_metrics( - model_name, input_tokens, output_tokens, usage=tracer.get("usage"), cost=cost - ) + # Calculate cost of chat after stream finishes + input_tokens, output_tokens, cost = 0, 0, 0 + if final_chunk and hasattr(final_chunk, "usage") and final_chunk.usage: + input_tokens = final_chunk.usage.prompt_tokens + output_tokens = final_chunk.usage.completion_tokens + # Estimated costs returned by DeepInfra API + if final_chunk.usage.model_extra and "estimated_cost" in final_chunk.usage.model_extra: + cost = final_chunk.usage.model_extra.get("estimated_cost", 0) # Save conversation trace tracer["chat_model"] = model_name tracer["temperature"] = temperature + tracer["usage"] = get_chat_usage_metrics( + model_name, input_tokens, output_tokens, usage=tracer.get("usage"), cost=cost + ) if is_promptrace_enabled(): commit_conversation_trace(messages, aggregated_response, tracer) except Exception as e: - logger.error(f"Error in llm_thread: {e}", exc_info=True) - finally: - g.close() + logger.error(f"Error in chat_completion_with_backoff stream: {e}", exc_info=True) def get_openai_api_json_support(model_name: str, api_base_url: str = None) -> JsonSupport: diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 01c25cf4..2bb6265e 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -254,7 +254,7 @@ def message_to_log( return conversation_log -def save_to_conversation_log( +async def save_to_conversation_log( q: str, chat_response: str, user: KhojUser, @@ -306,7 +306,7 @@ def save_to_conversation_log( khoj_message_metadata=khoj_message_metadata, conversation_log=meta_log.get("chat", []), ) - ConversationAdapters.save_conversation( + await ConversationAdapters.save_conversation( user, {"chat": updated_conversation}, client_application=client_application, diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 03f367dd..dd951238 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -67,7 +67,6 @@ from khoj.routers.research import ( from khoj.routers.storage import upload_user_image_to_bucket from khoj.utils import state from khoj.utils.helpers import ( - AsyncIteratorWrapper, ConversationCommand, command_descriptions, convert_image_to_webp, @@ -999,7 +998,7 @@ async def chat( return llm_response = construct_automation_created_message(automation, crontime, query_to_run, subject) - await sync_to_async(save_to_conversation_log)( + await save_to_conversation_log( q, llm_response, user, @@ -1308,26 +1307,31 @@ async def chat( yield result continue_stream = True - iterator = AsyncIteratorWrapper(llm_response) - async for item in iterator: + async for item in llm_response: + # Should not happen with async generator, end is signaled by loop exit. Skip. if item is None: - async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""): - yield result - # Send Usage Metadata once llm interactions are complete - async for event in send_event(ChatEvent.USAGE, tracer.get("usage")): - yield event - async for result in send_event(ChatEvent.END_RESPONSE, ""): - yield result - logger.debug("Finished streaming response") - return + continue if not connection_alive or not continue_stream: + # Drain the generator if disconnected but keep processing internally continue try: async for result in send_event(ChatEvent.MESSAGE, f"{item}"): yield result except Exception as e: continue_stream = False - logger.info(f"User {user} disconnected. Emitting rest of responses to clear thread: {e}") + logger.info(f"User {user} disconnected or error during streaming. Stopping send: {e}") + + # Signal end of LLM response after the loop finishes + if connection_alive: + async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""): + yield result + # Send Usage Metadata once llm interactions are complete + if tracer.get("usage"): + async for event in send_event(ChatEvent.USAGE, tracer.get("usage")): + yield event + async for result in send_event(ChatEvent.END_RESPONSE, ""): + yield result + logger.debug("Finished streaming response") ## Stream Text Response if stream: diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 742a8708..6042469a 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -1,4 +1,3 @@ -import asyncio import base64 import hashlib import json @@ -6,9 +5,7 @@ import logging import math import os import re -from concurrent.futures import ThreadPoolExecutor from datetime import datetime, timedelta, timezone -from enum import Enum from functools import partial from random import random from typing import ( @@ -17,7 +14,6 @@ from typing import ( AsyncGenerator, Callable, Dict, - Iterator, List, Optional, Set, @@ -126,8 +122,6 @@ from khoj.utils.rawconfig import ChatRequestBody, FileAttachment, FileData, Loca logger = logging.getLogger(__name__) -executor = ThreadPoolExecutor(max_workers=1) - NOTION_OAUTH_CLIENT_ID = os.getenv("NOTION_OAUTH_CLIENT_ID") NOTION_OAUTH_CLIENT_SECRET = os.getenv("NOTION_OAUTH_CLIENT_SECRET") @@ -262,11 +256,6 @@ def get_conversation_command(query: str) -> ConversationCommand: return ConversationCommand.Default -async def agenerate_chat_response(*args): - loop = asyncio.get_event_loop() - return await loop.run_in_executor(executor, generate_chat_response, *args) - - def gather_raw_query_files( query_files: Dict[str, str], ): @@ -1418,7 +1407,7 @@ def send_message_to_model_wrapper_sync( raise HTTPException(status_code=500, detail="Invalid conversation config") -def generate_chat_response( +async def agenerate_chat_response( q: str, meta_log: dict, conversation: Conversation, @@ -1444,13 +1433,14 @@ def generate_chat_response( generated_asset_results: Dict[str, Dict] = {}, is_subscribed: bool = False, tracer: dict = {}, -) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]: +) -> Tuple[AsyncGenerator[str, None], Dict[str, str]]: # Initialize Variables - chat_response = None + chat_response_generator = None logger.debug(f"Conversation Types: {conversation_commands}") metadata = {} - agent = AgentAdapters.get_conversation_agent_by_id(conversation.agent.id) if conversation.agent else None + agent = await AgentAdapters.aget_conversation_agent_by_id(conversation.agent.id) if conversation.agent else None + try: partial_completion = partial( save_to_conversation_log, @@ -1481,23 +1471,25 @@ def generate_chat_response( code_results = {} deepthought = True - chat_model = ConversationAdapters.get_valid_chat_model(user, conversation, is_subscribed) + chat_model = await ConversationAdapters.aget_valid_chat_model(user, conversation, is_subscribed) vision_available = chat_model.vision_enabled if not vision_available and query_images: - vision_enabled_config = ConversationAdapters.get_vision_enabled_config() + vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config() if vision_enabled_config: chat_model = vision_enabled_config vision_available = True if chat_model.model_type == "offline": + # Assuming converse_offline remains sync or is refactored separately loaded_model = state.offline_chat_processor_config.loaded_model - chat_response = converse_offline( + # If converse_offline returns an iterator, wrap it if needed, or refactor it to async generator + chat_response_generator = converse_offline( # Needs adaptation if it becomes async user_query=query_to_run, references=compiled_references, online_results=online_results, loaded_model=loaded_model, conversation_log=meta_log, - completion_func=partial_completion, + completion_func=partial_completion, # Pass the async wrapper conversation_commands=conversation_commands, model_name=chat_model.name, max_prompt_size=chat_model.max_prompt_size, @@ -1515,7 +1507,7 @@ def generate_chat_response( openai_chat_config = chat_model.ai_model_api api_key = openai_chat_config.api_key chat_model_name = chat_model.name - chat_response = converse_openai( + chat_response_generator = converse_openai( compiled_references, query_to_run, query_images=query_images, @@ -1542,9 +1534,10 @@ def generate_chat_response( ) elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC: + # Assuming converse_anthropic remains sync or is refactored separately api_key = chat_model.ai_model_api.api_key api_base_url = chat_model.ai_model_api.api_base_url - chat_response = converse_anthropic( + chat_response_generator = converse_anthropic( # Needs adaptation if it becomes async compiled_references, query_to_run, query_images=query_images, @@ -1570,9 +1563,10 @@ def generate_chat_response( tracer=tracer, ) elif chat_model.model_type == ChatModel.ModelType.GOOGLE: + # Assuming converse_gemini remains sync or is refactored separately api_key = chat_model.ai_model_api.api_key api_base_url = chat_model.ai_model_api.api_base_url - chat_response = converse_gemini( + chat_response_generator = converse_gemini( # Needs adaptation if it becomes async compiled_references, query_to_run, online_results, @@ -1604,7 +1598,8 @@ def generate_chat_response( logger.error(e, exc_info=True) raise HTTPException(status_code=500, detail=str(e)) - return chat_response, metadata + # Return the generator directly + return chat_response_generator, metadata class DeleteMessageRequestBody(BaseModel): From a557031447998a804ef0fa4584d20bcb68de6c5a Mon Sep 17 00:00:00 2001 From: Debanjum Date: Sun, 20 Apr 2025 02:54:56 +0530 Subject: [PATCH 3/6] Refactor Gemini chat response to stream async, no separate thread --- .../conversation/google/gemini_chat.py | 34 +++++--- .../processor/conversation/google/utils.py | 77 ++++++------------- src/khoj/routers/helpers.py | 3 +- 3 files changed, 48 insertions(+), 66 deletions(-) diff --git a/src/khoj/processor/conversation/google/gemini_chat.py b/src/khoj/processor/conversation/google/gemini_chat.py index 73167ca2..3c42ef06 100644 --- a/src/khoj/processor/conversation/google/gemini_chat.py +++ b/src/khoj/processor/conversation/google/gemini_chat.py @@ -1,6 +1,6 @@ import logging from datetime import datetime, timedelta -from typing import Dict, List, Optional +from typing import AsyncGenerator, Dict, List, Optional import pyjson5 from langchain.schema import ChatMessage @@ -160,7 +160,7 @@ def gemini_send_message_to_model( ) -def converse_gemini( +async def converse_gemini( references, user_query, online_results: Optional[Dict[str, Dict]] = None, @@ -185,7 +185,7 @@ def converse_gemini( program_execution_context: List[str] = None, deepthought: Optional[bool] = False, tracer={}, -): +) -> AsyncGenerator[str, None]: """ Converse with user using Google's Gemini """ @@ -216,11 +216,17 @@ def converse_gemini( # Get Conversation Primer appropriate to Conversation Type if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references): - completion_func(chat_response=prompts.no_notes_found.format()) - return iter([prompts.no_notes_found.format()]) + response = prompts.no_notes_found.format() + if completion_func: + await completion_func(chat_response=response) + yield response + return elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results): - completion_func(chat_response=prompts.no_online_results_found.format()) - return iter([prompts.no_online_results_found.format()]) + response = prompts.no_online_results_found.format() + if completion_func: + await completion_func(chat_response=response) + yield response + return context_message = "" if not is_none_or_empty(references): @@ -253,16 +259,20 @@ def converse_gemini( logger.debug(f"Conversation Context for Gemini: {messages_to_print(messages)}") # Get Response from Google AI - return gemini_chat_completion_with_backoff( + full_response = "" + async for chunk in gemini_chat_completion_with_backoff( messages=messages, - compiled_references=references, - online_results=online_results, model_name=model, temperature=temperature, api_key=api_key, api_base_url=api_base_url, system_prompt=system_prompt, - completion_func=completion_func, deepthought=deepthought, tracer=tracer, - ) + ): + full_response += chunk + yield chunk + + # Call completion_func once finish streaming and we have the full response + if completion_func: + await completion_func(chat_response=full_response) diff --git a/src/khoj/processor/conversation/google/utils.py b/src/khoj/processor/conversation/google/utils.py index b497edec..d85cd09e 100644 --- a/src/khoj/processor/conversation/google/utils.py +++ b/src/khoj/processor/conversation/google/utils.py @@ -2,8 +2,8 @@ import logging import os import random from copy import deepcopy -from threading import Thread -from typing import Dict +from time import perf_counter +from typing import AsyncGenerator, AsyncIterator, Dict from google import genai from google.genai import errors as gerrors @@ -19,7 +19,6 @@ from tenacity import ( ) from khoj.processor.conversation.utils import ( - ThreadedGenerator, commit_conversation_trace, get_image_from_base64, get_image_from_url, @@ -121,8 +120,8 @@ def gemini_completion_with_backoff( ) # Aggregate cost of chat - input_tokens = response.usage_metadata.prompt_token_count if response else 0 - output_tokens = response.usage_metadata.candidates_token_count if response else 0 + input_tokens = response.usage_metadata.prompt_token_count or 0 if response else 0 + output_tokens = response.usage_metadata.candidates_token_count or 0 if response else 0 thought_tokens = response.usage_metadata.thoughts_token_count or 0 if response else 0 tracer["usage"] = get_chat_usage_metrics( model_name, input_tokens, output_tokens, thought_tokens=thought_tokens, usage=tracer.get("usage") @@ -143,52 +142,17 @@ def gemini_completion_with_backoff( before_sleep=before_sleep_log(logger, logging.DEBUG), reraise=True, ) -def gemini_chat_completion_with_backoff( +async def gemini_chat_completion_with_backoff( messages, - compiled_references, - online_results, model_name, temperature, api_key, api_base_url, system_prompt, - completion_func=None, model_kwargs=None, deepthought=False, tracer: dict = {}, -): - g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func) - t = Thread( - target=gemini_llm_thread, - args=( - g, - messages, - system_prompt, - model_name, - temperature, - api_key, - api_base_url, - model_kwargs, - deepthought, - tracer, - ), - ) - t.start() - return g - - -def gemini_llm_thread( - g, - messages, - system_prompt, - model_name, - temperature, - api_key, - api_base_url=None, - model_kwargs=None, - deepthought=False, - tracer: dict = {}, -): +) -> AsyncGenerator[str, None]: try: client = gemini_clients.get(api_key) if not client: @@ -213,21 +177,32 @@ def gemini_llm_thread( ) aggregated_response = "" - - for chunk in client.models.generate_content_stream( + final_chunk = None + start_time = perf_counter() + chat_stream: AsyncIterator[gtypes.GenerateContentResponse] = await client.aio.models.generate_content_stream( model=model_name, config=config, contents=formatted_messages - ): + ) + async for chunk in chat_stream: + # Log the time taken to start response + if final_chunk is None: + logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds") + # Keep track of the last chunk for usage data + final_chunk = chunk + # Handle streamed response chunk message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback) message = message or chunk.text aggregated_response += message - g.send(message) + yield message if stopped: raise ValueError(message) + # Log the time taken to stream the entire response + logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds") + # Calculate cost of chat - input_tokens = chunk.usage_metadata.prompt_token_count - output_tokens = chunk.usage_metadata.candidates_token_count - thought_tokens = chunk.usage_metadata.thoughts_token_count or 0 + input_tokens = final_chunk.usage_metadata.prompt_token_count or 0 if final_chunk else 0 + output_tokens = final_chunk.usage_metadata.candidates_token_count or 0 if final_chunk else 0 + thought_tokens = final_chunk.usage_metadata.thoughts_token_count or 0 if final_chunk else 0 tracer["usage"] = get_chat_usage_metrics( model_name, input_tokens, output_tokens, thought_tokens=thought_tokens, usage=tracer.get("usage") ) @@ -243,9 +218,7 @@ def gemini_llm_thread( + f"Last Message by {messages[-1].role}: {messages[-1].content}" ) except Exception as e: - logger.error(f"Error in gemini_llm_thread: {e}", exc_info=True) - finally: - g.close() + logger.error(f"Error in gemini_chat_completion_with_backoff stream: {e}", exc_info=True) def handle_gemini_response( diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 6042469a..a5aa415a 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -1563,10 +1563,9 @@ async def agenerate_chat_response( tracer=tracer, ) elif chat_model.model_type == ChatModel.ModelType.GOOGLE: - # Assuming converse_gemini remains sync or is refactored separately api_key = chat_model.ai_model_api.api_key api_base_url = chat_model.ai_model_api.api_base_url - chat_response_generator = converse_gemini( # Needs adaptation if it becomes async + chat_response_generator = converse_gemini( compiled_references, query_to_run, online_results, From 932a9615efa87e9125266ef2006bbd889d6ac019 Mon Sep 17 00:00:00 2001 From: Debanjum Date: Sun, 20 Apr 2025 03:18:32 +0530 Subject: [PATCH 4/6] Refactor Anthropic chat response to stream async, no separate thread --- .../conversation/anthropic/anthropic_chat.py | 34 +++++---- .../processor/conversation/anthropic/utils.py | 70 ++++++------------- src/khoj/routers/helpers.py | 3 +- 3 files changed, 43 insertions(+), 64 deletions(-) diff --git a/src/khoj/processor/conversation/anthropic/anthropic_chat.py b/src/khoj/processor/conversation/anthropic/anthropic_chat.py index 977b25de..5bad38ef 100644 --- a/src/khoj/processor/conversation/anthropic/anthropic_chat.py +++ b/src/khoj/processor/conversation/anthropic/anthropic_chat.py @@ -1,6 +1,6 @@ import logging from datetime import datetime, timedelta -from typing import Dict, List, Optional +from typing import AsyncGenerator, Dict, List, Optional import pyjson5 from langchain.schema import ChatMessage @@ -137,7 +137,7 @@ def anthropic_send_message_to_model( ) -def converse_anthropic( +async def converse_anthropic( references, user_query, online_results: Optional[Dict[str, Dict]] = None, @@ -161,7 +161,7 @@ def converse_anthropic( generated_asset_results: Dict[str, Dict] = {}, deepthought: Optional[bool] = False, tracer: dict = {}, -): +) -> AsyncGenerator[str, None]: """ Converse with user using Anthropic's Claude """ @@ -191,11 +191,17 @@ def converse_anthropic( # Get Conversation Primer appropriate to Conversation Type if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references): - completion_func(chat_response=prompts.no_notes_found.format()) - return iter([prompts.no_notes_found.format()]) + response = prompts.no_notes_found.format() + if completion_func: + await completion_func(chat_response=response) + yield response + return elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results): - completion_func(chat_response=prompts.no_online_results_found.format()) - return iter([prompts.no_online_results_found.format()]) + response = prompts.no_online_results_found.format() + if completion_func: + await completion_func(chat_response=response) + yield response + return context_message = "" if not is_none_or_empty(references): @@ -228,17 +234,21 @@ def converse_anthropic( logger.debug(f"Conversation Context for Claude: {messages_to_print(messages)}") # Get Response from Claude - return anthropic_chat_completion_with_backoff( + full_response = "" + async for chunk in anthropic_chat_completion_with_backoff( messages=messages, - compiled_references=references, - online_results=online_results, model_name=model, temperature=0.2, api_key=api_key, api_base_url=api_base_url, system_prompt=system_prompt, - completion_func=completion_func, max_prompt_size=max_prompt_size, deepthought=deepthought, tracer=tracer, - ) + ): + full_response += chunk + yield chunk + + # Call completion_func once finish streaming and we have the full response + if completion_func: + await completion_func(chat_response=full_response) diff --git a/src/khoj/processor/conversation/anthropic/utils.py b/src/khoj/processor/conversation/anthropic/utils.py index 48c6515f..442e1cb3 100644 --- a/src/khoj/processor/conversation/anthropic/utils.py +++ b/src/khoj/processor/conversation/anthropic/utils.py @@ -1,5 +1,5 @@ import logging -from threading import Thread +from time import perf_counter from typing import Dict, List import anthropic @@ -13,12 +13,12 @@ from tenacity import ( ) from khoj.processor.conversation.utils import ( - ThreadedGenerator, commit_conversation_trace, get_image_from_base64, get_image_from_url, ) from khoj.utils.helpers import ( + get_anthropic_async_client, get_anthropic_client, get_chat_usage_metrics, is_none_or_empty, @@ -28,6 +28,7 @@ from khoj.utils.helpers import ( logger = logging.getLogger(__name__) anthropic_clients: Dict[str, anthropic.Anthropic | anthropic.AnthropicVertex] = {} +anthropic_async_clients: Dict[str, anthropic.AsyncAnthropic | anthropic.AsyncAnthropicVertex] = {} DEFAULT_MAX_TOKENS_ANTHROPIC = 8000 MAX_REASONING_TOKENS_ANTHROPIC = 12000 @@ -113,60 +114,23 @@ def anthropic_completion_with_backoff( before_sleep=before_sleep_log(logger, logging.DEBUG), reraise=True, ) -def anthropic_chat_completion_with_backoff( +async def anthropic_chat_completion_with_backoff( messages: list[ChatMessage], - compiled_references, - online_results, model_name, temperature, api_key, api_base_url, system_prompt: str, max_prompt_size=None, - completion_func=None, - deepthought=False, - model_kwargs=None, - tracer={}, -): - g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func) - t = Thread( - target=anthropic_llm_thread, - args=( - g, - messages, - system_prompt, - model_name, - temperature, - api_key, - api_base_url, - max_prompt_size, - deepthought, - model_kwargs, - tracer, - ), - ) - t.start() - return g - - -def anthropic_llm_thread( - g, - messages: list[ChatMessage], - system_prompt: str, - model_name: str, - temperature, - api_key, - api_base_url=None, - max_prompt_size=None, deepthought=False, model_kwargs=None, tracer={}, ): try: - client = anthropic_clients.get(api_key) + client = anthropic_async_clients.get(api_key) if not client: - client = get_anthropic_client(api_key, api_base_url) - anthropic_clients[api_key] = client + client = get_anthropic_async_client(api_key, api_base_url) + anthropic_async_clients[api_key] = client model_kwargs = model_kwargs or dict() max_tokens = DEFAULT_MAX_TOKENS_ANTHROPIC @@ -180,7 +144,8 @@ def anthropic_llm_thread( aggregated_response = "" final_message = None - with client.messages.stream( + start_time = perf_counter() + async with client.messages.stream( messages=formatted_messages, model=model_name, # type: ignore temperature=temperature, @@ -189,10 +154,17 @@ def anthropic_llm_thread( max_tokens=max_tokens, **model_kwargs, ) as stream: - for text in stream.text_stream: + async for text in stream.text_stream: + # Log the time taken to start response + if aggregated_response == "": + logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds") + # Handle streamed response chunk aggregated_response += text - g.send(text) - final_message = stream.get_final_message() + yield text + final_message = await stream.get_final_message() + + # Log the time taken to stream the entire response + logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds") # Calculate cost of chat input_tokens = final_message.usage.input_tokens @@ -209,9 +181,7 @@ def anthropic_llm_thread( if is_promptrace_enabled(): commit_conversation_trace(messages, aggregated_response, tracer) except Exception as e: - logger.error(f"Error in anthropic_llm_thread: {e}", exc_info=True) - finally: - g.close() + logger.error(f"Error in anthropic_chat_completion_with_backoff stream: {e}", exc_info=True) def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt: str = None): diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index a5aa415a..3d154059 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -1534,10 +1534,9 @@ async def agenerate_chat_response( ) elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC: - # Assuming converse_anthropic remains sync or is refactored separately api_key = chat_model.ai_model_api.api_key api_base_url = chat_model.ai_model_api.api_base_url - chat_response_generator = converse_anthropic( # Needs adaptation if it becomes async + chat_response_generator = converse_anthropic( compiled_references, query_to_run, query_images=query_images, From 763fa2fa794e15c5ba4ac17216bd0b46267ad994 Mon Sep 17 00:00:00 2001 From: Debanjum Date: Sun, 20 Apr 2025 03:42:04 +0530 Subject: [PATCH 5/6] Refactor Offline chat response to stream async, with separate thread --- .../conversation/offline/chat_model.py | 115 +++++++++++++----- src/khoj/routers/helpers.py | 7 +- 2 files changed, 85 insertions(+), 37 deletions(-) diff --git a/src/khoj/processor/conversation/offline/chat_model.py b/src/khoj/processor/conversation/offline/chat_model.py index f727fd1d..b7f89c8d 100644 --- a/src/khoj/processor/conversation/offline/chat_model.py +++ b/src/khoj/processor/conversation/offline/chat_model.py @@ -1,9 +1,10 @@ -import json +import asyncio import logging import os from datetime import datetime, timedelta from threading import Thread -from typing import Any, Dict, Iterator, List, Optional, Union +from time import perf_counter +from typing import Any, AsyncGenerator, Dict, List, Optional, Union import pyjson5 from langchain.schema import ChatMessage @@ -13,7 +14,6 @@ from khoj.database.models import Agent, ChatModel, KhojUser from khoj.processor.conversation import prompts from khoj.processor.conversation.offline.utils import download_model from khoj.processor.conversation.utils import ( - ThreadedGenerator, clean_json, commit_conversation_trace, generate_chatml_messages_with_context, @@ -147,7 +147,7 @@ def filter_questions(questions: List[str]): return list(filtered_questions) -def converse_offline( +async def converse_offline( user_query, references=[], online_results={}, @@ -167,9 +167,9 @@ def converse_offline( additional_context: List[str] = None, generated_asset_results: Dict[str, Dict] = {}, tracer: dict = {}, -) -> Union[ThreadedGenerator, Iterator[str]]: +) -> AsyncGenerator[str, None]: """ - Converse with user using Llama + Converse with user using Llama (Async Version) """ # Initialize Variables assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured" @@ -200,10 +200,17 @@ def converse_offline( # Get Conversation Primer appropriate to Conversation Type if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references): - return iter([prompts.no_notes_found.format()]) + response = prompts.no_notes_found.format() + if completion_func: + await completion_func(chat_response=response) + yield response + return elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results): - completion_func(chat_response=prompts.no_online_results_found.format()) - return iter([prompts.no_online_results_found.format()]) + response = prompts.no_online_results_found.format() + if completion_func: + await completion_func(chat_response=response) + yield response + return context_message = "" if not is_none_or_empty(references): @@ -240,33 +247,77 @@ def converse_offline( logger.debug(f"Conversation Context for {model_name}: {messages_to_print(messages)}") - g = ThreadedGenerator(references, online_results, completion_func=completion_func) - t = Thread(target=llm_thread, args=(g, messages, offline_chat_model, max_prompt_size, tracer)) - t.start() - return g - - -def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int = None, tracer: dict = {}): + # Use asyncio.Queue and a thread to bridge sync iterator + queue: asyncio.Queue = asyncio.Queue() stop_phrases = ["", "INST]", "Notes:"] - aggregated_response = "" + aggregated_response_container = {"response": ""} - state.chat_lock.acquire() - try: - response_iterator = send_message_to_model_offline( - messages, loaded_model=model, stop=stop_phrases, max_prompt_size=max_prompt_size, streaming=True - ) - for response in response_iterator: - response_delta = response["choices"][0]["delta"].get("content", "") - aggregated_response += response_delta - g.send(response_delta) + def _sync_llm_thread(): + """Synchronous function to run in a separate thread.""" + aggregated_response = "" + start_time = perf_counter() + state.chat_lock.acquire() + try: + response_iterator = send_message_to_model_offline( + messages, + loaded_model=offline_chat_model, + stop=stop_phrases, + max_prompt_size=max_prompt_size, + streaming=True, + tracer=tracer, + ) + for response in response_iterator: + response_delta = response["choices"][0]["delta"].get("content", "") + # Log the time taken to start response + if aggregated_response == "" and response_delta != "": + logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds") + # Handle response chunk + aggregated_response += response_delta + # Put chunk into the asyncio queue (non-blocking) + try: + queue.put_nowait(response_delta) + except asyncio.QueueFull: + # Should not happen with default queue size unless consumer is very slow + logger.warning("Asyncio queue full during offline LLM streaming.") + # Potentially block here or handle differently if needed + asyncio.run(queue.put(response_delta)) - # Save conversation trace - if is_promptrace_enabled(): - commit_conversation_trace(messages, aggregated_response, tracer) + # Log the time taken to stream the entire response + logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds") - finally: - state.chat_lock.release() - g.close() + # Save conversation trace + tracer["chat_model"] = model_name + if is_promptrace_enabled(): + commit_conversation_trace(messages, aggregated_response, tracer) + + except Exception as e: + logger.error(f"Error in offline LLM thread: {e}", exc_info=True) + finally: + state.chat_lock.release() + # Signal end of stream + queue.put_nowait(None) + aggregated_response_container["response"] = aggregated_response + + # Start the synchronous thread + thread = Thread(target=_sync_llm_thread) + thread.start() + + # Asynchronously consume from the queue + while True: + chunk = await queue.get() + if chunk is None: # End of stream signal + queue.task_done() + break + yield chunk + queue.task_done() + + # Wait for the thread to finish (optional, ensures cleanup) + loop = asyncio.get_running_loop() + await loop.run_in_executor(None, thread.join) + + # Call the completion function after streaming is done + if completion_func: + await completion_func(chat_response=aggregated_response_container["response"]) def send_message_to_model_offline( diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 3d154059..a0baffb9 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -93,7 +93,6 @@ from khoj.processor.conversation.openai.gpt import ( ) from khoj.processor.conversation.utils import ( ChatEvent, - ThreadedGenerator, clean_json, clean_mermaidjs, construct_chat_history, @@ -1480,16 +1479,14 @@ async def agenerate_chat_response( vision_available = True if chat_model.model_type == "offline": - # Assuming converse_offline remains sync or is refactored separately loaded_model = state.offline_chat_processor_config.loaded_model - # If converse_offline returns an iterator, wrap it if needed, or refactor it to async generator - chat_response_generator = converse_offline( # Needs adaptation if it becomes async + chat_response_generator = converse_offline( user_query=query_to_run, references=compiled_references, online_results=online_results, loaded_model=loaded_model, conversation_log=meta_log, - completion_func=partial_completion, # Pass the async wrapper + completion_func=partial_completion, conversation_commands=conversation_commands, model_name=chat_model.name, max_prompt_size=chat_model.max_prompt_size, From a4b5842ac372ef3a37db08ee37ea80042fa87f6f Mon Sep 17 00:00:00 2001 From: Debanjum Date: Mon, 21 Apr 2025 14:15:06 +0530 Subject: [PATCH 6/6] Remove ThreadedGenerator class, previously used to stream chat response --- src/khoj/processor/conversation/utils.py | 36 ------------------------ 1 file changed, 36 deletions(-) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 2bb6265e..b8eea907 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -77,42 +77,6 @@ model_to_prompt_size = { model_to_tokenizer: Dict[str, str] = {} -class ThreadedGenerator: - def __init__(self, compiled_references, online_results, completion_func=None): - self.queue = queue.Queue() - self.compiled_references = compiled_references - self.online_results = online_results - self.completion_func = completion_func - self.response = "" - self.start_time = perf_counter() - - def __iter__(self): - return self - - def __next__(self): - item = self.queue.get() - if item is StopIteration: - time_to_response = perf_counter() - self.start_time - logger.info(f"Chat streaming took: {time_to_response:.3f} seconds") - if self.completion_func: - # The completion func effectively acts as a callback. - # It adds the aggregated response to the conversation history. - self.completion_func(chat_response=self.response) - raise StopIteration - return item - - def send(self, data): - if self.response == "": - time_to_first_response = perf_counter() - self.start_time - logger.info(f"First response took: {time_to_first_response:.3f} seconds") - - self.response += data - self.queue.put(data) - - def close(self): - self.queue.put(StopIteration) - - class InformationCollectionIteration: def __init__( self,