Support /image slash command to generate images using the chat API

This commit is contained in:
Debanjum Singh Solanky
2023-12-04 17:58:04 -05:00
parent 1d9c1333f2
commit 252b35b2f0
3 changed files with 40 additions and 6 deletions

View File

@@ -35,12 +35,14 @@ from khoj.processor.conversation.offline.whisper import transcribe_audio_offline
from khoj.processor.conversation.openai.gpt import extract_questions from khoj.processor.conversation.openai.gpt import extract_questions
from khoj.processor.conversation.openai.whisper import transcribe_audio from khoj.processor.conversation.openai.whisper import transcribe_audio
from khoj.processor.conversation.prompts import help_message, no_entries_found from khoj.processor.conversation.prompts import help_message, no_entries_found
from khoj.processor.conversation.utils import save_to_conversation_log
from khoj.processor.tools.online_search import search_with_google from khoj.processor.tools.online_search import search_with_google
from khoj.routers.helpers import ( from khoj.routers.helpers import (
ApiUserRateLimiter, ApiUserRateLimiter,
CommonQueryParams, CommonQueryParams,
agenerate_chat_response, agenerate_chat_response,
get_conversation_command, get_conversation_command,
text_to_image,
is_ready_to_chat, is_ready_to_chat,
update_telemetry_state, update_telemetry_state,
validate_conversation_config, validate_conversation_config,
@@ -665,7 +667,7 @@ async def chat(
rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=60, window=60)), rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=60, window=60)),
rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=600, window=60 * 60 * 24)), rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=600, window=60 * 60 * 24)),
) -> Response: ) -> Response:
user = request.user.object user: KhojUser = request.user.object
await is_ready_to_chat(user) await is_ready_to_chat(user)
conversation_command = get_conversation_command(query=q, any_references=True) conversation_command = get_conversation_command(query=q, any_references=True)
@@ -703,6 +705,11 @@ async def chat(
media_type="text/event-stream", media_type="text/event-stream",
status_code=200, status_code=200,
) )
elif conversation_command == ConversationCommand.Image:
image_url, status_code = await text_to_image(q)
await sync_to_async(save_to_conversation_log)(q, image_url, user, meta_log, intent_type="text-to-image")
content_obj = {"imageUrl": image_url, "intentType": "text-to-image"}
return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code)
# Get the (streamed) chat response from the LLM of choice. # Get the (streamed) chat response from the LLM of choice.
llm_response, chat_metadata = await agenerate_chat_response( llm_response, chat_metadata = await agenerate_chat_response(

View File

@@ -9,23 +9,23 @@ from functools import partial
from time import time from time import time
from typing import Annotated, Any, Dict, Iterator, List, Optional, Tuple, Union from typing import Annotated, Any, Dict, Iterator, List, Optional, Tuple, Union
# External Packages
from fastapi import Depends, Header, HTTPException, Request, UploadFile from fastapi import Depends, Header, HTTPException, Request, UploadFile
import openai
from starlette.authentication import has_required_scope from starlette.authentication import has_required_scope
from asgiref.sync import sync_to_async
# Internal Packages
from khoj.database.adapters import ConversationAdapters, EntryAdapters from khoj.database.adapters import ConversationAdapters, EntryAdapters
from khoj.database.models import KhojUser, Subscription from khoj.database.models import KhojUser, Subscription, TextToImageModelConfig
from khoj.processor.conversation import prompts from khoj.processor.conversation import prompts
from khoj.processor.conversation.offline.chat_model import converse_offline, send_message_to_model_offline from khoj.processor.conversation.offline.chat_model import converse_offline, send_message_to_model_offline
from khoj.processor.conversation.openai.gpt import converse, send_message_to_model from khoj.processor.conversation.openai.gpt import converse, send_message_to_model
from khoj.processor.conversation.utils import ThreadedGenerator, save_to_conversation_log from khoj.processor.conversation.utils import ThreadedGenerator, save_to_conversation_log
# Internal Packages
from khoj.utils import state from khoj.utils import state
from khoj.utils.config import GPT4AllProcessorModel from khoj.utils.config import GPT4AllProcessorModel
from khoj.utils.helpers import ConversationCommand, log_telemetry from khoj.utils.helpers import ConversationCommand, log_telemetry
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
executor = ThreadPoolExecutor(max_workers=1) executor = ThreadPoolExecutor(max_workers=1)
@@ -102,6 +102,8 @@ def get_conversation_command(query: str, any_references: bool = False) -> Conver
return ConversationCommand.General return ConversationCommand.General
elif query.startswith("/online"): elif query.startswith("/online"):
return ConversationCommand.Online return ConversationCommand.Online
elif query.startswith("/image"):
return ConversationCommand.Image
# If no relevant notes found for the given query # If no relevant notes found for the given query
elif not any_references: elif not any_references:
return ConversationCommand.General return ConversationCommand.General
@@ -248,6 +250,29 @@ def generate_chat_response(
return chat_response, metadata return chat_response, metadata
async def text_to_image(message: str) -> Tuple[Optional[str], int]:
status_code = 200
image_url = None
# Send the audio data to the Whisper API
text_to_image_config = await ConversationAdapters.aget_text_to_image_model_config()
openai_chat_config = await ConversationAdapters.get_openai_chat_config()
if not text_to_image_config:
# If the user has not configured a text to image model, return an unprocessable entity error
status_code = 422
elif openai_chat_config and text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
client = openai.OpenAI(api_key=openai_chat_config.api_key)
text2image_model = text_to_image_config.model_name
try:
response = client.images.generate(prompt=message, model=text2image_model)
image_url = response.data[0].url
except openai.OpenAIError as e:
logger.error(f"Image Generation failed with {e.http_status}: {e.error}")
status_code = 500
return image_url, status_code
class ApiUserRateLimiter: class ApiUserRateLimiter:
def __init__(self, requests: int, subscribed_requests: int, window: int): def __init__(self, requests: int, subscribed_requests: int, window: int):
self.requests = requests self.requests = requests

View File

@@ -273,6 +273,7 @@ class ConversationCommand(str, Enum):
Notes = "notes" Notes = "notes"
Help = "help" Help = "help"
Online = "online" Online = "online"
Image = "image"
command_descriptions = { command_descriptions = {
@@ -280,6 +281,7 @@ command_descriptions = {
ConversationCommand.Notes: "Only talk about information that is available in your knowledge base.", ConversationCommand.Notes: "Only talk about information that is available in your knowledge base.",
ConversationCommand.Default: "The default command when no command specified. It intelligently auto-switches between general and notes mode.", ConversationCommand.Default: "The default command when no command specified. It intelligently auto-switches between general and notes mode.",
ConversationCommand.Online: "Look up information on the internet.", ConversationCommand.Online: "Look up information on the internet.",
ConversationCommand.Image: "Generate images by describing your imagination in words.",
ConversationCommand.Help: "Display a help message with all available commands and other metadata.", ConversationCommand.Help: "Display a help message with all available commands and other metadata.",
} }