From 039ed78253f03b7c780b26a9712a84375e5c0726 Mon Sep 17 00:00:00 2001 From: sabaimran <65192171+sabaimran@users.noreply.github.com> Date: Thu, 18 Jan 2024 05:54:14 -0800 Subject: [PATCH] Add support for a first-party client app to call into Khoj (Part 1) (#601) * Add support for a first party client app - Based on a client id and client secret, allow a first party app to call into the Khoj backend with a phone number identifier - Add migration to add phone numbers to the KhojUser object * Add plus in front of country code when registering a phone number. - Decrease free tier limit to 5 (from 10) - Return a response object when handling stripe webhooks * Fix telemetry method which references authenticated user's client app * Add better error handling for null phone numbers, simplify logic of authenticating user * Pull the client_secret in the API call from the authorization header * Add a migration merge to resolve phone number and other changes --- pyproject.toml | 2 + src/khoj/app/settings.py | 1 + src/khoj/configure.py | 59 ++++++++++++++++++- src/khoj/database/adapters/__init__.py | 43 ++++++++++++-- src/khoj/database/admin.py | 20 ++++++- ...lication_khojuser_phone_number_and_more.py | 46 +++++++++++++++ .../migrations/0027_merge_20240118_1324.py | 13 ++++ src/khoj/database/models/__init__.py | 12 ++++ src/khoj/routers/api.py | 12 ++-- src/khoj/routers/helpers.py | 13 ++-- src/khoj/routers/subscription.py | 1 + 11 files changed, 204 insertions(+), 18 deletions(-) create mode 100644 src/khoj/database/migrations/0025_clientapplication_khojuser_phone_number_and_more.py create mode 100644 src/khoj/database/migrations/0027_merge_20240118_1324.py diff --git a/pyproject.toml b/pyproject.toml index 5e8d6f25..7f969f3a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,6 +76,8 @@ dependencies = [ "rapidocr-onnxruntime == 1.3.8", "stripe == 7.3.0", "openai-whisper >= 20231117", + "django-phonenumber-field == 7.3.0", + "phonenumbers == 8.13.27", ] dynamic = ["version"] diff --git a/src/khoj/app/settings.py b/src/khoj/app/settings.py index 8f7d409d..c0d2d8a8 100644 --- a/src/khoj/app/settings.py +++ b/src/khoj/app/settings.py @@ -62,6 +62,7 @@ INSTALLED_APPS = [ "django.contrib.sessions", "django.contrib.messages", "django.contrib.staticfiles", + "phonenumber_field", ] MIDDLEWARE = [ diff --git a/src/khoj/configure.py b/src/khoj/configure.py index e2f16884..84a1b322 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -9,6 +9,7 @@ import openai import requests import schedule from django.utils.timezone import make_aware +from fastapi import Response from starlette.authentication import ( AuthCredentials, AuthenticationBackend, @@ -20,27 +21,32 @@ from starlette.middleware.sessions import SessionMiddleware from starlette.requests import HTTPConnection from khoj.database.adapters import ( + ClientApplicationAdapters, ConversationAdapters, SubscriptionState, + aget_or_create_user_by_phone_number, + aget_user_by_phone_number, aget_user_subscription_state, get_all_users, get_or_create_search_models, ) -from khoj.database.models import KhojUser, Subscription +from khoj.database.models import ClientApplication, KhojUser, Subscription from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel from khoj.routers.indexer import configure_content, configure_search, load_content from khoj.utils import constants, state from khoj.utils.config import SearchType from khoj.utils.fs_syncer import collect_files +from khoj.utils.helpers import is_none_or_empty from khoj.utils.rawconfig import FullConfig logger = logging.getLogger(__name__) class AuthenticatedKhojUser(SimpleUser): - def __init__(self, user): + def __init__(self, user, client_app: Optional[ClientApplication] = None): self.object = user - super().__init__(user.email) + self.client_app = client_app + super().__init__(user.username) class UserAuthenticationBackend(AuthenticationBackend): @@ -108,6 +114,53 @@ class UserAuthenticationBackend(AuthenticationBackend): if subscribed: return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user_with_token.user) return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user) + # Get query params for client_id and client_secret + client_id = request.query_params.get("client_id") + if client_id: + # Get the client secret, which is passed in the Authorization header + client_secret = request.headers["Authorization"].split("Bearer ")[1] + if not client_secret: + return Response( + status_code=401, + content="Please provide a client secret in the Authorization header with a client_id query param.", + ) + + # Get the client application + client_application = await ClientApplicationAdapters.aget_client_application_by_id(client_id, client_secret) + if client_application is None: + return AuthCredentials(), UnauthenticatedUser() + # Get the identifier used for the user + phone_number = request.query_params.get("phone_number") + if is_none_or_empty(phone_number): + return AuthCredentials(), UnauthenticatedUser() + + if not phone_number.startswith("+"): + phone_number = f"+{phone_number}" + + create_if_not_exists = request.query_params.get("create_if_not_exists") + if create_if_not_exists: + user = await aget_or_create_user_by_phone_number(phone_number) + else: + user = await aget_user_by_phone_number(phone_number) + + if user is None: + return AuthCredentials(), UnauthenticatedUser() + + if not state.billing_enabled: + return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user, client_application) + + subscription_state = await aget_user_subscription_state(user) + subscribed = ( + subscription_state == SubscriptionState.SUBSCRIBED.value + or subscription_state == SubscriptionState.TRIAL.value + or subscription_state == SubscriptionState.UNSUBSCRIBED.value + ) + if subscribed: + return ( + AuthCredentials(["authenticated", "premium"]), + AuthenticatedKhojUser(user), + ) + return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user, client_application) if state.anonymous_mode: user = await self.khojuser_manager.filter(username="default").prefetch_related("subscription").afirst() if user: diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index f09fedc6..1b5e4d1b 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -17,6 +17,7 @@ from torch import Tensor from khoj.database.models import ( ChatModelOptions, + ClientApplication, Conversation, Entry, GithubConfig, @@ -40,7 +41,7 @@ from khoj.search_filter.file_filter import FileFilter from khoj.search_filter.word_filter import WordFilter from khoj.utils import state from khoj.utils.config import GPT4AllProcessorModel -from khoj.utils.helpers import generate_random_name +from khoj.utils.helpers import generate_random_name, is_none_or_empty class SubscriptionState(Enum): @@ -85,6 +86,28 @@ async def get_or_create_user(token: dict) -> KhojUser: return user +async def aget_or_create_user_by_phone_number(phone_number: str) -> KhojUser: + if is_none_or_empty(phone_number): + return None + user = await aget_user_by_phone_number(phone_number) + if not user: + user = await acreate_user_by_phone_number(phone_number) + return user + + +async def acreate_user_by_phone_number(phone_number: str) -> KhojUser: + if is_none_or_empty(phone_number): + return None + user, _ = await KhojUser.objects.filter(phone_number=phone_number).aupdate_or_create( + defaults={"username": phone_number, "phone_number": phone_number} + ) + await user.asave() + + await Subscription.objects.acreate(user=user, type="trial") + + return user + + async def get_or_create_user_by_email(email: str) -> KhojUser: user, _ = await KhojUser.objects.filter(email=email).aupdate_or_create(defaults={"username": email, "email": email}) await user.asave() @@ -187,6 +210,12 @@ async def get_user_by_token(token: dict) -> KhojUser: return google_user.user +async def aget_user_by_phone_number(phone_number: str) -> KhojUser: + if is_none_or_empty(phone_number): + return None + return await KhojUser.objects.filter(phone_number=phone_number).prefetch_related("subscription").afirst() + + async def retrieve_user(session_id: str) -> KhojUser: session = SessionStore(session_key=session_id) if not await sync_to_async(session.exists)(session_key=session_id): @@ -270,6 +299,12 @@ async def aset_user_search_model(user: KhojUser, search_model_config_id: int): return new_config +class ClientApplicationAdapters: + @staticmethod + async def aget_client_application_by_id(client_id: str, client_secret: str): + return await ClientApplication.objects.filter(client_id=client_id, client_secret=client_secret).afirst() + + class ConversationAdapters: @staticmethod def get_conversation_by_user(user: KhojUser): @@ -279,11 +314,11 @@ class ConversationAdapters: return Conversation.objects.create(user=user) @staticmethod - async def aget_conversation_by_user(user: KhojUser): - conversation = Conversation.objects.filter(user=user) + async def aget_conversation_by_user(user: KhojUser, client_application: ClientApplication = None): + conversation = Conversation.objects.filter(user=user, client=client_application) if await conversation.aexists(): return await conversation.afirst() - return await Conversation.objects.acreate(user=user) + return await Conversation.objects.acreate(user=user, client=client_application) @staticmethod async def adelete_conversation_by_user(user: KhojUser): diff --git a/src/khoj/database/admin.py b/src/khoj/database/admin.py index f8d084d8..521e81de 100644 --- a/src/khoj/database/admin.py +++ b/src/khoj/database/admin.py @@ -7,6 +7,7 @@ from django.http import HttpResponse from khoj.database.models import ( ChatModelOptions, + ClientApplication, Conversation, KhojUser, OfflineChatProcessorConversationConfig, @@ -19,10 +20,24 @@ from khoj.database.models import ( UserSearchModelConfig, ) -# Register your models here. + +class KhojUserAdmin(UserAdmin): + list_display = ( + "id", + "email", + "username", + "is_active", + "is_staff", + "is_superuser", + "phone_number", + ) + search_fields = ("email", "username", "phone_number") + filter_horizontal = ("groups", "user_permissions") + + fieldsets = (("Personal info", {"fields": ("phone_number",)}),) + UserAdmin.fieldsets -admin.site.register(KhojUser, UserAdmin) +admin.site.register(KhojUser, KhojUserAdmin) admin.site.register(ChatModelOptions) admin.site.register(SpeechToTextModelOptions) @@ -33,6 +48,7 @@ admin.site.register(Subscription) admin.site.register(ReflectiveQuestion) admin.site.register(UserSearchModelConfig) admin.site.register(TextToImageModelConfig) +admin.site.register(ClientApplication) @admin.register(Conversation) diff --git a/src/khoj/database/migrations/0025_clientapplication_khojuser_phone_number_and_more.py b/src/khoj/database/migrations/0025_clientapplication_khojuser_phone_number_and_more.py new file mode 100644 index 00000000..14cc568a --- /dev/null +++ b/src/khoj/database/migrations/0025_clientapplication_khojuser_phone_number_and_more.py @@ -0,0 +1,46 @@ +# Generated by Django 4.2.7 on 2024-01-04 12:22 + +import django.db.models.deletion +import phonenumber_field.modelfields +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0024_alter_entry_embeddings"), + ] + + operations = [ + migrations.CreateModel( + name="ClientApplication", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("updated_at", models.DateTimeField(auto_now=True)), + ("name", models.CharField(max_length=200)), + ("client_id", models.CharField(max_length=200)), + ("client_secret", models.CharField(max_length=200)), + ], + options={ + "abstract": False, + }, + ), + migrations.AddField( + model_name="khojuser", + name="phone_number", + field=phonenumber_field.modelfields.PhoneNumberField( + blank=True, default=None, max_length=128, null=True, region=None + ), + ), + migrations.AddField( + model_name="conversation", + name="client", + field=models.ForeignKey( + blank=True, + default=None, + null=True, + on_delete=django.db.models.deletion.CASCADE, + to="database.clientapplication", + ), + ), + ] diff --git a/src/khoj/database/migrations/0027_merge_20240118_1324.py b/src/khoj/database/migrations/0027_merge_20240118_1324.py new file mode 100644 index 00000000..63aa0c25 --- /dev/null +++ b/src/khoj/database/migrations/0027_merge_20240118_1324.py @@ -0,0 +1,13 @@ +# Generated by Django 4.2.7 on 2024-01-18 13:24 +from typing import List + +from django.db import migrations + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0025_clientapplication_khojuser_phone_number_and_more"), + ("database", "0026_searchmodelconfig_cross_encoder_inference_endpoint_and_more"), + ] + + operations: List[str] = [] diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 030e7ea8..93f4f2ac 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -3,6 +3,7 @@ import uuid from django.contrib.auth.models import AbstractUser from django.db import models from pgvector.django import VectorField +from phonenumber_field.modelfields import PhoneNumberField class BaseModel(models.Model): @@ -13,8 +14,18 @@ class BaseModel(models.Model): abstract = True +class ClientApplication(BaseModel): + name = models.CharField(max_length=200) + client_id = models.CharField(max_length=200) + client_secret = models.CharField(max_length=200) + + def __str__(self): + return self.name + + class KhojUser(AbstractUser): uuid = models.UUIDField(models.UUIDField(default=uuid.uuid4, editable=False)) + phone_number = PhoneNumberField(null=True, default=None, blank=True) def save(self, *args, **kwargs): if not self.uuid: @@ -165,6 +176,7 @@ class UserSearchModelConfig(BaseModel): class Conversation(BaseModel): user = models.ForeignKey(KhojUser, on_delete=models.CASCADE) conversation_log = models.JSONField(default=dict) + client = models.ForeignKey(ClientApplication, on_delete=models.CASCADE, default=None, null=True, blank=True) class ReflectiveQuestion(BaseModel): diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 01d939ae..3047087d 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -359,8 +359,8 @@ async def chat( n: Optional[int] = 5, d: Optional[float] = 0.18, stream: Optional[bool] = False, - rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=60, window=60)), - rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=600, window=60 * 60 * 24)), + rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60)), + rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24)), ) -> Response: user: KhojUser = request.user.object q = unquote(q) @@ -372,7 +372,7 @@ async def chat( q = q.replace(f"/{conversation_command.value}", "").strip() - meta_log = (await ConversationAdapters.aget_conversation_by_user(user)).conversation_log + meta_log = (await ConversationAdapters.aget_conversation_by_user(user, request.user.client_app)).conversation_log compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions( request, common, meta_log, q, (n or 5), (d or math.inf), conversation_command @@ -392,7 +392,11 @@ async def chat( elif conversation_command == ConversationCommand.Notes and not await EntryAdapters.auser_has_entries(user): no_entries_found_format = no_entries_found.format() - return StreamingResponse(iter([no_entries_found_format]), media_type="text/event-stream", status_code=200) + if stream: + return StreamingResponse(iter([no_entries_found_format]), media_type="text/event-stream", status_code=200) + else: + response_obj = {"response": no_entries_found_format} + return Response(content=json.dumps(response_obj), media_type="text/plain", status_code=200) elif conversation_command == ConversationCommand.Online: try: diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index d7a92a20..70c7254b 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -13,7 +13,12 @@ from fastapi import Depends, Header, HTTPException, Request, UploadFile from starlette.authentication import has_required_scope from khoj.database.adapters import ConversationAdapters, EntryAdapters -from khoj.database.models import KhojUser, Subscription, TextToImageModelConfig +from khoj.database.models import ( + ClientApplication, + KhojUser, + Subscription, + TextToImageModelConfig, +) from khoj.processor.conversation import prompts from khoj.processor.conversation.offline.chat_model import ( converse_offline, @@ -74,6 +79,7 @@ def update_telemetry_state( metadata: Optional[dict] = None, ): user: KhojUser = request.user.object if request.user.is_authenticated else None + client_app: ClientApplication = request.user.client_app if request.user.is_authenticated else None subscription: Subscription = user.subscription if user and hasattr(user, "subscription") else None user_state = { "client_host": request.client.host if request.client else None, @@ -83,6 +89,7 @@ def update_telemetry_state( "server_id": str(user.uuid) if user else None, "subscription_type": subscription.type if subscription else None, "is_recurring": subscription.is_recurring if subscription else None, + "client_id": str(client_app.name) if client_app else None, } if metadata: @@ -113,10 +120,6 @@ def get_conversation_command(query: str, any_references: bool = False) -> Conver return ConversationCommand.Default -async def construct_conversation_logs(user: KhojUser): - return (await ConversationAdapters.aget_conversation_by_user(user)).conversation_log - - async def agenerate_chat_response(*args): loop = asyncio.get_event_loop() return await loop.run_in_executor(executor, generate_chat_response, *args) diff --git a/src/khoj/routers/subscription.py b/src/khoj/routers/subscription.py index 09d2a7d4..1ce49e04 100644 --- a/src/khoj/routers/subscription.py +++ b/src/khoj/routers/subscription.py @@ -5,6 +5,7 @@ from datetime import datetime, timezone import stripe from asgiref.sync import sync_to_async from fastapi import APIRouter, Request +from fastapi.responses import Response from starlette.authentication import requires from khoj.database import adapters