Resolve mypy errors for various data types

This commit is contained in:
sabaimran
2023-11-17 23:26:15 -08:00
parent 3328a41f08
commit f180b2ba94
2 changed files with 8 additions and 5 deletions

View File

@@ -151,17 +151,20 @@ def truncate_messages(
)
system_message = messages.pop()
assert type(system_message.content) == str
system_message_tokens = len(encoder.encode(system_message.content))
tokens = sum([len(encoder.encode(message.content)) for message in messages])
tokens = sum([len(encoder.encode(message.content)) for message in messages if type(message.content) == str])
while (tokens + system_message_tokens) > max_prompt_size and len(messages) > 1:
messages.pop()
tokens = sum([len(encoder.encode(message.content)) for message in messages])
assert type(system_message.content) == str
tokens = sum([len(encoder.encode(message.content)) for message in messages if type(message.content) == str])
# Truncate current message if still over max supported prompt size by model
if (tokens + system_message_tokens) > max_prompt_size:
current_message = "\n".join(messages[0].content.split("\n")[:-1])
original_question = "\n".join(messages[0].content.split("\n")[-1:])
assert type(system_message.content) == str
current_message = "\n".join(messages[0].content.split("\n")[:-1]) if type(messages[0].content) == str else ""
original_question = "\n".join(messages[0].content.split("\n")[-1:]) if type(messages[0].content) == str else ""
original_question_tokens = len(encoder.encode(original_question))
remaining_tokens = max_prompt_size - original_question_tokens - system_message_tokens
truncated_message = encoder.decode(encoder.encode(current_message)[:remaining_tokens]).strip()

View File

@@ -296,7 +296,7 @@ async def get_all_filenames(
client=client,
)
return await sync_to_async(list)(EntryAdapters.aget_all_filenames_by_source(user, content_source))
return await sync_to_async(list)(EntryAdapters.aget_all_filenames_by_source(user, content_source)) # type: ignore[call-arg]
@api.post("/config/data/conversation/model", status_code=200)