mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-06 21:29:12 +00:00
Rebase with master
This commit is contained in:
@@ -19,7 +19,7 @@ from starlette.authentication import requires
|
||||
from khoj.configure import configure_server
|
||||
from khoj.database import adapters
|
||||
from khoj.database.adapters import ConversationAdapters, EntryAdapters, get_default_search_model
|
||||
from khoj.database.models import ChatModelOptions
|
||||
from khoj.database.models import ChatModelOptions, SpeechToTextModelOptions
|
||||
from khoj.database.models import Entry as DbEntry
|
||||
from khoj.database.models import (
|
||||
GithubConfig,
|
||||
@@ -35,15 +35,18 @@ from khoj.processor.conversation.offline.whisper import transcribe_audio_offline
|
||||
from khoj.processor.conversation.openai.gpt import extract_questions
|
||||
from khoj.processor.conversation.openai.whisper import transcribe_audio
|
||||
from khoj.processor.conversation.prompts import help_message, no_entries_found
|
||||
from khoj.processor.conversation.utils import save_to_conversation_log
|
||||
from khoj.processor.tools.online_search import search_with_google
|
||||
from khoj.routers.helpers import (
|
||||
ApiUserRateLimiter,
|
||||
CommonQueryParams,
|
||||
agenerate_chat_response,
|
||||
get_conversation_command,
|
||||
text_to_image,
|
||||
is_ready_to_chat,
|
||||
update_telemetry_state,
|
||||
validate_conversation_config,
|
||||
ConversationCommandRateLimiter,
|
||||
)
|
||||
from khoj.search_filter.date_filter import DateFilter
|
||||
from khoj.search_filter.file_filter import FileFilter
|
||||
@@ -65,6 +68,7 @@ from khoj.utils.state import SearchType
|
||||
# Initialize Router
|
||||
api = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
conversation_command_rate_limiter = ConversationCommandRateLimiter(trial_rate_limit=5, subscribed_rate_limit=100)
|
||||
|
||||
|
||||
def map_config_to_object(content_source: str):
|
||||
@@ -603,7 +607,13 @@ async def chat_options(
|
||||
|
||||
@api.post("/transcribe")
|
||||
@requires(["authenticated"])
|
||||
async def transcribe(request: Request, common: CommonQueryParams, file: UploadFile = File(...)):
|
||||
async def transcribe(
|
||||
request: Request,
|
||||
common: CommonQueryParams,
|
||||
file: UploadFile = File(...),
|
||||
rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=1, subscribed_requests=10, window=60)),
|
||||
rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=600, window=60 * 60 * 24)),
|
||||
):
|
||||
user: KhojUser = request.user.object
|
||||
audio_filename = f"{user.uuid}-{str(uuid.uuid4())}.webm"
|
||||
user_message: str = None
|
||||
@@ -623,17 +633,15 @@ async def transcribe(request: Request, common: CommonQueryParams, file: UploadFi
|
||||
|
||||
# Send the audio data to the Whisper API
|
||||
speech_to_text_config = await ConversationAdapters.get_speech_to_text_config()
|
||||
openai_chat_config = await ConversationAdapters.get_openai_chat_config()
|
||||
if not speech_to_text_config:
|
||||
# If the user has not configured a speech to text model, return an unprocessable entity error
|
||||
status_code = 422
|
||||
elif openai_chat_config and speech_to_text_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
||||
api_key = openai_chat_config.api_key
|
||||
# If the user has not configured a speech to text model, return an unsupported on server error
|
||||
status_code = 501
|
||||
elif state.openai_client and speech_to_text_config.model_type == SpeechToTextModelOptions.ModelType.OPENAI:
|
||||
speech2text_model = speech_to_text_config.model_name
|
||||
user_message = await transcribe_audio(audio_file, model=speech2text_model, api_key=api_key)
|
||||
elif speech_to_text_config.model_type == ChatModelOptions.ModelType.OFFLINE:
|
||||
user_message = await transcribe_audio(audio_file, speech2text_model, client=state.openai_client)
|
||||
elif speech_to_text_config.model_type == SpeechToTextModelOptions.ModelType.OFFLINE:
|
||||
speech2text_model = speech_to_text_config.model_name
|
||||
user_message = await transcribe_audio_offline(audio_filename, model=speech2text_model)
|
||||
user_message = await transcribe_audio_offline(audio_filename, speech2text_model)
|
||||
finally:
|
||||
# Close and Delete the temporary audio file
|
||||
audio_file.close()
|
||||
@@ -666,11 +674,13 @@ async def chat(
|
||||
rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=60, window=60)),
|
||||
rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=600, window=60 * 60 * 24)),
|
||||
) -> Response:
|
||||
user = request.user.object
|
||||
user: KhojUser = request.user.object
|
||||
|
||||
await is_ready_to_chat(user)
|
||||
conversation_command = get_conversation_command(query=q, any_references=True)
|
||||
|
||||
conversation_command_rate_limiter.update_and_check_if_valid(request, conversation_command)
|
||||
|
||||
q = q.replace(f"/{conversation_command.value}", "").strip()
|
||||
|
||||
meta_log = (await ConversationAdapters.aget_conversation_by_user(user)).conversation_log
|
||||
@@ -704,6 +714,27 @@ async def chat(
|
||||
media_type="text/event-stream",
|
||||
status_code=200,
|
||||
)
|
||||
elif conversation_command == ConversationCommand.Image:
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
telemetry_type="api",
|
||||
api="chat",
|
||||
metadata={"conversation_command": conversation_command.value},
|
||||
**common.__dict__,
|
||||
)
|
||||
image, status_code, improved_image_prompt = await text_to_image(q)
|
||||
if image is None:
|
||||
content_obj = {
|
||||
"image": image,
|
||||
"intentType": "text-to-image",
|
||||
"detail": "Failed to generate image. Make sure your image generation configuration is set.",
|
||||
}
|
||||
return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code)
|
||||
await sync_to_async(save_to_conversation_log)(
|
||||
q, image, user, meta_log, intent_type="text-to-image", inferred_queries=[improved_image_prompt]
|
||||
)
|
||||
content_obj = {"image": image, "intentType": "text-to-image", "inferredQueries": [improved_image_prompt]} # type: ignore
|
||||
return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code)
|
||||
|
||||
# Get the (streamed) chat response from the LLM of choice.
|
||||
llm_response, chat_metadata = await agenerate_chat_response(
|
||||
@@ -787,7 +818,6 @@ async def extract_references_and_questions(
|
||||
conversation_config = await ConversationAdapters.aget_conversation_config(user)
|
||||
if conversation_config is None:
|
||||
conversation_config = await ConversationAdapters.aget_default_conversation_config()
|
||||
openai_chat_config = await ConversationAdapters.aget_openai_conversation_config()
|
||||
if (
|
||||
offline_chat_config
|
||||
and offline_chat_config.enabled
|
||||
@@ -804,7 +834,7 @@ async def extract_references_and_questions(
|
||||
inferred_queries = extract_questions_offline(
|
||||
defiltered_query, loaded_model=loaded_model, conversation_log=meta_log, should_extract_questions=False
|
||||
)
|
||||
elif openai_chat_config and conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
||||
elif conversation_config and conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
||||
openai_chat_config = await ConversationAdapters.get_openai_chat_config()
|
||||
openai_chat = await ConversationAdapters.get_openai_chat()
|
||||
api_key = openai_chat_config.api_key
|
||||
|
||||
Reference in New Issue
Block a user