mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 13:18:18 +00:00
Simplify ai provider converse methods
- Add context based on information provided rather than conversation commands. Let caller handle passing appropriate context to ai provider converse methods
This commit is contained in:
@@ -21,11 +21,7 @@ from khoj.processor.conversation.utils import (
|
||||
generate_chatml_messages_with_context,
|
||||
messages_to_print,
|
||||
)
|
||||
from khoj.utils.helpers import (
|
||||
ConversationCommand,
|
||||
is_none_or_empty,
|
||||
truncate_code_context,
|
||||
)
|
||||
from khoj.utils.helpers import is_none_or_empty, truncate_code_context
|
||||
from khoj.utils.rawconfig import FileAttachment, LocationData
|
||||
from khoj.utils.yaml import yaml_dump
|
||||
|
||||
@@ -136,27 +132,29 @@ def anthropic_send_message_to_model(
|
||||
|
||||
|
||||
async def converse_anthropic(
|
||||
# Query
|
||||
user_query: str,
|
||||
# Context
|
||||
references: list[dict],
|
||||
online_results: Optional[Dict[str, Dict]] = None,
|
||||
code_results: Optional[Dict[str, Dict]] = None,
|
||||
operator_results: Optional[List[OperatorRun]] = None,
|
||||
chat_history: List[ChatMessageModel] = [],
|
||||
model: Optional[str] = "claude-3-7-sonnet-latest",
|
||||
api_key: Optional[str] = None,
|
||||
api_base_url: Optional[str] = None,
|
||||
conversation_commands=[ConversationCommand.Default],
|
||||
max_prompt_size=None,
|
||||
tokenizer_name=None,
|
||||
location_data: LocationData = None,
|
||||
user_name: str = None,
|
||||
agent: Agent = None,
|
||||
query_images: Optional[list[str]] = None,
|
||||
vision_available: bool = False,
|
||||
query_files: str = None,
|
||||
generated_files: List[FileAttachment] = None,
|
||||
program_execution_context: Optional[List[str]] = None,
|
||||
generated_asset_results: Dict[str, Dict] = {},
|
||||
location_data: LocationData = None,
|
||||
user_name: str = None,
|
||||
chat_history: List[ChatMessageModel] = [],
|
||||
# Model
|
||||
model: Optional[str] = "claude-3-7-sonnet-latest",
|
||||
api_key: Optional[str] = None,
|
||||
api_base_url: Optional[str] = None,
|
||||
max_prompt_size=None,
|
||||
tokenizer_name=None,
|
||||
agent: Agent = None,
|
||||
vision_available: bool = False,
|
||||
deepthought: Optional[bool] = False,
|
||||
tracer: dict = {},
|
||||
) -> AsyncGenerator[ResponseWithThought, None]:
|
||||
@@ -187,26 +185,16 @@ async def converse_anthropic(
|
||||
user_name_prompt = prompts.user_name.format(name=user_name)
|
||||
system_prompt = f"{system_prompt}\n{user_name_prompt}"
|
||||
|
||||
# Get Conversation Primer appropriate to Conversation Type
|
||||
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
|
||||
response = prompts.no_notes_found.format()
|
||||
yield ResponseWithThought(response=response)
|
||||
return
|
||||
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
|
||||
response = prompts.no_online_results_found.format()
|
||||
yield ResponseWithThought(response=response)
|
||||
return
|
||||
|
||||
context_message = ""
|
||||
if not is_none_or_empty(references):
|
||||
context_message = f"{prompts.notes_conversation.format(query=user_query, references=yaml_dump(references))}\n\n"
|
||||
if ConversationCommand.Online in conversation_commands or ConversationCommand.Webpage in conversation_commands:
|
||||
if not is_none_or_empty(online_results):
|
||||
context_message += f"{prompts.online_search_conversation.format(online_results=yaml_dump(online_results))}\n\n"
|
||||
if ConversationCommand.Code in conversation_commands and not is_none_or_empty(code_results):
|
||||
if not is_none_or_empty(code_results):
|
||||
context_message += (
|
||||
f"{prompts.code_executed_context.format(code_results=truncate_code_context(code_results))}\n\n"
|
||||
)
|
||||
if ConversationCommand.Operator in conversation_commands and not is_none_or_empty(operator_results):
|
||||
if not is_none_or_empty(operator_results):
|
||||
operator_content = [
|
||||
{"query": oc.query, "response": oc.response, "webpages": oc.webpages} for oc in operator_results
|
||||
]
|
||||
|
||||
@@ -21,11 +21,7 @@ from khoj.processor.conversation.utils import (
|
||||
generate_chatml_messages_with_context,
|
||||
messages_to_print,
|
||||
)
|
||||
from khoj.utils.helpers import (
|
||||
ConversationCommand,
|
||||
is_none_or_empty,
|
||||
truncate_code_context,
|
||||
)
|
||||
from khoj.utils.helpers import is_none_or_empty, truncate_code_context
|
||||
from khoj.utils.rawconfig import FileAttachment, LocationData
|
||||
from khoj.utils.yaml import yaml_dump
|
||||
|
||||
@@ -158,28 +154,30 @@ def gemini_send_message_to_model(
|
||||
|
||||
|
||||
async def converse_gemini(
|
||||
# Query
|
||||
user_query: str,
|
||||
# Context
|
||||
references: list[dict],
|
||||
online_results: Optional[Dict[str, Dict]] = None,
|
||||
code_results: Optional[Dict[str, Dict]] = None,
|
||||
operator_results: Optional[List[OperatorRun]] = None,
|
||||
chat_history: List[ChatMessageModel] = [],
|
||||
model: Optional[str] = "gemini-2.0-flash",
|
||||
api_key: Optional[str] = None,
|
||||
api_base_url: Optional[str] = None,
|
||||
temperature: float = 1.0,
|
||||
conversation_commands=[ConversationCommand.Default],
|
||||
max_prompt_size=None,
|
||||
tokenizer_name=None,
|
||||
location_data: LocationData = None,
|
||||
user_name: str = None,
|
||||
agent: Agent = None,
|
||||
query_images: Optional[list[str]] = None,
|
||||
vision_available: bool = False,
|
||||
query_files: str = None,
|
||||
generated_files: List[FileAttachment] = None,
|
||||
generated_asset_results: Dict[str, Dict] = {},
|
||||
program_execution_context: List[str] = None,
|
||||
location_data: LocationData = None,
|
||||
user_name: str = None,
|
||||
chat_history: List[ChatMessageModel] = [],
|
||||
# Model
|
||||
model: Optional[str] = "gemini-2.0-flash",
|
||||
api_key: Optional[str] = None,
|
||||
api_base_url: Optional[str] = None,
|
||||
temperature: float = 1.0,
|
||||
max_prompt_size=None,
|
||||
tokenizer_name=None,
|
||||
agent: Agent = None,
|
||||
vision_available: bool = False,
|
||||
deepthought: Optional[bool] = False,
|
||||
tracer={},
|
||||
) -> AsyncGenerator[ResponseWithThought, None]:
|
||||
@@ -211,26 +209,16 @@ async def converse_gemini(
|
||||
user_name_prompt = prompts.user_name.format(name=user_name)
|
||||
system_prompt = f"{system_prompt}\n{user_name_prompt}"
|
||||
|
||||
# Get Conversation Primer appropriate to Conversation Type
|
||||
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
|
||||
response = prompts.no_notes_found.format()
|
||||
yield ResponseWithThought(response=response)
|
||||
return
|
||||
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
|
||||
response = prompts.no_online_results_found.format()
|
||||
yield ResponseWithThought(response=response)
|
||||
return
|
||||
|
||||
context_message = ""
|
||||
if not is_none_or_empty(references):
|
||||
context_message = f"{prompts.notes_conversation.format(query=user_query, references=yaml_dump(references))}\n\n"
|
||||
if ConversationCommand.Online in conversation_commands or ConversationCommand.Webpage in conversation_commands:
|
||||
if not is_none_or_empty(online_results):
|
||||
context_message += f"{prompts.online_search_conversation.format(online_results=yaml_dump(online_results))}\n\n"
|
||||
if ConversationCommand.Code in conversation_commands and not is_none_or_empty(code_results):
|
||||
if not is_none_or_empty(code_results):
|
||||
context_message += (
|
||||
f"{prompts.code_executed_context.format(code_results=truncate_code_context(code_results))}\n\n"
|
||||
)
|
||||
if ConversationCommand.Operator in conversation_commands and not is_none_or_empty(operator_results):
|
||||
if not is_none_or_empty(operator_results):
|
||||
operator_content = [
|
||||
{"query": oc.query, "response": oc.response, "webpages": oc.webpages} for oc in operator_results
|
||||
]
|
||||
|
||||
@@ -24,7 +24,6 @@ from khoj.processor.conversation.utils import (
|
||||
from khoj.utils import state
|
||||
from khoj.utils.constants import empty_escape_sequences
|
||||
from khoj.utils.helpers import (
|
||||
ConversationCommand,
|
||||
is_none_or_empty,
|
||||
is_promptrace_enabled,
|
||||
truncate_code_context,
|
||||
@@ -144,23 +143,25 @@ def filter_questions(questions: List[str]):
|
||||
|
||||
|
||||
async def converse_offline(
|
||||
# Query
|
||||
user_query: str,
|
||||
# Context
|
||||
references: list[dict] = [],
|
||||
online_results={},
|
||||
code_results={},
|
||||
chat_history: list[ChatMessageModel] = [],
|
||||
model_name: str = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF",
|
||||
loaded_model: Union[Any, None] = None,
|
||||
conversation_commands=[ConversationCommand.Default],
|
||||
max_prompt_size=None,
|
||||
tokenizer_name=None,
|
||||
location_data: LocationData = None,
|
||||
user_name: str = None,
|
||||
agent: Agent = None,
|
||||
query_files: str = None,
|
||||
generated_files: List[FileAttachment] = None,
|
||||
additional_context: List[str] = None,
|
||||
generated_asset_results: Dict[str, Dict] = {},
|
||||
location_data: LocationData = None,
|
||||
user_name: str = None,
|
||||
chat_history: list[ChatMessageModel] = [],
|
||||
# Model
|
||||
model_name: str = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF",
|
||||
loaded_model: Union[Any, None] = None,
|
||||
max_prompt_size=None,
|
||||
tokenizer_name=None,
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
) -> AsyncGenerator[ResponseWithThought, None]:
|
||||
"""
|
||||
@@ -194,26 +195,17 @@ async def converse_offline(
|
||||
system_prompt = f"{system_prompt}\n{user_name_prompt}"
|
||||
|
||||
# Get Conversation Primer appropriate to Conversation Type
|
||||
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
|
||||
response = prompts.no_notes_found.format()
|
||||
yield ResponseWithThought(response=response)
|
||||
return
|
||||
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
|
||||
response = prompts.no_online_results_found.format()
|
||||
yield ResponseWithThought(response=response)
|
||||
return
|
||||
|
||||
context_message = ""
|
||||
if not is_none_or_empty(references):
|
||||
context_message = f"{prompts.notes_conversation_offline.format(references=yaml_dump(references))}\n\n"
|
||||
if ConversationCommand.Online in conversation_commands or ConversationCommand.Webpage in conversation_commands:
|
||||
if not is_none_or_empty(online_results):
|
||||
simplified_online_results = online_results.copy()
|
||||
for result in online_results:
|
||||
if online_results[result].get("webpages"):
|
||||
simplified_online_results[result] = online_results[result]["webpages"]
|
||||
|
||||
context_message += f"{prompts.online_search_conversation_offline.format(online_results=yaml_dump(simplified_online_results))}\n\n"
|
||||
if ConversationCommand.Code in conversation_commands and not is_none_or_empty(code_results):
|
||||
if not is_none_or_empty(code_results):
|
||||
context_message += (
|
||||
f"{prompts.code_executed_context.format(code_results=truncate_code_context(code_results))}\n\n"
|
||||
)
|
||||
|
||||
@@ -24,11 +24,7 @@ from khoj.processor.conversation.utils import (
|
||||
generate_chatml_messages_with_context,
|
||||
messages_to_print,
|
||||
)
|
||||
from khoj.utils.helpers import (
|
||||
ConversationCommand,
|
||||
is_none_or_empty,
|
||||
truncate_code_context,
|
||||
)
|
||||
from khoj.utils.helpers import is_none_or_empty, truncate_code_context
|
||||
from khoj.utils.rawconfig import FileAttachment, LocationData
|
||||
from khoj.utils.yaml import yaml_dump
|
||||
|
||||
@@ -160,28 +156,29 @@ def send_message_to_model(
|
||||
|
||||
|
||||
async def converse_openai(
|
||||
# Query
|
||||
user_query: str,
|
||||
# Context
|
||||
references: list[dict],
|
||||
online_results: Optional[Dict[str, Dict]] = None,
|
||||
code_results: Optional[Dict[str, Dict]] = None,
|
||||
operator_results: Optional[List[OperatorRun]] = None,
|
||||
query_images: Optional[list[str]] = None,
|
||||
query_files: str = None,
|
||||
generated_files: List[FileAttachment] = None,
|
||||
generated_asset_results: Dict[str, Dict] = {},
|
||||
program_execution_context: List[str] = None,
|
||||
location_data: LocationData = None,
|
||||
chat_history: list[ChatMessageModel] = [],
|
||||
model: str = "gpt-4o-mini",
|
||||
api_key: Optional[str] = None,
|
||||
api_base_url: Optional[str] = None,
|
||||
temperature: float = 0.4,
|
||||
conversation_commands=[ConversationCommand.Default],
|
||||
max_prompt_size=None,
|
||||
tokenizer_name=None,
|
||||
location_data: LocationData = None,
|
||||
user_name: str = None,
|
||||
agent: Agent = None,
|
||||
query_images: Optional[list[str]] = None,
|
||||
vision_available: bool = False,
|
||||
query_files: str = None,
|
||||
generated_files: List[FileAttachment] = None,
|
||||
generated_asset_results: Dict[str, Dict] = {},
|
||||
program_execution_context: List[str] = None,
|
||||
deepthought: Optional[bool] = False,
|
||||
tracer: dict = {},
|
||||
) -> AsyncGenerator[ResponseWithThought, None]:
|
||||
@@ -212,16 +209,6 @@ async def converse_openai(
|
||||
user_name_prompt = prompts.user_name.format(name=user_name)
|
||||
system_prompt = f"{system_prompt}\n{user_name_prompt}"
|
||||
|
||||
# Get Conversation Primer appropriate to Conversation Type
|
||||
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
|
||||
response = prompts.no_notes_found.format()
|
||||
yield ResponseWithThought(response=response)
|
||||
return
|
||||
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
|
||||
response = prompts.no_online_results_found.format()
|
||||
yield ResponseWithThought(response=response)
|
||||
return
|
||||
|
||||
context_message = ""
|
||||
if not is_none_or_empty(references):
|
||||
context_message = f"{prompts.notes_conversation.format(references=yaml_dump(references))}\n\n"
|
||||
|
||||
@@ -1463,7 +1463,6 @@ async def chat(
|
||||
code_results,
|
||||
operator_results,
|
||||
research_results,
|
||||
conversation_commands,
|
||||
user,
|
||||
location,
|
||||
user_name,
|
||||
|
||||
@@ -1348,7 +1348,6 @@ async def agenerate_chat_response(
|
||||
code_results: Dict[str, Dict] = {},
|
||||
operator_results: List[OperatorRun] = [],
|
||||
research_results: List[ResearchIteration] = [],
|
||||
conversation_commands: List[ConversationCommand] = [ConversationCommand.Default],
|
||||
user: KhojUser = None,
|
||||
location_data: LocationData = None,
|
||||
user_name: Optional[str] = None,
|
||||
@@ -1362,7 +1361,6 @@ async def agenerate_chat_response(
|
||||
) -> Tuple[AsyncGenerator[ResponseWithThought, None], Dict[str, str]]:
|
||||
# Initialize Variables
|
||||
chat_response_generator: AsyncGenerator[ResponseWithThought, None] = None
|
||||
logger.debug(f"Conversation Types: {conversation_commands}")
|
||||
|
||||
metadata = {}
|
||||
agent = await AgentAdapters.aget_conversation_agent_by_id(conversation.agent.id) if conversation.agent else None
|
||||
@@ -1391,21 +1389,23 @@ async def agenerate_chat_response(
|
||||
if chat_model.model_type == "offline":
|
||||
loaded_model = state.offline_chat_processor_config.loaded_model
|
||||
chat_response_generator = converse_offline(
|
||||
# Query
|
||||
user_query=query_to_run,
|
||||
# Context
|
||||
references=compiled_references,
|
||||
online_results=online_results,
|
||||
loaded_model=loaded_model,
|
||||
generated_files=raw_generated_files,
|
||||
generated_asset_results=generated_asset_results,
|
||||
location_data=location_data,
|
||||
user_name=user_name,
|
||||
query_files=query_files,
|
||||
chat_history=chat_history,
|
||||
conversation_commands=conversation_commands,
|
||||
# Model
|
||||
loaded_model=loaded_model,
|
||||
model_name=chat_model.name,
|
||||
max_prompt_size=chat_model.max_prompt_size,
|
||||
tokenizer_name=chat_model.tokenizer,
|
||||
location_data=location_data,
|
||||
user_name=user_name,
|
||||
agent=agent,
|
||||
query_files=query_files,
|
||||
generated_files=raw_generated_files,
|
||||
generated_asset_results=generated_asset_results,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
@@ -1414,27 +1414,29 @@ async def agenerate_chat_response(
|
||||
api_key = openai_chat_config.api_key
|
||||
chat_model_name = chat_model.name
|
||||
chat_response_generator = converse_openai(
|
||||
# Query
|
||||
query_to_run,
|
||||
compiled_references,
|
||||
query_images=query_images,
|
||||
# Context
|
||||
references=compiled_references,
|
||||
online_results=online_results,
|
||||
code_results=code_results,
|
||||
operator_results=operator_results,
|
||||
chat_history=chat_history,
|
||||
model=chat_model_name,
|
||||
api_key=api_key,
|
||||
api_base_url=openai_chat_config.api_base_url,
|
||||
conversation_commands=conversation_commands,
|
||||
max_prompt_size=chat_model.max_prompt_size,
|
||||
tokenizer_name=chat_model.tokenizer,
|
||||
location_data=location_data,
|
||||
user_name=user_name,
|
||||
agent=agent,
|
||||
vision_available=vision_available,
|
||||
query_images=query_images,
|
||||
query_files=query_files,
|
||||
generated_files=raw_generated_files,
|
||||
generated_asset_results=generated_asset_results,
|
||||
program_execution_context=program_execution_context,
|
||||
location_data=location_data,
|
||||
user_name=user_name,
|
||||
chat_history=chat_history,
|
||||
# Model
|
||||
model=chat_model_name,
|
||||
api_key=api_key,
|
||||
api_base_url=openai_chat_config.api_base_url,
|
||||
max_prompt_size=chat_model.max_prompt_size,
|
||||
tokenizer_name=chat_model.tokenizer,
|
||||
agent=agent,
|
||||
vision_available=vision_available,
|
||||
deepthought=deepthought,
|
||||
tracer=tracer,
|
||||
)
|
||||
@@ -1443,27 +1445,29 @@ async def agenerate_chat_response(
|
||||
api_key = chat_model.ai_model_api.api_key
|
||||
api_base_url = chat_model.ai_model_api.api_base_url
|
||||
chat_response_generator = converse_anthropic(
|
||||
# Query
|
||||
query_to_run,
|
||||
compiled_references,
|
||||
query_images=query_images,
|
||||
# Context
|
||||
references=compiled_references,
|
||||
online_results=online_results,
|
||||
code_results=code_results,
|
||||
operator_results=operator_results,
|
||||
chat_history=chat_history,
|
||||
model=chat_model.name,
|
||||
api_key=api_key,
|
||||
api_base_url=api_base_url,
|
||||
conversation_commands=conversation_commands,
|
||||
max_prompt_size=chat_model.max_prompt_size,
|
||||
tokenizer_name=chat_model.tokenizer,
|
||||
location_data=location_data,
|
||||
user_name=user_name,
|
||||
agent=agent,
|
||||
vision_available=vision_available,
|
||||
query_images=query_images,
|
||||
query_files=query_files,
|
||||
generated_files=raw_generated_files,
|
||||
generated_asset_results=generated_asset_results,
|
||||
program_execution_context=program_execution_context,
|
||||
location_data=location_data,
|
||||
user_name=user_name,
|
||||
chat_history=chat_history,
|
||||
# Model
|
||||
model=chat_model.name,
|
||||
api_key=api_key,
|
||||
api_base_url=api_base_url,
|
||||
max_prompt_size=chat_model.max_prompt_size,
|
||||
tokenizer_name=chat_model.tokenizer,
|
||||
agent=agent,
|
||||
vision_available=vision_available,
|
||||
deepthought=deepthought,
|
||||
tracer=tracer,
|
||||
)
|
||||
@@ -1471,27 +1475,29 @@ async def agenerate_chat_response(
|
||||
api_key = chat_model.ai_model_api.api_key
|
||||
api_base_url = chat_model.ai_model_api.api_base_url
|
||||
chat_response_generator = converse_gemini(
|
||||
# Query
|
||||
query_to_run,
|
||||
compiled_references,
|
||||
# Context
|
||||
references=compiled_references,
|
||||
online_results=online_results,
|
||||
code_results=code_results,
|
||||
operator_results=operator_results,
|
||||
chat_history=chat_history,
|
||||
model=chat_model.name,
|
||||
api_key=api_key,
|
||||
api_base_url=api_base_url,
|
||||
conversation_commands=conversation_commands,
|
||||
max_prompt_size=chat_model.max_prompt_size,
|
||||
tokenizer_name=chat_model.tokenizer,
|
||||
location_data=location_data,
|
||||
user_name=user_name,
|
||||
agent=agent,
|
||||
query_images=query_images,
|
||||
vision_available=vision_available,
|
||||
query_files=query_files,
|
||||
generated_files=raw_generated_files,
|
||||
generated_asset_results=generated_asset_results,
|
||||
program_execution_context=program_execution_context,
|
||||
location_data=location_data,
|
||||
user_name=user_name,
|
||||
chat_history=chat_history,
|
||||
# Model
|
||||
model=chat_model.name,
|
||||
api_key=api_key,
|
||||
api_base_url=api_base_url,
|
||||
max_prompt_size=chat_model.max_prompt_size,
|
||||
tokenizer_name=chat_model.tokenizer,
|
||||
agent=agent,
|
||||
vision_available=vision_available,
|
||||
deepthought=deepthought,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user