mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-07 13:23:15 +00:00
Support /image slash command to generate images using the chat API
This commit is contained in:
@@ -35,12 +35,14 @@ from khoj.processor.conversation.offline.whisper import transcribe_audio_offline
|
|||||||
from khoj.processor.conversation.openai.gpt import extract_questions
|
from khoj.processor.conversation.openai.gpt import extract_questions
|
||||||
from khoj.processor.conversation.openai.whisper import transcribe_audio
|
from khoj.processor.conversation.openai.whisper import transcribe_audio
|
||||||
from khoj.processor.conversation.prompts import help_message, no_entries_found
|
from khoj.processor.conversation.prompts import help_message, no_entries_found
|
||||||
|
from khoj.processor.conversation.utils import save_to_conversation_log
|
||||||
from khoj.processor.tools.online_search import search_with_google
|
from khoj.processor.tools.online_search import search_with_google
|
||||||
from khoj.routers.helpers import (
|
from khoj.routers.helpers import (
|
||||||
ApiUserRateLimiter,
|
ApiUserRateLimiter,
|
||||||
CommonQueryParams,
|
CommonQueryParams,
|
||||||
agenerate_chat_response,
|
agenerate_chat_response,
|
||||||
get_conversation_command,
|
get_conversation_command,
|
||||||
|
text_to_image,
|
||||||
is_ready_to_chat,
|
is_ready_to_chat,
|
||||||
update_telemetry_state,
|
update_telemetry_state,
|
||||||
validate_conversation_config,
|
validate_conversation_config,
|
||||||
@@ -665,7 +667,7 @@ async def chat(
|
|||||||
rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=60, window=60)),
|
rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=60, window=60)),
|
||||||
rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=600, window=60 * 60 * 24)),
|
rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=600, window=60 * 60 * 24)),
|
||||||
) -> Response:
|
) -> Response:
|
||||||
user = request.user.object
|
user: KhojUser = request.user.object
|
||||||
|
|
||||||
await is_ready_to_chat(user)
|
await is_ready_to_chat(user)
|
||||||
conversation_command = get_conversation_command(query=q, any_references=True)
|
conversation_command = get_conversation_command(query=q, any_references=True)
|
||||||
@@ -703,6 +705,11 @@ async def chat(
|
|||||||
media_type="text/event-stream",
|
media_type="text/event-stream",
|
||||||
status_code=200,
|
status_code=200,
|
||||||
)
|
)
|
||||||
|
elif conversation_command == ConversationCommand.Image:
|
||||||
|
image_url, status_code = await text_to_image(q)
|
||||||
|
await sync_to_async(save_to_conversation_log)(q, image_url, user, meta_log, intent_type="text-to-image")
|
||||||
|
content_obj = {"imageUrl": image_url, "intentType": "text-to-image"}
|
||||||
|
return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code)
|
||||||
|
|
||||||
# Get the (streamed) chat response from the LLM of choice.
|
# Get the (streamed) chat response from the LLM of choice.
|
||||||
llm_response, chat_metadata = await agenerate_chat_response(
|
llm_response, chat_metadata = await agenerate_chat_response(
|
||||||
|
|||||||
@@ -9,23 +9,23 @@ from functools import partial
|
|||||||
from time import time
|
from time import time
|
||||||
from typing import Annotated, Any, Dict, Iterator, List, Optional, Tuple, Union
|
from typing import Annotated, Any, Dict, Iterator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
# External Packages
|
||||||
from fastapi import Depends, Header, HTTPException, Request, UploadFile
|
from fastapi import Depends, Header, HTTPException, Request, UploadFile
|
||||||
|
import openai
|
||||||
from starlette.authentication import has_required_scope
|
from starlette.authentication import has_required_scope
|
||||||
from asgiref.sync import sync_to_async
|
|
||||||
|
|
||||||
|
|
||||||
|
# Internal Packages
|
||||||
from khoj.database.adapters import ConversationAdapters, EntryAdapters
|
from khoj.database.adapters import ConversationAdapters, EntryAdapters
|
||||||
from khoj.database.models import KhojUser, Subscription
|
from khoj.database.models import KhojUser, Subscription, TextToImageModelConfig
|
||||||
from khoj.processor.conversation import prompts
|
from khoj.processor.conversation import prompts
|
||||||
from khoj.processor.conversation.offline.chat_model import converse_offline, send_message_to_model_offline
|
from khoj.processor.conversation.offline.chat_model import converse_offline, send_message_to_model_offline
|
||||||
from khoj.processor.conversation.openai.gpt import converse, send_message_to_model
|
from khoj.processor.conversation.openai.gpt import converse, send_message_to_model
|
||||||
from khoj.processor.conversation.utils import ThreadedGenerator, save_to_conversation_log
|
from khoj.processor.conversation.utils import ThreadedGenerator, save_to_conversation_log
|
||||||
|
|
||||||
# Internal Packages
|
|
||||||
from khoj.utils import state
|
from khoj.utils import state
|
||||||
from khoj.utils.config import GPT4AllProcessorModel
|
from khoj.utils.config import GPT4AllProcessorModel
|
||||||
from khoj.utils.helpers import ConversationCommand, log_telemetry
|
from khoj.utils.helpers import ConversationCommand, log_telemetry
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
executor = ThreadPoolExecutor(max_workers=1)
|
executor = ThreadPoolExecutor(max_workers=1)
|
||||||
@@ -102,6 +102,8 @@ def get_conversation_command(query: str, any_references: bool = False) -> Conver
|
|||||||
return ConversationCommand.General
|
return ConversationCommand.General
|
||||||
elif query.startswith("/online"):
|
elif query.startswith("/online"):
|
||||||
return ConversationCommand.Online
|
return ConversationCommand.Online
|
||||||
|
elif query.startswith("/image"):
|
||||||
|
return ConversationCommand.Image
|
||||||
# If no relevant notes found for the given query
|
# If no relevant notes found for the given query
|
||||||
elif not any_references:
|
elif not any_references:
|
||||||
return ConversationCommand.General
|
return ConversationCommand.General
|
||||||
@@ -248,6 +250,29 @@ def generate_chat_response(
|
|||||||
return chat_response, metadata
|
return chat_response, metadata
|
||||||
|
|
||||||
|
|
||||||
|
async def text_to_image(message: str) -> Tuple[Optional[str], int]:
|
||||||
|
status_code = 200
|
||||||
|
image_url = None
|
||||||
|
|
||||||
|
# Send the audio data to the Whisper API
|
||||||
|
text_to_image_config = await ConversationAdapters.aget_text_to_image_model_config()
|
||||||
|
openai_chat_config = await ConversationAdapters.get_openai_chat_config()
|
||||||
|
if not text_to_image_config:
|
||||||
|
# If the user has not configured a text to image model, return an unprocessable entity error
|
||||||
|
status_code = 422
|
||||||
|
elif openai_chat_config and text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
|
||||||
|
client = openai.OpenAI(api_key=openai_chat_config.api_key)
|
||||||
|
text2image_model = text_to_image_config.model_name
|
||||||
|
try:
|
||||||
|
response = client.images.generate(prompt=message, model=text2image_model)
|
||||||
|
image_url = response.data[0].url
|
||||||
|
except openai.OpenAIError as e:
|
||||||
|
logger.error(f"Image Generation failed with {e.http_status}: {e.error}")
|
||||||
|
status_code = 500
|
||||||
|
|
||||||
|
return image_url, status_code
|
||||||
|
|
||||||
|
|
||||||
class ApiUserRateLimiter:
|
class ApiUserRateLimiter:
|
||||||
def __init__(self, requests: int, subscribed_requests: int, window: int):
|
def __init__(self, requests: int, subscribed_requests: int, window: int):
|
||||||
self.requests = requests
|
self.requests = requests
|
||||||
|
|||||||
@@ -273,6 +273,7 @@ class ConversationCommand(str, Enum):
|
|||||||
Notes = "notes"
|
Notes = "notes"
|
||||||
Help = "help"
|
Help = "help"
|
||||||
Online = "online"
|
Online = "online"
|
||||||
|
Image = "image"
|
||||||
|
|
||||||
|
|
||||||
command_descriptions = {
|
command_descriptions = {
|
||||||
@@ -280,6 +281,7 @@ command_descriptions = {
|
|||||||
ConversationCommand.Notes: "Only talk about information that is available in your knowledge base.",
|
ConversationCommand.Notes: "Only talk about information that is available in your knowledge base.",
|
||||||
ConversationCommand.Default: "The default command when no command specified. It intelligently auto-switches between general and notes mode.",
|
ConversationCommand.Default: "The default command when no command specified. It intelligently auto-switches between general and notes mode.",
|
||||||
ConversationCommand.Online: "Look up information on the internet.",
|
ConversationCommand.Online: "Look up information on the internet.",
|
||||||
|
ConversationCommand.Image: "Generate images by describing your imagination in words.",
|
||||||
ConversationCommand.Help: "Display a help message with all available commands and other metadata.",
|
ConversationCommand.Help: "Display a help message with all available commands and other metadata.",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user