mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-07 13:23:15 +00:00
Improve Code Tool, Sandbox and Eval (#1120)
# Improve Code Tool, Sandbox - Improve code gen chat actor to output code in inline md code blocks - Stop code sandbox on request timeout to allow sandbox process restarts - Use tenacity retry decorator to retry executing code in sandbox - Add retry logic to code execution and add health check to sandbox container - Add E2B as an optional code sandbox provider # Improve Gemini Chat Models - Default to non-zero temperature for all queries to Gemini models - Default to Gemini 2.0 flash instead of 1.5 flash on setup - Set default chat model to KHOJ_CHAT_MODEL env var if set
This commit is contained in:
23
.github/workflows/run_evals.yml
vendored
23
.github/workflows/run_evals.yml
vendored
@@ -32,6 +32,19 @@ on:
|
||||
required: false
|
||||
default: 200
|
||||
type: number
|
||||
sandbox:
|
||||
description: 'Code sandbox to use'
|
||||
required: false
|
||||
default: 'terrarium'
|
||||
type: choice
|
||||
options:
|
||||
- terrarium
|
||||
- e2b
|
||||
chat_model:
|
||||
description: 'Chat model to use'
|
||||
required: false
|
||||
default: 'gemini-2.0-flash'
|
||||
type: string
|
||||
|
||||
jobs:
|
||||
eval:
|
||||
@@ -40,7 +53,7 @@ jobs:
|
||||
matrix:
|
||||
# Use input from manual trigger if available, else run all combinations
|
||||
khoj_mode: ${{ github.event_name == 'workflow_dispatch' && fromJSON(format('["{0}"]', inputs.khoj_mode)) || fromJSON('["general", "default", "research"]') }}
|
||||
dataset: ${{ github.event_name == 'workflow_dispatch' && fromJSON(format('["{0}"]', inputs.dataset)) || fromJSON('["frames", "simpleqa"]') }}
|
||||
dataset: ${{ github.event_name == 'workflow_dispatch' && fromJSON(format('["{0}"]', inputs.dataset)) || fromJSON('["frames", "simpleqa", "gpqa"]') }}
|
||||
|
||||
services:
|
||||
postgres:
|
||||
@@ -95,11 +108,14 @@ jobs:
|
||||
BATCH_SIZE: "20"
|
||||
RANDOMIZE: "True"
|
||||
KHOJ_URL: "http://localhost:42110"
|
||||
KHOJ_DEFAULT_CHAT_MODEL: ${{ github.event_name == 'workflow_dispatch' && inputs.chat_model || 'gemini-2.0-flash' }}
|
||||
KHOJ_LLM_SEED: "42"
|
||||
GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }}
|
||||
SERPER_DEV_API_KEY: ${{ matrix.dataset != 'math500' && secrets.SERPER_DEV_API_KEY }}
|
||||
OLOSTEP_API_KEY: ${{ matrix.dataset != 'math500' && secrets.OLOSTEP_API_KEY }}
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
E2B_API_KEY: ${{ inputs.sandbox == 'e2b' && secrets.E2B_API_KEY }}
|
||||
E2B_TEMPLATE: ${{ vars.E2B_TEMPLATE }}
|
||||
KHOJ_ADMIN_EMAIL: khoj
|
||||
KHOJ_ADMIN_PASSWORD: khoj
|
||||
POSTGRES_HOST: localhost
|
||||
@@ -114,7 +130,7 @@ jobs:
|
||||
|
||||
# Start code sandbox
|
||||
npm install -g pm2
|
||||
npm run ci --prefix terrarium
|
||||
NODE_ENV=production npm run ci --prefix terrarium
|
||||
|
||||
# Wait for server to be ready
|
||||
timeout=120
|
||||
@@ -147,7 +163,8 @@ jobs:
|
||||
echo "## Evaluation Summary of Khoj on ${{ matrix.dataset }} in ${{ matrix.khoj_mode }} mode" >> $GITHUB_STEP_SUMMARY
|
||||
echo "**$(head -n 1 *_evaluation_summary_*.txt)**" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Khoj Version: ${{ steps.hatch.outputs.version }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Chat Model: Gemini 1.5 Flash 002" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Chat Model: ${{ inputs.chat_model }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Code Sandbox: ${{ inputs.sandbox}}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "\`\`\`" >> $GITHUB_STEP_SUMMARY
|
||||
tail -n +2 *_evaluation_summary_*.txt >> $GITHUB_STEP_SUMMARY
|
||||
echo "" >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
@@ -14,6 +14,11 @@ services:
|
||||
retries: 5
|
||||
sandbox:
|
||||
image: ghcr.io/khoj-ai/terrarium:latest
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "curl -f http://localhost:8080/health"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 2
|
||||
search:
|
||||
image: docker.io/searxng/searxng:latest
|
||||
volumes:
|
||||
@@ -53,8 +58,10 @@ services:
|
||||
- KHOJ_DEBUG=False
|
||||
- KHOJ_ADMIN_EMAIL=username@example.com
|
||||
- KHOJ_ADMIN_PASSWORD=password
|
||||
# Default URL of Terrarium, the Python sandbox used by Khoj to run code. Its container is specified above
|
||||
# Default URL of Terrarium, the default Python sandbox used by Khoj to run code. Its container is specified above
|
||||
- KHOJ_TERRARIUM_URL=http://sandbox:8080
|
||||
# Uncomment line below to have Khoj run code in remote E2B code sandbox instead of the self-hosted Terrarium sandbox above. Get your E2B API key from https://e2b.dev/.
|
||||
# - E2B_API_KEY=your_e2b_api_key
|
||||
# Default URL of SearxNG, the default web search engine used by Khoj. Its container is specified above
|
||||
- KHOJ_SEARXNG_URL=http://search:8080
|
||||
# Uncomment line below to use with Ollama running on your local machine at localhost:11434.
|
||||
|
||||
@@ -3,22 +3,23 @@
|
||||
|
||||
# Code Execution
|
||||
|
||||
Khoj can generate and run very simple Python code snippets as well. This is useful if you want to generate a plot, run a simple calculation, or do some basic data manipulation. LLMs by default aren't skilled at complex quantitative tasks. Code generation & execution can come in handy for such tasks.
|
||||
Khoj can generate and run simple Python code as well. This is useful if you want to have Khoj do some data analysis, generate plots and reports. LLMs by default aren't skilled at complex quantitative tasks. Code generation & execution can come in handy for such tasks.
|
||||
|
||||
Just use `/code` in your chat command.
|
||||
Khoj automatically infers when to use the code tool. You can also tell it explicitly to use the code tool or use the `/code` [slash command](https://docs.khoj.dev/features/chat/#commands) in your chat.
|
||||
|
||||
### Setup (Self-Hosting)
|
||||
Run [Cohere's Terrarium](https://github.com/cohere-ai/cohere-terrarium) on your machine to enable code generation and execution.
|
||||
## Setup (Self-Hosting)
|
||||
### Terrarium Sandbox
|
||||
Use [Cohere's Terrarium](https://github.com/cohere-ai/cohere-terrarium) to host the code sandbox locally on your machine for free.
|
||||
|
||||
Check the [instructions](https://github.com/cohere-ai/cohere-terrarium?tab=readme-ov-file#development) for running from source.
|
||||
|
||||
For running with Docker, you can use our [docker-compose.yml](https://github.com/khoj-ai/khoj/blob/master/docker-compose.yml), or start it manually like this:
|
||||
To run with Docker, use our [docker-compose.yml](https://github.com/khoj-ai/khoj/blob/master/docker-compose.yml) to automatically setup the Terrarium code sandbox, or start it manually like this:
|
||||
|
||||
```bash
|
||||
docker pull ghcr.io/khoj-ai/terrarium:latest
|
||||
docker run -d -p 8080:8080 ghcr.io/khoj-ai/terrarium:latest
|
||||
```
|
||||
|
||||
To run from source, check [these instructions](https://github.com/khoj-ai/cohere-terrarium?tab=readme-ov-file#development).
|
||||
|
||||
#### Verify
|
||||
Verify that it's running, by evaluating a simple Python expression:
|
||||
|
||||
@@ -28,3 +29,12 @@ curl -X POST -H "Content-Type: application/json" \
|
||||
--data-raw '{"code": "1 + 1"}' \
|
||||
--no-buffer
|
||||
```
|
||||
|
||||
### E2B Sandbox
|
||||
[E2B](https://e2b.dev/) allows Khoj to run code on a remote but versatile sandbox with support for more python libraries. This is [not free](https://e2b.dev/pricing).
|
||||
|
||||
To have Khoj use E2B as the code sandbox:
|
||||
1. Generate an API key on [their dashboard](https://e2b.dev/dashboard).
|
||||
2. Set the `E2B_API_KEY` environment variable to it on the machine running your Khoj server.
|
||||
- When using our [docker-compose.yml](https://github.com/khoj-ai/khoj/blob/master/docker-compose.yml), uncomment and set the `E2B_API_KEY` env var in the `docker-compose.yml` file.
|
||||
3. Now restart your Khoj server to switch to using the E2B code sandbox.
|
||||
|
||||
@@ -333,7 +333,7 @@ Using Ollama? See the [Ollama Integration](/advanced/ollama) section for more cu
|
||||
- Add your [Gemini API key](https://aistudio.google.com/app/apikey)
|
||||
- Give the configuration a friendly name like `Gemini`. Do not configure the API base url.
|
||||
2. Create a new [chat model](http://localhost:42110/server/admin/database/chatmodel/add)
|
||||
- Set the `chat-model` field to a [Google Gemini chat model](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-models). Example: `gemini-1.5-flash`.
|
||||
- Set the `chat-model` field to a [Google Gemini chat model](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-models). Example: `gemini-2.0-flash`.
|
||||
- Set the `model-type` field to `Google`.
|
||||
- Set the `ai model api` field to the Gemini AI Model API you created in step 1.
|
||||
|
||||
|
||||
@@ -68,7 +68,7 @@ dependencies = [
|
||||
"authlib == 1.2.1",
|
||||
"llama-cpp-python == 0.2.88",
|
||||
"itsdangerous == 2.1.2",
|
||||
"httpx == 0.25.0",
|
||||
"httpx == 0.27.2",
|
||||
"pgvector == 0.2.4",
|
||||
"psycopg2-binary == 2.9.9",
|
||||
"lxml == 4.9.3",
|
||||
@@ -92,6 +92,7 @@ dependencies = [
|
||||
"pyjson5 == 1.6.7",
|
||||
"resend == 1.0.1",
|
||||
"email-validator == 2.2.0",
|
||||
"e2b-code-interpreter ~= 1.0.0",
|
||||
]
|
||||
dynamic = ["version"]
|
||||
|
||||
|
||||
@@ -1107,6 +1107,12 @@ class ConversationAdapters:
|
||||
return config.setting
|
||||
return ConversationAdapters.aget_advanced_chat_model(user)
|
||||
|
||||
@staticmethod
|
||||
def get_chat_model_by_name(chat_model_name: str, ai_model_api_name: str = None):
|
||||
if ai_model_api_name:
|
||||
return ChatModel.objects.filter(name=chat_model_name, ai_model_api__name=ai_model_api_name).first()
|
||||
return ChatModel.objects.filter(name=chat_model_name).first()
|
||||
|
||||
@staticmethod
|
||||
async def aget_voice_model_config(user: KhojUser) -> Optional[VoiceModelOption]:
|
||||
voice_model_config = await UserVoiceModelConfig.objects.filter(user=user).prefetch_related("setting").afirst()
|
||||
@@ -1205,6 +1211,15 @@ class ConversationAdapters:
|
||||
return server_chat_settings.chat_advanced
|
||||
return await ConversationAdapters.aget_default_chat_model(user)
|
||||
|
||||
@staticmethod
|
||||
def set_default_chat_model(chat_model: ChatModel):
|
||||
server_chat_settings = ServerChatSettings.objects.first()
|
||||
if server_chat_settings:
|
||||
server_chat_settings.chat_default = chat_model
|
||||
server_chat_settings.save()
|
||||
else:
|
||||
ServerChatSettings.objects.create(chat_default=chat_model)
|
||||
|
||||
@staticmethod
|
||||
async def aget_server_webscraper():
|
||||
server_chat_settings = await ServerChatSettings.objects.filter().prefetch_related("web_scraper").afirst()
|
||||
|
||||
@@ -31,10 +31,10 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
def extract_questions_gemini(
|
||||
text,
|
||||
model: Optional[str] = "gemini-1.5-flash",
|
||||
model: Optional[str] = "gemini-2.0-flash",
|
||||
conversation_log={},
|
||||
api_key=None,
|
||||
temperature=0,
|
||||
temperature=0.2,
|
||||
max_tokens=None,
|
||||
location_data: LocationData = None,
|
||||
user: KhojUser = None,
|
||||
@@ -121,7 +121,7 @@ def gemini_send_message_to_model(
|
||||
api_key,
|
||||
model,
|
||||
response_type="text",
|
||||
temperature=0,
|
||||
temperature=0.2,
|
||||
model_kwargs=None,
|
||||
tracer={},
|
||||
):
|
||||
@@ -132,9 +132,9 @@ def gemini_send_message_to_model(
|
||||
|
||||
model_kwargs = {}
|
||||
|
||||
# Sometimes, this causes unwanted behavior and terminates response early. Disable for now while it's flaky.
|
||||
# if response_type == "json_object":
|
||||
# model_kwargs["response_mime_type"] = "application/json"
|
||||
# This caused unwanted behavior and terminates response early for gemini 1.5 series. Monitor for flakiness with 2.0 series.
|
||||
if response_type == "json_object" and model in ["gemini-2.0-flash"]:
|
||||
model_kwargs["response_mime_type"] = "application/json"
|
||||
|
||||
# Get Response from Gemini
|
||||
return gemini_completion_with_backoff(
|
||||
@@ -154,7 +154,7 @@ def converse_gemini(
|
||||
online_results: Optional[Dict[str, Dict]] = None,
|
||||
code_results: Optional[Dict[str, Dict]] = None,
|
||||
conversation_log={},
|
||||
model: Optional[str] = "gemini-1.5-flash",
|
||||
model: Optional[str] = "gemini-2.0-flash",
|
||||
api_key: Optional[str] = None,
|
||||
temperature: float = 0.2,
|
||||
completion_func=None,
|
||||
|
||||
@@ -974,11 +974,9 @@ Khoj:
|
||||
python_code_generation_prompt = PromptTemplate.from_template(
|
||||
"""
|
||||
You are Khoj, an advanced python programmer. You are tasked with constructing a python program to best answer the user query.
|
||||
- The python program will run in a pyodide python sandbox with no network access.
|
||||
- The python program will run in a sandbox with no network access.
|
||||
- You can write programs to run complex calculations, analyze data, create charts, generate documents to meticulously answer the query.
|
||||
- The sandbox has access to the standard library, matplotlib, panda, numpy, scipy, bs4 and sympy packages. The requests, torch, catboost, tensorflow and tkinter packages are not available.
|
||||
- List known file paths to required user documents in "input_files" and known links to required documents from the web in the "input_links" field.
|
||||
- The python program should be self-contained. It can only read data generated by the program itself and from provided input_files, input_links by their basename (i.e filename excluding file path).
|
||||
- The python program should be self-contained. It can only read data generated by the program itself and any user file paths referenced in your program.
|
||||
- Do not try display images or plots in the code directly. The code should save the image or plot to a file instead.
|
||||
- Write any document, charts etc. to be shared with the user to file. These files can be seen by the user.
|
||||
- Use as much context from the previous questions and answers as required to generate your code.
|
||||
@@ -989,24 +987,99 @@ Current Date: {current_date}
|
||||
User's Location: {location}
|
||||
{username}
|
||||
|
||||
The response JSON schema is of the form {{"code": "<python_code>", "input_files": ["file_path_1", "file_path_2"], "input_links": ["link_1", "link_2"]}}
|
||||
Examples:
|
||||
Your response should contain python code wrapped in markdown code blocks (i.e starting with```python and ending with ```)
|
||||
Example 1:
|
||||
---
|
||||
{{
|
||||
"code": "# Input values\\nprincipal = 43235\\nrate = 5.24\\nyears = 5\\n\\n# Convert rate to decimal\\nrate_decimal = rate / 100\\n\\n# Calculate final amount\\nfinal_amount = principal * (1 + rate_decimal) ** years\\n\\n# Calculate interest earned\\ninterest_earned = final_amount - principal\\n\\n# Print results with formatting\\nprint(f"Interest Earned: ${{interest_earned:,.2f}}")\\nprint(f"Final Amount: ${{final_amount:,.2f}}")"
|
||||
}}
|
||||
Q: Calculate the interest earned and final amount for a principal of $43,235 invested at a rate of 5.24 percent for 5 years.
|
||||
A: Ok, to calculate the interest earned and final amount, we can use the formula for compound interest: $T = P(1 + r/n)^{{nt}}$,
|
||||
where T: total amount, P: principal, r: interest rate, n: number of times interest is compounded per year, and t: time in years.
|
||||
|
||||
{{
|
||||
"code": "import re\\n\\n# Read org file\\nfile_path = 'tasks.org'\\nwith open(file_path, 'r') as f:\\n content = f.read()\\n\\n# Get today's date in YYYY-MM-DD format\\ntoday = datetime.now().strftime('%Y-%m-%d')\\npattern = r'\*+\s+.*\\n.*SCHEDULED:\s+<' + today + r'.*>'\\n\\n# Find all matches using multiline mode\\nmatches = re.findall(pattern, content, re.MULTILINE)\\ncount = len(matches)\\n\\n# Display count\\nprint(f'Count of scheduled tasks for today: {{count}}')",
|
||||
"input_files": ["/home/linux/tasks.org"]
|
||||
}}
|
||||
Let's write the Python program to calculate this.
|
||||
|
||||
{{
|
||||
"code": "import pandas as pd\\nimport matplotlib.pyplot as plt\\n\\n# Load the CSV file\\ndf = pd.read_csv('world_population_by_year.csv')\\n\\n# Plot the data\\nplt.figure(figsize=(10, 6))\\nplt.plot(df['Year'], df['Population'], marker='o')\\n\\n# Add titles and labels\\nplt.title('Population by Year')\\nplt.xlabel('Year')\\nplt.ylabel('Population')\\n\\n# Save the plot to a file\\nplt.savefig('population_by_year_plot.png')",
|
||||
"input_links": ["https://population.un.org/world_population_by_year.csv"]
|
||||
}}
|
||||
```python
|
||||
# Input values
|
||||
principal = 43235
|
||||
rate = 5.24
|
||||
years = 5
|
||||
|
||||
# Convert rate to decimal
|
||||
rate_decimal = rate / 100
|
||||
|
||||
# Calculate final amount
|
||||
final_amount = principal * (1 + rate_decimal) ** years
|
||||
|
||||
# Calculate interest earned
|
||||
interest_earned = final_amount - principal
|
||||
|
||||
# Print results with formatting
|
||||
print(f"Interest Earned: ${{interest_earned:,.2f}}")
|
||||
print(f"Final Amount: ${{final_amount:,.2f}}")
|
||||
```
|
||||
|
||||
Example 2:
|
||||
---
|
||||
Q: Simplify first, then evaluate: $-7x+2(x^{{2}}-1)-(2x^{{2}}-x+3)$, where $x=1$.
|
||||
A: Certainly! Let's break down the problem step-by-step and utilize Python with SymPy to simplify and evaluate the expression.
|
||||
|
||||
1. **Expression Simplification:**
|
||||
We start with the expression \\(-7x + 2(x^2 - 1) - (2x^2 - x + 3)\\).
|
||||
|
||||
2. **Substitute \\(x=1\\) into the simplified expression:**
|
||||
Once simplified, we will substitute \\(x=1\\) into the expression to find its value.
|
||||
|
||||
Let's implement this in Python using SymPy (as the package is available in the sandbox):
|
||||
|
||||
```python
|
||||
import sympy as sp
|
||||
|
||||
# Define the variable
|
||||
x = sp.symbols('x')
|
||||
|
||||
# Define the expression
|
||||
expression = -7*x + 2*(x**2 - 1) - (2*x**2 - x + 3)
|
||||
|
||||
# Simplify the expression
|
||||
simplified_expression = sp.simplify(expression)
|
||||
|
||||
# Substitute x = 1 into the simplified expression
|
||||
evaluated_expression = simplified_expression.subs(x, 1)
|
||||
|
||||
# Print the simplified expression and its evaluated value
|
||||
print(\"Simplified Expression:\", simplified_expression)
|
||||
print(\"Evaluated Expression at x=1:\", evaluated_expression)
|
||||
```
|
||||
|
||||
Example 3:
|
||||
---
|
||||
Q: Plot the world population growth over the years, given this year, world population world tuples: [(2000, 6), (2001, 7), (2002, 8), (2003, 9), (2004, 10)].
|
||||
A: Absolutely! We can utilize the Pandas and Matplotlib libraries (as both are available in the sandbox) to create the world population growth plot.
|
||||
```python
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# Create a DataFrame of world population from the provided data
|
||||
data = {{
|
||||
'Year': [2000, 2001, 2002, 2003, 2004],
|
||||
'Population': [6, 7, 8, 9, 10]
|
||||
}}
|
||||
df = pd.DataFrame(data)
|
||||
|
||||
# Plot the data
|
||||
plt.figure(figsize=(10, 6))
|
||||
plt.plot(df['Year'], df['Population'], marker='o')
|
||||
|
||||
# Add titles and labels
|
||||
plt.title('Population by Year')
|
||||
plt.xlabel('Year')
|
||||
plt.ylabel('Population')
|
||||
|
||||
# Save the plot to a file
|
||||
plt.savefig('population_by_year_plot.png')
|
||||
```
|
||||
|
||||
Now it's your turn to construct a python program to answer the user's query using the provided context and coversation provided below.
|
||||
Ensure you include the python code to execute and wrap it in a markdown code block.
|
||||
|
||||
Now it's your turn to construct a python program to answer the user's question. Provide the code, required input files and input links in a JSON object. Do not say anything else.
|
||||
Context:
|
||||
---
|
||||
{context}
|
||||
@@ -1015,8 +1088,9 @@ Chat History:
|
||||
---
|
||||
{chat_history}
|
||||
|
||||
User: {query}
|
||||
Khoj:
|
||||
User Query:
|
||||
---
|
||||
{query}
|
||||
""".strip()
|
||||
)
|
||||
|
||||
@@ -1030,6 +1104,13 @@ Code Execution Results:
|
||||
""".strip()
|
||||
)
|
||||
|
||||
e2b_sandbox_context = """
|
||||
- The sandbox has access to only the standard library, matplotlib, pandas, numpy, scipy, bs4, sympy, einops, biopython, shapely, plotly and rdkit packages. The requests, torch, catboost, tensorflow and tkinter packages are not available.
|
||||
""".strip()
|
||||
|
||||
terrarium_sandbox_context = """
|
||||
The sandbox has access to the standard library, matplotlib, pandas, numpy, scipy, bs4 and sympy packages. The requests, torch, catboost, tensorflow, rdkit and tkinter packages are not available.
|
||||
""".strip()
|
||||
|
||||
# Automations
|
||||
# --
|
||||
|
||||
@@ -1,12 +1,23 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import datetime
|
||||
import logging
|
||||
import mimetypes
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, List, NamedTuple, Optional
|
||||
|
||||
import aiohttp
|
||||
from asgiref.sync import sync_to_async
|
||||
from httpx import RemoteProtocolError
|
||||
from tenacity import (
|
||||
before_sleep_log,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_random_exponential,
|
||||
)
|
||||
|
||||
from khoj.database.adapters import FileObjectAdapters
|
||||
from khoj.database.models import Agent, FileObject, KhojUser
|
||||
@@ -15,22 +26,26 @@ from khoj.processor.conversation.utils import (
|
||||
ChatEvent,
|
||||
clean_code_python,
|
||||
construct_chat_history,
|
||||
load_complex_json,
|
||||
)
|
||||
from khoj.routers.helpers import send_message_to_model_wrapper
|
||||
from khoj.utils.helpers import is_none_or_empty, timer, truncate_code_context
|
||||
from khoj.utils.helpers import (
|
||||
is_e2b_code_sandbox_enabled,
|
||||
is_none_or_empty,
|
||||
timer,
|
||||
truncate_code_context,
|
||||
)
|
||||
from khoj.utils.rawconfig import LocationData
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
SANDBOX_URL = os.getenv("KHOJ_TERRARIUM_URL", "http://localhost:8080")
|
||||
DEFAULT_E2B_TEMPLATE = "pmt2o0ghpang8gbiys57"
|
||||
|
||||
|
||||
class GeneratedCode(NamedTuple):
|
||||
code: str
|
||||
input_files: List[str]
|
||||
input_links: List[str]
|
||||
input_files: List[FileObject]
|
||||
|
||||
|
||||
async def run_code(
|
||||
@@ -68,13 +83,10 @@ async def run_code(
|
||||
|
||||
# Prepare Input Data
|
||||
input_data = []
|
||||
user_input_files: List[FileObject] = []
|
||||
for input_file in generated_code.input_files:
|
||||
user_input_files += await FileObjectAdapters.aget_file_objects_by_name(user, input_file)
|
||||
for f in user_input_files:
|
||||
for f in generated_code.input_files:
|
||||
input_data.append(
|
||||
{
|
||||
"filename": os.path.basename(f.file_name),
|
||||
"filename": f.file_name,
|
||||
"b64_data": base64.b64encode(f.raw_text.encode("utf-8")).decode("utf-8"),
|
||||
}
|
||||
)
|
||||
@@ -90,6 +102,14 @@ async def run_code(
|
||||
cleaned_result = truncate_code_context({"cleaned": {"results": result}})["cleaned"]["results"]
|
||||
logger.info(f"Executed Code\n----\n{code}\n----\nResult\n----\n{cleaned_result}\n----")
|
||||
yield {query: {"code": code, "results": result}}
|
||||
except asyncio.TimeoutError as e:
|
||||
# Call the sandbox_url/stop GET API endpoint to stop the code sandbox
|
||||
error = f"Failed to run code for {query} with Timeout error: {e}"
|
||||
try:
|
||||
await aiohttp.ClientSession().get(f"{sandbox_url}/stop", timeout=5)
|
||||
except Exception as e:
|
||||
error += f"\n\nFailed to stop code sandbox with error: {e}"
|
||||
raise ValueError(error)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to run code for {query} with error: {e}")
|
||||
|
||||
@@ -114,6 +134,12 @@ async def generate_python_code(
|
||||
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
||||
)
|
||||
|
||||
# add sandbox specific context like available packages
|
||||
sandbox_context = (
|
||||
prompts.e2b_sandbox_context if is_e2b_code_sandbox_enabled() else prompts.terrarium_sandbox_context
|
||||
)
|
||||
personality_context = f"{sandbox_context}\n{personality_context}"
|
||||
|
||||
code_generation_prompt = prompts.python_code_generation_prompt.format(
|
||||
current_date=utc_date,
|
||||
query=q,
|
||||
@@ -127,23 +153,50 @@ async def generate_python_code(
|
||||
response = await send_message_to_model_wrapper(
|
||||
code_generation_prompt,
|
||||
query_images=query_images,
|
||||
response_type="json_object",
|
||||
user=user,
|
||||
tracer=tracer,
|
||||
query_files=query_files,
|
||||
)
|
||||
|
||||
# Validate that the response is a non-empty, JSON-serializable list
|
||||
response = load_complex_json(response)
|
||||
code = response.get("code", "").strip()
|
||||
input_files = response.get("input_files", [])
|
||||
input_links = response.get("input_links", [])
|
||||
# Extract python code wrapped in markdown code blocks from the response
|
||||
code_blocks = re.findall(r"```(?:python)?\n(.*?)\n```", response, re.DOTALL)
|
||||
|
||||
if not code_blocks:
|
||||
raise ValueError("No Python code blocks found in response")
|
||||
|
||||
# Join multiple code blocks with newlines and strip any leading/trailing whitespace
|
||||
code = "\n".join(code_blocks).strip()
|
||||
|
||||
if not isinstance(code, str) or is_none_or_empty(code):
|
||||
raise ValueError
|
||||
return GeneratedCode(code, input_files, input_links)
|
||||
|
||||
# Infer user files required in sandbox based on user file paths mentioned in code
|
||||
input_files: List[FileObject] = []
|
||||
user_files = await sync_to_async(set)(FileObjectAdapters.get_all_file_objects(user))
|
||||
for user_file in user_files:
|
||||
if user_file.file_name in code:
|
||||
# Replace references to full file path used in code with just the file basename to ease reference in sandbox
|
||||
file_basename = os.path.basename(user_file.file_name)
|
||||
code = code.replace(user_file.file_name, file_basename)
|
||||
user_file.file_name = file_basename
|
||||
input_files.append(user_file)
|
||||
|
||||
return GeneratedCode(code, input_files)
|
||||
|
||||
|
||||
@retry(
|
||||
retry=(
|
||||
retry_if_exception_type(aiohttp.ClientError)
|
||||
| retry_if_exception_type(aiohttp.ClientTimeout)
|
||||
| retry_if_exception_type(asyncio.TimeoutError)
|
||||
| retry_if_exception_type(ConnectionError)
|
||||
| retry_if_exception_type(RemoteProtocolError)
|
||||
),
|
||||
wait=wait_random_exponential(min=1, max=5),
|
||||
stop=stop_after_attempt(3),
|
||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
reraise=True,
|
||||
)
|
||||
async def execute_sandboxed_python(code: str, input_data: list[dict], sandbox_url: str = SANDBOX_URL) -> dict[str, Any]:
|
||||
"""
|
||||
Takes code to run as a string and calls the terrarium API to execute it.
|
||||
@@ -152,15 +205,104 @@ async def execute_sandboxed_python(code: str, input_data: list[dict], sandbox_ur
|
||||
Reference data i/o format based on Terrarium example client code at:
|
||||
https://github.com/cohere-ai/cohere-terrarium/blob/main/example-clients/python/terrarium_client.py
|
||||
"""
|
||||
headers = {"Content-Type": "application/json"}
|
||||
cleaned_code = clean_code_python(code)
|
||||
data = {"code": cleaned_code, "files": input_data}
|
||||
if is_e2b_code_sandbox_enabled():
|
||||
try:
|
||||
return await execute_e2b(cleaned_code, input_data)
|
||||
except ImportError:
|
||||
pass
|
||||
return await execute_terrarium(cleaned_code, input_data, sandbox_url)
|
||||
|
||||
|
||||
async def execute_e2b(code: str, input_files: list[dict]) -> dict[str, Any]:
|
||||
"""Execute code and handle file I/O in e2b sandbox"""
|
||||
from e2b_code_interpreter import AsyncSandbox
|
||||
|
||||
sandbox = await AsyncSandbox.create(
|
||||
api_key=os.getenv("E2B_API_KEY"),
|
||||
template=os.getenv("E2B_TEMPLATE", DEFAULT_E2B_TEMPLATE),
|
||||
timeout=120,
|
||||
request_timeout=30,
|
||||
)
|
||||
|
||||
try:
|
||||
# Upload input files in parallel
|
||||
upload_tasks = [
|
||||
sandbox.files.write(path=file["filename"], data=base64.b64decode(file["b64_data"]), request_timeout=30)
|
||||
for file in input_files
|
||||
]
|
||||
await asyncio.gather(*upload_tasks)
|
||||
|
||||
# Note stored files before execution to identify new files created during execution
|
||||
E2bFile = NamedTuple("E2bFile", [("name", str), ("path", str)])
|
||||
original_files = {E2bFile(f.name, f.path) for f in await sandbox.files.list("~")}
|
||||
|
||||
# Execute code from main.py file
|
||||
execution = await sandbox.run_code(code=code, timeout=60)
|
||||
|
||||
# Collect output files
|
||||
output_files = []
|
||||
|
||||
# Identify new files created during execution
|
||||
new_files = set(E2bFile(f.name, f.path) for f in await sandbox.files.list("~")) - original_files
|
||||
# Read newly created files in parallel
|
||||
download_tasks = [sandbox.files.read(f.path, request_timeout=30) for f in new_files]
|
||||
downloaded_files = await asyncio.gather(*download_tasks)
|
||||
for f, content in zip(new_files, downloaded_files):
|
||||
if isinstance(content, bytes):
|
||||
# Binary files like PNG - encode as base64
|
||||
b64_data = base64.b64encode(content).decode("utf-8")
|
||||
elif Path(f.name).suffix in [".png", ".jpeg", ".jpg", ".svg"]:
|
||||
# Ignore image files as they are extracted from execution results below for inline display
|
||||
continue
|
||||
else:
|
||||
# Text files - encode utf-8 string as base64
|
||||
b64_data = base64.b64encode(content.encode("utf-8")).decode("utf-8")
|
||||
output_files.append({"filename": f.name, "b64_data": b64_data})
|
||||
|
||||
# Collect output files from execution results
|
||||
for idx, result in enumerate(execution.results):
|
||||
for result_type in {"png", "jpeg", "svg", "text", "markdown", "json"}:
|
||||
if b64_data := getattr(result, result_type, None):
|
||||
output_files.append({"filename": f"{idx}.{result_type}", "b64_data": b64_data})
|
||||
break
|
||||
|
||||
# collect logs
|
||||
success = not execution.error and not execution.logs.stderr
|
||||
stdout = "\n".join(execution.logs.stdout)
|
||||
errors = "\n".join(execution.logs.stderr)
|
||||
if execution.error:
|
||||
errors = f"{execution.error}\n{errors}"
|
||||
|
||||
return {
|
||||
"code": code,
|
||||
"success": success,
|
||||
"std_out": stdout,
|
||||
"std_err": errors,
|
||||
"output_files": output_files,
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"code": code,
|
||||
"success": False,
|
||||
"std_err": f"Sandbox failed to execute code: {str(e)}",
|
||||
"output_files": [],
|
||||
}
|
||||
|
||||
|
||||
async def execute_terrarium(
|
||||
code: str,
|
||||
input_data: list[dict],
|
||||
sandbox_url: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Execute code using Terrarium sandbox"""
|
||||
headers = {"Content-Type": "application/json"}
|
||||
data = {"code": code, "files": input_data}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(sandbox_url, json=data, headers=headers) as response:
|
||||
async with session.post(sandbox_url, json=data, headers=headers, timeout=30) as response:
|
||||
if response.status == 200:
|
||||
result: dict[str, Any] = await response.json()
|
||||
result["code"] = cleaned_code
|
||||
result["code"] = code
|
||||
# Store decoded output files
|
||||
result["output_files"] = result.get("output_files", [])
|
||||
for output_file in result["output_files"]:
|
||||
@@ -172,7 +314,7 @@ async def execute_sandboxed_python(code: str, input_data: list[dict], sandbox_ur
|
||||
return result
|
||||
else:
|
||||
return {
|
||||
"code": cleaned_code,
|
||||
"code": code,
|
||||
"success": False,
|
||||
"std_err": f"Failed to execute code with {response.status}",
|
||||
"output_files": [],
|
||||
|
||||
@@ -18,7 +18,7 @@ default_offline_chat_models = [
|
||||
"bartowski/Qwen2.5-14B-Instruct-GGUF",
|
||||
]
|
||||
default_openai_chat_models = ["gpt-4o-mini", "gpt-4o"]
|
||||
default_gemini_chat_models = ["gemini-1.5-flash", "gemini-1.5-pro"]
|
||||
default_gemini_chat_models = ["gemini-2.0-flash", "gemini-1.5-pro"]
|
||||
default_anthropic_chat_models = ["claude-3-5-sonnet-20241022", "claude-3-5-haiku-20241022"]
|
||||
|
||||
empty_config = {
|
||||
@@ -46,6 +46,7 @@ model_to_cost: Dict[str, Dict[str, float]] = {
|
||||
"gemini-1.5-flash-002": {"input": 0.075, "output": 0.30},
|
||||
"gemini-1.5-pro": {"input": 1.25, "output": 5.00},
|
||||
"gemini-1.5-pro-002": {"input": 1.25, "output": 5.00},
|
||||
"gemini-2.0-flash": {"input": 0.10, "output": 0.40},
|
||||
# Anthropic Pricing: https://www.anthropic.com/pricing#anthropic-api_
|
||||
"claude-3-5-sonnet-20241022": {"input": 3.0, "output": 15.0},
|
||||
"claude-3-5-haiku-20241022": {"input": 1.0, "output": 5.0},
|
||||
|
||||
@@ -321,6 +321,12 @@ def get_device() -> torch.device:
|
||||
return torch.device("cpu")
|
||||
|
||||
|
||||
def is_e2b_code_sandbox_enabled():
|
||||
"""Check if E2B code sandbox is enabled.
|
||||
Set E2B_API_KEY environment variable to use it."""
|
||||
return not is_none_or_empty(os.getenv("E2B_API_KEY"))
|
||||
|
||||
|
||||
class ConversationCommand(str, Enum):
|
||||
Default = "default"
|
||||
General = "general"
|
||||
@@ -362,20 +368,23 @@ command_descriptions_for_agent = {
|
||||
ConversationCommand.Code: "Agent can run Python code to parse information, run complex calculations, create documents and charts.",
|
||||
}
|
||||
|
||||
e2b_tool_description = "To run Python code in a E2B sandbox with no network access. Helpful to parse complex information, run calculations, create text documents and create charts with quantitative data. Only matplotlib, pandas, numpy, scipy, bs4, sympy, einops, biopython, shapely and rdkit external packages are available."
|
||||
terrarium_tool_description = "To run Python code in a Terrarium, Pyodide sandbox with no network access. Helpful to parse complex information, run complex calculations, create plaintext documents and create charts with quantitative data. Only matplotlib, panda, numpy, scipy, bs4 and sympy external packages are available."
|
||||
|
||||
tool_descriptions_for_llm = {
|
||||
ConversationCommand.Default: "To use a mix of your internal knowledge and the user's personal knowledge, or if you don't entirely understand the query.",
|
||||
ConversationCommand.General: "To use when you can answer the question without any outside information or personal knowledge",
|
||||
ConversationCommand.Notes: "To search the user's personal knowledge base. Especially helpful if the question expects context from the user's notes or documents.",
|
||||
ConversationCommand.Online: "To search for the latest, up-to-date information from the internet. Note: **Questions about Khoj should always use this data source**",
|
||||
ConversationCommand.Webpage: "To use if the user has directly provided the webpage urls or you are certain of the webpage urls to read.",
|
||||
ConversationCommand.Code: "To run Python code in a Pyodide sandbox with no network access. Helpful when need to parse complex information, run complex calculations, create plaintext documents, and create charts with quantitative data. Only matplotlib, panda, numpy, scipy, bs4 and sympy external packages are available.",
|
||||
ConversationCommand.Code: e2b_tool_description if is_e2b_code_sandbox_enabled() else terrarium_tool_description,
|
||||
}
|
||||
|
||||
function_calling_description_for_llm = {
|
||||
ConversationCommand.Notes: "To search the user's personal knowledge base. Especially helpful if the question expects context from the user's notes or documents.",
|
||||
ConversationCommand.Online: "To search the internet for information. Useful to get a quick, broad overview from the internet. Provide all relevant context to ensure new searches, not in previous iterations, are performed.",
|
||||
ConversationCommand.Webpage: "To extract information from webpages. Useful for more detailed research from the internet. Usually used when you know the webpage links to refer to. Share the webpage links and information to extract in your query.",
|
||||
ConversationCommand.Code: "To run Python code in a Pyodide sandbox with no network access. Helpful when need to parse complex information, run complex calculations, create plaintext documents, and create charts with quantitative data. Only matplotlib, panda, numpy, scipy, bs4 and sympy external packages are available.",
|
||||
ConversationCommand.Code: e2b_tool_description if is_e2b_code_sandbox_enabled() else terrarium_tool_description,
|
||||
}
|
||||
|
||||
mode_descriptions_for_llm = {
|
||||
|
||||
@@ -185,16 +185,18 @@ def initialization(interactive: bool = True):
|
||||
)
|
||||
provider_name = provider_name or model_type.name.capitalize()
|
||||
|
||||
default_use_model = {True: "y", False: "n"}[default_api_key is not None]
|
||||
|
||||
# If not in interactive mode & in the offline setting, it's most likely that we're running in a containerized environment. This usually means there's not enough RAM to load offline models directly within the application. In such cases, we default to not using the model -- it's recommended to use another service like Ollama to host the model locally in that case.
|
||||
default_use_model = {True: "n", False: default_use_model}[is_offline]
|
||||
default_use_model = default_api_key is not None
|
||||
# If not in interactive mode & in the offline setting, it's most likely that we're running in a containerized environment.
|
||||
# This usually means there's not enough RAM to load offline models directly within the application.
|
||||
# In such cases, we default to not using the model -- it's recommended to use another service like Ollama to host the model locally in that case.
|
||||
if is_offline:
|
||||
default_use_model = False
|
||||
|
||||
use_model_provider = (
|
||||
default_use_model if not interactive else input(f"Add {provider_name} chat models? (y/n): ")
|
||||
default_use_model if not interactive else input(f"Add {provider_name} chat models? (y/n): ") == "y"
|
||||
)
|
||||
|
||||
if use_model_provider != "y":
|
||||
if not use_model_provider:
|
||||
return False, None
|
||||
|
||||
logger.info(f"️💬 Setting up your {provider_name} chat configuration")
|
||||
@@ -303,4 +305,19 @@ def initialization(interactive: bool = True):
|
||||
logger.error(f"🚨 Failed to create chat configuration: {e}", exc_info=True)
|
||||
else:
|
||||
_update_chat_model_options()
|
||||
logger.info("🗣️ Chat model configuration updated")
|
||||
logger.info("🗣️ Chat model options updated")
|
||||
|
||||
# Update the default chat model if it doesn't match
|
||||
chat_config = ConversationAdapters.get_default_chat_model()
|
||||
env_default_chat_model = os.getenv("KHOJ_DEFAULT_CHAT_MODEL")
|
||||
if not chat_config or not env_default_chat_model:
|
||||
return
|
||||
if chat_config.name != env_default_chat_model:
|
||||
chat_model = ConversationAdapters.get_chat_model_by_name(env_default_chat_model)
|
||||
if not chat_model:
|
||||
logger.error(
|
||||
f"🚨 Not setting default chat model. Chat model {env_default_chat_model} not found in existing chat model options."
|
||||
)
|
||||
return
|
||||
ConversationAdapters.set_default_chat_model(chat_model)
|
||||
logger.info(f"🗣️ Default chat model set to {chat_model.name}")
|
||||
|
||||
@@ -315,7 +315,7 @@ def chat_client_builder(search_config, user, index_content=True, require_auth=Fa
|
||||
if chat_provider == ChatModel.ModelType.OPENAI:
|
||||
online_chat_model = ChatModelFactory(name="gpt-4o-mini", model_type="openai")
|
||||
elif chat_provider == ChatModel.ModelType.GOOGLE:
|
||||
online_chat_model = ChatModelFactory(name="gemini-1.5-flash", model_type="google")
|
||||
online_chat_model = ChatModelFactory(name="gemini-2.0-flash", model_type="google")
|
||||
elif chat_provider == ChatModel.ModelType.ANTHROPIC:
|
||||
online_chat_model = ChatModelFactory(name="claude-3-5-haiku-20241022", model_type="anthropic")
|
||||
if online_chat_model:
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
import argparse
|
||||
import base64
|
||||
import concurrent.futures
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from io import StringIO
|
||||
@@ -553,6 +551,7 @@ def process_batch(batch, batch_start, results, dataset_length, response_evaluato
|
||||
---------
|
||||
Decision: {colored_decision}
|
||||
Accuracy: {running_accuracy:.2%}
|
||||
Progress: {running_total_count.get()/dataset_length:.2%}
|
||||
Question: {prompt}
|
||||
Expected Answer: {answer}
|
||||
Agent Answer: {agent_response}
|
||||
@@ -630,7 +629,7 @@ def main():
|
||||
response_evaluator = evaluate_response_with_mcq_match
|
||||
elif args.dataset == "math500":
|
||||
response_evaluator = partial(
|
||||
evaluate_response_with_gemini, eval_model=os.getenv("GEMINI_EVAL_MODEL", "gemini-1.5-flash-002")
|
||||
evaluate_response_with_gemini, eval_model=os.getenv("GEMINI_EVAL_MODEL", "gemini-2.0-flash-001")
|
||||
)
|
||||
elif args.dataset == "frames_ir":
|
||||
response_evaluator = evaluate_response_for_ir
|
||||
@@ -667,7 +666,7 @@ def main():
|
||||
colored_accuracy_str = f"Overall Accuracy: {colored_accuracy} on {args.dataset.title()} dataset."
|
||||
accuracy_str = f"Overall Accuracy: {accuracy:.2%} on {args.dataset}."
|
||||
accuracy_by_reasoning = f"Accuracy by Reasoning Type:\n{reasoning_type_accuracy}"
|
||||
cost = f"Total Cost: ${running_cost.get():.5f}."
|
||||
cost = f"Total Cost: ${running_cost.get():.5f} to evaluate {running_total_count.get()} results."
|
||||
sample_type = f"Sampling Type: {SAMPLE_SIZE} samples." if SAMPLE_SIZE else "Whole dataset."
|
||||
sample_type += " Randomized." if RANDOMIZE else ""
|
||||
logger.info(f"\n{colored_accuracy_str}\n\n{accuracy_by_reasoning}\n\n{cost}\n\n{sample_type}\n")
|
||||
|
||||
Reference in New Issue
Block a user