mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 13:18:18 +00:00
Rename ResponseWithThought response field to text for better naming
This commit is contained in:
@@ -154,7 +154,7 @@ def anthropic_completion_with_backoff(
|
|||||||
if is_promptrace_enabled():
|
if is_promptrace_enabled():
|
||||||
commit_conversation_trace(messages, aggregated_response, tracer)
|
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||||
|
|
||||||
return ResponseWithThought(response=aggregated_response, thought=thoughts)
|
return ResponseWithThought(text=aggregated_response, thought=thoughts)
|
||||||
|
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
@@ -211,10 +211,10 @@ async def anthropic_chat_completion_with_backoff(
|
|||||||
if chunk.type == "message_delta":
|
if chunk.type == "message_delta":
|
||||||
if chunk.delta.stop_reason == "refusal":
|
if chunk.delta.stop_reason == "refusal":
|
||||||
yield ResponseWithThought(
|
yield ResponseWithThought(
|
||||||
response="...I'm sorry, but my safety filters prevent me from assisting with this query."
|
text="...I'm sorry, but my safety filters prevent me from assisting with this query."
|
||||||
)
|
)
|
||||||
elif chunk.delta.stop_reason == "max_tokens":
|
elif chunk.delta.stop_reason == "max_tokens":
|
||||||
yield ResponseWithThought(response="...I'm sorry, but I've hit my response length limit.")
|
yield ResponseWithThought(text="...I'm sorry, but I've hit my response length limit.")
|
||||||
if chunk.delta.stop_reason in ["refusal", "max_tokens"]:
|
if chunk.delta.stop_reason in ["refusal", "max_tokens"]:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"LLM Response Prevented for {model_name}: {chunk.delta.stop_reason}.\n"
|
f"LLM Response Prevented for {model_name}: {chunk.delta.stop_reason}.\n"
|
||||||
@@ -227,7 +227,7 @@ async def anthropic_chat_completion_with_backoff(
|
|||||||
# Handle streamed response chunk
|
# Handle streamed response chunk
|
||||||
response_chunk: ResponseWithThought = None
|
response_chunk: ResponseWithThought = None
|
||||||
if chunk.delta.type == "text_delta":
|
if chunk.delta.type == "text_delta":
|
||||||
response_chunk = ResponseWithThought(response=chunk.delta.text)
|
response_chunk = ResponseWithThought(text=chunk.delta.text)
|
||||||
aggregated_response += chunk.delta.text
|
aggregated_response += chunk.delta.text
|
||||||
if chunk.delta.type == "thinking_delta":
|
if chunk.delta.type == "thinking_delta":
|
||||||
response_chunk = ResponseWithThought(thought=chunk.delta.thinking)
|
response_chunk = ResponseWithThought(thought=chunk.delta.thinking)
|
||||||
|
|||||||
@@ -190,7 +190,7 @@ def gemini_completion_with_backoff(
|
|||||||
if is_promptrace_enabled():
|
if is_promptrace_enabled():
|
||||||
commit_conversation_trace(messages, response_text, tracer)
|
commit_conversation_trace(messages, response_text, tracer)
|
||||||
|
|
||||||
return ResponseWithThought(response=response_text, thought=response_thoughts)
|
return ResponseWithThought(text=response_text, thought=response_thoughts)
|
||||||
|
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
@@ -258,7 +258,7 @@ async def gemini_chat_completion_with_backoff(
|
|||||||
# handle safety, rate-limit, other finish reasons
|
# handle safety, rate-limit, other finish reasons
|
||||||
stop_message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback)
|
stop_message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback)
|
||||||
if stopped:
|
if stopped:
|
||||||
yield ResponseWithThought(response=stop_message)
|
yield ResponseWithThought(text=stop_message)
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"LLM Response Prevented for {model_name}: {stop_message}.\n"
|
f"LLM Response Prevented for {model_name}: {stop_message}.\n"
|
||||||
+ f"Last Message by {messages[-1].role}: {messages[-1].content}"
|
+ f"Last Message by {messages[-1].role}: {messages[-1].content}"
|
||||||
@@ -271,7 +271,7 @@ async def gemini_chat_completion_with_backoff(
|
|||||||
yield ResponseWithThought(thought=part.text)
|
yield ResponseWithThought(thought=part.text)
|
||||||
elif part.text:
|
elif part.text:
|
||||||
aggregated_response += part.text
|
aggregated_response += part.text
|
||||||
yield ResponseWithThought(response=part.text)
|
yield ResponseWithThought(text=part.text)
|
||||||
# Calculate cost of chat
|
# Calculate cost of chat
|
||||||
input_tokens = final_chunk.usage_metadata.prompt_token_count or 0 if final_chunk else 0
|
input_tokens = final_chunk.usage_metadata.prompt_token_count or 0 if final_chunk else 0
|
||||||
output_tokens = final_chunk.usage_metadata.candidates_token_count or 0 if final_chunk else 0
|
output_tokens = final_chunk.usage_metadata.candidates_token_count or 0 if final_chunk else 0
|
||||||
|
|||||||
@@ -145,12 +145,12 @@ async def converse_offline(
|
|||||||
aggregated_response += response_delta
|
aggregated_response += response_delta
|
||||||
# Put chunk into the asyncio queue (non-blocking)
|
# Put chunk into the asyncio queue (non-blocking)
|
||||||
try:
|
try:
|
||||||
queue.put_nowait(ResponseWithThought(response=response_delta))
|
queue.put_nowait(ResponseWithThought(text=response_delta))
|
||||||
except asyncio.QueueFull:
|
except asyncio.QueueFull:
|
||||||
# Should not happen with default queue size unless consumer is very slow
|
# Should not happen with default queue size unless consumer is very slow
|
||||||
logger.warning("Asyncio queue full during offline LLM streaming.")
|
logger.warning("Asyncio queue full during offline LLM streaming.")
|
||||||
# Potentially block here or handle differently if needed
|
# Potentially block here or handle differently if needed
|
||||||
asyncio.run(queue.put(ResponseWithThought(response=response_delta)))
|
asyncio.run(queue.put(ResponseWithThought(text=response_delta)))
|
||||||
|
|
||||||
# Log the time taken to stream the entire response
|
# Log the time taken to stream the entire response
|
||||||
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
|
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
|
||||||
@@ -221,4 +221,4 @@ def send_message_to_model_offline(
|
|||||||
if is_promptrace_enabled():
|
if is_promptrace_enabled():
|
||||||
commit_conversation_trace(messages, response_text, tracer)
|
commit_conversation_trace(messages, response_text, tracer)
|
||||||
|
|
||||||
return ResponseWithThought(response=response_text)
|
return ResponseWithThought(text=response_text)
|
||||||
|
|||||||
@@ -181,7 +181,7 @@ def completion_with_backoff(
|
|||||||
if is_promptrace_enabled():
|
if is_promptrace_enabled():
|
||||||
commit_conversation_trace(messages, aggregated_response, tracer)
|
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||||
|
|
||||||
return ResponseWithThought(response=aggregated_response, thought=thoughts)
|
return ResponseWithThought(text=aggregated_response, thought=thoughts)
|
||||||
|
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
@@ -297,7 +297,7 @@ async def chat_completion_with_backoff(
|
|||||||
raise ValueError("No response by model.")
|
raise ValueError("No response by model.")
|
||||||
aggregated_response = response.choices[0].message.content
|
aggregated_response = response.choices[0].message.content
|
||||||
final_chunk = response
|
final_chunk = response
|
||||||
yield ResponseWithThought(response=aggregated_response)
|
yield ResponseWithThought(text=aggregated_response)
|
||||||
else:
|
else:
|
||||||
async for chunk in stream_processor(response):
|
async for chunk in stream_processor(response):
|
||||||
# Log the time taken to start response
|
# Log the time taken to start response
|
||||||
@@ -313,8 +313,8 @@ async def chat_completion_with_backoff(
|
|||||||
response_chunk: ResponseWithThought = None
|
response_chunk: ResponseWithThought = None
|
||||||
response_delta = chunk.choices[0].delta
|
response_delta = chunk.choices[0].delta
|
||||||
if response_delta.content:
|
if response_delta.content:
|
||||||
response_chunk = ResponseWithThought(response=response_delta.content)
|
response_chunk = ResponseWithThought(text=response_delta.content)
|
||||||
aggregated_response += response_chunk.response
|
aggregated_response += response_chunk.text
|
||||||
elif response_delta.thought:
|
elif response_delta.thought:
|
||||||
response_chunk = ResponseWithThought(thought=response_delta.thought)
|
response_chunk = ResponseWithThought(thought=response_delta.thought)
|
||||||
if response_chunk:
|
if response_chunk:
|
||||||
|
|||||||
@@ -1187,6 +1187,6 @@ class StructuredOutputSupport(int, Enum):
|
|||||||
|
|
||||||
|
|
||||||
class ResponseWithThought:
|
class ResponseWithThought:
|
||||||
def __init__(self, response: str = None, thought: str = None):
|
def __init__(self, text: str = None, thought: str = None):
|
||||||
self.response = response
|
self.text = text
|
||||||
self.thought = thought
|
self.thought = thought
|
||||||
|
|||||||
@@ -129,7 +129,7 @@ class BinaryOperatorAgent(OperatorAgent):
|
|||||||
agent_chat_model=self.reasoning_model,
|
agent_chat_model=self.reasoning_model,
|
||||||
tracer=self.tracer,
|
tracer=self.tracer,
|
||||||
)
|
)
|
||||||
natural_language_action = raw_response.response
|
natural_language_action = raw_response.text
|
||||||
|
|
||||||
if not isinstance(natural_language_action, str) or not natural_language_action.strip():
|
if not isinstance(natural_language_action, str) or not natural_language_action.strip():
|
||||||
raise ValueError(f"Natural language action is empty or not a string. Got {natural_language_action}")
|
raise ValueError(f"Natural language action is empty or not a string. Got {natural_language_action}")
|
||||||
@@ -256,10 +256,10 @@ class BinaryOperatorAgent(OperatorAgent):
|
|||||||
|
|
||||||
# Append summary messages to history
|
# Append summary messages to history
|
||||||
trigger_summary = AgentMessage(role="user", content=summarize_prompt)
|
trigger_summary = AgentMessage(role="user", content=summarize_prompt)
|
||||||
summary_message = AgentMessage(role="assistant", content=summary.response)
|
summary_message = AgentMessage(role="assistant", content=summary.text)
|
||||||
self.messages.extend([trigger_summary, summary_message])
|
self.messages.extend([trigger_summary, summary_message])
|
||||||
|
|
||||||
return summary.response
|
return summary.text
|
||||||
|
|
||||||
def _compile_response(self, response_content: str | List) -> str:
|
def _compile_response(self, response_content: str | List) -> str:
|
||||||
"""Compile response content into a string, handling OpenAI message structures."""
|
"""Compile response content into a string, handling OpenAI message structures."""
|
||||||
|
|||||||
@@ -161,7 +161,7 @@ async def generate_python_code(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Extract python code wrapped in markdown code blocks from the response
|
# Extract python code wrapped in markdown code blocks from the response
|
||||||
code_blocks = re.findall(r"```(?:python)?\n(.*?)```", response.response, re.DOTALL)
|
code_blocks = re.findall(r"```(?:python)?\n(.*?)```", response.text, re.DOTALL)
|
||||||
|
|
||||||
if not code_blocks:
|
if not code_blocks:
|
||||||
raise ValueError("No Python code blocks found in response")
|
raise ValueError("No Python code blocks found in response")
|
||||||
|
|||||||
@@ -1390,7 +1390,7 @@ async def chat(
|
|||||||
continue
|
continue
|
||||||
if cancellation_event.is_set():
|
if cancellation_event.is_set():
|
||||||
break
|
break
|
||||||
message = item.response
|
message = item.text
|
||||||
full_response += message if message else ""
|
full_response += message if message else ""
|
||||||
if item.thought:
|
if item.thought:
|
||||||
async for result in send_event(ChatEvent.THOUGHT, item.thought):
|
async for result in send_event(ChatEvent.THOUGHT, item.thought):
|
||||||
|
|||||||
@@ -304,7 +304,7 @@ async def acreate_title_from_history(
|
|||||||
with timer("Chat actor: Generate title from conversation history", logger):
|
with timer("Chat actor: Generate title from conversation history", logger):
|
||||||
response = await send_message_to_model_wrapper(title_generation_prompt, user=user)
|
response = await send_message_to_model_wrapper(title_generation_prompt, user=user)
|
||||||
|
|
||||||
return response.response.strip()
|
return response.text.strip()
|
||||||
|
|
||||||
|
|
||||||
async def acreate_title_from_query(query: str, user: KhojUser = None) -> str:
|
async def acreate_title_from_query(query: str, user: KhojUser = None) -> str:
|
||||||
@@ -316,7 +316,7 @@ async def acreate_title_from_query(query: str, user: KhojUser = None) -> str:
|
|||||||
with timer("Chat actor: Generate title from query", logger):
|
with timer("Chat actor: Generate title from query", logger):
|
||||||
response = await send_message_to_model_wrapper(title_generation_prompt, user=user)
|
response = await send_message_to_model_wrapper(title_generation_prompt, user=user)
|
||||||
|
|
||||||
return response.response.strip()
|
return response.text.strip()
|
||||||
|
|
||||||
|
|
||||||
async def acheck_if_safe_prompt(system_prompt: str, user: KhojUser = None, lax: bool = False) -> Tuple[bool, str]:
|
async def acheck_if_safe_prompt(system_prompt: str, user: KhojUser = None, lax: bool = False) -> Tuple[bool, str]:
|
||||||
@@ -340,7 +340,7 @@ async def acheck_if_safe_prompt(system_prompt: str, user: KhojUser = None, lax:
|
|||||||
safe_prompt_check, user=user, response_type="json_object", response_schema=SafetyCheck
|
safe_prompt_check, user=user, response_type="json_object", response_schema=SafetyCheck
|
||||||
)
|
)
|
||||||
|
|
||||||
response = response.response.strip()
|
response = response.text.strip()
|
||||||
try:
|
try:
|
||||||
response = json.loads(clean_json(response))
|
response = json.loads(clean_json(response))
|
||||||
is_safe = str(response.get("safe", "true")).lower() == "true"
|
is_safe = str(response.get("safe", "true")).lower() == "true"
|
||||||
@@ -430,7 +430,7 @@ async def aget_data_sources_and_output_format(
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = clean_json(raw_response.response)
|
response = clean_json(raw_response.text)
|
||||||
response = json.loads(response)
|
response = json.loads(response)
|
||||||
|
|
||||||
chosen_sources = [s.strip() for s in response.get("source", []) if s.strip()]
|
chosen_sources = [s.strip() for s in response.get("source", []) if s.strip()]
|
||||||
@@ -520,7 +520,7 @@ async def infer_webpage_urls(
|
|||||||
|
|
||||||
# Validate that the response is a non-empty, JSON-serializable list of URLs
|
# Validate that the response is a non-empty, JSON-serializable list of URLs
|
||||||
try:
|
try:
|
||||||
response = clean_json(raw_response.response)
|
response = clean_json(raw_response.text)
|
||||||
urls = json.loads(response)
|
urls = json.loads(response)
|
||||||
valid_unique_urls = {str(url).strip() for url in urls["links"] if is_valid_url(url)}
|
valid_unique_urls = {str(url).strip() for url in urls["links"] if is_valid_url(url)}
|
||||||
if is_none_or_empty(valid_unique_urls):
|
if is_none_or_empty(valid_unique_urls):
|
||||||
@@ -585,7 +585,7 @@ async def generate_online_subqueries(
|
|||||||
|
|
||||||
# Validate that the response is a non-empty, JSON-serializable list
|
# Validate that the response is a non-empty, JSON-serializable list
|
||||||
try:
|
try:
|
||||||
response = clean_json(raw_response.response)
|
response = clean_json(raw_response.text)
|
||||||
response = pyjson5.loads(response)
|
response = pyjson5.loads(response)
|
||||||
response = {q.strip() for q in response["queries"] if q.strip()}
|
response = {q.strip() for q in response["queries"] if q.strip()}
|
||||||
if not isinstance(response, set) or not response or len(response) == 0:
|
if not isinstance(response, set) or not response or len(response) == 0:
|
||||||
@@ -646,7 +646,7 @@ async def aschedule_query(
|
|||||||
|
|
||||||
# Validate that the response is a non-empty, JSON-serializable list
|
# Validate that the response is a non-empty, JSON-serializable list
|
||||||
try:
|
try:
|
||||||
raw_response = raw_response.response.strip()
|
raw_response = raw_response.text.strip()
|
||||||
response: Dict[str, str] = json.loads(clean_json(raw_response))
|
response: Dict[str, str] = json.loads(clean_json(raw_response))
|
||||||
if not response or not isinstance(response, Dict) or len(response) != 3:
|
if not response or not isinstance(response, Dict) or len(response) != 3:
|
||||||
raise AssertionError(f"Invalid response for scheduling query : {response}")
|
raise AssertionError(f"Invalid response for scheduling query : {response}")
|
||||||
@@ -684,7 +684,7 @@ async def extract_relevant_info(
|
|||||||
agent_chat_model=agent_chat_model,
|
agent_chat_model=agent_chat_model,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
return response.response.strip()
|
return response.text.strip()
|
||||||
|
|
||||||
|
|
||||||
async def extract_relevant_summary(
|
async def extract_relevant_summary(
|
||||||
@@ -727,7 +727,7 @@ async def extract_relevant_summary(
|
|||||||
agent_chat_model=agent_chat_model,
|
agent_chat_model=agent_chat_model,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
return response.response.strip()
|
return response.text.strip()
|
||||||
|
|
||||||
|
|
||||||
async def generate_summary_from_files(
|
async def generate_summary_from_files(
|
||||||
@@ -898,7 +898,7 @@ async def generate_better_diagram_description(
|
|||||||
agent_chat_model=agent_chat_model,
|
agent_chat_model=agent_chat_model,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
response = response.response.strip()
|
response = response.text.strip()
|
||||||
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
|
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
|
||||||
response = response[1:-1]
|
response = response[1:-1]
|
||||||
|
|
||||||
@@ -926,7 +926,7 @@ async def generate_excalidraw_diagram_from_description(
|
|||||||
raw_response = await send_message_to_model_wrapper(
|
raw_response = await send_message_to_model_wrapper(
|
||||||
query=excalidraw_diagram_generation, user=user, agent_chat_model=agent_chat_model, tracer=tracer
|
query=excalidraw_diagram_generation, user=user, agent_chat_model=agent_chat_model, tracer=tracer
|
||||||
)
|
)
|
||||||
raw_response_text = clean_json(raw_response.response)
|
raw_response_text = clean_json(raw_response.text)
|
||||||
try:
|
try:
|
||||||
# Expect response to have `elements` and `scratchpad` keys
|
# Expect response to have `elements` and `scratchpad` keys
|
||||||
response: Dict[str, str] = json.loads(raw_response_text)
|
response: Dict[str, str] = json.loads(raw_response_text)
|
||||||
@@ -1049,7 +1049,7 @@ async def generate_better_mermaidjs_diagram_description(
|
|||||||
agent_chat_model=agent_chat_model,
|
agent_chat_model=agent_chat_model,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
response_text = response.response.strip()
|
response_text = response.text.strip()
|
||||||
if response_text.startswith(('"', "'")) and response_text.endswith(('"', "'")):
|
if response_text.startswith(('"', "'")) and response_text.endswith(('"', "'")):
|
||||||
response_text = response_text[1:-1]
|
response_text = response_text[1:-1]
|
||||||
|
|
||||||
@@ -1077,7 +1077,7 @@ async def generate_mermaidjs_diagram_from_description(
|
|||||||
raw_response = await send_message_to_model_wrapper(
|
raw_response = await send_message_to_model_wrapper(
|
||||||
query=mermaidjs_diagram_generation, user=user, agent_chat_model=agent_chat_model, tracer=tracer
|
query=mermaidjs_diagram_generation, user=user, agent_chat_model=agent_chat_model, tracer=tracer
|
||||||
)
|
)
|
||||||
return clean_mermaidjs(raw_response.response.strip())
|
return clean_mermaidjs(raw_response.text.strip())
|
||||||
|
|
||||||
|
|
||||||
async def generate_better_image_prompt(
|
async def generate_better_image_prompt(
|
||||||
@@ -1152,7 +1152,7 @@ async def generate_better_image_prompt(
|
|||||||
agent_chat_model=agent_chat_model,
|
agent_chat_model=agent_chat_model,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
response_text = response.response.strip()
|
response_text = response.text.strip()
|
||||||
if response_text.startswith(('"', "'")) and response_text.endswith(('"', "'")):
|
if response_text.startswith(('"', "'")) and response_text.endswith(('"', "'")):
|
||||||
response_text = response_text[1:-1]
|
response_text = response_text[1:-1]
|
||||||
|
|
||||||
@@ -1330,7 +1330,7 @@ async def extract_questions(
|
|||||||
|
|
||||||
# Extract questions from the response
|
# Extract questions from the response
|
||||||
try:
|
try:
|
||||||
response = clean_json(raw_response.response)
|
response = clean_json(raw_response.text)
|
||||||
response = pyjson5.loads(response)
|
response = pyjson5.loads(response)
|
||||||
queries = [q.strip() for q in response["queries"] if q.strip()]
|
queries = [q.strip() for q in response["queries"] if q.strip()]
|
||||||
if not isinstance(queries, list) or not queries:
|
if not isinstance(queries, list) or not queries:
|
||||||
|
|||||||
@@ -157,7 +157,7 @@ async def apick_next_tool(
|
|||||||
try:
|
try:
|
||||||
# Try parse the response as function call response to infer next tool to use.
|
# Try parse the response as function call response to infer next tool to use.
|
||||||
# TODO: Handle multiple tool calls.
|
# TODO: Handle multiple tool calls.
|
||||||
response_text = raw_response.response
|
response_text = raw_response.text
|
||||||
parsed_response = [ToolCall(**item) for item in load_complex_json(response_text)][0]
|
parsed_response = [ToolCall(**item) for item in load_complex_json(response_text)][0]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Otherwise assume the model has decided to end the research run and respond to the user.
|
# Otherwise assume the model has decided to end the research run and respond to the user.
|
||||||
|
|||||||
@@ -189,7 +189,7 @@ async def test_chat_with_no_chat_history_or_retrieved_content():
|
|||||||
user_query="Hello, my name is Testatron. Who are you?",
|
user_query="Hello, my name is Testatron. Who are you?",
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
)
|
)
|
||||||
response = "".join([response_chunk.response async for response_chunk in response_gen])
|
response = "".join([response_chunk.text async for response_chunk in response_gen])
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
expected_responses = ["Khoj", "khoj"]
|
expected_responses = ["Khoj", "khoj"]
|
||||||
@@ -217,7 +217,7 @@ async def test_answer_from_chat_history_and_no_content():
|
|||||||
chat_history=populate_chat_history(message_list),
|
chat_history=populate_chat_history(message_list),
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
)
|
)
|
||||||
response = "".join([response_chunk.response async for response_chunk in response_gen])
|
response = "".join([response_chunk.text async for response_chunk in response_gen])
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
expected_responses = ["Testatron", "testatron"]
|
expected_responses = ["Testatron", "testatron"]
|
||||||
@@ -250,7 +250,7 @@ async def test_answer_from_chat_history_and_previously_retrieved_content():
|
|||||||
chat_history=populate_chat_history(message_list),
|
chat_history=populate_chat_history(message_list),
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
)
|
)
|
||||||
response = "".join([response_chunk.response async for response_chunk in response_gen])
|
response = "".join([response_chunk.text async for response_chunk in response_gen])
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert len(response) > 0
|
assert len(response) > 0
|
||||||
@@ -279,7 +279,7 @@ async def test_answer_from_chat_history_and_currently_retrieved_content():
|
|||||||
chat_history=populate_chat_history(message_list),
|
chat_history=populate_chat_history(message_list),
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
)
|
)
|
||||||
response = "".join([response_chunk.response async for response_chunk in response_gen])
|
response = "".join([response_chunk.text async for response_chunk in response_gen])
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert len(response) > 0
|
assert len(response) > 0
|
||||||
@@ -305,7 +305,7 @@ async def test_refuse_answering_unanswerable_question():
|
|||||||
chat_history=populate_chat_history(message_list),
|
chat_history=populate_chat_history(message_list),
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
)
|
)
|
||||||
response = "".join([response_chunk.response async for response_chunk in response_gen])
|
response = "".join([response_chunk.text async for response_chunk in response_gen])
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
expected_responses = [
|
expected_responses = [
|
||||||
@@ -359,7 +359,7 @@ Expenses:Food:Dining 10.00 USD""",
|
|||||||
user_query="What did I have for Dinner today?",
|
user_query="What did I have for Dinner today?",
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
)
|
)
|
||||||
response = "".join([response_chunk.response async for response_chunk in response_gen])
|
response = "".join([response_chunk.text async for response_chunk in response_gen])
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
expected_responses = ["tacos", "Tacos"]
|
expected_responses = ["tacos", "Tacos"]
|
||||||
@@ -405,7 +405,7 @@ Expenses:Food:Dining 10.00 USD""",
|
|||||||
user_query="How much did I spend on dining this year?",
|
user_query="How much did I spend on dining this year?",
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
)
|
)
|
||||||
response = "".join([response_chunk.response async for response_chunk in response_gen])
|
response = "".join([response_chunk.text async for response_chunk in response_gen])
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert len(response) > 0
|
assert len(response) > 0
|
||||||
@@ -432,7 +432,7 @@ async def test_answer_general_question_not_in_chat_history_or_retrieved_content(
|
|||||||
chat_history=populate_chat_history(message_list),
|
chat_history=populate_chat_history(message_list),
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
)
|
)
|
||||||
response = "".join([response_chunk.response async for response_chunk in response_gen])
|
response = "".join([response_chunk.text async for response_chunk in response_gen])
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
expected_responses = ["test", "bug", "code"]
|
expected_responses = ["test", "bug", "code"]
|
||||||
@@ -473,7 +473,7 @@ My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet.""
|
|||||||
user_query="How many kids does my older sister have?",
|
user_query="How many kids does my older sister have?",
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
)
|
)
|
||||||
response = "".join([response_chunk.response async for response_chunk in response_gen])
|
response = "".join([response_chunk.text async for response_chunk in response_gen])
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
expected_responses = [
|
expected_responses = [
|
||||||
@@ -508,14 +508,14 @@ async def test_agent_prompt_should_be_used(openai_agent):
|
|||||||
user_query="What did I buy?",
|
user_query="What did I buy?",
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
)
|
)
|
||||||
no_agent_response = "".join([response_chunk.response async for response_chunk in response_gen])
|
no_agent_response = "".join([response_chunk.text async for response_chunk in response_gen])
|
||||||
response_gen = converse_openai(
|
response_gen = converse_openai(
|
||||||
references=context, # Assume context retrieved from notes for the user_query
|
references=context, # Assume context retrieved from notes for the user_query
|
||||||
user_query="What did I buy?",
|
user_query="What did I buy?",
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
agent=openai_agent,
|
agent=openai_agent,
|
||||||
)
|
)
|
||||||
agent_response = "".join([response_chunk.response async for response_chunk in response_gen])
|
agent_response = "".join([response_chunk.text async for response_chunk in response_gen])
|
||||||
|
|
||||||
# Assert that the model without the agent prompt does not include the summary of purchases
|
# Assert that the model without the agent prompt does not include the summary of purchases
|
||||||
assert all([expected_response not in no_agent_response for expected_response in expected_responses]), (
|
assert all([expected_response not in no_agent_response for expected_response in expected_responses]), (
|
||||||
|
|||||||
Reference in New Issue
Block a user