From fd0a57c4a14af000e03682afa27d61f3a5731b30 Mon Sep 17 00:00:00 2001 From: Oliver Date: Mon, 30 Oct 2023 06:57:40 +1100 Subject: [PATCH] Improve deletion behaviour for InvenTreeTree model (#5806) * Improve deletion behaviour for InvenTreeTree model - Remove recursive call to function - Handle database operations as bulk queries - Ensure child nodes have their pathstring updated correctly - Remove old @receiver hook - Refactor StockLocation.delete method - Refactor PartCategory.delete method - Atomic transactions potentially problematic here * Add docstring * Fix method name * Use bulk-update instead of recursive save when pathstring changes * Improvements for tree delete method - Handle case where item has already been deleted * Raise exception rather than simply logging * Update unit tests * Improvements to unrelated unit test * Fix urls.md * Fix typo --- InvenTree/InvenTree/models.py | 158 +++++++++++++++++++++++++----- InvenTree/part/models.py | 48 ++------- InvenTree/part/test_bom_import.py | 8 +- InvenTree/part/test_category.py | 85 ++++++++++++++++ InvenTree/stock/api.py | 20 ++-- InvenTree/stock/models.py | 49 ++------- 6 files changed, 254 insertions(+), 114 deletions(-) diff --git a/InvenTree/InvenTree/models.py b/InvenTree/InvenTree/models.py index 15910552dd..640fa23306 100644 --- a/InvenTree/InvenTree/models.py +++ b/InvenTree/InvenTree/models.py @@ -12,7 +12,7 @@ from django.contrib.auth.models import User from django.contrib.contenttypes.models import ContentType from django.core.exceptions import ValidationError from django.db import models -from django.db.models.signals import post_save, pre_delete +from django.db.models.signals import post_save from django.dispatch import receiver from django.urls import reverse from django.utils.translation import gettext_lazy as _ @@ -580,6 +580,10 @@ class InvenTreeTree(MPTTModel): parent: The item immediately above this one. An item with a null parent is a top-level item """ + # How items (not nodes) are hooked into the tree + # e.g. for StockLocation, this value is 'location' + ITEM_PARENT_KEY = None + class Meta: """Metaclass defines extra model properties.""" abstract = True @@ -588,6 +592,106 @@ class InvenTreeTree(MPTTModel): """Set insert order.""" order_insertion_by = ['name'] + def delete(self, delete_children=False, delete_items=False): + """Handle the deletion of a tree node. + + 1. Update nodes and items under the current node + 2. Delete this node + 3. Rebuild the model tree + 4. Rebuild the path for any remaining lower nodes + """ + tree_id = self.tree_id if self.parent else None + + # Ensure that we have the latest version of the database object + try: + self.refresh_from_db() + except self.__class__.DoesNotExist: + # If the object no longer exists, raise a ValidationError + raise ValidationError("Object %s of type %s no longer exists", str(self), str(self.__class__)) + + # Cache node ID values for lower nodes, before we delete this one + lower_nodes = list(self.get_descendants(include_self=False).values_list('pk', flat=True)) + + # 1. Update nodes and items under the current node + self.handle_tree_delete(delete_children=delete_children, delete_items=delete_items) + + # 2. Delete *this* node + super().delete() + + # 3. Update the tree structure + if tree_id: + self.__class__.objects.partial_rebuild(tree_id) + else: + self.__class__.objects.rebuild() + + # 4. Rebuild the path for any remaining lower nodes + nodes = self.__class__.objects.filter(pk__in=lower_nodes) + + nodes_to_update = [] + + for node in nodes: + new_path = node.construct_pathstring() + + if new_path != node.pathstring: + node.pathstring = new_path + nodes_to_update.append(node) + + if len(nodes_to_update) > 0: + self.__class__.objects.bulk_update(nodes_to_update, ['pathstring']) + + def handle_tree_delete(self, delete_children=False, delete_items=False): + """Delete a single instance of the tree, based on provided kwargs. + + Removing a tree "node" from the database must be considered carefully, + based on what the user intends for any items which exist *under* that node. + + - "children" are any nodes which exist *under* this node (e.g. PartCategory) + - "items" are any items which exist *under* this node (e.g. Part) + + Arguments: + delete_children: If True, delete all child items + delete_items: If True, delete all items associated with this node + + There are multiple scenarios we can consider here: + + A) delete_children = True and delete_items = True + B) delete_children = True and delete_items = False + C) delete_children = False and delete_items = True + D) delete_children = False and delete_items = False + """ + + # Case A: Delete all child items, and all child nodes. + # - Delete all items at any lower level + # - Delete all descendant nodes + if delete_children and delete_items: + self.get_items(cascade=True).delete() + self.get_descendants(include_self=False).delete() + + # Case B: Delete all child nodes, but move all child items up to the parent + # - Move all items at any lower level to the parent of this item + # - Delete all descendant nodes + elif delete_children and not delete_items: + self.get_items(cascade=True).update(**{ + self.ITEM_PARENT_KEY: self.parent + }) + self.get_descendants(include_self=False).delete() + + # Case C: Delete all child items, but keep all child nodes + # - Remove all items directly associated with this node + # - Move any direct child nodes up one level + elif not delete_children and delete_items: + self.get_items(cascade=False).delete() + self.get_children().update(parent=self.parent) + + # Case D: Keep all child items, and keep all child nodes + # - Move all items directly associated with this node up one level + # - Move any direct child nodes up one level + elif not delete_children and not delete_items: + self.get_items(cascade=False).update(**{ + self.ITEM_PARENT_KEY: self.parent + }) + self.get_children().update(parent=self.parent) + def validate_unique(self, exclude=None): """Validate that this tree instance satisfies our uniqueness requirements. @@ -614,6 +718,12 @@ class InvenTreeTree(MPTTModel): } } + def construct_pathstring(self): + """Construct the pathstring for this tree node""" + return InvenTree.helpers.constructPathString( + [item.name for item in self.path] + ) + def save(self, *args, **kwargs): """Custom save method for InvenTreeTree abstract model""" try: @@ -625,9 +735,7 @@ class InvenTreeTree(MPTTModel): }) # Re-calculate the 'pathstring' field - pathstring = InvenTree.helpers.constructPathString( - [item.name for item in self.path] - ) + pathstring = self.construct_pathstring() if pathstring != self.pathstring: @@ -639,9 +747,20 @@ class InvenTreeTree(MPTTModel): self.pathstring = pathstring super().save(*args, **kwargs) - # Ensure that the pathstring changes are propagated down the tree also - for child in self.get_children(): - child.save(*args, **kwargs) + # Update the pathstring for any child nodes + lower_nodes = self.get_descendants(include_self=False) + + nodes_to_update = [] + + for node in lower_nodes: + new_path = node.construct_pathstring() + + if new_path != node.pathstring: + node.pathstring = new_path + nodes_to_update.append(node) + + if len(nodes_to_update) > 0: + self.__class__.objects.bulk_update(nodes_to_update, ['pathstring']) name = models.CharField( blank=False, @@ -673,16 +792,15 @@ class InvenTreeTree(MPTTModel): help_text=_('Path') ) - @property - def item_count(self): - """Return the number of items which exist *under* this node in the tree. + def get_items(self, cascade=False): + """Return a queryset of items which exist *under* this node in the tree. - Here an 'item' is considered to be the 'leaf' at the end of each branch, - and the exact nature here will depend on the class implementation. + - For a StockLocation instance, this would be a queryset of StockItem objects + - For a PartCategory instance, this would be a queryset of Part objects - The default implementation returns zero + The default implementation returns an empty list """ - return 0 + raise NotImplementedError(f"items() method not implemented for {type(self)}") def getUniqueParents(self): """Return a flat set of all parent items that exist above this node. @@ -878,18 +996,6 @@ class InvenTreeBarcodeMixin(models.Model): self.save() -@receiver(pre_delete, sender=InvenTreeTree, dispatch_uid='tree_pre_delete_log') -def before_delete_tree_item(sender, instance, using, **kwargs): - """Receives pre_delete signal from InvenTreeTree object. - - Before an item is deleted, update each child object to point to the parent of the object being deleted. - """ - # Update each tree item below this one - for child in instance.children.all(): - child.parent = instance.parent - child.save() - - @receiver(post_save, sender=Error, dispatch_uid='error_post_save_notification') def after_error_logged(sender, instance: Error, created: bool, **kwargs): """Callback when a server error is logged. diff --git a/InvenTree/part/models.py b/InvenTree/part/models.py index b39196645c..7451c8cf63 100644 --- a/InvenTree/part/models.py +++ b/InvenTree/part/models.py @@ -72,55 +72,23 @@ class PartCategory(MetadataMixin, InvenTreeTree): default_keywords: Default keywords for parts created in this category """ + ITEM_PARENT_KEY = 'category' + class Meta: """Metaclass defines extra model properties""" verbose_name = _("Part Category") verbose_name_plural = _("Part Categories") - def delete_recursive(self, *args, **kwargs): - """This function handles the recursive deletion of subcategories depending on kwargs contents""" - delete_parts = kwargs.get('delete_parts', False) - parent_category = kwargs.get('parent_category', None) - - if parent_category is None: - # First iteration, (no part_category kwargs passed) - parent_category = self.parent - - for child_part in self.parts.all(): - if delete_parts: - child_part.delete() - else: - child_part.category = parent_category - child_part.save() - - for child_category in self.children.all(): - if kwargs.get('delete_child_categories', False): - child_category.delete_recursive(**{ - "delete_child_categories": True, - "delete_parts": delete_parts, - "parent_category": parent_category}) - else: - child_category.parent = parent_category - child_category.save() - - super().delete(*args, **{}) - def delete(self, *args, **kwargs): """Custom model deletion routine, which updates any child categories or parts. This must be handled within a transaction.atomic(), otherwise the tree structure is damaged """ - with transaction.atomic(): - self.delete_recursive(**{ - "delete_parts": kwargs.get('delete_parts', False), - "delete_child_categories": kwargs.get('delete_child_categories', False), - "parent_category": self.parent}) - if self.parent is not None: - # Partially rebuild the tree (cheaper than a complete rebuild) - PartCategory.objects.partial_rebuild(self.tree_id) - else: - PartCategory.objects.rebuild() + super().delete( + delete_children=kwargs.get('delete_child_categories', False), + delete_items=kwargs.get('delete_parts', False), + ) default_location = TreeForeignKey( 'stock.StockLocation', related_name="default_categories", @@ -189,6 +157,10 @@ class PartCategory(MetadataMixin, InvenTreeTree): """Return the number of parts contained in this PartCategory""" return self.partcount() + def get_items(self, cascade=False): + """Return a queryset containing the parts which exist in this category""" + return self.get_parts(cascade=cascade) + def partcount(self, cascade=True, active=False): """Return the total part count under this category (including children of child categories).""" query = self.get_parts(cascade=cascade) diff --git a/InvenTree/part/test_bom_import.py b/InvenTree/part/test_bom_import.py index 7f77d4254a..de6c706993 100644 --- a/InvenTree/part/test_bom_import.py +++ b/InvenTree/part/test_bom_import.py @@ -237,9 +237,9 @@ class BomUploadTest(InvenTreeAPITestCase): components = Part.objects.filter(component=True) - for idx, _ in enumerate(components): + for component in components: dataset.append([ - f"Component {idx}", + component.name, 10, ]) @@ -266,9 +266,9 @@ class BomUploadTest(InvenTreeAPITestCase): dataset.headers = ['part_ipn', 'quantity'] - for idx, _ in enumerate(components): + for component in components: dataset.append([ - f"CMP_{idx}", + component.IPN, 10, ]) diff --git a/InvenTree/part/test_category.py b/InvenTree/part/test_category.py index 0762f40cf3..4e5400c2fd 100644 --- a/InvenTree/part/test_category.py +++ b/InvenTree/part/test_category.py @@ -248,6 +248,18 @@ class CategoryTest(TestCase): C32 = PartCategory.objects.create(name='C32', parent=B3) C33 = PartCategory.objects.create(name='C33', parent=B3) + D31 = PartCategory.objects.create(name='D31', parent=C31) + D32 = PartCategory.objects.create(name='D32', parent=C32) + D33 = PartCategory.objects.create(name='D33', parent=C33) + + E33 = PartCategory.objects.create(name='E33', parent=D33) + + # Check that pathstrings have been generated correctly + self.assertEqual(B3.pathstring, 'A/B3') + self.assertEqual(C11.pathstring, 'A/B1/C11') + self.assertEqual(C22.pathstring, 'A/B2/C22') + self.assertEqual(C33.pathstring, 'A/B3/C33') + # Check that the tree_id value is correct for cat in [B1, B2, B3, C11, C22, C33]: self.assertEqual(cat.tree_id, A.tree_id) @@ -289,6 +301,8 @@ class CategoryTest(TestCase): self.assertEqual(cat.get_ancestors().count(), 1) self.assertEqual(cat.get_ancestors()[0], A) + self.assertEqual(cat.pathstring, f'A/{cat.name}') + # Now, delete category A A.delete() @@ -302,6 +316,13 @@ class CategoryTest(TestCase): self.assertEqual(loc.level, 0) self.assertEqual(loc.parent, None) + # Pathstring should be the same as the name + self.assertEqual(loc.pathstring, loc.name) + + # Test pathstring for direct children + for child in loc.get_children(): + self.assertEqual(child.pathstring, f'{loc.name}/{child.name}') + # Check descendants for B1 descendants = B1.get_descendants() self.assertEqual(descendants.count(), 3) @@ -321,6 +342,8 @@ class CategoryTest(TestCase): self.assertEqual(ancestors[0], B1) self.assertEqual(ancestors[1], loc) + self.assertEqual(loc.pathstring, f'B1/{loc.name}') + # Check category C2x, should be B2 -> C2x for loc in [C21, C22, C23]: loc.refresh_from_db() @@ -332,3 +355,65 @@ class CategoryTest(TestCase): self.assertEqual(ancestors.count(), 2) self.assertEqual(ancestors[0], B2) self.assertEqual(ancestors[1], loc) + + self.assertEqual(loc.pathstring, f'B2/{loc.name}') + + # Check category D3x, should be C3x -> D3x + D31.refresh_from_db() + self.assertEqual(D31.pathstring, 'C31/D31') + D32.refresh_from_db() + self.assertEqual(D32.pathstring, 'C32/D32') + D33.refresh_from_db() + self.assertEqual(D33.pathstring, 'C33/D33') + + # Check category E33 + E33.refresh_from_db() + self.assertEqual(E33.pathstring, 'C33/D33/E33') + + # Change the name of an upper level + C33.name = '-C33-' + C33.save() + + D33.refresh_from_db() + self.assertEqual(D33.pathstring, '-C33-/D33') + + E33.refresh_from_db() + self.assertEqual(E33.pathstring, '-C33-/D33/E33') + + # Test the "delete child categories" functionality + C33.delete(delete_child_categories=True) + + # Any child underneath C33 should have been deleted + for cat in [D33, E33]: + with self.assertRaises(PartCategory.DoesNotExist): + cat.refresh_from_db() + + Part.objects.all().delete() + + # Create some sample parts under D32 + for ii in range(10): + Part.objects.create( + name=f'Part D32 {ii}', + description='A test part', + category=D32, + ) + + self.assertEqual(Part.objects.filter(category=D32).count(), 10) + self.assertEqual(Part.objects.filter(category=C32).count(), 0) + + # Delete D32, should move the parts up to C32 + D32.delete(delete_child_categories=False, delete_parts=False) + + # All parts should have been deleted + self.assertEqual(Part.objects.filter(category=C32).count(), 10) + + # Now, delete C32 and delete all parts underneath + C32.delete(delete_parts=True) + + # 10 parts should have been deleted from the database + self.assertEqual(Part.objects.count(), 0) + + # Finally, try deleting a category which has already been deleted + # should log an exception + with self.assertRaises(ValidationError): + B3.delete() diff --git a/InvenTree/stock/api.py b/InvenTree/stock/api.py index 52ac122f7c..4482f85130 100644 --- a/InvenTree/stock/api.py +++ b/InvenTree/stock/api.py @@ -1418,13 +1418,19 @@ class LocationDetail(CustomRetrieveUpdateDestroyAPI): def destroy(self, request, *args, **kwargs): """Delete a Stock location instance via the API""" - delete_stock_items = 'delete_stock_items' in request.data and request.data['delete_stock_items'] == '1' - delete_sub_locations = 'delete_sub_locations' in request.data and request.data['delete_sub_locations'] == '1' - return super().destroy(request, - *args, - **dict(kwargs, - delete_sub_locations=delete_sub_locations, - delete_stock_items=delete_stock_items)) + + delete_stock_items = str(request.data.get('delete_stock_items', 0)) == '1' + delete_sub_locations = str(request.data.get('delete_sub_locations', 0)) == '1' + + return super().destroy( + request, + *args, + **dict( + kwargs, + delete_sub_locations=delete_sub_locations, + delete_stock_items=delete_stock_items + ) + ) stock_api_urls = [ diff --git a/InvenTree/stock/models.py b/InvenTree/stock/models.py index 4b49db3937..71b7e24320 100644 --- a/InvenTree/stock/models.py +++ b/InvenTree/stock/models.py @@ -108,6 +108,8 @@ class StockLocation(InvenTreeBarcodeMixin, MetadataMixin, InvenTreeTree): Stock locations can be hierarchical as required """ + ITEM_PARENT_KEY = 'location' + objects = StockLocationManager() class Meta: @@ -118,51 +120,16 @@ class StockLocation(InvenTreeBarcodeMixin, MetadataMixin, InvenTreeTree): tags = TaggableManager(blank=True) - def delete_recursive(self, *args, **kwargs): - """This function handles the recursive deletion of sub-locations depending on kwargs contents""" - delete_stock_items = kwargs.get('delete_stock_items', False) - parent_location = kwargs.get('parent_location', None) - - if parent_location is None: - # First iteration, (no parent_location kwargs passed) - parent_location = self.parent - - for child_item in self.get_stock_items(False): - if delete_stock_items: - child_item.delete() - else: - child_item.location = parent_location - child_item.save() - - for child_location in self.children.all(): - if kwargs.get('delete_sub_locations', False): - child_location.delete_recursive(**{ - "delete_sub_locations": True, - "delete_stock_items": delete_stock_items, - "parent_location": parent_location}) - else: - child_location.parent = parent_location - child_location.save() - - super().delete(*args, **{}) - def delete(self, *args, **kwargs): """Custom model deletion routine, which updates any child locations or items. This must be handled within a transaction.atomic(), otherwise the tree structure is damaged """ - with transaction.atomic(): - self.delete_recursive(**{ - "delete_stock_items": kwargs.get('delete_stock_items', False), - "delete_sub_locations": kwargs.get('delete_sub_locations', False), - "parent_category": self.parent}) - - if self.parent is not None: - # Partially rebuild the tree (cheaper than a complete rebuild) - StockLocation.objects.partial_rebuild(self.tree_id) - else: - StockLocation.objects.rebuild() + super().delete( + delete_children=kwargs.get('delete_sub_locations', False), + delete_items=kwargs.get('delete_stock_items', False), + ) @staticmethod def get_api_url(): @@ -300,6 +267,10 @@ class StockLocation(InvenTreeBarcodeMixin, MetadataMixin, InvenTreeTree): """ return self.stock_item_count() + def get_items(self, cascade=False): + """Return a queryset for all stock items under this category""" + return self.get_stock_items(cascade=cascade) + def generate_batch_code(): """Generate a default 'batch code' for a new StockItem.