Save conversation in common chat api func instead of each ai provider

This commit is contained in:
Debanjum
2025-06-04 18:37:11 -07:00
parent e7584bc29d
commit bfd4695705
6 changed files with 52 additions and 112 deletions

View File

@@ -1,4 +1,3 @@
import asyncio
import logging import logging
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import AsyncGenerator, Dict, List, Optional from typing import AsyncGenerator, Dict, List, Optional
@@ -146,7 +145,6 @@ async def converse_anthropic(
model: Optional[str] = "claude-3-7-sonnet-latest", model: Optional[str] = "claude-3-7-sonnet-latest",
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base_url: Optional[str] = None, api_base_url: Optional[str] = None,
completion_func=None,
conversation_commands=[ConversationCommand.Default], conversation_commands=[ConversationCommand.Default],
max_prompt_size=None, max_prompt_size=None,
tokenizer_name=None, tokenizer_name=None,
@@ -161,7 +159,7 @@ async def converse_anthropic(
generated_asset_results: Dict[str, Dict] = {}, generated_asset_results: Dict[str, Dict] = {},
deepthought: Optional[bool] = False, deepthought: Optional[bool] = False,
tracer: dict = {}, tracer: dict = {},
) -> AsyncGenerator[str | ResponseWithThought, None]: ) -> AsyncGenerator[ResponseWithThought, None]:
""" """
Converse with user using Anthropic's Claude Converse with user using Anthropic's Claude
""" """
@@ -192,15 +190,11 @@ async def converse_anthropic(
# Get Conversation Primer appropriate to Conversation Type # Get Conversation Primer appropriate to Conversation Type
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references): if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
response = prompts.no_notes_found.format() response = prompts.no_notes_found.format()
if completion_func: yield ResponseWithThought(response=response)
asyncio.create_task(completion_func(chat_response=response))
yield response
return return
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results): elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
response = prompts.no_online_results_found.format() response = prompts.no_online_results_found.format()
if completion_func: yield ResponseWithThought(response=response)
asyncio.create_task(completion_func(chat_response=response))
yield response
return return
context_message = "" context_message = ""
@@ -241,7 +235,6 @@ async def converse_anthropic(
logger.debug(f"Conversation Context for Claude: {messages_to_print(messages)}") logger.debug(f"Conversation Context for Claude: {messages_to_print(messages)}")
# Get Response from Claude # Get Response from Claude
full_response = ""
async for chunk in anthropic_chat_completion_with_backoff( async for chunk in anthropic_chat_completion_with_backoff(
messages=messages, messages=messages,
model_name=model, model_name=model,
@@ -253,10 +246,4 @@ async def converse_anthropic(
deepthought=deepthought, deepthought=deepthought,
tracer=tracer, tracer=tracer,
): ):
if chunk.response:
full_response += chunk.response
yield chunk yield chunk
# Call completion_func once finish streaming and we have the full response
if completion_func:
asyncio.create_task(completion_func(chat_response=full_response))

View File

@@ -1,4 +1,3 @@
import asyncio
import logging import logging
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import AsyncGenerator, Dict, List, Optional from typing import AsyncGenerator, Dict, List, Optional
@@ -15,6 +14,7 @@ from khoj.processor.conversation.google.utils import (
) )
from khoj.processor.conversation.utils import ( from khoj.processor.conversation.utils import (
OperatorRun, OperatorRun,
ResponseWithThought,
clean_json, clean_json,
construct_question_history, construct_question_history,
construct_structured_message, construct_structured_message,
@@ -168,7 +168,6 @@ async def converse_gemini(
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base_url: Optional[str] = None, api_base_url: Optional[str] = None,
temperature: float = 1.0, temperature: float = 1.0,
completion_func=None,
conversation_commands=[ConversationCommand.Default], conversation_commands=[ConversationCommand.Default],
max_prompt_size=None, max_prompt_size=None,
tokenizer_name=None, tokenizer_name=None,
@@ -183,7 +182,7 @@ async def converse_gemini(
program_execution_context: List[str] = None, program_execution_context: List[str] = None,
deepthought: Optional[bool] = False, deepthought: Optional[bool] = False,
tracer={}, tracer={},
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[ResponseWithThought, None]:
""" """
Converse with user using Google's Gemini Converse with user using Google's Gemini
""" """
@@ -215,15 +214,11 @@ async def converse_gemini(
# Get Conversation Primer appropriate to Conversation Type # Get Conversation Primer appropriate to Conversation Type
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references): if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
response = prompts.no_notes_found.format() response = prompts.no_notes_found.format()
if completion_func: yield ResponseWithThought(response=response)
asyncio.create_task(completion_func(chat_response=response))
yield response
return return
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results): elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
response = prompts.no_online_results_found.format() response = prompts.no_online_results_found.format()
if completion_func: yield ResponseWithThought(response=response)
asyncio.create_task(completion_func(chat_response=response))
yield response
return return
context_message = "" context_message = ""
@@ -264,7 +259,6 @@ async def converse_gemini(
logger.debug(f"Conversation Context for Gemini: {messages_to_print(messages)}") logger.debug(f"Conversation Context for Gemini: {messages_to_print(messages)}")
# Get Response from Google AI # Get Response from Google AI
full_response = ""
async for chunk in gemini_chat_completion_with_backoff( async for chunk in gemini_chat_completion_with_backoff(
messages=messages, messages=messages,
model_name=model, model_name=model,
@@ -275,10 +269,4 @@ async def converse_gemini(
deepthought=deepthought, deepthought=deepthought,
tracer=tracer, tracer=tracer,
): ):
if chunk.response:
full_response += chunk.response
yield chunk yield chunk
# Call completion_func once finish streaming and we have the full response
if completion_func:
asyncio.create_task(completion_func(chat_response=full_response))

View File

@@ -14,6 +14,7 @@ from khoj.database.models import Agent, ChatMessageModel, ChatModel, KhojUser
from khoj.processor.conversation import prompts from khoj.processor.conversation import prompts
from khoj.processor.conversation.offline.utils import download_model from khoj.processor.conversation.offline.utils import download_model
from khoj.processor.conversation.utils import ( from khoj.processor.conversation.utils import (
ResponseWithThought,
clean_json, clean_json,
commit_conversation_trace, commit_conversation_trace,
construct_question_history, construct_question_history,
@@ -150,7 +151,6 @@ async def converse_offline(
chat_history: list[ChatMessageModel] = [], chat_history: list[ChatMessageModel] = [],
model_name: str = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", model_name: str = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF",
loaded_model: Union[Any, None] = None, loaded_model: Union[Any, None] = None,
completion_func=None,
conversation_commands=[ConversationCommand.Default], conversation_commands=[ConversationCommand.Default],
max_prompt_size=None, max_prompt_size=None,
tokenizer_name=None, tokenizer_name=None,
@@ -162,7 +162,7 @@ async def converse_offline(
additional_context: List[str] = None, additional_context: List[str] = None,
generated_asset_results: Dict[str, Dict] = {}, generated_asset_results: Dict[str, Dict] = {},
tracer: dict = {}, tracer: dict = {},
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[ResponseWithThought, None]:
""" """
Converse with user using Llama (Async Version) Converse with user using Llama (Async Version)
""" """
@@ -196,15 +196,11 @@ async def converse_offline(
# Get Conversation Primer appropriate to Conversation Type # Get Conversation Primer appropriate to Conversation Type
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references): if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
response = prompts.no_notes_found.format() response = prompts.no_notes_found.format()
if completion_func: yield ResponseWithThought(response=response)
asyncio.create_task(completion_func(chat_response=response))
yield response
return return
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results): elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
response = prompts.no_online_results_found.format() response = prompts.no_online_results_found.format()
if completion_func: yield ResponseWithThought(response=response)
asyncio.create_task(completion_func(chat_response=response))
yield response
return return
context_message = "" context_message = ""
@@ -243,9 +239,8 @@ async def converse_offline(
logger.debug(f"Conversation Context for {model_name}: {messages_to_print(messages)}") logger.debug(f"Conversation Context for {model_name}: {messages_to_print(messages)}")
# Use asyncio.Queue and a thread to bridge sync iterator # Use asyncio.Queue and a thread to bridge sync iterator
queue: asyncio.Queue = asyncio.Queue() queue: asyncio.Queue[ResponseWithThought] = asyncio.Queue()
stop_phrases = ["<s>", "INST]", "Notes:"] stop_phrases = ["<s>", "INST]", "Notes:"]
aggregated_response_container = {"response": ""}
def _sync_llm_thread(): def _sync_llm_thread():
"""Synchronous function to run in a separate thread.""" """Synchronous function to run in a separate thread."""
@@ -262,7 +257,7 @@ async def converse_offline(
tracer=tracer, tracer=tracer,
) )
for response in response_iterator: for response in response_iterator:
response_delta = response["choices"][0]["delta"].get("content", "") response_delta: str = response["choices"][0]["delta"].get("content", "")
# Log the time taken to start response # Log the time taken to start response
if aggregated_response == "" and response_delta != "": if aggregated_response == "" and response_delta != "":
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds") logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
@@ -270,12 +265,12 @@ async def converse_offline(
aggregated_response += response_delta aggregated_response += response_delta
# Put chunk into the asyncio queue (non-blocking) # Put chunk into the asyncio queue (non-blocking)
try: try:
queue.put_nowait(response_delta) queue.put_nowait(ResponseWithThought(response=response_delta))
except asyncio.QueueFull: except asyncio.QueueFull:
# Should not happen with default queue size unless consumer is very slow # Should not happen with default queue size unless consumer is very slow
logger.warning("Asyncio queue full during offline LLM streaming.") logger.warning("Asyncio queue full during offline LLM streaming.")
# Potentially block here or handle differently if needed # Potentially block here or handle differently if needed
asyncio.run(queue.put(response_delta)) asyncio.run(queue.put(ResponseWithThought(response=response_delta)))
# Log the time taken to stream the entire response # Log the time taken to stream the entire response
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds") logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
@@ -291,7 +286,6 @@ async def converse_offline(
state.chat_lock.release() state.chat_lock.release()
# Signal end of stream # Signal end of stream
queue.put_nowait(None) queue.put_nowait(None)
aggregated_response_container["response"] = aggregated_response
# Start the synchronous thread # Start the synchronous thread
thread = Thread(target=_sync_llm_thread) thread = Thread(target=_sync_llm_thread)
@@ -310,10 +304,6 @@ async def converse_offline(
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
await loop.run_in_executor(None, thread.join) await loop.run_in_executor(None, thread.join)
# Call the completion function after streaming is done
if completion_func:
asyncio.create_task(completion_func(chat_response=aggregated_response_container["response"]))
def send_message_to_model_offline( def send_message_to_model_offline(
messages: List[ChatMessage], messages: List[ChatMessage],

View File

@@ -1,4 +1,3 @@
import asyncio
import logging import logging
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import AsyncGenerator, Dict, List, Optional from typing import AsyncGenerator, Dict, List, Optional
@@ -171,7 +170,6 @@ async def converse_openai(
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base_url: Optional[str] = None, api_base_url: Optional[str] = None,
temperature: float = 0.4, temperature: float = 0.4,
completion_func=None,
conversation_commands=[ConversationCommand.Default], conversation_commands=[ConversationCommand.Default],
max_prompt_size=None, max_prompt_size=None,
tokenizer_name=None, tokenizer_name=None,
@@ -186,7 +184,7 @@ async def converse_openai(
program_execution_context: List[str] = None, program_execution_context: List[str] = None,
deepthought: Optional[bool] = False, deepthought: Optional[bool] = False,
tracer: dict = {}, tracer: dict = {},
) -> AsyncGenerator[str | ResponseWithThought, None]: ) -> AsyncGenerator[ResponseWithThought, None]:
""" """
Converse with user using OpenAI's ChatGPT Converse with user using OpenAI's ChatGPT
""" """
@@ -217,15 +215,11 @@ async def converse_openai(
# Get Conversation Primer appropriate to Conversation Type # Get Conversation Primer appropriate to Conversation Type
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references): if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
response = prompts.no_notes_found.format() response = prompts.no_notes_found.format()
if completion_func: yield ResponseWithThought(response=response)
asyncio.create_task(completion_func(chat_response=response))
yield response
return return
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results): elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
response = prompts.no_online_results_found.format() response = prompts.no_online_results_found.format()
if completion_func: yield ResponseWithThought(response=response)
asyncio.create_task(completion_func(chat_response=response))
yield response
return return
context_message = "" context_message = ""
@@ -267,7 +261,6 @@ async def converse_openai(
logger.debug(f"Conversation Context for GPT: {messages_to_print(messages)}") logger.debug(f"Conversation Context for GPT: {messages_to_print(messages)}")
# Get Response from GPT # Get Response from GPT
full_response = ""
async for chunk in chat_completion_with_backoff( async for chunk in chat_completion_with_backoff(
messages=messages, messages=messages,
model_name=model, model_name=model,
@@ -277,14 +270,8 @@ async def converse_openai(
deepthought=deepthought, deepthought=deepthought,
tracer=tracer, tracer=tracer,
): ):
if chunk.response:
full_response += chunk.response
yield chunk yield chunk
# Call completion_func once finish streaming and we have the full response
if completion_func:
asyncio.create_task(completion_func(chat_response=full_response))
def clean_response_schema(schema: BaseModel | dict) -> dict: def clean_response_schema(schema: BaseModel | dict) -> dict:
""" """

View File

@@ -1463,33 +1463,30 @@ async def chat(
code_results, code_results,
operator_results, operator_results,
research_results, research_results,
inferred_queries,
conversation_commands, conversation_commands,
user, user,
request.user.client_app,
location, location,
user_name, user_name,
uploaded_images, uploaded_images,
train_of_thought,
attached_file_context, attached_file_context,
raw_query_files,
generated_images,
generated_files, generated_files,
generated_mermaidjs_diagram,
program_execution_context, program_execution_context,
generated_asset_results, generated_asset_results,
is_subscribed, is_subscribed,
tracer, tracer,
) )
full_response = ""
async for item in llm_response: async for item in llm_response:
# Should not happen with async generator, end is signaled by loop exit. Skip. # Should not happen with async generator. Skip.
if item is None: if item is None or not isinstance(item, ResponseWithThought):
logger.warning(f"Unexpected item type in LLM response: {type(item)}. Skipping.")
continue continue
if cancellation_event.is_set(): if cancellation_event.is_set():
break break
message = item.response if isinstance(item, ResponseWithThought) else item message = item.response
if isinstance(item, ResponseWithThought) and item.thought: full_response += message if message else ""
if item.thought:
async for result in send_event(ChatEvent.THOUGHT, item.thought): async for result in send_event(ChatEvent.THOUGHT, item.thought):
yield result yield result
continue continue
@@ -1506,6 +1503,31 @@ async def chat(
logger.warning(f"Error during streaming. Stopping send: {e}") logger.warning(f"Error during streaming. Stopping send: {e}")
break break
# Save conversation once finish streaming
asyncio.create_task(
save_to_conversation_log(
q,
chat_response=full_response,
user=user,
chat_history=chat_history,
compiled_references=compiled_references,
online_results=online_results,
code_results=code_results,
operator_results=operator_results,
research_results=research_results,
inferred_queries=inferred_queries,
client_application=request.user.client_app,
conversation_id=str(conversation.id),
query_images=uploaded_images,
train_of_thought=train_of_thought,
raw_query_files=raw_query_files,
generated_images=generated_images,
raw_generated_files=generated_files,
generated_mermaidjs_diagram=generated_mermaidjs_diagram,
tracer=tracer,
)
)
# Signal end of LLM response after the loop finishes # Signal end of LLM response after the loop finishes
if not cancellation_event.is_set(): if not cancellation_event.is_set():
async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""): async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):

View File

@@ -6,7 +6,6 @@ import math
import os import os
import re import re
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from functools import partial
from random import random from random import random
from typing import ( from typing import (
Annotated, Annotated,
@@ -102,7 +101,6 @@ from khoj.processor.conversation.utils import (
clean_mermaidjs, clean_mermaidjs,
construct_chat_history, construct_chat_history,
generate_chatml_messages_with_context, generate_chatml_messages_with_context,
save_to_conversation_log,
) )
from khoj.processor.speech.text_to_speech import is_eleven_labs_enabled from khoj.processor.speech.text_to_speech import is_eleven_labs_enabled
from khoj.routers.email import is_resend_enabled, send_task_email from khoj.routers.email import is_resend_enabled, send_task_email
@@ -1350,54 +1348,26 @@ async def agenerate_chat_response(
code_results: Dict[str, Dict] = {}, code_results: Dict[str, Dict] = {},
operator_results: List[OperatorRun] = [], operator_results: List[OperatorRun] = [],
research_results: List[ResearchIteration] = [], research_results: List[ResearchIteration] = [],
inferred_queries: List[str] = [],
conversation_commands: List[ConversationCommand] = [ConversationCommand.Default], conversation_commands: List[ConversationCommand] = [ConversationCommand.Default],
user: KhojUser = None, user: KhojUser = None,
client_application: ClientApplication = None,
location_data: LocationData = None, location_data: LocationData = None,
user_name: Optional[str] = None, user_name: Optional[str] = None,
query_images: Optional[List[str]] = None, query_images: Optional[List[str]] = None,
train_of_thought: List[Any] = [],
query_files: str = None, query_files: str = None,
raw_query_files: List[FileAttachment] = None,
generated_images: List[str] = None,
raw_generated_files: List[FileAttachment] = [], raw_generated_files: List[FileAttachment] = [],
generated_mermaidjs_diagram: str = None,
program_execution_context: List[str] = [], program_execution_context: List[str] = [],
generated_asset_results: Dict[str, Dict] = {}, generated_asset_results: Dict[str, Dict] = {},
is_subscribed: bool = False, is_subscribed: bool = False,
tracer: dict = {}, tracer: dict = {},
) -> Tuple[AsyncGenerator[str | ResponseWithThought, None], Dict[str, str]]: ) -> Tuple[AsyncGenerator[ResponseWithThought, None], Dict[str, str]]:
# Initialize Variables # Initialize Variables
chat_response_generator: AsyncGenerator[str | ResponseWithThought, None] = None chat_response_generator: AsyncGenerator[ResponseWithThought, None] = None
logger.debug(f"Conversation Types: {conversation_commands}") logger.debug(f"Conversation Types: {conversation_commands}")
metadata = {} metadata = {}
agent = await AgentAdapters.aget_conversation_agent_by_id(conversation.agent.id) if conversation.agent else None agent = await AgentAdapters.aget_conversation_agent_by_id(conversation.agent.id) if conversation.agent else None
try: try:
partial_completion = partial(
save_to_conversation_log,
q,
user=user,
chat_history=chat_history,
compiled_references=compiled_references,
online_results=online_results,
code_results=code_results,
operator_results=operator_results,
research_results=research_results,
inferred_queries=inferred_queries,
client_application=client_application,
conversation_id=str(conversation.id),
query_images=query_images,
train_of_thought=train_of_thought,
raw_query_files=raw_query_files,
generated_images=generated_images,
raw_generated_files=raw_generated_files,
generated_mermaidjs_diagram=generated_mermaidjs_diagram,
tracer=tracer,
)
query_to_run = q query_to_run = q
deepthought = False deepthought = False
if research_results: if research_results:
@@ -1426,7 +1396,6 @@ async def agenerate_chat_response(
online_results=online_results, online_results=online_results,
loaded_model=loaded_model, loaded_model=loaded_model,
chat_history=chat_history, chat_history=chat_history,
completion_func=partial_completion,
conversation_commands=conversation_commands, conversation_commands=conversation_commands,
model_name=chat_model.name, model_name=chat_model.name,
max_prompt_size=chat_model.max_prompt_size, max_prompt_size=chat_model.max_prompt_size,
@@ -1455,7 +1424,6 @@ async def agenerate_chat_response(
model=chat_model_name, model=chat_model_name,
api_key=api_key, api_key=api_key,
api_base_url=openai_chat_config.api_base_url, api_base_url=openai_chat_config.api_base_url,
completion_func=partial_completion,
conversation_commands=conversation_commands, conversation_commands=conversation_commands,
max_prompt_size=chat_model.max_prompt_size, max_prompt_size=chat_model.max_prompt_size,
tokenizer_name=chat_model.tokenizer, tokenizer_name=chat_model.tokenizer,
@@ -1485,7 +1453,6 @@ async def agenerate_chat_response(
model=chat_model.name, model=chat_model.name,
api_key=api_key, api_key=api_key,
api_base_url=api_base_url, api_base_url=api_base_url,
completion_func=partial_completion,
conversation_commands=conversation_commands, conversation_commands=conversation_commands,
max_prompt_size=chat_model.max_prompt_size, max_prompt_size=chat_model.max_prompt_size,
tokenizer_name=chat_model.tokenizer, tokenizer_name=chat_model.tokenizer,
@@ -1513,7 +1480,6 @@ async def agenerate_chat_response(
model=chat_model.name, model=chat_model.name,
api_key=api_key, api_key=api_key,
api_base_url=api_base_url, api_base_url=api_base_url,
completion_func=partial_completion,
conversation_commands=conversation_commands, conversation_commands=conversation_commands,
max_prompt_size=chat_model.max_prompt_size, max_prompt_size=chat_model.max_prompt_size,
tokenizer_name=chat_model.tokenizer, tokenizer_name=chat_model.tokenizer,