mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-10 13:26:13 +00:00
Retry on empty response or error in chat completion by llm over api
Previously all exceptions were being caught. So retry logic wasn't getting triggered. Exception catching had been added to close llm thread when threads instead of async was being used for final response generation. This isn't required anymore since moving to async. And we can now re-enable retry on failures. Raise error if response is empty to retry llm completion.
This commit is contained in:
@@ -1,6 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
from time import perf_counter
|
from time import perf_counter
|
||||||
from typing import Dict, List
|
from typing import AsyncGenerator, Dict, List
|
||||||
|
|
||||||
import anthropic
|
import anthropic
|
||||||
from langchain_core.messages.chat import ChatMessage
|
from langchain_core.messages.chat import ChatMessage
|
||||||
@@ -100,6 +100,11 @@ def anthropic_completion_with_backoff(
|
|||||||
model_name, input_tokens, output_tokens, cache_read_tokens, cache_write_tokens, usage=tracer.get("usage")
|
model_name, input_tokens, output_tokens, cache_read_tokens, cache_write_tokens, usage=tracer.get("usage")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Validate the response. If empty, raise an error to retry.
|
||||||
|
if is_none_or_empty(aggregated_response):
|
||||||
|
logger.warning(f"No response by {model_name}\nLast Message by {messages[-1].role}: {messages[-1].content}.")
|
||||||
|
raise ValueError(f"Empty or no response by {model_name} over API. Retry if needed.")
|
||||||
|
|
||||||
# Save conversation trace
|
# Save conversation trace
|
||||||
tracer["chat_model"] = model_name
|
tracer["chat_model"] = model_name
|
||||||
tracer["temperature"] = temperature
|
tracer["temperature"] = temperature
|
||||||
@@ -112,8 +117,8 @@ def anthropic_completion_with_backoff(
|
|||||||
@retry(
|
@retry(
|
||||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||||
stop=stop_after_attempt(2),
|
stop=stop_after_attempt(2),
|
||||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||||
reraise=True,
|
reraise=False,
|
||||||
)
|
)
|
||||||
async def anthropic_chat_completion_with_backoff(
|
async def anthropic_chat_completion_with_backoff(
|
||||||
messages: list[ChatMessage],
|
messages: list[ChatMessage],
|
||||||
@@ -126,75 +131,77 @@ async def anthropic_chat_completion_with_backoff(
|
|||||||
deepthought=False,
|
deepthought=False,
|
||||||
model_kwargs=None,
|
model_kwargs=None,
|
||||||
tracer={},
|
tracer={},
|
||||||
):
|
) -> AsyncGenerator[ResponseWithThought, None]:
|
||||||
try:
|
client = anthropic_async_clients.get(api_key)
|
||||||
client = anthropic_async_clients.get(api_key)
|
if not client:
|
||||||
if not client:
|
client = get_anthropic_async_client(api_key, api_base_url)
|
||||||
client = get_anthropic_async_client(api_key, api_base_url)
|
anthropic_async_clients[api_key] = client
|
||||||
anthropic_async_clients[api_key] = client
|
|
||||||
|
|
||||||
model_kwargs = model_kwargs or dict()
|
model_kwargs = model_kwargs or dict()
|
||||||
max_tokens = DEFAULT_MAX_TOKENS_ANTHROPIC
|
max_tokens = DEFAULT_MAX_TOKENS_ANTHROPIC
|
||||||
if deepthought and model_name.startswith("claude-3-7"):
|
if deepthought and model_name.startswith("claude-3-7"):
|
||||||
model_kwargs["thinking"] = {"type": "enabled", "budget_tokens": MAX_REASONING_TOKENS_ANTHROPIC}
|
model_kwargs["thinking"] = {"type": "enabled", "budget_tokens": MAX_REASONING_TOKENS_ANTHROPIC}
|
||||||
max_tokens += MAX_REASONING_TOKENS_ANTHROPIC
|
max_tokens += MAX_REASONING_TOKENS_ANTHROPIC
|
||||||
# Temperature control not supported when using extended thinking
|
# Temperature control not supported when using extended thinking
|
||||||
temperature = 1.0
|
temperature = 1.0
|
||||||
|
|
||||||
formatted_messages, system_prompt = format_messages_for_anthropic(messages, system_prompt)
|
formatted_messages, system_prompt = format_messages_for_anthropic(messages, system_prompt)
|
||||||
|
|
||||||
aggregated_response = ""
|
aggregated_response = ""
|
||||||
response_started = False
|
response_started = False
|
||||||
final_message = None
|
final_message = None
|
||||||
start_time = perf_counter()
|
start_time = perf_counter()
|
||||||
async with client.messages.stream(
|
async with client.messages.stream(
|
||||||
messages=formatted_messages,
|
messages=formatted_messages,
|
||||||
model=model_name, # type: ignore
|
model=model_name, # type: ignore
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
system=system_prompt,
|
system=system_prompt,
|
||||||
timeout=20,
|
timeout=20,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
) as stream:
|
) as stream:
|
||||||
async for chunk in stream:
|
async for chunk in stream:
|
||||||
# Log the time taken to start response
|
# Log the time taken to start response
|
||||||
if not response_started:
|
if not response_started:
|
||||||
response_started = True
|
response_started = True
|
||||||
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
|
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
|
||||||
# Skip empty chunks
|
# Skip empty chunks
|
||||||
if chunk.type != "content_block_delta":
|
if chunk.type != "content_block_delta":
|
||||||
continue
|
continue
|
||||||
# Handle streamed response chunk
|
# Handle streamed response chunk
|
||||||
response_chunk: ResponseWithThought = None
|
response_chunk: ResponseWithThought = None
|
||||||
if chunk.delta.type == "text_delta":
|
if chunk.delta.type == "text_delta":
|
||||||
response_chunk = ResponseWithThought(response=chunk.delta.text)
|
response_chunk = ResponseWithThought(response=chunk.delta.text)
|
||||||
aggregated_response += chunk.delta.text
|
aggregated_response += chunk.delta.text
|
||||||
if chunk.delta.type == "thinking_delta":
|
if chunk.delta.type == "thinking_delta":
|
||||||
response_chunk = ResponseWithThought(thought=chunk.delta.thinking)
|
response_chunk = ResponseWithThought(thought=chunk.delta.thinking)
|
||||||
# Handle streamed response chunk
|
# Handle streamed response chunk
|
||||||
if response_chunk:
|
if response_chunk:
|
||||||
yield response_chunk
|
yield response_chunk
|
||||||
final_message = await stream.get_final_message()
|
final_message = await stream.get_final_message()
|
||||||
|
|
||||||
# Log the time taken to stream the entire response
|
# Calculate cost of chat
|
||||||
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
|
input_tokens = final_message.usage.input_tokens
|
||||||
|
output_tokens = final_message.usage.output_tokens
|
||||||
|
cache_read_tokens = final_message.usage.cache_read_input_tokens
|
||||||
|
cache_write_tokens = final_message.usage.cache_creation_input_tokens
|
||||||
|
tracer["usage"] = get_chat_usage_metrics(
|
||||||
|
model_name, input_tokens, output_tokens, cache_read_tokens, cache_write_tokens, usage=tracer.get("usage")
|
||||||
|
)
|
||||||
|
|
||||||
# Calculate cost of chat
|
# Validate the response. If empty, raise an error to retry.
|
||||||
input_tokens = final_message.usage.input_tokens
|
if is_none_or_empty(aggregated_response):
|
||||||
output_tokens = final_message.usage.output_tokens
|
logger.warning(f"No response by {model_name}\nLast Message by {messages[-1].role}: {messages[-1].content}.")
|
||||||
cache_read_tokens = final_message.usage.cache_read_input_tokens
|
raise ValueError(f"Empty or no response by {model_name} over API. Retry if needed.")
|
||||||
cache_write_tokens = final_message.usage.cache_creation_input_tokens
|
|
||||||
tracer["usage"] = get_chat_usage_metrics(
|
|
||||||
model_name, input_tokens, output_tokens, cache_read_tokens, cache_write_tokens, usage=tracer.get("usage")
|
|
||||||
)
|
|
||||||
|
|
||||||
# Save conversation trace
|
# Log the time taken to stream the entire response
|
||||||
tracer["chat_model"] = model_name
|
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
|
||||||
tracer["temperature"] = temperature
|
|
||||||
if is_promptrace_enabled():
|
# Save conversation trace
|
||||||
commit_conversation_trace(messages, aggregated_response, tracer)
|
tracer["chat_model"] = model_name
|
||||||
except Exception as e:
|
tracer["temperature"] = temperature
|
||||||
logger.error(f"Error in anthropic_chat_completion_with_backoff stream: {e}", exc_info=True)
|
if is_promptrace_enabled():
|
||||||
|
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||||
|
|
||||||
|
|
||||||
def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt: str = None):
|
def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt: str = None):
|
||||||
|
|||||||
@@ -73,6 +73,9 @@ def _is_retryable_error(exception: BaseException) -> bool:
|
|||||||
# client errors
|
# client errors
|
||||||
if isinstance(exception, httpx.TimeoutException) or isinstance(exception, httpx.NetworkError):
|
if isinstance(exception, httpx.TimeoutException) or isinstance(exception, httpx.NetworkError):
|
||||||
return True
|
return True
|
||||||
|
# validation errors
|
||||||
|
if isinstance(exception, ValueError):
|
||||||
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
@@ -84,8 +87,8 @@ def _is_retryable_error(exception: BaseException) -> bool:
|
|||||||
reraise=True,
|
reraise=True,
|
||||||
)
|
)
|
||||||
def gemini_completion_with_backoff(
|
def gemini_completion_with_backoff(
|
||||||
messages,
|
messages: list[ChatMessage],
|
||||||
system_prompt,
|
system_prompt: str,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
temperature=1.0,
|
temperature=1.0,
|
||||||
api_key=None,
|
api_key=None,
|
||||||
@@ -144,6 +147,11 @@ def gemini_completion_with_backoff(
|
|||||||
model_name, input_tokens, output_tokens, thought_tokens=thought_tokens, usage=tracer.get("usage")
|
model_name, input_tokens, output_tokens, thought_tokens=thought_tokens, usage=tracer.get("usage")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Validate the response. If empty, raise an error to retry.
|
||||||
|
if is_none_or_empty(response_text):
|
||||||
|
logger.warning(f"No response by {model_name}\nLast Message by {messages[-1].role}: {messages[-1].content}.")
|
||||||
|
raise ValueError(f"Empty or no response by {model_name} over API. Retry if needed.")
|
||||||
|
|
||||||
# Save conversation trace
|
# Save conversation trace
|
||||||
tracer["chat_model"] = model_name
|
tracer["chat_model"] = model_name
|
||||||
tracer["temperature"] = temperature
|
tracer["temperature"] = temperature
|
||||||
@@ -157,89 +165,90 @@ def gemini_completion_with_backoff(
|
|||||||
retry=retry_if_exception(_is_retryable_error),
|
retry=retry_if_exception(_is_retryable_error),
|
||||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||||
stop=stop_after_attempt(3),
|
stop=stop_after_attempt(3),
|
||||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||||
reraise=True,
|
reraise=False,
|
||||||
)
|
)
|
||||||
async def gemini_chat_completion_with_backoff(
|
async def gemini_chat_completion_with_backoff(
|
||||||
messages,
|
messages: list[ChatMessage],
|
||||||
model_name,
|
model_name: str,
|
||||||
temperature,
|
temperature: float,
|
||||||
api_key,
|
api_key: str,
|
||||||
api_base_url,
|
api_base_url: str,
|
||||||
system_prompt,
|
system_prompt: str,
|
||||||
model_kwargs=None,
|
model_kwargs=None,
|
||||||
deepthought=False,
|
deepthought=False,
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
try:
|
client = gemini_clients.get(api_key)
|
||||||
client = gemini_clients.get(api_key)
|
if not client:
|
||||||
if not client:
|
client = get_gemini_client(api_key, api_base_url)
|
||||||
client = get_gemini_client(api_key, api_base_url)
|
gemini_clients[api_key] = client
|
||||||
gemini_clients[api_key] = client
|
|
||||||
|
|
||||||
formatted_messages, system_prompt = format_messages_for_gemini(messages, system_prompt)
|
formatted_messages, system_prompt = format_messages_for_gemini(messages, system_prompt)
|
||||||
|
|
||||||
thinking_config = None
|
thinking_config = None
|
||||||
if deepthought and model_name.startswith("gemini-2-5"):
|
if deepthought and model_name.startswith("gemini-2-5"):
|
||||||
thinking_config = gtypes.ThinkingConfig(thinking_budget=MAX_REASONING_TOKENS_GEMINI)
|
thinking_config = gtypes.ThinkingConfig(thinking_budget=MAX_REASONING_TOKENS_GEMINI)
|
||||||
|
|
||||||
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
|
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
|
||||||
config = gtypes.GenerateContentConfig(
|
config = gtypes.GenerateContentConfig(
|
||||||
system_instruction=system_prompt,
|
system_instruction=system_prompt,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
thinking_config=thinking_config,
|
thinking_config=thinking_config,
|
||||||
max_output_tokens=MAX_OUTPUT_TOKENS_GEMINI,
|
max_output_tokens=MAX_OUTPUT_TOKENS_GEMINI,
|
||||||
stop_sequences=["Notes:\n["],
|
stop_sequences=["Notes:\n["],
|
||||||
safety_settings=SAFETY_SETTINGS,
|
safety_settings=SAFETY_SETTINGS,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
http_options=gtypes.HttpOptions(async_client_args={"timeout": httpx.Timeout(30.0, read=60.0)}),
|
http_options=gtypes.HttpOptions(async_client_args={"timeout": httpx.Timeout(30.0, read=60.0)}),
|
||||||
)
|
)
|
||||||
|
|
||||||
aggregated_response = ""
|
aggregated_response = ""
|
||||||
final_chunk = None
|
final_chunk = None
|
||||||
response_started = False
|
response_started = False
|
||||||
start_time = perf_counter()
|
start_time = perf_counter()
|
||||||
chat_stream: AsyncIterator[gtypes.GenerateContentResponse] = await client.aio.models.generate_content_stream(
|
chat_stream: AsyncIterator[gtypes.GenerateContentResponse] = await client.aio.models.generate_content_stream(
|
||||||
model=model_name, config=config, contents=formatted_messages
|
model=model_name, config=config, contents=formatted_messages
|
||||||
)
|
)
|
||||||
async for chunk in chat_stream:
|
async for chunk in chat_stream:
|
||||||
# Log the time taken to start response
|
# Log the time taken to start response
|
||||||
if not response_started:
|
if not response_started:
|
||||||
response_started = True
|
response_started = True
|
||||||
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
|
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
|
||||||
# Keep track of the last chunk for usage data
|
# Keep track of the last chunk for usage data
|
||||||
final_chunk = chunk
|
final_chunk = chunk
|
||||||
# Handle streamed response chunk
|
# Handle streamed response chunk
|
||||||
message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback)
|
stop_message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback)
|
||||||
message = message or chunk.text
|
message = stop_message or chunk.text
|
||||||
aggregated_response += message
|
aggregated_response += message
|
||||||
yield message
|
yield message
|
||||||
if stopped:
|
if stopped:
|
||||||
raise ValueError(message)
|
logger.warning(
|
||||||
|
f"LLM Response Prevented for {model_name}: {stop_message}.\n"
|
||||||
|
+ f"Last Message by {messages[-1].role}: {messages[-1].content}"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
# Log the time taken to stream the entire response
|
# Calculate cost of chat
|
||||||
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
|
input_tokens = final_chunk.usage_metadata.prompt_token_count or 0 if final_chunk else 0
|
||||||
|
output_tokens = final_chunk.usage_metadata.candidates_token_count or 0 if final_chunk else 0
|
||||||
|
thought_tokens = final_chunk.usage_metadata.thoughts_token_count or 0 if final_chunk else 0
|
||||||
|
tracer["usage"] = get_chat_usage_metrics(
|
||||||
|
model_name, input_tokens, output_tokens, thought_tokens=thought_tokens, usage=tracer.get("usage")
|
||||||
|
)
|
||||||
|
|
||||||
# Calculate cost of chat
|
# Validate the response. If empty, raise an error to retry.
|
||||||
input_tokens = final_chunk.usage_metadata.prompt_token_count or 0 if final_chunk else 0
|
if is_none_or_empty(aggregated_response):
|
||||||
output_tokens = final_chunk.usage_metadata.candidates_token_count or 0 if final_chunk else 0
|
logger.warning(f"No response by {model_name}\nLast Message by {messages[-1].role}: {messages[-1].content}.")
|
||||||
thought_tokens = final_chunk.usage_metadata.thoughts_token_count or 0 if final_chunk else 0
|
raise ValueError(f"Empty or no response by {model_name} over API. Retry if needed.")
|
||||||
tracer["usage"] = get_chat_usage_metrics(
|
|
||||||
model_name, input_tokens, output_tokens, thought_tokens=thought_tokens, usage=tracer.get("usage")
|
|
||||||
)
|
|
||||||
|
|
||||||
# Save conversation trace
|
# Log the time taken to stream the entire response
|
||||||
tracer["chat_model"] = model_name
|
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
|
||||||
tracer["temperature"] = temperature
|
|
||||||
if is_promptrace_enabled():
|
# Save conversation trace
|
||||||
commit_conversation_trace(messages, aggregated_response, tracer)
|
tracer["chat_model"] = model_name
|
||||||
except ValueError as e:
|
tracer["temperature"] = temperature
|
||||||
logger.warning(
|
if is_promptrace_enabled():
|
||||||
f"LLM Response Prevented for {model_name}: {e.args[0]}.\n"
|
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||||
+ f"Last Message by {messages[-1].role}: {messages[-1].content}"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in gemini_chat_completion_with_backoff stream: {e}", exc_info=True)
|
|
||||||
|
|
||||||
|
|
||||||
def handle_gemini_response(
|
def handle_gemini_response(
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ from khoj.utils.helpers import (
|
|||||||
get_chat_usage_metrics,
|
get_chat_usage_metrics,
|
||||||
get_openai_async_client,
|
get_openai_async_client,
|
||||||
get_openai_client,
|
get_openai_client,
|
||||||
|
is_none_or_empty,
|
||||||
is_promptrace_enabled,
|
is_promptrace_enabled,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -54,6 +55,7 @@ openai_async_clients: Dict[str, openai.AsyncOpenAI] = {}
|
|||||||
| retry_if_exception_type(openai._exceptions.APIConnectionError)
|
| retry_if_exception_type(openai._exceptions.APIConnectionError)
|
||||||
| retry_if_exception_type(openai._exceptions.RateLimitError)
|
| retry_if_exception_type(openai._exceptions.RateLimitError)
|
||||||
| retry_if_exception_type(openai._exceptions.APIStatusError)
|
| retry_if_exception_type(openai._exceptions.APIStatusError)
|
||||||
|
| retry_if_exception_type(ValueError)
|
||||||
),
|
),
|
||||||
wait=wait_random_exponential(min=1, max=10),
|
wait=wait_random_exponential(min=1, max=10),
|
||||||
stop=stop_after_attempt(3),
|
stop=stop_after_attempt(3),
|
||||||
@@ -136,6 +138,11 @@ def completion_with_backoff(
|
|||||||
model_name, input_tokens, output_tokens, usage=tracer.get("usage"), cost=cost
|
model_name, input_tokens, output_tokens, usage=tracer.get("usage"), cost=cost
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Validate the response. If empty, raise an error to retry.
|
||||||
|
if is_none_or_empty(aggregated_response):
|
||||||
|
logger.warning(f"No response by {model_name}\nLast Message by {messages[-1].role}: {messages[-1].content}.")
|
||||||
|
raise ValueError(f"Empty or no response by {model_name} over API. Retry if needed.")
|
||||||
|
|
||||||
# Save conversation trace
|
# Save conversation trace
|
||||||
tracer["chat_model"] = model_name
|
tracer["chat_model"] = model_name
|
||||||
tracer["temperature"] = temperature
|
tracer["temperature"] = temperature
|
||||||
@@ -152,14 +159,15 @@ def completion_with_backoff(
|
|||||||
| retry_if_exception_type(openai._exceptions.APIConnectionError)
|
| retry_if_exception_type(openai._exceptions.APIConnectionError)
|
||||||
| retry_if_exception_type(openai._exceptions.RateLimitError)
|
| retry_if_exception_type(openai._exceptions.RateLimitError)
|
||||||
| retry_if_exception_type(openai._exceptions.APIStatusError)
|
| retry_if_exception_type(openai._exceptions.APIStatusError)
|
||||||
|
| retry_if_exception_type(ValueError)
|
||||||
),
|
),
|
||||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||||
stop=stop_after_attempt(3),
|
stop=stop_after_attempt(3),
|
||||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||||
reraise=True,
|
reraise=False,
|
||||||
)
|
)
|
||||||
async def chat_completion_with_backoff(
|
async def chat_completion_with_backoff(
|
||||||
messages,
|
messages: list[ChatMessage],
|
||||||
model_name: str,
|
model_name: str,
|
||||||
temperature,
|
temperature,
|
||||||
openai_api_key=None,
|
openai_api_key=None,
|
||||||
@@ -168,120 +176,122 @@ async def chat_completion_with_backoff(
|
|||||||
model_kwargs: dict = {},
|
model_kwargs: dict = {},
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
) -> AsyncGenerator[ResponseWithThought, None]:
|
) -> AsyncGenerator[ResponseWithThought, None]:
|
||||||
try:
|
client_key = f"{openai_api_key}--{api_base_url}"
|
||||||
client_key = f"{openai_api_key}--{api_base_url}"
|
client = openai_async_clients.get(client_key)
|
||||||
client = openai_async_clients.get(client_key)
|
if not client:
|
||||||
if not client:
|
client = get_openai_async_client(openai_api_key, api_base_url)
|
||||||
client = get_openai_async_client(openai_api_key, api_base_url)
|
openai_async_clients[client_key] = client
|
||||||
openai_async_clients[client_key] = client
|
|
||||||
|
|
||||||
stream_processor = adefault_stream_processor
|
stream_processor = adefault_stream_processor
|
||||||
formatted_messages = format_message_for_api(messages, api_base_url)
|
formatted_messages = format_message_for_api(messages, api_base_url)
|
||||||
|
|
||||||
# Configure thinking for openai reasoning models
|
# Configure thinking for openai reasoning models
|
||||||
if is_openai_reasoning_model(model_name, api_base_url):
|
if is_openai_reasoning_model(model_name, api_base_url):
|
||||||
temperature = 1
|
temperature = 1
|
||||||
reasoning_effort = "medium" if deepthought else "low"
|
reasoning_effort = "medium" if deepthought else "low"
|
||||||
model_kwargs["reasoning_effort"] = reasoning_effort
|
model_kwargs["reasoning_effort"] = reasoning_effort
|
||||||
model_kwargs.pop("stop", None) # Remove unsupported stop param for reasoning models
|
model_kwargs.pop("stop", None) # Remove unsupported stop param for reasoning models
|
||||||
|
|
||||||
# Get the first system message and add the string `Formatting re-enabled` to it.
|
# Get the first system message and add the string `Formatting re-enabled` to it.
|
||||||
# See https://platform.openai.com/docs/guides/reasoning-best-practices
|
# See https://platform.openai.com/docs/guides/reasoning-best-practices
|
||||||
if len(formatted_messages) > 0:
|
if len(formatted_messages) > 0:
|
||||||
system_messages = [
|
system_messages = [
|
||||||
(i, message) for i, message in enumerate(formatted_messages) if message["role"] == "system"
|
(i, message) for i, message in enumerate(formatted_messages) if message["role"] == "system"
|
||||||
]
|
]
|
||||||
if len(system_messages) > 0:
|
if len(system_messages) > 0:
|
||||||
first_system_message_index, first_system_message = system_messages[0]
|
first_system_message_index, first_system_message = system_messages[0]
|
||||||
first_system_message_content = first_system_message["content"]
|
first_system_message_content = first_system_message["content"]
|
||||||
formatted_messages[first_system_message_index][
|
formatted_messages[first_system_message_index][
|
||||||
"content"
|
"content"
|
||||||
] = f"{first_system_message_content}\nFormatting re-enabled"
|
] = f"{first_system_message_content}\nFormatting re-enabled"
|
||||||
elif is_twitter_reasoning_model(model_name, api_base_url):
|
elif is_twitter_reasoning_model(model_name, api_base_url):
|
||||||
stream_processor = adeepseek_stream_processor
|
stream_processor = adeepseek_stream_processor
|
||||||
reasoning_effort = "high" if deepthought else "low"
|
reasoning_effort = "high" if deepthought else "low"
|
||||||
model_kwargs["reasoning_effort"] = reasoning_effort
|
model_kwargs["reasoning_effort"] = reasoning_effort
|
||||||
elif model_name.startswith("deepseek-reasoner"):
|
elif model_name.startswith("deepseek-reasoner"):
|
||||||
stream_processor = adeepseek_stream_processor
|
stream_processor = adeepseek_stream_processor
|
||||||
# Two successive messages cannot be from the same role. Should merge any back-to-back messages from the same role.
|
# Two successive messages cannot be from the same role. Should merge any back-to-back messages from the same role.
|
||||||
# The first message should always be a user message (except system message).
|
# The first message should always be a user message (except system message).
|
||||||
updated_messages: List[dict] = []
|
updated_messages: List[dict] = []
|
||||||
for i, message in enumerate(formatted_messages):
|
for i, message in enumerate(formatted_messages):
|
||||||
if i > 0 and message["role"] == formatted_messages[i - 1]["role"]:
|
if i > 0 and message["role"] == formatted_messages[i - 1]["role"]:
|
||||||
updated_messages[-1]["content"] += " " + message["content"]
|
updated_messages[-1]["content"] += " " + message["content"]
|
||||||
elif i == 1 and formatted_messages[i - 1]["role"] == "system" and message["role"] == "assistant":
|
elif i == 1 and formatted_messages[i - 1]["role"] == "system" and message["role"] == "assistant":
|
||||||
updated_messages[-1]["content"] += " " + message["content"]
|
updated_messages[-1]["content"] += " " + message["content"]
|
||||||
else:
|
else:
|
||||||
updated_messages.append(message)
|
updated_messages.append(message)
|
||||||
formatted_messages = updated_messages
|
formatted_messages = updated_messages
|
||||||
elif is_qwen_reasoning_model(model_name, api_base_url):
|
elif is_qwen_reasoning_model(model_name, api_base_url):
|
||||||
stream_processor = partial(ain_stream_thought_processor, thought_tag="think")
|
stream_processor = partial(ain_stream_thought_processor, thought_tag="think")
|
||||||
# Reasoning is enabled by default. Disable when deepthought is False.
|
# Reasoning is enabled by default. Disable when deepthought is False.
|
||||||
# See https://qwenlm.github.io/blog/qwen3/#advanced-usages
|
# See https://qwenlm.github.io/blog/qwen3/#advanced-usages
|
||||||
if not deepthought and len(formatted_messages) > 0:
|
if not deepthought and len(formatted_messages) > 0:
|
||||||
formatted_messages[-1]["content"] = formatted_messages[-1]["content"] + " /no_think"
|
formatted_messages[-1]["content"] = formatted_messages[-1]["content"] + " /no_think"
|
||||||
|
|
||||||
stream = True
|
stream = True
|
||||||
read_timeout = 300 if is_local_api(api_base_url) else 60
|
read_timeout = 300 if is_local_api(api_base_url) else 60
|
||||||
model_kwargs["stream_options"] = {"include_usage": True}
|
model_kwargs["stream_options"] = {"include_usage": True}
|
||||||
if os.getenv("KHOJ_LLM_SEED"):
|
if os.getenv("KHOJ_LLM_SEED"):
|
||||||
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
|
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
|
||||||
|
|
||||||
aggregated_response = ""
|
aggregated_response = ""
|
||||||
final_chunk = None
|
final_chunk = None
|
||||||
response_started = False
|
response_started = False
|
||||||
start_time = perf_counter()
|
start_time = perf_counter()
|
||||||
chat_stream: openai.AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(
|
chat_stream: openai.AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(
|
||||||
messages=formatted_messages, # type: ignore
|
messages=formatted_messages, # type: ignore
|
||||||
model=model_name,
|
model=model_name,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
timeout=httpx.Timeout(30, read=read_timeout),
|
timeout=httpx.Timeout(30, read=read_timeout),
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
async for chunk in stream_processor(chat_stream):
|
async for chunk in stream_processor(chat_stream):
|
||||||
# Log the time taken to start response
|
# Log the time taken to start response
|
||||||
if not response_started:
|
if not response_started:
|
||||||
response_started = True
|
response_started = True
|
||||||
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
|
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
|
||||||
# Keep track of the last chunk for usage data
|
# Keep track of the last chunk for usage data
|
||||||
final_chunk = chunk
|
final_chunk = chunk
|
||||||
# Skip empty chunks
|
# Skip empty chunks
|
||||||
if len(chunk.choices) == 0:
|
if len(chunk.choices) == 0:
|
||||||
continue
|
continue
|
||||||
# Handle streamed response chunk
|
# Handle streamed response chunk
|
||||||
response_chunk: ResponseWithThought = None
|
response_chunk: ResponseWithThought = None
|
||||||
response_delta = chunk.choices[0].delta
|
response_delta = chunk.choices[0].delta
|
||||||
if response_delta.content:
|
if response_delta.content:
|
||||||
response_chunk = ResponseWithThought(response=response_delta.content)
|
response_chunk = ResponseWithThought(response=response_delta.content)
|
||||||
aggregated_response += response_chunk.response
|
aggregated_response += response_chunk.response
|
||||||
elif response_delta.thought:
|
elif response_delta.thought:
|
||||||
response_chunk = ResponseWithThought(thought=response_delta.thought)
|
response_chunk = ResponseWithThought(thought=response_delta.thought)
|
||||||
if response_chunk:
|
if response_chunk:
|
||||||
yield response_chunk
|
yield response_chunk
|
||||||
|
|
||||||
# Log the time taken to stream the entire response
|
# Calculate cost of chat after stream finishes
|
||||||
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
|
input_tokens, output_tokens, cost = 0, 0, 0
|
||||||
|
if final_chunk and hasattr(final_chunk, "usage") and final_chunk.usage:
|
||||||
|
input_tokens = final_chunk.usage.prompt_tokens
|
||||||
|
output_tokens = final_chunk.usage.completion_tokens
|
||||||
|
# Estimated costs returned by DeepInfra API
|
||||||
|
if final_chunk.usage.model_extra and "estimated_cost" in final_chunk.usage.model_extra:
|
||||||
|
cost = final_chunk.usage.model_extra.get("estimated_cost", 0)
|
||||||
|
tracer["usage"] = get_chat_usage_metrics(
|
||||||
|
model_name, input_tokens, output_tokens, usage=tracer.get("usage"), cost=cost
|
||||||
|
)
|
||||||
|
|
||||||
# Calculate cost of chat after stream finishes
|
# Validate the response. If empty, raise an error to retry.
|
||||||
input_tokens, output_tokens, cost = 0, 0, 0
|
if is_none_or_empty(aggregated_response):
|
||||||
if final_chunk and hasattr(final_chunk, "usage") and final_chunk.usage:
|
logger.warning(f"No response by {model_name}\nLast Message by {messages[-1].role}: {messages[-1].content}.")
|
||||||
input_tokens = final_chunk.usage.prompt_tokens
|
raise ValueError(f"Empty or no response by {model_name} over API. Retry if needed.")
|
||||||
output_tokens = final_chunk.usage.completion_tokens
|
|
||||||
# Estimated costs returned by DeepInfra API
|
|
||||||
if final_chunk.usage.model_extra and "estimated_cost" in final_chunk.usage.model_extra:
|
|
||||||
cost = final_chunk.usage.model_extra.get("estimated_cost", 0)
|
|
||||||
|
|
||||||
# Save conversation trace
|
# Log the time taken to stream the entire response
|
||||||
tracer["chat_model"] = model_name
|
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
|
||||||
tracer["temperature"] = temperature
|
|
||||||
tracer["usage"] = get_chat_usage_metrics(
|
# Save conversation trace
|
||||||
model_name, input_tokens, output_tokens, usage=tracer.get("usage"), cost=cost
|
tracer["chat_model"] = model_name
|
||||||
)
|
tracer["temperature"] = temperature
|
||||||
if is_promptrace_enabled():
|
if is_promptrace_enabled():
|
||||||
commit_conversation_trace(messages, aggregated_response, tracer)
|
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in chat_completion_with_backoff stream: {e}", exc_info=True)
|
|
||||||
|
|
||||||
|
|
||||||
def get_openai_api_json_support(model_name: str, api_base_url: str = None) -> JsonSupport:
|
def get_openai_api_json_support(model_name: str, api_base_url: str = None) -> JsonSupport:
|
||||||
|
|||||||
Reference in New Issue
Block a user