mirror of
				https://github.com/inventree/InvenTree.git
				synced 2025-11-03 22:55:43 +00:00 
			
		
		
		
	Add dj-rest-auth (#4187)
* Add dj-rest-auth [FR] User registration via API Fixes #3978 * add jwt support for API * check for old password * Add check if registration is allowed * make email mandatory if selected * lower postgres version? * update req * revert psql change * move form options out * Update reqs * Add handlers for most OAuth2 * refactor and add logging * make error message more actionable * add handler for twitter * add keycloak endpoint * warning for legacy apps * remove legacy twitter support * rename file * move url to sub * make JWT optional (default off) * Add var to config template * Add API endpoint to list available providers * fix url pattern
This commit is contained in:
		@@ -21,6 +21,8 @@ from crispy_forms.bootstrap import (AppendedText, PrependedAppendedText,
 | 
			
		||||
                                    PrependedText)
 | 
			
		||||
from crispy_forms.helper import FormHelper
 | 
			
		||||
from crispy_forms.layout import Field, Layout
 | 
			
		||||
from dj_rest_auth.registration.serializers import RegisterSerializer
 | 
			
		||||
from rest_framework import serializers
 | 
			
		||||
 | 
			
		||||
from common.models import InvenTreeSetting
 | 
			
		||||
from InvenTree.exceptions import log_error
 | 
			
		||||
@@ -206,6 +208,11 @@ class CustomSignupForm(SignupForm):
 | 
			
		||||
        return cleaned_data
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def registration_enabled():
 | 
			
		||||
    """Determine whether user registration is enabled."""
 | 
			
		||||
    return settings.EMAIL_HOST and (InvenTreeSetting.get_setting('LOGIN_ENABLE_REG') or InvenTreeSetting.get_setting('LOGIN_ENABLE_SSO_REG'))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RegistratonMixin:
 | 
			
		||||
    """Mixin to check if registration should be enabled."""
 | 
			
		||||
 | 
			
		||||
@@ -214,7 +221,7 @@ class RegistratonMixin:
 | 
			
		||||
 | 
			
		||||
        Configure the class variable `REGISTRATION_SETTING` to set which setting should be used, default: `LOGIN_ENABLE_REG`.
 | 
			
		||||
        """
 | 
			
		||||
        if settings.EMAIL_HOST and (InvenTreeSetting.get_setting('LOGIN_ENABLE_REG') or InvenTreeSetting.get_setting('LOGIN_ENABLE_SSO_REG')):
 | 
			
		||||
        if registration_enabled():
 | 
			
		||||
            return super().is_open_for_signup(request, *args, **kwargs)
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
@@ -319,3 +326,20 @@ class CustomSocialAccountAdapter(CustomUrlMixin, RegistratonMixin, DefaultSocial
 | 
			
		||||
 | 
			
		||||
        # Otherwise defer to the original allauth adapter.
 | 
			
		||||
        return super().login(request, user)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# override dj-rest-auth
 | 
			
		||||
class CustomRegisterSerializer(RegisterSerializer):
 | 
			
		||||
    """Override of serializer to use dynamic settings."""
 | 
			
		||||
    email = serializers.EmailField()
 | 
			
		||||
 | 
			
		||||
    def __init__(self, instance=None, data=..., **kwargs):
 | 
			
		||||
        """Check settings to influence which fields are needed."""
 | 
			
		||||
        kwargs['email_required'] = InvenTreeSetting.get_setting('LOGIN_MAIL_REQUIRED')
 | 
			
		||||
        super().__init__(instance, data, **kwargs)
 | 
			
		||||
 | 
			
		||||
    def save(self, request):
 | 
			
		||||
        """Override to check if registration is open."""
 | 
			
		||||
        if registration_enabled():
 | 
			
		||||
            return super().save(request)
 | 
			
		||||
        raise forms.ValidationError(_('Registration is disabled.'))
 | 
			
		||||
 
 | 
			
		||||
@@ -245,6 +245,8 @@ INSTALLED_APPS = [
 | 
			
		||||
    'django_otp.plugins.otp_static',        # Backup codes
 | 
			
		||||
 | 
			
		||||
    'allauth_2fa',                          # MFA flow for allauth
 | 
			
		||||
    'dj_rest_auth',                         # Authentication APIs - dj-rest-auth
 | 
			
		||||
    'dj_rest_auth.registration',            # Registration APIs - dj-rest-auth'
 | 
			
		||||
    'drf_spectacular',                      # API documentation
 | 
			
		||||
 | 
			
		||||
    'django_ical',                          # For exporting calendars
 | 
			
		||||
@@ -380,6 +382,23 @@ if DEBUG:
 | 
			
		||||
    # Enable browsable API if in DEBUG mode
 | 
			
		||||
    REST_FRAMEWORK['DEFAULT_RENDERER_CLASSES'].append('rest_framework.renderers.BrowsableAPIRenderer')
 | 
			
		||||
 | 
			
		||||
# dj-rest-auth
 | 
			
		||||
# JWT switch
 | 
			
		||||
USE_JWT = get_boolean_setting('INVENTREE_USE_JWT', 'use_jwt', False)
 | 
			
		||||
REST_USE_JWT = USE_JWT
 | 
			
		||||
OLD_PASSWORD_FIELD_ENABLED = True
 | 
			
		||||
REST_AUTH_REGISTER_SERIALIZERS = {'REGISTER_SERIALIZER': 'InvenTree.forms.CustomRegisterSerializer'}
 | 
			
		||||
 | 
			
		||||
# JWT settings - rest_framework_simplejwt
 | 
			
		||||
if USE_JWT:
 | 
			
		||||
    JWT_AUTH_COOKIE = 'inventree-auth'
 | 
			
		||||
    JWT_AUTH_REFRESH_COOKIE = 'inventree-token'
 | 
			
		||||
    REST_FRAMEWORK['DEFAULT_AUTHENTICATION_CLASSES'] + (
 | 
			
		||||
        'dj_rest_auth.jwt_auth.JWTCookieAuthentication',
 | 
			
		||||
    )
 | 
			
		||||
    INSTALLED_APPS.append('rest_framework_simplejwt')
 | 
			
		||||
 | 
			
		||||
# WSGI default setting
 | 
			
		||||
SPECTACULAR_SETTINGS = {
 | 
			
		||||
    'TITLE': 'InvenTree API',
 | 
			
		||||
    'DESCRIPTION': 'API for InvenTree - the intuitive open source inventory management system',
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										127
									
								
								InvenTree/InvenTree/social_auth_urls.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										127
									
								
								InvenTree/InvenTree/social_auth_urls.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,127 @@
 | 
			
		||||
"""API endpoints for social authentication with allauth."""
 | 
			
		||||
import logging
 | 
			
		||||
from importlib import import_module
 | 
			
		||||
 | 
			
		||||
from django.urls import include, path, reverse
 | 
			
		||||
 | 
			
		||||
from allauth.socialaccount import providers
 | 
			
		||||
from allauth.socialaccount.models import SocialApp
 | 
			
		||||
from allauth.socialaccount.providers.keycloak.views import \
 | 
			
		||||
    KeycloakOAuth2Adapter
 | 
			
		||||
from allauth.socialaccount.providers.oauth2.views import (OAuth2Adapter,
 | 
			
		||||
                                                          OAuth2LoginView)
 | 
			
		||||
from rest_framework.generics import ListAPIView
 | 
			
		||||
from rest_framework.permissions import AllowAny
 | 
			
		||||
from rest_framework.response import Response
 | 
			
		||||
 | 
			
		||||
from common.models import InvenTreeSetting
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger('inventree')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class GenericOAuth2ApiLoginView(OAuth2LoginView):
 | 
			
		||||
    """Api view to login a user with a social account"""
 | 
			
		||||
    def dispatch(self, request, *args, **kwargs):
 | 
			
		||||
        """Dispatch the regular login view directly."""
 | 
			
		||||
        return self.login(request, *args, **kwargs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class GenericOAuth2ApiConnectView(GenericOAuth2ApiLoginView):
 | 
			
		||||
    """Api view to connect a social account to the current user"""
 | 
			
		||||
 | 
			
		||||
    def dispatch(self, request, *args, **kwargs):
 | 
			
		||||
        """Dispatch the connect request directly."""
 | 
			
		||||
 | 
			
		||||
        # Override the request method be in connection mode
 | 
			
		||||
        request.GET = request.GET.copy()
 | 
			
		||||
        request.GET['process'] = 'connect'
 | 
			
		||||
 | 
			
		||||
        # Resume the dispatch
 | 
			
		||||
        return super().dispatch(request, *args, **kwargs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def handle_oauth2(adapter: OAuth2Adapter):
 | 
			
		||||
    """Define urls for oauth2 endpoints."""
 | 
			
		||||
    return [
 | 
			
		||||
        path('login/', GenericOAuth2ApiLoginView.adapter_view(adapter), name=f'{provider.id}_api_login'),
 | 
			
		||||
        path('connect/', GenericOAuth2ApiConnectView.adapter_view(adapter), name=f'{provider.id}_api_connect'),
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def handle_keycloak():
 | 
			
		||||
    """Define urls for keycloak."""
 | 
			
		||||
    return [
 | 
			
		||||
        path('login/', GenericOAuth2ApiLoginView.adapter_view(KeycloakOAuth2Adapter), name='keycloak_api_login'),
 | 
			
		||||
        path('connect/', GenericOAuth2ApiConnectView.adapter_view(KeycloakOAuth2Adapter), name='keycloak_api_connet'),
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
legacy = {
 | 
			
		||||
    'twitter': 'twitter_oauth2',
 | 
			
		||||
    'bitbucket': 'bitbucket_oauth2',
 | 
			
		||||
    'linkedin': 'linkedin_oauth2',
 | 
			
		||||
    'vimeo': 'vimeo_oauth2',
 | 
			
		||||
    'openid': 'openid_connect',
 | 
			
		||||
}  # legacy connectors
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Collect urls for all loaded providers
 | 
			
		||||
social_auth_urlpatterns = []
 | 
			
		||||
 | 
			
		||||
provider_urlpatterns = []
 | 
			
		||||
for provider in providers.registry.get_list():
 | 
			
		||||
    try:
 | 
			
		||||
        prov_mod = import_module(provider.get_package() + ".views")
 | 
			
		||||
    except ImportError:
 | 
			
		||||
        continue
 | 
			
		||||
 | 
			
		||||
    # Try to extract the adapter class
 | 
			
		||||
    adapters = [cls for cls in prov_mod.__dict__.values() if isinstance(cls, type) and not cls == OAuth2Adapter and issubclass(cls, OAuth2Adapter)]
 | 
			
		||||
 | 
			
		||||
    # Get urls
 | 
			
		||||
    urls = []
 | 
			
		||||
    if len(adapters) == 1:
 | 
			
		||||
        urls = handle_oauth2(adapter=adapters[0])
 | 
			
		||||
    else:
 | 
			
		||||
        if provider.id in legacy:
 | 
			
		||||
            logger.warning(f'`{provider.id}` is not supported on platform UI. Use `{legacy[provider.id]}` instead.')
 | 
			
		||||
            continue
 | 
			
		||||
        elif provider.id == 'keycloak':
 | 
			
		||||
            urls = handle_keycloak()
 | 
			
		||||
        else:
 | 
			
		||||
            logger.error(f'Found handler that is not yet ready for platform UI: `{provider.id}`. Open an feature request on GitHub if you need it implemented.')
 | 
			
		||||
            continue
 | 
			
		||||
    provider_urlpatterns += [path(f'{provider.id}/', include(urls))]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
social_auth_urlpatterns += provider_urlpatterns
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SocialProvierListView(ListAPIView):
 | 
			
		||||
    """List of available social providers."""
 | 
			
		||||
    permission_classes = (AllowAny,)
 | 
			
		||||
 | 
			
		||||
    def get(self, request, *args, **kwargs):
 | 
			
		||||
        """Get the list of providers."""
 | 
			
		||||
        provider_list = []
 | 
			
		||||
        for provider in providers.registry.get_list():
 | 
			
		||||
            provider_data = {
 | 
			
		||||
                'id': provider.id,
 | 
			
		||||
                'name': provider.name,
 | 
			
		||||
                'login': request.build_absolute_uri(reverse(f'{provider.id}_api_login')),
 | 
			
		||||
                'connect': request.build_absolute_uri(reverse(f'{provider.id}_api_connect')),
 | 
			
		||||
            }
 | 
			
		||||
            try:
 | 
			
		||||
                provider_data['display_name'] = provider.get_app(request).name
 | 
			
		||||
            except SocialApp.DoesNotExist:
 | 
			
		||||
                provider_data['display_name'] = provider.name
 | 
			
		||||
 | 
			
		||||
            provider_list.append(provider_data)
 | 
			
		||||
 | 
			
		||||
        data = {
 | 
			
		||||
            'sso_enabled': InvenTreeSetting.get_setting('LOGIN_ENABLE_SSO'),
 | 
			
		||||
            'sso_registration': InvenTreeSetting.get_setting('LOGIN_ENABLE_SSO_REG'),
 | 
			
		||||
            'mfa_required': InvenTreeSetting.get_setting('LOGIN_ENFORCE_MFA'),
 | 
			
		||||
            'providers': provider_list
 | 
			
		||||
        }
 | 
			
		||||
        return Response(data)
 | 
			
		||||
@@ -9,6 +9,8 @@ from django.contrib import admin
 | 
			
		||||
from django.urls import include, path, re_path
 | 
			
		||||
from django.views.generic.base import RedirectView
 | 
			
		||||
 | 
			
		||||
from dj_rest_auth.registration.views import (SocialAccountDisconnectView,
 | 
			
		||||
                                             SocialAccountListView)
 | 
			
		||||
from drf_spectacular.views import SpectacularAPIView, SpectacularRedocView
 | 
			
		||||
 | 
			
		||||
from build.api import build_api_urls
 | 
			
		||||
@@ -31,6 +33,7 @@ from stock.urls import stock_urls
 | 
			
		||||
from users.api import user_urls
 | 
			
		||||
 | 
			
		||||
from .api import APISearchView, InfoView, NotFoundView
 | 
			
		||||
from .social_auth_urls import SocialProvierListView, social_auth_urlpatterns
 | 
			
		||||
from .views import (AboutView, AppearanceSelectView, CustomConnectionsView,
 | 
			
		||||
                    CustomEmailView, CustomLoginView,
 | 
			
		||||
                    CustomPasswordResetFromKeyView,
 | 
			
		||||
@@ -71,6 +74,14 @@ apipatterns = [
 | 
			
		||||
    # InvenTree information endpoint
 | 
			
		||||
    path('', InfoView.as_view(), name='api-inventree-info'),
 | 
			
		||||
 | 
			
		||||
    # Third party API endpoints
 | 
			
		||||
    path('auth/', include('dj_rest_auth.urls')),
 | 
			
		||||
    path('auth/registration/', include('dj_rest_auth.registration.urls')),
 | 
			
		||||
    path('auth/providers/', SocialProvierListView.as_view(), name='social_providers'),
 | 
			
		||||
    path('auth/social/', include(social_auth_urlpatterns)),
 | 
			
		||||
    path('auth/social/', SocialAccountListView.as_view(), name='social_account_list'),
 | 
			
		||||
    path('auth/social/<int:pk>/disconnect/', SocialAccountDisconnectView.as_view(), name='social_account_disconnect'),
 | 
			
		||||
 | 
			
		||||
    # Unknown endpoint
 | 
			
		||||
    re_path(r'^.*$', NotFoundView.as_view(), name='api-404'),
 | 
			
		||||
]
 | 
			
		||||
 
 | 
			
		||||
@@ -183,6 +183,11 @@ login_default_protocol: http
 | 
			
		||||
remote_login_enabled: False
 | 
			
		||||
remote_login_header: HTTP_REMOTE_USER
 | 
			
		||||
 | 
			
		||||
# JWT tokens
 | 
			
		||||
# JWT can be used optionally to authenticate users. Turned off by default.
 | 
			
		||||
# Alternatively, use the environment variable INVENTREE_USE_JWT
 | 
			
		||||
# use_jwt: True
 | 
			
		||||
 | 
			
		||||
# Logout redirect configuration
 | 
			
		||||
# This setting may be required if using remote / proxy login to redirect requests
 | 
			
		||||
# during the logout process (default is 'index'). Please read the docs for more details
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user