diff --git a/src/khoj/processor/operator/operator_environment_browser.py b/src/khoj/processor/operator/operator_environment_browser.py index d4595a5e..64db38ab 100644 --- a/src/khoj/processor/operator/operator_environment_browser.py +++ b/src/khoj/processor/operator/operator_environment_browser.py @@ -1,10 +1,11 @@ import asyncio import base64 +import io import logging import os from typing import Optional, Set -from khoj.processor.operator.operator_actions import OperatorAction +from khoj.processor.operator.operator_actions import OperatorAction, Point from khoj.processor.operator.operator_environment_base import ( Environment, EnvState, @@ -33,6 +34,7 @@ class BrowserEnvironment(Environment): self.visited_urls: Set[str] = set() self.excluded_urls = {"about:blank", "https://duckduckgo.com", "https://www.bing.com", "https://www.google.com"} self.navigation_history: list[str] = [] + self.mouse_pos = Point(x=self.width / 2, y=self.height / 2) async def start(self, width: int = 1024, height: int = 768) -> None: self.width = width @@ -93,12 +95,33 @@ class BrowserEnvironment(Environment): return None try: screenshot_bytes = await self.page.screenshot(caret="initial", full_page=False, type="png") + # Draw mouse position on the screenshot image + if self.mouse_pos: + screenshot_bytes = await self._draw_mouse_position(screenshot_bytes, self.mouse_pos) screenshot_webp_bytes = convert_image_to_webp(screenshot_bytes) return base64.b64encode(screenshot_webp_bytes).decode("utf-8") except Exception as e: logger.error(f"Failed to get screenshot: {e}") return None + async def _draw_mouse_position(self, screenshot_bytes: bytes, mouse_pos: Point) -> bytes: + from PIL import Image, ImageDraw + + # Load the screenshot into a PIL image + image = Image.open(io.BytesIO(screenshot_bytes)) + + # Draw a red circle at the mouse position + draw = ImageDraw.Draw(image) + radius = 5 + draw.ellipse( + (mouse_pos.x - radius, mouse_pos.y - radius, mouse_pos.x + radius, mouse_pos.y + radius), fill="red" + ) + + # Save the modified image to a bytes buffer + output_buffer = io.BytesIO() + image.save(output_buffer, format="PNG") + return output_buffer.getvalue() + async def get_state(self) -> EnvState: if not self.page or self.page.is_closed(): return "about:blank", None @@ -127,17 +150,20 @@ class BrowserEnvironment(Environment): for modifier in reversed(modifiers): await self.page.keyboard.up(modifier) output = f"{button.capitalize()} clicked at ({x}, {y})" + self.mouse_pos = Point(x=x, y=y) logger.debug(f"Action: {action.type} {button} at ({x},{y})") case "double_click": x, y = action.x, action.y await self.page.mouse.dblclick(x, y) + self.mouse_pos = Point(x=x, y=y) output = f"Double clicked at ({x}, {y})" logger.debug(f"Action: {action.type} at ({x},{y})") case "triple_click": x, y = action.x, action.y await self.page.mouse.click(x, y, click_count=3) + self.mouse_pos = Point(x=x, y=y) output = f"Triple clicked at ({x}, {y})" logger.debug(f"Action: {action.type} at ({x},{y})") @@ -148,6 +174,7 @@ class BrowserEnvironment(Environment): scroll_y = action.scroll_y or 0 if action.x is not None and action.y is not None: await self.page.mouse.move(action.x, action.y) + self.mouse_pos = Point(x=action.x, y=action.y) await self.page.evaluate(f"window.scrollBy({scroll_x}, {scroll_y})") output = f"Scrolled by ({scroll_x}, {scroll_y})" logger.debug(f"Action: {action.type} by ({scroll_x},{scroll_y}) at ({action.x},{action.y})") @@ -166,6 +193,7 @@ class BrowserEnvironment(Environment): if action.x is not None and action.y is not None: await self.page.mouse.move(action.x, action.y) + self.mouse_pos = Point(x=action.x, y=action.y) await self.page.mouse.wheel(dx, dy) output = f"Scrolled {action.scroll_direction} by {amount}" logger.debug( @@ -210,6 +238,7 @@ class BrowserEnvironment(Environment): case "move": x, y = action.x, action.y await self.page.mouse.move(x, y) + self.mouse_pos = Point(x=x, y=y) output = f"Moved mouse to ({x}, {y})" logger.debug(f"Action: {action.type} to ({x},{y})") @@ -223,6 +252,7 @@ class BrowserEnvironment(Environment): for point in path[1:]: await self.page.mouse.move(point.x, point.y) await self.page.mouse.up() + self.mouse_pos = Point(x=path[-1].x, y=path[-1].y) output = f"Drag along path starting at ({path[0].x},{path[0].y})" logger.debug(f"Action: {action.type} with {len(path)} points")