mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 13:18:18 +00:00
Share context builder for chat final response across model types
The context building logic was nearly identical across all model types. This change extracts that logic into a shared function and calls it once in the `agenerate_chat_response', the entrypoint to the converse methods for all 3 model types. Main differences handled are - Gemini system prompt had additional verbosity instructions. Keep it - Pass system messsage via chatml messages list to anthropic, gemini models as well (like openai models) instead of passing it as separate arg to chat_completion_* funcs. The model specific message formatters for both already extract system instruction from the messages list. So system messages wil be automatically extracted from the chat_completion_* funcs to pass as separate arg required by anthropic, gemini api libraries.
This commit is contained in:
@@ -1,22 +1,16 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import AsyncGenerator, Dict, List, Optional
|
||||
from typing import AsyncGenerator, List, Optional
|
||||
|
||||
from langchain_core.messages.chat import ChatMessage
|
||||
|
||||
from khoj.database.models import Agent, ChatMessageModel, ChatModel
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.anthropic.utils import (
|
||||
anthropic_chat_completion_with_backoff,
|
||||
anthropic_completion_with_backoff,
|
||||
)
|
||||
from khoj.processor.conversation.utils import (
|
||||
OperatorRun,
|
||||
ResponseWithThought,
|
||||
generate_chatml_messages_with_context,
|
||||
messages_to_print,
|
||||
)
|
||||
from khoj.utils.helpers import is_none_or_empty, truncate_code_context
|
||||
from khoj.utils.rawconfig import LocationData
|
||||
from khoj.utils.yaml import yaml_dump
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -52,91 +46,17 @@ def anthropic_send_message_to_model(
|
||||
|
||||
async def converse_anthropic(
|
||||
# Query
|
||||
user_query: str,
|
||||
# Context
|
||||
references: list[dict],
|
||||
online_results: Optional[Dict[str, Dict]] = None,
|
||||
code_results: Optional[Dict[str, Dict]] = None,
|
||||
operator_results: Optional[List[OperatorRun]] = None,
|
||||
query_images: Optional[list[str]] = None,
|
||||
query_files: str = None,
|
||||
program_execution_context: Optional[List[str]] = None,
|
||||
generated_asset_results: Dict[str, Dict] = {},
|
||||
location_data: LocationData = None,
|
||||
user_name: str = None,
|
||||
chat_history: List[ChatMessageModel] = [],
|
||||
messages: List[ChatMessage],
|
||||
# Model
|
||||
model: Optional[str] = "claude-3-7-sonnet-latest",
|
||||
api_key: Optional[str] = None,
|
||||
api_base_url: Optional[str] = None,
|
||||
max_prompt_size=None,
|
||||
tokenizer_name=None,
|
||||
agent: Agent = None,
|
||||
vision_available: bool = False,
|
||||
deepthought: Optional[bool] = False,
|
||||
tracer: dict = {},
|
||||
) -> AsyncGenerator[ResponseWithThought, None]:
|
||||
"""
|
||||
Converse with user using Anthropic's Claude
|
||||
"""
|
||||
# Initialize Variables
|
||||
current_date = datetime.now()
|
||||
|
||||
if agent and agent.personality:
|
||||
system_prompt = prompts.custom_personality.format(
|
||||
name=agent.name,
|
||||
bio=agent.personality,
|
||||
current_date=current_date.strftime("%Y-%m-%d"),
|
||||
day_of_week=current_date.strftime("%A"),
|
||||
)
|
||||
else:
|
||||
system_prompt = prompts.personality.format(
|
||||
current_date=current_date.strftime("%Y-%m-%d"),
|
||||
day_of_week=current_date.strftime("%A"),
|
||||
)
|
||||
|
||||
if location_data:
|
||||
location_prompt = prompts.user_location.format(location=f"{location_data}")
|
||||
system_prompt = f"{system_prompt}\n{location_prompt}"
|
||||
|
||||
if user_name:
|
||||
user_name_prompt = prompts.user_name.format(name=user_name)
|
||||
system_prompt = f"{system_prompt}\n{user_name_prompt}"
|
||||
|
||||
context_message = ""
|
||||
if not is_none_or_empty(references):
|
||||
context_message = f"{prompts.notes_conversation.format(query=user_query, references=yaml_dump(references))}\n\n"
|
||||
if not is_none_or_empty(online_results):
|
||||
context_message += f"{prompts.online_search_conversation.format(online_results=yaml_dump(online_results))}\n\n"
|
||||
if not is_none_or_empty(code_results):
|
||||
context_message += (
|
||||
f"{prompts.code_executed_context.format(code_results=truncate_code_context(code_results))}\n\n"
|
||||
)
|
||||
if not is_none_or_empty(operator_results):
|
||||
operator_content = [
|
||||
{"query": oc.query, "response": oc.response, "webpages": oc.webpages} for oc in operator_results
|
||||
]
|
||||
context_message += (
|
||||
f"{prompts.operator_execution_context.format(operator_results=yaml_dump(operator_content))}\n\n"
|
||||
)
|
||||
context_message = context_message.strip()
|
||||
|
||||
# Setup Prompt with Primer or Conversation History
|
||||
messages = generate_chatml_messages_with_context(
|
||||
user_query,
|
||||
context_message=context_message,
|
||||
chat_history=chat_history,
|
||||
model_name=model,
|
||||
max_prompt_size=max_prompt_size,
|
||||
tokenizer_name=tokenizer_name,
|
||||
query_images=query_images,
|
||||
vision_enabled=vision_available,
|
||||
model_type=ChatModel.ModelType.ANTHROPIC,
|
||||
query_files=query_files,
|
||||
generated_asset_results=generated_asset_results,
|
||||
program_execution_context=program_execution_context,
|
||||
)
|
||||
|
||||
logger.debug(f"Conversation Context for Claude: {messages_to_print(messages)}")
|
||||
|
||||
# Get Response from Claude
|
||||
@@ -146,8 +66,6 @@ async def converse_anthropic(
|
||||
temperature=0.2,
|
||||
api_key=api_key,
|
||||
api_base_url=api_base_url,
|
||||
system_prompt=system_prompt,
|
||||
max_prompt_size=max_prompt_size,
|
||||
deepthought=deepthought,
|
||||
tracer=tracer,
|
||||
):
|
||||
|
||||
@@ -184,8 +184,7 @@ async def anthropic_chat_completion_with_backoff(
|
||||
temperature: float,
|
||||
api_key: str | None,
|
||||
api_base_url: str,
|
||||
system_prompt: str,
|
||||
max_prompt_size: int | None = None,
|
||||
system_prompt: str = "",
|
||||
deepthought: bool = False,
|
||||
model_kwargs: dict | None = None,
|
||||
tracer: dict = {},
|
||||
|
||||
@@ -1,22 +1,16 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import AsyncGenerator, Dict, List, Optional
|
||||
from typing import AsyncGenerator, List, Optional
|
||||
|
||||
from langchain_core.messages.chat import ChatMessage
|
||||
|
||||
from khoj.database.models import Agent, ChatMessageModel, ChatModel
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.google.utils import (
|
||||
gemini_chat_completion_with_backoff,
|
||||
gemini_completion_with_backoff,
|
||||
)
|
||||
from khoj.processor.conversation.utils import (
|
||||
OperatorRun,
|
||||
ResponseWithThought,
|
||||
generate_chatml_messages_with_context,
|
||||
messages_to_print,
|
||||
)
|
||||
from khoj.utils.helpers import is_none_or_empty, truncate_code_context
|
||||
from khoj.utils.rawconfig import LocationData
|
||||
from khoj.utils.yaml import yaml_dump
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -61,93 +55,18 @@ def gemini_send_message_to_model(
|
||||
|
||||
async def converse_gemini(
|
||||
# Query
|
||||
user_query: str,
|
||||
# Context
|
||||
references: list[dict],
|
||||
online_results: Optional[Dict[str, Dict]] = None,
|
||||
code_results: Optional[Dict[str, Dict]] = None,
|
||||
operator_results: Optional[List[OperatorRun]] = None,
|
||||
query_images: Optional[list[str]] = None,
|
||||
query_files: str = None,
|
||||
generated_asset_results: Dict[str, Dict] = {},
|
||||
program_execution_context: List[str] = None,
|
||||
location_data: LocationData = None,
|
||||
user_name: str = None,
|
||||
chat_history: List[ChatMessageModel] = [],
|
||||
messages: List[ChatMessage],
|
||||
# Model
|
||||
model: Optional[str] = "gemini-2.5-flash",
|
||||
api_key: Optional[str] = None,
|
||||
api_base_url: Optional[str] = None,
|
||||
temperature: float = 1.0,
|
||||
max_prompt_size=None,
|
||||
tokenizer_name=None,
|
||||
agent: Agent = None,
|
||||
vision_available: bool = False,
|
||||
deepthought: Optional[bool] = False,
|
||||
tracer={},
|
||||
) -> AsyncGenerator[ResponseWithThought, None]:
|
||||
"""
|
||||
Converse with user using Google's Gemini
|
||||
"""
|
||||
# Initialize Variables
|
||||
current_date = datetime.now()
|
||||
|
||||
if agent and agent.personality:
|
||||
system_prompt = prompts.custom_personality.format(
|
||||
name=agent.name,
|
||||
bio=agent.personality,
|
||||
current_date=current_date.strftime("%Y-%m-%d"),
|
||||
day_of_week=current_date.strftime("%A"),
|
||||
)
|
||||
else:
|
||||
system_prompt = prompts.personality.format(
|
||||
current_date=current_date.strftime("%Y-%m-%d"),
|
||||
day_of_week=current_date.strftime("%A"),
|
||||
)
|
||||
|
||||
system_prompt += f"\n\n{prompts.gemini_verbose_language_personality}"
|
||||
if location_data:
|
||||
location_prompt = prompts.user_location.format(location=f"{location_data}")
|
||||
system_prompt += f"\n{location_prompt}"
|
||||
|
||||
if user_name:
|
||||
user_name_prompt = prompts.user_name.format(name=user_name)
|
||||
system_prompt += f"\n{user_name_prompt}"
|
||||
|
||||
context_message = ""
|
||||
if not is_none_or_empty(references):
|
||||
context_message = f"{prompts.notes_conversation.format(query=user_query, references=yaml_dump(references))}\n\n"
|
||||
if not is_none_or_empty(online_results):
|
||||
context_message += f"{prompts.online_search_conversation.format(online_results=yaml_dump(online_results))}\n\n"
|
||||
if not is_none_or_empty(code_results):
|
||||
context_message += (
|
||||
f"{prompts.code_executed_context.format(code_results=truncate_code_context(code_results))}\n\n"
|
||||
)
|
||||
if not is_none_or_empty(operator_results):
|
||||
operator_content = [
|
||||
{"query": oc.query, "response": oc.response, "webpages": oc.webpages} for oc in operator_results
|
||||
]
|
||||
context_message += (
|
||||
f"{prompts.operator_execution_context.format(operator_results=yaml_dump(operator_content))}\n\n"
|
||||
)
|
||||
context_message = context_message.strip()
|
||||
|
||||
# Setup Prompt with Primer or Conversation History
|
||||
messages = generate_chatml_messages_with_context(
|
||||
user_query,
|
||||
context_message=context_message,
|
||||
chat_history=chat_history,
|
||||
model_name=model,
|
||||
max_prompt_size=max_prompt_size,
|
||||
tokenizer_name=tokenizer_name,
|
||||
query_images=query_images,
|
||||
vision_enabled=vision_available,
|
||||
model_type=ChatModel.ModelType.GOOGLE,
|
||||
query_files=query_files,
|
||||
generated_asset_results=generated_asset_results,
|
||||
program_execution_context=program_execution_context,
|
||||
)
|
||||
|
||||
logger.debug(f"Conversation Context for Gemini: {messages_to_print(messages)}")
|
||||
|
||||
# Get Response from Google AI
|
||||
@@ -157,7 +76,6 @@ async def converse_gemini(
|
||||
temperature=temperature,
|
||||
api_key=api_key,
|
||||
api_base_url=api_base_url,
|
||||
system_prompt=system_prompt,
|
||||
deepthought=deepthought,
|
||||
tracer=tracer,
|
||||
):
|
||||
|
||||
@@ -308,7 +308,7 @@ async def gemini_chat_completion_with_backoff(
|
||||
temperature: float,
|
||||
api_key: str,
|
||||
api_base_url: str,
|
||||
system_prompt: str,
|
||||
system_prompt: str = "",
|
||||
model_kwargs=None,
|
||||
deepthought=False,
|
||||
tracer: dict = {},
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from khoj.database.models import Agent, ChatMessageModel, ChatModel
|
||||
from khoj.processor.conversation import prompts
|
||||
from langchain_core.messages.chat import ChatMessage
|
||||
|
||||
from khoj.processor.conversation.openai.utils import (
|
||||
chat_completion_with_backoff,
|
||||
clean_response_schema,
|
||||
@@ -15,15 +14,11 @@ from khoj.processor.conversation.openai.utils import (
|
||||
to_openai_tools,
|
||||
)
|
||||
from khoj.processor.conversation.utils import (
|
||||
OperatorRun,
|
||||
ResponseWithThought,
|
||||
StructuredOutputSupport,
|
||||
generate_chatml_messages_with_context,
|
||||
messages_to_print,
|
||||
)
|
||||
from khoj.utils.helpers import ToolDefinition, is_none_or_empty, truncate_code_context
|
||||
from khoj.utils.rawconfig import LocationData
|
||||
from khoj.utils.yaml import yaml_dump
|
||||
from khoj.utils.helpers import ToolDefinition
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -96,92 +91,18 @@ def send_message_to_model(
|
||||
|
||||
async def converse_openai(
|
||||
# Query
|
||||
user_query: str,
|
||||
# Context
|
||||
references: list[dict],
|
||||
online_results: Optional[Dict[str, Dict]] = None,
|
||||
code_results: Optional[Dict[str, Dict]] = None,
|
||||
operator_results: Optional[List[OperatorRun]] = None,
|
||||
query_images: Optional[list[str]] = None,
|
||||
query_files: str = None,
|
||||
generated_asset_results: Dict[str, Dict] = {},
|
||||
program_execution_context: List[str] = None,
|
||||
location_data: LocationData = None,
|
||||
chat_history: list[ChatMessageModel] = [],
|
||||
messages: List[ChatMessage],
|
||||
# Model
|
||||
model: str = "gpt-4.1-mini",
|
||||
api_key: Optional[str] = None,
|
||||
api_base_url: Optional[str] = None,
|
||||
temperature: float = 0.6,
|
||||
max_prompt_size=None,
|
||||
tokenizer_name=None,
|
||||
user_name: str = None,
|
||||
agent: Agent = None,
|
||||
vision_available: bool = False,
|
||||
deepthought: Optional[bool] = False,
|
||||
tracer: dict = {},
|
||||
) -> AsyncGenerator[ResponseWithThought, None]:
|
||||
"""
|
||||
Converse with user using OpenAI's ChatGPT
|
||||
"""
|
||||
# Initialize Variables
|
||||
current_date = datetime.now()
|
||||
|
||||
if agent and agent.personality:
|
||||
system_prompt = prompts.custom_personality.format(
|
||||
name=agent.name,
|
||||
bio=agent.personality,
|
||||
current_date=current_date.strftime("%Y-%m-%d"),
|
||||
day_of_week=current_date.strftime("%A"),
|
||||
)
|
||||
else:
|
||||
system_prompt = prompts.personality.format(
|
||||
current_date=current_date.strftime("%Y-%m-%d"),
|
||||
day_of_week=current_date.strftime("%A"),
|
||||
)
|
||||
|
||||
if location_data:
|
||||
location_prompt = prompts.user_location.format(location=f"{location_data}")
|
||||
system_prompt = f"{system_prompt}\n{location_prompt}"
|
||||
|
||||
if user_name:
|
||||
user_name_prompt = prompts.user_name.format(name=user_name)
|
||||
system_prompt = f"{system_prompt}\n{user_name_prompt}"
|
||||
|
||||
context_message = ""
|
||||
if not is_none_or_empty(references):
|
||||
context_message = f"{prompts.notes_conversation.format(references=yaml_dump(references))}\n\n"
|
||||
if not is_none_or_empty(online_results):
|
||||
context_message += f"{prompts.online_search_conversation.format(online_results=yaml_dump(online_results))}\n\n"
|
||||
if not is_none_or_empty(code_results):
|
||||
context_message += (
|
||||
f"{prompts.code_executed_context.format(code_results=truncate_code_context(code_results))}\n\n"
|
||||
)
|
||||
if not is_none_or_empty(operator_results):
|
||||
operator_content = [
|
||||
{"query": oc.query, "response": oc.response, "webpages": oc.webpages} for oc in operator_results
|
||||
]
|
||||
context_message += (
|
||||
f"{prompts.operator_execution_context.format(operator_results=yaml_dump(operator_content))}\n\n"
|
||||
)
|
||||
|
||||
context_message = context_message.strip()
|
||||
|
||||
# Setup Prompt with Primer or Conversation History
|
||||
messages = generate_chatml_messages_with_context(
|
||||
user_query,
|
||||
system_prompt,
|
||||
chat_history,
|
||||
context_message=context_message,
|
||||
model_name=model,
|
||||
max_prompt_size=max_prompt_size,
|
||||
tokenizer_name=tokenizer_name,
|
||||
query_images=query_images,
|
||||
vision_enabled=vision_available,
|
||||
model_type=ChatModel.ModelType.OPENAI,
|
||||
query_files=query_files,
|
||||
generated_asset_results=generated_asset_results,
|
||||
program_execution_context=program_execution_context,
|
||||
)
|
||||
logger.debug(f"Conversation Context for GPT: {messages_to_print(messages)}")
|
||||
|
||||
# Get Response from GPT
|
||||
|
||||
@@ -33,6 +33,7 @@ from apscheduler.triggers.cron import CronTrigger
|
||||
from asgiref.sync import sync_to_async
|
||||
from django.utils import timezone as django_timezone
|
||||
from fastapi import Depends, Header, HTTPException, Request, UploadFile, WebSocket
|
||||
from langchain_core.messages.chat import ChatMessage
|
||||
from pydantic import BaseModel, EmailStr, Field
|
||||
from starlette.authentication import has_required_scope
|
||||
from starlette.requests import URL
|
||||
@@ -124,6 +125,7 @@ from khoj.utils.helpers import (
|
||||
mode_descriptions_for_llm,
|
||||
timer,
|
||||
tool_descriptions_for_llm,
|
||||
truncate_code_context,
|
||||
)
|
||||
from khoj.utils.rawconfig import (
|
||||
ChatRequestBody,
|
||||
@@ -132,6 +134,7 @@ from khoj.utils.rawconfig import (
|
||||
SearchResponse,
|
||||
)
|
||||
from khoj.utils.state import SearchType
|
||||
from khoj.utils.yaml import yaml_dump
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -1598,6 +1601,105 @@ def send_message_to_model_wrapper_sync(
|
||||
raise HTTPException(status_code=500, detail="Invalid conversation config")
|
||||
|
||||
|
||||
def build_conversation_context(
|
||||
# Query and Context
|
||||
user_query: str,
|
||||
references: List[Dict],
|
||||
online_results: Dict[str, Dict],
|
||||
code_results: Dict[str, Dict],
|
||||
operator_results: List[OperatorRun],
|
||||
query_files: str = None,
|
||||
query_images: Optional[List[str]] = None,
|
||||
generated_asset_results: Dict[str, Dict] = {},
|
||||
program_execution_context: List[str] = None,
|
||||
chat_history: List[ChatMessageModel] = [],
|
||||
location_data: LocationData = None,
|
||||
user_name: str = None,
|
||||
# Model config
|
||||
agent: Agent = None,
|
||||
model_name: str = None,
|
||||
model_type: ChatModel.ModelType = None,
|
||||
max_prompt_size: int = None,
|
||||
tokenizer_name: str = None,
|
||||
vision_available: bool = False,
|
||||
) -> List[ChatMessage]:
|
||||
"""
|
||||
Construct system, context and chatml messages for chat response.
|
||||
Share common logic across different model types.
|
||||
|
||||
Returns:
|
||||
List of ChatMessages with context
|
||||
"""
|
||||
# Initialize Variables
|
||||
current_date = datetime.now()
|
||||
|
||||
# Build system prompt
|
||||
if agent and agent.personality:
|
||||
system_prompt = prompts.custom_personality.format(
|
||||
name=agent.name,
|
||||
bio=agent.personality,
|
||||
current_date=current_date.strftime("%Y-%m-%d"),
|
||||
day_of_week=current_date.strftime("%A"),
|
||||
)
|
||||
else:
|
||||
system_prompt = prompts.personality.format(
|
||||
current_date=current_date.strftime("%Y-%m-%d"),
|
||||
day_of_week=current_date.strftime("%A"),
|
||||
)
|
||||
|
||||
# Add Gemini-specific personality enhancement
|
||||
if model_type == ChatModel.ModelType.GOOGLE:
|
||||
system_prompt += f"\n\n{prompts.gemini_verbose_language_personality}"
|
||||
|
||||
# Add location context if available
|
||||
if location_data:
|
||||
location_prompt = prompts.user_location.format(location=f"{location_data}")
|
||||
system_prompt += f"\n{location_prompt}"
|
||||
|
||||
# Add user name context if available
|
||||
if user_name:
|
||||
user_name_prompt = prompts.user_name.format(name=user_name)
|
||||
system_prompt += f"\n{user_name_prompt}"
|
||||
|
||||
# Build context message
|
||||
context_message = ""
|
||||
if not is_none_or_empty(references):
|
||||
context_message = f"{prompts.notes_conversation.format(references=yaml_dump(references))}\n\n"
|
||||
if not is_none_or_empty(online_results):
|
||||
context_message += f"{prompts.online_search_conversation.format(online_results=yaml_dump(online_results))}\n\n"
|
||||
if not is_none_or_empty(code_results):
|
||||
context_message += (
|
||||
f"{prompts.code_executed_context.format(code_results=truncate_code_context(code_results))}\n\n"
|
||||
)
|
||||
if not is_none_or_empty(operator_results):
|
||||
operator_content = [
|
||||
{"query": oc.query, "response": oc.response, "webpages": oc.webpages} for oc in operator_results
|
||||
]
|
||||
context_message += (
|
||||
f"{prompts.operator_execution_context.format(operator_results=yaml_dump(operator_content))}\n\n"
|
||||
)
|
||||
context_message = context_message.strip()
|
||||
|
||||
# Generate the chatml messages
|
||||
messages = generate_chatml_messages_with_context(
|
||||
user_message=user_query,
|
||||
query_files=query_files,
|
||||
query_images=query_images,
|
||||
context_message=context_message,
|
||||
generated_asset_results=generated_asset_results,
|
||||
program_execution_context=program_execution_context,
|
||||
chat_history=chat_history,
|
||||
system_message=system_prompt,
|
||||
model_name=model_name,
|
||||
model_type=model_type,
|
||||
max_prompt_size=max_prompt_size,
|
||||
tokenizer_name=tokenizer_name,
|
||||
vision_enabled=vision_available,
|
||||
)
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
async def agenerate_chat_response(
|
||||
q: str,
|
||||
chat_history: List[ChatMessageModel],
|
||||
@@ -1645,33 +1747,39 @@ async def agenerate_chat_response(
|
||||
chat_model = vision_enabled_config
|
||||
vision_available = True
|
||||
|
||||
# Build shared conversation context and generate chatml messages
|
||||
messages = build_conversation_context(
|
||||
user_query=query_to_run,
|
||||
references=compiled_references,
|
||||
online_results=online_results,
|
||||
code_results=code_results,
|
||||
operator_results=operator_results,
|
||||
query_files=query_files,
|
||||
query_images=query_images,
|
||||
generated_asset_results=generated_asset_results,
|
||||
program_execution_context=program_execution_context,
|
||||
chat_history=chat_history,
|
||||
location_data=location_data,
|
||||
user_name=user_name,
|
||||
agent=agent,
|
||||
model_type=chat_model.model_type,
|
||||
model_name=chat_model.name,
|
||||
max_prompt_size=max_prompt_size,
|
||||
tokenizer_name=chat_model.tokenizer,
|
||||
vision_available=vision_available,
|
||||
)
|
||||
|
||||
if chat_model.model_type == ChatModel.ModelType.OPENAI:
|
||||
openai_chat_config = chat_model.ai_model_api
|
||||
api_key = openai_chat_config.api_key
|
||||
chat_model_name = chat_model.name
|
||||
chat_response_generator = converse_openai(
|
||||
# Query
|
||||
query_to_run,
|
||||
# Context
|
||||
references=compiled_references,
|
||||
online_results=online_results,
|
||||
code_results=code_results,
|
||||
operator_results=operator_results,
|
||||
query_images=query_images,
|
||||
query_files=query_files,
|
||||
generated_asset_results=generated_asset_results,
|
||||
program_execution_context=program_execution_context,
|
||||
location_data=location_data,
|
||||
user_name=user_name,
|
||||
chat_history=chat_history,
|
||||
# Query + Context Messages
|
||||
messages,
|
||||
# Model
|
||||
model=chat_model_name,
|
||||
api_key=api_key,
|
||||
api_base_url=openai_chat_config.api_base_url,
|
||||
max_prompt_size=max_prompt_size,
|
||||
tokenizer_name=chat_model.tokenizer,
|
||||
agent=agent,
|
||||
vision_available=vision_available,
|
||||
deepthought=deepthought,
|
||||
tracer=tracer,
|
||||
)
|
||||
@@ -1680,28 +1788,12 @@ async def agenerate_chat_response(
|
||||
api_key = chat_model.ai_model_api.api_key
|
||||
api_base_url = chat_model.ai_model_api.api_base_url
|
||||
chat_response_generator = converse_anthropic(
|
||||
# Query
|
||||
query_to_run,
|
||||
# Context
|
||||
references=compiled_references,
|
||||
online_results=online_results,
|
||||
code_results=code_results,
|
||||
operator_results=operator_results,
|
||||
query_images=query_images,
|
||||
query_files=query_files,
|
||||
generated_asset_results=generated_asset_results,
|
||||
program_execution_context=program_execution_context,
|
||||
location_data=location_data,
|
||||
user_name=user_name,
|
||||
chat_history=chat_history,
|
||||
# Query + Context Messages
|
||||
messages,
|
||||
# Model
|
||||
model=chat_model.name,
|
||||
api_key=api_key,
|
||||
api_base_url=api_base_url,
|
||||
max_prompt_size=max_prompt_size,
|
||||
tokenizer_name=chat_model.tokenizer,
|
||||
agent=agent,
|
||||
vision_available=vision_available,
|
||||
deepthought=deepthought,
|
||||
tracer=tracer,
|
||||
)
|
||||
@@ -1709,28 +1801,12 @@ async def agenerate_chat_response(
|
||||
api_key = chat_model.ai_model_api.api_key
|
||||
api_base_url = chat_model.ai_model_api.api_base_url
|
||||
chat_response_generator = converse_gemini(
|
||||
# Query
|
||||
query_to_run,
|
||||
# Context
|
||||
references=compiled_references,
|
||||
online_results=online_results,
|
||||
code_results=code_results,
|
||||
operator_results=operator_results,
|
||||
query_images=query_images,
|
||||
query_files=query_files,
|
||||
generated_asset_results=generated_asset_results,
|
||||
program_execution_context=program_execution_context,
|
||||
location_data=location_data,
|
||||
user_name=user_name,
|
||||
chat_history=chat_history,
|
||||
# Query + Context Messages
|
||||
messages,
|
||||
# Model
|
||||
model=chat_model.name,
|
||||
api_key=api_key,
|
||||
api_base_url=api_base_url,
|
||||
max_prompt_size=max_prompt_size,
|
||||
tokenizer_name=chat_model.tokenizer,
|
||||
agent=agent,
|
||||
vision_available=vision_available,
|
||||
deepthought=deepthought,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user