2
0
mirror of https://github.com/inventree/InvenTree.git synced 2025-04-29 20:16:44 +00:00

Improve StockItem API speed (#5042)

- Removes child detail fields which cannot be effectively annotated
- Prefetch required fields
- Add unit test method for checking query count
This commit is contained in:
Oliver 2023-06-14 18:33:49 +10:00 committed by GitHub
parent 8d16abcefb
commit be6ab14c9b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 73 additions and 11 deletions

View File

@ -2,13 +2,17 @@
import csv import csv
import io import io
import json
import re import re
from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
from django.contrib.auth.models import Group, Permission from django.contrib.auth.models import Group, Permission
from django.db import connections
from django.http.response import StreamingHttpResponse from django.http.response import StreamingHttpResponse
from django.test import TestCase from django.test import TestCase
from django.test.utils import CaptureQueriesContext
from djmoney.contrib.exchange.models import ExchangeBackend, Rate from djmoney.contrib.exchange.models import ExchangeBackend, Rate
from rest_framework.test import APITestCase from rest_framework.test import APITestCase
@ -241,6 +245,30 @@ class InvenTreeTestCase(ExchangeRateMixin, UserMixin, TestCase):
class InvenTreeAPITestCase(ExchangeRateMixin, UserMixin, APITestCase): class InvenTreeAPITestCase(ExchangeRateMixin, UserMixin, APITestCase):
"""Base class for running InvenTree API tests.""" """Base class for running InvenTree API tests."""
@contextmanager
def assertNumQueriesLessThan(self, value, using='default', verbose=False, debug=False):
"""Context manager to check that the number of queries is less than a certain value.
Example:
with self.assertNumQueriesLessThan(10):
# Do some stuff
Ref: https://stackoverflow.com/questions/1254170/django-is-there-a-way-to-count-sql-queries-from-an-unit-test/59089020#59089020
"""
with CaptureQueriesContext(connections[using]) as context:
yield # your test will be run here
if verbose:
msg = "\r\n%s" % json.dumps(context.captured_queries, indent=4)
else:
msg = None
n = len(context.captured_queries)
if debug:
print(f"Expected less than {value} queries, got {n} queries")
self.assertLess(n, value, msg=msg)
def checkResponse(self, url, method, expected_code, response): def checkResponse(self, url, method, expected_code, response):
"""Debug output for an unexpected response""" """Debug output for an unexpected response"""

View File

@ -334,7 +334,7 @@ class SupplierPartSerializer(InvenTreeTagModelSerializer):
MPN = serializers.CharField(read_only=True) MPN = serializers.CharField(read_only=True)
manufacturer_part_detail = ManufacturerPartSerializer(source='manufacturer_part', read_only=True) manufacturer_part_detail = ManufacturerPartSerializer(source='manufacturer_part', part_detail=False, read_only=True)
url = serializers.CharField(source='get_absolute_url', read_only=True) url = serializers.CharField(source='get_absolute_url', read_only=True)

View File

@ -795,15 +795,6 @@ class StockList(APIDownloadMixin, ListCreateDestroyAPIView):
queryset = StockSerializers.StockItemSerializer.annotate_queryset(queryset) queryset = StockSerializers.StockItemSerializer.annotate_queryset(queryset)
# Also ensure that we pre-fecth all the related items
queryset = queryset.prefetch_related(
'part',
'part__category',
'location',
'test_results',
'tags',
)
return queryset return queryset
def filter_queryset(self, queryset): def filter_queryset(self, queryset):

View File

@ -231,10 +231,17 @@ class StockItemSerializer(InvenTree.serializers.InvenTreeTagModelSerializer):
"""Add some extra annotations to the queryset, performing database queries as efficiently as possible.""" """Add some extra annotations to the queryset, performing database queries as efficiently as possible."""
queryset = queryset.prefetch_related( queryset = queryset.prefetch_related(
'location',
'sales_order', 'sales_order',
'purchase_order', 'purchase_order',
'part', 'part',
'part__category',
'part__pricing_data', 'part__pricing_data',
'supplier_part',
'supplier_part__manufacturer_part',
'supplier_part__tags',
'test_results',
'tags',
) )
# Annotate the queryset with the total allocated to sales orders # Annotate the queryset with the total allocated to sales orders
@ -280,7 +287,7 @@ class StockItemSerializer(InvenTree.serializers.InvenTreeTagModelSerializer):
status_text = serializers.CharField(source='get_status_display', read_only=True) status_text = serializers.CharField(source='get_status_display', read_only=True)
# Optional detail fields, which can be appended via query parameters # Optional detail fields, which can be appended via query parameters
supplier_part_detail = SupplierPartSerializer(source='supplier_part', many=False, read_only=True) supplier_part_detail = SupplierPartSerializer(source='supplier_part', supplier_detail=False, manufacturer_detail=False, part_detail=False, many=False, read_only=True)
part_detail = PartBriefSerializer(source='part', many=False, read_only=True) part_detail = PartBriefSerializer(source='part', many=False, read_only=True)
location_detail = LocationBriefSerializer(source='location', many=False, read_only=True) location_detail = LocationBriefSerializer(source='location', many=False, read_only=True)
tests = StockItemTestResultSerializer(source='test_results', many=True, read_only=True) tests = StockItemTestResultSerializer(source='test_results', many=True, read_only=True)

View File

@ -557,6 +557,42 @@ class StockItemListTest(StockAPITestCase):
self.assertEqual(len(dataset), 17) self.assertEqual(len(dataset), 17)
def test_query_count(self):
"""Test that the number of queries required to fetch stock items is reasonable."""
def get_stock(data):
"""Helper function to fetch stock items."""
response = self.client.get(self.list_url, data=data)
self.assertEqual(response.status_code, 200)
return response.data
# Create a bunch of StockItem objects
prt = Part.objects.first()
StockItem.objects.bulk_create([
StockItem(
part=prt,
quantity=1,
level=0, tree_id=0, lft=0, rght=0,
) for _ in range(100)
])
# List *all* stock items
with self.assertNumQueriesLessThan(25):
get_stock({})
# List all stock items, with part detail
with self.assertNumQueriesLessThan(20):
get_stock({'part_detail': True})
# List all stock items, with supplier_part detail
with self.assertNumQueriesLessThan(20):
get_stock({'supplier_part_detail': True})
# List all stock items, with 'location' and 'tests' detail
with self.assertNumQueriesLessThan(20):
get_stock({'location_detail': True, 'tests': True})
class StockItemTest(StockAPITestCase): class StockItemTest(StockAPITestCase):
"""Series of API tests for the StockItem API.""" """Series of API tests for the StockItem API."""