Rename ResponseWithThought response field to text for better naming

This commit is contained in:
Debanjum
2025-06-13 18:40:10 -07:00
parent 490f0a435d
commit 721c55a37b
11 changed files with 48 additions and 48 deletions

View File

@@ -154,7 +154,7 @@ def anthropic_completion_with_backoff(
if is_promptrace_enabled():
commit_conversation_trace(messages, aggregated_response, tracer)
return ResponseWithThought(response=aggregated_response, thought=thoughts)
return ResponseWithThought(text=aggregated_response, thought=thoughts)
@retry(
@@ -211,10 +211,10 @@ async def anthropic_chat_completion_with_backoff(
if chunk.type == "message_delta":
if chunk.delta.stop_reason == "refusal":
yield ResponseWithThought(
response="...I'm sorry, but my safety filters prevent me from assisting with this query."
text="...I'm sorry, but my safety filters prevent me from assisting with this query."
)
elif chunk.delta.stop_reason == "max_tokens":
yield ResponseWithThought(response="...I'm sorry, but I've hit my response length limit.")
yield ResponseWithThought(text="...I'm sorry, but I've hit my response length limit.")
if chunk.delta.stop_reason in ["refusal", "max_tokens"]:
logger.warning(
f"LLM Response Prevented for {model_name}: {chunk.delta.stop_reason}.\n"
@@ -227,7 +227,7 @@ async def anthropic_chat_completion_with_backoff(
# Handle streamed response chunk
response_chunk: ResponseWithThought = None
if chunk.delta.type == "text_delta":
response_chunk = ResponseWithThought(response=chunk.delta.text)
response_chunk = ResponseWithThought(text=chunk.delta.text)
aggregated_response += chunk.delta.text
if chunk.delta.type == "thinking_delta":
response_chunk = ResponseWithThought(thought=chunk.delta.thinking)

View File

@@ -190,7 +190,7 @@ def gemini_completion_with_backoff(
if is_promptrace_enabled():
commit_conversation_trace(messages, response_text, tracer)
return ResponseWithThought(response=response_text, thought=response_thoughts)
return ResponseWithThought(text=response_text, thought=response_thoughts)
@retry(
@@ -258,7 +258,7 @@ async def gemini_chat_completion_with_backoff(
# handle safety, rate-limit, other finish reasons
stop_message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback)
if stopped:
yield ResponseWithThought(response=stop_message)
yield ResponseWithThought(text=stop_message)
logger.warning(
f"LLM Response Prevented for {model_name}: {stop_message}.\n"
+ f"Last Message by {messages[-1].role}: {messages[-1].content}"
@@ -271,7 +271,7 @@ async def gemini_chat_completion_with_backoff(
yield ResponseWithThought(thought=part.text)
elif part.text:
aggregated_response += part.text
yield ResponseWithThought(response=part.text)
yield ResponseWithThought(text=part.text)
# Calculate cost of chat
input_tokens = final_chunk.usage_metadata.prompt_token_count or 0 if final_chunk else 0
output_tokens = final_chunk.usage_metadata.candidates_token_count or 0 if final_chunk else 0

View File

@@ -145,12 +145,12 @@ async def converse_offline(
aggregated_response += response_delta
# Put chunk into the asyncio queue (non-blocking)
try:
queue.put_nowait(ResponseWithThought(response=response_delta))
queue.put_nowait(ResponseWithThought(text=response_delta))
except asyncio.QueueFull:
# Should not happen with default queue size unless consumer is very slow
logger.warning("Asyncio queue full during offline LLM streaming.")
# Potentially block here or handle differently if needed
asyncio.run(queue.put(ResponseWithThought(response=response_delta)))
asyncio.run(queue.put(ResponseWithThought(text=response_delta)))
# Log the time taken to stream the entire response
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
@@ -221,4 +221,4 @@ def send_message_to_model_offline(
if is_promptrace_enabled():
commit_conversation_trace(messages, response_text, tracer)
return ResponseWithThought(response=response_text)
return ResponseWithThought(text=response_text)

View File

@@ -181,7 +181,7 @@ def completion_with_backoff(
if is_promptrace_enabled():
commit_conversation_trace(messages, aggregated_response, tracer)
return ResponseWithThought(response=aggregated_response, thought=thoughts)
return ResponseWithThought(text=aggregated_response, thought=thoughts)
@retry(
@@ -297,7 +297,7 @@ async def chat_completion_with_backoff(
raise ValueError("No response by model.")
aggregated_response = response.choices[0].message.content
final_chunk = response
yield ResponseWithThought(response=aggregated_response)
yield ResponseWithThought(text=aggregated_response)
else:
async for chunk in stream_processor(response):
# Log the time taken to start response
@@ -313,8 +313,8 @@ async def chat_completion_with_backoff(
response_chunk: ResponseWithThought = None
response_delta = chunk.choices[0].delta
if response_delta.content:
response_chunk = ResponseWithThought(response=response_delta.content)
aggregated_response += response_chunk.response
response_chunk = ResponseWithThought(text=response_delta.content)
aggregated_response += response_chunk.text
elif response_delta.thought:
response_chunk = ResponseWithThought(thought=response_delta.thought)
if response_chunk:

View File

@@ -1187,6 +1187,6 @@ class StructuredOutputSupport(int, Enum):
class ResponseWithThought:
def __init__(self, response: str = None, thought: str = None):
self.response = response
def __init__(self, text: str = None, thought: str = None):
self.text = text
self.thought = thought

View File

@@ -129,7 +129,7 @@ class BinaryOperatorAgent(OperatorAgent):
agent_chat_model=self.reasoning_model,
tracer=self.tracer,
)
natural_language_action = raw_response.response
natural_language_action = raw_response.text
if not isinstance(natural_language_action, str) or not natural_language_action.strip():
raise ValueError(f"Natural language action is empty or not a string. Got {natural_language_action}")
@@ -256,10 +256,10 @@ class BinaryOperatorAgent(OperatorAgent):
# Append summary messages to history
trigger_summary = AgentMessage(role="user", content=summarize_prompt)
summary_message = AgentMessage(role="assistant", content=summary.response)
summary_message = AgentMessage(role="assistant", content=summary.text)
self.messages.extend([trigger_summary, summary_message])
return summary.response
return summary.text
def _compile_response(self, response_content: str | List) -> str:
"""Compile response content into a string, handling OpenAI message structures."""

View File

@@ -161,7 +161,7 @@ async def generate_python_code(
)
# Extract python code wrapped in markdown code blocks from the response
code_blocks = re.findall(r"```(?:python)?\n(.*?)```", response.response, re.DOTALL)
code_blocks = re.findall(r"```(?:python)?\n(.*?)```", response.text, re.DOTALL)
if not code_blocks:
raise ValueError("No Python code blocks found in response")

View File

@@ -1390,7 +1390,7 @@ async def chat(
continue
if cancellation_event.is_set():
break
message = item.response
message = item.text
full_response += message if message else ""
if item.thought:
async for result in send_event(ChatEvent.THOUGHT, item.thought):

View File

@@ -304,7 +304,7 @@ async def acreate_title_from_history(
with timer("Chat actor: Generate title from conversation history", logger):
response = await send_message_to_model_wrapper(title_generation_prompt, user=user)
return response.response.strip()
return response.text.strip()
async def acreate_title_from_query(query: str, user: KhojUser = None) -> str:
@@ -316,7 +316,7 @@ async def acreate_title_from_query(query: str, user: KhojUser = None) -> str:
with timer("Chat actor: Generate title from query", logger):
response = await send_message_to_model_wrapper(title_generation_prompt, user=user)
return response.response.strip()
return response.text.strip()
async def acheck_if_safe_prompt(system_prompt: str, user: KhojUser = None, lax: bool = False) -> Tuple[bool, str]:
@@ -340,7 +340,7 @@ async def acheck_if_safe_prompt(system_prompt: str, user: KhojUser = None, lax:
safe_prompt_check, user=user, response_type="json_object", response_schema=SafetyCheck
)
response = response.response.strip()
response = response.text.strip()
try:
response = json.loads(clean_json(response))
is_safe = str(response.get("safe", "true")).lower() == "true"
@@ -430,7 +430,7 @@ async def aget_data_sources_and_output_format(
)
try:
response = clean_json(raw_response.response)
response = clean_json(raw_response.text)
response = json.loads(response)
chosen_sources = [s.strip() for s in response.get("source", []) if s.strip()]
@@ -520,7 +520,7 @@ async def infer_webpage_urls(
# Validate that the response is a non-empty, JSON-serializable list of URLs
try:
response = clean_json(raw_response.response)
response = clean_json(raw_response.text)
urls = json.loads(response)
valid_unique_urls = {str(url).strip() for url in urls["links"] if is_valid_url(url)}
if is_none_or_empty(valid_unique_urls):
@@ -585,7 +585,7 @@ async def generate_online_subqueries(
# Validate that the response is a non-empty, JSON-serializable list
try:
response = clean_json(raw_response.response)
response = clean_json(raw_response.text)
response = pyjson5.loads(response)
response = {q.strip() for q in response["queries"] if q.strip()}
if not isinstance(response, set) or not response or len(response) == 0:
@@ -646,7 +646,7 @@ async def aschedule_query(
# Validate that the response is a non-empty, JSON-serializable list
try:
raw_response = raw_response.response.strip()
raw_response = raw_response.text.strip()
response: Dict[str, str] = json.loads(clean_json(raw_response))
if not response or not isinstance(response, Dict) or len(response) != 3:
raise AssertionError(f"Invalid response for scheduling query : {response}")
@@ -684,7 +684,7 @@ async def extract_relevant_info(
agent_chat_model=agent_chat_model,
tracer=tracer,
)
return response.response.strip()
return response.text.strip()
async def extract_relevant_summary(
@@ -727,7 +727,7 @@ async def extract_relevant_summary(
agent_chat_model=agent_chat_model,
tracer=tracer,
)
return response.response.strip()
return response.text.strip()
async def generate_summary_from_files(
@@ -898,7 +898,7 @@ async def generate_better_diagram_description(
agent_chat_model=agent_chat_model,
tracer=tracer,
)
response = response.response.strip()
response = response.text.strip()
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
response = response[1:-1]
@@ -926,7 +926,7 @@ async def generate_excalidraw_diagram_from_description(
raw_response = await send_message_to_model_wrapper(
query=excalidraw_diagram_generation, user=user, agent_chat_model=agent_chat_model, tracer=tracer
)
raw_response_text = clean_json(raw_response.response)
raw_response_text = clean_json(raw_response.text)
try:
# Expect response to have `elements` and `scratchpad` keys
response: Dict[str, str] = json.loads(raw_response_text)
@@ -1049,7 +1049,7 @@ async def generate_better_mermaidjs_diagram_description(
agent_chat_model=agent_chat_model,
tracer=tracer,
)
response_text = response.response.strip()
response_text = response.text.strip()
if response_text.startswith(('"', "'")) and response_text.endswith(('"', "'")):
response_text = response_text[1:-1]
@@ -1077,7 +1077,7 @@ async def generate_mermaidjs_diagram_from_description(
raw_response = await send_message_to_model_wrapper(
query=mermaidjs_diagram_generation, user=user, agent_chat_model=agent_chat_model, tracer=tracer
)
return clean_mermaidjs(raw_response.response.strip())
return clean_mermaidjs(raw_response.text.strip())
async def generate_better_image_prompt(
@@ -1152,7 +1152,7 @@ async def generate_better_image_prompt(
agent_chat_model=agent_chat_model,
tracer=tracer,
)
response_text = response.response.strip()
response_text = response.text.strip()
if response_text.startswith(('"', "'")) and response_text.endswith(('"', "'")):
response_text = response_text[1:-1]
@@ -1330,7 +1330,7 @@ async def extract_questions(
# Extract questions from the response
try:
response = clean_json(raw_response.response)
response = clean_json(raw_response.text)
response = pyjson5.loads(response)
queries = [q.strip() for q in response["queries"] if q.strip()]
if not isinstance(queries, list) or not queries:

View File

@@ -157,7 +157,7 @@ async def apick_next_tool(
try:
# Try parse the response as function call response to infer next tool to use.
# TODO: Handle multiple tool calls.
response_text = raw_response.response
response_text = raw_response.text
parsed_response = [ToolCall(**item) for item in load_complex_json(response_text)][0]
except Exception as e:
# Otherwise assume the model has decided to end the research run and respond to the user.

View File

@@ -189,7 +189,7 @@ async def test_chat_with_no_chat_history_or_retrieved_content():
user_query="Hello, my name is Testatron. Who are you?",
api_key=api_key,
)
response = "".join([response_chunk.response async for response_chunk in response_gen])
response = "".join([response_chunk.text async for response_chunk in response_gen])
# Assert
expected_responses = ["Khoj", "khoj"]
@@ -217,7 +217,7 @@ async def test_answer_from_chat_history_and_no_content():
chat_history=populate_chat_history(message_list),
api_key=api_key,
)
response = "".join([response_chunk.response async for response_chunk in response_gen])
response = "".join([response_chunk.text async for response_chunk in response_gen])
# Assert
expected_responses = ["Testatron", "testatron"]
@@ -250,7 +250,7 @@ async def test_answer_from_chat_history_and_previously_retrieved_content():
chat_history=populate_chat_history(message_list),
api_key=api_key,
)
response = "".join([response_chunk.response async for response_chunk in response_gen])
response = "".join([response_chunk.text async for response_chunk in response_gen])
# Assert
assert len(response) > 0
@@ -279,7 +279,7 @@ async def test_answer_from_chat_history_and_currently_retrieved_content():
chat_history=populate_chat_history(message_list),
api_key=api_key,
)
response = "".join([response_chunk.response async for response_chunk in response_gen])
response = "".join([response_chunk.text async for response_chunk in response_gen])
# Assert
assert len(response) > 0
@@ -305,7 +305,7 @@ async def test_refuse_answering_unanswerable_question():
chat_history=populate_chat_history(message_list),
api_key=api_key,
)
response = "".join([response_chunk.response async for response_chunk in response_gen])
response = "".join([response_chunk.text async for response_chunk in response_gen])
# Assert
expected_responses = [
@@ -359,7 +359,7 @@ Expenses:Food:Dining 10.00 USD""",
user_query="What did I have for Dinner today?",
api_key=api_key,
)
response = "".join([response_chunk.response async for response_chunk in response_gen])
response = "".join([response_chunk.text async for response_chunk in response_gen])
# Assert
expected_responses = ["tacos", "Tacos"]
@@ -405,7 +405,7 @@ Expenses:Food:Dining 10.00 USD""",
user_query="How much did I spend on dining this year?",
api_key=api_key,
)
response = "".join([response_chunk.response async for response_chunk in response_gen])
response = "".join([response_chunk.text async for response_chunk in response_gen])
# Assert
assert len(response) > 0
@@ -432,7 +432,7 @@ async def test_answer_general_question_not_in_chat_history_or_retrieved_content(
chat_history=populate_chat_history(message_list),
api_key=api_key,
)
response = "".join([response_chunk.response async for response_chunk in response_gen])
response = "".join([response_chunk.text async for response_chunk in response_gen])
# Assert
expected_responses = ["test", "bug", "code"]
@@ -473,7 +473,7 @@ My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet.""
user_query="How many kids does my older sister have?",
api_key=api_key,
)
response = "".join([response_chunk.response async for response_chunk in response_gen])
response = "".join([response_chunk.text async for response_chunk in response_gen])
# Assert
expected_responses = [
@@ -508,14 +508,14 @@ async def test_agent_prompt_should_be_used(openai_agent):
user_query="What did I buy?",
api_key=api_key,
)
no_agent_response = "".join([response_chunk.response async for response_chunk in response_gen])
no_agent_response = "".join([response_chunk.text async for response_chunk in response_gen])
response_gen = converse_openai(
references=context, # Assume context retrieved from notes for the user_query
user_query="What did I buy?",
api_key=api_key,
agent=openai_agent,
)
agent_response = "".join([response_chunk.response async for response_chunk in response_gen])
agent_response = "".join([response_chunk.text async for response_chunk in response_gen])
# Assert that the model without the agent prompt does not include the summary of purchases
assert all([expected_response not in no_agent_response for expected_response in expected_responses]), (