Limit vision_enabled image formatting to OpenAI APIs and send vision to extract_questions query

This commit is contained in:
sabaimran
2024-09-10 20:08:14 -07:00
parent aa31d041f3
commit 8d40fc0aef
7 changed files with 54 additions and 25 deletions

View File

@@ -6,7 +6,7 @@ from typing import Dict, Optional
from langchain.schema import ChatMessage from langchain.schema import ChatMessage
from khoj.database.models import Agent, KhojUser from khoj.database.models import Agent, ChatModelOptions, KhojUser
from khoj.processor.conversation import prompts from khoj.processor.conversation import prompts
from khoj.processor.conversation.anthropic.utils import ( from khoj.processor.conversation.anthropic.utils import (
anthropic_chat_completion_with_backoff, anthropic_chat_completion_with_backoff,
@@ -188,6 +188,7 @@ def converse_anthropic(
model_name=model, model_name=model,
max_prompt_size=max_prompt_size, max_prompt_size=max_prompt_size,
tokenizer_name=tokenizer_name, tokenizer_name=tokenizer_name,
model_type=ChatModelOptions.ModelType.ANTHROPIC,
) )
if len(messages) > 1: if len(messages) > 1:

View File

@@ -7,7 +7,7 @@ from typing import Any, Iterator, List, Union
from langchain.schema import ChatMessage from langchain.schema import ChatMessage
from llama_cpp import Llama from llama_cpp import Llama
from khoj.database.models import Agent, KhojUser from khoj.database.models import Agent, ChatModelOptions, KhojUser
from khoj.processor.conversation import prompts from khoj.processor.conversation import prompts
from khoj.processor.conversation.offline.utils import download_model from khoj.processor.conversation.offline.utils import download_model
from khoj.processor.conversation.utils import ( from khoj.processor.conversation.utils import (
@@ -76,7 +76,11 @@ def extract_questions_offline(
) )
messages = generate_chatml_messages_with_context( messages = generate_chatml_messages_with_context(
example_questions, model_name=model, loaded_model=offline_chat_model, max_prompt_size=max_prompt_size example_questions,
model_name=model,
loaded_model=offline_chat_model,
max_prompt_size=max_prompt_size,
model_type=ChatModelOptions.ModelType.OFFLINE,
) )
state.chat_lock.acquire() state.chat_lock.acquire()
@@ -201,6 +205,7 @@ def converse_offline(
loaded_model=offline_chat_model, loaded_model=offline_chat_model,
max_prompt_size=max_prompt_size, max_prompt_size=max_prompt_size,
tokenizer_name=tokenizer_name, tokenizer_name=tokenizer_name,
model_type=ChatModelOptions.ModelType.OFFLINE,
) )
truncated_messages = "\n".join({f"{message.content[:70]}..." for message in messages}) truncated_messages = "\n".join({f"{message.content[:70]}..." for message in messages})

View File

@@ -5,13 +5,16 @@ from typing import Dict, Optional
from langchain.schema import ChatMessage from langchain.schema import ChatMessage
from khoj.database.models import Agent, KhojUser from khoj.database.models import Agent, ChatModelOptions, KhojUser
from khoj.processor.conversation import prompts from khoj.processor.conversation import prompts
from khoj.processor.conversation.openai.utils import ( from khoj.processor.conversation.openai.utils import (
chat_completion_with_backoff, chat_completion_with_backoff,
completion_with_backoff, completion_with_backoff,
) )
from khoj.processor.conversation.utils import generate_chatml_messages_with_context from khoj.processor.conversation.utils import (
construct_structured_message,
generate_chatml_messages_with_context,
)
from khoj.utils.helpers import ConversationCommand, is_none_or_empty from khoj.utils.helpers import ConversationCommand, is_none_or_empty
from khoj.utils.rawconfig import LocationData from khoj.utils.rawconfig import LocationData
@@ -24,9 +27,10 @@ def extract_questions(
conversation_log={}, conversation_log={},
api_key=None, api_key=None,
api_base_url=None, api_base_url=None,
temperature=0.7,
location_data: LocationData = None, location_data: LocationData = None,
user: KhojUser = None, user: KhojUser = None,
uploaded_image_url: Optional[str] = None,
vision_enabled: bool = False,
): ):
""" """
Infer search queries to retrieve relevant notes to answer user query Infer search queries to retrieve relevant notes to answer user query
@@ -63,17 +67,17 @@ def extract_questions(
location=location, location=location,
username=username, username=username,
) )
prompt = construct_structured_message(
message=prompt,
image_url=uploaded_image_url,
model_type=ChatModelOptions.ModelType.OPENAI,
vision_enabled=vision_enabled,
)
messages = [ChatMessage(content=prompt, role="user")] messages = [ChatMessage(content=prompt, role="user")]
# Get Response from GPT response = send_message_to_model(messages, api_key, model, response_type="json_object", api_base_url=api_base_url)
response = completion_with_backoff(
messages=messages,
model=model,
temperature=temperature,
api_base_url=api_base_url,
model_kwargs={"response_format": {"type": "json_object"}},
openai_api_key=api_key,
)
# Extract, Clean Message from GPT's Response # Extract, Clean Message from GPT's Response
try: try:
@@ -182,6 +186,7 @@ def converse(
tokenizer_name=tokenizer_name, tokenizer_name=tokenizer_name,
uploaded_image_url=image_url, uploaded_image_url=image_url,
vision_enabled=vision_available, vision_enabled=vision_available,
model_type=ChatModelOptions.ModelType.OPENAI,
) )
truncated_messages = "\n".join({f"{message.content[:70]}..." for message in messages}) truncated_messages = "\n".join({f"{message.content[:70]}..." for message in messages})
logger.debug(f"Conversation Context for GPT: {truncated_messages}") logger.debug(f"Conversation Context for GPT: {truncated_messages}")

View File

@@ -12,7 +12,7 @@ from llama_cpp.llama import Llama
from transformers import AutoTokenizer from transformers import AutoTokenizer
from khoj.database.adapters import ConversationAdapters from khoj.database.adapters import ConversationAdapters
from khoj.database.models import ClientApplication, KhojUser from khoj.database.models import ChatModelOptions, ClientApplication, KhojUser
from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens
from khoj.utils import state from khoj.utils import state
from khoj.utils.helpers import is_none_or_empty, merge_dicts from khoj.utils.helpers import is_none_or_empty, merge_dicts
@@ -137,6 +137,13 @@ Khoj: "{inferred_queries if ("text-to-image" in intent_type) else chat_response}
) )
# Format user and system messages to chatml format
def construct_structured_message(message, image_url, model_type, vision_enabled):
if image_url and vision_enabled and model_type == ChatModelOptions.ModelType.OPENAI:
return [{"type": "text", "text": message}, {"type": "image_url", "image_url": {"url": image_url}}]
return message
def generate_chatml_messages_with_context( def generate_chatml_messages_with_context(
user_message, user_message,
system_message=None, system_message=None,
@@ -147,6 +154,7 @@ def generate_chatml_messages_with_context(
tokenizer_name=None, tokenizer_name=None,
uploaded_image_url=None, uploaded_image_url=None,
vision_enabled=False, vision_enabled=False,
model_type="",
): ):
"""Generate messages for ChatGPT with context from previous conversation""" """Generate messages for ChatGPT with context from previous conversation"""
# Set max prompt size from user config or based on pre-configured for model and machine specs # Set max prompt size from user config or based on pre-configured for model and machine specs
@@ -156,12 +164,6 @@ def generate_chatml_messages_with_context(
else: else:
max_prompt_size = model_to_prompt_size.get(model_name, 2000) max_prompt_size = model_to_prompt_size.get(model_name, 2000)
# Format user and system messages to chatml format
def construct_structured_message(message, image_url):
if image_url and vision_enabled:
return [{"type": "text", "text": message}, {"type": "image_url", "image_url": {"url": image_url}}]
return message
# Scale lookback turns proportional to max prompt size supported by model # Scale lookback turns proportional to max prompt size supported by model
lookback_turns = max_prompt_size // 750 lookback_turns = max_prompt_size // 750
@@ -174,7 +176,9 @@ def generate_chatml_messages_with_context(
message_content = chat["message"] + message_notes message_content = chat["message"] + message_notes
if chat.get("uploadedImageData") and vision_enabled: if chat.get("uploadedImageData") and vision_enabled:
message_content = construct_structured_message(message_content, chat.get("uploadedImageData")) message_content = construct_structured_message(
message_content, chat.get("uploadedImageData"), model_type, vision_enabled
)
reconstructed_message = ChatMessage(content=message_content, role=role) reconstructed_message = ChatMessage(content=message_content, role=role)
@@ -186,7 +190,10 @@ def generate_chatml_messages_with_context(
messages = [] messages = []
if not is_none_or_empty(user_message): if not is_none_or_empty(user_message):
messages.append( messages.append(
ChatMessage(content=construct_structured_message(user_message, uploaded_image_url), role="user") ChatMessage(
content=construct_structured_message(user_message, uploaded_image_url, model_type, vision_enabled),
role="user",
)
) )
if len(chatml_messages) > 0: if len(chatml_messages) > 0:
messages += chatml_messages messages += chatml_messages

View File

@@ -331,6 +331,7 @@ async def extract_references_and_questions(
conversation_commands: List[ConversationCommand] = [ConversationCommand.Default], conversation_commands: List[ConversationCommand] = [ConversationCommand.Default],
location_data: LocationData = None, location_data: LocationData = None,
send_status_func: Optional[Callable] = None, send_status_func: Optional[Callable] = None,
uploaded_image_url: Optional[str] = None,
): ):
user = request.user.object if request.user.is_authenticated else None user = request.user.object if request.user.is_authenticated else None
@@ -370,6 +371,7 @@ async def extract_references_and_questions(
with timer("Extracting search queries took", logger): with timer("Extracting search queries took", logger):
# If we've reached here, either the user has enabled offline chat or the openai model is enabled. # If we've reached here, either the user has enabled offline chat or the openai model is enabled.
conversation_config = await ConversationAdapters.aget_default_conversation_config() conversation_config = await ConversationAdapters.aget_default_conversation_config()
vision_enabled = conversation_config.vision_enabled
if conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE: if conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE:
using_offline_chat = True using_offline_chat = True
@@ -403,6 +405,8 @@ async def extract_references_and_questions(
conversation_log=meta_log, conversation_log=meta_log,
location_data=location_data, location_data=location_data,
user=user, user=user,
uploaded_image_url=uploaded_image_url,
vision_enabled=vision_enabled,
) )
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC: elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC:
api_key = conversation_config.openai_config.api_key api_key = conversation_config.openai_config.api_key

View File

@@ -807,6 +807,7 @@ async def chat(
conversation_commands, conversation_commands,
location, location,
partial(send_event, ChatEvent.STATUS), partial(send_event, ChatEvent.STATUS),
uploaded_image_url=uploaded_image_url,
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS] yield result[ChatEvent.STATUS]

View File

@@ -330,7 +330,7 @@ async def aget_relevant_output_modes(
chat_history = construct_chat_history(conversation_history) chat_history = construct_chat_history(conversation_history)
if uploaded_image_url: if uploaded_image_url:
query = f"[placeholder for image attached to this message]\n{query}" query = f"<user uploaded content redacted> \n{query}"
relevant_mode_prompt = prompts.pick_relevant_output_mode.format( relevant_mode_prompt = prompts.pick_relevant_output_mode.format(
query=query, query=query,
@@ -622,6 +622,7 @@ async def send_message_to_model_wrapper(
tokenizer_name=tokenizer, tokenizer_name=tokenizer,
max_prompt_size=max_tokens, max_prompt_size=max_tokens,
vision_enabled=vision_available, vision_enabled=vision_available,
model_type=conversation_config.model_type,
) )
return send_message_to_model_offline( return send_message_to_model_offline(
@@ -644,6 +645,7 @@ async def send_message_to_model_wrapper(
tokenizer_name=tokenizer, tokenizer_name=tokenizer,
vision_enabled=vision_available, vision_enabled=vision_available,
uploaded_image_url=uploaded_image_url, uploaded_image_url=uploaded_image_url,
model_type=conversation_config.model_type,
) )
openai_response = send_message_to_model( openai_response = send_message_to_model(
@@ -664,6 +666,7 @@ async def send_message_to_model_wrapper(
max_prompt_size=max_tokens, max_prompt_size=max_tokens,
tokenizer_name=tokenizer, tokenizer_name=tokenizer,
vision_enabled=vision_available, vision_enabled=vision_available,
model_type=conversation_config.model_type,
) )
return anthropic_send_message_to_model( return anthropic_send_message_to_model(
@@ -700,6 +703,7 @@ def send_message_to_model_wrapper_sync(
model_name=chat_model, model_name=chat_model,
loaded_model=loaded_model, loaded_model=loaded_model,
vision_enabled=vision_available, vision_enabled=vision_available,
model_type=conversation_config.model_type,
) )
return send_message_to_model_offline( return send_message_to_model_offline(
@@ -717,6 +721,7 @@ def send_message_to_model_wrapper_sync(
system_message=system_message, system_message=system_message,
model_name=chat_model, model_name=chat_model,
vision_enabled=vision_available, vision_enabled=vision_available,
model_type=conversation_config.model_type,
) )
openai_response = send_message_to_model( openai_response = send_message_to_model(
@@ -733,6 +738,7 @@ def send_message_to_model_wrapper_sync(
model_name=chat_model, model_name=chat_model,
max_prompt_size=max_tokens, max_prompt_size=max_tokens,
vision_enabled=vision_available, vision_enabled=vision_available,
model_type=conversation_config.model_type,
) )
return anthropic_send_message_to_model( return anthropic_send_message_to_model(