Files
khoj/src/khoj/processor/conversation/utils.py
Debanjum 57d6ebb1b8 Support Google Gemini 3
- Use thinking level for gemini 3 models instead of thinking budget.
- Bump google gemini library
- Add default context, pricing
2025-11-18 14:13:24 -08:00

1155 lines
43 KiB
Python

import base64
import json
import logging
import mimetypes
import os
import re
import uuid
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from io import BytesIO
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
import PIL.Image
import pyjson5
import requests
import yaml
from langchain_core.messages.chat import ChatMessage
from pydantic import BaseModel, ConfigDict, ValidationError
from khoj.database.adapters import ConversationAdapters
from khoj.database.models import (
ChatMessageModel,
ClientApplication,
Intent,
KhojUser,
)
from khoj.processor.conversation import prompts
from khoj.search_filter.base_filter import BaseFilter
from khoj.search_filter.date_filter import DateFilter
from khoj.search_filter.file_filter import FileFilter
from khoj.search_filter.word_filter import WordFilter
from khoj.utils.helpers import (
ConversationCommand,
count_tokens,
get_encoder,
is_none_or_empty,
is_promptrace_enabled,
merge_dicts,
)
from khoj.utils.rawconfig import FileAttachment
from khoj.utils.yaml import yaml_dump
logger = logging.getLogger(__name__)
try:
import importlib.util
git_spec = importlib.util.find_spec("git")
if git_spec is None:
raise ImportError
except ImportError:
if is_promptrace_enabled():
logger.warning("GitPython not installed. `pip install gitpython` to use prompt tracer.")
model_to_prompt_size = {
# OpenAI Models
"gpt-4o": 60000,
"gpt-4o-mini": 60000,
"gpt-4.1": 60000,
"gpt-4.1-mini": 120000,
"gpt-4.1-nano": 120000,
"o1-mini": 90000,
"o1": 30000,
"o3-mini": 90000,
"o3": 60000,
"o3-pro": 30000,
"o4-mini": 90000,
"gpt-5-2025-08-07": 120000,
"gpt-5-mini-2025-08-07": 120000,
"gpt-5-nano-2025-08-07": 120000,
# Google Models
"gemini-3-pro-preview": 120000,
"gemini-2.5-flash": 120000,
"gemini-2.5-flash-lite": 120000,
"gemini-2.5-pro": 60000,
"gemini-2.0-flash": 120000,
"gemini-2.0-flash-lite": 120000,
"gemini-1.5-flash": 120000,
"gemini-1.5-pro": 60000,
# Anthropic Models
"claude-3-5-sonnet-20241022": 60000,
"claude-3-5-sonnet-latest": 60000,
"claude-3-7-sonnet-20250219": 60000,
"claude-3-7-sonnet-latest": 60000,
"claude-3-5-haiku-20241022": 60000,
"claude-haiku-4-5-20251001": 60000,
"claude-sonnet-4-0": 60000,
"claude-sonnet-4-20250514": 60000,
"claude-opus-4-0": 60000,
"claude-opus-4-20250514": 60000,
}
model_to_tokenizer: Dict[str, str] = {}
class AgentMessage(BaseModel):
role: Literal["user", "assistant", "system", "environment"]
content: Union[str, List]
class OperatorRun:
def __init__(
self,
query: str,
trajectory: list[AgentMessage] | list[dict] = None,
response: str = None,
webpages: list[dict] = None,
):
self.query = query
self.response = response
self.webpages = webpages or []
self.trajectory: list[AgentMessage] = []
if trajectory:
for item in trajectory:
if isinstance(item, dict):
self.trajectory.append(AgentMessage(**item))
elif hasattr(item, "role") and hasattr(item, "content"): # Heuristic for AgentMessage like object
self.trajectory.append(item)
else:
logger.warning(f"Unexpected item type in trajectory: {type(item)}")
def to_dict(self) -> dict:
# Ensure AgentMessage instances in trajectory are also dicts
serialized_trajectory = []
for msg in self.trajectory:
if hasattr(msg, "model_dump"): # Check if it's a Pydantic model
serialized_trajectory.append(msg.model_dump())
elif isinstance(msg, dict):
serialized_trajectory.append(msg) # Already a dict
return {
"query": self.query,
"response": self.response,
"trajectory": serialized_trajectory,
"webpages": self.webpages,
}
class ToolCall:
def __init__(self, name: str, args: dict, id: str):
self.name = name
self.args = args
self.id = id
class ResearchIteration:
def __init__(
self,
query: ToolCall | dict | str,
context: list = None,
onlineContext: dict = None,
codeContext: dict = None,
operatorContext: dict | OperatorRun = None,
summarizedResult: str = None,
warning: str = None,
raw_response: list = None,
):
self.query = ToolCall(**query) if isinstance(query, dict) else query
self.context = context
self.onlineContext = onlineContext
self.codeContext = codeContext
self.operatorContext = OperatorRun(**operatorContext) if isinstance(operatorContext, dict) else operatorContext
self.summarizedResult = summarizedResult
self.warning = warning
self.raw_response = raw_response
def to_dict(self) -> dict:
data = vars(self).copy()
data["query"] = self.query.__dict__ if isinstance(self.query, ToolCall) else self.query
data["operatorContext"] = self.operatorContext.to_dict() if self.operatorContext else None
return data
def construct_iteration_history(
previous_iterations: List[ResearchIteration],
query: str = None,
query_images: List[str] = None,
query_files: str = None,
) -> list[ChatMessageModel]:
iteration_history: list[ChatMessageModel] = []
query_message_content = construct_structured_message(query, query_images, attached_file_context=query_files)
if query_message_content:
iteration_history.append(ChatMessageModel(by="you", message=query_message_content))
for iteration in previous_iterations:
if not iteration.query or isinstance(iteration.query, str):
iteration_history.append(
ChatMessageModel(
by="you",
message=iteration.summarizedResult
or iteration.warning
or "Please specify what you want to do next.",
)
)
continue
iteration_history += [
ChatMessageModel(
by="khoj",
message=iteration.raw_response or [iteration.query.__dict__],
intent=Intent(type="tool_call", query=query),
),
ChatMessageModel(
by="you",
intent=Intent(type="tool_result"),
message=[
{
"type": "tool_result",
"id": iteration.query.id,
"name": iteration.query.name,
"content": iteration.summarizedResult,
}
],
),
]
return iteration_history
def construct_chat_history(chat_history: list[ChatMessageModel], n: int = 4, agent_name="AI") -> str:
chat_history_str = ""
for chat in chat_history[-n:]:
intent_type = chat.intent.type if chat.intent and chat.intent.type else ""
inferred_queries = chat.intent.inferred_queries if chat.intent else None
if chat.by == "khoj" and intent_type in ["remember", "reminder", "summarize"]:
if inferred_queries:
chat_history_str += f'{agent_name}: {{"queries": {inferred_queries}}}\n'
chat_history_str += f"{agent_name}: {chat.message}\n\n"
elif chat.by == "khoj" and chat.images:
chat_history_str += f"{agent_name}: [generated image redacted for space]\n"
elif chat.by == "khoj" and ("excalidraw" in intent_type):
chat_history_str += f"{agent_name}: {inferred_queries[0]}\n"
elif chat.by == "you":
chat_history_str += f"User: {chat.message}\n"
raw_query_files = chat.queryFiles
if raw_query_files:
query_files: Dict[str, str] = {}
for file in raw_query_files:
query_files[file["name"]] = file["content"]
query_file_context = gather_raw_query_files(query_files)
chat_history_str += f"User: {query_file_context}\n"
return chat_history_str
def construct_question_history(
conversation_log: list[ChatMessageModel],
include_query: bool = True,
lookback: int = 6,
query_prefix: str = "Q",
agent_name: str = "Khoj",
) -> str:
"""
Constructs a chat history string formatted for query extraction purposes.
"""
history_parts = ""
original_query = None
for chat in conversation_log[-lookback:]:
if chat.by == "you":
original_query = json.dumps(chat.message)
history_parts += f"{query_prefix}: {original_query}\n"
if chat.by == "khoj":
if original_query is None:
continue
message = chat.message
inferred_queries_list = chat.intent.inferred_queries or [] if chat.intent else []
# Ensure inferred_queries_list is a list, defaulting to the original query in a list
if not inferred_queries_list:
inferred_queries_list = [original_query]
# If it's a string (though unlikely based on usage), wrap it in a list
elif isinstance(inferred_queries_list, str):
inferred_queries_list = [inferred_queries_list]
if include_query:
# Ensure 'type' exists and is a string before checking 'to-image'
intent_type = chat.intent.type if chat.intent and chat.intent.type else ""
if "to-image" not in intent_type:
history_parts += f'{agent_name}: {{"queries": {inferred_queries_list}}}\n'
history_parts += f"A: {message}\n\n"
else:
history_parts += f"{agent_name}: {message}\n\n"
# Reset original_query for the next turn
original_query = None
return history_parts
def construct_chat_history_for_operator(conversation_history: List[ChatMessageModel], n: int = 6) -> list[AgentMessage]:
"""
Construct chat history for operator agent in conversation log.
Only include last n completed turns (i.e with user and khoj message).
"""
chat_history: list[AgentMessage] = []
user_message: Optional[AgentMessage] = None
for chat in conversation_history:
if len(chat_history) >= n:
break
if chat.by == "you" and chat.message:
content = [{"type": "text", "text": chat.message}]
for file in chat.queryFiles or []:
content += [{"type": "text", "text": f"## File: {file['name']}\n\n{file['content']}"}]
user_message = AgentMessage(role="user", content=content)
elif chat.by == "khoj" and chat.message:
chat_history += [user_message, AgentMessage(role="assistant", content=chat.message)]
return chat_history
def construct_tool_chat_history(
previous_iterations: List[ResearchIteration], tool: ConversationCommand = None
) -> List[ChatMessageModel]:
"""
Construct chat history from previous iterations for a specific tool
If a tool is provided, only the inferred queries for that tool is added.
If no tool is provided inferred query for all tools used are added.
"""
chat_history: list = []
def base_extractor(iteration: ResearchIteration) -> List[str]:
return []
extract_inferred_query_map: Dict[ConversationCommand, Callable[[ResearchIteration], List[str]]] = {
ConversationCommand.SemanticSearchFiles: (
lambda iteration: [c["query"] for c in iteration.context] if iteration.context else []
),
ConversationCommand.SearchWeb: (
lambda iteration: list(iteration.onlineContext.keys()) if iteration.onlineContext else []
),
ConversationCommand.ReadWebpage: (
lambda iteration: list(iteration.onlineContext.keys()) if iteration.onlineContext else []
),
ConversationCommand.PythonCoder: (
lambda iteration: list(iteration.codeContext.keys()) if iteration.codeContext else []
),
}
for iteration in previous_iterations:
if not iteration.query or isinstance(iteration.query, str):
chat_history.append(
ChatMessageModel(
by="you",
message=iteration.summarizedResult
or iteration.warning
or "Please specify what you want to do next.",
)
)
continue
# If a tool is provided use the inferred query extractor for that tool if available
# If no tool is provided, use inferred query extractor for the tool used in the iteration
# Fallback to base extractor if the tool does not have an inferred query extractor
inferred_query_extractor = extract_inferred_query_map.get(
tool or ConversationCommand(iteration.query.name), base_extractor
)
chat_history += [
ChatMessageModel(
by="you",
message=yaml.dump(iteration.query.args, default_flow_style=False),
),
ChatMessageModel(
by="khoj",
intent=Intent(
type="remember",
query=yaml.dump(iteration.query.args, default_flow_style=False),
inferred_queries=inferred_query_extractor(iteration),
memory_type="notes",
),
message=iteration.summarizedResult,
),
]
return chat_history
class ChatEvent(Enum):
START_LLM_RESPONSE = "start_llm_response"
END_LLM_RESPONSE = "end_llm_response"
MESSAGE = "message"
REFERENCES = "references"
GENERATED_ASSETS = "generated_assets"
STATUS = "status"
THOUGHT = "thought"
METADATA = "metadata"
USAGE = "usage"
END_RESPONSE = "end_response"
INTERRUPT = "interrupt"
END_EVENT = "␃🔚␗"
def message_to_log(
user_message,
chat_response,
user_message_metadata={},
khoj_message_metadata={},
chat_history: List[ChatMessageModel] = [],
) -> List[ChatMessageModel]:
"""Create json logs from messages, metadata for conversation log"""
default_khoj_message_metadata = {
"intent": {"type": "remember", "memory-type": "notes", "query": user_message},
}
khoj_response_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
# Filter out any fields that are set to None
user_message_metadata = {k: v for k, v in user_message_metadata.items() if v is not None}
khoj_message_metadata = {k: v for k, v in khoj_message_metadata.items() if v is not None}
# Create json log from Human's message
human_log = merge_dicts({"message": user_message, "by": "you"}, user_message_metadata)
# Create json log from GPT's response
khoj_log = merge_dicts(khoj_message_metadata, default_khoj_message_metadata)
khoj_log = merge_dicts({"message": chat_response, "by": "khoj", "created": khoj_response_time}, khoj_log)
# Validate message logs
# Only validates top-level fields, not nested fields, defined in ChatMessageModel
class StrictChatMessageModel(ChatMessageModel):
model_config = ConfigDict(extra="forbid", strict=True)
try:
StrictChatMessageModel(**human_log)
except ValidationError as e:
logger.error(f"Validation error in user chat message: {e}\nUser Message: {human_log}\n")
try:
StrictChatMessageModel(**khoj_log)
except ValidationError as e:
logger.error(f"Validation error in khoj chat message: {e}\nKhoj Message: {khoj_log}\n")
human_message = ChatMessageModel(**human_log)
khoj_message = ChatMessageModel(**khoj_log)
chat_history.extend([human_message, khoj_message])
return chat_history
async def save_to_conversation_log(
q: str,
chat_response: str,
user: KhojUser,
user_message_time: str = None,
compiled_references: List[Dict[str, Any]] = [],
online_results: Dict[str, Any] = {},
code_results: Dict[str, Any] = {},
operator_results: List[OperatorRun] = None,
inferred_queries: List[str] = [],
intent_type: str = "remember",
client_application: ClientApplication = None,
conversation_id: str = None,
automation_id: str = None,
query_images: List[str] = None,
raw_query_files: List[FileAttachment] = [],
generated_images: List[str] = [],
generated_mermaidjs_diagram: str = None,
research_results: Optional[List[ResearchIteration]] = None,
train_of_thought: List[Any] = [],
tracer: Dict[str, Any] = {},
):
user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
turn_id = tracer.get("mid") or str(uuid.uuid4())
user_message_metadata = {"created": user_message_time, "images": query_images, "turnId": turn_id}
if raw_query_files and len(raw_query_files) > 0:
user_message_metadata["queryFiles"] = [file.model_dump(mode="json") for file in raw_query_files]
khoj_message_metadata = {
"context": compiled_references,
"intent": {"inferred-queries": inferred_queries, "type": intent_type},
"onlineContext": online_results,
"codeContext": code_results,
"operatorContext": [o.to_dict() for o in operator_results] if operator_results and not chat_response else None,
"researchContext": [r.to_dict() for r in research_results] if research_results and not chat_response else None,
"automationId": automation_id,
"trainOfThought": train_of_thought,
"turnId": turn_id,
"images": generated_images,
}
if generated_mermaidjs_diagram:
khoj_message_metadata["mermaidjsDiagram"] = generated_mermaidjs_diagram
try:
new_messages = message_to_log(
user_message=q,
chat_response=chat_response,
user_message_metadata=user_message_metadata,
khoj_message_metadata=khoj_message_metadata,
chat_history=[],
)
except ValidationError as e:
new_messages = None
logger.error(f"Error constructing chat history: {e}")
db_conversation = None
if new_messages:
db_conversation = await ConversationAdapters.save_conversation(
user,
new_messages,
client_application=client_application,
conversation_id=conversation_id,
user_message=q,
)
if is_promptrace_enabled():
merge_message_into_conversation_trace(q, chat_response, tracer)
logger.info(
f"""
Saved Conversation Turn ({db_conversation.id if db_conversation else "N/A"}):
You ({user.username}): "{q}"
Khoj: "{chat_response}"
""".strip()
)
def construct_structured_message(
message: list[dict] | str,
images: list[str] = None,
model_type: str = None,
vision_enabled: bool = True,
attached_file_context: str = None,
):
"""
Format messages into appropriate multimedia format for supported chat model types.
Assume vision is enabled and chat model provider supports messages in chatml format, unless specified otherwise.
"""
constructed_messages: List[dict[str, Any]] = []
if not is_none_or_empty(message):
constructed_messages += [{"type": "text", "text": message}] if isinstance(message, str) else message
# Drop image message passed by caller if chat model does not have vision enabled
if not vision_enabled:
constructed_messages = [m for m in constructed_messages if m.get("type") != "image_url"]
if not is_none_or_empty(attached_file_context):
constructed_messages += [{"type": "text", "text": attached_file_context}]
if vision_enabled and images:
for image in images:
constructed_messages += [{"type": "image_url", "image_url": {"url": image}}]
return constructed_messages
def gather_raw_query_files(
query_files: Dict[str, str],
):
"""
Gather contextual data from the given (raw) files
"""
if len(query_files) == 0:
return ""
contextual_data = " ".join(
[f"File: {file_name}\n\n{file_content}\n\n" for file_name, file_content in query_files.items()]
)
return f"I have attached the following files:\n\n{contextual_data}"
def generate_chatml_messages_with_context(
# Context
user_message: str,
query_files: str = None,
query_images=None,
context_message="",
generated_asset_results: Dict[str, Dict] = {},
program_execution_context: List[str] = [],
chat_history: list[ChatMessageModel] = [],
system_message: str = None,
# Model Config
model_name="gpt-4o-mini",
model_type="",
max_prompt_size=None,
tokenizer_name=None,
vision_enabled=False,
):
"""Generate chat messages with appropriate context from previous conversation to send to the chat model"""
# Set max prompt size from user config or based on pre-configured for model and machine specs
if not max_prompt_size:
max_prompt_size = model_to_prompt_size.get(model_name, 10000)
# Scale lookback turns proportional to max prompt size supported by model
lookback_turns = max_prompt_size // 750
# Extract Chat History for Context
chatml_messages: List[ChatMessage] = []
for chat in chat_history:
message_context = []
message_attached_files = ""
generated_assets = {}
chat_message = chat.message
role = "user" if chat.by == "you" else "assistant"
# Legacy code to handle excalidraw diagrams prior to Dec 2024
if chat.by == "khoj" and chat.intent and "excalidraw" in chat.intent.type:
chat_message = (chat.intent.inferred_queries or [])[0]
# Add search and action context
if not is_none_or_empty(chat.onlineContext):
message_context += [
{
"type": "text",
"text": f"{prompts.online_search_conversation.format(online_results=chat.onlineContext)}",
}
]
if not is_none_or_empty(chat.codeContext):
message_context += [
{
"type": "text",
"text": f"{prompts.code_executed_context.format(code_results=chat.codeContext)}",
}
]
if not is_none_or_empty(chat.operatorContext):
operator_context = chat.operatorContext
operator_content = "\n\n".join([f"## Task: {oc['query']}\n{oc['response']}\n" for oc in operator_context])
message_context += [
{
"type": "text",
"text": f"{prompts.operator_execution_context.format(operator_results=operator_content)}",
}
]
if not is_none_or_empty(chat.context):
references = "\n\n".join(
{
f"# URI: {item.uri or item.file}\n## {item.compiled}\n"
for item in chat.context or []
if isinstance(item, dict)
}
)
message_context += [{"type": "text", "text": f"{prompts.notes_conversation.format(references=references)}"}]
if not is_none_or_empty(message_context):
reconstructed_context_message = ChatMessage(content=message_context, role="user")
chatml_messages.append(reconstructed_context_message)
# Add generated assets
if not is_none_or_empty(chat.images) and role == "assistant":
generated_assets["image"] = {
"description": (chat.intent.inferred_queries or [user_message])[0],
}
if not is_none_or_empty(chat.mermaidjsDiagram) and role == "assistant":
generated_assets["diagram"] = {
"query": (chat.intent.inferred_queries or [user_message])[0],
}
if not is_none_or_empty(generated_assets):
chatml_messages.append(
ChatMessage(
content=f"{prompts.generated_assets_context.format(generated_assets=yaml_dump(generated_assets))}\n",
role="user",
)
)
# Add user query with attached file, images or khoj response
if chat.queryFiles:
raw_query_files = chat.queryFiles
query_files_dict = dict()
for file in raw_query_files:
query_files_dict[file["name"]] = file["content"]
message_attached_files = gather_raw_query_files(query_files_dict)
message_content = construct_structured_message(
chat_message, chat.images if role == "user" else [], model_type, vision_enabled, message_attached_files
)
reconstructed_message = ChatMessage(
content=message_content,
role=role,
additional_kwargs={"message_type": chat.intent.type if chat.intent else None},
)
chatml_messages.append(reconstructed_message)
if len(chatml_messages) >= 3 * lookback_turns:
break
messages: list[ChatMessage] = []
if not is_none_or_empty(system_message):
messages.append(ChatMessage(content=system_message, role="system"))
if len(chatml_messages) > 0:
messages += chatml_messages
if program_execution_context:
program_context_text = "\n".join(program_execution_context)
context_message += f"{prompts.additional_program_context.format(context=program_context_text)}\n"
if not is_none_or_empty(context_message):
messages.append(ChatMessage(content=context_message, role="user"))
if not is_none_or_empty(generated_asset_results):
messages.append(
ChatMessage(
content=prompts.generated_assets_context.format(generated_assets=yaml_dump(generated_asset_results)),
role="user",
),
)
if not is_none_or_empty(user_message):
messages.append(
ChatMessage(
content=construct_structured_message(
user_message, query_images, model_type, vision_enabled, query_files
),
role="user",
)
)
# Normalize message content to list of chatml dictionaries
for message in messages:
if isinstance(message.content, str):
message.content = [{"type": "text", "text": message.content}]
# Truncate oldest messages from conversation history until under max supported prompt size by model
messages = truncate_messages(messages, max_prompt_size, model_name, tokenizer_name)
# Return messages in chronological order
return messages
def count_total_tokens(
messages: list[ChatMessage], encoder, system_message: Optional[list[ChatMessage]] = None
) -> Tuple[int, int]:
"""Count total tokens in messages including system message"""
system_message_tokens = (
sum([count_tokens(message.content, encoder) for message in system_message]) if system_message else 0
)
message_tokens = sum([count_tokens(message.content, encoder) for message in messages])
# Reserves 4 tokens to demarcate each message (e.g <|im_start|>user, <|im_end|>, <|endoftext|> etc.)
total_tokens = message_tokens + system_message_tokens + 4 * len(messages)
return total_tokens, system_message_tokens
def truncate_messages(
messages: list[ChatMessage],
max_prompt_size: int,
model_name: str,
tokenizer_name=None,
) -> list[ChatMessage]:
"""Truncate messages to fit within max prompt size supported by model"""
encoder = get_encoder(model_name, tokenizer_name)
# Extract system message from messages
system_message = []
non_system_messages = []
for message in messages:
if message.role == "system":
system_message.append(message)
else:
non_system_messages.append(message)
# New message list without system messages
messages = non_system_messages
# Drop older messages until under max supported prompt size by model
total_tokens, system_message_tokens = count_total_tokens(messages, encoder, system_message)
while total_tokens > max_prompt_size and (len(messages) > 1 or len(messages[0].content) > 1):
# If the last message has more than one content part, pop the oldest content part.
# For tool calls, the whole message should dropped, assistant's tool call content being truncated annoys AI APIs.
if len(messages[0].content) > 1 and messages[0].additional_kwargs.get("message_type") != "tool_call":
# The oldest content part is earlier in content list. So pop from the front.
messages[0].content.pop(0)
# Otherwise, pop the last message if it has only one content part or is a tool call.
else:
# The oldest message is the last one. So pop from the back.
dropped_message = messages.pop(0)
# Drop tool result pair of tool call, if tool call message has been removed
if (
dropped_message.additional_kwargs.get("message_type") == "tool_call"
and messages
and messages[0].additional_kwargs.get("message_type") == "tool_result"
):
messages.pop(0)
total_tokens, _ = count_total_tokens(messages, encoder, system_message)
# Truncate current message if still over max supported prompt size by model
total_tokens, _ = count_total_tokens(messages, encoder, system_message)
if total_tokens > max_prompt_size:
# At this point, a single message with a single content part of type dict should remain
assert len(messages) == 1 and len(messages[0].content) == 1 and isinstance(messages[0].content[0], dict), (
"Expected a single message with a single content part remaining at this point in truncation"
)
# Collate message content into single string to ease truncation
part = messages[0].content[0]
message_content: str = part["text"] if part["type"] == "text" else json.dumps(part)
message_role = messages[0].role
remaining_context = "\n".join(message_content.split("\n")[:-1])
original_question = "\n" + "\n".join(message_content.split("\n")[-1:])
original_question_tokens = count_tokens(original_question, encoder)
remaining_tokens = max_prompt_size - system_message_tokens
if remaining_tokens > original_question_tokens:
remaining_tokens -= original_question_tokens
truncated_context = encoder.decode(encoder.encode(remaining_context)[:remaining_tokens]).strip()
truncated_content = truncated_context + original_question
else:
truncated_content = encoder.decode(encoder.encode(original_question)[:remaining_tokens]).strip()
messages = [ChatMessage(content=[{"type": "text", "text": truncated_content}], role=message_role)]
truncated_snippet = (
f"{truncated_content[:1000]}\n...\n{truncated_content[-1000:]}"
if len(truncated_content) > 2000
else truncated_content
)
logger.debug(
f"Truncate current message to fit within max prompt size of {max_prompt_size} supported by {model_name} model:\n {truncated_snippet}"
)
return system_message + messages if system_message else messages
def reciprocal_conversation_to_chatml(message_pair):
"""Convert a single back and forth between user and assistant to chatml format"""
return [ChatMessage(content=message, role=role) for message, role in zip(message_pair, ["user", "assistant"])]
def clean_json(response: str):
"""Remove any markdown json codeblock and newline formatting if present. Useful for non schema enforceable models"""
return response.strip().replace("\n", "").removeprefix("```json").removesuffix("```")
def clean_mermaidjs(response: str):
"""Remove any markdown mermaidjs codeblock and newline formatting if present. Useful for non schema enforceable models"""
return response.strip().removeprefix("```mermaid").removesuffix("```")
def clean_code_python(code: str):
"""Remove any markdown codeblock and newline formatting if present. Useful for non schema enforceable models"""
return code.strip().removeprefix("```python").removesuffix("```")
def load_complex_json(json_str):
"""
Preprocess a raw JSON string to
- escape unescaped double quotes within value strings while preserving the JSON structure and already escaped quotes.
- remove suffix after the first valid JSON object,
"""
def replace_unescaped_quotes(match):
# Get the content between colons and commas/end braces
content = match.group(1)
# Replace unescaped double, single quotes that aren't already escaped
# Uses negative lookbehind to avoid replacing already escaped quotes
# Replace " with \"
processed_dq = re.sub(r'(?<!\\)"', '\\"', content)
# Replace \' with \\'
processed_final = re.sub(r"(?<!\\)\\'", r"\\\\'", processed_dq)
return f': "{processed_final}"'
# Match content between : and either , or }
# This pattern looks for ': ' followed by any characters until , or }
pattern = r':\s*"(.*?)(?<!\\)"(?=[,}])'
# Process the JSON string
cleaned = clean_json(rf"{json_str}")
processed = re.sub(pattern, replace_unescaped_quotes, cleaned)
# See which json loader can load the processed JSON as valid
errors = []
json_loaders_to_try = [json.loads, pyjson5.loads]
for loads in json_loaders_to_try:
try:
return loads(processed)
except (json.JSONDecodeError, pyjson5.Json5Exception) as e_load:
loader_name = loads.__name__
errors.append(f"{loader_name} (initial parse): {type(e_load).__name__}: {str(e_load)}")
# Handle plain text suffixes by slicing at error position
if hasattr(e_load, "pos") and 0 < e_load.pos < len(processed):
try:
sliced = processed[: e_load.pos].strip()
if sliced:
return loads(sliced)
except Exception as e_slice:
errors.append(
f"{loader_name} after slice at {e_load.pos}: {type(e_slice).__name__}: {str(e_slice)}"
)
# If all loaders fail, raise the aggregated error
raise ValueError(
f"Failed to load JSON with errors: {'; '.join(errors)}\n\n"
f"While attempting to load this cleaned JSON:\n{processed}"
)
def defilter_query(query: str):
"""Remove any query filters in query"""
defiltered_query = query
filters: List[BaseFilter] = [WordFilter(), FileFilter(), DateFilter()]
for filter in filters:
defiltered_query = filter.defilter(defiltered_query)
return defiltered_query
@dataclass
class ImageWithType:
content: Any
type: str
def get_image_from_base64(image: str, type="b64"):
# Extract image type and base64 data from inline image data
image_base64 = image.split(",", 1)[1]
image_type = image.split(";", 1)[0].split(":", 1)[1]
# Convert image to desired format
if type == "b64":
return ImageWithType(content=image_base64, type=image_type)
elif type == "pil":
image_data = base64.b64decode(image_base64)
image_pil = PIL.Image.open(BytesIO(image_data))
return ImageWithType(content=image_pil, type=image_type)
elif type == "bytes":
image_data = base64.b64decode(image_base64)
return ImageWithType(content=image_data, type=image_type)
def get_image_from_url(image_url: str, type="pil"):
try:
response = requests.get(image_url)
response.raise_for_status() # Check if the request was successful
# Get content type from response or infer from URL
content_type = response.headers.get("content-type") or mimetypes.guess_type(image_url)[0] or "image/webp"
# Convert image to desired format
image_data: Any = None
if type == "b64":
image_data = base64.b64encode(response.content).decode("utf-8")
elif type == "pil":
image_data = PIL.Image.open(BytesIO(response.content))
elif type == "bytes":
image_data = response.content
else:
raise ValueError(f"Invalid image type: {type}")
return ImageWithType(content=image_data, type=content_type)
except requests.exceptions.RequestException as e:
logger.error(f"Failed to get image from URL {image_url}: {e}")
return ImageWithType(content=None, type=None)
def commit_conversation_trace(
session: list[ChatMessage],
response: str | list[dict],
tracer: dict,
system_message: str | list[dict] = "",
repo_path: str = None,
) -> str:
"""
Save trace of conversation step using git. Useful to visualize, compare and debug traces.
Returns the path to the repository.
"""
try:
from git import Repo
except ImportError:
return None
# Infer repository path from environment variable or provided path
repo_path = repo_path if not is_none_or_empty(repo_path) else os.getenv("PROMPTRACE_DIR")
if not repo_path:
return None
# Serialize session, system message and response to yaml
system_message_yaml = json.dumps(system_message, ensure_ascii=False, sort_keys=False)
response_yaml = json.dumps(response, ensure_ascii=False, sort_keys=False)
formatted_session = [{"role": message.role, "content": message.content} for message in session]
session_yaml = json.dumps(formatted_session, ensure_ascii=False, sort_keys=False)
query = (
json.dumps(session[-1].content, ensure_ascii=False, sort_keys=False).strip().removeprefix("'").removesuffix("'")
) # Extract serialized query from chat session
# Extract chat metadata for session
uid, cid, mid = tracer.get("uid", "main"), tracer.get("cid", "main"), tracer.get("mid")
try:
# Prepare git repository
os.makedirs(repo_path, exist_ok=True)
repo = Repo.init(repo_path)
# Remove post-commit hook if it exists
hooks_dir = os.path.join(repo_path, ".git", "hooks")
post_commit_hook = os.path.join(hooks_dir, "post-commit")
if os.path.exists(post_commit_hook):
os.remove(post_commit_hook)
# Configure git user if not set
if not repo.config_reader().has_option("user", "email"):
repo.config_writer().set_value("user", "name", "Prompt Tracer").release()
repo.config_writer().set_value("user", "email", "promptracer@khoj.dev").release()
# Create an initial commit if the repository is newly created
if not repo.head.is_valid():
repo.index.commit("And then there was a trace")
# Check out the initial commit
initial_commit = repo.commit("HEAD~0")
repo.head.reference = initial_commit
repo.head.reset(index=True, working_tree=True)
# Create or switch to user branch from initial commit
user_branch = f"u_{uid}"
if user_branch not in repo.branches:
repo.create_head(user_branch)
repo.heads[user_branch].checkout()
# Create or switch to conversation branch from user branch
conv_branch = f"c_{cid}"
if conv_branch not in repo.branches:
repo.create_head(conv_branch)
repo.heads[conv_branch].checkout()
# Create or switch to message branch from conversation branch
msg_branch = f"m_{mid}" if mid else None
if msg_branch and msg_branch not in repo.branches:
repo.create_head(msg_branch)
if msg_branch:
repo.heads[msg_branch].checkout()
# Include file with content to commit
files_to_commit = {"query": session_yaml, "response": response_yaml, "system_prompt": system_message_yaml}
# Write files and stage them
for filename, content in files_to_commit.items():
file_path = os.path.join(repo_path, filename)
# Unescape special characters in content for better readability
content = content.strip().replace("\\n", "\n").replace("\\t", "\t")
with open(file_path, "w", encoding="utf-8") as f:
f.write(content)
repo.index.add([filename])
# Create commit
metadata_yaml = yaml.dump(tracer, allow_unicode=True, sort_keys=False, default_flow_style=False)
commit_message = f"""
{query[:250]}
Response:
---
{response[:500]}...
Metadata
---
{metadata_yaml}
""".strip()
repo.index.commit(commit_message)
logger.debug(f"Saved conversation trace to repo at {repo_path}")
return repo_path
except Exception as e:
logger.error(f"Failed to add conversation trace to repo: {str(e)}", exc_info=True)
return None
def merge_message_into_conversation_trace(query: str, response: str, tracer: dict, repo_path=None) -> bool:
"""
Merge the message branch into its parent conversation branch.
Args:
query: User query
response: Assistant response
tracer: Dictionary containing uid, cid and mid
repo_path: Path to the git repository
Returns:
bool: True if merge was successful, False otherwise
"""
try:
from git import Repo
except ImportError:
return False
try:
# Extract branch names
msg_branch = f"m_{tracer['mid']}"
conv_branch = f"c_{tracer['cid']}"
# Infer repository path from environment variable or provided path
repo_path = repo_path if not is_none_or_empty(repo_path) else os.getenv("PROMPTRACE_DIR")
if not repo_path:
return None
repo = Repo(repo_path)
# Checkout conversation branch
repo.heads[conv_branch].checkout()
# Create commit message
metadata_yaml = yaml.dump(tracer, allow_unicode=True, sort_keys=False, default_flow_style=False)
commit_message = f"""
{query[:250]}
Response:
---
{response[:500]}...
Metadata
---
{metadata_yaml}
""".strip()
# Merge message branch into conversation branch
repo.git.merge(msg_branch, no_ff=True, m=commit_message)
# Delete message branch after merge
repo.delete_head(msg_branch, force=True)
logger.debug(f"Successfully merged {msg_branch} into {conv_branch}")
return True
except Exception as e:
logger.error(f"Failed to merge message {msg_branch} into conversation {conv_branch}: {str(e)}", exc_info=True)
return False
def messages_to_print(messages: list[ChatMessage], max_length: int = 70) -> str:
"""
Format and truncate messages to print, ensuring JSON serializable content
"""
def safe_serialize(content: Any) -> str:
try:
# Try JSON serialization
json.dumps(content)
return content
except (TypeError, json.JSONDecodeError):
# Handle non-serializable types
if hasattr(content, "format") and content.format == "WEBP":
return "[WebP Image]"
elif hasattr(content, "__dict__"):
return str(content.__dict__)
return str(content)
return "\n".join([f"{json.dumps(safe_serialize(message.content))[:max_length]}..." for message in messages])
class StructuredOutputSupport(int, Enum):
NONE = 0
OBJECT = 1
SCHEMA = 2
TOOL = 3
class ResponseWithThought:
def __init__(self, text: str = None, thought: str = None, raw_content: list = None):
self.text = text
self.thought = thought
self.raw_content = raw_content