mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
Support constraining OpenAI model output to specified response schema
This commit is contained in:
@@ -121,21 +121,34 @@ def extract_questions(
|
||||
|
||||
|
||||
def send_message_to_model(
|
||||
messages, api_key, model, response_type="text", api_base_url=None, temperature=0, tracer: dict = {}
|
||||
messages,
|
||||
api_key,
|
||||
model,
|
||||
response_type="text",
|
||||
response_schema=None,
|
||||
api_base_url=None,
|
||||
temperature=0,
|
||||
tracer: dict = {},
|
||||
):
|
||||
"""
|
||||
Send message to model
|
||||
"""
|
||||
|
||||
# Get Response from GPT
|
||||
model_kwargs = {}
|
||||
json_support = get_openai_api_json_support(model, api_base_url)
|
||||
if response_schema and json_support == JsonSupport.SCHEMA:
|
||||
model_kwargs["response_format"] = response_schema
|
||||
elif response_type == "json_object" and json_support == JsonSupport.OBJECT:
|
||||
model_kwargs["response_format"] = {"type": response_type}
|
||||
|
||||
# Get Response from GPT
|
||||
return completion_with_backoff(
|
||||
messages=messages,
|
||||
model_name=model,
|
||||
openai_api_key=api_key,
|
||||
temperature=temperature,
|
||||
api_base_url=api_base_url,
|
||||
model_kwargs={"response_format": {"type": response_type}} if json_support >= JsonSupport.OBJECT else {},
|
||||
model_kwargs=model_kwargs,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
|
||||
@@ -67,33 +67,24 @@ def completion_with_backoff(
|
||||
temperature = 1
|
||||
model_kwargs["reasoning_effort"] = "medium"
|
||||
|
||||
stream = True
|
||||
model_kwargs["stream_options"] = {"include_usage": True}
|
||||
if os.getenv("KHOJ_LLM_SEED"):
|
||||
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
|
||||
|
||||
chat: ChatCompletion | openai.Stream[ChatCompletionChunk] = client.chat.completions.create(
|
||||
aggregated_response = ""
|
||||
with client.beta.chat.completions.stream(
|
||||
messages=formatted_messages, # type: ignore
|
||||
model=model_name, # type: ignore
|
||||
stream=stream,
|
||||
model=model_name,
|
||||
temperature=temperature,
|
||||
timeout=20,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
aggregated_response = ""
|
||||
if not stream:
|
||||
chunk = chat
|
||||
aggregated_response = chunk.choices[0].message.content
|
||||
else:
|
||||
) as chat:
|
||||
for chunk in chat:
|
||||
if len(chunk.choices) == 0:
|
||||
if chunk.type == "error":
|
||||
logger.error(f"Openai api response error: {chunk.error}", exc_info=True)
|
||||
continue
|
||||
delta_chunk = chunk.choices[0].delta # type: ignore
|
||||
if isinstance(delta_chunk, str):
|
||||
aggregated_response += delta_chunk
|
||||
elif delta_chunk.content:
|
||||
aggregated_response += delta_chunk.content
|
||||
elif chunk.type == "content.delta":
|
||||
aggregated_response += chunk.delta
|
||||
|
||||
# Calculate cost of chat
|
||||
input_tokens = chunk.usage.prompt_tokens if hasattr(chunk, "usage") and chunk.usage else 0
|
||||
|
||||
@@ -1209,6 +1209,7 @@ async def send_message_to_model_wrapper(
|
||||
api_key=api_key,
|
||||
model=chat_model_name,
|
||||
response_type=response_type,
|
||||
response_schema=response_schema,
|
||||
api_base_url=api_base_url,
|
||||
tracer=tracer,
|
||||
)
|
||||
@@ -1326,6 +1327,7 @@ def send_message_to_model_wrapper_sync(
|
||||
api_base_url=api_base_url,
|
||||
model=chat_model_name,
|
||||
response_type=response_type,
|
||||
response_schema=response_schema,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user