mirror of
				https://github.com/inventree/InvenTree.git
				synced 2025-11-04 15:15:42 +00:00 
			
		
		
		
	Part API query tests (#4423)
* Add unit tests for validating number of queries * Simplify category_detail annotation to PartList API endpoint - Previous approach was an old hack from before the n+1 problem was understood
This commit is contained in:
		@@ -1118,6 +1118,7 @@ class PartList(APIDownloadMixin, ListCreateAPI):
 | 
			
		||||
            params = self.request.query_params
 | 
			
		||||
 | 
			
		||||
            kwargs['parameters'] = str2bool(params.get('parameters', None))
 | 
			
		||||
            kwargs['category_detail'] = str2bool(params.get('category_detail', False))
 | 
			
		||||
 | 
			
		||||
        except AttributeError:
 | 
			
		||||
            pass
 | 
			
		||||
@@ -1156,41 +1157,6 @@ class PartList(APIDownloadMixin, ListCreateAPI):
 | 
			
		||||
 | 
			
		||||
        data = serializer.data
 | 
			
		||||
 | 
			
		||||
        # Do we wish to include PartCategory detail?
 | 
			
		||||
        if str2bool(request.query_params.get('category_detail', False)):
 | 
			
		||||
 | 
			
		||||
            # Work out which part categories we need to query
 | 
			
		||||
            category_ids = set()
 | 
			
		||||
 | 
			
		||||
            for part in data:
 | 
			
		||||
                cat_id = part['category']
 | 
			
		||||
 | 
			
		||||
                if cat_id is not None:
 | 
			
		||||
                    category_ids.add(cat_id)
 | 
			
		||||
 | 
			
		||||
            # Fetch only the required PartCategory objects from the database
 | 
			
		||||
            categories = PartCategory.objects.filter(pk__in=category_ids).prefetch_related(
 | 
			
		||||
                'parts',
 | 
			
		||||
                'parent',
 | 
			
		||||
                'children',
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            category_map = {}
 | 
			
		||||
 | 
			
		||||
            # Serialize each PartCategory object
 | 
			
		||||
            for category in categories:
 | 
			
		||||
                category_map[category.pk] = part_serializers.CategorySerializer(category).data
 | 
			
		||||
 | 
			
		||||
            for part in data:
 | 
			
		||||
                cat_id = part['category']
 | 
			
		||||
 | 
			
		||||
                if cat_id is not None and cat_id in category_map.keys():
 | 
			
		||||
                    detail = category_map[cat_id]
 | 
			
		||||
                else:
 | 
			
		||||
                    detail = None
 | 
			
		||||
 | 
			
		||||
                part['category_detail'] = detail
 | 
			
		||||
 | 
			
		||||
        """
 | 
			
		||||
        Determine the response type based on the request.
 | 
			
		||||
        a) For HTTP requests (e.g. via the browseable API) return a DRF response
 | 
			
		||||
 
 | 
			
		||||
@@ -470,22 +470,19 @@ class PartSerializer(RemoteImageMixin, InvenTreeModelSerializer):
 | 
			
		||||
        - Allows us to optionally pass extra fields based on the query.
 | 
			
		||||
        """
 | 
			
		||||
        self.starred_parts = kwargs.pop('starred_parts', [])
 | 
			
		||||
 | 
			
		||||
        category_detail = kwargs.pop('category_detail', False)
 | 
			
		||||
 | 
			
		||||
        parameters = kwargs.pop('parameters', False)
 | 
			
		||||
 | 
			
		||||
        create = kwargs.pop('create', False)
 | 
			
		||||
 | 
			
		||||
        super().__init__(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
        if category_detail is not True:
 | 
			
		||||
        if not category_detail:
 | 
			
		||||
            self.fields.pop('category_detail')
 | 
			
		||||
 | 
			
		||||
        if parameters is not True:
 | 
			
		||||
        if not parameters:
 | 
			
		||||
            self.fields.pop('parameters')
 | 
			
		||||
 | 
			
		||||
        if create is not True:
 | 
			
		||||
        if not create:
 | 
			
		||||
            # These fields are only used for the LIST API endpoint
 | 
			
		||||
            for f in self.skip_create_fields()[1:]:
 | 
			
		||||
                self.fields.pop(f)
 | 
			
		||||
 
 | 
			
		||||
@@ -5,6 +5,8 @@ from enum import IntEnum
 | 
			
		||||
from random import randint
 | 
			
		||||
 | 
			
		||||
from django.core.exceptions import ValidationError
 | 
			
		||||
from django.db import connection
 | 
			
		||||
from django.test.utils import CaptureQueriesContext
 | 
			
		||||
from django.urls import reverse
 | 
			
		||||
 | 
			
		||||
import PIL
 | 
			
		||||
@@ -1704,6 +1706,60 @@ class PartDetailTests(PartAPITestBase):
 | 
			
		||||
        self.assertEqual(part.metadata['x'], 'y')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PartListTests(PartAPITestBase):
 | 
			
		||||
    """Unit tests for the Part List API endpoint"""
 | 
			
		||||
 | 
			
		||||
    def test_query_count(self):
 | 
			
		||||
        """Test that the query count is unchanged, independent of query results"""
 | 
			
		||||
 | 
			
		||||
        queries = [
 | 
			
		||||
            {'limit': 1},
 | 
			
		||||
            {'limit': 10},
 | 
			
		||||
            {'limit': 50},
 | 
			
		||||
            {'category': 1},
 | 
			
		||||
            {},
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
        url = reverse('api-part-list')
 | 
			
		||||
 | 
			
		||||
        # Create a bunch of extra parts (efficiently)
 | 
			
		||||
        parts = []
 | 
			
		||||
 | 
			
		||||
        for ii in range(100):
 | 
			
		||||
            parts.append(Part(
 | 
			
		||||
                name=f"Extra part {ii}",
 | 
			
		||||
                description="A new part which will appear via the API",
 | 
			
		||||
                level=0, tree_id=0,
 | 
			
		||||
                lft=0, rght=0,
 | 
			
		||||
            ))
 | 
			
		||||
 | 
			
		||||
        Part.objects.bulk_create(parts)
 | 
			
		||||
 | 
			
		||||
        for query in queries:
 | 
			
		||||
 | 
			
		||||
            with CaptureQueriesContext(connection) as ctx:
 | 
			
		||||
                self.get(url, query, expected_code=200)
 | 
			
		||||
 | 
			
		||||
            # No more than 20 database queries
 | 
			
		||||
            self.assertLess(len(ctx), 20)
 | 
			
		||||
 | 
			
		||||
        # Test 'category_detail' annotation
 | 
			
		||||
        for b in [False, True]:
 | 
			
		||||
            with CaptureQueriesContext(connection) as ctx:
 | 
			
		||||
                results = self.get(
 | 
			
		||||
                    reverse('api-part-list'),
 | 
			
		||||
                    {'category_detail': b},
 | 
			
		||||
                    expected_code=200
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
                for result in results.data:
 | 
			
		||||
                    if b and result['category'] is not None:
 | 
			
		||||
                        self.assertIn('category_detail', result)
 | 
			
		||||
 | 
			
		||||
            # No more than 20 DB queries
 | 
			
		||||
            self.assertLessEqual(len(ctx), 20)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PartNotesTests(InvenTreeAPITestCase):
 | 
			
		||||
    """Tests for the 'notes' field (markdown field)"""
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user