mirror of
https://github.com/khoaliber/khoj.git
synced 2026-04-19 17:14:35 +00:00
Operate Computer with Khoj Operator (#1190)
## Summary - Enable Khoj to operate computers: Add experimental computer operator functionality that allows Khoj to interact with desktop environments, browsers, and terminals to accomplish complex tasks - Multi-environment support: Implement computer environments with GUI, file system, and terminal access. Can control host computer or Docker container computer ## Key Features ### Computer Operation Capabilities - Desktop control (screenshots, clicking, typing, keyboard shortcuts) - File editing and management - Terminal/bash command execution - Web browser automation - Visual feedback via train-of-thought video playback ### Infrastructure & Architecture: - Docker container (ghcr.io/khoj-ai/computer:latest) with Ubuntu 24.04, XFCE desktop, VNC access - Local computer environment support with pyautogui - Modular operator agent system supporting multiple environment types - Trajectory compression and context management for long-running tasks ### Model Integration: - Anthropic models only (Claude Sonnet 4, Claude 3.7 Sonnet, Claude Opus 4) - OpenAI and binary operator agents temporarily disabled - Enhanced caching and context management for operator conversations ### User Experience: - `/operator` command or just ask Khoj to use operator tool to invoke computer operation - Integrate with research mode for extended 30+ minute task execution - Video of computer operation in train of thought for transparency ### Configuration - Set `KHOJ_OPERATOR_ENABLED=True` in `docker-compose.yml` - Requires Anthropic API key - Computer container runs on port 5900 (VNC)
This commit is contained in:
29
.github/workflows/dockerize.yml
vendored
29
.github/workflows/dockerize.yml
vendored
@@ -12,6 +12,7 @@ on:
|
||||
- pyproject.toml
|
||||
- Dockerfile
|
||||
- prod.Dockerfile
|
||||
- computer.Dockerfile
|
||||
- docker-compose.yml
|
||||
- .github/workflows/dockerize.yml
|
||||
workflow_dispatch:
|
||||
@@ -27,6 +28,10 @@ on:
|
||||
description: 'Build Khoj cloud docker image'
|
||||
type: boolean
|
||||
default: true
|
||||
khoj-computer:
|
||||
description: 'Build computer for Khoj'
|
||||
type: boolean
|
||||
default: true
|
||||
|
||||
env:
|
||||
# Tag Image with tag name on release
|
||||
@@ -115,6 +120,21 @@ jobs:
|
||||
org.opencontainers.image.description=Khoj AI Cloud - Your second brain powered by LLMs and Neural Search
|
||||
org.opencontainers.image.source=${{ github.server_url }}/${{ github.repository }}
|
||||
|
||||
- name: 📦️️💻 Build and Push Computer for Khoj
|
||||
uses: docker/build-push-action@v4
|
||||
if: github.event_name == 'workflow_dispatch' && github.event.inputs.khoj-computer == 'true'
|
||||
with:
|
||||
context: .
|
||||
file: computer.Dockerfile
|
||||
push: true
|
||||
tags: |
|
||||
ghcr.io/${{ github.repository }}-computer:${{ env.DOCKER_IMAGE_TAG }}-${{ matrix.platform == 'linux/amd64' && 'amd64' || 'arm64' }}
|
||||
cache-from: type=gha,scope=computer-${{ matrix.platform }}
|
||||
cache-to: type=gha,mode=max,scope=computer-${{ matrix.platform }}
|
||||
labels: |
|
||||
org.opencontainers.image.description=Khoj AI Computer - A computer for your second brain to operate
|
||||
org.opencontainers.image.source=${{ github.server_url }}/${{ github.repository }}
|
||||
|
||||
manifest:
|
||||
needs: build
|
||||
runs-on: ubuntu-latest
|
||||
@@ -147,3 +167,12 @@ jobs:
|
||||
-t ghcr.io/${{ github.repository }}-cloud:${{ github.ref_type == 'tag' && 'latest' || env.DOCKER_IMAGE_TAG }} \
|
||||
ghcr.io/${{ github.repository }}-cloud:${{ env.DOCKER_IMAGE_TAG }}-amd64 \
|
||||
ghcr.io/${{ github.repository }}-cloud:${{ env.DOCKER_IMAGE_TAG }}-arm64
|
||||
|
||||
- name: Create and Push Computer Manifest
|
||||
if: github.event.inputs.khoj-computer == 'true'
|
||||
run: |
|
||||
docker buildx imagetools create \
|
||||
-t ghcr.io/${{ github.repository }}-computer:${{ env.DOCKER_IMAGE_TAG }} \
|
||||
-t ghcr.io/${{ github.repository }}-computer:${{ github.ref_type == 'tag' && 'latest' || env.DOCKER_IMAGE_TAG }} \
|
||||
ghcr.io/${{ github.repository }}-computer:${{ env.DOCKER_IMAGE_TAG }}-amd64 \
|
||||
ghcr.io/${{ github.repository }}-computer:${{ env.DOCKER_IMAGE_TAG }}-arm64
|
||||
|
||||
129
computer.Dockerfile
Normal file
129
computer.Dockerfile
Normal file
@@ -0,0 +1,129 @@
|
||||
FROM ubuntu:24.04
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# Install System Dependencies
|
||||
RUN apt update \
|
||||
&& apt install -y \
|
||||
ca-certificates \
|
||||
gnupg \
|
||||
xfce4 \
|
||||
xfce4-goodies \
|
||||
x11vnc \
|
||||
xvfb \
|
||||
xdotool \
|
||||
imagemagick \
|
||||
x11-apps \
|
||||
dbus-x11 \
|
||||
sudo \
|
||||
python3-pip \
|
||||
python3-tk \
|
||||
python3-dev \
|
||||
build-essential \
|
||||
scrot \
|
||||
gnome-screenshot \
|
||||
net-tools \
|
||||
libx11-dev \
|
||||
libxext-dev \
|
||||
libxtst-dev \
|
||||
libxinerama-dev \
|
||||
libxmu-dev \
|
||||
libxrandr-dev \
|
||||
libxfixes-dev \
|
||||
software-properties-common \
|
||||
&& add-apt-repository ppa:mozillateam/ppa && apt update \
|
||||
&& apt install -y --no-install-recommends \
|
||||
# Desktop apps
|
||||
firefox-esr \
|
||||
libreoffice \
|
||||
x11-apps \
|
||||
xpdf \
|
||||
gedit \
|
||||
xpaint \
|
||||
tint2 \
|
||||
galculator \
|
||||
pcmanfm \
|
||||
unzip \
|
||||
# Terminal apps like file editors, viewers, git, wget/curl etc.
|
||||
less \
|
||||
nano \
|
||||
neovim \
|
||||
vim \
|
||||
git \
|
||||
curl \
|
||||
wget \
|
||||
procps \
|
||||
# Python/pyenv dependencies
|
||||
libssl-dev \
|
||||
zlib1g-dev \
|
||||
libbz2-dev \
|
||||
libreadline-dev \
|
||||
libsqlite3-dev \
|
||||
libncursesw5-dev \
|
||||
xz-utils \
|
||||
tk-dev \
|
||||
libxml2-dev \
|
||||
libxmlsec1-dev \
|
||||
libffi-dev \
|
||||
liblzma-dev \
|
||||
# set default browser
|
||||
&& update-alternatives --set x-www-browser /usr/bin/firefox-esr \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/* \
|
||||
# remove screen locks, power managers
|
||||
&& apt remove -y light-locker xfce4-screensaver xfce4-power-manager || true
|
||||
|
||||
# Create Computer User
|
||||
ENV USERNAME=operator
|
||||
ENV HOME=/home/$USERNAME
|
||||
RUN useradd -m -s /bin/bash -d $HOME -g $USERNAME $USERNAME && echo "${USERNAME} ALL=(ALL) NOPASSWD: ALL" >> /etc/sudoers
|
||||
USER $USERNAME
|
||||
WORKDIR $HOME
|
||||
|
||||
# Setup Python
|
||||
RUN git clone https://github.com/pyenv/pyenv.git ~/.pyenv && \
|
||||
cd ~/.pyenv && src/configure && make -C src && cd .. && \
|
||||
echo 'export PYENV_ROOT="$HOME/.pyenv"' >> ~/.bashrc && \
|
||||
echo 'command -v pyenv >/dev/null || export PATH="$PYENV_ROOT/bin:$PATH"' >> ~/.bashrc && \
|
||||
echo 'eval "$(pyenv init -)"' >> ~/.bashrc
|
||||
ENV PYENV_ROOT="$HOME/.pyenv"
|
||||
ENV PATH="$PYENV_ROOT/bin:$PATH"
|
||||
ENV PYENV_VERSION_MAJOR=3
|
||||
ENV PYENV_VERSION_MINOR=11
|
||||
ENV PYENV_VERSION_PATCH=6
|
||||
ENV PYENV_VERSION=$PYENV_VERSION_MAJOR.$PYENV_VERSION_MINOR.$PYENV_VERSION_PATCH
|
||||
RUN eval "$(pyenv init -)" && \
|
||||
pyenv install $PYENV_VERSION && \
|
||||
pyenv global $PYENV_VERSION && \
|
||||
pyenv rehash
|
||||
ENV PATH="$HOME/.pyenv/shims:$HOME/.pyenv/bin:$PATH"
|
||||
|
||||
# Install Python Packages
|
||||
RUN python3 -m pip install --no-cache-dir \
|
||||
pyautogui \
|
||||
Pillow \
|
||||
pyperclip \
|
||||
pygetwindow
|
||||
|
||||
# Setup VNC
|
||||
RUN x11vnc -storepasswd secret /home/operator/.vncpass
|
||||
|
||||
ARG WIDTH=1024
|
||||
ARG HEIGHT=768
|
||||
ARG DISPLAY_NUM=99
|
||||
ENV WIDTH=$WIDTH
|
||||
ENV HEIGHT=$HEIGHT
|
||||
ENV DISPLAY_NUM=$DISPLAY_NUM
|
||||
ENV DISPLAY=":$DISPLAY_NUM"
|
||||
|
||||
# Expose VNC on port 5900
|
||||
# run Xvfb, x11vnc, Xfce (no login manager)
|
||||
EXPOSE 5900
|
||||
CMD ["/bin/sh", "-c", " export XDG_RUNTIME_DIR=/run/user/$(id -u); \
|
||||
mkdir -p $XDG_RUNTIME_DIR && chown $USERNAME:$USERNAME $XDG_RUNTIME_DIR && chmod 0700 $XDG_RUNTIME_DIR; \
|
||||
Xvfb $DISPLAY -screen 0 ${WIDTH}x${HEIGHT}x24 -dpi 96 -auth /home/$USERNAME/.Xauthority >/dev/null 2>&1 & \
|
||||
sleep 1; \
|
||||
xauth add $DISPLAY . $(mcookie); \
|
||||
x11vnc -display $DISPLAY -forever -rfbauth /home/$USERNAME/.vncpass -listen 0.0.0.0 -rfbport 5900 >/dev/null 2>&1 & \
|
||||
eval $(dbus-launch --sh-syntax) && \
|
||||
startxfce4 & \
|
||||
sleep 2 && echo 'Container running!' && \
|
||||
tail -f /dev/null "]
|
||||
@@ -25,6 +25,14 @@ services:
|
||||
- khoj_search:/etc/searxng
|
||||
environment:
|
||||
- SEARXNG_BASE_URL=http://localhost:8080/
|
||||
# Creates Computer for Khoj to use.
|
||||
# Set KHOJ_OPERATOR_ENABLED=True in the server service environment variable to enable.
|
||||
computer:
|
||||
image: ghcr.io/khoj-ai/computer:latest
|
||||
ports:
|
||||
- "5900:5900"
|
||||
volumes:
|
||||
- khoj_computer:/home/operator
|
||||
server:
|
||||
depends_on:
|
||||
database:
|
||||
@@ -75,6 +83,8 @@ services:
|
||||
# - GEMINI_API_KEY=your_gemini_api_key
|
||||
# - ANTHROPIC_API_KEY=your_anthropic_api_key
|
||||
#
|
||||
# Uncomment line below to enable Khoj to use its computer.
|
||||
# - KHOJ_OPERATOR_ENABLED=True
|
||||
# Uncomment appropriate lines below to enable web results with Khoj
|
||||
# Ensure you set your provider specific API keys.
|
||||
# ---
|
||||
@@ -112,3 +122,4 @@ volumes:
|
||||
khoj_db:
|
||||
khoj_models:
|
||||
khoj_search:
|
||||
khoj_computer:
|
||||
|
||||
@@ -114,6 +114,7 @@ prod = [
|
||||
local = [
|
||||
"pgserver == 0.1.4",
|
||||
"playwright >= 1.49.0",
|
||||
"pyautogui == 0.9.54",
|
||||
]
|
||||
dev = [
|
||||
"khoj[prod,local]",
|
||||
|
||||
@@ -8,7 +8,9 @@ import ChatMessage, {
|
||||
ChatHistoryData,
|
||||
StreamMessage,
|
||||
TrainOfThought,
|
||||
TrainOfThoughtObject,
|
||||
} from "../chatMessage/chatMessage";
|
||||
import TrainOfThoughtVideoPlayer from "../../../components/trainOfThoughtVideoPlayer/trainOfThoughtVideoPlayer";
|
||||
|
||||
import { ScrollArea } from "@/components/ui/scroll-area";
|
||||
|
||||
@@ -41,17 +43,108 @@ interface ChatHistoryProps {
|
||||
setIsOwner?: (isOwner: boolean) => void;
|
||||
}
|
||||
|
||||
interface TrainOfThoughtFrame {
|
||||
text: string;
|
||||
image?: string;
|
||||
timestamp: number;
|
||||
}
|
||||
|
||||
interface TrainOfThoughtGroup {
|
||||
type: 'video' | 'text';
|
||||
frames?: TrainOfThoughtFrame[];
|
||||
textEntries?: TrainOfThoughtObject[];
|
||||
}
|
||||
|
||||
interface TrainOfThoughtComponentProps {
|
||||
trainOfThought: string[];
|
||||
trainOfThought: string[] | TrainOfThoughtObject[];
|
||||
lastMessage: boolean;
|
||||
agentColor: string;
|
||||
keyId: string;
|
||||
completed?: boolean;
|
||||
}
|
||||
|
||||
function extractTrainOfThoughtGroups(trainOfThought?: TrainOfThoughtObject[]): TrainOfThoughtGroup[] {
|
||||
if (!trainOfThought) return [];
|
||||
|
||||
const groups: TrainOfThoughtGroup[] = [];
|
||||
let currentVideoFrames: TrainOfThoughtFrame[] = [];
|
||||
let currentTextEntries: TrainOfThoughtObject[] = [];
|
||||
|
||||
trainOfThought.forEach((thought, index) => {
|
||||
let text = thought.data;
|
||||
let hasImage = false;
|
||||
|
||||
// Extract screenshot image from the thought data
|
||||
try {
|
||||
const jsonMatch = text.match(
|
||||
/\{.*(\"action\": \"screenshot\"|\"type\": \"screenshot\"|\"image\": \"data:image\/.*\").*\}/,
|
||||
);
|
||||
if (jsonMatch) {
|
||||
const jsonMessage = JSON.parse(jsonMatch[0]);
|
||||
if (jsonMessage.image) {
|
||||
hasImage = true;
|
||||
// Clean up the text to remove the JSON action
|
||||
text = text.replace(`:\n**Action**: ${jsonMatch[0]}`, "");
|
||||
if (jsonMessage.text) {
|
||||
text += `\n\n${jsonMessage.text}`;
|
||||
}
|
||||
|
||||
// If we have accumulated text entries, add them as a text group
|
||||
if (currentTextEntries.length > 0) {
|
||||
groups.push({
|
||||
type: 'text',
|
||||
textEntries: [...currentTextEntries]
|
||||
});
|
||||
currentTextEntries = [];
|
||||
}
|
||||
|
||||
// Add to current video frames
|
||||
currentVideoFrames.push({
|
||||
text: text,
|
||||
image: jsonMessage.image,
|
||||
timestamp: index,
|
||||
});
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
console.error("Failed to parse screenshot data", e);
|
||||
}
|
||||
|
||||
if (!hasImage) {
|
||||
// If we have accumulated video frames, add them as a video group
|
||||
if (currentVideoFrames.length > 0) {
|
||||
groups.push({
|
||||
type: 'video',
|
||||
frames: [...currentVideoFrames]
|
||||
});
|
||||
currentVideoFrames = [];
|
||||
}
|
||||
|
||||
// Add to current text entries
|
||||
currentTextEntries.push(thought);
|
||||
}
|
||||
});
|
||||
|
||||
// Add any remaining frames/entries
|
||||
if (currentVideoFrames.length > 0) {
|
||||
groups.push({
|
||||
type: 'video',
|
||||
frames: currentVideoFrames
|
||||
});
|
||||
}
|
||||
if (currentTextEntries.length > 0) {
|
||||
groups.push({
|
||||
type: 'text',
|
||||
textEntries: currentTextEntries
|
||||
});
|
||||
}
|
||||
|
||||
return groups;
|
||||
}
|
||||
|
||||
function TrainOfThoughtComponent(props: TrainOfThoughtComponentProps) {
|
||||
const lastIndex = props.trainOfThought.length - 1;
|
||||
const [collapsed, setCollapsed] = useState(props.completed);
|
||||
const [trainOfThoughtGroups, setTrainOfThoughtGroups] = useState<TrainOfThoughtGroup[]>([]);
|
||||
|
||||
const variants = {
|
||||
open: {
|
||||
@@ -72,6 +165,29 @@ function TrainOfThoughtComponent(props: TrainOfThoughtComponentProps) {
|
||||
}
|
||||
}, [props.completed]);
|
||||
|
||||
useEffect(() => {
|
||||
// Handle empty array case
|
||||
if (!props.trainOfThought || props.trainOfThought.length === 0) {
|
||||
setTrainOfThoughtGroups([]);
|
||||
return;
|
||||
}
|
||||
|
||||
// Convert string array to TrainOfThoughtObject array if needed
|
||||
let trainOfThoughtObjects: TrainOfThoughtObject[];
|
||||
|
||||
if (typeof props.trainOfThought[0] === 'string') {
|
||||
trainOfThoughtObjects = (props.trainOfThought as string[]).map((data, index) => ({
|
||||
type: 'text',
|
||||
data: data
|
||||
}));
|
||||
} else {
|
||||
trainOfThoughtObjects = props.trainOfThought as TrainOfThoughtObject[];
|
||||
}
|
||||
|
||||
const groups = extractTrainOfThoughtGroups(trainOfThoughtObjects);
|
||||
setTrainOfThoughtGroups(groups);
|
||||
}, [props.trainOfThought]);
|
||||
|
||||
return (
|
||||
<div
|
||||
className={`${!collapsed ? styles.trainOfThought + " border" : ""} rounded-lg`}
|
||||
@@ -101,15 +217,31 @@ function TrainOfThoughtComponent(props: TrainOfThoughtComponentProps) {
|
||||
<AnimatePresence initial={false}>
|
||||
{!collapsed && (
|
||||
<motion.div initial="closed" animate="open" exit="closed" variants={variants}>
|
||||
{props.trainOfThought.map((train, index) => (
|
||||
<TrainOfThought
|
||||
key={`train-${index}`}
|
||||
message={train}
|
||||
primary={
|
||||
index === lastIndex && props.lastMessage && !props.completed
|
||||
}
|
||||
agentColor={props.agentColor}
|
||||
/>
|
||||
{trainOfThoughtGroups.map((group, groupIndex) => (
|
||||
<div key={`train-group-${groupIndex}`}>
|
||||
{group.type === 'video' && group.frames && group.frames.length > 0 && (
|
||||
<TrainOfThoughtVideoPlayer
|
||||
frames={group.frames}
|
||||
autoPlay={false}
|
||||
playbackSpeed={1500}
|
||||
/>
|
||||
)}
|
||||
{group.type === 'text' && group.textEntries && group.textEntries.map((entry, entryIndex) => {
|
||||
const lastIndex = trainOfThoughtGroups.length - 1;
|
||||
const isLastGroup = groupIndex === lastIndex;
|
||||
const isLastEntry = entryIndex === group.textEntries!.length - 1;
|
||||
const isPrimaryEntry = isLastGroup && isLastEntry && props.lastMessage && !props.completed;
|
||||
|
||||
return (
|
||||
<TrainOfThought
|
||||
key={`train-text-${groupIndex}-${entryIndex}-${entry.data.length}`}
|
||||
message={entry.data}
|
||||
primary={isPrimaryEntry}
|
||||
agentColor={props.agentColor}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
))}
|
||||
</motion.div>
|
||||
)}
|
||||
@@ -125,6 +257,7 @@ export default function ChatHistory(props: ChatHistoryProps) {
|
||||
const [currentTurnId, setCurrentTurnId] = useState<string | null>(null);
|
||||
const sentinelRef = useRef<HTMLDivElement | null>(null);
|
||||
const scrollAreaRef = useRef<HTMLDivElement | null>(null);
|
||||
const scrollableContentWrapperRef = useRef<HTMLDivElement | null>(null);
|
||||
const latestUserMessageRef = useRef<HTMLDivElement | null>(null);
|
||||
const latestFetchedMessageRef = useRef<HTMLDivElement | null>(null);
|
||||
|
||||
@@ -151,16 +284,46 @@ export default function ChatHistory(props: ChatHistoryProps) {
|
||||
};
|
||||
|
||||
scrollAreaEl.addEventListener("scroll", detectIsNearBottom);
|
||||
detectIsNearBottom(); // Initial check
|
||||
return () => scrollAreaEl.removeEventListener("scroll", detectIsNearBottom);
|
||||
}, []);
|
||||
}, [scrollAreaRef]);
|
||||
|
||||
// Auto scroll while incoming message is streamed
|
||||
useEffect(() => {
|
||||
if (props.incomingMessages && props.incomingMessages.length > 0 && isNearBottom) {
|
||||
scrollToBottom();
|
||||
scrollToBottom(true);
|
||||
}
|
||||
}, [props.incomingMessages, isNearBottom]);
|
||||
|
||||
// ResizeObserver to handle content height changes (e.g., images loading)
|
||||
useEffect(() => {
|
||||
const contentWrapper = scrollableContentWrapperRef.current;
|
||||
const scrollViewport = scrollAreaRef.current?.querySelector<HTMLElement>(scrollAreaSelector);
|
||||
|
||||
if (!contentWrapper || !scrollViewport) return;
|
||||
|
||||
const observer = new ResizeObserver(() => {
|
||||
// Check current scroll position to decide if auto-scroll is warranted
|
||||
const { scrollTop, scrollHeight, clientHeight } = scrollViewport;
|
||||
const bottomThreshold = 50;
|
||||
const currentlyNearBottom = (scrollHeight - (scrollTop + clientHeight)) <= bottomThreshold;
|
||||
|
||||
if (currentlyNearBottom) {
|
||||
// Only auto-scroll if there are incoming messages being processed
|
||||
if (props.incomingMessages && props.incomingMessages.length > 0) {
|
||||
const lastMessage = props.incomingMessages[props.incomingMessages.length - 1];
|
||||
// If the last message is not completed, or it just completed (indicated by incompleteIncomingMessageIndex still being set)
|
||||
if (!lastMessage.completed || (lastMessage.completed && incompleteIncomingMessageIndex !== null)) {
|
||||
scrollToBottom(true); // Use instant scroll
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
observer.observe(contentWrapper);
|
||||
return () => observer.disconnect();
|
||||
}, [props.incomingMessages, incompleteIncomingMessageIndex, scrollAreaRef]); // Dependencies
|
||||
|
||||
// Scroll to most recent user message after the first page of chat messages is loaded.
|
||||
useEffect(() => {
|
||||
if (data && data.chat && data.chat.length > 0 && currentPage < 2) {
|
||||
@@ -297,7 +460,10 @@ export default function ChatHistory(props: ChatHistoryProps) {
|
||||
behavior: instant ? "auto" : "smooth",
|
||||
});
|
||||
});
|
||||
setIsNearBottom(true);
|
||||
// Optimistically set, the scroll listener will verify
|
||||
if (instant || scrollAreaEl && (scrollAreaEl.scrollHeight - (scrollAreaEl.scrollTop + scrollAreaEl.clientHeight)) < 5) {
|
||||
setIsNearBottom(true);
|
||||
}
|
||||
};
|
||||
|
||||
function constructAgentLink() {
|
||||
@@ -356,7 +522,7 @@ export default function ChatHistory(props: ChatHistoryProps) {
|
||||
`}
|
||||
ref={scrollAreaRef}
|
||||
>
|
||||
<div>
|
||||
<div ref={scrollableContentWrapperRef}>
|
||||
<div className={`${styles.chatHistory} ${props.customClassName}`}>
|
||||
<div ref={sentinelRef} style={{ height: "1px" }}>
|
||||
{fetchingData && <InlineLoading className="opacity-50" />}
|
||||
@@ -367,9 +533,7 @@ export default function ChatHistory(props: ChatHistoryProps) {
|
||||
<React.Fragment key={`chatMessage-${index}`}>
|
||||
{chatMessage.trainOfThought && chatMessage.by === "khoj" && (
|
||||
<TrainOfThoughtComponent
|
||||
trainOfThought={chatMessage.trainOfThought?.map(
|
||||
(train) => train.data,
|
||||
)}
|
||||
trainOfThought={chatMessage.trainOfThought}
|
||||
lastMessage={false}
|
||||
agentColor={data?.agent?.color || "orange"}
|
||||
key={`${index}trainOfThought`}
|
||||
@@ -428,12 +592,12 @@ export default function ChatHistory(props: ChatHistoryProps) {
|
||||
conversationId={props.conversationId}
|
||||
turnId={messageTurnId}
|
||||
/>
|
||||
{message.trainOfThought && (
|
||||
{message.trainOfThought && message.trainOfThought.length > 0 && (
|
||||
<TrainOfThoughtComponent
|
||||
trainOfThought={message.trainOfThought}
|
||||
lastMessage={index === incompleteIncomingMessageIndex}
|
||||
agentColor={data?.agent?.color || "orange"}
|
||||
key={`${index}trainOfThought`}
|
||||
key={`${index}trainOfThought-${message.trainOfThought.length}-${message.trainOfThought.map(t => t.length).join('-')}`}
|
||||
keyId={`${index}trainOfThought`}
|
||||
completed={message.completed}
|
||||
/>
|
||||
@@ -519,7 +683,6 @@ export default function ChatHistory(props: ChatHistoryProps) {
|
||||
className="absolute bottom-0 right-0 bg-white dark:bg-[hsl(var(--background))] text-neutral-500 dark:text-white p-2 rounded-full shadow-xl"
|
||||
onClick={() => {
|
||||
scrollToBottom();
|
||||
setIsNearBottom(true);
|
||||
}}
|
||||
>
|
||||
<ArrowDown size={24} />
|
||||
|
||||
@@ -144,7 +144,7 @@ interface Intent {
|
||||
"inferred-queries": string[];
|
||||
}
|
||||
|
||||
interface TrainOfThoughtObject {
|
||||
export interface TrainOfThoughtObject {
|
||||
type: string;
|
||||
data: string;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,170 @@
|
||||
.videoPlayer {
|
||||
border: 1px solid hsl(var(--border));
|
||||
border-radius: 8px;
|
||||
background-color: hsl(var(--background));
|
||||
margin: 16px 0;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.screen {
|
||||
position: relative;
|
||||
background-color: hsl(var(--muted));
|
||||
min-height: 300px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
.screenImage {
|
||||
max-width: 100%;
|
||||
max-height: 400px;
|
||||
object-fit: contain;
|
||||
border-radius: 4px;
|
||||
}
|
||||
|
||||
.textOverlay {
|
||||
position: absolute;
|
||||
bottom: 0;
|
||||
left: 0;
|
||||
right: 0;
|
||||
background: linear-gradient(transparent, rgba(0, 0, 0, 0.8));
|
||||
padding: 20px 16px 12px;
|
||||
color: white;
|
||||
}
|
||||
|
||||
.thoughtText {
|
||||
font-size: 14px;
|
||||
line-height: 1.4;
|
||||
max-height: 100px;
|
||||
overflow-y: auto;
|
||||
}
|
||||
|
||||
.controls {
|
||||
padding: 12px 16px;
|
||||
background-color: hsl(var(--card));
|
||||
border-top: 1px solid hsl(var(--border));
|
||||
}
|
||||
|
||||
.timeline {
|
||||
position: relative;
|
||||
margin-bottom: 12px;
|
||||
}
|
||||
|
||||
.timelineSlider {
|
||||
width: 100%;
|
||||
height: 4px;
|
||||
background-color: hsl(var(--muted));
|
||||
border-radius: 2px;
|
||||
outline: none;
|
||||
cursor: pointer;
|
||||
-webkit-appearance: none;
|
||||
appearance: none;
|
||||
}
|
||||
|
||||
.timelineSlider::-webkit-slider-thumb {
|
||||
-webkit-appearance: none;
|
||||
appearance: none;
|
||||
width: 16px;
|
||||
height: 16px;
|
||||
border-radius: 50%;
|
||||
background-color: hsl(var(--primary));
|
||||
cursor: pointer;
|
||||
border: 2px solid white;
|
||||
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.2);
|
||||
}
|
||||
|
||||
.timelineSlider::-moz-range-thumb {
|
||||
width: 16px;
|
||||
height: 16px;
|
||||
border-radius: 50%;
|
||||
background-color: hsl(var(--primary));
|
||||
cursor: pointer;
|
||||
border: 2px solid white;
|
||||
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.2);
|
||||
}
|
||||
|
||||
.frameMarkers {
|
||||
position: absolute;
|
||||
top: -2px;
|
||||
left: 0;
|
||||
right: 0;
|
||||
height: 8px;
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
pointer-events: none;
|
||||
}
|
||||
|
||||
.frameMarker {
|
||||
width: 6px;
|
||||
height: 8px;
|
||||
border-radius: 1px;
|
||||
cursor: pointer;
|
||||
pointer-events: auto;
|
||||
transition: all 0.2s ease;
|
||||
}
|
||||
|
||||
.frameMarker.hasImage {
|
||||
background-color: hsl(var(--primary));
|
||||
}
|
||||
|
||||
.frameMarker.textOnly {
|
||||
background-color: hsl(var(--muted-foreground));
|
||||
}
|
||||
|
||||
.frameMarker.active {
|
||||
background-color: hsl(var(--accent)) !important;
|
||||
transform: scaleY(1.5);
|
||||
}
|
||||
|
||||
.frameMarker:hover {
|
||||
transform: scaleY(1.2);
|
||||
}
|
||||
|
||||
.controlButtons {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
|
||||
.controlButton {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
width: 32px;
|
||||
height: 32px;
|
||||
border: 1px solid hsl(var(--border));
|
||||
border-radius: 4px;
|
||||
background-color: hsl(var(--background));
|
||||
color: hsl(var(--foreground));
|
||||
cursor: pointer;
|
||||
transition: all 0.2s ease;
|
||||
}
|
||||
|
||||
.controlButton:hover:not(:disabled) {
|
||||
background-color: hsl(var(--muted));
|
||||
}
|
||||
|
||||
.controlButton:disabled {
|
||||
opacity: 0.5;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
.controlButton.active {
|
||||
background-color: hsl(var(--primary));
|
||||
color: hsl(var(--primary-foreground));
|
||||
}
|
||||
|
||||
.frameInfo {
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
font-size: 12px;
|
||||
color: hsl(var(--muted-foreground));
|
||||
}
|
||||
|
||||
/* Dark mode adjustments */
|
||||
@media (prefers-color-scheme: dark) {
|
||||
.textOverlay {
|
||||
background: linear-gradient(transparent, rgba(0, 0, 0, 0.9));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,186 @@
|
||||
"use client";
|
||||
|
||||
import React, { useState, useRef, useEffect } from "react";
|
||||
import { Play, Pause, FastForward, Rewind } from "@phosphor-icons/react";
|
||||
import styles from "./trainOfThoughtVideoPlayer.module.css";
|
||||
|
||||
interface TrainOfThoughtFrame {
|
||||
text: string;
|
||||
image?: string;
|
||||
timestamp: number;
|
||||
}
|
||||
|
||||
interface TrainOfThoughtVideoPlayerProps {
|
||||
frames: TrainOfThoughtFrame[];
|
||||
autoPlay?: boolean;
|
||||
playbackSpeed?: number;
|
||||
}
|
||||
|
||||
export default function TrainOfThoughtVideoPlayer({
|
||||
frames,
|
||||
autoPlay = true,
|
||||
playbackSpeed = 1000, // ms per frame
|
||||
}: TrainOfThoughtVideoPlayerProps) {
|
||||
const [currentFrameIndex, setCurrentFrameIndex] = useState(0);
|
||||
const [isPlaying, setIsPlaying] = useState(autoPlay);
|
||||
const [isAutoTracking, setIsAutoTracking] = useState(true);
|
||||
const intervalRef = useRef<NodeJS.Timeout | null>(null);
|
||||
|
||||
// Auto-advance to latest frame when new frames are added
|
||||
useEffect(() => {
|
||||
if (isAutoTracking && frames.length > 0) {
|
||||
setCurrentFrameIndex(frames.length - 1);
|
||||
}
|
||||
}, [frames.length, isAutoTracking]);
|
||||
|
||||
// Handle playback
|
||||
useEffect(() => {
|
||||
if (isPlaying && frames.length > 1) {
|
||||
intervalRef.current = setInterval(() => {
|
||||
setCurrentFrameIndex((prev) => {
|
||||
const next = prev + 1;
|
||||
if (next >= frames.length) {
|
||||
setIsPlaying(false);
|
||||
return prev;
|
||||
}
|
||||
return next;
|
||||
});
|
||||
}, playbackSpeed);
|
||||
} else {
|
||||
if (intervalRef.current) {
|
||||
clearInterval(intervalRef.current);
|
||||
intervalRef.current = null;
|
||||
}
|
||||
}
|
||||
|
||||
return () => {
|
||||
if (intervalRef.current) {
|
||||
clearInterval(intervalRef.current);
|
||||
}
|
||||
};
|
||||
}, [isPlaying, frames.length, playbackSpeed]);
|
||||
|
||||
const currentFrame = frames[currentFrameIndex];
|
||||
|
||||
const handleSeek = (index: number) => {
|
||||
setCurrentFrameIndex(index);
|
||||
setIsAutoTracking(false);
|
||||
setIsPlaying(false);
|
||||
};
|
||||
|
||||
const handlePlay = () => {
|
||||
setIsPlaying(!isPlaying);
|
||||
setIsAutoTracking(false);
|
||||
};
|
||||
|
||||
const handlePrevious = () => {
|
||||
if (currentFrameIndex > 0) {
|
||||
setCurrentFrameIndex(currentFrameIndex - 1);
|
||||
setIsAutoTracking(false);
|
||||
setIsPlaying(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleNext = () => {
|
||||
if (currentFrameIndex < frames.length - 1) {
|
||||
setCurrentFrameIndex(currentFrameIndex + 1);
|
||||
setIsAutoTracking(false);
|
||||
setIsPlaying(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleAutoTrack = () => {
|
||||
setIsAutoTracking(true);
|
||||
setCurrentFrameIndex(frames.length - 1);
|
||||
setIsPlaying(false);
|
||||
};
|
||||
|
||||
if (!frames.length) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<div className={styles.videoPlayer}>
|
||||
<div className={styles.screen}>
|
||||
{currentFrame?.image && (
|
||||
<img
|
||||
src={currentFrame.image}
|
||||
alt={`Train of thought frame ${currentFrameIndex + 1}`}
|
||||
className={styles.screenImage}
|
||||
/>
|
||||
)}
|
||||
<div className={styles.textOverlay}>
|
||||
<div className={styles.thoughtText}>{currentFrame?.text}</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className={styles.controls}>
|
||||
<div className={styles.timeline}>
|
||||
<input
|
||||
type="range"
|
||||
min={0}
|
||||
max={Math.max(0, frames.length - 1)}
|
||||
value={currentFrameIndex}
|
||||
onChange={(e) => handleSeek(parseInt(e.target.value))}
|
||||
className={styles.timelineSlider}
|
||||
/>
|
||||
<div className={styles.frameMarkers}>
|
||||
{frames.map((frame, index) => (
|
||||
<div
|
||||
key={index}
|
||||
className={`${styles.frameMarker} ${
|
||||
frame.image ? styles.hasImage : styles.textOnly
|
||||
} ${index === currentFrameIndex ? styles.active : ""}`}
|
||||
onClick={() => handleSeek(index)}
|
||||
title={`Frame ${index + 1}: ${frame.text.slice(0, 50)}...`}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className={styles.controlButtons}>
|
||||
<button
|
||||
onClick={handlePrevious}
|
||||
disabled={currentFrameIndex === 0}
|
||||
title="Previous frame"
|
||||
className={styles.controlButton}
|
||||
>
|
||||
<Rewind size={16} />
|
||||
</button>
|
||||
|
||||
<button
|
||||
onClick={handlePlay}
|
||||
disabled={frames.length <= 1}
|
||||
title={isPlaying ? "Pause" : "Play"}
|
||||
className={styles.controlButton}
|
||||
>
|
||||
{isPlaying ? <Pause size={16} /> : <Play size={16} />}
|
||||
</button>
|
||||
|
||||
<button
|
||||
onClick={handleNext}
|
||||
disabled={currentFrameIndex === frames.length - 1}
|
||||
title="Next frame"
|
||||
className={styles.controlButton}
|
||||
>
|
||||
<FastForward size={16} />
|
||||
</button>
|
||||
|
||||
<button
|
||||
onClick={handleAutoTrack}
|
||||
className={`${styles.controlButton} ${isAutoTracking ? styles.active : ""}`}
|
||||
title="Auto-track latest"
|
||||
>
|
||||
Live
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div className={styles.frameInfo}>
|
||||
<span>
|
||||
{currentFrameIndex + 1} / {frames.length}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1312,6 +1312,26 @@ class ConversationAdapters:
|
||||
else:
|
||||
ServerChatSettings.objects.create(chat_default=chat_model, chat_advanced=chat_model)
|
||||
|
||||
@staticmethod
|
||||
def get_max_context_size(chat_model: ChatModel, user: KhojUser) -> int | None:
|
||||
"""Get the max context size for the user based on the chat model."""
|
||||
subscribed = is_user_subscribed(user) if user else False
|
||||
if subscribed and chat_model.subscribed_max_prompt_size:
|
||||
max_tokens = chat_model.subscribed_max_prompt_size
|
||||
else:
|
||||
max_tokens = chat_model.max_prompt_size
|
||||
return max_tokens
|
||||
|
||||
@staticmethod
|
||||
async def aget_max_context_size(chat_model: ChatModel, user: KhojUser) -> int | None:
|
||||
"""Get the max context size for the user based on the chat model."""
|
||||
subscribed = await ais_user_subscribed(user) if user else False
|
||||
if subscribed and chat_model.subscribed_max_prompt_size:
|
||||
max_tokens = chat_model.subscribed_max_prompt_size
|
||||
else:
|
||||
max_tokens = chat_model.max_prompt_size
|
||||
return max_tokens
|
||||
|
||||
@staticmethod
|
||||
async def aget_server_webscraper():
|
||||
server_chat_settings = await ServerChatSettings.objects.filter().prefetch_related("web_scraper").afirst()
|
||||
|
||||
@@ -107,7 +107,7 @@ class ChatMessage(PydanticBaseModel):
|
||||
onlineContext: Dict[str, OnlineContext] = {}
|
||||
codeContext: Dict[str, CodeContextData] = {}
|
||||
researchContext: Optional[List] = None
|
||||
operatorContext: Optional[Dict[str, str]] = None
|
||||
operatorContext: Optional[List] = None
|
||||
created: str
|
||||
images: Optional[List[str]] = None
|
||||
queryFiles: Optional[List[Dict]] = None
|
||||
|
||||
@@ -14,8 +14,10 @@ from khoj.processor.conversation.anthropic.utils import (
|
||||
format_messages_for_anthropic,
|
||||
)
|
||||
from khoj.processor.conversation.utils import (
|
||||
OperatorRun,
|
||||
ResponseWithThought,
|
||||
clean_json,
|
||||
construct_question_history,
|
||||
construct_structured_message,
|
||||
generate_chatml_messages_with_context,
|
||||
messages_to_print,
|
||||
@@ -53,13 +55,7 @@ def extract_questions_anthropic(
|
||||
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
|
||||
|
||||
# Extract Past User Message and Inferred Questions from Conversation Log
|
||||
chat_history = "".join(
|
||||
[
|
||||
f'User: {chat["intent"]["query"]}\nAssistant: {{"queries": {chat["intent"].get("inferred-queries") or list([chat["intent"]["query"]])}}}\nA: {chat["message"]}\n\n'
|
||||
for chat in conversation_log.get("chat", [])[-4:]
|
||||
if chat["by"] == "khoj"
|
||||
]
|
||||
)
|
||||
chat_history = construct_question_history(conversation_log, query_prefix="User", agent_name="Assistant")
|
||||
|
||||
# Get dates relative to today for prompt creation
|
||||
today = datetime.today()
|
||||
@@ -144,7 +140,7 @@ async def converse_anthropic(
|
||||
references: list[dict],
|
||||
online_results: Optional[Dict[str, Dict]] = None,
|
||||
code_results: Optional[Dict[str, Dict]] = None,
|
||||
operator_results: Optional[Dict[str, str]] = None,
|
||||
operator_results: Optional[List[OperatorRun]] = None,
|
||||
conversation_log={},
|
||||
model: Optional[str] = "claude-3-7-sonnet-latest",
|
||||
api_key: Optional[str] = None,
|
||||
@@ -216,8 +212,11 @@ async def converse_anthropic(
|
||||
f"{prompts.code_executed_context.format(code_results=truncate_code_context(code_results))}\n\n"
|
||||
)
|
||||
if ConversationCommand.Operator in conversation_commands and not is_none_or_empty(operator_results):
|
||||
operator_content = [
|
||||
{"query": oc.query, "response": oc.response, "webpages": oc.webpages} for oc in operator_results
|
||||
]
|
||||
context_message += (
|
||||
f"{prompts.operator_execution_context.format(operator_results=yaml_dump(operator_results))}\n\n"
|
||||
f"{prompts.operator_execution_context.format(operator_results=yaml_dump(operator_content))}\n\n"
|
||||
)
|
||||
context_message = context_message.strip()
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ def anthropic_completion_with_backoff(
|
||||
client = get_anthropic_client(api_key, api_base_url)
|
||||
anthropic_clients[api_key] = client
|
||||
|
||||
formatted_messages, system_prompt = format_messages_for_anthropic(messages, system_prompt)
|
||||
formatted_messages, system = format_messages_for_anthropic(messages, system_prompt)
|
||||
|
||||
aggregated_response = ""
|
||||
if response_type == "json_object" and not deepthought:
|
||||
@@ -70,8 +70,8 @@ def anthropic_completion_with_backoff(
|
||||
|
||||
final_message = None
|
||||
model_kwargs = model_kwargs or dict()
|
||||
if system_prompt:
|
||||
model_kwargs["system"] = system_prompt
|
||||
if system:
|
||||
model_kwargs["system"] = system
|
||||
|
||||
max_tokens = max_tokens or DEFAULT_MAX_TOKENS_ANTHROPIC
|
||||
if deepthought and is_reasoning_model(model_name):
|
||||
@@ -146,7 +146,7 @@ async def anthropic_chat_completion_with_backoff(
|
||||
# Temperature control not supported when using extended thinking
|
||||
temperature = 1.0
|
||||
|
||||
formatted_messages, system_prompt = format_messages_for_anthropic(messages, system_prompt)
|
||||
formatted_messages, system = format_messages_for_anthropic(messages, system_prompt)
|
||||
|
||||
aggregated_response = ""
|
||||
response_started = False
|
||||
@@ -156,7 +156,7 @@ async def anthropic_chat_completion_with_backoff(
|
||||
messages=formatted_messages,
|
||||
model=model_name, # type: ignore
|
||||
temperature=temperature,
|
||||
system=system_prompt,
|
||||
system=system,
|
||||
timeout=20,
|
||||
max_tokens=max_tokens,
|
||||
**model_kwargs,
|
||||
@@ -231,7 +231,10 @@ def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt: st
|
||||
else:
|
||||
system_prompt += message.content
|
||||
messages.remove(message)
|
||||
system_prompt = None if is_none_or_empty(system_prompt) else system_prompt
|
||||
if not is_none_or_empty(system_prompt):
|
||||
system = [{"type": "text", "text": system_prompt, "cache_control": {"type": "ephemeral"}}]
|
||||
else:
|
||||
system = None
|
||||
|
||||
# Anthropic requires the first message to be a 'user' message
|
||||
if len(messages) == 1:
|
||||
@@ -274,12 +277,32 @@ def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt: st
|
||||
logger.error(f"Drop message with empty content as not supported:\n{message}")
|
||||
messages.remove(message)
|
||||
continue
|
||||
if isinstance(message.content, str):
|
||||
message.content = [{"type": "text", "text": message.content}]
|
||||
|
||||
# Add cache control to enable prompt caching for conversations with sufficient context
|
||||
# Only add caching if we have multiple messages to make it worthwhile
|
||||
if len(messages) > 2:
|
||||
# Remove any existing cache controls from previous messages
|
||||
for message in messages: # All except the last message
|
||||
if isinstance(message.content, list):
|
||||
for block in message.content:
|
||||
if isinstance(block, dict) and "cache_control" in block:
|
||||
del block["cache_control"]
|
||||
|
||||
# Add cache control to the last content block of second to last message.
|
||||
# In research mode, this message content is list of iterations, updated after each research iteration.
|
||||
# Caching it should improve research efficiency.
|
||||
cache_message = messages[-2]
|
||||
if isinstance(cache_message.content, list) and cache_message.content:
|
||||
# Add cache control to the last content block
|
||||
cache_message.content[-1]["cache_control"] = {"type": "ephemeral"}
|
||||
|
||||
formatted_messages: List[anthropic.types.MessageParam] = [
|
||||
anthropic.types.MessageParam(role=message.role, content=message.content) for message in messages
|
||||
]
|
||||
|
||||
return formatted_messages, system_prompt
|
||||
return formatted_messages, system
|
||||
|
||||
|
||||
def is_reasoning_model(model_name: str) -> bool:
|
||||
|
||||
@@ -14,7 +14,9 @@ from khoj.processor.conversation.google.utils import (
|
||||
gemini_completion_with_backoff,
|
||||
)
|
||||
from khoj.processor.conversation.utils import (
|
||||
OperatorRun,
|
||||
clean_json,
|
||||
construct_question_history,
|
||||
construct_structured_message,
|
||||
generate_chatml_messages_with_context,
|
||||
messages_to_print,
|
||||
@@ -53,13 +55,7 @@ def extract_questions_gemini(
|
||||
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
|
||||
|
||||
# Extract Past User Message and Inferred Questions from Conversation Log
|
||||
chat_history = "".join(
|
||||
[
|
||||
f'User: {chat["intent"]["query"]}\nAssistant: {{"queries": {chat["intent"].get("inferred-queries") or list([chat["intent"]["query"]])}}}\nA: {chat["message"]}\n\n'
|
||||
for chat in conversation_log.get("chat", [])[-4:]
|
||||
if chat["by"] == "khoj"
|
||||
]
|
||||
)
|
||||
chat_history = construct_question_history(conversation_log, query_prefix="User", agent_name="Assistant")
|
||||
|
||||
# Get dates relative to today for prompt creation
|
||||
today = datetime.today()
|
||||
@@ -166,7 +162,7 @@ async def converse_gemini(
|
||||
references: list[dict],
|
||||
online_results: Optional[Dict[str, Dict]] = None,
|
||||
code_results: Optional[Dict[str, Dict]] = None,
|
||||
operator_results: Optional[Dict[str, str]] = None,
|
||||
operator_results: Optional[List[OperatorRun]] = None,
|
||||
conversation_log={},
|
||||
model: Optional[str] = "gemini-2.0-flash",
|
||||
api_key: Optional[str] = None,
|
||||
@@ -240,8 +236,11 @@ async def converse_gemini(
|
||||
f"{prompts.code_executed_context.format(code_results=truncate_code_context(code_results))}\n\n"
|
||||
)
|
||||
if ConversationCommand.Operator in conversation_commands and not is_none_or_empty(operator_results):
|
||||
operator_content = [
|
||||
{"query": oc.query, "response": oc.response, "webpages": oc.webpages} for oc in operator_results
|
||||
]
|
||||
context_message += (
|
||||
f"{prompts.operator_execution_context.format(operator_results=yaml_dump(operator_results))}\n\n"
|
||||
f"{prompts.operator_execution_context.format(operator_results=yaml_dump(operator_content))}\n\n"
|
||||
)
|
||||
context_message = context_message.strip()
|
||||
|
||||
@@ -276,7 +275,8 @@ async def converse_gemini(
|
||||
deepthought=deepthought,
|
||||
tracer=tracer,
|
||||
):
|
||||
full_response += chunk
|
||||
if chunk.response:
|
||||
full_response += chunk.response
|
||||
yield chunk
|
||||
|
||||
# Call completion_func once finish streaming and we have the full response
|
||||
|
||||
@@ -21,6 +21,7 @@ from tenacity import (
|
||||
)
|
||||
|
||||
from khoj.processor.conversation.utils import (
|
||||
ResponseWithThought,
|
||||
commit_conversation_trace,
|
||||
get_image_from_base64,
|
||||
get_image_from_url,
|
||||
@@ -102,7 +103,7 @@ def gemini_completion_with_backoff(
|
||||
client = get_gemini_client(api_key, api_base_url)
|
||||
gemini_clients[api_key] = client
|
||||
|
||||
formatted_messages, system_prompt = format_messages_for_gemini(messages, system_prompt)
|
||||
formatted_messages, system_instruction = format_messages_for_gemini(messages, system_prompt)
|
||||
|
||||
# format model response schema
|
||||
response_schema = None
|
||||
@@ -110,12 +111,12 @@ def gemini_completion_with_backoff(
|
||||
response_schema = clean_response_schema(model_kwargs["response_schema"])
|
||||
|
||||
thinking_config = None
|
||||
if deepthought and model_name.startswith("gemini-2-5"):
|
||||
if deepthought and model_name.startswith("gemini-2.5"):
|
||||
thinking_config = gtypes.ThinkingConfig(thinking_budget=MAX_REASONING_TOKENS_GEMINI)
|
||||
|
||||
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
|
||||
config = gtypes.GenerateContentConfig(
|
||||
system_instruction=system_prompt,
|
||||
system_instruction=system_instruction,
|
||||
temperature=temperature,
|
||||
thinking_config=thinking_config,
|
||||
max_output_tokens=MAX_OUTPUT_TOKENS_GEMINI,
|
||||
@@ -178,21 +179,21 @@ async def gemini_chat_completion_with_backoff(
|
||||
model_kwargs=None,
|
||||
deepthought=False,
|
||||
tracer: dict = {},
|
||||
) -> AsyncGenerator[str, None]:
|
||||
) -> AsyncGenerator[ResponseWithThought, None]:
|
||||
client = gemini_clients.get(api_key)
|
||||
if not client:
|
||||
client = get_gemini_client(api_key, api_base_url)
|
||||
gemini_clients[api_key] = client
|
||||
|
||||
formatted_messages, system_prompt = format_messages_for_gemini(messages, system_prompt)
|
||||
formatted_messages, system_instruction = format_messages_for_gemini(messages, system_prompt)
|
||||
|
||||
thinking_config = None
|
||||
if deepthought and model_name.startswith("gemini-2-5"):
|
||||
thinking_config = gtypes.ThinkingConfig(thinking_budget=MAX_REASONING_TOKENS_GEMINI)
|
||||
if deepthought and model_name.startswith("gemini-2.5"):
|
||||
thinking_config = gtypes.ThinkingConfig(thinking_budget=MAX_REASONING_TOKENS_GEMINI, include_thoughts=True)
|
||||
|
||||
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
|
||||
config = gtypes.GenerateContentConfig(
|
||||
system_instruction=system_prompt,
|
||||
system_instruction=system_instruction,
|
||||
temperature=temperature,
|
||||
thinking_config=thinking_config,
|
||||
max_output_tokens=MAX_OUTPUT_TOKENS_GEMINI,
|
||||
@@ -216,18 +217,25 @@ async def gemini_chat_completion_with_backoff(
|
||||
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
|
||||
# Keep track of the last chunk for usage data
|
||||
final_chunk = chunk
|
||||
# Handle streamed response chunk
|
||||
|
||||
# handle safety, rate-limit, other finish reasons
|
||||
stop_message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback)
|
||||
message = stop_message or chunk.text
|
||||
aggregated_response += message
|
||||
yield message
|
||||
if stopped:
|
||||
yield ResponseWithThought(response=stop_message)
|
||||
logger.warning(
|
||||
f"LLM Response Prevented for {model_name}: {stop_message}.\n"
|
||||
+ f"Last Message by {messages[-1].role}: {messages[-1].content}"
|
||||
)
|
||||
break
|
||||
|
||||
# emit thought vs response parts
|
||||
for part in chunk.candidates[0].content.parts:
|
||||
if part.text:
|
||||
aggregated_response += part.text
|
||||
yield ResponseWithThought(response=part.text)
|
||||
if part.thought:
|
||||
yield ResponseWithThought(thought=part.text)
|
||||
|
||||
# Calculate cost of chat
|
||||
input_tokens = final_chunk.usage_metadata.prompt_token_count or 0 if final_chunk else 0
|
||||
output_tokens = final_chunk.usage_metadata.candidates_token_count or 0 if final_chunk else 0
|
||||
|
||||
@@ -16,6 +16,7 @@ from khoj.processor.conversation.offline.utils import download_model
|
||||
from khoj.processor.conversation.utils import (
|
||||
clean_json,
|
||||
commit_conversation_trace,
|
||||
construct_question_history,
|
||||
generate_chatml_messages_with_context,
|
||||
messages_to_print,
|
||||
)
|
||||
@@ -64,13 +65,7 @@ def extract_questions_offline(
|
||||
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
|
||||
|
||||
# Extract Past User Message and Inferred Questions from Conversation Log
|
||||
chat_history = ""
|
||||
|
||||
if use_history:
|
||||
for chat in conversation_log.get("chat", [])[-4:]:
|
||||
if chat["by"] == "khoj":
|
||||
chat_history += f"Q: {chat['intent']['query']}\n"
|
||||
chat_history += f"Khoj: {chat['message']}\n\n"
|
||||
chat_history = construct_question_history(conversation_log, include_query=False) if use_history else ""
|
||||
|
||||
# Get dates relative to today for prompt creation
|
||||
today = datetime.today()
|
||||
|
||||
@@ -17,8 +17,10 @@ from khoj.processor.conversation.openai.utils import (
|
||||
)
|
||||
from khoj.processor.conversation.utils import (
|
||||
JsonSupport,
|
||||
OperatorRun,
|
||||
ResponseWithThought,
|
||||
clean_json,
|
||||
construct_question_history,
|
||||
construct_structured_message,
|
||||
generate_chatml_messages_with_context,
|
||||
messages_to_print,
|
||||
@@ -55,13 +57,7 @@ def extract_questions(
|
||||
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
|
||||
|
||||
# Extract Past User Message and Inferred Questions from Conversation Log
|
||||
chat_history = "".join(
|
||||
[
|
||||
f'Q: {chat["intent"]["query"]}\nKhoj: {{"queries": {chat["intent"].get("inferred-queries") or list([chat["intent"]["query"]])}}}\nA: {chat["message"]}\n\n'
|
||||
for chat in conversation_log.get("chat", [])[-4:]
|
||||
if chat["by"] == "khoj" and "to-image" not in chat["intent"].get("type")
|
||||
]
|
||||
)
|
||||
chat_history = construct_question_history(conversation_log)
|
||||
|
||||
# Get dates relative to today for prompt creation
|
||||
today = datetime.today()
|
||||
@@ -169,7 +165,7 @@ async def converse_openai(
|
||||
references: list[dict],
|
||||
online_results: Optional[Dict[str, Dict]] = None,
|
||||
code_results: Optional[Dict[str, Dict]] = None,
|
||||
operator_results: Optional[Dict[str, str]] = None,
|
||||
operator_results: Optional[List[OperatorRun]] = None,
|
||||
conversation_log={},
|
||||
model: str = "gpt-4o-mini",
|
||||
api_key: Optional[str] = None,
|
||||
@@ -242,8 +238,11 @@ async def converse_openai(
|
||||
f"{prompts.code_executed_context.format(code_results=truncate_code_context(code_results))}\n\n"
|
||||
)
|
||||
if not is_none_or_empty(operator_results):
|
||||
operator_content = [
|
||||
{"query": oc.query, "response": oc.response, "webpages": oc.webpages} for oc in operator_results
|
||||
]
|
||||
context_message += (
|
||||
f"{prompts.operator_execution_context.format(operator_results=yaml_dump(operator_results))}\n\n"
|
||||
f"{prompts.operator_execution_context.format(operator_results=yaml_dump(operator_content))}\n\n"
|
||||
)
|
||||
|
||||
context_message = context_message.strip()
|
||||
|
||||
@@ -10,7 +10,7 @@ from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from io import BytesIO
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Union
|
||||
|
||||
import PIL.Image
|
||||
import pyjson5
|
||||
@@ -20,6 +20,7 @@ import yaml
|
||||
from langchain_core.messages.chat import ChatMessage
|
||||
from llama_cpp import LlamaTokenizer
|
||||
from llama_cpp.llama import Llama
|
||||
from pydantic import BaseModel
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
|
||||
from khoj.database.adapters import ConversationAdapters
|
||||
@@ -73,9 +74,9 @@ model_to_prompt_size = {
|
||||
"claude-3-7-sonnet-20250219": 60000,
|
||||
"claude-3-7-sonnet-latest": 60000,
|
||||
"claude-3-5-haiku-20241022": 60000,
|
||||
"claude-sonnet-4": 60000,
|
||||
"claude-sonnet-4-0": 60000,
|
||||
"claude-sonnet-4-20250514": 60000,
|
||||
"claude-opus-4": 60000,
|
||||
"claude-opus-4-0": 60000,
|
||||
"claude-opus-4-20250514": 60000,
|
||||
# Offline Models
|
||||
"bartowski/Qwen2.5-14B-Instruct-GGUF": 20000,
|
||||
@@ -87,7 +88,49 @@ model_to_prompt_size = {
|
||||
model_to_tokenizer: Dict[str, str] = {}
|
||||
|
||||
|
||||
class InformationCollectionIteration:
|
||||
class AgentMessage(BaseModel):
|
||||
role: Literal["user", "assistant", "system", "environment"]
|
||||
content: Union[str, List]
|
||||
|
||||
|
||||
class OperatorRun:
|
||||
def __init__(
|
||||
self,
|
||||
query: str,
|
||||
trajectory: list[AgentMessage] | list[dict] = None,
|
||||
response: str = None,
|
||||
webpages: list[dict] = None,
|
||||
):
|
||||
self.query = query
|
||||
self.response = response
|
||||
self.webpages = webpages or []
|
||||
self.trajectory: list[AgentMessage] = []
|
||||
if trajectory:
|
||||
for item in trajectory:
|
||||
if isinstance(item, dict):
|
||||
self.trajectory.append(AgentMessage(**item))
|
||||
elif hasattr(item, "role") and hasattr(item, "content"): # Heuristic for AgentMessage like object
|
||||
self.trajectory.append(item)
|
||||
else:
|
||||
logger.warning(f"Unexpected item type in trajectory: {type(item)}")
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
# Ensure AgentMessage instances in trajectory are also dicts
|
||||
serialized_trajectory = []
|
||||
for msg in self.trajectory:
|
||||
if hasattr(msg, "model_dump"): # Check if it's a Pydantic model
|
||||
serialized_trajectory.append(msg.model_dump())
|
||||
elif isinstance(msg, dict):
|
||||
serialized_trajectory.append(msg) # Already a dict
|
||||
return {
|
||||
"query": self.query,
|
||||
"response": self.response,
|
||||
"trajectory": serialized_trajectory,
|
||||
"webpages": self.webpages,
|
||||
}
|
||||
|
||||
|
||||
class ResearchIteration:
|
||||
def __init__(
|
||||
self,
|
||||
tool: str,
|
||||
@@ -95,7 +138,7 @@ class InformationCollectionIteration:
|
||||
context: list = None,
|
||||
onlineContext: dict = None,
|
||||
codeContext: dict = None,
|
||||
operatorContext: dict[str, str] = None,
|
||||
operatorContext: dict | OperatorRun = None,
|
||||
summarizedResult: str = None,
|
||||
warning: str = None,
|
||||
):
|
||||
@@ -104,13 +147,18 @@ class InformationCollectionIteration:
|
||||
self.context = context
|
||||
self.onlineContext = onlineContext
|
||||
self.codeContext = codeContext
|
||||
self.operatorContext = operatorContext
|
||||
self.operatorContext = OperatorRun(**operatorContext) if isinstance(operatorContext, dict) else operatorContext
|
||||
self.summarizedResult = summarizedResult
|
||||
self.warning = warning
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
data = vars(self).copy()
|
||||
data["operatorContext"] = self.operatorContext.to_dict() if self.operatorContext else None
|
||||
return data
|
||||
|
||||
|
||||
def construct_iteration_history(
|
||||
previous_iterations: List[InformationCollectionIteration],
|
||||
previous_iterations: List[ResearchIteration],
|
||||
previous_iteration_prompt: str,
|
||||
query: str = None,
|
||||
) -> list[dict]:
|
||||
@@ -143,11 +191,8 @@ def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="A
|
||||
chat_history = ""
|
||||
for chat in conversation_history.get("chat", [])[-n:]:
|
||||
if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder", "summarize"]:
|
||||
chat_history += f"User: {chat['intent']['query']}\n"
|
||||
|
||||
if chat["intent"].get("inferred-queries"):
|
||||
chat_history += f'{agent_name}: {{"queries": {chat["intent"].get("inferred-queries")}}}\n'
|
||||
|
||||
chat_history += f"{agent_name}: {chat['message']}\n\n"
|
||||
elif chat["by"] == "khoj" and chat.get("images"):
|
||||
chat_history += f"User: {chat['intent']['query']}\n"
|
||||
@@ -156,6 +201,7 @@ def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="A
|
||||
chat_history += f"User: {chat['intent']['query']}\n"
|
||||
chat_history += f"{agent_name}: {chat['intent']['inferred-queries'][0]}\n"
|
||||
elif chat["by"] == "you":
|
||||
chat_history += f"User: {chat['message']}\n"
|
||||
raw_query_files = chat.get("queryFiles")
|
||||
if raw_query_files:
|
||||
query_files: Dict[str, str] = {}
|
||||
@@ -168,8 +214,74 @@ def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="A
|
||||
return chat_history
|
||||
|
||||
|
||||
def construct_question_history(
|
||||
conversation_log: dict,
|
||||
include_query: bool = True,
|
||||
lookback: int = 6,
|
||||
query_prefix: str = "Q",
|
||||
agent_name: str = "Khoj",
|
||||
) -> str:
|
||||
"""
|
||||
Constructs a chat history string formatted for query extraction purposes.
|
||||
"""
|
||||
history_parts = ""
|
||||
original_query = None
|
||||
for chat in conversation_log.get("chat", [])[-lookback:]:
|
||||
if chat["by"] == "you":
|
||||
original_query = chat.get("message")
|
||||
history_parts += f"{query_prefix}: {original_query}\n"
|
||||
if chat["by"] == "khoj":
|
||||
if original_query is None:
|
||||
continue
|
||||
|
||||
message = chat.get("message", "")
|
||||
inferred_queries_list = chat.get("intent", {}).get("inferred-queries")
|
||||
|
||||
# Ensure inferred_queries_list is a list, defaulting to the original query in a list
|
||||
if not inferred_queries_list:
|
||||
inferred_queries_list = [original_query]
|
||||
# If it's a string (though unlikely based on usage), wrap it in a list
|
||||
elif isinstance(inferred_queries_list, str):
|
||||
inferred_queries_list = [inferred_queries_list]
|
||||
|
||||
if include_query:
|
||||
# Ensure 'type' exists and is a string before checking 'to-image'
|
||||
intent_type = chat.get("intent", {}).get("type", "")
|
||||
if "to-image" not in intent_type:
|
||||
history_parts += f'{agent_name}: {{"queries": {inferred_queries_list}}}\n'
|
||||
history_parts += f"A: {message}\n\n"
|
||||
else:
|
||||
history_parts += f"{agent_name}: {message}\n\n"
|
||||
|
||||
# Reset original_query for the next turn
|
||||
original_query = None
|
||||
|
||||
return history_parts
|
||||
|
||||
|
||||
def construct_chat_history_for_operator(conversation_history: dict, n: int = 6) -> list[AgentMessage]:
|
||||
"""
|
||||
Construct chat history for operator agent in conversation log.
|
||||
Only include last n completed turns (i.e with user and khoj message).
|
||||
"""
|
||||
chat_history: list[AgentMessage] = []
|
||||
user_message: Optional[AgentMessage] = None
|
||||
|
||||
for chat in conversation_history.get("chat", []):
|
||||
if len(chat_history) >= n:
|
||||
break
|
||||
if chat["by"] == "you" and chat.get("message"):
|
||||
content = [{"type": "text", "text": chat["message"]}]
|
||||
for file in chat.get("queryFiles", []):
|
||||
content += [{"type": "text", "text": f'## File: {file["name"]}\n\n{file["content"]}'}]
|
||||
user_message = AgentMessage(role="user", content=content)
|
||||
elif chat["by"] == "khoj" and chat.get("message"):
|
||||
chat_history += [user_message, AgentMessage(role="assistant", content=chat["message"])]
|
||||
return chat_history
|
||||
|
||||
|
||||
def construct_tool_chat_history(
|
||||
previous_iterations: List[InformationCollectionIteration], tool: ConversationCommand = None
|
||||
previous_iterations: List[ResearchIteration], tool: ConversationCommand = None
|
||||
) -> Dict[str, list]:
|
||||
"""
|
||||
Construct chat history from previous iterations for a specific tool
|
||||
@@ -178,8 +290,8 @@ def construct_tool_chat_history(
|
||||
If no tool is provided inferred query for all tools used are added.
|
||||
"""
|
||||
chat_history: list = []
|
||||
base_extractor: Callable[[InformationCollectionIteration], List[str]] = lambda x: []
|
||||
extract_inferred_query_map: Dict[ConversationCommand, Callable[[InformationCollectionIteration], List[str]]] = {
|
||||
base_extractor: Callable[[ResearchIteration], List[str]] = lambda iteration: []
|
||||
extract_inferred_query_map: Dict[ConversationCommand, Callable[[ResearchIteration], List[str]]] = {
|
||||
ConversationCommand.Notes: (
|
||||
lambda iteration: [c["query"] for c in iteration.context] if iteration.context else []
|
||||
),
|
||||
@@ -192,9 +304,6 @@ def construct_tool_chat_history(
|
||||
ConversationCommand.Code: (
|
||||
lambda iteration: list(iteration.codeContext.keys()) if iteration.codeContext else []
|
||||
),
|
||||
ConversationCommand.Operator: (
|
||||
lambda iteration: list(iteration.operatorContext.keys()) if iteration.operatorContext else []
|
||||
),
|
||||
}
|
||||
for iteration in previous_iterations:
|
||||
# If a tool is provided use the inferred query extractor for that tool if available
|
||||
@@ -273,7 +382,7 @@ async def save_to_conversation_log(
|
||||
compiled_references: List[Dict[str, Any]] = [],
|
||||
online_results: Dict[str, Any] = {},
|
||||
code_results: Dict[str, Any] = {},
|
||||
operator_results: Dict[str, str] = {},
|
||||
operator_results: List[OperatorRun] = None,
|
||||
inferred_queries: List[str] = [],
|
||||
intent_type: str = "remember",
|
||||
client_application: ClientApplication = None,
|
||||
@@ -284,7 +393,7 @@ async def save_to_conversation_log(
|
||||
generated_images: List[str] = [],
|
||||
raw_generated_files: List[FileAttachment] = [],
|
||||
generated_mermaidjs_diagram: str = None,
|
||||
research_results: Optional[List[InformationCollectionIteration]] = None,
|
||||
research_results: Optional[List[ResearchIteration]] = None,
|
||||
train_of_thought: List[Any] = [],
|
||||
tracer: Dict[str, Any] = {},
|
||||
):
|
||||
@@ -301,8 +410,8 @@ async def save_to_conversation_log(
|
||||
"intent": {"inferred-queries": inferred_queries, "type": intent_type},
|
||||
"onlineContext": online_results,
|
||||
"codeContext": code_results,
|
||||
"operatorContext": operator_results,
|
||||
"researchContext": [vars(r) for r in research_results] if research_results and not chat_response else None,
|
||||
"operatorContext": [o.to_dict() for o in operator_results] if operator_results and not chat_response else None,
|
||||
"researchContext": [r.to_dict() for r in research_results] if research_results and not chat_response else None,
|
||||
"automationId": automation_id,
|
||||
"trainOfThought": train_of_thought,
|
||||
"turnId": turn_id,
|
||||
@@ -459,10 +568,12 @@ def generate_chatml_messages_with_context(
|
||||
]
|
||||
|
||||
if not is_none_or_empty(chat.get("operatorContext")):
|
||||
operator_context = chat.get("operatorContext")
|
||||
operator_content = "\n\n".join([f'## Task: {oc["query"]}\n{oc["response"]}\n' for oc in operator_context])
|
||||
message_context += [
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"{prompts.operator_execution_context.format(operator_results=chat.get('operatorContext'))}",
|
||||
"text": f"{prompts.operator_execution_context.format(operator_results=operator_content)}",
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
59
src/khoj/processor/operator/README.md
Normal file
59
src/khoj/processor/operator/README.md
Normal file
@@ -0,0 +1,59 @@
|
||||
# Khoj Operator (Experimental)
|
||||
|
||||
## Overview
|
||||
Give Khoj its own computer to operate in a transparent, controlled manner. Accomplish tasks that require visual browsing, file editing and terminal access. Operator with research mode can work for 30+ minutes to accomplish more substantial tasks like feature development, travel planning, shopping etc.
|
||||
|
||||
## Setup
|
||||
|
||||
### Prerequisites
|
||||
- Docker and Docker Compose installed
|
||||
- Anthropic API key (required - only Anthropic models currently enabled)
|
||||
|
||||
### Installation Steps
|
||||
1. Download the Khoj docker-compose.yml file
|
||||
```shell
|
||||
mkdir ~/.khoj && cd ~/.khoj
|
||||
wget https://raw.githubusercontent.com/khoj-ai/khoj/master/docker-compose.yml
|
||||
```
|
||||
|
||||
2. Configure environment variables in `docker-compose.yml`
|
||||
- Set `ANTHROPIC_API_KEY` to your [Anthropic API key](https://console.anthropic.com/settings/keys)
|
||||
- Uncomment `KHOJ_OPERATOR_ENABLED=True` to enable the operator tool
|
||||
|
||||
3. Start Khoj services
|
||||
```shell
|
||||
docker-compose up
|
||||
```
|
||||
|
||||
4. Access the web app at http://localhost:42110
|
||||
Ensure you're using a claude 3.7+ models on your [settings page](http://localhost:42110/settings)
|
||||
|
||||
## Usage
|
||||
Use the `/operator` command or ask Khoj in normal or research mode to use the operator tool to have it operate its computer:
|
||||
|
||||
**Examples:**
|
||||
- `/operator Find flights from Bangkok to Mexico City with no US layover`
|
||||
- `/research Clone the khoj repo and tell me how the operator tool is implemented`
|
||||
|
||||
## Supported Models
|
||||
|
||||
Currently enables **only Anthropic models**:
|
||||
- Claude Sonnet 4
|
||||
- Claude 3.7 Sonnet
|
||||
- Claude Opus 4
|
||||
|
||||
*Note: OpenAI and other operator models are disabled while in developemnt.*
|
||||
|
||||
## Capabilities
|
||||
|
||||
The operator can:
|
||||
- **Computer Control**: Take screenshots, click, type, navigate desktop
|
||||
- **File Operations**: Create, edit, and manage files
|
||||
- **Terminal Access**: Execute bash commands and scripts
|
||||
- **Web Browsing**: Navigate websites, documents and extract information
|
||||
|
||||
## Architecture
|
||||
|
||||
- **Environments**: Operator Computer and Browser environments
|
||||
- **Models**: Enable Vision Language Models (VLM) to operate computer
|
||||
- **Execution**: Containerize computer environment for security and isolation
|
||||
@@ -6,13 +6,23 @@ from typing import Callable, List, Optional
|
||||
|
||||
from khoj.database.adapters import AgentAdapters, ConversationAdapters
|
||||
from khoj.database.models import Agent, ChatModel, KhojUser
|
||||
from khoj.processor.conversation.utils import (
|
||||
OperatorRun,
|
||||
construct_chat_history,
|
||||
construct_chat_history_for_operator,
|
||||
)
|
||||
from khoj.processor.operator.operator_actions import *
|
||||
from khoj.processor.operator.operator_agent_anthropic import AnthropicOperatorAgent
|
||||
from khoj.processor.operator.operator_agent_base import OperatorAgent
|
||||
from khoj.processor.operator.operator_agent_binary import BinaryOperatorAgent
|
||||
from khoj.processor.operator.operator_agent_openai import OpenAIOperatorAgent
|
||||
from khoj.processor.operator.operator_environment_base import EnvStepResult
|
||||
from khoj.processor.operator.operator_environment_base import (
|
||||
Environment,
|
||||
EnvironmentType,
|
||||
EnvStepResult,
|
||||
)
|
||||
from khoj.processor.operator.operator_environment_browser import BrowserEnvironment
|
||||
from khoj.processor.operator.operator_environment_computer import ComputerEnvironment
|
||||
from khoj.routers.helpers import ChatEvent
|
||||
from khoj.utils.helpers import timer
|
||||
from khoj.utils.rawconfig import LocationData
|
||||
@@ -20,12 +30,14 @@ from khoj.utils.rawconfig import LocationData
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# --- Browser Operator Function ---
|
||||
async def operate_browser(
|
||||
# --- Main Operator Entrypoint ---
|
||||
async def operate_environment(
|
||||
query: str,
|
||||
user: KhojUser,
|
||||
conversation_log: dict,
|
||||
location_data: LocationData,
|
||||
previous_trajectory: Optional[OperatorRun] = None,
|
||||
environment_type: EnvironmentType = EnvironmentType.COMPUTER,
|
||||
send_status_func: Optional[Callable] = None,
|
||||
query_images: Optional[List[str]] = None, # TODO: Handle query images
|
||||
agent: Agent = None,
|
||||
@@ -33,8 +45,11 @@ async def operate_browser(
|
||||
cancellation_event: Optional[asyncio.Event] = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
response, summary_message, user_input_message = None, None, None
|
||||
environment: Optional[BrowserEnvironment] = None
|
||||
response, user_input_message = None, None
|
||||
|
||||
# Only use partial previous trajectories to continue existing task
|
||||
if previous_trajectory and previous_trajectory.response:
|
||||
previous_trajectory = None
|
||||
|
||||
# Get the agent chat model
|
||||
agent_chat_model = await AgentAdapters.aget_agent_chat_model(agent, user) if agent else None
|
||||
@@ -42,16 +57,40 @@ async def operate_browser(
|
||||
if not reasoning_model or not reasoning_model.vision_enabled:
|
||||
reasoning_model = await ConversationAdapters.aget_vision_enabled_config()
|
||||
if not reasoning_model:
|
||||
raise ValueError(f"No vision enabled chat model found. Configure a vision chat model to operate browser.")
|
||||
raise ValueError(f"No vision enabled chat model found. Configure a vision chat model to operate environment.")
|
||||
|
||||
# Create conversation history from conversation log
|
||||
chat_history = construct_chat_history_for_operator(conversation_log)
|
||||
|
||||
# Initialize Agent
|
||||
max_iterations = int(os.getenv("KHOJ_OPERATOR_ITERATIONS", 40))
|
||||
max_context = await ConversationAdapters.aget_max_context_size(reasoning_model, user) or 20000
|
||||
max_iterations = int(os.getenv("KHOJ_OPERATOR_ITERATIONS", 100))
|
||||
operator_agent: OperatorAgent
|
||||
if is_operator_model(reasoning_model.name) == ChatModel.ModelType.OPENAI:
|
||||
operator_agent = OpenAIOperatorAgent(query, reasoning_model, max_iterations, tracer)
|
||||
elif is_operator_model(reasoning_model.name) == ChatModel.ModelType.ANTHROPIC:
|
||||
operator_agent = AnthropicOperatorAgent(query, reasoning_model, max_iterations, tracer)
|
||||
else:
|
||||
if is_operator_model(reasoning_model.name) == ChatModel.ModelType.ANTHROPIC:
|
||||
operator_agent = AnthropicOperatorAgent(
|
||||
query,
|
||||
reasoning_model,
|
||||
environment_type,
|
||||
max_iterations,
|
||||
max_context,
|
||||
chat_history,
|
||||
previous_trajectory,
|
||||
tracer,
|
||||
)
|
||||
# TODO: Remove once OpenAI Operator Agent is useful
|
||||
elif is_operator_model(reasoning_model.name) == ChatModel.ModelType.OPENAI and False:
|
||||
operator_agent = OpenAIOperatorAgent(
|
||||
query,
|
||||
reasoning_model,
|
||||
environment_type,
|
||||
max_iterations,
|
||||
max_context,
|
||||
chat_history,
|
||||
previous_trajectory,
|
||||
tracer,
|
||||
)
|
||||
# TODO: Remove once Binary Operator Agent is useful
|
||||
elif False:
|
||||
grounding_model_name = "ui-tars-1.5"
|
||||
grounding_model = await ConversationAdapters.aget_chat_model_by_name(grounding_model_name)
|
||||
if (
|
||||
@@ -59,41 +98,62 @@ async def operate_browser(
|
||||
or not grounding_model.vision_enabled
|
||||
or not grounding_model.model_type == ChatModel.ModelType.OPENAI
|
||||
):
|
||||
raise ValueError("No supported visual grounding model for binary operator agent found.")
|
||||
operator_agent = BinaryOperatorAgent(query, reasoning_model, grounding_model, max_iterations, tracer)
|
||||
raise ValueError("Binary operator agent needs ui-tars-1.5 served over an OpenAI compatible API.")
|
||||
operator_agent = BinaryOperatorAgent(
|
||||
query,
|
||||
reasoning_model,
|
||||
grounding_model,
|
||||
environment_type,
|
||||
max_iterations,
|
||||
max_context,
|
||||
chat_history,
|
||||
previous_trajectory,
|
||||
tracer,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported operator model: {reasoning_model.name}. "
|
||||
"Please use a supported operator model. Only Anthropic models are currently supported."
|
||||
)
|
||||
|
||||
# Initialize Environment
|
||||
if send_status_func:
|
||||
async for event in send_status_func(f"**Launching Browser**"):
|
||||
async for event in send_status_func(f"**Launching {environment_type.value}**"):
|
||||
yield {ChatEvent.STATUS: event}
|
||||
environment = BrowserEnvironment()
|
||||
if environment_type == EnvironmentType.BROWSER:
|
||||
environment: Environment = BrowserEnvironment()
|
||||
else:
|
||||
environment = ComputerEnvironment(provider="docker")
|
||||
await environment.start(width=1024, height=768)
|
||||
|
||||
# Start Operator Loop
|
||||
try:
|
||||
summarize_prompt = f"Use the results of our research to provide a comprehensive, self-contained answer for the target query:\n{query}."
|
||||
task_completed = False
|
||||
iterations = 0
|
||||
operator_run = OperatorRun(query=query, trajectory=operator_agent.messages, response=response)
|
||||
yield operator_run
|
||||
|
||||
with timer(f"Operating browser with {reasoning_model.model_type} {reasoning_model.name}", logger):
|
||||
with timer(
|
||||
f"Operating {environment_type.value} with {reasoning_model.model_type} {reasoning_model.name}", logger
|
||||
):
|
||||
while iterations < max_iterations and not task_completed:
|
||||
if cancellation_event and cancellation_event.is_set():
|
||||
logger.debug(f"Browser operator cancelled by client disconnect")
|
||||
logger.debug(f"{environment_type.value} operator cancelled by client disconnect")
|
||||
break
|
||||
|
||||
iterations += 1
|
||||
|
||||
# 1. Get current environment state
|
||||
browser_state = await environment.get_state()
|
||||
env_state = await environment.get_state()
|
||||
|
||||
# 2. Agent decides action(s)
|
||||
agent_result = await operator_agent.act(browser_state)
|
||||
agent_result = await operator_agent.act(env_state)
|
||||
|
||||
# 3. Execute actions in the environment
|
||||
env_steps: List[EnvStepResult] = []
|
||||
for action in agent_result.actions:
|
||||
if cancellation_event and cancellation_event.is_set():
|
||||
logger.debug(f"Browser operator cancelled by client disconnect")
|
||||
logger.debug(f"{environment_type.value} operator cancelled by client disconnect")
|
||||
break
|
||||
# Handle request for user action and break the loop
|
||||
if isinstance(action, RequestUserAction):
|
||||
@@ -106,12 +166,14 @@ async def operate_browser(
|
||||
env_steps.append(env_step)
|
||||
|
||||
# Render status update
|
||||
latest_screenshot = f"data:image/webp;base64,{env_steps[-1].screenshot_base64 if env_steps else browser_state.screenshot}"
|
||||
latest_screenshot = (
|
||||
f"data:image/webp;base64,{env_steps[-1].screenshot_base64 if env_steps else env_state.screenshot}"
|
||||
)
|
||||
render_payload = agent_result.rendered_response
|
||||
render_payload["image"] = latest_screenshot
|
||||
render_content = f"**Action**: {json.dumps(render_payload)}"
|
||||
if send_status_func:
|
||||
async for event in send_status_func(f"**Operating Browser**:\n{render_content}"):
|
||||
async for event in send_status_func(f"**Operating {environment_type.value}**:\n{render_content}"):
|
||||
yield {ChatEvent.STATUS: event}
|
||||
|
||||
# Check if termination conditions are met
|
||||
@@ -123,31 +185,33 @@ async def operate_browser(
|
||||
if task_completed or trigger_iteration_limit:
|
||||
# Summarize results of operator run on last iteration
|
||||
operator_agent.add_action_results(env_steps, agent_result)
|
||||
summary_message = await operator_agent.summarize(summarize_prompt, browser_state)
|
||||
summary_message = await operator_agent.summarize(env_state)
|
||||
logger.info(f"Task completed: {task_completed}, Iteration limit: {trigger_iteration_limit}")
|
||||
break
|
||||
|
||||
# 4. Update agent on the results of its action on the environment
|
||||
operator_agent.add_action_results(env_steps, agent_result)
|
||||
operator_run.trajectory = operator_agent.messages
|
||||
|
||||
# Determine final response message
|
||||
if user_input_message:
|
||||
response = user_input_message
|
||||
operator_run.response = user_input_message
|
||||
elif task_completed:
|
||||
response = summary_message
|
||||
operator_run.response = summary_message
|
||||
elif cancellation_event and cancellation_event.is_set():
|
||||
operator_run.response = None
|
||||
else: # Hit iteration limit
|
||||
response = f"Operator hit iteration limit ({max_iterations}). If the results seem incomplete try again, assign a smaller task or try a different approach.\nThese were the results till now:\n{summary_message}"
|
||||
operator_run.response = f"Operator hit iteration limit ({max_iterations}). If the results seem incomplete try again, assign a smaller task or try a different approach.\nThese were the results till now:\n{summary_message}"
|
||||
finally:
|
||||
if environment and not user_input_message: # Don't close browser if user input required
|
||||
if environment and not user_input_message: # Don't close environment if user input required
|
||||
await environment.close()
|
||||
if operator_agent:
|
||||
operator_agent.reset()
|
||||
|
||||
yield {
|
||||
"query": query,
|
||||
"result": user_input_message or response,
|
||||
"webpages": [{"link": url, "snippet": ""} for url in environment.visited_urls],
|
||||
}
|
||||
if environment_type == EnvironmentType.BROWSER and hasattr(environment, "visited_urls"):
|
||||
operator_run.webpages = [{"link": url, "snippet": ""} for url in environment.visited_urls]
|
||||
|
||||
yield operator_run
|
||||
|
||||
|
||||
def is_operator_model(model: str) -> ChatModel.ModelType | None:
|
||||
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
from textwrap import dedent
|
||||
|
||||
from openai import AzureOpenAI, OpenAI
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionMessage
|
||||
@@ -8,7 +9,7 @@ from khoj.database.models import ChatModel
|
||||
from khoj.processor.conversation.utils import construct_structured_message
|
||||
from khoj.processor.operator.operator_actions import *
|
||||
from khoj.processor.operator.operator_agent_base import AgentActResult
|
||||
from khoj.processor.operator.operator_environment_base import EnvState
|
||||
from khoj.processor.operator.operator_environment_base import EnvironmentType, EnvState
|
||||
from khoj.utils.helpers import get_chat_usage_metrics
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -18,6 +19,7 @@ class GroundingAgent:
|
||||
def __init__(
|
||||
self,
|
||||
model: ChatModel,
|
||||
environment_type: EnvironmentType,
|
||||
client: OpenAI | AzureOpenAI,
|
||||
max_iterations: int,
|
||||
tracer: dict = None,
|
||||
@@ -26,9 +28,211 @@ class GroundingAgent:
|
||||
self.client = client
|
||||
self.max_iterations = max_iterations
|
||||
self.tracer = tracer
|
||||
self.environment_type = environment_type
|
||||
self.action_tools = self.get_tools(self.environment_type)
|
||||
|
||||
# Define tools for the grounding LLM (OpenAI format)
|
||||
self.action_tools = [
|
||||
async def act(self, instruction: str, current_state: EnvState) -> tuple[str, list[OperatorAction]]:
|
||||
"""Call the grounding LLM to get the next action based on the current state and instruction."""
|
||||
# Format the message for the API call
|
||||
messages_for_api = self._format_message_for_api(instruction, current_state)
|
||||
try:
|
||||
grounding_response: ChatCompletion = await self.client.chat.completions.create(
|
||||
messages=messages_for_api,
|
||||
model=self.model.name,
|
||||
tools=self.action_tools,
|
||||
tool_choice="required",
|
||||
temperature=0.0, # Grounding should be precise
|
||||
max_completion_tokens=1000, # Allow for thoughts + actions
|
||||
)
|
||||
if not isinstance(grounding_response, ChatCompletion):
|
||||
raise ValueError("Grounding LLM response is not of type ChatCompletion.")
|
||||
logger.debug(f"Grounding LLM response: {grounding_response.model_dump_json()}")
|
||||
|
||||
# Parse tool calls
|
||||
grounding_message = grounding_response.choices[0].message
|
||||
rendered_response, actions = self._parse_action(grounding_message, instruction, current_state)
|
||||
|
||||
# Update usage by grounding model
|
||||
self.tracer["usage"] = get_chat_usage_metrics(
|
||||
self.model.name,
|
||||
input_tokens=grounding_response.usage.prompt_tokens,
|
||||
output_tokens=grounding_response.usage.completion_tokens,
|
||||
usage=self.tracer.get("usage"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error calling Grounding LLM: {e}")
|
||||
rendered_response = f"**Error**: Error contacting Grounding LLM: {e}"
|
||||
actions = []
|
||||
|
||||
return rendered_response, actions
|
||||
|
||||
def _format_message_for_api(self, instruction: str, current_state: EnvState) -> List:
|
||||
"""Format the message for the API call."""
|
||||
# Construct grounding LLM input (using only the latest user prompt + image)
|
||||
# We don't pass the full history here, as grounding depends on the *current* state + NL action
|
||||
grounding_user_prompt = self.get_instruction(instruction, self.environment_type)
|
||||
screenshots = [f"data:image/webp;base64,{current_state.screenshot}"]
|
||||
grounding_messages_content = construct_structured_message(
|
||||
grounding_user_prompt, screenshots, self.model.name, vision_enabled=True
|
||||
)
|
||||
return [{"role": "user", "content": grounding_messages_content}]
|
||||
|
||||
def _parse_action(
|
||||
self, grounding_message: ChatCompletionMessage, instruction: str, current_state: EnvState
|
||||
) -> tuple[str, list[OperatorAction]]:
|
||||
"""Parse the tool calls from the grounding LLM response and convert them to action objects."""
|
||||
actions: List[OperatorAction] = []
|
||||
action_results: List[dict] = []
|
||||
|
||||
if grounding_message.tool_calls:
|
||||
rendered_parts = []
|
||||
for tool_call in grounding_message.tool_calls:
|
||||
function_name = tool_call.function.name
|
||||
try:
|
||||
arguments = json.loads(tool_call.function.arguments)
|
||||
action_to_run: Optional[OperatorAction] = None
|
||||
action_render_str = f"**Action ({function_name})**: {tool_call.function.arguments}"
|
||||
|
||||
if function_name == "click":
|
||||
action_to_run = ClickAction(**arguments)
|
||||
elif function_name == "left_double":
|
||||
action_to_run = DoubleClickAction(**arguments)
|
||||
elif function_name == "right_single":
|
||||
action_to_run = ClickAction(button="right", **arguments)
|
||||
elif function_name == "type":
|
||||
content = arguments.get("content")
|
||||
action_to_run = TypeAction(text=content)
|
||||
elif function_name == "scroll":
|
||||
direction = arguments.get("direction", "down")
|
||||
amount = 3
|
||||
action_to_run = ScrollAction(scroll_direction=direction, scroll_amount=amount, **arguments)
|
||||
elif function_name == "hotkey":
|
||||
action_to_run = KeypressAction(**arguments)
|
||||
elif function_name == "goto":
|
||||
action_to_run = GotoAction(**arguments)
|
||||
elif function_name == "back":
|
||||
action_to_run = BackAction(**arguments)
|
||||
elif function_name == "wait":
|
||||
action_to_run = WaitAction(**arguments)
|
||||
elif function_name == "screenshot":
|
||||
action_to_run = ScreenshotAction(**arguments)
|
||||
elif function_name == "drag":
|
||||
# Need to convert list of dicts to list of Point objects
|
||||
path_dicts = arguments.get("path", [])
|
||||
path_points = [Point(**p) for p in path_dicts]
|
||||
if path_points:
|
||||
action_to_run = DragAction(path=path_points)
|
||||
else:
|
||||
logger.warning(f"Drag action called with empty path: {arguments}")
|
||||
action_render_str += " [Skipped - empty path]"
|
||||
elif function_name == "finished":
|
||||
action_to_run = None
|
||||
else:
|
||||
logger.warning(f"Grounding LLM called unhandled tool: {function_name}")
|
||||
action_render_str += " [Unhandled]"
|
||||
|
||||
if action_to_run:
|
||||
actions.append(action_to_run)
|
||||
action_results.append(
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_call_id": tool_call.id,
|
||||
"content": None, # Updated after environment step
|
||||
}
|
||||
)
|
||||
rendered_parts.append(action_render_str)
|
||||
except (json.JSONDecodeError, TypeError, ValueError) as arg_err:
|
||||
logger.error(
|
||||
f"Error parsing arguments for tool {function_name}: {arg_err} - Args: {tool_call.function.arguments}"
|
||||
)
|
||||
rendered_parts.append(f"**Error**: Failed to parse arguments for {function_name}")
|
||||
rendered_response = "\n- ".join(rendered_parts)
|
||||
else:
|
||||
# Grounding LLM responded but didn't call a tool
|
||||
logger.warning("Grounding LLM did not produce a tool call.")
|
||||
rendered_response = f"{grounding_message.content or 'No action required.'}"
|
||||
|
||||
# Render the response
|
||||
return rendered_response, actions
|
||||
|
||||
def get_instruction(self, instruction: str, environment_type: EnvironmentType) -> str:
|
||||
"""
|
||||
Get the instruction for the agent based on the environment type.
|
||||
"""
|
||||
UITARS_COMPUTER_PREFIX_PROMPT = """
|
||||
You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
|
||||
"""
|
||||
UITARS_BROWSER_PREFIX_PROMPT = """
|
||||
You are a GUI agent. You are given a task and a screenshot of the web browser tab you operate. You need to decide the next action to complete the task.
|
||||
You control a single tab in a Chromium browser. You cannot access the OS, filesystem or the application window.
|
||||
Always use the `goto` function to navigate to a specific URL. Ctrl+t, Ctrl+w, Ctrl+q, Ctrl+Shift+T, Ctrl+Shift+W are not allowed.
|
||||
"""
|
||||
|
||||
UITARS_USR_COMPUTER_PROMPT_THOUGHT = f"""
|
||||
Try fulfill the user instruction to the best of your ability, especially when the instruction is given multiple times. Do not ignore the instruction.
|
||||
|
||||
## Output Format
|
||||
```
|
||||
Thought: ...
|
||||
Action: ...
|
||||
```
|
||||
|
||||
## Action Space
|
||||
click(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
left_double(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
right_single(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>')
|
||||
hotkey(key='')
|
||||
type(content='xxx') # Use escape characters \\', \\\", and \\n in content part to ensure we can parse the content in normal python string format. If you want to submit your input, use \\n at the end of content.
|
||||
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
|
||||
wait(duration='time') # Sleep for specified time. Default is 1s and take a screenshot to check for any changes.
|
||||
|
||||
## Note
|
||||
- Use English in `Thought` part.
|
||||
- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part.
|
||||
|
||||
## User Instruction
|
||||
{instruction}
|
||||
"""
|
||||
UITARS_USR_BROWSER_PROMPT_THOUGHT = f"""
|
||||
Try fulfill the user instruction to the best of your ability, especially when the instruction is given multiple times. Do not ignore the instruction.
|
||||
|
||||
## Output Format
|
||||
```
|
||||
Thought: ...
|
||||
Action: ...
|
||||
```
|
||||
|
||||
## Action Space
|
||||
click(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
left_double(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
right_single(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>')
|
||||
hotkey(key='')
|
||||
type(content='xxx') # Use escape characters \\', \\\", and \\n in content part to ensure we can parse the content in normal python string format. If you want to submit your input, use \\n at the end of content.
|
||||
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
|
||||
wait(duration='time') # Sleep for specified time. Default is 1s and take a screenshot to check for any changes.
|
||||
goto(url='xxx') # Always use this to navigate to a specific URL. Use escape characters \\', \\", and \\n in url part to ensure we can parse the url in normal python string format.
|
||||
back() # Use this to go back to the previous page.
|
||||
|
||||
## Note
|
||||
- Use English in `Thought` part.
|
||||
- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part.
|
||||
|
||||
## User Instruction
|
||||
{instruction}
|
||||
"""
|
||||
|
||||
if environment_type == EnvironmentType.BROWSER:
|
||||
return dedent(UITARS_BROWSER_PREFIX_PROMPT + UITARS_USR_BROWSER_PROMPT_THOUGHT).lstrip()
|
||||
elif environment_type == EnvironmentType.COMPUTER:
|
||||
return dedent(UITARS_COMPUTER_PREFIX_PROMPT + UITARS_USR_COMPUTER_PROMPT_THOUGHT).lstrip()
|
||||
else:
|
||||
raise ValueError(f"Expected environment type: Computer or Browser. Got {environment_type}.")
|
||||
|
||||
def get_tools(self, environment_type: EnvironmentType) -> list[dict]:
|
||||
"""Get tools for the grounding LLM, in OpenAI API tool format"""
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
@@ -163,182 +367,32 @@ class GroundingAgent:
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "goto",
|
||||
"description": "Navigate to a specific URL.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"url": {"type": "string", "description": "Fully qualified URL"}},
|
||||
"required": ["url"],
|
||||
]
|
||||
if environment_type == EnvironmentType.BROWSER:
|
||||
tools += [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "goto",
|
||||
"description": "Navigate to a specific URL.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"url": {"type": "string", "description": "Fully qualified URL"}},
|
||||
"required": ["url"],
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "back",
|
||||
"description": "navigate back to the previous page.",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "back",
|
||||
"description": "navigate back to the previous page.",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
]
|
||||
|
||||
async def act(self, instruction: str, current_state: EnvState) -> tuple[str, list[OperatorAction]]:
|
||||
"""Call the grounding LLM to get the next action based on the current state and instruction."""
|
||||
# Format the message for the API call
|
||||
messages_for_api = self._format_message_for_api(instruction, current_state)
|
||||
try:
|
||||
grounding_response: ChatCompletion = await self.client.chat.completions.create(
|
||||
messages=messages_for_api,
|
||||
model=self.model.name,
|
||||
tools=self.action_tools,
|
||||
tool_choice="required",
|
||||
temperature=0.0, # Grounding should be precise
|
||||
max_completion_tokens=1000, # Allow for thoughts + actions
|
||||
)
|
||||
if not isinstance(grounding_response, ChatCompletion):
|
||||
raise ValueError("Grounding LLM response is not of type ChatCompletion.")
|
||||
logger.debug(f"Grounding LLM response: {grounding_response.model_dump_json()}")
|
||||
|
||||
# Parse tool calls
|
||||
grounding_message = grounding_response.choices[0].message
|
||||
rendered_response, actions = self._parse_action(grounding_message, instruction, current_state)
|
||||
|
||||
# Update usage by grounding model
|
||||
self.tracer["usage"] = get_chat_usage_metrics(
|
||||
self.model.name,
|
||||
input_tokens=grounding_response.usage.prompt_tokens,
|
||||
output_tokens=grounding_response.usage.completion_tokens,
|
||||
usage=self.tracer.get("usage"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error calling Grounding LLM: {e}")
|
||||
rendered_response = f"**Error**: Error contacting Grounding LLM: {e}"
|
||||
actions = []
|
||||
|
||||
return rendered_response, actions
|
||||
|
||||
def _format_message_for_api(self, instruction: str, current_state: EnvState) -> List:
|
||||
"""Format the message for the API call."""
|
||||
grounding_user_prompt = f"""
|
||||
You are a GUI agent. You are given a task and a screenshot of the web browser tab you operate. You need to decide the next action to complete the task.
|
||||
You control a single tab in a Chromium browser. You cannot access the OS, filesystem or the application window.
|
||||
Always use the `goto` function to navigate to a specific URL. Ctrl+t, Ctrl+w, Ctrl+q, Ctrl+Shift+T, Ctrl+Shift+W are not allowed.
|
||||
|
||||
## Output Format
|
||||
```
|
||||
Thought: ...
|
||||
Action: ...
|
||||
```
|
||||
|
||||
## Action Space
|
||||
|
||||
click(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
left_double(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
right_single(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>')
|
||||
hotkey(key='')
|
||||
type(content='xxx') # Use escape characters \\', \\\", and \\n in content part to ensure we can parse the content in normal python string format. If you want to submit your input, use \\n at the end of content.
|
||||
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
|
||||
wait(duration='time') # Sleep for specified time. Default is 1s and take a screenshot to check for any changes.
|
||||
goto(url='xxx') # Always use this to navigate to a specific URL. Use escape characters \\', \\", and \\n in url part to ensure we can parse the url in normal python string format.
|
||||
back() # Use this to go back to the previous page.
|
||||
|
||||
## Note
|
||||
- Use English in `Thought` part.
|
||||
- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part.
|
||||
|
||||
## User Instruction
|
||||
{instruction}
|
||||
""".lstrip()
|
||||
|
||||
# Construct grounding LLM input (using only the latest user prompt + image)
|
||||
# We don't pass the full history here, as grounding depends on the *current* state + NL action
|
||||
screenshots = [f"data:image/webp;base64,{current_state.screenshot}"]
|
||||
grounding_messages_content = construct_structured_message(
|
||||
grounding_user_prompt, screenshots, self.model.name, vision_enabled=True
|
||||
)
|
||||
return [{"role": "user", "content": grounding_messages_content}]
|
||||
|
||||
def _parse_action(
|
||||
self, grounding_message: ChatCompletionMessage, instruction: str, current_state: EnvState
|
||||
) -> tuple[str, list[OperatorAction]]:
|
||||
"""Parse the tool calls from the grounding LLM response and convert them to action objects."""
|
||||
actions: List[OperatorAction] = []
|
||||
action_results: List[dict] = []
|
||||
|
||||
if grounding_message.tool_calls:
|
||||
rendered_parts = []
|
||||
for tool_call in grounding_message.tool_calls:
|
||||
function_name = tool_call.function.name
|
||||
try:
|
||||
arguments = json.loads(tool_call.function.arguments)
|
||||
action_to_run: Optional[OperatorAction] = None
|
||||
action_render_str = f"**Action ({function_name})**: {tool_call.function.arguments}"
|
||||
|
||||
if function_name == "click":
|
||||
action_to_run = ClickAction(**arguments)
|
||||
elif function_name == "left_double":
|
||||
action_to_run = DoubleClickAction(**arguments)
|
||||
elif function_name == "right_single":
|
||||
action_to_run = ClickAction(button="right", **arguments)
|
||||
elif function_name == "type":
|
||||
content = arguments.get("content")
|
||||
action_to_run = TypeAction(text=content)
|
||||
elif function_name == "scroll":
|
||||
direction = arguments.get("direction", "down")
|
||||
amount = 3
|
||||
action_to_run = ScrollAction(scroll_direction=direction, scroll_amount=amount, **arguments)
|
||||
elif function_name == "hotkey":
|
||||
action_to_run = KeypressAction(**arguments)
|
||||
elif function_name == "goto":
|
||||
action_to_run = GotoAction(**arguments)
|
||||
elif function_name == "back":
|
||||
action_to_run = BackAction(**arguments)
|
||||
elif function_name == "wait":
|
||||
action_to_run = WaitAction(**arguments)
|
||||
elif function_name == "screenshot":
|
||||
action_to_run = ScreenshotAction(**arguments)
|
||||
elif function_name == "drag":
|
||||
# Need to convert list of dicts to list of Point objects
|
||||
path_dicts = arguments.get("path", [])
|
||||
path_points = [Point(**p) for p in path_dicts]
|
||||
if path_points:
|
||||
action_to_run = DragAction(path=path_points)
|
||||
else:
|
||||
logger.warning(f"Drag action called with empty path: {arguments}")
|
||||
action_render_str += " [Skipped - empty path]"
|
||||
elif function_name == "finished":
|
||||
action_to_run = None
|
||||
else:
|
||||
logger.warning(f"Grounding LLM called unhandled tool: {function_name}")
|
||||
action_render_str += " [Unhandled]"
|
||||
|
||||
if action_to_run:
|
||||
actions.append(action_to_run)
|
||||
action_results.append(
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_call_id": tool_call.id,
|
||||
"content": None, # Updated after environment step
|
||||
}
|
||||
)
|
||||
rendered_parts.append(action_render_str)
|
||||
except (json.JSONDecodeError, TypeError, ValueError) as arg_err:
|
||||
logger.error(
|
||||
f"Error parsing arguments for tool {function_name}: {arg_err} - Args: {tool_call.function.arguments}"
|
||||
)
|
||||
rendered_parts.append(f"**Error**: Failed to parse arguments for {function_name}")
|
||||
rendered_response = "\n- ".join(rendered_parts)
|
||||
else:
|
||||
# Grounding LLM responded but didn't call a tool
|
||||
logger.warning("Grounding LLM did not produce a tool call.")
|
||||
rendered_response = f"{grounding_message.content or 'No action required.'}"
|
||||
|
||||
# Render the response
|
||||
return rendered_response, actions
|
||||
return tools
|
||||
|
||||
def reset(self):
|
||||
"""Reset the agent state."""
|
||||
|
||||
@@ -10,6 +10,7 @@ import logging
|
||||
import math
|
||||
import re
|
||||
from io import BytesIO
|
||||
from textwrap import dedent
|
||||
from typing import Any, List
|
||||
|
||||
import numpy as np
|
||||
@@ -18,7 +19,7 @@ from openai.types.chat import ChatCompletion
|
||||
from PIL import Image
|
||||
|
||||
from khoj.processor.operator.operator_actions import *
|
||||
from khoj.processor.operator.operator_environment_base import EnvState
|
||||
from khoj.processor.operator.operator_environment_base import EnvironmentType, EnvState
|
||||
from khoj.utils.helpers import get_chat_usage_metrics
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -35,29 +36,8 @@ class GroundingAgentUitars:
|
||||
MAX_PIXELS = 16384 * 28 * 28
|
||||
MAX_RATIO = 200
|
||||
|
||||
UITARS_USR_PROMPT_THOUGHT = """
|
||||
You are a GUI agent. You are given a task and a screenshot of the web browser tab you operate. You need to perform the next action to complete the task.
|
||||
You control a single tab in a Chromium browser. You cannot access the OS, filesystem, the application window or the addressbar.
|
||||
Try fulfill the user instruction to the best of your ability, especially when the instruction is given multiple times. Do not ignore the instruction.
|
||||
|
||||
## Output Format
|
||||
```
|
||||
Thought: ...
|
||||
Action: ...
|
||||
```
|
||||
|
||||
## Action Space
|
||||
{action_space}
|
||||
|
||||
## Note
|
||||
- Use {language} in `Thought` part.
|
||||
- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part.
|
||||
|
||||
## User Instruction
|
||||
{instruction}
|
||||
"""
|
||||
|
||||
UITARS_NORMAL_ACTION_SPACE = """
|
||||
UITARS_NORMAL_ACTION_SPACE = dedent(
|
||||
"""
|
||||
click(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
left_double(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
right_single(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
@@ -67,14 +47,15 @@ class GroundingAgentUitars:
|
||||
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
|
||||
wait() #Sleep for 5s and take a screenshot to check for any changes.
|
||||
finished(content='xxx') # Use escape characters \\', \\", and \\n in content part to ensure we can parse the content in normal python string format.
|
||||
""".lstrip()
|
||||
"""
|
||||
).lstrip()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
environment_type: EnvironmentType,
|
||||
client: AsyncOpenAI | AsyncAzureOpenAI,
|
||||
max_iterations=50,
|
||||
environment_type: Literal["computer", "web"] = "computer",
|
||||
runtime_conf: dict = {
|
||||
"infer_mode": "qwen25vl_normal",
|
||||
"prompt_style": "qwen25vl_normal",
|
||||
@@ -94,7 +75,7 @@ class GroundingAgentUitars:
|
||||
self.model_name = model_name
|
||||
self.client = client
|
||||
self.tracer = tracer
|
||||
self.environment_type = environment_type
|
||||
self.environment = environment_type
|
||||
|
||||
self.max_iterations = max_iterations
|
||||
self.runtime_conf = runtime_conf
|
||||
@@ -116,7 +97,7 @@ class GroundingAgentUitars:
|
||||
self.history_images: list[bytes] = []
|
||||
self.history_responses: list[str] = []
|
||||
|
||||
self.prompt_template = self.UITARS_USR_PROMPT_THOUGHT
|
||||
self.prompt_template = self.get_instruction(self.environment)
|
||||
self.prompt_action_space = self.UITARS_NORMAL_ACTION_SPACE
|
||||
|
||||
if "history_n" in self.runtime_conf:
|
||||
@@ -126,11 +107,11 @@ class GroundingAgentUitars:
|
||||
|
||||
self.cur_callusr_count = 0
|
||||
|
||||
async def act(self, instruction: str, env_state: EnvState) -> tuple[str, list[OperatorAction]]:
|
||||
async def act(self, instruction: str, current_state: EnvState) -> tuple[str, list[OperatorAction]]:
|
||||
"""
|
||||
Suggest the next action(s) based on the instruction and current environment.
|
||||
"""
|
||||
messages = self._format_messages_for_api(instruction, env_state)
|
||||
messages = self._format_messages_for_api(instruction, current_state)
|
||||
|
||||
recent_screenshot = Image.open(BytesIO(self.history_images[-1]))
|
||||
origin_resized_height = recent_screenshot.height
|
||||
@@ -145,9 +126,11 @@ class GroundingAgentUitars:
|
||||
try_times = 3
|
||||
while not parsed_responses:
|
||||
if try_times <= 0:
|
||||
print(f"Reach max retry times to fetch response from client, as error flag.")
|
||||
logger.warning(f"Reach max retry times to fetch response from client, as error flag.")
|
||||
return "client error\nFAIL", []
|
||||
try:
|
||||
message_content = "\n".join([msg["content"][0].get("text") or "[image]" for msg in messages])
|
||||
logger.debug(f"User message content: {message_content}")
|
||||
response: ChatCompletion = await self.client.chat.completions.create(
|
||||
model="ui-tars",
|
||||
messages=messages,
|
||||
@@ -228,20 +211,9 @@ class GroundingAgentUitars:
|
||||
self.actions.append(actions)
|
||||
return f"{prediction}\nFAIL", []
|
||||
|
||||
if self.environment_type == "web":
|
||||
actions.extend(
|
||||
self.parsing_response_to_action(parsed_response, obs_image_height, obs_image_width, self.input_swap)
|
||||
)
|
||||
else:
|
||||
pass
|
||||
# TODO: Add PyautoguiAction when enable computer environment
|
||||
# actions.append(
|
||||
# PyautoguiAction(code=
|
||||
# self.parsing_response_to_pyautogui_code(
|
||||
# parsed_response, obs_image_height, obs_image_width, self.input_swap
|
||||
# )
|
||||
# )
|
||||
# )
|
||||
actions.extend(
|
||||
self.parsing_response_to_action(parsed_response, obs_image_height, obs_image_width, self.input_swap)
|
||||
)
|
||||
|
||||
self.actions.append(actions)
|
||||
|
||||
@@ -252,13 +224,52 @@ class GroundingAgentUitars:
|
||||
|
||||
return prediction or "", actions
|
||||
|
||||
def _format_messages_for_api(self, instruction: str, env_state: EnvState):
|
||||
def get_instruction(self, environment_type: EnvironmentType) -> str:
|
||||
"""
|
||||
Get the instruction for the agent based on the environment type.
|
||||
"""
|
||||
UITARS_COMPUTER_PREFIX_PROMPT = """
|
||||
You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
|
||||
"""
|
||||
UITARS_BROWSER_PREFIX_PROMPT = """
|
||||
You are a GUI agent. You are given a task and a screenshot of the web browser tab you operate. You need to perform the next action to complete the task.
|
||||
You control a single tab in a Chromium browser. You cannot access the OS, filesystem, the application window or the addressbar.
|
||||
"""
|
||||
|
||||
UITARS_USR_PROMPT_THOUGHT = """
|
||||
Try fulfill the user instruction to the best of your ability, especially when the instruction is given multiple times. Do not ignore the instruction.
|
||||
|
||||
## Output Format
|
||||
```
|
||||
Thought: ...
|
||||
Action: ...
|
||||
```
|
||||
|
||||
## Action Space
|
||||
{action_space}
|
||||
|
||||
## Note
|
||||
- Use {language} in `Thought` part.
|
||||
- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part.
|
||||
|
||||
## User Instruction
|
||||
{instruction}
|
||||
"""
|
||||
|
||||
if environment_type == EnvironmentType.BROWSER:
|
||||
return dedent(UITARS_BROWSER_PREFIX_PROMPT + UITARS_USR_PROMPT_THOUGHT).lstrip()
|
||||
elif environment_type == EnvironmentType.COMPUTER:
|
||||
return dedent(UITARS_COMPUTER_PREFIX_PROMPT + UITARS_USR_PROMPT_THOUGHT).lstrip()
|
||||
else:
|
||||
raise ValueError(f"Unsupported environment type: {environment_type}")
|
||||
|
||||
def _format_messages_for_api(self, instruction: str, current_state: EnvState):
|
||||
assert len(self.observations) == len(self.actions) and len(self.actions) == len(
|
||||
self.thoughts
|
||||
), "The number of observations and actions should be the same."
|
||||
|
||||
self.history_images.append(base64.b64decode(env_state.screenshot))
|
||||
self.observations.append({"screenshot": env_state.screenshot, "accessibility_tree": None})
|
||||
self.history_images.append(base64.b64decode(current_state.screenshot))
|
||||
self.observations.append({"screenshot": current_state.screenshot, "accessibility_tree": None})
|
||||
|
||||
user_prompt = self.prompt_template.format(
|
||||
instruction=instruction, action_space=self.prompt_action_space, language=self.language
|
||||
|
||||
@@ -125,6 +125,49 @@ class NoopAction(BaseAction):
|
||||
type: Literal["noop"] = "noop"
|
||||
|
||||
|
||||
# --- Text Editor Actions ---
|
||||
class TextEditorViewAction(BaseAction):
|
||||
"""View contents of a file."""
|
||||
|
||||
type: Literal["text_editor_view"] = "text_editor_view"
|
||||
path: str
|
||||
view_range: Optional[List[int]] = None # [start_line, end_line]
|
||||
|
||||
|
||||
class TextEditorCreateAction(BaseAction):
|
||||
"""Create a new file with specified contents."""
|
||||
|
||||
type: Literal["text_editor_create"] = "text_editor_create"
|
||||
path: str
|
||||
file_text: str
|
||||
|
||||
|
||||
class TextEditorStrReplaceAction(BaseAction):
|
||||
"""Execute an exact string match replacement on a file."""
|
||||
|
||||
type: Literal["text_editor_str_replace"] = "text_editor_str_replace"
|
||||
path: str
|
||||
old_str: str
|
||||
new_str: str
|
||||
|
||||
|
||||
class TextEditorInsertAction(BaseAction):
|
||||
"""Insert new text after a specified line number."""
|
||||
|
||||
type: Literal["text_editor_insert"] = "text_editor_insert"
|
||||
path: str
|
||||
insert_line: int
|
||||
new_str: str
|
||||
|
||||
|
||||
class TerminalAction(BaseAction):
|
||||
"""Insert new text after a specified line number."""
|
||||
|
||||
type: Literal["terminal"] = "terminal"
|
||||
command: str
|
||||
restart: bool = False
|
||||
|
||||
|
||||
OperatorAction = Union[
|
||||
ClickAction,
|
||||
DoubleClickAction,
|
||||
@@ -146,4 +189,9 @@ OperatorAction = Union[
|
||||
BackAction,
|
||||
RequestUserAction,
|
||||
NoopAction,
|
||||
TextEditorViewAction,
|
||||
TextEditorCreateAction,
|
||||
TextEditorStrReplaceAction,
|
||||
TextEditorInsertAction,
|
||||
TerminalAction,
|
||||
]
|
||||
|
||||
@@ -3,18 +3,21 @@ import json
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, cast
|
||||
from textwrap import dedent
|
||||
from typing import List, Literal, Optional, cast
|
||||
|
||||
from anthropic.types.beta import BetaContentBlock
|
||||
from anthropic.types.beta import BetaContentBlock, BetaTextBlock, BetaToolUseBlock
|
||||
|
||||
from khoj.database.models import ChatModel
|
||||
from khoj.processor.conversation.anthropic.utils import is_reasoning_model
|
||||
from khoj.processor.conversation.utils import AgentMessage
|
||||
from khoj.processor.operator.operator_actions import *
|
||||
from khoj.processor.operator.operator_agent_base import (
|
||||
AgentActResult,
|
||||
AgentMessage,
|
||||
OperatorAgent,
|
||||
from khoj.processor.operator.operator_agent_base import AgentActResult, OperatorAgent
|
||||
from khoj.processor.operator.operator_environment_base import (
|
||||
EnvironmentType,
|
||||
EnvState,
|
||||
EnvStepResult,
|
||||
)
|
||||
from khoj.processor.operator.operator_environment_base import EnvState, EnvStepResult
|
||||
from khoj.utils.helpers import get_anthropic_async_client, is_none_or_empty
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -23,81 +26,34 @@ logger = logging.getLogger(__name__)
|
||||
# --- Anthropic Operator Agent ---
|
||||
class AnthropicOperatorAgent(OperatorAgent):
|
||||
async def act(self, current_state: EnvState) -> AgentActResult:
|
||||
client = get_anthropic_async_client(
|
||||
self.vision_model.ai_model_api.api_key, self.vision_model.ai_model_api.api_base_url
|
||||
)
|
||||
betas = self.model_default_headers()
|
||||
temperature = 1.0
|
||||
actions: List[OperatorAction] = []
|
||||
action_results: List[dict] = []
|
||||
self._commit_trace() # Commit trace before next action
|
||||
|
||||
system_prompt = f"""<SYSTEM_CAPABILITY>
|
||||
* You are Khoj, a smart web browser operating assistant. You help the users accomplish tasks using a web browser.
|
||||
* You operate a Chromium browser using Playwright via the 'computer' tool.
|
||||
* You cannot access the OS or filesystem.
|
||||
* You can interact with the web browser to perform tasks like clicking, typing, scrolling, and more.
|
||||
* You can use the additional back() and goto() helper functions to ease navigating the browser. If you see nothing, try goto duckduckgo.com
|
||||
* When viewing a webpage it can be helpful to zoom out so that you can see everything on the page. Either that, or make sure you scroll down to see everything before deciding something isn't available.
|
||||
* When using your computer function calls, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request.
|
||||
* Perform web searches using DuckDuckGo. Don't use Google even if requested as the query will fail.
|
||||
* The current date is {datetime.today().strftime('%A, %B %-d, %Y')}.
|
||||
* The current URL is {current_state.url}.
|
||||
</SYSTEM_CAPABILITY>
|
||||
system_prompt = self.get_instructions(self.environment_type, current_state)
|
||||
tools = self.get_tools(self.environment_type, current_state)
|
||||
|
||||
<IMPORTANT>
|
||||
* You are allowed upto {self.max_iterations} iterations to complete the task.
|
||||
* Do not loop on wait, screenshot for too many turns without taking any action.
|
||||
* After initialization if the browser is blank, enter a website URL using the goto() function instead of waiting
|
||||
</IMPORTANT>
|
||||
"""
|
||||
if is_none_or_empty(self.messages):
|
||||
self.messages = [AgentMessage(role="user", content=self.query)]
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": self.model_default_tool("computer"),
|
||||
"name": "computer",
|
||||
"display_width_px": 1024,
|
||||
"display_height_px": 768,
|
||||
}, # TODO: Get from env
|
||||
{
|
||||
"name": "back",
|
||||
"description": "Go back to the previous page.",
|
||||
"input_schema": {"type": "object", "properties": {}},
|
||||
},
|
||||
{
|
||||
"name": "goto",
|
||||
"description": "Go to a specific URL.",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {"url": {"type": "string", "description": "Fully qualified URL to navigate to."}},
|
||||
"required": ["url"],
|
||||
},
|
||||
},
|
||||
]
|
||||
# Trigger trajectory compression if exceed size limit
|
||||
if len(self.messages) > self.message_limit:
|
||||
logger.debug("Compacting operator trajectory.")
|
||||
await self._compress()
|
||||
|
||||
thinking: dict[str, str | int] = {"type": "disabled"}
|
||||
if is_reasoning_model(self.vision_model.name):
|
||||
thinking = {"type": "enabled", "budget_tokens": 1024}
|
||||
|
||||
messages_for_api = self._format_message_for_api(self.messages)
|
||||
response = await client.beta.messages.create(
|
||||
messages=messages_for_api,
|
||||
model=self.vision_model.name,
|
||||
system=system_prompt,
|
||||
response_content = await self._call_model(
|
||||
messages=self.messages,
|
||||
model=self.vision_model,
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
betas=betas,
|
||||
thinking=thinking,
|
||||
max_tokens=4096, # TODO: Make configurable?
|
||||
temperature=temperature,
|
||||
headers=self.model_default_headers(),
|
||||
)
|
||||
|
||||
logger.debug(f"Anthropic response: {response.model_dump_json()}")
|
||||
self.messages.append(AgentMessage(role="assistant", content=response.content))
|
||||
rendered_response = self._render_response(response.content, current_state.screenshot)
|
||||
self.messages.append(AgentMessage(role="assistant", content=response_content))
|
||||
rendered_response = self._render_response(response_content, current_state.screenshot)
|
||||
|
||||
for block in response.content:
|
||||
# Parse actions from response
|
||||
for block in response_content:
|
||||
if block.type == "tool_use":
|
||||
content = None
|
||||
is_error = False
|
||||
@@ -179,6 +135,40 @@ class AnthropicOperatorAgent(OperatorAgent):
|
||||
logger.warning("Goto tool called without URL.")
|
||||
elif tool_name == "back":
|
||||
action_to_run = BackAction()
|
||||
elif tool_name == self.model_default_tool("terminal")["name"]:
|
||||
command = tool_input.get("command")
|
||||
restart = tool_input.get("restart", False)
|
||||
if command:
|
||||
action_to_run = TerminalAction(command=command, restart=restart)
|
||||
elif tool_name == "str_replace_based_edit_tool":
|
||||
# Handle text editor tool calls
|
||||
command = tool_input.get("command")
|
||||
if command == "view":
|
||||
path = tool_input.get("path")
|
||||
view_range = tool_input.get("view_range")
|
||||
if path:
|
||||
action_to_run = TextEditorViewAction(path=path, view_range=view_range)
|
||||
elif command == "create":
|
||||
path = tool_input.get("path")
|
||||
file_text = tool_input.get("file_text", "")
|
||||
if path:
|
||||
action_to_run = TextEditorCreateAction(path=path, file_text=file_text)
|
||||
elif command == "str_replace":
|
||||
path = tool_input.get("path")
|
||||
old_str = tool_input.get("old_str")
|
||||
new_str = tool_input.get("new_str")
|
||||
if path and old_str is not None and new_str is not None:
|
||||
action_to_run = TextEditorStrReplaceAction(path=path, old_str=old_str, new_str=new_str)
|
||||
elif command == "insert":
|
||||
path = tool_input.get("path")
|
||||
insert_line = tool_input.get("insert_line")
|
||||
new_str = tool_input.get("new_str")
|
||||
if path and insert_line is not None and new_str is not None:
|
||||
action_to_run = TextEditorInsertAction(
|
||||
path=path, insert_line=insert_line, new_str=new_str
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Unsupported text editor command: {command}")
|
||||
else:
|
||||
logger.warning(f"Unsupported Anthropic computer action type: {tool_name}")
|
||||
|
||||
@@ -200,14 +190,6 @@ class AnthropicOperatorAgent(OperatorAgent):
|
||||
}
|
||||
)
|
||||
|
||||
self._update_usage(
|
||||
response.usage.input_tokens,
|
||||
response.usage.output_tokens,
|
||||
response.usage.cache_read_input_tokens,
|
||||
response.usage.cache_creation_input_tokens,
|
||||
)
|
||||
self.tracer["temperature"] = temperature
|
||||
|
||||
return AgentActResult(
|
||||
actions=actions,
|
||||
action_results=action_results,
|
||||
@@ -240,18 +222,19 @@ class AnthropicOperatorAgent(OperatorAgent):
|
||||
if env_step.error:
|
||||
action_result["is_error"] = True
|
||||
|
||||
# Append tool results to the message history
|
||||
self.messages += [AgentMessage(role="environment", content=agent_action.action_results)]
|
||||
|
||||
# Mark the final tool result as a cache break point
|
||||
agent_action.action_results[-1]["cache_control"] = {"type": "ephemeral"}
|
||||
# Remove previous cache controls
|
||||
for msg in self.messages:
|
||||
if msg.role == "environment" and isinstance(msg.content, list):
|
||||
if isinstance(msg.content, list):
|
||||
for block in msg.content:
|
||||
if isinstance(block, dict) and "cache_control" in block:
|
||||
del block["cache_control"]
|
||||
|
||||
# Mark the final tool result as a cache break point
|
||||
agent_action.action_results[-1]["cache_control"] = {"type": "ephemeral"}
|
||||
|
||||
# Append tool results to the message history
|
||||
self.messages += [AgentMessage(role="environment", content=agent_action.action_results)]
|
||||
|
||||
def _format_message_for_api(self, messages: list[AgentMessage]) -> list[dict]:
|
||||
"""Format Anthropic response into a single string."""
|
||||
formatted_messages = []
|
||||
@@ -270,7 +253,7 @@ class AnthropicOperatorAgent(OperatorAgent):
|
||||
)
|
||||
return formatted_messages
|
||||
|
||||
def compile_response(self, response_content: list[BetaContentBlock | dict] | str) -> str:
|
||||
def _compile_response(self, response_content: list[BetaContentBlock | dict] | str) -> str:
|
||||
"""Compile Anthropic response into a single string."""
|
||||
if isinstance(response_content, str):
|
||||
return response_content
|
||||
@@ -288,7 +271,11 @@ class AnthropicOperatorAgent(OperatorAgent):
|
||||
compiled_response.append(block.text)
|
||||
elif block.type == "tool_use":
|
||||
block_input = {"action": block.name}
|
||||
if block.name == "computer":
|
||||
if block.name in (
|
||||
self.model_default_tool("computer")["name"],
|
||||
self.model_default_tool("editor")["name"],
|
||||
self.model_default_tool("terminal")["name"],
|
||||
):
|
||||
block_input = block.input # Computer action details are in input dict
|
||||
elif block.name == "goto":
|
||||
block_input["url"] = block.input.get("url", "[Missing URL]")
|
||||
@@ -345,7 +332,34 @@ class AnthropicOperatorAgent(OperatorAgent):
|
||||
else:
|
||||
# Handle other actions
|
||||
render_texts += [f"{action.capitalize()}"]
|
||||
|
||||
elif block.name == self.model_default_tool("editor")["name"]:
|
||||
# Handle text editor actions
|
||||
command = block.input.get("command")
|
||||
if command == "view":
|
||||
path = block.input.get("path")
|
||||
view_range = block.input.get("view_range")
|
||||
if path:
|
||||
render_texts += [f"View file: {path} (lines {view_range})"]
|
||||
elif command == "create":
|
||||
path = block.input.get("path")
|
||||
file_text = block.input.get("file_text", "")
|
||||
if path:
|
||||
render_texts += [f"Create file: {path} with content:\n{file_text}"]
|
||||
elif command == "str_replace":
|
||||
path = block.input.get("path")
|
||||
old_str = block.input.get("old_str")
|
||||
new_str = block.input.get("new_str")
|
||||
if path and old_str is not None and new_str is not None:
|
||||
render_texts += [f"File: {path}\n**Find**\n{old_str}\n**Replace**\n{new_str}'"]
|
||||
elif command == "insert":
|
||||
path = block.input.get("path")
|
||||
insert_line = block.input.get("insert_line")
|
||||
new_str = block.input.get("new_str")
|
||||
if path and insert_line is not None and new_str is not None:
|
||||
render_texts += [f"In file: {path} at line {insert_line} insert\n{new_str}"]
|
||||
render_texts += [f"Edit file: {block.input['path']}"]
|
||||
elif block.name == self.model_default_tool("terminal")["name"]:
|
||||
render_texts += [f"Run command:\n{block.input['command']}"]
|
||||
# If screenshot is not available when screenshot action was requested
|
||||
if isinstance(block.input, dict) and block.input.get("action") == "screenshot" and not screenshot:
|
||||
render_texts += ["Failed to get screenshot"]
|
||||
@@ -365,6 +379,107 @@ class AnthropicOperatorAgent(OperatorAgent):
|
||||
|
||||
return render_payload
|
||||
|
||||
async def _call_model(
|
||||
self,
|
||||
messages: list[AgentMessage],
|
||||
model: ChatModel,
|
||||
system_prompt: str,
|
||||
tools: list[dict] = [],
|
||||
headers: list[str] = [],
|
||||
temperature: float = 1.0,
|
||||
max_tokens: int = 4096,
|
||||
) -> list[BetaContentBlock]:
|
||||
client = get_anthropic_async_client(model.ai_model_api.api_key, model.ai_model_api.api_base_url)
|
||||
thinking: dict[str, str | int] = {"type": "disabled"}
|
||||
system = [{"type": "text", "text": system_prompt, "cache_control": {"type": "ephemeral"}}]
|
||||
kwargs: dict = {}
|
||||
if is_reasoning_model(model.name):
|
||||
thinking = {"type": "enabled", "budget_tokens": 1024}
|
||||
if headers:
|
||||
kwargs["betas"] = headers
|
||||
if tools:
|
||||
tools[-1]["cache_control"] = {"type": "ephemeral"} # Mark last tool as cache break point
|
||||
kwargs["tools"] = tools
|
||||
|
||||
messages_for_api = self._format_message_for_api(messages)
|
||||
try:
|
||||
response = await client.beta.messages.create(
|
||||
messages=messages_for_api,
|
||||
model=model.name,
|
||||
system=system,
|
||||
thinking=thinking,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
**kwargs,
|
||||
)
|
||||
response_content = response.content
|
||||
except Exception as e:
|
||||
# create a response block with error message
|
||||
logger.error(f"Error during Anthropic API call: {e}")
|
||||
error_str = e.message if hasattr(e, "message") else str(e)
|
||||
response = None
|
||||
response_content = [BetaTextBlock(text=f"Communication Error: {error_str}", type="text")]
|
||||
|
||||
if response:
|
||||
logger.debug(f"Anthropic response: {response.model_dump_json()}")
|
||||
self._update_usage(
|
||||
response.usage.input_tokens,
|
||||
response.usage.output_tokens,
|
||||
response.usage.cache_read_input_tokens,
|
||||
response.usage.cache_creation_input_tokens,
|
||||
)
|
||||
self.tracer["temperature"] = temperature
|
||||
return response_content
|
||||
|
||||
async def _compress(self):
|
||||
# 1. Prepare messages for compression
|
||||
original_messages = list(self.messages)
|
||||
messages_to_summarize = self.messages[: self.compress_length]
|
||||
# ensure last message isn't a tool call request
|
||||
if messages_to_summarize[-1].role == "assistant" and (
|
||||
any(isinstance(block, BetaToolUseBlock) for block in messages_to_summarize[-1].content)
|
||||
or any(block["type"] == "tool_use" for block in messages_to_summarize[-1].content)
|
||||
):
|
||||
messages_to_summarize.pop()
|
||||
|
||||
summarize_prompt = f"Summarize your research and computer use till now to help answer my query:\n{self.query}"
|
||||
summarize_message = AgentMessage(role="user", content=summarize_prompt)
|
||||
system_prompt = dedent(
|
||||
"""
|
||||
You are a computer operator with meticulous communication skills. You can condense your partial computer use traces and research into an appropriately detailed summary.
|
||||
When requested summarize your key actions, results and findings until now to achieve the user specified task.
|
||||
Your summary should help you remember the key information required to both complete the task and later generate a final report.
|
||||
"""
|
||||
)
|
||||
|
||||
# 2. Get summary of operation trajectory
|
||||
try:
|
||||
response_content = await self._call_model(
|
||||
messages=messages_to_summarize + [summarize_message],
|
||||
model=self.vision_model,
|
||||
system_prompt=system_prompt,
|
||||
max_tokens=8192,
|
||||
)
|
||||
except Exception as e:
|
||||
# create a response block with error message
|
||||
logger.error(f"Error during Anthropic API call: {e}")
|
||||
error_str = e.message if hasattr(e, "message") else str(e)
|
||||
response_content = [BetaTextBlock(text=f"Communication Error: {error_str}", type="text")]
|
||||
|
||||
summary_message = AgentMessage(role="assistant", content=response_content)
|
||||
|
||||
# 3. Rebuild message history with condensed trajectory
|
||||
primary_task = [original_messages.pop(0)]
|
||||
condensed_trajectory = [summarize_message, summary_message]
|
||||
recent_trajectory = original_messages[self.compress_length - 1 :] # -1 since we popped the first message
|
||||
# ensure first message isn't a tool result
|
||||
if recent_trajectory[0].role == "environment" and any(
|
||||
block["type"] == "tool_result" for block in recent_trajectory[0].content
|
||||
):
|
||||
recent_trajectory.pop(0)
|
||||
|
||||
self.messages = primary_task + condensed_trajectory + recent_trajectory
|
||||
|
||||
def get_coordinates(self, tool_input: dict, key: str = "coordinate") -> Optional[list | tuple]:
|
||||
"""Get coordinates from tool input."""
|
||||
raw_coord = tool_input.get(key)
|
||||
@@ -382,14 +497,22 @@ class AnthropicOperatorAgent(OperatorAgent):
|
||||
|
||||
return coord
|
||||
|
||||
def model_default_tool(self, tool_type: Literal["computer", "editor", "terminal"]) -> str:
|
||||
def model_default_tool(self, tool_type: Literal["computer", "editor", "terminal"]) -> dict[str, str]:
|
||||
"""Get the default tool of specified type for the given model."""
|
||||
if self.vision_model.name.startswith("claude-3-7-sonnet"):
|
||||
if tool_type == "computer":
|
||||
return "computer_20250124"
|
||||
return {"name": "computer", "type": "computer_20250124"}
|
||||
elif tool_type == "editor":
|
||||
return {"name": "str_replace_editor", "type": "text_editor_20250124"}
|
||||
elif tool_type == "terminal":
|
||||
return {"name": "bash_20250124", "type": "bash"}
|
||||
elif self.vision_model.name.startswith("claude-sonnet-4") or self.vision_model.name.startswith("claude-opus-4"):
|
||||
if tool_type == "computer":
|
||||
return "computer_20250124"
|
||||
return {"name": "computer", "type": "computer_20250124"}
|
||||
elif tool_type == "editor":
|
||||
return {"name": "str_replace_based_edit_tool", "type": "text_editor_20250429"}
|
||||
elif tool_type == "terminal":
|
||||
return {"name": "bash", "type": "bash_20250124"}
|
||||
raise ValueError(f"Unsupported tool type for model '{self.vision_model.name}': {tool_type}")
|
||||
|
||||
def model_default_headers(self) -> list[str]:
|
||||
@@ -400,3 +523,88 @@ class AnthropicOperatorAgent(OperatorAgent):
|
||||
return ["computer-use-2025-01-24"]
|
||||
else:
|
||||
return []
|
||||
|
||||
def get_instructions(self, environment_type: EnvironmentType, current_state: EnvState) -> str:
|
||||
"""Return system instructions for the Anthropic operator."""
|
||||
if environment_type == EnvironmentType.BROWSER:
|
||||
return dedent(
|
||||
f"""
|
||||
<SYSTEM_CAPABILITY>
|
||||
* You are Khoj, a smart web browser operating assistant. You help the users accomplish tasks using a web browser.
|
||||
* You operate a Chromium browser using Playwright via the 'computer' tool.
|
||||
* You cannot access the OS or filesystem.
|
||||
* You can interact with the web browser to perform tasks like clicking, typing, scrolling, and more.
|
||||
* You can use the additional back() and goto() helper functions to ease navigating the browser. If you see nothing, try goto duckduckgo.com
|
||||
* When viewing a webpage it can be helpful to zoom out so that you can see everything on the page. Either that, or make sure you scroll down to see everything before deciding something isn't available.
|
||||
* When using your computer function calls, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request.
|
||||
* Perform web searches using DuckDuckGo. Don't use Google even if requested as the query will fail.
|
||||
* The current date is {datetime.today().strftime('%A, %B %-d, %Y')}.
|
||||
* The current URL is {current_state.url}.
|
||||
</SYSTEM_CAPABILITY>
|
||||
|
||||
<IMPORTANT>
|
||||
* You are allowed upto {self.max_iterations} iterations to complete the task.
|
||||
* Do not loop on wait, screenshot for too many turns without taking any action.
|
||||
* After initialization if the browser is blank, enter a website URL using the goto() function instead of waiting
|
||||
</IMPORTANT>
|
||||
"""
|
||||
).lstrip()
|
||||
elif environment_type == EnvironmentType.COMPUTER:
|
||||
return dedent(
|
||||
f"""
|
||||
<SYSTEM_CAPABILITY>
|
||||
* You are Khoj, a smart computer operating assistant. You help the users accomplish tasks using a computer.
|
||||
* You can interact with the computer to perform tasks like clicking, typing, scrolling, and more.
|
||||
* When viewing a document or webpage it can be helpful to zoom out or scroll down to ensure you see everything before deciding something isn't available.
|
||||
* When using your computer function calls, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request.
|
||||
* Perform web searches using DuckDuckGo. Don't use Google even if requested as the query will fail.
|
||||
* Do not loop on wait, screenshot for too many turns without taking any action.
|
||||
* You are allowed upto {self.max_iterations} iterations to complete the task.
|
||||
</SYSTEM_CAPABILITY>
|
||||
|
||||
<CONTEXT>
|
||||
* The current date is {datetime.today().strftime('%A, %B %-d, %Y')}.
|
||||
</CONTEXT>
|
||||
"""
|
||||
).lstrip()
|
||||
else:
|
||||
raise ValueError(f"Unsupported environment type for Anthropic operator: {environment_type}")
|
||||
|
||||
def get_tools(self, environment: EnvironmentType, current_state: EnvState) -> list[dict]:
|
||||
"""Return the tools available for the Anthropic operator."""
|
||||
tools: list[dict] = [
|
||||
{
|
||||
"type": self.model_default_tool("computer")["type"],
|
||||
"name": "computer",
|
||||
"display_width_px": current_state.width,
|
||||
"display_height_px": current_state.height,
|
||||
},
|
||||
{
|
||||
"type": self.model_default_tool("editor")["type"],
|
||||
"name": self.model_default_tool("editor")["name"],
|
||||
},
|
||||
{
|
||||
"type": self.model_default_tool("terminal")["type"],
|
||||
"name": self.model_default_tool("terminal")["name"],
|
||||
},
|
||||
]
|
||||
|
||||
if environment == "browser":
|
||||
tools += [
|
||||
{
|
||||
"name": "back",
|
||||
"description": "Go back to the previous page.",
|
||||
"input_schema": {"type": "object", "properties": {}},
|
||||
},
|
||||
{
|
||||
"name": "goto",
|
||||
"description": "Go to a specific URL.",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {"url": {"type": "string", "description": "Fully qualified URL to navigate to."}},
|
||||
"required": ["url"],
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
return tools
|
||||
|
||||
@@ -5,9 +5,17 @@ from typing import List, Literal, Optional, Union
|
||||
from pydantic import BaseModel
|
||||
|
||||
from khoj.database.models import ChatModel
|
||||
from khoj.processor.conversation.utils import commit_conversation_trace
|
||||
from khoj.processor.conversation.utils import (
|
||||
AgentMessage,
|
||||
OperatorRun,
|
||||
commit_conversation_trace,
|
||||
)
|
||||
from khoj.processor.operator.operator_actions import OperatorAction
|
||||
from khoj.processor.operator.operator_environment_base import EnvState, EnvStepResult
|
||||
from khoj.processor.operator.operator_environment_base import (
|
||||
EnvironmentType,
|
||||
EnvState,
|
||||
EnvStepResult,
|
||||
)
|
||||
from khoj.utils.helpers import get_chat_usage_metrics, is_promptrace_enabled
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -19,18 +27,41 @@ class AgentActResult(BaseModel):
|
||||
rendered_response: Optional[dict] = None
|
||||
|
||||
|
||||
class AgentMessage(BaseModel):
|
||||
role: Literal["user", "assistant", "system", "environment"]
|
||||
content: Union[str, List]
|
||||
|
||||
|
||||
class OperatorAgent(ABC):
|
||||
def __init__(self, query: str, vision_model: ChatModel, max_iterations: int, tracer: dict):
|
||||
def __init__(
|
||||
self,
|
||||
query: str,
|
||||
vision_model: ChatModel,
|
||||
environment_type: EnvironmentType,
|
||||
max_iterations: int,
|
||||
max_context: int,
|
||||
chat_history: List[AgentMessage] = [],
|
||||
previous_trajectory: Optional[OperatorRun] = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
self.query = query
|
||||
self.vision_model = vision_model
|
||||
self.environment_type = environment_type
|
||||
self.max_iterations = max_iterations
|
||||
self.tracer = tracer
|
||||
self.messages: List[AgentMessage] = []
|
||||
self.summarize_prompt = f"Use the results of our research to provide a comprehensive, self-contained answer for the target query:\n{query}."
|
||||
|
||||
self.messages: List[AgentMessage] = chat_history
|
||||
if previous_trajectory:
|
||||
# Remove tool call from previous trajectory as tool call w/o result not supported
|
||||
if previous_trajectory.trajectory and previous_trajectory.trajectory[-1].role == "assistant":
|
||||
previous_trajectory.trajectory.pop()
|
||||
self.messages += previous_trajectory.trajectory
|
||||
self.messages += [AgentMessage(role="user", content=query)]
|
||||
|
||||
# Context compression parameters
|
||||
self.context_compress_trigger = 2e3 # heuristic to determine compression trigger
|
||||
# turns after which compression triggered. scales with model max context size. Minimum 5 turns.
|
||||
self.message_limit = 2 * max(5, int(max_context / self.context_compress_trigger))
|
||||
# compression ratio determines how many messages to compress down to one
|
||||
# e.g. if 5 messages, a compress ratio of 4/5 means compress 5 messages into 1 + keep 1 uncompressed
|
||||
self.message_compress_ratio = 4 / 5
|
||||
self.compress_length = int(self.message_limit * self.message_compress_ratio)
|
||||
|
||||
@abstractmethod
|
||||
async def act(self, current_state: EnvState) -> AgentActResult:
|
||||
@@ -41,16 +72,17 @@ class OperatorAgent(ABC):
|
||||
"""Track results of agent actions on the environment."""
|
||||
pass
|
||||
|
||||
async def summarize(self, summarize_prompt: str, current_state: EnvState) -> str:
|
||||
async def summarize(self, current_state: EnvState, summarize_prompt: str = None) -> str:
|
||||
"""Summarize the agent's actions and results."""
|
||||
summarize_prompt = summarize_prompt or self.summarize_prompt
|
||||
self.messages.append(AgentMessage(role="user", content=summarize_prompt))
|
||||
await self.act(current_state)
|
||||
if not self.messages:
|
||||
return "No actions to summarize."
|
||||
return self.compile_response(self.messages[-1].content)
|
||||
return self._compile_response(self.messages[-1].content)
|
||||
|
||||
@abstractmethod
|
||||
def compile_response(self, response: List | str) -> str:
|
||||
def _compile_response(self, response: List | str) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -65,13 +97,12 @@ class OperatorAgent(ABC):
|
||||
self.tracer["usage"] = get_chat_usage_metrics(
|
||||
self.vision_model.name, input_tokens, output_tokens, cache_read, cache_write, usage=self.tracer.get("usage")
|
||||
)
|
||||
logger.debug(f"Operator usage by {self.vision_model.model_type}: {self.tracer['usage']}")
|
||||
|
||||
def _commit_trace(self):
|
||||
self.tracer["chat_model"] = self.vision_model.name
|
||||
if is_promptrace_enabled() and len(self.messages) > 1:
|
||||
compiled_messages = [
|
||||
AgentMessage(role=msg.role, content=self.compile_response(msg.content)) for msg in self.messages
|
||||
AgentMessage(role=msg.role, content=self._compile_response(msg.content)) for msg in self.messages
|
||||
]
|
||||
commit_conversation_trace(compiled_messages[:-1], compiled_messages[-1].content, self.tracer)
|
||||
|
||||
|
||||
@@ -1,21 +1,24 @@
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from textwrap import dedent
|
||||
from typing import List, Optional
|
||||
|
||||
from openai.types.chat import ChatCompletion
|
||||
|
||||
from khoj.database.models import ChatModel
|
||||
from khoj.processor.conversation.utils import construct_structured_message
|
||||
from khoj.processor.conversation.utils import (
|
||||
AgentMessage,
|
||||
OperatorRun,
|
||||
construct_structured_message,
|
||||
)
|
||||
from khoj.processor.operator.grounding_agent import GroundingAgent
|
||||
from khoj.processor.operator.grounding_agent_uitars import GroundingAgentUitars
|
||||
from khoj.processor.operator.operator_actions import *
|
||||
from khoj.processor.operator.operator_agent_base import (
|
||||
AgentActResult,
|
||||
AgentMessage,
|
||||
OperatorAgent,
|
||||
from khoj.processor.operator.operator_agent_base import AgentActResult, OperatorAgent
|
||||
from khoj.processor.operator.operator_environment_base import (
|
||||
EnvironmentType,
|
||||
EnvState,
|
||||
EnvStepResult,
|
||||
)
|
||||
from khoj.processor.operator.operator_environment_base import EnvState, EnvStepResult
|
||||
from khoj.routers.helpers import send_message_to_model_wrapper
|
||||
from khoj.utils.helpers import get_openai_async_client, is_none_or_empty
|
||||
|
||||
@@ -27,7 +30,7 @@ class BinaryOperatorAgent(OperatorAgent):
|
||||
"""
|
||||
An OperatorAgent that uses two LLMs:
|
||||
1. Reasoning LLM: Determines the next high-level action based on the objective and current visual reasoning trajectory.
|
||||
2. Grounding LLM: Converts the high-level action into specific, executable browser actions.
|
||||
2. Grounding LLM: Converts the high-level action into specific, actions executable on the environment.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -35,10 +38,23 @@ class BinaryOperatorAgent(OperatorAgent):
|
||||
query: str,
|
||||
reasoning_model: ChatModel,
|
||||
grounding_model: ChatModel,
|
||||
environment_type: EnvironmentType,
|
||||
max_iterations: int,
|
||||
tracer: dict,
|
||||
max_context: int,
|
||||
chat_history: List[AgentMessage] = [],
|
||||
previous_trajectory: Optional[OperatorRun] = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
super().__init__(query, reasoning_model, max_iterations, tracer) # Use reasoning model for primary tracking
|
||||
super().__init__(
|
||||
query,
|
||||
reasoning_model,
|
||||
environment_type,
|
||||
max_iterations,
|
||||
max_context,
|
||||
chat_history,
|
||||
previous_trajectory,
|
||||
tracer,
|
||||
) # Use reasoning model for primary tracking
|
||||
self.reasoning_model = reasoning_model
|
||||
self.grounding_model = grounding_model
|
||||
# Initialize openai api compatible client for grounding model
|
||||
@@ -49,10 +65,12 @@ class BinaryOperatorAgent(OperatorAgent):
|
||||
self.grounding_agent: GroundingAgent | GroundingAgentUitars = None
|
||||
if "ui-tars-1.5" in grounding_model.name:
|
||||
self.grounding_agent = GroundingAgentUitars(
|
||||
grounding_model.name, grounding_client, max_iterations, environment_type="web", tracer=tracer
|
||||
grounding_model.name, self.environment_type, grounding_client, max_iterations, tracer=tracer
|
||||
)
|
||||
else:
|
||||
self.grounding_agent = GroundingAgent(grounding_model.name, grounding_client, max_iterations, tracer=tracer)
|
||||
self.grounding_agent = GroundingAgent(
|
||||
grounding_model.name, self.environment_type, grounding_client, max_iterations, tracer=tracer
|
||||
)
|
||||
|
||||
async def act(self, current_state: EnvState) -> AgentActResult:
|
||||
"""
|
||||
@@ -84,48 +102,7 @@ class BinaryOperatorAgent(OperatorAgent):
|
||||
"""
|
||||
Uses the reasoning LLM to determine the next high-level action based on the operation trajectory.
|
||||
"""
|
||||
reasoning_system_prompt = f"""
|
||||
# Introduction
|
||||
* You are Khoj, a smart and resourceful web browsing assistant. You help the user accomplish their task using a web browser.
|
||||
* You are given the user's query and screenshots of the browser's state transitions.
|
||||
* The current date is {datetime.today().strftime('%A, %B %-d, %Y')}.
|
||||
* The current URL is {current_state.url}.
|
||||
|
||||
# Your Task
|
||||
* First look at the screenshots carefully to notice all pertinent information.
|
||||
* Then instruct a tool AI to perform the next action that will help you progress towards the user's goal.
|
||||
* Make sure you scroll down to see everything before deciding something isn't available.
|
||||
* Perform web searches using DuckDuckGo. Don't use Google even if requested as the query will fail.
|
||||
* Use your creativity to find alternate ways to make progress if you get stuck at any point.
|
||||
|
||||
# Tool AI Capabilities
|
||||
* The tool AI only has access to the current screenshot and your instructions. It uses your instructions to perform the next action on the page.
|
||||
* It can interact with the web browser with these actions: click, right click, double click, type, scroll, drag, wait, goto url and go back to previous page.
|
||||
* It cannot access the OS, filesystem or application window. It just controls a single Chromium browser tab via Playwright.
|
||||
|
||||
# IMPORTANT
|
||||
* You are allowed upto {self.max_iterations} iterations to complete the task.
|
||||
* To navigate to a specific URL, put "GOTO <URL>" (without quotes) on the last line of your response.
|
||||
* To navigate back to the previous page, end your response with "BACK" (without quotes).
|
||||
* Once you've verified that the main objective has been achieved, end your response with "DONE" (without quotes).
|
||||
|
||||
# Examples
|
||||
## Example 1
|
||||
GOTO https://example.com
|
||||
## Example 2
|
||||
click the blue login button located at the top right corner
|
||||
## Example 3
|
||||
scroll down the page
|
||||
## Example 4
|
||||
type the username example@email.com into the input field labeled Username
|
||||
## Example 5
|
||||
DONE
|
||||
|
||||
# Instructions
|
||||
Now describe a single high-level action to take next to progress towards the user's goal in detail.
|
||||
Focus on the visual action and provide all necessary context.
|
||||
""".strip()
|
||||
|
||||
reasoning_system_prompt = self.get_instruction(self.environment_type, current_state)
|
||||
if is_none_or_empty(self.messages):
|
||||
query_text = f"**Main Objective**: {self.query}"
|
||||
query_screenshot = [f"data:image/webp;base64,{current_state.screenshot}"]
|
||||
@@ -259,7 +236,8 @@ Focus on the visual action and provide all necessary context.
|
||||
action_results_content.extend(action_result["content"])
|
||||
self.messages.append(AgentMessage(role="environment", content=action_results_content))
|
||||
|
||||
async def summarize(self, summarize_prompt: str, env_state: EnvState) -> str:
|
||||
async def summarize(self, env_state: EnvState, summarize_prompt: str = None) -> str:
|
||||
summarize_prompt = summarize_prompt or self.summarize_prompt
|
||||
conversation_history = {"chat": self._format_message_for_api(self.messages)}
|
||||
try:
|
||||
summary = await send_message_to_model_wrapper(
|
||||
@@ -282,7 +260,7 @@ Focus on the visual action and provide all necessary context.
|
||||
|
||||
return summary
|
||||
|
||||
def compile_response(self, response_content: str | List) -> str:
|
||||
def _compile_response(self, response_content: str | List) -> str:
|
||||
"""Compile response content into a string, handling OpenAI message structures."""
|
||||
if isinstance(response_content, str):
|
||||
return response_content
|
||||
@@ -330,6 +308,96 @@ Focus on the visual action and provide all necessary context.
|
||||
]
|
||||
return formatted_messages
|
||||
|
||||
def get_instruction(self, environment_type: EnvironmentType, env_state: EnvState) -> str:
|
||||
"""Get the system instruction for the reasoning agent."""
|
||||
if environment_type == EnvironmentType.BROWSER:
|
||||
return dedent(
|
||||
f"""
|
||||
# Introduction
|
||||
* You are Khoj, a smart and resourceful web browsing assistant. You help the user accomplish their task using a web browser.
|
||||
* You are given the user's query and screenshots of the browser's state transitions.
|
||||
* The current date is {datetime.today().strftime('%A, %B %-d, %Y')}.
|
||||
* The current URL is {env_state.url}.
|
||||
|
||||
# Your Task
|
||||
* First look at the screenshots carefully to notice all pertinent information.
|
||||
* Then instruct a tool AI to perform the next action that will help you progress towards the user's goal.
|
||||
* Make sure you scroll down to see everything before deciding something isn't available.
|
||||
* Perform web searches using DuckDuckGo. Don't use Google even if requested as the query will fail.
|
||||
* Use your creativity to find alternate ways to make progress if you get stuck at any point.
|
||||
|
||||
# Tool AI Capabilities
|
||||
* The tool AI only has access to the current screenshot and your instructions. It uses your instructions to perform the next action on the page.
|
||||
* It can interact with the web browser with these actions: click, right click, double click, type, scroll, drag, wait, goto url and go back to previous page.
|
||||
* It cannot access the OS, filesystem or application window. It just controls a single Chromium browser tab via Playwright.
|
||||
|
||||
# IMPORTANT
|
||||
* You are allowed upto {self.max_iterations} iterations to complete the task.
|
||||
* To navigate to a specific URL, put "GOTO <URL>" (without quotes) on the last line of your response.
|
||||
* To navigate back to the previous page, end your response with "BACK" (without quotes).
|
||||
* Once you've verified that the main objective has been achieved, end your response with "DONE" (without quotes).
|
||||
|
||||
# Examples
|
||||
## Example 1
|
||||
GOTO https://example.com
|
||||
## Example 2
|
||||
click the blue login button located at the top right corner
|
||||
## Example 3
|
||||
scroll down the page
|
||||
## Example 4
|
||||
type the username example@email.com into the input field labeled Username
|
||||
## Example 5
|
||||
DONE
|
||||
|
||||
# Instructions
|
||||
Now describe a single high-level action to take next to progress towards the user's goal in detail.
|
||||
Focus on the visual action and provide all necessary context.
|
||||
"""
|
||||
).strip()
|
||||
|
||||
elif environment_type == EnvironmentType.COMPUTER:
|
||||
return dedent(
|
||||
f"""
|
||||
# Introduction
|
||||
* You are Khoj, a smart and resourceful computer assistant. You help the user accomplish their task using a computer.
|
||||
* You are given the user's query and screenshots of the computer's state transitions.
|
||||
* The current date is {datetime.today().strftime('%A, %B %-d, %Y')}.
|
||||
|
||||
# Your Task
|
||||
* First look at the screenshots carefully to notice all pertinent information.
|
||||
* Then instruct a tool AI to perform the next action that will help you progress towards the user's goal.
|
||||
* Make sure you scroll down to see everything before deciding something isn't available.
|
||||
* Perform web searches using DuckDuckGo. Don't use Google even if requested as the query will fail.
|
||||
* Use your creativity to find alternate ways to make progress if you get stuck at any point.
|
||||
|
||||
# Tool AI Capabilities
|
||||
* The tool AI only has access to the current screenshot and your instructions. It uses your instructions to perform the next action on the page.
|
||||
* It can interact with the computer with these actions: click, right click, double click, type, scroll, drag, wait to previous page.
|
||||
|
||||
# IMPORTANT
|
||||
* You are allowed upto {self.max_iterations} iterations to complete the task.
|
||||
* Once you've verified that the main objective has been achieved, end your response with "DONE" (without quotes).
|
||||
|
||||
# Examples
|
||||
## Example 1
|
||||
type https://example.com into the address bar and press Enter
|
||||
## Example 2
|
||||
click the blue login button located at the top right corner
|
||||
## Example 3
|
||||
scroll down the page
|
||||
## Example 4
|
||||
type the username example@email.com into the input field labeled Username
|
||||
## Example 5
|
||||
DONE
|
||||
|
||||
# Instructions
|
||||
Now describe a single high-level action to take next to progress towards the user's goal in detail.
|
||||
Focus on the visual action and provide all necessary context.
|
||||
"""
|
||||
).strip()
|
||||
else:
|
||||
raise ValueError(f"Expected environment type: Computer or Browser. Got {environment_type}.")
|
||||
|
||||
def reset(self):
|
||||
"""Reset the agent state."""
|
||||
super().reset()
|
||||
|
||||
@@ -1,18 +1,22 @@
|
||||
import json
|
||||
import logging
|
||||
import platform
|
||||
from copy import deepcopy
|
||||
from datetime import datetime
|
||||
from textwrap import dedent
|
||||
from typing import List, Optional, cast
|
||||
|
||||
from openai.types.responses import Response, ResponseOutputItem
|
||||
|
||||
from khoj.database.models import ChatModel
|
||||
from khoj.processor.conversation.utils import AgentMessage
|
||||
from khoj.processor.operator.operator_actions import *
|
||||
from khoj.processor.operator.operator_agent_base import (
|
||||
AgentActResult,
|
||||
AgentMessage,
|
||||
OperatorAgent,
|
||||
from khoj.processor.operator.operator_agent_base import AgentActResult, OperatorAgent
|
||||
from khoj.processor.operator.operator_environment_base import (
|
||||
EnvironmentType,
|
||||
EnvState,
|
||||
EnvStepResult,
|
||||
)
|
||||
from khoj.processor.operator.operator_environment_base import EnvState, EnvStepResult
|
||||
from khoj.utils.helpers import get_openai_async_client, is_none_or_empty
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -21,80 +25,18 @@ logger = logging.getLogger(__name__)
|
||||
# --- Anthropic Operator Agent ---
|
||||
class OpenAIOperatorAgent(OperatorAgent):
|
||||
async def act(self, current_state: EnvState) -> AgentActResult:
|
||||
client = get_openai_async_client(
|
||||
self.vision_model.ai_model_api.api_key, self.vision_model.ai_model_api.api_base_url
|
||||
)
|
||||
safety_check_prefix = "Say 'continue' after resolving the following safety checks to proceed:"
|
||||
safety_check_message = None
|
||||
actions: List[OperatorAction] = []
|
||||
action_results: List[dict] = []
|
||||
self._commit_trace() # Commit trace before next action
|
||||
system_prompt = f"""<SYSTEM_CAPABILITY>
|
||||
* You are Khoj, a smart web browser operating assistant. You help the users accomplish tasks using a web browser.
|
||||
* You operate a single Chromium browser page using Playwright.
|
||||
* You cannot access the OS or filesystem.
|
||||
* You can interact with the web browser to perform tasks like clicking, typing, scrolling, and more using the computer_use_preview tool.
|
||||
* You can use the additional back() and goto() functions to navigate the browser.
|
||||
* Always use the goto() function to navigate to a specific URL. If you see nothing, try goto duckduckgo.com
|
||||
* When viewing a webpage it can be helpful to zoom out so that you can see everything on the page. Either that, or make sure you scroll down to see everything before deciding something isn't available.
|
||||
* When using your computer function calls, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request.
|
||||
* Perform web searches using DuckDuckGo. Don't use Google even if requested as the query will fail.
|
||||
* The current date is {datetime.today().strftime('%A, %B %-d, %Y')}.
|
||||
* The current URL is {current_state.url}.
|
||||
</SYSTEM_CAPABILITY>
|
||||
|
||||
<IMPORTANT>
|
||||
* You are allowed upto {self.max_iterations} iterations to complete the task.
|
||||
* After initialization if the browser is blank, enter a website URL using the goto() function instead of waiting
|
||||
</IMPORTANT>
|
||||
"""
|
||||
tools = [
|
||||
{
|
||||
"type": "computer_use_preview",
|
||||
"display_width": 1024, # TODO: Get from env
|
||||
"display_height": 768, # TODO: Get from env
|
||||
"environment": "browser",
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"name": "back",
|
||||
"description": "Go back to the previous page.",
|
||||
"parameters": {},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"name": "goto",
|
||||
"description": "Go to a specific URL.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "Fully qualified URL to navigate to.",
|
||||
},
|
||||
},
|
||||
"additionalProperties": False,
|
||||
"required": ["url"],
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
system_prompt = self.get_instructions(self.environment_type, current_state)
|
||||
tools = self.get_tools(self.environment_type, current_state)
|
||||
if is_none_or_empty(self.messages):
|
||||
self.messages = [AgentMessage(role="user", content=self.query)]
|
||||
|
||||
messages_for_api = self._format_message_for_api(self.messages)
|
||||
response: Response = await client.responses.create(
|
||||
model="computer-use-preview",
|
||||
input=messages_for_api,
|
||||
instructions=system_prompt,
|
||||
tools=tools,
|
||||
parallel_tool_calls=False, # Keep sequential for now
|
||||
max_output_tokens=4096, # TODO: Make configurable?
|
||||
truncation="auto",
|
||||
)
|
||||
|
||||
logger.debug(f"Openai response: {response.model_dump_json()}")
|
||||
self.messages += [AgentMessage(role="environment", content=response.output)]
|
||||
response = await self._call_model(self.vision_model, system_prompt, tools)
|
||||
self.messages += [AgentMessage(role="assistant", content=response.output)]
|
||||
rendered_response = self._render_response(response.output, current_state.screenshot)
|
||||
|
||||
last_call_id = None
|
||||
@@ -174,6 +116,9 @@ class OpenAIOperatorAgent(OperatorAgent):
|
||||
"summary": [],
|
||||
}
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Unsupported response block type: {block.type}")
|
||||
content = f"Unsupported response block type: {block.type}"
|
||||
if action_to_run or content:
|
||||
actions.append(action_to_run)
|
||||
if action_to_run or content:
|
||||
@@ -220,6 +165,10 @@ class OpenAIOperatorAgent(OperatorAgent):
|
||||
elif action_result["type"] == "reasoning":
|
||||
items_to_pop.append(idx) # Mark placeholder reasoning action result for removal
|
||||
continue
|
||||
elif action_result["type"] == "computer_call" and action_result["status"] == "in_progress":
|
||||
if isinstance(result_content, dict):
|
||||
result_content["status"] = "completed" # Mark in-progress actions as completed
|
||||
action_result["output"] = result_content
|
||||
else:
|
||||
# Add text data
|
||||
action_result["output"] = result_content
|
||||
@@ -229,11 +178,45 @@ class OpenAIOperatorAgent(OperatorAgent):
|
||||
|
||||
self.messages += [AgentMessage(role="environment", content=agent_action.action_results)]
|
||||
|
||||
async def summarize(self, current_state: EnvState, summarize_prompt: str = None) -> str:
|
||||
summarize_prompt = summarize_prompt or self.summarize_prompt
|
||||
self.messages.append(AgentMessage(role="user", content=summarize_prompt))
|
||||
response = await self._call_model(self.vision_model, summarize_prompt, [])
|
||||
self.messages += [AgentMessage(role="assistant", content=response.output)]
|
||||
if not self.messages:
|
||||
return "No actions to summarize."
|
||||
return self._compile_response(self.messages[-1].content)
|
||||
|
||||
async def _call_model(self, model: ChatModel, system_prompt, tools) -> Response:
|
||||
client = get_openai_async_client(model.ai_model_api.api_key, model.ai_model_api.api_base_url)
|
||||
if tools:
|
||||
model_name = "computer-use-preview"
|
||||
else:
|
||||
model_name = model.name
|
||||
|
||||
# Format messages for OpenAI API
|
||||
messages_for_api = self._format_message_for_api(self.messages)
|
||||
# format messages for summary if model is not computer-use-preview
|
||||
if model_name != "computer-use-preview":
|
||||
messages_for_api = self._format_messages_for_summary(messages_for_api)
|
||||
|
||||
response: Response = await client.responses.create(
|
||||
model=model_name,
|
||||
input=messages_for_api,
|
||||
instructions=system_prompt,
|
||||
tools=tools,
|
||||
parallel_tool_calls=False,
|
||||
truncation="auto",
|
||||
)
|
||||
|
||||
logger.debug(f"Openai response: {response.model_dump_json()}")
|
||||
return response
|
||||
|
||||
def _format_message_for_api(self, messages: list[AgentMessage]) -> list:
|
||||
"""Format the message for OpenAI API."""
|
||||
formatted_messages: list = []
|
||||
for message in messages:
|
||||
if message.role == "environment":
|
||||
if message.role == "assistant":
|
||||
if isinstance(message.content, list):
|
||||
# Remove reasoning message if not followed by computer call
|
||||
if (
|
||||
@@ -252,18 +235,23 @@ class OpenAIOperatorAgent(OperatorAgent):
|
||||
message.content.pop(0)
|
||||
formatted_messages.extend(message.content)
|
||||
else:
|
||||
logger.warning(f"Expected message content list from environment, got {type(message.content)}")
|
||||
logger.warning(f"Expected message content list from assistant, got {type(message.content)}")
|
||||
elif message.role == "environment":
|
||||
formatted_messages.extend(message.content)
|
||||
else:
|
||||
if isinstance(message.content, list):
|
||||
message.content = "\n".join([part["text"] for part in message.content if part["type"] == "text"])
|
||||
formatted_messages.append(
|
||||
{
|
||||
"role": message.role,
|
||||
"content": message.content,
|
||||
}
|
||||
)
|
||||
|
||||
return formatted_messages
|
||||
|
||||
def compile_response(self, response_content: str | list[dict | ResponseOutputItem]) -> str:
|
||||
"""Compile the response from model into a single string."""
|
||||
def _compile_response(self, response_content: str | list[dict | ResponseOutputItem]) -> str:
|
||||
"""Compile the response from model into a single string for prompt tracing."""
|
||||
# Handle case where response content is a string.
|
||||
# This is the case when response content is a user query
|
||||
if isinstance(response_content, str):
|
||||
@@ -347,3 +335,123 @@ class OpenAIOperatorAgent(OperatorAgent):
|
||||
}
|
||||
|
||||
return render_payload
|
||||
|
||||
def get_instructions(self, environment_type: EnvironmentType, current_state: EnvState) -> str:
|
||||
"""Return system instructions for the OpenAI operator."""
|
||||
if environment_type == EnvironmentType.BROWSER:
|
||||
return dedent(
|
||||
f"""
|
||||
<SYSTEM_CAPABILITY>
|
||||
* You are Khoj, a smart web browser operating assistant. You help the users accomplish tasks using a web browser.
|
||||
* You operate a single Chromium browser page using Playwright.
|
||||
* You cannot access the OS or filesystem.
|
||||
* You can interact with the web browser to perform tasks like clicking, typing, scrolling, and more using the computer_use_preview tool.
|
||||
* You can use the additional back() and goto() functions to navigate the browser.
|
||||
* Always use the goto() function to navigate to a specific URL. If you see nothing, try goto duckduckgo.com
|
||||
* When viewing a webpage it can be helpful to zoom out so that you can see everything on the page. Either that, or make sure you scroll down to see everything before deciding something isn't available.
|
||||
* When using your computer function calls, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request.
|
||||
* Perform web searches using DuckDuckGo. Don't use Google even if requested as the query will fail.
|
||||
* The current date is {datetime.today().strftime('%A, %B %-d, %Y')}.
|
||||
* The current URL is {current_state.url}.
|
||||
</SYSTEM_CAPABILITY>
|
||||
|
||||
<IMPORTANT>
|
||||
* You are allowed upto {self.max_iterations} iterations to complete the task.
|
||||
* After initialization if the browser is blank, enter a website URL using the goto() function instead of waiting
|
||||
</IMPORTANT>
|
||||
"""
|
||||
).lstrip()
|
||||
elif environment_type == EnvironmentType.COMPUTER:
|
||||
return dedent(
|
||||
f"""
|
||||
<SYSTEM_CAPABILITY>
|
||||
* You are Khoj, a smart computer operating assistant. You help the users accomplish their tasks using a computer.
|
||||
* You can interact with the computer to perform tasks like clicking, typing, scrolling, and more using the computer_use_preview tool.
|
||||
* When viewing a document or webpage it can be helpful to zoom out or scroll down to ensure you see everything before deciding something isn't available.
|
||||
* When using your computer function calls, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request.
|
||||
* Perform web searches using DuckDuckGo. Don't use Google even if requested as the query will fail.
|
||||
* You are allowed upto {self.max_iterations} iterations to complete the task.
|
||||
</SYSTEM_CAPABILITY>
|
||||
|
||||
<CONTEXT>
|
||||
* The current date is {datetime.today().strftime('%A, %B %-d, %Y')}.
|
||||
</CONTEXT>
|
||||
"""
|
||||
).lstrip()
|
||||
else:
|
||||
raise ValueError(f"Unsupported environment type: {environment_type}")
|
||||
|
||||
def get_tools(self, environment_type: EnvironmentType, current_state: EnvState) -> list[dict]:
|
||||
"""Return the tools available for the OpenAI operator."""
|
||||
if environment_type == EnvironmentType.COMPUTER:
|
||||
# TODO: Get OS info from the environment
|
||||
# For now, assume Linux as the environment OS
|
||||
environment_os = "linux"
|
||||
# environment = "mac" if platform.system() == "Darwin" else "windows" if platform.system() == "Windows" else "linux"
|
||||
else:
|
||||
environment_os = "browser"
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "computer_use_preview",
|
||||
"display_width": current_state.width,
|
||||
"display_height": current_state.height,
|
||||
"environment": environment_os,
|
||||
}
|
||||
]
|
||||
if environment_type == EnvironmentType.BROWSER:
|
||||
tools += [
|
||||
{
|
||||
"type": "function",
|
||||
"name": "back",
|
||||
"description": "Go back to the previous page.",
|
||||
"parameters": {},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"name": "goto",
|
||||
"description": "Go to a specific URL.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "Fully qualified URL to navigate to.",
|
||||
},
|
||||
},
|
||||
"additionalProperties": False,
|
||||
"required": ["url"],
|
||||
},
|
||||
},
|
||||
]
|
||||
return tools
|
||||
|
||||
def _format_messages_for_summary(self, formatted_messages: List[dict]) -> List[dict]:
|
||||
"""Format messages for summary."""
|
||||
# Format messages to interact with non computer use AI models
|
||||
items_to_drop = [] # Track indices to drop reasoning messages
|
||||
for idx, msg in enumerate(formatted_messages):
|
||||
if isinstance(msg, dict) and "content" in msg:
|
||||
continue
|
||||
elif isinstance(msg, dict) and "output" in msg:
|
||||
# Drop current_url from output as not supported for non computer operations
|
||||
if "current_url" in msg["output"]:
|
||||
del msg["output"]["current_url"]
|
||||
formatted_messages[idx] = {"role": "user", "content": [msg["output"]]}
|
||||
elif isinstance(msg, str):
|
||||
formatted_messages[idx] = {"role": "user", "content": [{"type": "input_text", "text": msg}]}
|
||||
else:
|
||||
text = self._compile_response([msg])
|
||||
if not text:
|
||||
items_to_drop.append(idx) # Track index to drop reasoning message
|
||||
else:
|
||||
formatted_messages[idx] = {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "output_text", "text": text}],
|
||||
}
|
||||
|
||||
# Remove reasoning messages for non-computer use models
|
||||
for idx in reversed(items_to_drop):
|
||||
formatted_messages.pop(idx)
|
||||
|
||||
return formatted_messages
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -6,9 +7,18 @@ from pydantic import BaseModel
|
||||
from khoj.processor.operator.operator_actions import OperatorAction
|
||||
|
||||
|
||||
class EnvironmentType(Enum):
|
||||
"""Type of environment to operate."""
|
||||
|
||||
COMPUTER = "computer"
|
||||
BROWSER = "browser"
|
||||
|
||||
|
||||
class EnvState(BaseModel):
|
||||
url: str
|
||||
height: int
|
||||
width: int
|
||||
screenshot: Optional[str] = None
|
||||
url: Optional[str] = None
|
||||
|
||||
|
||||
class EnvStepResult(BaseModel):
|
||||
|
||||
@@ -5,7 +5,7 @@ import logging
|
||||
import os
|
||||
from typing import Optional, Set, Union
|
||||
|
||||
from khoj.processor.operator.operator_actions import OperatorAction, Point
|
||||
from khoj.processor.operator.operator_actions import DragAction, OperatorAction, Point
|
||||
from khoj.processor.operator.operator_environment_base import (
|
||||
Environment,
|
||||
EnvState,
|
||||
@@ -124,10 +124,10 @@ class BrowserEnvironment(Environment):
|
||||
|
||||
async def get_state(self) -> EnvState:
|
||||
if not self.page or self.page.is_closed():
|
||||
return EnvState(url="about:blank", screenshot=None)
|
||||
return EnvState(url="about:blank", screenshot=None, height=self.height, width=self.width)
|
||||
url = self.page.url
|
||||
screenshot = await self._get_screenshot()
|
||||
return EnvState(url=url, screenshot=screenshot)
|
||||
return EnvState(url=url, screenshot=screenshot, height=self.height, width=self.width)
|
||||
|
||||
async def step(self, action: OperatorAction) -> EnvStepResult:
|
||||
if not self.page or self.page.is_closed():
|
||||
@@ -246,6 +246,8 @@ class BrowserEnvironment(Environment):
|
||||
logger.debug(f"Action: {action.type} to ({x},{y})")
|
||||
|
||||
case "drag":
|
||||
if not isinstance(action, DragAction):
|
||||
raise TypeError(f"Invalid action type for drag")
|
||||
path = action.path
|
||||
if not path:
|
||||
error = "Missing path for drag action"
|
||||
|
||||
658
src/khoj/processor/operator/operator_environment_computer.py
Normal file
658
src/khoj/processor/operator/operator_environment_computer.py
Normal file
@@ -0,0 +1,658 @@
|
||||
import ast
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
import platform
|
||||
import subprocess
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
from khoj.processor.operator.operator_actions import DragAction, OperatorAction, Point
|
||||
from khoj.processor.operator.operator_environment_base import (
|
||||
Environment,
|
||||
EnvState,
|
||||
EnvStepResult,
|
||||
)
|
||||
from khoj.utils.helpers import convert_image_to_webp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# --- Concrete Computer Environment ---
|
||||
class ComputerEnvironment(Environment):
|
||||
def __init__(
|
||||
self,
|
||||
provider: Literal["local", "docker"] = "local",
|
||||
docker_display: str = ":99",
|
||||
docker_container_name: str = "khoj-computer",
|
||||
):
|
||||
self.provider = provider
|
||||
self.docker_display = docker_display
|
||||
self.docker_container_name = docker_container_name
|
||||
|
||||
self.width: int = 0
|
||||
self.height: int = 0
|
||||
self.mouse_pos: Point = Point(x=0, y=0)
|
||||
|
||||
async def _execute(self, func_name, *args, **kwargs):
|
||||
"""
|
||||
Executes a pyautogui function, abstracting the execution context.
|
||||
Currently runs locally using asyncio.to_thread.
|
||||
"""
|
||||
python_command_str = self.generate_pyautogui_command(func_name, *args, **kwargs)
|
||||
# Docker execution
|
||||
if self.provider == "docker":
|
||||
try:
|
||||
output_str = await self.docker_execute(python_command_str)
|
||||
except RuntimeError as e: # Catch other Docker execution errors
|
||||
logger.error(f"Error during Docker execution of {func_name}: {e}")
|
||||
raise # Re-raise as a general error for the caller to handle
|
||||
# Local execution
|
||||
else:
|
||||
process = await asyncio.to_thread(
|
||||
subprocess.run,
|
||||
["python3", "-c", python_command_str],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=False, # We check returncode manually
|
||||
)
|
||||
output_str = process.stdout.strip()
|
||||
if process.returncode != 0:
|
||||
if "FailSafeException" in process.stderr or "FailSafeException" in process.stdout:
|
||||
# Extract the message if possible, otherwise use generic
|
||||
fs_msg = process.stderr or process.stdout
|
||||
raise KeyboardInterrupt(fs_msg)
|
||||
else:
|
||||
error_msg = (
|
||||
f'Local script execution failed:\nCmd: python3 -c "{python_command_str[:200]}...{python_command_str[-200:]}\n'
|
||||
f"Return Code: {process.returncode}\nStderr: {process.stderr}\nStdout: {process.stdout}"
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(f"Local script execution error: {process.stderr or process.stdout}")
|
||||
if not output_str or output_str == "None":
|
||||
return None
|
||||
|
||||
try:
|
||||
return ast.literal_eval(output_str)
|
||||
except (ValueError, SyntaxError):
|
||||
# If not a literal (e.g., some other string output), return as is
|
||||
return output_str
|
||||
|
||||
async def start(self, width: int, height: int) -> None:
|
||||
"""
|
||||
Initializes the computer environment.
|
||||
The width and height parameters are logged, but actual screen dimensions are used.
|
||||
"""
|
||||
screen_width, screen_height = await self._execute("size")
|
||||
|
||||
self.width = screen_width
|
||||
self.height = screen_height
|
||||
# Initialize mouse position to center, or current if available
|
||||
try:
|
||||
current_x, current_y = await self._execute("position")
|
||||
self.mouse_pos = Point(x=current_x, y=current_y)
|
||||
except Exception: # Fallback if position cannot be obtained initially
|
||||
self.mouse_pos = Point(x=self.width / 2, y=self.height / 2)
|
||||
|
||||
logger.info(
|
||||
f"Computer environment started. Screen size: {self.width}x{self.height}. "
|
||||
f"Input width/height ({width}x{height}) are noted but screen dimensioning uses actual screen size. "
|
||||
f"Initial mouse position: ({self.mouse_pos.x},{self.mouse_pos.y})"
|
||||
)
|
||||
|
||||
async def _get_screenshot(self) -> Optional[str]:
|
||||
try:
|
||||
# Get screenshot
|
||||
base64_png_str = await self._execute("screenshot")
|
||||
screenshot_bytes = base64.b64decode(base64_png_str)
|
||||
|
||||
# Get current mouse position
|
||||
current_mouse_x, current_mouse_y = await self._execute("position")
|
||||
draw_pos = Point(x=current_mouse_x, y=current_mouse_y)
|
||||
|
||||
# Add mouse position to screenshot
|
||||
screenshot_bytes_with_mouse = await self._draw_mouse_position(screenshot_bytes, draw_pos)
|
||||
screenshot_webp_bytes = convert_image_to_webp(screenshot_bytes_with_mouse)
|
||||
return base64.b64encode(screenshot_webp_bytes).decode("utf-8")
|
||||
except KeyboardInterrupt: # Propagate keyboard interrupts
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get screenshot: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
async def _draw_mouse_position(self, screenshot_bytes: bytes, mouse_pos: Point) -> bytes:
|
||||
if Image is None or ImageDraw is None:
|
||||
return screenshot_bytes
|
||||
try:
|
||||
image = Image.open(io.BytesIO(screenshot_bytes))
|
||||
draw = ImageDraw.Draw(image)
|
||||
radius = 8
|
||||
# Red circle with black border for better visibility
|
||||
draw.ellipse(
|
||||
(mouse_pos.x - radius, mouse_pos.y - radius, mouse_pos.x + radius, mouse_pos.y + radius),
|
||||
outline="black",
|
||||
fill="red",
|
||||
width=2,
|
||||
)
|
||||
output_buffer = io.BytesIO()
|
||||
image.save(output_buffer, format="PNG")
|
||||
return output_buffer.getvalue()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to draw mouse position: {e}")
|
||||
return screenshot_bytes
|
||||
|
||||
async def get_state(self) -> EnvState:
|
||||
screenshot = await self._get_screenshot()
|
||||
return EnvState(screenshot=screenshot, height=self.height, width=self.width)
|
||||
|
||||
async def step(self, action: OperatorAction) -> EnvStepResult:
|
||||
output: Optional[Union[str, dict]] = None
|
||||
error: Optional[str] = None
|
||||
step_type: str = "text"
|
||||
|
||||
try:
|
||||
match action.type:
|
||||
case "click":
|
||||
x, y, button_name = action.x, action.y, action.button
|
||||
modifiers_to_press = self.parse_key_combination(action.modifiers) if action.modifiers else []
|
||||
for mod_key in modifiers_to_press:
|
||||
await self._execute("keyDown", mod_key)
|
||||
|
||||
if button_name == "wheel":
|
||||
# Perform a small scroll action at this position (e.g., one "tick" down)
|
||||
# Pyautogui scroll: positive up, negative down.
|
||||
await self._execute("scroll", -1, x=x, y=y)
|
||||
output = f"Scrolled wheel at ({x}, {y})"
|
||||
else:
|
||||
pyautogui_button = button_name.lower() if button_name else "left"
|
||||
await self._execute("click", x=x, y=y, button=pyautogui_button)
|
||||
output = f"{button_name.capitalize() if button_name else 'Left'} clicked at ({x}, {y})"
|
||||
|
||||
for mod_key in reversed(modifiers_to_press):
|
||||
await self._execute("keyUp", mod_key)
|
||||
|
||||
self.mouse_pos = Point(x=x, y=y)
|
||||
logger.debug(f"Action: {action.type} {button_name} at ({x},{y}) with modifiers {action.modifiers}")
|
||||
|
||||
case "double_click":
|
||||
x, y = action.x, action.y
|
||||
await self._execute("doubleClick", x=x, y=y)
|
||||
self.mouse_pos = Point(x=x, y=y)
|
||||
output = f"Double clicked at ({x}, {y})"
|
||||
logger.debug(f"Action: {action.type} at ({x},{y})")
|
||||
|
||||
case "triple_click":
|
||||
x, y = action.x, action.y
|
||||
await self._execute("click", x=x, y=y, clicks=3)
|
||||
self.mouse_pos = Point(x=x, y=y)
|
||||
output = f"Triple clicked at ({x}, {y})"
|
||||
logger.debug(f"Action: {action.type} at ({x},{y})")
|
||||
|
||||
case "scroll":
|
||||
current_x_pos, current_y_pos = await self._execute("position")
|
||||
target_x = action.x if action.x is not None else current_x_pos
|
||||
target_y = action.y if action.y is not None else current_y_pos
|
||||
|
||||
if target_x != current_x_pos or target_y != current_y_pos:
|
||||
await self._execute("moveTo", target_x, target_y)
|
||||
|
||||
self.mouse_pos = Point(x=target_x, y=target_y) # Update mouse pos to scroll location
|
||||
|
||||
if action.scroll_x is not None or action.scroll_y is not None:
|
||||
scroll_x_amount = action.scroll_x or 0
|
||||
scroll_y_amount = action.scroll_y or 0
|
||||
|
||||
if scroll_x_amount != 0:
|
||||
await self._execute("hscroll", scroll_x_amount)
|
||||
if scroll_y_amount != 0:
|
||||
# pyautogui scroll: positive up, so negate for typical "scroll down" meaning positive y
|
||||
await self._execute("scroll", -scroll_y_amount)
|
||||
output = f"Scrolled by (x:{scroll_x_amount}, y:{scroll_y_amount}) at ({target_x}, {target_y})"
|
||||
elif action.scroll_direction:
|
||||
# Define scroll unit (number of pyautogui scroll 'clicks')
|
||||
# This might need tuning based on desired sensitivity.
|
||||
pyautogui_scroll_clicks_per_unit = 1
|
||||
amount = action.scroll_amount or 1
|
||||
total_scroll_clicks = pyautogui_scroll_clicks_per_unit * amount
|
||||
|
||||
if action.scroll_direction == "up":
|
||||
await self._execute("scroll", total_scroll_clicks)
|
||||
elif action.scroll_direction == "down":
|
||||
await self._execute("scroll", -total_scroll_clicks)
|
||||
elif action.scroll_direction == "left":
|
||||
await self._execute("hscroll", -total_scroll_clicks)
|
||||
elif action.scroll_direction == "right":
|
||||
await self._execute("hscroll", total_scroll_clicks)
|
||||
output = f"Scrolled {action.scroll_direction} by {amount} units at ({target_x}, {target_y})"
|
||||
else:
|
||||
error = "Scroll action requires either scroll_x/y or scroll_direction"
|
||||
logger.debug(f"Action: {action.type} details: {output or error}")
|
||||
|
||||
case "keypress":
|
||||
mapped_keys = [self.CUA_KEY_TO_PYAUTOGUI_KEY.get(k.lower(), k.lower()) for k in action.keys]
|
||||
key_string = "N/A"
|
||||
if not mapped_keys:
|
||||
error = "Keypress action requires at least one key"
|
||||
elif len(mapped_keys) > 1:
|
||||
await self._execute("hotkey", *mapped_keys)
|
||||
key_string = "+".join(mapped_keys)
|
||||
else:
|
||||
await self._execute("press", mapped_keys[0])
|
||||
key_string = mapped_keys[0]
|
||||
if not error:
|
||||
output = f"Pressed key(s): {key_string}"
|
||||
logger.debug(f"Action: {action.type} '{key_string}'")
|
||||
|
||||
case "type":
|
||||
text_to_type = action.text
|
||||
await self._execute("typewrite", text_to_type, interval=0.02) # Small interval
|
||||
output = f"Typed text: {text_to_type}"
|
||||
logger.debug(f"Action: {action.type} '{text_to_type}'")
|
||||
|
||||
case "wait":
|
||||
duration = action.duration
|
||||
await asyncio.sleep(duration)
|
||||
output = f"Waited for {duration} seconds"
|
||||
logger.debug(f"Action: {action.type} for {duration}s")
|
||||
|
||||
case "screenshot":
|
||||
step_type = "image"
|
||||
# The actual screenshot data is added from after_state later
|
||||
output = {"message": "Screenshot captured", "url": "desktop"}
|
||||
logger.debug(f"Action: {action.type}")
|
||||
|
||||
case "move":
|
||||
x, y = action.x, action.y
|
||||
await self._execute("moveTo", x, y, duration=0.2) # Small duration for smooth move
|
||||
self.mouse_pos = Point(x=x, y=y)
|
||||
output = f"Moved mouse to ({x}, {y})"
|
||||
logger.debug(f"Action: {action.type} to ({x},{y})")
|
||||
|
||||
case "drag":
|
||||
if not isinstance(action, DragAction):
|
||||
raise TypeError("Invalid action type for drag")
|
||||
drag_path = action.path
|
||||
if not drag_path:
|
||||
error = "Missing path for drag action"
|
||||
else:
|
||||
start_x, start_y = drag_path[0].x, drag_path[0].y
|
||||
await self._execute("moveTo", start_x, start_y, duration=0.1)
|
||||
await self._execute("mouseDown")
|
||||
for point in drag_path[1:]:
|
||||
await self._execute("moveTo", point.x, point.y, duration=0.05)
|
||||
await self._execute("mouseUp")
|
||||
self.mouse_pos = Point(x=drag_path[-1].x, y=drag_path[-1].y)
|
||||
output = f"Drag along path starting at ({start_x},{start_y})"
|
||||
logger.debug(f"Action: {action.type} with {len(drag_path)} points")
|
||||
|
||||
case "mouse_down":
|
||||
pyautogui_button = action.button.lower() if action.button else "left"
|
||||
await self._execute("mouseDown", button=pyautogui_button)
|
||||
output = f"{action.button.capitalize() if action.button else 'Left'} mouse button down"
|
||||
logger.debug(f"Action: {action.type} {action.button}")
|
||||
|
||||
case "mouse_up":
|
||||
pyautogui_button = action.button.lower() if action.button else "left"
|
||||
await self._execute("mouseUp", button=pyautogui_button)
|
||||
output = f"{action.button.capitalize() if action.button else 'Left'} mouse button up"
|
||||
logger.debug(f"Action: {action.type} {action.button}")
|
||||
|
||||
case "hold_key":
|
||||
keys_to_hold_str = action.text
|
||||
duration = action.duration
|
||||
parsed_keys = self.parse_key_combination(keys_to_hold_str)
|
||||
if not parsed_keys:
|
||||
error = f"No valid keys found in '{keys_to_hold_str}' for hold_key"
|
||||
else:
|
||||
for key_to_hold in parsed_keys:
|
||||
await self._execute("keyDown", key_to_hold)
|
||||
await asyncio.sleep(duration) # Non-pyautogui, direct sleep
|
||||
for key_to_hold in reversed(parsed_keys): # Release in reverse order
|
||||
await self._execute("keyUp", key_to_hold)
|
||||
output = (
|
||||
f"Held key{'s' if len(parsed_keys) > 1 else ''} {keys_to_hold_str} for {duration} seconds"
|
||||
)
|
||||
logger.debug(f"Action: {action.type} '{keys_to_hold_str}' for {duration}s")
|
||||
|
||||
case "key_down":
|
||||
key_to_press = self.CUA_KEY_TO_PYAUTOGUI_KEY.get(action.key.lower(), action.key)
|
||||
await self._execute("keyDown", key_to_press)
|
||||
output = f"Key down: {key_to_press}"
|
||||
logger.debug(f"Action: {action.type} {key_to_press}")
|
||||
|
||||
case "key_up":
|
||||
key_to_release = self.CUA_KEY_TO_PYAUTOGUI_KEY.get(action.key.lower(), action.key)
|
||||
await self._execute("keyUp", key_to_release)
|
||||
output = f"Key up: {key_to_release}"
|
||||
logger.debug(f"Action: {action.type} {key_to_release}")
|
||||
|
||||
case "cursor_position":
|
||||
pos_x, pos_y = await self._execute("position")
|
||||
self.mouse_pos = Point(x=pos_x, y=pos_y)
|
||||
output = f"Cursor position is ({pos_x}, {pos_y})"
|
||||
logger.debug(f"Action: {action.type}, position: ({pos_x},{pos_y})")
|
||||
|
||||
case "goto":
|
||||
output = f"Goto action (URL: {action.url}) is not applicable for ComputerEnvironment."
|
||||
logger.warning(f"Unsupported action: {action.type} for ComputerEnvironment.")
|
||||
|
||||
case "back":
|
||||
output = "Back action is not applicable for ComputerEnvironment."
|
||||
logger.warning(f"Unsupported action: {action.type} for ComputerEnvironment.")
|
||||
|
||||
case "terminal":
|
||||
# Execute terminal command
|
||||
result = await self._execute_shell_command(action.command)
|
||||
if result["success"]:
|
||||
output = f"Command executed successfully:\n{result['output']}"
|
||||
else:
|
||||
error = f"Command execution failed: {result['error']}"
|
||||
logger.debug(f"Action: {action.type} with command '{action.command}'")
|
||||
|
||||
case "text_editor_view":
|
||||
# View file contents
|
||||
file_path = action.path
|
||||
view_range = action.view_range
|
||||
# Type guard: path should be str for text editor actions
|
||||
if not isinstance(file_path, str):
|
||||
raise TypeError("Invalid path type for text editor view action")
|
||||
escaped_path = file_path.replace("'", "'\"'\"'")
|
||||
is_dir = await self._execute("os.path.isdir", escaped_path)
|
||||
if is_dir:
|
||||
cmd = rf"find {escaped_path} -maxdepth 2 -not -path '*/\.*'"
|
||||
elif view_range:
|
||||
# Use head/tail to view specific line range
|
||||
start_line, end_line = view_range
|
||||
lines_to_show = end_line - start_line + 1
|
||||
cmd = f"head -n {end_line} '{escaped_path}' | tail -n {lines_to_show}"
|
||||
else:
|
||||
# View entire file
|
||||
cmd = f"cat '{escaped_path}'"
|
||||
|
||||
result = await self._execute_shell_command(cmd)
|
||||
MAX_OUTPUT_LENGTH = 15000 # Limit output length to avoid excessive data
|
||||
if len(result["output"]) > MAX_OUTPUT_LENGTH:
|
||||
result["output"] = f"{result['output'][:MAX_OUTPUT_LENGTH]}..."
|
||||
if result["success"]:
|
||||
if is_dir:
|
||||
output = f"Here's the files and directories up to 2 levels deep in {file_path}, excluding hidden items:\n{result['output']}"
|
||||
else:
|
||||
output = f"File contents of {file_path}:\n{result['output']}"
|
||||
else:
|
||||
error = f"Failed to view file {file_path}: {result['error']}"
|
||||
logger.debug(f"Action: {action.type} for file {file_path}")
|
||||
|
||||
case "text_editor_create":
|
||||
# Create new file with contents
|
||||
file_path = action.path
|
||||
file_text = action.file_text
|
||||
# Type guard: path should be str for text editor actions
|
||||
if not isinstance(file_path, str):
|
||||
raise TypeError("Invalid path type for text editor create action")
|
||||
escaped_path = file_path.replace("'", "'\"'\"'")
|
||||
escaped_content = file_text.replace("\t", " ").replace(
|
||||
"'", "'\"'\"'"
|
||||
) # Escape single quotes for shell
|
||||
cmd = f"echo '{escaped_content}' > '{escaped_path}'"
|
||||
|
||||
result = await self._execute_shell_command(cmd)
|
||||
if result["success"]:
|
||||
output = f"Created file {file_path} with {len(file_text)} characters"
|
||||
else:
|
||||
error = f"Failed to create file {file_path}: {result['error']}"
|
||||
logger.debug(f"Action: {action.type} created file {file_path}")
|
||||
|
||||
case "text_editor_str_replace":
|
||||
# Execute string replacement
|
||||
file_path = action.path
|
||||
old_str = action.old_str
|
||||
new_str = action.new_str
|
||||
|
||||
# Type guard: path should be str for text editor actions
|
||||
if not isinstance(file_path, str):
|
||||
raise TypeError("Invalid path type for text editor str_replace action")
|
||||
# Use sed for string replacement, escaping special characters
|
||||
escaped_path = file_path.replace("'", "'\"'\"'")
|
||||
escaped_old = (
|
||||
old_str.replace("\t", " ")
|
||||
.replace("\\", "\\\\")
|
||||
.replace("\n", "\\n")
|
||||
.replace("/", "\\/")
|
||||
.replace("'", "'\"'\"'")
|
||||
)
|
||||
escaped_new = (
|
||||
new_str.replace("\t", " ")
|
||||
.replace("\\", "\\\\")
|
||||
.replace("\n", "\\n")
|
||||
.replace("&", "\\&")
|
||||
.replace("/", "\\/")
|
||||
.replace("'", "'\"'\"'")
|
||||
)
|
||||
cmd = f"sed -i.bak 's/{escaped_old}/{escaped_new}/g' '{escaped_path}'"
|
||||
|
||||
result = await self._execute_shell_command(cmd)
|
||||
if result["success"]:
|
||||
output = f"Replaced '{old_str[:50]}...' with '{new_str[:50]}...' in {file_path}"
|
||||
else:
|
||||
error = f"Failed to replace text in {file_path}: {result['error']}"
|
||||
logger.debug(f"Action: {action.type} in file {file_path}")
|
||||
|
||||
case "text_editor_insert":
|
||||
# Insert text after specified line
|
||||
file_path = action.path
|
||||
insert_line = action.insert_line
|
||||
new_str = action.new_str
|
||||
|
||||
# Type guard: path should be str for text editor actions
|
||||
if not isinstance(file_path, str):
|
||||
error = "Invalid path type for text editor insert action.\n"
|
||||
error += f"Failed to insert text in {file_path}: {result['error']}"
|
||||
raise TypeError(error)
|
||||
escaped_path = file_path.replace("'", "'\"'\"'")
|
||||
escaped_content = (
|
||||
new_str.replace("\t", " ")
|
||||
.replace("\\", "\\\\")
|
||||
.replace("'", "'\"'\"'")
|
||||
.replace("\n", "\\\n")
|
||||
)
|
||||
cmd = f"sed -i.bak '{insert_line}a\\{escaped_content}' '{escaped_path}'"
|
||||
|
||||
result = await self._execute_shell_command(cmd)
|
||||
if result["success"]:
|
||||
output = f"Inserted text after line {insert_line} in {file_path}"
|
||||
else:
|
||||
error = f"Failed to insert text in {file_path}: {result['error']}"
|
||||
logger.debug(f"Action: {action.type} at line {insert_line} in file {file_path}")
|
||||
|
||||
case _:
|
||||
error = f"Unrecognized action type: {action.type}"
|
||||
logger.warning(error)
|
||||
except KeyboardInterrupt:
|
||||
error = "User interrupt. Operation aborted."
|
||||
logger.error(error)
|
||||
except TypeError as e:
|
||||
logger.error(f"Error executing action {action.type}: {e}")
|
||||
except Exception as e:
|
||||
error = f"Unexpected error executing action {action.type}: {str(e)}"
|
||||
logger.exception(
|
||||
f"Unexpected error during step execution for action: {action.model_dump_json(exclude_none=True)}"
|
||||
)
|
||||
|
||||
after_state = await self.get_state()
|
||||
|
||||
if action.type == "screenshot" and step_type == "image":
|
||||
output = {"image": after_state.screenshot, "url": after_state.url}
|
||||
|
||||
return EnvStepResult(
|
||||
type=step_type,
|
||||
output=output,
|
||||
error=error,
|
||||
current_url=after_state.url,
|
||||
screenshot_base64=after_state.screenshot,
|
||||
)
|
||||
|
||||
async def _execute_shell_command(self, command: str, new: bool = True) -> dict:
|
||||
"""Execute a shell command and return the result."""
|
||||
try:
|
||||
if self.provider == "docker":
|
||||
# Execute command in Docker container
|
||||
docker_args = [
|
||||
"docker",
|
||||
"exec",
|
||||
self.docker_container_name,
|
||||
"bash",
|
||||
"-c",
|
||||
command, # The command string is passed as a single argument to bash -c
|
||||
]
|
||||
process = await asyncio.to_thread(
|
||||
subprocess.run,
|
||||
docker_args,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=False,
|
||||
timeout=120,
|
||||
)
|
||||
else:
|
||||
# Execute command locally
|
||||
process = await asyncio.to_thread(
|
||||
subprocess.run,
|
||||
command,
|
||||
shell=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=False,
|
||||
start_new_session=new,
|
||||
timeout=120,
|
||||
)
|
||||
|
||||
if process.returncode == 0:
|
||||
return {"success": True, "output": process.stdout, "error": None}
|
||||
else:
|
||||
return {"success": False, "output": process.stdout, "error": process.stderr}
|
||||
except asyncio.TimeoutError:
|
||||
return {"success": False, "output": "", "error": f"Command timed out after 120 seconds."}
|
||||
except Exception as e:
|
||||
return {"success": False, "output": "", "error": str(e)}
|
||||
|
||||
async def close(self) -> None:
|
||||
logger.debug("Computer environment closed. No specific resources to release for PyAutoGUI.")
|
||||
|
||||
CUA_KEY_TO_PYAUTOGUI_KEY = {
|
||||
# Modifiers
|
||||
"option": "alt",
|
||||
"control": "ctrl",
|
||||
"cmd": "command",
|
||||
"super": "win",
|
||||
"meta": "command" if platform.system() == "Darwin" else "win",
|
||||
# Navigation & Editing
|
||||
"arrowdown": "down",
|
||||
"arrowleft": "left",
|
||||
"arrowright": "right",
|
||||
"arrowup": "up",
|
||||
"caps_lock": "capslock",
|
||||
"del": "delete",
|
||||
"return": "enter",
|
||||
"esc": "escape",
|
||||
"pgdn": "pagedown",
|
||||
"pgup": "pageup",
|
||||
" ": "space",
|
||||
# Numpad keys (example, pyautogui uses 'num0', 'add', 'subtract', etc.)
|
||||
"numpad0": "num0",
|
||||
"numpad_0": "num0",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def parse_key_combination(text: str) -> list[str]:
|
||||
if not text:
|
||||
return []
|
||||
|
||||
keys_str_list = text.lower().split("+")
|
||||
mapped_keys = []
|
||||
for k_str in keys_str_list:
|
||||
# Use the mapped key if found, otherwise use the string itself (e.g. 'a', '1')
|
||||
mapped_keys.append(ComputerEnvironment.CUA_KEY_TO_PYAUTOGUI_KEY.get(k_str.strip(), k_str.strip()))
|
||||
return mapped_keys
|
||||
|
||||
def generate_pyautogui_command(self, func_name: str, *args, **kwargs) -> str:
|
||||
args_repr = [repr(arg) for arg in args]
|
||||
kwargs_repr = [f"{k}={repr(v)}" for k, v in kwargs.items()]
|
||||
all_params_repr = ", ".join(args_repr + kwargs_repr)
|
||||
|
||||
# Base script setup
|
||||
script_lines = [
|
||||
"import os",
|
||||
"import pyautogui",
|
||||
]
|
||||
|
||||
if self.provider == "docker":
|
||||
script_lines.extend(
|
||||
[
|
||||
# Display export for Docker.
|
||||
f"os.environ['DISPLAY']='{self.docker_display}'",
|
||||
# Disable failsafe in Docker to avoid accidental exits
|
||||
"pyautogui.FAILSAFE = False",
|
||||
]
|
||||
)
|
||||
|
||||
# Function-specific logic
|
||||
if func_name == "screenshot":
|
||||
script_lines.extend(
|
||||
[
|
||||
"import io",
|
||||
"import base64",
|
||||
"img = pyautogui.screenshot()",
|
||||
"buf = io.BytesIO()",
|
||||
"img.save(buf, format='PNG')",
|
||||
"print(base64.b64encode(buf.getvalue()).decode('utf-8'))",
|
||||
]
|
||||
)
|
||||
elif func_name == "size":
|
||||
script_lines.extend(["size = pyautogui.size()", "print(f'({size.width}, {size.height})')"])
|
||||
elif func_name == "position":
|
||||
script_lines.extend(["pos = pyautogui.position()", "print(f'({pos.x}, {pos.y})')"])
|
||||
else: # General command structure
|
||||
script_lines.extend(
|
||||
[f"result = pyautogui.{func_name}({all_params_repr})", "print(result if result is not None else '')"]
|
||||
)
|
||||
|
||||
return "; ".join(script_lines)
|
||||
|
||||
async def docker_execute(self, python_command_str: str) -> Optional[str]:
|
||||
if not self.docker_container_name or not self.docker_display:
|
||||
logger.error("Container name or Docker display not set for Docker execution.")
|
||||
return None
|
||||
|
||||
safe_python_cmd = python_command_str.replace('"', '\\"')
|
||||
docker_full_cmd = (
|
||||
f'docker exec -e DISPLAY={self.docker_display} "{self.docker_container_name}" '
|
||||
f'python3 -c "{safe_python_cmd}"'
|
||||
)
|
||||
|
||||
try:
|
||||
process = await asyncio.to_thread(
|
||||
subprocess.run,
|
||||
docker_full_cmd,
|
||||
shell=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=False, # We check returncode manually
|
||||
)
|
||||
if process.returncode != 0:
|
||||
if "FailSafeException" in process.stderr or "FailSafeException" in process.stdout:
|
||||
raise KeyboardInterrupt(process.stderr or process.stdout)
|
||||
else:
|
||||
error_msg = (
|
||||
f"Docker command failed:\nCmd: {docker_full_cmd}\n"
|
||||
f"Return Code: {process.returncode}\nStderr: {process.stderr}\nStdout: {process.stdout}"
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(f"Docker exec error: {process.stderr or process.stdout}")
|
||||
return process.stdout.strip()
|
||||
except KeyboardInterrupt: # Re-raise if caught from above
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error running command in Docker '{docker_full_cmd}': {e}")
|
||||
# Encapsulate as RuntimeError to avoid leaking subprocess errors directly
|
||||
raise RuntimeError(f"Unexpected Docker error: {e}") from e
|
||||
@@ -26,12 +26,13 @@ from khoj.database.models import Agent, KhojUser
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.prompts import help_message, no_entries_found
|
||||
from khoj.processor.conversation.utils import (
|
||||
OperatorRun,
|
||||
ResponseWithThought,
|
||||
defilter_query,
|
||||
save_to_conversation_log,
|
||||
)
|
||||
from khoj.processor.image.generate import text_to_image
|
||||
from khoj.processor.operator.operate_browser import operate_browser
|
||||
from khoj.processor.operator import operate_environment
|
||||
from khoj.processor.speech.text_to_speech import generate_text_to_speech
|
||||
from khoj.processor.tools.online_search import (
|
||||
deduplicate_organic_results,
|
||||
@@ -65,10 +66,7 @@ from khoj.routers.helpers import (
|
||||
update_telemetry_state,
|
||||
validate_chat_model,
|
||||
)
|
||||
from khoj.routers.research import (
|
||||
InformationCollectionIteration,
|
||||
execute_information_collection,
|
||||
)
|
||||
from khoj.routers.research import ResearchIteration, research
|
||||
from khoj.routers.storage import upload_user_image_to_bucket
|
||||
from khoj.utils import state
|
||||
from khoj.utils.helpers import (
|
||||
@@ -722,10 +720,10 @@ async def chat(
|
||||
for file in raw_query_files:
|
||||
query_files[file.name] = file.content
|
||||
|
||||
research_results: List[InformationCollectionIteration] = []
|
||||
research_results: List[ResearchIteration] = []
|
||||
online_results: Dict = dict()
|
||||
code_results: Dict = dict()
|
||||
operator_results: Dict[str, str] = {}
|
||||
operator_results: List[OperatorRun] = []
|
||||
compiled_references: List[Any] = []
|
||||
inferred_queries: List[Any] = []
|
||||
attached_file_context = gather_raw_query_files(query_files)
|
||||
@@ -960,11 +958,10 @@ async def chat(
|
||||
last_message = conversation.messages[-1]
|
||||
online_results = {key: val.model_dump() for key, val in last_message.onlineContext.items() or []}
|
||||
code_results = {key: val.model_dump() for key, val in last_message.codeContext.items() or []}
|
||||
operator_results = last_message.operatorContext or {}
|
||||
compiled_references = [ref.model_dump() for ref in last_message.context or []]
|
||||
research_results = [
|
||||
InformationCollectionIteration(**iter_dict) for iter_dict in last_message.researchContext or []
|
||||
]
|
||||
research_results = [ResearchIteration(**iter_dict) for iter_dict in last_message.researchContext or []]
|
||||
operator_results = [OperatorRun(**iter_dict) for iter_dict in last_message.operatorContext or []]
|
||||
train_of_thought = [thought.model_dump() for thought in last_message.trainOfThought or []]
|
||||
# Drop the interrupted message from conversation history
|
||||
meta_log["chat"].pop()
|
||||
logger.info(f"Loaded interrupted partial context from conversation {conversation_id}.")
|
||||
@@ -1009,12 +1006,12 @@ async def chat(
|
||||
file_filters = conversation.file_filters if conversation and conversation.file_filters else []
|
||||
|
||||
if conversation_commands == [ConversationCommand.Research]:
|
||||
async for research_result in execute_information_collection(
|
||||
async for research_result in research(
|
||||
user=user,
|
||||
query=defiltered_query,
|
||||
conversation_id=conversation_id,
|
||||
conversation_history=meta_log,
|
||||
previous_iterations=research_results,
|
||||
previous_iterations=list(research_results),
|
||||
query_images=uploaded_images,
|
||||
agent=agent,
|
||||
send_status_func=partial(send_event, ChatEvent.STATUS),
|
||||
@@ -1025,7 +1022,7 @@ async def chat(
|
||||
tracer=tracer,
|
||||
cancellation_event=cancellation_event,
|
||||
):
|
||||
if isinstance(research_result, InformationCollectionIteration):
|
||||
if isinstance(research_result, ResearchIteration):
|
||||
if research_result.summarizedResult:
|
||||
if research_result.onlineContext:
|
||||
online_results.update(research_result.onlineContext)
|
||||
@@ -1033,13 +1030,26 @@ async def chat(
|
||||
code_results.update(research_result.codeContext)
|
||||
if research_result.context:
|
||||
compiled_references.extend(research_result.context)
|
||||
if research_result.operatorContext:
|
||||
operator_results.update(research_result.operatorContext)
|
||||
if not research_results or research_results[-1] is not research_result:
|
||||
research_results.append(research_result)
|
||||
|
||||
else:
|
||||
yield research_result
|
||||
|
||||
# Track operator results across research and operator iterations
|
||||
# This relies on two conditions:
|
||||
# 1. Check to append new (partial) operator results
|
||||
# Relies on triggering this check on every status updates.
|
||||
# Status updates cascade up from operator to research to chat api on every step.
|
||||
# 2. Keep operator results in sync with each research operator step
|
||||
# Relies on python object references to ensure operator results
|
||||
# are implicitly kept in sync after the initial append
|
||||
if (
|
||||
research_results
|
||||
and research_results[-1].operatorContext
|
||||
and (not operator_results or operator_results[-1] is not research_results[-1].operatorContext)
|
||||
):
|
||||
operator_results.append(research_results[-1].operatorContext)
|
||||
|
||||
# researched_results = await extract_relevant_info(q, researched_results, agent)
|
||||
if state.verbose > 1:
|
||||
logger.debug(f'Researched Results: {"".join(r.summarizedResult for r in research_results)}')
|
||||
@@ -1292,11 +1302,12 @@ async def chat(
|
||||
)
|
||||
if ConversationCommand.Operator in conversation_commands:
|
||||
try:
|
||||
async for result in operate_browser(
|
||||
async for result in operate_environment(
|
||||
defiltered_query,
|
||||
user,
|
||||
meta_log,
|
||||
location,
|
||||
list(operator_results)[-1] if operator_results else None,
|
||||
query_images=uploaded_images,
|
||||
query_files=attached_file_context,
|
||||
send_status_func=partial(send_event, ChatEvent.STATUS),
|
||||
@@ -1306,16 +1317,17 @@ async def chat(
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
else:
|
||||
operator_results = {result["query"]: result["result"]}
|
||||
elif isinstance(result, OperatorRun):
|
||||
if not operator_results or operator_results[-1] is not result:
|
||||
operator_results.append(result)
|
||||
# Add webpages visited while operating browser to references
|
||||
if result.get("webpages"):
|
||||
if result.webpages:
|
||||
if not online_results.get(defiltered_query):
|
||||
online_results[defiltered_query] = {"webpages": result["webpages"]}
|
||||
online_results[defiltered_query] = {"webpages": result.webpages}
|
||||
elif not online_results[defiltered_query].get("webpages"):
|
||||
online_results[defiltered_query]["webpages"] = result["webpages"]
|
||||
online_results[defiltered_query]["webpages"] = result.webpages
|
||||
else:
|
||||
online_results[defiltered_query]["webpages"] += result["webpages"]
|
||||
online_results[defiltered_query]["webpages"] += result.webpages
|
||||
except ValueError as e:
|
||||
program_execution_context.append(f"Browser operation error: {e}")
|
||||
logger.warning(f"Failed to operate browser with {e}", exc_info=True)
|
||||
@@ -1333,7 +1345,6 @@ async def chat(
|
||||
"context": compiled_references,
|
||||
"onlineContext": unique_online_results,
|
||||
"codeContext": code_results,
|
||||
"operatorContext": operator_results,
|
||||
},
|
||||
):
|
||||
yield result
|
||||
|
||||
@@ -94,7 +94,8 @@ from khoj.processor.conversation.openai.gpt import (
|
||||
)
|
||||
from khoj.processor.conversation.utils import (
|
||||
ChatEvent,
|
||||
InformationCollectionIteration,
|
||||
OperatorRun,
|
||||
ResearchIteration,
|
||||
ResponseWithThought,
|
||||
clean_json,
|
||||
clean_mermaidjs,
|
||||
@@ -385,7 +386,7 @@ async def aget_data_sources_and_output_format(
|
||||
if len(agent_outputs) == 0 or output.value in agent_outputs:
|
||||
output_options_str += f'- "{output.value}": "{description}"\n'
|
||||
|
||||
chat_history = construct_chat_history(conversation_history)
|
||||
chat_history = construct_chat_history(conversation_history, n=6)
|
||||
|
||||
if query_images:
|
||||
query = f"[placeholder for {len(query_images)} user attached images]\n{query}"
|
||||
@@ -1174,12 +1175,7 @@ async def send_message_to_model_wrapper(
|
||||
if vision_available and query_images:
|
||||
logger.info(f"Using {chat_model.name} model to understand {len(query_images)} images.")
|
||||
|
||||
subscribed = await ais_user_subscribed(user) if user else False
|
||||
max_tokens = (
|
||||
chat_model.subscribed_max_prompt_size
|
||||
if subscribed and chat_model.subscribed_max_prompt_size
|
||||
else chat_model.max_prompt_size
|
||||
)
|
||||
max_tokens = await ConversationAdapters.aget_max_context_size(chat_model, user)
|
||||
chat_model_name = chat_model.name
|
||||
tokenizer = chat_model.tokenizer
|
||||
model_type = chat_model.model_type
|
||||
@@ -1271,12 +1267,7 @@ def send_message_to_model_wrapper_sync(
|
||||
if chat_model is None:
|
||||
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
|
||||
|
||||
subscribed = is_user_subscribed(user) if user else False
|
||||
max_tokens = (
|
||||
chat_model.subscribed_max_prompt_size
|
||||
if subscribed and chat_model.subscribed_max_prompt_size
|
||||
else chat_model.max_prompt_size
|
||||
)
|
||||
max_tokens = ConversationAdapters.get_max_context_size(chat_model, user)
|
||||
chat_model_name = chat_model.name
|
||||
model_type = chat_model.model_type
|
||||
vision_available = chat_model.vision_enabled
|
||||
@@ -1355,8 +1346,8 @@ async def agenerate_chat_response(
|
||||
compiled_references: List[Dict] = [],
|
||||
online_results: Dict[str, Dict] = {},
|
||||
code_results: Dict[str, Dict] = {},
|
||||
operator_results: Dict[str, str] = {},
|
||||
research_results: List[InformationCollectionIteration] = [],
|
||||
operator_results: List[OperatorRun] = [],
|
||||
research_results: List[ResearchIteration] = [],
|
||||
inferred_queries: List[str] = [],
|
||||
conversation_commands: List[ConversationCommand] = [ConversationCommand.Default],
|
||||
user: KhojUser = None,
|
||||
@@ -1414,7 +1405,7 @@ async def agenerate_chat_response(
|
||||
compiled_references = []
|
||||
online_results = {}
|
||||
code_results = {}
|
||||
operator_results = {}
|
||||
operator_results = []
|
||||
deepthought = True
|
||||
|
||||
chat_model = await ConversationAdapters.aget_valid_chat_model(user, conversation, is_subscribed)
|
||||
|
||||
@@ -13,12 +13,13 @@ from khoj.database.adapters import AgentAdapters, EntryAdapters
|
||||
from khoj.database.models import Agent, KhojUser
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.utils import (
|
||||
InformationCollectionIteration,
|
||||
OperatorRun,
|
||||
ResearchIteration,
|
||||
construct_iteration_history,
|
||||
construct_tool_chat_history,
|
||||
load_complex_json,
|
||||
)
|
||||
from khoj.processor.operator.operate_browser import operate_browser
|
||||
from khoj.processor.operator import operate_environment
|
||||
from khoj.processor.tools.online_search import read_webpages, search_online
|
||||
from khoj.processor.tools.run_code import run_code
|
||||
from khoj.routers.api import extract_references_and_questions
|
||||
@@ -83,7 +84,7 @@ async def apick_next_tool(
|
||||
location: LocationData = None,
|
||||
user_name: str = None,
|
||||
agent: Agent = None,
|
||||
previous_iterations: List[InformationCollectionIteration] = [],
|
||||
previous_iterations: List[ResearchIteration] = [],
|
||||
max_iterations: int = 5,
|
||||
query_images: List[str] = [],
|
||||
query_files: str = None,
|
||||
@@ -95,6 +96,24 @@ async def apick_next_tool(
|
||||
):
|
||||
"""Given a query, determine which of the available tools the agent should use in order to answer appropriately."""
|
||||
|
||||
# Continue with previous iteration if a multi-step tool use is in progress
|
||||
if (
|
||||
previous_iterations
|
||||
and previous_iterations[-1].tool == ConversationCommand.Operator
|
||||
and not previous_iterations[-1].summarizedResult
|
||||
):
|
||||
previous_iteration = previous_iterations[-1]
|
||||
yield ResearchIteration(
|
||||
tool=previous_iteration.tool,
|
||||
query=query,
|
||||
context=previous_iteration.context,
|
||||
onlineContext=previous_iteration.onlineContext,
|
||||
codeContext=previous_iteration.codeContext,
|
||||
operatorContext=previous_iteration.operatorContext,
|
||||
warning=previous_iteration.warning,
|
||||
)
|
||||
return
|
||||
|
||||
# Construct tool options for the agent to choose from
|
||||
tool_options = dict()
|
||||
tool_options_str = ""
|
||||
@@ -165,7 +184,7 @@ async def apick_next_tool(
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to infer information sources to refer: {e}", exc_info=True)
|
||||
yield InformationCollectionIteration(
|
||||
yield ResearchIteration(
|
||||
tool=None,
|
||||
query=None,
|
||||
warning="Failed to infer information sources to refer. Skipping iteration. Try again.",
|
||||
@@ -194,26 +213,26 @@ async def apick_next_tool(
|
||||
async for event in send_status_func(f"{scratchpad}"):
|
||||
yield {ChatEvent.STATUS: event}
|
||||
|
||||
yield InformationCollectionIteration(
|
||||
yield ResearchIteration(
|
||||
tool=selected_tool,
|
||||
query=generated_query,
|
||||
warning=warning,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Invalid response for determining relevant tools: {response}. {e}", exc_info=True)
|
||||
yield InformationCollectionIteration(
|
||||
yield ResearchIteration(
|
||||
tool=None,
|
||||
query=None,
|
||||
warning=f"Invalid response for determining relevant tools: {response}. Skipping iteration. Fix error: {e}",
|
||||
)
|
||||
|
||||
|
||||
async def execute_information_collection(
|
||||
async def research(
|
||||
user: KhojUser,
|
||||
query: str,
|
||||
conversation_id: str,
|
||||
conversation_history: dict,
|
||||
previous_iterations: List[InformationCollectionIteration],
|
||||
previous_iterations: List[ResearchIteration],
|
||||
query_images: List[str],
|
||||
agent: Agent = None,
|
||||
send_status_func: Optional[Callable] = None,
|
||||
@@ -248,9 +267,9 @@ async def execute_information_collection(
|
||||
online_results: Dict = dict()
|
||||
code_results: Dict = dict()
|
||||
document_results: List[Dict[str, str]] = []
|
||||
operator_results: Dict[str, str] = {}
|
||||
operator_results: OperatorRun = None
|
||||
summarize_files: str = ""
|
||||
this_iteration = InformationCollectionIteration(tool=None, query=query)
|
||||
this_iteration = ResearchIteration(tool=None, query=query)
|
||||
|
||||
async for result in apick_next_tool(
|
||||
query,
|
||||
@@ -271,8 +290,9 @@ async def execute_information_collection(
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
elif isinstance(result, InformationCollectionIteration):
|
||||
elif isinstance(result, ResearchIteration):
|
||||
this_iteration = result
|
||||
yield this_iteration
|
||||
|
||||
# Skip running iteration if warning present in iteration
|
||||
if this_iteration.warning:
|
||||
@@ -417,12 +437,13 @@ async def execute_information_collection(
|
||||
|
||||
elif this_iteration.tool == ConversationCommand.Operator:
|
||||
try:
|
||||
async for result in operate_browser(
|
||||
async for result in operate_environment(
|
||||
this_iteration.query,
|
||||
user,
|
||||
construct_tool_chat_history(previous_iterations, ConversationCommand.Operator),
|
||||
location,
|
||||
send_status_func,
|
||||
previous_iterations[-1].operatorContext if previous_iterations else None,
|
||||
send_status_func=send_status_func,
|
||||
query_images=query_images,
|
||||
agent=agent,
|
||||
query_files=query_files,
|
||||
@@ -431,17 +452,17 @@ async def execute_information_collection(
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
else:
|
||||
operator_results = {result["query"]: result["result"]}
|
||||
elif isinstance(result, OperatorRun):
|
||||
operator_results = result
|
||||
this_iteration.operatorContext = operator_results
|
||||
# Add webpages visited while operating browser to references
|
||||
if result.get("webpages"):
|
||||
if result.webpages:
|
||||
if not online_results.get(this_iteration.query):
|
||||
online_results[this_iteration.query] = {"webpages": result["webpages"]}
|
||||
online_results[this_iteration.query] = {"webpages": result.webpages}
|
||||
elif not online_results[this_iteration.query].get("webpages"):
|
||||
online_results[this_iteration.query]["webpages"] = result["webpages"]
|
||||
online_results[this_iteration.query]["webpages"] = result.webpages
|
||||
else:
|
||||
online_results[this_iteration.query]["webpages"] += result["webpages"]
|
||||
online_results[this_iteration.query]["webpages"] += result.webpages
|
||||
this_iteration.onlineContext = online_results
|
||||
except Exception as e:
|
||||
this_iteration.warning = f"Error operating browser: {e}"
|
||||
@@ -489,7 +510,9 @@ async def execute_information_collection(
|
||||
if code_results:
|
||||
results_data += f"\n<code_results>\n{yaml.dump(truncate_code_context(code_results), allow_unicode=True, sort_keys=False, default_flow_style=False)}\n</code_results>"
|
||||
if operator_results:
|
||||
results_data += f"\n<browser_operator_results>\n{next(iter(operator_results.values()))}\n</browser_operator_results>"
|
||||
results_data += (
|
||||
f"\n<browser_operator_results>\n{operator_results.response}\n</browser_operator_results>"
|
||||
)
|
||||
if summarize_files:
|
||||
results_data += f"\n<summarized_files>\n{yaml.dump(summarize_files, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n</summarized_files>"
|
||||
if this_iteration.warning:
|
||||
|
||||
@@ -18,8 +18,8 @@ default_offline_chat_models = [
|
||||
"bartowski/Qwen2.5-14B-Instruct-GGUF",
|
||||
]
|
||||
default_openai_chat_models = ["gpt-4o-mini", "gpt-4.1"]
|
||||
default_gemini_chat_models = ["gemini-2.0-flash", "gemini-2.5-flash-preview-04-17", "gemini-2.5-pro-preview-03-25"]
|
||||
default_anthropic_chat_models = ["claude-3-7-sonnet-latest", "claude-3-5-haiku-latest"]
|
||||
default_gemini_chat_models = ["gemini-2.0-flash", "gemini-2.5-flash-preview-05-20", "gemini-2.5-pro-preview-05-06"]
|
||||
default_anthropic_chat_models = ["claude-sonnet-4-0", "claude-3-5-haiku-latest"]
|
||||
|
||||
empty_config = {
|
||||
"search-type": {
|
||||
@@ -63,10 +63,10 @@ model_to_cost: Dict[str, Dict[str, float]] = {
|
||||
"claude-3-7-sonnet-20250219": {"input": 3.0, "output": 15.0, "cache_read": 0.3, "cache_write": 3.75},
|
||||
"claude-3-7-sonnet@20250219": {"input": 3.0, "output": 15.0, "cache_read": 0.3, "cache_write": 3.75},
|
||||
"claude-3-7-sonnet-latest": {"input": 3.0, "output": 15.0, "cache_read": 0.3, "cache_write": 3.75},
|
||||
"claude-sonnet-4": {"input": 3.0, "output": 15.0, "cache_read": 0.3, "cache_write": 3.75},
|
||||
"claude-sonnet-4-0": {"input": 3.0, "output": 15.0, "cache_read": 0.3, "cache_write": 3.75},
|
||||
"claude-sonnet-4-20250514": {"input": 3.0, "output": 15.0, "cache_read": 0.3, "cache_write": 3.75},
|
||||
"claude-sonnet-4@20250514": {"input": 3.0, "output": 15.0, "cache_read": 0.3, "cache_write": 3.75},
|
||||
"claude-opus-4": {"input": 15.0, "output": 75.0, "cache_read": 1.50, "cache_write": 18.75},
|
||||
"claude-opus-4-0": {"input": 15.0, "output": 75.0, "cache_read": 1.50, "cache_write": 18.75},
|
||||
"claude-opus-4-20250514": {"input": 15.0, "output": 75.0, "cache_read": 1.50, "cache_write": 18.75},
|
||||
"claude-opus-4@20250514": {"input": 15.0, "output": 75.0, "cache_read": 1.50, "cache_write": 18.75},
|
||||
# Grok pricing: https://docs.x.ai/docs/models
|
||||
|
||||
@@ -46,6 +46,7 @@ if TYPE_CHECKING:
|
||||
from khoj.utils.models import BaseEncoder
|
||||
from khoj.utils.rawconfig import AppConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Initialize Magika for file type identification
|
||||
magika = Magika()
|
||||
@@ -364,7 +365,7 @@ command_descriptions = {
|
||||
ConversationCommand.Summarize: "Get help with a question pertaining to an entire document.",
|
||||
ConversationCommand.Diagram: "Draw a flowchart, diagram, or any other visual representation best expressed with primitives like lines, rectangles, and text.",
|
||||
ConversationCommand.Research: "Do deep research on a topic. This will take longer than usual, but give a more detailed, comprehensive answer.",
|
||||
ConversationCommand.Operator: "Operate and perform tasks using a GUI web browser.",
|
||||
ConversationCommand.Operator: "Operate and perform tasks using a computer.",
|
||||
}
|
||||
|
||||
command_descriptions_for_agent = {
|
||||
@@ -373,12 +374,12 @@ command_descriptions_for_agent = {
|
||||
ConversationCommand.Online: "Agent can search the internet for information.",
|
||||
ConversationCommand.Webpage: "Agent can read suggested web pages for information.",
|
||||
ConversationCommand.Research: "Agent can do deep research on a topic.",
|
||||
ConversationCommand.Code: "Agent can run Python code to parse information, run complex calculations, create documents and charts.",
|
||||
ConversationCommand.Operator: "Agent can operate and perform actions using a GUI web browser to complete a task.",
|
||||
ConversationCommand.Code: "Agent can run a Python script to parse information, run complex calculations, create documents and charts.",
|
||||
ConversationCommand.Operator: "Agent can operate a computer to complete tasks.",
|
||||
}
|
||||
|
||||
e2b_tool_description = "To run Python code in a E2B sandbox with no network access. Helpful to parse complex information, run calculations, create text documents and create charts with quantitative data. Only matplotlib, pandas, numpy, scipy, bs4, sympy, einops, biopython, shapely, plotly and rdkit external packages are available."
|
||||
terrarium_tool_description = "To run Python code in a Terrarium, Pyodide sandbox with no network access. Helpful to parse complex information, run complex calculations, create plaintext documents and create charts with quantitative data. Only matplotlib, panda, numpy, scipy, bs4 and sympy external packages are available."
|
||||
e2b_tool_description = "To run a Python script in a E2B sandbox with no network access. Helpful to parse complex information, run calculations, create text documents and create charts with quantitative data. Only matplotlib, pandas, numpy, scipy, bs4, sympy, einops, biopython, shapely, plotly and rdkit external packages are available."
|
||||
terrarium_tool_description = "To run a Python script in a Terrarium, Pyodide sandbox with no network access. Helpful to parse complex information, run complex calculations, create plaintext documents and create charts with quantitative data. Only matplotlib, panda, numpy, scipy, bs4 and sympy external packages are available."
|
||||
|
||||
tool_descriptions_for_llm = {
|
||||
ConversationCommand.Default: "To use a mix of your internal knowledge and the user's personal knowledge, or if you don't entirely understand the query.",
|
||||
@@ -387,7 +388,7 @@ tool_descriptions_for_llm = {
|
||||
ConversationCommand.Online: "To search for the latest, up-to-date information from the internet. Note: **Questions about Khoj should always use this data source**",
|
||||
ConversationCommand.Webpage: "To use if the user has directly provided the webpage urls or you are certain of the webpage urls to read.",
|
||||
ConversationCommand.Code: e2b_tool_description if is_e2b_code_sandbox_enabled() else terrarium_tool_description,
|
||||
ConversationCommand.Operator: "To use when you need to operate and take actions using a GUI web browser.",
|
||||
ConversationCommand.Operator: "To use when you need to operate a computer to complete the task.",
|
||||
}
|
||||
|
||||
tool_description_for_research_llm = {
|
||||
@@ -396,7 +397,7 @@ tool_description_for_research_llm = {
|
||||
ConversationCommand.Webpage: "To extract information from webpages. Useful for more detailed research from the internet. Usually used when you know the webpage links to refer to. Share upto {max_webpages_to_read} webpage links and what information to extract from them in your query.",
|
||||
ConversationCommand.Code: e2b_tool_description if is_e2b_code_sandbox_enabled() else terrarium_tool_description,
|
||||
ConversationCommand.Text: "To respond to the user once you've completed your research and have the required information.",
|
||||
ConversationCommand.Operator: "To operate and take actions using a GUI web browser.",
|
||||
ConversationCommand.Operator: "To operate a computer to complete the task.",
|
||||
}
|
||||
|
||||
mode_descriptions_for_llm = {
|
||||
@@ -493,13 +494,7 @@ def is_promptrace_enabled():
|
||||
def is_operator_enabled():
|
||||
"""Check if Khoj can operate GUI applications.
|
||||
Set KHOJ_OPERATOR_ENABLED env var to true and install playwright to enable it."""
|
||||
try:
|
||||
import playwright
|
||||
|
||||
is_playwright_installed = True
|
||||
except ImportError:
|
||||
is_playwright_installed = False
|
||||
return is_env_var_true("KHOJ_OPERATOR_ENABLED") and is_playwright_installed
|
||||
return is_env_var_true("KHOJ_OPERATOR_ENABLED")
|
||||
|
||||
|
||||
def is_valid_url(url: str) -> bool:
|
||||
@@ -686,7 +681,7 @@ def get_chat_usage_metrics(
|
||||
"cache_write_tokens": 0,
|
||||
"cost": 0.0,
|
||||
}
|
||||
return {
|
||||
current_usage = {
|
||||
"input_tokens": prev_usage["input_tokens"] + input_tokens,
|
||||
"output_tokens": prev_usage["output_tokens"] + output_tokens,
|
||||
"thought_tokens": prev_usage.get("thought_tokens", 0) + thought_tokens,
|
||||
@@ -703,6 +698,8 @@ def get_chat_usage_metrics(
|
||||
prev_cost=prev_usage["cost"],
|
||||
),
|
||||
}
|
||||
logger.debug(f"AI API usage by {model_name}: {current_usage}")
|
||||
return current_usage
|
||||
|
||||
|
||||
class AiApiInfo(NamedTuple):
|
||||
|
||||
Reference in New Issue
Block a user