mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 13:18:18 +00:00
Retry on empty response or error in chat completion by llm over api
Previously all exceptions were being caught. So retry logic wasn't getting triggered. Exception catching had been added to close llm thread when threads instead of async was being used for final response generation. This isn't required anymore since moving to async. And we can now re-enable retry on failures. Raise error if response is empty to retry llm completion.
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import logging
|
||||
from time import perf_counter
|
||||
from typing import Dict, List
|
||||
from typing import AsyncGenerator, Dict, List
|
||||
|
||||
import anthropic
|
||||
from langchain_core.messages.chat import ChatMessage
|
||||
@@ -100,6 +100,11 @@ def anthropic_completion_with_backoff(
|
||||
model_name, input_tokens, output_tokens, cache_read_tokens, cache_write_tokens, usage=tracer.get("usage")
|
||||
)
|
||||
|
||||
# Validate the response. If empty, raise an error to retry.
|
||||
if is_none_or_empty(aggregated_response):
|
||||
logger.warning(f"No response by {model_name}\nLast Message by {messages[-1].role}: {messages[-1].content}.")
|
||||
raise ValueError(f"Empty or no response by {model_name} over API. Retry if needed.")
|
||||
|
||||
# Save conversation trace
|
||||
tracer["chat_model"] = model_name
|
||||
tracer["temperature"] = temperature
|
||||
@@ -112,8 +117,8 @@ def anthropic_completion_with_backoff(
|
||||
@retry(
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
stop=stop_after_attempt(2),
|
||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
reraise=True,
|
||||
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||
reraise=False,
|
||||
)
|
||||
async def anthropic_chat_completion_with_backoff(
|
||||
messages: list[ChatMessage],
|
||||
@@ -126,75 +131,77 @@ async def anthropic_chat_completion_with_backoff(
|
||||
deepthought=False,
|
||||
model_kwargs=None,
|
||||
tracer={},
|
||||
):
|
||||
try:
|
||||
client = anthropic_async_clients.get(api_key)
|
||||
if not client:
|
||||
client = get_anthropic_async_client(api_key, api_base_url)
|
||||
anthropic_async_clients[api_key] = client
|
||||
) -> AsyncGenerator[ResponseWithThought, None]:
|
||||
client = anthropic_async_clients.get(api_key)
|
||||
if not client:
|
||||
client = get_anthropic_async_client(api_key, api_base_url)
|
||||
anthropic_async_clients[api_key] = client
|
||||
|
||||
model_kwargs = model_kwargs or dict()
|
||||
max_tokens = DEFAULT_MAX_TOKENS_ANTHROPIC
|
||||
if deepthought and model_name.startswith("claude-3-7"):
|
||||
model_kwargs["thinking"] = {"type": "enabled", "budget_tokens": MAX_REASONING_TOKENS_ANTHROPIC}
|
||||
max_tokens += MAX_REASONING_TOKENS_ANTHROPIC
|
||||
# Temperature control not supported when using extended thinking
|
||||
temperature = 1.0
|
||||
model_kwargs = model_kwargs or dict()
|
||||
max_tokens = DEFAULT_MAX_TOKENS_ANTHROPIC
|
||||
if deepthought and model_name.startswith("claude-3-7"):
|
||||
model_kwargs["thinking"] = {"type": "enabled", "budget_tokens": MAX_REASONING_TOKENS_ANTHROPIC}
|
||||
max_tokens += MAX_REASONING_TOKENS_ANTHROPIC
|
||||
# Temperature control not supported when using extended thinking
|
||||
temperature = 1.0
|
||||
|
||||
formatted_messages, system_prompt = format_messages_for_anthropic(messages, system_prompt)
|
||||
formatted_messages, system_prompt = format_messages_for_anthropic(messages, system_prompt)
|
||||
|
||||
aggregated_response = ""
|
||||
response_started = False
|
||||
final_message = None
|
||||
start_time = perf_counter()
|
||||
async with client.messages.stream(
|
||||
messages=formatted_messages,
|
||||
model=model_name, # type: ignore
|
||||
temperature=temperature,
|
||||
system=system_prompt,
|
||||
timeout=20,
|
||||
max_tokens=max_tokens,
|
||||
**model_kwargs,
|
||||
) as stream:
|
||||
async for chunk in stream:
|
||||
# Log the time taken to start response
|
||||
if not response_started:
|
||||
response_started = True
|
||||
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
|
||||
# Skip empty chunks
|
||||
if chunk.type != "content_block_delta":
|
||||
continue
|
||||
# Handle streamed response chunk
|
||||
response_chunk: ResponseWithThought = None
|
||||
if chunk.delta.type == "text_delta":
|
||||
response_chunk = ResponseWithThought(response=chunk.delta.text)
|
||||
aggregated_response += chunk.delta.text
|
||||
if chunk.delta.type == "thinking_delta":
|
||||
response_chunk = ResponseWithThought(thought=chunk.delta.thinking)
|
||||
# Handle streamed response chunk
|
||||
if response_chunk:
|
||||
yield response_chunk
|
||||
final_message = await stream.get_final_message()
|
||||
aggregated_response = ""
|
||||
response_started = False
|
||||
final_message = None
|
||||
start_time = perf_counter()
|
||||
async with client.messages.stream(
|
||||
messages=formatted_messages,
|
||||
model=model_name, # type: ignore
|
||||
temperature=temperature,
|
||||
system=system_prompt,
|
||||
timeout=20,
|
||||
max_tokens=max_tokens,
|
||||
**model_kwargs,
|
||||
) as stream:
|
||||
async for chunk in stream:
|
||||
# Log the time taken to start response
|
||||
if not response_started:
|
||||
response_started = True
|
||||
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
|
||||
# Skip empty chunks
|
||||
if chunk.type != "content_block_delta":
|
||||
continue
|
||||
# Handle streamed response chunk
|
||||
response_chunk: ResponseWithThought = None
|
||||
if chunk.delta.type == "text_delta":
|
||||
response_chunk = ResponseWithThought(response=chunk.delta.text)
|
||||
aggregated_response += chunk.delta.text
|
||||
if chunk.delta.type == "thinking_delta":
|
||||
response_chunk = ResponseWithThought(thought=chunk.delta.thinking)
|
||||
# Handle streamed response chunk
|
||||
if response_chunk:
|
||||
yield response_chunk
|
||||
final_message = await stream.get_final_message()
|
||||
|
||||
# Log the time taken to stream the entire response
|
||||
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
|
||||
# Calculate cost of chat
|
||||
input_tokens = final_message.usage.input_tokens
|
||||
output_tokens = final_message.usage.output_tokens
|
||||
cache_read_tokens = final_message.usage.cache_read_input_tokens
|
||||
cache_write_tokens = final_message.usage.cache_creation_input_tokens
|
||||
tracer["usage"] = get_chat_usage_metrics(
|
||||
model_name, input_tokens, output_tokens, cache_read_tokens, cache_write_tokens, usage=tracer.get("usage")
|
||||
)
|
||||
|
||||
# Calculate cost of chat
|
||||
input_tokens = final_message.usage.input_tokens
|
||||
output_tokens = final_message.usage.output_tokens
|
||||
cache_read_tokens = final_message.usage.cache_read_input_tokens
|
||||
cache_write_tokens = final_message.usage.cache_creation_input_tokens
|
||||
tracer["usage"] = get_chat_usage_metrics(
|
||||
model_name, input_tokens, output_tokens, cache_read_tokens, cache_write_tokens, usage=tracer.get("usage")
|
||||
)
|
||||
# Validate the response. If empty, raise an error to retry.
|
||||
if is_none_or_empty(aggregated_response):
|
||||
logger.warning(f"No response by {model_name}\nLast Message by {messages[-1].role}: {messages[-1].content}.")
|
||||
raise ValueError(f"Empty or no response by {model_name} over API. Retry if needed.")
|
||||
|
||||
# Save conversation trace
|
||||
tracer["chat_model"] = model_name
|
||||
tracer["temperature"] = temperature
|
||||
if is_promptrace_enabled():
|
||||
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in anthropic_chat_completion_with_backoff stream: {e}", exc_info=True)
|
||||
# Log the time taken to stream the entire response
|
||||
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
|
||||
|
||||
# Save conversation trace
|
||||
tracer["chat_model"] = model_name
|
||||
tracer["temperature"] = temperature
|
||||
if is_promptrace_enabled():
|
||||
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||
|
||||
|
||||
def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt: str = None):
|
||||
|
||||
@@ -73,6 +73,9 @@ def _is_retryable_error(exception: BaseException) -> bool:
|
||||
# client errors
|
||||
if isinstance(exception, httpx.TimeoutException) or isinstance(exception, httpx.NetworkError):
|
||||
return True
|
||||
# validation errors
|
||||
if isinstance(exception, ValueError):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@@ -84,8 +87,8 @@ def _is_retryable_error(exception: BaseException) -> bool:
|
||||
reraise=True,
|
||||
)
|
||||
def gemini_completion_with_backoff(
|
||||
messages,
|
||||
system_prompt,
|
||||
messages: list[ChatMessage],
|
||||
system_prompt: str,
|
||||
model_name: str,
|
||||
temperature=1.0,
|
||||
api_key=None,
|
||||
@@ -144,6 +147,11 @@ def gemini_completion_with_backoff(
|
||||
model_name, input_tokens, output_tokens, thought_tokens=thought_tokens, usage=tracer.get("usage")
|
||||
)
|
||||
|
||||
# Validate the response. If empty, raise an error to retry.
|
||||
if is_none_or_empty(response_text):
|
||||
logger.warning(f"No response by {model_name}\nLast Message by {messages[-1].role}: {messages[-1].content}.")
|
||||
raise ValueError(f"Empty or no response by {model_name} over API. Retry if needed.")
|
||||
|
||||
# Save conversation trace
|
||||
tracer["chat_model"] = model_name
|
||||
tracer["temperature"] = temperature
|
||||
@@ -157,89 +165,90 @@ def gemini_completion_with_backoff(
|
||||
retry=retry_if_exception(_is_retryable_error),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
stop=stop_after_attempt(3),
|
||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
reraise=True,
|
||||
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||
reraise=False,
|
||||
)
|
||||
async def gemini_chat_completion_with_backoff(
|
||||
messages,
|
||||
model_name,
|
||||
temperature,
|
||||
api_key,
|
||||
api_base_url,
|
||||
system_prompt,
|
||||
messages: list[ChatMessage],
|
||||
model_name: str,
|
||||
temperature: float,
|
||||
api_key: str,
|
||||
api_base_url: str,
|
||||
system_prompt: str,
|
||||
model_kwargs=None,
|
||||
deepthought=False,
|
||||
tracer: dict = {},
|
||||
) -> AsyncGenerator[str, None]:
|
||||
try:
|
||||
client = gemini_clients.get(api_key)
|
||||
if not client:
|
||||
client = get_gemini_client(api_key, api_base_url)
|
||||
gemini_clients[api_key] = client
|
||||
client = gemini_clients.get(api_key)
|
||||
if not client:
|
||||
client = get_gemini_client(api_key, api_base_url)
|
||||
gemini_clients[api_key] = client
|
||||
|
||||
formatted_messages, system_prompt = format_messages_for_gemini(messages, system_prompt)
|
||||
formatted_messages, system_prompt = format_messages_for_gemini(messages, system_prompt)
|
||||
|
||||
thinking_config = None
|
||||
if deepthought and model_name.startswith("gemini-2-5"):
|
||||
thinking_config = gtypes.ThinkingConfig(thinking_budget=MAX_REASONING_TOKENS_GEMINI)
|
||||
thinking_config = None
|
||||
if deepthought and model_name.startswith("gemini-2-5"):
|
||||
thinking_config = gtypes.ThinkingConfig(thinking_budget=MAX_REASONING_TOKENS_GEMINI)
|
||||
|
||||
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
|
||||
config = gtypes.GenerateContentConfig(
|
||||
system_instruction=system_prompt,
|
||||
temperature=temperature,
|
||||
thinking_config=thinking_config,
|
||||
max_output_tokens=MAX_OUTPUT_TOKENS_GEMINI,
|
||||
stop_sequences=["Notes:\n["],
|
||||
safety_settings=SAFETY_SETTINGS,
|
||||
seed=seed,
|
||||
http_options=gtypes.HttpOptions(async_client_args={"timeout": httpx.Timeout(30.0, read=60.0)}),
|
||||
)
|
||||
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
|
||||
config = gtypes.GenerateContentConfig(
|
||||
system_instruction=system_prompt,
|
||||
temperature=temperature,
|
||||
thinking_config=thinking_config,
|
||||
max_output_tokens=MAX_OUTPUT_TOKENS_GEMINI,
|
||||
stop_sequences=["Notes:\n["],
|
||||
safety_settings=SAFETY_SETTINGS,
|
||||
seed=seed,
|
||||
http_options=gtypes.HttpOptions(async_client_args={"timeout": httpx.Timeout(30.0, read=60.0)}),
|
||||
)
|
||||
|
||||
aggregated_response = ""
|
||||
final_chunk = None
|
||||
response_started = False
|
||||
start_time = perf_counter()
|
||||
chat_stream: AsyncIterator[gtypes.GenerateContentResponse] = await client.aio.models.generate_content_stream(
|
||||
model=model_name, config=config, contents=formatted_messages
|
||||
)
|
||||
async for chunk in chat_stream:
|
||||
# Log the time taken to start response
|
||||
if not response_started:
|
||||
response_started = True
|
||||
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
|
||||
# Keep track of the last chunk for usage data
|
||||
final_chunk = chunk
|
||||
# Handle streamed response chunk
|
||||
message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback)
|
||||
message = message or chunk.text
|
||||
aggregated_response += message
|
||||
yield message
|
||||
if stopped:
|
||||
raise ValueError(message)
|
||||
aggregated_response = ""
|
||||
final_chunk = None
|
||||
response_started = False
|
||||
start_time = perf_counter()
|
||||
chat_stream: AsyncIterator[gtypes.GenerateContentResponse] = await client.aio.models.generate_content_stream(
|
||||
model=model_name, config=config, contents=formatted_messages
|
||||
)
|
||||
async for chunk in chat_stream:
|
||||
# Log the time taken to start response
|
||||
if not response_started:
|
||||
response_started = True
|
||||
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
|
||||
# Keep track of the last chunk for usage data
|
||||
final_chunk = chunk
|
||||
# Handle streamed response chunk
|
||||
stop_message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback)
|
||||
message = stop_message or chunk.text
|
||||
aggregated_response += message
|
||||
yield message
|
||||
if stopped:
|
||||
logger.warning(
|
||||
f"LLM Response Prevented for {model_name}: {stop_message}.\n"
|
||||
+ f"Last Message by {messages[-1].role}: {messages[-1].content}"
|
||||
)
|
||||
break
|
||||
|
||||
# Log the time taken to stream the entire response
|
||||
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
|
||||
# Calculate cost of chat
|
||||
input_tokens = final_chunk.usage_metadata.prompt_token_count or 0 if final_chunk else 0
|
||||
output_tokens = final_chunk.usage_metadata.candidates_token_count or 0 if final_chunk else 0
|
||||
thought_tokens = final_chunk.usage_metadata.thoughts_token_count or 0 if final_chunk else 0
|
||||
tracer["usage"] = get_chat_usage_metrics(
|
||||
model_name, input_tokens, output_tokens, thought_tokens=thought_tokens, usage=tracer.get("usage")
|
||||
)
|
||||
|
||||
# Calculate cost of chat
|
||||
input_tokens = final_chunk.usage_metadata.prompt_token_count or 0 if final_chunk else 0
|
||||
output_tokens = final_chunk.usage_metadata.candidates_token_count or 0 if final_chunk else 0
|
||||
thought_tokens = final_chunk.usage_metadata.thoughts_token_count or 0 if final_chunk else 0
|
||||
tracer["usage"] = get_chat_usage_metrics(
|
||||
model_name, input_tokens, output_tokens, thought_tokens=thought_tokens, usage=tracer.get("usage")
|
||||
)
|
||||
# Validate the response. If empty, raise an error to retry.
|
||||
if is_none_or_empty(aggregated_response):
|
||||
logger.warning(f"No response by {model_name}\nLast Message by {messages[-1].role}: {messages[-1].content}.")
|
||||
raise ValueError(f"Empty or no response by {model_name} over API. Retry if needed.")
|
||||
|
||||
# Save conversation trace
|
||||
tracer["chat_model"] = model_name
|
||||
tracer["temperature"] = temperature
|
||||
if is_promptrace_enabled():
|
||||
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||
except ValueError as e:
|
||||
logger.warning(
|
||||
f"LLM Response Prevented for {model_name}: {e.args[0]}.\n"
|
||||
+ f"Last Message by {messages[-1].role}: {messages[-1].content}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in gemini_chat_completion_with_backoff stream: {e}", exc_info=True)
|
||||
# Log the time taken to stream the entire response
|
||||
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
|
||||
|
||||
# Save conversation trace
|
||||
tracer["chat_model"] = model_name
|
||||
tracer["temperature"] = temperature
|
||||
if is_promptrace_enabled():
|
||||
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||
|
||||
|
||||
def handle_gemini_response(
|
||||
|
||||
@@ -38,6 +38,7 @@ from khoj.utils.helpers import (
|
||||
get_chat_usage_metrics,
|
||||
get_openai_async_client,
|
||||
get_openai_client,
|
||||
is_none_or_empty,
|
||||
is_promptrace_enabled,
|
||||
)
|
||||
|
||||
@@ -54,6 +55,7 @@ openai_async_clients: Dict[str, openai.AsyncOpenAI] = {}
|
||||
| retry_if_exception_type(openai._exceptions.APIConnectionError)
|
||||
| retry_if_exception_type(openai._exceptions.RateLimitError)
|
||||
| retry_if_exception_type(openai._exceptions.APIStatusError)
|
||||
| retry_if_exception_type(ValueError)
|
||||
),
|
||||
wait=wait_random_exponential(min=1, max=10),
|
||||
stop=stop_after_attempt(3),
|
||||
@@ -136,6 +138,11 @@ def completion_with_backoff(
|
||||
model_name, input_tokens, output_tokens, usage=tracer.get("usage"), cost=cost
|
||||
)
|
||||
|
||||
# Validate the response. If empty, raise an error to retry.
|
||||
if is_none_or_empty(aggregated_response):
|
||||
logger.warning(f"No response by {model_name}\nLast Message by {messages[-1].role}: {messages[-1].content}.")
|
||||
raise ValueError(f"Empty or no response by {model_name} over API. Retry if needed.")
|
||||
|
||||
# Save conversation trace
|
||||
tracer["chat_model"] = model_name
|
||||
tracer["temperature"] = temperature
|
||||
@@ -152,14 +159,15 @@ def completion_with_backoff(
|
||||
| retry_if_exception_type(openai._exceptions.APIConnectionError)
|
||||
| retry_if_exception_type(openai._exceptions.RateLimitError)
|
||||
| retry_if_exception_type(openai._exceptions.APIStatusError)
|
||||
| retry_if_exception_type(ValueError)
|
||||
),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
stop=stop_after_attempt(3),
|
||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
reraise=True,
|
||||
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||
reraise=False,
|
||||
)
|
||||
async def chat_completion_with_backoff(
|
||||
messages,
|
||||
messages: list[ChatMessage],
|
||||
model_name: str,
|
||||
temperature,
|
||||
openai_api_key=None,
|
||||
@@ -168,120 +176,122 @@ async def chat_completion_with_backoff(
|
||||
model_kwargs: dict = {},
|
||||
tracer: dict = {},
|
||||
) -> AsyncGenerator[ResponseWithThought, None]:
|
||||
try:
|
||||
client_key = f"{openai_api_key}--{api_base_url}"
|
||||
client = openai_async_clients.get(client_key)
|
||||
if not client:
|
||||
client = get_openai_async_client(openai_api_key, api_base_url)
|
||||
openai_async_clients[client_key] = client
|
||||
client_key = f"{openai_api_key}--{api_base_url}"
|
||||
client = openai_async_clients.get(client_key)
|
||||
if not client:
|
||||
client = get_openai_async_client(openai_api_key, api_base_url)
|
||||
openai_async_clients[client_key] = client
|
||||
|
||||
stream_processor = adefault_stream_processor
|
||||
formatted_messages = format_message_for_api(messages, api_base_url)
|
||||
stream_processor = adefault_stream_processor
|
||||
formatted_messages = format_message_for_api(messages, api_base_url)
|
||||
|
||||
# Configure thinking for openai reasoning models
|
||||
if is_openai_reasoning_model(model_name, api_base_url):
|
||||
temperature = 1
|
||||
reasoning_effort = "medium" if deepthought else "low"
|
||||
model_kwargs["reasoning_effort"] = reasoning_effort
|
||||
model_kwargs.pop("stop", None) # Remove unsupported stop param for reasoning models
|
||||
# Configure thinking for openai reasoning models
|
||||
if is_openai_reasoning_model(model_name, api_base_url):
|
||||
temperature = 1
|
||||
reasoning_effort = "medium" if deepthought else "low"
|
||||
model_kwargs["reasoning_effort"] = reasoning_effort
|
||||
model_kwargs.pop("stop", None) # Remove unsupported stop param for reasoning models
|
||||
|
||||
# Get the first system message and add the string `Formatting re-enabled` to it.
|
||||
# See https://platform.openai.com/docs/guides/reasoning-best-practices
|
||||
if len(formatted_messages) > 0:
|
||||
system_messages = [
|
||||
(i, message) for i, message in enumerate(formatted_messages) if message["role"] == "system"
|
||||
]
|
||||
if len(system_messages) > 0:
|
||||
first_system_message_index, first_system_message = system_messages[0]
|
||||
first_system_message_content = first_system_message["content"]
|
||||
formatted_messages[first_system_message_index][
|
||||
"content"
|
||||
] = f"{first_system_message_content}\nFormatting re-enabled"
|
||||
elif is_twitter_reasoning_model(model_name, api_base_url):
|
||||
stream_processor = adeepseek_stream_processor
|
||||
reasoning_effort = "high" if deepthought else "low"
|
||||
model_kwargs["reasoning_effort"] = reasoning_effort
|
||||
elif model_name.startswith("deepseek-reasoner"):
|
||||
stream_processor = adeepseek_stream_processor
|
||||
# Two successive messages cannot be from the same role. Should merge any back-to-back messages from the same role.
|
||||
# The first message should always be a user message (except system message).
|
||||
updated_messages: List[dict] = []
|
||||
for i, message in enumerate(formatted_messages):
|
||||
if i > 0 and message["role"] == formatted_messages[i - 1]["role"]:
|
||||
updated_messages[-1]["content"] += " " + message["content"]
|
||||
elif i == 1 and formatted_messages[i - 1]["role"] == "system" and message["role"] == "assistant":
|
||||
updated_messages[-1]["content"] += " " + message["content"]
|
||||
else:
|
||||
updated_messages.append(message)
|
||||
formatted_messages = updated_messages
|
||||
elif is_qwen_reasoning_model(model_name, api_base_url):
|
||||
stream_processor = partial(ain_stream_thought_processor, thought_tag="think")
|
||||
# Reasoning is enabled by default. Disable when deepthought is False.
|
||||
# See https://qwenlm.github.io/blog/qwen3/#advanced-usages
|
||||
if not deepthought and len(formatted_messages) > 0:
|
||||
formatted_messages[-1]["content"] = formatted_messages[-1]["content"] + " /no_think"
|
||||
# Get the first system message and add the string `Formatting re-enabled` to it.
|
||||
# See https://platform.openai.com/docs/guides/reasoning-best-practices
|
||||
if len(formatted_messages) > 0:
|
||||
system_messages = [
|
||||
(i, message) for i, message in enumerate(formatted_messages) if message["role"] == "system"
|
||||
]
|
||||
if len(system_messages) > 0:
|
||||
first_system_message_index, first_system_message = system_messages[0]
|
||||
first_system_message_content = first_system_message["content"]
|
||||
formatted_messages[first_system_message_index][
|
||||
"content"
|
||||
] = f"{first_system_message_content}\nFormatting re-enabled"
|
||||
elif is_twitter_reasoning_model(model_name, api_base_url):
|
||||
stream_processor = adeepseek_stream_processor
|
||||
reasoning_effort = "high" if deepthought else "low"
|
||||
model_kwargs["reasoning_effort"] = reasoning_effort
|
||||
elif model_name.startswith("deepseek-reasoner"):
|
||||
stream_processor = adeepseek_stream_processor
|
||||
# Two successive messages cannot be from the same role. Should merge any back-to-back messages from the same role.
|
||||
# The first message should always be a user message (except system message).
|
||||
updated_messages: List[dict] = []
|
||||
for i, message in enumerate(formatted_messages):
|
||||
if i > 0 and message["role"] == formatted_messages[i - 1]["role"]:
|
||||
updated_messages[-1]["content"] += " " + message["content"]
|
||||
elif i == 1 and formatted_messages[i - 1]["role"] == "system" and message["role"] == "assistant":
|
||||
updated_messages[-1]["content"] += " " + message["content"]
|
||||
else:
|
||||
updated_messages.append(message)
|
||||
formatted_messages = updated_messages
|
||||
elif is_qwen_reasoning_model(model_name, api_base_url):
|
||||
stream_processor = partial(ain_stream_thought_processor, thought_tag="think")
|
||||
# Reasoning is enabled by default. Disable when deepthought is False.
|
||||
# See https://qwenlm.github.io/blog/qwen3/#advanced-usages
|
||||
if not deepthought and len(formatted_messages) > 0:
|
||||
formatted_messages[-1]["content"] = formatted_messages[-1]["content"] + " /no_think"
|
||||
|
||||
stream = True
|
||||
read_timeout = 300 if is_local_api(api_base_url) else 60
|
||||
model_kwargs["stream_options"] = {"include_usage": True}
|
||||
if os.getenv("KHOJ_LLM_SEED"):
|
||||
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
|
||||
stream = True
|
||||
read_timeout = 300 if is_local_api(api_base_url) else 60
|
||||
model_kwargs["stream_options"] = {"include_usage": True}
|
||||
if os.getenv("KHOJ_LLM_SEED"):
|
||||
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
|
||||
|
||||
aggregated_response = ""
|
||||
final_chunk = None
|
||||
response_started = False
|
||||
start_time = perf_counter()
|
||||
chat_stream: openai.AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(
|
||||
messages=formatted_messages, # type: ignore
|
||||
model=model_name,
|
||||
stream=stream,
|
||||
temperature=temperature,
|
||||
timeout=httpx.Timeout(30, read=read_timeout),
|
||||
**model_kwargs,
|
||||
)
|
||||
async for chunk in stream_processor(chat_stream):
|
||||
# Log the time taken to start response
|
||||
if not response_started:
|
||||
response_started = True
|
||||
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
|
||||
# Keep track of the last chunk for usage data
|
||||
final_chunk = chunk
|
||||
# Skip empty chunks
|
||||
if len(chunk.choices) == 0:
|
||||
continue
|
||||
# Handle streamed response chunk
|
||||
response_chunk: ResponseWithThought = None
|
||||
response_delta = chunk.choices[0].delta
|
||||
if response_delta.content:
|
||||
response_chunk = ResponseWithThought(response=response_delta.content)
|
||||
aggregated_response += response_chunk.response
|
||||
elif response_delta.thought:
|
||||
response_chunk = ResponseWithThought(thought=response_delta.thought)
|
||||
if response_chunk:
|
||||
yield response_chunk
|
||||
aggregated_response = ""
|
||||
final_chunk = None
|
||||
response_started = False
|
||||
start_time = perf_counter()
|
||||
chat_stream: openai.AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(
|
||||
messages=formatted_messages, # type: ignore
|
||||
model=model_name,
|
||||
stream=stream,
|
||||
temperature=temperature,
|
||||
timeout=httpx.Timeout(30, read=read_timeout),
|
||||
**model_kwargs,
|
||||
)
|
||||
async for chunk in stream_processor(chat_stream):
|
||||
# Log the time taken to start response
|
||||
if not response_started:
|
||||
response_started = True
|
||||
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
|
||||
# Keep track of the last chunk for usage data
|
||||
final_chunk = chunk
|
||||
# Skip empty chunks
|
||||
if len(chunk.choices) == 0:
|
||||
continue
|
||||
# Handle streamed response chunk
|
||||
response_chunk: ResponseWithThought = None
|
||||
response_delta = chunk.choices[0].delta
|
||||
if response_delta.content:
|
||||
response_chunk = ResponseWithThought(response=response_delta.content)
|
||||
aggregated_response += response_chunk.response
|
||||
elif response_delta.thought:
|
||||
response_chunk = ResponseWithThought(thought=response_delta.thought)
|
||||
if response_chunk:
|
||||
yield response_chunk
|
||||
|
||||
# Log the time taken to stream the entire response
|
||||
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
|
||||
# Calculate cost of chat after stream finishes
|
||||
input_tokens, output_tokens, cost = 0, 0, 0
|
||||
if final_chunk and hasattr(final_chunk, "usage") and final_chunk.usage:
|
||||
input_tokens = final_chunk.usage.prompt_tokens
|
||||
output_tokens = final_chunk.usage.completion_tokens
|
||||
# Estimated costs returned by DeepInfra API
|
||||
if final_chunk.usage.model_extra and "estimated_cost" in final_chunk.usage.model_extra:
|
||||
cost = final_chunk.usage.model_extra.get("estimated_cost", 0)
|
||||
tracer["usage"] = get_chat_usage_metrics(
|
||||
model_name, input_tokens, output_tokens, usage=tracer.get("usage"), cost=cost
|
||||
)
|
||||
|
||||
# Calculate cost of chat after stream finishes
|
||||
input_tokens, output_tokens, cost = 0, 0, 0
|
||||
if final_chunk and hasattr(final_chunk, "usage") and final_chunk.usage:
|
||||
input_tokens = final_chunk.usage.prompt_tokens
|
||||
output_tokens = final_chunk.usage.completion_tokens
|
||||
# Estimated costs returned by DeepInfra API
|
||||
if final_chunk.usage.model_extra and "estimated_cost" in final_chunk.usage.model_extra:
|
||||
cost = final_chunk.usage.model_extra.get("estimated_cost", 0)
|
||||
# Validate the response. If empty, raise an error to retry.
|
||||
if is_none_or_empty(aggregated_response):
|
||||
logger.warning(f"No response by {model_name}\nLast Message by {messages[-1].role}: {messages[-1].content}.")
|
||||
raise ValueError(f"Empty or no response by {model_name} over API. Retry if needed.")
|
||||
|
||||
# Save conversation trace
|
||||
tracer["chat_model"] = model_name
|
||||
tracer["temperature"] = temperature
|
||||
tracer["usage"] = get_chat_usage_metrics(
|
||||
model_name, input_tokens, output_tokens, usage=tracer.get("usage"), cost=cost
|
||||
)
|
||||
if is_promptrace_enabled():
|
||||
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in chat_completion_with_backoff stream: {e}", exc_info=True)
|
||||
# Log the time taken to stream the entire response
|
||||
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
|
||||
|
||||
# Save conversation trace
|
||||
tracer["chat_model"] = model_name
|
||||
tracer["temperature"] = temperature
|
||||
if is_promptrace_enabled():
|
||||
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||
|
||||
|
||||
def get_openai_api_json_support(model_name: str, api_base_url: str = None) -> JsonSupport:
|
||||
|
||||
Reference in New Issue
Block a user