mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 05:29:12 +00:00
Use fast model in default mode and for most chat actors
What -- - Default to using fast model for most chat actors. Specifically in this change we default to using fast model for doc, web search chat actors - Only research chat director uses the deep chat model. - Make using fast model by chat actors configurable via func argument Code chat actor continues to use deep chat model and webpage reader continues to use fast chat model. Deep, fast chat models can be configured via ServerChatSettings on the admin panel. Why -- Modern models are good enough at instruction following. So defaulting most chat actor to use the fast model should improve chat speed with acceptable response quality. The option to fallback to research mode for higher quality responses or deeper research always exists.
This commit is contained in:
@@ -46,8 +46,8 @@ async def text_to_image(
|
||||
online_results: Dict[str, Any],
|
||||
send_status_func: Optional[Callable] = None,
|
||||
query_images: Optional[List[str]] = None,
|
||||
agent: Agent = None,
|
||||
query_files: str = None,
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
status_code = 200
|
||||
@@ -90,9 +90,9 @@ async def text_to_image(
|
||||
online_results=online_results,
|
||||
model_type=text_to_image_config.model_type,
|
||||
query_images=query_images,
|
||||
query_files=query_files,
|
||||
user=user,
|
||||
agent=agent,
|
||||
query_files=query_files,
|
||||
tracer=tracer,
|
||||
)
|
||||
image_prompt = image_prompt_response["description"]
|
||||
|
||||
@@ -64,6 +64,7 @@ async def search_online(
|
||||
query_images: List[str] = None,
|
||||
query_files: str = None,
|
||||
previous_subqueries: Set = set(),
|
||||
fast_model: bool = True,
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
@@ -82,6 +83,7 @@ async def search_online(
|
||||
query_images=query_images,
|
||||
query_files=query_files,
|
||||
max_queries=max_online_searches,
|
||||
fast_model=fast_model,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
)
|
||||
@@ -432,9 +434,10 @@ async def read_webpages(
|
||||
user: KhojUser,
|
||||
send_status_func: Optional[Callable] = None,
|
||||
query_images: List[str] = None,
|
||||
query_files: str = None,
|
||||
fast_model: bool = True,
|
||||
agent: Agent = None,
|
||||
max_webpages_to_read: int = 1,
|
||||
query_files: str = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
"Infer web pages to read from the query and extract relevant information from them"
|
||||
@@ -446,8 +449,9 @@ async def read_webpages(
|
||||
location,
|
||||
user,
|
||||
query_images,
|
||||
agent=agent,
|
||||
query_files=query_files,
|
||||
fast_model=fast_model,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
)
|
||||
async for result in read_webpages_content(
|
||||
|
||||
@@ -56,9 +56,9 @@ async def run_code(
|
||||
user: KhojUser,
|
||||
send_status_func: Optional[Callable] = None,
|
||||
query_images: List[str] = None,
|
||||
query_files: str = None,
|
||||
agent: Agent = None,
|
||||
sandbox_url: str = SANDBOX_URL,
|
||||
query_files: str = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
# Generate Code
|
||||
|
||||
@@ -1023,14 +1023,14 @@ async def event_generator(
|
||||
conversation_history=chat_history,
|
||||
previous_iterations=list(research_results),
|
||||
query_images=uploaded_images,
|
||||
agent=agent,
|
||||
send_status_func=partial(send_event, ChatEvent.STATUS),
|
||||
query_files=attached_file_context,
|
||||
user_name=user_name,
|
||||
location=location,
|
||||
query_files=attached_file_context,
|
||||
send_status_func=partial(send_event, ChatEvent.STATUS),
|
||||
cancellation_event=cancellation_event,
|
||||
interrupt_queue=child_interrupt_queue,
|
||||
abort_message=ChatEvent.END_EVENT.value,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
):
|
||||
if isinstance(research_result, ResearchIteration):
|
||||
@@ -1080,8 +1080,8 @@ async def event_generator(
|
||||
location,
|
||||
partial(send_event, ChatEvent.STATUS),
|
||||
query_images=uploaded_images,
|
||||
agent=agent,
|
||||
query_files=attached_file_context,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
@@ -1156,8 +1156,8 @@ async def event_generator(
|
||||
partial(send_event, ChatEvent.STATUS),
|
||||
max_webpages_to_read=1,
|
||||
query_images=uploaded_images,
|
||||
agent=agent,
|
||||
query_files=attached_file_context,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
@@ -1197,8 +1197,8 @@ async def event_generator(
|
||||
user,
|
||||
partial(send_event, ChatEvent.STATUS),
|
||||
query_images=uploaded_images,
|
||||
agent=agent,
|
||||
query_files=attached_file_context,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
@@ -1275,8 +1275,8 @@ async def event_generator(
|
||||
online_results=online_results,
|
||||
send_status_func=partial(send_event, ChatEvent.STATUS),
|
||||
query_images=uploaded_images,
|
||||
agent=agent,
|
||||
query_files=attached_file_context,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
@@ -1316,10 +1316,10 @@ async def event_generator(
|
||||
note_references=compiled_references,
|
||||
online_results=online_results,
|
||||
query_images=uploaded_images,
|
||||
query_files=attached_file_context,
|
||||
user=user,
|
||||
agent=agent,
|
||||
send_status_func=partial(send_event, ChatEvent.STATUS),
|
||||
query_files=attached_file_context,
|
||||
tracer=tracer,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
|
||||
@@ -415,7 +415,7 @@ async def aget_data_sources_and_output_format(
|
||||
query_images=query_images,
|
||||
response_type="json_object",
|
||||
response_schema=PickTools,
|
||||
fast_model=False,
|
||||
fast_model=True,
|
||||
agent_chat_model=agent_chat_model,
|
||||
user=user,
|
||||
tracer=tracer,
|
||||
@@ -467,8 +467,9 @@ async def infer_webpage_urls(
|
||||
location_data: LocationData,
|
||||
user: KhojUser,
|
||||
query_images: List[str] = None,
|
||||
agent: Agent = None,
|
||||
query_files: str = None,
|
||||
fast_model: bool = True,
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
) -> List[str]:
|
||||
"""
|
||||
@@ -505,7 +506,7 @@ async def infer_webpage_urls(
|
||||
query_images=query_images,
|
||||
response_type="json_object",
|
||||
response_schema=WebpageUrls,
|
||||
fast_model=False,
|
||||
fast_model=fast_model,
|
||||
agent_chat_model=agent_chat_model,
|
||||
user=user,
|
||||
tracer=tracer,
|
||||
@@ -534,6 +535,7 @@ async def generate_online_subqueries(
|
||||
query_images: List[str] = None,
|
||||
query_files: str = None,
|
||||
max_queries: int = 3,
|
||||
fast_model: bool = True,
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
) -> Set[str]:
|
||||
@@ -571,7 +573,7 @@ async def generate_online_subqueries(
|
||||
query_images=query_images,
|
||||
response_type="json_object",
|
||||
response_schema=OnlineQueries,
|
||||
fast_model=False,
|
||||
fast_model=fast_model,
|
||||
agent_chat_model=agent_chat_model,
|
||||
user=user,
|
||||
tracer=tracer,
|
||||
@@ -737,9 +739,9 @@ async def generate_summary_from_files(
|
||||
file_filters: List[str],
|
||||
chat_history: List[ChatMessageModel] = [],
|
||||
query_images: List[str] = None,
|
||||
query_files: str = None,
|
||||
agent: Agent = None,
|
||||
send_status_func: Optional[Callable] = None,
|
||||
query_files: str = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
try:
|
||||
@@ -797,10 +799,10 @@ async def generate_excalidraw_diagram(
|
||||
note_references: List[Dict[str, Any]],
|
||||
online_results: Optional[dict] = None,
|
||||
query_images: List[str] = None,
|
||||
query_files: str = None,
|
||||
user: KhojUser = None,
|
||||
agent: Agent = None,
|
||||
send_status_func: Optional[Callable] = None,
|
||||
query_files: str = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
if send_status_func:
|
||||
@@ -814,9 +816,9 @@ async def generate_excalidraw_diagram(
|
||||
note_references=note_references,
|
||||
online_results=online_results,
|
||||
query_images=query_images,
|
||||
query_files=query_files,
|
||||
user=user,
|
||||
agent=agent,
|
||||
query_files=query_files,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
@@ -849,9 +851,9 @@ async def generate_better_diagram_description(
|
||||
note_references: List[Dict[str, Any]],
|
||||
online_results: Optional[dict] = None,
|
||||
query_images: List[str] = None,
|
||||
query_files: str = None,
|
||||
user: KhojUser = None,
|
||||
agent: Agent = None,
|
||||
query_files: str = None,
|
||||
tracer: dict = {},
|
||||
) -> str:
|
||||
"""
|
||||
@@ -959,10 +961,10 @@ async def generate_mermaidjs_diagram(
|
||||
note_references: List[Dict[str, Any]],
|
||||
online_results: Optional[dict] = None,
|
||||
query_images: List[str] = None,
|
||||
query_files: str = None,
|
||||
user: KhojUser = None,
|
||||
agent: Agent = None,
|
||||
send_status_func: Optional[Callable] = None,
|
||||
query_files: str = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
if send_status_func:
|
||||
@@ -976,9 +978,9 @@ async def generate_mermaidjs_diagram(
|
||||
note_references=note_references,
|
||||
online_results=online_results,
|
||||
query_images=query_images,
|
||||
query_files=query_files,
|
||||
user=user,
|
||||
agent=agent,
|
||||
query_files=query_files,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
@@ -1005,9 +1007,9 @@ async def generate_better_mermaidjs_diagram_description(
|
||||
note_references: List[Dict[str, Any]],
|
||||
online_results: Optional[dict] = None,
|
||||
query_images: List[str] = None,
|
||||
query_files: str = None,
|
||||
user: KhojUser = None,
|
||||
agent: Agent = None,
|
||||
query_files: str = None,
|
||||
tracer: dict = {},
|
||||
) -> str:
|
||||
"""
|
||||
@@ -1099,9 +1101,9 @@ async def generate_better_image_prompt(
|
||||
online_results: Optional[dict] = None,
|
||||
model_type: Optional[str] = None,
|
||||
query_images: Optional[List[str]] = None,
|
||||
query_files: str = "",
|
||||
user: KhojUser = None,
|
||||
agent: Agent = None,
|
||||
query_files: str = "",
|
||||
tracer: dict = {},
|
||||
) -> dict:
|
||||
"""
|
||||
@@ -1148,7 +1150,7 @@ async def generate_better_image_prompt(
|
||||
system_message=enhance_image_system_message,
|
||||
response_type="json_object",
|
||||
response_schema=ImagePromptResponse,
|
||||
fast_model=False,
|
||||
fast_model=True,
|
||||
agent_chat_model=agent_chat_model,
|
||||
user=user,
|
||||
tracer=tracer,
|
||||
@@ -1175,9 +1177,10 @@ async def search_documents(
|
||||
location_data: LocationData = None,
|
||||
send_status_func: Optional[Callable] = None,
|
||||
query_images: Optional[List[str]] = None,
|
||||
previous_inferred_queries: Set = set(),
|
||||
agent: Agent = None,
|
||||
query_files: str = None,
|
||||
previous_inferred_queries: Set = set(),
|
||||
fast_model: bool = True,
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
# Initialize Variables
|
||||
@@ -1228,6 +1231,7 @@ async def search_documents(
|
||||
personality_context=personality_context,
|
||||
location_data=location_data,
|
||||
chat_history=chat_history,
|
||||
fast_model=fast_model,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
)
|
||||
@@ -1280,6 +1284,7 @@ async def extract_questions(
|
||||
location_data: LocationData = None,
|
||||
chat_history: List[ChatMessageModel] = [],
|
||||
max_queries: int = 5,
|
||||
fast_model: bool = True,
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
@@ -1334,7 +1339,7 @@ async def extract_questions(
|
||||
system_message=system_prompt,
|
||||
response_type="json_object",
|
||||
response_schema=DocumentQueries,
|
||||
fast_model=False,
|
||||
fast_model=fast_model,
|
||||
agent_chat_model=agent_chat_model,
|
||||
user=user,
|
||||
tracer=tracer,
|
||||
|
||||
@@ -243,14 +243,14 @@ async def research(
|
||||
conversation_history: List[ChatMessageModel],
|
||||
previous_iterations: List[ResearchIteration],
|
||||
query_images: List[str],
|
||||
agent: Agent = None,
|
||||
send_status_func: Optional[Callable] = None,
|
||||
query_files: str = None,
|
||||
user_name: str = None,
|
||||
location: LocationData = None,
|
||||
query_files: str = None,
|
||||
send_status_func: Optional[Callable] = None,
|
||||
cancellation_event: Optional[asyncio.Event] = None,
|
||||
interrupt_queue: Optional[asyncio.Queue] = None,
|
||||
abort_message: str = ChatEvent.END_EVENT.value,
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
max_document_searches = 7
|
||||
@@ -356,10 +356,10 @@ async def research(
|
||||
location_data=location,
|
||||
send_status_func=send_status_func,
|
||||
query_images=query_images,
|
||||
query_files=query_files,
|
||||
previous_inferred_queries=previous_inferred_queries,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
query_files=query_files,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
@@ -424,7 +424,6 @@ async def research(
|
||||
**this_iteration.query.args,
|
||||
user=user,
|
||||
send_status_func=send_status_func,
|
||||
# max_webpages_to_read=max_webpages_to_read,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
):
|
||||
@@ -459,8 +458,8 @@ async def research(
|
||||
user=user,
|
||||
send_status_func=send_status_func,
|
||||
query_images=query_images,
|
||||
agent=agent,
|
||||
query_files=query_files,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
|
||||
Reference in New Issue
Block a user