Support constraining OpenAI model output to specified response schema

This commit is contained in:
Debanjum
2025-03-19 17:46:26 +05:30
parent 4a4d225455
commit ac4b36b9fd
3 changed files with 26 additions and 20 deletions

View File

@@ -121,21 +121,34 @@ def extract_questions(
def send_message_to_model(
messages, api_key, model, response_type="text", api_base_url=None, temperature=0, tracer: dict = {}
messages,
api_key,
model,
response_type="text",
response_schema=None,
api_base_url=None,
temperature=0,
tracer: dict = {},
):
"""
Send message to model
"""
# Get Response from GPT
model_kwargs = {}
json_support = get_openai_api_json_support(model, api_base_url)
if response_schema and json_support == JsonSupport.SCHEMA:
model_kwargs["response_format"] = response_schema
elif response_type == "json_object" and json_support == JsonSupport.OBJECT:
model_kwargs["response_format"] = {"type": response_type}
# Get Response from GPT
return completion_with_backoff(
messages=messages,
model_name=model,
openai_api_key=api_key,
temperature=temperature,
api_base_url=api_base_url,
model_kwargs={"response_format": {"type": response_type}} if json_support >= JsonSupport.OBJECT else {},
model_kwargs=model_kwargs,
tracer=tracer,
)

View File

@@ -67,33 +67,24 @@ def completion_with_backoff(
temperature = 1
model_kwargs["reasoning_effort"] = "medium"
stream = True
model_kwargs["stream_options"] = {"include_usage": True}
if os.getenv("KHOJ_LLM_SEED"):
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
chat: ChatCompletion | openai.Stream[ChatCompletionChunk] = client.chat.completions.create(
aggregated_response = ""
with client.beta.chat.completions.stream(
messages=formatted_messages, # type: ignore
model=model_name, # type: ignore
stream=stream,
model=model_name,
temperature=temperature,
timeout=20,
**model_kwargs,
)
aggregated_response = ""
if not stream:
chunk = chat
aggregated_response = chunk.choices[0].message.content
else:
) as chat:
for chunk in chat:
if len(chunk.choices) == 0:
if chunk.type == "error":
logger.error(f"Openai api response error: {chunk.error}", exc_info=True)
continue
delta_chunk = chunk.choices[0].delta # type: ignore
if isinstance(delta_chunk, str):
aggregated_response += delta_chunk
elif delta_chunk.content:
aggregated_response += delta_chunk.content
elif chunk.type == "content.delta":
aggregated_response += chunk.delta
# Calculate cost of chat
input_tokens = chunk.usage.prompt_tokens if hasattr(chunk, "usage") and chunk.usage else 0

View File

@@ -1209,6 +1209,7 @@ async def send_message_to_model_wrapper(
api_key=api_key,
model=chat_model_name,
response_type=response_type,
response_schema=response_schema,
api_base_url=api_base_url,
tracer=tracer,
)
@@ -1326,6 +1327,7 @@ def send_message_to_model_wrapper_sync(
api_base_url=api_base_url,
model=chat_model_name,
response_type=response_type,
response_schema=response_schema,
tracer=tracer,
)