Merge branch 'master' into features/advanced-reasoning

- Conflicts:
  Combine both sides of the conflict in all 3 files below
  - src/khoj/processor/conversation/utils.py
  - src/khoj/routers/helpers.py
  - src/khoj/utils/helpers.py
This commit is contained in:
Debanjum Singh Solanky
2024-10-26 05:15:51 -07:00
18 changed files with 147 additions and 68 deletions

View File

@@ -1,7 +1,7 @@
{ {
"id": "khoj", "id": "khoj",
"name": "Khoj", "name": "Khoj",
"version": "1.26.4", "version": "1.27.1",
"minAppVersion": "0.15.0", "minAppVersion": "0.15.0",
"description": "Your Second Brain", "description": "Your Second Brain",
"author": "Khoj Inc.", "author": "Khoj Inc.",

View File

@@ -1,6 +1,6 @@
{ {
"name": "Khoj", "name": "Khoj",
"version": "1.26.4", "version": "1.27.1",
"description": "Your Second Brain", "description": "Your Second Brain",
"author": "Khoj Inc. <team@khoj.dev>", "author": "Khoj Inc. <team@khoj.dev>",
"license": "GPL-3.0-or-later", "license": "GPL-3.0-or-later",

View File

@@ -6,7 +6,7 @@
;; Saba Imran <saba@khoj.dev> ;; Saba Imran <saba@khoj.dev>
;; Description: Your Second Brain ;; Description: Your Second Brain
;; Keywords: search, chat, ai, org-mode, outlines, markdown, pdf, image ;; Keywords: search, chat, ai, org-mode, outlines, markdown, pdf, image
;; Version: 1.26.4 ;; Version: 1.27.1
;; Package-Requires: ((emacs "27.1") (transient "0.3.0") (dash "2.19.1")) ;; Package-Requires: ((emacs "27.1") (transient "0.3.0") (dash "2.19.1"))
;; URL: https://github.com/khoj-ai/khoj/tree/master/src/interface/emacs ;; URL: https://github.com/khoj-ai/khoj/tree/master/src/interface/emacs

View File

@@ -1,7 +1,7 @@
{ {
"id": "khoj", "id": "khoj",
"name": "Khoj", "name": "Khoj",
"version": "1.26.4", "version": "1.27.1",
"minAppVersion": "0.15.0", "minAppVersion": "0.15.0",
"description": "Your Second Brain", "description": "Your Second Brain",
"author": "Khoj Inc.", "author": "Khoj Inc.",

View File

@@ -1,6 +1,6 @@
{ {
"name": "Khoj", "name": "Khoj",
"version": "1.26.4", "version": "1.27.1",
"description": "Your Second Brain", "description": "Your Second Brain",
"author": "Debanjum Singh Solanky, Saba Imran <team@khoj.dev>", "author": "Debanjum Singh Solanky, Saba Imran <team@khoj.dev>",
"license": "GPL-3.0-or-later", "license": "GPL-3.0-or-later",

View File

@@ -82,5 +82,7 @@
"1.26.1": "0.15.0", "1.26.1": "0.15.0",
"1.26.2": "0.15.0", "1.26.2": "0.15.0",
"1.26.3": "0.15.0", "1.26.3": "0.15.0",
"1.26.4": "0.15.0" "1.26.4": "0.15.0",
"1.27.0": "0.15.0",
"1.27.1": "0.15.0"
} }

View File

@@ -1,6 +1,6 @@
{ {
"name": "khoj-ai", "name": "khoj-ai",
"version": "1.26.4", "version": "1.27.1",
"private": true, "private": true,
"scripts": { "scripts": {
"dev": "next dev", "dev": "next dev",
@@ -19,6 +19,7 @@
"prepare": "husky" "prepare": "husky"
}, },
"dependencies": { "dependencies": {
"@excalidraw/excalidraw": "^0.17.6",
"@hookform/resolvers": "^3.9.0", "@hookform/resolvers": "^3.9.0",
"@phosphor-icons/react": "^2.1.7", "@phosphor-icons/react": "^2.1.7",
"@radix-ui/react-alert-dialog": "^1.1.1", "@radix-ui/react-alert-dialog": "^1.1.1",
@@ -63,8 +64,7 @@
"swr": "^2.2.5", "swr": "^2.2.5",
"typescript": "^5", "typescript": "^5",
"vaul": "^0.9.1", "vaul": "^0.9.1",
"zod": "^3.23.8", "zod": "^3.23.8"
"@excalidraw/excalidraw": "^0.17.6"
}, },
"devDependencies": { "devDependencies": {
"@types/dompurify": "^3.0.5", "@types/dompurify": "^3.0.5",

View File

@@ -301,7 +301,7 @@ def subscription_to_state(subscription: Subscription) -> str:
return SubscriptionState.INVALID.value return SubscriptionState.INVALID.value
elif subscription.type == Subscription.Type.TRIAL: elif subscription.type == Subscription.Type.TRIAL:
# Check if the trial has expired # Check if the trial has expired
if datetime.now(tz=timezone.utc) > subscription.renewal_date: if subscription.renewal_date and datetime.now(tz=timezone.utc) > subscription.renewal_date:
return SubscriptionState.EXPIRED.value return SubscriptionState.EXPIRED.value
return SubscriptionState.TRIAL.value return SubscriptionState.TRIAL.value
elif subscription.is_recurring and subscription.renewal_date > datetime.now(tz=timezone.utc): elif subscription.is_recurring and subscription.renewal_date > datetime.now(tz=timezone.utc):

View File

@@ -11,8 +11,12 @@ from khoj.processor.conversation import prompts
from khoj.processor.conversation.anthropic.utils import ( from khoj.processor.conversation.anthropic.utils import (
anthropic_chat_completion_with_backoff, anthropic_chat_completion_with_backoff,
anthropic_completion_with_backoff, anthropic_completion_with_backoff,
format_messages_for_anthropic,
)
from khoj.processor.conversation.utils import (
construct_structured_message,
generate_chatml_messages_with_context,
) )
from khoj.processor.conversation.utils import generate_chatml_messages_with_context
from khoj.utils.helpers import ConversationCommand, is_none_or_empty from khoj.utils.helpers import ConversationCommand, is_none_or_empty
from khoj.utils.rawconfig import LocationData from khoj.utils.rawconfig import LocationData
@@ -27,6 +31,8 @@ def extract_questions_anthropic(
temperature=0.7, temperature=0.7,
location_data: LocationData = None, location_data: LocationData = None,
user: KhojUser = None, user: KhojUser = None,
query_images: Optional[list[str]] = None,
vision_enabled: bool = False,
personality_context: Optional[str] = None, personality_context: Optional[str] = None,
): ):
""" """
@@ -68,6 +74,13 @@ def extract_questions_anthropic(
text=text, text=text,
) )
prompt = construct_structured_message(
message=prompt,
images=query_images,
model_type=ChatModelOptions.ModelType.ANTHROPIC,
vision_enabled=vision_enabled,
)
messages = [ChatMessage(content=prompt, role="user")] messages = [ChatMessage(content=prompt, role="user")]
response = anthropic_completion_with_backoff( response = anthropic_completion_with_backoff(
@@ -101,17 +114,7 @@ def anthropic_send_message_to_model(messages, api_key, model):
""" """
Send message to model Send message to model
""" """
# Anthropic requires the first message to be a 'user' message, and the system prompt is not to be sent in the messages parameter messages, system_prompt = format_messages_for_anthropic(messages)
system_prompt = None
if len(messages) == 1:
messages[0].role = "user"
else:
system_prompt = ""
for message in messages.copy():
if message.role == "system":
system_prompt += message.content
messages.remove(message)
# Get Response from GPT. Don't use response_type because Anthropic doesn't support it. # Get Response from GPT. Don't use response_type because Anthropic doesn't support it.
return anthropic_completion_with_backoff( return anthropic_completion_with_backoff(
@@ -128,7 +131,7 @@ def converse_anthropic(
online_results: Optional[Dict[str, Dict]] = None, online_results: Optional[Dict[str, Dict]] = None,
code_results: Optional[Dict[str, Dict]] = None, code_results: Optional[Dict[str, Dict]] = None,
conversation_log={}, conversation_log={},
model: Optional[str] = "claude-instant-1.2", model: Optional[str] = "claude-3-5-sonnet-20241022",
api_key: Optional[str] = None, api_key: Optional[str] = None,
completion_func=None, completion_func=None,
conversation_commands=[ConversationCommand.Default], conversation_commands=[ConversationCommand.Default],
@@ -137,6 +140,8 @@ def converse_anthropic(
location_data: LocationData = None, location_data: LocationData = None,
user_name: str = None, user_name: str = None,
agent: Agent = None, agent: Agent = None,
query_images: Optional[list[str]] = None,
vision_available: bool = False,
): ):
""" """
Converse with user using Anthropic's Claude Converse with user using Anthropic's Claude
@@ -194,17 +199,12 @@ def converse_anthropic(
model_name=model, model_name=model,
max_prompt_size=max_prompt_size, max_prompt_size=max_prompt_size,
tokenizer_name=tokenizer_name, tokenizer_name=tokenizer_name,
query_images=query_images,
vision_enabled=vision_available,
model_type=ChatModelOptions.ModelType.ANTHROPIC, model_type=ChatModelOptions.ModelType.ANTHROPIC,
) )
if len(messages) > 1: messages, system_prompt = format_messages_for_anthropic(messages, system_prompt)
if messages[0].role == "assistant":
messages = messages[1:]
for message in messages.copy():
if message.role == "system":
system_prompt += message.content
messages.remove(message)
truncated_messages = "\n".join({f"{message.content[:40]}..." for message in messages}) truncated_messages = "\n".join({f"{message.content[:40]}..." for message in messages})
logger.debug(f"Conversation Context for Claude: {truncated_messages}") logger.debug(f"Conversation Context for Claude: {truncated_messages}")

View File

@@ -3,6 +3,7 @@ from threading import Thread
from typing import Dict, List from typing import Dict, List
import anthropic import anthropic
from langchain.schema import ChatMessage
from tenacity import ( from tenacity import (
before_sleep_log, before_sleep_log,
retry, retry,
@@ -11,7 +12,8 @@ from tenacity import (
wait_random_exponential, wait_random_exponential,
) )
from khoj.processor.conversation.utils import ThreadedGenerator from khoj.processor.conversation.utils import ThreadedGenerator, get_image_from_url
from khoj.utils.helpers import is_none_or_empty
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -115,3 +117,51 @@ def anthropic_llm_thread(
logger.error(f"Error in anthropic_llm_thread: {e}", exc_info=True) logger.error(f"Error in anthropic_llm_thread: {e}", exc_info=True)
finally: finally:
g.close() g.close()
def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt=None):
"""
Format messages for Anthropic
"""
# Extract system prompt
system_prompt = system_prompt or ""
for message in messages.copy():
if message.role == "system":
system_prompt += message.content
messages.remove(message)
system_prompt = None if is_none_or_empty(system_prompt) else system_prompt
# Anthropic requires the first message to be a 'user' message
if len(messages) == 1:
messages[0].role = "user"
elif len(messages) > 1 and messages[0].role == "assistant":
messages = messages[1:]
# Convert image urls to base64 encoded images in Anthropic message format
for message in messages:
if isinstance(message.content, list):
content = []
# Sort the content. Anthropic models prefer that text comes after images.
message.content.sort(key=lambda x: 0 if x["type"] == "image_url" else 1)
for idx, part in enumerate(message.content):
if part["type"] == "text":
content.append({"type": "text", "text": part["text"]})
elif part["type"] == "image_url":
image = get_image_from_url(part["image_url"]["url"], type="b64")
# Prefix each image with text block enumerating the image number
# This helps the model reference the image in its response. Recommended by Anthropic
content.extend(
[
{
"type": "text",
"text": f"Image {idx + 1}:",
},
{
"type": "image",
"source": {"type": "base64", "media_type": image.type, "data": image.content},
},
]
)
message.content = content
return messages, system_prompt

View File

@@ -1,11 +1,8 @@
import logging import logging
import random import random
from io import BytesIO
from threading import Thread from threading import Thread
import google.generativeai as genai import google.generativeai as genai
import PIL.Image
import requests
from google.generativeai.types.answer_types import FinishReason from google.generativeai.types.answer_types import FinishReason
from google.generativeai.types.generation_types import StopCandidateException from google.generativeai.types.generation_types import StopCandidateException
from google.generativeai.types.safety_types import ( from google.generativeai.types.safety_types import (
@@ -22,7 +19,7 @@ from tenacity import (
wait_random_exponential, wait_random_exponential,
) )
from khoj.processor.conversation.utils import ThreadedGenerator from khoj.processor.conversation.utils import ThreadedGenerator, get_image_from_url
from khoj.utils.helpers import is_none_or_empty from khoj.utils.helpers import is_none_or_empty
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -207,7 +204,7 @@ def format_messages_for_gemini(messages: list[ChatMessage], system_prompt: str =
if isinstance(message.content, list): if isinstance(message.content, list):
# Convert image_urls to PIL.Image and place them at beginning of list (better for Gemini) # Convert image_urls to PIL.Image and place them at beginning of list (better for Gemini)
message.content = [ message.content = [
get_image_from_url(item["image_url"]["url"]) if item["type"] == "image_url" else item["text"] get_image_from_url(item["image_url"]["url"]).content if item["type"] == "image_url" else item["text"]
for item in sorted(message.content, key=lambda x: 0 if x["type"] == "image_url" else 1) for item in sorted(message.content, key=lambda x: 0 if x["type"] == "image_url" else 1)
] ]
elif isinstance(message.content, str): elif isinstance(message.content, str):
@@ -220,13 +217,3 @@ def format_messages_for_gemini(messages: list[ChatMessage], system_prompt: str =
messages[0].role = "user" messages[0].role = "user"
return messages, system_prompt return messages, system_prompt
def get_image_from_url(image_url: str) -> PIL.Image:
try:
response = requests.get(image_url)
response.raise_for_status() # Check if the request was successful
return PIL.Image.open(BytesIO(response.content))
except requests.exceptions.RequestException as e:
logger.error(f"Failed to get image from URL {image_url}: {e}")
return None

View File

@@ -616,7 +616,7 @@ AI: It's currently 28°C and partly cloudy in Bali.
Q: Share a painting using the weather for Bali every morning. Q: Share a painting using the weather for Bali every morning.
Khoj: {{"output": "automation"}} Khoj: {{"output": "automation"}}
Now it's your turn to pick the mode you would like to use to answer the user's question. Provide your response as a JSON. Now it's your turn to pick the mode you would like to use to answer the user's question. Provide your response as a JSON. Do not say anything else.
Chat History: Chat History:
{chat_history} {chat_history}

View File

@@ -1,11 +1,17 @@
import base64
import logging import logging
import math import math
import mimetypes
import queue import queue
from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from io import BytesIO
from time import perf_counter from time import perf_counter
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import PIL.Image
import requests
import tiktoken import tiktoken
from langchain.schema import ChatMessage from langchain.schema import ChatMessage
from llama_cpp.llama import Llama from llama_cpp.llama import Llama
@@ -215,7 +221,11 @@ def construct_structured_message(message: str, images: list[str], model_type: st
if not images or not vision_enabled: if not images or not vision_enabled:
return message return message
if model_type in [ChatModelOptions.ModelType.OPENAI, ChatModelOptions.ModelType.GOOGLE]: if model_type in [
ChatModelOptions.ModelType.OPENAI,
ChatModelOptions.ModelType.GOOGLE,
ChatModelOptions.ModelType.ANTHROPIC,
]:
return [ return [
{"type": "text", "text": message}, {"type": "text", "text": message},
*[{"type": "image_url", "image_url": {"url": image}} for image in images], *[{"type": "image_url", "image_url": {"url": image}} for image in images],
@@ -377,3 +387,31 @@ def defilter_query(query: str):
for filter in [DateFilter(), WordFilter(), FileFilter()]: for filter in [DateFilter(), WordFilter(), FileFilter()]:
defiltered_query = filter.defilter(defiltered_query) defiltered_query = filter.defilter(defiltered_query)
return defiltered_query return defiltered_query
@dataclass
class ImageWithType:
content: Any
type: str
def get_image_from_url(image_url: str, type="pil"):
try:
response = requests.get(image_url)
response.raise_for_status() # Check if the request was successful
# Get content type from response or infer from URL
content_type = response.headers.get("content-type") or mimetypes.guess_type(image_url)[0] or "image/webp"
# Convert image to desired format
if type == "b64":
image_data = base64.b64encode(response.content).decode("utf-8")
elif type == "pil":
image_data = PIL.Image.open(BytesIO(response.content))
else:
raise ValueError(f"Invalid image type: {type}")
return ImageWithType(content=image_data, type=content_type)
except requests.exceptions.RequestException as e:
logger.error(f"Failed to get image from URL {image_url}: {e}")
return ImageWithType(content=None, type=None)

View File

@@ -204,9 +204,10 @@ def generate_image_with_replicate(
# Raise exception if the image generation task fails # Raise exception if the image generation task fails
if status != "succeeded": if status != "succeeded":
error = get_prediction.get("error")
if retry_count >= 10: if retry_count >= 10:
raise requests.RequestException("Image generation timed out") raise requests.RequestException("Image generation timed out")
raise requests.RequestException(f"Image generation failed with status: {status}") raise requests.RequestException(f"Image generation failed with status: {status}, message: {error}")
# Get the generated image # Get the generated image
image_url = get_prediction["output"][0] if isinstance(get_prediction["output"], list) else get_prediction["output"] image_url = get_prediction["output"][0] if isinstance(get_prediction["output"], list) else get_prediction["output"]

View File

@@ -447,11 +447,13 @@ async def extract_references_and_questions(
chat_model = conversation_config.chat_model chat_model = conversation_config.chat_model
inferred_queries = extract_questions_anthropic( inferred_queries = extract_questions_anthropic(
defiltered_query, defiltered_query,
query_images=query_images,
model=chat_model, model=chat_model,
api_key=api_key, api_key=api_key,
conversation_log=meta_log, conversation_log=meta_log,
location_data=location_data, location_data=location_data,
user=user, user=user,
vision_enabled=vision_enabled,
personality_context=personality_context, personality_context=personality_context,
) )
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE: elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:

View File

@@ -721,10 +721,7 @@ async def generate_better_diagram_description(
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else "" prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
) )
if location_data: location = f"{location_data}" if location_data else "Unknown"
location_prompt = prompts.user_location.format(location=f"{location_data}")
else:
location_prompt = "Unknown"
user_references = "\n\n".join([f"# {item['compiled']}" for item in note_references]) user_references = "\n\n".join([f"# {item['compiled']}" for item in note_references])
@@ -742,7 +739,7 @@ async def generate_better_diagram_description(
improve_diagram_description_prompt = prompts.improve_diagram_description_prompt.format( improve_diagram_description_prompt = prompts.improve_diagram_description_prompt.format(
query=q, query=q,
chat_history=chat_history, chat_history=chat_history,
location=location_prompt, location=location,
current_date=today_date, current_date=today_date,
references=user_references, references=user_references,
online_results=simplified_online_results, online_results=simplified_online_results,
@@ -807,10 +804,7 @@ async def generate_better_image_prompt(
) )
model_type = model_type or TextToImageModelConfig.ModelType.OPENAI model_type = model_type or TextToImageModelConfig.ModelType.OPENAI
if location_data: location = f"{location_data}" if location_data else "Unknown"
location_prompt = prompts.user_location.format(location=f"{location_data}")
else:
location_prompt = "Unknown"
user_references = "\n\n".join([f"# {item['compiled']}" for item in note_references]) user_references = "\n\n".join([f"# {item['compiled']}" for item in note_references])
@@ -827,7 +821,7 @@ async def generate_better_image_prompt(
image_prompt = prompts.image_generation_improve_prompt_dalle.format( image_prompt = prompts.image_generation_improve_prompt_dalle.format(
query=q, query=q,
chat_history=conversation_history, chat_history=conversation_history,
location=location_prompt, location=location,
current_date=today_date, current_date=today_date,
references=user_references, references=user_references,
online_results=simplified_online_results, online_results=simplified_online_results,
@@ -837,7 +831,7 @@ async def generate_better_image_prompt(
image_prompt = prompts.image_generation_improve_prompt_sd.format( image_prompt = prompts.image_generation_improve_prompt_sd.format(
query=q, query=q,
chat_history=conversation_history, chat_history=conversation_history,
location=location_prompt, location=location,
current_date=today_date, current_date=today_date,
references=user_references, references=user_references,
online_results=simplified_online_results, online_results=simplified_online_results,
@@ -863,10 +857,13 @@ async def send_message_to_model_wrapper(
conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config(user) conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config(user)
vision_available = conversation_config.vision_enabled vision_available = conversation_config.vision_enabled
if not vision_available and query_images: if not vision_available and query_images:
logger.warning(f"Vision is not enabled for default model: {conversation_config.chat_model}.")
vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config() vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config()
if vision_enabled_config: if vision_enabled_config:
conversation_config = vision_enabled_config conversation_config = vision_enabled_config
vision_available = True vision_available = True
if vision_available and query_images:
logger.info(f"Using {conversation_config.chat_model} model to understand {len(query_images)} images.")
subscribed = await ais_user_subscribed(user) subscribed = await ais_user_subscribed(user)
chat_model = conversation_config.chat_model chat_model = conversation_config.chat_model
@@ -1154,9 +1151,10 @@ def generate_chat_response(
chat_response = converse_anthropic( chat_response = converse_anthropic(
compiled_references, compiled_references,
query_to_run, query_to_run,
online_results, query_images=query_images,
code_results, online_results=online_results,
meta_log, code_results=code_results,
conversation_log=meta_log,
model=conversation_config.chat_model, model=conversation_config.chat_model,
api_key=api_key, api_key=api_key,
completion_func=partial_completion, completion_func=partial_completion,
@@ -1166,6 +1164,7 @@ def generate_chat_response(
location_data=location_data, location_data=location_data,
user_name=user_name, user_name=user_name,
agent=agent, agent=agent,
vision_available=vision_available,
) )
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE: elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
api_key = conversation_config.openai_config.api_key api_key = conversation_config.openai_config.api_key

View File

@@ -362,9 +362,7 @@ function_calling_description_for_llm = {
} }
mode_descriptions_for_llm = { mode_descriptions_for_llm = {
ConversationCommand.Image: "Use this if the user is requesting you to generate images based on their description. This does not support generating charts or graphs.", ConversationCommand.Image: "Use this if you are confident the user is requesting you to create a new picture based on their description. This does not support generating charts or graphs.",
ConversationCommand.Automation: "Use this if the user is requesting a response at a scheduled date or time.",
ConversationCommand.Text: "Use this if the other response modes don't seem to fit the query.",
ConversationCommand.Automation: "Use this if you are confident the user is requesting a response at a scheduled date, time and frequency", ConversationCommand.Automation: "Use this if you are confident the user is requesting a response at a scheduled date, time and frequency",
ConversationCommand.Text: "Use this if a normal text response would be sufficient for accurately responding to the query.", ConversationCommand.Text: "Use this if a normal text response would be sufficient for accurately responding to the query.",
ConversationCommand.Diagram: "Use this if the user is requesting a visual representation that requires primitives like lines, rectangles, and text.", ConversationCommand.Diagram: "Use this if the user is requesting a visual representation that requires primitives like lines, rectangles, and text.",

View File

@@ -82,5 +82,7 @@
"1.26.1": "0.15.0", "1.26.1": "0.15.0",
"1.26.2": "0.15.0", "1.26.2": "0.15.0",
"1.26.3": "0.15.0", "1.26.3": "0.15.0",
"1.26.4": "0.15.0" "1.26.4": "0.15.0",
"1.27.0": "0.15.0",
"1.27.1": "0.15.0"
} }