mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-09 21:29:11 +00:00
Save conversation in common chat api func instead of each ai provider
This commit is contained in:
@@ -1,4 +1,3 @@
|
|||||||
import asyncio
|
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import AsyncGenerator, Dict, List, Optional
|
from typing import AsyncGenerator, Dict, List, Optional
|
||||||
@@ -146,7 +145,6 @@ async def converse_anthropic(
|
|||||||
model: Optional[str] = "claude-3-7-sonnet-latest",
|
model: Optional[str] = "claude-3-7-sonnet-latest",
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base_url: Optional[str] = None,
|
api_base_url: Optional[str] = None,
|
||||||
completion_func=None,
|
|
||||||
conversation_commands=[ConversationCommand.Default],
|
conversation_commands=[ConversationCommand.Default],
|
||||||
max_prompt_size=None,
|
max_prompt_size=None,
|
||||||
tokenizer_name=None,
|
tokenizer_name=None,
|
||||||
@@ -161,7 +159,7 @@ async def converse_anthropic(
|
|||||||
generated_asset_results: Dict[str, Dict] = {},
|
generated_asset_results: Dict[str, Dict] = {},
|
||||||
deepthought: Optional[bool] = False,
|
deepthought: Optional[bool] = False,
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
) -> AsyncGenerator[str | ResponseWithThought, None]:
|
) -> AsyncGenerator[ResponseWithThought, None]:
|
||||||
"""
|
"""
|
||||||
Converse with user using Anthropic's Claude
|
Converse with user using Anthropic's Claude
|
||||||
"""
|
"""
|
||||||
@@ -192,15 +190,11 @@ async def converse_anthropic(
|
|||||||
# Get Conversation Primer appropriate to Conversation Type
|
# Get Conversation Primer appropriate to Conversation Type
|
||||||
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
|
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
|
||||||
response = prompts.no_notes_found.format()
|
response = prompts.no_notes_found.format()
|
||||||
if completion_func:
|
yield ResponseWithThought(response=response)
|
||||||
asyncio.create_task(completion_func(chat_response=response))
|
|
||||||
yield response
|
|
||||||
return
|
return
|
||||||
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
|
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
|
||||||
response = prompts.no_online_results_found.format()
|
response = prompts.no_online_results_found.format()
|
||||||
if completion_func:
|
yield ResponseWithThought(response=response)
|
||||||
asyncio.create_task(completion_func(chat_response=response))
|
|
||||||
yield response
|
|
||||||
return
|
return
|
||||||
|
|
||||||
context_message = ""
|
context_message = ""
|
||||||
@@ -241,7 +235,6 @@ async def converse_anthropic(
|
|||||||
logger.debug(f"Conversation Context for Claude: {messages_to_print(messages)}")
|
logger.debug(f"Conversation Context for Claude: {messages_to_print(messages)}")
|
||||||
|
|
||||||
# Get Response from Claude
|
# Get Response from Claude
|
||||||
full_response = ""
|
|
||||||
async for chunk in anthropic_chat_completion_with_backoff(
|
async for chunk in anthropic_chat_completion_with_backoff(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
model_name=model,
|
model_name=model,
|
||||||
@@ -253,10 +246,4 @@ async def converse_anthropic(
|
|||||||
deepthought=deepthought,
|
deepthought=deepthought,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
):
|
):
|
||||||
if chunk.response:
|
|
||||||
full_response += chunk.response
|
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
# Call completion_func once finish streaming and we have the full response
|
|
||||||
if completion_func:
|
|
||||||
asyncio.create_task(completion_func(chat_response=full_response))
|
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import asyncio
|
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import AsyncGenerator, Dict, List, Optional
|
from typing import AsyncGenerator, Dict, List, Optional
|
||||||
@@ -15,6 +14,7 @@ from khoj.processor.conversation.google.utils import (
|
|||||||
)
|
)
|
||||||
from khoj.processor.conversation.utils import (
|
from khoj.processor.conversation.utils import (
|
||||||
OperatorRun,
|
OperatorRun,
|
||||||
|
ResponseWithThought,
|
||||||
clean_json,
|
clean_json,
|
||||||
construct_question_history,
|
construct_question_history,
|
||||||
construct_structured_message,
|
construct_structured_message,
|
||||||
@@ -168,7 +168,6 @@ async def converse_gemini(
|
|||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base_url: Optional[str] = None,
|
api_base_url: Optional[str] = None,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
completion_func=None,
|
|
||||||
conversation_commands=[ConversationCommand.Default],
|
conversation_commands=[ConversationCommand.Default],
|
||||||
max_prompt_size=None,
|
max_prompt_size=None,
|
||||||
tokenizer_name=None,
|
tokenizer_name=None,
|
||||||
@@ -183,7 +182,7 @@ async def converse_gemini(
|
|||||||
program_execution_context: List[str] = None,
|
program_execution_context: List[str] = None,
|
||||||
deepthought: Optional[bool] = False,
|
deepthought: Optional[bool] = False,
|
||||||
tracer={},
|
tracer={},
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[ResponseWithThought, None]:
|
||||||
"""
|
"""
|
||||||
Converse with user using Google's Gemini
|
Converse with user using Google's Gemini
|
||||||
"""
|
"""
|
||||||
@@ -215,15 +214,11 @@ async def converse_gemini(
|
|||||||
# Get Conversation Primer appropriate to Conversation Type
|
# Get Conversation Primer appropriate to Conversation Type
|
||||||
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
|
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
|
||||||
response = prompts.no_notes_found.format()
|
response = prompts.no_notes_found.format()
|
||||||
if completion_func:
|
yield ResponseWithThought(response=response)
|
||||||
asyncio.create_task(completion_func(chat_response=response))
|
|
||||||
yield response
|
|
||||||
return
|
return
|
||||||
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
|
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
|
||||||
response = prompts.no_online_results_found.format()
|
response = prompts.no_online_results_found.format()
|
||||||
if completion_func:
|
yield ResponseWithThought(response=response)
|
||||||
asyncio.create_task(completion_func(chat_response=response))
|
|
||||||
yield response
|
|
||||||
return
|
return
|
||||||
|
|
||||||
context_message = ""
|
context_message = ""
|
||||||
@@ -264,7 +259,6 @@ async def converse_gemini(
|
|||||||
logger.debug(f"Conversation Context for Gemini: {messages_to_print(messages)}")
|
logger.debug(f"Conversation Context for Gemini: {messages_to_print(messages)}")
|
||||||
|
|
||||||
# Get Response from Google AI
|
# Get Response from Google AI
|
||||||
full_response = ""
|
|
||||||
async for chunk in gemini_chat_completion_with_backoff(
|
async for chunk in gemini_chat_completion_with_backoff(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
model_name=model,
|
model_name=model,
|
||||||
@@ -275,10 +269,4 @@ async def converse_gemini(
|
|||||||
deepthought=deepthought,
|
deepthought=deepthought,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
):
|
):
|
||||||
if chunk.response:
|
|
||||||
full_response += chunk.response
|
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
# Call completion_func once finish streaming and we have the full response
|
|
||||||
if completion_func:
|
|
||||||
asyncio.create_task(completion_func(chat_response=full_response))
|
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from khoj.database.models import Agent, ChatMessageModel, ChatModel, KhojUser
|
|||||||
from khoj.processor.conversation import prompts
|
from khoj.processor.conversation import prompts
|
||||||
from khoj.processor.conversation.offline.utils import download_model
|
from khoj.processor.conversation.offline.utils import download_model
|
||||||
from khoj.processor.conversation.utils import (
|
from khoj.processor.conversation.utils import (
|
||||||
|
ResponseWithThought,
|
||||||
clean_json,
|
clean_json,
|
||||||
commit_conversation_trace,
|
commit_conversation_trace,
|
||||||
construct_question_history,
|
construct_question_history,
|
||||||
@@ -150,7 +151,6 @@ async def converse_offline(
|
|||||||
chat_history: list[ChatMessageModel] = [],
|
chat_history: list[ChatMessageModel] = [],
|
||||||
model_name: str = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF",
|
model_name: str = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF",
|
||||||
loaded_model: Union[Any, None] = None,
|
loaded_model: Union[Any, None] = None,
|
||||||
completion_func=None,
|
|
||||||
conversation_commands=[ConversationCommand.Default],
|
conversation_commands=[ConversationCommand.Default],
|
||||||
max_prompt_size=None,
|
max_prompt_size=None,
|
||||||
tokenizer_name=None,
|
tokenizer_name=None,
|
||||||
@@ -162,7 +162,7 @@ async def converse_offline(
|
|||||||
additional_context: List[str] = None,
|
additional_context: List[str] = None,
|
||||||
generated_asset_results: Dict[str, Dict] = {},
|
generated_asset_results: Dict[str, Dict] = {},
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[ResponseWithThought, None]:
|
||||||
"""
|
"""
|
||||||
Converse with user using Llama (Async Version)
|
Converse with user using Llama (Async Version)
|
||||||
"""
|
"""
|
||||||
@@ -196,15 +196,11 @@ async def converse_offline(
|
|||||||
# Get Conversation Primer appropriate to Conversation Type
|
# Get Conversation Primer appropriate to Conversation Type
|
||||||
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
|
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
|
||||||
response = prompts.no_notes_found.format()
|
response = prompts.no_notes_found.format()
|
||||||
if completion_func:
|
yield ResponseWithThought(response=response)
|
||||||
asyncio.create_task(completion_func(chat_response=response))
|
|
||||||
yield response
|
|
||||||
return
|
return
|
||||||
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
|
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
|
||||||
response = prompts.no_online_results_found.format()
|
response = prompts.no_online_results_found.format()
|
||||||
if completion_func:
|
yield ResponseWithThought(response=response)
|
||||||
asyncio.create_task(completion_func(chat_response=response))
|
|
||||||
yield response
|
|
||||||
return
|
return
|
||||||
|
|
||||||
context_message = ""
|
context_message = ""
|
||||||
@@ -243,9 +239,8 @@ async def converse_offline(
|
|||||||
logger.debug(f"Conversation Context for {model_name}: {messages_to_print(messages)}")
|
logger.debug(f"Conversation Context for {model_name}: {messages_to_print(messages)}")
|
||||||
|
|
||||||
# Use asyncio.Queue and a thread to bridge sync iterator
|
# Use asyncio.Queue and a thread to bridge sync iterator
|
||||||
queue: asyncio.Queue = asyncio.Queue()
|
queue: asyncio.Queue[ResponseWithThought] = asyncio.Queue()
|
||||||
stop_phrases = ["<s>", "INST]", "Notes:"]
|
stop_phrases = ["<s>", "INST]", "Notes:"]
|
||||||
aggregated_response_container = {"response": ""}
|
|
||||||
|
|
||||||
def _sync_llm_thread():
|
def _sync_llm_thread():
|
||||||
"""Synchronous function to run in a separate thread."""
|
"""Synchronous function to run in a separate thread."""
|
||||||
@@ -262,7 +257,7 @@ async def converse_offline(
|
|||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
for response in response_iterator:
|
for response in response_iterator:
|
||||||
response_delta = response["choices"][0]["delta"].get("content", "")
|
response_delta: str = response["choices"][0]["delta"].get("content", "")
|
||||||
# Log the time taken to start response
|
# Log the time taken to start response
|
||||||
if aggregated_response == "" and response_delta != "":
|
if aggregated_response == "" and response_delta != "":
|
||||||
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
|
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
|
||||||
@@ -270,12 +265,12 @@ async def converse_offline(
|
|||||||
aggregated_response += response_delta
|
aggregated_response += response_delta
|
||||||
# Put chunk into the asyncio queue (non-blocking)
|
# Put chunk into the asyncio queue (non-blocking)
|
||||||
try:
|
try:
|
||||||
queue.put_nowait(response_delta)
|
queue.put_nowait(ResponseWithThought(response=response_delta))
|
||||||
except asyncio.QueueFull:
|
except asyncio.QueueFull:
|
||||||
# Should not happen with default queue size unless consumer is very slow
|
# Should not happen with default queue size unless consumer is very slow
|
||||||
logger.warning("Asyncio queue full during offline LLM streaming.")
|
logger.warning("Asyncio queue full during offline LLM streaming.")
|
||||||
# Potentially block here or handle differently if needed
|
# Potentially block here or handle differently if needed
|
||||||
asyncio.run(queue.put(response_delta))
|
asyncio.run(queue.put(ResponseWithThought(response=response_delta)))
|
||||||
|
|
||||||
# Log the time taken to stream the entire response
|
# Log the time taken to stream the entire response
|
||||||
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
|
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
|
||||||
@@ -291,7 +286,6 @@ async def converse_offline(
|
|||||||
state.chat_lock.release()
|
state.chat_lock.release()
|
||||||
# Signal end of stream
|
# Signal end of stream
|
||||||
queue.put_nowait(None)
|
queue.put_nowait(None)
|
||||||
aggregated_response_container["response"] = aggregated_response
|
|
||||||
|
|
||||||
# Start the synchronous thread
|
# Start the synchronous thread
|
||||||
thread = Thread(target=_sync_llm_thread)
|
thread = Thread(target=_sync_llm_thread)
|
||||||
@@ -310,10 +304,6 @@ async def converse_offline(
|
|||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
await loop.run_in_executor(None, thread.join)
|
await loop.run_in_executor(None, thread.join)
|
||||||
|
|
||||||
# Call the completion function after streaming is done
|
|
||||||
if completion_func:
|
|
||||||
asyncio.create_task(completion_func(chat_response=aggregated_response_container["response"]))
|
|
||||||
|
|
||||||
|
|
||||||
def send_message_to_model_offline(
|
def send_message_to_model_offline(
|
||||||
messages: List[ChatMessage],
|
messages: List[ChatMessage],
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import asyncio
|
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import AsyncGenerator, Dict, List, Optional
|
from typing import AsyncGenerator, Dict, List, Optional
|
||||||
@@ -171,7 +170,6 @@ async def converse_openai(
|
|||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base_url: Optional[str] = None,
|
api_base_url: Optional[str] = None,
|
||||||
temperature: float = 0.4,
|
temperature: float = 0.4,
|
||||||
completion_func=None,
|
|
||||||
conversation_commands=[ConversationCommand.Default],
|
conversation_commands=[ConversationCommand.Default],
|
||||||
max_prompt_size=None,
|
max_prompt_size=None,
|
||||||
tokenizer_name=None,
|
tokenizer_name=None,
|
||||||
@@ -186,7 +184,7 @@ async def converse_openai(
|
|||||||
program_execution_context: List[str] = None,
|
program_execution_context: List[str] = None,
|
||||||
deepthought: Optional[bool] = False,
|
deepthought: Optional[bool] = False,
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
) -> AsyncGenerator[str | ResponseWithThought, None]:
|
) -> AsyncGenerator[ResponseWithThought, None]:
|
||||||
"""
|
"""
|
||||||
Converse with user using OpenAI's ChatGPT
|
Converse with user using OpenAI's ChatGPT
|
||||||
"""
|
"""
|
||||||
@@ -217,15 +215,11 @@ async def converse_openai(
|
|||||||
# Get Conversation Primer appropriate to Conversation Type
|
# Get Conversation Primer appropriate to Conversation Type
|
||||||
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
|
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
|
||||||
response = prompts.no_notes_found.format()
|
response = prompts.no_notes_found.format()
|
||||||
if completion_func:
|
yield ResponseWithThought(response=response)
|
||||||
asyncio.create_task(completion_func(chat_response=response))
|
|
||||||
yield response
|
|
||||||
return
|
return
|
||||||
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
|
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
|
||||||
response = prompts.no_online_results_found.format()
|
response = prompts.no_online_results_found.format()
|
||||||
if completion_func:
|
yield ResponseWithThought(response=response)
|
||||||
asyncio.create_task(completion_func(chat_response=response))
|
|
||||||
yield response
|
|
||||||
return
|
return
|
||||||
|
|
||||||
context_message = ""
|
context_message = ""
|
||||||
@@ -267,7 +261,6 @@ async def converse_openai(
|
|||||||
logger.debug(f"Conversation Context for GPT: {messages_to_print(messages)}")
|
logger.debug(f"Conversation Context for GPT: {messages_to_print(messages)}")
|
||||||
|
|
||||||
# Get Response from GPT
|
# Get Response from GPT
|
||||||
full_response = ""
|
|
||||||
async for chunk in chat_completion_with_backoff(
|
async for chunk in chat_completion_with_backoff(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
model_name=model,
|
model_name=model,
|
||||||
@@ -277,14 +270,8 @@ async def converse_openai(
|
|||||||
deepthought=deepthought,
|
deepthought=deepthought,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
):
|
):
|
||||||
if chunk.response:
|
|
||||||
full_response += chunk.response
|
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
# Call completion_func once finish streaming and we have the full response
|
|
||||||
if completion_func:
|
|
||||||
asyncio.create_task(completion_func(chat_response=full_response))
|
|
||||||
|
|
||||||
|
|
||||||
def clean_response_schema(schema: BaseModel | dict) -> dict:
|
def clean_response_schema(schema: BaseModel | dict) -> dict:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1463,33 +1463,30 @@ async def chat(
|
|||||||
code_results,
|
code_results,
|
||||||
operator_results,
|
operator_results,
|
||||||
research_results,
|
research_results,
|
||||||
inferred_queries,
|
|
||||||
conversation_commands,
|
conversation_commands,
|
||||||
user,
|
user,
|
||||||
request.user.client_app,
|
|
||||||
location,
|
location,
|
||||||
user_name,
|
user_name,
|
||||||
uploaded_images,
|
uploaded_images,
|
||||||
train_of_thought,
|
|
||||||
attached_file_context,
|
attached_file_context,
|
||||||
raw_query_files,
|
|
||||||
generated_images,
|
|
||||||
generated_files,
|
generated_files,
|
||||||
generated_mermaidjs_diagram,
|
|
||||||
program_execution_context,
|
program_execution_context,
|
||||||
generated_asset_results,
|
generated_asset_results,
|
||||||
is_subscribed,
|
is_subscribed,
|
||||||
tracer,
|
tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
full_response = ""
|
||||||
async for item in llm_response:
|
async for item in llm_response:
|
||||||
# Should not happen with async generator, end is signaled by loop exit. Skip.
|
# Should not happen with async generator. Skip.
|
||||||
if item is None:
|
if item is None or not isinstance(item, ResponseWithThought):
|
||||||
|
logger.warning(f"Unexpected item type in LLM response: {type(item)}. Skipping.")
|
||||||
continue
|
continue
|
||||||
if cancellation_event.is_set():
|
if cancellation_event.is_set():
|
||||||
break
|
break
|
||||||
message = item.response if isinstance(item, ResponseWithThought) else item
|
message = item.response
|
||||||
if isinstance(item, ResponseWithThought) and item.thought:
|
full_response += message if message else ""
|
||||||
|
if item.thought:
|
||||||
async for result in send_event(ChatEvent.THOUGHT, item.thought):
|
async for result in send_event(ChatEvent.THOUGHT, item.thought):
|
||||||
yield result
|
yield result
|
||||||
continue
|
continue
|
||||||
@@ -1506,6 +1503,31 @@ async def chat(
|
|||||||
logger.warning(f"Error during streaming. Stopping send: {e}")
|
logger.warning(f"Error during streaming. Stopping send: {e}")
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# Save conversation once finish streaming
|
||||||
|
asyncio.create_task(
|
||||||
|
save_to_conversation_log(
|
||||||
|
q,
|
||||||
|
chat_response=full_response,
|
||||||
|
user=user,
|
||||||
|
chat_history=chat_history,
|
||||||
|
compiled_references=compiled_references,
|
||||||
|
online_results=online_results,
|
||||||
|
code_results=code_results,
|
||||||
|
operator_results=operator_results,
|
||||||
|
research_results=research_results,
|
||||||
|
inferred_queries=inferred_queries,
|
||||||
|
client_application=request.user.client_app,
|
||||||
|
conversation_id=str(conversation.id),
|
||||||
|
query_images=uploaded_images,
|
||||||
|
train_of_thought=train_of_thought,
|
||||||
|
raw_query_files=raw_query_files,
|
||||||
|
generated_images=generated_images,
|
||||||
|
raw_generated_files=generated_files,
|
||||||
|
generated_mermaidjs_diagram=generated_mermaidjs_diagram,
|
||||||
|
tracer=tracer,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Signal end of LLM response after the loop finishes
|
# Signal end of LLM response after the loop finishes
|
||||||
if not cancellation_event.is_set():
|
if not cancellation_event.is_set():
|
||||||
async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
|
async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import math
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from functools import partial
|
|
||||||
from random import random
|
from random import random
|
||||||
from typing import (
|
from typing import (
|
||||||
Annotated,
|
Annotated,
|
||||||
@@ -102,7 +101,6 @@ from khoj.processor.conversation.utils import (
|
|||||||
clean_mermaidjs,
|
clean_mermaidjs,
|
||||||
construct_chat_history,
|
construct_chat_history,
|
||||||
generate_chatml_messages_with_context,
|
generate_chatml_messages_with_context,
|
||||||
save_to_conversation_log,
|
|
||||||
)
|
)
|
||||||
from khoj.processor.speech.text_to_speech import is_eleven_labs_enabled
|
from khoj.processor.speech.text_to_speech import is_eleven_labs_enabled
|
||||||
from khoj.routers.email import is_resend_enabled, send_task_email
|
from khoj.routers.email import is_resend_enabled, send_task_email
|
||||||
@@ -1350,54 +1348,26 @@ async def agenerate_chat_response(
|
|||||||
code_results: Dict[str, Dict] = {},
|
code_results: Dict[str, Dict] = {},
|
||||||
operator_results: List[OperatorRun] = [],
|
operator_results: List[OperatorRun] = [],
|
||||||
research_results: List[ResearchIteration] = [],
|
research_results: List[ResearchIteration] = [],
|
||||||
inferred_queries: List[str] = [],
|
|
||||||
conversation_commands: List[ConversationCommand] = [ConversationCommand.Default],
|
conversation_commands: List[ConversationCommand] = [ConversationCommand.Default],
|
||||||
user: KhojUser = None,
|
user: KhojUser = None,
|
||||||
client_application: ClientApplication = None,
|
|
||||||
location_data: LocationData = None,
|
location_data: LocationData = None,
|
||||||
user_name: Optional[str] = None,
|
user_name: Optional[str] = None,
|
||||||
query_images: Optional[List[str]] = None,
|
query_images: Optional[List[str]] = None,
|
||||||
train_of_thought: List[Any] = [],
|
|
||||||
query_files: str = None,
|
query_files: str = None,
|
||||||
raw_query_files: List[FileAttachment] = None,
|
|
||||||
generated_images: List[str] = None,
|
|
||||||
raw_generated_files: List[FileAttachment] = [],
|
raw_generated_files: List[FileAttachment] = [],
|
||||||
generated_mermaidjs_diagram: str = None,
|
|
||||||
program_execution_context: List[str] = [],
|
program_execution_context: List[str] = [],
|
||||||
generated_asset_results: Dict[str, Dict] = {},
|
generated_asset_results: Dict[str, Dict] = {},
|
||||||
is_subscribed: bool = False,
|
is_subscribed: bool = False,
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
) -> Tuple[AsyncGenerator[str | ResponseWithThought, None], Dict[str, str]]:
|
) -> Tuple[AsyncGenerator[ResponseWithThought, None], Dict[str, str]]:
|
||||||
# Initialize Variables
|
# Initialize Variables
|
||||||
chat_response_generator: AsyncGenerator[str | ResponseWithThought, None] = None
|
chat_response_generator: AsyncGenerator[ResponseWithThought, None] = None
|
||||||
logger.debug(f"Conversation Types: {conversation_commands}")
|
logger.debug(f"Conversation Types: {conversation_commands}")
|
||||||
|
|
||||||
metadata = {}
|
metadata = {}
|
||||||
agent = await AgentAdapters.aget_conversation_agent_by_id(conversation.agent.id) if conversation.agent else None
|
agent = await AgentAdapters.aget_conversation_agent_by_id(conversation.agent.id) if conversation.agent else None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
partial_completion = partial(
|
|
||||||
save_to_conversation_log,
|
|
||||||
q,
|
|
||||||
user=user,
|
|
||||||
chat_history=chat_history,
|
|
||||||
compiled_references=compiled_references,
|
|
||||||
online_results=online_results,
|
|
||||||
code_results=code_results,
|
|
||||||
operator_results=operator_results,
|
|
||||||
research_results=research_results,
|
|
||||||
inferred_queries=inferred_queries,
|
|
||||||
client_application=client_application,
|
|
||||||
conversation_id=str(conversation.id),
|
|
||||||
query_images=query_images,
|
|
||||||
train_of_thought=train_of_thought,
|
|
||||||
raw_query_files=raw_query_files,
|
|
||||||
generated_images=generated_images,
|
|
||||||
raw_generated_files=raw_generated_files,
|
|
||||||
generated_mermaidjs_diagram=generated_mermaidjs_diagram,
|
|
||||||
tracer=tracer,
|
|
||||||
)
|
|
||||||
|
|
||||||
query_to_run = q
|
query_to_run = q
|
||||||
deepthought = False
|
deepthought = False
|
||||||
if research_results:
|
if research_results:
|
||||||
@@ -1426,7 +1396,6 @@ async def agenerate_chat_response(
|
|||||||
online_results=online_results,
|
online_results=online_results,
|
||||||
loaded_model=loaded_model,
|
loaded_model=loaded_model,
|
||||||
chat_history=chat_history,
|
chat_history=chat_history,
|
||||||
completion_func=partial_completion,
|
|
||||||
conversation_commands=conversation_commands,
|
conversation_commands=conversation_commands,
|
||||||
model_name=chat_model.name,
|
model_name=chat_model.name,
|
||||||
max_prompt_size=chat_model.max_prompt_size,
|
max_prompt_size=chat_model.max_prompt_size,
|
||||||
@@ -1455,7 +1424,6 @@ async def agenerate_chat_response(
|
|||||||
model=chat_model_name,
|
model=chat_model_name,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_base_url=openai_chat_config.api_base_url,
|
api_base_url=openai_chat_config.api_base_url,
|
||||||
completion_func=partial_completion,
|
|
||||||
conversation_commands=conversation_commands,
|
conversation_commands=conversation_commands,
|
||||||
max_prompt_size=chat_model.max_prompt_size,
|
max_prompt_size=chat_model.max_prompt_size,
|
||||||
tokenizer_name=chat_model.tokenizer,
|
tokenizer_name=chat_model.tokenizer,
|
||||||
@@ -1485,7 +1453,6 @@ async def agenerate_chat_response(
|
|||||||
model=chat_model.name,
|
model=chat_model.name,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_base_url=api_base_url,
|
api_base_url=api_base_url,
|
||||||
completion_func=partial_completion,
|
|
||||||
conversation_commands=conversation_commands,
|
conversation_commands=conversation_commands,
|
||||||
max_prompt_size=chat_model.max_prompt_size,
|
max_prompt_size=chat_model.max_prompt_size,
|
||||||
tokenizer_name=chat_model.tokenizer,
|
tokenizer_name=chat_model.tokenizer,
|
||||||
@@ -1513,7 +1480,6 @@ async def agenerate_chat_response(
|
|||||||
model=chat_model.name,
|
model=chat_model.name,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_base_url=api_base_url,
|
api_base_url=api_base_url,
|
||||||
completion_func=partial_completion,
|
|
||||||
conversation_commands=conversation_commands,
|
conversation_commands=conversation_commands,
|
||||||
max_prompt_size=chat_model.max_prompt_size,
|
max_prompt_size=chat_model.max_prompt_size,
|
||||||
tokenizer_name=chat_model.tokenizer,
|
tokenizer_name=chat_model.tokenizer,
|
||||||
|
|||||||
Reference in New Issue
Block a user