diff --git a/src/backend/InvenTree/InvenTree/api.py b/src/backend/InvenTree/InvenTree/api.py index c3fee4cdb9..9f61ade0aa 100644 --- a/src/backend/InvenTree/InvenTree/api.py +++ b/src/backend/InvenTree/InvenTree/api.py @@ -17,6 +17,7 @@ from django_q.models import OrmQ from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema from rest_framework import serializers from rest_framework.generics import GenericAPIView +from rest_framework.request import clone_request from rest_framework.response import Response from rest_framework.serializers import ValidationError from rest_framework.views import APIView @@ -31,7 +32,7 @@ from InvenTree.mixins import ListCreateAPI from InvenTree.sso import sso_registration_enabled from plugin.serializers import MetadataSerializer from users.models import ApiToken -from users.permissions import check_user_permission +from users.permissions import check_user_permission, prefetch_rule_sets from .helpers import plugins_info from .helpers_email import is_email_configured @@ -767,6 +768,13 @@ class APISearchView(GenericAPIView): search_filters = self.get_result_filters() + # Create a clone of the request object to modify + # Use GET method for the individual list views + cloned_request = clone_request(request, 'GET') + + # Fetch and cache all groups associated with the current user + groups = prefetch_rule_sets(request.user) + for key, cls in self.get_result_types().items(): # Only return results which are specifically requested if key in data: @@ -790,22 +798,23 @@ class APISearchView(GenericAPIView): view = cls() # Override regular query params with specific ones for this search request - request._request.GET = params - view.request = request + cloned_request._request.GET = params + view.request = cloned_request view.format_kwarg = 'format' # Check permissions and update results dict with particular query model = view.serializer_class.Meta.model + if not check_user_permission( + request.user, model, 'view', groups=groups + ): + results[key] = { + 'error': _('User does not have permission to view this model') + } + continue + try: - if check_user_permission(request.user, model, 'view'): - results[key] = view.list(request, *args, **kwargs).data - else: - results[key] = { - 'error': _( - 'User does not have permission to view this model' - ) - } + results[key] = view.list(request, *args, **kwargs).data except Exception as exc: results[key] = {'error': str(exc)} diff --git a/src/backend/InvenTree/InvenTree/serializers.py b/src/backend/InvenTree/InvenTree/serializers.py index b215a0639a..4bd226a574 100644 --- a/src/backend/InvenTree/InvenTree/serializers.py +++ b/src/backend/InvenTree/InvenTree/serializers.py @@ -21,6 +21,7 @@ from rest_framework import serializers from rest_framework.exceptions import ValidationError from rest_framework.fields import empty from rest_framework.mixins import ListModelMixin +from rest_framework.permissions import SAFE_METHODS from rest_framework.serializers import DecimalField from rest_framework.utils import model_meta from taggit.serializers import TaggitSerializer, TagListSerializerField @@ -229,7 +230,7 @@ class FilterableSerializerMixin: # Skip filtering for a write requests - all fields should be present for data creation if request := self.context.get('request', None): if method := getattr(request, 'method', None): - if str(method).lower() in ['post', 'put', 'patch'] and not is_exporting: + if method not in SAFE_METHODS and not is_exporting: return # Throw out fields which are not requested (either by default or explicitly) diff --git a/src/backend/InvenTree/company/serializers.py b/src/backend/InvenTree/company/serializers.py index b3c936ae49..53f10eb961 100644 --- a/src/backend/InvenTree/company/serializers.py +++ b/src/backend/InvenTree/company/serializers.py @@ -268,11 +268,7 @@ class ManufacturerPartSerializer( source='part', many=False, read_only=True, allow_null=True ), True, - prefetch_fields=[ - Prefetch( - 'part', queryset=part.models.Part.objects.select_related('pricing_data') - ) - ], + prefetch_fields=['part', 'part__pricing_data', 'part__category'], ) pretty_name = enable_filter( @@ -438,7 +434,7 @@ class SupplierPartSerializer( label=_('Part'), source='part', many=False, read_only=True, allow_null=True ), False, - prefetch_fields=['part'], + prefetch_fields=['part', 'part__pricing_data'], ) supplier_detail = enable_filter( diff --git a/src/backend/InvenTree/company/test_api.py b/src/backend/InvenTree/company/test_api.py index b3701fb08f..8ff721f123 100644 --- a/src/backend/InvenTree/company/test_api.py +++ b/src/backend/InvenTree/company/test_api.py @@ -2,7 +2,14 @@ from django.urls import reverse -from company.models import Address, Company, Contact, SupplierPart, SupplierPriceBreak +from company.models import ( + Address, + Company, + Contact, + ManufacturerPart, + SupplierPart, + SupplierPriceBreak, +) from InvenTree.unit_test import InvenTreeAPITestCase from part.models import Part from users.permissions import check_user_permission @@ -498,7 +505,9 @@ class ManufacturerTest(InvenTreeAPITestCase): def test_manufacturer_part_detail(self): """Tests for the ManufacturerPart detail endpoint.""" - url = reverse('api-manufacturer-part-detail', kwargs={'pk': 1}) + mp = ManufacturerPart.objects.first() + + url = reverse('api-manufacturer-part-detail', kwargs={'pk': mp.pk}) response = self.get(url) self.assertEqual(response.data['MPN'], 'MPN123') diff --git a/src/backend/InvenTree/order/api.py b/src/backend/InvenTree/order/api.py index 4c0eec6156..c59fc2c22a 100644 --- a/src/backend/InvenTree/order/api.py +++ b/src/backend/InvenTree/order/api.py @@ -362,7 +362,9 @@ class PurchaseOrderOutputOptions(OutputConfiguration): class PurchaseOrderMixin(SerializerContextMixin): """Mixin class for PurchaseOrder endpoints.""" - queryset = models.PurchaseOrder.objects.all() + queryset = models.PurchaseOrder.objects.all().prefetch_related( + 'supplier', 'created_by' + ) serializer_class = serializers.PurchaseOrderSerializer def get_queryset(self, *args, **kwargs): @@ -371,8 +373,6 @@ class PurchaseOrderMixin(SerializerContextMixin): queryset = serializers.PurchaseOrderSerializer.annotate_queryset(queryset) - queryset = queryset.prefetch_related('supplier', 'created_by') - return queryset @@ -824,15 +824,15 @@ class SalesOrderFilter(OrderFilter): class SalesOrderMixin(SerializerContextMixin): """Mixin class for SalesOrder endpoints.""" - queryset = models.SalesOrder.objects.all() + queryset = models.SalesOrder.objects.all().prefetch_related( + 'customer', 'created_by' + ) serializer_class = serializers.SalesOrderSerializer def get_queryset(self, *args, **kwargs): """Return annotated queryset for this endpoint.""" queryset = super().get_queryset(*args, **kwargs) - queryset = queryset.prefetch_related('customer', 'created_by') - queryset = serializers.SalesOrderSerializer.annotate_queryset(queryset) return queryset diff --git a/src/backend/InvenTree/part/api.py b/src/backend/InvenTree/part/api.py index 25c3c3ad29..64e3cabefb 100644 --- a/src/backend/InvenTree/part/api.py +++ b/src/backend/InvenTree/part/api.py @@ -1009,7 +1009,9 @@ class PartMixin(SerializerContextMixin): """Mixin class for Part API endpoints.""" serializer_class = part_serializers.PartSerializer - queryset = Part.objects.all().select_related('pricing_data') + queryset = ( + Part.objects.all().select_related('pricing_data').prefetch_related('category') + ) starred_parts = None is_create = False diff --git a/src/backend/InvenTree/stock/serializers.py b/src/backend/InvenTree/stock/serializers.py index e5314ed6ef..35054ffe40 100644 --- a/src/backend/InvenTree/stock/serializers.py +++ b/src/backend/InvenTree/stock/serializers.py @@ -489,14 +489,13 @@ class StockItemSerializer( ), 'parent', 'part__category', - 'part__pricing_data', 'supplier_part', 'supplier_part__manufacturer_part', 'customer', 'belongs_to', 'sales_order', 'consumed_by', - ).select_related('part') + ).select_related('part', 'part__pricing_data') # Annotate the queryset with the total allocated to sales orders queryset = queryset.annotate( diff --git a/src/backend/InvenTree/users/permissions.py b/src/backend/InvenTree/users/permissions.py index 67b735ae4b..7c93f5ec97 100644 --- a/src/backend/InvenTree/users/permissions.py +++ b/src/backend/InvenTree/users/permissions.py @@ -130,7 +130,11 @@ def check_user_role( def check_user_permission( - user: User, model: models.Model, permission: str, allow_inactive: bool = False + user: User, + model: models.Model, + permission: str, + allow_inactive: bool = False, + groups: Optional[QuerySet] = None, ) -> bool: """Check if the user has a particular permission against a given model type. @@ -139,6 +143,7 @@ def check_user_permission( model: The model class to check (e.g. 'part') permission: The permission to check (e.g. 'view' / 'delete') allow_inactive: If False, disallow inactive users from having permissions + groups: Optional cached queryset of groups to check (defaults to user's groups) Returns: bool: True if the user has the specified permission @@ -160,9 +165,11 @@ def check_user_permission( if table_name in get_ruleset_ignore(): return True + groups = groups or prefetch_rule_sets(user) + for role, table_names in get_ruleset_models().items(): if table_name in table_names: - if check_user_role(user, role, permission): + if check_user_role(user, role, permission, groups=groups): return True # Check for children models which inherits from parent role @@ -172,7 +179,7 @@ def check_user_permission( if parent_child_string == table_name: # Check if parent role has change permission - if check_user_role(user, parent, 'change'): + if check_user_role(user, parent, 'change', groups=groups): return True # Generate the permission name based on the model and permission diff --git a/src/performance/tests.py b/src/performance/tests.py index 34930630a9..cd15152433 100644 --- a/src/performance/tests.py +++ b/src/performance/tests.py @@ -88,3 +88,73 @@ def test_api_options_performance(url): assert result assert 'actions' in result assert len(result['actions']) > 0 + + +@pytest.mark.benchmark +@pytest.mark.parametrize( + 'key', + [ + 'all', + 'part', + 'partcategory', + 'supplierpart', + 'manufacturerpart', + 'stockitem', + 'stocklocation', + 'build', + 'supplier', + 'manufacturer', + 'customer', + 'purchaseorder', + 'salesorder', + 'salesordershipment', + 'returnorder', + ], +) +def test_search_performance(key: str): + """Benchmark the API search performance.""" + SEARCH_URL = '/api/search/' + + # An indicative search query for various model types + SEARCH_DATA = { + 'part': {'active': True}, + 'partcategory': {}, + 'supplierpart': { + 'part_detail': True, + 'supplier_detail': True, + 'manufacturer_detail': True, + }, + 'manufacturerpart': { + 'part_detail': True, + 'supplier_detail': True, + 'manufacturer_detail': True, + }, + 'stockitem': {'part_detail': True, 'location_detail': True, 'in_stock': True}, + 'stocklocation': {}, + 'build': {'part_detail': True}, + 'supplier': {}, + 'manufacturer': {}, + 'customer': {}, + 'purchaseorder': {'supplier_detail': True, 'outstanding': True}, + 'salesorder': {'customer_detail': True, 'outstanding': True}, + 'salesordershipment': {}, + 'returnorder': {'customer_detail': True, 'outstanding': True}, + } + + model_types = list(SEARCH_DATA.keys()) + + search_params = SEARCH_DATA if key == 'all' else {key: SEARCH_DATA[key]} + + # Add in a common search term + search_params.update({'search': '0', 'limit': 50}) + + response = api_client.post(SEARCH_URL, data=search_params) + assert response + + if key == 'all': + for model_type in model_types: + assert model_type in response + assert 'error' not in response[model_type] + else: + assert key in response + assert 'error' not in response[key]