diff --git a/.github/workflows/run_evals.yml b/.github/workflows/run_evals.yml index 914a0835..2c8e9688 100644 --- a/.github/workflows/run_evals.yml +++ b/.github/workflows/run_evals.yml @@ -32,6 +32,19 @@ on: required: false default: 200 type: number + sandbox: + description: 'Code sandbox to use' + required: false + default: 'terrarium' + type: choice + options: + - terrarium + - e2b + chat_model: + description: 'Chat model to use' + required: false + default: 'gemini-2.0-flash' + type: string jobs: eval: @@ -40,7 +53,7 @@ jobs: matrix: # Use input from manual trigger if available, else run all combinations khoj_mode: ${{ github.event_name == 'workflow_dispatch' && fromJSON(format('["{0}"]', inputs.khoj_mode)) || fromJSON('["general", "default", "research"]') }} - dataset: ${{ github.event_name == 'workflow_dispatch' && fromJSON(format('["{0}"]', inputs.dataset)) || fromJSON('["frames", "simpleqa"]') }} + dataset: ${{ github.event_name == 'workflow_dispatch' && fromJSON(format('["{0}"]', inputs.dataset)) || fromJSON('["frames", "simpleqa", "gpqa"]') }} services: postgres: @@ -95,11 +108,14 @@ jobs: BATCH_SIZE: "20" RANDOMIZE: "True" KHOJ_URL: "http://localhost:42110" + KHOJ_DEFAULT_CHAT_MODEL: ${{ github.event_name == 'workflow_dispatch' && inputs.chat_model || 'gemini-2.0-flash' }} KHOJ_LLM_SEED: "42" GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }} SERPER_DEV_API_KEY: ${{ matrix.dataset != 'math500' && secrets.SERPER_DEV_API_KEY }} OLOSTEP_API_KEY: ${{ matrix.dataset != 'math500' && secrets.OLOSTEP_API_KEY }} HF_TOKEN: ${{ secrets.HF_TOKEN }} + E2B_API_KEY: ${{ inputs.sandbox == 'e2b' && secrets.E2B_API_KEY }} + E2B_TEMPLATE: ${{ vars.E2B_TEMPLATE }} KHOJ_ADMIN_EMAIL: khoj KHOJ_ADMIN_PASSWORD: khoj POSTGRES_HOST: localhost @@ -114,7 +130,7 @@ jobs: # Start code sandbox npm install -g pm2 - npm run ci --prefix terrarium + NODE_ENV=production npm run ci --prefix terrarium # Wait for server to be ready timeout=120 @@ -147,7 +163,8 @@ jobs: echo "## Evaluation Summary of Khoj on ${{ matrix.dataset }} in ${{ matrix.khoj_mode }} mode" >> $GITHUB_STEP_SUMMARY echo "**$(head -n 1 *_evaluation_summary_*.txt)**" >> $GITHUB_STEP_SUMMARY echo "- Khoj Version: ${{ steps.hatch.outputs.version }}" >> $GITHUB_STEP_SUMMARY - echo "- Chat Model: Gemini 1.5 Flash 002" >> $GITHUB_STEP_SUMMARY + echo "- Chat Model: ${{ inputs.chat_model }}" >> $GITHUB_STEP_SUMMARY + echo "- Code Sandbox: ${{ inputs.sandbox}}" >> $GITHUB_STEP_SUMMARY echo "\`\`\`" >> $GITHUB_STEP_SUMMARY tail -n +2 *_evaluation_summary_*.txt >> $GITHUB_STEP_SUMMARY echo "" >> $GITHUB_STEP_SUMMARY diff --git a/docker-compose.yml b/docker-compose.yml index 22371182..3ff21b11 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -14,6 +14,11 @@ services: retries: 5 sandbox: image: ghcr.io/khoj-ai/terrarium:latest + healthcheck: + test: ["CMD-SHELL", "curl -f http://localhost:8080/health"] + interval: 30s + timeout: 10s + retries: 2 search: image: docker.io/searxng/searxng:latest volumes: @@ -53,8 +58,10 @@ services: - KHOJ_DEBUG=False - KHOJ_ADMIN_EMAIL=username@example.com - KHOJ_ADMIN_PASSWORD=password - # Default URL of Terrarium, the Python sandbox used by Khoj to run code. Its container is specified above + # Default URL of Terrarium, the default Python sandbox used by Khoj to run code. Its container is specified above - KHOJ_TERRARIUM_URL=http://sandbox:8080 + # Uncomment line below to have Khoj run code in remote E2B code sandbox instead of the self-hosted Terrarium sandbox above. Get your E2B API key from https://e2b.dev/. + # - E2B_API_KEY=your_e2b_api_key # Default URL of SearxNG, the default web search engine used by Khoj. Its container is specified above - KHOJ_SEARXNG_URL=http://search:8080 # Uncomment line below to use with Ollama running on your local machine at localhost:11434. diff --git a/documentation/docs/features/code_execution.md b/documentation/docs/features/code_execution.md index 8403d466..05c994d7 100644 --- a/documentation/docs/features/code_execution.md +++ b/documentation/docs/features/code_execution.md @@ -3,22 +3,23 @@ # Code Execution -Khoj can generate and run very simple Python code snippets as well. This is useful if you want to generate a plot, run a simple calculation, or do some basic data manipulation. LLMs by default aren't skilled at complex quantitative tasks. Code generation & execution can come in handy for such tasks. +Khoj can generate and run simple Python code as well. This is useful if you want to have Khoj do some data analysis, generate plots and reports. LLMs by default aren't skilled at complex quantitative tasks. Code generation & execution can come in handy for such tasks. -Just use `/code` in your chat command. +Khoj automatically infers when to use the code tool. You can also tell it explicitly to use the code tool or use the `/code` [slash command](https://docs.khoj.dev/features/chat/#commands) in your chat. -### Setup (Self-Hosting) -Run [Cohere's Terrarium](https://github.com/cohere-ai/cohere-terrarium) on your machine to enable code generation and execution. +## Setup (Self-Hosting) +### Terrarium Sandbox +Use [Cohere's Terrarium](https://github.com/cohere-ai/cohere-terrarium) to host the code sandbox locally on your machine for free. -Check the [instructions](https://github.com/cohere-ai/cohere-terrarium?tab=readme-ov-file#development) for running from source. - -For running with Docker, you can use our [docker-compose.yml](https://github.com/khoj-ai/khoj/blob/master/docker-compose.yml), or start it manually like this: +To run with Docker, use our [docker-compose.yml](https://github.com/khoj-ai/khoj/blob/master/docker-compose.yml) to automatically setup the Terrarium code sandbox, or start it manually like this: ```bash docker pull ghcr.io/khoj-ai/terrarium:latest docker run -d -p 8080:8080 ghcr.io/khoj-ai/terrarium:latest ``` +To run from source, check [these instructions](https://github.com/khoj-ai/cohere-terrarium?tab=readme-ov-file#development). + #### Verify Verify that it's running, by evaluating a simple Python expression: @@ -28,3 +29,12 @@ curl -X POST -H "Content-Type: application/json" \ --data-raw '{"code": "1 + 1"}' \ --no-buffer ``` + +### E2B Sandbox +[E2B](https://e2b.dev/) allows Khoj to run code on a remote but versatile sandbox with support for more python libraries. This is [not free](https://e2b.dev/pricing). + +To have Khoj use E2B as the code sandbox: +1. Generate an API key on [their dashboard](https://e2b.dev/dashboard). +2. Set the `E2B_API_KEY` environment variable to it on the machine running your Khoj server. + - When using our [docker-compose.yml](https://github.com/khoj-ai/khoj/blob/master/docker-compose.yml), uncomment and set the `E2B_API_KEY` env var in the `docker-compose.yml` file. +3. Now restart your Khoj server to switch to using the E2B code sandbox. diff --git a/documentation/docs/get-started/setup.mdx b/documentation/docs/get-started/setup.mdx index c6cdec42..fb8e9f4c 100644 --- a/documentation/docs/get-started/setup.mdx +++ b/documentation/docs/get-started/setup.mdx @@ -333,7 +333,7 @@ Using Ollama? See the [Ollama Integration](/advanced/ollama) section for more cu - Add your [Gemini API key](https://aistudio.google.com/app/apikey) - Give the configuration a friendly name like `Gemini`. Do not configure the API base url. 2. Create a new [chat model](http://localhost:42110/server/admin/database/chatmodel/add) - - Set the `chat-model` field to a [Google Gemini chat model](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-models). Example: `gemini-1.5-flash`. + - Set the `chat-model` field to a [Google Gemini chat model](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-models). Example: `gemini-2.0-flash`. - Set the `model-type` field to `Google`. - Set the `ai model api` field to the Gemini AI Model API you created in step 1. diff --git a/pyproject.toml b/pyproject.toml index 1ed9426e..5093fee7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,7 +68,7 @@ dependencies = [ "authlib == 1.2.1", "llama-cpp-python == 0.2.88", "itsdangerous == 2.1.2", - "httpx == 0.25.0", + "httpx == 0.27.2", "pgvector == 0.2.4", "psycopg2-binary == 2.9.9", "lxml == 4.9.3", @@ -92,6 +92,7 @@ dependencies = [ "pyjson5 == 1.6.7", "resend == 1.0.1", "email-validator == 2.2.0", + "e2b-code-interpreter ~= 1.0.0", ] dynamic = ["version"] diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 33e879aa..058017d2 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -1107,6 +1107,12 @@ class ConversationAdapters: return config.setting return ConversationAdapters.aget_advanced_chat_model(user) + @staticmethod + def get_chat_model_by_name(chat_model_name: str, ai_model_api_name: str = None): + if ai_model_api_name: + return ChatModel.objects.filter(name=chat_model_name, ai_model_api__name=ai_model_api_name).first() + return ChatModel.objects.filter(name=chat_model_name).first() + @staticmethod async def aget_voice_model_config(user: KhojUser) -> Optional[VoiceModelOption]: voice_model_config = await UserVoiceModelConfig.objects.filter(user=user).prefetch_related("setting").afirst() @@ -1205,6 +1211,15 @@ class ConversationAdapters: return server_chat_settings.chat_advanced return await ConversationAdapters.aget_default_chat_model(user) + @staticmethod + def set_default_chat_model(chat_model: ChatModel): + server_chat_settings = ServerChatSettings.objects.first() + if server_chat_settings: + server_chat_settings.chat_default = chat_model + server_chat_settings.save() + else: + ServerChatSettings.objects.create(chat_default=chat_model) + @staticmethod async def aget_server_webscraper(): server_chat_settings = await ServerChatSettings.objects.filter().prefetch_related("web_scraper").afirst() diff --git a/src/khoj/processor/conversation/google/gemini_chat.py b/src/khoj/processor/conversation/google/gemini_chat.py index 6fd95ccd..f4e52914 100644 --- a/src/khoj/processor/conversation/google/gemini_chat.py +++ b/src/khoj/processor/conversation/google/gemini_chat.py @@ -31,10 +31,10 @@ logger = logging.getLogger(__name__) def extract_questions_gemini( text, - model: Optional[str] = "gemini-1.5-flash", + model: Optional[str] = "gemini-2.0-flash", conversation_log={}, api_key=None, - temperature=0, + temperature=0.2, max_tokens=None, location_data: LocationData = None, user: KhojUser = None, @@ -121,7 +121,7 @@ def gemini_send_message_to_model( api_key, model, response_type="text", - temperature=0, + temperature=0.2, model_kwargs=None, tracer={}, ): @@ -132,9 +132,9 @@ def gemini_send_message_to_model( model_kwargs = {} - # Sometimes, this causes unwanted behavior and terminates response early. Disable for now while it's flaky. - # if response_type == "json_object": - # model_kwargs["response_mime_type"] = "application/json" + # This caused unwanted behavior and terminates response early for gemini 1.5 series. Monitor for flakiness with 2.0 series. + if response_type == "json_object" and model in ["gemini-2.0-flash"]: + model_kwargs["response_mime_type"] = "application/json" # Get Response from Gemini return gemini_completion_with_backoff( @@ -154,7 +154,7 @@ def converse_gemini( online_results: Optional[Dict[str, Dict]] = None, code_results: Optional[Dict[str, Dict]] = None, conversation_log={}, - model: Optional[str] = "gemini-1.5-flash", + model: Optional[str] = "gemini-2.0-flash", api_key: Optional[str] = None, temperature: float = 0.2, completion_func=None, diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index 61060572..0c2b3bbe 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -974,11 +974,9 @@ Khoj: python_code_generation_prompt = PromptTemplate.from_template( """ You are Khoj, an advanced python programmer. You are tasked with constructing a python program to best answer the user query. -- The python program will run in a pyodide python sandbox with no network access. +- The python program will run in a sandbox with no network access. - You can write programs to run complex calculations, analyze data, create charts, generate documents to meticulously answer the query. -- The sandbox has access to the standard library, matplotlib, panda, numpy, scipy, bs4 and sympy packages. The requests, torch, catboost, tensorflow and tkinter packages are not available. -- List known file paths to required user documents in "input_files" and known links to required documents from the web in the "input_links" field. -- The python program should be self-contained. It can only read data generated by the program itself and from provided input_files, input_links by their basename (i.e filename excluding file path). +- The python program should be self-contained. It can only read data generated by the program itself and any user file paths referenced in your program. - Do not try display images or plots in the code directly. The code should save the image or plot to a file instead. - Write any document, charts etc. to be shared with the user to file. These files can be seen by the user. - Use as much context from the previous questions and answers as required to generate your code. @@ -989,24 +987,99 @@ Current Date: {current_date} User's Location: {location} {username} -The response JSON schema is of the form {{"code": "", "input_files": ["file_path_1", "file_path_2"], "input_links": ["link_1", "link_2"]}} -Examples: +Your response should contain python code wrapped in markdown code blocks (i.e starting with```python and ending with ```) +Example 1: --- -{{ -"code": "# Input values\\nprincipal = 43235\\nrate = 5.24\\nyears = 5\\n\\n# Convert rate to decimal\\nrate_decimal = rate / 100\\n\\n# Calculate final amount\\nfinal_amount = principal * (1 + rate_decimal) ** years\\n\\n# Calculate interest earned\\ninterest_earned = final_amount - principal\\n\\n# Print results with formatting\\nprint(f"Interest Earned: ${{interest_earned:,.2f}}")\\nprint(f"Final Amount: ${{final_amount:,.2f}}")" -}} +Q: Calculate the interest earned and final amount for a principal of $43,235 invested at a rate of 5.24 percent for 5 years. +A: Ok, to calculate the interest earned and final amount, we can use the formula for compound interest: $T = P(1 + r/n)^{{nt}}$, +where T: total amount, P: principal, r: interest rate, n: number of times interest is compounded per year, and t: time in years. -{{ -"code": "import re\\n\\n# Read org file\\nfile_path = 'tasks.org'\\nwith open(file_path, 'r') as f:\\n content = f.read()\\n\\n# Get today's date in YYYY-MM-DD format\\ntoday = datetime.now().strftime('%Y-%m-%d')\\npattern = r'\*+\s+.*\\n.*SCHEDULED:\s+<' + today + r'.*>'\\n\\n# Find all matches using multiline mode\\nmatches = re.findall(pattern, content, re.MULTILINE)\\ncount = len(matches)\\n\\n# Display count\\nprint(f'Count of scheduled tasks for today: {{count}}')", -"input_files": ["/home/linux/tasks.org"] -}} +Let's write the Python program to calculate this. -{{ -"code": "import pandas as pd\\nimport matplotlib.pyplot as plt\\n\\n# Load the CSV file\\ndf = pd.read_csv('world_population_by_year.csv')\\n\\n# Plot the data\\nplt.figure(figsize=(10, 6))\\nplt.plot(df['Year'], df['Population'], marker='o')\\n\\n# Add titles and labels\\nplt.title('Population by Year')\\nplt.xlabel('Year')\\nplt.ylabel('Population')\\n\\n# Save the plot to a file\\nplt.savefig('population_by_year_plot.png')", -"input_links": ["https://population.un.org/world_population_by_year.csv"] -}} +```python +# Input values +principal = 43235 +rate = 5.24 +years = 5 + +# Convert rate to decimal +rate_decimal = rate / 100 + +# Calculate final amount +final_amount = principal * (1 + rate_decimal) ** years + +# Calculate interest earned +interest_earned = final_amount - principal + +# Print results with formatting +print(f"Interest Earned: ${{interest_earned:,.2f}}") +print(f"Final Amount: ${{final_amount:,.2f}}") +``` + +Example 2: +--- +Q: Simplify first, then evaluate: $-7x+2(x^{{2}}-1)-(2x^{{2}}-x+3)$, where $x=1$. +A: Certainly! Let's break down the problem step-by-step and utilize Python with SymPy to simplify and evaluate the expression. + +1. **Expression Simplification:** + We start with the expression \\(-7x + 2(x^2 - 1) - (2x^2 - x + 3)\\). + +2. **Substitute \\(x=1\\) into the simplified expression:** + Once simplified, we will substitute \\(x=1\\) into the expression to find its value. + +Let's implement this in Python using SymPy (as the package is available in the sandbox): + +```python +import sympy as sp + +# Define the variable +x = sp.symbols('x') + +# Define the expression +expression = -7*x + 2*(x**2 - 1) - (2*x**2 - x + 3) + +# Simplify the expression +simplified_expression = sp.simplify(expression) + +# Substitute x = 1 into the simplified expression +evaluated_expression = simplified_expression.subs(x, 1) + +# Print the simplified expression and its evaluated value +print(\"Simplified Expression:\", simplified_expression) +print(\"Evaluated Expression at x=1:\", evaluated_expression) +``` + +Example 3: +--- +Q: Plot the world population growth over the years, given this year, world population world tuples: [(2000, 6), (2001, 7), (2002, 8), (2003, 9), (2004, 10)]. +A: Absolutely! We can utilize the Pandas and Matplotlib libraries (as both are available in the sandbox) to create the world population growth plot. +```python +import pandas as pd +import matplotlib.pyplot as plt + +# Create a DataFrame of world population from the provided data +data = {{ + 'Year': [2000, 2001, 2002, 2003, 2004], + 'Population': [6, 7, 8, 9, 10] +}} +df = pd.DataFrame(data) + +# Plot the data +plt.figure(figsize=(10, 6)) +plt.plot(df['Year'], df['Population'], marker='o') + +# Add titles and labels +plt.title('Population by Year') +plt.xlabel('Year') +plt.ylabel('Population') + +# Save the plot to a file +plt.savefig('population_by_year_plot.png') +``` + +Now it's your turn to construct a python program to answer the user's query using the provided context and coversation provided below. +Ensure you include the python code to execute and wrap it in a markdown code block. -Now it's your turn to construct a python program to answer the user's question. Provide the code, required input files and input links in a JSON object. Do not say anything else. Context: --- {context} @@ -1015,8 +1088,9 @@ Chat History: --- {chat_history} -User: {query} -Khoj: +User Query: +--- +{query} """.strip() ) @@ -1030,6 +1104,13 @@ Code Execution Results: """.strip() ) +e2b_sandbox_context = """ +- The sandbox has access to only the standard library, matplotlib, pandas, numpy, scipy, bs4, sympy, einops, biopython, shapely, plotly and rdkit packages. The requests, torch, catboost, tensorflow and tkinter packages are not available. +""".strip() + +terrarium_sandbox_context = """ +The sandbox has access to the standard library, matplotlib, pandas, numpy, scipy, bs4 and sympy packages. The requests, torch, catboost, tensorflow, rdkit and tkinter packages are not available. +""".strip() # Automations # -- diff --git a/src/khoj/processor/tools/run_code.py b/src/khoj/processor/tools/run_code.py index 6c1eb125..12e65670 100644 --- a/src/khoj/processor/tools/run_code.py +++ b/src/khoj/processor/tools/run_code.py @@ -1,12 +1,23 @@ +import asyncio import base64 import datetime import logging import mimetypes import os +import re from pathlib import Path from typing import Any, Callable, List, NamedTuple, Optional import aiohttp +from asgiref.sync import sync_to_async +from httpx import RemoteProtocolError +from tenacity import ( + before_sleep_log, + retry, + retry_if_exception_type, + stop_after_attempt, + wait_random_exponential, +) from khoj.database.adapters import FileObjectAdapters from khoj.database.models import Agent, FileObject, KhojUser @@ -15,22 +26,26 @@ from khoj.processor.conversation.utils import ( ChatEvent, clean_code_python, construct_chat_history, - load_complex_json, ) from khoj.routers.helpers import send_message_to_model_wrapper -from khoj.utils.helpers import is_none_or_empty, timer, truncate_code_context +from khoj.utils.helpers import ( + is_e2b_code_sandbox_enabled, + is_none_or_empty, + timer, + truncate_code_context, +) from khoj.utils.rawconfig import LocationData logger = logging.getLogger(__name__) SANDBOX_URL = os.getenv("KHOJ_TERRARIUM_URL", "http://localhost:8080") +DEFAULT_E2B_TEMPLATE = "pmt2o0ghpang8gbiys57" class GeneratedCode(NamedTuple): code: str - input_files: List[str] - input_links: List[str] + input_files: List[FileObject] async def run_code( @@ -68,13 +83,10 @@ async def run_code( # Prepare Input Data input_data = [] - user_input_files: List[FileObject] = [] - for input_file in generated_code.input_files: - user_input_files += await FileObjectAdapters.aget_file_objects_by_name(user, input_file) - for f in user_input_files: + for f in generated_code.input_files: input_data.append( { - "filename": os.path.basename(f.file_name), + "filename": f.file_name, "b64_data": base64.b64encode(f.raw_text.encode("utf-8")).decode("utf-8"), } ) @@ -90,6 +102,14 @@ async def run_code( cleaned_result = truncate_code_context({"cleaned": {"results": result}})["cleaned"]["results"] logger.info(f"Executed Code\n----\n{code}\n----\nResult\n----\n{cleaned_result}\n----") yield {query: {"code": code, "results": result}} + except asyncio.TimeoutError as e: + # Call the sandbox_url/stop GET API endpoint to stop the code sandbox + error = f"Failed to run code for {query} with Timeout error: {e}" + try: + await aiohttp.ClientSession().get(f"{sandbox_url}/stop", timeout=5) + except Exception as e: + error += f"\n\nFailed to stop code sandbox with error: {e}" + raise ValueError(error) except Exception as e: raise ValueError(f"Failed to run code for {query} with error: {e}") @@ -114,6 +134,12 @@ async def generate_python_code( prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else "" ) + # add sandbox specific context like available packages + sandbox_context = ( + prompts.e2b_sandbox_context if is_e2b_code_sandbox_enabled() else prompts.terrarium_sandbox_context + ) + personality_context = f"{sandbox_context}\n{personality_context}" + code_generation_prompt = prompts.python_code_generation_prompt.format( current_date=utc_date, query=q, @@ -127,23 +153,50 @@ async def generate_python_code( response = await send_message_to_model_wrapper( code_generation_prompt, query_images=query_images, - response_type="json_object", user=user, tracer=tracer, query_files=query_files, ) - # Validate that the response is a non-empty, JSON-serializable list - response = load_complex_json(response) - code = response.get("code", "").strip() - input_files = response.get("input_files", []) - input_links = response.get("input_links", []) + # Extract python code wrapped in markdown code blocks from the response + code_blocks = re.findall(r"```(?:python)?\n(.*?)\n```", response, re.DOTALL) + + if not code_blocks: + raise ValueError("No Python code blocks found in response") + + # Join multiple code blocks with newlines and strip any leading/trailing whitespace + code = "\n".join(code_blocks).strip() if not isinstance(code, str) or is_none_or_empty(code): raise ValueError - return GeneratedCode(code, input_files, input_links) + + # Infer user files required in sandbox based on user file paths mentioned in code + input_files: List[FileObject] = [] + user_files = await sync_to_async(set)(FileObjectAdapters.get_all_file_objects(user)) + for user_file in user_files: + if user_file.file_name in code: + # Replace references to full file path used in code with just the file basename to ease reference in sandbox + file_basename = os.path.basename(user_file.file_name) + code = code.replace(user_file.file_name, file_basename) + user_file.file_name = file_basename + input_files.append(user_file) + + return GeneratedCode(code, input_files) +@retry( + retry=( + retry_if_exception_type(aiohttp.ClientError) + | retry_if_exception_type(aiohttp.ClientTimeout) + | retry_if_exception_type(asyncio.TimeoutError) + | retry_if_exception_type(ConnectionError) + | retry_if_exception_type(RemoteProtocolError) + ), + wait=wait_random_exponential(min=1, max=5), + stop=stop_after_attempt(3), + before_sleep=before_sleep_log(logger, logging.DEBUG), + reraise=True, +) async def execute_sandboxed_python(code: str, input_data: list[dict], sandbox_url: str = SANDBOX_URL) -> dict[str, Any]: """ Takes code to run as a string and calls the terrarium API to execute it. @@ -152,15 +205,104 @@ async def execute_sandboxed_python(code: str, input_data: list[dict], sandbox_ur Reference data i/o format based on Terrarium example client code at: https://github.com/cohere-ai/cohere-terrarium/blob/main/example-clients/python/terrarium_client.py """ - headers = {"Content-Type": "application/json"} cleaned_code = clean_code_python(code) - data = {"code": cleaned_code, "files": input_data} + if is_e2b_code_sandbox_enabled(): + try: + return await execute_e2b(cleaned_code, input_data) + except ImportError: + pass + return await execute_terrarium(cleaned_code, input_data, sandbox_url) + +async def execute_e2b(code: str, input_files: list[dict]) -> dict[str, Any]: + """Execute code and handle file I/O in e2b sandbox""" + from e2b_code_interpreter import AsyncSandbox + + sandbox = await AsyncSandbox.create( + api_key=os.getenv("E2B_API_KEY"), + template=os.getenv("E2B_TEMPLATE", DEFAULT_E2B_TEMPLATE), + timeout=120, + request_timeout=30, + ) + + try: + # Upload input files in parallel + upload_tasks = [ + sandbox.files.write(path=file["filename"], data=base64.b64decode(file["b64_data"]), request_timeout=30) + for file in input_files + ] + await asyncio.gather(*upload_tasks) + + # Note stored files before execution to identify new files created during execution + E2bFile = NamedTuple("E2bFile", [("name", str), ("path", str)]) + original_files = {E2bFile(f.name, f.path) for f in await sandbox.files.list("~")} + + # Execute code from main.py file + execution = await sandbox.run_code(code=code, timeout=60) + + # Collect output files + output_files = [] + + # Identify new files created during execution + new_files = set(E2bFile(f.name, f.path) for f in await sandbox.files.list("~")) - original_files + # Read newly created files in parallel + download_tasks = [sandbox.files.read(f.path, request_timeout=30) for f in new_files] + downloaded_files = await asyncio.gather(*download_tasks) + for f, content in zip(new_files, downloaded_files): + if isinstance(content, bytes): + # Binary files like PNG - encode as base64 + b64_data = base64.b64encode(content).decode("utf-8") + elif Path(f.name).suffix in [".png", ".jpeg", ".jpg", ".svg"]: + # Ignore image files as they are extracted from execution results below for inline display + continue + else: + # Text files - encode utf-8 string as base64 + b64_data = base64.b64encode(content.encode("utf-8")).decode("utf-8") + output_files.append({"filename": f.name, "b64_data": b64_data}) + + # Collect output files from execution results + for idx, result in enumerate(execution.results): + for result_type in {"png", "jpeg", "svg", "text", "markdown", "json"}: + if b64_data := getattr(result, result_type, None): + output_files.append({"filename": f"{idx}.{result_type}", "b64_data": b64_data}) + break + + # collect logs + success = not execution.error and not execution.logs.stderr + stdout = "\n".join(execution.logs.stdout) + errors = "\n".join(execution.logs.stderr) + if execution.error: + errors = f"{execution.error}\n{errors}" + + return { + "code": code, + "success": success, + "std_out": stdout, + "std_err": errors, + "output_files": output_files, + } + except Exception as e: + return { + "code": code, + "success": False, + "std_err": f"Sandbox failed to execute code: {str(e)}", + "output_files": [], + } + + +async def execute_terrarium( + code: str, + input_data: list[dict], + sandbox_url: str, +) -> dict[str, Any]: + """Execute code using Terrarium sandbox""" + headers = {"Content-Type": "application/json"} + data = {"code": code, "files": input_data} async with aiohttp.ClientSession() as session: - async with session.post(sandbox_url, json=data, headers=headers) as response: + async with session.post(sandbox_url, json=data, headers=headers, timeout=30) as response: if response.status == 200: result: dict[str, Any] = await response.json() - result["code"] = cleaned_code + result["code"] = code # Store decoded output files result["output_files"] = result.get("output_files", []) for output_file in result["output_files"]: @@ -172,7 +314,7 @@ async def execute_sandboxed_python(code: str, input_data: list[dict], sandbox_ur return result else: return { - "code": cleaned_code, + "code": code, "success": False, "std_err": f"Failed to execute code with {response.status}", "output_files": [], diff --git a/src/khoj/utils/constants.py b/src/khoj/utils/constants.py index b3ff1f97..74c06172 100644 --- a/src/khoj/utils/constants.py +++ b/src/khoj/utils/constants.py @@ -18,7 +18,7 @@ default_offline_chat_models = [ "bartowski/Qwen2.5-14B-Instruct-GGUF", ] default_openai_chat_models = ["gpt-4o-mini", "gpt-4o"] -default_gemini_chat_models = ["gemini-1.5-flash", "gemini-1.5-pro"] +default_gemini_chat_models = ["gemini-2.0-flash", "gemini-1.5-pro"] default_anthropic_chat_models = ["claude-3-5-sonnet-20241022", "claude-3-5-haiku-20241022"] empty_config = { @@ -46,6 +46,7 @@ model_to_cost: Dict[str, Dict[str, float]] = { "gemini-1.5-flash-002": {"input": 0.075, "output": 0.30}, "gemini-1.5-pro": {"input": 1.25, "output": 5.00}, "gemini-1.5-pro-002": {"input": 1.25, "output": 5.00}, + "gemini-2.0-flash": {"input": 0.10, "output": 0.40}, # Anthropic Pricing: https://www.anthropic.com/pricing#anthropic-api_ "claude-3-5-sonnet-20241022": {"input": 3.0, "output": 15.0}, "claude-3-5-haiku-20241022": {"input": 1.0, "output": 5.0}, diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py index b48436c6..4723403e 100644 --- a/src/khoj/utils/helpers.py +++ b/src/khoj/utils/helpers.py @@ -321,6 +321,12 @@ def get_device() -> torch.device: return torch.device("cpu") +def is_e2b_code_sandbox_enabled(): + """Check if E2B code sandbox is enabled. + Set E2B_API_KEY environment variable to use it.""" + return not is_none_or_empty(os.getenv("E2B_API_KEY")) + + class ConversationCommand(str, Enum): Default = "default" General = "general" @@ -362,20 +368,23 @@ command_descriptions_for_agent = { ConversationCommand.Code: "Agent can run Python code to parse information, run complex calculations, create documents and charts.", } +e2b_tool_description = "To run Python code in a E2B sandbox with no network access. Helpful to parse complex information, run calculations, create text documents and create charts with quantitative data. Only matplotlib, pandas, numpy, scipy, bs4, sympy, einops, biopython, shapely and rdkit external packages are available." +terrarium_tool_description = "To run Python code in a Terrarium, Pyodide sandbox with no network access. Helpful to parse complex information, run complex calculations, create plaintext documents and create charts with quantitative data. Only matplotlib, panda, numpy, scipy, bs4 and sympy external packages are available." + tool_descriptions_for_llm = { ConversationCommand.Default: "To use a mix of your internal knowledge and the user's personal knowledge, or if you don't entirely understand the query.", ConversationCommand.General: "To use when you can answer the question without any outside information or personal knowledge", ConversationCommand.Notes: "To search the user's personal knowledge base. Especially helpful if the question expects context from the user's notes or documents.", ConversationCommand.Online: "To search for the latest, up-to-date information from the internet. Note: **Questions about Khoj should always use this data source**", ConversationCommand.Webpage: "To use if the user has directly provided the webpage urls or you are certain of the webpage urls to read.", - ConversationCommand.Code: "To run Python code in a Pyodide sandbox with no network access. Helpful when need to parse complex information, run complex calculations, create plaintext documents, and create charts with quantitative data. Only matplotlib, panda, numpy, scipy, bs4 and sympy external packages are available.", + ConversationCommand.Code: e2b_tool_description if is_e2b_code_sandbox_enabled() else terrarium_tool_description, } function_calling_description_for_llm = { ConversationCommand.Notes: "To search the user's personal knowledge base. Especially helpful if the question expects context from the user's notes or documents.", ConversationCommand.Online: "To search the internet for information. Useful to get a quick, broad overview from the internet. Provide all relevant context to ensure new searches, not in previous iterations, are performed.", ConversationCommand.Webpage: "To extract information from webpages. Useful for more detailed research from the internet. Usually used when you know the webpage links to refer to. Share the webpage links and information to extract in your query.", - ConversationCommand.Code: "To run Python code in a Pyodide sandbox with no network access. Helpful when need to parse complex information, run complex calculations, create plaintext documents, and create charts with quantitative data. Only matplotlib, panda, numpy, scipy, bs4 and sympy external packages are available.", + ConversationCommand.Code: e2b_tool_description if is_e2b_code_sandbox_enabled() else terrarium_tool_description, } mode_descriptions_for_llm = { diff --git a/src/khoj/utils/initialization.py b/src/khoj/utils/initialization.py index 5f4254b5..b5c661c4 100644 --- a/src/khoj/utils/initialization.py +++ b/src/khoj/utils/initialization.py @@ -185,16 +185,18 @@ def initialization(interactive: bool = True): ) provider_name = provider_name or model_type.name.capitalize() - default_use_model = {True: "y", False: "n"}[default_api_key is not None] - - # If not in interactive mode & in the offline setting, it's most likely that we're running in a containerized environment. This usually means there's not enough RAM to load offline models directly within the application. In such cases, we default to not using the model -- it's recommended to use another service like Ollama to host the model locally in that case. - default_use_model = {True: "n", False: default_use_model}[is_offline] + default_use_model = default_api_key is not None + # If not in interactive mode & in the offline setting, it's most likely that we're running in a containerized environment. + # This usually means there's not enough RAM to load offline models directly within the application. + # In such cases, we default to not using the model -- it's recommended to use another service like Ollama to host the model locally in that case. + if is_offline: + default_use_model = False use_model_provider = ( - default_use_model if not interactive else input(f"Add {provider_name} chat models? (y/n): ") + default_use_model if not interactive else input(f"Add {provider_name} chat models? (y/n): ") == "y" ) - if use_model_provider != "y": + if not use_model_provider: return False, None logger.info(f"️💬 Setting up your {provider_name} chat configuration") @@ -303,4 +305,19 @@ def initialization(interactive: bool = True): logger.error(f"🚨 Failed to create chat configuration: {e}", exc_info=True) else: _update_chat_model_options() - logger.info("🗣️ Chat model configuration updated") + logger.info("🗣️ Chat model options updated") + + # Update the default chat model if it doesn't match + chat_config = ConversationAdapters.get_default_chat_model() + env_default_chat_model = os.getenv("KHOJ_DEFAULT_CHAT_MODEL") + if not chat_config or not env_default_chat_model: + return + if chat_config.name != env_default_chat_model: + chat_model = ConversationAdapters.get_chat_model_by_name(env_default_chat_model) + if not chat_model: + logger.error( + f"🚨 Not setting default chat model. Chat model {env_default_chat_model} not found in existing chat model options." + ) + return + ConversationAdapters.set_default_chat_model(chat_model) + logger.info(f"🗣️ Default chat model set to {chat_model.name}") diff --git a/tests/conftest.py b/tests/conftest.py index 1795b340..e5ab3a8e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -315,7 +315,7 @@ def chat_client_builder(search_config, user, index_content=True, require_auth=Fa if chat_provider == ChatModel.ModelType.OPENAI: online_chat_model = ChatModelFactory(name="gpt-4o-mini", model_type="openai") elif chat_provider == ChatModel.ModelType.GOOGLE: - online_chat_model = ChatModelFactory(name="gemini-1.5-flash", model_type="google") + online_chat_model = ChatModelFactory(name="gemini-2.0-flash", model_type="google") elif chat_provider == ChatModel.ModelType.ANTHROPIC: online_chat_model = ChatModelFactory(name="claude-3-5-haiku-20241022", model_type="anthropic") if online_chat_model: diff --git a/tests/evals/eval.py b/tests/evals/eval.py index 20a6051e..e9d56f03 100644 --- a/tests/evals/eval.py +++ b/tests/evals/eval.py @@ -1,13 +1,11 @@ import argparse import base64 import concurrent.futures -import hashlib import json import logging import os import re import time -import uuid from datetime import datetime from functools import partial from io import StringIO @@ -553,6 +551,7 @@ def process_batch(batch, batch_start, results, dataset_length, response_evaluato --------- Decision: {colored_decision} Accuracy: {running_accuracy:.2%} +Progress: {running_total_count.get()/dataset_length:.2%} Question: {prompt} Expected Answer: {answer} Agent Answer: {agent_response} @@ -630,7 +629,7 @@ def main(): response_evaluator = evaluate_response_with_mcq_match elif args.dataset == "math500": response_evaluator = partial( - evaluate_response_with_gemini, eval_model=os.getenv("GEMINI_EVAL_MODEL", "gemini-1.5-flash-002") + evaluate_response_with_gemini, eval_model=os.getenv("GEMINI_EVAL_MODEL", "gemini-2.0-flash-001") ) elif args.dataset == "frames_ir": response_evaluator = evaluate_response_for_ir @@ -667,7 +666,7 @@ def main(): colored_accuracy_str = f"Overall Accuracy: {colored_accuracy} on {args.dataset.title()} dataset." accuracy_str = f"Overall Accuracy: {accuracy:.2%} on {args.dataset}." accuracy_by_reasoning = f"Accuracy by Reasoning Type:\n{reasoning_type_accuracy}" - cost = f"Total Cost: ${running_cost.get():.5f}." + cost = f"Total Cost: ${running_cost.get():.5f} to evaluate {running_total_count.get()} results." sample_type = f"Sampling Type: {SAMPLE_SIZE} samples." if SAMPLE_SIZE else "Whole dataset." sample_type += " Randomized." if RANDOMIZE else "" logger.info(f"\n{colored_accuracy_str}\n\n{accuracy_by_reasoning}\n\n{cost}\n\n{sample_type}\n")