diff --git a/src/khoj/processor/image/generate.py b/src/khoj/processor/image/generate.py index 4ab1f241..6938cdf5 100644 --- a/src/khoj/processor/image/generate.py +++ b/src/khoj/processor/image/generate.py @@ -46,8 +46,8 @@ async def text_to_image( online_results: Dict[str, Any], send_status_func: Optional[Callable] = None, query_images: Optional[List[str]] = None, - agent: Agent = None, query_files: str = None, + agent: Agent = None, tracer: dict = {}, ): status_code = 200 @@ -90,9 +90,9 @@ async def text_to_image( online_results=online_results, model_type=text_to_image_config.model_type, query_images=query_images, + query_files=query_files, user=user, agent=agent, - query_files=query_files, tracer=tracer, ) image_prompt = image_prompt_response["description"] diff --git a/src/khoj/processor/tools/online_search.py b/src/khoj/processor/tools/online_search.py index 01623efa..beeb640c 100644 --- a/src/khoj/processor/tools/online_search.py +++ b/src/khoj/processor/tools/online_search.py @@ -64,6 +64,7 @@ async def search_online( query_images: List[str] = None, query_files: str = None, previous_subqueries: Set = set(), + fast_model: bool = True, agent: Agent = None, tracer: dict = {}, ): @@ -82,6 +83,7 @@ async def search_online( query_images=query_images, query_files=query_files, max_queries=max_online_searches, + fast_model=fast_model, agent=agent, tracer=tracer, ) @@ -432,9 +434,10 @@ async def read_webpages( user: KhojUser, send_status_func: Optional[Callable] = None, query_images: List[str] = None, + query_files: str = None, + fast_model: bool = True, agent: Agent = None, max_webpages_to_read: int = 1, - query_files: str = None, tracer: dict = {}, ): "Infer web pages to read from the query and extract relevant information from them" @@ -446,8 +449,9 @@ async def read_webpages( location, user, query_images, - agent=agent, query_files=query_files, + fast_model=fast_model, + agent=agent, tracer=tracer, ) async for result in read_webpages_content( diff --git a/src/khoj/processor/tools/run_code.py b/src/khoj/processor/tools/run_code.py index c62c710e..15e77d9c 100644 --- a/src/khoj/processor/tools/run_code.py +++ b/src/khoj/processor/tools/run_code.py @@ -56,9 +56,9 @@ async def run_code( user: KhojUser, send_status_func: Optional[Callable] = None, query_images: List[str] = None, + query_files: str = None, agent: Agent = None, sandbox_url: str = SANDBOX_URL, - query_files: str = None, tracer: dict = {}, ): # Generate Code diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 68e7ea43..d622936f 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -1023,14 +1023,14 @@ async def event_generator( conversation_history=chat_history, previous_iterations=list(research_results), query_images=uploaded_images, - agent=agent, - send_status_func=partial(send_event, ChatEvent.STATUS), + query_files=attached_file_context, user_name=user_name, location=location, - query_files=attached_file_context, + send_status_func=partial(send_event, ChatEvent.STATUS), cancellation_event=cancellation_event, interrupt_queue=child_interrupt_queue, abort_message=ChatEvent.END_EVENT.value, + agent=agent, tracer=tracer, ): if isinstance(research_result, ResearchIteration): @@ -1080,8 +1080,8 @@ async def event_generator( location, partial(send_event, ChatEvent.STATUS), query_images=uploaded_images, - agent=agent, query_files=attached_file_context, + agent=agent, tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: @@ -1156,8 +1156,8 @@ async def event_generator( partial(send_event, ChatEvent.STATUS), max_webpages_to_read=1, query_images=uploaded_images, - agent=agent, query_files=attached_file_context, + agent=agent, tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: @@ -1197,8 +1197,8 @@ async def event_generator( user, partial(send_event, ChatEvent.STATUS), query_images=uploaded_images, - agent=agent, query_files=attached_file_context, + agent=agent, tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: @@ -1275,8 +1275,8 @@ async def event_generator( online_results=online_results, send_status_func=partial(send_event, ChatEvent.STATUS), query_images=uploaded_images, - agent=agent, query_files=attached_file_context, + agent=agent, tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: @@ -1316,10 +1316,10 @@ async def event_generator( note_references=compiled_references, online_results=online_results, query_images=uploaded_images, + query_files=attached_file_context, user=user, agent=agent, send_status_func=partial(send_event, ChatEvent.STATUS), - query_files=attached_file_context, tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index d93e8a49..8e64c1ea 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -415,7 +415,7 @@ async def aget_data_sources_and_output_format( query_images=query_images, response_type="json_object", response_schema=PickTools, - fast_model=False, + fast_model=True, agent_chat_model=agent_chat_model, user=user, tracer=tracer, @@ -467,8 +467,9 @@ async def infer_webpage_urls( location_data: LocationData, user: KhojUser, query_images: List[str] = None, - agent: Agent = None, query_files: str = None, + fast_model: bool = True, + agent: Agent = None, tracer: dict = {}, ) -> List[str]: """ @@ -505,7 +506,7 @@ async def infer_webpage_urls( query_images=query_images, response_type="json_object", response_schema=WebpageUrls, - fast_model=False, + fast_model=fast_model, agent_chat_model=agent_chat_model, user=user, tracer=tracer, @@ -534,6 +535,7 @@ async def generate_online_subqueries( query_images: List[str] = None, query_files: str = None, max_queries: int = 3, + fast_model: bool = True, agent: Agent = None, tracer: dict = {}, ) -> Set[str]: @@ -571,7 +573,7 @@ async def generate_online_subqueries( query_images=query_images, response_type="json_object", response_schema=OnlineQueries, - fast_model=False, + fast_model=fast_model, agent_chat_model=agent_chat_model, user=user, tracer=tracer, @@ -737,9 +739,9 @@ async def generate_summary_from_files( file_filters: List[str], chat_history: List[ChatMessageModel] = [], query_images: List[str] = None, + query_files: str = None, agent: Agent = None, send_status_func: Optional[Callable] = None, - query_files: str = None, tracer: dict = {}, ): try: @@ -797,10 +799,10 @@ async def generate_excalidraw_diagram( note_references: List[Dict[str, Any]], online_results: Optional[dict] = None, query_images: List[str] = None, + query_files: str = None, user: KhojUser = None, agent: Agent = None, send_status_func: Optional[Callable] = None, - query_files: str = None, tracer: dict = {}, ): if send_status_func: @@ -814,9 +816,9 @@ async def generate_excalidraw_diagram( note_references=note_references, online_results=online_results, query_images=query_images, + query_files=query_files, user=user, agent=agent, - query_files=query_files, tracer=tracer, ) @@ -849,9 +851,9 @@ async def generate_better_diagram_description( note_references: List[Dict[str, Any]], online_results: Optional[dict] = None, query_images: List[str] = None, + query_files: str = None, user: KhojUser = None, agent: Agent = None, - query_files: str = None, tracer: dict = {}, ) -> str: """ @@ -959,10 +961,10 @@ async def generate_mermaidjs_diagram( note_references: List[Dict[str, Any]], online_results: Optional[dict] = None, query_images: List[str] = None, + query_files: str = None, user: KhojUser = None, agent: Agent = None, send_status_func: Optional[Callable] = None, - query_files: str = None, tracer: dict = {}, ): if send_status_func: @@ -976,9 +978,9 @@ async def generate_mermaidjs_diagram( note_references=note_references, online_results=online_results, query_images=query_images, + query_files=query_files, user=user, agent=agent, - query_files=query_files, tracer=tracer, ) @@ -1005,9 +1007,9 @@ async def generate_better_mermaidjs_diagram_description( note_references: List[Dict[str, Any]], online_results: Optional[dict] = None, query_images: List[str] = None, + query_files: str = None, user: KhojUser = None, agent: Agent = None, - query_files: str = None, tracer: dict = {}, ) -> str: """ @@ -1099,9 +1101,9 @@ async def generate_better_image_prompt( online_results: Optional[dict] = None, model_type: Optional[str] = None, query_images: Optional[List[str]] = None, + query_files: str = "", user: KhojUser = None, agent: Agent = None, - query_files: str = "", tracer: dict = {}, ) -> dict: """ @@ -1148,7 +1150,7 @@ async def generate_better_image_prompt( system_message=enhance_image_system_message, response_type="json_object", response_schema=ImagePromptResponse, - fast_model=False, + fast_model=True, agent_chat_model=agent_chat_model, user=user, tracer=tracer, @@ -1175,9 +1177,10 @@ async def search_documents( location_data: LocationData = None, send_status_func: Optional[Callable] = None, query_images: Optional[List[str]] = None, - previous_inferred_queries: Set = set(), - agent: Agent = None, query_files: str = None, + previous_inferred_queries: Set = set(), + fast_model: bool = True, + agent: Agent = None, tracer: dict = {}, ): # Initialize Variables @@ -1228,6 +1231,7 @@ async def search_documents( personality_context=personality_context, location_data=location_data, chat_history=chat_history, + fast_model=fast_model, agent=agent, tracer=tracer, ) @@ -1280,6 +1284,7 @@ async def extract_questions( location_data: LocationData = None, chat_history: List[ChatMessageModel] = [], max_queries: int = 5, + fast_model: bool = True, agent: Agent = None, tracer: dict = {}, ): @@ -1334,7 +1339,7 @@ async def extract_questions( system_message=system_prompt, response_type="json_object", response_schema=DocumentQueries, - fast_model=False, + fast_model=fast_model, agent_chat_model=agent_chat_model, user=user, tracer=tracer, diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index 63e6b71d..ba053711 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -243,14 +243,14 @@ async def research( conversation_history: List[ChatMessageModel], previous_iterations: List[ResearchIteration], query_images: List[str], - agent: Agent = None, - send_status_func: Optional[Callable] = None, + query_files: str = None, user_name: str = None, location: LocationData = None, - query_files: str = None, + send_status_func: Optional[Callable] = None, cancellation_event: Optional[asyncio.Event] = None, interrupt_queue: Optional[asyncio.Queue] = None, abort_message: str = ChatEvent.END_EVENT.value, + agent: Agent = None, tracer: dict = {}, ): max_document_searches = 7 @@ -356,10 +356,10 @@ async def research( location_data=location, send_status_func=send_status_func, query_images=query_images, + query_files=query_files, previous_inferred_queries=previous_inferred_queries, agent=agent, tracer=tracer, - query_files=query_files, ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] @@ -424,7 +424,6 @@ async def research( **this_iteration.query.args, user=user, send_status_func=send_status_func, - # max_webpages_to_read=max_webpages_to_read, agent=agent, tracer=tracer, ): @@ -459,8 +458,8 @@ async def research( user=user, send_status_func=send_status_func, query_images=query_images, - agent=agent, query_files=query_files, + agent=agent, tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: