From f180b2ba94eafe6e6eed8b586f54c873430253b2 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Fri, 17 Nov 2023 23:26:15 -0800 Subject: [PATCH] Resolve mypy errors for various data types --- src/khoj/processor/conversation/utils.py | 11 +++++++---- src/khoj/routers/api.py | 2 +- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index b0d401fa..ecd4f8ad 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -151,17 +151,20 @@ def truncate_messages( ) system_message = messages.pop() + assert type(system_message.content) == str system_message_tokens = len(encoder.encode(system_message.content)) - tokens = sum([len(encoder.encode(message.content)) for message in messages]) + tokens = sum([len(encoder.encode(message.content)) for message in messages if type(message.content) == str]) while (tokens + system_message_tokens) > max_prompt_size and len(messages) > 1: messages.pop() - tokens = sum([len(encoder.encode(message.content)) for message in messages]) + assert type(system_message.content) == str + tokens = sum([len(encoder.encode(message.content)) for message in messages if type(message.content) == str]) # Truncate current message if still over max supported prompt size by model if (tokens + system_message_tokens) > max_prompt_size: - current_message = "\n".join(messages[0].content.split("\n")[:-1]) - original_question = "\n".join(messages[0].content.split("\n")[-1:]) + assert type(system_message.content) == str + current_message = "\n".join(messages[0].content.split("\n")[:-1]) if type(messages[0].content) == str else "" + original_question = "\n".join(messages[0].content.split("\n")[-1:]) if type(messages[0].content) == str else "" original_question_tokens = len(encoder.encode(original_question)) remaining_tokens = max_prompt_size - original_question_tokens - system_message_tokens truncated_message = encoder.decode(encoder.encode(current_message)[:remaining_tokens]).strip() diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index be2643bd..b384d8a3 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -296,7 +296,7 @@ async def get_all_filenames( client=client, ) - return await sync_to_async(list)(EntryAdapters.aget_all_filenames_by_source(user, content_source)) + return await sync_to_async(list)(EntryAdapters.aget_all_filenames_by_source(user, content_source)) # type: ignore[call-arg] @api.post("/config/data/conversation/model", status_code=200)