2
0
mirror of https://github.com/inventree/InvenTree.git synced 2025-04-29 20:16:44 +00:00

Part API query tests (#4423)

* Add unit tests for validating number of queries

* Simplify category_detail annotation to PartList API endpoint

- Previous approach was an old hack from before the n+1 problem was understood
This commit is contained in:
Oliver 2023-02-26 23:33:23 +11:00 committed by GitHub
parent b657fb4405
commit 71db557d3b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 60 additions and 41 deletions

View File

@ -1118,6 +1118,7 @@ class PartList(APIDownloadMixin, ListCreateAPI):
params = self.request.query_params params = self.request.query_params
kwargs['parameters'] = str2bool(params.get('parameters', None)) kwargs['parameters'] = str2bool(params.get('parameters', None))
kwargs['category_detail'] = str2bool(params.get('category_detail', False))
except AttributeError: except AttributeError:
pass pass
@ -1156,41 +1157,6 @@ class PartList(APIDownloadMixin, ListCreateAPI):
data = serializer.data data = serializer.data
# Do we wish to include PartCategory detail?
if str2bool(request.query_params.get('category_detail', False)):
# Work out which part categories we need to query
category_ids = set()
for part in data:
cat_id = part['category']
if cat_id is not None:
category_ids.add(cat_id)
# Fetch only the required PartCategory objects from the database
categories = PartCategory.objects.filter(pk__in=category_ids).prefetch_related(
'parts',
'parent',
'children',
)
category_map = {}
# Serialize each PartCategory object
for category in categories:
category_map[category.pk] = part_serializers.CategorySerializer(category).data
for part in data:
cat_id = part['category']
if cat_id is not None and cat_id in category_map.keys():
detail = category_map[cat_id]
else:
detail = None
part['category_detail'] = detail
""" """
Determine the response type based on the request. Determine the response type based on the request.
a) For HTTP requests (e.g. via the browseable API) return a DRF response a) For HTTP requests (e.g. via the browseable API) return a DRF response

View File

@ -470,22 +470,19 @@ class PartSerializer(RemoteImageMixin, InvenTreeModelSerializer):
- Allows us to optionally pass extra fields based on the query. - Allows us to optionally pass extra fields based on the query.
""" """
self.starred_parts = kwargs.pop('starred_parts', []) self.starred_parts = kwargs.pop('starred_parts', [])
category_detail = kwargs.pop('category_detail', False) category_detail = kwargs.pop('category_detail', False)
parameters = kwargs.pop('parameters', False) parameters = kwargs.pop('parameters', False)
create = kwargs.pop('create', False) create = kwargs.pop('create', False)
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
if category_detail is not True: if not category_detail:
self.fields.pop('category_detail') self.fields.pop('category_detail')
if parameters is not True: if not parameters:
self.fields.pop('parameters') self.fields.pop('parameters')
if create is not True: if not create:
# These fields are only used for the LIST API endpoint # These fields are only used for the LIST API endpoint
for f in self.skip_create_fields()[1:]: for f in self.skip_create_fields()[1:]:
self.fields.pop(f) self.fields.pop(f)

View File

@ -5,6 +5,8 @@ from enum import IntEnum
from random import randint from random import randint
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django.db import connection
from django.test.utils import CaptureQueriesContext
from django.urls import reverse from django.urls import reverse
import PIL import PIL
@ -1704,6 +1706,60 @@ class PartDetailTests(PartAPITestBase):
self.assertEqual(part.metadata['x'], 'y') self.assertEqual(part.metadata['x'], 'y')
class PartListTests(PartAPITestBase):
"""Unit tests for the Part List API endpoint"""
def test_query_count(self):
"""Test that the query count is unchanged, independent of query results"""
queries = [
{'limit': 1},
{'limit': 10},
{'limit': 50},
{'category': 1},
{},
]
url = reverse('api-part-list')
# Create a bunch of extra parts (efficiently)
parts = []
for ii in range(100):
parts.append(Part(
name=f"Extra part {ii}",
description="A new part which will appear via the API",
level=0, tree_id=0,
lft=0, rght=0,
))
Part.objects.bulk_create(parts)
for query in queries:
with CaptureQueriesContext(connection) as ctx:
self.get(url, query, expected_code=200)
# No more than 20 database queries
self.assertLess(len(ctx), 20)
# Test 'category_detail' annotation
for b in [False, True]:
with CaptureQueriesContext(connection) as ctx:
results = self.get(
reverse('api-part-list'),
{'category_detail': b},
expected_code=200
)
for result in results.data:
if b and result['category'] is not None:
self.assertIn('category_detail', result)
# No more than 20 DB queries
self.assertLessEqual(len(ctx), 20)
class PartNotesTests(InvenTreeAPITestCase): class PartNotesTests(InvenTreeAPITestCase):
"""Tests for the 'notes' field (markdown field)""" """Tests for the 'notes' field (markdown field)"""