mirror of
https://github.com/inventree/InvenTree.git
synced 2025-04-29 20:16:44 +00:00
Part API query tests (#4423)
* Add unit tests for validating number of queries * Simplify category_detail annotation to PartList API endpoint - Previous approach was an old hack from before the n+1 problem was understood
This commit is contained in:
parent
b657fb4405
commit
71db557d3b
@ -1118,6 +1118,7 @@ class PartList(APIDownloadMixin, ListCreateAPI):
|
|||||||
params = self.request.query_params
|
params = self.request.query_params
|
||||||
|
|
||||||
kwargs['parameters'] = str2bool(params.get('parameters', None))
|
kwargs['parameters'] = str2bool(params.get('parameters', None))
|
||||||
|
kwargs['category_detail'] = str2bool(params.get('category_detail', False))
|
||||||
|
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
pass
|
pass
|
||||||
@ -1156,41 +1157,6 @@ class PartList(APIDownloadMixin, ListCreateAPI):
|
|||||||
|
|
||||||
data = serializer.data
|
data = serializer.data
|
||||||
|
|
||||||
# Do we wish to include PartCategory detail?
|
|
||||||
if str2bool(request.query_params.get('category_detail', False)):
|
|
||||||
|
|
||||||
# Work out which part categories we need to query
|
|
||||||
category_ids = set()
|
|
||||||
|
|
||||||
for part in data:
|
|
||||||
cat_id = part['category']
|
|
||||||
|
|
||||||
if cat_id is not None:
|
|
||||||
category_ids.add(cat_id)
|
|
||||||
|
|
||||||
# Fetch only the required PartCategory objects from the database
|
|
||||||
categories = PartCategory.objects.filter(pk__in=category_ids).prefetch_related(
|
|
||||||
'parts',
|
|
||||||
'parent',
|
|
||||||
'children',
|
|
||||||
)
|
|
||||||
|
|
||||||
category_map = {}
|
|
||||||
|
|
||||||
# Serialize each PartCategory object
|
|
||||||
for category in categories:
|
|
||||||
category_map[category.pk] = part_serializers.CategorySerializer(category).data
|
|
||||||
|
|
||||||
for part in data:
|
|
||||||
cat_id = part['category']
|
|
||||||
|
|
||||||
if cat_id is not None and cat_id in category_map.keys():
|
|
||||||
detail = category_map[cat_id]
|
|
||||||
else:
|
|
||||||
detail = None
|
|
||||||
|
|
||||||
part['category_detail'] = detail
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Determine the response type based on the request.
|
Determine the response type based on the request.
|
||||||
a) For HTTP requests (e.g. via the browseable API) return a DRF response
|
a) For HTTP requests (e.g. via the browseable API) return a DRF response
|
||||||
|
@ -470,22 +470,19 @@ class PartSerializer(RemoteImageMixin, InvenTreeModelSerializer):
|
|||||||
- Allows us to optionally pass extra fields based on the query.
|
- Allows us to optionally pass extra fields based on the query.
|
||||||
"""
|
"""
|
||||||
self.starred_parts = kwargs.pop('starred_parts', [])
|
self.starred_parts = kwargs.pop('starred_parts', [])
|
||||||
|
|
||||||
category_detail = kwargs.pop('category_detail', False)
|
category_detail = kwargs.pop('category_detail', False)
|
||||||
|
|
||||||
parameters = kwargs.pop('parameters', False)
|
parameters = kwargs.pop('parameters', False)
|
||||||
|
|
||||||
create = kwargs.pop('create', False)
|
create = kwargs.pop('create', False)
|
||||||
|
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
if category_detail is not True:
|
if not category_detail:
|
||||||
self.fields.pop('category_detail')
|
self.fields.pop('category_detail')
|
||||||
|
|
||||||
if parameters is not True:
|
if not parameters:
|
||||||
self.fields.pop('parameters')
|
self.fields.pop('parameters')
|
||||||
|
|
||||||
if create is not True:
|
if not create:
|
||||||
# These fields are only used for the LIST API endpoint
|
# These fields are only used for the LIST API endpoint
|
||||||
for f in self.skip_create_fields()[1:]:
|
for f in self.skip_create_fields()[1:]:
|
||||||
self.fields.pop(f)
|
self.fields.pop(f)
|
||||||
|
@ -5,6 +5,8 @@ from enum import IntEnum
|
|||||||
from random import randint
|
from random import randint
|
||||||
|
|
||||||
from django.core.exceptions import ValidationError
|
from django.core.exceptions import ValidationError
|
||||||
|
from django.db import connection
|
||||||
|
from django.test.utils import CaptureQueriesContext
|
||||||
from django.urls import reverse
|
from django.urls import reverse
|
||||||
|
|
||||||
import PIL
|
import PIL
|
||||||
@ -1704,6 +1706,60 @@ class PartDetailTests(PartAPITestBase):
|
|||||||
self.assertEqual(part.metadata['x'], 'y')
|
self.assertEqual(part.metadata['x'], 'y')
|
||||||
|
|
||||||
|
|
||||||
|
class PartListTests(PartAPITestBase):
|
||||||
|
"""Unit tests for the Part List API endpoint"""
|
||||||
|
|
||||||
|
def test_query_count(self):
|
||||||
|
"""Test that the query count is unchanged, independent of query results"""
|
||||||
|
|
||||||
|
queries = [
|
||||||
|
{'limit': 1},
|
||||||
|
{'limit': 10},
|
||||||
|
{'limit': 50},
|
||||||
|
{'category': 1},
|
||||||
|
{},
|
||||||
|
]
|
||||||
|
|
||||||
|
url = reverse('api-part-list')
|
||||||
|
|
||||||
|
# Create a bunch of extra parts (efficiently)
|
||||||
|
parts = []
|
||||||
|
|
||||||
|
for ii in range(100):
|
||||||
|
parts.append(Part(
|
||||||
|
name=f"Extra part {ii}",
|
||||||
|
description="A new part which will appear via the API",
|
||||||
|
level=0, tree_id=0,
|
||||||
|
lft=0, rght=0,
|
||||||
|
))
|
||||||
|
|
||||||
|
Part.objects.bulk_create(parts)
|
||||||
|
|
||||||
|
for query in queries:
|
||||||
|
|
||||||
|
with CaptureQueriesContext(connection) as ctx:
|
||||||
|
self.get(url, query, expected_code=200)
|
||||||
|
|
||||||
|
# No more than 20 database queries
|
||||||
|
self.assertLess(len(ctx), 20)
|
||||||
|
|
||||||
|
# Test 'category_detail' annotation
|
||||||
|
for b in [False, True]:
|
||||||
|
with CaptureQueriesContext(connection) as ctx:
|
||||||
|
results = self.get(
|
||||||
|
reverse('api-part-list'),
|
||||||
|
{'category_detail': b},
|
||||||
|
expected_code=200
|
||||||
|
)
|
||||||
|
|
||||||
|
for result in results.data:
|
||||||
|
if b and result['category'] is not None:
|
||||||
|
self.assertIn('category_detail', result)
|
||||||
|
|
||||||
|
# No more than 20 DB queries
|
||||||
|
self.assertLessEqual(len(ctx), 20)
|
||||||
|
|
||||||
|
|
||||||
class PartNotesTests(InvenTreeAPITestCase):
|
class PartNotesTests(InvenTreeAPITestCase):
|
||||||
"""Tests for the 'notes' field (markdown field)"""
|
"""Tests for the 'notes' field (markdown field)"""
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user