From 2fef34852c5c855c9be1793a120034cf8d1ff622 Mon Sep 17 00:00:00 2001
From: Oliver <oliver.henry.walters@gmail.com>
Date: Wed, 13 Mar 2024 20:37:05 +1100
Subject: [PATCH] Unit tests for HOST settings (#6698)

- CORS
- ALLOWED_HOSTS
---
 InvenTree/InvenTree/tests.py | 142 +++++++++++++++++++++++++++++++++++
 1 file changed, 142 insertions(+)

diff --git a/InvenTree/InvenTree/tests.py b/InvenTree/InvenTree/tests.py
index 1a8fb3ed1c..bd5ebf709b 100644
--- a/InvenTree/InvenTree/tests.py
+++ b/InvenTree/InvenTree/tests.py
@@ -40,6 +40,147 @@ from .tasks import offload_task
 from .validators import validate_overage
 
 
+class HostTest(InvenTreeTestCase):
+    """Test for host configuration."""
+
+    @override_settings(ALLOWED_HOSTS=['testserver'])
+    def test_allowed_hosts(self):
+        """Test that the ALLOWED_HOSTS functions as expected."""
+        self.assertIn('testserver', settings.ALLOWED_HOSTS)
+
+        response = self.client.get('/api/', headers={'host': 'testserver'})
+
+        self.assertEqual(response.status_code, 200)
+
+        response = self.client.get('/api/', headers={'host': 'invalidserver'})
+
+        self.assertEqual(response.status_code, 400)
+
+    @override_settings(ALLOWED_HOSTS=['invalidserver.co.uk'])
+    def test_allowed_hosts_2(self):
+        """Another test for ALLOWED_HOSTS functionality."""
+        response = self.client.get('/api/', headers={'host': 'invalidserver.co.uk'})
+
+        self.assertEqual(response.status_code, 200)
+
+
+class CorsTest(TestCase):
+    """Unit tests for CORS functionality."""
+
+    def cors_headers(self):
+        """Return a list of CORS headers."""
+        return [
+            'access-control-allow-origin',
+            'access-control-allow-credentials',
+            'access-control-allow-methods',
+            'access-control-allow-headers',
+        ]
+
+    def preflight(self, url, origin, method='GET'):
+        """Make a CORS preflight request to the specified URL."""
+        headers = {'origin': origin, 'access-control-request-method': method}
+
+        return self.client.options(url, headers=headers)
+
+    def test_no_origin(self):
+        """Test that CORS headers are not included for regular requests.
+
+        - We use the /api/ endpoint for this test (it does not require auth)
+        - By default, in debug mode *all* CORS origins are allowed
+        """
+        # Perform an initial response without the "origin" header
+        response = self.client.get('/api/')
+        self.assertEqual(response.status_code, 200)
+
+        for header in self.cors_headers():
+            self.assertNotIn(header, response.headers)
+
+        # Now, perform a "preflight" request with the "origin" header
+        response = self.preflight('/api/', origin='http://random-external-server.com')
+        self.assertEqual(response.status_code, 200)
+
+        for header in self.cors_headers():
+            self.assertIn(header, response.headers)
+
+        self.assertEqual(response.headers['content-length'], '0')
+        self.assertEqual(
+            response.headers['access-control-allow-origin'],
+            'http://random-external-server.com',
+        )
+
+    @override_settings(
+        CORS_ALLOW_ALL_ORIGINS=False,
+        CORS_ALLOWED_ORIGINS=['http://my-external-server.com'],
+        CORS_ALLOWED_ORIGIN_REGEXES=[],
+    )
+    def test_auth_view(self):
+        """Test that CORS requests work for the /auth/ view.
+
+        Here, we are not authorized by default,
+        but the CORS headers should still be included.
+        """
+        url = '/auth/'
+
+        # First, a preflight request with a "valid" origin
+
+        response = self.preflight(url, origin='http://my-external-server.com')
+
+        self.assertEqual(response.status_code, 200)
+
+        for header in self.cors_headers():
+            self.assertIn(header, response.headers)
+
+        # Next, a preflight request with an "invalid" origin
+        response = self.preflight(url, origin='http://random-external-server.com')
+
+        self.assertEqual(response.status_code, 200)
+
+        for header in self.cors_headers():
+            self.assertNotIn(header, response.headers)
+
+        # Next, make a GET request (without a token)
+        response = self.client.get(
+            url, headers={'origin': 'http://my-external-server.com'}
+        )
+
+        # Unauthorized
+        self.assertEqual(response.status_code, 401)
+
+        self.assertIn('access-control-allow-origin', response.headers)
+        self.assertNotIn('access-control-allow-methods', response.headers)
+
+    @override_settings(
+        CORS_ALLOW_ALL_ORIGINS=False,
+        CORS_ALLOWED_ORIGINS=[],
+        CORS_ALLOWED_ORIGIN_REGEXES=['http://.*myserver.com'],
+    )
+    def test_cors_regex(self):
+        """Test that CORS regexes work as expected."""
+        valid_urls = [
+            'http://www.myserver.com',
+            'http://test.myserver.com',
+            'http://myserver.com',
+            'http://www.myserver.com:8080',
+        ]
+
+        invalid_urls = [
+            'http://myserver.org',
+            'http://www.other-server.org',
+            'http://google.com',
+            'http://myserver.co.uk:8080',
+        ]
+
+        for url in valid_urls:
+            response = self.preflight('/api/', origin=url)
+            self.assertEqual(response.status_code, 200)
+            self.assertIn('access-control-allow-origin', response.headers)
+
+        for url in invalid_urls:
+            response = self.preflight('/api/', origin=url)
+            self.assertEqual(response.status_code, 200)
+            self.assertNotIn('access-control-allow-origin', response.headers)
+
+
 class ConversionTest(TestCase):
     """Tests for conversion of physical units."""
 
@@ -912,6 +1053,7 @@ class TestVersionNumber(TestCase):
         hash = str(
             subprocess.check_output('git rev-parse --short HEAD'.split()), 'utf-8'
         ).strip()
+
         self.assertEqual(hash, version.inventreeCommitHash())
 
         d = (