From 3c17367e3c2b23631a5d1d7d9b7d93c4174b8eb1 Mon Sep 17 00:00:00 2001 From: Oliver Date: Mon, 15 Jun 2026 21:03:44 +1000 Subject: [PATCH] Data import permissions (#12169) * Update data importer child permissions - Row data - Column data * Add unit tests * Cleanup session data after import is completed * Further scope narrowing --- src/backend/InvenTree/importer/api.py | 46 +++++- src/backend/InvenTree/importer/models.py | 11 +- src/backend/InvenTree/importer/tests.py | 198 ++++++++++++++++++++++- 3 files changed, 243 insertions(+), 12 deletions(-) diff --git a/src/backend/InvenTree/importer/api.py b/src/backend/InvenTree/importer/api.py index c64f71c654..fd8c6024fd 100644 --- a/src/backend/InvenTree/importer/api.py +++ b/src/backend/InvenTree/importer/api.py @@ -115,6 +115,10 @@ class DataImportSessionAcceptFields(APIView): """Accept the field mapping for a DataImportSession.""" session = get_object_or_404(importer.models.DataImportSession, pk=pk) + # Check session ownership + if not request.user.is_staff and session.user != request.user: + raise PermissionDenied() + # Check that the user has permission to accept the field mapping if model_class := session.model_class: if not check_user_permission(request.user, model_class, 'change'): @@ -137,17 +141,45 @@ class DataImportSessionAcceptRows(DataImporterPermissionMixin, CreateAPI): ctx = super().get_serializer_context() try: - ctx['session'] = importer.models.DataImportSession.objects.get( + session = importer.models.DataImportSession.objects.get( pk=self.kwargs.get('pk', None) ) - except Exception: - pass + except importer.models.DataImportSession.DoesNotExist: + session = None + + if session: + user = self.request.user + if not user.is_staff and session.user != user: + raise PermissionDenied() + ctx['session'] = session ctx['request'] = self.request return ctx -class DataImportColumnMappingList(DataImporterPermissionMixin, ListAPI): +class DataImportSessionChildMixin(DataImporterPermissionMixin): + """Mixin for DataImportRow and DataImportColumnMap views. + + Ensures users can only access objects that belong to an import session they own. + Staff users retain access to all objects. + """ + + def get_queryset(self): + """Return only objects whose session belongs to the requesting user.""" + queryset = super().get_queryset() + + try: + user = self.request.user + except AttributeError: + raise PermissionDenied('User information is not available') + + if user.is_staff: + return queryset + + return queryset.filter(session__user=user) + + +class DataImportColumnMappingList(DataImportSessionChildMixin, ListAPI): """API endpoint for accessing a list of DataImportColumnMap objects.""" queryset = importer.models.DataImportColumnMap.objects.all() @@ -158,14 +190,14 @@ class DataImportColumnMappingList(DataImporterPermissionMixin, ListAPI): filterset_fields = ['session'] -class DataImportColumnMappingDetail(DataImporterPermissionMixin, RetrieveUpdateAPI): +class DataImportColumnMappingDetail(DataImportSessionChildMixin, RetrieveUpdateAPI): """Detail endpoint for a single DataImportColumnMap object.""" queryset = importer.models.DataImportColumnMap.objects.all() serializer_class = importer.serializers.DataImportColumnMapSerializer -class DataImportRowList(DataImporterPermissionMixin, BulkDeleteMixin, ListAPI): +class DataImportRowList(DataImportSessionChildMixin, BulkDeleteMixin, ListAPI): """API endpoint for accessing a list of DataImportRow objects.""" queryset = importer.models.DataImportRow.objects.all() @@ -180,7 +212,7 @@ class DataImportRowList(DataImporterPermissionMixin, BulkDeleteMixin, ListAPI): ordering = 'row_index' -class DataImportRowDetail(DataImporterPermissionMixin, RetrieveUpdateDestroyAPI): +class DataImportRowDetail(DataImportSessionChildMixin, RetrieveUpdateDestroyAPI): """Detail endpoint for a single DataImportRow object.""" queryset = importer.models.DataImportRow.objects.all() diff --git a/src/backend/InvenTree/importer/models.py b/src/backend/InvenTree/importer/models.py index 8dae1376ab..b25644c5b2 100644 --- a/src/backend/InvenTree/importer/models.py +++ b/src/backend/InvenTree/importer/models.py @@ -344,14 +344,21 @@ class DataImportSession(models.Model): self.save() def check_complete(self) -> bool: - """Check if the import session is complete.""" + """Check if the import session is complete. + + When all rows have been accepted, the rows and column mappings are + deleted as they are no longer needed. The session itself is retained + as an audit record. + """ if self.completed_row_count < self.row_count: return False - # Update the status of this session if self.status != DataImportStatusCode.COMPLETE.value: self.status = DataImportStatusCode.COMPLETE.value self.save() + # Clear staging data now that all rows have been imported + self.rows.all().delete() + self.column_mappings.all().delete() return True diff --git a/src/backend/InvenTree/importer/tests.py b/src/backend/InvenTree/importer/tests.py index 17490b9283..17cde8ccbb 100644 --- a/src/backend/InvenTree/importer/tests.py +++ b/src/backend/InvenTree/importer/tests.py @@ -2,10 +2,11 @@ import os +from django.contrib.auth.models import User from django.core.files.base import ContentFile from django.urls import reverse -from importer.models import DataImportRow, DataImportSession +from importer.models import DataImportColumnMap, DataImportRow, DataImportSession from InvenTree.unit_test import AdminTestCase, InvenTreeAPITestCase, InvenTreeTestCase @@ -58,14 +59,20 @@ class ImporterTest(ImporterMixin, InvenTreeTestCase): self.assertEqual(session.rows.count(), 12) # Check that some data has been imported - for row in session.rows.all(): + rows = list(session.rows.all()) + self.assertEqual(len(rows), 12) + + for row in rows: self.assertIsNotNone(row.data.get('name', None)) self.assertTrue(row.valid) row.validate(commit=True) self.assertTrue(row.complete) - self.assertEqual(session.completed_row_count, 12) + # All rows accepted: rows and mappings are cleared, session is retained + session.refresh_from_db() + self.assertEqual(session.rows.count(), 0) + self.assertEqual(session.column_mappings.count(), 0) # Check that the new companies have been created self.assertEqual(n + 12, Company.objects.count()) @@ -204,6 +211,191 @@ class ImportAPITest(ImporterMixin, InvenTreeAPITestCase): for session in response.data: self.assertEqual(session['user'], self.user.pk) + def test_accept_fields_ownership(self): + """Test that accept_fields rejects requests for sessions owned by another user.""" + other_user = User.objects.create_user( + username='other_accept', password='password' + ) + + f = self.helper_file('companies.csv') + session = DataImportSession.objects.create( + data_file=f, model_type='company', user=other_user + ) + + url = reverse('api-import-session-accept-fields', kwargs={'pk': session.pk}) + + # Non-owner, non-staff should be denied + self.user.is_staff = False + self.user.save() + self.post(url, expected_code=403) + + # Staff should be allowed (subject to model permission) + # Company is part of the purchase_order ruleset + self.user.is_staff = True + self.user.save() + self.assignRole('purchase_order.change') + self.post(url, expected_code=200) + + def test_accept_rows_ownership(self): + """Test that accept_rows rejects requests for sessions owned by another user.""" + other_user = User.objects.create_user( + username='other_accept_rows', password='password' + ) + + f = self.helper_file('companies.csv') + session = DataImportSession.objects.create( + data_file=f, model_type='company', user=other_user + ) + session.extract_columns() + + url = reverse('api-import-session-accept-rows', kwargs={'pk': session.pk}) + + self.user.is_staff = False + self.user.save() + self.post(url, {'rows': []}, expected_code=403) + + # Staff can reach the endpoint (rows list is empty so validation rejects with 400, not 403) + self.user.is_staff = True + self.user.save() + self.post(url, {'rows': []}, expected_code=400) + + def test_session_cleanup_on_complete(self): + """Test that a completed import session deletes itself and all associated data.""" + url = reverse('api-importer-session-list') + data_file = self.helper_file('part_categories.csv') + + data = self.post( + url, + {'model_type': 'partcategory', 'data_file': data_file}, + format='multipart', + ).data + + session_id = data['pk'] + session_pk = session_id + + self.assignRole('part_category.add') + self.post( + reverse('api-import-session-accept-fields', kwargs={'pk': session_id}), + expected_code=200, + ) + + rows = self.get( + reverse('api-importer-row-list'), data={'session': session_id} + ).data + row_ids = [r['pk'] for r in rows] + self.assertGreater(len(row_ids), 0) + + # Confirm rows and mappings exist before acceptance + self.assertTrue(DataImportRow.objects.filter(session_id=session_pk).exists()) + self.assertTrue( + DataImportColumnMap.objects.filter(session_id=session_pk).exists() + ) + + # Accept all rows — this should trigger cleanup of rows and mappings + self.post( + reverse('api-import-session-accept-rows', kwargs={'pk': session_id}), + {'rows': row_ids}, + ) + + # Rows and column mappings must be cleared + self.assertFalse(DataImportRow.objects.filter(session_id=session_pk).exists()) + self.assertFalse( + DataImportColumnMap.objects.filter(session_id=session_pk).exists() + ) + + # Session itself is retained as an audit record with COMPLETE status + from importer.models import DataImportSession + from importer.status_codes import DataImportStatusCode + + session_obj = DataImportSession.objects.get(pk=session_pk) + self.assertEqual(session_obj.status, DataImportStatusCode.COMPLETE.value) + + detail = self.get( + reverse('api-import-session-detail', kwargs={'pk': session_id}), + expected_code=200, + ).data + self.assertEqual(detail['row_count'], 0) + self.assertEqual(detail['completed_row_count'], 0) + + def test_row_and_mapping_ownership(self): + """Test that DataImportRow and DataImportColumnMap endpoints filter by session ownership.""" + f = self.helper_file('companies.csv') + + other_user = User.objects.create_user( + username='other_importer', password='password' + ) + + # Session owned by self.user + session_mine = DataImportSession.objects.create( + data_file=f, model_type='company', user=self.user + ) + session_mine.extract_columns() + + # Session owned by another user + f2 = self.helper_file('companies.csv') + session_other = DataImportSession.objects.create( + data_file=f2, model_type='company', user=other_user + ) + session_other.extract_columns() + + row_list_url = reverse('api-importer-row-list') + mapping_list_url = reverse('api-importer-mapping-list') + + # Non-staff: should only see rows/mappings from own session + self.user.is_staff = False + self.user.save() + + rows = self.get(row_list_url).data + for row in rows: + self.assertEqual(row['session'], session_mine.pk) + + mappings = self.get(mapping_list_url).data + for mapping in mappings: + self.assertEqual(mapping['session'], session_mine.pk) + + # Detail endpoint: own session's row/mapping should be accessible + own_row = DataImportRow.objects.filter(session=session_mine).first() + other_row = DataImportRow.objects.filter(session=session_other).first() + + if own_row: + self.get( + reverse('api-importer-row-detail', kwargs={'pk': own_row.pk}), + expected_code=200, + ) + if other_row: + self.get( + reverse('api-importer-row-detail', kwargs={'pk': other_row.pk}), + expected_code=404, + ) + + own_mapping = DataImportColumnMap.objects.filter(session=session_mine).first() + other_mapping = DataImportColumnMap.objects.filter( + session=session_other + ).first() + + if own_mapping: + self.get( + reverse('api-importer-mapping-detail', kwargs={'pk': own_mapping.pk}), + expected_code=200, + ) + if other_mapping: + self.get( + reverse('api-importer-mapping-detail', kwargs={'pk': other_mapping.pk}), + expected_code=404, + ) + + # Staff user: should see rows/mappings from all sessions + self.user.is_staff = True + self.user.save() + + all_row_pks = set(DataImportRow.objects.values_list('pk', flat=True)) + response_rows = self.get(row_list_url).data + self.assertEqual({r['pk'] for r in response_rows}, all_row_pks) + + all_mapping_pks = set(DataImportColumnMap.objects.values_list('pk', flat=True)) + response_mappings = self.get(mapping_list_url).data + self.assertEqual({m['pk'] for m in response_mappings}, all_mapping_pks) + class AdminTest(ImporterMixin, AdminTestCase): """Tests for the admin interface integration."""