diff --git a/.github/workflows/run_evals.yml b/.github/workflows/run_evals.yml index 71123acf..21870c04 100644 --- a/.github/workflows/run_evals.yml +++ b/.github/workflows/run_evals.yml @@ -40,6 +40,11 @@ on: options: - terrarium - e2b + chat_model: + description: 'Chat model to use' + required: false + default: 'gemini-2.0-flash' + type: string jobs: eval: @@ -48,7 +53,7 @@ jobs: matrix: # Use input from manual trigger if available, else run all combinations khoj_mode: ${{ github.event_name == 'workflow_dispatch' && fromJSON(format('["{0}"]', inputs.khoj_mode)) || fromJSON('["general", "default", "research"]') }} - dataset: ${{ github.event_name == 'workflow_dispatch' && fromJSON(format('["{0}"]', inputs.dataset)) || fromJSON('["frames", "simpleqa"]') }} + dataset: ${{ github.event_name == 'workflow_dispatch' && fromJSON(format('["{0}"]', inputs.dataset)) || fromJSON('["frames", "simpleqa", "gpqa"]') }} services: postgres: @@ -103,6 +108,7 @@ jobs: BATCH_SIZE: "20" RANDOMIZE: "True" KHOJ_URL: "http://localhost:42110" + KHOJ_CHAT_MODEL: ${{ github.event_name == 'workflow_dispatch' && inputs.chat_model || 'gemini-2.0-flash' }} KHOJ_LLM_SEED: "42" GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }} SERPER_DEV_API_KEY: ${{ matrix.dataset != 'math500' && secrets.SERPER_DEV_API_KEY }} @@ -157,7 +163,7 @@ jobs: echo "## Evaluation Summary of Khoj on ${{ matrix.dataset }} in ${{ matrix.khoj_mode }} mode" >> $GITHUB_STEP_SUMMARY echo "**$(head -n 1 *_evaluation_summary_*.txt)**" >> $GITHUB_STEP_SUMMARY echo "- Khoj Version: ${{ steps.hatch.outputs.version }}" >> $GITHUB_STEP_SUMMARY - echo "- Chat Model: Gemini 2.0 Flash" >> $GITHUB_STEP_SUMMARY + echo "- Chat Model: ${{ inputs.chat_model }}" >> $GITHUB_STEP_SUMMARY echo "- Code Sandbox: ${{ inputs.sandbox}}" >> $GITHUB_STEP_SUMMARY echo "\`\`\`" >> $GITHUB_STEP_SUMMARY tail -n +2 *_evaluation_summary_*.txt >> $GITHUB_STEP_SUMMARY diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 33e879aa..058017d2 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -1107,6 +1107,12 @@ class ConversationAdapters: return config.setting return ConversationAdapters.aget_advanced_chat_model(user) + @staticmethod + def get_chat_model_by_name(chat_model_name: str, ai_model_api_name: str = None): + if ai_model_api_name: + return ChatModel.objects.filter(name=chat_model_name, ai_model_api__name=ai_model_api_name).first() + return ChatModel.objects.filter(name=chat_model_name).first() + @staticmethod async def aget_voice_model_config(user: KhojUser) -> Optional[VoiceModelOption]: voice_model_config = await UserVoiceModelConfig.objects.filter(user=user).prefetch_related("setting").afirst() @@ -1205,6 +1211,15 @@ class ConversationAdapters: return server_chat_settings.chat_advanced return await ConversationAdapters.aget_default_chat_model(user) + @staticmethod + def set_default_chat_model(chat_model: ChatModel): + server_chat_settings = ServerChatSettings.objects.first() + if server_chat_settings: + server_chat_settings.chat_default = chat_model + server_chat_settings.save() + else: + ServerChatSettings.objects.create(chat_default=chat_model) + @staticmethod async def aget_server_webscraper(): server_chat_settings = await ServerChatSettings.objects.filter().prefetch_related("web_scraper").afirst() diff --git a/src/khoj/utils/initialization.py b/src/khoj/utils/initialization.py index 5f4254b5..3ea73891 100644 --- a/src/khoj/utils/initialization.py +++ b/src/khoj/utils/initialization.py @@ -185,16 +185,18 @@ def initialization(interactive: bool = True): ) provider_name = provider_name or model_type.name.capitalize() - default_use_model = {True: "y", False: "n"}[default_api_key is not None] - - # If not in interactive mode & in the offline setting, it's most likely that we're running in a containerized environment. This usually means there's not enough RAM to load offline models directly within the application. In such cases, we default to not using the model -- it's recommended to use another service like Ollama to host the model locally in that case. - default_use_model = {True: "n", False: default_use_model}[is_offline] + default_use_model = default_api_key is not None + # If not in interactive mode & in the offline setting, it's most likely that we're running in a containerized environment. + # This usually means there's not enough RAM to load offline models directly within the application. + # In such cases, we default to not using the model -- it's recommended to use another service like Ollama to host the model locally in that case. + if is_offline: + default_use_model = False use_model_provider = ( - default_use_model if not interactive else input(f"Add {provider_name} chat models? (y/n): ") + default_use_model if not interactive else input(f"Add {provider_name} chat models? (y/n): ") == "y" ) - if use_model_provider != "y": + if not use_model_provider: return False, None logger.info(f"️💬 Setting up your {provider_name} chat configuration") @@ -303,4 +305,19 @@ def initialization(interactive: bool = True): logger.error(f"🚨 Failed to create chat configuration: {e}", exc_info=True) else: _update_chat_model_options() - logger.info("🗣️ Chat model configuration updated") + logger.info("🗣️ Chat model options updated") + + # Update the default chat model if it doesn't match + chat_config = ConversationAdapters.get_default_chat_model() + env_default_chat_model = os.getenv("KHOJ_CHAT_MODEL") + if not chat_config or not env_default_chat_model: + return + if chat_config.name != env_default_chat_model: + chat_model = ConversationAdapters.get_chat_model_by_name(env_default_chat_model) + if not chat_model: + logger.error( + f"🚨 Not setting default chat model. Chat model {env_default_chat_model} not found in existing chat model options." + ) + return + ConversationAdapters.set_default_chat_model(chat_model) + logger.info(f"🗣️ Default chat model set to {chat_model.name}") diff --git a/tests/evals/eval.py b/tests/evals/eval.py index 0c95996f..e9d56f03 100644 --- a/tests/evals/eval.py +++ b/tests/evals/eval.py @@ -666,7 +666,7 @@ def main(): colored_accuracy_str = f"Overall Accuracy: {colored_accuracy} on {args.dataset.title()} dataset." accuracy_str = f"Overall Accuracy: {accuracy:.2%} on {args.dataset}." accuracy_by_reasoning = f"Accuracy by Reasoning Type:\n{reasoning_type_accuracy}" - cost = f"Total Cost: ${running_cost.get():.5f}." + cost = f"Total Cost: ${running_cost.get():.5f} to evaluate {running_total_count.get()} results." sample_type = f"Sampling Type: {SAMPLE_SIZE} samples." if SAMPLE_SIZE else "Whole dataset." sample_type += " Randomized." if RANDOMIZE else "" logger.info(f"\n{colored_accuracy_str}\n\n{accuracy_by_reasoning}\n\n{cost}\n\n{sample_type}\n")