mirror of
				https://github.com/inventree/InvenTree.git
				synced 2025-10-30 20:55:42 +00:00 
			
		
		
		
	Merge pull request #735 from SchrodingersGat/token-auth-fix
Improvements for token authentication
This commit is contained in:
		| @@ -4,9 +4,10 @@ from rest_framework.test import APITestCase | |||||||
| from rest_framework import status | from rest_framework import status | ||||||
|  |  | ||||||
| from django.urls import reverse | from django.urls import reverse | ||||||
|  |  | ||||||
| from django.contrib.auth import get_user_model | from django.contrib.auth import get_user_model | ||||||
|  |  | ||||||
|  | from base64 import b64encode | ||||||
|  |  | ||||||
|  |  | ||||||
| class APITests(APITestCase): | class APITests(APITestCase): | ||||||
|     """ Tests for the InvenTree API """ |     """ Tests for the InvenTree API """ | ||||||
| @@ -21,24 +22,48 @@ class APITests(APITestCase): | |||||||
|     username = 'test_user' |     username = 'test_user' | ||||||
|     password = 'test_pass' |     password = 'test_pass' | ||||||
|  |  | ||||||
|  |     token = None | ||||||
|  |  | ||||||
|     def setUp(self): |     def setUp(self): | ||||||
|  |  | ||||||
|         # Create a user (but do not log in!) |         # Create a user (but do not log in!) | ||||||
|         User = get_user_model() |         User = get_user_model() | ||||||
|         User.objects.create_user(self.username, 'user@email.com', self.password) |         User.objects.create_user(self.username, 'user@email.com', self.password) | ||||||
|  |  | ||||||
|     def get_token(self): |     def basicAuth(self): | ||||||
|         token_url = reverse('api-token') |         # Use basic authentication | ||||||
|  |  | ||||||
|         # POST to retreive a token |         authstring = bytes("{u}:{p}".format(u=self.username, p=self.password), "ascii") | ||||||
|         response = self.client.post(token_url, format='json', data={'username': self.username, 'password': self.password}) |  | ||||||
|  |         # Use "basic" auth by default | ||||||
|  |         auth = b64encode(authstring).decode("ascii") | ||||||
|  |         self.client.credentials(HTTP_AUTHORIZATION="Basic {auth}".format(auth=auth)) | ||||||
|  |  | ||||||
|  |     def tokenAuth(self): | ||||||
|  |  | ||||||
|  |         self.basicAuth() | ||||||
|  |         token_url = reverse('api-token') | ||||||
|  |         response = self.client.get(token_url, format='json', data={}) | ||||||
|  |  | ||||||
|  |         self.assertEqual(response.status_code, status.HTTP_200_OK) | ||||||
|  |         self.assertIn('token', response.data) | ||||||
|  |  | ||||||
|         token = response.data['token'] |         token = response.data['token'] | ||||||
|  |  | ||||||
|         self.client.credentials(HTTP_AUTHORIZATION='Token ' + token) |  | ||||||
|  |  | ||||||
|         self.token = token |         self.token = token | ||||||
|  |  | ||||||
|  |     def token_failure(self): | ||||||
|  |         # Test token endpoint without basic auth | ||||||
|  |         url = reverse('api-token') | ||||||
|  |         response = self.client.get(url, format='json') | ||||||
|  |  | ||||||
|  |         self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) | ||||||
|  |         self.assertIsNone(self.token) | ||||||
|  |  | ||||||
|  |     def token_success(self): | ||||||
|  |  | ||||||
|  |         self.tokenAuth() | ||||||
|  |         self.assertIsNotNone(self.token) | ||||||
|  |  | ||||||
|     def test_info_view(self): |     def test_info_view(self): | ||||||
|         """ |         """ | ||||||
|         Test that we can read the 'info-view' endpoint. |         Test that we can read the 'info-view' endpoint. | ||||||
| @@ -55,51 +80,18 @@ class APITests(APITestCase): | |||||||
|  |  | ||||||
|         self.assertEquals('InvenTree', data['server']) |         self.assertEquals('InvenTree', data['server']) | ||||||
|  |  | ||||||
|     def test_get_token_fail(self): |     def test_barcode_fail(self): | ||||||
|         """ Ensure that an invalid user cannot get a token """ |         # Test barcode endpoint without auth | ||||||
|  |         response = self.client.post(reverse('api-barcode-plugin'), format='json') | ||||||
|         token_url = reverse('api-token') |  | ||||||
|  |  | ||||||
|         response = self.client.post(token_url, format='json', data={'username': 'bad', 'password': 'also_bad'}) |  | ||||||
|  |  | ||||||
|         self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) |  | ||||||
|         self.assertFalse('token' in response.data) |  | ||||||
|  |  | ||||||
|     def test_get_token_pass(self): |  | ||||||
|         """ Ensure that a valid user can request an API token """ |  | ||||||
|  |  | ||||||
|         token_url = reverse('api-token') |  | ||||||
|  |  | ||||||
|         # POST to retreive a token |  | ||||||
|         response = self.client.post(token_url, format='json', data={'username': self.username, 'password': self.password}) |  | ||||||
|  |  | ||||||
|         self.assertEqual(response.status_code, status.HTTP_200_OK) |  | ||||||
|         self.assertTrue('token' in response.data) |  | ||||||
|         self.assertTrue('pk' in response.data) |  | ||||||
|         self.assertTrue(len(response.data['token']) > 0) |  | ||||||
|  |  | ||||||
|         # Now, use the token to access other data |  | ||||||
|         token = response.data['token'] |  | ||||||
|  |  | ||||||
|         part_url = reverse('api-part-list') |  | ||||||
|  |  | ||||||
|         # Try to access without a token |  | ||||||
|         response = self.client.get(part_url, format='json') |  | ||||||
|  |  | ||||||
|         self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) |         self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) | ||||||
|  |  | ||||||
|         # Now, with the token |  | ||||||
|         self.client.credentials(HTTP_AUTHORIZATION='Token ' + token) |  | ||||||
|         response = self.client.get(part_url, format='json') |  | ||||||
|  |  | ||||||
|         self.assertEqual(response.status_code, status.HTTP_200_OK) |  | ||||||
|  |  | ||||||
|     def test_barcode(self): |     def test_barcode(self): | ||||||
|         """ Test the barcode endpoint """ |         """ Test the barcode endpoint """ | ||||||
|  |  | ||||||
|         url = reverse('api-barcode-plugin') |         self.tokenAuth() | ||||||
|  |  | ||||||
|         self.get_token() |         url = reverse('api-barcode-plugin') | ||||||
|  |  | ||||||
|         data = { |         data = { | ||||||
|             'barcode': { |             'barcode': { | ||||||
|   | |||||||
| @@ -3,7 +3,7 @@ from django.contrib.auth.models import User | |||||||
| from django.core.exceptions import ObjectDoesNotExist | from django.core.exceptions import ObjectDoesNotExist | ||||||
| from .serializers import UserSerializer | from .serializers import UserSerializer | ||||||
|  |  | ||||||
| from rest_framework.authtoken.views import ObtainAuthToken | from rest_framework.views import APIView | ||||||
| from rest_framework.authtoken.models import Token | from rest_framework.authtoken.models import Token | ||||||
| from rest_framework.response import Response | from rest_framework.response import Response | ||||||
| from rest_framework import status | from rest_framework import status | ||||||
| @@ -25,27 +25,31 @@ class UserList(generics.ListAPIView): | |||||||
|     permission_classes = (permissions.IsAuthenticated,) |     permission_classes = (permissions.IsAuthenticated,) | ||||||
|  |  | ||||||
|  |  | ||||||
| class GetAuthToken(ObtainAuthToken): | class GetAuthToken(APIView): | ||||||
|     """ Return authentication token for an authenticated user. """ |     """ Return authentication token for an authenticated user. """ | ||||||
|  |  | ||||||
|     def post(self, request, *args, **kwargs): |     permission_classes = [ | ||||||
|  |         permissions.IsAuthenticated, | ||||||
|  |     ] | ||||||
|  |  | ||||||
|  |     def get(self, request, *args, **kwargs): | ||||||
|         return self.login(request) |         return self.login(request) | ||||||
|  |  | ||||||
|     def delete(self, request): |     def delete(self, request): | ||||||
|         return self.logout(request) |         return self.logout(request) | ||||||
|  |  | ||||||
|     def login(self, request): |     def login(self, request): | ||||||
|         serializer = self.serializer_class(data=request.data, |  | ||||||
|                                            context={'request': request}) |  | ||||||
|         serializer.is_valid(raise_exception=True) |  | ||||||
|         user = serializer.validated_data['user'] |  | ||||||
|         token, created = Token.objects.get_or_create(user=user) |  | ||||||
|  |  | ||||||
|  |         if request.user.is_authenticated: | ||||||
|  |             # Get the user token (or create one if it does not exist) | ||||||
|  |             token, created = Token.objects.get_or_create(user=request.user) | ||||||
|             return Response({ |             return Response({ | ||||||
|                 'token': token.key, |                 'token': token.key, | ||||||
|             'pk': user.pk, |             }) | ||||||
|             'username': user.username, |  | ||||||
|             'email': user.email |         else: | ||||||
|  |             return Response({ | ||||||
|  |                 'error': 'User not authenticated', | ||||||
|             }) |             }) | ||||||
|  |  | ||||||
|     def logout(self, request): |     def logout(self, request): | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user