diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index 18eaea47..f087fc93 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -121,21 +121,34 @@ def extract_questions( def send_message_to_model( - messages, api_key, model, response_type="text", api_base_url=None, temperature=0, tracer: dict = {} + messages, + api_key, + model, + response_type="text", + response_schema=None, + api_base_url=None, + temperature=0, + tracer: dict = {}, ): """ Send message to model """ - # Get Response from GPT + model_kwargs = {} json_support = get_openai_api_json_support(model, api_base_url) + if response_schema and json_support == JsonSupport.SCHEMA: + model_kwargs["response_format"] = response_schema + elif response_type == "json_object" and json_support == JsonSupport.OBJECT: + model_kwargs["response_format"] = {"type": response_type} + + # Get Response from GPT return completion_with_backoff( messages=messages, model_name=model, openai_api_key=api_key, temperature=temperature, api_base_url=api_base_url, - model_kwargs={"response_format": {"type": response_type}} if json_support >= JsonSupport.OBJECT else {}, + model_kwargs=model_kwargs, tracer=tracer, ) diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index f80c446a..25ddd60a 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -67,33 +67,24 @@ def completion_with_backoff( temperature = 1 model_kwargs["reasoning_effort"] = "medium" - stream = True model_kwargs["stream_options"] = {"include_usage": True} if os.getenv("KHOJ_LLM_SEED"): model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED")) - chat: ChatCompletion | openai.Stream[ChatCompletionChunk] = client.chat.completions.create( + aggregated_response = "" + with client.beta.chat.completions.stream( messages=formatted_messages, # type: ignore - model=model_name, # type: ignore - stream=stream, + model=model_name, temperature=temperature, timeout=20, **model_kwargs, - ) - - aggregated_response = "" - if not stream: - chunk = chat - aggregated_response = chunk.choices[0].message.content - else: + ) as chat: for chunk in chat: - if len(chunk.choices) == 0: + if chunk.type == "error": + logger.error(f"Openai api response error: {chunk.error}", exc_info=True) continue - delta_chunk = chunk.choices[0].delta # type: ignore - if isinstance(delta_chunk, str): - aggregated_response += delta_chunk - elif delta_chunk.content: - aggregated_response += delta_chunk.content + elif chunk.type == "content.delta": + aggregated_response += chunk.delta # Calculate cost of chat input_tokens = chunk.usage.prompt_tokens if hasattr(chunk, "usage") and chunk.usage else 0 diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 75f38948..4f3a0fcc 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -1209,6 +1209,7 @@ async def send_message_to_model_wrapper( api_key=api_key, model=chat_model_name, response_type=response_type, + response_schema=response_schema, api_base_url=api_base_url, tracer=tracer, ) @@ -1326,6 +1327,7 @@ def send_message_to_model_wrapper_sync( api_base_url=api_base_url, model=chat_model_name, response_type=response_type, + response_schema=response_schema, tracer=tracer, )