From 252b35b2f01e16cba56f95e73f33e9fb7a99ef48 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Mon, 4 Dec 2023 17:58:04 -0500 Subject: [PATCH] Support /image slash command to generate images using the chat API --- src/khoj/routers/api.py | 9 ++++++++- src/khoj/routers/helpers.py | 35 ++++++++++++++++++++++++++++++----- src/khoj/utils/helpers.py | 2 ++ 3 files changed, 40 insertions(+), 6 deletions(-) diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index ae125980..ae31c260 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -35,12 +35,14 @@ from khoj.processor.conversation.offline.whisper import transcribe_audio_offline from khoj.processor.conversation.openai.gpt import extract_questions from khoj.processor.conversation.openai.whisper import transcribe_audio from khoj.processor.conversation.prompts import help_message, no_entries_found +from khoj.processor.conversation.utils import save_to_conversation_log from khoj.processor.tools.online_search import search_with_google from khoj.routers.helpers import ( ApiUserRateLimiter, CommonQueryParams, agenerate_chat_response, get_conversation_command, + text_to_image, is_ready_to_chat, update_telemetry_state, validate_conversation_config, @@ -665,7 +667,7 @@ async def chat( rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=60, window=60)), rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=600, window=60 * 60 * 24)), ) -> Response: - user = request.user.object + user: KhojUser = request.user.object await is_ready_to_chat(user) conversation_command = get_conversation_command(query=q, any_references=True) @@ -703,6 +705,11 @@ async def chat( media_type="text/event-stream", status_code=200, ) + elif conversation_command == ConversationCommand.Image: + image_url, status_code = await text_to_image(q) + await sync_to_async(save_to_conversation_log)(q, image_url, user, meta_log, intent_type="text-to-image") + content_obj = {"imageUrl": image_url, "intentType": "text-to-image"} + return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code) # Get the (streamed) chat response from the LLM of choice. llm_response, chat_metadata = await agenerate_chat_response( diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 3e8ed155..4e43289f 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -9,23 +9,23 @@ from functools import partial from time import time from typing import Annotated, Any, Dict, Iterator, List, Optional, Tuple, Union +# External Packages from fastapi import Depends, Header, HTTPException, Request, UploadFile +import openai from starlette.authentication import has_required_scope -from asgiref.sync import sync_to_async - +# Internal Packages from khoj.database.adapters import ConversationAdapters, EntryAdapters -from khoj.database.models import KhojUser, Subscription +from khoj.database.models import KhojUser, Subscription, TextToImageModelConfig from khoj.processor.conversation import prompts from khoj.processor.conversation.offline.chat_model import converse_offline, send_message_to_model_offline from khoj.processor.conversation.openai.gpt import converse, send_message_to_model from khoj.processor.conversation.utils import ThreadedGenerator, save_to_conversation_log - -# Internal Packages from khoj.utils import state from khoj.utils.config import GPT4AllProcessorModel from khoj.utils.helpers import ConversationCommand, log_telemetry + logger = logging.getLogger(__name__) executor = ThreadPoolExecutor(max_workers=1) @@ -102,6 +102,8 @@ def get_conversation_command(query: str, any_references: bool = False) -> Conver return ConversationCommand.General elif query.startswith("/online"): return ConversationCommand.Online + elif query.startswith("/image"): + return ConversationCommand.Image # If no relevant notes found for the given query elif not any_references: return ConversationCommand.General @@ -248,6 +250,29 @@ def generate_chat_response( return chat_response, metadata +async def text_to_image(message: str) -> Tuple[Optional[str], int]: + status_code = 200 + image_url = None + + # Send the audio data to the Whisper API + text_to_image_config = await ConversationAdapters.aget_text_to_image_model_config() + openai_chat_config = await ConversationAdapters.get_openai_chat_config() + if not text_to_image_config: + # If the user has not configured a text to image model, return an unprocessable entity error + status_code = 422 + elif openai_chat_config and text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI: + client = openai.OpenAI(api_key=openai_chat_config.api_key) + text2image_model = text_to_image_config.model_name + try: + response = client.images.generate(prompt=message, model=text2image_model) + image_url = response.data[0].url + except openai.OpenAIError as e: + logger.error(f"Image Generation failed with {e.http_status}: {e.error}") + status_code = 500 + + return image_url, status_code + + class ApiUserRateLimiter: def __init__(self, requests: int, subscribed_requests: int, window: int): self.requests = requests diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py index 42e3835d..21fe7e98 100644 --- a/src/khoj/utils/helpers.py +++ b/src/khoj/utils/helpers.py @@ -273,6 +273,7 @@ class ConversationCommand(str, Enum): Notes = "notes" Help = "help" Online = "online" + Image = "image" command_descriptions = { @@ -280,6 +281,7 @@ command_descriptions = { ConversationCommand.Notes: "Only talk about information that is available in your knowledge base.", ConversationCommand.Default: "The default command when no command specified. It intelligently auto-switches between general and notes mode.", ConversationCommand.Online: "Look up information on the internet.", + ConversationCommand.Image: "Generate images by describing your imagination in words.", ConversationCommand.Help: "Display a help message with all available commands and other metadata.", }