Limit vision_enabled image formatting to OpenAI APIs and send vision to extract_questions query

This commit is contained in:
sabaimran
2024-09-10 20:08:14 -07:00
parent aa31d041f3
commit 8d40fc0aef
7 changed files with 54 additions and 25 deletions

View File

@@ -5,13 +5,16 @@ from typing import Dict, Optional
from langchain.schema import ChatMessage
from khoj.database.models import Agent, KhojUser
from khoj.database.models import Agent, ChatModelOptions, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.openai.utils import (
chat_completion_with_backoff,
completion_with_backoff,
)
from khoj.processor.conversation.utils import generate_chatml_messages_with_context
from khoj.processor.conversation.utils import (
construct_structured_message,
generate_chatml_messages_with_context,
)
from khoj.utils.helpers import ConversationCommand, is_none_or_empty
from khoj.utils.rawconfig import LocationData
@@ -24,9 +27,10 @@ def extract_questions(
conversation_log={},
api_key=None,
api_base_url=None,
temperature=0.7,
location_data: LocationData = None,
user: KhojUser = None,
uploaded_image_url: Optional[str] = None,
vision_enabled: bool = False,
):
"""
Infer search queries to retrieve relevant notes to answer user query
@@ -63,17 +67,17 @@ def extract_questions(
location=location,
username=username,
)
prompt = construct_structured_message(
message=prompt,
image_url=uploaded_image_url,
model_type=ChatModelOptions.ModelType.OPENAI,
vision_enabled=vision_enabled,
)
messages = [ChatMessage(content=prompt, role="user")]
# Get Response from GPT
response = completion_with_backoff(
messages=messages,
model=model,
temperature=temperature,
api_base_url=api_base_url,
model_kwargs={"response_format": {"type": "json_object"}},
openai_api_key=api_key,
)
response = send_message_to_model(messages, api_key, model, response_type="json_object", api_base_url=api_base_url)
# Extract, Clean Message from GPT's Response
try:
@@ -182,6 +186,7 @@ def converse(
tokenizer_name=tokenizer_name,
uploaded_image_url=image_url,
vision_enabled=vision_available,
model_type=ChatModelOptions.ModelType.OPENAI,
)
truncated_messages = "\n".join({f"{message.content[:70]}..." for message in messages})
logger.debug(f"Conversation Context for GPT: {truncated_messages}")