mirror of
https://github.com/inventree/InvenTree.git
synced 2026-07-04 06:00:38 +00:00
Data import permissions (#12169)
* Update data importer child permissions - Row data - Column data * Add unit tests * Cleanup session data after import is completed * Further scope narrowing
This commit is contained in:
@@ -115,6 +115,10 @@ class DataImportSessionAcceptFields(APIView):
|
|||||||
"""Accept the field mapping for a DataImportSession."""
|
"""Accept the field mapping for a DataImportSession."""
|
||||||
session = get_object_or_404(importer.models.DataImportSession, pk=pk)
|
session = get_object_or_404(importer.models.DataImportSession, pk=pk)
|
||||||
|
|
||||||
|
# Check session ownership
|
||||||
|
if not request.user.is_staff and session.user != request.user:
|
||||||
|
raise PermissionDenied()
|
||||||
|
|
||||||
# Check that the user has permission to accept the field mapping
|
# Check that the user has permission to accept the field mapping
|
||||||
if model_class := session.model_class:
|
if model_class := session.model_class:
|
||||||
if not check_user_permission(request.user, model_class, 'change'):
|
if not check_user_permission(request.user, model_class, 'change'):
|
||||||
@@ -137,17 +141,45 @@ class DataImportSessionAcceptRows(DataImporterPermissionMixin, CreateAPI):
|
|||||||
ctx = super().get_serializer_context()
|
ctx = super().get_serializer_context()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ctx['session'] = importer.models.DataImportSession.objects.get(
|
session = importer.models.DataImportSession.objects.get(
|
||||||
pk=self.kwargs.get('pk', None)
|
pk=self.kwargs.get('pk', None)
|
||||||
)
|
)
|
||||||
except Exception:
|
except importer.models.DataImportSession.DoesNotExist:
|
||||||
pass
|
session = None
|
||||||
|
|
||||||
|
if session:
|
||||||
|
user = self.request.user
|
||||||
|
if not user.is_staff and session.user != user:
|
||||||
|
raise PermissionDenied()
|
||||||
|
ctx['session'] = session
|
||||||
|
|
||||||
ctx['request'] = self.request
|
ctx['request'] = self.request
|
||||||
return ctx
|
return ctx
|
||||||
|
|
||||||
|
|
||||||
class DataImportColumnMappingList(DataImporterPermissionMixin, ListAPI):
|
class DataImportSessionChildMixin(DataImporterPermissionMixin):
|
||||||
|
"""Mixin for DataImportRow and DataImportColumnMap views.
|
||||||
|
|
||||||
|
Ensures users can only access objects that belong to an import session they own.
|
||||||
|
Staff users retain access to all objects.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_queryset(self):
|
||||||
|
"""Return only objects whose session belongs to the requesting user."""
|
||||||
|
queryset = super().get_queryset()
|
||||||
|
|
||||||
|
try:
|
||||||
|
user = self.request.user
|
||||||
|
except AttributeError:
|
||||||
|
raise PermissionDenied('User information is not available')
|
||||||
|
|
||||||
|
if user.is_staff:
|
||||||
|
return queryset
|
||||||
|
|
||||||
|
return queryset.filter(session__user=user)
|
||||||
|
|
||||||
|
|
||||||
|
class DataImportColumnMappingList(DataImportSessionChildMixin, ListAPI):
|
||||||
"""API endpoint for accessing a list of DataImportColumnMap objects."""
|
"""API endpoint for accessing a list of DataImportColumnMap objects."""
|
||||||
|
|
||||||
queryset = importer.models.DataImportColumnMap.objects.all()
|
queryset = importer.models.DataImportColumnMap.objects.all()
|
||||||
@@ -158,14 +190,14 @@ class DataImportColumnMappingList(DataImporterPermissionMixin, ListAPI):
|
|||||||
filterset_fields = ['session']
|
filterset_fields = ['session']
|
||||||
|
|
||||||
|
|
||||||
class DataImportColumnMappingDetail(DataImporterPermissionMixin, RetrieveUpdateAPI):
|
class DataImportColumnMappingDetail(DataImportSessionChildMixin, RetrieveUpdateAPI):
|
||||||
"""Detail endpoint for a single DataImportColumnMap object."""
|
"""Detail endpoint for a single DataImportColumnMap object."""
|
||||||
|
|
||||||
queryset = importer.models.DataImportColumnMap.objects.all()
|
queryset = importer.models.DataImportColumnMap.objects.all()
|
||||||
serializer_class = importer.serializers.DataImportColumnMapSerializer
|
serializer_class = importer.serializers.DataImportColumnMapSerializer
|
||||||
|
|
||||||
|
|
||||||
class DataImportRowList(DataImporterPermissionMixin, BulkDeleteMixin, ListAPI):
|
class DataImportRowList(DataImportSessionChildMixin, BulkDeleteMixin, ListAPI):
|
||||||
"""API endpoint for accessing a list of DataImportRow objects."""
|
"""API endpoint for accessing a list of DataImportRow objects."""
|
||||||
|
|
||||||
queryset = importer.models.DataImportRow.objects.all()
|
queryset = importer.models.DataImportRow.objects.all()
|
||||||
@@ -180,7 +212,7 @@ class DataImportRowList(DataImporterPermissionMixin, BulkDeleteMixin, ListAPI):
|
|||||||
ordering = 'row_index'
|
ordering = 'row_index'
|
||||||
|
|
||||||
|
|
||||||
class DataImportRowDetail(DataImporterPermissionMixin, RetrieveUpdateDestroyAPI):
|
class DataImportRowDetail(DataImportSessionChildMixin, RetrieveUpdateDestroyAPI):
|
||||||
"""Detail endpoint for a single DataImportRow object."""
|
"""Detail endpoint for a single DataImportRow object."""
|
||||||
|
|
||||||
queryset = importer.models.DataImportRow.objects.all()
|
queryset = importer.models.DataImportRow.objects.all()
|
||||||
|
|||||||
@@ -344,14 +344,21 @@ class DataImportSession(models.Model):
|
|||||||
self.save()
|
self.save()
|
||||||
|
|
||||||
def check_complete(self) -> bool:
|
def check_complete(self) -> bool:
|
||||||
"""Check if the import session is complete."""
|
"""Check if the import session is complete.
|
||||||
|
|
||||||
|
When all rows have been accepted, the rows and column mappings are
|
||||||
|
deleted as they are no longer needed. The session itself is retained
|
||||||
|
as an audit record.
|
||||||
|
"""
|
||||||
if self.completed_row_count < self.row_count:
|
if self.completed_row_count < self.row_count:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Update the status of this session
|
|
||||||
if self.status != DataImportStatusCode.COMPLETE.value:
|
if self.status != DataImportStatusCode.COMPLETE.value:
|
||||||
self.status = DataImportStatusCode.COMPLETE.value
|
self.status = DataImportStatusCode.COMPLETE.value
|
||||||
self.save()
|
self.save()
|
||||||
|
# Clear staging data now that all rows have been imported
|
||||||
|
self.rows.all().delete()
|
||||||
|
self.column_mappings.all().delete()
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|||||||
@@ -2,10 +2,11 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
from django.contrib.auth.models import User
|
||||||
from django.core.files.base import ContentFile
|
from django.core.files.base import ContentFile
|
||||||
from django.urls import reverse
|
from django.urls import reverse
|
||||||
|
|
||||||
from importer.models import DataImportRow, DataImportSession
|
from importer.models import DataImportColumnMap, DataImportRow, DataImportSession
|
||||||
from InvenTree.unit_test import AdminTestCase, InvenTreeAPITestCase, InvenTreeTestCase
|
from InvenTree.unit_test import AdminTestCase, InvenTreeAPITestCase, InvenTreeTestCase
|
||||||
|
|
||||||
|
|
||||||
@@ -58,14 +59,20 @@ class ImporterTest(ImporterMixin, InvenTreeTestCase):
|
|||||||
self.assertEqual(session.rows.count(), 12)
|
self.assertEqual(session.rows.count(), 12)
|
||||||
|
|
||||||
# Check that some data has been imported
|
# Check that some data has been imported
|
||||||
for row in session.rows.all():
|
rows = list(session.rows.all())
|
||||||
|
self.assertEqual(len(rows), 12)
|
||||||
|
|
||||||
|
for row in rows:
|
||||||
self.assertIsNotNone(row.data.get('name', None))
|
self.assertIsNotNone(row.data.get('name', None))
|
||||||
self.assertTrue(row.valid)
|
self.assertTrue(row.valid)
|
||||||
|
|
||||||
row.validate(commit=True)
|
row.validate(commit=True)
|
||||||
self.assertTrue(row.complete)
|
self.assertTrue(row.complete)
|
||||||
|
|
||||||
self.assertEqual(session.completed_row_count, 12)
|
# All rows accepted: rows and mappings are cleared, session is retained
|
||||||
|
session.refresh_from_db()
|
||||||
|
self.assertEqual(session.rows.count(), 0)
|
||||||
|
self.assertEqual(session.column_mappings.count(), 0)
|
||||||
|
|
||||||
# Check that the new companies have been created
|
# Check that the new companies have been created
|
||||||
self.assertEqual(n + 12, Company.objects.count())
|
self.assertEqual(n + 12, Company.objects.count())
|
||||||
@@ -204,6 +211,191 @@ class ImportAPITest(ImporterMixin, InvenTreeAPITestCase):
|
|||||||
for session in response.data:
|
for session in response.data:
|
||||||
self.assertEqual(session['user'], self.user.pk)
|
self.assertEqual(session['user'], self.user.pk)
|
||||||
|
|
||||||
|
def test_accept_fields_ownership(self):
|
||||||
|
"""Test that accept_fields rejects requests for sessions owned by another user."""
|
||||||
|
other_user = User.objects.create_user(
|
||||||
|
username='other_accept', password='password'
|
||||||
|
)
|
||||||
|
|
||||||
|
f = self.helper_file('companies.csv')
|
||||||
|
session = DataImportSession.objects.create(
|
||||||
|
data_file=f, model_type='company', user=other_user
|
||||||
|
)
|
||||||
|
|
||||||
|
url = reverse('api-import-session-accept-fields', kwargs={'pk': session.pk})
|
||||||
|
|
||||||
|
# Non-owner, non-staff should be denied
|
||||||
|
self.user.is_staff = False
|
||||||
|
self.user.save()
|
||||||
|
self.post(url, expected_code=403)
|
||||||
|
|
||||||
|
# Staff should be allowed (subject to model permission)
|
||||||
|
# Company is part of the purchase_order ruleset
|
||||||
|
self.user.is_staff = True
|
||||||
|
self.user.save()
|
||||||
|
self.assignRole('purchase_order.change')
|
||||||
|
self.post(url, expected_code=200)
|
||||||
|
|
||||||
|
def test_accept_rows_ownership(self):
|
||||||
|
"""Test that accept_rows rejects requests for sessions owned by another user."""
|
||||||
|
other_user = User.objects.create_user(
|
||||||
|
username='other_accept_rows', password='password'
|
||||||
|
)
|
||||||
|
|
||||||
|
f = self.helper_file('companies.csv')
|
||||||
|
session = DataImportSession.objects.create(
|
||||||
|
data_file=f, model_type='company', user=other_user
|
||||||
|
)
|
||||||
|
session.extract_columns()
|
||||||
|
|
||||||
|
url = reverse('api-import-session-accept-rows', kwargs={'pk': session.pk})
|
||||||
|
|
||||||
|
self.user.is_staff = False
|
||||||
|
self.user.save()
|
||||||
|
self.post(url, {'rows': []}, expected_code=403)
|
||||||
|
|
||||||
|
# Staff can reach the endpoint (rows list is empty so validation rejects with 400, not 403)
|
||||||
|
self.user.is_staff = True
|
||||||
|
self.user.save()
|
||||||
|
self.post(url, {'rows': []}, expected_code=400)
|
||||||
|
|
||||||
|
def test_session_cleanup_on_complete(self):
|
||||||
|
"""Test that a completed import session deletes itself and all associated data."""
|
||||||
|
url = reverse('api-importer-session-list')
|
||||||
|
data_file = self.helper_file('part_categories.csv')
|
||||||
|
|
||||||
|
data = self.post(
|
||||||
|
url,
|
||||||
|
{'model_type': 'partcategory', 'data_file': data_file},
|
||||||
|
format='multipart',
|
||||||
|
).data
|
||||||
|
|
||||||
|
session_id = data['pk']
|
||||||
|
session_pk = session_id
|
||||||
|
|
||||||
|
self.assignRole('part_category.add')
|
||||||
|
self.post(
|
||||||
|
reverse('api-import-session-accept-fields', kwargs={'pk': session_id}),
|
||||||
|
expected_code=200,
|
||||||
|
)
|
||||||
|
|
||||||
|
rows = self.get(
|
||||||
|
reverse('api-importer-row-list'), data={'session': session_id}
|
||||||
|
).data
|
||||||
|
row_ids = [r['pk'] for r in rows]
|
||||||
|
self.assertGreater(len(row_ids), 0)
|
||||||
|
|
||||||
|
# Confirm rows and mappings exist before acceptance
|
||||||
|
self.assertTrue(DataImportRow.objects.filter(session_id=session_pk).exists())
|
||||||
|
self.assertTrue(
|
||||||
|
DataImportColumnMap.objects.filter(session_id=session_pk).exists()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Accept all rows — this should trigger cleanup of rows and mappings
|
||||||
|
self.post(
|
||||||
|
reverse('api-import-session-accept-rows', kwargs={'pk': session_id}),
|
||||||
|
{'rows': row_ids},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Rows and column mappings must be cleared
|
||||||
|
self.assertFalse(DataImportRow.objects.filter(session_id=session_pk).exists())
|
||||||
|
self.assertFalse(
|
||||||
|
DataImportColumnMap.objects.filter(session_id=session_pk).exists()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Session itself is retained as an audit record with COMPLETE status
|
||||||
|
from importer.models import DataImportSession
|
||||||
|
from importer.status_codes import DataImportStatusCode
|
||||||
|
|
||||||
|
session_obj = DataImportSession.objects.get(pk=session_pk)
|
||||||
|
self.assertEqual(session_obj.status, DataImportStatusCode.COMPLETE.value)
|
||||||
|
|
||||||
|
detail = self.get(
|
||||||
|
reverse('api-import-session-detail', kwargs={'pk': session_id}),
|
||||||
|
expected_code=200,
|
||||||
|
).data
|
||||||
|
self.assertEqual(detail['row_count'], 0)
|
||||||
|
self.assertEqual(detail['completed_row_count'], 0)
|
||||||
|
|
||||||
|
def test_row_and_mapping_ownership(self):
|
||||||
|
"""Test that DataImportRow and DataImportColumnMap endpoints filter by session ownership."""
|
||||||
|
f = self.helper_file('companies.csv')
|
||||||
|
|
||||||
|
other_user = User.objects.create_user(
|
||||||
|
username='other_importer', password='password'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Session owned by self.user
|
||||||
|
session_mine = DataImportSession.objects.create(
|
||||||
|
data_file=f, model_type='company', user=self.user
|
||||||
|
)
|
||||||
|
session_mine.extract_columns()
|
||||||
|
|
||||||
|
# Session owned by another user
|
||||||
|
f2 = self.helper_file('companies.csv')
|
||||||
|
session_other = DataImportSession.objects.create(
|
||||||
|
data_file=f2, model_type='company', user=other_user
|
||||||
|
)
|
||||||
|
session_other.extract_columns()
|
||||||
|
|
||||||
|
row_list_url = reverse('api-importer-row-list')
|
||||||
|
mapping_list_url = reverse('api-importer-mapping-list')
|
||||||
|
|
||||||
|
# Non-staff: should only see rows/mappings from own session
|
||||||
|
self.user.is_staff = False
|
||||||
|
self.user.save()
|
||||||
|
|
||||||
|
rows = self.get(row_list_url).data
|
||||||
|
for row in rows:
|
||||||
|
self.assertEqual(row['session'], session_mine.pk)
|
||||||
|
|
||||||
|
mappings = self.get(mapping_list_url).data
|
||||||
|
for mapping in mappings:
|
||||||
|
self.assertEqual(mapping['session'], session_mine.pk)
|
||||||
|
|
||||||
|
# Detail endpoint: own session's row/mapping should be accessible
|
||||||
|
own_row = DataImportRow.objects.filter(session=session_mine).first()
|
||||||
|
other_row = DataImportRow.objects.filter(session=session_other).first()
|
||||||
|
|
||||||
|
if own_row:
|
||||||
|
self.get(
|
||||||
|
reverse('api-importer-row-detail', kwargs={'pk': own_row.pk}),
|
||||||
|
expected_code=200,
|
||||||
|
)
|
||||||
|
if other_row:
|
||||||
|
self.get(
|
||||||
|
reverse('api-importer-row-detail', kwargs={'pk': other_row.pk}),
|
||||||
|
expected_code=404,
|
||||||
|
)
|
||||||
|
|
||||||
|
own_mapping = DataImportColumnMap.objects.filter(session=session_mine).first()
|
||||||
|
other_mapping = DataImportColumnMap.objects.filter(
|
||||||
|
session=session_other
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if own_mapping:
|
||||||
|
self.get(
|
||||||
|
reverse('api-importer-mapping-detail', kwargs={'pk': own_mapping.pk}),
|
||||||
|
expected_code=200,
|
||||||
|
)
|
||||||
|
if other_mapping:
|
||||||
|
self.get(
|
||||||
|
reverse('api-importer-mapping-detail', kwargs={'pk': other_mapping.pk}),
|
||||||
|
expected_code=404,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Staff user: should see rows/mappings from all sessions
|
||||||
|
self.user.is_staff = True
|
||||||
|
self.user.save()
|
||||||
|
|
||||||
|
all_row_pks = set(DataImportRow.objects.values_list('pk', flat=True))
|
||||||
|
response_rows = self.get(row_list_url).data
|
||||||
|
self.assertEqual({r['pk'] for r in response_rows}, all_row_pks)
|
||||||
|
|
||||||
|
all_mapping_pks = set(DataImportColumnMap.objects.values_list('pk', flat=True))
|
||||||
|
response_mappings = self.get(mapping_list_url).data
|
||||||
|
self.assertEqual({m['pk'] for m in response_mappings}, all_mapping_pks)
|
||||||
|
|
||||||
|
|
||||||
class AdminTest(ImporterMixin, AdminTestCase):
|
class AdminTest(ImporterMixin, AdminTestCase):
|
||||||
"""Tests for the admin interface integration."""
|
"""Tests for the admin interface integration."""
|
||||||
|
|||||||
Reference in New Issue
Block a user