Add current cursor position to browser screenshots for ai, human view

This commit is contained in:
Debanjum
2025-05-09 14:46:03 -06:00
parent 1be3986537
commit a1d712e031

View File

@@ -1,10 +1,11 @@
import asyncio
import base64
import io
import logging
import os
from typing import Optional, Set
from khoj.processor.operator.operator_actions import OperatorAction
from khoj.processor.operator.operator_actions import OperatorAction, Point
from khoj.processor.operator.operator_environment_base import (
Environment,
EnvState,
@@ -33,6 +34,7 @@ class BrowserEnvironment(Environment):
self.visited_urls: Set[str] = set()
self.excluded_urls = {"about:blank", "https://duckduckgo.com", "https://www.bing.com", "https://www.google.com"}
self.navigation_history: list[str] = []
self.mouse_pos = Point(x=self.width / 2, y=self.height / 2)
async def start(self, width: int = 1024, height: int = 768) -> None:
self.width = width
@@ -93,12 +95,33 @@ class BrowserEnvironment(Environment):
return None
try:
screenshot_bytes = await self.page.screenshot(caret="initial", full_page=False, type="png")
# Draw mouse position on the screenshot image
if self.mouse_pos:
screenshot_bytes = await self._draw_mouse_position(screenshot_bytes, self.mouse_pos)
screenshot_webp_bytes = convert_image_to_webp(screenshot_bytes)
return base64.b64encode(screenshot_webp_bytes).decode("utf-8")
except Exception as e:
logger.error(f"Failed to get screenshot: {e}")
return None
async def _draw_mouse_position(self, screenshot_bytes: bytes, mouse_pos: Point) -> bytes:
from PIL import Image, ImageDraw
# Load the screenshot into a PIL image
image = Image.open(io.BytesIO(screenshot_bytes))
# Draw a red circle at the mouse position
draw = ImageDraw.Draw(image)
radius = 5
draw.ellipse(
(mouse_pos.x - radius, mouse_pos.y - radius, mouse_pos.x + radius, mouse_pos.y + radius), fill="red"
)
# Save the modified image to a bytes buffer
output_buffer = io.BytesIO()
image.save(output_buffer, format="PNG")
return output_buffer.getvalue()
async def get_state(self) -> EnvState:
if not self.page or self.page.is_closed():
return "about:blank", None
@@ -127,17 +150,20 @@ class BrowserEnvironment(Environment):
for modifier in reversed(modifiers):
await self.page.keyboard.up(modifier)
output = f"{button.capitalize()} clicked at ({x}, {y})"
self.mouse_pos = Point(x=x, y=y)
logger.debug(f"Action: {action.type} {button} at ({x},{y})")
case "double_click":
x, y = action.x, action.y
await self.page.mouse.dblclick(x, y)
self.mouse_pos = Point(x=x, y=y)
output = f"Double clicked at ({x}, {y})"
logger.debug(f"Action: {action.type} at ({x},{y})")
case "triple_click":
x, y = action.x, action.y
await self.page.mouse.click(x, y, click_count=3)
self.mouse_pos = Point(x=x, y=y)
output = f"Triple clicked at ({x}, {y})"
logger.debug(f"Action: {action.type} at ({x},{y})")
@@ -148,6 +174,7 @@ class BrowserEnvironment(Environment):
scroll_y = action.scroll_y or 0
if action.x is not None and action.y is not None:
await self.page.mouse.move(action.x, action.y)
self.mouse_pos = Point(x=action.x, y=action.y)
await self.page.evaluate(f"window.scrollBy({scroll_x}, {scroll_y})")
output = f"Scrolled by ({scroll_x}, {scroll_y})"
logger.debug(f"Action: {action.type} by ({scroll_x},{scroll_y}) at ({action.x},{action.y})")
@@ -166,6 +193,7 @@ class BrowserEnvironment(Environment):
if action.x is not None and action.y is not None:
await self.page.mouse.move(action.x, action.y)
self.mouse_pos = Point(x=action.x, y=action.y)
await self.page.mouse.wheel(dx, dy)
output = f"Scrolled {action.scroll_direction} by {amount}"
logger.debug(
@@ -210,6 +238,7 @@ class BrowserEnvironment(Environment):
case "move":
x, y = action.x, action.y
await self.page.mouse.move(x, y)
self.mouse_pos = Point(x=x, y=y)
output = f"Moved mouse to ({x}, {y})"
logger.debug(f"Action: {action.type} to ({x},{y})")
@@ -223,6 +252,7 @@ class BrowserEnvironment(Environment):
for point in path[1:]:
await self.page.mouse.move(point.x, point.y)
await self.page.mouse.up()
self.mouse_pos = Point(x=path[-1].x, y=path[-1].y)
output = f"Drag along path starting at ({path[0].x},{path[0].y})"
logger.debug(f"Action: {action.type} with {len(path)} points")