mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-09 21:29:11 +00:00
Set default chat model to KHOJ_CHAT_MODEL env var if set
Simplify code log to set default_use_model during init for readability
This commit is contained in:
10
.github/workflows/run_evals.yml
vendored
10
.github/workflows/run_evals.yml
vendored
@@ -40,6 +40,11 @@ on:
|
|||||||
options:
|
options:
|
||||||
- terrarium
|
- terrarium
|
||||||
- e2b
|
- e2b
|
||||||
|
chat_model:
|
||||||
|
description: 'Chat model to use'
|
||||||
|
required: false
|
||||||
|
default: 'gemini-2.0-flash'
|
||||||
|
type: string
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
eval:
|
eval:
|
||||||
@@ -48,7 +53,7 @@ jobs:
|
|||||||
matrix:
|
matrix:
|
||||||
# Use input from manual trigger if available, else run all combinations
|
# Use input from manual trigger if available, else run all combinations
|
||||||
khoj_mode: ${{ github.event_name == 'workflow_dispatch' && fromJSON(format('["{0}"]', inputs.khoj_mode)) || fromJSON('["general", "default", "research"]') }}
|
khoj_mode: ${{ github.event_name == 'workflow_dispatch' && fromJSON(format('["{0}"]', inputs.khoj_mode)) || fromJSON('["general", "default", "research"]') }}
|
||||||
dataset: ${{ github.event_name == 'workflow_dispatch' && fromJSON(format('["{0}"]', inputs.dataset)) || fromJSON('["frames", "simpleqa"]') }}
|
dataset: ${{ github.event_name == 'workflow_dispatch' && fromJSON(format('["{0}"]', inputs.dataset)) || fromJSON('["frames", "simpleqa", "gpqa"]') }}
|
||||||
|
|
||||||
services:
|
services:
|
||||||
postgres:
|
postgres:
|
||||||
@@ -103,6 +108,7 @@ jobs:
|
|||||||
BATCH_SIZE: "20"
|
BATCH_SIZE: "20"
|
||||||
RANDOMIZE: "True"
|
RANDOMIZE: "True"
|
||||||
KHOJ_URL: "http://localhost:42110"
|
KHOJ_URL: "http://localhost:42110"
|
||||||
|
KHOJ_CHAT_MODEL: ${{ github.event_name == 'workflow_dispatch' && inputs.chat_model || 'gemini-2.0-flash' }}
|
||||||
KHOJ_LLM_SEED: "42"
|
KHOJ_LLM_SEED: "42"
|
||||||
GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }}
|
GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }}
|
||||||
SERPER_DEV_API_KEY: ${{ matrix.dataset != 'math500' && secrets.SERPER_DEV_API_KEY }}
|
SERPER_DEV_API_KEY: ${{ matrix.dataset != 'math500' && secrets.SERPER_DEV_API_KEY }}
|
||||||
@@ -157,7 +163,7 @@ jobs:
|
|||||||
echo "## Evaluation Summary of Khoj on ${{ matrix.dataset }} in ${{ matrix.khoj_mode }} mode" >> $GITHUB_STEP_SUMMARY
|
echo "## Evaluation Summary of Khoj on ${{ matrix.dataset }} in ${{ matrix.khoj_mode }} mode" >> $GITHUB_STEP_SUMMARY
|
||||||
echo "**$(head -n 1 *_evaluation_summary_*.txt)**" >> $GITHUB_STEP_SUMMARY
|
echo "**$(head -n 1 *_evaluation_summary_*.txt)**" >> $GITHUB_STEP_SUMMARY
|
||||||
echo "- Khoj Version: ${{ steps.hatch.outputs.version }}" >> $GITHUB_STEP_SUMMARY
|
echo "- Khoj Version: ${{ steps.hatch.outputs.version }}" >> $GITHUB_STEP_SUMMARY
|
||||||
echo "- Chat Model: Gemini 2.0 Flash" >> $GITHUB_STEP_SUMMARY
|
echo "- Chat Model: ${{ inputs.chat_model }}" >> $GITHUB_STEP_SUMMARY
|
||||||
echo "- Code Sandbox: ${{ inputs.sandbox}}" >> $GITHUB_STEP_SUMMARY
|
echo "- Code Sandbox: ${{ inputs.sandbox}}" >> $GITHUB_STEP_SUMMARY
|
||||||
echo "\`\`\`" >> $GITHUB_STEP_SUMMARY
|
echo "\`\`\`" >> $GITHUB_STEP_SUMMARY
|
||||||
tail -n +2 *_evaluation_summary_*.txt >> $GITHUB_STEP_SUMMARY
|
tail -n +2 *_evaluation_summary_*.txt >> $GITHUB_STEP_SUMMARY
|
||||||
|
|||||||
@@ -1107,6 +1107,12 @@ class ConversationAdapters:
|
|||||||
return config.setting
|
return config.setting
|
||||||
return ConversationAdapters.aget_advanced_chat_model(user)
|
return ConversationAdapters.aget_advanced_chat_model(user)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_chat_model_by_name(chat_model_name: str, ai_model_api_name: str = None):
|
||||||
|
if ai_model_api_name:
|
||||||
|
return ChatModel.objects.filter(name=chat_model_name, ai_model_api__name=ai_model_api_name).first()
|
||||||
|
return ChatModel.objects.filter(name=chat_model_name).first()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def aget_voice_model_config(user: KhojUser) -> Optional[VoiceModelOption]:
|
async def aget_voice_model_config(user: KhojUser) -> Optional[VoiceModelOption]:
|
||||||
voice_model_config = await UserVoiceModelConfig.objects.filter(user=user).prefetch_related("setting").afirst()
|
voice_model_config = await UserVoiceModelConfig.objects.filter(user=user).prefetch_related("setting").afirst()
|
||||||
@@ -1205,6 +1211,15 @@ class ConversationAdapters:
|
|||||||
return server_chat_settings.chat_advanced
|
return server_chat_settings.chat_advanced
|
||||||
return await ConversationAdapters.aget_default_chat_model(user)
|
return await ConversationAdapters.aget_default_chat_model(user)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_default_chat_model(chat_model: ChatModel):
|
||||||
|
server_chat_settings = ServerChatSettings.objects.first()
|
||||||
|
if server_chat_settings:
|
||||||
|
server_chat_settings.chat_default = chat_model
|
||||||
|
server_chat_settings.save()
|
||||||
|
else:
|
||||||
|
ServerChatSettings.objects.create(chat_default=chat_model)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def aget_server_webscraper():
|
async def aget_server_webscraper():
|
||||||
server_chat_settings = await ServerChatSettings.objects.filter().prefetch_related("web_scraper").afirst()
|
server_chat_settings = await ServerChatSettings.objects.filter().prefetch_related("web_scraper").afirst()
|
||||||
|
|||||||
@@ -185,16 +185,18 @@ def initialization(interactive: bool = True):
|
|||||||
)
|
)
|
||||||
provider_name = provider_name or model_type.name.capitalize()
|
provider_name = provider_name or model_type.name.capitalize()
|
||||||
|
|
||||||
default_use_model = {True: "y", False: "n"}[default_api_key is not None]
|
default_use_model = default_api_key is not None
|
||||||
|
# If not in interactive mode & in the offline setting, it's most likely that we're running in a containerized environment.
|
||||||
# If not in interactive mode & in the offline setting, it's most likely that we're running in a containerized environment. This usually means there's not enough RAM to load offline models directly within the application. In such cases, we default to not using the model -- it's recommended to use another service like Ollama to host the model locally in that case.
|
# This usually means there's not enough RAM to load offline models directly within the application.
|
||||||
default_use_model = {True: "n", False: default_use_model}[is_offline]
|
# In such cases, we default to not using the model -- it's recommended to use another service like Ollama to host the model locally in that case.
|
||||||
|
if is_offline:
|
||||||
|
default_use_model = False
|
||||||
|
|
||||||
use_model_provider = (
|
use_model_provider = (
|
||||||
default_use_model if not interactive else input(f"Add {provider_name} chat models? (y/n): ")
|
default_use_model if not interactive else input(f"Add {provider_name} chat models? (y/n): ") == "y"
|
||||||
)
|
)
|
||||||
|
|
||||||
if use_model_provider != "y":
|
if not use_model_provider:
|
||||||
return False, None
|
return False, None
|
||||||
|
|
||||||
logger.info(f"️💬 Setting up your {provider_name} chat configuration")
|
logger.info(f"️💬 Setting up your {provider_name} chat configuration")
|
||||||
@@ -303,4 +305,19 @@ def initialization(interactive: bool = True):
|
|||||||
logger.error(f"🚨 Failed to create chat configuration: {e}", exc_info=True)
|
logger.error(f"🚨 Failed to create chat configuration: {e}", exc_info=True)
|
||||||
else:
|
else:
|
||||||
_update_chat_model_options()
|
_update_chat_model_options()
|
||||||
logger.info("🗣️ Chat model configuration updated")
|
logger.info("🗣️ Chat model options updated")
|
||||||
|
|
||||||
|
# Update the default chat model if it doesn't match
|
||||||
|
chat_config = ConversationAdapters.get_default_chat_model()
|
||||||
|
env_default_chat_model = os.getenv("KHOJ_CHAT_MODEL")
|
||||||
|
if not chat_config or not env_default_chat_model:
|
||||||
|
return
|
||||||
|
if chat_config.name != env_default_chat_model:
|
||||||
|
chat_model = ConversationAdapters.get_chat_model_by_name(env_default_chat_model)
|
||||||
|
if not chat_model:
|
||||||
|
logger.error(
|
||||||
|
f"🚨 Not setting default chat model. Chat model {env_default_chat_model} not found in existing chat model options."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
ConversationAdapters.set_default_chat_model(chat_model)
|
||||||
|
logger.info(f"🗣️ Default chat model set to {chat_model.name}")
|
||||||
|
|||||||
@@ -666,7 +666,7 @@ def main():
|
|||||||
colored_accuracy_str = f"Overall Accuracy: {colored_accuracy} on {args.dataset.title()} dataset."
|
colored_accuracy_str = f"Overall Accuracy: {colored_accuracy} on {args.dataset.title()} dataset."
|
||||||
accuracy_str = f"Overall Accuracy: {accuracy:.2%} on {args.dataset}."
|
accuracy_str = f"Overall Accuracy: {accuracy:.2%} on {args.dataset}."
|
||||||
accuracy_by_reasoning = f"Accuracy by Reasoning Type:\n{reasoning_type_accuracy}"
|
accuracy_by_reasoning = f"Accuracy by Reasoning Type:\n{reasoning_type_accuracy}"
|
||||||
cost = f"Total Cost: ${running_cost.get():.5f}."
|
cost = f"Total Cost: ${running_cost.get():.5f} to evaluate {running_total_count.get()} results."
|
||||||
sample_type = f"Sampling Type: {SAMPLE_SIZE} samples." if SAMPLE_SIZE else "Whole dataset."
|
sample_type = f"Sampling Type: {SAMPLE_SIZE} samples." if SAMPLE_SIZE else "Whole dataset."
|
||||||
sample_type += " Randomized." if RANDOMIZE else ""
|
sample_type += " Randomized." if RANDOMIZE else ""
|
||||||
logger.info(f"\n{colored_accuracy_str}\n\n{accuracy_by_reasoning}\n\n{cost}\n\n{sample_type}\n")
|
logger.info(f"\n{colored_accuracy_str}\n\n{accuracy_by_reasoning}\n\n{cost}\n\n{sample_type}\n")
|
||||||
|
|||||||
Reference in New Issue
Block a user