2
0
mirror of https://github.com/inventree/InvenTree.git synced 2026-01-10 05:08:09 +00:00

[API] Search improvements (#11094)

* Improve prefetching

* Cache user groups for permission  check

* Use a GET request to execute search

- Prevent forced prefetch
- Reduce execution time significantly

* Fix group caching

* Improve StockItemSerializer

- Select related for pricing_data rather than prefetch

* Add benchmarking for search endpoint

* Adjust prefetch

* Ensure no errors returned

* Fix prefetch

* Fix more prefetch issues

* Remove debug print

* Fix for performance testing

* Data is already returned as dict

* Test fix

* Extract model types better
This commit is contained in:
Oliver
2026-01-08 18:06:23 +11:00
committed by GitHub
parent 2457197446
commit 4709dc8a9a
9 changed files with 125 additions and 32 deletions

View File

@@ -17,6 +17,7 @@ from django_q.models import OrmQ
from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema
from rest_framework import serializers
from rest_framework.generics import GenericAPIView
from rest_framework.request import clone_request
from rest_framework.response import Response
from rest_framework.serializers import ValidationError
from rest_framework.views import APIView
@@ -31,7 +32,7 @@ from InvenTree.mixins import ListCreateAPI
from InvenTree.sso import sso_registration_enabled
from plugin.serializers import MetadataSerializer
from users.models import ApiToken
from users.permissions import check_user_permission
from users.permissions import check_user_permission, prefetch_rule_sets
from .helpers import plugins_info
from .helpers_email import is_email_configured
@@ -767,6 +768,13 @@ class APISearchView(GenericAPIView):
search_filters = self.get_result_filters()
# Create a clone of the request object to modify
# Use GET method for the individual list views
cloned_request = clone_request(request, 'GET')
# Fetch and cache all groups associated with the current user
groups = prefetch_rule_sets(request.user)
for key, cls in self.get_result_types().items():
# Only return results which are specifically requested
if key in data:
@@ -790,22 +798,23 @@ class APISearchView(GenericAPIView):
view = cls()
# Override regular query params with specific ones for this search request
request._request.GET = params
view.request = request
cloned_request._request.GET = params
view.request = cloned_request
view.format_kwarg = 'format'
# Check permissions and update results dict with particular query
model = view.serializer_class.Meta.model
if not check_user_permission(
request.user, model, 'view', groups=groups
):
results[key] = {
'error': _('User does not have permission to view this model')
}
continue
try:
if check_user_permission(request.user, model, 'view'):
results[key] = view.list(request, *args, **kwargs).data
else:
results[key] = {
'error': _(
'User does not have permission to view this model'
)
}
results[key] = view.list(request, *args, **kwargs).data
except Exception as exc:
results[key] = {'error': str(exc)}

View File

@@ -21,6 +21,7 @@ from rest_framework import serializers
from rest_framework.exceptions import ValidationError
from rest_framework.fields import empty
from rest_framework.mixins import ListModelMixin
from rest_framework.permissions import SAFE_METHODS
from rest_framework.serializers import DecimalField
from rest_framework.utils import model_meta
from taggit.serializers import TaggitSerializer, TagListSerializerField
@@ -229,7 +230,7 @@ class FilterableSerializerMixin:
# Skip filtering for a write requests - all fields should be present for data creation
if request := self.context.get('request', None):
if method := getattr(request, 'method', None):
if str(method).lower() in ['post', 'put', 'patch'] and not is_exporting:
if method not in SAFE_METHODS and not is_exporting:
return
# Throw out fields which are not requested (either by default or explicitly)

View File

@@ -268,11 +268,7 @@ class ManufacturerPartSerializer(
source='part', many=False, read_only=True, allow_null=True
),
True,
prefetch_fields=[
Prefetch(
'part', queryset=part.models.Part.objects.select_related('pricing_data')
)
],
prefetch_fields=['part', 'part__pricing_data', 'part__category'],
)
pretty_name = enable_filter(
@@ -438,7 +434,7 @@ class SupplierPartSerializer(
label=_('Part'), source='part', many=False, read_only=True, allow_null=True
),
False,
prefetch_fields=['part'],
prefetch_fields=['part', 'part__pricing_data'],
)
supplier_detail = enable_filter(

View File

@@ -2,7 +2,14 @@
from django.urls import reverse
from company.models import Address, Company, Contact, SupplierPart, SupplierPriceBreak
from company.models import (
Address,
Company,
Contact,
ManufacturerPart,
SupplierPart,
SupplierPriceBreak,
)
from InvenTree.unit_test import InvenTreeAPITestCase
from part.models import Part
from users.permissions import check_user_permission
@@ -498,7 +505,9 @@ class ManufacturerTest(InvenTreeAPITestCase):
def test_manufacturer_part_detail(self):
"""Tests for the ManufacturerPart detail endpoint."""
url = reverse('api-manufacturer-part-detail', kwargs={'pk': 1})
mp = ManufacturerPart.objects.first()
url = reverse('api-manufacturer-part-detail', kwargs={'pk': mp.pk})
response = self.get(url)
self.assertEqual(response.data['MPN'], 'MPN123')

View File

@@ -362,7 +362,9 @@ class PurchaseOrderOutputOptions(OutputConfiguration):
class PurchaseOrderMixin(SerializerContextMixin):
"""Mixin class for PurchaseOrder endpoints."""
queryset = models.PurchaseOrder.objects.all()
queryset = models.PurchaseOrder.objects.all().prefetch_related(
'supplier', 'created_by'
)
serializer_class = serializers.PurchaseOrderSerializer
def get_queryset(self, *args, **kwargs):
@@ -371,8 +373,6 @@ class PurchaseOrderMixin(SerializerContextMixin):
queryset = serializers.PurchaseOrderSerializer.annotate_queryset(queryset)
queryset = queryset.prefetch_related('supplier', 'created_by')
return queryset
@@ -824,15 +824,15 @@ class SalesOrderFilter(OrderFilter):
class SalesOrderMixin(SerializerContextMixin):
"""Mixin class for SalesOrder endpoints."""
queryset = models.SalesOrder.objects.all()
queryset = models.SalesOrder.objects.all().prefetch_related(
'customer', 'created_by'
)
serializer_class = serializers.SalesOrderSerializer
def get_queryset(self, *args, **kwargs):
"""Return annotated queryset for this endpoint."""
queryset = super().get_queryset(*args, **kwargs)
queryset = queryset.prefetch_related('customer', 'created_by')
queryset = serializers.SalesOrderSerializer.annotate_queryset(queryset)
return queryset

View File

@@ -1009,7 +1009,9 @@ class PartMixin(SerializerContextMixin):
"""Mixin class for Part API endpoints."""
serializer_class = part_serializers.PartSerializer
queryset = Part.objects.all().select_related('pricing_data')
queryset = (
Part.objects.all().select_related('pricing_data').prefetch_related('category')
)
starred_parts = None
is_create = False

View File

@@ -489,14 +489,13 @@ class StockItemSerializer(
),
'parent',
'part__category',
'part__pricing_data',
'supplier_part',
'supplier_part__manufacturer_part',
'customer',
'belongs_to',
'sales_order',
'consumed_by',
).select_related('part')
).select_related('part', 'part__pricing_data')
# Annotate the queryset with the total allocated to sales orders
queryset = queryset.annotate(

View File

@@ -130,7 +130,11 @@ def check_user_role(
def check_user_permission(
user: User, model: models.Model, permission: str, allow_inactive: bool = False
user: User,
model: models.Model,
permission: str,
allow_inactive: bool = False,
groups: Optional[QuerySet] = None,
) -> bool:
"""Check if the user has a particular permission against a given model type.
@@ -139,6 +143,7 @@ def check_user_permission(
model: The model class to check (e.g. 'part')
permission: The permission to check (e.g. 'view' / 'delete')
allow_inactive: If False, disallow inactive users from having permissions
groups: Optional cached queryset of groups to check (defaults to user's groups)
Returns:
bool: True if the user has the specified permission
@@ -160,9 +165,11 @@ def check_user_permission(
if table_name in get_ruleset_ignore():
return True
groups = groups or prefetch_rule_sets(user)
for role, table_names in get_ruleset_models().items():
if table_name in table_names:
if check_user_role(user, role, permission):
if check_user_role(user, role, permission, groups=groups):
return True
# Check for children models which inherits from parent role
@@ -172,7 +179,7 @@ def check_user_permission(
if parent_child_string == table_name:
# Check if parent role has change permission
if check_user_role(user, parent, 'change'):
if check_user_role(user, parent, 'change', groups=groups):
return True
# Generate the permission name based on the model and permission

View File

@@ -88,3 +88,73 @@ def test_api_options_performance(url):
assert result
assert 'actions' in result
assert len(result['actions']) > 0
@pytest.mark.benchmark
@pytest.mark.parametrize(
'key',
[
'all',
'part',
'partcategory',
'supplierpart',
'manufacturerpart',
'stockitem',
'stocklocation',
'build',
'supplier',
'manufacturer',
'customer',
'purchaseorder',
'salesorder',
'salesordershipment',
'returnorder',
],
)
def test_search_performance(key: str):
"""Benchmark the API search performance."""
SEARCH_URL = '/api/search/'
# An indicative search query for various model types
SEARCH_DATA = {
'part': {'active': True},
'partcategory': {},
'supplierpart': {
'part_detail': True,
'supplier_detail': True,
'manufacturer_detail': True,
},
'manufacturerpart': {
'part_detail': True,
'supplier_detail': True,
'manufacturer_detail': True,
},
'stockitem': {'part_detail': True, 'location_detail': True, 'in_stock': True},
'stocklocation': {},
'build': {'part_detail': True},
'supplier': {},
'manufacturer': {},
'customer': {},
'purchaseorder': {'supplier_detail': True, 'outstanding': True},
'salesorder': {'customer_detail': True, 'outstanding': True},
'salesordershipment': {},
'returnorder': {'customer_detail': True, 'outstanding': True},
}
model_types = list(SEARCH_DATA.keys())
search_params = SEARCH_DATA if key == 'all' else {key: SEARCH_DATA[key]}
# Add in a common search term
search_params.update({'search': '0', 'limit': 50})
response = api_client.post(SEARCH_URL, data=search_params)
assert response
if key == 'all':
for model_type in model_types:
assert model_type in response
assert 'error' not in response[model_type]
else:
assert key in response
assert 'error' not in response[key]