mirror of
https://github.com/inventree/InvenTree.git
synced 2026-01-10 05:08:09 +00:00
[API] Search improvements (#11094)
* Improve prefetching * Cache user groups for permission check * Use a GET request to execute search - Prevent forced prefetch - Reduce execution time significantly * Fix group caching * Improve StockItemSerializer - Select related for pricing_data rather than prefetch * Add benchmarking for search endpoint * Adjust prefetch * Ensure no errors returned * Fix prefetch * Fix more prefetch issues * Remove debug print * Fix for performance testing * Data is already returned as dict * Test fix * Extract model types better
This commit is contained in:
@@ -17,6 +17,7 @@ from django_q.models import OrmQ
|
||||
from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema
|
||||
from rest_framework import serializers
|
||||
from rest_framework.generics import GenericAPIView
|
||||
from rest_framework.request import clone_request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.serializers import ValidationError
|
||||
from rest_framework.views import APIView
|
||||
@@ -31,7 +32,7 @@ from InvenTree.mixins import ListCreateAPI
|
||||
from InvenTree.sso import sso_registration_enabled
|
||||
from plugin.serializers import MetadataSerializer
|
||||
from users.models import ApiToken
|
||||
from users.permissions import check_user_permission
|
||||
from users.permissions import check_user_permission, prefetch_rule_sets
|
||||
|
||||
from .helpers import plugins_info
|
||||
from .helpers_email import is_email_configured
|
||||
@@ -767,6 +768,13 @@ class APISearchView(GenericAPIView):
|
||||
|
||||
search_filters = self.get_result_filters()
|
||||
|
||||
# Create a clone of the request object to modify
|
||||
# Use GET method for the individual list views
|
||||
cloned_request = clone_request(request, 'GET')
|
||||
|
||||
# Fetch and cache all groups associated with the current user
|
||||
groups = prefetch_rule_sets(request.user)
|
||||
|
||||
for key, cls in self.get_result_types().items():
|
||||
# Only return results which are specifically requested
|
||||
if key in data:
|
||||
@@ -790,22 +798,23 @@ class APISearchView(GenericAPIView):
|
||||
view = cls()
|
||||
|
||||
# Override regular query params with specific ones for this search request
|
||||
request._request.GET = params
|
||||
view.request = request
|
||||
cloned_request._request.GET = params
|
||||
view.request = cloned_request
|
||||
view.format_kwarg = 'format'
|
||||
|
||||
# Check permissions and update results dict with particular query
|
||||
model = view.serializer_class.Meta.model
|
||||
|
||||
try:
|
||||
if check_user_permission(request.user, model, 'view'):
|
||||
results[key] = view.list(request, *args, **kwargs).data
|
||||
else:
|
||||
if not check_user_permission(
|
||||
request.user, model, 'view', groups=groups
|
||||
):
|
||||
results[key] = {
|
||||
'error': _(
|
||||
'User does not have permission to view this model'
|
||||
)
|
||||
'error': _('User does not have permission to view this model')
|
||||
}
|
||||
continue
|
||||
|
||||
try:
|
||||
results[key] = view.list(request, *args, **kwargs).data
|
||||
except Exception as exc:
|
||||
results[key] = {'error': str(exc)}
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ from rest_framework import serializers
|
||||
from rest_framework.exceptions import ValidationError
|
||||
from rest_framework.fields import empty
|
||||
from rest_framework.mixins import ListModelMixin
|
||||
from rest_framework.permissions import SAFE_METHODS
|
||||
from rest_framework.serializers import DecimalField
|
||||
from rest_framework.utils import model_meta
|
||||
from taggit.serializers import TaggitSerializer, TagListSerializerField
|
||||
@@ -229,7 +230,7 @@ class FilterableSerializerMixin:
|
||||
# Skip filtering for a write requests - all fields should be present for data creation
|
||||
if request := self.context.get('request', None):
|
||||
if method := getattr(request, 'method', None):
|
||||
if str(method).lower() in ['post', 'put', 'patch'] and not is_exporting:
|
||||
if method not in SAFE_METHODS and not is_exporting:
|
||||
return
|
||||
|
||||
# Throw out fields which are not requested (either by default or explicitly)
|
||||
|
||||
@@ -268,11 +268,7 @@ class ManufacturerPartSerializer(
|
||||
source='part', many=False, read_only=True, allow_null=True
|
||||
),
|
||||
True,
|
||||
prefetch_fields=[
|
||||
Prefetch(
|
||||
'part', queryset=part.models.Part.objects.select_related('pricing_data')
|
||||
)
|
||||
],
|
||||
prefetch_fields=['part', 'part__pricing_data', 'part__category'],
|
||||
)
|
||||
|
||||
pretty_name = enable_filter(
|
||||
@@ -438,7 +434,7 @@ class SupplierPartSerializer(
|
||||
label=_('Part'), source='part', many=False, read_only=True, allow_null=True
|
||||
),
|
||||
False,
|
||||
prefetch_fields=['part'],
|
||||
prefetch_fields=['part', 'part__pricing_data'],
|
||||
)
|
||||
|
||||
supplier_detail = enable_filter(
|
||||
|
||||
@@ -2,7 +2,14 @@
|
||||
|
||||
from django.urls import reverse
|
||||
|
||||
from company.models import Address, Company, Contact, SupplierPart, SupplierPriceBreak
|
||||
from company.models import (
|
||||
Address,
|
||||
Company,
|
||||
Contact,
|
||||
ManufacturerPart,
|
||||
SupplierPart,
|
||||
SupplierPriceBreak,
|
||||
)
|
||||
from InvenTree.unit_test import InvenTreeAPITestCase
|
||||
from part.models import Part
|
||||
from users.permissions import check_user_permission
|
||||
@@ -498,7 +505,9 @@ class ManufacturerTest(InvenTreeAPITestCase):
|
||||
|
||||
def test_manufacturer_part_detail(self):
|
||||
"""Tests for the ManufacturerPart detail endpoint."""
|
||||
url = reverse('api-manufacturer-part-detail', kwargs={'pk': 1})
|
||||
mp = ManufacturerPart.objects.first()
|
||||
|
||||
url = reverse('api-manufacturer-part-detail', kwargs={'pk': mp.pk})
|
||||
|
||||
response = self.get(url)
|
||||
self.assertEqual(response.data['MPN'], 'MPN123')
|
||||
|
||||
@@ -362,7 +362,9 @@ class PurchaseOrderOutputOptions(OutputConfiguration):
|
||||
class PurchaseOrderMixin(SerializerContextMixin):
|
||||
"""Mixin class for PurchaseOrder endpoints."""
|
||||
|
||||
queryset = models.PurchaseOrder.objects.all()
|
||||
queryset = models.PurchaseOrder.objects.all().prefetch_related(
|
||||
'supplier', 'created_by'
|
||||
)
|
||||
serializer_class = serializers.PurchaseOrderSerializer
|
||||
|
||||
def get_queryset(self, *args, **kwargs):
|
||||
@@ -371,8 +373,6 @@ class PurchaseOrderMixin(SerializerContextMixin):
|
||||
|
||||
queryset = serializers.PurchaseOrderSerializer.annotate_queryset(queryset)
|
||||
|
||||
queryset = queryset.prefetch_related('supplier', 'created_by')
|
||||
|
||||
return queryset
|
||||
|
||||
|
||||
@@ -824,15 +824,15 @@ class SalesOrderFilter(OrderFilter):
|
||||
class SalesOrderMixin(SerializerContextMixin):
|
||||
"""Mixin class for SalesOrder endpoints."""
|
||||
|
||||
queryset = models.SalesOrder.objects.all()
|
||||
queryset = models.SalesOrder.objects.all().prefetch_related(
|
||||
'customer', 'created_by'
|
||||
)
|
||||
serializer_class = serializers.SalesOrderSerializer
|
||||
|
||||
def get_queryset(self, *args, **kwargs):
|
||||
"""Return annotated queryset for this endpoint."""
|
||||
queryset = super().get_queryset(*args, **kwargs)
|
||||
|
||||
queryset = queryset.prefetch_related('customer', 'created_by')
|
||||
|
||||
queryset = serializers.SalesOrderSerializer.annotate_queryset(queryset)
|
||||
|
||||
return queryset
|
||||
|
||||
@@ -1009,7 +1009,9 @@ class PartMixin(SerializerContextMixin):
|
||||
"""Mixin class for Part API endpoints."""
|
||||
|
||||
serializer_class = part_serializers.PartSerializer
|
||||
queryset = Part.objects.all().select_related('pricing_data')
|
||||
queryset = (
|
||||
Part.objects.all().select_related('pricing_data').prefetch_related('category')
|
||||
)
|
||||
|
||||
starred_parts = None
|
||||
is_create = False
|
||||
|
||||
@@ -489,14 +489,13 @@ class StockItemSerializer(
|
||||
),
|
||||
'parent',
|
||||
'part__category',
|
||||
'part__pricing_data',
|
||||
'supplier_part',
|
||||
'supplier_part__manufacturer_part',
|
||||
'customer',
|
||||
'belongs_to',
|
||||
'sales_order',
|
||||
'consumed_by',
|
||||
).select_related('part')
|
||||
).select_related('part', 'part__pricing_data')
|
||||
|
||||
# Annotate the queryset with the total allocated to sales orders
|
||||
queryset = queryset.annotate(
|
||||
|
||||
@@ -130,7 +130,11 @@ def check_user_role(
|
||||
|
||||
|
||||
def check_user_permission(
|
||||
user: User, model: models.Model, permission: str, allow_inactive: bool = False
|
||||
user: User,
|
||||
model: models.Model,
|
||||
permission: str,
|
||||
allow_inactive: bool = False,
|
||||
groups: Optional[QuerySet] = None,
|
||||
) -> bool:
|
||||
"""Check if the user has a particular permission against a given model type.
|
||||
|
||||
@@ -139,6 +143,7 @@ def check_user_permission(
|
||||
model: The model class to check (e.g. 'part')
|
||||
permission: The permission to check (e.g. 'view' / 'delete')
|
||||
allow_inactive: If False, disallow inactive users from having permissions
|
||||
groups: Optional cached queryset of groups to check (defaults to user's groups)
|
||||
|
||||
Returns:
|
||||
bool: True if the user has the specified permission
|
||||
@@ -160,9 +165,11 @@ def check_user_permission(
|
||||
if table_name in get_ruleset_ignore():
|
||||
return True
|
||||
|
||||
groups = groups or prefetch_rule_sets(user)
|
||||
|
||||
for role, table_names in get_ruleset_models().items():
|
||||
if table_name in table_names:
|
||||
if check_user_role(user, role, permission):
|
||||
if check_user_role(user, role, permission, groups=groups):
|
||||
return True
|
||||
|
||||
# Check for children models which inherits from parent role
|
||||
@@ -172,7 +179,7 @@ def check_user_permission(
|
||||
|
||||
if parent_child_string == table_name:
|
||||
# Check if parent role has change permission
|
||||
if check_user_role(user, parent, 'change'):
|
||||
if check_user_role(user, parent, 'change', groups=groups):
|
||||
return True
|
||||
|
||||
# Generate the permission name based on the model and permission
|
||||
|
||||
@@ -88,3 +88,73 @@ def test_api_options_performance(url):
|
||||
assert result
|
||||
assert 'actions' in result
|
||||
assert len(result['actions']) > 0
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
@pytest.mark.parametrize(
|
||||
'key',
|
||||
[
|
||||
'all',
|
||||
'part',
|
||||
'partcategory',
|
||||
'supplierpart',
|
||||
'manufacturerpart',
|
||||
'stockitem',
|
||||
'stocklocation',
|
||||
'build',
|
||||
'supplier',
|
||||
'manufacturer',
|
||||
'customer',
|
||||
'purchaseorder',
|
||||
'salesorder',
|
||||
'salesordershipment',
|
||||
'returnorder',
|
||||
],
|
||||
)
|
||||
def test_search_performance(key: str):
|
||||
"""Benchmark the API search performance."""
|
||||
SEARCH_URL = '/api/search/'
|
||||
|
||||
# An indicative search query for various model types
|
||||
SEARCH_DATA = {
|
||||
'part': {'active': True},
|
||||
'partcategory': {},
|
||||
'supplierpart': {
|
||||
'part_detail': True,
|
||||
'supplier_detail': True,
|
||||
'manufacturer_detail': True,
|
||||
},
|
||||
'manufacturerpart': {
|
||||
'part_detail': True,
|
||||
'supplier_detail': True,
|
||||
'manufacturer_detail': True,
|
||||
},
|
||||
'stockitem': {'part_detail': True, 'location_detail': True, 'in_stock': True},
|
||||
'stocklocation': {},
|
||||
'build': {'part_detail': True},
|
||||
'supplier': {},
|
||||
'manufacturer': {},
|
||||
'customer': {},
|
||||
'purchaseorder': {'supplier_detail': True, 'outstanding': True},
|
||||
'salesorder': {'customer_detail': True, 'outstanding': True},
|
||||
'salesordershipment': {},
|
||||
'returnorder': {'customer_detail': True, 'outstanding': True},
|
||||
}
|
||||
|
||||
model_types = list(SEARCH_DATA.keys())
|
||||
|
||||
search_params = SEARCH_DATA if key == 'all' else {key: SEARCH_DATA[key]}
|
||||
|
||||
# Add in a common search term
|
||||
search_params.update({'search': '0', 'limit': 50})
|
||||
|
||||
response = api_client.post(SEARCH_URL, data=search_params)
|
||||
assert response
|
||||
|
||||
if key == 'all':
|
||||
for model_type in model_types:
|
||||
assert model_type in response
|
||||
assert 'error' not in response[model_type]
|
||||
else:
|
||||
assert key in response
|
||||
assert 'error' not in response[key]
|
||||
|
||||
Reference in New Issue
Block a user