diff --git a/src/khoj/processor/conversation/anthropic/anthropic_chat.py b/src/khoj/processor/conversation/anthropic/anthropic_chat.py index 6989f4c1..fac8dfa4 100644 --- a/src/khoj/processor/conversation/anthropic/anthropic_chat.py +++ b/src/khoj/processor/conversation/anthropic/anthropic_chat.py @@ -83,13 +83,11 @@ def extract_questions_anthropic( images=query_images, model_type=ChatModelOptions.ModelType.ANTHROPIC, vision_enabled=vision_enabled, + attached_file_context=attached_files, ) messages = [] - if attached_files: - messages.append(ChatMessage(content=attached_files, role="user")) - messages.append(ChatMessage(content=prompt, role="user")) messages, system_prompt = format_messages_for_anthropic(messages, system_prompt) diff --git a/src/khoj/processor/conversation/google/gemini_chat.py b/src/khoj/processor/conversation/google/gemini_chat.py index e4de609f..b7a7739f 100644 --- a/src/khoj/processor/conversation/google/gemini_chat.py +++ b/src/khoj/processor/conversation/google/gemini_chat.py @@ -84,13 +84,11 @@ def extract_questions_gemini( images=query_images, model_type=ChatModelOptions.ModelType.GOOGLE, vision_enabled=vision_enabled, + attached_file_context=attached_files, ) messages = [] - if attached_files: - messages.append(ChatMessage(content=attached_files, role="user")) - messages.append(ChatMessage(content=prompt, role="user")) messages.append(ChatMessage(content=system_prompt, role="system")) diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index f2919afb..70d208d8 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -80,13 +80,10 @@ def extract_questions( images=query_images, model_type=ChatModelOptions.ModelType.OPENAI, vision_enabled=vision_enabled, + attached_file_context=attached_files, ) messages = [] - - if attached_files: - messages.append(ChatMessage(content=attached_files, role="user")) - messages.append(ChatMessage(content=prompt, role="user")) response = send_message_to_model( diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 27e23a88..7187acdb 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -271,23 +271,31 @@ Khoj: "{inferred_queries if ("text-to-image" in intent_type) else chat_response} ) -def construct_structured_message(message: str, images: list[str], model_type: str, vision_enabled: bool): +def construct_structured_message( + message: str, images: list[str], model_type: str, vision_enabled: bool, attached_file_context: str +): """ Format messages into appropriate multimedia format for supported chat model types """ if not images or not vision_enabled: return message + constructed_messages = [ + {"type": "text", "text": message}, + ] + + if not is_none_or_empty(attached_file_context): + constructed_messages.append({"type": "text", "text": attached_file_context}) + if model_type in [ ChatModelOptions.ModelType.OPENAI, ChatModelOptions.ModelType.GOOGLE, ChatModelOptions.ModelType.ANTHROPIC, ]: - return [ - {"type": "text", "text": message}, - *[{"type": "image_url", "image_url": {"url": image}} for image in images], - ] - return message + for image in images: + constructed_messages.append({"type": "image_url", "image_url": {"url": image}}) + + return constructed_messages def gather_raw_attached_files( @@ -362,7 +370,9 @@ def generate_chatml_messages_with_context( chatml_messages.insert(0, reconstructed_context_message) role = "user" if chat["by"] == "you" else "assistant" - message_content = construct_structured_message(chat["message"], chat.get("images"), model_type, vision_enabled) + message_content = construct_structured_message( + chat["message"], chat.get("images"), model_type, vision_enabled, attached_file_context=attached_files + ) reconstructed_message = ChatMessage(content=message_content, role=role) chatml_messages.insert(0, reconstructed_message) @@ -374,16 +384,15 @@ def generate_chatml_messages_with_context( if not is_none_or_empty(user_message): messages.append( ChatMessage( - content=construct_structured_message(user_message, query_images, model_type, vision_enabled), + content=construct_structured_message( + user_message, query_images, model_type, vision_enabled, attached_files + ), role="user", ) ) if not is_none_or_empty(context_message): messages.append(ChatMessage(content=context_message, role="user")) - if not is_none_or_empty(attached_files): - messages.append(ChatMessage(content=attached_files, role="user")) - if len(chatml_messages) > 0: messages += chatml_messages