mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-05 13:21:18 +00:00
Update truncation logic to handle multi-part message content
This commit is contained in:
@@ -203,7 +203,10 @@ def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt: st
|
||||
system_prompt = system_prompt or ""
|
||||
for message in messages.copy():
|
||||
if message.role == "system":
|
||||
system_prompt += message.content
|
||||
if isinstance(message.content, list):
|
||||
system_prompt += "\n".join([part["text"] for part in message.content if part["type"] == "text"])
|
||||
else:
|
||||
system_prompt += message.content
|
||||
messages.remove(message)
|
||||
system_prompt = None if is_none_or_empty(system_prompt) else system_prompt
|
||||
|
||||
|
||||
@@ -301,7 +301,10 @@ def format_messages_for_gemini(
|
||||
messages = deepcopy(original_messages)
|
||||
for message in messages.copy():
|
||||
if message.role == "system":
|
||||
system_prompt += message.content
|
||||
if isinstance(message.content, list):
|
||||
system_prompt += "\n".join([part["text"] for part in message.content if part["type"] == "text"])
|
||||
else:
|
||||
system_prompt += message.content
|
||||
messages.remove(message)
|
||||
system_prompt = None if is_none_or_empty(system_prompt) else system_prompt
|
||||
|
||||
|
||||
@@ -4,14 +4,12 @@ import logging
|
||||
import math
|
||||
import mimetypes
|
||||
import os
|
||||
import queue
|
||||
import re
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from io import BytesIO
|
||||
from time import perf_counter
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import PIL.Image
|
||||
@@ -20,8 +18,9 @@ import requests
|
||||
import tiktoken
|
||||
import yaml
|
||||
from langchain.schema import ChatMessage
|
||||
from llama_cpp import LlamaTokenizer
|
||||
from llama_cpp.llama import Llama
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
|
||||
from khoj.database.adapters import ConversationAdapters
|
||||
from khoj.database.models import ChatModel, ClientApplication, KhojUser
|
||||
@@ -382,7 +381,7 @@ def gather_raw_query_files(
|
||||
|
||||
def generate_chatml_messages_with_context(
|
||||
user_message,
|
||||
system_message=None,
|
||||
system_message: str = None,
|
||||
conversation_log={},
|
||||
model_name="gpt-4o-mini",
|
||||
loaded_model: Optional[Llama] = None,
|
||||
@@ -480,7 +479,7 @@ def generate_chatml_messages_with_context(
|
||||
if len(chatml_messages) >= 3 * lookback_turns:
|
||||
break
|
||||
|
||||
messages = []
|
||||
messages: list[ChatMessage] = []
|
||||
|
||||
if not is_none_or_empty(generated_asset_results):
|
||||
messages.append(
|
||||
@@ -517,6 +516,11 @@ def generate_chatml_messages_with_context(
|
||||
if not is_none_or_empty(system_message):
|
||||
messages.append(ChatMessage(content=system_message, role="system"))
|
||||
|
||||
# Normalize message content to list of chatml dictionaries
|
||||
for message in messages:
|
||||
if isinstance(message.content, str):
|
||||
message.content = [{"type": "text", "text": message.content}]
|
||||
|
||||
# Truncate oldest messages from conversation history until under max supported prompt size by model
|
||||
messages = truncate_messages(messages, max_prompt_size, model_name, loaded_model, tokenizer_name)
|
||||
|
||||
@@ -524,14 +528,11 @@ def generate_chatml_messages_with_context(
|
||||
return messages[::-1]
|
||||
|
||||
|
||||
def truncate_messages(
|
||||
messages: list[ChatMessage],
|
||||
max_prompt_size: int,
|
||||
def get_encoder(
|
||||
model_name: str,
|
||||
loaded_model: Optional[Llama] = None,
|
||||
tokenizer_name=None,
|
||||
) -> list[ChatMessage]:
|
||||
"""Truncate messages to fit within max prompt size supported by model"""
|
||||
) -> tiktoken.Encoding | PreTrainedTokenizer | PreTrainedTokenizerFast | LlamaTokenizer:
|
||||
default_tokenizer = "gpt-4o"
|
||||
|
||||
try:
|
||||
@@ -554,6 +555,48 @@ def truncate_messages(
|
||||
logger.debug(
|
||||
f"Fallback to default chat model tokenizer: {default_tokenizer}.\nConfigure tokenizer for model: {model_name} in Khoj settings to improve context stuffing."
|
||||
)
|
||||
return encoder
|
||||
|
||||
|
||||
def count_tokens(
|
||||
message_content: str | list[str | dict],
|
||||
encoder: PreTrainedTokenizer | PreTrainedTokenizerFast | LlamaTokenizer | tiktoken.Encoding,
|
||||
) -> int:
|
||||
"""
|
||||
Count the total number of tokens in a list of messages.
|
||||
|
||||
Assumes each images takes 500 tokens for approximation.
|
||||
"""
|
||||
if isinstance(message_content, list):
|
||||
image_count = 0
|
||||
message_content_parts: list[str] = []
|
||||
# Collate message content into single string to ease token counting
|
||||
for part in message_content:
|
||||
if isinstance(part, dict) and part.get("type") == "text":
|
||||
message_content_parts.append(part["text"])
|
||||
elif isinstance(part, dict) and part.get("type") == "image_url":
|
||||
image_count += 1
|
||||
elif isinstance(part, str):
|
||||
message_content_parts.append(part)
|
||||
else:
|
||||
logger.warning(f"Unknown message type: {part}. Skipping.")
|
||||
message_content = "\n".join(message_content_parts).rstrip()
|
||||
return len(encoder.encode(message_content)) + image_count * 500
|
||||
elif isinstance(message_content, str):
|
||||
return len(encoder.encode(message_content))
|
||||
else:
|
||||
return len(encoder.encode(json.dumps(message_content)))
|
||||
|
||||
|
||||
def truncate_messages(
|
||||
messages: list[ChatMessage],
|
||||
max_prompt_size: int,
|
||||
model_name: str,
|
||||
loaded_model: Optional[Llama] = None,
|
||||
tokenizer_name=None,
|
||||
) -> list[ChatMessage]:
|
||||
"""Truncate messages to fit within max prompt size supported by model"""
|
||||
encoder = get_encoder(model_name, loaded_model, tokenizer_name)
|
||||
|
||||
# Extract system message from messages
|
||||
system_message = None
|
||||
@@ -562,35 +605,55 @@ def truncate_messages(
|
||||
system_message = messages.pop(idx)
|
||||
break
|
||||
|
||||
# TODO: Handle truncation of multi-part message.content, i.e when message.content is a list[dict] rather than a string
|
||||
system_message_tokens = (
|
||||
len(encoder.encode(system_message.content)) if system_message and type(system_message.content) == str else 0
|
||||
)
|
||||
|
||||
tokens = sum([len(encoder.encode(message.content)) for message in messages if type(message.content) == str])
|
||||
|
||||
# Drop older messages until under max supported prompt size by model
|
||||
# Reserves 4 tokens to demarcate each message (e.g <|im_start|>user, <|im_end|>, <|endoftext|> etc.)
|
||||
while (tokens + system_message_tokens + 4 * len(messages)) > max_prompt_size and len(messages) > 1:
|
||||
messages.pop()
|
||||
tokens = sum([len(encoder.encode(message.content)) for message in messages if type(message.content) == str])
|
||||
system_message_tokens = count_tokens(system_message.content, encoder) if system_message else 0
|
||||
tokens = sum([count_tokens(message.content, encoder) for message in messages])
|
||||
total_tokens = tokens + system_message_tokens + 4 * len(messages)
|
||||
|
||||
while total_tokens > max_prompt_size and (len(messages) > 1 or len(messages[0].content) > 1):
|
||||
if len(messages[-1].content) > 1:
|
||||
# The oldest content part is earlier in content list. So pop from the front.
|
||||
messages[-1].content.pop(0)
|
||||
else:
|
||||
# The oldest message is the last one. So pop from the back.
|
||||
messages.pop()
|
||||
tokens = sum([count_tokens(message.content, encoder) for message in messages])
|
||||
total_tokens = tokens + system_message_tokens + 4 * len(messages)
|
||||
|
||||
# Truncate current message if still over max supported prompt size by model
|
||||
if (tokens + system_message_tokens) > max_prompt_size:
|
||||
current_message = "\n".join(messages[0].content.split("\n")[:-1]) if type(messages[0].content) == str else ""
|
||||
original_question = "\n".join(messages[0].content.split("\n")[-1:]) if type(messages[0].content) == str else ""
|
||||
original_question = f"\n{original_question}"
|
||||
original_question_tokens = len(encoder.encode(original_question))
|
||||
total_tokens = tokens + system_message_tokens + 4 * len(messages)
|
||||
if total_tokens > max_prompt_size:
|
||||
# At this point, a single message with a single content part of type dict should remain
|
||||
assert (
|
||||
len(messages) == 1 and len(messages[0].content) == 1 and isinstance(messages[0].content[0], dict)
|
||||
), "Expected a single message with a single content part remaining at this point in truncation"
|
||||
|
||||
# Collate message content into single string to ease truncation
|
||||
part = messages[0].content[0]
|
||||
message_content: str = part["text"] if part["type"] == "text" else json.dumps(part)
|
||||
message_role = messages[0].role
|
||||
|
||||
remaining_context = "\n".join(message_content.split("\n")[:-1])
|
||||
original_question = "\n" + "\n".join(message_content.split("\n")[-1:])
|
||||
|
||||
original_question_tokens = count_tokens(original_question, encoder)
|
||||
remaining_tokens = max_prompt_size - system_message_tokens
|
||||
if remaining_tokens > original_question_tokens:
|
||||
remaining_tokens -= original_question_tokens
|
||||
truncated_message = encoder.decode(encoder.encode(current_message)[:remaining_tokens]).strip()
|
||||
messages = [ChatMessage(content=truncated_message + original_question, role=messages[0].role)]
|
||||
truncated_context = encoder.decode(encoder.encode(remaining_context)[:remaining_tokens]).strip()
|
||||
truncated_content = truncated_context + original_question
|
||||
else:
|
||||
truncated_message = encoder.decode(encoder.encode(original_question)[:remaining_tokens]).strip()
|
||||
messages = [ChatMessage(content=truncated_message, role=messages[0].role)]
|
||||
truncated_content = encoder.decode(encoder.encode(original_question)[:remaining_tokens]).strip()
|
||||
messages = [ChatMessage(content=[{"type": "text", "text": truncated_content}], role=message_role)]
|
||||
|
||||
truncated_snippet = (
|
||||
f"{truncated_content[:1000]}\n...\n{truncated_content[-1000:]}"
|
||||
if len(truncated_content) > 2000
|
||||
else truncated_content
|
||||
)
|
||||
logger.debug(
|
||||
f"Truncate current message to fit within max prompt size of {max_prompt_size} supported by {model_name} model:\n {truncated_message[:1000]}..."
|
||||
f"Truncate current message to fit within max prompt size of {max_prompt_size} supported by {model_name} model:\n {truncated_snippet}"
|
||||
)
|
||||
|
||||
if system_message:
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Any, Dict, List
|
||||
|
||||
from apscheduler.schedulers.background import BackgroundScheduler
|
||||
from openai import OpenAI
|
||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
from whisper import Whisper
|
||||
|
||||
from khoj.database.models import ProcessLock
|
||||
@@ -40,7 +41,7 @@ khoj_version: str = None
|
||||
device = get_device()
|
||||
chat_on_gpu: bool = True
|
||||
anonymous_mode: bool = False
|
||||
pretrained_tokenizers: Dict[str, Any] = dict()
|
||||
pretrained_tokenizers: Dict[str, PreTrainedTokenizer | PreTrainedTokenizerFast] = dict()
|
||||
billing_enabled: bool = (
|
||||
os.getenv("STRIPE_API_KEY") is not None
|
||||
and os.getenv("STRIPE_SIGNING_SECRET") is not None
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from copy import deepcopy
|
||||
|
||||
import tiktoken
|
||||
from langchain.schema import ChatMessage
|
||||
|
||||
@@ -5,7 +7,7 @@ from khoj.processor.conversation import utils
|
||||
|
||||
|
||||
class TestTruncateMessage:
|
||||
max_prompt_size = 10
|
||||
max_prompt_size = 40
|
||||
model_name = "gpt-4o-mini"
|
||||
encoder = tiktoken.encoding_for_model(model_name)
|
||||
|
||||
@@ -15,45 +17,108 @@ class TestTruncateMessage:
|
||||
|
||||
# Act
|
||||
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
|
||||
tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])
|
||||
tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history])
|
||||
|
||||
# Assert
|
||||
# The original object has been modified. Verify certain properties
|
||||
assert len(chat_history) < 50
|
||||
assert len(chat_history) > 1
|
||||
assert len(chat_history) > 5
|
||||
assert tokens <= self.max_prompt_size
|
||||
|
||||
def test_truncate_message_only_oldest_big(self):
|
||||
# Arrange
|
||||
chat_history = generate_chat_history(5)
|
||||
big_chat_message = ChatMessage(role="user", content=generate_content(100, suffix="Question?"))
|
||||
chat_history.append(big_chat_message)
|
||||
|
||||
# Act
|
||||
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
|
||||
tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history])
|
||||
|
||||
# Assert
|
||||
# The original object has been modified. Verify certain properties
|
||||
assert len(chat_history) == 5
|
||||
assert tokens <= self.max_prompt_size
|
||||
|
||||
def test_truncate_message_with_image(self):
|
||||
# Arrange
|
||||
image_content_item = {"type": "image_url", "image_url": {"url": "placeholder"}}
|
||||
content_list = [{"type": "text", "text": f"{index}"} for index in range(100)]
|
||||
content_list += [image_content_item, {"type": "text", "text": "Question?"}]
|
||||
big_chat_message = ChatMessage(role="user", content=content_list)
|
||||
copy_big_chat_message = deepcopy(big_chat_message)
|
||||
chat_history = [big_chat_message]
|
||||
tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history])
|
||||
|
||||
# Act
|
||||
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
|
||||
tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history])
|
||||
|
||||
# Assert
|
||||
# The original object has been modified. Verify certain properties
|
||||
assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified"
|
||||
assert truncated_chat_history[0].content[-1]["text"] == "Question?", "Query should be preserved"
|
||||
assert tokens <= self.max_prompt_size, "Truncated message should be within max prompt size"
|
||||
|
||||
def test_truncate_message_with_content_list(self):
|
||||
# Arrange
|
||||
chat_history = generate_chat_history(5)
|
||||
content_list = [{"type": "text", "text": f"{index}"} for index in range(100)]
|
||||
content_list += [{"type": "text", "text": "Question?"}]
|
||||
big_chat_message = ChatMessage(role="user", content=content_list)
|
||||
copy_big_chat_message = deepcopy(big_chat_message)
|
||||
chat_history.insert(0, big_chat_message)
|
||||
tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history])
|
||||
|
||||
# Act
|
||||
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
|
||||
tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history])
|
||||
|
||||
# Assert
|
||||
# The original object has been modified. Verify certain properties
|
||||
assert (
|
||||
len(chat_history) == 1
|
||||
), "Only most recent message should be present as it itself is larger than context size"
|
||||
assert len(truncated_chat_history[0].content) < len(
|
||||
copy_big_chat_message.content
|
||||
), "message content list should be modified"
|
||||
assert truncated_chat_history[0].content[-1]["text"] == "Question?", "Query should be preserved"
|
||||
assert tokens <= self.max_prompt_size, "Truncated message should be within max prompt size"
|
||||
|
||||
def test_truncate_message_first_large(self):
|
||||
# Arrange
|
||||
chat_history = generate_chat_history(5)
|
||||
big_chat_message = ChatMessage(role="user", content=f"{generate_content(6)}\nQuestion?")
|
||||
big_chat_message = ChatMessage(role="user", content=generate_content(100, suffix="Question?"))
|
||||
copy_big_chat_message = big_chat_message.copy()
|
||||
chat_history.insert(0, big_chat_message)
|
||||
tokens = sum([len(self.encoder.encode(message.content)) for message in chat_history])
|
||||
tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history])
|
||||
|
||||
# Act
|
||||
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
|
||||
tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])
|
||||
tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history])
|
||||
|
||||
# Assert
|
||||
# The original object has been modified. Verify certain properties
|
||||
assert len(chat_history) == 1
|
||||
assert truncated_chat_history[0] != copy_big_chat_message
|
||||
assert tokens <= self.max_prompt_size
|
||||
assert (
|
||||
len(chat_history) == 1
|
||||
), "Only most recent message should be present as it itself is larger than context size"
|
||||
assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified"
|
||||
assert truncated_chat_history[0].content[0]["text"].endswith("\nQuestion?"), "Query should be preserved"
|
||||
assert tokens <= self.max_prompt_size, "Truncated message should be within max prompt size"
|
||||
|
||||
def test_truncate_message_last_large(self):
|
||||
def test_truncate_message_large_system_message_first(self):
|
||||
# Arrange
|
||||
chat_history = generate_chat_history(5)
|
||||
chat_history[0].role = "system" # Mark the first message as system message
|
||||
big_chat_message = ChatMessage(role="user", content=f"{generate_content(11)}\nQuestion?")
|
||||
big_chat_message = ChatMessage(role="user", content=generate_content(100, suffix="Question?"))
|
||||
copy_big_chat_message = big_chat_message.copy()
|
||||
|
||||
chat_history.insert(0, big_chat_message)
|
||||
initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_history])
|
||||
initial_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history])
|
||||
|
||||
# Act
|
||||
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
|
||||
final_tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])
|
||||
final_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history])
|
||||
|
||||
# Assert
|
||||
# The original object has been modified. Verify certain properties.
|
||||
@@ -62,46 +127,52 @@ class TestTruncateMessage:
|
||||
) # Because the system_prompt is popped off from the chat_messages list
|
||||
assert len(truncated_chat_history) < 10
|
||||
assert len(truncated_chat_history) > 1
|
||||
assert truncated_chat_history[0] != copy_big_chat_message
|
||||
assert initial_tokens > self.max_prompt_size
|
||||
assert final_tokens <= self.max_prompt_size
|
||||
assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified"
|
||||
assert truncated_chat_history[0].content[0]["text"].endswith("\nQuestion?"), "Query should be preserved"
|
||||
assert initial_tokens > self.max_prompt_size, "Initial tokens should be greater than max prompt size"
|
||||
assert final_tokens <= self.max_prompt_size, "Final tokens should be within max prompt size"
|
||||
|
||||
def test_truncate_single_large_non_system_message(self):
|
||||
# Arrange
|
||||
big_chat_message = ChatMessage(role="user", content=f"{generate_content(11)}\nQuestion?")
|
||||
big_chat_message = ChatMessage(role="user", content=generate_content(100, suffix="Question?"))
|
||||
copy_big_chat_message = big_chat_message.copy()
|
||||
chat_messages = [big_chat_message]
|
||||
initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
|
||||
initial_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_messages])
|
||||
|
||||
# Act
|
||||
truncated_chat_history = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name)
|
||||
final_tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])
|
||||
final_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history])
|
||||
|
||||
# Assert
|
||||
# The original object has been modified. Verify certain properties
|
||||
assert initial_tokens > self.max_prompt_size
|
||||
assert final_tokens <= self.max_prompt_size
|
||||
assert len(chat_messages) == 1
|
||||
assert truncated_chat_history[0] != copy_big_chat_message
|
||||
assert initial_tokens > self.max_prompt_size, "Initial tokens should be greater than max prompt size"
|
||||
assert final_tokens <= self.max_prompt_size, "Final tokens should be within max prompt size"
|
||||
assert (
|
||||
len(chat_messages) == 1
|
||||
), "Only most recent message should be present as it itself is larger than context size"
|
||||
assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified"
|
||||
assert truncated_chat_history[0].content[0]["text"].endswith("\nQuestion?"), "Query should be preserved"
|
||||
|
||||
def test_truncate_single_large_question(self):
|
||||
# Arrange
|
||||
big_chat_message_content = " ".join(["hi"] * (self.max_prompt_size + 1))
|
||||
big_chat_message_content = [{"type": "text", "text": " ".join(["hi"] * (self.max_prompt_size + 1))}]
|
||||
big_chat_message = ChatMessage(role="user", content=big_chat_message_content)
|
||||
copy_big_chat_message = big_chat_message.copy()
|
||||
chat_messages = [big_chat_message]
|
||||
initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
|
||||
initial_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_messages])
|
||||
|
||||
# Act
|
||||
truncated_chat_history = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name)
|
||||
final_tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])
|
||||
final_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history])
|
||||
|
||||
# Assert
|
||||
# The original object has been modified. Verify certain properties
|
||||
assert initial_tokens > self.max_prompt_size
|
||||
assert final_tokens <= self.max_prompt_size
|
||||
assert len(chat_messages) == 1
|
||||
assert truncated_chat_history[0] != copy_big_chat_message
|
||||
assert initial_tokens > self.max_prompt_size, "Initial tokens should be greater than max prompt size"
|
||||
assert final_tokens <= self.max_prompt_size, "Final tokens should be within max prompt size"
|
||||
assert (
|
||||
len(chat_messages) == 1
|
||||
), "Only most recent message should be present as it itself is larger than context size"
|
||||
assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified"
|
||||
|
||||
|
||||
def test_load_complex_raw_json_string():
|
||||
@@ -116,12 +187,12 @@ def test_load_complex_raw_json_string():
|
||||
assert parsed_json == expeced_json
|
||||
|
||||
|
||||
def generate_content(count):
|
||||
return " ".join([f"{index}" for index, _ in enumerate(range(count))])
|
||||
def generate_content(count, suffix=""):
|
||||
return [{"type": "text", "text": " ".join([f"{index}" for index, _ in enumerate(range(count))]) + "\n" + suffix}]
|
||||
|
||||
|
||||
def generate_chat_history(count):
|
||||
return [
|
||||
ChatMessage(role="user" if index % 2 == 0 else "assistant", content=f"{index}")
|
||||
ChatMessage(role="user" if index % 2 == 0 else "assistant", content=[{"type": "text", "text": f"{index}"}])
|
||||
for index, _ in enumerate(range(count))
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user