From 2c8dbb8308254d0ff65cdf204ca850cf2f794352 Mon Sep 17 00:00:00 2001 From: Oliver Date: Thu, 10 Mar 2022 16:07:05 +1100 Subject: [PATCH] Merge pull request #2736 from SchrodingersGat/loc-del-bug Fix behaviour when deleting a StockLocation (cherry picked from commit ff9d48f1c0dc27542e2fc678703035a156f1399e) --- InvenTree/part/models.py | 52 +++++----- InvenTree/part/test_category.py | 119 ++++++++++++++++++++++ InvenTree/stock/models.py | 43 +++++--- InvenTree/stock/tests.py | 168 ++++++++++++++++++++++++++++++++ 4 files changed, 346 insertions(+), 36 deletions(-) diff --git a/InvenTree/part/models.py b/InvenTree/part/models.py index 09e1f77542..4653e41a61 100644 --- a/InvenTree/part/models.py +++ b/InvenTree/part/models.py @@ -20,7 +20,7 @@ from django.db.models.functions import Coalesce from django.core.validators import MinValueValidator from django.contrib.auth.models import User -from django.db.models.signals import pre_delete, post_save +from django.db.models.signals import post_save from django.dispatch import receiver from jinja2 import Template @@ -76,6 +76,35 @@ class PartCategory(InvenTreeTree): default_keywords: Default keywords for parts created in this category """ + def delete(self, *args, **kwargs): + """ + Custom model deletion routine, which updates any child categories or parts. + This must be handled within a transaction.atomic(), otherwise the tree structure is damaged + """ + + with transaction.atomic(): + + parent = self.parent + tree_id = self.tree_id + + # Update each part in this category to point to the parent category + for part in self.parts.all(): + part.category = self.parent + part.save() + + # Update each child category + for child in self.children.all(): + child.parent = self.parent + child.save() + + super().delete(*args, **kwargs) + + if parent is not None: + # Partially rebuild the tree (cheaper than a complete rebuild) + PartCategory.objects.partial_rebuild(tree_id) + else: + PartCategory.objects.rebuild() + default_location = TreeForeignKey( 'stock.StockLocation', related_name="default_categories", null=True, blank=True, @@ -260,27 +289,6 @@ class PartCategory(InvenTreeTree): ).delete() -@receiver(pre_delete, sender=PartCategory, dispatch_uid='partcategory_delete_log') -def before_delete_part_category(sender, instance, using, **kwargs): - """ Receives before_delete signal for PartCategory object - - Before deleting, update child Part and PartCategory objects: - - - For each child category, set the parent to the parent of *this* category - - For each part, set the 'category' to the parent of *this* category - """ - - # Update each part in this category to point to the parent category - for part in instance.parts.all(): - part.category = instance.parent - part.save() - - # Update each child category - for child in instance.children.all(): - child.parent = instance.parent - child.save() - - def rename_part_image(instance, filename): """ Function for renaming a part image file diff --git a/InvenTree/part/test_category.py b/InvenTree/part/test_category.py index 53030d402a..6eb76fa845 100644 --- a/InvenTree/part/test_category.py +++ b/InvenTree/part/test_category.py @@ -172,3 +172,122 @@ class CategoryTest(TestCase): # And one part should have no default location at all w = Part.objects.get(name='Widget') self.assertIsNone(w.get_default_location()) + + def test_category_tree(self): + """ + Unit tests for the part category tree structure (MPTT) + Ensure that the MPTT structure is rebuilt correctly, + and the correct ancestor tree is observed. + """ + + # Clear out any existing parts + Part.objects.all().delete() + + # First, create a structured tree of part categories + A = PartCategory.objects.create( + name='A', + description='Top level category', + ) + + B1 = PartCategory.objects.create(name='B1', parent=A) + B2 = PartCategory.objects.create(name='B2', parent=A) + B3 = PartCategory.objects.create(name='B3', parent=A) + + C11 = PartCategory.objects.create(name='C11', parent=B1) + C12 = PartCategory.objects.create(name='C12', parent=B1) + C13 = PartCategory.objects.create(name='C13', parent=B1) + + C21 = PartCategory.objects.create(name='C21', parent=B2) + C22 = PartCategory.objects.create(name='C22', parent=B2) + C23 = PartCategory.objects.create(name='C23', parent=B2) + + C31 = PartCategory.objects.create(name='C31', parent=B3) + C32 = PartCategory.objects.create(name='C32', parent=B3) + C33 = PartCategory.objects.create(name='C33', parent=B3) + + # Check that the tree_id value is correct + for cat in [B1, B2, B3, C11, C22, C33]: + self.assertEqual(cat.tree_id, A.tree_id) + self.assertEqual(cat.level, cat.parent.level + 1) + self.assertEqual(cat.get_ancestors().count(), cat.level) + + # Spot check for C31 + ancestors = C31.get_ancestors(include_self=True) + + self.assertEqual(ancestors.count(), 3) + self.assertEqual(ancestors[0], A) + self.assertEqual(ancestors[1], B3) + self.assertEqual(ancestors[2], C31) + + # At this point, we are confident that the tree is correctly structured + + # Add some parts to category B3 + + for i in range(10): + Part.objects.create( + name=f'Part {i}', + description='A test part', + category=B3, + ) + + self.assertEqual(Part.objects.filter(category=B3).count(), 10) + self.assertEqual(Part.objects.filter(category=A).count(), 0) + + # Delete category B3 + B3.delete() + + # Child parts have been moved to category A + self.assertEqual(Part.objects.filter(category=A).count(), 10) + + for cat in [C31, C32, C33]: + # These categories should now be directly under A + cat.refresh_from_db() + + self.assertEqual(cat.parent, A) + self.assertEqual(cat.level, 1) + self.assertEqual(cat.get_ancestors().count(), 1) + self.assertEqual(cat.get_ancestors()[0], A) + + # Now, delete category A + A.delete() + + # Parts have now been moved to the top-level category + self.assertEqual(Part.objects.filter(category=None).count(), 10) + + for loc in [B1, B2, C31, C32, C33]: + # These should now all be "top level" categories + loc.refresh_from_db() + + self.assertEqual(loc.level, 0) + self.assertEqual(loc.parent, None) + + # Check descendants for B1 + descendants = B1.get_descendants() + self.assertEqual(descendants.count(), 3) + + for loc in [C11, C12, C13]: + self.assertTrue(loc in descendants) + + # Check category C1x, should be B1 -> C1x + for loc in [C11, C12, C13]: + loc.refresh_from_db() + + self.assertEqual(loc.level, 1) + self.assertEqual(loc.parent, B1) + ancestors = loc.get_ancestors(include_self=True) + + self.assertEqual(ancestors.count(), 2) + self.assertEqual(ancestors[0], B1) + self.assertEqual(ancestors[1], loc) + + # Check category C2x, should be B2 -> C2x + for loc in [C21, C22, C23]: + loc.refresh_from_db() + + self.assertEqual(loc.level, 1) + self.assertEqual(loc.parent, B2) + ancestors = loc.get_ancestors(include_self=True) + + self.assertEqual(ancestors.count(), 2) + self.assertEqual(ancestors[0], B2) + self.assertEqual(ancestors[1], loc) diff --git a/InvenTree/stock/models.py b/InvenTree/stock/models.py index 27d6cf5fc3..171ee7e0a3 100644 --- a/InvenTree/stock/models.py +++ b/InvenTree/stock/models.py @@ -54,6 +54,35 @@ class StockLocation(InvenTreeTree): Stock locations can be heirarchical as required """ + def delete(self, *args, **kwargs): + """ + Custom model deletion routine, which updates any child locations or items. + This must be handled within a transaction.atomic(), otherwise the tree structure is damaged + """ + + with transaction.atomic(): + + parent = self.parent + tree_id = self.tree_id + + # Update each stock item in the stock location + for item in self.stock_items.all(): + item.location = self.parent + item.save() + + # Update each child category + for child in self.children.all(): + child.parent = self.parent + child.save() + + super().delete(*args, **kwargs) + + if parent is not None: + # Partially rebuild the tree (cheaper than a complete rebuild) + StockLocation.objects.partial_rebuild(tree_id) + else: + StockLocation.objects.rebuild() + @staticmethod def get_api_url(): return reverse('api-location-list') @@ -159,20 +188,6 @@ class StockLocation(InvenTreeTree): return self.stock_item_count() -@receiver(pre_delete, sender=StockLocation, dispatch_uid='stocklocation_delete_log') -def before_delete_stock_location(sender, instance, using, **kwargs): - - # Update each part in the stock location - for item in instance.stock_items.all(): - item.location = instance.parent - item.save() - - # Update each child category - for child in instance.children.all(): - child.parent = instance.parent - child.save() - - class StockItemManager(TreeManager): """ Custom database manager for the StockItem class. diff --git a/InvenTree/stock/tests.py b/InvenTree/stock/tests.py index d1e68fc8e5..50f77a593b 100644 --- a/InvenTree/stock/tests.py +++ b/InvenTree/stock/tests.py @@ -524,6 +524,174 @@ class StockTest(TestCase): # Serialize the remainder of the stock item.serializeStock(2, [99, 100], self.user) + def test_location_tree(self): + """ + Unit tests for stock location tree structure (MPTT). + Ensure that the MPTT structure is rebuilt correctly, + and the corrent ancestor tree is observed. + + Ref: https://github.com/inventree/InvenTree/issues/2636 + Ref: https://github.com/inventree/InvenTree/issues/2733 + """ + + # First, we will create a stock location structure + + A = StockLocation.objects.create( + name='A', + description='Top level location' + ) + + B1 = StockLocation.objects.create( + name='B1', + parent=A + ) + + B2 = StockLocation.objects.create( + name='B2', + parent=A + ) + + B3 = StockLocation.objects.create( + name='B3', + parent=A + ) + + C11 = StockLocation.objects.create( + name='C11', + parent=B1, + ) + + C12 = StockLocation.objects.create( + name='C12', + parent=B1, + ) + + C21 = StockLocation.objects.create( + name='C21', + parent=B2, + ) + + C22 = StockLocation.objects.create( + name='C22', + parent=B2, + ) + + C31 = StockLocation.objects.create( + name='C31', + parent=B3, + ) + + C32 = StockLocation.objects.create( + name='C32', + parent=B3 + ) + + # Check that the tree_id is correct for each sublocation + for loc in [B1, B2, B3, C11, C12, C21, C22, C31, C32]: + self.assertEqual(loc.tree_id, A.tree_id) + + # Check that the tree levels are correct for each node in the tree + + self.assertEqual(A.level, 0) + self.assertEqual(A.get_ancestors().count(), 0) + + for loc in [B1, B2, B3]: + self.assertEqual(loc.parent, A) + self.assertEqual(loc.level, 1) + self.assertEqual(loc.get_ancestors().count(), 1) + + for loc in [C11, C12]: + self.assertEqual(loc.parent, B1) + self.assertEqual(loc.level, 2) + self.assertEqual(loc.get_ancestors().count(), 2) + + for loc in [C21, C22]: + self.assertEqual(loc.parent, B2) + self.assertEqual(loc.level, 2) + self.assertEqual(loc.get_ancestors().count(), 2) + + for loc in [C31, C32]: + self.assertEqual(loc.parent, B3) + self.assertEqual(loc.level, 2) + self.assertEqual(loc.get_ancestors().count(), 2) + + # Spot-check for C32 + ancestors = C32.get_ancestors(include_self=True) + + self.assertEqual(ancestors[0], A) + self.assertEqual(ancestors[1], B3) + self.assertEqual(ancestors[2], C32) + + # At this point, we are confident that the tree is correctly structured. + + # Let's delete node B3 from the tree. We expect that: + # - C31 should move directly under A + # - C32 should move directly under A + + # Add some stock items to B3 + for i in range(10): + StockItem.objects.create( + part=Part.objects.get(pk=1), + quantity=10, + location=B3 + ) + + self.assertEqual(StockItem.objects.filter(location=B3).count(), 10) + self.assertEqual(StockItem.objects.filter(location=A).count(), 0) + + B3.delete() + + A.refresh_from_db() + C31.refresh_from_db() + C32.refresh_from_db() + + # Stock items have been moved to A + self.assertEqual(StockItem.objects.filter(location=A).count(), 10) + + # Parent should be A + self.assertEqual(C31.parent, A) + self.assertEqual(C32.parent, A) + + self.assertEqual(C31.tree_id, A.tree_id) + self.assertEqual(C31.level, 1) + + self.assertEqual(C32.tree_id, A.tree_id) + self.assertEqual(C32.level, 1) + + # Ancestor tree should be just A + ancestors = C31.get_ancestors() + self.assertEqual(ancestors.count(), 1) + self.assertEqual(ancestors[0], A) + + ancestors = C32.get_ancestors() + self.assertEqual(ancestors.count(), 1) + self.assertEqual(ancestors[0], A) + + # Delete A + A.delete() + + # Stock items have been moved to top-level location + self.assertEqual(StockItem.objects.filter(location=None).count(), 10) + + for loc in [B1, B2, C11, C12, C21, C22]: + loc.refresh_from_db() + + self.assertEqual(B1.parent, None) + self.assertEqual(B2.parent, None) + + self.assertEqual(C11.parent, B1) + self.assertEqual(C12.parent, B1) + self.assertEqual(C11.get_ancestors().count(), 1) + self.assertEqual(C12.get_ancestors().count(), 1) + + self.assertEqual(C21.parent, B2) + self.assertEqual(C22.parent, B2) + + ancestors = C21.get_ancestors() + + self.assertEqual(C21.get_ancestors().count(), 1) + self.assertEqual(C22.get_ancestors().count(), 1) + class VariantTest(StockTest): """