mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-10 05:39:11 +00:00
Support constraining OpenAI model output to specified response schema
This commit is contained in:
@@ -121,21 +121,34 @@ def extract_questions(
|
|||||||
|
|
||||||
|
|
||||||
def send_message_to_model(
|
def send_message_to_model(
|
||||||
messages, api_key, model, response_type="text", api_base_url=None, temperature=0, tracer: dict = {}
|
messages,
|
||||||
|
api_key,
|
||||||
|
model,
|
||||||
|
response_type="text",
|
||||||
|
response_schema=None,
|
||||||
|
api_base_url=None,
|
||||||
|
temperature=0,
|
||||||
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Send message to model
|
Send message to model
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Get Response from GPT
|
model_kwargs = {}
|
||||||
json_support = get_openai_api_json_support(model, api_base_url)
|
json_support = get_openai_api_json_support(model, api_base_url)
|
||||||
|
if response_schema and json_support == JsonSupport.SCHEMA:
|
||||||
|
model_kwargs["response_format"] = response_schema
|
||||||
|
elif response_type == "json_object" and json_support == JsonSupport.OBJECT:
|
||||||
|
model_kwargs["response_format"] = {"type": response_type}
|
||||||
|
|
||||||
|
# Get Response from GPT
|
||||||
return completion_with_backoff(
|
return completion_with_backoff(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
model_name=model,
|
model_name=model,
|
||||||
openai_api_key=api_key,
|
openai_api_key=api_key,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
api_base_url=api_base_url,
|
api_base_url=api_base_url,
|
||||||
model_kwargs={"response_format": {"type": response_type}} if json_support >= JsonSupport.OBJECT else {},
|
model_kwargs=model_kwargs,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -67,33 +67,24 @@ def completion_with_backoff(
|
|||||||
temperature = 1
|
temperature = 1
|
||||||
model_kwargs["reasoning_effort"] = "medium"
|
model_kwargs["reasoning_effort"] = "medium"
|
||||||
|
|
||||||
stream = True
|
|
||||||
model_kwargs["stream_options"] = {"include_usage": True}
|
model_kwargs["stream_options"] = {"include_usage": True}
|
||||||
if os.getenv("KHOJ_LLM_SEED"):
|
if os.getenv("KHOJ_LLM_SEED"):
|
||||||
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
|
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
|
||||||
|
|
||||||
chat: ChatCompletion | openai.Stream[ChatCompletionChunk] = client.chat.completions.create(
|
aggregated_response = ""
|
||||||
|
with client.beta.chat.completions.stream(
|
||||||
messages=formatted_messages, # type: ignore
|
messages=formatted_messages, # type: ignore
|
||||||
model=model_name, # type: ignore
|
model=model_name,
|
||||||
stream=stream,
|
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
timeout=20,
|
timeout=20,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
) as chat:
|
||||||
|
|
||||||
aggregated_response = ""
|
|
||||||
if not stream:
|
|
||||||
chunk = chat
|
|
||||||
aggregated_response = chunk.choices[0].message.content
|
|
||||||
else:
|
|
||||||
for chunk in chat:
|
for chunk in chat:
|
||||||
if len(chunk.choices) == 0:
|
if chunk.type == "error":
|
||||||
|
logger.error(f"Openai api response error: {chunk.error}", exc_info=True)
|
||||||
continue
|
continue
|
||||||
delta_chunk = chunk.choices[0].delta # type: ignore
|
elif chunk.type == "content.delta":
|
||||||
if isinstance(delta_chunk, str):
|
aggregated_response += chunk.delta
|
||||||
aggregated_response += delta_chunk
|
|
||||||
elif delta_chunk.content:
|
|
||||||
aggregated_response += delta_chunk.content
|
|
||||||
|
|
||||||
# Calculate cost of chat
|
# Calculate cost of chat
|
||||||
input_tokens = chunk.usage.prompt_tokens if hasattr(chunk, "usage") and chunk.usage else 0
|
input_tokens = chunk.usage.prompt_tokens if hasattr(chunk, "usage") and chunk.usage else 0
|
||||||
|
|||||||
@@ -1209,6 +1209,7 @@ async def send_message_to_model_wrapper(
|
|||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
model=chat_model_name,
|
model=chat_model_name,
|
||||||
response_type=response_type,
|
response_type=response_type,
|
||||||
|
response_schema=response_schema,
|
||||||
api_base_url=api_base_url,
|
api_base_url=api_base_url,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
@@ -1326,6 +1327,7 @@ def send_message_to_model_wrapper_sync(
|
|||||||
api_base_url=api_base_url,
|
api_base_url=api_base_url,
|
||||||
model=chat_model_name,
|
model=chat_model_name,
|
||||||
response_type=response_type,
|
response_type=response_type,
|
||||||
|
response_schema=response_schema,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user