Get max context for user, operator model pair for context compression

This commit is contained in:
Debanjum
2025-05-30 16:44:01 -07:00
parent 7eaf0e80c5
commit ded1db642c
3 changed files with 22 additions and 5 deletions

View File

@@ -63,15 +63,30 @@ async def operate_environment(
chat_history = construct_chat_history_for_operator(conversation_log)
# Initialize Agent
max_context = await ConversationAdapters.aget_max_context_size(reasoning_model, user) or 20000
max_iterations = int(os.getenv("KHOJ_OPERATOR_ITERATIONS", 100))
operator_agent: OperatorAgent
if is_operator_model(reasoning_model.name) == ChatModel.ModelType.OPENAI:
operator_agent = OpenAIOperatorAgent(
query, reasoning_model, environment_type, max_iterations, chat_history, previous_trajectory, tracer
query,
reasoning_model,
environment_type,
max_iterations,
max_context,
chat_history,
previous_trajectory,
tracer,
)
elif is_operator_model(reasoning_model.name) == ChatModel.ModelType.ANTHROPIC:
operator_agent = AnthropicOperatorAgent(
query, reasoning_model, environment_type, max_iterations, chat_history, previous_trajectory, tracer
query,
reasoning_model,
environment_type,
max_iterations,
max_context,
chat_history,
previous_trajectory,
tracer,
)
else:
grounding_model_name = "ui-tars-1.5"
@@ -88,6 +103,7 @@ async def operate_environment(
grounding_model,
environment_type,
max_iterations,
max_context,
chat_history,
previous_trajectory,
tracer,

View File

@@ -34,6 +34,7 @@ class OperatorAgent(ABC):
vision_model: ChatModel,
environment_type: EnvironmentType,
max_iterations: int,
max_context: int,
chat_history: List[AgentMessage] = [],
previous_trajectory: Optional[OperatorRun] = None,
tracer: dict = {},
@@ -56,9 +57,7 @@ class OperatorAgent(ABC):
# Context compression parameters
self.context_compress_trigger = 2e3 # heuristic to determine compression trigger
# turns after which compression triggered. scales with model max context size. Minimum 5 turns.
self.message_limit = 2 * max(
5, int(self.vision_model.subscribed_max_prompt_size / self.context_compress_trigger)
)
self.message_limit = 2 * max(5, int(max_context / self.context_compress_trigger))
# compression ratio determines how many messages to compress down to one
# e.g. if 5 messages, a compress ratio of 4/5 means compress 5 messages into 1 + keep 1 uncompressed
self.message_compress_ratio = 4 / 5

View File

@@ -40,6 +40,7 @@ class BinaryOperatorAgent(OperatorAgent):
grounding_model: ChatModel,
environment_type: EnvironmentType,
max_iterations: int,
max_context: int,
chat_history: List[AgentMessage] = [],
previous_trajectory: Optional[OperatorRun] = None,
tracer: dict = {},
@@ -49,6 +50,7 @@ class BinaryOperatorAgent(OperatorAgent):
reasoning_model,
environment_type,
max_iterations,
max_context,
chat_history,
previous_trajectory,
tracer,