mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 13:18:18 +00:00
Share context builder for chat final response across model types
The context building logic was nearly identical across all model types. This change extracts that logic into a shared function and calls it once in the `agenerate_chat_response', the entrypoint to the converse methods for all 3 model types. Main differences handled are - Gemini system prompt had additional verbosity instructions. Keep it - Pass system messsage via chatml messages list to anthropic, gemini models as well (like openai models) instead of passing it as separate arg to chat_completion_* funcs. The model specific message formatters for both already extract system instruction from the messages list. So system messages wil be automatically extracted from the chat_completion_* funcs to pass as separate arg required by anthropic, gemini api libraries.
This commit is contained in:
@@ -1,22 +1,16 @@
|
|||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
from typing import AsyncGenerator, List, Optional
|
||||||
from typing import AsyncGenerator, Dict, List, Optional
|
|
||||||
|
from langchain_core.messages.chat import ChatMessage
|
||||||
|
|
||||||
from khoj.database.models import Agent, ChatMessageModel, ChatModel
|
|
||||||
from khoj.processor.conversation import prompts
|
|
||||||
from khoj.processor.conversation.anthropic.utils import (
|
from khoj.processor.conversation.anthropic.utils import (
|
||||||
anthropic_chat_completion_with_backoff,
|
anthropic_chat_completion_with_backoff,
|
||||||
anthropic_completion_with_backoff,
|
anthropic_completion_with_backoff,
|
||||||
)
|
)
|
||||||
from khoj.processor.conversation.utils import (
|
from khoj.processor.conversation.utils import (
|
||||||
OperatorRun,
|
|
||||||
ResponseWithThought,
|
ResponseWithThought,
|
||||||
generate_chatml_messages_with_context,
|
|
||||||
messages_to_print,
|
messages_to_print,
|
||||||
)
|
)
|
||||||
from khoj.utils.helpers import is_none_or_empty, truncate_code_context
|
|
||||||
from khoj.utils.rawconfig import LocationData
|
|
||||||
from khoj.utils.yaml import yaml_dump
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -52,91 +46,17 @@ def anthropic_send_message_to_model(
|
|||||||
|
|
||||||
async def converse_anthropic(
|
async def converse_anthropic(
|
||||||
# Query
|
# Query
|
||||||
user_query: str,
|
messages: List[ChatMessage],
|
||||||
# Context
|
|
||||||
references: list[dict],
|
|
||||||
online_results: Optional[Dict[str, Dict]] = None,
|
|
||||||
code_results: Optional[Dict[str, Dict]] = None,
|
|
||||||
operator_results: Optional[List[OperatorRun]] = None,
|
|
||||||
query_images: Optional[list[str]] = None,
|
|
||||||
query_files: str = None,
|
|
||||||
program_execution_context: Optional[List[str]] = None,
|
|
||||||
generated_asset_results: Dict[str, Dict] = {},
|
|
||||||
location_data: LocationData = None,
|
|
||||||
user_name: str = None,
|
|
||||||
chat_history: List[ChatMessageModel] = [],
|
|
||||||
# Model
|
# Model
|
||||||
model: Optional[str] = "claude-3-7-sonnet-latest",
|
model: Optional[str] = "claude-3-7-sonnet-latest",
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base_url: Optional[str] = None,
|
api_base_url: Optional[str] = None,
|
||||||
max_prompt_size=None,
|
|
||||||
tokenizer_name=None,
|
|
||||||
agent: Agent = None,
|
|
||||||
vision_available: bool = False,
|
|
||||||
deepthought: Optional[bool] = False,
|
deepthought: Optional[bool] = False,
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
) -> AsyncGenerator[ResponseWithThought, None]:
|
) -> AsyncGenerator[ResponseWithThought, None]:
|
||||||
"""
|
"""
|
||||||
Converse with user using Anthropic's Claude
|
Converse with user using Anthropic's Claude
|
||||||
"""
|
"""
|
||||||
# Initialize Variables
|
|
||||||
current_date = datetime.now()
|
|
||||||
|
|
||||||
if agent and agent.personality:
|
|
||||||
system_prompt = prompts.custom_personality.format(
|
|
||||||
name=agent.name,
|
|
||||||
bio=agent.personality,
|
|
||||||
current_date=current_date.strftime("%Y-%m-%d"),
|
|
||||||
day_of_week=current_date.strftime("%A"),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
system_prompt = prompts.personality.format(
|
|
||||||
current_date=current_date.strftime("%Y-%m-%d"),
|
|
||||||
day_of_week=current_date.strftime("%A"),
|
|
||||||
)
|
|
||||||
|
|
||||||
if location_data:
|
|
||||||
location_prompt = prompts.user_location.format(location=f"{location_data}")
|
|
||||||
system_prompt = f"{system_prompt}\n{location_prompt}"
|
|
||||||
|
|
||||||
if user_name:
|
|
||||||
user_name_prompt = prompts.user_name.format(name=user_name)
|
|
||||||
system_prompt = f"{system_prompt}\n{user_name_prompt}"
|
|
||||||
|
|
||||||
context_message = ""
|
|
||||||
if not is_none_or_empty(references):
|
|
||||||
context_message = f"{prompts.notes_conversation.format(query=user_query, references=yaml_dump(references))}\n\n"
|
|
||||||
if not is_none_or_empty(online_results):
|
|
||||||
context_message += f"{prompts.online_search_conversation.format(online_results=yaml_dump(online_results))}\n\n"
|
|
||||||
if not is_none_or_empty(code_results):
|
|
||||||
context_message += (
|
|
||||||
f"{prompts.code_executed_context.format(code_results=truncate_code_context(code_results))}\n\n"
|
|
||||||
)
|
|
||||||
if not is_none_or_empty(operator_results):
|
|
||||||
operator_content = [
|
|
||||||
{"query": oc.query, "response": oc.response, "webpages": oc.webpages} for oc in operator_results
|
|
||||||
]
|
|
||||||
context_message += (
|
|
||||||
f"{prompts.operator_execution_context.format(operator_results=yaml_dump(operator_content))}\n\n"
|
|
||||||
)
|
|
||||||
context_message = context_message.strip()
|
|
||||||
|
|
||||||
# Setup Prompt with Primer or Conversation History
|
|
||||||
messages = generate_chatml_messages_with_context(
|
|
||||||
user_query,
|
|
||||||
context_message=context_message,
|
|
||||||
chat_history=chat_history,
|
|
||||||
model_name=model,
|
|
||||||
max_prompt_size=max_prompt_size,
|
|
||||||
tokenizer_name=tokenizer_name,
|
|
||||||
query_images=query_images,
|
|
||||||
vision_enabled=vision_available,
|
|
||||||
model_type=ChatModel.ModelType.ANTHROPIC,
|
|
||||||
query_files=query_files,
|
|
||||||
generated_asset_results=generated_asset_results,
|
|
||||||
program_execution_context=program_execution_context,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug(f"Conversation Context for Claude: {messages_to_print(messages)}")
|
logger.debug(f"Conversation Context for Claude: {messages_to_print(messages)}")
|
||||||
|
|
||||||
# Get Response from Claude
|
# Get Response from Claude
|
||||||
@@ -146,8 +66,6 @@ async def converse_anthropic(
|
|||||||
temperature=0.2,
|
temperature=0.2,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_base_url=api_base_url,
|
api_base_url=api_base_url,
|
||||||
system_prompt=system_prompt,
|
|
||||||
max_prompt_size=max_prompt_size,
|
|
||||||
deepthought=deepthought,
|
deepthought=deepthought,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -184,8 +184,7 @@ async def anthropic_chat_completion_with_backoff(
|
|||||||
temperature: float,
|
temperature: float,
|
||||||
api_key: str | None,
|
api_key: str | None,
|
||||||
api_base_url: str,
|
api_base_url: str,
|
||||||
system_prompt: str,
|
system_prompt: str = "",
|
||||||
max_prompt_size: int | None = None,
|
|
||||||
deepthought: bool = False,
|
deepthought: bool = False,
|
||||||
model_kwargs: dict | None = None,
|
model_kwargs: dict | None = None,
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
|
|||||||
@@ -1,22 +1,16 @@
|
|||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
from typing import AsyncGenerator, List, Optional
|
||||||
from typing import AsyncGenerator, Dict, List, Optional
|
|
||||||
|
from langchain_core.messages.chat import ChatMessage
|
||||||
|
|
||||||
from khoj.database.models import Agent, ChatMessageModel, ChatModel
|
|
||||||
from khoj.processor.conversation import prompts
|
|
||||||
from khoj.processor.conversation.google.utils import (
|
from khoj.processor.conversation.google.utils import (
|
||||||
gemini_chat_completion_with_backoff,
|
gemini_chat_completion_with_backoff,
|
||||||
gemini_completion_with_backoff,
|
gemini_completion_with_backoff,
|
||||||
)
|
)
|
||||||
from khoj.processor.conversation.utils import (
|
from khoj.processor.conversation.utils import (
|
||||||
OperatorRun,
|
|
||||||
ResponseWithThought,
|
ResponseWithThought,
|
||||||
generate_chatml_messages_with_context,
|
|
||||||
messages_to_print,
|
messages_to_print,
|
||||||
)
|
)
|
||||||
from khoj.utils.helpers import is_none_or_empty, truncate_code_context
|
|
||||||
from khoj.utils.rawconfig import LocationData
|
|
||||||
from khoj.utils.yaml import yaml_dump
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -61,93 +55,18 @@ def gemini_send_message_to_model(
|
|||||||
|
|
||||||
async def converse_gemini(
|
async def converse_gemini(
|
||||||
# Query
|
# Query
|
||||||
user_query: str,
|
messages: List[ChatMessage],
|
||||||
# Context
|
|
||||||
references: list[dict],
|
|
||||||
online_results: Optional[Dict[str, Dict]] = None,
|
|
||||||
code_results: Optional[Dict[str, Dict]] = None,
|
|
||||||
operator_results: Optional[List[OperatorRun]] = None,
|
|
||||||
query_images: Optional[list[str]] = None,
|
|
||||||
query_files: str = None,
|
|
||||||
generated_asset_results: Dict[str, Dict] = {},
|
|
||||||
program_execution_context: List[str] = None,
|
|
||||||
location_data: LocationData = None,
|
|
||||||
user_name: str = None,
|
|
||||||
chat_history: List[ChatMessageModel] = [],
|
|
||||||
# Model
|
# Model
|
||||||
model: Optional[str] = "gemini-2.5-flash",
|
model: Optional[str] = "gemini-2.5-flash",
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base_url: Optional[str] = None,
|
api_base_url: Optional[str] = None,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
max_prompt_size=None,
|
|
||||||
tokenizer_name=None,
|
|
||||||
agent: Agent = None,
|
|
||||||
vision_available: bool = False,
|
|
||||||
deepthought: Optional[bool] = False,
|
deepthought: Optional[bool] = False,
|
||||||
tracer={},
|
tracer={},
|
||||||
) -> AsyncGenerator[ResponseWithThought, None]:
|
) -> AsyncGenerator[ResponseWithThought, None]:
|
||||||
"""
|
"""
|
||||||
Converse with user using Google's Gemini
|
Converse with user using Google's Gemini
|
||||||
"""
|
"""
|
||||||
# Initialize Variables
|
|
||||||
current_date = datetime.now()
|
|
||||||
|
|
||||||
if agent and agent.personality:
|
|
||||||
system_prompt = prompts.custom_personality.format(
|
|
||||||
name=agent.name,
|
|
||||||
bio=agent.personality,
|
|
||||||
current_date=current_date.strftime("%Y-%m-%d"),
|
|
||||||
day_of_week=current_date.strftime("%A"),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
system_prompt = prompts.personality.format(
|
|
||||||
current_date=current_date.strftime("%Y-%m-%d"),
|
|
||||||
day_of_week=current_date.strftime("%A"),
|
|
||||||
)
|
|
||||||
|
|
||||||
system_prompt += f"\n\n{prompts.gemini_verbose_language_personality}"
|
|
||||||
if location_data:
|
|
||||||
location_prompt = prompts.user_location.format(location=f"{location_data}")
|
|
||||||
system_prompt += f"\n{location_prompt}"
|
|
||||||
|
|
||||||
if user_name:
|
|
||||||
user_name_prompt = prompts.user_name.format(name=user_name)
|
|
||||||
system_prompt += f"\n{user_name_prompt}"
|
|
||||||
|
|
||||||
context_message = ""
|
|
||||||
if not is_none_or_empty(references):
|
|
||||||
context_message = f"{prompts.notes_conversation.format(query=user_query, references=yaml_dump(references))}\n\n"
|
|
||||||
if not is_none_or_empty(online_results):
|
|
||||||
context_message += f"{prompts.online_search_conversation.format(online_results=yaml_dump(online_results))}\n\n"
|
|
||||||
if not is_none_or_empty(code_results):
|
|
||||||
context_message += (
|
|
||||||
f"{prompts.code_executed_context.format(code_results=truncate_code_context(code_results))}\n\n"
|
|
||||||
)
|
|
||||||
if not is_none_or_empty(operator_results):
|
|
||||||
operator_content = [
|
|
||||||
{"query": oc.query, "response": oc.response, "webpages": oc.webpages} for oc in operator_results
|
|
||||||
]
|
|
||||||
context_message += (
|
|
||||||
f"{prompts.operator_execution_context.format(operator_results=yaml_dump(operator_content))}\n\n"
|
|
||||||
)
|
|
||||||
context_message = context_message.strip()
|
|
||||||
|
|
||||||
# Setup Prompt with Primer or Conversation History
|
|
||||||
messages = generate_chatml_messages_with_context(
|
|
||||||
user_query,
|
|
||||||
context_message=context_message,
|
|
||||||
chat_history=chat_history,
|
|
||||||
model_name=model,
|
|
||||||
max_prompt_size=max_prompt_size,
|
|
||||||
tokenizer_name=tokenizer_name,
|
|
||||||
query_images=query_images,
|
|
||||||
vision_enabled=vision_available,
|
|
||||||
model_type=ChatModel.ModelType.GOOGLE,
|
|
||||||
query_files=query_files,
|
|
||||||
generated_asset_results=generated_asset_results,
|
|
||||||
program_execution_context=program_execution_context,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug(f"Conversation Context for Gemini: {messages_to_print(messages)}")
|
logger.debug(f"Conversation Context for Gemini: {messages_to_print(messages)}")
|
||||||
|
|
||||||
# Get Response from Google AI
|
# Get Response from Google AI
|
||||||
@@ -157,7 +76,6 @@ async def converse_gemini(
|
|||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_base_url=api_base_url,
|
api_base_url=api_base_url,
|
||||||
system_prompt=system_prompt,
|
|
||||||
deepthought=deepthought,
|
deepthought=deepthought,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -308,7 +308,7 @@ async def gemini_chat_completion_with_backoff(
|
|||||||
temperature: float,
|
temperature: float,
|
||||||
api_key: str,
|
api_key: str,
|
||||||
api_base_url: str,
|
api_base_url: str,
|
||||||
system_prompt: str,
|
system_prompt: str = "",
|
||||||
model_kwargs=None,
|
model_kwargs=None,
|
||||||
deepthought=False,
|
deepthought=False,
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
|
|||||||
@@ -1,9 +1,8 @@
|
|||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
|
||||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||||
|
|
||||||
from khoj.database.models import Agent, ChatMessageModel, ChatModel
|
from langchain_core.messages.chat import ChatMessage
|
||||||
from khoj.processor.conversation import prompts
|
|
||||||
from khoj.processor.conversation.openai.utils import (
|
from khoj.processor.conversation.openai.utils import (
|
||||||
chat_completion_with_backoff,
|
chat_completion_with_backoff,
|
||||||
clean_response_schema,
|
clean_response_schema,
|
||||||
@@ -15,15 +14,11 @@ from khoj.processor.conversation.openai.utils import (
|
|||||||
to_openai_tools,
|
to_openai_tools,
|
||||||
)
|
)
|
||||||
from khoj.processor.conversation.utils import (
|
from khoj.processor.conversation.utils import (
|
||||||
OperatorRun,
|
|
||||||
ResponseWithThought,
|
ResponseWithThought,
|
||||||
StructuredOutputSupport,
|
StructuredOutputSupport,
|
||||||
generate_chatml_messages_with_context,
|
|
||||||
messages_to_print,
|
messages_to_print,
|
||||||
)
|
)
|
||||||
from khoj.utils.helpers import ToolDefinition, is_none_or_empty, truncate_code_context
|
from khoj.utils.helpers import ToolDefinition
|
||||||
from khoj.utils.rawconfig import LocationData
|
|
||||||
from khoj.utils.yaml import yaml_dump
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -96,92 +91,18 @@ def send_message_to_model(
|
|||||||
|
|
||||||
async def converse_openai(
|
async def converse_openai(
|
||||||
# Query
|
# Query
|
||||||
user_query: str,
|
messages: List[ChatMessage],
|
||||||
# Context
|
# Model
|
||||||
references: list[dict],
|
|
||||||
online_results: Optional[Dict[str, Dict]] = None,
|
|
||||||
code_results: Optional[Dict[str, Dict]] = None,
|
|
||||||
operator_results: Optional[List[OperatorRun]] = None,
|
|
||||||
query_images: Optional[list[str]] = None,
|
|
||||||
query_files: str = None,
|
|
||||||
generated_asset_results: Dict[str, Dict] = {},
|
|
||||||
program_execution_context: List[str] = None,
|
|
||||||
location_data: LocationData = None,
|
|
||||||
chat_history: list[ChatMessageModel] = [],
|
|
||||||
model: str = "gpt-4.1-mini",
|
model: str = "gpt-4.1-mini",
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base_url: Optional[str] = None,
|
api_base_url: Optional[str] = None,
|
||||||
temperature: float = 0.6,
|
temperature: float = 0.6,
|
||||||
max_prompt_size=None,
|
|
||||||
tokenizer_name=None,
|
|
||||||
user_name: str = None,
|
|
||||||
agent: Agent = None,
|
|
||||||
vision_available: bool = False,
|
|
||||||
deepthought: Optional[bool] = False,
|
deepthought: Optional[bool] = False,
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
) -> AsyncGenerator[ResponseWithThought, None]:
|
) -> AsyncGenerator[ResponseWithThought, None]:
|
||||||
"""
|
"""
|
||||||
Converse with user using OpenAI's ChatGPT
|
Converse with user using OpenAI's ChatGPT
|
||||||
"""
|
"""
|
||||||
# Initialize Variables
|
|
||||||
current_date = datetime.now()
|
|
||||||
|
|
||||||
if agent and agent.personality:
|
|
||||||
system_prompt = prompts.custom_personality.format(
|
|
||||||
name=agent.name,
|
|
||||||
bio=agent.personality,
|
|
||||||
current_date=current_date.strftime("%Y-%m-%d"),
|
|
||||||
day_of_week=current_date.strftime("%A"),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
system_prompt = prompts.personality.format(
|
|
||||||
current_date=current_date.strftime("%Y-%m-%d"),
|
|
||||||
day_of_week=current_date.strftime("%A"),
|
|
||||||
)
|
|
||||||
|
|
||||||
if location_data:
|
|
||||||
location_prompt = prompts.user_location.format(location=f"{location_data}")
|
|
||||||
system_prompt = f"{system_prompt}\n{location_prompt}"
|
|
||||||
|
|
||||||
if user_name:
|
|
||||||
user_name_prompt = prompts.user_name.format(name=user_name)
|
|
||||||
system_prompt = f"{system_prompt}\n{user_name_prompt}"
|
|
||||||
|
|
||||||
context_message = ""
|
|
||||||
if not is_none_or_empty(references):
|
|
||||||
context_message = f"{prompts.notes_conversation.format(references=yaml_dump(references))}\n\n"
|
|
||||||
if not is_none_or_empty(online_results):
|
|
||||||
context_message += f"{prompts.online_search_conversation.format(online_results=yaml_dump(online_results))}\n\n"
|
|
||||||
if not is_none_or_empty(code_results):
|
|
||||||
context_message += (
|
|
||||||
f"{prompts.code_executed_context.format(code_results=truncate_code_context(code_results))}\n\n"
|
|
||||||
)
|
|
||||||
if not is_none_or_empty(operator_results):
|
|
||||||
operator_content = [
|
|
||||||
{"query": oc.query, "response": oc.response, "webpages": oc.webpages} for oc in operator_results
|
|
||||||
]
|
|
||||||
context_message += (
|
|
||||||
f"{prompts.operator_execution_context.format(operator_results=yaml_dump(operator_content))}\n\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
context_message = context_message.strip()
|
|
||||||
|
|
||||||
# Setup Prompt with Primer or Conversation History
|
|
||||||
messages = generate_chatml_messages_with_context(
|
|
||||||
user_query,
|
|
||||||
system_prompt,
|
|
||||||
chat_history,
|
|
||||||
context_message=context_message,
|
|
||||||
model_name=model,
|
|
||||||
max_prompt_size=max_prompt_size,
|
|
||||||
tokenizer_name=tokenizer_name,
|
|
||||||
query_images=query_images,
|
|
||||||
vision_enabled=vision_available,
|
|
||||||
model_type=ChatModel.ModelType.OPENAI,
|
|
||||||
query_files=query_files,
|
|
||||||
generated_asset_results=generated_asset_results,
|
|
||||||
program_execution_context=program_execution_context,
|
|
||||||
)
|
|
||||||
logger.debug(f"Conversation Context for GPT: {messages_to_print(messages)}")
|
logger.debug(f"Conversation Context for GPT: {messages_to_print(messages)}")
|
||||||
|
|
||||||
# Get Response from GPT
|
# Get Response from GPT
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ from apscheduler.triggers.cron import CronTrigger
|
|||||||
from asgiref.sync import sync_to_async
|
from asgiref.sync import sync_to_async
|
||||||
from django.utils import timezone as django_timezone
|
from django.utils import timezone as django_timezone
|
||||||
from fastapi import Depends, Header, HTTPException, Request, UploadFile, WebSocket
|
from fastapi import Depends, Header, HTTPException, Request, UploadFile, WebSocket
|
||||||
|
from langchain_core.messages.chat import ChatMessage
|
||||||
from pydantic import BaseModel, EmailStr, Field
|
from pydantic import BaseModel, EmailStr, Field
|
||||||
from starlette.authentication import has_required_scope
|
from starlette.authentication import has_required_scope
|
||||||
from starlette.requests import URL
|
from starlette.requests import URL
|
||||||
@@ -124,6 +125,7 @@ from khoj.utils.helpers import (
|
|||||||
mode_descriptions_for_llm,
|
mode_descriptions_for_llm,
|
||||||
timer,
|
timer,
|
||||||
tool_descriptions_for_llm,
|
tool_descriptions_for_llm,
|
||||||
|
truncate_code_context,
|
||||||
)
|
)
|
||||||
from khoj.utils.rawconfig import (
|
from khoj.utils.rawconfig import (
|
||||||
ChatRequestBody,
|
ChatRequestBody,
|
||||||
@@ -132,6 +134,7 @@ from khoj.utils.rawconfig import (
|
|||||||
SearchResponse,
|
SearchResponse,
|
||||||
)
|
)
|
||||||
from khoj.utils.state import SearchType
|
from khoj.utils.state import SearchType
|
||||||
|
from khoj.utils.yaml import yaml_dump
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -1598,6 +1601,105 @@ def send_message_to_model_wrapper_sync(
|
|||||||
raise HTTPException(status_code=500, detail="Invalid conversation config")
|
raise HTTPException(status_code=500, detail="Invalid conversation config")
|
||||||
|
|
||||||
|
|
||||||
|
def build_conversation_context(
|
||||||
|
# Query and Context
|
||||||
|
user_query: str,
|
||||||
|
references: List[Dict],
|
||||||
|
online_results: Dict[str, Dict],
|
||||||
|
code_results: Dict[str, Dict],
|
||||||
|
operator_results: List[OperatorRun],
|
||||||
|
query_files: str = None,
|
||||||
|
query_images: Optional[List[str]] = None,
|
||||||
|
generated_asset_results: Dict[str, Dict] = {},
|
||||||
|
program_execution_context: List[str] = None,
|
||||||
|
chat_history: List[ChatMessageModel] = [],
|
||||||
|
location_data: LocationData = None,
|
||||||
|
user_name: str = None,
|
||||||
|
# Model config
|
||||||
|
agent: Agent = None,
|
||||||
|
model_name: str = None,
|
||||||
|
model_type: ChatModel.ModelType = None,
|
||||||
|
max_prompt_size: int = None,
|
||||||
|
tokenizer_name: str = None,
|
||||||
|
vision_available: bool = False,
|
||||||
|
) -> List[ChatMessage]:
|
||||||
|
"""
|
||||||
|
Construct system, context and chatml messages for chat response.
|
||||||
|
Share common logic across different model types.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of ChatMessages with context
|
||||||
|
"""
|
||||||
|
# Initialize Variables
|
||||||
|
current_date = datetime.now()
|
||||||
|
|
||||||
|
# Build system prompt
|
||||||
|
if agent and agent.personality:
|
||||||
|
system_prompt = prompts.custom_personality.format(
|
||||||
|
name=agent.name,
|
||||||
|
bio=agent.personality,
|
||||||
|
current_date=current_date.strftime("%Y-%m-%d"),
|
||||||
|
day_of_week=current_date.strftime("%A"),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
system_prompt = prompts.personality.format(
|
||||||
|
current_date=current_date.strftime("%Y-%m-%d"),
|
||||||
|
day_of_week=current_date.strftime("%A"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add Gemini-specific personality enhancement
|
||||||
|
if model_type == ChatModel.ModelType.GOOGLE:
|
||||||
|
system_prompt += f"\n\n{prompts.gemini_verbose_language_personality}"
|
||||||
|
|
||||||
|
# Add location context if available
|
||||||
|
if location_data:
|
||||||
|
location_prompt = prompts.user_location.format(location=f"{location_data}")
|
||||||
|
system_prompt += f"\n{location_prompt}"
|
||||||
|
|
||||||
|
# Add user name context if available
|
||||||
|
if user_name:
|
||||||
|
user_name_prompt = prompts.user_name.format(name=user_name)
|
||||||
|
system_prompt += f"\n{user_name_prompt}"
|
||||||
|
|
||||||
|
# Build context message
|
||||||
|
context_message = ""
|
||||||
|
if not is_none_or_empty(references):
|
||||||
|
context_message = f"{prompts.notes_conversation.format(references=yaml_dump(references))}\n\n"
|
||||||
|
if not is_none_or_empty(online_results):
|
||||||
|
context_message += f"{prompts.online_search_conversation.format(online_results=yaml_dump(online_results))}\n\n"
|
||||||
|
if not is_none_or_empty(code_results):
|
||||||
|
context_message += (
|
||||||
|
f"{prompts.code_executed_context.format(code_results=truncate_code_context(code_results))}\n\n"
|
||||||
|
)
|
||||||
|
if not is_none_or_empty(operator_results):
|
||||||
|
operator_content = [
|
||||||
|
{"query": oc.query, "response": oc.response, "webpages": oc.webpages} for oc in operator_results
|
||||||
|
]
|
||||||
|
context_message += (
|
||||||
|
f"{prompts.operator_execution_context.format(operator_results=yaml_dump(operator_content))}\n\n"
|
||||||
|
)
|
||||||
|
context_message = context_message.strip()
|
||||||
|
|
||||||
|
# Generate the chatml messages
|
||||||
|
messages = generate_chatml_messages_with_context(
|
||||||
|
user_message=user_query,
|
||||||
|
query_files=query_files,
|
||||||
|
query_images=query_images,
|
||||||
|
context_message=context_message,
|
||||||
|
generated_asset_results=generated_asset_results,
|
||||||
|
program_execution_context=program_execution_context,
|
||||||
|
chat_history=chat_history,
|
||||||
|
system_message=system_prompt,
|
||||||
|
model_name=model_name,
|
||||||
|
model_type=model_type,
|
||||||
|
max_prompt_size=max_prompt_size,
|
||||||
|
tokenizer_name=tokenizer_name,
|
||||||
|
vision_enabled=vision_available,
|
||||||
|
)
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
async def agenerate_chat_response(
|
async def agenerate_chat_response(
|
||||||
q: str,
|
q: str,
|
||||||
chat_history: List[ChatMessageModel],
|
chat_history: List[ChatMessageModel],
|
||||||
@@ -1645,33 +1747,39 @@ async def agenerate_chat_response(
|
|||||||
chat_model = vision_enabled_config
|
chat_model = vision_enabled_config
|
||||||
vision_available = True
|
vision_available = True
|
||||||
|
|
||||||
|
# Build shared conversation context and generate chatml messages
|
||||||
|
messages = build_conversation_context(
|
||||||
|
user_query=query_to_run,
|
||||||
|
references=compiled_references,
|
||||||
|
online_results=online_results,
|
||||||
|
code_results=code_results,
|
||||||
|
operator_results=operator_results,
|
||||||
|
query_files=query_files,
|
||||||
|
query_images=query_images,
|
||||||
|
generated_asset_results=generated_asset_results,
|
||||||
|
program_execution_context=program_execution_context,
|
||||||
|
chat_history=chat_history,
|
||||||
|
location_data=location_data,
|
||||||
|
user_name=user_name,
|
||||||
|
agent=agent,
|
||||||
|
model_type=chat_model.model_type,
|
||||||
|
model_name=chat_model.name,
|
||||||
|
max_prompt_size=max_prompt_size,
|
||||||
|
tokenizer_name=chat_model.tokenizer,
|
||||||
|
vision_available=vision_available,
|
||||||
|
)
|
||||||
|
|
||||||
if chat_model.model_type == ChatModel.ModelType.OPENAI:
|
if chat_model.model_type == ChatModel.ModelType.OPENAI:
|
||||||
openai_chat_config = chat_model.ai_model_api
|
openai_chat_config = chat_model.ai_model_api
|
||||||
api_key = openai_chat_config.api_key
|
api_key = openai_chat_config.api_key
|
||||||
chat_model_name = chat_model.name
|
chat_model_name = chat_model.name
|
||||||
chat_response_generator = converse_openai(
|
chat_response_generator = converse_openai(
|
||||||
# Query
|
# Query + Context Messages
|
||||||
query_to_run,
|
messages,
|
||||||
# Context
|
|
||||||
references=compiled_references,
|
|
||||||
online_results=online_results,
|
|
||||||
code_results=code_results,
|
|
||||||
operator_results=operator_results,
|
|
||||||
query_images=query_images,
|
|
||||||
query_files=query_files,
|
|
||||||
generated_asset_results=generated_asset_results,
|
|
||||||
program_execution_context=program_execution_context,
|
|
||||||
location_data=location_data,
|
|
||||||
user_name=user_name,
|
|
||||||
chat_history=chat_history,
|
|
||||||
# Model
|
# Model
|
||||||
model=chat_model_name,
|
model=chat_model_name,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_base_url=openai_chat_config.api_base_url,
|
api_base_url=openai_chat_config.api_base_url,
|
||||||
max_prompt_size=max_prompt_size,
|
|
||||||
tokenizer_name=chat_model.tokenizer,
|
|
||||||
agent=agent,
|
|
||||||
vision_available=vision_available,
|
|
||||||
deepthought=deepthought,
|
deepthought=deepthought,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
@@ -1680,28 +1788,12 @@ async def agenerate_chat_response(
|
|||||||
api_key = chat_model.ai_model_api.api_key
|
api_key = chat_model.ai_model_api.api_key
|
||||||
api_base_url = chat_model.ai_model_api.api_base_url
|
api_base_url = chat_model.ai_model_api.api_base_url
|
||||||
chat_response_generator = converse_anthropic(
|
chat_response_generator = converse_anthropic(
|
||||||
# Query
|
# Query + Context Messages
|
||||||
query_to_run,
|
messages,
|
||||||
# Context
|
|
||||||
references=compiled_references,
|
|
||||||
online_results=online_results,
|
|
||||||
code_results=code_results,
|
|
||||||
operator_results=operator_results,
|
|
||||||
query_images=query_images,
|
|
||||||
query_files=query_files,
|
|
||||||
generated_asset_results=generated_asset_results,
|
|
||||||
program_execution_context=program_execution_context,
|
|
||||||
location_data=location_data,
|
|
||||||
user_name=user_name,
|
|
||||||
chat_history=chat_history,
|
|
||||||
# Model
|
# Model
|
||||||
model=chat_model.name,
|
model=chat_model.name,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_base_url=api_base_url,
|
api_base_url=api_base_url,
|
||||||
max_prompt_size=max_prompt_size,
|
|
||||||
tokenizer_name=chat_model.tokenizer,
|
|
||||||
agent=agent,
|
|
||||||
vision_available=vision_available,
|
|
||||||
deepthought=deepthought,
|
deepthought=deepthought,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
@@ -1709,28 +1801,12 @@ async def agenerate_chat_response(
|
|||||||
api_key = chat_model.ai_model_api.api_key
|
api_key = chat_model.ai_model_api.api_key
|
||||||
api_base_url = chat_model.ai_model_api.api_base_url
|
api_base_url = chat_model.ai_model_api.api_base_url
|
||||||
chat_response_generator = converse_gemini(
|
chat_response_generator = converse_gemini(
|
||||||
# Query
|
# Query + Context Messages
|
||||||
query_to_run,
|
messages,
|
||||||
# Context
|
|
||||||
references=compiled_references,
|
|
||||||
online_results=online_results,
|
|
||||||
code_results=code_results,
|
|
||||||
operator_results=operator_results,
|
|
||||||
query_images=query_images,
|
|
||||||
query_files=query_files,
|
|
||||||
generated_asset_results=generated_asset_results,
|
|
||||||
program_execution_context=program_execution_context,
|
|
||||||
location_data=location_data,
|
|
||||||
user_name=user_name,
|
|
||||||
chat_history=chat_history,
|
|
||||||
# Model
|
# Model
|
||||||
model=chat_model.name,
|
model=chat_model.name,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_base_url=api_base_url,
|
api_base_url=api_base_url,
|
||||||
max_prompt_size=max_prompt_size,
|
|
||||||
tokenizer_name=chat_model.tokenizer,
|
|
||||||
agent=agent,
|
|
||||||
vision_available=vision_available,
|
|
||||||
deepthought=deepthought,
|
deepthought=deepthought,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user