From 38fa34a861115a8b6dd99837a2fd7359ca49b82f Mon Sep 17 00:00:00 2001 From: Debanjum Date: Wed, 4 Jun 2025 20:28:21 -0700 Subject: [PATCH] Simplify ai provider converse methods - Add context based on information provided rather than conversation commands. Let caller handle passing appropriate context to ai provider converse methods --- .../conversation/anthropic/anthropic_chat.py | 46 +++----- .../conversation/google/gemini_chat.py | 48 ++++----- .../conversation/offline/chat_model.py | 34 +++--- src/khoj/processor/conversation/openai/gpt.py | 31 ++---- src/khoj/routers/api_chat.py | 1 - src/khoj/routers/helpers.py | 100 ++++++++++-------- 6 files changed, 110 insertions(+), 150 deletions(-) diff --git a/src/khoj/processor/conversation/anthropic/anthropic_chat.py b/src/khoj/processor/conversation/anthropic/anthropic_chat.py index 0079cf8c..8c94fc7f 100644 --- a/src/khoj/processor/conversation/anthropic/anthropic_chat.py +++ b/src/khoj/processor/conversation/anthropic/anthropic_chat.py @@ -21,11 +21,7 @@ from khoj.processor.conversation.utils import ( generate_chatml_messages_with_context, messages_to_print, ) -from khoj.utils.helpers import ( - ConversationCommand, - is_none_or_empty, - truncate_code_context, -) +from khoj.utils.helpers import is_none_or_empty, truncate_code_context from khoj.utils.rawconfig import FileAttachment, LocationData from khoj.utils.yaml import yaml_dump @@ -136,27 +132,29 @@ def anthropic_send_message_to_model( async def converse_anthropic( + # Query user_query: str, + # Context references: list[dict], online_results: Optional[Dict[str, Dict]] = None, code_results: Optional[Dict[str, Dict]] = None, operator_results: Optional[List[OperatorRun]] = None, - chat_history: List[ChatMessageModel] = [], - model: Optional[str] = "claude-3-7-sonnet-latest", - api_key: Optional[str] = None, - api_base_url: Optional[str] = None, - conversation_commands=[ConversationCommand.Default], - max_prompt_size=None, - tokenizer_name=None, - location_data: LocationData = None, - user_name: str = None, - agent: Agent = None, query_images: Optional[list[str]] = None, - vision_available: bool = False, query_files: str = None, generated_files: List[FileAttachment] = None, program_execution_context: Optional[List[str]] = None, generated_asset_results: Dict[str, Dict] = {}, + location_data: LocationData = None, + user_name: str = None, + chat_history: List[ChatMessageModel] = [], + # Model + model: Optional[str] = "claude-3-7-sonnet-latest", + api_key: Optional[str] = None, + api_base_url: Optional[str] = None, + max_prompt_size=None, + tokenizer_name=None, + agent: Agent = None, + vision_available: bool = False, deepthought: Optional[bool] = False, tracer: dict = {}, ) -> AsyncGenerator[ResponseWithThought, None]: @@ -187,26 +185,16 @@ async def converse_anthropic( user_name_prompt = prompts.user_name.format(name=user_name) system_prompt = f"{system_prompt}\n{user_name_prompt}" - # Get Conversation Primer appropriate to Conversation Type - if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references): - response = prompts.no_notes_found.format() - yield ResponseWithThought(response=response) - return - elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results): - response = prompts.no_online_results_found.format() - yield ResponseWithThought(response=response) - return - context_message = "" if not is_none_or_empty(references): context_message = f"{prompts.notes_conversation.format(query=user_query, references=yaml_dump(references))}\n\n" - if ConversationCommand.Online in conversation_commands or ConversationCommand.Webpage in conversation_commands: + if not is_none_or_empty(online_results): context_message += f"{prompts.online_search_conversation.format(online_results=yaml_dump(online_results))}\n\n" - if ConversationCommand.Code in conversation_commands and not is_none_or_empty(code_results): + if not is_none_or_empty(code_results): context_message += ( f"{prompts.code_executed_context.format(code_results=truncate_code_context(code_results))}\n\n" ) - if ConversationCommand.Operator in conversation_commands and not is_none_or_empty(operator_results): + if not is_none_or_empty(operator_results): operator_content = [ {"query": oc.query, "response": oc.response, "webpages": oc.webpages} for oc in operator_results ] diff --git a/src/khoj/processor/conversation/google/gemini_chat.py b/src/khoj/processor/conversation/google/gemini_chat.py index 78cd6fa4..556086aa 100644 --- a/src/khoj/processor/conversation/google/gemini_chat.py +++ b/src/khoj/processor/conversation/google/gemini_chat.py @@ -21,11 +21,7 @@ from khoj.processor.conversation.utils import ( generate_chatml_messages_with_context, messages_to_print, ) -from khoj.utils.helpers import ( - ConversationCommand, - is_none_or_empty, - truncate_code_context, -) +from khoj.utils.helpers import is_none_or_empty, truncate_code_context from khoj.utils.rawconfig import FileAttachment, LocationData from khoj.utils.yaml import yaml_dump @@ -158,28 +154,30 @@ def gemini_send_message_to_model( async def converse_gemini( + # Query user_query: str, + # Context references: list[dict], online_results: Optional[Dict[str, Dict]] = None, code_results: Optional[Dict[str, Dict]] = None, operator_results: Optional[List[OperatorRun]] = None, - chat_history: List[ChatMessageModel] = [], - model: Optional[str] = "gemini-2.0-flash", - api_key: Optional[str] = None, - api_base_url: Optional[str] = None, - temperature: float = 1.0, - conversation_commands=[ConversationCommand.Default], - max_prompt_size=None, - tokenizer_name=None, - location_data: LocationData = None, - user_name: str = None, - agent: Agent = None, query_images: Optional[list[str]] = None, - vision_available: bool = False, query_files: str = None, generated_files: List[FileAttachment] = None, generated_asset_results: Dict[str, Dict] = {}, program_execution_context: List[str] = None, + location_data: LocationData = None, + user_name: str = None, + chat_history: List[ChatMessageModel] = [], + # Model + model: Optional[str] = "gemini-2.0-flash", + api_key: Optional[str] = None, + api_base_url: Optional[str] = None, + temperature: float = 1.0, + max_prompt_size=None, + tokenizer_name=None, + agent: Agent = None, + vision_available: bool = False, deepthought: Optional[bool] = False, tracer={}, ) -> AsyncGenerator[ResponseWithThought, None]: @@ -211,26 +209,16 @@ async def converse_gemini( user_name_prompt = prompts.user_name.format(name=user_name) system_prompt = f"{system_prompt}\n{user_name_prompt}" - # Get Conversation Primer appropriate to Conversation Type - if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references): - response = prompts.no_notes_found.format() - yield ResponseWithThought(response=response) - return - elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results): - response = prompts.no_online_results_found.format() - yield ResponseWithThought(response=response) - return - context_message = "" if not is_none_or_empty(references): context_message = f"{prompts.notes_conversation.format(query=user_query, references=yaml_dump(references))}\n\n" - if ConversationCommand.Online in conversation_commands or ConversationCommand.Webpage in conversation_commands: + if not is_none_or_empty(online_results): context_message += f"{prompts.online_search_conversation.format(online_results=yaml_dump(online_results))}\n\n" - if ConversationCommand.Code in conversation_commands and not is_none_or_empty(code_results): + if not is_none_or_empty(code_results): context_message += ( f"{prompts.code_executed_context.format(code_results=truncate_code_context(code_results))}\n\n" ) - if ConversationCommand.Operator in conversation_commands and not is_none_or_empty(operator_results): + if not is_none_or_empty(operator_results): operator_content = [ {"query": oc.query, "response": oc.response, "webpages": oc.webpages} for oc in operator_results ] diff --git a/src/khoj/processor/conversation/offline/chat_model.py b/src/khoj/processor/conversation/offline/chat_model.py index 639b1c8b..61b9f358 100644 --- a/src/khoj/processor/conversation/offline/chat_model.py +++ b/src/khoj/processor/conversation/offline/chat_model.py @@ -24,7 +24,6 @@ from khoj.processor.conversation.utils import ( from khoj.utils import state from khoj.utils.constants import empty_escape_sequences from khoj.utils.helpers import ( - ConversationCommand, is_none_or_empty, is_promptrace_enabled, truncate_code_context, @@ -144,23 +143,25 @@ def filter_questions(questions: List[str]): async def converse_offline( + # Query user_query: str, + # Context references: list[dict] = [], online_results={}, code_results={}, - chat_history: list[ChatMessageModel] = [], - model_name: str = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", - loaded_model: Union[Any, None] = None, - conversation_commands=[ConversationCommand.Default], - max_prompt_size=None, - tokenizer_name=None, - location_data: LocationData = None, - user_name: str = None, - agent: Agent = None, query_files: str = None, generated_files: List[FileAttachment] = None, additional_context: List[str] = None, generated_asset_results: Dict[str, Dict] = {}, + location_data: LocationData = None, + user_name: str = None, + chat_history: list[ChatMessageModel] = [], + # Model + model_name: str = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", + loaded_model: Union[Any, None] = None, + max_prompt_size=None, + tokenizer_name=None, + agent: Agent = None, tracer: dict = {}, ) -> AsyncGenerator[ResponseWithThought, None]: """ @@ -194,26 +195,17 @@ async def converse_offline( system_prompt = f"{system_prompt}\n{user_name_prompt}" # Get Conversation Primer appropriate to Conversation Type - if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references): - response = prompts.no_notes_found.format() - yield ResponseWithThought(response=response) - return - elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results): - response = prompts.no_online_results_found.format() - yield ResponseWithThought(response=response) - return - context_message = "" if not is_none_or_empty(references): context_message = f"{prompts.notes_conversation_offline.format(references=yaml_dump(references))}\n\n" - if ConversationCommand.Online in conversation_commands or ConversationCommand.Webpage in conversation_commands: + if not is_none_or_empty(online_results): simplified_online_results = online_results.copy() for result in online_results: if online_results[result].get("webpages"): simplified_online_results[result] = online_results[result]["webpages"] context_message += f"{prompts.online_search_conversation_offline.format(online_results=yaml_dump(simplified_online_results))}\n\n" - if ConversationCommand.Code in conversation_commands and not is_none_or_empty(code_results): + if not is_none_or_empty(code_results): context_message += ( f"{prompts.code_executed_context.format(code_results=truncate_code_context(code_results))}\n\n" ) diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index f49030b9..55cbace4 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -24,11 +24,7 @@ from khoj.processor.conversation.utils import ( generate_chatml_messages_with_context, messages_to_print, ) -from khoj.utils.helpers import ( - ConversationCommand, - is_none_or_empty, - truncate_code_context, -) +from khoj.utils.helpers import is_none_or_empty, truncate_code_context from khoj.utils.rawconfig import FileAttachment, LocationData from khoj.utils.yaml import yaml_dump @@ -160,28 +156,29 @@ def send_message_to_model( async def converse_openai( + # Query user_query: str, + # Context references: list[dict], online_results: Optional[Dict[str, Dict]] = None, code_results: Optional[Dict[str, Dict]] = None, operator_results: Optional[List[OperatorRun]] = None, + query_images: Optional[list[str]] = None, + query_files: str = None, + generated_files: List[FileAttachment] = None, + generated_asset_results: Dict[str, Dict] = {}, + program_execution_context: List[str] = None, + location_data: LocationData = None, chat_history: list[ChatMessageModel] = [], model: str = "gpt-4o-mini", api_key: Optional[str] = None, api_base_url: Optional[str] = None, temperature: float = 0.4, - conversation_commands=[ConversationCommand.Default], max_prompt_size=None, tokenizer_name=None, - location_data: LocationData = None, user_name: str = None, agent: Agent = None, - query_images: Optional[list[str]] = None, vision_available: bool = False, - query_files: str = None, - generated_files: List[FileAttachment] = None, - generated_asset_results: Dict[str, Dict] = {}, - program_execution_context: List[str] = None, deepthought: Optional[bool] = False, tracer: dict = {}, ) -> AsyncGenerator[ResponseWithThought, None]: @@ -212,16 +209,6 @@ async def converse_openai( user_name_prompt = prompts.user_name.format(name=user_name) system_prompt = f"{system_prompt}\n{user_name_prompt}" - # Get Conversation Primer appropriate to Conversation Type - if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references): - response = prompts.no_notes_found.format() - yield ResponseWithThought(response=response) - return - elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results): - response = prompts.no_online_results_found.format() - yield ResponseWithThought(response=response) - return - context_message = "" if not is_none_or_empty(references): context_message = f"{prompts.notes_conversation.format(references=yaml_dump(references))}\n\n" diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 83adf4de..8f9c9921 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -1463,7 +1463,6 @@ async def chat( code_results, operator_results, research_results, - conversation_commands, user, location, user_name, diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 6723029a..4d5082d9 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -1348,7 +1348,6 @@ async def agenerate_chat_response( code_results: Dict[str, Dict] = {}, operator_results: List[OperatorRun] = [], research_results: List[ResearchIteration] = [], - conversation_commands: List[ConversationCommand] = [ConversationCommand.Default], user: KhojUser = None, location_data: LocationData = None, user_name: Optional[str] = None, @@ -1362,7 +1361,6 @@ async def agenerate_chat_response( ) -> Tuple[AsyncGenerator[ResponseWithThought, None], Dict[str, str]]: # Initialize Variables chat_response_generator: AsyncGenerator[ResponseWithThought, None] = None - logger.debug(f"Conversation Types: {conversation_commands}") metadata = {} agent = await AgentAdapters.aget_conversation_agent_by_id(conversation.agent.id) if conversation.agent else None @@ -1391,21 +1389,23 @@ async def agenerate_chat_response( if chat_model.model_type == "offline": loaded_model = state.offline_chat_processor_config.loaded_model chat_response_generator = converse_offline( + # Query user_query=query_to_run, + # Context references=compiled_references, online_results=online_results, - loaded_model=loaded_model, + generated_files=raw_generated_files, + generated_asset_results=generated_asset_results, + location_data=location_data, + user_name=user_name, + query_files=query_files, chat_history=chat_history, - conversation_commands=conversation_commands, + # Model + loaded_model=loaded_model, model_name=chat_model.name, max_prompt_size=chat_model.max_prompt_size, tokenizer_name=chat_model.tokenizer, - location_data=location_data, - user_name=user_name, agent=agent, - query_files=query_files, - generated_files=raw_generated_files, - generated_asset_results=generated_asset_results, tracer=tracer, ) @@ -1414,27 +1414,29 @@ async def agenerate_chat_response( api_key = openai_chat_config.api_key chat_model_name = chat_model.name chat_response_generator = converse_openai( + # Query query_to_run, - compiled_references, - query_images=query_images, + # Context + references=compiled_references, online_results=online_results, code_results=code_results, operator_results=operator_results, - chat_history=chat_history, - model=chat_model_name, - api_key=api_key, - api_base_url=openai_chat_config.api_base_url, - conversation_commands=conversation_commands, - max_prompt_size=chat_model.max_prompt_size, - tokenizer_name=chat_model.tokenizer, - location_data=location_data, - user_name=user_name, - agent=agent, - vision_available=vision_available, + query_images=query_images, query_files=query_files, generated_files=raw_generated_files, generated_asset_results=generated_asset_results, program_execution_context=program_execution_context, + location_data=location_data, + user_name=user_name, + chat_history=chat_history, + # Model + model=chat_model_name, + api_key=api_key, + api_base_url=openai_chat_config.api_base_url, + max_prompt_size=chat_model.max_prompt_size, + tokenizer_name=chat_model.tokenizer, + agent=agent, + vision_available=vision_available, deepthought=deepthought, tracer=tracer, ) @@ -1443,27 +1445,29 @@ async def agenerate_chat_response( api_key = chat_model.ai_model_api.api_key api_base_url = chat_model.ai_model_api.api_base_url chat_response_generator = converse_anthropic( + # Query query_to_run, - compiled_references, - query_images=query_images, + # Context + references=compiled_references, online_results=online_results, code_results=code_results, operator_results=operator_results, - chat_history=chat_history, - model=chat_model.name, - api_key=api_key, - api_base_url=api_base_url, - conversation_commands=conversation_commands, - max_prompt_size=chat_model.max_prompt_size, - tokenizer_name=chat_model.tokenizer, - location_data=location_data, - user_name=user_name, - agent=agent, - vision_available=vision_available, + query_images=query_images, query_files=query_files, generated_files=raw_generated_files, generated_asset_results=generated_asset_results, program_execution_context=program_execution_context, + location_data=location_data, + user_name=user_name, + chat_history=chat_history, + # Model + model=chat_model.name, + api_key=api_key, + api_base_url=api_base_url, + max_prompt_size=chat_model.max_prompt_size, + tokenizer_name=chat_model.tokenizer, + agent=agent, + vision_available=vision_available, deepthought=deepthought, tracer=tracer, ) @@ -1471,27 +1475,29 @@ async def agenerate_chat_response( api_key = chat_model.ai_model_api.api_key api_base_url = chat_model.ai_model_api.api_base_url chat_response_generator = converse_gemini( + # Query query_to_run, - compiled_references, + # Context + references=compiled_references, online_results=online_results, code_results=code_results, operator_results=operator_results, - chat_history=chat_history, - model=chat_model.name, - api_key=api_key, - api_base_url=api_base_url, - conversation_commands=conversation_commands, - max_prompt_size=chat_model.max_prompt_size, - tokenizer_name=chat_model.tokenizer, - location_data=location_data, - user_name=user_name, - agent=agent, query_images=query_images, - vision_available=vision_available, query_files=query_files, generated_files=raw_generated_files, generated_asset_results=generated_asset_results, program_execution_context=program_execution_context, + location_data=location_data, + user_name=user_name, + chat_history=chat_history, + # Model + model=chat_model.name, + api_key=api_key, + api_base_url=api_base_url, + max_prompt_size=chat_model.max_prompt_size, + tokenizer_name=chat_model.tokenizer, + agent=agent, + vision_available=vision_available, deepthought=deepthought, tracer=tracer, )