diff --git a/src/khoj/processor/conversation/anthropic/anthropic_chat.py b/src/khoj/processor/conversation/anthropic/anthropic_chat.py index ef8539b3..94d8df03 100644 --- a/src/khoj/processor/conversation/anthropic/anthropic_chat.py +++ b/src/khoj/processor/conversation/anthropic/anthropic_chat.py @@ -6,7 +6,7 @@ from typing import Dict, Optional from langchain.schema import ChatMessage -from khoj.database.models import Agent, KhojUser +from khoj.database.models import Agent, ChatModelOptions, KhojUser from khoj.processor.conversation import prompts from khoj.processor.conversation.anthropic.utils import ( anthropic_chat_completion_with_backoff, @@ -188,6 +188,7 @@ def converse_anthropic( model_name=model, max_prompt_size=max_prompt_size, tokenizer_name=tokenizer_name, + model_type=ChatModelOptions.ModelType.ANTHROPIC, ) if len(messages) > 1: diff --git a/src/khoj/processor/conversation/offline/chat_model.py b/src/khoj/processor/conversation/offline/chat_model.py index 8a58e181..febe3786 100644 --- a/src/khoj/processor/conversation/offline/chat_model.py +++ b/src/khoj/processor/conversation/offline/chat_model.py @@ -7,7 +7,7 @@ from typing import Any, Iterator, List, Union from langchain.schema import ChatMessage from llama_cpp import Llama -from khoj.database.models import Agent, KhojUser +from khoj.database.models import Agent, ChatModelOptions, KhojUser from khoj.processor.conversation import prompts from khoj.processor.conversation.offline.utils import download_model from khoj.processor.conversation.utils import ( @@ -76,7 +76,11 @@ def extract_questions_offline( ) messages = generate_chatml_messages_with_context( - example_questions, model_name=model, loaded_model=offline_chat_model, max_prompt_size=max_prompt_size + example_questions, + model_name=model, + loaded_model=offline_chat_model, + max_prompt_size=max_prompt_size, + model_type=ChatModelOptions.ModelType.OFFLINE, ) state.chat_lock.acquire() @@ -201,6 +205,7 @@ def converse_offline( loaded_model=offline_chat_model, max_prompt_size=max_prompt_size, tokenizer_name=tokenizer_name, + model_type=ChatModelOptions.ModelType.OFFLINE, ) truncated_messages = "\n".join({f"{message.content[:70]}..." for message in messages}) diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index f2e87884..619bf9d6 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -5,13 +5,16 @@ from typing import Dict, Optional from langchain.schema import ChatMessage -from khoj.database.models import Agent, KhojUser +from khoj.database.models import Agent, ChatModelOptions, KhojUser from khoj.processor.conversation import prompts from khoj.processor.conversation.openai.utils import ( chat_completion_with_backoff, completion_with_backoff, ) -from khoj.processor.conversation.utils import generate_chatml_messages_with_context +from khoj.processor.conversation.utils import ( + construct_structured_message, + generate_chatml_messages_with_context, +) from khoj.utils.helpers import ConversationCommand, is_none_or_empty from khoj.utils.rawconfig import LocationData @@ -24,9 +27,10 @@ def extract_questions( conversation_log={}, api_key=None, api_base_url=None, - temperature=0.7, location_data: LocationData = None, user: KhojUser = None, + uploaded_image_url: Optional[str] = None, + vision_enabled: bool = False, ): """ Infer search queries to retrieve relevant notes to answer user query @@ -63,17 +67,17 @@ def extract_questions( location=location, username=username, ) + + prompt = construct_structured_message( + message=prompt, + image_url=uploaded_image_url, + model_type=ChatModelOptions.ModelType.OPENAI, + vision_enabled=vision_enabled, + ) + messages = [ChatMessage(content=prompt, role="user")] - # Get Response from GPT - response = completion_with_backoff( - messages=messages, - model=model, - temperature=temperature, - api_base_url=api_base_url, - model_kwargs={"response_format": {"type": "json_object"}}, - openai_api_key=api_key, - ) + response = send_message_to_model(messages, api_key, model, response_type="json_object", api_base_url=api_base_url) # Extract, Clean Message from GPT's Response try: @@ -182,6 +186,7 @@ def converse( tokenizer_name=tokenizer_name, uploaded_image_url=image_url, vision_enabled=vision_available, + model_type=ChatModelOptions.ModelType.OPENAI, ) truncated_messages = "\n".join({f"{message.content[:70]}..." for message in messages}) logger.debug(f"Conversation Context for GPT: {truncated_messages}") diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 8b4e34e6..54bf6334 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -12,7 +12,7 @@ from llama_cpp.llama import Llama from transformers import AutoTokenizer from khoj.database.adapters import ConversationAdapters -from khoj.database.models import ClientApplication, KhojUser +from khoj.database.models import ChatModelOptions, ClientApplication, KhojUser from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens from khoj.utils import state from khoj.utils.helpers import is_none_or_empty, merge_dicts @@ -137,6 +137,13 @@ Khoj: "{inferred_queries if ("text-to-image" in intent_type) else chat_response} ) +# Format user and system messages to chatml format +def construct_structured_message(message, image_url, model_type, vision_enabled): + if image_url and vision_enabled and model_type == ChatModelOptions.ModelType.OPENAI: + return [{"type": "text", "text": message}, {"type": "image_url", "image_url": {"url": image_url}}] + return message + + def generate_chatml_messages_with_context( user_message, system_message=None, @@ -147,6 +154,7 @@ def generate_chatml_messages_with_context( tokenizer_name=None, uploaded_image_url=None, vision_enabled=False, + model_type="", ): """Generate messages for ChatGPT with context from previous conversation""" # Set max prompt size from user config or based on pre-configured for model and machine specs @@ -156,12 +164,6 @@ def generate_chatml_messages_with_context( else: max_prompt_size = model_to_prompt_size.get(model_name, 2000) - # Format user and system messages to chatml format - def construct_structured_message(message, image_url): - if image_url and vision_enabled: - return [{"type": "text", "text": message}, {"type": "image_url", "image_url": {"url": image_url}}] - return message - # Scale lookback turns proportional to max prompt size supported by model lookback_turns = max_prompt_size // 750 @@ -174,7 +176,9 @@ def generate_chatml_messages_with_context( message_content = chat["message"] + message_notes if chat.get("uploadedImageData") and vision_enabled: - message_content = construct_structured_message(message_content, chat.get("uploadedImageData")) + message_content = construct_structured_message( + message_content, chat.get("uploadedImageData"), model_type, vision_enabled + ) reconstructed_message = ChatMessage(content=message_content, role=role) @@ -186,7 +190,10 @@ def generate_chatml_messages_with_context( messages = [] if not is_none_or_empty(user_message): messages.append( - ChatMessage(content=construct_structured_message(user_message, uploaded_image_url), role="user") + ChatMessage( + content=construct_structured_message(user_message, uploaded_image_url, model_type, vision_enabled), + role="user", + ) ) if len(chatml_messages) > 0: messages += chatml_messages diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 40c65ce2..f00c362d 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -331,6 +331,7 @@ async def extract_references_and_questions( conversation_commands: List[ConversationCommand] = [ConversationCommand.Default], location_data: LocationData = None, send_status_func: Optional[Callable] = None, + uploaded_image_url: Optional[str] = None, ): user = request.user.object if request.user.is_authenticated else None @@ -370,6 +371,7 @@ async def extract_references_and_questions( with timer("Extracting search queries took", logger): # If we've reached here, either the user has enabled offline chat or the openai model is enabled. conversation_config = await ConversationAdapters.aget_default_conversation_config() + vision_enabled = conversation_config.vision_enabled if conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE: using_offline_chat = True @@ -403,6 +405,8 @@ async def extract_references_and_questions( conversation_log=meta_log, location_data=location_data, user=user, + uploaded_image_url=uploaded_image_url, + vision_enabled=vision_enabled, ) elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC: api_key = conversation_config.openai_config.api_key diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 0e83b9c2..0915b180 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -807,6 +807,7 @@ async def chat( conversation_commands, location, partial(send_event, ChatEvent.STATUS), + uploaded_image_url=uploaded_image_url, ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index da688ddb..8dd84873 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -330,7 +330,7 @@ async def aget_relevant_output_modes( chat_history = construct_chat_history(conversation_history) if uploaded_image_url: - query = f"[placeholder for image attached to this message]\n{query}" + query = f" \n{query}" relevant_mode_prompt = prompts.pick_relevant_output_mode.format( query=query, @@ -622,6 +622,7 @@ async def send_message_to_model_wrapper( tokenizer_name=tokenizer, max_prompt_size=max_tokens, vision_enabled=vision_available, + model_type=conversation_config.model_type, ) return send_message_to_model_offline( @@ -644,6 +645,7 @@ async def send_message_to_model_wrapper( tokenizer_name=tokenizer, vision_enabled=vision_available, uploaded_image_url=uploaded_image_url, + model_type=conversation_config.model_type, ) openai_response = send_message_to_model( @@ -664,6 +666,7 @@ async def send_message_to_model_wrapper( max_prompt_size=max_tokens, tokenizer_name=tokenizer, vision_enabled=vision_available, + model_type=conversation_config.model_type, ) return anthropic_send_message_to_model( @@ -700,6 +703,7 @@ def send_message_to_model_wrapper_sync( model_name=chat_model, loaded_model=loaded_model, vision_enabled=vision_available, + model_type=conversation_config.model_type, ) return send_message_to_model_offline( @@ -717,6 +721,7 @@ def send_message_to_model_wrapper_sync( system_message=system_message, model_name=chat_model, vision_enabled=vision_available, + model_type=conversation_config.model_type, ) openai_response = send_message_to_model( @@ -733,6 +738,7 @@ def send_message_to_model_wrapper_sync( model_name=chat_model, max_prompt_size=max_tokens, vision_enabled=vision_available, + model_type=conversation_config.model_type, ) return anthropic_send_message_to_model(