Improve response quality with Gemini. Improve evaluation harness (#1150)

### Improve Gemini usage
- Allow text tool to give agent ability to terminate research
- Set default context for gemini 2 flash models
  2x context window for small, commercial models to 120K
- Default temperature of Gemini models to 1.0 to reduce repetition

### Improve evaluation harness
- Add more knobs to control eval workflow
  - Allow running eval with any chat model served over an openai compatible api
  - Control random sampling from eval set
  - Auto read web page
- Use embedded postgres instead of postgres server for eval workflow
- Use Gemini 2.0 flash as evaluator. Set seed for evaluator to reduce decision variance
This commit is contained in:
Debanjum
2025-04-04 20:17:36 +05:30
committed by GitHub
10 changed files with 61 additions and 25 deletions

View File

@@ -51,4 +51,4 @@ body:
description: "Provide a link to the first message of feature request's discussion on Discord or Github.\n
This will help to keep history of why this feature request exists."
validations:
required: false
required: false

View File

@@ -50,11 +50,32 @@ on:
required: false
default: 5
type: number
openai_api_key:
description: 'OpenAI API key'
required: false
default: ''
type: string
openai_base_url:
description: 'Base URL of OpenAI compatible API'
required: false
default: ''
type: string
auto_read_webpage:
description: 'Auto read webpage on online search'
required: false
default: 'false'
type: choice
options:
- 'false'
- 'true'
randomize:
description: 'Randomize the sample of questions'
required: false
default: 'true'
type: choice
options:
- 'false'
- 'true'
jobs:
eval:
@@ -92,15 +113,21 @@ jobs:
- name: Get App Version
id: hatch
run: echo "version=$(pipx run hatch version)" >> $GITHUB_OUTPUT
run: |
# Mask relevant workflow inputs as secret early
OPENAI_API_KEY=$(jq -r '.inputs.openai_api_key' $GITHUB_EVENT_PATH)
echo ::add-mask::$OPENAI_API_KEY
echo OPENAI_API_KEY="$OPENAI_API_KEY" >> $GITHUB_ENV
# Get app version from hatch
echo "version=$(pipx run hatch version)" >> $GITHUB_OUTPUT
- name: ⏬️ Install Dependencies
env:
DEBIAN_FRONTEND: noninteractive
run: |
# install postgres and other dependencies
# install dependencies
sudo apt update && sudo apt install -y git python3-pip libegl1 sqlite3 libsqlite3-dev libsqlite3-0 ffmpeg libsm6 libxext6
sudo apt install -y postgresql postgresql-client && sudo apt install -y postgresql-server-dev-16
# upgrade pip
python -m ensurepip --upgrade && python -m pip install --upgrade pip
# install terrarium for code sandbox
@@ -116,13 +143,13 @@ jobs:
KHOJ_MODE: ${{ matrix.khoj_mode }}
SAMPLE_SIZE: ${{ github.event_name == 'workflow_dispatch' && inputs.sample_size || 200 }}
BATCH_SIZE: "20"
RANDOMIZE: "True"
RANDOMIZE: ${{ github.event_name == 'workflow_dispatch' && inputs.randomize || 'true' }}
KHOJ_URL: "http://localhost:42110"
KHOJ_DEFAULT_CHAT_MODEL: ${{ github.event_name == 'workflow_dispatch' && inputs.chat_model || 'gemini-2.0-flash' }}
KHOJ_LLM_SEED: "42"
KHOJ_DEFAULT_CHAT_MODEL: ${{ github.event_name == 'workflow_dispatch' && inputs.chat_model || 'gemini-2.0-flash' }}
KHOJ_RESEARCH_ITERATIONS: ${{ github.event_name == 'workflow_dispatch' && inputs.max_research_iterations || 5 }}
KHOJ_AUTO_READ_WEBPAGE: ${{ github.event_name == 'workflow_dispatch' && inputs.auto_read_webpage || 'false' }}
GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
OPENAI_BASE_URL: ${{ github.event_name == 'workflow_dispatch' && inputs.openai_base_url || '' }}
SERPER_DEV_API_KEY: ${{ matrix.dataset != 'math500' && secrets.SERPER_DEV_API_KEY || '' }}
OLOSTEP_API_KEY: ${{ matrix.dataset != 'math500' && secrets.OLOSTEP_API_KEY || ''}}
@@ -137,6 +164,7 @@ jobs:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_DB: postgres
USE_EMBEDDED_DB: "true"
KHOJ_TELEMETRY_DISABLE: "True" # To disable telemetry for tests
run: |
# Start Khoj server in background

View File

@@ -166,7 +166,7 @@ def converse_gemini(
model: Optional[str] = "gemini-2.0-flash",
api_key: Optional[str] = None,
api_base_url: Optional[str] = None,
temperature: float = 0.4,
temperature: float = 1.0,
completion_func=None,
conversation_commands=[ConversationCommand.Default],
max_prompt_size=None,

View File

@@ -78,7 +78,7 @@ def get_gemini_client(api_key, api_base_url=None) -> genai.Client:
reraise=True,
)
def gemini_completion_with_backoff(
messages, system_prompt, model_name, temperature=0.8, api_key=None, api_base_url=None, model_kwargs=None, tracer={}
messages, system_prompt, model_name, temperature=1.0, api_key=None, api_base_url=None, model_kwargs=None, tracer={}
) -> str:
client = gemini_clients.get(api_key)
if not client:

View File

@@ -735,7 +735,7 @@ Create a multi-step plan and intelligently iterate on the plan based on the retr
- Ensure that all required context is passed to the tool AIs for successful execution. They only know the context provided in your query.
- Think step by step to come up with creative strategies when the previous iteration did not yield useful results.
- You are allowed upto {max_iterations} iterations to use the help of the provided tool AIs to answer the user's question.
- Stop when you have the required information by returning a JSON object with an empty "tool" field. E.g., {{scratchpad: "I have all I need", tool: "", query: ""}}
- Stop when you have the required information by returning a JSON object with the "tool" field set to "text" and "query" field empty. E.g., {{"scratchpad": "I have all I need", "tool": "text", "query": ""}}
# Examples
Assuming you can search the user's notes and the internet.
@@ -770,7 +770,7 @@ Which of the tool AIs listed below would you use to answer the user's question?
Return the next tool AI to use and the query to ask it. Your response should always be a valid JSON object. Do not say anything else.
Response format:
{{"scratchpad": "<your_scratchpad_to_reason_about_which_tool_to_use>", "query": "<your_detailed_query_for_the_tool_ai>", "tool": "<name_of_tool_ai>"}}
{{"scratchpad": "<your_scratchpad_to_reason_about_which_tool_to_use>", "tool": "<name_of_tool_ai>", "query": "<your_detailed_query_for_the_tool_ai>"}}
""".strip()
)
@@ -917,7 +917,7 @@ User's Location: {location}
Here are some examples:
Example Chat History:
User: I like to use Hacker News to get my tech news.
Khoj: {{queries: ["what is Hacker News?", "Hacker News website for tech news"]}}
Khoj: {{"queries": ["what is Hacker News?", "Hacker News website for tech news"]}}
AI: Hacker News is an online forum for sharing and discussing the latest tech news. It is a great place to learn about new technologies and startups.
User: Summarize the top posts on HackerNews

View File

@@ -52,12 +52,14 @@ except ImportError:
model_to_prompt_size = {
# OpenAI Models
"gpt-4o": 60000,
"gpt-4o-mini": 60000,
"gpt-4o-mini": 120000,
"o1": 20000,
"o1-mini": 60000,
"o3-mini": 60000,
# Google Models
"gemini-1.5-flash": 60000,
"gemini-2.0-flash": 120000,
"gemini-2.0-flash-lite": 120000,
"gemini-1.5-flash": 120000,
"gemini-1.5-pro": 60000,
# Anthropic Models
"claude-3-5-sonnet-20241022": 60000,

View File

@@ -2,7 +2,6 @@ import asyncio
import json
import logging
import os
import urllib.parse
from collections import defaultdict
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
@@ -33,7 +32,7 @@ logger = logging.getLogger(__name__)
GOOGLE_SEARCH_API_KEY = os.getenv("GOOGLE_SEARCH_API_KEY")
GOOGLE_SEARCH_ENGINE_ID = os.getenv("GOOGLE_SEARCH_ENGINE_ID")
SERPER_DEV_API_KEY = os.getenv("SERPER_DEV_API_KEY")
AUTO_READ_WEBPAGE = is_env_var_true("AUTO_READ_WEBPAGE")
AUTO_READ_WEBPAGE = is_env_var_true("KHOJ_AUTO_READ_WEBPAGE")
SERPER_DEV_URL = "https://google.serper.dev/search"
JINA_SEARCH_API_URL = "https://s.jina.ai/"
@@ -113,7 +112,6 @@ async def search_online(
search_engine = "Searxng"
search_engines.append((search_engine, search_with_searxng))
logger.info(f"🌐 Searching the Internet for {subqueries}")
if send_status_func:
subqueries_str = "\n- " + "\n- ".join(subqueries)
async for event in send_status_func(f"**Searching the Internet for**: {subqueries_str}"):
@@ -121,6 +119,7 @@ async def search_online(
response_dict = {}
for search_engine, search_func in search_engines:
logger.info(f"🌐 Searching the Internet with {search_engine} for {subqueries}")
with timer(f"Internet searches with {search_engine} for {subqueries} took", logger):
try:
search_tasks = [search_func(subquery, location) for subquery in subqueries]

View File

@@ -41,11 +41,9 @@ logger = logging.getLogger(__name__)
class PlanningResponse(BaseModel):
"""
Schema for the response from planning agent when deciding the next tool to pick.
The tool field is dynamically validated based on available tools.
"""
scratchpad: str = Field(..., description="Reasoning about which tool to use next")
query: str = Field(..., description="Detailed query for the selected tool")
scratchpad: str = Field(..., description="Scratchpad to reason about which tool to use next")
class Config:
arbitrary_types_allowed = True
@@ -56,6 +54,9 @@ class PlanningResponse(BaseModel):
Factory method that creates a customized PlanningResponse model
with a properly typed tool field based on available tools.
The tool field is dynamically generated based on available tools.
The query field should be filled by the model after the tool field for a more logical reasoning flow.
Args:
tool_options: Dictionary mapping tool names to values
@@ -68,6 +69,7 @@ class PlanningResponse(BaseModel):
# Create and return a customized response model with the enum
class PlanningResponseWithTool(PlanningResponse):
tool: tool_enum = Field(..., description="Name of the tool to use")
query: str = Field(..., description="Detailed query for the selected tool")
return PlanningResponseWithTool
@@ -97,6 +99,7 @@ async def apick_next_tool(
# Skip showing Notes tool as an option if user has no entries
if tool == ConversationCommand.Notes and not user_has_entries:
continue
# Add tool if agent does not have any tools defined or the tool is supported by the agent.
if len(agent_tools) == 0 or tool.value in agent_tools:
tool_options[tool.name] = tool.value
tool_options_str += f'- "{tool.value}": "{description}"\n'
@@ -168,7 +171,9 @@ async def apick_next_tool(
# Only send client status updates if we'll execute this iteration
elif send_status_func:
determined_tool_message = "**Determined Tool**: "
determined_tool_message += f"{selected_tool}({generated_query})." if selected_tool else "respond."
determined_tool_message += (
f"{selected_tool}({generated_query})." if selected_tool != ConversationCommand.Text else "respond."
)
determined_tool_message += f"\nReason: {scratchpad}" if scratchpad else ""
async for event in send_status_func(f"{scratchpad}"):
yield {ChatEvent.STATUS: event}
@@ -235,8 +240,8 @@ async def execute_information_collection(
if this_iteration.warning:
logger.warning(f"Research mode: {this_iteration.warning}.")
# Terminate research if query, tool not set for next iteration
elif not this_iteration.query or not this_iteration.tool:
# Terminate research if selected text tool or query, tool not set for next iteration
elif not this_iteration.query or not this_iteration.tool or this_iteration.tool == ConversationCommand.Text:
current_iteration = MAX_ITERATIONS
elif this_iteration.tool == ConversationCommand.Notes:

View File

@@ -389,6 +389,7 @@ function_calling_description_for_llm = {
ConversationCommand.Online: "To search the internet for information. Useful to get a quick, broad overview from the internet. Provide all relevant context to ensure new searches, not in previous iterations, are performed.",
ConversationCommand.Webpage: "To extract information from webpages. Useful for more detailed research from the internet. Usually used when you know the webpage links to refer to. Share the webpage links and information to extract in your query.",
ConversationCommand.Code: e2b_tool_description if is_e2b_code_sandbox_enabled() else terrarium_tool_description,
ConversationCommand.Text: "To respond to the user once you've completed your research and have the required information.",
}
mode_descriptions_for_llm = {

View File

@@ -37,8 +37,9 @@ KHOJ_API_KEY = os.getenv("KHOJ_API_KEY")
KHOJ_MODE = os.getenv("KHOJ_MODE", "default").lower() # E.g research, general, notes etc.
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
GEMINI_EVAL_MODEL = os.getenv("GEMINI_EVAL_MODEL", "gemini-1.5-pro-002")
GEMINI_EVAL_MODEL = os.getenv("GEMINI_EVAL_MODEL", "gemini-2.0-flash-001")
LLM_SEED = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
SAMPLE_SIZE = os.getenv("SAMPLE_SIZE") # Number of examples to evaluate
RANDOMIZE = os.getenv("RANDOMIZE", "false").lower() == "true" # Randomize examples
BATCH_SIZE = int(
@@ -469,7 +470,7 @@ def evaluate_response_with_gemini(
headers={"Content-Type": "application/json"},
json={
"contents": [{"parts": [{"text": evaluation_prompt}]}],
"generationConfig": {"response_mime_type": "application/json"},
"generationConfig": {"response_mime_type": "application/json", "seed": LLM_SEED},
},
)
response.raise_for_status()