mirror of
https://github.com/inventree/InvenTree.git
synced 2026-07-04 06:00:38 +00:00
Data import permissions (#12169)
* Update data importer child permissions - Row data - Column data * Add unit tests * Cleanup session data after import is completed * Further scope narrowing
This commit is contained in:
@@ -115,6 +115,10 @@ class DataImportSessionAcceptFields(APIView):
|
||||
"""Accept the field mapping for a DataImportSession."""
|
||||
session = get_object_or_404(importer.models.DataImportSession, pk=pk)
|
||||
|
||||
# Check session ownership
|
||||
if not request.user.is_staff and session.user != request.user:
|
||||
raise PermissionDenied()
|
||||
|
||||
# Check that the user has permission to accept the field mapping
|
||||
if model_class := session.model_class:
|
||||
if not check_user_permission(request.user, model_class, 'change'):
|
||||
@@ -137,17 +141,45 @@ class DataImportSessionAcceptRows(DataImporterPermissionMixin, CreateAPI):
|
||||
ctx = super().get_serializer_context()
|
||||
|
||||
try:
|
||||
ctx['session'] = importer.models.DataImportSession.objects.get(
|
||||
session = importer.models.DataImportSession.objects.get(
|
||||
pk=self.kwargs.get('pk', None)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
except importer.models.DataImportSession.DoesNotExist:
|
||||
session = None
|
||||
|
||||
if session:
|
||||
user = self.request.user
|
||||
if not user.is_staff and session.user != user:
|
||||
raise PermissionDenied()
|
||||
ctx['session'] = session
|
||||
|
||||
ctx['request'] = self.request
|
||||
return ctx
|
||||
|
||||
|
||||
class DataImportColumnMappingList(DataImporterPermissionMixin, ListAPI):
|
||||
class DataImportSessionChildMixin(DataImporterPermissionMixin):
|
||||
"""Mixin for DataImportRow and DataImportColumnMap views.
|
||||
|
||||
Ensures users can only access objects that belong to an import session they own.
|
||||
Staff users retain access to all objects.
|
||||
"""
|
||||
|
||||
def get_queryset(self):
|
||||
"""Return only objects whose session belongs to the requesting user."""
|
||||
queryset = super().get_queryset()
|
||||
|
||||
try:
|
||||
user = self.request.user
|
||||
except AttributeError:
|
||||
raise PermissionDenied('User information is not available')
|
||||
|
||||
if user.is_staff:
|
||||
return queryset
|
||||
|
||||
return queryset.filter(session__user=user)
|
||||
|
||||
|
||||
class DataImportColumnMappingList(DataImportSessionChildMixin, ListAPI):
|
||||
"""API endpoint for accessing a list of DataImportColumnMap objects."""
|
||||
|
||||
queryset = importer.models.DataImportColumnMap.objects.all()
|
||||
@@ -158,14 +190,14 @@ class DataImportColumnMappingList(DataImporterPermissionMixin, ListAPI):
|
||||
filterset_fields = ['session']
|
||||
|
||||
|
||||
class DataImportColumnMappingDetail(DataImporterPermissionMixin, RetrieveUpdateAPI):
|
||||
class DataImportColumnMappingDetail(DataImportSessionChildMixin, RetrieveUpdateAPI):
|
||||
"""Detail endpoint for a single DataImportColumnMap object."""
|
||||
|
||||
queryset = importer.models.DataImportColumnMap.objects.all()
|
||||
serializer_class = importer.serializers.DataImportColumnMapSerializer
|
||||
|
||||
|
||||
class DataImportRowList(DataImporterPermissionMixin, BulkDeleteMixin, ListAPI):
|
||||
class DataImportRowList(DataImportSessionChildMixin, BulkDeleteMixin, ListAPI):
|
||||
"""API endpoint for accessing a list of DataImportRow objects."""
|
||||
|
||||
queryset = importer.models.DataImportRow.objects.all()
|
||||
@@ -180,7 +212,7 @@ class DataImportRowList(DataImporterPermissionMixin, BulkDeleteMixin, ListAPI):
|
||||
ordering = 'row_index'
|
||||
|
||||
|
||||
class DataImportRowDetail(DataImporterPermissionMixin, RetrieveUpdateDestroyAPI):
|
||||
class DataImportRowDetail(DataImportSessionChildMixin, RetrieveUpdateDestroyAPI):
|
||||
"""Detail endpoint for a single DataImportRow object."""
|
||||
|
||||
queryset = importer.models.DataImportRow.objects.all()
|
||||
|
||||
@@ -344,14 +344,21 @@ class DataImportSession(models.Model):
|
||||
self.save()
|
||||
|
||||
def check_complete(self) -> bool:
|
||||
"""Check if the import session is complete."""
|
||||
"""Check if the import session is complete.
|
||||
|
||||
When all rows have been accepted, the rows and column mappings are
|
||||
deleted as they are no longer needed. The session itself is retained
|
||||
as an audit record.
|
||||
"""
|
||||
if self.completed_row_count < self.row_count:
|
||||
return False
|
||||
|
||||
# Update the status of this session
|
||||
if self.status != DataImportStatusCode.COMPLETE.value:
|
||||
self.status = DataImportStatusCode.COMPLETE.value
|
||||
self.save()
|
||||
# Clear staging data now that all rows have been imported
|
||||
self.rows.all().delete()
|
||||
self.column_mappings.all().delete()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@@ -2,10 +2,11 @@
|
||||
|
||||
import os
|
||||
|
||||
from django.contrib.auth.models import User
|
||||
from django.core.files.base import ContentFile
|
||||
from django.urls import reverse
|
||||
|
||||
from importer.models import DataImportRow, DataImportSession
|
||||
from importer.models import DataImportColumnMap, DataImportRow, DataImportSession
|
||||
from InvenTree.unit_test import AdminTestCase, InvenTreeAPITestCase, InvenTreeTestCase
|
||||
|
||||
|
||||
@@ -58,14 +59,20 @@ class ImporterTest(ImporterMixin, InvenTreeTestCase):
|
||||
self.assertEqual(session.rows.count(), 12)
|
||||
|
||||
# Check that some data has been imported
|
||||
for row in session.rows.all():
|
||||
rows = list(session.rows.all())
|
||||
self.assertEqual(len(rows), 12)
|
||||
|
||||
for row in rows:
|
||||
self.assertIsNotNone(row.data.get('name', None))
|
||||
self.assertTrue(row.valid)
|
||||
|
||||
row.validate(commit=True)
|
||||
self.assertTrue(row.complete)
|
||||
|
||||
self.assertEqual(session.completed_row_count, 12)
|
||||
# All rows accepted: rows and mappings are cleared, session is retained
|
||||
session.refresh_from_db()
|
||||
self.assertEqual(session.rows.count(), 0)
|
||||
self.assertEqual(session.column_mappings.count(), 0)
|
||||
|
||||
# Check that the new companies have been created
|
||||
self.assertEqual(n + 12, Company.objects.count())
|
||||
@@ -204,6 +211,191 @@ class ImportAPITest(ImporterMixin, InvenTreeAPITestCase):
|
||||
for session in response.data:
|
||||
self.assertEqual(session['user'], self.user.pk)
|
||||
|
||||
def test_accept_fields_ownership(self):
|
||||
"""Test that accept_fields rejects requests for sessions owned by another user."""
|
||||
other_user = User.objects.create_user(
|
||||
username='other_accept', password='password'
|
||||
)
|
||||
|
||||
f = self.helper_file('companies.csv')
|
||||
session = DataImportSession.objects.create(
|
||||
data_file=f, model_type='company', user=other_user
|
||||
)
|
||||
|
||||
url = reverse('api-import-session-accept-fields', kwargs={'pk': session.pk})
|
||||
|
||||
# Non-owner, non-staff should be denied
|
||||
self.user.is_staff = False
|
||||
self.user.save()
|
||||
self.post(url, expected_code=403)
|
||||
|
||||
# Staff should be allowed (subject to model permission)
|
||||
# Company is part of the purchase_order ruleset
|
||||
self.user.is_staff = True
|
||||
self.user.save()
|
||||
self.assignRole('purchase_order.change')
|
||||
self.post(url, expected_code=200)
|
||||
|
||||
def test_accept_rows_ownership(self):
|
||||
"""Test that accept_rows rejects requests for sessions owned by another user."""
|
||||
other_user = User.objects.create_user(
|
||||
username='other_accept_rows', password='password'
|
||||
)
|
||||
|
||||
f = self.helper_file('companies.csv')
|
||||
session = DataImportSession.objects.create(
|
||||
data_file=f, model_type='company', user=other_user
|
||||
)
|
||||
session.extract_columns()
|
||||
|
||||
url = reverse('api-import-session-accept-rows', kwargs={'pk': session.pk})
|
||||
|
||||
self.user.is_staff = False
|
||||
self.user.save()
|
||||
self.post(url, {'rows': []}, expected_code=403)
|
||||
|
||||
# Staff can reach the endpoint (rows list is empty so validation rejects with 400, not 403)
|
||||
self.user.is_staff = True
|
||||
self.user.save()
|
||||
self.post(url, {'rows': []}, expected_code=400)
|
||||
|
||||
def test_session_cleanup_on_complete(self):
|
||||
"""Test that a completed import session deletes itself and all associated data."""
|
||||
url = reverse('api-importer-session-list')
|
||||
data_file = self.helper_file('part_categories.csv')
|
||||
|
||||
data = self.post(
|
||||
url,
|
||||
{'model_type': 'partcategory', 'data_file': data_file},
|
||||
format='multipart',
|
||||
).data
|
||||
|
||||
session_id = data['pk']
|
||||
session_pk = session_id
|
||||
|
||||
self.assignRole('part_category.add')
|
||||
self.post(
|
||||
reverse('api-import-session-accept-fields', kwargs={'pk': session_id}),
|
||||
expected_code=200,
|
||||
)
|
||||
|
||||
rows = self.get(
|
||||
reverse('api-importer-row-list'), data={'session': session_id}
|
||||
).data
|
||||
row_ids = [r['pk'] for r in rows]
|
||||
self.assertGreater(len(row_ids), 0)
|
||||
|
||||
# Confirm rows and mappings exist before acceptance
|
||||
self.assertTrue(DataImportRow.objects.filter(session_id=session_pk).exists())
|
||||
self.assertTrue(
|
||||
DataImportColumnMap.objects.filter(session_id=session_pk).exists()
|
||||
)
|
||||
|
||||
# Accept all rows — this should trigger cleanup of rows and mappings
|
||||
self.post(
|
||||
reverse('api-import-session-accept-rows', kwargs={'pk': session_id}),
|
||||
{'rows': row_ids},
|
||||
)
|
||||
|
||||
# Rows and column mappings must be cleared
|
||||
self.assertFalse(DataImportRow.objects.filter(session_id=session_pk).exists())
|
||||
self.assertFalse(
|
||||
DataImportColumnMap.objects.filter(session_id=session_pk).exists()
|
||||
)
|
||||
|
||||
# Session itself is retained as an audit record with COMPLETE status
|
||||
from importer.models import DataImportSession
|
||||
from importer.status_codes import DataImportStatusCode
|
||||
|
||||
session_obj = DataImportSession.objects.get(pk=session_pk)
|
||||
self.assertEqual(session_obj.status, DataImportStatusCode.COMPLETE.value)
|
||||
|
||||
detail = self.get(
|
||||
reverse('api-import-session-detail', kwargs={'pk': session_id}),
|
||||
expected_code=200,
|
||||
).data
|
||||
self.assertEqual(detail['row_count'], 0)
|
||||
self.assertEqual(detail['completed_row_count'], 0)
|
||||
|
||||
def test_row_and_mapping_ownership(self):
|
||||
"""Test that DataImportRow and DataImportColumnMap endpoints filter by session ownership."""
|
||||
f = self.helper_file('companies.csv')
|
||||
|
||||
other_user = User.objects.create_user(
|
||||
username='other_importer', password='password'
|
||||
)
|
||||
|
||||
# Session owned by self.user
|
||||
session_mine = DataImportSession.objects.create(
|
||||
data_file=f, model_type='company', user=self.user
|
||||
)
|
||||
session_mine.extract_columns()
|
||||
|
||||
# Session owned by another user
|
||||
f2 = self.helper_file('companies.csv')
|
||||
session_other = DataImportSession.objects.create(
|
||||
data_file=f2, model_type='company', user=other_user
|
||||
)
|
||||
session_other.extract_columns()
|
||||
|
||||
row_list_url = reverse('api-importer-row-list')
|
||||
mapping_list_url = reverse('api-importer-mapping-list')
|
||||
|
||||
# Non-staff: should only see rows/mappings from own session
|
||||
self.user.is_staff = False
|
||||
self.user.save()
|
||||
|
||||
rows = self.get(row_list_url).data
|
||||
for row in rows:
|
||||
self.assertEqual(row['session'], session_mine.pk)
|
||||
|
||||
mappings = self.get(mapping_list_url).data
|
||||
for mapping in mappings:
|
||||
self.assertEqual(mapping['session'], session_mine.pk)
|
||||
|
||||
# Detail endpoint: own session's row/mapping should be accessible
|
||||
own_row = DataImportRow.objects.filter(session=session_mine).first()
|
||||
other_row = DataImportRow.objects.filter(session=session_other).first()
|
||||
|
||||
if own_row:
|
||||
self.get(
|
||||
reverse('api-importer-row-detail', kwargs={'pk': own_row.pk}),
|
||||
expected_code=200,
|
||||
)
|
||||
if other_row:
|
||||
self.get(
|
||||
reverse('api-importer-row-detail', kwargs={'pk': other_row.pk}),
|
||||
expected_code=404,
|
||||
)
|
||||
|
||||
own_mapping = DataImportColumnMap.objects.filter(session=session_mine).first()
|
||||
other_mapping = DataImportColumnMap.objects.filter(
|
||||
session=session_other
|
||||
).first()
|
||||
|
||||
if own_mapping:
|
||||
self.get(
|
||||
reverse('api-importer-mapping-detail', kwargs={'pk': own_mapping.pk}),
|
||||
expected_code=200,
|
||||
)
|
||||
if other_mapping:
|
||||
self.get(
|
||||
reverse('api-importer-mapping-detail', kwargs={'pk': other_mapping.pk}),
|
||||
expected_code=404,
|
||||
)
|
||||
|
||||
# Staff user: should see rows/mappings from all sessions
|
||||
self.user.is_staff = True
|
||||
self.user.save()
|
||||
|
||||
all_row_pks = set(DataImportRow.objects.values_list('pk', flat=True))
|
||||
response_rows = self.get(row_list_url).data
|
||||
self.assertEqual({r['pk'] for r in response_rows}, all_row_pks)
|
||||
|
||||
all_mapping_pks = set(DataImportColumnMap.objects.values_list('pk', flat=True))
|
||||
response_mappings = self.get(mapping_list_url).data
|
||||
self.assertEqual({m['pk'] for m in response_mappings}, all_mapping_pks)
|
||||
|
||||
|
||||
class AdminTest(ImporterMixin, AdminTestCase):
|
||||
"""Tests for the admin interface integration."""
|
||||
|
||||
Reference in New Issue
Block a user