mirror of
https://github.com/inventree/InvenTree.git
synced 2026-01-10 05:08:09 +00:00
[API] Search improvements (#11094)
* Improve prefetching * Cache user groups for permission check * Use a GET request to execute search - Prevent forced prefetch - Reduce execution time significantly * Fix group caching * Improve StockItemSerializer - Select related for pricing_data rather than prefetch * Add benchmarking for search endpoint * Adjust prefetch * Ensure no errors returned * Fix prefetch * Fix more prefetch issues * Remove debug print * Fix for performance testing * Data is already returned as dict * Test fix * Extract model types better
This commit is contained in:
@@ -17,6 +17,7 @@ from django_q.models import OrmQ
|
|||||||
from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema
|
from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema
|
||||||
from rest_framework import serializers
|
from rest_framework import serializers
|
||||||
from rest_framework.generics import GenericAPIView
|
from rest_framework.generics import GenericAPIView
|
||||||
|
from rest_framework.request import clone_request
|
||||||
from rest_framework.response import Response
|
from rest_framework.response import Response
|
||||||
from rest_framework.serializers import ValidationError
|
from rest_framework.serializers import ValidationError
|
||||||
from rest_framework.views import APIView
|
from rest_framework.views import APIView
|
||||||
@@ -31,7 +32,7 @@ from InvenTree.mixins import ListCreateAPI
|
|||||||
from InvenTree.sso import sso_registration_enabled
|
from InvenTree.sso import sso_registration_enabled
|
||||||
from plugin.serializers import MetadataSerializer
|
from plugin.serializers import MetadataSerializer
|
||||||
from users.models import ApiToken
|
from users.models import ApiToken
|
||||||
from users.permissions import check_user_permission
|
from users.permissions import check_user_permission, prefetch_rule_sets
|
||||||
|
|
||||||
from .helpers import plugins_info
|
from .helpers import plugins_info
|
||||||
from .helpers_email import is_email_configured
|
from .helpers_email import is_email_configured
|
||||||
@@ -767,6 +768,13 @@ class APISearchView(GenericAPIView):
|
|||||||
|
|
||||||
search_filters = self.get_result_filters()
|
search_filters = self.get_result_filters()
|
||||||
|
|
||||||
|
# Create a clone of the request object to modify
|
||||||
|
# Use GET method for the individual list views
|
||||||
|
cloned_request = clone_request(request, 'GET')
|
||||||
|
|
||||||
|
# Fetch and cache all groups associated with the current user
|
||||||
|
groups = prefetch_rule_sets(request.user)
|
||||||
|
|
||||||
for key, cls in self.get_result_types().items():
|
for key, cls in self.get_result_types().items():
|
||||||
# Only return results which are specifically requested
|
# Only return results which are specifically requested
|
||||||
if key in data:
|
if key in data:
|
||||||
@@ -790,22 +798,23 @@ class APISearchView(GenericAPIView):
|
|||||||
view = cls()
|
view = cls()
|
||||||
|
|
||||||
# Override regular query params with specific ones for this search request
|
# Override regular query params with specific ones for this search request
|
||||||
request._request.GET = params
|
cloned_request._request.GET = params
|
||||||
view.request = request
|
view.request = cloned_request
|
||||||
view.format_kwarg = 'format'
|
view.format_kwarg = 'format'
|
||||||
|
|
||||||
# Check permissions and update results dict with particular query
|
# Check permissions and update results dict with particular query
|
||||||
model = view.serializer_class.Meta.model
|
model = view.serializer_class.Meta.model
|
||||||
|
|
||||||
|
if not check_user_permission(
|
||||||
|
request.user, model, 'view', groups=groups
|
||||||
|
):
|
||||||
|
results[key] = {
|
||||||
|
'error': _('User does not have permission to view this model')
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if check_user_permission(request.user, model, 'view'):
|
results[key] = view.list(request, *args, **kwargs).data
|
||||||
results[key] = view.list(request, *args, **kwargs).data
|
|
||||||
else:
|
|
||||||
results[key] = {
|
|
||||||
'error': _(
|
|
||||||
'User does not have permission to view this model'
|
|
||||||
)
|
|
||||||
}
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
results[key] = {'error': str(exc)}
|
results[key] = {'error': str(exc)}
|
||||||
|
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ from rest_framework import serializers
|
|||||||
from rest_framework.exceptions import ValidationError
|
from rest_framework.exceptions import ValidationError
|
||||||
from rest_framework.fields import empty
|
from rest_framework.fields import empty
|
||||||
from rest_framework.mixins import ListModelMixin
|
from rest_framework.mixins import ListModelMixin
|
||||||
|
from rest_framework.permissions import SAFE_METHODS
|
||||||
from rest_framework.serializers import DecimalField
|
from rest_framework.serializers import DecimalField
|
||||||
from rest_framework.utils import model_meta
|
from rest_framework.utils import model_meta
|
||||||
from taggit.serializers import TaggitSerializer, TagListSerializerField
|
from taggit.serializers import TaggitSerializer, TagListSerializerField
|
||||||
@@ -229,7 +230,7 @@ class FilterableSerializerMixin:
|
|||||||
# Skip filtering for a write requests - all fields should be present for data creation
|
# Skip filtering for a write requests - all fields should be present for data creation
|
||||||
if request := self.context.get('request', None):
|
if request := self.context.get('request', None):
|
||||||
if method := getattr(request, 'method', None):
|
if method := getattr(request, 'method', None):
|
||||||
if str(method).lower() in ['post', 'put', 'patch'] and not is_exporting:
|
if method not in SAFE_METHODS and not is_exporting:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Throw out fields which are not requested (either by default or explicitly)
|
# Throw out fields which are not requested (either by default or explicitly)
|
||||||
|
|||||||
@@ -268,11 +268,7 @@ class ManufacturerPartSerializer(
|
|||||||
source='part', many=False, read_only=True, allow_null=True
|
source='part', many=False, read_only=True, allow_null=True
|
||||||
),
|
),
|
||||||
True,
|
True,
|
||||||
prefetch_fields=[
|
prefetch_fields=['part', 'part__pricing_data', 'part__category'],
|
||||||
Prefetch(
|
|
||||||
'part', queryset=part.models.Part.objects.select_related('pricing_data')
|
|
||||||
)
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
pretty_name = enable_filter(
|
pretty_name = enable_filter(
|
||||||
@@ -438,7 +434,7 @@ class SupplierPartSerializer(
|
|||||||
label=_('Part'), source='part', many=False, read_only=True, allow_null=True
|
label=_('Part'), source='part', many=False, read_only=True, allow_null=True
|
||||||
),
|
),
|
||||||
False,
|
False,
|
||||||
prefetch_fields=['part'],
|
prefetch_fields=['part', 'part__pricing_data'],
|
||||||
)
|
)
|
||||||
|
|
||||||
supplier_detail = enable_filter(
|
supplier_detail = enable_filter(
|
||||||
|
|||||||
@@ -2,7 +2,14 @@
|
|||||||
|
|
||||||
from django.urls import reverse
|
from django.urls import reverse
|
||||||
|
|
||||||
from company.models import Address, Company, Contact, SupplierPart, SupplierPriceBreak
|
from company.models import (
|
||||||
|
Address,
|
||||||
|
Company,
|
||||||
|
Contact,
|
||||||
|
ManufacturerPart,
|
||||||
|
SupplierPart,
|
||||||
|
SupplierPriceBreak,
|
||||||
|
)
|
||||||
from InvenTree.unit_test import InvenTreeAPITestCase
|
from InvenTree.unit_test import InvenTreeAPITestCase
|
||||||
from part.models import Part
|
from part.models import Part
|
||||||
from users.permissions import check_user_permission
|
from users.permissions import check_user_permission
|
||||||
@@ -498,7 +505,9 @@ class ManufacturerTest(InvenTreeAPITestCase):
|
|||||||
|
|
||||||
def test_manufacturer_part_detail(self):
|
def test_manufacturer_part_detail(self):
|
||||||
"""Tests for the ManufacturerPart detail endpoint."""
|
"""Tests for the ManufacturerPart detail endpoint."""
|
||||||
url = reverse('api-manufacturer-part-detail', kwargs={'pk': 1})
|
mp = ManufacturerPart.objects.first()
|
||||||
|
|
||||||
|
url = reverse('api-manufacturer-part-detail', kwargs={'pk': mp.pk})
|
||||||
|
|
||||||
response = self.get(url)
|
response = self.get(url)
|
||||||
self.assertEqual(response.data['MPN'], 'MPN123')
|
self.assertEqual(response.data['MPN'], 'MPN123')
|
||||||
|
|||||||
@@ -362,7 +362,9 @@ class PurchaseOrderOutputOptions(OutputConfiguration):
|
|||||||
class PurchaseOrderMixin(SerializerContextMixin):
|
class PurchaseOrderMixin(SerializerContextMixin):
|
||||||
"""Mixin class for PurchaseOrder endpoints."""
|
"""Mixin class for PurchaseOrder endpoints."""
|
||||||
|
|
||||||
queryset = models.PurchaseOrder.objects.all()
|
queryset = models.PurchaseOrder.objects.all().prefetch_related(
|
||||||
|
'supplier', 'created_by'
|
||||||
|
)
|
||||||
serializer_class = serializers.PurchaseOrderSerializer
|
serializer_class = serializers.PurchaseOrderSerializer
|
||||||
|
|
||||||
def get_queryset(self, *args, **kwargs):
|
def get_queryset(self, *args, **kwargs):
|
||||||
@@ -371,8 +373,6 @@ class PurchaseOrderMixin(SerializerContextMixin):
|
|||||||
|
|
||||||
queryset = serializers.PurchaseOrderSerializer.annotate_queryset(queryset)
|
queryset = serializers.PurchaseOrderSerializer.annotate_queryset(queryset)
|
||||||
|
|
||||||
queryset = queryset.prefetch_related('supplier', 'created_by')
|
|
||||||
|
|
||||||
return queryset
|
return queryset
|
||||||
|
|
||||||
|
|
||||||
@@ -824,15 +824,15 @@ class SalesOrderFilter(OrderFilter):
|
|||||||
class SalesOrderMixin(SerializerContextMixin):
|
class SalesOrderMixin(SerializerContextMixin):
|
||||||
"""Mixin class for SalesOrder endpoints."""
|
"""Mixin class for SalesOrder endpoints."""
|
||||||
|
|
||||||
queryset = models.SalesOrder.objects.all()
|
queryset = models.SalesOrder.objects.all().prefetch_related(
|
||||||
|
'customer', 'created_by'
|
||||||
|
)
|
||||||
serializer_class = serializers.SalesOrderSerializer
|
serializer_class = serializers.SalesOrderSerializer
|
||||||
|
|
||||||
def get_queryset(self, *args, **kwargs):
|
def get_queryset(self, *args, **kwargs):
|
||||||
"""Return annotated queryset for this endpoint."""
|
"""Return annotated queryset for this endpoint."""
|
||||||
queryset = super().get_queryset(*args, **kwargs)
|
queryset = super().get_queryset(*args, **kwargs)
|
||||||
|
|
||||||
queryset = queryset.prefetch_related('customer', 'created_by')
|
|
||||||
|
|
||||||
queryset = serializers.SalesOrderSerializer.annotate_queryset(queryset)
|
queryset = serializers.SalesOrderSerializer.annotate_queryset(queryset)
|
||||||
|
|
||||||
return queryset
|
return queryset
|
||||||
|
|||||||
@@ -1009,7 +1009,9 @@ class PartMixin(SerializerContextMixin):
|
|||||||
"""Mixin class for Part API endpoints."""
|
"""Mixin class for Part API endpoints."""
|
||||||
|
|
||||||
serializer_class = part_serializers.PartSerializer
|
serializer_class = part_serializers.PartSerializer
|
||||||
queryset = Part.objects.all().select_related('pricing_data')
|
queryset = (
|
||||||
|
Part.objects.all().select_related('pricing_data').prefetch_related('category')
|
||||||
|
)
|
||||||
|
|
||||||
starred_parts = None
|
starred_parts = None
|
||||||
is_create = False
|
is_create = False
|
||||||
|
|||||||
@@ -489,14 +489,13 @@ class StockItemSerializer(
|
|||||||
),
|
),
|
||||||
'parent',
|
'parent',
|
||||||
'part__category',
|
'part__category',
|
||||||
'part__pricing_data',
|
|
||||||
'supplier_part',
|
'supplier_part',
|
||||||
'supplier_part__manufacturer_part',
|
'supplier_part__manufacturer_part',
|
||||||
'customer',
|
'customer',
|
||||||
'belongs_to',
|
'belongs_to',
|
||||||
'sales_order',
|
'sales_order',
|
||||||
'consumed_by',
|
'consumed_by',
|
||||||
).select_related('part')
|
).select_related('part', 'part__pricing_data')
|
||||||
|
|
||||||
# Annotate the queryset with the total allocated to sales orders
|
# Annotate the queryset with the total allocated to sales orders
|
||||||
queryset = queryset.annotate(
|
queryset = queryset.annotate(
|
||||||
|
|||||||
@@ -130,7 +130,11 @@ def check_user_role(
|
|||||||
|
|
||||||
|
|
||||||
def check_user_permission(
|
def check_user_permission(
|
||||||
user: User, model: models.Model, permission: str, allow_inactive: bool = False
|
user: User,
|
||||||
|
model: models.Model,
|
||||||
|
permission: str,
|
||||||
|
allow_inactive: bool = False,
|
||||||
|
groups: Optional[QuerySet] = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Check if the user has a particular permission against a given model type.
|
"""Check if the user has a particular permission against a given model type.
|
||||||
|
|
||||||
@@ -139,6 +143,7 @@ def check_user_permission(
|
|||||||
model: The model class to check (e.g. 'part')
|
model: The model class to check (e.g. 'part')
|
||||||
permission: The permission to check (e.g. 'view' / 'delete')
|
permission: The permission to check (e.g. 'view' / 'delete')
|
||||||
allow_inactive: If False, disallow inactive users from having permissions
|
allow_inactive: If False, disallow inactive users from having permissions
|
||||||
|
groups: Optional cached queryset of groups to check (defaults to user's groups)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: True if the user has the specified permission
|
bool: True if the user has the specified permission
|
||||||
@@ -160,9 +165,11 @@ def check_user_permission(
|
|||||||
if table_name in get_ruleset_ignore():
|
if table_name in get_ruleset_ignore():
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
groups = groups or prefetch_rule_sets(user)
|
||||||
|
|
||||||
for role, table_names in get_ruleset_models().items():
|
for role, table_names in get_ruleset_models().items():
|
||||||
if table_name in table_names:
|
if table_name in table_names:
|
||||||
if check_user_role(user, role, permission):
|
if check_user_role(user, role, permission, groups=groups):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Check for children models which inherits from parent role
|
# Check for children models which inherits from parent role
|
||||||
@@ -172,7 +179,7 @@ def check_user_permission(
|
|||||||
|
|
||||||
if parent_child_string == table_name:
|
if parent_child_string == table_name:
|
||||||
# Check if parent role has change permission
|
# Check if parent role has change permission
|
||||||
if check_user_role(user, parent, 'change'):
|
if check_user_role(user, parent, 'change', groups=groups):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Generate the permission name based on the model and permission
|
# Generate the permission name based on the model and permission
|
||||||
|
|||||||
@@ -88,3 +88,73 @@ def test_api_options_performance(url):
|
|||||||
assert result
|
assert result
|
||||||
assert 'actions' in result
|
assert 'actions' in result
|
||||||
assert len(result['actions']) > 0
|
assert len(result['actions']) > 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.benchmark
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
'key',
|
||||||
|
[
|
||||||
|
'all',
|
||||||
|
'part',
|
||||||
|
'partcategory',
|
||||||
|
'supplierpart',
|
||||||
|
'manufacturerpart',
|
||||||
|
'stockitem',
|
||||||
|
'stocklocation',
|
||||||
|
'build',
|
||||||
|
'supplier',
|
||||||
|
'manufacturer',
|
||||||
|
'customer',
|
||||||
|
'purchaseorder',
|
||||||
|
'salesorder',
|
||||||
|
'salesordershipment',
|
||||||
|
'returnorder',
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_search_performance(key: str):
|
||||||
|
"""Benchmark the API search performance."""
|
||||||
|
SEARCH_URL = '/api/search/'
|
||||||
|
|
||||||
|
# An indicative search query for various model types
|
||||||
|
SEARCH_DATA = {
|
||||||
|
'part': {'active': True},
|
||||||
|
'partcategory': {},
|
||||||
|
'supplierpart': {
|
||||||
|
'part_detail': True,
|
||||||
|
'supplier_detail': True,
|
||||||
|
'manufacturer_detail': True,
|
||||||
|
},
|
||||||
|
'manufacturerpart': {
|
||||||
|
'part_detail': True,
|
||||||
|
'supplier_detail': True,
|
||||||
|
'manufacturer_detail': True,
|
||||||
|
},
|
||||||
|
'stockitem': {'part_detail': True, 'location_detail': True, 'in_stock': True},
|
||||||
|
'stocklocation': {},
|
||||||
|
'build': {'part_detail': True},
|
||||||
|
'supplier': {},
|
||||||
|
'manufacturer': {},
|
||||||
|
'customer': {},
|
||||||
|
'purchaseorder': {'supplier_detail': True, 'outstanding': True},
|
||||||
|
'salesorder': {'customer_detail': True, 'outstanding': True},
|
||||||
|
'salesordershipment': {},
|
||||||
|
'returnorder': {'customer_detail': True, 'outstanding': True},
|
||||||
|
}
|
||||||
|
|
||||||
|
model_types = list(SEARCH_DATA.keys())
|
||||||
|
|
||||||
|
search_params = SEARCH_DATA if key == 'all' else {key: SEARCH_DATA[key]}
|
||||||
|
|
||||||
|
# Add in a common search term
|
||||||
|
search_params.update({'search': '0', 'limit': 50})
|
||||||
|
|
||||||
|
response = api_client.post(SEARCH_URL, data=search_params)
|
||||||
|
assert response
|
||||||
|
|
||||||
|
if key == 'all':
|
||||||
|
for model_type in model_types:
|
||||||
|
assert model_type in response
|
||||||
|
assert 'error' not in response[model_type]
|
||||||
|
else:
|
||||||
|
assert key in response
|
||||||
|
assert 'error' not in response[key]
|
||||||
|
|||||||
Reference in New Issue
Block a user