2
0
mirror of https://github.com/inventree/InvenTree.git synced 2025-06-18 13:05:42 +00:00

Tree query improvements (#3443)

* Allow part category table to be ordered by part count

* Add queryset annotation for part-category part-count

- Uses subquery to annotate the part-count for sub-categories
- Huge reduction in number of queries

* Update 'pathstring' property of PartCategory and StockLocation

- No longer a dynamically calculated value
- Constructed when the model is saved, and then written to the database
- Limited to 250 characters

* Data migration to re-construct pathstring for PartCategory objects

* Fix for tree model save() method

* Add unit tests for pathstring construction

* Data migration for StockLocation pathstring values

* Update part API

- Add new annotation to PartLocationDetail view

* Update API version

* Apply similar annotation to StockLocation API endpoints

* Extra tests for PartCategory API

* Unit test fixes

* Allow PartCategory and StockLocation lists to be sorted by 'pathstring'

* Further unit test fixes
This commit is contained in:
Oliver
2022-08-01 13:43:27 +10:00
committed by GitHub
parent 1306db74b2
commit 175d9555b0
19 changed files with 478 additions and 21 deletions

View File

@ -53,6 +53,13 @@ class CategoryList(ListCreateAPI):
queryset = PartCategory.objects.all()
serializer_class = part_serializers.CategorySerializer
def get_queryset(self, *args, **kwargs):
"""Return an annotated queryset for the CategoryList endpoint"""
queryset = super().get_queryset(*args, **kwargs)
queryset = part_serializers.CategorySerializer.annotate_queryset(queryset)
return queryset
def get_serializer_context(self):
"""Add extra context data to the serializer for the PartCategoryList endpoint"""
ctx = super().get_serializer_context()
@ -141,9 +148,11 @@ class CategoryList(ListCreateAPI):
ordering_fields = [
'name',
'pathstring',
'level',
'tree_id',
'lft',
'part_count',
]
# Use hierarchical ordering by default
@ -165,6 +174,13 @@ class CategoryDetail(RetrieveUpdateDestroyAPI):
serializer_class = part_serializers.CategorySerializer
queryset = PartCategory.objects.all()
def get_queryset(self, *args, **kwargs):
"""Return an annotated queryset for the CategoryDetail endpoint"""
queryset = super().get_queryset(*args, **kwargs)
queryset = part_serializers.CategorySerializer.annotate_queryset(queryset)
return queryset
def get_serializer_context(self):
"""Add extra context to the serializer for the CategoryDetail endpoint"""
ctx = super().get_serializer_context()

View File

@ -1,4 +1,4 @@
"""Custom query filters for the Part model
"""Custom query filters for the Part models
The code here makes heavy use of subquery annotations!
@ -19,11 +19,13 @@ Relevant PRs:
from decimal import Decimal
from django.db import models
from django.db.models import F, FloatField, Func, OuterRef, Q, Subquery
from django.db.models import (F, FloatField, Func, IntegerField, OuterRef, Q,
Subquery)
from django.db.models.functions import Coalesce
from sql_util.utils import SubquerySum
import part.models
import stock.models
from InvenTree.status_codes import (BuildStatus, PurchaseOrderStatus,
SalesOrderStatus)
@ -158,3 +160,29 @@ def annotate_variant_quantity(subquery: Q, reference: str = 'quantity'):
0,
output_field=FloatField(),
)
def annotate_category_parts():
"""Construct a queryset annotation which returns the number of parts in a particular category.
- Includes parts in subcategories also
- Requires subquery to perform annotation
"""
# Construct a subquery to provide all parts in this category and any subcategories:
subquery = part.models.Part.objects.exclude(category=None).filter(
category__tree_id=OuterRef('tree_id'),
category__lft__gte=OuterRef('lft'),
category__rght__lte=OuterRef('rght'),
category__level__gte=OuterRef('level'),
)
return Coalesce(
Subquery(
subquery.annotate(
total=Func(F('pk'), function='COUNT', output_field=IntegerField())
).values('total'),
),
0,
output_field=IntegerField()
)

View File

@ -0,0 +1,18 @@
# Generated by Django 3.2.14 on 2022-07-31 23:54
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('part', '0081_alter_partcategory_name'),
]
operations = [
migrations.AddField(
model_name='partcategory',
name='pathstring',
field=models.CharField(blank=True, help_text='Path', max_length=250, verbose_name='Path'),
),
]

View File

@ -0,0 +1,54 @@
# Generated by Django 3.2.14 on 2022-07-31 23:57
from django.db import migrations
from InvenTree.helpers import constructPathString
def update_pathstring(apps, schema_editor):
"""Construct pathstring for all existing PartCategory objects"""
PartCategory = apps.get_model('part', 'partcategory')
n = PartCategory.objects.count()
if n > 0:
for cat in PartCategory.objects.all():
# Construct complete path for category
path = [cat.name]
parent = cat.parent
# Iterate up the tree
while parent is not None:
path = [parent.name] + path
parent = parent.parent
pathstring = constructPathString(path)
cat.pathstring = pathstring
cat.save()
print(f"\n--- Updated 'pathstring' for {n} PartCategory objects ---\n")
def nupdate_pathstring(apps, schema_editor):
"""Empty function for reverse migration compatibility"""
pass
class Migration(migrations.Migration):
dependencies = [
('part', '0082_partcategory_pathstring'),
]
operations = [
migrations.RunPython(
update_pathstring,
reverse_code=nupdate_pathstring
)
]

View File

@ -41,9 +41,20 @@ class CategorySerializer(InvenTreeModelSerializer):
"""Return True if the category is directly "starred" by the current user."""
return category in self.context.get('starred_categories', [])
@staticmethod
def annotate_queryset(queryset):
"""Annotate extra information to the queryset"""
# Annotate the number of 'parts' which exist in each category (including subcategories!)
queryset = queryset.annotate(
part_count=part.filters.annotate_category_parts()
)
return queryset
url = serializers.CharField(source='get_absolute_url', read_only=True)
parts = serializers.IntegerField(source='item_count', read_only=True)
part_count = serializers.IntegerField(read_only=True)
level = serializers.IntegerField(read_only=True)
@ -60,7 +71,7 @@ class CategorySerializer(InvenTreeModelSerializer):
'default_keywords',
'level',
'parent',
'parts',
'part_count',
'pathstring',
'starred',
'url',

View File

@ -77,6 +77,76 @@ class PartCategoryAPITest(InvenTreeAPITestCase):
self.assertEqual(len(response.data), 5)
# Check that the required fields are present
fields = [
'pk',
'name',
'description',
'default_location',
'level',
'parent',
'part_count',
'pathstring',
'url'
]
for result in response.data:
for f in fields:
self.assertIn(f, result)
def test_part_count(self):
"""Test that the 'part_count' field is annotated correctly"""
url = reverse('api-part-category-list')
# Create a parent category
cat = PartCategory.objects.create(
name='Parent Cat',
description='Some name',
parent=None
)
# Create child categories
for ii in range(10):
child = PartCategory.objects.create(
name=f"Child cat {ii}",
description="A child category",
parent=cat
)
# Create parts in this category
for jj in range(10):
Part.objects.create(
name=f"Part xyz {jj}",
description="A test part",
category=child
)
# Filter by parent category
response = self.get(
url,
{
'parent': cat.pk,
},
expected_code=200
)
# 10 child categories
self.assertEqual(len(response.data), 10)
for result in response.data:
self.assertEqual(result['parent'], cat.pk)
self.assertEqual(result['part_count'], 10)
# Detail view for parent category
response = self.get(
f'/api/part/category/{cat.pk}/',
expected_code=200
)
# Annotation should include parts from all sub-categories
self.assertEqual(response.data['part_count'], 100)
def test_category_metadata(self):
"""Test metadata endpoint for the PartCategory."""
cat = PartCategory.objects.get(pk=1)

View File

@ -1,5 +1,6 @@
"""Unit tests for the PartCategory model"""
from django.core.exceptions import ValidationError
from django.test import TestCase
from .models import Part, PartCategory, PartParameter, PartParameterTemplate
@ -63,9 +64,69 @@ class CategoryTest(TestCase):
def test_path_string(self):
"""Test that the category path string works correctly."""
# Note that due to data migrations, these fields need to be saved first
self.resistors.save()
self.transceivers.save()
self.assertEqual(str(self.resistors), 'Electronics/Resistors - Resistors')
self.assertEqual(str(self.transceivers.pathstring), 'Electronics/IC/Transceivers')
# Create a new subcategory
subcat = PartCategory.objects.create(
name='Subcategory',
description='My little sub category',
parent=self.transceivers
)
# Pathstring should have been updated correctly
self.assertEqual(subcat.pathstring, 'Electronics/IC/Transceivers/Subcategory')
self.assertEqual(len(subcat.path), 4)
# Move to a new parent location
subcat.parent = self.resistors
subcat.save()
self.assertEqual(subcat.pathstring, 'Electronics/Resistors/Subcategory')
self.assertEqual(len(subcat.path), 3)
# Move to top-level
subcat.parent = None
subcat.save()
self.assertEqual(subcat.pathstring, 'Subcategory')
self.assertEqual(len(subcat.path), 1)
# Construct a very long pathstring and ensure it gets updated correctly
cat = PartCategory.objects.create(
name='Cat',
description='A long running category',
parent=None
)
parent = cat
for idx in range(26):
letter = chr(ord('A') + idx)
child = PartCategory.objects.create(
name=letter * 10,
description=f"Subcategory {letter}",
parent=parent
)
parent = child
self.assertTrue(len(child.path), 26)
self.assertEqual(
child.pathstring,
"Cat/AAAAAAAAAA/BBBBBBBBBB/CCCCCCCCCC/DDDDDDDDDD/EEEEEEEEEE/FFFFFFFFFF/GGGGGGGGGG/HHHHHHHHHH/IIIIIIIIII/JJJJJJJJJJ/.../OOOOOOOOOO/PPPPPPPPPP/QQQQQQQQQQ/RRRRRRRRRR/SSSSSSSSSS/TTTTTTTTTT/UUUUUUUUUU/VVVVVVVVVV/WWWWWWWWWW/XXXXXXXXXX/YYYYYYYYYY/ZZZZZZZZZZ"
)
self.assertTrue(len(child.pathstring) <= 250)
# Attempt an invalid move
with self.assertRaises(ValidationError):
cat.parent = child
cat.save()
def test_url(self):
"""Test that the PartCategory URL works."""
self.assertEqual(self.capacitors.get_absolute_url(), '/part/category/3/')
@ -130,6 +191,9 @@ class CategoryTest(TestCase):
def test_default_locations(self):
"""Test traversal for default locations."""
self.assertIsNotNone(self.fasteners.default_location)
self.fasteners.default_location.save()
self.assertEqual(str(self.fasteners.default_location), 'Office/Drawer_1 - In my desk')
# Any part under electronics should default to 'Home'

View File

@ -220,6 +220,7 @@ class PartTest(TestCase):
def test_category(self):
"""Test PartCategory path"""
self.c1.category.save()
self.assertEqual(str(self.c1.category), 'Electronics/Capacitors - Capacitors')
orphan = Part.objects.get(name='Orphan')