2
0
mirror of https://github.com/inventree/InvenTree.git synced 2026-01-10 05:08:09 +00:00

[API] Search improvements (#11094)

* Improve prefetching

* Cache user groups for permission  check

* Use a GET request to execute search

- Prevent forced prefetch
- Reduce execution time significantly

* Fix group caching

* Improve StockItemSerializer

- Select related for pricing_data rather than prefetch

* Add benchmarking for search endpoint

* Adjust prefetch

* Ensure no errors returned

* Fix prefetch

* Fix more prefetch issues

* Remove debug print

* Fix for performance testing

* Data is already returned as dict

* Test fix

* Extract model types better
This commit is contained in:
Oliver
2026-01-08 18:06:23 +11:00
committed by GitHub
parent 2457197446
commit 4709dc8a9a
9 changed files with 125 additions and 32 deletions

View File

@@ -17,6 +17,7 @@ from django_q.models import OrmQ
from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema
from rest_framework import serializers from rest_framework import serializers
from rest_framework.generics import GenericAPIView from rest_framework.generics import GenericAPIView
from rest_framework.request import clone_request
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.serializers import ValidationError from rest_framework.serializers import ValidationError
from rest_framework.views import APIView from rest_framework.views import APIView
@@ -31,7 +32,7 @@ from InvenTree.mixins import ListCreateAPI
from InvenTree.sso import sso_registration_enabled from InvenTree.sso import sso_registration_enabled
from plugin.serializers import MetadataSerializer from plugin.serializers import MetadataSerializer
from users.models import ApiToken from users.models import ApiToken
from users.permissions import check_user_permission from users.permissions import check_user_permission, prefetch_rule_sets
from .helpers import plugins_info from .helpers import plugins_info
from .helpers_email import is_email_configured from .helpers_email import is_email_configured
@@ -767,6 +768,13 @@ class APISearchView(GenericAPIView):
search_filters = self.get_result_filters() search_filters = self.get_result_filters()
# Create a clone of the request object to modify
# Use GET method for the individual list views
cloned_request = clone_request(request, 'GET')
# Fetch and cache all groups associated with the current user
groups = prefetch_rule_sets(request.user)
for key, cls in self.get_result_types().items(): for key, cls in self.get_result_types().items():
# Only return results which are specifically requested # Only return results which are specifically requested
if key in data: if key in data:
@@ -790,22 +798,23 @@ class APISearchView(GenericAPIView):
view = cls() view = cls()
# Override regular query params with specific ones for this search request # Override regular query params with specific ones for this search request
request._request.GET = params cloned_request._request.GET = params
view.request = request view.request = cloned_request
view.format_kwarg = 'format' view.format_kwarg = 'format'
# Check permissions and update results dict with particular query # Check permissions and update results dict with particular query
model = view.serializer_class.Meta.model model = view.serializer_class.Meta.model
if not check_user_permission(
request.user, model, 'view', groups=groups
):
results[key] = {
'error': _('User does not have permission to view this model')
}
continue
try: try:
if check_user_permission(request.user, model, 'view'): results[key] = view.list(request, *args, **kwargs).data
results[key] = view.list(request, *args, **kwargs).data
else:
results[key] = {
'error': _(
'User does not have permission to view this model'
)
}
except Exception as exc: except Exception as exc:
results[key] = {'error': str(exc)} results[key] = {'error': str(exc)}

View File

@@ -21,6 +21,7 @@ from rest_framework import serializers
from rest_framework.exceptions import ValidationError from rest_framework.exceptions import ValidationError
from rest_framework.fields import empty from rest_framework.fields import empty
from rest_framework.mixins import ListModelMixin from rest_framework.mixins import ListModelMixin
from rest_framework.permissions import SAFE_METHODS
from rest_framework.serializers import DecimalField from rest_framework.serializers import DecimalField
from rest_framework.utils import model_meta from rest_framework.utils import model_meta
from taggit.serializers import TaggitSerializer, TagListSerializerField from taggit.serializers import TaggitSerializer, TagListSerializerField
@@ -229,7 +230,7 @@ class FilterableSerializerMixin:
# Skip filtering for a write requests - all fields should be present for data creation # Skip filtering for a write requests - all fields should be present for data creation
if request := self.context.get('request', None): if request := self.context.get('request', None):
if method := getattr(request, 'method', None): if method := getattr(request, 'method', None):
if str(method).lower() in ['post', 'put', 'patch'] and not is_exporting: if method not in SAFE_METHODS and not is_exporting:
return return
# Throw out fields which are not requested (either by default or explicitly) # Throw out fields which are not requested (either by default or explicitly)

View File

@@ -268,11 +268,7 @@ class ManufacturerPartSerializer(
source='part', many=False, read_only=True, allow_null=True source='part', many=False, read_only=True, allow_null=True
), ),
True, True,
prefetch_fields=[ prefetch_fields=['part', 'part__pricing_data', 'part__category'],
Prefetch(
'part', queryset=part.models.Part.objects.select_related('pricing_data')
)
],
) )
pretty_name = enable_filter( pretty_name = enable_filter(
@@ -438,7 +434,7 @@ class SupplierPartSerializer(
label=_('Part'), source='part', many=False, read_only=True, allow_null=True label=_('Part'), source='part', many=False, read_only=True, allow_null=True
), ),
False, False,
prefetch_fields=['part'], prefetch_fields=['part', 'part__pricing_data'],
) )
supplier_detail = enable_filter( supplier_detail = enable_filter(

View File

@@ -2,7 +2,14 @@
from django.urls import reverse from django.urls import reverse
from company.models import Address, Company, Contact, SupplierPart, SupplierPriceBreak from company.models import (
Address,
Company,
Contact,
ManufacturerPart,
SupplierPart,
SupplierPriceBreak,
)
from InvenTree.unit_test import InvenTreeAPITestCase from InvenTree.unit_test import InvenTreeAPITestCase
from part.models import Part from part.models import Part
from users.permissions import check_user_permission from users.permissions import check_user_permission
@@ -498,7 +505,9 @@ class ManufacturerTest(InvenTreeAPITestCase):
def test_manufacturer_part_detail(self): def test_manufacturer_part_detail(self):
"""Tests for the ManufacturerPart detail endpoint.""" """Tests for the ManufacturerPart detail endpoint."""
url = reverse('api-manufacturer-part-detail', kwargs={'pk': 1}) mp = ManufacturerPart.objects.first()
url = reverse('api-manufacturer-part-detail', kwargs={'pk': mp.pk})
response = self.get(url) response = self.get(url)
self.assertEqual(response.data['MPN'], 'MPN123') self.assertEqual(response.data['MPN'], 'MPN123')

View File

@@ -362,7 +362,9 @@ class PurchaseOrderOutputOptions(OutputConfiguration):
class PurchaseOrderMixin(SerializerContextMixin): class PurchaseOrderMixin(SerializerContextMixin):
"""Mixin class for PurchaseOrder endpoints.""" """Mixin class for PurchaseOrder endpoints."""
queryset = models.PurchaseOrder.objects.all() queryset = models.PurchaseOrder.objects.all().prefetch_related(
'supplier', 'created_by'
)
serializer_class = serializers.PurchaseOrderSerializer serializer_class = serializers.PurchaseOrderSerializer
def get_queryset(self, *args, **kwargs): def get_queryset(self, *args, **kwargs):
@@ -371,8 +373,6 @@ class PurchaseOrderMixin(SerializerContextMixin):
queryset = serializers.PurchaseOrderSerializer.annotate_queryset(queryset) queryset = serializers.PurchaseOrderSerializer.annotate_queryset(queryset)
queryset = queryset.prefetch_related('supplier', 'created_by')
return queryset return queryset
@@ -824,15 +824,15 @@ class SalesOrderFilter(OrderFilter):
class SalesOrderMixin(SerializerContextMixin): class SalesOrderMixin(SerializerContextMixin):
"""Mixin class for SalesOrder endpoints.""" """Mixin class for SalesOrder endpoints."""
queryset = models.SalesOrder.objects.all() queryset = models.SalesOrder.objects.all().prefetch_related(
'customer', 'created_by'
)
serializer_class = serializers.SalesOrderSerializer serializer_class = serializers.SalesOrderSerializer
def get_queryset(self, *args, **kwargs): def get_queryset(self, *args, **kwargs):
"""Return annotated queryset for this endpoint.""" """Return annotated queryset for this endpoint."""
queryset = super().get_queryset(*args, **kwargs) queryset = super().get_queryset(*args, **kwargs)
queryset = queryset.prefetch_related('customer', 'created_by')
queryset = serializers.SalesOrderSerializer.annotate_queryset(queryset) queryset = serializers.SalesOrderSerializer.annotate_queryset(queryset)
return queryset return queryset

View File

@@ -1009,7 +1009,9 @@ class PartMixin(SerializerContextMixin):
"""Mixin class for Part API endpoints.""" """Mixin class for Part API endpoints."""
serializer_class = part_serializers.PartSerializer serializer_class = part_serializers.PartSerializer
queryset = Part.objects.all().select_related('pricing_data') queryset = (
Part.objects.all().select_related('pricing_data').prefetch_related('category')
)
starred_parts = None starred_parts = None
is_create = False is_create = False

View File

@@ -489,14 +489,13 @@ class StockItemSerializer(
), ),
'parent', 'parent',
'part__category', 'part__category',
'part__pricing_data',
'supplier_part', 'supplier_part',
'supplier_part__manufacturer_part', 'supplier_part__manufacturer_part',
'customer', 'customer',
'belongs_to', 'belongs_to',
'sales_order', 'sales_order',
'consumed_by', 'consumed_by',
).select_related('part') ).select_related('part', 'part__pricing_data')
# Annotate the queryset with the total allocated to sales orders # Annotate the queryset with the total allocated to sales orders
queryset = queryset.annotate( queryset = queryset.annotate(

View File

@@ -130,7 +130,11 @@ def check_user_role(
def check_user_permission( def check_user_permission(
user: User, model: models.Model, permission: str, allow_inactive: bool = False user: User,
model: models.Model,
permission: str,
allow_inactive: bool = False,
groups: Optional[QuerySet] = None,
) -> bool: ) -> bool:
"""Check if the user has a particular permission against a given model type. """Check if the user has a particular permission against a given model type.
@@ -139,6 +143,7 @@ def check_user_permission(
model: The model class to check (e.g. 'part') model: The model class to check (e.g. 'part')
permission: The permission to check (e.g. 'view' / 'delete') permission: The permission to check (e.g. 'view' / 'delete')
allow_inactive: If False, disallow inactive users from having permissions allow_inactive: If False, disallow inactive users from having permissions
groups: Optional cached queryset of groups to check (defaults to user's groups)
Returns: Returns:
bool: True if the user has the specified permission bool: True if the user has the specified permission
@@ -160,9 +165,11 @@ def check_user_permission(
if table_name in get_ruleset_ignore(): if table_name in get_ruleset_ignore():
return True return True
groups = groups or prefetch_rule_sets(user)
for role, table_names in get_ruleset_models().items(): for role, table_names in get_ruleset_models().items():
if table_name in table_names: if table_name in table_names:
if check_user_role(user, role, permission): if check_user_role(user, role, permission, groups=groups):
return True return True
# Check for children models which inherits from parent role # Check for children models which inherits from parent role
@@ -172,7 +179,7 @@ def check_user_permission(
if parent_child_string == table_name: if parent_child_string == table_name:
# Check if parent role has change permission # Check if parent role has change permission
if check_user_role(user, parent, 'change'): if check_user_role(user, parent, 'change', groups=groups):
return True return True
# Generate the permission name based on the model and permission # Generate the permission name based on the model and permission

View File

@@ -88,3 +88,73 @@ def test_api_options_performance(url):
assert result assert result
assert 'actions' in result assert 'actions' in result
assert len(result['actions']) > 0 assert len(result['actions']) > 0
@pytest.mark.benchmark
@pytest.mark.parametrize(
'key',
[
'all',
'part',
'partcategory',
'supplierpart',
'manufacturerpart',
'stockitem',
'stocklocation',
'build',
'supplier',
'manufacturer',
'customer',
'purchaseorder',
'salesorder',
'salesordershipment',
'returnorder',
],
)
def test_search_performance(key: str):
"""Benchmark the API search performance."""
SEARCH_URL = '/api/search/'
# An indicative search query for various model types
SEARCH_DATA = {
'part': {'active': True},
'partcategory': {},
'supplierpart': {
'part_detail': True,
'supplier_detail': True,
'manufacturer_detail': True,
},
'manufacturerpart': {
'part_detail': True,
'supplier_detail': True,
'manufacturer_detail': True,
},
'stockitem': {'part_detail': True, 'location_detail': True, 'in_stock': True},
'stocklocation': {},
'build': {'part_detail': True},
'supplier': {},
'manufacturer': {},
'customer': {},
'purchaseorder': {'supplier_detail': True, 'outstanding': True},
'salesorder': {'customer_detail': True, 'outstanding': True},
'salesordershipment': {},
'returnorder': {'customer_detail': True, 'outstanding': True},
}
model_types = list(SEARCH_DATA.keys())
search_params = SEARCH_DATA if key == 'all' else {key: SEARCH_DATA[key]}
# Add in a common search term
search_params.update({'search': '0', 'limit': 50})
response = api_client.post(SEARCH_URL, data=search_params)
assert response
if key == 'all':
for model_type in model_types:
assert model_type in response
assert 'error' not in response[model_type]
else:
assert key in response
assert 'error' not in response[key]