Simplify ai provider converse methods

- Add context based on information provided rather than conversation
  commands. Let caller handle passing appropriate context to ai
  provider converse methods
This commit is contained in:
Debanjum
2025-06-04 20:28:21 -07:00
parent bfd4695705
commit 38fa34a861
6 changed files with 110 additions and 150 deletions

View File

@@ -21,11 +21,7 @@ from khoj.processor.conversation.utils import (
generate_chatml_messages_with_context,
messages_to_print,
)
from khoj.utils.helpers import (
ConversationCommand,
is_none_or_empty,
truncate_code_context,
)
from khoj.utils.helpers import is_none_or_empty, truncate_code_context
from khoj.utils.rawconfig import FileAttachment, LocationData
from khoj.utils.yaml import yaml_dump
@@ -136,27 +132,29 @@ def anthropic_send_message_to_model(
async def converse_anthropic(
# Query
user_query: str,
# Context
references: list[dict],
online_results: Optional[Dict[str, Dict]] = None,
code_results: Optional[Dict[str, Dict]] = None,
operator_results: Optional[List[OperatorRun]] = None,
chat_history: List[ChatMessageModel] = [],
model: Optional[str] = "claude-3-7-sonnet-latest",
api_key: Optional[str] = None,
api_base_url: Optional[str] = None,
conversation_commands=[ConversationCommand.Default],
max_prompt_size=None,
tokenizer_name=None,
location_data: LocationData = None,
user_name: str = None,
agent: Agent = None,
query_images: Optional[list[str]] = None,
vision_available: bool = False,
query_files: str = None,
generated_files: List[FileAttachment] = None,
program_execution_context: Optional[List[str]] = None,
generated_asset_results: Dict[str, Dict] = {},
location_data: LocationData = None,
user_name: str = None,
chat_history: List[ChatMessageModel] = [],
# Model
model: Optional[str] = "claude-3-7-sonnet-latest",
api_key: Optional[str] = None,
api_base_url: Optional[str] = None,
max_prompt_size=None,
tokenizer_name=None,
agent: Agent = None,
vision_available: bool = False,
deepthought: Optional[bool] = False,
tracer: dict = {},
) -> AsyncGenerator[ResponseWithThought, None]:
@@ -187,26 +185,16 @@ async def converse_anthropic(
user_name_prompt = prompts.user_name.format(name=user_name)
system_prompt = f"{system_prompt}\n{user_name_prompt}"
# Get Conversation Primer appropriate to Conversation Type
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
response = prompts.no_notes_found.format()
yield ResponseWithThought(response=response)
return
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
response = prompts.no_online_results_found.format()
yield ResponseWithThought(response=response)
return
context_message = ""
if not is_none_or_empty(references):
context_message = f"{prompts.notes_conversation.format(query=user_query, references=yaml_dump(references))}\n\n"
if ConversationCommand.Online in conversation_commands or ConversationCommand.Webpage in conversation_commands:
if not is_none_or_empty(online_results):
context_message += f"{prompts.online_search_conversation.format(online_results=yaml_dump(online_results))}\n\n"
if ConversationCommand.Code in conversation_commands and not is_none_or_empty(code_results):
if not is_none_or_empty(code_results):
context_message += (
f"{prompts.code_executed_context.format(code_results=truncate_code_context(code_results))}\n\n"
)
if ConversationCommand.Operator in conversation_commands and not is_none_or_empty(operator_results):
if not is_none_or_empty(operator_results):
operator_content = [
{"query": oc.query, "response": oc.response, "webpages": oc.webpages} for oc in operator_results
]

View File

@@ -21,11 +21,7 @@ from khoj.processor.conversation.utils import (
generate_chatml_messages_with_context,
messages_to_print,
)
from khoj.utils.helpers import (
ConversationCommand,
is_none_or_empty,
truncate_code_context,
)
from khoj.utils.helpers import is_none_or_empty, truncate_code_context
from khoj.utils.rawconfig import FileAttachment, LocationData
from khoj.utils.yaml import yaml_dump
@@ -158,28 +154,30 @@ def gemini_send_message_to_model(
async def converse_gemini(
# Query
user_query: str,
# Context
references: list[dict],
online_results: Optional[Dict[str, Dict]] = None,
code_results: Optional[Dict[str, Dict]] = None,
operator_results: Optional[List[OperatorRun]] = None,
chat_history: List[ChatMessageModel] = [],
model: Optional[str] = "gemini-2.0-flash",
api_key: Optional[str] = None,
api_base_url: Optional[str] = None,
temperature: float = 1.0,
conversation_commands=[ConversationCommand.Default],
max_prompt_size=None,
tokenizer_name=None,
location_data: LocationData = None,
user_name: str = None,
agent: Agent = None,
query_images: Optional[list[str]] = None,
vision_available: bool = False,
query_files: str = None,
generated_files: List[FileAttachment] = None,
generated_asset_results: Dict[str, Dict] = {},
program_execution_context: List[str] = None,
location_data: LocationData = None,
user_name: str = None,
chat_history: List[ChatMessageModel] = [],
# Model
model: Optional[str] = "gemini-2.0-flash",
api_key: Optional[str] = None,
api_base_url: Optional[str] = None,
temperature: float = 1.0,
max_prompt_size=None,
tokenizer_name=None,
agent: Agent = None,
vision_available: bool = False,
deepthought: Optional[bool] = False,
tracer={},
) -> AsyncGenerator[ResponseWithThought, None]:
@@ -211,26 +209,16 @@ async def converse_gemini(
user_name_prompt = prompts.user_name.format(name=user_name)
system_prompt = f"{system_prompt}\n{user_name_prompt}"
# Get Conversation Primer appropriate to Conversation Type
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
response = prompts.no_notes_found.format()
yield ResponseWithThought(response=response)
return
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
response = prompts.no_online_results_found.format()
yield ResponseWithThought(response=response)
return
context_message = ""
if not is_none_or_empty(references):
context_message = f"{prompts.notes_conversation.format(query=user_query, references=yaml_dump(references))}\n\n"
if ConversationCommand.Online in conversation_commands or ConversationCommand.Webpage in conversation_commands:
if not is_none_or_empty(online_results):
context_message += f"{prompts.online_search_conversation.format(online_results=yaml_dump(online_results))}\n\n"
if ConversationCommand.Code in conversation_commands and not is_none_or_empty(code_results):
if not is_none_or_empty(code_results):
context_message += (
f"{prompts.code_executed_context.format(code_results=truncate_code_context(code_results))}\n\n"
)
if ConversationCommand.Operator in conversation_commands and not is_none_or_empty(operator_results):
if not is_none_or_empty(operator_results):
operator_content = [
{"query": oc.query, "response": oc.response, "webpages": oc.webpages} for oc in operator_results
]

View File

@@ -24,7 +24,6 @@ from khoj.processor.conversation.utils import (
from khoj.utils import state
from khoj.utils.constants import empty_escape_sequences
from khoj.utils.helpers import (
ConversationCommand,
is_none_or_empty,
is_promptrace_enabled,
truncate_code_context,
@@ -144,23 +143,25 @@ def filter_questions(questions: List[str]):
async def converse_offline(
# Query
user_query: str,
# Context
references: list[dict] = [],
online_results={},
code_results={},
chat_history: list[ChatMessageModel] = [],
model_name: str = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF",
loaded_model: Union[Any, None] = None,
conversation_commands=[ConversationCommand.Default],
max_prompt_size=None,
tokenizer_name=None,
location_data: LocationData = None,
user_name: str = None,
agent: Agent = None,
query_files: str = None,
generated_files: List[FileAttachment] = None,
additional_context: List[str] = None,
generated_asset_results: Dict[str, Dict] = {},
location_data: LocationData = None,
user_name: str = None,
chat_history: list[ChatMessageModel] = [],
# Model
model_name: str = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF",
loaded_model: Union[Any, None] = None,
max_prompt_size=None,
tokenizer_name=None,
agent: Agent = None,
tracer: dict = {},
) -> AsyncGenerator[ResponseWithThought, None]:
"""
@@ -194,26 +195,17 @@ async def converse_offline(
system_prompt = f"{system_prompt}\n{user_name_prompt}"
# Get Conversation Primer appropriate to Conversation Type
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
response = prompts.no_notes_found.format()
yield ResponseWithThought(response=response)
return
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
response = prompts.no_online_results_found.format()
yield ResponseWithThought(response=response)
return
context_message = ""
if not is_none_or_empty(references):
context_message = f"{prompts.notes_conversation_offline.format(references=yaml_dump(references))}\n\n"
if ConversationCommand.Online in conversation_commands or ConversationCommand.Webpage in conversation_commands:
if not is_none_or_empty(online_results):
simplified_online_results = online_results.copy()
for result in online_results:
if online_results[result].get("webpages"):
simplified_online_results[result] = online_results[result]["webpages"]
context_message += f"{prompts.online_search_conversation_offline.format(online_results=yaml_dump(simplified_online_results))}\n\n"
if ConversationCommand.Code in conversation_commands and not is_none_or_empty(code_results):
if not is_none_or_empty(code_results):
context_message += (
f"{prompts.code_executed_context.format(code_results=truncate_code_context(code_results))}\n\n"
)

View File

@@ -24,11 +24,7 @@ from khoj.processor.conversation.utils import (
generate_chatml_messages_with_context,
messages_to_print,
)
from khoj.utils.helpers import (
ConversationCommand,
is_none_or_empty,
truncate_code_context,
)
from khoj.utils.helpers import is_none_or_empty, truncate_code_context
from khoj.utils.rawconfig import FileAttachment, LocationData
from khoj.utils.yaml import yaml_dump
@@ -160,28 +156,29 @@ def send_message_to_model(
async def converse_openai(
# Query
user_query: str,
# Context
references: list[dict],
online_results: Optional[Dict[str, Dict]] = None,
code_results: Optional[Dict[str, Dict]] = None,
operator_results: Optional[List[OperatorRun]] = None,
query_images: Optional[list[str]] = None,
query_files: str = None,
generated_files: List[FileAttachment] = None,
generated_asset_results: Dict[str, Dict] = {},
program_execution_context: List[str] = None,
location_data: LocationData = None,
chat_history: list[ChatMessageModel] = [],
model: str = "gpt-4o-mini",
api_key: Optional[str] = None,
api_base_url: Optional[str] = None,
temperature: float = 0.4,
conversation_commands=[ConversationCommand.Default],
max_prompt_size=None,
tokenizer_name=None,
location_data: LocationData = None,
user_name: str = None,
agent: Agent = None,
query_images: Optional[list[str]] = None,
vision_available: bool = False,
query_files: str = None,
generated_files: List[FileAttachment] = None,
generated_asset_results: Dict[str, Dict] = {},
program_execution_context: List[str] = None,
deepthought: Optional[bool] = False,
tracer: dict = {},
) -> AsyncGenerator[ResponseWithThought, None]:
@@ -212,16 +209,6 @@ async def converse_openai(
user_name_prompt = prompts.user_name.format(name=user_name)
system_prompt = f"{system_prompt}\n{user_name_prompt}"
# Get Conversation Primer appropriate to Conversation Type
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
response = prompts.no_notes_found.format()
yield ResponseWithThought(response=response)
return
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
response = prompts.no_online_results_found.format()
yield ResponseWithThought(response=response)
return
context_message = ""
if not is_none_or_empty(references):
context_message = f"{prompts.notes_conversation.format(references=yaml_dump(references))}\n\n"

View File

@@ -1463,7 +1463,6 @@ async def chat(
code_results,
operator_results,
research_results,
conversation_commands,
user,
location,
user_name,

View File

@@ -1348,7 +1348,6 @@ async def agenerate_chat_response(
code_results: Dict[str, Dict] = {},
operator_results: List[OperatorRun] = [],
research_results: List[ResearchIteration] = [],
conversation_commands: List[ConversationCommand] = [ConversationCommand.Default],
user: KhojUser = None,
location_data: LocationData = None,
user_name: Optional[str] = None,
@@ -1362,7 +1361,6 @@ async def agenerate_chat_response(
) -> Tuple[AsyncGenerator[ResponseWithThought, None], Dict[str, str]]:
# Initialize Variables
chat_response_generator: AsyncGenerator[ResponseWithThought, None] = None
logger.debug(f"Conversation Types: {conversation_commands}")
metadata = {}
agent = await AgentAdapters.aget_conversation_agent_by_id(conversation.agent.id) if conversation.agent else None
@@ -1391,21 +1389,23 @@ async def agenerate_chat_response(
if chat_model.model_type == "offline":
loaded_model = state.offline_chat_processor_config.loaded_model
chat_response_generator = converse_offline(
# Query
user_query=query_to_run,
# Context
references=compiled_references,
online_results=online_results,
loaded_model=loaded_model,
generated_files=raw_generated_files,
generated_asset_results=generated_asset_results,
location_data=location_data,
user_name=user_name,
query_files=query_files,
chat_history=chat_history,
conversation_commands=conversation_commands,
# Model
loaded_model=loaded_model,
model_name=chat_model.name,
max_prompt_size=chat_model.max_prompt_size,
tokenizer_name=chat_model.tokenizer,
location_data=location_data,
user_name=user_name,
agent=agent,
query_files=query_files,
generated_files=raw_generated_files,
generated_asset_results=generated_asset_results,
tracer=tracer,
)
@@ -1414,27 +1414,29 @@ async def agenerate_chat_response(
api_key = openai_chat_config.api_key
chat_model_name = chat_model.name
chat_response_generator = converse_openai(
# Query
query_to_run,
compiled_references,
query_images=query_images,
# Context
references=compiled_references,
online_results=online_results,
code_results=code_results,
operator_results=operator_results,
chat_history=chat_history,
model=chat_model_name,
api_key=api_key,
api_base_url=openai_chat_config.api_base_url,
conversation_commands=conversation_commands,
max_prompt_size=chat_model.max_prompt_size,
tokenizer_name=chat_model.tokenizer,
location_data=location_data,
user_name=user_name,
agent=agent,
vision_available=vision_available,
query_images=query_images,
query_files=query_files,
generated_files=raw_generated_files,
generated_asset_results=generated_asset_results,
program_execution_context=program_execution_context,
location_data=location_data,
user_name=user_name,
chat_history=chat_history,
# Model
model=chat_model_name,
api_key=api_key,
api_base_url=openai_chat_config.api_base_url,
max_prompt_size=chat_model.max_prompt_size,
tokenizer_name=chat_model.tokenizer,
agent=agent,
vision_available=vision_available,
deepthought=deepthought,
tracer=tracer,
)
@@ -1443,27 +1445,29 @@ async def agenerate_chat_response(
api_key = chat_model.ai_model_api.api_key
api_base_url = chat_model.ai_model_api.api_base_url
chat_response_generator = converse_anthropic(
# Query
query_to_run,
compiled_references,
query_images=query_images,
# Context
references=compiled_references,
online_results=online_results,
code_results=code_results,
operator_results=operator_results,
chat_history=chat_history,
model=chat_model.name,
api_key=api_key,
api_base_url=api_base_url,
conversation_commands=conversation_commands,
max_prompt_size=chat_model.max_prompt_size,
tokenizer_name=chat_model.tokenizer,
location_data=location_data,
user_name=user_name,
agent=agent,
vision_available=vision_available,
query_images=query_images,
query_files=query_files,
generated_files=raw_generated_files,
generated_asset_results=generated_asset_results,
program_execution_context=program_execution_context,
location_data=location_data,
user_name=user_name,
chat_history=chat_history,
# Model
model=chat_model.name,
api_key=api_key,
api_base_url=api_base_url,
max_prompt_size=chat_model.max_prompt_size,
tokenizer_name=chat_model.tokenizer,
agent=agent,
vision_available=vision_available,
deepthought=deepthought,
tracer=tracer,
)
@@ -1471,27 +1475,29 @@ async def agenerate_chat_response(
api_key = chat_model.ai_model_api.api_key
api_base_url = chat_model.ai_model_api.api_base_url
chat_response_generator = converse_gemini(
# Query
query_to_run,
compiled_references,
# Context
references=compiled_references,
online_results=online_results,
code_results=code_results,
operator_results=operator_results,
chat_history=chat_history,
model=chat_model.name,
api_key=api_key,
api_base_url=api_base_url,
conversation_commands=conversation_commands,
max_prompt_size=chat_model.max_prompt_size,
tokenizer_name=chat_model.tokenizer,
location_data=location_data,
user_name=user_name,
agent=agent,
query_images=query_images,
vision_available=vision_available,
query_files=query_files,
generated_files=raw_generated_files,
generated_asset_results=generated_asset_results,
program_execution_context=program_execution_context,
location_data=location_data,
user_name=user_name,
chat_history=chat_history,
# Model
model=chat_model.name,
api_key=api_key,
api_base_url=api_base_url,
max_prompt_size=chat_model.max_prompt_size,
tokenizer_name=chat_model.tokenizer,
agent=agent,
vision_available=vision_available,
deepthought=deepthought,
tracer=tracer,
)