mirror of
				https://github.com/inventree/InvenTree.git
				synced 2025-11-03 22:55:43 +00:00 
			
		
		
		
	Tree query improvements (#3443)
* Allow part category table to be ordered by part count * Add queryset annotation for part-category part-count - Uses subquery to annotate the part-count for sub-categories - Huge reduction in number of queries * Update 'pathstring' property of PartCategory and StockLocation - No longer a dynamically calculated value - Constructed when the model is saved, and then written to the database - Limited to 250 characters * Data migration to re-construct pathstring for PartCategory objects * Fix for tree model save() method * Add unit tests for pathstring construction * Data migration for StockLocation pathstring values * Update part API - Add new annotation to PartLocationDetail view * Update API version * Apply similar annotation to StockLocation API endpoints * Extra tests for PartCategory API * Unit test fixes * Allow PartCategory and StockLocation lists to be sorted by 'pathstring' * Further unit test fixes
This commit is contained in:
		@@ -53,6 +53,13 @@ class CategoryList(ListCreateAPI):
 | 
			
		||||
    queryset = PartCategory.objects.all()
 | 
			
		||||
    serializer_class = part_serializers.CategorySerializer
 | 
			
		||||
 | 
			
		||||
    def get_queryset(self, *args, **kwargs):
 | 
			
		||||
        """Return an annotated queryset for the CategoryList endpoint"""
 | 
			
		||||
 | 
			
		||||
        queryset = super().get_queryset(*args, **kwargs)
 | 
			
		||||
        queryset = part_serializers.CategorySerializer.annotate_queryset(queryset)
 | 
			
		||||
        return queryset
 | 
			
		||||
 | 
			
		||||
    def get_serializer_context(self):
 | 
			
		||||
        """Add extra context data to the serializer for the PartCategoryList endpoint"""
 | 
			
		||||
        ctx = super().get_serializer_context()
 | 
			
		||||
@@ -141,9 +148,11 @@ class CategoryList(ListCreateAPI):
 | 
			
		||||
 | 
			
		||||
    ordering_fields = [
 | 
			
		||||
        'name',
 | 
			
		||||
        'pathstring',
 | 
			
		||||
        'level',
 | 
			
		||||
        'tree_id',
 | 
			
		||||
        'lft',
 | 
			
		||||
        'part_count',
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
    # Use hierarchical ordering by default
 | 
			
		||||
@@ -165,6 +174,13 @@ class CategoryDetail(RetrieveUpdateDestroyAPI):
 | 
			
		||||
    serializer_class = part_serializers.CategorySerializer
 | 
			
		||||
    queryset = PartCategory.objects.all()
 | 
			
		||||
 | 
			
		||||
    def get_queryset(self, *args, **kwargs):
 | 
			
		||||
        """Return an annotated queryset for the CategoryDetail endpoint"""
 | 
			
		||||
 | 
			
		||||
        queryset = super().get_queryset(*args, **kwargs)
 | 
			
		||||
        queryset = part_serializers.CategorySerializer.annotate_queryset(queryset)
 | 
			
		||||
        return queryset
 | 
			
		||||
 | 
			
		||||
    def get_serializer_context(self):
 | 
			
		||||
        """Add extra context to the serializer for the CategoryDetail endpoint"""
 | 
			
		||||
        ctx = super().get_serializer_context()
 | 
			
		||||
 
 | 
			
		||||
@@ -1,4 +1,4 @@
 | 
			
		||||
"""Custom query filters for the Part model
 | 
			
		||||
"""Custom query filters for the Part models
 | 
			
		||||
 | 
			
		||||
The code here makes heavy use of subquery annotations!
 | 
			
		||||
 | 
			
		||||
@@ -19,11 +19,13 @@ Relevant PRs:
 | 
			
		||||
from decimal import Decimal
 | 
			
		||||
 | 
			
		||||
from django.db import models
 | 
			
		||||
from django.db.models import F, FloatField, Func, OuterRef, Q, Subquery
 | 
			
		||||
from django.db.models import (F, FloatField, Func, IntegerField, OuterRef, Q,
 | 
			
		||||
                              Subquery)
 | 
			
		||||
from django.db.models.functions import Coalesce
 | 
			
		||||
 | 
			
		||||
from sql_util.utils import SubquerySum
 | 
			
		||||
 | 
			
		||||
import part.models
 | 
			
		||||
import stock.models
 | 
			
		||||
from InvenTree.status_codes import (BuildStatus, PurchaseOrderStatus,
 | 
			
		||||
                                    SalesOrderStatus)
 | 
			
		||||
@@ -158,3 +160,29 @@ def annotate_variant_quantity(subquery: Q, reference: str = 'quantity'):
 | 
			
		||||
        0,
 | 
			
		||||
        output_field=FloatField(),
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def annotate_category_parts():
 | 
			
		||||
    """Construct a queryset annotation which returns the number of parts in a particular category.
 | 
			
		||||
 | 
			
		||||
    - Includes parts in subcategories also
 | 
			
		||||
    - Requires subquery to perform annotation
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    # Construct a subquery to provide all parts in this category and any subcategories:
 | 
			
		||||
    subquery = part.models.Part.objects.exclude(category=None).filter(
 | 
			
		||||
        category__tree_id=OuterRef('tree_id'),
 | 
			
		||||
        category__lft__gte=OuterRef('lft'),
 | 
			
		||||
        category__rght__lte=OuterRef('rght'),
 | 
			
		||||
        category__level__gte=OuterRef('level'),
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    return Coalesce(
 | 
			
		||||
        Subquery(
 | 
			
		||||
            subquery.annotate(
 | 
			
		||||
                total=Func(F('pk'), function='COUNT', output_field=IntegerField())
 | 
			
		||||
            ).values('total'),
 | 
			
		||||
        ),
 | 
			
		||||
        0,
 | 
			
		||||
        output_field=IntegerField()
 | 
			
		||||
    )
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										18
									
								
								InvenTree/part/migrations/0082_partcategory_pathstring.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										18
									
								
								InvenTree/part/migrations/0082_partcategory_pathstring.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,18 @@
 | 
			
		||||
# Generated by Django 3.2.14 on 2022-07-31 23:54
 | 
			
		||||
 | 
			
		||||
from django.db import migrations, models
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Migration(migrations.Migration):
 | 
			
		||||
 | 
			
		||||
    dependencies = [
 | 
			
		||||
        ('part', '0081_alter_partcategory_name'),
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
    operations = [
 | 
			
		||||
        migrations.AddField(
 | 
			
		||||
            model_name='partcategory',
 | 
			
		||||
            name='pathstring',
 | 
			
		||||
            field=models.CharField(blank=True, help_text='Path', max_length=250, verbose_name='Path'),
 | 
			
		||||
        ),
 | 
			
		||||
    ]
 | 
			
		||||
							
								
								
									
										54
									
								
								InvenTree/part/migrations/0083_auto_20220731_2357.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										54
									
								
								InvenTree/part/migrations/0083_auto_20220731_2357.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,54 @@
 | 
			
		||||
# Generated by Django 3.2.14 on 2022-07-31 23:57
 | 
			
		||||
 | 
			
		||||
from django.db import migrations
 | 
			
		||||
 | 
			
		||||
from InvenTree.helpers import constructPathString
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def update_pathstring(apps, schema_editor):
 | 
			
		||||
    """Construct pathstring for all existing PartCategory objects"""
 | 
			
		||||
 | 
			
		||||
    PartCategory = apps.get_model('part', 'partcategory')
 | 
			
		||||
 | 
			
		||||
    n = PartCategory.objects.count()
 | 
			
		||||
 | 
			
		||||
    if n > 0:
 | 
			
		||||
 | 
			
		||||
        for cat in PartCategory.objects.all():
 | 
			
		||||
 | 
			
		||||
            # Construct complete path for category
 | 
			
		||||
            path = [cat.name]
 | 
			
		||||
 | 
			
		||||
            parent = cat.parent
 | 
			
		||||
 | 
			
		||||
            # Iterate up the tree
 | 
			
		||||
            while parent is not None:
 | 
			
		||||
                path = [parent.name] + path
 | 
			
		||||
                parent = parent.parent
 | 
			
		||||
 | 
			
		||||
            pathstring = constructPathString(path)
 | 
			
		||||
 | 
			
		||||
            cat.pathstring = pathstring
 | 
			
		||||
            cat.save()
 | 
			
		||||
 | 
			
		||||
        print(f"\n--- Updated 'pathstring' for {n} PartCategory objects ---\n")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def nupdate_pathstring(apps, schema_editor):
 | 
			
		||||
    """Empty function for reverse migration compatibility"""
 | 
			
		||||
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Migration(migrations.Migration):
 | 
			
		||||
 | 
			
		||||
    dependencies = [
 | 
			
		||||
        ('part', '0082_partcategory_pathstring'),
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
    operations = [
 | 
			
		||||
        migrations.RunPython(
 | 
			
		||||
            update_pathstring,
 | 
			
		||||
            reverse_code=nupdate_pathstring
 | 
			
		||||
        )
 | 
			
		||||
    ]
 | 
			
		||||
@@ -41,9 +41,20 @@ class CategorySerializer(InvenTreeModelSerializer):
 | 
			
		||||
        """Return True if the category is directly "starred" by the current user."""
 | 
			
		||||
        return category in self.context.get('starred_categories', [])
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def annotate_queryset(queryset):
 | 
			
		||||
        """Annotate extra information to the queryset"""
 | 
			
		||||
 | 
			
		||||
        # Annotate the number of 'parts' which exist in each category (including subcategories!)
 | 
			
		||||
        queryset = queryset.annotate(
 | 
			
		||||
            part_count=part.filters.annotate_category_parts()
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        return queryset
 | 
			
		||||
 | 
			
		||||
    url = serializers.CharField(source='get_absolute_url', read_only=True)
 | 
			
		||||
 | 
			
		||||
    parts = serializers.IntegerField(source='item_count', read_only=True)
 | 
			
		||||
    part_count = serializers.IntegerField(read_only=True)
 | 
			
		||||
 | 
			
		||||
    level = serializers.IntegerField(read_only=True)
 | 
			
		||||
 | 
			
		||||
@@ -60,7 +71,7 @@ class CategorySerializer(InvenTreeModelSerializer):
 | 
			
		||||
            'default_keywords',
 | 
			
		||||
            'level',
 | 
			
		||||
            'parent',
 | 
			
		||||
            'parts',
 | 
			
		||||
            'part_count',
 | 
			
		||||
            'pathstring',
 | 
			
		||||
            'starred',
 | 
			
		||||
            'url',
 | 
			
		||||
 
 | 
			
		||||
@@ -77,6 +77,76 @@ class PartCategoryAPITest(InvenTreeAPITestCase):
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(len(response.data), 5)
 | 
			
		||||
 | 
			
		||||
        # Check that the required fields are present
 | 
			
		||||
        fields = [
 | 
			
		||||
            'pk',
 | 
			
		||||
            'name',
 | 
			
		||||
            'description',
 | 
			
		||||
            'default_location',
 | 
			
		||||
            'level',
 | 
			
		||||
            'parent',
 | 
			
		||||
            'part_count',
 | 
			
		||||
            'pathstring',
 | 
			
		||||
            'url'
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
        for result in response.data:
 | 
			
		||||
            for f in fields:
 | 
			
		||||
                self.assertIn(f, result)
 | 
			
		||||
 | 
			
		||||
    def test_part_count(self):
 | 
			
		||||
        """Test that the 'part_count' field is annotated correctly"""
 | 
			
		||||
 | 
			
		||||
        url = reverse('api-part-category-list')
 | 
			
		||||
 | 
			
		||||
        # Create a parent category
 | 
			
		||||
        cat = PartCategory.objects.create(
 | 
			
		||||
            name='Parent Cat',
 | 
			
		||||
            description='Some name',
 | 
			
		||||
            parent=None
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Create child categories
 | 
			
		||||
        for ii in range(10):
 | 
			
		||||
            child = PartCategory.objects.create(
 | 
			
		||||
                name=f"Child cat {ii}",
 | 
			
		||||
                description="A child category",
 | 
			
		||||
                parent=cat
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            # Create parts in this category
 | 
			
		||||
            for jj in range(10):
 | 
			
		||||
                Part.objects.create(
 | 
			
		||||
                    name=f"Part xyz {jj}",
 | 
			
		||||
                    description="A test part",
 | 
			
		||||
                    category=child
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
        # Filter by parent category
 | 
			
		||||
        response = self.get(
 | 
			
		||||
            url,
 | 
			
		||||
            {
 | 
			
		||||
                'parent': cat.pk,
 | 
			
		||||
            },
 | 
			
		||||
            expected_code=200
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # 10 child categories
 | 
			
		||||
        self.assertEqual(len(response.data), 10)
 | 
			
		||||
 | 
			
		||||
        for result in response.data:
 | 
			
		||||
            self.assertEqual(result['parent'], cat.pk)
 | 
			
		||||
            self.assertEqual(result['part_count'], 10)
 | 
			
		||||
 | 
			
		||||
        # Detail view for parent category
 | 
			
		||||
        response = self.get(
 | 
			
		||||
            f'/api/part/category/{cat.pk}/',
 | 
			
		||||
            expected_code=200
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Annotation should include parts from all sub-categories
 | 
			
		||||
        self.assertEqual(response.data['part_count'], 100)
 | 
			
		||||
 | 
			
		||||
    def test_category_metadata(self):
 | 
			
		||||
        """Test metadata endpoint for the PartCategory."""
 | 
			
		||||
        cat = PartCategory.objects.get(pk=1)
 | 
			
		||||
 
 | 
			
		||||
@@ -1,5 +1,6 @@
 | 
			
		||||
"""Unit tests for the PartCategory model"""
 | 
			
		||||
 | 
			
		||||
from django.core.exceptions import ValidationError
 | 
			
		||||
from django.test import TestCase
 | 
			
		||||
 | 
			
		||||
from .models import Part, PartCategory, PartParameter, PartParameterTemplate
 | 
			
		||||
@@ -63,9 +64,69 @@ class CategoryTest(TestCase):
 | 
			
		||||
 | 
			
		||||
    def test_path_string(self):
 | 
			
		||||
        """Test that the category path string works correctly."""
 | 
			
		||||
 | 
			
		||||
        # Note that due to data migrations, these fields need to be saved first
 | 
			
		||||
        self.resistors.save()
 | 
			
		||||
        self.transceivers.save()
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(str(self.resistors), 'Electronics/Resistors - Resistors')
 | 
			
		||||
        self.assertEqual(str(self.transceivers.pathstring), 'Electronics/IC/Transceivers')
 | 
			
		||||
 | 
			
		||||
        # Create a new subcategory
 | 
			
		||||
        subcat = PartCategory.objects.create(
 | 
			
		||||
            name='Subcategory',
 | 
			
		||||
            description='My little sub category',
 | 
			
		||||
            parent=self.transceivers
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Pathstring should have been updated correctly
 | 
			
		||||
        self.assertEqual(subcat.pathstring, 'Electronics/IC/Transceivers/Subcategory')
 | 
			
		||||
        self.assertEqual(len(subcat.path), 4)
 | 
			
		||||
 | 
			
		||||
        # Move to a new parent location
 | 
			
		||||
        subcat.parent = self.resistors
 | 
			
		||||
        subcat.save()
 | 
			
		||||
        self.assertEqual(subcat.pathstring, 'Electronics/Resistors/Subcategory')
 | 
			
		||||
        self.assertEqual(len(subcat.path), 3)
 | 
			
		||||
 | 
			
		||||
        # Move to top-level
 | 
			
		||||
        subcat.parent = None
 | 
			
		||||
        subcat.save()
 | 
			
		||||
        self.assertEqual(subcat.pathstring, 'Subcategory')
 | 
			
		||||
        self.assertEqual(len(subcat.path), 1)
 | 
			
		||||
 | 
			
		||||
        # Construct a very long pathstring and ensure it gets updated correctly
 | 
			
		||||
        cat = PartCategory.objects.create(
 | 
			
		||||
            name='Cat',
 | 
			
		||||
            description='A long running category',
 | 
			
		||||
            parent=None
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        parent = cat
 | 
			
		||||
 | 
			
		||||
        for idx in range(26):
 | 
			
		||||
            letter = chr(ord('A') + idx)
 | 
			
		||||
 | 
			
		||||
            child = PartCategory.objects.create(
 | 
			
		||||
                name=letter * 10,
 | 
			
		||||
                description=f"Subcategory {letter}",
 | 
			
		||||
                parent=parent
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            parent = child
 | 
			
		||||
 | 
			
		||||
        self.assertTrue(len(child.path), 26)
 | 
			
		||||
        self.assertEqual(
 | 
			
		||||
            child.pathstring,
 | 
			
		||||
            "Cat/AAAAAAAAAA/BBBBBBBBBB/CCCCCCCCCC/DDDDDDDDDD/EEEEEEEEEE/FFFFFFFFFF/GGGGGGGGGG/HHHHHHHHHH/IIIIIIIIII/JJJJJJJJJJ/.../OOOOOOOOOO/PPPPPPPPPP/QQQQQQQQQQ/RRRRRRRRRR/SSSSSSSSSS/TTTTTTTTTT/UUUUUUUUUU/VVVVVVVVVV/WWWWWWWWWW/XXXXXXXXXX/YYYYYYYYYY/ZZZZZZZZZZ"
 | 
			
		||||
        )
 | 
			
		||||
        self.assertTrue(len(child.pathstring) <= 250)
 | 
			
		||||
 | 
			
		||||
        # Attempt an invalid move
 | 
			
		||||
        with self.assertRaises(ValidationError):
 | 
			
		||||
            cat.parent = child
 | 
			
		||||
            cat.save()
 | 
			
		||||
 | 
			
		||||
    def test_url(self):
 | 
			
		||||
        """Test that the PartCategory URL works."""
 | 
			
		||||
        self.assertEqual(self.capacitors.get_absolute_url(), '/part/category/3/')
 | 
			
		||||
@@ -130,6 +191,9 @@ class CategoryTest(TestCase):
 | 
			
		||||
 | 
			
		||||
    def test_default_locations(self):
 | 
			
		||||
        """Test traversal for default locations."""
 | 
			
		||||
 | 
			
		||||
        self.assertIsNotNone(self.fasteners.default_location)
 | 
			
		||||
        self.fasteners.default_location.save()
 | 
			
		||||
        self.assertEqual(str(self.fasteners.default_location), 'Office/Drawer_1 - In my desk')
 | 
			
		||||
 | 
			
		||||
        # Any part under electronics should default to 'Home'
 | 
			
		||||
 
 | 
			
		||||
@@ -220,6 +220,7 @@ class PartTest(TestCase):
 | 
			
		||||
 | 
			
		||||
    def test_category(self):
 | 
			
		||||
        """Test PartCategory path"""
 | 
			
		||||
        self.c1.category.save()
 | 
			
		||||
        self.assertEqual(str(self.c1.category), 'Electronics/Capacitors - Capacitors')
 | 
			
		||||
 | 
			
		||||
        orphan = Part.objects.get(name='Orphan')
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user