Use fast model in default mode and for most chat actors

What
--
- Default to using fast model for most chat actors. Specifically in this
  change we default to using fast model for doc, web search chat actors
- Only research chat director uses the deep chat model.
- Make using fast model by chat actors configurable via func argument

Code chat actor continues to use deep chat model and webpage reader
continues to use fast chat model.

Deep, fast chat models can be configured via ServerChatSettings on the
admin panel.

Why
--
Modern models are good enough at instruction following. So defaulting
most chat actor to use the fast model should improve chat speed with
acceptable response quality.

The option to fallback to research mode for higher quality
responses or deeper research always exists.
This commit is contained in:
Debanjum
2025-11-26 16:56:35 -08:00
parent da493be417
commit 99f16df7e2
6 changed files with 43 additions and 35 deletions

View File

@@ -46,8 +46,8 @@ async def text_to_image(
online_results: Dict[str, Any],
send_status_func: Optional[Callable] = None,
query_images: Optional[List[str]] = None,
agent: Agent = None,
query_files: str = None,
agent: Agent = None,
tracer: dict = {},
):
status_code = 200
@@ -90,9 +90,9 @@ async def text_to_image(
online_results=online_results,
model_type=text_to_image_config.model_type,
query_images=query_images,
query_files=query_files,
user=user,
agent=agent,
query_files=query_files,
tracer=tracer,
)
image_prompt = image_prompt_response["description"]

View File

@@ -64,6 +64,7 @@ async def search_online(
query_images: List[str] = None,
query_files: str = None,
previous_subqueries: Set = set(),
fast_model: bool = True,
agent: Agent = None,
tracer: dict = {},
):
@@ -82,6 +83,7 @@ async def search_online(
query_images=query_images,
query_files=query_files,
max_queries=max_online_searches,
fast_model=fast_model,
agent=agent,
tracer=tracer,
)
@@ -432,9 +434,10 @@ async def read_webpages(
user: KhojUser,
send_status_func: Optional[Callable] = None,
query_images: List[str] = None,
query_files: str = None,
fast_model: bool = True,
agent: Agent = None,
max_webpages_to_read: int = 1,
query_files: str = None,
tracer: dict = {},
):
"Infer web pages to read from the query and extract relevant information from them"
@@ -446,8 +449,9 @@ async def read_webpages(
location,
user,
query_images,
agent=agent,
query_files=query_files,
fast_model=fast_model,
agent=agent,
tracer=tracer,
)
async for result in read_webpages_content(

View File

@@ -56,9 +56,9 @@ async def run_code(
user: KhojUser,
send_status_func: Optional[Callable] = None,
query_images: List[str] = None,
query_files: str = None,
agent: Agent = None,
sandbox_url: str = SANDBOX_URL,
query_files: str = None,
tracer: dict = {},
):
# Generate Code

View File

@@ -1023,14 +1023,14 @@ async def event_generator(
conversation_history=chat_history,
previous_iterations=list(research_results),
query_images=uploaded_images,
agent=agent,
send_status_func=partial(send_event, ChatEvent.STATUS),
query_files=attached_file_context,
user_name=user_name,
location=location,
query_files=attached_file_context,
send_status_func=partial(send_event, ChatEvent.STATUS),
cancellation_event=cancellation_event,
interrupt_queue=child_interrupt_queue,
abort_message=ChatEvent.END_EVENT.value,
agent=agent,
tracer=tracer,
):
if isinstance(research_result, ResearchIteration):
@@ -1080,8 +1080,8 @@ async def event_generator(
location,
partial(send_event, ChatEvent.STATUS),
query_images=uploaded_images,
agent=agent,
query_files=attached_file_context,
agent=agent,
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
@@ -1156,8 +1156,8 @@ async def event_generator(
partial(send_event, ChatEvent.STATUS),
max_webpages_to_read=1,
query_images=uploaded_images,
agent=agent,
query_files=attached_file_context,
agent=agent,
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
@@ -1197,8 +1197,8 @@ async def event_generator(
user,
partial(send_event, ChatEvent.STATUS),
query_images=uploaded_images,
agent=agent,
query_files=attached_file_context,
agent=agent,
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
@@ -1275,8 +1275,8 @@ async def event_generator(
online_results=online_results,
send_status_func=partial(send_event, ChatEvent.STATUS),
query_images=uploaded_images,
agent=agent,
query_files=attached_file_context,
agent=agent,
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
@@ -1316,10 +1316,10 @@ async def event_generator(
note_references=compiled_references,
online_results=online_results,
query_images=uploaded_images,
query_files=attached_file_context,
user=user,
agent=agent,
send_status_func=partial(send_event, ChatEvent.STATUS),
query_files=attached_file_context,
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:

View File

@@ -415,7 +415,7 @@ async def aget_data_sources_and_output_format(
query_images=query_images,
response_type="json_object",
response_schema=PickTools,
fast_model=False,
fast_model=True,
agent_chat_model=agent_chat_model,
user=user,
tracer=tracer,
@@ -467,8 +467,9 @@ async def infer_webpage_urls(
location_data: LocationData,
user: KhojUser,
query_images: List[str] = None,
agent: Agent = None,
query_files: str = None,
fast_model: bool = True,
agent: Agent = None,
tracer: dict = {},
) -> List[str]:
"""
@@ -505,7 +506,7 @@ async def infer_webpage_urls(
query_images=query_images,
response_type="json_object",
response_schema=WebpageUrls,
fast_model=False,
fast_model=fast_model,
agent_chat_model=agent_chat_model,
user=user,
tracer=tracer,
@@ -534,6 +535,7 @@ async def generate_online_subqueries(
query_images: List[str] = None,
query_files: str = None,
max_queries: int = 3,
fast_model: bool = True,
agent: Agent = None,
tracer: dict = {},
) -> Set[str]:
@@ -571,7 +573,7 @@ async def generate_online_subqueries(
query_images=query_images,
response_type="json_object",
response_schema=OnlineQueries,
fast_model=False,
fast_model=fast_model,
agent_chat_model=agent_chat_model,
user=user,
tracer=tracer,
@@ -737,9 +739,9 @@ async def generate_summary_from_files(
file_filters: List[str],
chat_history: List[ChatMessageModel] = [],
query_images: List[str] = None,
query_files: str = None,
agent: Agent = None,
send_status_func: Optional[Callable] = None,
query_files: str = None,
tracer: dict = {},
):
try:
@@ -797,10 +799,10 @@ async def generate_excalidraw_diagram(
note_references: List[Dict[str, Any]],
online_results: Optional[dict] = None,
query_images: List[str] = None,
query_files: str = None,
user: KhojUser = None,
agent: Agent = None,
send_status_func: Optional[Callable] = None,
query_files: str = None,
tracer: dict = {},
):
if send_status_func:
@@ -814,9 +816,9 @@ async def generate_excalidraw_diagram(
note_references=note_references,
online_results=online_results,
query_images=query_images,
query_files=query_files,
user=user,
agent=agent,
query_files=query_files,
tracer=tracer,
)
@@ -849,9 +851,9 @@ async def generate_better_diagram_description(
note_references: List[Dict[str, Any]],
online_results: Optional[dict] = None,
query_images: List[str] = None,
query_files: str = None,
user: KhojUser = None,
agent: Agent = None,
query_files: str = None,
tracer: dict = {},
) -> str:
"""
@@ -959,10 +961,10 @@ async def generate_mermaidjs_diagram(
note_references: List[Dict[str, Any]],
online_results: Optional[dict] = None,
query_images: List[str] = None,
query_files: str = None,
user: KhojUser = None,
agent: Agent = None,
send_status_func: Optional[Callable] = None,
query_files: str = None,
tracer: dict = {},
):
if send_status_func:
@@ -976,9 +978,9 @@ async def generate_mermaidjs_diagram(
note_references=note_references,
online_results=online_results,
query_images=query_images,
query_files=query_files,
user=user,
agent=agent,
query_files=query_files,
tracer=tracer,
)
@@ -1005,9 +1007,9 @@ async def generate_better_mermaidjs_diagram_description(
note_references: List[Dict[str, Any]],
online_results: Optional[dict] = None,
query_images: List[str] = None,
query_files: str = None,
user: KhojUser = None,
agent: Agent = None,
query_files: str = None,
tracer: dict = {},
) -> str:
"""
@@ -1099,9 +1101,9 @@ async def generate_better_image_prompt(
online_results: Optional[dict] = None,
model_type: Optional[str] = None,
query_images: Optional[List[str]] = None,
query_files: str = "",
user: KhojUser = None,
agent: Agent = None,
query_files: str = "",
tracer: dict = {},
) -> dict:
"""
@@ -1148,7 +1150,7 @@ async def generate_better_image_prompt(
system_message=enhance_image_system_message,
response_type="json_object",
response_schema=ImagePromptResponse,
fast_model=False,
fast_model=True,
agent_chat_model=agent_chat_model,
user=user,
tracer=tracer,
@@ -1175,9 +1177,10 @@ async def search_documents(
location_data: LocationData = None,
send_status_func: Optional[Callable] = None,
query_images: Optional[List[str]] = None,
previous_inferred_queries: Set = set(),
agent: Agent = None,
query_files: str = None,
previous_inferred_queries: Set = set(),
fast_model: bool = True,
agent: Agent = None,
tracer: dict = {},
):
# Initialize Variables
@@ -1228,6 +1231,7 @@ async def search_documents(
personality_context=personality_context,
location_data=location_data,
chat_history=chat_history,
fast_model=fast_model,
agent=agent,
tracer=tracer,
)
@@ -1280,6 +1284,7 @@ async def extract_questions(
location_data: LocationData = None,
chat_history: List[ChatMessageModel] = [],
max_queries: int = 5,
fast_model: bool = True,
agent: Agent = None,
tracer: dict = {},
):
@@ -1334,7 +1339,7 @@ async def extract_questions(
system_message=system_prompt,
response_type="json_object",
response_schema=DocumentQueries,
fast_model=False,
fast_model=fast_model,
agent_chat_model=agent_chat_model,
user=user,
tracer=tracer,

View File

@@ -243,14 +243,14 @@ async def research(
conversation_history: List[ChatMessageModel],
previous_iterations: List[ResearchIteration],
query_images: List[str],
agent: Agent = None,
send_status_func: Optional[Callable] = None,
query_files: str = None,
user_name: str = None,
location: LocationData = None,
query_files: str = None,
send_status_func: Optional[Callable] = None,
cancellation_event: Optional[asyncio.Event] = None,
interrupt_queue: Optional[asyncio.Queue] = None,
abort_message: str = ChatEvent.END_EVENT.value,
agent: Agent = None,
tracer: dict = {},
):
max_document_searches = 7
@@ -356,10 +356,10 @@ async def research(
location_data=location,
send_status_func=send_status_func,
query_images=query_images,
query_files=query_files,
previous_inferred_queries=previous_inferred_queries,
agent=agent,
tracer=tracer,
query_files=query_files,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
@@ -424,7 +424,6 @@ async def research(
**this_iteration.query.args,
user=user,
send_status_func=send_status_func,
# max_webpages_to_read=max_webpages_to_read,
agent=agent,
tracer=tracer,
):
@@ -459,8 +458,8 @@ async def research(
user=user,
send_status_func=send_status_func,
query_images=query_images,
agent=agent,
query_files=query_files,
agent=agent,
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result: