diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index b0d401fa..ecd4f8ad 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -151,17 +151,20 @@ def truncate_messages( ) system_message = messages.pop() + assert type(system_message.content) == str system_message_tokens = len(encoder.encode(system_message.content)) - tokens = sum([len(encoder.encode(message.content)) for message in messages]) + tokens = sum([len(encoder.encode(message.content)) for message in messages if type(message.content) == str]) while (tokens + system_message_tokens) > max_prompt_size and len(messages) > 1: messages.pop() - tokens = sum([len(encoder.encode(message.content)) for message in messages]) + assert type(system_message.content) == str + tokens = sum([len(encoder.encode(message.content)) for message in messages if type(message.content) == str]) # Truncate current message if still over max supported prompt size by model if (tokens + system_message_tokens) > max_prompt_size: - current_message = "\n".join(messages[0].content.split("\n")[:-1]) - original_question = "\n".join(messages[0].content.split("\n")[-1:]) + assert type(system_message.content) == str + current_message = "\n".join(messages[0].content.split("\n")[:-1]) if type(messages[0].content) == str else "" + original_question = "\n".join(messages[0].content.split("\n")[-1:]) if type(messages[0].content) == str else "" original_question_tokens = len(encoder.encode(original_question)) remaining_tokens = max_prompt_size - original_question_tokens - system_message_tokens truncated_message = encoder.decode(encoder.encode(current_message)[:remaining_tokens]).strip() diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index be2643bd..b384d8a3 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -296,7 +296,7 @@ async def get_all_filenames( client=client, ) - return await sync_to_async(list)(EntryAdapters.aget_all_filenames_by_source(user, content_source)) + return await sync_to_async(list)(EntryAdapters.aget_all_filenames_by_source(user, content_source)) # type: ignore[call-arg] @api.post("/config/data/conversation/model", status_code=200)