mirror of
https://github.com/inventree/InvenTree.git
synced 2025-06-18 13:05:42 +00:00
Tree query improvements (#3443)
* Allow part category table to be ordered by part count * Add queryset annotation for part-category part-count - Uses subquery to annotate the part-count for sub-categories - Huge reduction in number of queries * Update 'pathstring' property of PartCategory and StockLocation - No longer a dynamically calculated value - Constructed when the model is saved, and then written to the database - Limited to 250 characters * Data migration to re-construct pathstring for PartCategory objects * Fix for tree model save() method * Add unit tests for pathstring construction * Data migration for StockLocation pathstring values * Update part API - Add new annotation to PartLocationDetail view * Update API version * Apply similar annotation to StockLocation API endpoints * Extra tests for PartCategory API * Unit test fixes * Allow PartCategory and StockLocation lists to be sorted by 'pathstring' * Further unit test fixes
This commit is contained in:
@ -53,6 +53,13 @@ class CategoryList(ListCreateAPI):
|
||||
queryset = PartCategory.objects.all()
|
||||
serializer_class = part_serializers.CategorySerializer
|
||||
|
||||
def get_queryset(self, *args, **kwargs):
|
||||
"""Return an annotated queryset for the CategoryList endpoint"""
|
||||
|
||||
queryset = super().get_queryset(*args, **kwargs)
|
||||
queryset = part_serializers.CategorySerializer.annotate_queryset(queryset)
|
||||
return queryset
|
||||
|
||||
def get_serializer_context(self):
|
||||
"""Add extra context data to the serializer for the PartCategoryList endpoint"""
|
||||
ctx = super().get_serializer_context()
|
||||
@ -141,9 +148,11 @@ class CategoryList(ListCreateAPI):
|
||||
|
||||
ordering_fields = [
|
||||
'name',
|
||||
'pathstring',
|
||||
'level',
|
||||
'tree_id',
|
||||
'lft',
|
||||
'part_count',
|
||||
]
|
||||
|
||||
# Use hierarchical ordering by default
|
||||
@ -165,6 +174,13 @@ class CategoryDetail(RetrieveUpdateDestroyAPI):
|
||||
serializer_class = part_serializers.CategorySerializer
|
||||
queryset = PartCategory.objects.all()
|
||||
|
||||
def get_queryset(self, *args, **kwargs):
|
||||
"""Return an annotated queryset for the CategoryDetail endpoint"""
|
||||
|
||||
queryset = super().get_queryset(*args, **kwargs)
|
||||
queryset = part_serializers.CategorySerializer.annotate_queryset(queryset)
|
||||
return queryset
|
||||
|
||||
def get_serializer_context(self):
|
||||
"""Add extra context to the serializer for the CategoryDetail endpoint"""
|
||||
ctx = super().get_serializer_context()
|
||||
|
@ -1,4 +1,4 @@
|
||||
"""Custom query filters for the Part model
|
||||
"""Custom query filters for the Part models
|
||||
|
||||
The code here makes heavy use of subquery annotations!
|
||||
|
||||
@ -19,11 +19,13 @@ Relevant PRs:
|
||||
from decimal import Decimal
|
||||
|
||||
from django.db import models
|
||||
from django.db.models import F, FloatField, Func, OuterRef, Q, Subquery
|
||||
from django.db.models import (F, FloatField, Func, IntegerField, OuterRef, Q,
|
||||
Subquery)
|
||||
from django.db.models.functions import Coalesce
|
||||
|
||||
from sql_util.utils import SubquerySum
|
||||
|
||||
import part.models
|
||||
import stock.models
|
||||
from InvenTree.status_codes import (BuildStatus, PurchaseOrderStatus,
|
||||
SalesOrderStatus)
|
||||
@ -158,3 +160,29 @@ def annotate_variant_quantity(subquery: Q, reference: str = 'quantity'):
|
||||
0,
|
||||
output_field=FloatField(),
|
||||
)
|
||||
|
||||
|
||||
def annotate_category_parts():
|
||||
"""Construct a queryset annotation which returns the number of parts in a particular category.
|
||||
|
||||
- Includes parts in subcategories also
|
||||
- Requires subquery to perform annotation
|
||||
"""
|
||||
|
||||
# Construct a subquery to provide all parts in this category and any subcategories:
|
||||
subquery = part.models.Part.objects.exclude(category=None).filter(
|
||||
category__tree_id=OuterRef('tree_id'),
|
||||
category__lft__gte=OuterRef('lft'),
|
||||
category__rght__lte=OuterRef('rght'),
|
||||
category__level__gte=OuterRef('level'),
|
||||
)
|
||||
|
||||
return Coalesce(
|
||||
Subquery(
|
||||
subquery.annotate(
|
||||
total=Func(F('pk'), function='COUNT', output_field=IntegerField())
|
||||
).values('total'),
|
||||
),
|
||||
0,
|
||||
output_field=IntegerField()
|
||||
)
|
||||
|
18
InvenTree/part/migrations/0082_partcategory_pathstring.py
Normal file
18
InvenTree/part/migrations/0082_partcategory_pathstring.py
Normal file
@ -0,0 +1,18 @@
|
||||
# Generated by Django 3.2.14 on 2022-07-31 23:54
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('part', '0081_alter_partcategory_name'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name='partcategory',
|
||||
name='pathstring',
|
||||
field=models.CharField(blank=True, help_text='Path', max_length=250, verbose_name='Path'),
|
||||
),
|
||||
]
|
54
InvenTree/part/migrations/0083_auto_20220731_2357.py
Normal file
54
InvenTree/part/migrations/0083_auto_20220731_2357.py
Normal file
@ -0,0 +1,54 @@
|
||||
# Generated by Django 3.2.14 on 2022-07-31 23:57
|
||||
|
||||
from django.db import migrations
|
||||
|
||||
from InvenTree.helpers import constructPathString
|
||||
|
||||
|
||||
def update_pathstring(apps, schema_editor):
|
||||
"""Construct pathstring for all existing PartCategory objects"""
|
||||
|
||||
PartCategory = apps.get_model('part', 'partcategory')
|
||||
|
||||
n = PartCategory.objects.count()
|
||||
|
||||
if n > 0:
|
||||
|
||||
for cat in PartCategory.objects.all():
|
||||
|
||||
# Construct complete path for category
|
||||
path = [cat.name]
|
||||
|
||||
parent = cat.parent
|
||||
|
||||
# Iterate up the tree
|
||||
while parent is not None:
|
||||
path = [parent.name] + path
|
||||
parent = parent.parent
|
||||
|
||||
pathstring = constructPathString(path)
|
||||
|
||||
cat.pathstring = pathstring
|
||||
cat.save()
|
||||
|
||||
print(f"\n--- Updated 'pathstring' for {n} PartCategory objects ---\n")
|
||||
|
||||
|
||||
def nupdate_pathstring(apps, schema_editor):
|
||||
"""Empty function for reverse migration compatibility"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('part', '0082_partcategory_pathstring'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.RunPython(
|
||||
update_pathstring,
|
||||
reverse_code=nupdate_pathstring
|
||||
)
|
||||
]
|
@ -41,9 +41,20 @@ class CategorySerializer(InvenTreeModelSerializer):
|
||||
"""Return True if the category is directly "starred" by the current user."""
|
||||
return category in self.context.get('starred_categories', [])
|
||||
|
||||
@staticmethod
|
||||
def annotate_queryset(queryset):
|
||||
"""Annotate extra information to the queryset"""
|
||||
|
||||
# Annotate the number of 'parts' which exist in each category (including subcategories!)
|
||||
queryset = queryset.annotate(
|
||||
part_count=part.filters.annotate_category_parts()
|
||||
)
|
||||
|
||||
return queryset
|
||||
|
||||
url = serializers.CharField(source='get_absolute_url', read_only=True)
|
||||
|
||||
parts = serializers.IntegerField(source='item_count', read_only=True)
|
||||
part_count = serializers.IntegerField(read_only=True)
|
||||
|
||||
level = serializers.IntegerField(read_only=True)
|
||||
|
||||
@ -60,7 +71,7 @@ class CategorySerializer(InvenTreeModelSerializer):
|
||||
'default_keywords',
|
||||
'level',
|
||||
'parent',
|
||||
'parts',
|
||||
'part_count',
|
||||
'pathstring',
|
||||
'starred',
|
||||
'url',
|
||||
|
@ -77,6 +77,76 @@ class PartCategoryAPITest(InvenTreeAPITestCase):
|
||||
|
||||
self.assertEqual(len(response.data), 5)
|
||||
|
||||
# Check that the required fields are present
|
||||
fields = [
|
||||
'pk',
|
||||
'name',
|
||||
'description',
|
||||
'default_location',
|
||||
'level',
|
||||
'parent',
|
||||
'part_count',
|
||||
'pathstring',
|
||||
'url'
|
||||
]
|
||||
|
||||
for result in response.data:
|
||||
for f in fields:
|
||||
self.assertIn(f, result)
|
||||
|
||||
def test_part_count(self):
|
||||
"""Test that the 'part_count' field is annotated correctly"""
|
||||
|
||||
url = reverse('api-part-category-list')
|
||||
|
||||
# Create a parent category
|
||||
cat = PartCategory.objects.create(
|
||||
name='Parent Cat',
|
||||
description='Some name',
|
||||
parent=None
|
||||
)
|
||||
|
||||
# Create child categories
|
||||
for ii in range(10):
|
||||
child = PartCategory.objects.create(
|
||||
name=f"Child cat {ii}",
|
||||
description="A child category",
|
||||
parent=cat
|
||||
)
|
||||
|
||||
# Create parts in this category
|
||||
for jj in range(10):
|
||||
Part.objects.create(
|
||||
name=f"Part xyz {jj}",
|
||||
description="A test part",
|
||||
category=child
|
||||
)
|
||||
|
||||
# Filter by parent category
|
||||
response = self.get(
|
||||
url,
|
||||
{
|
||||
'parent': cat.pk,
|
||||
},
|
||||
expected_code=200
|
||||
)
|
||||
|
||||
# 10 child categories
|
||||
self.assertEqual(len(response.data), 10)
|
||||
|
||||
for result in response.data:
|
||||
self.assertEqual(result['parent'], cat.pk)
|
||||
self.assertEqual(result['part_count'], 10)
|
||||
|
||||
# Detail view for parent category
|
||||
response = self.get(
|
||||
f'/api/part/category/{cat.pk}/',
|
||||
expected_code=200
|
||||
)
|
||||
|
||||
# Annotation should include parts from all sub-categories
|
||||
self.assertEqual(response.data['part_count'], 100)
|
||||
|
||||
def test_category_metadata(self):
|
||||
"""Test metadata endpoint for the PartCategory."""
|
||||
cat = PartCategory.objects.get(pk=1)
|
||||
|
@ -1,5 +1,6 @@
|
||||
"""Unit tests for the PartCategory model"""
|
||||
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.test import TestCase
|
||||
|
||||
from .models import Part, PartCategory, PartParameter, PartParameterTemplate
|
||||
@ -63,9 +64,69 @@ class CategoryTest(TestCase):
|
||||
|
||||
def test_path_string(self):
|
||||
"""Test that the category path string works correctly."""
|
||||
|
||||
# Note that due to data migrations, these fields need to be saved first
|
||||
self.resistors.save()
|
||||
self.transceivers.save()
|
||||
|
||||
self.assertEqual(str(self.resistors), 'Electronics/Resistors - Resistors')
|
||||
self.assertEqual(str(self.transceivers.pathstring), 'Electronics/IC/Transceivers')
|
||||
|
||||
# Create a new subcategory
|
||||
subcat = PartCategory.objects.create(
|
||||
name='Subcategory',
|
||||
description='My little sub category',
|
||||
parent=self.transceivers
|
||||
)
|
||||
|
||||
# Pathstring should have been updated correctly
|
||||
self.assertEqual(subcat.pathstring, 'Electronics/IC/Transceivers/Subcategory')
|
||||
self.assertEqual(len(subcat.path), 4)
|
||||
|
||||
# Move to a new parent location
|
||||
subcat.parent = self.resistors
|
||||
subcat.save()
|
||||
self.assertEqual(subcat.pathstring, 'Electronics/Resistors/Subcategory')
|
||||
self.assertEqual(len(subcat.path), 3)
|
||||
|
||||
# Move to top-level
|
||||
subcat.parent = None
|
||||
subcat.save()
|
||||
self.assertEqual(subcat.pathstring, 'Subcategory')
|
||||
self.assertEqual(len(subcat.path), 1)
|
||||
|
||||
# Construct a very long pathstring and ensure it gets updated correctly
|
||||
cat = PartCategory.objects.create(
|
||||
name='Cat',
|
||||
description='A long running category',
|
||||
parent=None
|
||||
)
|
||||
|
||||
parent = cat
|
||||
|
||||
for idx in range(26):
|
||||
letter = chr(ord('A') + idx)
|
||||
|
||||
child = PartCategory.objects.create(
|
||||
name=letter * 10,
|
||||
description=f"Subcategory {letter}",
|
||||
parent=parent
|
||||
)
|
||||
|
||||
parent = child
|
||||
|
||||
self.assertTrue(len(child.path), 26)
|
||||
self.assertEqual(
|
||||
child.pathstring,
|
||||
"Cat/AAAAAAAAAA/BBBBBBBBBB/CCCCCCCCCC/DDDDDDDDDD/EEEEEEEEEE/FFFFFFFFFF/GGGGGGGGGG/HHHHHHHHHH/IIIIIIIIII/JJJJJJJJJJ/.../OOOOOOOOOO/PPPPPPPPPP/QQQQQQQQQQ/RRRRRRRRRR/SSSSSSSSSS/TTTTTTTTTT/UUUUUUUUUU/VVVVVVVVVV/WWWWWWWWWW/XXXXXXXXXX/YYYYYYYYYY/ZZZZZZZZZZ"
|
||||
)
|
||||
self.assertTrue(len(child.pathstring) <= 250)
|
||||
|
||||
# Attempt an invalid move
|
||||
with self.assertRaises(ValidationError):
|
||||
cat.parent = child
|
||||
cat.save()
|
||||
|
||||
def test_url(self):
|
||||
"""Test that the PartCategory URL works."""
|
||||
self.assertEqual(self.capacitors.get_absolute_url(), '/part/category/3/')
|
||||
@ -130,6 +191,9 @@ class CategoryTest(TestCase):
|
||||
|
||||
def test_default_locations(self):
|
||||
"""Test traversal for default locations."""
|
||||
|
||||
self.assertIsNotNone(self.fasteners.default_location)
|
||||
self.fasteners.default_location.save()
|
||||
self.assertEqual(str(self.fasteners.default_location), 'Office/Drawer_1 - In my desk')
|
||||
|
||||
# Any part under electronics should default to 'Home'
|
||||
|
@ -220,6 +220,7 @@ class PartTest(TestCase):
|
||||
|
||||
def test_category(self):
|
||||
"""Test PartCategory path"""
|
||||
self.c1.category.save()
|
||||
self.assertEqual(str(self.c1.category), 'Electronics/Capacitors - Capacitors')
|
||||
|
||||
orphan = Part.objects.get(name='Orphan')
|
||||
|
Reference in New Issue
Block a user