mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-06 13:22:12 +00:00
Limit vision_enabled image formatting to OpenAI APIs and send vision to extract_questions query
This commit is contained in:
@@ -6,7 +6,7 @@ from typing import Dict, Optional
|
|||||||
|
|
||||||
from langchain.schema import ChatMessage
|
from langchain.schema import ChatMessage
|
||||||
|
|
||||||
from khoj.database.models import Agent, KhojUser
|
from khoj.database.models import Agent, ChatModelOptions, KhojUser
|
||||||
from khoj.processor.conversation import prompts
|
from khoj.processor.conversation import prompts
|
||||||
from khoj.processor.conversation.anthropic.utils import (
|
from khoj.processor.conversation.anthropic.utils import (
|
||||||
anthropic_chat_completion_with_backoff,
|
anthropic_chat_completion_with_backoff,
|
||||||
@@ -188,6 +188,7 @@ def converse_anthropic(
|
|||||||
model_name=model,
|
model_name=model,
|
||||||
max_prompt_size=max_prompt_size,
|
max_prompt_size=max_prompt_size,
|
||||||
tokenizer_name=tokenizer_name,
|
tokenizer_name=tokenizer_name,
|
||||||
|
model_type=ChatModelOptions.ModelType.ANTHROPIC,
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(messages) > 1:
|
if len(messages) > 1:
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from typing import Any, Iterator, List, Union
|
|||||||
from langchain.schema import ChatMessage
|
from langchain.schema import ChatMessage
|
||||||
from llama_cpp import Llama
|
from llama_cpp import Llama
|
||||||
|
|
||||||
from khoj.database.models import Agent, KhojUser
|
from khoj.database.models import Agent, ChatModelOptions, KhojUser
|
||||||
from khoj.processor.conversation import prompts
|
from khoj.processor.conversation import prompts
|
||||||
from khoj.processor.conversation.offline.utils import download_model
|
from khoj.processor.conversation.offline.utils import download_model
|
||||||
from khoj.processor.conversation.utils import (
|
from khoj.processor.conversation.utils import (
|
||||||
@@ -76,7 +76,11 @@ def extract_questions_offline(
|
|||||||
)
|
)
|
||||||
|
|
||||||
messages = generate_chatml_messages_with_context(
|
messages = generate_chatml_messages_with_context(
|
||||||
example_questions, model_name=model, loaded_model=offline_chat_model, max_prompt_size=max_prompt_size
|
example_questions,
|
||||||
|
model_name=model,
|
||||||
|
loaded_model=offline_chat_model,
|
||||||
|
max_prompt_size=max_prompt_size,
|
||||||
|
model_type=ChatModelOptions.ModelType.OFFLINE,
|
||||||
)
|
)
|
||||||
|
|
||||||
state.chat_lock.acquire()
|
state.chat_lock.acquire()
|
||||||
@@ -201,6 +205,7 @@ def converse_offline(
|
|||||||
loaded_model=offline_chat_model,
|
loaded_model=offline_chat_model,
|
||||||
max_prompt_size=max_prompt_size,
|
max_prompt_size=max_prompt_size,
|
||||||
tokenizer_name=tokenizer_name,
|
tokenizer_name=tokenizer_name,
|
||||||
|
model_type=ChatModelOptions.ModelType.OFFLINE,
|
||||||
)
|
)
|
||||||
|
|
||||||
truncated_messages = "\n".join({f"{message.content[:70]}..." for message in messages})
|
truncated_messages = "\n".join({f"{message.content[:70]}..." for message in messages})
|
||||||
|
|||||||
@@ -5,13 +5,16 @@ from typing import Dict, Optional
|
|||||||
|
|
||||||
from langchain.schema import ChatMessage
|
from langchain.schema import ChatMessage
|
||||||
|
|
||||||
from khoj.database.models import Agent, KhojUser
|
from khoj.database.models import Agent, ChatModelOptions, KhojUser
|
||||||
from khoj.processor.conversation import prompts
|
from khoj.processor.conversation import prompts
|
||||||
from khoj.processor.conversation.openai.utils import (
|
from khoj.processor.conversation.openai.utils import (
|
||||||
chat_completion_with_backoff,
|
chat_completion_with_backoff,
|
||||||
completion_with_backoff,
|
completion_with_backoff,
|
||||||
)
|
)
|
||||||
from khoj.processor.conversation.utils import generate_chatml_messages_with_context
|
from khoj.processor.conversation.utils import (
|
||||||
|
construct_structured_message,
|
||||||
|
generate_chatml_messages_with_context,
|
||||||
|
)
|
||||||
from khoj.utils.helpers import ConversationCommand, is_none_or_empty
|
from khoj.utils.helpers import ConversationCommand, is_none_or_empty
|
||||||
from khoj.utils.rawconfig import LocationData
|
from khoj.utils.rawconfig import LocationData
|
||||||
|
|
||||||
@@ -24,9 +27,10 @@ def extract_questions(
|
|||||||
conversation_log={},
|
conversation_log={},
|
||||||
api_key=None,
|
api_key=None,
|
||||||
api_base_url=None,
|
api_base_url=None,
|
||||||
temperature=0.7,
|
|
||||||
location_data: LocationData = None,
|
location_data: LocationData = None,
|
||||||
user: KhojUser = None,
|
user: KhojUser = None,
|
||||||
|
uploaded_image_url: Optional[str] = None,
|
||||||
|
vision_enabled: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Infer search queries to retrieve relevant notes to answer user query
|
Infer search queries to retrieve relevant notes to answer user query
|
||||||
@@ -63,17 +67,17 @@ def extract_questions(
|
|||||||
location=location,
|
location=location,
|
||||||
username=username,
|
username=username,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
prompt = construct_structured_message(
|
||||||
|
message=prompt,
|
||||||
|
image_url=uploaded_image_url,
|
||||||
|
model_type=ChatModelOptions.ModelType.OPENAI,
|
||||||
|
vision_enabled=vision_enabled,
|
||||||
|
)
|
||||||
|
|
||||||
messages = [ChatMessage(content=prompt, role="user")]
|
messages = [ChatMessage(content=prompt, role="user")]
|
||||||
|
|
||||||
# Get Response from GPT
|
response = send_message_to_model(messages, api_key, model, response_type="json_object", api_base_url=api_base_url)
|
||||||
response = completion_with_backoff(
|
|
||||||
messages=messages,
|
|
||||||
model=model,
|
|
||||||
temperature=temperature,
|
|
||||||
api_base_url=api_base_url,
|
|
||||||
model_kwargs={"response_format": {"type": "json_object"}},
|
|
||||||
openai_api_key=api_key,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extract, Clean Message from GPT's Response
|
# Extract, Clean Message from GPT's Response
|
||||||
try:
|
try:
|
||||||
@@ -182,6 +186,7 @@ def converse(
|
|||||||
tokenizer_name=tokenizer_name,
|
tokenizer_name=tokenizer_name,
|
||||||
uploaded_image_url=image_url,
|
uploaded_image_url=image_url,
|
||||||
vision_enabled=vision_available,
|
vision_enabled=vision_available,
|
||||||
|
model_type=ChatModelOptions.ModelType.OPENAI,
|
||||||
)
|
)
|
||||||
truncated_messages = "\n".join({f"{message.content[:70]}..." for message in messages})
|
truncated_messages = "\n".join({f"{message.content[:70]}..." for message in messages})
|
||||||
logger.debug(f"Conversation Context for GPT: {truncated_messages}")
|
logger.debug(f"Conversation Context for GPT: {truncated_messages}")
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from llama_cpp.llama import Llama
|
|||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from khoj.database.adapters import ConversationAdapters
|
from khoj.database.adapters import ConversationAdapters
|
||||||
from khoj.database.models import ClientApplication, KhojUser
|
from khoj.database.models import ChatModelOptions, ClientApplication, KhojUser
|
||||||
from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens
|
from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens
|
||||||
from khoj.utils import state
|
from khoj.utils import state
|
||||||
from khoj.utils.helpers import is_none_or_empty, merge_dicts
|
from khoj.utils.helpers import is_none_or_empty, merge_dicts
|
||||||
@@ -137,6 +137,13 @@ Khoj: "{inferred_queries if ("text-to-image" in intent_type) else chat_response}
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Format user and system messages to chatml format
|
||||||
|
def construct_structured_message(message, image_url, model_type, vision_enabled):
|
||||||
|
if image_url and vision_enabled and model_type == ChatModelOptions.ModelType.OPENAI:
|
||||||
|
return [{"type": "text", "text": message}, {"type": "image_url", "image_url": {"url": image_url}}]
|
||||||
|
return message
|
||||||
|
|
||||||
|
|
||||||
def generate_chatml_messages_with_context(
|
def generate_chatml_messages_with_context(
|
||||||
user_message,
|
user_message,
|
||||||
system_message=None,
|
system_message=None,
|
||||||
@@ -147,6 +154,7 @@ def generate_chatml_messages_with_context(
|
|||||||
tokenizer_name=None,
|
tokenizer_name=None,
|
||||||
uploaded_image_url=None,
|
uploaded_image_url=None,
|
||||||
vision_enabled=False,
|
vision_enabled=False,
|
||||||
|
model_type="",
|
||||||
):
|
):
|
||||||
"""Generate messages for ChatGPT with context from previous conversation"""
|
"""Generate messages for ChatGPT with context from previous conversation"""
|
||||||
# Set max prompt size from user config or based on pre-configured for model and machine specs
|
# Set max prompt size from user config or based on pre-configured for model and machine specs
|
||||||
@@ -156,12 +164,6 @@ def generate_chatml_messages_with_context(
|
|||||||
else:
|
else:
|
||||||
max_prompt_size = model_to_prompt_size.get(model_name, 2000)
|
max_prompt_size = model_to_prompt_size.get(model_name, 2000)
|
||||||
|
|
||||||
# Format user and system messages to chatml format
|
|
||||||
def construct_structured_message(message, image_url):
|
|
||||||
if image_url and vision_enabled:
|
|
||||||
return [{"type": "text", "text": message}, {"type": "image_url", "image_url": {"url": image_url}}]
|
|
||||||
return message
|
|
||||||
|
|
||||||
# Scale lookback turns proportional to max prompt size supported by model
|
# Scale lookback turns proportional to max prompt size supported by model
|
||||||
lookback_turns = max_prompt_size // 750
|
lookback_turns = max_prompt_size // 750
|
||||||
|
|
||||||
@@ -174,7 +176,9 @@ def generate_chatml_messages_with_context(
|
|||||||
message_content = chat["message"] + message_notes
|
message_content = chat["message"] + message_notes
|
||||||
|
|
||||||
if chat.get("uploadedImageData") and vision_enabled:
|
if chat.get("uploadedImageData") and vision_enabled:
|
||||||
message_content = construct_structured_message(message_content, chat.get("uploadedImageData"))
|
message_content = construct_structured_message(
|
||||||
|
message_content, chat.get("uploadedImageData"), model_type, vision_enabled
|
||||||
|
)
|
||||||
|
|
||||||
reconstructed_message = ChatMessage(content=message_content, role=role)
|
reconstructed_message = ChatMessage(content=message_content, role=role)
|
||||||
|
|
||||||
@@ -186,7 +190,10 @@ def generate_chatml_messages_with_context(
|
|||||||
messages = []
|
messages = []
|
||||||
if not is_none_or_empty(user_message):
|
if not is_none_or_empty(user_message):
|
||||||
messages.append(
|
messages.append(
|
||||||
ChatMessage(content=construct_structured_message(user_message, uploaded_image_url), role="user")
|
ChatMessage(
|
||||||
|
content=construct_structured_message(user_message, uploaded_image_url, model_type, vision_enabled),
|
||||||
|
role="user",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
if len(chatml_messages) > 0:
|
if len(chatml_messages) > 0:
|
||||||
messages += chatml_messages
|
messages += chatml_messages
|
||||||
|
|||||||
@@ -331,6 +331,7 @@ async def extract_references_and_questions(
|
|||||||
conversation_commands: List[ConversationCommand] = [ConversationCommand.Default],
|
conversation_commands: List[ConversationCommand] = [ConversationCommand.Default],
|
||||||
location_data: LocationData = None,
|
location_data: LocationData = None,
|
||||||
send_status_func: Optional[Callable] = None,
|
send_status_func: Optional[Callable] = None,
|
||||||
|
uploaded_image_url: Optional[str] = None,
|
||||||
):
|
):
|
||||||
user = request.user.object if request.user.is_authenticated else None
|
user = request.user.object if request.user.is_authenticated else None
|
||||||
|
|
||||||
@@ -370,6 +371,7 @@ async def extract_references_and_questions(
|
|||||||
with timer("Extracting search queries took", logger):
|
with timer("Extracting search queries took", logger):
|
||||||
# If we've reached here, either the user has enabled offline chat or the openai model is enabled.
|
# If we've reached here, either the user has enabled offline chat or the openai model is enabled.
|
||||||
conversation_config = await ConversationAdapters.aget_default_conversation_config()
|
conversation_config = await ConversationAdapters.aget_default_conversation_config()
|
||||||
|
vision_enabled = conversation_config.vision_enabled
|
||||||
|
|
||||||
if conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE:
|
if conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE:
|
||||||
using_offline_chat = True
|
using_offline_chat = True
|
||||||
@@ -403,6 +405,8 @@ async def extract_references_and_questions(
|
|||||||
conversation_log=meta_log,
|
conversation_log=meta_log,
|
||||||
location_data=location_data,
|
location_data=location_data,
|
||||||
user=user,
|
user=user,
|
||||||
|
uploaded_image_url=uploaded_image_url,
|
||||||
|
vision_enabled=vision_enabled,
|
||||||
)
|
)
|
||||||
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC:
|
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC:
|
||||||
api_key = conversation_config.openai_config.api_key
|
api_key = conversation_config.openai_config.api_key
|
||||||
|
|||||||
@@ -807,6 +807,7 @@ async def chat(
|
|||||||
conversation_commands,
|
conversation_commands,
|
||||||
location,
|
location,
|
||||||
partial(send_event, ChatEvent.STATUS),
|
partial(send_event, ChatEvent.STATUS),
|
||||||
|
uploaded_image_url=uploaded_image_url,
|
||||||
):
|
):
|
||||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
yield result[ChatEvent.STATUS]
|
yield result[ChatEvent.STATUS]
|
||||||
|
|||||||
@@ -330,7 +330,7 @@ async def aget_relevant_output_modes(
|
|||||||
chat_history = construct_chat_history(conversation_history)
|
chat_history = construct_chat_history(conversation_history)
|
||||||
|
|
||||||
if uploaded_image_url:
|
if uploaded_image_url:
|
||||||
query = f"[placeholder for image attached to this message]\n{query}"
|
query = f"<user uploaded content redacted> \n{query}"
|
||||||
|
|
||||||
relevant_mode_prompt = prompts.pick_relevant_output_mode.format(
|
relevant_mode_prompt = prompts.pick_relevant_output_mode.format(
|
||||||
query=query,
|
query=query,
|
||||||
@@ -622,6 +622,7 @@ async def send_message_to_model_wrapper(
|
|||||||
tokenizer_name=tokenizer,
|
tokenizer_name=tokenizer,
|
||||||
max_prompt_size=max_tokens,
|
max_prompt_size=max_tokens,
|
||||||
vision_enabled=vision_available,
|
vision_enabled=vision_available,
|
||||||
|
model_type=conversation_config.model_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
return send_message_to_model_offline(
|
return send_message_to_model_offline(
|
||||||
@@ -644,6 +645,7 @@ async def send_message_to_model_wrapper(
|
|||||||
tokenizer_name=tokenizer,
|
tokenizer_name=tokenizer,
|
||||||
vision_enabled=vision_available,
|
vision_enabled=vision_available,
|
||||||
uploaded_image_url=uploaded_image_url,
|
uploaded_image_url=uploaded_image_url,
|
||||||
|
model_type=conversation_config.model_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
openai_response = send_message_to_model(
|
openai_response = send_message_to_model(
|
||||||
@@ -664,6 +666,7 @@ async def send_message_to_model_wrapper(
|
|||||||
max_prompt_size=max_tokens,
|
max_prompt_size=max_tokens,
|
||||||
tokenizer_name=tokenizer,
|
tokenizer_name=tokenizer,
|
||||||
vision_enabled=vision_available,
|
vision_enabled=vision_available,
|
||||||
|
model_type=conversation_config.model_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
return anthropic_send_message_to_model(
|
return anthropic_send_message_to_model(
|
||||||
@@ -700,6 +703,7 @@ def send_message_to_model_wrapper_sync(
|
|||||||
model_name=chat_model,
|
model_name=chat_model,
|
||||||
loaded_model=loaded_model,
|
loaded_model=loaded_model,
|
||||||
vision_enabled=vision_available,
|
vision_enabled=vision_available,
|
||||||
|
model_type=conversation_config.model_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
return send_message_to_model_offline(
|
return send_message_to_model_offline(
|
||||||
@@ -717,6 +721,7 @@ def send_message_to_model_wrapper_sync(
|
|||||||
system_message=system_message,
|
system_message=system_message,
|
||||||
model_name=chat_model,
|
model_name=chat_model,
|
||||||
vision_enabled=vision_available,
|
vision_enabled=vision_available,
|
||||||
|
model_type=conversation_config.model_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
openai_response = send_message_to_model(
|
openai_response = send_message_to_model(
|
||||||
@@ -733,6 +738,7 @@ def send_message_to_model_wrapper_sync(
|
|||||||
model_name=chat_model,
|
model_name=chat_model,
|
||||||
max_prompt_size=max_tokens,
|
max_prompt_size=max_tokens,
|
||||||
vision_enabled=vision_available,
|
vision_enabled=vision_available,
|
||||||
|
model_type=conversation_config.model_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
return anthropic_send_message_to_model(
|
return anthropic_send_message_to_model(
|
||||||
|
|||||||
Reference in New Issue
Block a user