Operate Computer with Khoj Operator (#1190)

## Summary
- Enable Khoj to operate computers: Add experimental computer operator
functionality that allows Khoj to interact with desktop environments,
browsers, and terminals to accomplish complex tasks
- Multi-environment support: Implement computer environments with GUI,
file system, and terminal access. Can control host computer or Docker
container computer

## Key Features
### Computer Operation Capabilities
- Desktop control (screenshots, clicking, typing, keyboard shortcuts)
- File editing and management
- Terminal/bash command execution
- Web browser automation
- Visual feedback via train-of-thought video playback

### Infrastructure & Architecture:
- Docker container (ghcr.io/khoj-ai/computer:latest) with Ubuntu 24.04,
XFCE desktop, VNC access
- Local computer environment support with pyautogui
- Modular operator agent system supporting multiple environment types
- Trajectory compression and context management for long-running tasks

###  Model Integration:
- Anthropic models only (Claude Sonnet 4, Claude 3.7 Sonnet, Claude Opus
4)
- OpenAI and binary operator agents temporarily disabled
- Enhanced caching and context management for operator conversations

### User Experience:
- `/operator` command or just ask Khoj to use operator tool to invoke
computer operation
- Integrate with research mode for extended 30+ minute task execution
- Video of computer operation in train of thought for transparency

### Configuration
- Set `KHOJ_OPERATOR_ENABLED=True` in `docker-compose.yml`
- Requires Anthropic API key
- Computer container runs on port 5900 (VNC)
This commit is contained in:
Debanjum
2025-05-31 22:04:12 -07:00
committed by GitHub
34 changed files with 2862 additions and 675 deletions

View File

@@ -12,6 +12,7 @@ on:
- pyproject.toml
- Dockerfile
- prod.Dockerfile
- computer.Dockerfile
- docker-compose.yml
- .github/workflows/dockerize.yml
workflow_dispatch:
@@ -27,6 +28,10 @@ on:
description: 'Build Khoj cloud docker image'
type: boolean
default: true
khoj-computer:
description: 'Build computer for Khoj'
type: boolean
default: true
env:
# Tag Image with tag name on release
@@ -115,6 +120,21 @@ jobs:
org.opencontainers.image.description=Khoj AI Cloud - Your second brain powered by LLMs and Neural Search
org.opencontainers.image.source=${{ github.server_url }}/${{ github.repository }}
- name: 📦️️💻 Build and Push Computer for Khoj
uses: docker/build-push-action@v4
if: github.event_name == 'workflow_dispatch' && github.event.inputs.khoj-computer == 'true'
with:
context: .
file: computer.Dockerfile
push: true
tags: |
ghcr.io/${{ github.repository }}-computer:${{ env.DOCKER_IMAGE_TAG }}-${{ matrix.platform == 'linux/amd64' && 'amd64' || 'arm64' }}
cache-from: type=gha,scope=computer-${{ matrix.platform }}
cache-to: type=gha,mode=max,scope=computer-${{ matrix.platform }}
labels: |
org.opencontainers.image.description=Khoj AI Computer - A computer for your second brain to operate
org.opencontainers.image.source=${{ github.server_url }}/${{ github.repository }}
manifest:
needs: build
runs-on: ubuntu-latest
@@ -147,3 +167,12 @@ jobs:
-t ghcr.io/${{ github.repository }}-cloud:${{ github.ref_type == 'tag' && 'latest' || env.DOCKER_IMAGE_TAG }} \
ghcr.io/${{ github.repository }}-cloud:${{ env.DOCKER_IMAGE_TAG }}-amd64 \
ghcr.io/${{ github.repository }}-cloud:${{ env.DOCKER_IMAGE_TAG }}-arm64
- name: Create and Push Computer Manifest
if: github.event.inputs.khoj-computer == 'true'
run: |
docker buildx imagetools create \
-t ghcr.io/${{ github.repository }}-computer:${{ env.DOCKER_IMAGE_TAG }} \
-t ghcr.io/${{ github.repository }}-computer:${{ github.ref_type == 'tag' && 'latest' || env.DOCKER_IMAGE_TAG }} \
ghcr.io/${{ github.repository }}-computer:${{ env.DOCKER_IMAGE_TAG }}-amd64 \
ghcr.io/${{ github.repository }}-computer:${{ env.DOCKER_IMAGE_TAG }}-arm64

129
computer.Dockerfile Normal file
View File

@@ -0,0 +1,129 @@
FROM ubuntu:24.04
ENV DEBIAN_FRONTEND=noninteractive
# Install System Dependencies
RUN apt update \
&& apt install -y \
ca-certificates \
gnupg \
xfce4 \
xfce4-goodies \
x11vnc \
xvfb \
xdotool \
imagemagick \
x11-apps \
dbus-x11 \
sudo \
python3-pip \
python3-tk \
python3-dev \
build-essential \
scrot \
gnome-screenshot \
net-tools \
libx11-dev \
libxext-dev \
libxtst-dev \
libxinerama-dev \
libxmu-dev \
libxrandr-dev \
libxfixes-dev \
software-properties-common \
&& add-apt-repository ppa:mozillateam/ppa && apt update \
&& apt install -y --no-install-recommends \
# Desktop apps
firefox-esr \
libreoffice \
x11-apps \
xpdf \
gedit \
xpaint \
tint2 \
galculator \
pcmanfm \
unzip \
# Terminal apps like file editors, viewers, git, wget/curl etc.
less \
nano \
neovim \
vim \
git \
curl \
wget \
procps \
# Python/pyenv dependencies
libssl-dev \
zlib1g-dev \
libbz2-dev \
libreadline-dev \
libsqlite3-dev \
libncursesw5-dev \
xz-utils \
tk-dev \
libxml2-dev \
libxmlsec1-dev \
libffi-dev \
liblzma-dev \
# set default browser
&& update-alternatives --set x-www-browser /usr/bin/firefox-esr \
&& apt-get clean && rm -rf /var/lib/apt/lists/* \
# remove screen locks, power managers
&& apt remove -y light-locker xfce4-screensaver xfce4-power-manager || true
# Create Computer User
ENV USERNAME=operator
ENV HOME=/home/$USERNAME
RUN useradd -m -s /bin/bash -d $HOME -g $USERNAME $USERNAME && echo "${USERNAME} ALL=(ALL) NOPASSWD: ALL" >> /etc/sudoers
USER $USERNAME
WORKDIR $HOME
# Setup Python
RUN git clone https://github.com/pyenv/pyenv.git ~/.pyenv && \
cd ~/.pyenv && src/configure && make -C src && cd .. && \
echo 'export PYENV_ROOT="$HOME/.pyenv"' >> ~/.bashrc && \
echo 'command -v pyenv >/dev/null || export PATH="$PYENV_ROOT/bin:$PATH"' >> ~/.bashrc && \
echo 'eval "$(pyenv init -)"' >> ~/.bashrc
ENV PYENV_ROOT="$HOME/.pyenv"
ENV PATH="$PYENV_ROOT/bin:$PATH"
ENV PYENV_VERSION_MAJOR=3
ENV PYENV_VERSION_MINOR=11
ENV PYENV_VERSION_PATCH=6
ENV PYENV_VERSION=$PYENV_VERSION_MAJOR.$PYENV_VERSION_MINOR.$PYENV_VERSION_PATCH
RUN eval "$(pyenv init -)" && \
pyenv install $PYENV_VERSION && \
pyenv global $PYENV_VERSION && \
pyenv rehash
ENV PATH="$HOME/.pyenv/shims:$HOME/.pyenv/bin:$PATH"
# Install Python Packages
RUN python3 -m pip install --no-cache-dir \
pyautogui \
Pillow \
pyperclip \
pygetwindow
# Setup VNC
RUN x11vnc -storepasswd secret /home/operator/.vncpass
ARG WIDTH=1024
ARG HEIGHT=768
ARG DISPLAY_NUM=99
ENV WIDTH=$WIDTH
ENV HEIGHT=$HEIGHT
ENV DISPLAY_NUM=$DISPLAY_NUM
ENV DISPLAY=":$DISPLAY_NUM"
# Expose VNC on port 5900
# run Xvfb, x11vnc, Xfce (no login manager)
EXPOSE 5900
CMD ["/bin/sh", "-c", " export XDG_RUNTIME_DIR=/run/user/$(id -u); \
mkdir -p $XDG_RUNTIME_DIR && chown $USERNAME:$USERNAME $XDG_RUNTIME_DIR && chmod 0700 $XDG_RUNTIME_DIR; \
Xvfb $DISPLAY -screen 0 ${WIDTH}x${HEIGHT}x24 -dpi 96 -auth /home/$USERNAME/.Xauthority >/dev/null 2>&1 & \
sleep 1; \
xauth add $DISPLAY . $(mcookie); \
x11vnc -display $DISPLAY -forever -rfbauth /home/$USERNAME/.vncpass -listen 0.0.0.0 -rfbport 5900 >/dev/null 2>&1 & \
eval $(dbus-launch --sh-syntax) && \
startxfce4 & \
sleep 2 && echo 'Container running!' && \
tail -f /dev/null "]

View File

@@ -25,6 +25,14 @@ services:
- khoj_search:/etc/searxng
environment:
- SEARXNG_BASE_URL=http://localhost:8080/
# Creates Computer for Khoj to use.
# Set KHOJ_OPERATOR_ENABLED=True in the server service environment variable to enable.
computer:
image: ghcr.io/khoj-ai/computer:latest
ports:
- "5900:5900"
volumes:
- khoj_computer:/home/operator
server:
depends_on:
database:
@@ -75,6 +83,8 @@ services:
# - GEMINI_API_KEY=your_gemini_api_key
# - ANTHROPIC_API_KEY=your_anthropic_api_key
#
# Uncomment line below to enable Khoj to use its computer.
# - KHOJ_OPERATOR_ENABLED=True
# Uncomment appropriate lines below to enable web results with Khoj
# Ensure you set your provider specific API keys.
# ---
@@ -112,3 +122,4 @@ volumes:
khoj_db:
khoj_models:
khoj_search:
khoj_computer:

View File

@@ -114,6 +114,7 @@ prod = [
local = [
"pgserver == 0.1.4",
"playwright >= 1.49.0",
"pyautogui == 0.9.54",
]
dev = [
"khoj[prod,local]",

View File

@@ -8,7 +8,9 @@ import ChatMessage, {
ChatHistoryData,
StreamMessage,
TrainOfThought,
TrainOfThoughtObject,
} from "../chatMessage/chatMessage";
import TrainOfThoughtVideoPlayer from "../../../components/trainOfThoughtVideoPlayer/trainOfThoughtVideoPlayer";
import { ScrollArea } from "@/components/ui/scroll-area";
@@ -41,17 +43,108 @@ interface ChatHistoryProps {
setIsOwner?: (isOwner: boolean) => void;
}
interface TrainOfThoughtFrame {
text: string;
image?: string;
timestamp: number;
}
interface TrainOfThoughtGroup {
type: 'video' | 'text';
frames?: TrainOfThoughtFrame[];
textEntries?: TrainOfThoughtObject[];
}
interface TrainOfThoughtComponentProps {
trainOfThought: string[];
trainOfThought: string[] | TrainOfThoughtObject[];
lastMessage: boolean;
agentColor: string;
keyId: string;
completed?: boolean;
}
function extractTrainOfThoughtGroups(trainOfThought?: TrainOfThoughtObject[]): TrainOfThoughtGroup[] {
if (!trainOfThought) return [];
const groups: TrainOfThoughtGroup[] = [];
let currentVideoFrames: TrainOfThoughtFrame[] = [];
let currentTextEntries: TrainOfThoughtObject[] = [];
trainOfThought.forEach((thought, index) => {
let text = thought.data;
let hasImage = false;
// Extract screenshot image from the thought data
try {
const jsonMatch = text.match(
/\{.*(\"action\": \"screenshot\"|\"type\": \"screenshot\"|\"image\": \"data:image\/.*\").*\}/,
);
if (jsonMatch) {
const jsonMessage = JSON.parse(jsonMatch[0]);
if (jsonMessage.image) {
hasImage = true;
// Clean up the text to remove the JSON action
text = text.replace(`:\n**Action**: ${jsonMatch[0]}`, "");
if (jsonMessage.text) {
text += `\n\n${jsonMessage.text}`;
}
// If we have accumulated text entries, add them as a text group
if (currentTextEntries.length > 0) {
groups.push({
type: 'text',
textEntries: [...currentTextEntries]
});
currentTextEntries = [];
}
// Add to current video frames
currentVideoFrames.push({
text: text,
image: jsonMessage.image,
timestamp: index,
});
}
}
} catch (e) {
console.error("Failed to parse screenshot data", e);
}
if (!hasImage) {
// If we have accumulated video frames, add them as a video group
if (currentVideoFrames.length > 0) {
groups.push({
type: 'video',
frames: [...currentVideoFrames]
});
currentVideoFrames = [];
}
// Add to current text entries
currentTextEntries.push(thought);
}
});
// Add any remaining frames/entries
if (currentVideoFrames.length > 0) {
groups.push({
type: 'video',
frames: currentVideoFrames
});
}
if (currentTextEntries.length > 0) {
groups.push({
type: 'text',
textEntries: currentTextEntries
});
}
return groups;
}
function TrainOfThoughtComponent(props: TrainOfThoughtComponentProps) {
const lastIndex = props.trainOfThought.length - 1;
const [collapsed, setCollapsed] = useState(props.completed);
const [trainOfThoughtGroups, setTrainOfThoughtGroups] = useState<TrainOfThoughtGroup[]>([]);
const variants = {
open: {
@@ -72,6 +165,29 @@ function TrainOfThoughtComponent(props: TrainOfThoughtComponentProps) {
}
}, [props.completed]);
useEffect(() => {
// Handle empty array case
if (!props.trainOfThought || props.trainOfThought.length === 0) {
setTrainOfThoughtGroups([]);
return;
}
// Convert string array to TrainOfThoughtObject array if needed
let trainOfThoughtObjects: TrainOfThoughtObject[];
if (typeof props.trainOfThought[0] === 'string') {
trainOfThoughtObjects = (props.trainOfThought as string[]).map((data, index) => ({
type: 'text',
data: data
}));
} else {
trainOfThoughtObjects = props.trainOfThought as TrainOfThoughtObject[];
}
const groups = extractTrainOfThoughtGroups(trainOfThoughtObjects);
setTrainOfThoughtGroups(groups);
}, [props.trainOfThought]);
return (
<div
className={`${!collapsed ? styles.trainOfThought + " border" : ""} rounded-lg`}
@@ -101,15 +217,31 @@ function TrainOfThoughtComponent(props: TrainOfThoughtComponentProps) {
<AnimatePresence initial={false}>
{!collapsed && (
<motion.div initial="closed" animate="open" exit="closed" variants={variants}>
{props.trainOfThought.map((train, index) => (
<TrainOfThought
key={`train-${index}`}
message={train}
primary={
index === lastIndex && props.lastMessage && !props.completed
}
agentColor={props.agentColor}
/>
{trainOfThoughtGroups.map((group, groupIndex) => (
<div key={`train-group-${groupIndex}`}>
{group.type === 'video' && group.frames && group.frames.length > 0 && (
<TrainOfThoughtVideoPlayer
frames={group.frames}
autoPlay={false}
playbackSpeed={1500}
/>
)}
{group.type === 'text' && group.textEntries && group.textEntries.map((entry, entryIndex) => {
const lastIndex = trainOfThoughtGroups.length - 1;
const isLastGroup = groupIndex === lastIndex;
const isLastEntry = entryIndex === group.textEntries!.length - 1;
const isPrimaryEntry = isLastGroup && isLastEntry && props.lastMessage && !props.completed;
return (
<TrainOfThought
key={`train-text-${groupIndex}-${entryIndex}-${entry.data.length}`}
message={entry.data}
primary={isPrimaryEntry}
agentColor={props.agentColor}
/>
);
})}
</div>
))}
</motion.div>
)}
@@ -125,6 +257,7 @@ export default function ChatHistory(props: ChatHistoryProps) {
const [currentTurnId, setCurrentTurnId] = useState<string | null>(null);
const sentinelRef = useRef<HTMLDivElement | null>(null);
const scrollAreaRef = useRef<HTMLDivElement | null>(null);
const scrollableContentWrapperRef = useRef<HTMLDivElement | null>(null);
const latestUserMessageRef = useRef<HTMLDivElement | null>(null);
const latestFetchedMessageRef = useRef<HTMLDivElement | null>(null);
@@ -151,16 +284,46 @@ export default function ChatHistory(props: ChatHistoryProps) {
};
scrollAreaEl.addEventListener("scroll", detectIsNearBottom);
detectIsNearBottom(); // Initial check
return () => scrollAreaEl.removeEventListener("scroll", detectIsNearBottom);
}, []);
}, [scrollAreaRef]);
// Auto scroll while incoming message is streamed
useEffect(() => {
if (props.incomingMessages && props.incomingMessages.length > 0 && isNearBottom) {
scrollToBottom();
scrollToBottom(true);
}
}, [props.incomingMessages, isNearBottom]);
// ResizeObserver to handle content height changes (e.g., images loading)
useEffect(() => {
const contentWrapper = scrollableContentWrapperRef.current;
const scrollViewport = scrollAreaRef.current?.querySelector<HTMLElement>(scrollAreaSelector);
if (!contentWrapper || !scrollViewport) return;
const observer = new ResizeObserver(() => {
// Check current scroll position to decide if auto-scroll is warranted
const { scrollTop, scrollHeight, clientHeight } = scrollViewport;
const bottomThreshold = 50;
const currentlyNearBottom = (scrollHeight - (scrollTop + clientHeight)) <= bottomThreshold;
if (currentlyNearBottom) {
// Only auto-scroll if there are incoming messages being processed
if (props.incomingMessages && props.incomingMessages.length > 0) {
const lastMessage = props.incomingMessages[props.incomingMessages.length - 1];
// If the last message is not completed, or it just completed (indicated by incompleteIncomingMessageIndex still being set)
if (!lastMessage.completed || (lastMessage.completed && incompleteIncomingMessageIndex !== null)) {
scrollToBottom(true); // Use instant scroll
}
}
}
});
observer.observe(contentWrapper);
return () => observer.disconnect();
}, [props.incomingMessages, incompleteIncomingMessageIndex, scrollAreaRef]); // Dependencies
// Scroll to most recent user message after the first page of chat messages is loaded.
useEffect(() => {
if (data && data.chat && data.chat.length > 0 && currentPage < 2) {
@@ -297,7 +460,10 @@ export default function ChatHistory(props: ChatHistoryProps) {
behavior: instant ? "auto" : "smooth",
});
});
setIsNearBottom(true);
// Optimistically set, the scroll listener will verify
if (instant || scrollAreaEl && (scrollAreaEl.scrollHeight - (scrollAreaEl.scrollTop + scrollAreaEl.clientHeight)) < 5) {
setIsNearBottom(true);
}
};
function constructAgentLink() {
@@ -356,7 +522,7 @@ export default function ChatHistory(props: ChatHistoryProps) {
`}
ref={scrollAreaRef}
>
<div>
<div ref={scrollableContentWrapperRef}>
<div className={`${styles.chatHistory} ${props.customClassName}`}>
<div ref={sentinelRef} style={{ height: "1px" }}>
{fetchingData && <InlineLoading className="opacity-50" />}
@@ -367,9 +533,7 @@ export default function ChatHistory(props: ChatHistoryProps) {
<React.Fragment key={`chatMessage-${index}`}>
{chatMessage.trainOfThought && chatMessage.by === "khoj" && (
<TrainOfThoughtComponent
trainOfThought={chatMessage.trainOfThought?.map(
(train) => train.data,
)}
trainOfThought={chatMessage.trainOfThought}
lastMessage={false}
agentColor={data?.agent?.color || "orange"}
key={`${index}trainOfThought`}
@@ -428,12 +592,12 @@ export default function ChatHistory(props: ChatHistoryProps) {
conversationId={props.conversationId}
turnId={messageTurnId}
/>
{message.trainOfThought && (
{message.trainOfThought && message.trainOfThought.length > 0 && (
<TrainOfThoughtComponent
trainOfThought={message.trainOfThought}
lastMessage={index === incompleteIncomingMessageIndex}
agentColor={data?.agent?.color || "orange"}
key={`${index}trainOfThought`}
key={`${index}trainOfThought-${message.trainOfThought.length}-${message.trainOfThought.map(t => t.length).join('-')}`}
keyId={`${index}trainOfThought`}
completed={message.completed}
/>
@@ -519,7 +683,6 @@ export default function ChatHistory(props: ChatHistoryProps) {
className="absolute bottom-0 right-0 bg-white dark:bg-[hsl(var(--background))] text-neutral-500 dark:text-white p-2 rounded-full shadow-xl"
onClick={() => {
scrollToBottom();
setIsNearBottom(true);
}}
>
<ArrowDown size={24} />

View File

@@ -144,7 +144,7 @@ interface Intent {
"inferred-queries": string[];
}
interface TrainOfThoughtObject {
export interface TrainOfThoughtObject {
type: string;
data: string;
}

View File

@@ -0,0 +1,170 @@
.videoPlayer {
border: 1px solid hsl(var(--border));
border-radius: 8px;
background-color: hsl(var(--background));
margin: 16px 0;
overflow: hidden;
}
.screen {
position: relative;
background-color: hsl(var(--muted));
min-height: 300px;
display: flex;
align-items: center;
justify-content: center;
}
.screenImage {
max-width: 100%;
max-height: 400px;
object-fit: contain;
border-radius: 4px;
}
.textOverlay {
position: absolute;
bottom: 0;
left: 0;
right: 0;
background: linear-gradient(transparent, rgba(0, 0, 0, 0.8));
padding: 20px 16px 12px;
color: white;
}
.thoughtText {
font-size: 14px;
line-height: 1.4;
max-height: 100px;
overflow-y: auto;
}
.controls {
padding: 12px 16px;
background-color: hsl(var(--card));
border-top: 1px solid hsl(var(--border));
}
.timeline {
position: relative;
margin-bottom: 12px;
}
.timelineSlider {
width: 100%;
height: 4px;
background-color: hsl(var(--muted));
border-radius: 2px;
outline: none;
cursor: pointer;
-webkit-appearance: none;
appearance: none;
}
.timelineSlider::-webkit-slider-thumb {
-webkit-appearance: none;
appearance: none;
width: 16px;
height: 16px;
border-radius: 50%;
background-color: hsl(var(--primary));
cursor: pointer;
border: 2px solid white;
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.2);
}
.timelineSlider::-moz-range-thumb {
width: 16px;
height: 16px;
border-radius: 50%;
background-color: hsl(var(--primary));
cursor: pointer;
border: 2px solid white;
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.2);
}
.frameMarkers {
position: absolute;
top: -2px;
left: 0;
right: 0;
height: 8px;
display: flex;
justify-content: space-between;
pointer-events: none;
}
.frameMarker {
width: 6px;
height: 8px;
border-radius: 1px;
cursor: pointer;
pointer-events: auto;
transition: all 0.2s ease;
}
.frameMarker.hasImage {
background-color: hsl(var(--primary));
}
.frameMarker.textOnly {
background-color: hsl(var(--muted-foreground));
}
.frameMarker.active {
background-color: hsl(var(--accent)) !important;
transform: scaleY(1.5);
}
.frameMarker:hover {
transform: scaleY(1.2);
}
.controlButtons {
display: flex;
align-items: center;
gap: 8px;
margin-bottom: 8px;
}
.controlButton {
display: flex;
align-items: center;
justify-content: center;
width: 32px;
height: 32px;
border: 1px solid hsl(var(--border));
border-radius: 4px;
background-color: hsl(var(--background));
color: hsl(var(--foreground));
cursor: pointer;
transition: all 0.2s ease;
}
.controlButton:hover:not(:disabled) {
background-color: hsl(var(--muted));
}
.controlButton:disabled {
opacity: 0.5;
cursor: not-allowed;
}
.controlButton.active {
background-color: hsl(var(--primary));
color: hsl(var(--primary-foreground));
}
.frameInfo {
display: flex;
justify-content: center;
font-size: 12px;
color: hsl(var(--muted-foreground));
}
/* Dark mode adjustments */
@media (prefers-color-scheme: dark) {
.textOverlay {
background: linear-gradient(transparent, rgba(0, 0, 0, 0.9));
}
}

View File

@@ -0,0 +1,186 @@
"use client";
import React, { useState, useRef, useEffect } from "react";
import { Play, Pause, FastForward, Rewind } from "@phosphor-icons/react";
import styles from "./trainOfThoughtVideoPlayer.module.css";
interface TrainOfThoughtFrame {
text: string;
image?: string;
timestamp: number;
}
interface TrainOfThoughtVideoPlayerProps {
frames: TrainOfThoughtFrame[];
autoPlay?: boolean;
playbackSpeed?: number;
}
export default function TrainOfThoughtVideoPlayer({
frames,
autoPlay = true,
playbackSpeed = 1000, // ms per frame
}: TrainOfThoughtVideoPlayerProps) {
const [currentFrameIndex, setCurrentFrameIndex] = useState(0);
const [isPlaying, setIsPlaying] = useState(autoPlay);
const [isAutoTracking, setIsAutoTracking] = useState(true);
const intervalRef = useRef<NodeJS.Timeout | null>(null);
// Auto-advance to latest frame when new frames are added
useEffect(() => {
if (isAutoTracking && frames.length > 0) {
setCurrentFrameIndex(frames.length - 1);
}
}, [frames.length, isAutoTracking]);
// Handle playback
useEffect(() => {
if (isPlaying && frames.length > 1) {
intervalRef.current = setInterval(() => {
setCurrentFrameIndex((prev) => {
const next = prev + 1;
if (next >= frames.length) {
setIsPlaying(false);
return prev;
}
return next;
});
}, playbackSpeed);
} else {
if (intervalRef.current) {
clearInterval(intervalRef.current);
intervalRef.current = null;
}
}
return () => {
if (intervalRef.current) {
clearInterval(intervalRef.current);
}
};
}, [isPlaying, frames.length, playbackSpeed]);
const currentFrame = frames[currentFrameIndex];
const handleSeek = (index: number) => {
setCurrentFrameIndex(index);
setIsAutoTracking(false);
setIsPlaying(false);
};
const handlePlay = () => {
setIsPlaying(!isPlaying);
setIsAutoTracking(false);
};
const handlePrevious = () => {
if (currentFrameIndex > 0) {
setCurrentFrameIndex(currentFrameIndex - 1);
setIsAutoTracking(false);
setIsPlaying(false);
}
};
const handleNext = () => {
if (currentFrameIndex < frames.length - 1) {
setCurrentFrameIndex(currentFrameIndex + 1);
setIsAutoTracking(false);
setIsPlaying(false);
}
};
const handleAutoTrack = () => {
setIsAutoTracking(true);
setCurrentFrameIndex(frames.length - 1);
setIsPlaying(false);
};
if (!frames.length) {
return null;
}
return (
<div className={styles.videoPlayer}>
<div className={styles.screen}>
{currentFrame?.image && (
<img
src={currentFrame.image}
alt={`Train of thought frame ${currentFrameIndex + 1}`}
className={styles.screenImage}
/>
)}
<div className={styles.textOverlay}>
<div className={styles.thoughtText}>{currentFrame?.text}</div>
</div>
</div>
<div className={styles.controls}>
<div className={styles.timeline}>
<input
type="range"
min={0}
max={Math.max(0, frames.length - 1)}
value={currentFrameIndex}
onChange={(e) => handleSeek(parseInt(e.target.value))}
className={styles.timelineSlider}
/>
<div className={styles.frameMarkers}>
{frames.map((frame, index) => (
<div
key={index}
className={`${styles.frameMarker} ${
frame.image ? styles.hasImage : styles.textOnly
} ${index === currentFrameIndex ? styles.active : ""}`}
onClick={() => handleSeek(index)}
title={`Frame ${index + 1}: ${frame.text.slice(0, 50)}...`}
/>
))}
</div>
</div>
<div className={styles.controlButtons}>
<button
onClick={handlePrevious}
disabled={currentFrameIndex === 0}
title="Previous frame"
className={styles.controlButton}
>
<Rewind size={16} />
</button>
<button
onClick={handlePlay}
disabled={frames.length <= 1}
title={isPlaying ? "Pause" : "Play"}
className={styles.controlButton}
>
{isPlaying ? <Pause size={16} /> : <Play size={16} />}
</button>
<button
onClick={handleNext}
disabled={currentFrameIndex === frames.length - 1}
title="Next frame"
className={styles.controlButton}
>
<FastForward size={16} />
</button>
<button
onClick={handleAutoTrack}
className={`${styles.controlButton} ${isAutoTracking ? styles.active : ""}`}
title="Auto-track latest"
>
Live
</button>
</div>
<div className={styles.frameInfo}>
<span>
{currentFrameIndex + 1} / {frames.length}
</span>
</div>
</div>
</div>
);
}

View File

@@ -1312,6 +1312,26 @@ class ConversationAdapters:
else:
ServerChatSettings.objects.create(chat_default=chat_model, chat_advanced=chat_model)
@staticmethod
def get_max_context_size(chat_model: ChatModel, user: KhojUser) -> int | None:
"""Get the max context size for the user based on the chat model."""
subscribed = is_user_subscribed(user) if user else False
if subscribed and chat_model.subscribed_max_prompt_size:
max_tokens = chat_model.subscribed_max_prompt_size
else:
max_tokens = chat_model.max_prompt_size
return max_tokens
@staticmethod
async def aget_max_context_size(chat_model: ChatModel, user: KhojUser) -> int | None:
"""Get the max context size for the user based on the chat model."""
subscribed = await ais_user_subscribed(user) if user else False
if subscribed and chat_model.subscribed_max_prompt_size:
max_tokens = chat_model.subscribed_max_prompt_size
else:
max_tokens = chat_model.max_prompt_size
return max_tokens
@staticmethod
async def aget_server_webscraper():
server_chat_settings = await ServerChatSettings.objects.filter().prefetch_related("web_scraper").afirst()

View File

@@ -107,7 +107,7 @@ class ChatMessage(PydanticBaseModel):
onlineContext: Dict[str, OnlineContext] = {}
codeContext: Dict[str, CodeContextData] = {}
researchContext: Optional[List] = None
operatorContext: Optional[Dict[str, str]] = None
operatorContext: Optional[List] = None
created: str
images: Optional[List[str]] = None
queryFiles: Optional[List[Dict]] = None

View File

@@ -14,8 +14,10 @@ from khoj.processor.conversation.anthropic.utils import (
format_messages_for_anthropic,
)
from khoj.processor.conversation.utils import (
OperatorRun,
ResponseWithThought,
clean_json,
construct_question_history,
construct_structured_message,
generate_chatml_messages_with_context,
messages_to_print,
@@ -53,13 +55,7 @@ def extract_questions_anthropic(
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
# Extract Past User Message and Inferred Questions from Conversation Log
chat_history = "".join(
[
f'User: {chat["intent"]["query"]}\nAssistant: {{"queries": {chat["intent"].get("inferred-queries") or list([chat["intent"]["query"]])}}}\nA: {chat["message"]}\n\n'
for chat in conversation_log.get("chat", [])[-4:]
if chat["by"] == "khoj"
]
)
chat_history = construct_question_history(conversation_log, query_prefix="User", agent_name="Assistant")
# Get dates relative to today for prompt creation
today = datetime.today()
@@ -144,7 +140,7 @@ async def converse_anthropic(
references: list[dict],
online_results: Optional[Dict[str, Dict]] = None,
code_results: Optional[Dict[str, Dict]] = None,
operator_results: Optional[Dict[str, str]] = None,
operator_results: Optional[List[OperatorRun]] = None,
conversation_log={},
model: Optional[str] = "claude-3-7-sonnet-latest",
api_key: Optional[str] = None,
@@ -216,8 +212,11 @@ async def converse_anthropic(
f"{prompts.code_executed_context.format(code_results=truncate_code_context(code_results))}\n\n"
)
if ConversationCommand.Operator in conversation_commands and not is_none_or_empty(operator_results):
operator_content = [
{"query": oc.query, "response": oc.response, "webpages": oc.webpages} for oc in operator_results
]
context_message += (
f"{prompts.operator_execution_context.format(operator_results=yaml_dump(operator_results))}\n\n"
f"{prompts.operator_execution_context.format(operator_results=yaml_dump(operator_content))}\n\n"
)
context_message = context_message.strip()

View File

@@ -60,7 +60,7 @@ def anthropic_completion_with_backoff(
client = get_anthropic_client(api_key, api_base_url)
anthropic_clients[api_key] = client
formatted_messages, system_prompt = format_messages_for_anthropic(messages, system_prompt)
formatted_messages, system = format_messages_for_anthropic(messages, system_prompt)
aggregated_response = ""
if response_type == "json_object" and not deepthought:
@@ -70,8 +70,8 @@ def anthropic_completion_with_backoff(
final_message = None
model_kwargs = model_kwargs or dict()
if system_prompt:
model_kwargs["system"] = system_prompt
if system:
model_kwargs["system"] = system
max_tokens = max_tokens or DEFAULT_MAX_TOKENS_ANTHROPIC
if deepthought and is_reasoning_model(model_name):
@@ -146,7 +146,7 @@ async def anthropic_chat_completion_with_backoff(
# Temperature control not supported when using extended thinking
temperature = 1.0
formatted_messages, system_prompt = format_messages_for_anthropic(messages, system_prompt)
formatted_messages, system = format_messages_for_anthropic(messages, system_prompt)
aggregated_response = ""
response_started = False
@@ -156,7 +156,7 @@ async def anthropic_chat_completion_with_backoff(
messages=formatted_messages,
model=model_name, # type: ignore
temperature=temperature,
system=system_prompt,
system=system,
timeout=20,
max_tokens=max_tokens,
**model_kwargs,
@@ -231,7 +231,10 @@ def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt: st
else:
system_prompt += message.content
messages.remove(message)
system_prompt = None if is_none_or_empty(system_prompt) else system_prompt
if not is_none_or_empty(system_prompt):
system = [{"type": "text", "text": system_prompt, "cache_control": {"type": "ephemeral"}}]
else:
system = None
# Anthropic requires the first message to be a 'user' message
if len(messages) == 1:
@@ -274,12 +277,32 @@ def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt: st
logger.error(f"Drop message with empty content as not supported:\n{message}")
messages.remove(message)
continue
if isinstance(message.content, str):
message.content = [{"type": "text", "text": message.content}]
# Add cache control to enable prompt caching for conversations with sufficient context
# Only add caching if we have multiple messages to make it worthwhile
if len(messages) > 2:
# Remove any existing cache controls from previous messages
for message in messages: # All except the last message
if isinstance(message.content, list):
for block in message.content:
if isinstance(block, dict) and "cache_control" in block:
del block["cache_control"]
# Add cache control to the last content block of second to last message.
# In research mode, this message content is list of iterations, updated after each research iteration.
# Caching it should improve research efficiency.
cache_message = messages[-2]
if isinstance(cache_message.content, list) and cache_message.content:
# Add cache control to the last content block
cache_message.content[-1]["cache_control"] = {"type": "ephemeral"}
formatted_messages: List[anthropic.types.MessageParam] = [
anthropic.types.MessageParam(role=message.role, content=message.content) for message in messages
]
return formatted_messages, system_prompt
return formatted_messages, system
def is_reasoning_model(model_name: str) -> bool:

View File

@@ -14,7 +14,9 @@ from khoj.processor.conversation.google.utils import (
gemini_completion_with_backoff,
)
from khoj.processor.conversation.utils import (
OperatorRun,
clean_json,
construct_question_history,
construct_structured_message,
generate_chatml_messages_with_context,
messages_to_print,
@@ -53,13 +55,7 @@ def extract_questions_gemini(
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
# Extract Past User Message and Inferred Questions from Conversation Log
chat_history = "".join(
[
f'User: {chat["intent"]["query"]}\nAssistant: {{"queries": {chat["intent"].get("inferred-queries") or list([chat["intent"]["query"]])}}}\nA: {chat["message"]}\n\n'
for chat in conversation_log.get("chat", [])[-4:]
if chat["by"] == "khoj"
]
)
chat_history = construct_question_history(conversation_log, query_prefix="User", agent_name="Assistant")
# Get dates relative to today for prompt creation
today = datetime.today()
@@ -166,7 +162,7 @@ async def converse_gemini(
references: list[dict],
online_results: Optional[Dict[str, Dict]] = None,
code_results: Optional[Dict[str, Dict]] = None,
operator_results: Optional[Dict[str, str]] = None,
operator_results: Optional[List[OperatorRun]] = None,
conversation_log={},
model: Optional[str] = "gemini-2.0-flash",
api_key: Optional[str] = None,
@@ -240,8 +236,11 @@ async def converse_gemini(
f"{prompts.code_executed_context.format(code_results=truncate_code_context(code_results))}\n\n"
)
if ConversationCommand.Operator in conversation_commands and not is_none_or_empty(operator_results):
operator_content = [
{"query": oc.query, "response": oc.response, "webpages": oc.webpages} for oc in operator_results
]
context_message += (
f"{prompts.operator_execution_context.format(operator_results=yaml_dump(operator_results))}\n\n"
f"{prompts.operator_execution_context.format(operator_results=yaml_dump(operator_content))}\n\n"
)
context_message = context_message.strip()
@@ -276,7 +275,8 @@ async def converse_gemini(
deepthought=deepthought,
tracer=tracer,
):
full_response += chunk
if chunk.response:
full_response += chunk.response
yield chunk
# Call completion_func once finish streaming and we have the full response

View File

@@ -21,6 +21,7 @@ from tenacity import (
)
from khoj.processor.conversation.utils import (
ResponseWithThought,
commit_conversation_trace,
get_image_from_base64,
get_image_from_url,
@@ -102,7 +103,7 @@ def gemini_completion_with_backoff(
client = get_gemini_client(api_key, api_base_url)
gemini_clients[api_key] = client
formatted_messages, system_prompt = format_messages_for_gemini(messages, system_prompt)
formatted_messages, system_instruction = format_messages_for_gemini(messages, system_prompt)
# format model response schema
response_schema = None
@@ -110,12 +111,12 @@ def gemini_completion_with_backoff(
response_schema = clean_response_schema(model_kwargs["response_schema"])
thinking_config = None
if deepthought and model_name.startswith("gemini-2-5"):
if deepthought and model_name.startswith("gemini-2.5"):
thinking_config = gtypes.ThinkingConfig(thinking_budget=MAX_REASONING_TOKENS_GEMINI)
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
config = gtypes.GenerateContentConfig(
system_instruction=system_prompt,
system_instruction=system_instruction,
temperature=temperature,
thinking_config=thinking_config,
max_output_tokens=MAX_OUTPUT_TOKENS_GEMINI,
@@ -178,21 +179,21 @@ async def gemini_chat_completion_with_backoff(
model_kwargs=None,
deepthought=False,
tracer: dict = {},
) -> AsyncGenerator[str, None]:
) -> AsyncGenerator[ResponseWithThought, None]:
client = gemini_clients.get(api_key)
if not client:
client = get_gemini_client(api_key, api_base_url)
gemini_clients[api_key] = client
formatted_messages, system_prompt = format_messages_for_gemini(messages, system_prompt)
formatted_messages, system_instruction = format_messages_for_gemini(messages, system_prompt)
thinking_config = None
if deepthought and model_name.startswith("gemini-2-5"):
thinking_config = gtypes.ThinkingConfig(thinking_budget=MAX_REASONING_TOKENS_GEMINI)
if deepthought and model_name.startswith("gemini-2.5"):
thinking_config = gtypes.ThinkingConfig(thinking_budget=MAX_REASONING_TOKENS_GEMINI, include_thoughts=True)
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
config = gtypes.GenerateContentConfig(
system_instruction=system_prompt,
system_instruction=system_instruction,
temperature=temperature,
thinking_config=thinking_config,
max_output_tokens=MAX_OUTPUT_TOKENS_GEMINI,
@@ -216,18 +217,25 @@ async def gemini_chat_completion_with_backoff(
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
# Keep track of the last chunk for usage data
final_chunk = chunk
# Handle streamed response chunk
# handle safety, rate-limit, other finish reasons
stop_message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback)
message = stop_message or chunk.text
aggregated_response += message
yield message
if stopped:
yield ResponseWithThought(response=stop_message)
logger.warning(
f"LLM Response Prevented for {model_name}: {stop_message}.\n"
+ f"Last Message by {messages[-1].role}: {messages[-1].content}"
)
break
# emit thought vs response parts
for part in chunk.candidates[0].content.parts:
if part.text:
aggregated_response += part.text
yield ResponseWithThought(response=part.text)
if part.thought:
yield ResponseWithThought(thought=part.text)
# Calculate cost of chat
input_tokens = final_chunk.usage_metadata.prompt_token_count or 0 if final_chunk else 0
output_tokens = final_chunk.usage_metadata.candidates_token_count or 0 if final_chunk else 0

View File

@@ -16,6 +16,7 @@ from khoj.processor.conversation.offline.utils import download_model
from khoj.processor.conversation.utils import (
clean_json,
commit_conversation_trace,
construct_question_history,
generate_chatml_messages_with_context,
messages_to_print,
)
@@ -64,13 +65,7 @@ def extract_questions_offline(
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
# Extract Past User Message and Inferred Questions from Conversation Log
chat_history = ""
if use_history:
for chat in conversation_log.get("chat", [])[-4:]:
if chat["by"] == "khoj":
chat_history += f"Q: {chat['intent']['query']}\n"
chat_history += f"Khoj: {chat['message']}\n\n"
chat_history = construct_question_history(conversation_log, include_query=False) if use_history else ""
# Get dates relative to today for prompt creation
today = datetime.today()

View File

@@ -17,8 +17,10 @@ from khoj.processor.conversation.openai.utils import (
)
from khoj.processor.conversation.utils import (
JsonSupport,
OperatorRun,
ResponseWithThought,
clean_json,
construct_question_history,
construct_structured_message,
generate_chatml_messages_with_context,
messages_to_print,
@@ -55,13 +57,7 @@ def extract_questions(
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
# Extract Past User Message and Inferred Questions from Conversation Log
chat_history = "".join(
[
f'Q: {chat["intent"]["query"]}\nKhoj: {{"queries": {chat["intent"].get("inferred-queries") or list([chat["intent"]["query"]])}}}\nA: {chat["message"]}\n\n'
for chat in conversation_log.get("chat", [])[-4:]
if chat["by"] == "khoj" and "to-image" not in chat["intent"].get("type")
]
)
chat_history = construct_question_history(conversation_log)
# Get dates relative to today for prompt creation
today = datetime.today()
@@ -169,7 +165,7 @@ async def converse_openai(
references: list[dict],
online_results: Optional[Dict[str, Dict]] = None,
code_results: Optional[Dict[str, Dict]] = None,
operator_results: Optional[Dict[str, str]] = None,
operator_results: Optional[List[OperatorRun]] = None,
conversation_log={},
model: str = "gpt-4o-mini",
api_key: Optional[str] = None,
@@ -242,8 +238,11 @@ async def converse_openai(
f"{prompts.code_executed_context.format(code_results=truncate_code_context(code_results))}\n\n"
)
if not is_none_or_empty(operator_results):
operator_content = [
{"query": oc.query, "response": oc.response, "webpages": oc.webpages} for oc in operator_results
]
context_message += (
f"{prompts.operator_execution_context.format(operator_results=yaml_dump(operator_results))}\n\n"
f"{prompts.operator_execution_context.format(operator_results=yaml_dump(operator_content))}\n\n"
)
context_message = context_message.strip()

View File

@@ -10,7 +10,7 @@ from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from io import BytesIO
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Dict, List, Literal, Optional, Union
import PIL.Image
import pyjson5
@@ -20,6 +20,7 @@ import yaml
from langchain_core.messages.chat import ChatMessage
from llama_cpp import LlamaTokenizer
from llama_cpp.llama import Llama
from pydantic import BaseModel
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
from khoj.database.adapters import ConversationAdapters
@@ -73,9 +74,9 @@ model_to_prompt_size = {
"claude-3-7-sonnet-20250219": 60000,
"claude-3-7-sonnet-latest": 60000,
"claude-3-5-haiku-20241022": 60000,
"claude-sonnet-4": 60000,
"claude-sonnet-4-0": 60000,
"claude-sonnet-4-20250514": 60000,
"claude-opus-4": 60000,
"claude-opus-4-0": 60000,
"claude-opus-4-20250514": 60000,
# Offline Models
"bartowski/Qwen2.5-14B-Instruct-GGUF": 20000,
@@ -87,7 +88,49 @@ model_to_prompt_size = {
model_to_tokenizer: Dict[str, str] = {}
class InformationCollectionIteration:
class AgentMessage(BaseModel):
role: Literal["user", "assistant", "system", "environment"]
content: Union[str, List]
class OperatorRun:
def __init__(
self,
query: str,
trajectory: list[AgentMessage] | list[dict] = None,
response: str = None,
webpages: list[dict] = None,
):
self.query = query
self.response = response
self.webpages = webpages or []
self.trajectory: list[AgentMessage] = []
if trajectory:
for item in trajectory:
if isinstance(item, dict):
self.trajectory.append(AgentMessage(**item))
elif hasattr(item, "role") and hasattr(item, "content"): # Heuristic for AgentMessage like object
self.trajectory.append(item)
else:
logger.warning(f"Unexpected item type in trajectory: {type(item)}")
def to_dict(self) -> dict:
# Ensure AgentMessage instances in trajectory are also dicts
serialized_trajectory = []
for msg in self.trajectory:
if hasattr(msg, "model_dump"): # Check if it's a Pydantic model
serialized_trajectory.append(msg.model_dump())
elif isinstance(msg, dict):
serialized_trajectory.append(msg) # Already a dict
return {
"query": self.query,
"response": self.response,
"trajectory": serialized_trajectory,
"webpages": self.webpages,
}
class ResearchIteration:
def __init__(
self,
tool: str,
@@ -95,7 +138,7 @@ class InformationCollectionIteration:
context: list = None,
onlineContext: dict = None,
codeContext: dict = None,
operatorContext: dict[str, str] = None,
operatorContext: dict | OperatorRun = None,
summarizedResult: str = None,
warning: str = None,
):
@@ -104,13 +147,18 @@ class InformationCollectionIteration:
self.context = context
self.onlineContext = onlineContext
self.codeContext = codeContext
self.operatorContext = operatorContext
self.operatorContext = OperatorRun(**operatorContext) if isinstance(operatorContext, dict) else operatorContext
self.summarizedResult = summarizedResult
self.warning = warning
def to_dict(self) -> dict:
data = vars(self).copy()
data["operatorContext"] = self.operatorContext.to_dict() if self.operatorContext else None
return data
def construct_iteration_history(
previous_iterations: List[InformationCollectionIteration],
previous_iterations: List[ResearchIteration],
previous_iteration_prompt: str,
query: str = None,
) -> list[dict]:
@@ -143,11 +191,8 @@ def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="A
chat_history = ""
for chat in conversation_history.get("chat", [])[-n:]:
if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder", "summarize"]:
chat_history += f"User: {chat['intent']['query']}\n"
if chat["intent"].get("inferred-queries"):
chat_history += f'{agent_name}: {{"queries": {chat["intent"].get("inferred-queries")}}}\n'
chat_history += f"{agent_name}: {chat['message']}\n\n"
elif chat["by"] == "khoj" and chat.get("images"):
chat_history += f"User: {chat['intent']['query']}\n"
@@ -156,6 +201,7 @@ def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="A
chat_history += f"User: {chat['intent']['query']}\n"
chat_history += f"{agent_name}: {chat['intent']['inferred-queries'][0]}\n"
elif chat["by"] == "you":
chat_history += f"User: {chat['message']}\n"
raw_query_files = chat.get("queryFiles")
if raw_query_files:
query_files: Dict[str, str] = {}
@@ -168,8 +214,74 @@ def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="A
return chat_history
def construct_question_history(
conversation_log: dict,
include_query: bool = True,
lookback: int = 6,
query_prefix: str = "Q",
agent_name: str = "Khoj",
) -> str:
"""
Constructs a chat history string formatted for query extraction purposes.
"""
history_parts = ""
original_query = None
for chat in conversation_log.get("chat", [])[-lookback:]:
if chat["by"] == "you":
original_query = chat.get("message")
history_parts += f"{query_prefix}: {original_query}\n"
if chat["by"] == "khoj":
if original_query is None:
continue
message = chat.get("message", "")
inferred_queries_list = chat.get("intent", {}).get("inferred-queries")
# Ensure inferred_queries_list is a list, defaulting to the original query in a list
if not inferred_queries_list:
inferred_queries_list = [original_query]
# If it's a string (though unlikely based on usage), wrap it in a list
elif isinstance(inferred_queries_list, str):
inferred_queries_list = [inferred_queries_list]
if include_query:
# Ensure 'type' exists and is a string before checking 'to-image'
intent_type = chat.get("intent", {}).get("type", "")
if "to-image" not in intent_type:
history_parts += f'{agent_name}: {{"queries": {inferred_queries_list}}}\n'
history_parts += f"A: {message}\n\n"
else:
history_parts += f"{agent_name}: {message}\n\n"
# Reset original_query for the next turn
original_query = None
return history_parts
def construct_chat_history_for_operator(conversation_history: dict, n: int = 6) -> list[AgentMessage]:
"""
Construct chat history for operator agent in conversation log.
Only include last n completed turns (i.e with user and khoj message).
"""
chat_history: list[AgentMessage] = []
user_message: Optional[AgentMessage] = None
for chat in conversation_history.get("chat", []):
if len(chat_history) >= n:
break
if chat["by"] == "you" and chat.get("message"):
content = [{"type": "text", "text": chat["message"]}]
for file in chat.get("queryFiles", []):
content += [{"type": "text", "text": f'## File: {file["name"]}\n\n{file["content"]}'}]
user_message = AgentMessage(role="user", content=content)
elif chat["by"] == "khoj" and chat.get("message"):
chat_history += [user_message, AgentMessage(role="assistant", content=chat["message"])]
return chat_history
def construct_tool_chat_history(
previous_iterations: List[InformationCollectionIteration], tool: ConversationCommand = None
previous_iterations: List[ResearchIteration], tool: ConversationCommand = None
) -> Dict[str, list]:
"""
Construct chat history from previous iterations for a specific tool
@@ -178,8 +290,8 @@ def construct_tool_chat_history(
If no tool is provided inferred query for all tools used are added.
"""
chat_history: list = []
base_extractor: Callable[[InformationCollectionIteration], List[str]] = lambda x: []
extract_inferred_query_map: Dict[ConversationCommand, Callable[[InformationCollectionIteration], List[str]]] = {
base_extractor: Callable[[ResearchIteration], List[str]] = lambda iteration: []
extract_inferred_query_map: Dict[ConversationCommand, Callable[[ResearchIteration], List[str]]] = {
ConversationCommand.Notes: (
lambda iteration: [c["query"] for c in iteration.context] if iteration.context else []
),
@@ -192,9 +304,6 @@ def construct_tool_chat_history(
ConversationCommand.Code: (
lambda iteration: list(iteration.codeContext.keys()) if iteration.codeContext else []
),
ConversationCommand.Operator: (
lambda iteration: list(iteration.operatorContext.keys()) if iteration.operatorContext else []
),
}
for iteration in previous_iterations:
# If a tool is provided use the inferred query extractor for that tool if available
@@ -273,7 +382,7 @@ async def save_to_conversation_log(
compiled_references: List[Dict[str, Any]] = [],
online_results: Dict[str, Any] = {},
code_results: Dict[str, Any] = {},
operator_results: Dict[str, str] = {},
operator_results: List[OperatorRun] = None,
inferred_queries: List[str] = [],
intent_type: str = "remember",
client_application: ClientApplication = None,
@@ -284,7 +393,7 @@ async def save_to_conversation_log(
generated_images: List[str] = [],
raw_generated_files: List[FileAttachment] = [],
generated_mermaidjs_diagram: str = None,
research_results: Optional[List[InformationCollectionIteration]] = None,
research_results: Optional[List[ResearchIteration]] = None,
train_of_thought: List[Any] = [],
tracer: Dict[str, Any] = {},
):
@@ -301,8 +410,8 @@ async def save_to_conversation_log(
"intent": {"inferred-queries": inferred_queries, "type": intent_type},
"onlineContext": online_results,
"codeContext": code_results,
"operatorContext": operator_results,
"researchContext": [vars(r) for r in research_results] if research_results and not chat_response else None,
"operatorContext": [o.to_dict() for o in operator_results] if operator_results and not chat_response else None,
"researchContext": [r.to_dict() for r in research_results] if research_results and not chat_response else None,
"automationId": automation_id,
"trainOfThought": train_of_thought,
"turnId": turn_id,
@@ -459,10 +568,12 @@ def generate_chatml_messages_with_context(
]
if not is_none_or_empty(chat.get("operatorContext")):
operator_context = chat.get("operatorContext")
operator_content = "\n\n".join([f'## Task: {oc["query"]}\n{oc["response"]}\n' for oc in operator_context])
message_context += [
{
"type": "text",
"text": f"{prompts.operator_execution_context.format(operator_results=chat.get('operatorContext'))}",
"text": f"{prompts.operator_execution_context.format(operator_results=operator_content)}",
}
]

View File

@@ -0,0 +1,59 @@
# Khoj Operator (Experimental)
## Overview
Give Khoj its own computer to operate in a transparent, controlled manner. Accomplish tasks that require visual browsing, file editing and terminal access. Operator with research mode can work for 30+ minutes to accomplish more substantial tasks like feature development, travel planning, shopping etc.
## Setup
### Prerequisites
- Docker and Docker Compose installed
- Anthropic API key (required - only Anthropic models currently enabled)
### Installation Steps
1. Download the Khoj docker-compose.yml file
```shell
mkdir ~/.khoj && cd ~/.khoj
wget https://raw.githubusercontent.com/khoj-ai/khoj/master/docker-compose.yml
```
2. Configure environment variables in `docker-compose.yml`
- Set `ANTHROPIC_API_KEY` to your [Anthropic API key](https://console.anthropic.com/settings/keys)
- Uncomment `KHOJ_OPERATOR_ENABLED=True` to enable the operator tool
3. Start Khoj services
```shell
docker-compose up
```
4. Access the web app at http://localhost:42110
Ensure you're using a claude 3.7+ models on your [settings page](http://localhost:42110/settings)
## Usage
Use the `/operator` command or ask Khoj in normal or research mode to use the operator tool to have it operate its computer:
**Examples:**
- `/operator Find flights from Bangkok to Mexico City with no US layover`
- `/research Clone the khoj repo and tell me how the operator tool is implemented`
## Supported Models
Currently enables **only Anthropic models**:
- Claude Sonnet 4
- Claude 3.7 Sonnet
- Claude Opus 4
*Note: OpenAI and other operator models are disabled while in developemnt.*
## Capabilities
The operator can:
- **Computer Control**: Take screenshots, click, type, navigate desktop
- **File Operations**: Create, edit, and manage files
- **Terminal Access**: Execute bash commands and scripts
- **Web Browsing**: Navigate websites, documents and extract information
## Architecture
- **Environments**: Operator Computer and Browser environments
- **Models**: Enable Vision Language Models (VLM) to operate computer
- **Execution**: Containerize computer environment for security and isolation

View File

@@ -6,13 +6,23 @@ from typing import Callable, List, Optional
from khoj.database.adapters import AgentAdapters, ConversationAdapters
from khoj.database.models import Agent, ChatModel, KhojUser
from khoj.processor.conversation.utils import (
OperatorRun,
construct_chat_history,
construct_chat_history_for_operator,
)
from khoj.processor.operator.operator_actions import *
from khoj.processor.operator.operator_agent_anthropic import AnthropicOperatorAgent
from khoj.processor.operator.operator_agent_base import OperatorAgent
from khoj.processor.operator.operator_agent_binary import BinaryOperatorAgent
from khoj.processor.operator.operator_agent_openai import OpenAIOperatorAgent
from khoj.processor.operator.operator_environment_base import EnvStepResult
from khoj.processor.operator.operator_environment_base import (
Environment,
EnvironmentType,
EnvStepResult,
)
from khoj.processor.operator.operator_environment_browser import BrowserEnvironment
from khoj.processor.operator.operator_environment_computer import ComputerEnvironment
from khoj.routers.helpers import ChatEvent
from khoj.utils.helpers import timer
from khoj.utils.rawconfig import LocationData
@@ -20,12 +30,14 @@ from khoj.utils.rawconfig import LocationData
logger = logging.getLogger(__name__)
# --- Browser Operator Function ---
async def operate_browser(
# --- Main Operator Entrypoint ---
async def operate_environment(
query: str,
user: KhojUser,
conversation_log: dict,
location_data: LocationData,
previous_trajectory: Optional[OperatorRun] = None,
environment_type: EnvironmentType = EnvironmentType.COMPUTER,
send_status_func: Optional[Callable] = None,
query_images: Optional[List[str]] = None, # TODO: Handle query images
agent: Agent = None,
@@ -33,8 +45,11 @@ async def operate_browser(
cancellation_event: Optional[asyncio.Event] = None,
tracer: dict = {},
):
response, summary_message, user_input_message = None, None, None
environment: Optional[BrowserEnvironment] = None
response, user_input_message = None, None
# Only use partial previous trajectories to continue existing task
if previous_trajectory and previous_trajectory.response:
previous_trajectory = None
# Get the agent chat model
agent_chat_model = await AgentAdapters.aget_agent_chat_model(agent, user) if agent else None
@@ -42,16 +57,40 @@ async def operate_browser(
if not reasoning_model or not reasoning_model.vision_enabled:
reasoning_model = await ConversationAdapters.aget_vision_enabled_config()
if not reasoning_model:
raise ValueError(f"No vision enabled chat model found. Configure a vision chat model to operate browser.")
raise ValueError(f"No vision enabled chat model found. Configure a vision chat model to operate environment.")
# Create conversation history from conversation log
chat_history = construct_chat_history_for_operator(conversation_log)
# Initialize Agent
max_iterations = int(os.getenv("KHOJ_OPERATOR_ITERATIONS", 40))
max_context = await ConversationAdapters.aget_max_context_size(reasoning_model, user) or 20000
max_iterations = int(os.getenv("KHOJ_OPERATOR_ITERATIONS", 100))
operator_agent: OperatorAgent
if is_operator_model(reasoning_model.name) == ChatModel.ModelType.OPENAI:
operator_agent = OpenAIOperatorAgent(query, reasoning_model, max_iterations, tracer)
elif is_operator_model(reasoning_model.name) == ChatModel.ModelType.ANTHROPIC:
operator_agent = AnthropicOperatorAgent(query, reasoning_model, max_iterations, tracer)
else:
if is_operator_model(reasoning_model.name) == ChatModel.ModelType.ANTHROPIC:
operator_agent = AnthropicOperatorAgent(
query,
reasoning_model,
environment_type,
max_iterations,
max_context,
chat_history,
previous_trajectory,
tracer,
)
# TODO: Remove once OpenAI Operator Agent is useful
elif is_operator_model(reasoning_model.name) == ChatModel.ModelType.OPENAI and False:
operator_agent = OpenAIOperatorAgent(
query,
reasoning_model,
environment_type,
max_iterations,
max_context,
chat_history,
previous_trajectory,
tracer,
)
# TODO: Remove once Binary Operator Agent is useful
elif False:
grounding_model_name = "ui-tars-1.5"
grounding_model = await ConversationAdapters.aget_chat_model_by_name(grounding_model_name)
if (
@@ -59,41 +98,62 @@ async def operate_browser(
or not grounding_model.vision_enabled
or not grounding_model.model_type == ChatModel.ModelType.OPENAI
):
raise ValueError("No supported visual grounding model for binary operator agent found.")
operator_agent = BinaryOperatorAgent(query, reasoning_model, grounding_model, max_iterations, tracer)
raise ValueError("Binary operator agent needs ui-tars-1.5 served over an OpenAI compatible API.")
operator_agent = BinaryOperatorAgent(
query,
reasoning_model,
grounding_model,
environment_type,
max_iterations,
max_context,
chat_history,
previous_trajectory,
tracer,
)
else:
raise ValueError(
f"Unsupported operator model: {reasoning_model.name}. "
"Please use a supported operator model. Only Anthropic models are currently supported."
)
# Initialize Environment
if send_status_func:
async for event in send_status_func(f"**Launching Browser**"):
async for event in send_status_func(f"**Launching {environment_type.value}**"):
yield {ChatEvent.STATUS: event}
environment = BrowserEnvironment()
if environment_type == EnvironmentType.BROWSER:
environment: Environment = BrowserEnvironment()
else:
environment = ComputerEnvironment(provider="docker")
await environment.start(width=1024, height=768)
# Start Operator Loop
try:
summarize_prompt = f"Use the results of our research to provide a comprehensive, self-contained answer for the target query:\n{query}."
task_completed = False
iterations = 0
operator_run = OperatorRun(query=query, trajectory=operator_agent.messages, response=response)
yield operator_run
with timer(f"Operating browser with {reasoning_model.model_type} {reasoning_model.name}", logger):
with timer(
f"Operating {environment_type.value} with {reasoning_model.model_type} {reasoning_model.name}", logger
):
while iterations < max_iterations and not task_completed:
if cancellation_event and cancellation_event.is_set():
logger.debug(f"Browser operator cancelled by client disconnect")
logger.debug(f"{environment_type.value} operator cancelled by client disconnect")
break
iterations += 1
# 1. Get current environment state
browser_state = await environment.get_state()
env_state = await environment.get_state()
# 2. Agent decides action(s)
agent_result = await operator_agent.act(browser_state)
agent_result = await operator_agent.act(env_state)
# 3. Execute actions in the environment
env_steps: List[EnvStepResult] = []
for action in agent_result.actions:
if cancellation_event and cancellation_event.is_set():
logger.debug(f"Browser operator cancelled by client disconnect")
logger.debug(f"{environment_type.value} operator cancelled by client disconnect")
break
# Handle request for user action and break the loop
if isinstance(action, RequestUserAction):
@@ -106,12 +166,14 @@ async def operate_browser(
env_steps.append(env_step)
# Render status update
latest_screenshot = f"data:image/webp;base64,{env_steps[-1].screenshot_base64 if env_steps else browser_state.screenshot}"
latest_screenshot = (
f"data:image/webp;base64,{env_steps[-1].screenshot_base64 if env_steps else env_state.screenshot}"
)
render_payload = agent_result.rendered_response
render_payload["image"] = latest_screenshot
render_content = f"**Action**: {json.dumps(render_payload)}"
if send_status_func:
async for event in send_status_func(f"**Operating Browser**:\n{render_content}"):
async for event in send_status_func(f"**Operating {environment_type.value}**:\n{render_content}"):
yield {ChatEvent.STATUS: event}
# Check if termination conditions are met
@@ -123,31 +185,33 @@ async def operate_browser(
if task_completed or trigger_iteration_limit:
# Summarize results of operator run on last iteration
operator_agent.add_action_results(env_steps, agent_result)
summary_message = await operator_agent.summarize(summarize_prompt, browser_state)
summary_message = await operator_agent.summarize(env_state)
logger.info(f"Task completed: {task_completed}, Iteration limit: {trigger_iteration_limit}")
break
# 4. Update agent on the results of its action on the environment
operator_agent.add_action_results(env_steps, agent_result)
operator_run.trajectory = operator_agent.messages
# Determine final response message
if user_input_message:
response = user_input_message
operator_run.response = user_input_message
elif task_completed:
response = summary_message
operator_run.response = summary_message
elif cancellation_event and cancellation_event.is_set():
operator_run.response = None
else: # Hit iteration limit
response = f"Operator hit iteration limit ({max_iterations}). If the results seem incomplete try again, assign a smaller task or try a different approach.\nThese were the results till now:\n{summary_message}"
operator_run.response = f"Operator hit iteration limit ({max_iterations}). If the results seem incomplete try again, assign a smaller task or try a different approach.\nThese were the results till now:\n{summary_message}"
finally:
if environment and not user_input_message: # Don't close browser if user input required
if environment and not user_input_message: # Don't close environment if user input required
await environment.close()
if operator_agent:
operator_agent.reset()
yield {
"query": query,
"result": user_input_message or response,
"webpages": [{"link": url, "snippet": ""} for url in environment.visited_urls],
}
if environment_type == EnvironmentType.BROWSER and hasattr(environment, "visited_urls"):
operator_run.webpages = [{"link": url, "snippet": ""} for url in environment.visited_urls]
yield operator_run
def is_operator_model(model: str) -> ChatModel.ModelType | None:

View File

@@ -1,5 +1,6 @@
import json
import logging
from textwrap import dedent
from openai import AzureOpenAI, OpenAI
from openai.types.chat import ChatCompletion, ChatCompletionMessage
@@ -8,7 +9,7 @@ from khoj.database.models import ChatModel
from khoj.processor.conversation.utils import construct_structured_message
from khoj.processor.operator.operator_actions import *
from khoj.processor.operator.operator_agent_base import AgentActResult
from khoj.processor.operator.operator_environment_base import EnvState
from khoj.processor.operator.operator_environment_base import EnvironmentType, EnvState
from khoj.utils.helpers import get_chat_usage_metrics
logger = logging.getLogger(__name__)
@@ -18,6 +19,7 @@ class GroundingAgent:
def __init__(
self,
model: ChatModel,
environment_type: EnvironmentType,
client: OpenAI | AzureOpenAI,
max_iterations: int,
tracer: dict = None,
@@ -26,9 +28,211 @@ class GroundingAgent:
self.client = client
self.max_iterations = max_iterations
self.tracer = tracer
self.environment_type = environment_type
self.action_tools = self.get_tools(self.environment_type)
# Define tools for the grounding LLM (OpenAI format)
self.action_tools = [
async def act(self, instruction: str, current_state: EnvState) -> tuple[str, list[OperatorAction]]:
"""Call the grounding LLM to get the next action based on the current state and instruction."""
# Format the message for the API call
messages_for_api = self._format_message_for_api(instruction, current_state)
try:
grounding_response: ChatCompletion = await self.client.chat.completions.create(
messages=messages_for_api,
model=self.model.name,
tools=self.action_tools,
tool_choice="required",
temperature=0.0, # Grounding should be precise
max_completion_tokens=1000, # Allow for thoughts + actions
)
if not isinstance(grounding_response, ChatCompletion):
raise ValueError("Grounding LLM response is not of type ChatCompletion.")
logger.debug(f"Grounding LLM response: {grounding_response.model_dump_json()}")
# Parse tool calls
grounding_message = grounding_response.choices[0].message
rendered_response, actions = self._parse_action(grounding_message, instruction, current_state)
# Update usage by grounding model
self.tracer["usage"] = get_chat_usage_metrics(
self.model.name,
input_tokens=grounding_response.usage.prompt_tokens,
output_tokens=grounding_response.usage.completion_tokens,
usage=self.tracer.get("usage"),
)
except Exception as e:
logger.error(f"Error calling Grounding LLM: {e}")
rendered_response = f"**Error**: Error contacting Grounding LLM: {e}"
actions = []
return rendered_response, actions
def _format_message_for_api(self, instruction: str, current_state: EnvState) -> List:
"""Format the message for the API call."""
# Construct grounding LLM input (using only the latest user prompt + image)
# We don't pass the full history here, as grounding depends on the *current* state + NL action
grounding_user_prompt = self.get_instruction(instruction, self.environment_type)
screenshots = [f"data:image/webp;base64,{current_state.screenshot}"]
grounding_messages_content = construct_structured_message(
grounding_user_prompt, screenshots, self.model.name, vision_enabled=True
)
return [{"role": "user", "content": grounding_messages_content}]
def _parse_action(
self, grounding_message: ChatCompletionMessage, instruction: str, current_state: EnvState
) -> tuple[str, list[OperatorAction]]:
"""Parse the tool calls from the grounding LLM response and convert them to action objects."""
actions: List[OperatorAction] = []
action_results: List[dict] = []
if grounding_message.tool_calls:
rendered_parts = []
for tool_call in grounding_message.tool_calls:
function_name = tool_call.function.name
try:
arguments = json.loads(tool_call.function.arguments)
action_to_run: Optional[OperatorAction] = None
action_render_str = f"**Action ({function_name})**: {tool_call.function.arguments}"
if function_name == "click":
action_to_run = ClickAction(**arguments)
elif function_name == "left_double":
action_to_run = DoubleClickAction(**arguments)
elif function_name == "right_single":
action_to_run = ClickAction(button="right", **arguments)
elif function_name == "type":
content = arguments.get("content")
action_to_run = TypeAction(text=content)
elif function_name == "scroll":
direction = arguments.get("direction", "down")
amount = 3
action_to_run = ScrollAction(scroll_direction=direction, scroll_amount=amount, **arguments)
elif function_name == "hotkey":
action_to_run = KeypressAction(**arguments)
elif function_name == "goto":
action_to_run = GotoAction(**arguments)
elif function_name == "back":
action_to_run = BackAction(**arguments)
elif function_name == "wait":
action_to_run = WaitAction(**arguments)
elif function_name == "screenshot":
action_to_run = ScreenshotAction(**arguments)
elif function_name == "drag":
# Need to convert list of dicts to list of Point objects
path_dicts = arguments.get("path", [])
path_points = [Point(**p) for p in path_dicts]
if path_points:
action_to_run = DragAction(path=path_points)
else:
logger.warning(f"Drag action called with empty path: {arguments}")
action_render_str += " [Skipped - empty path]"
elif function_name == "finished":
action_to_run = None
else:
logger.warning(f"Grounding LLM called unhandled tool: {function_name}")
action_render_str += " [Unhandled]"
if action_to_run:
actions.append(action_to_run)
action_results.append(
{
"type": "tool_result",
"tool_call_id": tool_call.id,
"content": None, # Updated after environment step
}
)
rendered_parts.append(action_render_str)
except (json.JSONDecodeError, TypeError, ValueError) as arg_err:
logger.error(
f"Error parsing arguments for tool {function_name}: {arg_err} - Args: {tool_call.function.arguments}"
)
rendered_parts.append(f"**Error**: Failed to parse arguments for {function_name}")
rendered_response = "\n- ".join(rendered_parts)
else:
# Grounding LLM responded but didn't call a tool
logger.warning("Grounding LLM did not produce a tool call.")
rendered_response = f"{grounding_message.content or 'No action required.'}"
# Render the response
return rendered_response, actions
def get_instruction(self, instruction: str, environment_type: EnvironmentType) -> str:
"""
Get the instruction for the agent based on the environment type.
"""
UITARS_COMPUTER_PREFIX_PROMPT = """
You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
"""
UITARS_BROWSER_PREFIX_PROMPT = """
You are a GUI agent. You are given a task and a screenshot of the web browser tab you operate. You need to decide the next action to complete the task.
You control a single tab in a Chromium browser. You cannot access the OS, filesystem or the application window.
Always use the `goto` function to navigate to a specific URL. Ctrl+t, Ctrl+w, Ctrl+q, Ctrl+Shift+T, Ctrl+Shift+W are not allowed.
"""
UITARS_USR_COMPUTER_PROMPT_THOUGHT = f"""
Try fulfill the user instruction to the best of your ability, especially when the instruction is given multiple times. Do not ignore the instruction.
## Output Format
```
Thought: ...
Action: ...
```
## Action Space
click(start_box='<|box_start|>(x1,y1)<|box_end|>')
left_double(start_box='<|box_start|>(x1,y1)<|box_end|>')
right_single(start_box='<|box_start|>(x1,y1)<|box_end|>')
drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>')
hotkey(key='')
type(content='xxx') # Use escape characters \\', \\\", and \\n in content part to ensure we can parse the content in normal python string format. If you want to submit your input, use \\n at the end of content.
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
wait(duration='time') # Sleep for specified time. Default is 1s and take a screenshot to check for any changes.
## Note
- Use English in `Thought` part.
- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part.
## User Instruction
{instruction}
"""
UITARS_USR_BROWSER_PROMPT_THOUGHT = f"""
Try fulfill the user instruction to the best of your ability, especially when the instruction is given multiple times. Do not ignore the instruction.
## Output Format
```
Thought: ...
Action: ...
```
## Action Space
click(start_box='<|box_start|>(x1,y1)<|box_end|>')
left_double(start_box='<|box_start|>(x1,y1)<|box_end|>')
right_single(start_box='<|box_start|>(x1,y1)<|box_end|>')
drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>')
hotkey(key='')
type(content='xxx') # Use escape characters \\', \\\", and \\n in content part to ensure we can parse the content in normal python string format. If you want to submit your input, use \\n at the end of content.
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
wait(duration='time') # Sleep for specified time. Default is 1s and take a screenshot to check for any changes.
goto(url='xxx') # Always use this to navigate to a specific URL. Use escape characters \\', \\", and \\n in url part to ensure we can parse the url in normal python string format.
back() # Use this to go back to the previous page.
## Note
- Use English in `Thought` part.
- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part.
## User Instruction
{instruction}
"""
if environment_type == EnvironmentType.BROWSER:
return dedent(UITARS_BROWSER_PREFIX_PROMPT + UITARS_USR_BROWSER_PROMPT_THOUGHT).lstrip()
elif environment_type == EnvironmentType.COMPUTER:
return dedent(UITARS_COMPUTER_PREFIX_PROMPT + UITARS_USR_COMPUTER_PROMPT_THOUGHT).lstrip()
else:
raise ValueError(f"Expected environment type: Computer or Browser. Got {environment_type}.")
def get_tools(self, environment_type: EnvironmentType) -> list[dict]:
"""Get tools for the grounding LLM, in OpenAI API tool format"""
tools = [
{
"type": "function",
"function": {
@@ -163,182 +367,32 @@ class GroundingAgent:
},
},
},
{
"type": "function",
"function": {
"name": "goto",
"description": "Navigate to a specific URL.",
"parameters": {
"type": "object",
"properties": {"url": {"type": "string", "description": "Fully qualified URL"}},
"required": ["url"],
]
if environment_type == EnvironmentType.BROWSER:
tools += [
{
"type": "function",
"function": {
"name": "goto",
"description": "Navigate to a specific URL.",
"parameters": {
"type": "object",
"properties": {"url": {"type": "string", "description": "Fully qualified URL"}},
"required": ["url"],
},
},
},
},
{
"type": "function",
"function": {
"name": "back",
"description": "navigate back to the previous page.",
"parameters": {"type": "object", "properties": {}},
{
"type": "function",
"function": {
"name": "back",
"description": "navigate back to the previous page.",
"parameters": {"type": "object", "properties": {}},
},
},
},
]
]
async def act(self, instruction: str, current_state: EnvState) -> tuple[str, list[OperatorAction]]:
"""Call the grounding LLM to get the next action based on the current state and instruction."""
# Format the message for the API call
messages_for_api = self._format_message_for_api(instruction, current_state)
try:
grounding_response: ChatCompletion = await self.client.chat.completions.create(
messages=messages_for_api,
model=self.model.name,
tools=self.action_tools,
tool_choice="required",
temperature=0.0, # Grounding should be precise
max_completion_tokens=1000, # Allow for thoughts + actions
)
if not isinstance(grounding_response, ChatCompletion):
raise ValueError("Grounding LLM response is not of type ChatCompletion.")
logger.debug(f"Grounding LLM response: {grounding_response.model_dump_json()}")
# Parse tool calls
grounding_message = grounding_response.choices[0].message
rendered_response, actions = self._parse_action(grounding_message, instruction, current_state)
# Update usage by grounding model
self.tracer["usage"] = get_chat_usage_metrics(
self.model.name,
input_tokens=grounding_response.usage.prompt_tokens,
output_tokens=grounding_response.usage.completion_tokens,
usage=self.tracer.get("usage"),
)
except Exception as e:
logger.error(f"Error calling Grounding LLM: {e}")
rendered_response = f"**Error**: Error contacting Grounding LLM: {e}"
actions = []
return rendered_response, actions
def _format_message_for_api(self, instruction: str, current_state: EnvState) -> List:
"""Format the message for the API call."""
grounding_user_prompt = f"""
You are a GUI agent. You are given a task and a screenshot of the web browser tab you operate. You need to decide the next action to complete the task.
You control a single tab in a Chromium browser. You cannot access the OS, filesystem or the application window.
Always use the `goto` function to navigate to a specific URL. Ctrl+t, Ctrl+w, Ctrl+q, Ctrl+Shift+T, Ctrl+Shift+W are not allowed.
## Output Format
```
Thought: ...
Action: ...
```
## Action Space
click(start_box='<|box_start|>(x1,y1)<|box_end|>')
left_double(start_box='<|box_start|>(x1,y1)<|box_end|>')
right_single(start_box='<|box_start|>(x1,y1)<|box_end|>')
drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>')
hotkey(key='')
type(content='xxx') # Use escape characters \\', \\\", and \\n in content part to ensure we can parse the content in normal python string format. If you want to submit your input, use \\n at the end of content.
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
wait(duration='time') # Sleep for specified time. Default is 1s and take a screenshot to check for any changes.
goto(url='xxx') # Always use this to navigate to a specific URL. Use escape characters \\', \\", and \\n in url part to ensure we can parse the url in normal python string format.
back() # Use this to go back to the previous page.
## Note
- Use English in `Thought` part.
- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part.
## User Instruction
{instruction}
""".lstrip()
# Construct grounding LLM input (using only the latest user prompt + image)
# We don't pass the full history here, as grounding depends on the *current* state + NL action
screenshots = [f"data:image/webp;base64,{current_state.screenshot}"]
grounding_messages_content = construct_structured_message(
grounding_user_prompt, screenshots, self.model.name, vision_enabled=True
)
return [{"role": "user", "content": grounding_messages_content}]
def _parse_action(
self, grounding_message: ChatCompletionMessage, instruction: str, current_state: EnvState
) -> tuple[str, list[OperatorAction]]:
"""Parse the tool calls from the grounding LLM response and convert them to action objects."""
actions: List[OperatorAction] = []
action_results: List[dict] = []
if grounding_message.tool_calls:
rendered_parts = []
for tool_call in grounding_message.tool_calls:
function_name = tool_call.function.name
try:
arguments = json.loads(tool_call.function.arguments)
action_to_run: Optional[OperatorAction] = None
action_render_str = f"**Action ({function_name})**: {tool_call.function.arguments}"
if function_name == "click":
action_to_run = ClickAction(**arguments)
elif function_name == "left_double":
action_to_run = DoubleClickAction(**arguments)
elif function_name == "right_single":
action_to_run = ClickAction(button="right", **arguments)
elif function_name == "type":
content = arguments.get("content")
action_to_run = TypeAction(text=content)
elif function_name == "scroll":
direction = arguments.get("direction", "down")
amount = 3
action_to_run = ScrollAction(scroll_direction=direction, scroll_amount=amount, **arguments)
elif function_name == "hotkey":
action_to_run = KeypressAction(**arguments)
elif function_name == "goto":
action_to_run = GotoAction(**arguments)
elif function_name == "back":
action_to_run = BackAction(**arguments)
elif function_name == "wait":
action_to_run = WaitAction(**arguments)
elif function_name == "screenshot":
action_to_run = ScreenshotAction(**arguments)
elif function_name == "drag":
# Need to convert list of dicts to list of Point objects
path_dicts = arguments.get("path", [])
path_points = [Point(**p) for p in path_dicts]
if path_points:
action_to_run = DragAction(path=path_points)
else:
logger.warning(f"Drag action called with empty path: {arguments}")
action_render_str += " [Skipped - empty path]"
elif function_name == "finished":
action_to_run = None
else:
logger.warning(f"Grounding LLM called unhandled tool: {function_name}")
action_render_str += " [Unhandled]"
if action_to_run:
actions.append(action_to_run)
action_results.append(
{
"type": "tool_result",
"tool_call_id": tool_call.id,
"content": None, # Updated after environment step
}
)
rendered_parts.append(action_render_str)
except (json.JSONDecodeError, TypeError, ValueError) as arg_err:
logger.error(
f"Error parsing arguments for tool {function_name}: {arg_err} - Args: {tool_call.function.arguments}"
)
rendered_parts.append(f"**Error**: Failed to parse arguments for {function_name}")
rendered_response = "\n- ".join(rendered_parts)
else:
# Grounding LLM responded but didn't call a tool
logger.warning("Grounding LLM did not produce a tool call.")
rendered_response = f"{grounding_message.content or 'No action required.'}"
# Render the response
return rendered_response, actions
return tools
def reset(self):
"""Reset the agent state."""

View File

@@ -10,6 +10,7 @@ import logging
import math
import re
from io import BytesIO
from textwrap import dedent
from typing import Any, List
import numpy as np
@@ -18,7 +19,7 @@ from openai.types.chat import ChatCompletion
from PIL import Image
from khoj.processor.operator.operator_actions import *
from khoj.processor.operator.operator_environment_base import EnvState
from khoj.processor.operator.operator_environment_base import EnvironmentType, EnvState
from khoj.utils.helpers import get_chat_usage_metrics
logger = logging.getLogger(__name__)
@@ -35,29 +36,8 @@ class GroundingAgentUitars:
MAX_PIXELS = 16384 * 28 * 28
MAX_RATIO = 200
UITARS_USR_PROMPT_THOUGHT = """
You are a GUI agent. You are given a task and a screenshot of the web browser tab you operate. You need to perform the next action to complete the task.
You control a single tab in a Chromium browser. You cannot access the OS, filesystem, the application window or the addressbar.
Try fulfill the user instruction to the best of your ability, especially when the instruction is given multiple times. Do not ignore the instruction.
## Output Format
```
Thought: ...
Action: ...
```
## Action Space
{action_space}
## Note
- Use {language} in `Thought` part.
- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part.
## User Instruction
{instruction}
"""
UITARS_NORMAL_ACTION_SPACE = """
UITARS_NORMAL_ACTION_SPACE = dedent(
"""
click(start_box='<|box_start|>(x1,y1)<|box_end|>')
left_double(start_box='<|box_start|>(x1,y1)<|box_end|>')
right_single(start_box='<|box_start|>(x1,y1)<|box_end|>')
@@ -67,14 +47,15 @@ class GroundingAgentUitars:
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
wait() #Sleep for 5s and take a screenshot to check for any changes.
finished(content='xxx') # Use escape characters \\', \\", and \\n in content part to ensure we can parse the content in normal python string format.
""".lstrip()
"""
).lstrip()
def __init__(
self,
model_name: str,
environment_type: EnvironmentType,
client: AsyncOpenAI | AsyncAzureOpenAI,
max_iterations=50,
environment_type: Literal["computer", "web"] = "computer",
runtime_conf: dict = {
"infer_mode": "qwen25vl_normal",
"prompt_style": "qwen25vl_normal",
@@ -94,7 +75,7 @@ class GroundingAgentUitars:
self.model_name = model_name
self.client = client
self.tracer = tracer
self.environment_type = environment_type
self.environment = environment_type
self.max_iterations = max_iterations
self.runtime_conf = runtime_conf
@@ -116,7 +97,7 @@ class GroundingAgentUitars:
self.history_images: list[bytes] = []
self.history_responses: list[str] = []
self.prompt_template = self.UITARS_USR_PROMPT_THOUGHT
self.prompt_template = self.get_instruction(self.environment)
self.prompt_action_space = self.UITARS_NORMAL_ACTION_SPACE
if "history_n" in self.runtime_conf:
@@ -126,11 +107,11 @@ class GroundingAgentUitars:
self.cur_callusr_count = 0
async def act(self, instruction: str, env_state: EnvState) -> tuple[str, list[OperatorAction]]:
async def act(self, instruction: str, current_state: EnvState) -> tuple[str, list[OperatorAction]]:
"""
Suggest the next action(s) based on the instruction and current environment.
"""
messages = self._format_messages_for_api(instruction, env_state)
messages = self._format_messages_for_api(instruction, current_state)
recent_screenshot = Image.open(BytesIO(self.history_images[-1]))
origin_resized_height = recent_screenshot.height
@@ -145,9 +126,11 @@ class GroundingAgentUitars:
try_times = 3
while not parsed_responses:
if try_times <= 0:
print(f"Reach max retry times to fetch response from client, as error flag.")
logger.warning(f"Reach max retry times to fetch response from client, as error flag.")
return "client error\nFAIL", []
try:
message_content = "\n".join([msg["content"][0].get("text") or "[image]" for msg in messages])
logger.debug(f"User message content: {message_content}")
response: ChatCompletion = await self.client.chat.completions.create(
model="ui-tars",
messages=messages,
@@ -228,20 +211,9 @@ class GroundingAgentUitars:
self.actions.append(actions)
return f"{prediction}\nFAIL", []
if self.environment_type == "web":
actions.extend(
self.parsing_response_to_action(parsed_response, obs_image_height, obs_image_width, self.input_swap)
)
else:
pass
# TODO: Add PyautoguiAction when enable computer environment
# actions.append(
# PyautoguiAction(code=
# self.parsing_response_to_pyautogui_code(
# parsed_response, obs_image_height, obs_image_width, self.input_swap
# )
# )
# )
actions.extend(
self.parsing_response_to_action(parsed_response, obs_image_height, obs_image_width, self.input_swap)
)
self.actions.append(actions)
@@ -252,13 +224,52 @@ class GroundingAgentUitars:
return prediction or "", actions
def _format_messages_for_api(self, instruction: str, env_state: EnvState):
def get_instruction(self, environment_type: EnvironmentType) -> str:
"""
Get the instruction for the agent based on the environment type.
"""
UITARS_COMPUTER_PREFIX_PROMPT = """
You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
"""
UITARS_BROWSER_PREFIX_PROMPT = """
You are a GUI agent. You are given a task and a screenshot of the web browser tab you operate. You need to perform the next action to complete the task.
You control a single tab in a Chromium browser. You cannot access the OS, filesystem, the application window or the addressbar.
"""
UITARS_USR_PROMPT_THOUGHT = """
Try fulfill the user instruction to the best of your ability, especially when the instruction is given multiple times. Do not ignore the instruction.
## Output Format
```
Thought: ...
Action: ...
```
## Action Space
{action_space}
## Note
- Use {language} in `Thought` part.
- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part.
## User Instruction
{instruction}
"""
if environment_type == EnvironmentType.BROWSER:
return dedent(UITARS_BROWSER_PREFIX_PROMPT + UITARS_USR_PROMPT_THOUGHT).lstrip()
elif environment_type == EnvironmentType.COMPUTER:
return dedent(UITARS_COMPUTER_PREFIX_PROMPT + UITARS_USR_PROMPT_THOUGHT).lstrip()
else:
raise ValueError(f"Unsupported environment type: {environment_type}")
def _format_messages_for_api(self, instruction: str, current_state: EnvState):
assert len(self.observations) == len(self.actions) and len(self.actions) == len(
self.thoughts
), "The number of observations and actions should be the same."
self.history_images.append(base64.b64decode(env_state.screenshot))
self.observations.append({"screenshot": env_state.screenshot, "accessibility_tree": None})
self.history_images.append(base64.b64decode(current_state.screenshot))
self.observations.append({"screenshot": current_state.screenshot, "accessibility_tree": None})
user_prompt = self.prompt_template.format(
instruction=instruction, action_space=self.prompt_action_space, language=self.language

View File

@@ -125,6 +125,49 @@ class NoopAction(BaseAction):
type: Literal["noop"] = "noop"
# --- Text Editor Actions ---
class TextEditorViewAction(BaseAction):
"""View contents of a file."""
type: Literal["text_editor_view"] = "text_editor_view"
path: str
view_range: Optional[List[int]] = None # [start_line, end_line]
class TextEditorCreateAction(BaseAction):
"""Create a new file with specified contents."""
type: Literal["text_editor_create"] = "text_editor_create"
path: str
file_text: str
class TextEditorStrReplaceAction(BaseAction):
"""Execute an exact string match replacement on a file."""
type: Literal["text_editor_str_replace"] = "text_editor_str_replace"
path: str
old_str: str
new_str: str
class TextEditorInsertAction(BaseAction):
"""Insert new text after a specified line number."""
type: Literal["text_editor_insert"] = "text_editor_insert"
path: str
insert_line: int
new_str: str
class TerminalAction(BaseAction):
"""Insert new text after a specified line number."""
type: Literal["terminal"] = "terminal"
command: str
restart: bool = False
OperatorAction = Union[
ClickAction,
DoubleClickAction,
@@ -146,4 +189,9 @@ OperatorAction = Union[
BackAction,
RequestUserAction,
NoopAction,
TextEditorViewAction,
TextEditorCreateAction,
TextEditorStrReplaceAction,
TextEditorInsertAction,
TerminalAction,
]

View File

@@ -3,18 +3,21 @@ import json
import logging
from copy import deepcopy
from datetime import datetime
from typing import List, Optional, cast
from textwrap import dedent
from typing import List, Literal, Optional, cast
from anthropic.types.beta import BetaContentBlock
from anthropic.types.beta import BetaContentBlock, BetaTextBlock, BetaToolUseBlock
from khoj.database.models import ChatModel
from khoj.processor.conversation.anthropic.utils import is_reasoning_model
from khoj.processor.conversation.utils import AgentMessage
from khoj.processor.operator.operator_actions import *
from khoj.processor.operator.operator_agent_base import (
AgentActResult,
AgentMessage,
OperatorAgent,
from khoj.processor.operator.operator_agent_base import AgentActResult, OperatorAgent
from khoj.processor.operator.operator_environment_base import (
EnvironmentType,
EnvState,
EnvStepResult,
)
from khoj.processor.operator.operator_environment_base import EnvState, EnvStepResult
from khoj.utils.helpers import get_anthropic_async_client, is_none_or_empty
logger = logging.getLogger(__name__)
@@ -23,81 +26,34 @@ logger = logging.getLogger(__name__)
# --- Anthropic Operator Agent ---
class AnthropicOperatorAgent(OperatorAgent):
async def act(self, current_state: EnvState) -> AgentActResult:
client = get_anthropic_async_client(
self.vision_model.ai_model_api.api_key, self.vision_model.ai_model_api.api_base_url
)
betas = self.model_default_headers()
temperature = 1.0
actions: List[OperatorAction] = []
action_results: List[dict] = []
self._commit_trace() # Commit trace before next action
system_prompt = f"""<SYSTEM_CAPABILITY>
* You are Khoj, a smart web browser operating assistant. You help the users accomplish tasks using a web browser.
* You operate a Chromium browser using Playwright via the 'computer' tool.
* You cannot access the OS or filesystem.
* You can interact with the web browser to perform tasks like clicking, typing, scrolling, and more.
* You can use the additional back() and goto() helper functions to ease navigating the browser. If you see nothing, try goto duckduckgo.com
* When viewing a webpage it can be helpful to zoom out so that you can see everything on the page. Either that, or make sure you scroll down to see everything before deciding something isn't available.
* When using your computer function calls, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request.
* Perform web searches using DuckDuckGo. Don't use Google even if requested as the query will fail.
* The current date is {datetime.today().strftime('%A, %B %-d, %Y')}.
* The current URL is {current_state.url}.
</SYSTEM_CAPABILITY>
system_prompt = self.get_instructions(self.environment_type, current_state)
tools = self.get_tools(self.environment_type, current_state)
<IMPORTANT>
* You are allowed upto {self.max_iterations} iterations to complete the task.
* Do not loop on wait, screenshot for too many turns without taking any action.
* After initialization if the browser is blank, enter a website URL using the goto() function instead of waiting
</IMPORTANT>
"""
if is_none_or_empty(self.messages):
self.messages = [AgentMessage(role="user", content=self.query)]
tools = [
{
"type": self.model_default_tool("computer"),
"name": "computer",
"display_width_px": 1024,
"display_height_px": 768,
}, # TODO: Get from env
{
"name": "back",
"description": "Go back to the previous page.",
"input_schema": {"type": "object", "properties": {}},
},
{
"name": "goto",
"description": "Go to a specific URL.",
"input_schema": {
"type": "object",
"properties": {"url": {"type": "string", "description": "Fully qualified URL to navigate to."}},
"required": ["url"],
},
},
]
# Trigger trajectory compression if exceed size limit
if len(self.messages) > self.message_limit:
logger.debug("Compacting operator trajectory.")
await self._compress()
thinking: dict[str, str | int] = {"type": "disabled"}
if is_reasoning_model(self.vision_model.name):
thinking = {"type": "enabled", "budget_tokens": 1024}
messages_for_api = self._format_message_for_api(self.messages)
response = await client.beta.messages.create(
messages=messages_for_api,
model=self.vision_model.name,
system=system_prompt,
response_content = await self._call_model(
messages=self.messages,
model=self.vision_model,
system_prompt=system_prompt,
tools=tools,
betas=betas,
thinking=thinking,
max_tokens=4096, # TODO: Make configurable?
temperature=temperature,
headers=self.model_default_headers(),
)
logger.debug(f"Anthropic response: {response.model_dump_json()}")
self.messages.append(AgentMessage(role="assistant", content=response.content))
rendered_response = self._render_response(response.content, current_state.screenshot)
self.messages.append(AgentMessage(role="assistant", content=response_content))
rendered_response = self._render_response(response_content, current_state.screenshot)
for block in response.content:
# Parse actions from response
for block in response_content:
if block.type == "tool_use":
content = None
is_error = False
@@ -179,6 +135,40 @@ class AnthropicOperatorAgent(OperatorAgent):
logger.warning("Goto tool called without URL.")
elif tool_name == "back":
action_to_run = BackAction()
elif tool_name == self.model_default_tool("terminal")["name"]:
command = tool_input.get("command")
restart = tool_input.get("restart", False)
if command:
action_to_run = TerminalAction(command=command, restart=restart)
elif tool_name == "str_replace_based_edit_tool":
# Handle text editor tool calls
command = tool_input.get("command")
if command == "view":
path = tool_input.get("path")
view_range = tool_input.get("view_range")
if path:
action_to_run = TextEditorViewAction(path=path, view_range=view_range)
elif command == "create":
path = tool_input.get("path")
file_text = tool_input.get("file_text", "")
if path:
action_to_run = TextEditorCreateAction(path=path, file_text=file_text)
elif command == "str_replace":
path = tool_input.get("path")
old_str = tool_input.get("old_str")
new_str = tool_input.get("new_str")
if path and old_str is not None and new_str is not None:
action_to_run = TextEditorStrReplaceAction(path=path, old_str=old_str, new_str=new_str)
elif command == "insert":
path = tool_input.get("path")
insert_line = tool_input.get("insert_line")
new_str = tool_input.get("new_str")
if path and insert_line is not None and new_str is not None:
action_to_run = TextEditorInsertAction(
path=path, insert_line=insert_line, new_str=new_str
)
else:
logger.warning(f"Unsupported text editor command: {command}")
else:
logger.warning(f"Unsupported Anthropic computer action type: {tool_name}")
@@ -200,14 +190,6 @@ class AnthropicOperatorAgent(OperatorAgent):
}
)
self._update_usage(
response.usage.input_tokens,
response.usage.output_tokens,
response.usage.cache_read_input_tokens,
response.usage.cache_creation_input_tokens,
)
self.tracer["temperature"] = temperature
return AgentActResult(
actions=actions,
action_results=action_results,
@@ -240,18 +222,19 @@ class AnthropicOperatorAgent(OperatorAgent):
if env_step.error:
action_result["is_error"] = True
# Append tool results to the message history
self.messages += [AgentMessage(role="environment", content=agent_action.action_results)]
# Mark the final tool result as a cache break point
agent_action.action_results[-1]["cache_control"] = {"type": "ephemeral"}
# Remove previous cache controls
for msg in self.messages:
if msg.role == "environment" and isinstance(msg.content, list):
if isinstance(msg.content, list):
for block in msg.content:
if isinstance(block, dict) and "cache_control" in block:
del block["cache_control"]
# Mark the final tool result as a cache break point
agent_action.action_results[-1]["cache_control"] = {"type": "ephemeral"}
# Append tool results to the message history
self.messages += [AgentMessage(role="environment", content=agent_action.action_results)]
def _format_message_for_api(self, messages: list[AgentMessage]) -> list[dict]:
"""Format Anthropic response into a single string."""
formatted_messages = []
@@ -270,7 +253,7 @@ class AnthropicOperatorAgent(OperatorAgent):
)
return formatted_messages
def compile_response(self, response_content: list[BetaContentBlock | dict] | str) -> str:
def _compile_response(self, response_content: list[BetaContentBlock | dict] | str) -> str:
"""Compile Anthropic response into a single string."""
if isinstance(response_content, str):
return response_content
@@ -288,7 +271,11 @@ class AnthropicOperatorAgent(OperatorAgent):
compiled_response.append(block.text)
elif block.type == "tool_use":
block_input = {"action": block.name}
if block.name == "computer":
if block.name in (
self.model_default_tool("computer")["name"],
self.model_default_tool("editor")["name"],
self.model_default_tool("terminal")["name"],
):
block_input = block.input # Computer action details are in input dict
elif block.name == "goto":
block_input["url"] = block.input.get("url", "[Missing URL]")
@@ -345,7 +332,34 @@ class AnthropicOperatorAgent(OperatorAgent):
else:
# Handle other actions
render_texts += [f"{action.capitalize()}"]
elif block.name == self.model_default_tool("editor")["name"]:
# Handle text editor actions
command = block.input.get("command")
if command == "view":
path = block.input.get("path")
view_range = block.input.get("view_range")
if path:
render_texts += [f"View file: {path} (lines {view_range})"]
elif command == "create":
path = block.input.get("path")
file_text = block.input.get("file_text", "")
if path:
render_texts += [f"Create file: {path} with content:\n{file_text}"]
elif command == "str_replace":
path = block.input.get("path")
old_str = block.input.get("old_str")
new_str = block.input.get("new_str")
if path and old_str is not None and new_str is not None:
render_texts += [f"File: {path}\n**Find**\n{old_str}\n**Replace**\n{new_str}'"]
elif command == "insert":
path = block.input.get("path")
insert_line = block.input.get("insert_line")
new_str = block.input.get("new_str")
if path and insert_line is not None and new_str is not None:
render_texts += [f"In file: {path} at line {insert_line} insert\n{new_str}"]
render_texts += [f"Edit file: {block.input['path']}"]
elif block.name == self.model_default_tool("terminal")["name"]:
render_texts += [f"Run command:\n{block.input['command']}"]
# If screenshot is not available when screenshot action was requested
if isinstance(block.input, dict) and block.input.get("action") == "screenshot" and not screenshot:
render_texts += ["Failed to get screenshot"]
@@ -365,6 +379,107 @@ class AnthropicOperatorAgent(OperatorAgent):
return render_payload
async def _call_model(
self,
messages: list[AgentMessage],
model: ChatModel,
system_prompt: str,
tools: list[dict] = [],
headers: list[str] = [],
temperature: float = 1.0,
max_tokens: int = 4096,
) -> list[BetaContentBlock]:
client = get_anthropic_async_client(model.ai_model_api.api_key, model.ai_model_api.api_base_url)
thinking: dict[str, str | int] = {"type": "disabled"}
system = [{"type": "text", "text": system_prompt, "cache_control": {"type": "ephemeral"}}]
kwargs: dict = {}
if is_reasoning_model(model.name):
thinking = {"type": "enabled", "budget_tokens": 1024}
if headers:
kwargs["betas"] = headers
if tools:
tools[-1]["cache_control"] = {"type": "ephemeral"} # Mark last tool as cache break point
kwargs["tools"] = tools
messages_for_api = self._format_message_for_api(messages)
try:
response = await client.beta.messages.create(
messages=messages_for_api,
model=model.name,
system=system,
thinking=thinking,
max_tokens=max_tokens,
temperature=temperature,
**kwargs,
)
response_content = response.content
except Exception as e:
# create a response block with error message
logger.error(f"Error during Anthropic API call: {e}")
error_str = e.message if hasattr(e, "message") else str(e)
response = None
response_content = [BetaTextBlock(text=f"Communication Error: {error_str}", type="text")]
if response:
logger.debug(f"Anthropic response: {response.model_dump_json()}")
self._update_usage(
response.usage.input_tokens,
response.usage.output_tokens,
response.usage.cache_read_input_tokens,
response.usage.cache_creation_input_tokens,
)
self.tracer["temperature"] = temperature
return response_content
async def _compress(self):
# 1. Prepare messages for compression
original_messages = list(self.messages)
messages_to_summarize = self.messages[: self.compress_length]
# ensure last message isn't a tool call request
if messages_to_summarize[-1].role == "assistant" and (
any(isinstance(block, BetaToolUseBlock) for block in messages_to_summarize[-1].content)
or any(block["type"] == "tool_use" for block in messages_to_summarize[-1].content)
):
messages_to_summarize.pop()
summarize_prompt = f"Summarize your research and computer use till now to help answer my query:\n{self.query}"
summarize_message = AgentMessage(role="user", content=summarize_prompt)
system_prompt = dedent(
"""
You are a computer operator with meticulous communication skills. You can condense your partial computer use traces and research into an appropriately detailed summary.
When requested summarize your key actions, results and findings until now to achieve the user specified task.
Your summary should help you remember the key information required to both complete the task and later generate a final report.
"""
)
# 2. Get summary of operation trajectory
try:
response_content = await self._call_model(
messages=messages_to_summarize + [summarize_message],
model=self.vision_model,
system_prompt=system_prompt,
max_tokens=8192,
)
except Exception as e:
# create a response block with error message
logger.error(f"Error during Anthropic API call: {e}")
error_str = e.message if hasattr(e, "message") else str(e)
response_content = [BetaTextBlock(text=f"Communication Error: {error_str}", type="text")]
summary_message = AgentMessage(role="assistant", content=response_content)
# 3. Rebuild message history with condensed trajectory
primary_task = [original_messages.pop(0)]
condensed_trajectory = [summarize_message, summary_message]
recent_trajectory = original_messages[self.compress_length - 1 :] # -1 since we popped the first message
# ensure first message isn't a tool result
if recent_trajectory[0].role == "environment" and any(
block["type"] == "tool_result" for block in recent_trajectory[0].content
):
recent_trajectory.pop(0)
self.messages = primary_task + condensed_trajectory + recent_trajectory
def get_coordinates(self, tool_input: dict, key: str = "coordinate") -> Optional[list | tuple]:
"""Get coordinates from tool input."""
raw_coord = tool_input.get(key)
@@ -382,14 +497,22 @@ class AnthropicOperatorAgent(OperatorAgent):
return coord
def model_default_tool(self, tool_type: Literal["computer", "editor", "terminal"]) -> str:
def model_default_tool(self, tool_type: Literal["computer", "editor", "terminal"]) -> dict[str, str]:
"""Get the default tool of specified type for the given model."""
if self.vision_model.name.startswith("claude-3-7-sonnet"):
if tool_type == "computer":
return "computer_20250124"
return {"name": "computer", "type": "computer_20250124"}
elif tool_type == "editor":
return {"name": "str_replace_editor", "type": "text_editor_20250124"}
elif tool_type == "terminal":
return {"name": "bash_20250124", "type": "bash"}
elif self.vision_model.name.startswith("claude-sonnet-4") or self.vision_model.name.startswith("claude-opus-4"):
if tool_type == "computer":
return "computer_20250124"
return {"name": "computer", "type": "computer_20250124"}
elif tool_type == "editor":
return {"name": "str_replace_based_edit_tool", "type": "text_editor_20250429"}
elif tool_type == "terminal":
return {"name": "bash", "type": "bash_20250124"}
raise ValueError(f"Unsupported tool type for model '{self.vision_model.name}': {tool_type}")
def model_default_headers(self) -> list[str]:
@@ -400,3 +523,88 @@ class AnthropicOperatorAgent(OperatorAgent):
return ["computer-use-2025-01-24"]
else:
return []
def get_instructions(self, environment_type: EnvironmentType, current_state: EnvState) -> str:
"""Return system instructions for the Anthropic operator."""
if environment_type == EnvironmentType.BROWSER:
return dedent(
f"""
<SYSTEM_CAPABILITY>
* You are Khoj, a smart web browser operating assistant. You help the users accomplish tasks using a web browser.
* You operate a Chromium browser using Playwright via the 'computer' tool.
* You cannot access the OS or filesystem.
* You can interact with the web browser to perform tasks like clicking, typing, scrolling, and more.
* You can use the additional back() and goto() helper functions to ease navigating the browser. If you see nothing, try goto duckduckgo.com
* When viewing a webpage it can be helpful to zoom out so that you can see everything on the page. Either that, or make sure you scroll down to see everything before deciding something isn't available.
* When using your computer function calls, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request.
* Perform web searches using DuckDuckGo. Don't use Google even if requested as the query will fail.
* The current date is {datetime.today().strftime('%A, %B %-d, %Y')}.
* The current URL is {current_state.url}.
</SYSTEM_CAPABILITY>
<IMPORTANT>
* You are allowed upto {self.max_iterations} iterations to complete the task.
* Do not loop on wait, screenshot for too many turns without taking any action.
* After initialization if the browser is blank, enter a website URL using the goto() function instead of waiting
</IMPORTANT>
"""
).lstrip()
elif environment_type == EnvironmentType.COMPUTER:
return dedent(
f"""
<SYSTEM_CAPABILITY>
* You are Khoj, a smart computer operating assistant. You help the users accomplish tasks using a computer.
* You can interact with the computer to perform tasks like clicking, typing, scrolling, and more.
* When viewing a document or webpage it can be helpful to zoom out or scroll down to ensure you see everything before deciding something isn't available.
* When using your computer function calls, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request.
* Perform web searches using DuckDuckGo. Don't use Google even if requested as the query will fail.
* Do not loop on wait, screenshot for too many turns without taking any action.
* You are allowed upto {self.max_iterations} iterations to complete the task.
</SYSTEM_CAPABILITY>
<CONTEXT>
* The current date is {datetime.today().strftime('%A, %B %-d, %Y')}.
</CONTEXT>
"""
).lstrip()
else:
raise ValueError(f"Unsupported environment type for Anthropic operator: {environment_type}")
def get_tools(self, environment: EnvironmentType, current_state: EnvState) -> list[dict]:
"""Return the tools available for the Anthropic operator."""
tools: list[dict] = [
{
"type": self.model_default_tool("computer")["type"],
"name": "computer",
"display_width_px": current_state.width,
"display_height_px": current_state.height,
},
{
"type": self.model_default_tool("editor")["type"],
"name": self.model_default_tool("editor")["name"],
},
{
"type": self.model_default_tool("terminal")["type"],
"name": self.model_default_tool("terminal")["name"],
},
]
if environment == "browser":
tools += [
{
"name": "back",
"description": "Go back to the previous page.",
"input_schema": {"type": "object", "properties": {}},
},
{
"name": "goto",
"description": "Go to a specific URL.",
"input_schema": {
"type": "object",
"properties": {"url": {"type": "string", "description": "Fully qualified URL to navigate to."}},
"required": ["url"],
},
},
]
return tools

View File

@@ -5,9 +5,17 @@ from typing import List, Literal, Optional, Union
from pydantic import BaseModel
from khoj.database.models import ChatModel
from khoj.processor.conversation.utils import commit_conversation_trace
from khoj.processor.conversation.utils import (
AgentMessage,
OperatorRun,
commit_conversation_trace,
)
from khoj.processor.operator.operator_actions import OperatorAction
from khoj.processor.operator.operator_environment_base import EnvState, EnvStepResult
from khoj.processor.operator.operator_environment_base import (
EnvironmentType,
EnvState,
EnvStepResult,
)
from khoj.utils.helpers import get_chat_usage_metrics, is_promptrace_enabled
logger = logging.getLogger(__name__)
@@ -19,18 +27,41 @@ class AgentActResult(BaseModel):
rendered_response: Optional[dict] = None
class AgentMessage(BaseModel):
role: Literal["user", "assistant", "system", "environment"]
content: Union[str, List]
class OperatorAgent(ABC):
def __init__(self, query: str, vision_model: ChatModel, max_iterations: int, tracer: dict):
def __init__(
self,
query: str,
vision_model: ChatModel,
environment_type: EnvironmentType,
max_iterations: int,
max_context: int,
chat_history: List[AgentMessage] = [],
previous_trajectory: Optional[OperatorRun] = None,
tracer: dict = {},
):
self.query = query
self.vision_model = vision_model
self.environment_type = environment_type
self.max_iterations = max_iterations
self.tracer = tracer
self.messages: List[AgentMessage] = []
self.summarize_prompt = f"Use the results of our research to provide a comprehensive, self-contained answer for the target query:\n{query}."
self.messages: List[AgentMessage] = chat_history
if previous_trajectory:
# Remove tool call from previous trajectory as tool call w/o result not supported
if previous_trajectory.trajectory and previous_trajectory.trajectory[-1].role == "assistant":
previous_trajectory.trajectory.pop()
self.messages += previous_trajectory.trajectory
self.messages += [AgentMessage(role="user", content=query)]
# Context compression parameters
self.context_compress_trigger = 2e3 # heuristic to determine compression trigger
# turns after which compression triggered. scales with model max context size. Minimum 5 turns.
self.message_limit = 2 * max(5, int(max_context / self.context_compress_trigger))
# compression ratio determines how many messages to compress down to one
# e.g. if 5 messages, a compress ratio of 4/5 means compress 5 messages into 1 + keep 1 uncompressed
self.message_compress_ratio = 4 / 5
self.compress_length = int(self.message_limit * self.message_compress_ratio)
@abstractmethod
async def act(self, current_state: EnvState) -> AgentActResult:
@@ -41,16 +72,17 @@ class OperatorAgent(ABC):
"""Track results of agent actions on the environment."""
pass
async def summarize(self, summarize_prompt: str, current_state: EnvState) -> str:
async def summarize(self, current_state: EnvState, summarize_prompt: str = None) -> str:
"""Summarize the agent's actions and results."""
summarize_prompt = summarize_prompt or self.summarize_prompt
self.messages.append(AgentMessage(role="user", content=summarize_prompt))
await self.act(current_state)
if not self.messages:
return "No actions to summarize."
return self.compile_response(self.messages[-1].content)
return self._compile_response(self.messages[-1].content)
@abstractmethod
def compile_response(self, response: List | str) -> str:
def _compile_response(self, response: List | str) -> str:
pass
@abstractmethod
@@ -65,13 +97,12 @@ class OperatorAgent(ABC):
self.tracer["usage"] = get_chat_usage_metrics(
self.vision_model.name, input_tokens, output_tokens, cache_read, cache_write, usage=self.tracer.get("usage")
)
logger.debug(f"Operator usage by {self.vision_model.model_type}: {self.tracer['usage']}")
def _commit_trace(self):
self.tracer["chat_model"] = self.vision_model.name
if is_promptrace_enabled() and len(self.messages) > 1:
compiled_messages = [
AgentMessage(role=msg.role, content=self.compile_response(msg.content)) for msg in self.messages
AgentMessage(role=msg.role, content=self._compile_response(msg.content)) for msg in self.messages
]
commit_conversation_trace(compiled_messages[:-1], compiled_messages[-1].content, self.tracer)

View File

@@ -1,21 +1,24 @@
import json
import logging
from datetime import datetime
from textwrap import dedent
from typing import List, Optional
from openai.types.chat import ChatCompletion
from khoj.database.models import ChatModel
from khoj.processor.conversation.utils import construct_structured_message
from khoj.processor.conversation.utils import (
AgentMessage,
OperatorRun,
construct_structured_message,
)
from khoj.processor.operator.grounding_agent import GroundingAgent
from khoj.processor.operator.grounding_agent_uitars import GroundingAgentUitars
from khoj.processor.operator.operator_actions import *
from khoj.processor.operator.operator_agent_base import (
AgentActResult,
AgentMessage,
OperatorAgent,
from khoj.processor.operator.operator_agent_base import AgentActResult, OperatorAgent
from khoj.processor.operator.operator_environment_base import (
EnvironmentType,
EnvState,
EnvStepResult,
)
from khoj.processor.operator.operator_environment_base import EnvState, EnvStepResult
from khoj.routers.helpers import send_message_to_model_wrapper
from khoj.utils.helpers import get_openai_async_client, is_none_or_empty
@@ -27,7 +30,7 @@ class BinaryOperatorAgent(OperatorAgent):
"""
An OperatorAgent that uses two LLMs:
1. Reasoning LLM: Determines the next high-level action based on the objective and current visual reasoning trajectory.
2. Grounding LLM: Converts the high-level action into specific, executable browser actions.
2. Grounding LLM: Converts the high-level action into specific, actions executable on the environment.
"""
def __init__(
@@ -35,10 +38,23 @@ class BinaryOperatorAgent(OperatorAgent):
query: str,
reasoning_model: ChatModel,
grounding_model: ChatModel,
environment_type: EnvironmentType,
max_iterations: int,
tracer: dict,
max_context: int,
chat_history: List[AgentMessage] = [],
previous_trajectory: Optional[OperatorRun] = None,
tracer: dict = {},
):
super().__init__(query, reasoning_model, max_iterations, tracer) # Use reasoning model for primary tracking
super().__init__(
query,
reasoning_model,
environment_type,
max_iterations,
max_context,
chat_history,
previous_trajectory,
tracer,
) # Use reasoning model for primary tracking
self.reasoning_model = reasoning_model
self.grounding_model = grounding_model
# Initialize openai api compatible client for grounding model
@@ -49,10 +65,12 @@ class BinaryOperatorAgent(OperatorAgent):
self.grounding_agent: GroundingAgent | GroundingAgentUitars = None
if "ui-tars-1.5" in grounding_model.name:
self.grounding_agent = GroundingAgentUitars(
grounding_model.name, grounding_client, max_iterations, environment_type="web", tracer=tracer
grounding_model.name, self.environment_type, grounding_client, max_iterations, tracer=tracer
)
else:
self.grounding_agent = GroundingAgent(grounding_model.name, grounding_client, max_iterations, tracer=tracer)
self.grounding_agent = GroundingAgent(
grounding_model.name, self.environment_type, grounding_client, max_iterations, tracer=tracer
)
async def act(self, current_state: EnvState) -> AgentActResult:
"""
@@ -84,48 +102,7 @@ class BinaryOperatorAgent(OperatorAgent):
"""
Uses the reasoning LLM to determine the next high-level action based on the operation trajectory.
"""
reasoning_system_prompt = f"""
# Introduction
* You are Khoj, a smart and resourceful web browsing assistant. You help the user accomplish their task using a web browser.
* You are given the user's query and screenshots of the browser's state transitions.
* The current date is {datetime.today().strftime('%A, %B %-d, %Y')}.
* The current URL is {current_state.url}.
# Your Task
* First look at the screenshots carefully to notice all pertinent information.
* Then instruct a tool AI to perform the next action that will help you progress towards the user's goal.
* Make sure you scroll down to see everything before deciding something isn't available.
* Perform web searches using DuckDuckGo. Don't use Google even if requested as the query will fail.
* Use your creativity to find alternate ways to make progress if you get stuck at any point.
# Tool AI Capabilities
* The tool AI only has access to the current screenshot and your instructions. It uses your instructions to perform the next action on the page.
* It can interact with the web browser with these actions: click, right click, double click, type, scroll, drag, wait, goto url and go back to previous page.
* It cannot access the OS, filesystem or application window. It just controls a single Chromium browser tab via Playwright.
# IMPORTANT
* You are allowed upto {self.max_iterations} iterations to complete the task.
* To navigate to a specific URL, put "GOTO <URL>" (without quotes) on the last line of your response.
* To navigate back to the previous page, end your response with "BACK" (without quotes).
* Once you've verified that the main objective has been achieved, end your response with "DONE" (without quotes).
# Examples
## Example 1
GOTO https://example.com
## Example 2
click the blue login button located at the top right corner
## Example 3
scroll down the page
## Example 4
type the username example@email.com into the input field labeled Username
## Example 5
DONE
# Instructions
Now describe a single high-level action to take next to progress towards the user's goal in detail.
Focus on the visual action and provide all necessary context.
""".strip()
reasoning_system_prompt = self.get_instruction(self.environment_type, current_state)
if is_none_or_empty(self.messages):
query_text = f"**Main Objective**: {self.query}"
query_screenshot = [f"data:image/webp;base64,{current_state.screenshot}"]
@@ -259,7 +236,8 @@ Focus on the visual action and provide all necessary context.
action_results_content.extend(action_result["content"])
self.messages.append(AgentMessage(role="environment", content=action_results_content))
async def summarize(self, summarize_prompt: str, env_state: EnvState) -> str:
async def summarize(self, env_state: EnvState, summarize_prompt: str = None) -> str:
summarize_prompt = summarize_prompt or self.summarize_prompt
conversation_history = {"chat": self._format_message_for_api(self.messages)}
try:
summary = await send_message_to_model_wrapper(
@@ -282,7 +260,7 @@ Focus on the visual action and provide all necessary context.
return summary
def compile_response(self, response_content: str | List) -> str:
def _compile_response(self, response_content: str | List) -> str:
"""Compile response content into a string, handling OpenAI message structures."""
if isinstance(response_content, str):
return response_content
@@ -330,6 +308,96 @@ Focus on the visual action and provide all necessary context.
]
return formatted_messages
def get_instruction(self, environment_type: EnvironmentType, env_state: EnvState) -> str:
"""Get the system instruction for the reasoning agent."""
if environment_type == EnvironmentType.BROWSER:
return dedent(
f"""
# Introduction
* You are Khoj, a smart and resourceful web browsing assistant. You help the user accomplish their task using a web browser.
* You are given the user's query and screenshots of the browser's state transitions.
* The current date is {datetime.today().strftime('%A, %B %-d, %Y')}.
* The current URL is {env_state.url}.
# Your Task
* First look at the screenshots carefully to notice all pertinent information.
* Then instruct a tool AI to perform the next action that will help you progress towards the user's goal.
* Make sure you scroll down to see everything before deciding something isn't available.
* Perform web searches using DuckDuckGo. Don't use Google even if requested as the query will fail.
* Use your creativity to find alternate ways to make progress if you get stuck at any point.
# Tool AI Capabilities
* The tool AI only has access to the current screenshot and your instructions. It uses your instructions to perform the next action on the page.
* It can interact with the web browser with these actions: click, right click, double click, type, scroll, drag, wait, goto url and go back to previous page.
* It cannot access the OS, filesystem or application window. It just controls a single Chromium browser tab via Playwright.
# IMPORTANT
* You are allowed upto {self.max_iterations} iterations to complete the task.
* To navigate to a specific URL, put "GOTO <URL>" (without quotes) on the last line of your response.
* To navigate back to the previous page, end your response with "BACK" (without quotes).
* Once you've verified that the main objective has been achieved, end your response with "DONE" (without quotes).
# Examples
## Example 1
GOTO https://example.com
## Example 2
click the blue login button located at the top right corner
## Example 3
scroll down the page
## Example 4
type the username example@email.com into the input field labeled Username
## Example 5
DONE
# Instructions
Now describe a single high-level action to take next to progress towards the user's goal in detail.
Focus on the visual action and provide all necessary context.
"""
).strip()
elif environment_type == EnvironmentType.COMPUTER:
return dedent(
f"""
# Introduction
* You are Khoj, a smart and resourceful computer assistant. You help the user accomplish their task using a computer.
* You are given the user's query and screenshots of the computer's state transitions.
* The current date is {datetime.today().strftime('%A, %B %-d, %Y')}.
# Your Task
* First look at the screenshots carefully to notice all pertinent information.
* Then instruct a tool AI to perform the next action that will help you progress towards the user's goal.
* Make sure you scroll down to see everything before deciding something isn't available.
* Perform web searches using DuckDuckGo. Don't use Google even if requested as the query will fail.
* Use your creativity to find alternate ways to make progress if you get stuck at any point.
# Tool AI Capabilities
* The tool AI only has access to the current screenshot and your instructions. It uses your instructions to perform the next action on the page.
* It can interact with the computer with these actions: click, right click, double click, type, scroll, drag, wait to previous page.
# IMPORTANT
* You are allowed upto {self.max_iterations} iterations to complete the task.
* Once you've verified that the main objective has been achieved, end your response with "DONE" (without quotes).
# Examples
## Example 1
type https://example.com into the address bar and press Enter
## Example 2
click the blue login button located at the top right corner
## Example 3
scroll down the page
## Example 4
type the username example@email.com into the input field labeled Username
## Example 5
DONE
# Instructions
Now describe a single high-level action to take next to progress towards the user's goal in detail.
Focus on the visual action and provide all necessary context.
"""
).strip()
else:
raise ValueError(f"Expected environment type: Computer or Browser. Got {environment_type}.")
def reset(self):
"""Reset the agent state."""
super().reset()

View File

@@ -1,18 +1,22 @@
import json
import logging
import platform
from copy import deepcopy
from datetime import datetime
from textwrap import dedent
from typing import List, Optional, cast
from openai.types.responses import Response, ResponseOutputItem
from khoj.database.models import ChatModel
from khoj.processor.conversation.utils import AgentMessage
from khoj.processor.operator.operator_actions import *
from khoj.processor.operator.operator_agent_base import (
AgentActResult,
AgentMessage,
OperatorAgent,
from khoj.processor.operator.operator_agent_base import AgentActResult, OperatorAgent
from khoj.processor.operator.operator_environment_base import (
EnvironmentType,
EnvState,
EnvStepResult,
)
from khoj.processor.operator.operator_environment_base import EnvState, EnvStepResult
from khoj.utils.helpers import get_openai_async_client, is_none_or_empty
logger = logging.getLogger(__name__)
@@ -21,80 +25,18 @@ logger = logging.getLogger(__name__)
# --- Anthropic Operator Agent ---
class OpenAIOperatorAgent(OperatorAgent):
async def act(self, current_state: EnvState) -> AgentActResult:
client = get_openai_async_client(
self.vision_model.ai_model_api.api_key, self.vision_model.ai_model_api.api_base_url
)
safety_check_prefix = "Say 'continue' after resolving the following safety checks to proceed:"
safety_check_message = None
actions: List[OperatorAction] = []
action_results: List[dict] = []
self._commit_trace() # Commit trace before next action
system_prompt = f"""<SYSTEM_CAPABILITY>
* You are Khoj, a smart web browser operating assistant. You help the users accomplish tasks using a web browser.
* You operate a single Chromium browser page using Playwright.
* You cannot access the OS or filesystem.
* You can interact with the web browser to perform tasks like clicking, typing, scrolling, and more using the computer_use_preview tool.
* You can use the additional back() and goto() functions to navigate the browser.
* Always use the goto() function to navigate to a specific URL. If you see nothing, try goto duckduckgo.com
* When viewing a webpage it can be helpful to zoom out so that you can see everything on the page. Either that, or make sure you scroll down to see everything before deciding something isn't available.
* When using your computer function calls, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request.
* Perform web searches using DuckDuckGo. Don't use Google even if requested as the query will fail.
* The current date is {datetime.today().strftime('%A, %B %-d, %Y')}.
* The current URL is {current_state.url}.
</SYSTEM_CAPABILITY>
<IMPORTANT>
* You are allowed upto {self.max_iterations} iterations to complete the task.
* After initialization if the browser is blank, enter a website URL using the goto() function instead of waiting
</IMPORTANT>
"""
tools = [
{
"type": "computer_use_preview",
"display_width": 1024, # TODO: Get from env
"display_height": 768, # TODO: Get from env
"environment": "browser",
},
{
"type": "function",
"name": "back",
"description": "Go back to the previous page.",
"parameters": {},
},
{
"type": "function",
"name": "goto",
"description": "Go to a specific URL.",
"parameters": {
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "Fully qualified URL to navigate to.",
},
},
"additionalProperties": False,
"required": ["url"],
},
},
]
system_prompt = self.get_instructions(self.environment_type, current_state)
tools = self.get_tools(self.environment_type, current_state)
if is_none_or_empty(self.messages):
self.messages = [AgentMessage(role="user", content=self.query)]
messages_for_api = self._format_message_for_api(self.messages)
response: Response = await client.responses.create(
model="computer-use-preview",
input=messages_for_api,
instructions=system_prompt,
tools=tools,
parallel_tool_calls=False, # Keep sequential for now
max_output_tokens=4096, # TODO: Make configurable?
truncation="auto",
)
logger.debug(f"Openai response: {response.model_dump_json()}")
self.messages += [AgentMessage(role="environment", content=response.output)]
response = await self._call_model(self.vision_model, system_prompt, tools)
self.messages += [AgentMessage(role="assistant", content=response.output)]
rendered_response = self._render_response(response.output, current_state.screenshot)
last_call_id = None
@@ -174,6 +116,9 @@ class OpenAIOperatorAgent(OperatorAgent):
"summary": [],
}
)
else:
logger.warning(f"Unsupported response block type: {block.type}")
content = f"Unsupported response block type: {block.type}"
if action_to_run or content:
actions.append(action_to_run)
if action_to_run or content:
@@ -220,6 +165,10 @@ class OpenAIOperatorAgent(OperatorAgent):
elif action_result["type"] == "reasoning":
items_to_pop.append(idx) # Mark placeholder reasoning action result for removal
continue
elif action_result["type"] == "computer_call" and action_result["status"] == "in_progress":
if isinstance(result_content, dict):
result_content["status"] = "completed" # Mark in-progress actions as completed
action_result["output"] = result_content
else:
# Add text data
action_result["output"] = result_content
@@ -229,11 +178,45 @@ class OpenAIOperatorAgent(OperatorAgent):
self.messages += [AgentMessage(role="environment", content=agent_action.action_results)]
async def summarize(self, current_state: EnvState, summarize_prompt: str = None) -> str:
summarize_prompt = summarize_prompt or self.summarize_prompt
self.messages.append(AgentMessage(role="user", content=summarize_prompt))
response = await self._call_model(self.vision_model, summarize_prompt, [])
self.messages += [AgentMessage(role="assistant", content=response.output)]
if not self.messages:
return "No actions to summarize."
return self._compile_response(self.messages[-1].content)
async def _call_model(self, model: ChatModel, system_prompt, tools) -> Response:
client = get_openai_async_client(model.ai_model_api.api_key, model.ai_model_api.api_base_url)
if tools:
model_name = "computer-use-preview"
else:
model_name = model.name
# Format messages for OpenAI API
messages_for_api = self._format_message_for_api(self.messages)
# format messages for summary if model is not computer-use-preview
if model_name != "computer-use-preview":
messages_for_api = self._format_messages_for_summary(messages_for_api)
response: Response = await client.responses.create(
model=model_name,
input=messages_for_api,
instructions=system_prompt,
tools=tools,
parallel_tool_calls=False,
truncation="auto",
)
logger.debug(f"Openai response: {response.model_dump_json()}")
return response
def _format_message_for_api(self, messages: list[AgentMessage]) -> list:
"""Format the message for OpenAI API."""
formatted_messages: list = []
for message in messages:
if message.role == "environment":
if message.role == "assistant":
if isinstance(message.content, list):
# Remove reasoning message if not followed by computer call
if (
@@ -252,18 +235,23 @@ class OpenAIOperatorAgent(OperatorAgent):
message.content.pop(0)
formatted_messages.extend(message.content)
else:
logger.warning(f"Expected message content list from environment, got {type(message.content)}")
logger.warning(f"Expected message content list from assistant, got {type(message.content)}")
elif message.role == "environment":
formatted_messages.extend(message.content)
else:
if isinstance(message.content, list):
message.content = "\n".join([part["text"] for part in message.content if part["type"] == "text"])
formatted_messages.append(
{
"role": message.role,
"content": message.content,
}
)
return formatted_messages
def compile_response(self, response_content: str | list[dict | ResponseOutputItem]) -> str:
"""Compile the response from model into a single string."""
def _compile_response(self, response_content: str | list[dict | ResponseOutputItem]) -> str:
"""Compile the response from model into a single string for prompt tracing."""
# Handle case where response content is a string.
# This is the case when response content is a user query
if isinstance(response_content, str):
@@ -347,3 +335,123 @@ class OpenAIOperatorAgent(OperatorAgent):
}
return render_payload
def get_instructions(self, environment_type: EnvironmentType, current_state: EnvState) -> str:
"""Return system instructions for the OpenAI operator."""
if environment_type == EnvironmentType.BROWSER:
return dedent(
f"""
<SYSTEM_CAPABILITY>
* You are Khoj, a smart web browser operating assistant. You help the users accomplish tasks using a web browser.
* You operate a single Chromium browser page using Playwright.
* You cannot access the OS or filesystem.
* You can interact with the web browser to perform tasks like clicking, typing, scrolling, and more using the computer_use_preview tool.
* You can use the additional back() and goto() functions to navigate the browser.
* Always use the goto() function to navigate to a specific URL. If you see nothing, try goto duckduckgo.com
* When viewing a webpage it can be helpful to zoom out so that you can see everything on the page. Either that, or make sure you scroll down to see everything before deciding something isn't available.
* When using your computer function calls, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request.
* Perform web searches using DuckDuckGo. Don't use Google even if requested as the query will fail.
* The current date is {datetime.today().strftime('%A, %B %-d, %Y')}.
* The current URL is {current_state.url}.
</SYSTEM_CAPABILITY>
<IMPORTANT>
* You are allowed upto {self.max_iterations} iterations to complete the task.
* After initialization if the browser is blank, enter a website URL using the goto() function instead of waiting
</IMPORTANT>
"""
).lstrip()
elif environment_type == EnvironmentType.COMPUTER:
return dedent(
f"""
<SYSTEM_CAPABILITY>
* You are Khoj, a smart computer operating assistant. You help the users accomplish their tasks using a computer.
* You can interact with the computer to perform tasks like clicking, typing, scrolling, and more using the computer_use_preview tool.
* When viewing a document or webpage it can be helpful to zoom out or scroll down to ensure you see everything before deciding something isn't available.
* When using your computer function calls, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request.
* Perform web searches using DuckDuckGo. Don't use Google even if requested as the query will fail.
* You are allowed upto {self.max_iterations} iterations to complete the task.
</SYSTEM_CAPABILITY>
<CONTEXT>
* The current date is {datetime.today().strftime('%A, %B %-d, %Y')}.
</CONTEXT>
"""
).lstrip()
else:
raise ValueError(f"Unsupported environment type: {environment_type}")
def get_tools(self, environment_type: EnvironmentType, current_state: EnvState) -> list[dict]:
"""Return the tools available for the OpenAI operator."""
if environment_type == EnvironmentType.COMPUTER:
# TODO: Get OS info from the environment
# For now, assume Linux as the environment OS
environment_os = "linux"
# environment = "mac" if platform.system() == "Darwin" else "windows" if platform.system() == "Windows" else "linux"
else:
environment_os = "browser"
tools = [
{
"type": "computer_use_preview",
"display_width": current_state.width,
"display_height": current_state.height,
"environment": environment_os,
}
]
if environment_type == EnvironmentType.BROWSER:
tools += [
{
"type": "function",
"name": "back",
"description": "Go back to the previous page.",
"parameters": {},
},
{
"type": "function",
"name": "goto",
"description": "Go to a specific URL.",
"parameters": {
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "Fully qualified URL to navigate to.",
},
},
"additionalProperties": False,
"required": ["url"],
},
},
]
return tools
def _format_messages_for_summary(self, formatted_messages: List[dict]) -> List[dict]:
"""Format messages for summary."""
# Format messages to interact with non computer use AI models
items_to_drop = [] # Track indices to drop reasoning messages
for idx, msg in enumerate(formatted_messages):
if isinstance(msg, dict) and "content" in msg:
continue
elif isinstance(msg, dict) and "output" in msg:
# Drop current_url from output as not supported for non computer operations
if "current_url" in msg["output"]:
del msg["output"]["current_url"]
formatted_messages[idx] = {"role": "user", "content": [msg["output"]]}
elif isinstance(msg, str):
formatted_messages[idx] = {"role": "user", "content": [{"type": "input_text", "text": msg}]}
else:
text = self._compile_response([msg])
if not text:
items_to_drop.append(idx) # Track index to drop reasoning message
else:
formatted_messages[idx] = {
"role": "assistant",
"content": [{"type": "output_text", "text": text}],
}
# Remove reasoning messages for non-computer use models
for idx in reversed(items_to_drop):
formatted_messages.pop(idx)
return formatted_messages

View File

@@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import Literal, Optional
from pydantic import BaseModel
@@ -6,9 +7,18 @@ from pydantic import BaseModel
from khoj.processor.operator.operator_actions import OperatorAction
class EnvironmentType(Enum):
"""Type of environment to operate."""
COMPUTER = "computer"
BROWSER = "browser"
class EnvState(BaseModel):
url: str
height: int
width: int
screenshot: Optional[str] = None
url: Optional[str] = None
class EnvStepResult(BaseModel):

View File

@@ -5,7 +5,7 @@ import logging
import os
from typing import Optional, Set, Union
from khoj.processor.operator.operator_actions import OperatorAction, Point
from khoj.processor.operator.operator_actions import DragAction, OperatorAction, Point
from khoj.processor.operator.operator_environment_base import (
Environment,
EnvState,
@@ -124,10 +124,10 @@ class BrowserEnvironment(Environment):
async def get_state(self) -> EnvState:
if not self.page or self.page.is_closed():
return EnvState(url="about:blank", screenshot=None)
return EnvState(url="about:blank", screenshot=None, height=self.height, width=self.width)
url = self.page.url
screenshot = await self._get_screenshot()
return EnvState(url=url, screenshot=screenshot)
return EnvState(url=url, screenshot=screenshot, height=self.height, width=self.width)
async def step(self, action: OperatorAction) -> EnvStepResult:
if not self.page or self.page.is_closed():
@@ -246,6 +246,8 @@ class BrowserEnvironment(Environment):
logger.debug(f"Action: {action.type} to ({x},{y})")
case "drag":
if not isinstance(action, DragAction):
raise TypeError(f"Invalid action type for drag")
path = action.path
if not path:
error = "Missing path for drag action"

View File

@@ -0,0 +1,658 @@
import ast
import asyncio
import base64
import io
import logging
import platform
import subprocess
from typing import Literal, Optional, Union
from PIL import Image, ImageDraw
from khoj.processor.operator.operator_actions import DragAction, OperatorAction, Point
from khoj.processor.operator.operator_environment_base import (
Environment,
EnvState,
EnvStepResult,
)
from khoj.utils.helpers import convert_image_to_webp
logger = logging.getLogger(__name__)
# --- Concrete Computer Environment ---
class ComputerEnvironment(Environment):
def __init__(
self,
provider: Literal["local", "docker"] = "local",
docker_display: str = ":99",
docker_container_name: str = "khoj-computer",
):
self.provider = provider
self.docker_display = docker_display
self.docker_container_name = docker_container_name
self.width: int = 0
self.height: int = 0
self.mouse_pos: Point = Point(x=0, y=0)
async def _execute(self, func_name, *args, **kwargs):
"""
Executes a pyautogui function, abstracting the execution context.
Currently runs locally using asyncio.to_thread.
"""
python_command_str = self.generate_pyautogui_command(func_name, *args, **kwargs)
# Docker execution
if self.provider == "docker":
try:
output_str = await self.docker_execute(python_command_str)
except RuntimeError as e: # Catch other Docker execution errors
logger.error(f"Error during Docker execution of {func_name}: {e}")
raise # Re-raise as a general error for the caller to handle
# Local execution
else:
process = await asyncio.to_thread(
subprocess.run,
["python3", "-c", python_command_str],
capture_output=True,
text=True,
check=False, # We check returncode manually
)
output_str = process.stdout.strip()
if process.returncode != 0:
if "FailSafeException" in process.stderr or "FailSafeException" in process.stdout:
# Extract the message if possible, otherwise use generic
fs_msg = process.stderr or process.stdout
raise KeyboardInterrupt(fs_msg)
else:
error_msg = (
f'Local script execution failed:\nCmd: python3 -c "{python_command_str[:200]}...{python_command_str[-200:]}\n'
f"Return Code: {process.returncode}\nStderr: {process.stderr}\nStdout: {process.stdout}"
)
logger.error(error_msg)
raise RuntimeError(f"Local script execution error: {process.stderr or process.stdout}")
if not output_str or output_str == "None":
return None
try:
return ast.literal_eval(output_str)
except (ValueError, SyntaxError):
# If not a literal (e.g., some other string output), return as is
return output_str
async def start(self, width: int, height: int) -> None:
"""
Initializes the computer environment.
The width and height parameters are logged, but actual screen dimensions are used.
"""
screen_width, screen_height = await self._execute("size")
self.width = screen_width
self.height = screen_height
# Initialize mouse position to center, or current if available
try:
current_x, current_y = await self._execute("position")
self.mouse_pos = Point(x=current_x, y=current_y)
except Exception: # Fallback if position cannot be obtained initially
self.mouse_pos = Point(x=self.width / 2, y=self.height / 2)
logger.info(
f"Computer environment started. Screen size: {self.width}x{self.height}. "
f"Input width/height ({width}x{height}) are noted but screen dimensioning uses actual screen size. "
f"Initial mouse position: ({self.mouse_pos.x},{self.mouse_pos.y})"
)
async def _get_screenshot(self) -> Optional[str]:
try:
# Get screenshot
base64_png_str = await self._execute("screenshot")
screenshot_bytes = base64.b64decode(base64_png_str)
# Get current mouse position
current_mouse_x, current_mouse_y = await self._execute("position")
draw_pos = Point(x=current_mouse_x, y=current_mouse_y)
# Add mouse position to screenshot
screenshot_bytes_with_mouse = await self._draw_mouse_position(screenshot_bytes, draw_pos)
screenshot_webp_bytes = convert_image_to_webp(screenshot_bytes_with_mouse)
return base64.b64encode(screenshot_webp_bytes).decode("utf-8")
except KeyboardInterrupt: # Propagate keyboard interrupts
raise
except Exception as e:
logger.error(f"Failed to get screenshot: {e}", exc_info=True)
return None
async def _draw_mouse_position(self, screenshot_bytes: bytes, mouse_pos: Point) -> bytes:
if Image is None or ImageDraw is None:
return screenshot_bytes
try:
image = Image.open(io.BytesIO(screenshot_bytes))
draw = ImageDraw.Draw(image)
radius = 8
# Red circle with black border for better visibility
draw.ellipse(
(mouse_pos.x - radius, mouse_pos.y - radius, mouse_pos.x + radius, mouse_pos.y + radius),
outline="black",
fill="red",
width=2,
)
output_buffer = io.BytesIO()
image.save(output_buffer, format="PNG")
return output_buffer.getvalue()
except Exception as e:
logger.error(f"Failed to draw mouse position: {e}")
return screenshot_bytes
async def get_state(self) -> EnvState:
screenshot = await self._get_screenshot()
return EnvState(screenshot=screenshot, height=self.height, width=self.width)
async def step(self, action: OperatorAction) -> EnvStepResult:
output: Optional[Union[str, dict]] = None
error: Optional[str] = None
step_type: str = "text"
try:
match action.type:
case "click":
x, y, button_name = action.x, action.y, action.button
modifiers_to_press = self.parse_key_combination(action.modifiers) if action.modifiers else []
for mod_key in modifiers_to_press:
await self._execute("keyDown", mod_key)
if button_name == "wheel":
# Perform a small scroll action at this position (e.g., one "tick" down)
# Pyautogui scroll: positive up, negative down.
await self._execute("scroll", -1, x=x, y=y)
output = f"Scrolled wheel at ({x}, {y})"
else:
pyautogui_button = button_name.lower() if button_name else "left"
await self._execute("click", x=x, y=y, button=pyautogui_button)
output = f"{button_name.capitalize() if button_name else 'Left'} clicked at ({x}, {y})"
for mod_key in reversed(modifiers_to_press):
await self._execute("keyUp", mod_key)
self.mouse_pos = Point(x=x, y=y)
logger.debug(f"Action: {action.type} {button_name} at ({x},{y}) with modifiers {action.modifiers}")
case "double_click":
x, y = action.x, action.y
await self._execute("doubleClick", x=x, y=y)
self.mouse_pos = Point(x=x, y=y)
output = f"Double clicked at ({x}, {y})"
logger.debug(f"Action: {action.type} at ({x},{y})")
case "triple_click":
x, y = action.x, action.y
await self._execute("click", x=x, y=y, clicks=3)
self.mouse_pos = Point(x=x, y=y)
output = f"Triple clicked at ({x}, {y})"
logger.debug(f"Action: {action.type} at ({x},{y})")
case "scroll":
current_x_pos, current_y_pos = await self._execute("position")
target_x = action.x if action.x is not None else current_x_pos
target_y = action.y if action.y is not None else current_y_pos
if target_x != current_x_pos or target_y != current_y_pos:
await self._execute("moveTo", target_x, target_y)
self.mouse_pos = Point(x=target_x, y=target_y) # Update mouse pos to scroll location
if action.scroll_x is not None or action.scroll_y is not None:
scroll_x_amount = action.scroll_x or 0
scroll_y_amount = action.scroll_y or 0
if scroll_x_amount != 0:
await self._execute("hscroll", scroll_x_amount)
if scroll_y_amount != 0:
# pyautogui scroll: positive up, so negate for typical "scroll down" meaning positive y
await self._execute("scroll", -scroll_y_amount)
output = f"Scrolled by (x:{scroll_x_amount}, y:{scroll_y_amount}) at ({target_x}, {target_y})"
elif action.scroll_direction:
# Define scroll unit (number of pyautogui scroll 'clicks')
# This might need tuning based on desired sensitivity.
pyautogui_scroll_clicks_per_unit = 1
amount = action.scroll_amount or 1
total_scroll_clicks = pyautogui_scroll_clicks_per_unit * amount
if action.scroll_direction == "up":
await self._execute("scroll", total_scroll_clicks)
elif action.scroll_direction == "down":
await self._execute("scroll", -total_scroll_clicks)
elif action.scroll_direction == "left":
await self._execute("hscroll", -total_scroll_clicks)
elif action.scroll_direction == "right":
await self._execute("hscroll", total_scroll_clicks)
output = f"Scrolled {action.scroll_direction} by {amount} units at ({target_x}, {target_y})"
else:
error = "Scroll action requires either scroll_x/y or scroll_direction"
logger.debug(f"Action: {action.type} details: {output or error}")
case "keypress":
mapped_keys = [self.CUA_KEY_TO_PYAUTOGUI_KEY.get(k.lower(), k.lower()) for k in action.keys]
key_string = "N/A"
if not mapped_keys:
error = "Keypress action requires at least one key"
elif len(mapped_keys) > 1:
await self._execute("hotkey", *mapped_keys)
key_string = "+".join(mapped_keys)
else:
await self._execute("press", mapped_keys[0])
key_string = mapped_keys[0]
if not error:
output = f"Pressed key(s): {key_string}"
logger.debug(f"Action: {action.type} '{key_string}'")
case "type":
text_to_type = action.text
await self._execute("typewrite", text_to_type, interval=0.02) # Small interval
output = f"Typed text: {text_to_type}"
logger.debug(f"Action: {action.type} '{text_to_type}'")
case "wait":
duration = action.duration
await asyncio.sleep(duration)
output = f"Waited for {duration} seconds"
logger.debug(f"Action: {action.type} for {duration}s")
case "screenshot":
step_type = "image"
# The actual screenshot data is added from after_state later
output = {"message": "Screenshot captured", "url": "desktop"}
logger.debug(f"Action: {action.type}")
case "move":
x, y = action.x, action.y
await self._execute("moveTo", x, y, duration=0.2) # Small duration for smooth move
self.mouse_pos = Point(x=x, y=y)
output = f"Moved mouse to ({x}, {y})"
logger.debug(f"Action: {action.type} to ({x},{y})")
case "drag":
if not isinstance(action, DragAction):
raise TypeError("Invalid action type for drag")
drag_path = action.path
if not drag_path:
error = "Missing path for drag action"
else:
start_x, start_y = drag_path[0].x, drag_path[0].y
await self._execute("moveTo", start_x, start_y, duration=0.1)
await self._execute("mouseDown")
for point in drag_path[1:]:
await self._execute("moveTo", point.x, point.y, duration=0.05)
await self._execute("mouseUp")
self.mouse_pos = Point(x=drag_path[-1].x, y=drag_path[-1].y)
output = f"Drag along path starting at ({start_x},{start_y})"
logger.debug(f"Action: {action.type} with {len(drag_path)} points")
case "mouse_down":
pyautogui_button = action.button.lower() if action.button else "left"
await self._execute("mouseDown", button=pyautogui_button)
output = f"{action.button.capitalize() if action.button else 'Left'} mouse button down"
logger.debug(f"Action: {action.type} {action.button}")
case "mouse_up":
pyautogui_button = action.button.lower() if action.button else "left"
await self._execute("mouseUp", button=pyautogui_button)
output = f"{action.button.capitalize() if action.button else 'Left'} mouse button up"
logger.debug(f"Action: {action.type} {action.button}")
case "hold_key":
keys_to_hold_str = action.text
duration = action.duration
parsed_keys = self.parse_key_combination(keys_to_hold_str)
if not parsed_keys:
error = f"No valid keys found in '{keys_to_hold_str}' for hold_key"
else:
for key_to_hold in parsed_keys:
await self._execute("keyDown", key_to_hold)
await asyncio.sleep(duration) # Non-pyautogui, direct sleep
for key_to_hold in reversed(parsed_keys): # Release in reverse order
await self._execute("keyUp", key_to_hold)
output = (
f"Held key{'s' if len(parsed_keys) > 1 else ''} {keys_to_hold_str} for {duration} seconds"
)
logger.debug(f"Action: {action.type} '{keys_to_hold_str}' for {duration}s")
case "key_down":
key_to_press = self.CUA_KEY_TO_PYAUTOGUI_KEY.get(action.key.lower(), action.key)
await self._execute("keyDown", key_to_press)
output = f"Key down: {key_to_press}"
logger.debug(f"Action: {action.type} {key_to_press}")
case "key_up":
key_to_release = self.CUA_KEY_TO_PYAUTOGUI_KEY.get(action.key.lower(), action.key)
await self._execute("keyUp", key_to_release)
output = f"Key up: {key_to_release}"
logger.debug(f"Action: {action.type} {key_to_release}")
case "cursor_position":
pos_x, pos_y = await self._execute("position")
self.mouse_pos = Point(x=pos_x, y=pos_y)
output = f"Cursor position is ({pos_x}, {pos_y})"
logger.debug(f"Action: {action.type}, position: ({pos_x},{pos_y})")
case "goto":
output = f"Goto action (URL: {action.url}) is not applicable for ComputerEnvironment."
logger.warning(f"Unsupported action: {action.type} for ComputerEnvironment.")
case "back":
output = "Back action is not applicable for ComputerEnvironment."
logger.warning(f"Unsupported action: {action.type} for ComputerEnvironment.")
case "terminal":
# Execute terminal command
result = await self._execute_shell_command(action.command)
if result["success"]:
output = f"Command executed successfully:\n{result['output']}"
else:
error = f"Command execution failed: {result['error']}"
logger.debug(f"Action: {action.type} with command '{action.command}'")
case "text_editor_view":
# View file contents
file_path = action.path
view_range = action.view_range
# Type guard: path should be str for text editor actions
if not isinstance(file_path, str):
raise TypeError("Invalid path type for text editor view action")
escaped_path = file_path.replace("'", "'\"'\"'")
is_dir = await self._execute("os.path.isdir", escaped_path)
if is_dir:
cmd = rf"find {escaped_path} -maxdepth 2 -not -path '*/\.*'"
elif view_range:
# Use head/tail to view specific line range
start_line, end_line = view_range
lines_to_show = end_line - start_line + 1
cmd = f"head -n {end_line} '{escaped_path}' | tail -n {lines_to_show}"
else:
# View entire file
cmd = f"cat '{escaped_path}'"
result = await self._execute_shell_command(cmd)
MAX_OUTPUT_LENGTH = 15000 # Limit output length to avoid excessive data
if len(result["output"]) > MAX_OUTPUT_LENGTH:
result["output"] = f"{result['output'][:MAX_OUTPUT_LENGTH]}..."
if result["success"]:
if is_dir:
output = f"Here's the files and directories up to 2 levels deep in {file_path}, excluding hidden items:\n{result['output']}"
else:
output = f"File contents of {file_path}:\n{result['output']}"
else:
error = f"Failed to view file {file_path}: {result['error']}"
logger.debug(f"Action: {action.type} for file {file_path}")
case "text_editor_create":
# Create new file with contents
file_path = action.path
file_text = action.file_text
# Type guard: path should be str for text editor actions
if not isinstance(file_path, str):
raise TypeError("Invalid path type for text editor create action")
escaped_path = file_path.replace("'", "'\"'\"'")
escaped_content = file_text.replace("\t", " ").replace(
"'", "'\"'\"'"
) # Escape single quotes for shell
cmd = f"echo '{escaped_content}' > '{escaped_path}'"
result = await self._execute_shell_command(cmd)
if result["success"]:
output = f"Created file {file_path} with {len(file_text)} characters"
else:
error = f"Failed to create file {file_path}: {result['error']}"
logger.debug(f"Action: {action.type} created file {file_path}")
case "text_editor_str_replace":
# Execute string replacement
file_path = action.path
old_str = action.old_str
new_str = action.new_str
# Type guard: path should be str for text editor actions
if not isinstance(file_path, str):
raise TypeError("Invalid path type for text editor str_replace action")
# Use sed for string replacement, escaping special characters
escaped_path = file_path.replace("'", "'\"'\"'")
escaped_old = (
old_str.replace("\t", " ")
.replace("\\", "\\\\")
.replace("\n", "\\n")
.replace("/", "\\/")
.replace("'", "'\"'\"'")
)
escaped_new = (
new_str.replace("\t", " ")
.replace("\\", "\\\\")
.replace("\n", "\\n")
.replace("&", "\\&")
.replace("/", "\\/")
.replace("'", "'\"'\"'")
)
cmd = f"sed -i.bak 's/{escaped_old}/{escaped_new}/g' '{escaped_path}'"
result = await self._execute_shell_command(cmd)
if result["success"]:
output = f"Replaced '{old_str[:50]}...' with '{new_str[:50]}...' in {file_path}"
else:
error = f"Failed to replace text in {file_path}: {result['error']}"
logger.debug(f"Action: {action.type} in file {file_path}")
case "text_editor_insert":
# Insert text after specified line
file_path = action.path
insert_line = action.insert_line
new_str = action.new_str
# Type guard: path should be str for text editor actions
if not isinstance(file_path, str):
error = "Invalid path type for text editor insert action.\n"
error += f"Failed to insert text in {file_path}: {result['error']}"
raise TypeError(error)
escaped_path = file_path.replace("'", "'\"'\"'")
escaped_content = (
new_str.replace("\t", " ")
.replace("\\", "\\\\")
.replace("'", "'\"'\"'")
.replace("\n", "\\\n")
)
cmd = f"sed -i.bak '{insert_line}a\\{escaped_content}' '{escaped_path}'"
result = await self._execute_shell_command(cmd)
if result["success"]:
output = f"Inserted text after line {insert_line} in {file_path}"
else:
error = f"Failed to insert text in {file_path}: {result['error']}"
logger.debug(f"Action: {action.type} at line {insert_line} in file {file_path}")
case _:
error = f"Unrecognized action type: {action.type}"
logger.warning(error)
except KeyboardInterrupt:
error = "User interrupt. Operation aborted."
logger.error(error)
except TypeError as e:
logger.error(f"Error executing action {action.type}: {e}")
except Exception as e:
error = f"Unexpected error executing action {action.type}: {str(e)}"
logger.exception(
f"Unexpected error during step execution for action: {action.model_dump_json(exclude_none=True)}"
)
after_state = await self.get_state()
if action.type == "screenshot" and step_type == "image":
output = {"image": after_state.screenshot, "url": after_state.url}
return EnvStepResult(
type=step_type,
output=output,
error=error,
current_url=after_state.url,
screenshot_base64=after_state.screenshot,
)
async def _execute_shell_command(self, command: str, new: bool = True) -> dict:
"""Execute a shell command and return the result."""
try:
if self.provider == "docker":
# Execute command in Docker container
docker_args = [
"docker",
"exec",
self.docker_container_name,
"bash",
"-c",
command, # The command string is passed as a single argument to bash -c
]
process = await asyncio.to_thread(
subprocess.run,
docker_args,
capture_output=True,
text=True,
check=False,
timeout=120,
)
else:
# Execute command locally
process = await asyncio.to_thread(
subprocess.run,
command,
shell=True,
capture_output=True,
text=True,
check=False,
start_new_session=new,
timeout=120,
)
if process.returncode == 0:
return {"success": True, "output": process.stdout, "error": None}
else:
return {"success": False, "output": process.stdout, "error": process.stderr}
except asyncio.TimeoutError:
return {"success": False, "output": "", "error": f"Command timed out after 120 seconds."}
except Exception as e:
return {"success": False, "output": "", "error": str(e)}
async def close(self) -> None:
logger.debug("Computer environment closed. No specific resources to release for PyAutoGUI.")
CUA_KEY_TO_PYAUTOGUI_KEY = {
# Modifiers
"option": "alt",
"control": "ctrl",
"cmd": "command",
"super": "win",
"meta": "command" if platform.system() == "Darwin" else "win",
# Navigation & Editing
"arrowdown": "down",
"arrowleft": "left",
"arrowright": "right",
"arrowup": "up",
"caps_lock": "capslock",
"del": "delete",
"return": "enter",
"esc": "escape",
"pgdn": "pagedown",
"pgup": "pageup",
" ": "space",
# Numpad keys (example, pyautogui uses 'num0', 'add', 'subtract', etc.)
"numpad0": "num0",
"numpad_0": "num0",
}
@staticmethod
def parse_key_combination(text: str) -> list[str]:
if not text:
return []
keys_str_list = text.lower().split("+")
mapped_keys = []
for k_str in keys_str_list:
# Use the mapped key if found, otherwise use the string itself (e.g. 'a', '1')
mapped_keys.append(ComputerEnvironment.CUA_KEY_TO_PYAUTOGUI_KEY.get(k_str.strip(), k_str.strip()))
return mapped_keys
def generate_pyautogui_command(self, func_name: str, *args, **kwargs) -> str:
args_repr = [repr(arg) for arg in args]
kwargs_repr = [f"{k}={repr(v)}" for k, v in kwargs.items()]
all_params_repr = ", ".join(args_repr + kwargs_repr)
# Base script setup
script_lines = [
"import os",
"import pyautogui",
]
if self.provider == "docker":
script_lines.extend(
[
# Display export for Docker.
f"os.environ['DISPLAY']='{self.docker_display}'",
# Disable failsafe in Docker to avoid accidental exits
"pyautogui.FAILSAFE = False",
]
)
# Function-specific logic
if func_name == "screenshot":
script_lines.extend(
[
"import io",
"import base64",
"img = pyautogui.screenshot()",
"buf = io.BytesIO()",
"img.save(buf, format='PNG')",
"print(base64.b64encode(buf.getvalue()).decode('utf-8'))",
]
)
elif func_name == "size":
script_lines.extend(["size = pyautogui.size()", "print(f'({size.width}, {size.height})')"])
elif func_name == "position":
script_lines.extend(["pos = pyautogui.position()", "print(f'({pos.x}, {pos.y})')"])
else: # General command structure
script_lines.extend(
[f"result = pyautogui.{func_name}({all_params_repr})", "print(result if result is not None else '')"]
)
return "; ".join(script_lines)
async def docker_execute(self, python_command_str: str) -> Optional[str]:
if not self.docker_container_name or not self.docker_display:
logger.error("Container name or Docker display not set for Docker execution.")
return None
safe_python_cmd = python_command_str.replace('"', '\\"')
docker_full_cmd = (
f'docker exec -e DISPLAY={self.docker_display} "{self.docker_container_name}" '
f'python3 -c "{safe_python_cmd}"'
)
try:
process = await asyncio.to_thread(
subprocess.run,
docker_full_cmd,
shell=True,
capture_output=True,
text=True,
check=False, # We check returncode manually
)
if process.returncode != 0:
if "FailSafeException" in process.stderr or "FailSafeException" in process.stdout:
raise KeyboardInterrupt(process.stderr or process.stdout)
else:
error_msg = (
f"Docker command failed:\nCmd: {docker_full_cmd}\n"
f"Return Code: {process.returncode}\nStderr: {process.stderr}\nStdout: {process.stdout}"
)
logger.error(error_msg)
raise RuntimeError(f"Docker exec error: {process.stderr or process.stdout}")
return process.stdout.strip()
except KeyboardInterrupt: # Re-raise if caught from above
raise
except Exception as e:
logger.error(f"Unexpected error running command in Docker '{docker_full_cmd}': {e}")
# Encapsulate as RuntimeError to avoid leaking subprocess errors directly
raise RuntimeError(f"Unexpected Docker error: {e}") from e

View File

@@ -26,12 +26,13 @@ from khoj.database.models import Agent, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.prompts import help_message, no_entries_found
from khoj.processor.conversation.utils import (
OperatorRun,
ResponseWithThought,
defilter_query,
save_to_conversation_log,
)
from khoj.processor.image.generate import text_to_image
from khoj.processor.operator.operate_browser import operate_browser
from khoj.processor.operator import operate_environment
from khoj.processor.speech.text_to_speech import generate_text_to_speech
from khoj.processor.tools.online_search import (
deduplicate_organic_results,
@@ -65,10 +66,7 @@ from khoj.routers.helpers import (
update_telemetry_state,
validate_chat_model,
)
from khoj.routers.research import (
InformationCollectionIteration,
execute_information_collection,
)
from khoj.routers.research import ResearchIteration, research
from khoj.routers.storage import upload_user_image_to_bucket
from khoj.utils import state
from khoj.utils.helpers import (
@@ -722,10 +720,10 @@ async def chat(
for file in raw_query_files:
query_files[file.name] = file.content
research_results: List[InformationCollectionIteration] = []
research_results: List[ResearchIteration] = []
online_results: Dict = dict()
code_results: Dict = dict()
operator_results: Dict[str, str] = {}
operator_results: List[OperatorRun] = []
compiled_references: List[Any] = []
inferred_queries: List[Any] = []
attached_file_context = gather_raw_query_files(query_files)
@@ -960,11 +958,10 @@ async def chat(
last_message = conversation.messages[-1]
online_results = {key: val.model_dump() for key, val in last_message.onlineContext.items() or []}
code_results = {key: val.model_dump() for key, val in last_message.codeContext.items() or []}
operator_results = last_message.operatorContext or {}
compiled_references = [ref.model_dump() for ref in last_message.context or []]
research_results = [
InformationCollectionIteration(**iter_dict) for iter_dict in last_message.researchContext or []
]
research_results = [ResearchIteration(**iter_dict) for iter_dict in last_message.researchContext or []]
operator_results = [OperatorRun(**iter_dict) for iter_dict in last_message.operatorContext or []]
train_of_thought = [thought.model_dump() for thought in last_message.trainOfThought or []]
# Drop the interrupted message from conversation history
meta_log["chat"].pop()
logger.info(f"Loaded interrupted partial context from conversation {conversation_id}.")
@@ -1009,12 +1006,12 @@ async def chat(
file_filters = conversation.file_filters if conversation and conversation.file_filters else []
if conversation_commands == [ConversationCommand.Research]:
async for research_result in execute_information_collection(
async for research_result in research(
user=user,
query=defiltered_query,
conversation_id=conversation_id,
conversation_history=meta_log,
previous_iterations=research_results,
previous_iterations=list(research_results),
query_images=uploaded_images,
agent=agent,
send_status_func=partial(send_event, ChatEvent.STATUS),
@@ -1025,7 +1022,7 @@ async def chat(
tracer=tracer,
cancellation_event=cancellation_event,
):
if isinstance(research_result, InformationCollectionIteration):
if isinstance(research_result, ResearchIteration):
if research_result.summarizedResult:
if research_result.onlineContext:
online_results.update(research_result.onlineContext)
@@ -1033,13 +1030,26 @@ async def chat(
code_results.update(research_result.codeContext)
if research_result.context:
compiled_references.extend(research_result.context)
if research_result.operatorContext:
operator_results.update(research_result.operatorContext)
if not research_results or research_results[-1] is not research_result:
research_results.append(research_result)
else:
yield research_result
# Track operator results across research and operator iterations
# This relies on two conditions:
# 1. Check to append new (partial) operator results
# Relies on triggering this check on every status updates.
# Status updates cascade up from operator to research to chat api on every step.
# 2. Keep operator results in sync with each research operator step
# Relies on python object references to ensure operator results
# are implicitly kept in sync after the initial append
if (
research_results
and research_results[-1].operatorContext
and (not operator_results or operator_results[-1] is not research_results[-1].operatorContext)
):
operator_results.append(research_results[-1].operatorContext)
# researched_results = await extract_relevant_info(q, researched_results, agent)
if state.verbose > 1:
logger.debug(f'Researched Results: {"".join(r.summarizedResult for r in research_results)}')
@@ -1292,11 +1302,12 @@ async def chat(
)
if ConversationCommand.Operator in conversation_commands:
try:
async for result in operate_browser(
async for result in operate_environment(
defiltered_query,
user,
meta_log,
location,
list(operator_results)[-1] if operator_results else None,
query_images=uploaded_images,
query_files=attached_file_context,
send_status_func=partial(send_event, ChatEvent.STATUS),
@@ -1306,16 +1317,17 @@ async def chat(
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
operator_results = {result["query"]: result["result"]}
elif isinstance(result, OperatorRun):
if not operator_results or operator_results[-1] is not result:
operator_results.append(result)
# Add webpages visited while operating browser to references
if result.get("webpages"):
if result.webpages:
if not online_results.get(defiltered_query):
online_results[defiltered_query] = {"webpages": result["webpages"]}
online_results[defiltered_query] = {"webpages": result.webpages}
elif not online_results[defiltered_query].get("webpages"):
online_results[defiltered_query]["webpages"] = result["webpages"]
online_results[defiltered_query]["webpages"] = result.webpages
else:
online_results[defiltered_query]["webpages"] += result["webpages"]
online_results[defiltered_query]["webpages"] += result.webpages
except ValueError as e:
program_execution_context.append(f"Browser operation error: {e}")
logger.warning(f"Failed to operate browser with {e}", exc_info=True)
@@ -1333,7 +1345,6 @@ async def chat(
"context": compiled_references,
"onlineContext": unique_online_results,
"codeContext": code_results,
"operatorContext": operator_results,
},
):
yield result

View File

@@ -94,7 +94,8 @@ from khoj.processor.conversation.openai.gpt import (
)
from khoj.processor.conversation.utils import (
ChatEvent,
InformationCollectionIteration,
OperatorRun,
ResearchIteration,
ResponseWithThought,
clean_json,
clean_mermaidjs,
@@ -385,7 +386,7 @@ async def aget_data_sources_and_output_format(
if len(agent_outputs) == 0 or output.value in agent_outputs:
output_options_str += f'- "{output.value}": "{description}"\n'
chat_history = construct_chat_history(conversation_history)
chat_history = construct_chat_history(conversation_history, n=6)
if query_images:
query = f"[placeholder for {len(query_images)} user attached images]\n{query}"
@@ -1174,12 +1175,7 @@ async def send_message_to_model_wrapper(
if vision_available and query_images:
logger.info(f"Using {chat_model.name} model to understand {len(query_images)} images.")
subscribed = await ais_user_subscribed(user) if user else False
max_tokens = (
chat_model.subscribed_max_prompt_size
if subscribed and chat_model.subscribed_max_prompt_size
else chat_model.max_prompt_size
)
max_tokens = await ConversationAdapters.aget_max_context_size(chat_model, user)
chat_model_name = chat_model.name
tokenizer = chat_model.tokenizer
model_type = chat_model.model_type
@@ -1271,12 +1267,7 @@ def send_message_to_model_wrapper_sync(
if chat_model is None:
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
subscribed = is_user_subscribed(user) if user else False
max_tokens = (
chat_model.subscribed_max_prompt_size
if subscribed and chat_model.subscribed_max_prompt_size
else chat_model.max_prompt_size
)
max_tokens = ConversationAdapters.get_max_context_size(chat_model, user)
chat_model_name = chat_model.name
model_type = chat_model.model_type
vision_available = chat_model.vision_enabled
@@ -1355,8 +1346,8 @@ async def agenerate_chat_response(
compiled_references: List[Dict] = [],
online_results: Dict[str, Dict] = {},
code_results: Dict[str, Dict] = {},
operator_results: Dict[str, str] = {},
research_results: List[InformationCollectionIteration] = [],
operator_results: List[OperatorRun] = [],
research_results: List[ResearchIteration] = [],
inferred_queries: List[str] = [],
conversation_commands: List[ConversationCommand] = [ConversationCommand.Default],
user: KhojUser = None,
@@ -1414,7 +1405,7 @@ async def agenerate_chat_response(
compiled_references = []
online_results = {}
code_results = {}
operator_results = {}
operator_results = []
deepthought = True
chat_model = await ConversationAdapters.aget_valid_chat_model(user, conversation, is_subscribed)

View File

@@ -13,12 +13,13 @@ from khoj.database.adapters import AgentAdapters, EntryAdapters
from khoj.database.models import Agent, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.utils import (
InformationCollectionIteration,
OperatorRun,
ResearchIteration,
construct_iteration_history,
construct_tool_chat_history,
load_complex_json,
)
from khoj.processor.operator.operate_browser import operate_browser
from khoj.processor.operator import operate_environment
from khoj.processor.tools.online_search import read_webpages, search_online
from khoj.processor.tools.run_code import run_code
from khoj.routers.api import extract_references_and_questions
@@ -83,7 +84,7 @@ async def apick_next_tool(
location: LocationData = None,
user_name: str = None,
agent: Agent = None,
previous_iterations: List[InformationCollectionIteration] = [],
previous_iterations: List[ResearchIteration] = [],
max_iterations: int = 5,
query_images: List[str] = [],
query_files: str = None,
@@ -95,6 +96,24 @@ async def apick_next_tool(
):
"""Given a query, determine which of the available tools the agent should use in order to answer appropriately."""
# Continue with previous iteration if a multi-step tool use is in progress
if (
previous_iterations
and previous_iterations[-1].tool == ConversationCommand.Operator
and not previous_iterations[-1].summarizedResult
):
previous_iteration = previous_iterations[-1]
yield ResearchIteration(
tool=previous_iteration.tool,
query=query,
context=previous_iteration.context,
onlineContext=previous_iteration.onlineContext,
codeContext=previous_iteration.codeContext,
operatorContext=previous_iteration.operatorContext,
warning=previous_iteration.warning,
)
return
# Construct tool options for the agent to choose from
tool_options = dict()
tool_options_str = ""
@@ -165,7 +184,7 @@ async def apick_next_tool(
)
except Exception as e:
logger.error(f"Failed to infer information sources to refer: {e}", exc_info=True)
yield InformationCollectionIteration(
yield ResearchIteration(
tool=None,
query=None,
warning="Failed to infer information sources to refer. Skipping iteration. Try again.",
@@ -194,26 +213,26 @@ async def apick_next_tool(
async for event in send_status_func(f"{scratchpad}"):
yield {ChatEvent.STATUS: event}
yield InformationCollectionIteration(
yield ResearchIteration(
tool=selected_tool,
query=generated_query,
warning=warning,
)
except Exception as e:
logger.error(f"Invalid response for determining relevant tools: {response}. {e}", exc_info=True)
yield InformationCollectionIteration(
yield ResearchIteration(
tool=None,
query=None,
warning=f"Invalid response for determining relevant tools: {response}. Skipping iteration. Fix error: {e}",
)
async def execute_information_collection(
async def research(
user: KhojUser,
query: str,
conversation_id: str,
conversation_history: dict,
previous_iterations: List[InformationCollectionIteration],
previous_iterations: List[ResearchIteration],
query_images: List[str],
agent: Agent = None,
send_status_func: Optional[Callable] = None,
@@ -248,9 +267,9 @@ async def execute_information_collection(
online_results: Dict = dict()
code_results: Dict = dict()
document_results: List[Dict[str, str]] = []
operator_results: Dict[str, str] = {}
operator_results: OperatorRun = None
summarize_files: str = ""
this_iteration = InformationCollectionIteration(tool=None, query=query)
this_iteration = ResearchIteration(tool=None, query=query)
async for result in apick_next_tool(
query,
@@ -271,8 +290,9 @@ async def execute_information_collection(
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
elif isinstance(result, InformationCollectionIteration):
elif isinstance(result, ResearchIteration):
this_iteration = result
yield this_iteration
# Skip running iteration if warning present in iteration
if this_iteration.warning:
@@ -417,12 +437,13 @@ async def execute_information_collection(
elif this_iteration.tool == ConversationCommand.Operator:
try:
async for result in operate_browser(
async for result in operate_environment(
this_iteration.query,
user,
construct_tool_chat_history(previous_iterations, ConversationCommand.Operator),
location,
send_status_func,
previous_iterations[-1].operatorContext if previous_iterations else None,
send_status_func=send_status_func,
query_images=query_images,
agent=agent,
query_files=query_files,
@@ -431,17 +452,17 @@ async def execute_information_collection(
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
operator_results = {result["query"]: result["result"]}
elif isinstance(result, OperatorRun):
operator_results = result
this_iteration.operatorContext = operator_results
# Add webpages visited while operating browser to references
if result.get("webpages"):
if result.webpages:
if not online_results.get(this_iteration.query):
online_results[this_iteration.query] = {"webpages": result["webpages"]}
online_results[this_iteration.query] = {"webpages": result.webpages}
elif not online_results[this_iteration.query].get("webpages"):
online_results[this_iteration.query]["webpages"] = result["webpages"]
online_results[this_iteration.query]["webpages"] = result.webpages
else:
online_results[this_iteration.query]["webpages"] += result["webpages"]
online_results[this_iteration.query]["webpages"] += result.webpages
this_iteration.onlineContext = online_results
except Exception as e:
this_iteration.warning = f"Error operating browser: {e}"
@@ -489,7 +510,9 @@ async def execute_information_collection(
if code_results:
results_data += f"\n<code_results>\n{yaml.dump(truncate_code_context(code_results), allow_unicode=True, sort_keys=False, default_flow_style=False)}\n</code_results>"
if operator_results:
results_data += f"\n<browser_operator_results>\n{next(iter(operator_results.values()))}\n</browser_operator_results>"
results_data += (
f"\n<browser_operator_results>\n{operator_results.response}\n</browser_operator_results>"
)
if summarize_files:
results_data += f"\n<summarized_files>\n{yaml.dump(summarize_files, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n</summarized_files>"
if this_iteration.warning:

View File

@@ -18,8 +18,8 @@ default_offline_chat_models = [
"bartowski/Qwen2.5-14B-Instruct-GGUF",
]
default_openai_chat_models = ["gpt-4o-mini", "gpt-4.1"]
default_gemini_chat_models = ["gemini-2.0-flash", "gemini-2.5-flash-preview-04-17", "gemini-2.5-pro-preview-03-25"]
default_anthropic_chat_models = ["claude-3-7-sonnet-latest", "claude-3-5-haiku-latest"]
default_gemini_chat_models = ["gemini-2.0-flash", "gemini-2.5-flash-preview-05-20", "gemini-2.5-pro-preview-05-06"]
default_anthropic_chat_models = ["claude-sonnet-4-0", "claude-3-5-haiku-latest"]
empty_config = {
"search-type": {
@@ -63,10 +63,10 @@ model_to_cost: Dict[str, Dict[str, float]] = {
"claude-3-7-sonnet-20250219": {"input": 3.0, "output": 15.0, "cache_read": 0.3, "cache_write": 3.75},
"claude-3-7-sonnet@20250219": {"input": 3.0, "output": 15.0, "cache_read": 0.3, "cache_write": 3.75},
"claude-3-7-sonnet-latest": {"input": 3.0, "output": 15.0, "cache_read": 0.3, "cache_write": 3.75},
"claude-sonnet-4": {"input": 3.0, "output": 15.0, "cache_read": 0.3, "cache_write": 3.75},
"claude-sonnet-4-0": {"input": 3.0, "output": 15.0, "cache_read": 0.3, "cache_write": 3.75},
"claude-sonnet-4-20250514": {"input": 3.0, "output": 15.0, "cache_read": 0.3, "cache_write": 3.75},
"claude-sonnet-4@20250514": {"input": 3.0, "output": 15.0, "cache_read": 0.3, "cache_write": 3.75},
"claude-opus-4": {"input": 15.0, "output": 75.0, "cache_read": 1.50, "cache_write": 18.75},
"claude-opus-4-0": {"input": 15.0, "output": 75.0, "cache_read": 1.50, "cache_write": 18.75},
"claude-opus-4-20250514": {"input": 15.0, "output": 75.0, "cache_read": 1.50, "cache_write": 18.75},
"claude-opus-4@20250514": {"input": 15.0, "output": 75.0, "cache_read": 1.50, "cache_write": 18.75},
# Grok pricing: https://docs.x.ai/docs/models

View File

@@ -46,6 +46,7 @@ if TYPE_CHECKING:
from khoj.utils.models import BaseEncoder
from khoj.utils.rawconfig import AppConfig
logger = logging.getLogger(__name__)
# Initialize Magika for file type identification
magika = Magika()
@@ -364,7 +365,7 @@ command_descriptions = {
ConversationCommand.Summarize: "Get help with a question pertaining to an entire document.",
ConversationCommand.Diagram: "Draw a flowchart, diagram, or any other visual representation best expressed with primitives like lines, rectangles, and text.",
ConversationCommand.Research: "Do deep research on a topic. This will take longer than usual, but give a more detailed, comprehensive answer.",
ConversationCommand.Operator: "Operate and perform tasks using a GUI web browser.",
ConversationCommand.Operator: "Operate and perform tasks using a computer.",
}
command_descriptions_for_agent = {
@@ -373,12 +374,12 @@ command_descriptions_for_agent = {
ConversationCommand.Online: "Agent can search the internet for information.",
ConversationCommand.Webpage: "Agent can read suggested web pages for information.",
ConversationCommand.Research: "Agent can do deep research on a topic.",
ConversationCommand.Code: "Agent can run Python code to parse information, run complex calculations, create documents and charts.",
ConversationCommand.Operator: "Agent can operate and perform actions using a GUI web browser to complete a task.",
ConversationCommand.Code: "Agent can run a Python script to parse information, run complex calculations, create documents and charts.",
ConversationCommand.Operator: "Agent can operate a computer to complete tasks.",
}
e2b_tool_description = "To run Python code in a E2B sandbox with no network access. Helpful to parse complex information, run calculations, create text documents and create charts with quantitative data. Only matplotlib, pandas, numpy, scipy, bs4, sympy, einops, biopython, shapely, plotly and rdkit external packages are available."
terrarium_tool_description = "To run Python code in a Terrarium, Pyodide sandbox with no network access. Helpful to parse complex information, run complex calculations, create plaintext documents and create charts with quantitative data. Only matplotlib, panda, numpy, scipy, bs4 and sympy external packages are available."
e2b_tool_description = "To run a Python script in a E2B sandbox with no network access. Helpful to parse complex information, run calculations, create text documents and create charts with quantitative data. Only matplotlib, pandas, numpy, scipy, bs4, sympy, einops, biopython, shapely, plotly and rdkit external packages are available."
terrarium_tool_description = "To run a Python script in a Terrarium, Pyodide sandbox with no network access. Helpful to parse complex information, run complex calculations, create plaintext documents and create charts with quantitative data. Only matplotlib, panda, numpy, scipy, bs4 and sympy external packages are available."
tool_descriptions_for_llm = {
ConversationCommand.Default: "To use a mix of your internal knowledge and the user's personal knowledge, or if you don't entirely understand the query.",
@@ -387,7 +388,7 @@ tool_descriptions_for_llm = {
ConversationCommand.Online: "To search for the latest, up-to-date information from the internet. Note: **Questions about Khoj should always use this data source**",
ConversationCommand.Webpage: "To use if the user has directly provided the webpage urls or you are certain of the webpage urls to read.",
ConversationCommand.Code: e2b_tool_description if is_e2b_code_sandbox_enabled() else terrarium_tool_description,
ConversationCommand.Operator: "To use when you need to operate and take actions using a GUI web browser.",
ConversationCommand.Operator: "To use when you need to operate a computer to complete the task.",
}
tool_description_for_research_llm = {
@@ -396,7 +397,7 @@ tool_description_for_research_llm = {
ConversationCommand.Webpage: "To extract information from webpages. Useful for more detailed research from the internet. Usually used when you know the webpage links to refer to. Share upto {max_webpages_to_read} webpage links and what information to extract from them in your query.",
ConversationCommand.Code: e2b_tool_description if is_e2b_code_sandbox_enabled() else terrarium_tool_description,
ConversationCommand.Text: "To respond to the user once you've completed your research and have the required information.",
ConversationCommand.Operator: "To operate and take actions using a GUI web browser.",
ConversationCommand.Operator: "To operate a computer to complete the task.",
}
mode_descriptions_for_llm = {
@@ -493,13 +494,7 @@ def is_promptrace_enabled():
def is_operator_enabled():
"""Check if Khoj can operate GUI applications.
Set KHOJ_OPERATOR_ENABLED env var to true and install playwright to enable it."""
try:
import playwright
is_playwright_installed = True
except ImportError:
is_playwright_installed = False
return is_env_var_true("KHOJ_OPERATOR_ENABLED") and is_playwright_installed
return is_env_var_true("KHOJ_OPERATOR_ENABLED")
def is_valid_url(url: str) -> bool:
@@ -686,7 +681,7 @@ def get_chat_usage_metrics(
"cache_write_tokens": 0,
"cost": 0.0,
}
return {
current_usage = {
"input_tokens": prev_usage["input_tokens"] + input_tokens,
"output_tokens": prev_usage["output_tokens"] + output_tokens,
"thought_tokens": prev_usage.get("thought_tokens", 0) + thought_tokens,
@@ -703,6 +698,8 @@ def get_chat_usage_metrics(
prev_cost=prev_usage["cost"],
),
}
logger.debug(f"AI API usage by {model_name}: {current_usage}")
return current_usage
class AiApiInfo(NamedTuple):