diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index a9a6f09f..cc69930e 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -45,7 +45,7 @@ from khoj.routers.helpers import ( aget_relevant_output_modes, construct_automation_created_message, create_automation, - gather_attached_files, + gather_raw_attached_files, generate_excalidraw_diagram, generate_summary_from_files, get_conversation_command, @@ -71,7 +71,12 @@ from khoj.utils.helpers import ( get_device, is_none_or_empty, ) -from khoj.utils.rawconfig import FileFilterRequest, FilesFilterRequest, LocationData +from khoj.utils.rawconfig import ( + ChatRequestBody, + FileFilterRequest, + FilesFilterRequest, + LocationData, +) # Initialize Router logger = logging.getLogger(__name__) @@ -566,6 +571,7 @@ async def chat( country_code = body.country_code or get_country_code_from_timezone(body.timezone) timezone = body.timezone raw_images = body.images + raw_attached_files = body.files async def event_generator(q: str, images: list[str]): start_time = time.perf_counter() @@ -577,6 +583,7 @@ async def chat( q = unquote(q) train_of_thought = [] nonlocal conversation_id + nonlocal raw_attached_files tracer: dict = { "mid": turn_id, @@ -596,6 +603,11 @@ async def chat( if uploaded_image: uploaded_images.append(uploaded_image) + attached_files: Dict[str, str] = {} + if raw_attached_files: + for file in raw_attached_files: + attached_files[file.name] = file.content + async def send_event(event_type: ChatEvent, data: str | dict): nonlocal connection_alive, ttft, train_of_thought if not connection_alive or await request.is_disconnected(): @@ -707,7 +719,7 @@ async def chat( compiled_references: List[Any] = [] inferred_queries: List[Any] = [] file_filters = conversation.file_filters if conversation and conversation.file_filters else [] - attached_file_context = await gather_attached_files(user, file_filters) + attached_file_context = gather_raw_attached_files(attached_files) if conversation_commands == [ConversationCommand.Default] or is_automated_task: conversation_commands = await aget_relevant_information_sources( @@ -833,6 +845,7 @@ async def chat( query_images=uploaded_images, tracer=tracer, train_of_thought=train_of_thought, + raw_attached_files=raw_attached_files, ) return @@ -878,6 +891,7 @@ async def chat( query_images=uploaded_images, tracer=tracer, train_of_thought=train_of_thought, + raw_attached_files=raw_attached_files, ) async for result in send_llm_response(llm_response): yield result @@ -900,6 +914,7 @@ async def chat( query_images=uploaded_images, agent=agent, tracer=tracer, + attached_files=attached_file_context, ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] @@ -1085,6 +1100,8 @@ async def chat( query_images=uploaded_images, tracer=tracer, train_of_thought=train_of_thought, + attached_file_context=attached_file_context, + raw_attached_files=raw_attached_files, ) content_obj = { "intentType": intent_type, @@ -1144,6 +1161,8 @@ async def chat( query_images=uploaded_images, tracer=tracer, train_of_thought=train_of_thought, + attached_file_context=attached_file_context, + raw_attached_files=raw_attached_files, ) async for result in send_llm_response(json.dumps(content_obj)): @@ -1172,6 +1191,7 @@ async def chat( tracer, train_of_thought, attached_file_context, + raw_attached_files, ) # Send Response