diff --git a/InvenTree/InvenTree/tests.py b/InvenTree/InvenTree/tests.py index 1a8fb3ed1c..bd5ebf709b 100644 --- a/InvenTree/InvenTree/tests.py +++ b/InvenTree/InvenTree/tests.py @@ -40,6 +40,147 @@ from .tasks import offload_task from .validators import validate_overage +class HostTest(InvenTreeTestCase): + """Test for host configuration.""" + + @override_settings(ALLOWED_HOSTS=['testserver']) + def test_allowed_hosts(self): + """Test that the ALLOWED_HOSTS functions as expected.""" + self.assertIn('testserver', settings.ALLOWED_HOSTS) + + response = self.client.get('/api/', headers={'host': 'testserver'}) + + self.assertEqual(response.status_code, 200) + + response = self.client.get('/api/', headers={'host': 'invalidserver'}) + + self.assertEqual(response.status_code, 400) + + @override_settings(ALLOWED_HOSTS=['invalidserver.co.uk']) + def test_allowed_hosts_2(self): + """Another test for ALLOWED_HOSTS functionality.""" + response = self.client.get('/api/', headers={'host': 'invalidserver.co.uk'}) + + self.assertEqual(response.status_code, 200) + + +class CorsTest(TestCase): + """Unit tests for CORS functionality.""" + + def cors_headers(self): + """Return a list of CORS headers.""" + return [ + 'access-control-allow-origin', + 'access-control-allow-credentials', + 'access-control-allow-methods', + 'access-control-allow-headers', + ] + + def preflight(self, url, origin, method='GET'): + """Make a CORS preflight request to the specified URL.""" + headers = {'origin': origin, 'access-control-request-method': method} + + return self.client.options(url, headers=headers) + + def test_no_origin(self): + """Test that CORS headers are not included for regular requests. + + - We use the /api/ endpoint for this test (it does not require auth) + - By default, in debug mode *all* CORS origins are allowed + """ + # Perform an initial response without the "origin" header + response = self.client.get('/api/') + self.assertEqual(response.status_code, 200) + + for header in self.cors_headers(): + self.assertNotIn(header, response.headers) + + # Now, perform a "preflight" request with the "origin" header + response = self.preflight('/api/', origin='http://random-external-server.com') + self.assertEqual(response.status_code, 200) + + for header in self.cors_headers(): + self.assertIn(header, response.headers) + + self.assertEqual(response.headers['content-length'], '0') + self.assertEqual( + response.headers['access-control-allow-origin'], + 'http://random-external-server.com', + ) + + @override_settings( + CORS_ALLOW_ALL_ORIGINS=False, + CORS_ALLOWED_ORIGINS=['http://my-external-server.com'], + CORS_ALLOWED_ORIGIN_REGEXES=[], + ) + def test_auth_view(self): + """Test that CORS requests work for the /auth/ view. + + Here, we are not authorized by default, + but the CORS headers should still be included. + """ + url = '/auth/' + + # First, a preflight request with a "valid" origin + + response = self.preflight(url, origin='http://my-external-server.com') + + self.assertEqual(response.status_code, 200) + + for header in self.cors_headers(): + self.assertIn(header, response.headers) + + # Next, a preflight request with an "invalid" origin + response = self.preflight(url, origin='http://random-external-server.com') + + self.assertEqual(response.status_code, 200) + + for header in self.cors_headers(): + self.assertNotIn(header, response.headers) + + # Next, make a GET request (without a token) + response = self.client.get( + url, headers={'origin': 'http://my-external-server.com'} + ) + + # Unauthorized + self.assertEqual(response.status_code, 401) + + self.assertIn('access-control-allow-origin', response.headers) + self.assertNotIn('access-control-allow-methods', response.headers) + + @override_settings( + CORS_ALLOW_ALL_ORIGINS=False, + CORS_ALLOWED_ORIGINS=[], + CORS_ALLOWED_ORIGIN_REGEXES=['http://.*myserver.com'], + ) + def test_cors_regex(self): + """Test that CORS regexes work as expected.""" + valid_urls = [ + 'http://www.myserver.com', + 'http://test.myserver.com', + 'http://myserver.com', + 'http://www.myserver.com:8080', + ] + + invalid_urls = [ + 'http://myserver.org', + 'http://www.other-server.org', + 'http://google.com', + 'http://myserver.co.uk:8080', + ] + + for url in valid_urls: + response = self.preflight('/api/', origin=url) + self.assertEqual(response.status_code, 200) + self.assertIn('access-control-allow-origin', response.headers) + + for url in invalid_urls: + response = self.preflight('/api/', origin=url) + self.assertEqual(response.status_code, 200) + self.assertNotIn('access-control-allow-origin', response.headers) + + class ConversionTest(TestCase): """Tests for conversion of physical units.""" @@ -912,6 +1053,7 @@ class TestVersionNumber(TestCase): hash = str( subprocess.check_output('git rev-parse --short HEAD'.split()), 'utf-8' ).strip() + self.assertEqual(hash, version.inventreeCommitHash()) d = (