2
0
mirror of https://github.com/inventree/InvenTree.git synced 2026-07-04 06:00:38 +00:00

Data import permissions (#12169)

* Update data importer child permissions

- Row data
- Column data

* Add unit tests

* Cleanup session data after import is completed

* Further scope narrowing
This commit is contained in:
Oliver
2026-06-15 21:03:44 +10:00
committed by GitHub
parent aece90512c
commit 3c17367e3c
3 changed files with 243 additions and 12 deletions
+39 -7
View File
@@ -115,6 +115,10 @@ class DataImportSessionAcceptFields(APIView):
"""Accept the field mapping for a DataImportSession.""" """Accept the field mapping for a DataImportSession."""
session = get_object_or_404(importer.models.DataImportSession, pk=pk) session = get_object_or_404(importer.models.DataImportSession, pk=pk)
# Check session ownership
if not request.user.is_staff and session.user != request.user:
raise PermissionDenied()
# Check that the user has permission to accept the field mapping # Check that the user has permission to accept the field mapping
if model_class := session.model_class: if model_class := session.model_class:
if not check_user_permission(request.user, model_class, 'change'): if not check_user_permission(request.user, model_class, 'change'):
@@ -137,17 +141,45 @@ class DataImportSessionAcceptRows(DataImporterPermissionMixin, CreateAPI):
ctx = super().get_serializer_context() ctx = super().get_serializer_context()
try: try:
ctx['session'] = importer.models.DataImportSession.objects.get( session = importer.models.DataImportSession.objects.get(
pk=self.kwargs.get('pk', None) pk=self.kwargs.get('pk', None)
) )
except Exception: except importer.models.DataImportSession.DoesNotExist:
pass session = None
if session:
user = self.request.user
if not user.is_staff and session.user != user:
raise PermissionDenied()
ctx['session'] = session
ctx['request'] = self.request ctx['request'] = self.request
return ctx return ctx
class DataImportColumnMappingList(DataImporterPermissionMixin, ListAPI): class DataImportSessionChildMixin(DataImporterPermissionMixin):
"""Mixin for DataImportRow and DataImportColumnMap views.
Ensures users can only access objects that belong to an import session they own.
Staff users retain access to all objects.
"""
def get_queryset(self):
"""Return only objects whose session belongs to the requesting user."""
queryset = super().get_queryset()
try:
user = self.request.user
except AttributeError:
raise PermissionDenied('User information is not available')
if user.is_staff:
return queryset
return queryset.filter(session__user=user)
class DataImportColumnMappingList(DataImportSessionChildMixin, ListAPI):
"""API endpoint for accessing a list of DataImportColumnMap objects.""" """API endpoint for accessing a list of DataImportColumnMap objects."""
queryset = importer.models.DataImportColumnMap.objects.all() queryset = importer.models.DataImportColumnMap.objects.all()
@@ -158,14 +190,14 @@ class DataImportColumnMappingList(DataImporterPermissionMixin, ListAPI):
filterset_fields = ['session'] filterset_fields = ['session']
class DataImportColumnMappingDetail(DataImporterPermissionMixin, RetrieveUpdateAPI): class DataImportColumnMappingDetail(DataImportSessionChildMixin, RetrieveUpdateAPI):
"""Detail endpoint for a single DataImportColumnMap object.""" """Detail endpoint for a single DataImportColumnMap object."""
queryset = importer.models.DataImportColumnMap.objects.all() queryset = importer.models.DataImportColumnMap.objects.all()
serializer_class = importer.serializers.DataImportColumnMapSerializer serializer_class = importer.serializers.DataImportColumnMapSerializer
class DataImportRowList(DataImporterPermissionMixin, BulkDeleteMixin, ListAPI): class DataImportRowList(DataImportSessionChildMixin, BulkDeleteMixin, ListAPI):
"""API endpoint for accessing a list of DataImportRow objects.""" """API endpoint for accessing a list of DataImportRow objects."""
queryset = importer.models.DataImportRow.objects.all() queryset = importer.models.DataImportRow.objects.all()
@@ -180,7 +212,7 @@ class DataImportRowList(DataImporterPermissionMixin, BulkDeleteMixin, ListAPI):
ordering = 'row_index' ordering = 'row_index'
class DataImportRowDetail(DataImporterPermissionMixin, RetrieveUpdateDestroyAPI): class DataImportRowDetail(DataImportSessionChildMixin, RetrieveUpdateDestroyAPI):
"""Detail endpoint for a single DataImportRow object.""" """Detail endpoint for a single DataImportRow object."""
queryset = importer.models.DataImportRow.objects.all() queryset = importer.models.DataImportRow.objects.all()
+9 -2
View File
@@ -344,14 +344,21 @@ class DataImportSession(models.Model):
self.save() self.save()
def check_complete(self) -> bool: def check_complete(self) -> bool:
"""Check if the import session is complete.""" """Check if the import session is complete.
When all rows have been accepted, the rows and column mappings are
deleted as they are no longer needed. The session itself is retained
as an audit record.
"""
if self.completed_row_count < self.row_count: if self.completed_row_count < self.row_count:
return False return False
# Update the status of this session
if self.status != DataImportStatusCode.COMPLETE.value: if self.status != DataImportStatusCode.COMPLETE.value:
self.status = DataImportStatusCode.COMPLETE.value self.status = DataImportStatusCode.COMPLETE.value
self.save() self.save()
# Clear staging data now that all rows have been imported
self.rows.all().delete()
self.column_mappings.all().delete()
return True return True
+195 -3
View File
@@ -2,10 +2,11 @@
import os import os
from django.contrib.auth.models import User
from django.core.files.base import ContentFile from django.core.files.base import ContentFile
from django.urls import reverse from django.urls import reverse
from importer.models import DataImportRow, DataImportSession from importer.models import DataImportColumnMap, DataImportRow, DataImportSession
from InvenTree.unit_test import AdminTestCase, InvenTreeAPITestCase, InvenTreeTestCase from InvenTree.unit_test import AdminTestCase, InvenTreeAPITestCase, InvenTreeTestCase
@@ -58,14 +59,20 @@ class ImporterTest(ImporterMixin, InvenTreeTestCase):
self.assertEqual(session.rows.count(), 12) self.assertEqual(session.rows.count(), 12)
# Check that some data has been imported # Check that some data has been imported
for row in session.rows.all(): rows = list(session.rows.all())
self.assertEqual(len(rows), 12)
for row in rows:
self.assertIsNotNone(row.data.get('name', None)) self.assertIsNotNone(row.data.get('name', None))
self.assertTrue(row.valid) self.assertTrue(row.valid)
row.validate(commit=True) row.validate(commit=True)
self.assertTrue(row.complete) self.assertTrue(row.complete)
self.assertEqual(session.completed_row_count, 12) # All rows accepted: rows and mappings are cleared, session is retained
session.refresh_from_db()
self.assertEqual(session.rows.count(), 0)
self.assertEqual(session.column_mappings.count(), 0)
# Check that the new companies have been created # Check that the new companies have been created
self.assertEqual(n + 12, Company.objects.count()) self.assertEqual(n + 12, Company.objects.count())
@@ -204,6 +211,191 @@ class ImportAPITest(ImporterMixin, InvenTreeAPITestCase):
for session in response.data: for session in response.data:
self.assertEqual(session['user'], self.user.pk) self.assertEqual(session['user'], self.user.pk)
def test_accept_fields_ownership(self):
"""Test that accept_fields rejects requests for sessions owned by another user."""
other_user = User.objects.create_user(
username='other_accept', password='password'
)
f = self.helper_file('companies.csv')
session = DataImportSession.objects.create(
data_file=f, model_type='company', user=other_user
)
url = reverse('api-import-session-accept-fields', kwargs={'pk': session.pk})
# Non-owner, non-staff should be denied
self.user.is_staff = False
self.user.save()
self.post(url, expected_code=403)
# Staff should be allowed (subject to model permission)
# Company is part of the purchase_order ruleset
self.user.is_staff = True
self.user.save()
self.assignRole('purchase_order.change')
self.post(url, expected_code=200)
def test_accept_rows_ownership(self):
"""Test that accept_rows rejects requests for sessions owned by another user."""
other_user = User.objects.create_user(
username='other_accept_rows', password='password'
)
f = self.helper_file('companies.csv')
session = DataImportSession.objects.create(
data_file=f, model_type='company', user=other_user
)
session.extract_columns()
url = reverse('api-import-session-accept-rows', kwargs={'pk': session.pk})
self.user.is_staff = False
self.user.save()
self.post(url, {'rows': []}, expected_code=403)
# Staff can reach the endpoint (rows list is empty so validation rejects with 400, not 403)
self.user.is_staff = True
self.user.save()
self.post(url, {'rows': []}, expected_code=400)
def test_session_cleanup_on_complete(self):
"""Test that a completed import session deletes itself and all associated data."""
url = reverse('api-importer-session-list')
data_file = self.helper_file('part_categories.csv')
data = self.post(
url,
{'model_type': 'partcategory', 'data_file': data_file},
format='multipart',
).data
session_id = data['pk']
session_pk = session_id
self.assignRole('part_category.add')
self.post(
reverse('api-import-session-accept-fields', kwargs={'pk': session_id}),
expected_code=200,
)
rows = self.get(
reverse('api-importer-row-list'), data={'session': session_id}
).data
row_ids = [r['pk'] for r in rows]
self.assertGreater(len(row_ids), 0)
# Confirm rows and mappings exist before acceptance
self.assertTrue(DataImportRow.objects.filter(session_id=session_pk).exists())
self.assertTrue(
DataImportColumnMap.objects.filter(session_id=session_pk).exists()
)
# Accept all rows — this should trigger cleanup of rows and mappings
self.post(
reverse('api-import-session-accept-rows', kwargs={'pk': session_id}),
{'rows': row_ids},
)
# Rows and column mappings must be cleared
self.assertFalse(DataImportRow.objects.filter(session_id=session_pk).exists())
self.assertFalse(
DataImportColumnMap.objects.filter(session_id=session_pk).exists()
)
# Session itself is retained as an audit record with COMPLETE status
from importer.models import DataImportSession
from importer.status_codes import DataImportStatusCode
session_obj = DataImportSession.objects.get(pk=session_pk)
self.assertEqual(session_obj.status, DataImportStatusCode.COMPLETE.value)
detail = self.get(
reverse('api-import-session-detail', kwargs={'pk': session_id}),
expected_code=200,
).data
self.assertEqual(detail['row_count'], 0)
self.assertEqual(detail['completed_row_count'], 0)
def test_row_and_mapping_ownership(self):
"""Test that DataImportRow and DataImportColumnMap endpoints filter by session ownership."""
f = self.helper_file('companies.csv')
other_user = User.objects.create_user(
username='other_importer', password='password'
)
# Session owned by self.user
session_mine = DataImportSession.objects.create(
data_file=f, model_type='company', user=self.user
)
session_mine.extract_columns()
# Session owned by another user
f2 = self.helper_file('companies.csv')
session_other = DataImportSession.objects.create(
data_file=f2, model_type='company', user=other_user
)
session_other.extract_columns()
row_list_url = reverse('api-importer-row-list')
mapping_list_url = reverse('api-importer-mapping-list')
# Non-staff: should only see rows/mappings from own session
self.user.is_staff = False
self.user.save()
rows = self.get(row_list_url).data
for row in rows:
self.assertEqual(row['session'], session_mine.pk)
mappings = self.get(mapping_list_url).data
for mapping in mappings:
self.assertEqual(mapping['session'], session_mine.pk)
# Detail endpoint: own session's row/mapping should be accessible
own_row = DataImportRow.objects.filter(session=session_mine).first()
other_row = DataImportRow.objects.filter(session=session_other).first()
if own_row:
self.get(
reverse('api-importer-row-detail', kwargs={'pk': own_row.pk}),
expected_code=200,
)
if other_row:
self.get(
reverse('api-importer-row-detail', kwargs={'pk': other_row.pk}),
expected_code=404,
)
own_mapping = DataImportColumnMap.objects.filter(session=session_mine).first()
other_mapping = DataImportColumnMap.objects.filter(
session=session_other
).first()
if own_mapping:
self.get(
reverse('api-importer-mapping-detail', kwargs={'pk': own_mapping.pk}),
expected_code=200,
)
if other_mapping:
self.get(
reverse('api-importer-mapping-detail', kwargs={'pk': other_mapping.pk}),
expected_code=404,
)
# Staff user: should see rows/mappings from all sessions
self.user.is_staff = True
self.user.save()
all_row_pks = set(DataImportRow.objects.values_list('pk', flat=True))
response_rows = self.get(row_list_url).data
self.assertEqual({r['pk'] for r in response_rows}, all_row_pks)
all_mapping_pks = set(DataImportColumnMap.objects.values_list('pk', flat=True))
response_mappings = self.get(mapping_list_url).data
self.assertEqual({m['pk'] for m in response_mappings}, all_mapping_pks)
class AdminTest(ImporterMixin, AdminTestCase): class AdminTest(ImporterMixin, AdminTestCase):
"""Tests for the admin interface integration.""" """Tests for the admin interface integration."""