mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
228 lines
8.3 KiB
Python
228 lines
8.3 KiB
Python
# Standard Packages
|
|
import os
|
|
import logging
|
|
from datetime import datetime
|
|
from typing import Any, Optional
|
|
from uuid import UUID
|
|
import asyncio
|
|
from threading import Thread
|
|
import json
|
|
|
|
# External Packages
|
|
from langchain.chat_models import ChatOpenAI
|
|
from langchain.llms import OpenAI
|
|
from langchain.schema import ChatMessage
|
|
from langchain.callbacks.base import BaseCallbackHandler
|
|
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
|
from langchain.callbacks import AsyncIteratorCallbackHandler
|
|
from langchain.callbacks.base import BaseCallbackManager, AsyncCallbackHandler
|
|
import openai
|
|
import tiktoken
|
|
from tenacity import (
|
|
before_sleep_log,
|
|
retry,
|
|
retry_if_exception_type,
|
|
stop_after_attempt,
|
|
wait_exponential,
|
|
wait_random_exponential,
|
|
)
|
|
|
|
# Internal Packages
|
|
from khoj.utils.helpers import merge_dicts
|
|
import queue
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
max_prompt_size = {"gpt-3.5-turbo": 4096, "gpt-4": 8192}
|
|
|
|
|
|
class ThreadedGenerator:
|
|
def __init__(self, compiled_references, completion_func=None):
|
|
self.queue = queue.Queue()
|
|
self.compiled_references = compiled_references
|
|
self.completion_func = completion_func
|
|
self.response = ""
|
|
|
|
def __iter__(self):
|
|
return self
|
|
|
|
def __next__(self):
|
|
item = self.queue.get()
|
|
if item is StopIteration:
|
|
if self.completion_func:
|
|
self.completion_func(gpt_response=self.response)
|
|
raise StopIteration
|
|
return item
|
|
|
|
def send(self, data):
|
|
self.response += data
|
|
self.queue.put(data)
|
|
|
|
def close(self):
|
|
if self.compiled_references and len(self.compiled_references) > 0:
|
|
self.queue.put(f"### compiled references:{json.dumps(self.compiled_references)}")
|
|
self.queue.put(StopIteration)
|
|
|
|
|
|
class StreamingChatCallbackHandler(StreamingStdOutCallbackHandler):
|
|
def __init__(self, gen: ThreadedGenerator):
|
|
super().__init__()
|
|
self.gen = gen
|
|
|
|
def on_llm_new_token(self, token: str, **kwargs) -> Any:
|
|
self.gen.send(token)
|
|
|
|
|
|
@retry(
|
|
retry=(
|
|
retry_if_exception_type(openai.error.Timeout)
|
|
| retry_if_exception_type(openai.error.APIError)
|
|
| retry_if_exception_type(openai.error.APIConnectionError)
|
|
| retry_if_exception_type(openai.error.RateLimitError)
|
|
| retry_if_exception_type(openai.error.ServiceUnavailableError)
|
|
),
|
|
wait=wait_random_exponential(min=1, max=10),
|
|
stop=stop_after_attempt(3),
|
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
|
reraise=True,
|
|
)
|
|
def completion_with_backoff(**kwargs):
|
|
prompt = kwargs.pop("prompt")
|
|
if "api_key" in kwargs:
|
|
kwargs["openai_api_key"] = kwargs.get("api_key")
|
|
else:
|
|
kwargs["openai_api_key"] = os.getenv("OPENAI_API_KEY")
|
|
llm = OpenAI(**kwargs, request_timeout=20, max_retries=1)
|
|
return llm(prompt)
|
|
|
|
|
|
@retry(
|
|
retry=(
|
|
retry_if_exception_type(openai.error.Timeout)
|
|
| retry_if_exception_type(openai.error.APIError)
|
|
| retry_if_exception_type(openai.error.APIConnectionError)
|
|
| retry_if_exception_type(openai.error.RateLimitError)
|
|
| retry_if_exception_type(openai.error.ServiceUnavailableError)
|
|
),
|
|
wait=wait_exponential(multiplier=1, min=4, max=10),
|
|
stop=stop_after_attempt(3),
|
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
|
reraise=True,
|
|
)
|
|
def chat_completion_with_backoff(
|
|
messages, compiled_references, model_name, temperature, openai_api_key=None, completion_func=None
|
|
):
|
|
g = ThreadedGenerator(compiled_references, completion_func=completion_func)
|
|
t = Thread(target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key))
|
|
t.start()
|
|
return g
|
|
|
|
|
|
def llm_thread(g, messages, model_name, temperature, openai_api_key=None):
|
|
callback_handler = StreamingChatCallbackHandler(g)
|
|
chat = ChatOpenAI(
|
|
streaming=True,
|
|
verbose=True,
|
|
callback_manager=BaseCallbackManager([callback_handler]),
|
|
model_name=model_name,
|
|
temperature=temperature,
|
|
openai_api_key=openai_api_key or os.getenv("OPENAI_API_KEY"),
|
|
request_timeout=20,
|
|
max_retries=1,
|
|
)
|
|
|
|
chat(messages=messages)
|
|
|
|
g.close()
|
|
|
|
|
|
def generate_chatml_messages_with_context(
|
|
user_message, system_message, conversation_log={}, model_name="gpt-3.5-turbo", lookback_turns=2
|
|
):
|
|
"""Generate messages for ChatGPT with context from previous conversation"""
|
|
# Extract Chat History for Context
|
|
chat_logs = []
|
|
for chat in conversation_log.get("chat", []):
|
|
chat_notes = f'\n\n Notes:\n{chat.get("context")}' if chat.get("context") else "\n"
|
|
chat_logs += [chat["message"] + chat_notes]
|
|
|
|
rest_backnforths = []
|
|
# Extract in reverse chronological order
|
|
for user_msg, assistant_msg in zip(chat_logs[-2::-2], chat_logs[::-2]):
|
|
if len(rest_backnforths) >= 2 * lookback_turns:
|
|
break
|
|
rest_backnforths += reciprocal_conversation_to_chatml([user_msg, assistant_msg])[::-1]
|
|
|
|
# Format user and system messages to chatml format
|
|
system_chatml_message = [ChatMessage(content=system_message, role="system")]
|
|
user_chatml_message = [ChatMessage(content=user_message, role="user")]
|
|
|
|
messages = user_chatml_message + rest_backnforths + system_chatml_message
|
|
|
|
# Truncate oldest messages from conversation history until under max supported prompt size by model
|
|
messages = truncate_messages(messages, max_prompt_size[model_name], model_name)
|
|
|
|
# Return message in chronological order
|
|
return messages[::-1]
|
|
|
|
|
|
def truncate_messages(messages, max_prompt_size, model_name):
|
|
"""Truncate messages to fit within max prompt size supported by model"""
|
|
encoder = tiktoken.encoding_for_model(model_name)
|
|
tokens = sum([len(encoder.encode(message.content)) for message in messages])
|
|
while tokens > max_prompt_size and len(messages) > 1:
|
|
messages.pop()
|
|
tokens = sum([len(encoder.encode(message.content)) for message in messages])
|
|
|
|
# Truncate last message if still over max supported prompt size by model
|
|
if tokens > max_prompt_size:
|
|
last_message = "\n".join(messages[-1].content.split("\n")[:-1])
|
|
original_question = "\n".join(messages[-1].content.split("\n")[-1:])
|
|
original_question_tokens = len(encoder.encode(original_question))
|
|
remaining_tokens = max_prompt_size - original_question_tokens
|
|
truncated_message = encoder.decode(encoder.encode(last_message)[:remaining_tokens]).strip()
|
|
logger.debug(
|
|
f"Truncate last message to fit within max prompt size of {max_prompt_size} supported by {model_name} model:\n {truncated_message}"
|
|
)
|
|
messages = [ChatMessage(content=truncated_message + original_question, role=messages[-1].role)]
|
|
|
|
return messages
|
|
|
|
|
|
def reciprocal_conversation_to_chatml(message_pair):
|
|
"""Convert a single back and forth between user and assistant to chatml format"""
|
|
return [ChatMessage(content=message, role=role) for message, role in zip(message_pair, ["user", "assistant"])]
|
|
|
|
|
|
def message_to_prompt(
|
|
user_message, conversation_history="", gpt_message=None, start_sequence="\nAI:", restart_sequence="\nHuman:"
|
|
):
|
|
"""Create prompt for GPT from messages and conversation history"""
|
|
gpt_message = f" {gpt_message}" if gpt_message else ""
|
|
|
|
return f"{conversation_history}{restart_sequence} {user_message}{start_sequence}{gpt_message}"
|
|
|
|
|
|
def message_to_log(user_message, gpt_message, user_message_metadata={}, khoj_message_metadata={}, conversation_log=[]):
|
|
"""Create json logs from messages, metadata for conversation log"""
|
|
default_khoj_message_metadata = {
|
|
"intent": {"type": "remember", "memory-type": "notes", "query": user_message},
|
|
"trigger-emotion": "calm",
|
|
}
|
|
khoj_response_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
|
|
# Create json log from Human's message
|
|
human_log = merge_dicts({"message": user_message, "by": "you"}, user_message_metadata)
|
|
|
|
# Create json log from GPT's response
|
|
khoj_log = merge_dicts(khoj_message_metadata, default_khoj_message_metadata)
|
|
khoj_log = merge_dicts({"message": gpt_message, "by": "khoj", "created": khoj_response_time}, khoj_log)
|
|
|
|
conversation_log.extend([human_log, khoj_log])
|
|
return conversation_log
|
|
|
|
|
|
def extract_summaries(metadata):
|
|
"""Extract summaries from metadata"""
|
|
return "".join([f'\n{session["summary"]}' for session in metadata])
|