From 99f16df7e2e7e6d5131b171ddf3394b3cd58e5c3 Mon Sep 17 00:00:00 2001 From: Debanjum Date: Wed, 26 Nov 2025 16:56:35 -0800 Subject: [PATCH] Use fast model in default mode and for most chat actors What -- - Default to using fast model for most chat actors. Specifically in this change we default to using fast model for doc, web search chat actors - Only research chat director uses the deep chat model. - Make using fast model by chat actors configurable via func argument Code chat actor continues to use deep chat model and webpage reader continues to use fast chat model. Deep, fast chat models can be configured via ServerChatSettings on the admin panel. Why -- Modern models are good enough at instruction following. So defaulting most chat actor to use the fast model should improve chat speed with acceptable response quality. The option to fallback to research mode for higher quality responses or deeper research always exists. --- src/khoj/processor/image/generate.py | 4 +-- src/khoj/processor/tools/online_search.py | 8 +++-- src/khoj/processor/tools/run_code.py | 2 +- src/khoj/routers/api_chat.py | 16 +++++----- src/khoj/routers/helpers.py | 37 +++++++++++++---------- src/khoj/routers/research.py | 11 +++---- 6 files changed, 43 insertions(+), 35 deletions(-) diff --git a/src/khoj/processor/image/generate.py b/src/khoj/processor/image/generate.py index 4ab1f241..6938cdf5 100644 --- a/src/khoj/processor/image/generate.py +++ b/src/khoj/processor/image/generate.py @@ -46,8 +46,8 @@ async def text_to_image( online_results: Dict[str, Any], send_status_func: Optional[Callable] = None, query_images: Optional[List[str]] = None, - agent: Agent = None, query_files: str = None, + agent: Agent = None, tracer: dict = {}, ): status_code = 200 @@ -90,9 +90,9 @@ async def text_to_image( online_results=online_results, model_type=text_to_image_config.model_type, query_images=query_images, + query_files=query_files, user=user, agent=agent, - query_files=query_files, tracer=tracer, ) image_prompt = image_prompt_response["description"] diff --git a/src/khoj/processor/tools/online_search.py b/src/khoj/processor/tools/online_search.py index 01623efa..beeb640c 100644 --- a/src/khoj/processor/tools/online_search.py +++ b/src/khoj/processor/tools/online_search.py @@ -64,6 +64,7 @@ async def search_online( query_images: List[str] = None, query_files: str = None, previous_subqueries: Set = set(), + fast_model: bool = True, agent: Agent = None, tracer: dict = {}, ): @@ -82,6 +83,7 @@ async def search_online( query_images=query_images, query_files=query_files, max_queries=max_online_searches, + fast_model=fast_model, agent=agent, tracer=tracer, ) @@ -432,9 +434,10 @@ async def read_webpages( user: KhojUser, send_status_func: Optional[Callable] = None, query_images: List[str] = None, + query_files: str = None, + fast_model: bool = True, agent: Agent = None, max_webpages_to_read: int = 1, - query_files: str = None, tracer: dict = {}, ): "Infer web pages to read from the query and extract relevant information from them" @@ -446,8 +449,9 @@ async def read_webpages( location, user, query_images, - agent=agent, query_files=query_files, + fast_model=fast_model, + agent=agent, tracer=tracer, ) async for result in read_webpages_content( diff --git a/src/khoj/processor/tools/run_code.py b/src/khoj/processor/tools/run_code.py index c62c710e..15e77d9c 100644 --- a/src/khoj/processor/tools/run_code.py +++ b/src/khoj/processor/tools/run_code.py @@ -56,9 +56,9 @@ async def run_code( user: KhojUser, send_status_func: Optional[Callable] = None, query_images: List[str] = None, + query_files: str = None, agent: Agent = None, sandbox_url: str = SANDBOX_URL, - query_files: str = None, tracer: dict = {}, ): # Generate Code diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 68e7ea43..d622936f 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -1023,14 +1023,14 @@ async def event_generator( conversation_history=chat_history, previous_iterations=list(research_results), query_images=uploaded_images, - agent=agent, - send_status_func=partial(send_event, ChatEvent.STATUS), + query_files=attached_file_context, user_name=user_name, location=location, - query_files=attached_file_context, + send_status_func=partial(send_event, ChatEvent.STATUS), cancellation_event=cancellation_event, interrupt_queue=child_interrupt_queue, abort_message=ChatEvent.END_EVENT.value, + agent=agent, tracer=tracer, ): if isinstance(research_result, ResearchIteration): @@ -1080,8 +1080,8 @@ async def event_generator( location, partial(send_event, ChatEvent.STATUS), query_images=uploaded_images, - agent=agent, query_files=attached_file_context, + agent=agent, tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: @@ -1156,8 +1156,8 @@ async def event_generator( partial(send_event, ChatEvent.STATUS), max_webpages_to_read=1, query_images=uploaded_images, - agent=agent, query_files=attached_file_context, + agent=agent, tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: @@ -1197,8 +1197,8 @@ async def event_generator( user, partial(send_event, ChatEvent.STATUS), query_images=uploaded_images, - agent=agent, query_files=attached_file_context, + agent=agent, tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: @@ -1275,8 +1275,8 @@ async def event_generator( online_results=online_results, send_status_func=partial(send_event, ChatEvent.STATUS), query_images=uploaded_images, - agent=agent, query_files=attached_file_context, + agent=agent, tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: @@ -1316,10 +1316,10 @@ async def event_generator( note_references=compiled_references, online_results=online_results, query_images=uploaded_images, + query_files=attached_file_context, user=user, agent=agent, send_status_func=partial(send_event, ChatEvent.STATUS), - query_files=attached_file_context, tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index d93e8a49..8e64c1ea 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -415,7 +415,7 @@ async def aget_data_sources_and_output_format( query_images=query_images, response_type="json_object", response_schema=PickTools, - fast_model=False, + fast_model=True, agent_chat_model=agent_chat_model, user=user, tracer=tracer, @@ -467,8 +467,9 @@ async def infer_webpage_urls( location_data: LocationData, user: KhojUser, query_images: List[str] = None, - agent: Agent = None, query_files: str = None, + fast_model: bool = True, + agent: Agent = None, tracer: dict = {}, ) -> List[str]: """ @@ -505,7 +506,7 @@ async def infer_webpage_urls( query_images=query_images, response_type="json_object", response_schema=WebpageUrls, - fast_model=False, + fast_model=fast_model, agent_chat_model=agent_chat_model, user=user, tracer=tracer, @@ -534,6 +535,7 @@ async def generate_online_subqueries( query_images: List[str] = None, query_files: str = None, max_queries: int = 3, + fast_model: bool = True, agent: Agent = None, tracer: dict = {}, ) -> Set[str]: @@ -571,7 +573,7 @@ async def generate_online_subqueries( query_images=query_images, response_type="json_object", response_schema=OnlineQueries, - fast_model=False, + fast_model=fast_model, agent_chat_model=agent_chat_model, user=user, tracer=tracer, @@ -737,9 +739,9 @@ async def generate_summary_from_files( file_filters: List[str], chat_history: List[ChatMessageModel] = [], query_images: List[str] = None, + query_files: str = None, agent: Agent = None, send_status_func: Optional[Callable] = None, - query_files: str = None, tracer: dict = {}, ): try: @@ -797,10 +799,10 @@ async def generate_excalidraw_diagram( note_references: List[Dict[str, Any]], online_results: Optional[dict] = None, query_images: List[str] = None, + query_files: str = None, user: KhojUser = None, agent: Agent = None, send_status_func: Optional[Callable] = None, - query_files: str = None, tracer: dict = {}, ): if send_status_func: @@ -814,9 +816,9 @@ async def generate_excalidraw_diagram( note_references=note_references, online_results=online_results, query_images=query_images, + query_files=query_files, user=user, agent=agent, - query_files=query_files, tracer=tracer, ) @@ -849,9 +851,9 @@ async def generate_better_diagram_description( note_references: List[Dict[str, Any]], online_results: Optional[dict] = None, query_images: List[str] = None, + query_files: str = None, user: KhojUser = None, agent: Agent = None, - query_files: str = None, tracer: dict = {}, ) -> str: """ @@ -959,10 +961,10 @@ async def generate_mermaidjs_diagram( note_references: List[Dict[str, Any]], online_results: Optional[dict] = None, query_images: List[str] = None, + query_files: str = None, user: KhojUser = None, agent: Agent = None, send_status_func: Optional[Callable] = None, - query_files: str = None, tracer: dict = {}, ): if send_status_func: @@ -976,9 +978,9 @@ async def generate_mermaidjs_diagram( note_references=note_references, online_results=online_results, query_images=query_images, + query_files=query_files, user=user, agent=agent, - query_files=query_files, tracer=tracer, ) @@ -1005,9 +1007,9 @@ async def generate_better_mermaidjs_diagram_description( note_references: List[Dict[str, Any]], online_results: Optional[dict] = None, query_images: List[str] = None, + query_files: str = None, user: KhojUser = None, agent: Agent = None, - query_files: str = None, tracer: dict = {}, ) -> str: """ @@ -1099,9 +1101,9 @@ async def generate_better_image_prompt( online_results: Optional[dict] = None, model_type: Optional[str] = None, query_images: Optional[List[str]] = None, + query_files: str = "", user: KhojUser = None, agent: Agent = None, - query_files: str = "", tracer: dict = {}, ) -> dict: """ @@ -1148,7 +1150,7 @@ async def generate_better_image_prompt( system_message=enhance_image_system_message, response_type="json_object", response_schema=ImagePromptResponse, - fast_model=False, + fast_model=True, agent_chat_model=agent_chat_model, user=user, tracer=tracer, @@ -1175,9 +1177,10 @@ async def search_documents( location_data: LocationData = None, send_status_func: Optional[Callable] = None, query_images: Optional[List[str]] = None, - previous_inferred_queries: Set = set(), - agent: Agent = None, query_files: str = None, + previous_inferred_queries: Set = set(), + fast_model: bool = True, + agent: Agent = None, tracer: dict = {}, ): # Initialize Variables @@ -1228,6 +1231,7 @@ async def search_documents( personality_context=personality_context, location_data=location_data, chat_history=chat_history, + fast_model=fast_model, agent=agent, tracer=tracer, ) @@ -1280,6 +1284,7 @@ async def extract_questions( location_data: LocationData = None, chat_history: List[ChatMessageModel] = [], max_queries: int = 5, + fast_model: bool = True, agent: Agent = None, tracer: dict = {}, ): @@ -1334,7 +1339,7 @@ async def extract_questions( system_message=system_prompt, response_type="json_object", response_schema=DocumentQueries, - fast_model=False, + fast_model=fast_model, agent_chat_model=agent_chat_model, user=user, tracer=tracer, diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index 63e6b71d..ba053711 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -243,14 +243,14 @@ async def research( conversation_history: List[ChatMessageModel], previous_iterations: List[ResearchIteration], query_images: List[str], - agent: Agent = None, - send_status_func: Optional[Callable] = None, + query_files: str = None, user_name: str = None, location: LocationData = None, - query_files: str = None, + send_status_func: Optional[Callable] = None, cancellation_event: Optional[asyncio.Event] = None, interrupt_queue: Optional[asyncio.Queue] = None, abort_message: str = ChatEvent.END_EVENT.value, + agent: Agent = None, tracer: dict = {}, ): max_document_searches = 7 @@ -356,10 +356,10 @@ async def research( location_data=location, send_status_func=send_status_func, query_images=query_images, + query_files=query_files, previous_inferred_queries=previous_inferred_queries, agent=agent, tracer=tracer, - query_files=query_files, ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] @@ -424,7 +424,6 @@ async def research( **this_iteration.query.args, user=user, send_status_func=send_status_func, - # max_webpages_to_read=max_webpages_to_read, agent=agent, tracer=tracer, ): @@ -459,8 +458,8 @@ async def research( user=user, send_status_func=send_status_func, query_images=query_images, - agent=agent, query_files=query_files, + agent=agent, tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: