diff --git a/src/khoj/processor/conversation/anthropic/anthropic_chat.py b/src/khoj/processor/conversation/anthropic/anthropic_chat.py index 775b8b99..0079cf8c 100644 --- a/src/khoj/processor/conversation/anthropic/anthropic_chat.py +++ b/src/khoj/processor/conversation/anthropic/anthropic_chat.py @@ -1,4 +1,3 @@ -import asyncio import logging from datetime import datetime, timedelta from typing import AsyncGenerator, Dict, List, Optional @@ -146,7 +145,6 @@ async def converse_anthropic( model: Optional[str] = "claude-3-7-sonnet-latest", api_key: Optional[str] = None, api_base_url: Optional[str] = None, - completion_func=None, conversation_commands=[ConversationCommand.Default], max_prompt_size=None, tokenizer_name=None, @@ -161,7 +159,7 @@ async def converse_anthropic( generated_asset_results: Dict[str, Dict] = {}, deepthought: Optional[bool] = False, tracer: dict = {}, -) -> AsyncGenerator[str | ResponseWithThought, None]: +) -> AsyncGenerator[ResponseWithThought, None]: """ Converse with user using Anthropic's Claude """ @@ -192,15 +190,11 @@ async def converse_anthropic( # Get Conversation Primer appropriate to Conversation Type if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references): response = prompts.no_notes_found.format() - if completion_func: - asyncio.create_task(completion_func(chat_response=response)) - yield response + yield ResponseWithThought(response=response) return elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results): response = prompts.no_online_results_found.format() - if completion_func: - asyncio.create_task(completion_func(chat_response=response)) - yield response + yield ResponseWithThought(response=response) return context_message = "" @@ -241,7 +235,6 @@ async def converse_anthropic( logger.debug(f"Conversation Context for Claude: {messages_to_print(messages)}") # Get Response from Claude - full_response = "" async for chunk in anthropic_chat_completion_with_backoff( messages=messages, model_name=model, @@ -253,10 +246,4 @@ async def converse_anthropic( deepthought=deepthought, tracer=tracer, ): - if chunk.response: - full_response += chunk.response yield chunk - - # Call completion_func once finish streaming and we have the full response - if completion_func: - asyncio.create_task(completion_func(chat_response=full_response)) diff --git a/src/khoj/processor/conversation/google/gemini_chat.py b/src/khoj/processor/conversation/google/gemini_chat.py index b2f48c81..78cd6fa4 100644 --- a/src/khoj/processor/conversation/google/gemini_chat.py +++ b/src/khoj/processor/conversation/google/gemini_chat.py @@ -1,4 +1,3 @@ -import asyncio import logging from datetime import datetime, timedelta from typing import AsyncGenerator, Dict, List, Optional @@ -15,6 +14,7 @@ from khoj.processor.conversation.google.utils import ( ) from khoj.processor.conversation.utils import ( OperatorRun, + ResponseWithThought, clean_json, construct_question_history, construct_structured_message, @@ -168,7 +168,6 @@ async def converse_gemini( api_key: Optional[str] = None, api_base_url: Optional[str] = None, temperature: float = 1.0, - completion_func=None, conversation_commands=[ConversationCommand.Default], max_prompt_size=None, tokenizer_name=None, @@ -183,7 +182,7 @@ async def converse_gemini( program_execution_context: List[str] = None, deepthought: Optional[bool] = False, tracer={}, -) -> AsyncGenerator[str, None]: +) -> AsyncGenerator[ResponseWithThought, None]: """ Converse with user using Google's Gemini """ @@ -215,15 +214,11 @@ async def converse_gemini( # Get Conversation Primer appropriate to Conversation Type if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references): response = prompts.no_notes_found.format() - if completion_func: - asyncio.create_task(completion_func(chat_response=response)) - yield response + yield ResponseWithThought(response=response) return elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results): response = prompts.no_online_results_found.format() - if completion_func: - asyncio.create_task(completion_func(chat_response=response)) - yield response + yield ResponseWithThought(response=response) return context_message = "" @@ -264,7 +259,6 @@ async def converse_gemini( logger.debug(f"Conversation Context for Gemini: {messages_to_print(messages)}") # Get Response from Google AI - full_response = "" async for chunk in gemini_chat_completion_with_backoff( messages=messages, model_name=model, @@ -275,10 +269,4 @@ async def converse_gemini( deepthought=deepthought, tracer=tracer, ): - if chunk.response: - full_response += chunk.response yield chunk - - # Call completion_func once finish streaming and we have the full response - if completion_func: - asyncio.create_task(completion_func(chat_response=full_response)) diff --git a/src/khoj/processor/conversation/offline/chat_model.py b/src/khoj/processor/conversation/offline/chat_model.py index 27cd9a9e..639b1c8b 100644 --- a/src/khoj/processor/conversation/offline/chat_model.py +++ b/src/khoj/processor/conversation/offline/chat_model.py @@ -14,6 +14,7 @@ from khoj.database.models import Agent, ChatMessageModel, ChatModel, KhojUser from khoj.processor.conversation import prompts from khoj.processor.conversation.offline.utils import download_model from khoj.processor.conversation.utils import ( + ResponseWithThought, clean_json, commit_conversation_trace, construct_question_history, @@ -150,7 +151,6 @@ async def converse_offline( chat_history: list[ChatMessageModel] = [], model_name: str = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", loaded_model: Union[Any, None] = None, - completion_func=None, conversation_commands=[ConversationCommand.Default], max_prompt_size=None, tokenizer_name=None, @@ -162,7 +162,7 @@ async def converse_offline( additional_context: List[str] = None, generated_asset_results: Dict[str, Dict] = {}, tracer: dict = {}, -) -> AsyncGenerator[str, None]: +) -> AsyncGenerator[ResponseWithThought, None]: """ Converse with user using Llama (Async Version) """ @@ -196,15 +196,11 @@ async def converse_offline( # Get Conversation Primer appropriate to Conversation Type if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references): response = prompts.no_notes_found.format() - if completion_func: - asyncio.create_task(completion_func(chat_response=response)) - yield response + yield ResponseWithThought(response=response) return elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results): response = prompts.no_online_results_found.format() - if completion_func: - asyncio.create_task(completion_func(chat_response=response)) - yield response + yield ResponseWithThought(response=response) return context_message = "" @@ -243,9 +239,8 @@ async def converse_offline( logger.debug(f"Conversation Context for {model_name}: {messages_to_print(messages)}") # Use asyncio.Queue and a thread to bridge sync iterator - queue: asyncio.Queue = asyncio.Queue() + queue: asyncio.Queue[ResponseWithThought] = asyncio.Queue() stop_phrases = ["", "INST]", "Notes:"] - aggregated_response_container = {"response": ""} def _sync_llm_thread(): """Synchronous function to run in a separate thread.""" @@ -262,7 +257,7 @@ async def converse_offline( tracer=tracer, ) for response in response_iterator: - response_delta = response["choices"][0]["delta"].get("content", "") + response_delta: str = response["choices"][0]["delta"].get("content", "") # Log the time taken to start response if aggregated_response == "" and response_delta != "": logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds") @@ -270,12 +265,12 @@ async def converse_offline( aggregated_response += response_delta # Put chunk into the asyncio queue (non-blocking) try: - queue.put_nowait(response_delta) + queue.put_nowait(ResponseWithThought(response=response_delta)) except asyncio.QueueFull: # Should not happen with default queue size unless consumer is very slow logger.warning("Asyncio queue full during offline LLM streaming.") # Potentially block here or handle differently if needed - asyncio.run(queue.put(response_delta)) + asyncio.run(queue.put(ResponseWithThought(response=response_delta))) # Log the time taken to stream the entire response logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds") @@ -291,7 +286,6 @@ async def converse_offline( state.chat_lock.release() # Signal end of stream queue.put_nowait(None) - aggregated_response_container["response"] = aggregated_response # Start the synchronous thread thread = Thread(target=_sync_llm_thread) @@ -310,10 +304,6 @@ async def converse_offline( loop = asyncio.get_running_loop() await loop.run_in_executor(None, thread.join) - # Call the completion function after streaming is done - if completion_func: - asyncio.create_task(completion_func(chat_response=aggregated_response_container["response"])) - def send_message_to_model_offline( messages: List[ChatMessage], diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index 20500458..f49030b9 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -1,4 +1,3 @@ -import asyncio import logging from datetime import datetime, timedelta from typing import AsyncGenerator, Dict, List, Optional @@ -171,7 +170,6 @@ async def converse_openai( api_key: Optional[str] = None, api_base_url: Optional[str] = None, temperature: float = 0.4, - completion_func=None, conversation_commands=[ConversationCommand.Default], max_prompt_size=None, tokenizer_name=None, @@ -186,7 +184,7 @@ async def converse_openai( program_execution_context: List[str] = None, deepthought: Optional[bool] = False, tracer: dict = {}, -) -> AsyncGenerator[str | ResponseWithThought, None]: +) -> AsyncGenerator[ResponseWithThought, None]: """ Converse with user using OpenAI's ChatGPT """ @@ -217,15 +215,11 @@ async def converse_openai( # Get Conversation Primer appropriate to Conversation Type if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references): response = prompts.no_notes_found.format() - if completion_func: - asyncio.create_task(completion_func(chat_response=response)) - yield response + yield ResponseWithThought(response=response) return elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results): response = prompts.no_online_results_found.format() - if completion_func: - asyncio.create_task(completion_func(chat_response=response)) - yield response + yield ResponseWithThought(response=response) return context_message = "" @@ -267,7 +261,6 @@ async def converse_openai( logger.debug(f"Conversation Context for GPT: {messages_to_print(messages)}") # Get Response from GPT - full_response = "" async for chunk in chat_completion_with_backoff( messages=messages, model_name=model, @@ -277,14 +270,8 @@ async def converse_openai( deepthought=deepthought, tracer=tracer, ): - if chunk.response: - full_response += chunk.response yield chunk - # Call completion_func once finish streaming and we have the full response - if completion_func: - asyncio.create_task(completion_func(chat_response=full_response)) - def clean_response_schema(schema: BaseModel | dict) -> dict: """ diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index df3245fe..83adf4de 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -1463,33 +1463,30 @@ async def chat( code_results, operator_results, research_results, - inferred_queries, conversation_commands, user, - request.user.client_app, location, user_name, uploaded_images, - train_of_thought, attached_file_context, - raw_query_files, - generated_images, generated_files, - generated_mermaidjs_diagram, program_execution_context, generated_asset_results, is_subscribed, tracer, ) + full_response = "" async for item in llm_response: - # Should not happen with async generator, end is signaled by loop exit. Skip. - if item is None: + # Should not happen with async generator. Skip. + if item is None or not isinstance(item, ResponseWithThought): + logger.warning(f"Unexpected item type in LLM response: {type(item)}. Skipping.") continue if cancellation_event.is_set(): break - message = item.response if isinstance(item, ResponseWithThought) else item - if isinstance(item, ResponseWithThought) and item.thought: + message = item.response + full_response += message if message else "" + if item.thought: async for result in send_event(ChatEvent.THOUGHT, item.thought): yield result continue @@ -1506,6 +1503,31 @@ async def chat( logger.warning(f"Error during streaming. Stopping send: {e}") break + # Save conversation once finish streaming + asyncio.create_task( + save_to_conversation_log( + q, + chat_response=full_response, + user=user, + chat_history=chat_history, + compiled_references=compiled_references, + online_results=online_results, + code_results=code_results, + operator_results=operator_results, + research_results=research_results, + inferred_queries=inferred_queries, + client_application=request.user.client_app, + conversation_id=str(conversation.id), + query_images=uploaded_images, + train_of_thought=train_of_thought, + raw_query_files=raw_query_files, + generated_images=generated_images, + raw_generated_files=generated_files, + generated_mermaidjs_diagram=generated_mermaidjs_diagram, + tracer=tracer, + ) + ) + # Signal end of LLM response after the loop finishes if not cancellation_event.is_set(): async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""): diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 3eaefd8c..6723029a 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -6,7 +6,6 @@ import math import os import re from datetime import datetime, timedelta, timezone -from functools import partial from random import random from typing import ( Annotated, @@ -102,7 +101,6 @@ from khoj.processor.conversation.utils import ( clean_mermaidjs, construct_chat_history, generate_chatml_messages_with_context, - save_to_conversation_log, ) from khoj.processor.speech.text_to_speech import is_eleven_labs_enabled from khoj.routers.email import is_resend_enabled, send_task_email @@ -1350,54 +1348,26 @@ async def agenerate_chat_response( code_results: Dict[str, Dict] = {}, operator_results: List[OperatorRun] = [], research_results: List[ResearchIteration] = [], - inferred_queries: List[str] = [], conversation_commands: List[ConversationCommand] = [ConversationCommand.Default], user: KhojUser = None, - client_application: ClientApplication = None, location_data: LocationData = None, user_name: Optional[str] = None, query_images: Optional[List[str]] = None, - train_of_thought: List[Any] = [], query_files: str = None, - raw_query_files: List[FileAttachment] = None, - generated_images: List[str] = None, raw_generated_files: List[FileAttachment] = [], - generated_mermaidjs_diagram: str = None, program_execution_context: List[str] = [], generated_asset_results: Dict[str, Dict] = {}, is_subscribed: bool = False, tracer: dict = {}, -) -> Tuple[AsyncGenerator[str | ResponseWithThought, None], Dict[str, str]]: +) -> Tuple[AsyncGenerator[ResponseWithThought, None], Dict[str, str]]: # Initialize Variables - chat_response_generator: AsyncGenerator[str | ResponseWithThought, None] = None + chat_response_generator: AsyncGenerator[ResponseWithThought, None] = None logger.debug(f"Conversation Types: {conversation_commands}") metadata = {} agent = await AgentAdapters.aget_conversation_agent_by_id(conversation.agent.id) if conversation.agent else None try: - partial_completion = partial( - save_to_conversation_log, - q, - user=user, - chat_history=chat_history, - compiled_references=compiled_references, - online_results=online_results, - code_results=code_results, - operator_results=operator_results, - research_results=research_results, - inferred_queries=inferred_queries, - client_application=client_application, - conversation_id=str(conversation.id), - query_images=query_images, - train_of_thought=train_of_thought, - raw_query_files=raw_query_files, - generated_images=generated_images, - raw_generated_files=raw_generated_files, - generated_mermaidjs_diagram=generated_mermaidjs_diagram, - tracer=tracer, - ) - query_to_run = q deepthought = False if research_results: @@ -1426,7 +1396,6 @@ async def agenerate_chat_response( online_results=online_results, loaded_model=loaded_model, chat_history=chat_history, - completion_func=partial_completion, conversation_commands=conversation_commands, model_name=chat_model.name, max_prompt_size=chat_model.max_prompt_size, @@ -1455,7 +1424,6 @@ async def agenerate_chat_response( model=chat_model_name, api_key=api_key, api_base_url=openai_chat_config.api_base_url, - completion_func=partial_completion, conversation_commands=conversation_commands, max_prompt_size=chat_model.max_prompt_size, tokenizer_name=chat_model.tokenizer, @@ -1485,7 +1453,6 @@ async def agenerate_chat_response( model=chat_model.name, api_key=api_key, api_base_url=api_base_url, - completion_func=partial_completion, conversation_commands=conversation_commands, max_prompt_size=chat_model.max_prompt_size, tokenizer_name=chat_model.tokenizer, @@ -1513,7 +1480,6 @@ async def agenerate_chat_response( model=chat_model.name, api_key=api_key, api_base_url=api_base_url, - completion_func=partial_completion, conversation_commands=conversation_commands, max_prompt_size=chat_model.max_prompt_size, tokenizer_name=chat_model.tokenizer,