From 22218fd5c605cd1099755f1eb49e6d34736e0623 Mon Sep 17 00:00:00 2001 From: Oliver Date: Sun, 13 Jul 2025 21:27:32 +1000 Subject: [PATCH] [bug] Tree fix (#9979) * Refactor InvenTreeTree model structure - Allow for tree with null items * Refactor pathstring * Factor pathstring out into a separate mixin - Keep tree operations separate (in InvenTreeTree) - Pathstring operations are now in PathStringMixin * throw error - Ensure that this func gets removed in future commit * Fix node delete code * Migrate "Build" model to new structure * Add unit tests for Build tree structure * Refactor StockLocationTreeTest * Implement tree rebuild test for StockItem model * Add unit test for stock item serialization * Refactor "Part" model to use mixin * Add unit tests for part variant tree * Add test for node deletion * Adjust unit tests * Ensure items are not created with null tree_id * Further unit tests and updates * Fix unit tests * Remove duplicate check * Adjust build fixture * Remove rebuild call * Fixing more tests * Remove calls to rebuild part tree * Add test for tree fixtures * Report tree rebuild errors to sentry * Remove helper func * Updates for splitStock * Cleaner inheritance * Simpilfy test - tree_id is somewhat ephemeral * Handle null parent * Enforce partial rebuild if parent changes * Fix * Remove hard-coded "parent" references * Fix order of delete operations * Fix unit test * Unit test tweaks * Improved handling for deleting a root node * Only set tree_id if not already specified * Only rebuild valid tree_id values * Cast to list * Adjust unit test - Test values were wrong, due to bad data in fixtures * Do not bulk delete - mysql no likey * Enhanced rebuild logic * Fix for unit test * Improve logic for _create_serial_numbers * Unit test fix * Remove unused function --- src/backend/InvenTree/InvenTree/helpers.py | 2 +- src/backend/InvenTree/InvenTree/models.py | 441 +++++++++----- src/backend/InvenTree/InvenTree/sentry.py | 12 +- src/backend/InvenTree/InvenTree/tests.py | 73 ++- .../InvenTree/build/fixtures/build.yaml | 30 +- src/backend/InvenTree/build/models.py | 19 +- src/backend/InvenTree/build/test_api.py | 18 +- src/backend/InvenTree/build/tests.py | 163 ++++++ src/backend/InvenTree/order/test_api.py | 23 +- .../InvenTree/part/fixtures/category.yaml | 34 +- src/backend/InvenTree/part/fixtures/part.yaml | 95 ++- src/backend/InvenTree/part/models.py | 21 +- src/backend/InvenTree/part/test_api.py | 26 +- src/backend/InvenTree/part/test_bom_item.py | 2 - src/backend/InvenTree/part/test_category.py | 51 +- src/backend/InvenTree/part/test_part.py | 190 +++++- .../integration/test_validation_sample.py | 17 +- .../InvenTree/stock/fixtures/location.yaml | 8 +- .../InvenTree/stock/fixtures/stock.yaml | 182 +++--- src/backend/InvenTree/stock/models.py | 86 +-- src/backend/InvenTree/stock/serializers.py | 2 +- src/backend/InvenTree/stock/tasks.py | 52 +- src/backend/InvenTree/stock/test_api.py | 26 +- src/backend/InvenTree/stock/tests.py | 540 +++++++++++------- 24 files changed, 1447 insertions(+), 666 deletions(-) diff --git a/src/backend/InvenTree/InvenTree/helpers.py b/src/backend/InvenTree/InvenTree/helpers.py index 2f1080c212..03cfbcb794 100644 --- a/src/backend/InvenTree/InvenTree/helpers.py +++ b/src/backend/InvenTree/InvenTree/helpers.py @@ -154,7 +154,7 @@ def generateTestKey(test_name: str) -> str: return key -def constructPathString(path, max_chars=250): +def constructPathString(path: list[str], max_chars: int = 250) -> str: """Construct a 'path string' for the given path. Arguments: diff --git a/src/backend/InvenTree/InvenTree/models.py b/src/backend/InvenTree/InvenTree/models.py index 6297216f8c..554ba1062d 100644 --- a/src/backend/InvenTree/InvenTree/models.py +++ b/src/backend/InvenTree/InvenTree/models.py @@ -4,7 +4,6 @@ from datetime import datetime from string import Formatter from django.contrib.auth import get_user_model -from django.contrib.contenttypes.models import ContentType from django.core.exceptions import ValidationError from django.db import models from django.db.models import QuerySet @@ -26,6 +25,7 @@ import InvenTree.fields import InvenTree.format import InvenTree.helpers import InvenTree.helpers_model +import InvenTree.sentry logger = structlog.get_logger('inventree') @@ -123,7 +123,7 @@ class PluginValidationMixin(DiffMixin): self.run_plugin_validation() super().save(*args, **kwargs) - def delete(self): + def delete(self, *args, **kwargs): """Run plugin validation on model delete. Allows plugins to prevent model instances from being deleted. @@ -143,7 +143,7 @@ class PluginValidationMixin(DiffMixin): log_error('validate_model_deletion', plugin=plugin.slug) continue - super().delete() + super().delete(*args, **kwargs) class MetadataMixin(models.Model): @@ -474,13 +474,13 @@ class InvenTreeAttachmentMixin: - attachments: Return a queryset containing all attachments for this model """ - def delete(self): + def delete(self, *args, **kwargs): """Handle the deletion of a model instance. Before deleting the model instance, delete any associated attachments. """ self.attachments.all().delete() - super().delete() + super().delete(*args, **kwargs) @property def attachments(self): @@ -525,44 +525,51 @@ class InvenTreeAttachmentMixin: Attachment.objects.create(**kwargs) -class InvenTreeTree(MetadataMixin, PluginValidationMixin, MPTTModel): - """Provides an abstracted self-referencing tree model for data categories. +class InvenTreeTree(MPTTModel): + """Provides an abstracted self-referencing tree model, based on the MPTTModel class. - - Each Category has one parent Category, which can be blank (for a top-level Category). - - Each Category can have zero-or-more child Category(y/ies) + Our implementation provides the following key improvements: - Attributes: - name: brief name - description: longer form description - parent: The item immediately above this one. An item with a null parent is a top-level item + - Allow tracking of separate concepts of "nodes" and "items" + - Better handling of deletion of nodes and items + - Ensure tree is correctly rebuilt after deletion and other operations + - Improved protection against recursive tree structures """ + # How each node reference its parent object + NODE_PARENT_KEY = 'parent' + # How items (not nodes) are hooked into the tree # e.g. for StockLocation, this value is 'location' ITEM_PARENT_KEY = None - # Extra fields to include in the get_path result. E.g. icon - EXTRA_PATH_FIELDS = [] - class Meta: """Metaclass defines extra model properties.""" abstract = True class MPTTMeta: - """Set insert order.""" + """MPTT metaclass options.""" order_insertion_by = ['name'] - def delete(self, delete_children=False, delete_items=False): + def delete(self, *args, **kwargs): """Handle the deletion of a tree node. - 1. Update nodes and items under the current node - 2. Delete this node - 3. Rebuild the model tree - 4. Rebuild the path for any remaining lower nodes + kwargs: + delete_children: If True, delete all child nodes (otherwise, point to the parent of this node) + delete_items: If True, delete all items associated with this node (otherwise, point to the parent of this node) + + Order of operations: + 1. Update nodes and items under the current node + 2. Delete this node + 3. Rebuild the model tree """ - tree_id = self.tree_id if self.parent else None + delete_children = kwargs.pop('delete_children', False) + delete_items = kwargs.pop('delete_items', False) + + tree_id = self.tree_id + parent = getattr(self, self.NODE_PARENT_KEY, None) # Ensure that we have the latest version of the database object try: @@ -573,10 +580,20 @@ class InvenTreeTree(MetadataMixin, PluginValidationMixin, MPTTModel): 'Object %s of type %s no longer exists', str(self), str(self.__class__) ) - # Cache node ID values for lower nodes, before we delete this one - lower_nodes = list( - self.get_descendants(include_self=False).values_list('pk', flat=True) - ) + # When deleting a top level node with multiple children, + # we need to assign a new tree_id to each child node + # otherwise they will all have the same tree_id (which is not allowed) + lower_trees = [] + + if not parent: # No parent, which means this is a top-level node + for child in self.get_children(): + # Store a flattened list of node IDs for each of the lower trees + nodes = list( + child.get_descendants(include_self=True) + .values_list('pk', flat=True) + .distinct() + ) + lower_trees.append(nodes) # 1. Update nodes and items under the current node self.handle_tree_delete( @@ -584,49 +601,48 @@ class InvenTreeTree(MetadataMixin, PluginValidationMixin, MPTTModel): ) # 2. Delete *this* node - super().delete() + super().delete(*args, **kwargs) + + # A set of tree_id values which need to be rebuilt + trees = set() - # 3. Update the tree structure if tree_id: - try: - self.__class__.objects.partial_rebuild(tree_id) - except Exception: - InvenTree.exceptions.log_error( - f'{self.__class__.__name__}.partial_rebuild' - ) - logger.warning( - 'Failed to rebuild tree for %s <%s>', - self.__class__.__name__, - self.pk, - ) - # If the partial rebuild fails, rebuild the entire tree - self.__class__.objects.rebuild() - else: + # If this node had a tree_id, we need to rebuild that tree + trees.add(tree_id) + + # Did we delete a top-level node? + next_tree_id = self.getNextTreeID() + + # If there is only one sub-tree, it can retain the same tree_id value + for tree in lower_trees[1:]: + # Bulk update the tree_id for all lower nodes + lower_nodes = self.__class__.objects.filter(pk__in=tree) + lower_nodes.update(tree_id=next_tree_id) + trees.add(next_tree_id) + next_tree_id += 1 + + # 3. Rebuild the model tree(s) as required + # - If any partial rebuilds fail, we will rebuild the entire tree + + result = True + + for tree_id in trees: + if tree_id: + if not self.partial_rebuild(tree_id): + result = False + + if not result: + # Rebuild the entire tree (expensive!!!) self.__class__.objects.rebuild() - # 4. Rebuild the path for any remaining lower nodes - nodes = self.__class__.objects.filter(pk__in=lower_nodes) - - nodes_to_update = [] - - for node in nodes: - new_path = node.construct_pathstring() - - if new_path != node.pathstring: - node.pathstring = new_path - nodes_to_update.append(node) - - if len(nodes_to_update) > 0: - self.__class__.objects.bulk_update(nodes_to_update, ['pathstring']) - def handle_tree_delete(self, delete_children=False, delete_items=False): """Delete a single instance of the tree, based on provided kwargs. Removing a tree "node" from the database must be considered carefully, based on what the user intends for any items which exist *under* that node. - - "children" are any nodes which exist *under* this node (e.g. PartCategory) - - "items" are any items which exist *under* this node (e.g. Part) + - "children" are any nodes (of the same type) which exist *under* this node (e.g. PartCategory) + - "items" are any items (of a different type) which exist *under* this node (e.g. Part) Arguments: delete_children: If True, delete all child items @@ -645,30 +661,34 @@ class InvenTreeTree(MetadataMixin, PluginValidationMixin, MPTTModel): # - Delete all items at any lower level # - Delete all descendant nodes if delete_children and delete_items: - self.get_items(cascade=True).delete() + self.delete_items(cascade=True) self.delete_nodes(child_nodes) # Case B: Delete all child nodes, but move all child items up to the parent # - Move all items at any lower level to the parent of this item # - Delete all descendant nodes elif delete_children and not delete_items: - self.get_items(cascade=True).update(**{self.ITEM_PARENT_KEY: self.parent}) - + if items := self.get_items(cascade=True): + parent = getattr(self, self.NODE_PARENT_KEY, None) + items.update(**{self.ITEM_PARENT_KEY: parent}) self.delete_nodes(child_nodes) # Case C: Delete all child items, but keep all child nodes # - Remove all items directly associated with this node # - Move any direct child nodes up one level elif not delete_children and delete_items: - self.get_items(cascade=False).delete() - self.get_children().update(parent=self.parent) + self.delete_items(cascade=False) + parent = getattr(self, self.NODE_PARENT_KEY, None) + self.get_children().update(**{self.NODE_PARENT_KEY: parent}) # Case D: Keep all child items, and keep all child nodes # - Move all items directly associated with this node up one level # - Move any direct child nodes up one level elif not delete_children and not delete_items: - self.get_items(cascade=False).update(**{self.ITEM_PARENT_KEY: self.parent}) - self.get_children().update(parent=self.parent) + parent = getattr(self, self.NODE_PARENT_KEY, None) + if items := self.get_items(cascade=False): + items.update(**{self.ITEM_PARENT_KEY: parent}) + self.get_children().update(**{self.NODE_PARENT_KEY: parent}) def delete_nodes(self, nodes): """Delete a set of nodes from the tree. @@ -681,67 +701,144 @@ class InvenTreeTree(MetadataMixin, PluginValidationMixin, MPTTModel): Arguments: nodes: A queryset of nodes to delete """ - nodes.update(parent=None) + nodes.update(**{self.NODE_PARENT_KEY: None}) nodes.delete() - def validate_unique(self, exclude=None): - """Validate that this tree instance satisfies our uniqueness requirements. - - Note that a 'unique_together' requirement for ('name', 'parent') is insufficient, - as it ignores cases where parent=None (i.e. top-level items) - """ - super().validate_unique(exclude) - - results = self.__class__.objects.filter( - name=self.name, parent=self.parent - ).exclude(pk=self.pk) - - if results.exists(): - raise ValidationError({ - 'name': _('Duplicate names cannot exist under the same parent') - }) - def api_instance_filters(self): """Instance filters for InvenTreeTree models.""" - return {'parent': {'exclude_tree': self.pk}} - - def construct_pathstring(self): - """Construct the pathstring for this tree node.""" - return InvenTree.helpers.constructPathString([item.name for item in self.path]) + return {self.NODE_PARENT_KEY: {'exclude_tree': self.pk}} def save(self, *args, **kwargs): """Custom save method for InvenTreeTree abstract model.""" + db_instance = None + + parent = getattr(self, self.NODE_PARENT_KEY, None) + + if not self.tree_id: + if parent: + # If we have a parent, use the parent's tree_id + self.tree_id = parent.tree_id + else: + # Otherwise, we need to generate a new tree_id + self.tree_id = self.getNextTreeID() + + if self.pk: + try: + db_instance = self.get_db_instance() + except self.__class__.DoesNotExist: + # If the instance does not exist, we cannot get the db instance + db_instance = None try: super().save(*args, **kwargs) except InvalidMove: # Provide better error for parent selection - raise ValidationError({'parent': _('Invalid choice')}) + raise ValidationError({self.NODE_PARENT_KEY: _('Invalid choice')}) - # Re-calculate the 'pathstring' field - pathstring = self.construct_pathstring() + trees = set() - if pathstring != self.pathstring: - kwargs.pop('force_insert', None) + if db_instance: + # If the tree_id or parent has changed, we need to rebuild the tree + if getattr(db_instance, self.NODE_PARENT_KEY) != getattr( + self, self.NODE_PARENT_KEY + ): + trees.add(db_instance.tree_id) + if db_instance.tree_id != self.tree_id: + trees.add(self.tree_id) + trees.add(db_instance.tree_id) + else: + # New instance, so we need to rebuild the tree + trees.add(self.tree_id) - kwargs['force_update'] = True + for tree_id in trees: + if tree_id: + self.partial_rebuild(tree_id) - self.pathstring = pathstring - super().save(*args, **kwargs) + def partial_rebuild(self, tree_id: int) -> bool: + """Perform a partial rebuild of the tree structure. - # Update the pathstring for any child nodes - lower_nodes = self.get_descendants(include_self=False) + If a failure occurs, log the error and return False. + """ + try: + self.__class__.objects.partial_rebuild(tree_id) + return True + except Exception as e: + # This is a critical error, explicitly report to sentry + InvenTree.sentry.report_exception(e) - nodes_to_update = [] + InvenTree.exceptions.log_error(f'{self.__class__.__name__}.partial_rebuild') + logger.exception( + 'Failed to rebuild tree for %s <%s>: %s', + self.__class__.__name__, + self.pk, + e, + ) + return False - for node in lower_nodes: - new_path = node.construct_pathstring() + def delete_items(self, cascade: bool = False): + """Delete any 'items' which exist under this node in the tree. - if new_path != node.pathstring: - node.pathstring = new_path - nodes_to_update.append(node) + - Note that an 'item' is an instance of a different model class. + - Not all tree structures will have items associated with them. + """ + if items := self.get_items(cascade=cascade): + items.delete() - if len(nodes_to_update) > 0: - self.__class__.objects.bulk_update(nodes_to_update, ['pathstring']) + def get_items(self, cascade: bool = False): + """Return a queryset of items which exist *under* this node in the tree. + + - For a StockLocation instance, this would be a queryset of StockItem objects + - For a PartCategory instance, this would be a queryset of Part objects + + The default implementation returns None, indicating that no items exist under this node. + """ + return None + + def getUniqueParents(self) -> QuerySet: + """Return a flat set of all parent items that exist above this node.""" + return self.get_ancestors() + + def getUniqueChildren(self, include_self=True) -> QuerySet: + """Return a flat set of all child items that exist under this node.""" + return self.get_descendants(include_self=include_self) + + @property + def has_children(self) -> bool: + """True if there are any children under this item.""" + return self.getUniqueChildren(include_self=False).count() > 0 + + @classmethod + def getNextTreeID(cls) -> int: + """Return the next available tree_id for this model class.""" + instance = cls.objects.order_by('-tree_id').first() + + if instance: + return instance.tree_id + 1 + else: + return 1 + + +class PathStringMixin(models.Model): + """Mixin class for adding a 'pathstring' field to a model class. + + The pathstring is a string representation of the path to this model instance, + which can be used for display purposes. + + The pathstring is automatically generated when the model instance is saved. + """ + + # Field to use for constructing a "pathstring" for the tree + PATH_FIELD = 'name' + + # Extra fields to include in the get_path result. E.g. icon + EXTRA_PATH_FIELDS = [] + + class Meta: + """Metaclass options for this mixin. + + Note: abstract must be true, as this is only a mixin, not a separate table + """ + + abstract = True name = models.CharField( blank=False, max_length=100, verbose_name=_('Name'), help_text=_('Name') @@ -769,48 +866,110 @@ class InvenTreeTree(MetadataMixin, PluginValidationMixin, MPTTModel): blank=True, max_length=250, verbose_name=_('Path'), help_text=_('Path') ) - def get_items(self, cascade=False): - """Return a queryset of items which exist *under* this node in the tree. + def save(self, *args, **kwargs): + """Update the pathstring field when saving the model instance.""" + old_pathstring = self.pathstring - - For a StockLocation instance, this would be a queryset of StockItem objects - - For a PartCategory instance, this would be a queryset of Part objects + # Rebuild upper first, to ensure the lower nodes are updated correctly + super().save(*args, **kwargs) - The default implementation returns an empty list + # Ensure that the pathstring is correctly constructed + pathstring = self.construct_pathstring(refresh=True) + + if pathstring != old_pathstring: + kwargs.pop('force_insert', None) + kwargs['force_update'] = True + + self.pathstring = pathstring + super().save(*args, **kwargs) + + # Bulk-update any child nodes, if applicable + lower_nodes = list( + self.get_descendants(include_self=False).values_list('pk', flat=True) + ) + + self.rebuild_lower_nodes(lower_nodes) + + def delete(self, *args, **kwargs): + """Custom delete method for PathStringMixin. + + - Before deleting the object, update the pathstring for any child nodes. + - Then, delete the object. """ - raise NotImplementedError(f'items() method not implemented for {type(self)}') + # Ensure that we have the latest version of the database object + try: + self.refresh_from_db() + except self.__class__.DoesNotExist: + # If the object no longer exists, raise a ValidationError + raise ValidationError( + 'Object %s of type %s no longer exists', str(self), str(self.__class__) + ) - def getUniqueParents(self) -> QuerySet: - """Return a flat set of all parent items that exist above this node.""" - return self.get_ancestors() + # Store the node ID values for lower nodes, before we delete this one + lower_nodes = list( + self.get_descendants(include_self=False).values_list('pk', flat=True) + ) - def getUniqueChildren(self, include_self=True) -> QuerySet: - """Return a flat set of all child items that exist under this node.""" - return self.get_descendants(include_self=include_self) + # Delete this node - after which we expect the tree structure will be updated + super().delete(*args, **kwargs) - @property - def has_children(self) -> bool: - """True if there are any children under this item.""" - return self.getUniqueChildren(include_self=False).count() > 0 + # Rebuild the pathstring for lower nodes + self.rebuild_lower_nodes(lower_nodes) - def getAcceptableParents(self) -> list: - """Returns a list of acceptable parent items within this model Acceptable parents are ones which are not underneath this item. + def __str__(self): + """String representation of a category is the full path to that category.""" + return f'{self.pathstring} - {self.description}' - Setting the parent of an item to its own child results in recursion. + def rebuild_lower_nodes(self, lower_nodes: list[int]): + """Rebuild the pathstring for lower nodes in the tree. + + - This is used when the pathstring for this node is updated, and we need to update all lower nodes. + - We use a bulk-update to update the pathstring for all lower nodes in the tree. """ - contents = ContentType.objects.get_for_model(type(self)) + nodes = self.__class__.objects.filter(pk__in=lower_nodes) - available = contents.get_all_objects_for_this_type() + nodes_to_update = [] - # List of child IDs - children = self.getUniqueChildren() + for node in nodes: + new_path = node.construct_pathstring() - acceptable = [None] + if new_path != node.pathstring: + node.pathstring = new_path + nodes_to_update.append(node) - for a in available: - if a.id not in children: - acceptable.append(a) + if len(nodes_to_update) > 0: + self.__class__.objects.bulk_update(nodes_to_update, ['pathstring']) - return acceptable + def construct_pathstring(self, refresh: bool = False) -> str: + """Construct the pathstring for this tree node. + + Arguments: + refresh: If True, force a refresh of the model instance + """ + if refresh: + # Refresh the model instance from the database + self.refresh_from_db() + + return InvenTree.helpers.constructPathString([ + getattr(item, self.PATH_FIELD, item.pk) for item in self.path + ]) + + def validate_unique(self, exclude=None): + """Validate that this tree instance satisfies our uniqueness requirements. + + Note that a 'unique_together' requirement for ('name', 'parent') is insufficient, + as it ignores cases where parent=None (i.e. top-level items) + """ + super().validate_unique(exclude) + + results = self.__class__.objects.filter( + name=self.name, parent=self.parent + ).exclude(pk=self.pk) + + if results.exists(): + raise ValidationError( + _('Duplicate names cannot exist under the same parent') + ) @property def parentpath(self) -> list: @@ -845,16 +1004,12 @@ class InvenTreeTree(MetadataMixin, PluginValidationMixin, MPTTModel): return [ { 'pk': item.pk, - 'name': item.name, + 'name': getattr(item, self.PATH_FIELD, item.pk), **{k: getattr(item, k, None) for k in self.EXTRA_PATH_FIELDS}, } for item in self.path ] - def __str__(self): - """String representation of a category is the full path to that category.""" - return f'{self.pathstring} - {self.description}' - class InvenTreeNotesMixin(models.Model): """A mixin class for adding notes functionality to a model class. @@ -872,7 +1027,7 @@ class InvenTreeNotesMixin(models.Model): abstract = True - def delete(self): + def delete(self, *args, **kwargs): """Custom delete method for InvenTreeNotesMixin. - Before deleting the object, check if there are any uploaded images associated with it. @@ -894,7 +1049,7 @@ class InvenTreeNotesMixin(models.Model): images.delete() - super().delete() + super().delete(*args, **kwargs) notes = InvenTree.fields.InvenTreeNotesField( verbose_name=_('Notes'), help_text=_('Markdown notes (optional)') diff --git a/src/backend/InvenTree/InvenTree/sentry.py b/src/backend/InvenTree/InvenTree/sentry.py index 3e8f577350..a050fbf71e 100644 --- a/src/backend/InvenTree/InvenTree/sentry.py +++ b/src/backend/InvenTree/InvenTree/sentry.py @@ -1,5 +1,7 @@ """Configuration for Sentry.io error reporting.""" +from typing import Optional + from django.conf import settings from django.core.exceptions import ValidationError from django.http import Http404 @@ -64,17 +66,17 @@ def init_sentry(dsn, sample_rate, tags): sentry_sdk.set_tag('git_date', InvenTree.version.inventreeCommitDate()) -def report_exception(exc): +def report_exception(exc, scope: Optional[dict] = None): """Report an exception to sentry.io.""" - if settings.TESTING: - # Skip reporting exceptions in testing mode - return + assert settings.TESTING == False, ( + 'report_exception should not be called in testing mode' + ) if settings.SENTRY_ENABLED and settings.SENTRY_DSN: if not any(isinstance(exc, e) for e in sentry_ignore_errors()): logger.info('Reporting exception to sentry.io: %s', exc) try: - sentry_sdk.capture_exception(exc) + sentry_sdk.capture_exception(exc, scope=scope) except Exception: logger.warning('Failed to report exception to sentry.io') diff --git a/src/backend/InvenTree/InvenTree/tests.py b/src/backend/InvenTree/InvenTree/tests.py index c4a8c2a3e2..9b8f47dd96 100644 --- a/src/backend/InvenTree/InvenTree/tests.py +++ b/src/backend/InvenTree/InvenTree/tests.py @@ -43,6 +43,73 @@ from .tasks import offload_task from .validators import validate_overage +class TreeFixtureTest(TestCase): + """Unit testing for our MPTT fixture data.""" + + fixtures = ['location', 'category', 'part', 'stock', 'build'] + + def node_string(self, node): + """Construct a string representation of a tree node.""" + return ':'.join([ + str(getattr(node, attr, None)) + for attr in ['parent', 'level', 'lft', 'rght'] + ]) + + def run_tree_test(self, model): + """Run MPTT test for a given model type. + + The intent here is to check that the MPTT tree structure + does not change after rebuilding the tree. + + This ensures that the fixutre data is consistent. + """ + nodes = {} + + for instance in model.objects.all(): + nodes[instance.pk] = self.node_string(instance) + + # Rebuild the tree structure + model.objects.rebuild() + + faults = [] + + # Check that no nodes have changed + for instance in model.objects.all().order_by('pk'): + ns = self.node_string(instance) + if ns != nodes[instance.pk]: + faults.append( + f'Node {instance.pk} changed: {nodes[instance.pk]} -> {ns}' + ) + + if len(faults) > 0: + print(f'!!! Fixture data changed for: {model.__name__} !!!') + + for f in faults: + print('-', f) + + assert len(faults) == 0 + + def test_part(self): + """Test MPTT tree structure for Part model.""" + from part.models import Part, PartCategory + + self.run_tree_test(Part) + self.run_tree_test(PartCategory) + + def test_build(self): + """Test MPTT tree structure for Build model.""" + from build.models import Build + + self.run_tree_test(Build) + + def test_stock(self): + """Test MPTT tree structure for Stock model.""" + from stock.models import StockItem, StockLocation + + self.run_tree_test(StockItem) + self.run_tree_test(StockLocation) + + class HostTest(InvenTreeTestCase): """Test for host configuration.""" @@ -811,12 +878,6 @@ class TestMPTT(TestCase): fixtures = ['location'] - @classmethod - def setUpTestData(cls): - """Setup for all tests.""" - super().setUpTestData() - StockLocation.objects.rebuild() - def test_self_as_parent(self): """Test that we cannot set self as parent.""" loc = StockLocation.objects.get(pk=4) diff --git a/src/backend/InvenTree/build/fixtures/build.yaml b/src/backend/InvenTree/build/fixtures/build.yaml index 186faa9c5a..a84a043d12 100644 --- a/src/backend/InvenTree/build/fixtures/build.yaml +++ b/src/backend/InvenTree/build/fixtures/build.yaml @@ -12,10 +12,10 @@ status: 10 # PENDING creation_date: '2019-03-16' link: http://www.google.com + tree_id: 1 level: 0 - lft: 0 - rght: 0 - tree_id: 0 + lft: 1 + rght: 2 - model: build.build pk: 2 @@ -28,10 +28,10 @@ quantity: 21 notes: 'Some more simple notes' creation_date: '2019-03-16' + tree_id: 2 level: 0 - lft: 0 - rght: 0 - tree_id: 1 + lft: 1 + rght: 2 - model: build.build pk: 3 @@ -44,10 +44,10 @@ quantity: 21 notes: 'Some even more simple notes' creation_date: '2019-03-16' + tree_id: 4 level: 0 - lft: 0 - rght: 0 - tree_id: 1 + lft: 1 + rght: 2 - model: build.build pk: 4 @@ -60,10 +60,10 @@ quantity: 21 notes: 'Some even even more simple notes' creation_date: '2019-03-16' + tree_id: 5 level: 0 - lft: 0 - rght: 0 - tree_id: 1 + lft: 1 + rght: 2 - model: build.build pk: 5 @@ -76,7 +76,7 @@ quantity: 10 creation_date: '2019-03-16' notes: "A thing" + tree_id: 3 level: 0 - lft: 0 - rght: 0 - tree_id: 1 + lft: 1 + rght: 2 diff --git a/src/backend/InvenTree/build/models.py b/src/backend/InvenTree/build/models.py index 744a722e14..43411f2e9d 100644 --- a/src/backend/InvenTree/build/models.py +++ b/src/backend/InvenTree/build/models.py @@ -14,8 +14,7 @@ from django.urls import reverse from django.utils.translation import gettext_lazy as _ import structlog -from mptt.exceptions import InvalidMove -from mptt.models import MPTTModel, TreeForeignKey +from mptt.models import TreeForeignKey from rest_framework import serializers import generic.states @@ -74,16 +73,16 @@ class BuildReportContext(report.mixins.BaseReportContext): class Build( + InvenTree.models.PluginValidationMixin, report.mixins.InvenTreeReportMixin, InvenTree.models.InvenTreeAttachmentMixin, InvenTree.models.InvenTreeBarcodeMixin, InvenTree.models.InvenTreeNotesMixin, - InvenTree.models.MetadataMixin, - InvenTree.models.PluginValidationMixin, InvenTree.models.ReferenceIndexingMixin, StateTransitionMixin, StatusCodeMixin, - MPTTModel, + InvenTree.models.MetadataMixin, + InvenTree.models.InvenTreeTree, ): """A Build object organises the creation of new StockItem objects from other existing StockItem objects. @@ -117,6 +116,11 @@ class Build( verbose_name = _('Build Order') verbose_name_plural = _('Build Orders') + class MPTTMeta: + """MPTT options for the BuildOrder model.""" + + order_insertion_by = ['reference'] + OVERDUE_FILTER = ( Q(status__in=BuildStatusGroups.ACTIVE_CODES) & ~Q(target_date=None) @@ -183,10 +187,7 @@ class Build( if not self.destination: self.destination = self.part.get_default_location() - try: - super().save(*args, **kwargs) - except InvalidMove: - raise ValidationError({'parent': _('Invalid choice for parent build')}) + super().save(*args, **kwargs) def clean(self): """Validate the BuildOrder model.""" diff --git a/src/backend/InvenTree/build/test_api.py b/src/backend/InvenTree/build/test_api.py index 7cbb5c65ed..3a19aa6623 100644 --- a/src/backend/InvenTree/build/test_api.py +++ b/src/backend/InvenTree/build/test_api.py @@ -1079,7 +1079,6 @@ class BuildListTest(BuildAPITest): for ii, sub_build in enumerate(Build.objects.filter(parent=parent)): for i in range(3): x = ii * 10 + i + 50 - Build.objects.create( part=part, reference=f'BO-{x}', @@ -1091,7 +1090,22 @@ class BuildListTest(BuildAPITest): # 20 new builds should have been created! self.assertEqual(Build.objects.count(), (n + 20)) - Build.objects.rebuild() + parent.refresh_from_db() + + # There should be 5 sub-builds + self.assertEqual(parent.get_children().count(), 5) + + # Check tree structure for direct children + for sub_build in parent.get_children(): + self.assertEqual(sub_build.parent, parent) + self.assertLess(sub_build.rght, parent.rght) + self.assertGreater(sub_build.lft, parent.lft) + self.assertEqual(sub_build.level, parent.level + 1) + self.assertEqual(sub_build.tree_id, parent.tree_id) + self.assertEqual(sub_build.get_children().count(), 3) + + # And a total of 20 descendants + self.assertEqual(parent.get_descendants().count(), 20) # Search by parent response = self.get(self.url, data={'parent': parent.pk}) diff --git a/src/backend/InvenTree/build/tests.py b/src/backend/InvenTree/build/tests.py index ae94855a8a..9856927ad2 100644 --- a/src/backend/InvenTree/build/tests.py +++ b/src/backend/InvenTree/build/tests.py @@ -144,3 +144,166 @@ class BuildTestSimple(InvenTreeTestCase): # Check that expected quantity of new builds is created self.assertEqual(Build.objects.count(), n + 4) + + +class BuildTreeTest(InvenTreeTestCase): + """Unit tests for the Build tree structure.""" + + @classmethod + def setUpTestData(cls): + """Initialize test data for the Build tree tests.""" + from build.models import Build + from part.models import Part + + # Create a test assembly part + cls.assembly = Part.objects.create( + name='Test Assembly', + description='A test assembly part', + assembly=True, + active=True, + locked=False, + ) + + # Generate a top-level build + cls.build = Build.objects.create( + part=cls.assembly, quantity=5, reference='BO-1234', target_date=None + ) + + def test_basic_tree(self): + """Test basic tree structure functionality. + + - In this test we test a simple non-branching tree structure. + - Check that the tree structure is correctly created. + - Verify parent-child relationships and tree properties. + - Ensure that the number of children and descendants is as expected. + - Validate that the tree properties (tree_id, level, lft, rght) are correct + - Check that node deletion works correctly. + """ + from build.models import Build + + # Create a cascading tree structure of builds + child = self.build + + builds = [self.build] + + self.assertEqual(Build.objects.count(), 1) + + for i in range(10): + child = Build.objects.create( + part=self.assembly, quantity=2, reference=f'BO-{1235 + i}', parent=child + ) + + builds.append(child) + + self.assertEqual(Build.objects.count(), 11) + + # Test the tree structure for each node + for idx, child in enumerate(builds): + # Check parent-child relationships + expected_parent = builds[idx - 1] if idx > 0 else None + self.assertEqual(child.parent, expected_parent) + + # Check number of children + expected_children = 0 if idx == 10 else 1 + self.assertEqual(child.get_children().count(), expected_children) + + # Check number of descendants + expected_descendants = max(10 - idx, 0) + self.assertEqual( + child.get_descendants(include_self=False).count(), expected_descendants + ) + + # Test tree structure + self.assertEqual(child.tree_id, self.build.tree_id) + self.assertEqual(child.level, idx) + self.assertEqual(child.lft, idx + 1) + self.assertEqual(child.rght, 22 - idx) + + # Test deletion of a node - delete BO-1238 + Build.objects.get(reference='BO-1238').delete() + + # We expect that only a SINGLE node is deleted + self.assertEqual(Build.objects.count(), 10) + self.assertEqual(self.build.get_descendants(include_self=False).count(), 9) + + # Check that the item parents have been correctly remapped + build_reference_map = { + 'BO-1235': 'BO-1234', + 'BO-1236': 'BO-1235', + 'BO-1237': 'BO-1236', + 'BO-1239': 'BO-1237', # BO-1238 was deleted, so BO-1239's parent is now BO-1237 + 'BO-1240': 'BO-1239', + 'BO-1241': 'BO-1240', + 'BO-1242': 'BO-1241', + 'BO-1243': 'BO-1242', + 'BO-1244': 'BO-1243', + } + + # Check that the tree structure is still valid + for child_ref, parent_ref in build_reference_map.items(): + build = Build.objects.get(reference=child_ref) + parent = Build.objects.get(reference=parent_ref) + self.assertEqual(parent_ref, parent.reference) + self.assertEqual(build.tree_id, self.build.tree_id) + self.assertEqual(build.level, parent.level + 1) + self.assertEqual(build.lft, parent.lft + 1) + self.assertEqual(build.rght, parent.rght - 1) + + def test_complex_tree(self): + """Test a more complex tree structure with multiple branches. + + - Ensure that grafting nodes works correctly. + """ + ref = 1235 + + for ii in range(3): + # Create child builds + child = Build.objects.create( + part=self.assembly, + quantity=2, + reference=f'BO-{ref + (ii * 4)}', + parent=self.build, + ) + + for jj in range(3): + # Create grandchild builds + grandchild = Build.objects.create( + part=self.assembly, + quantity=2, + reference=f'BO-{ref + (ii * 4) + jj + 1}', + parent=child, + ) + + self.assertEqual(grandchild.parent, child) + self.assertEqual(grandchild.tree_id, self.build.tree_id) + self.assertEqual(grandchild.level, 2) + + self.assertEqual(child.get_children().count(), 3) + self.assertEqual(child.get_descendants(include_self=False).count(), 3) + + self.assertEqual(child.level, 1) + self.assertEqual(child.tree_id, self.build.tree_id) + + # Basic tests + self.assertEqual(Build.objects.count(), 13) + self.assertEqual(self.build.get_children().count(), 3) + self.assertEqual(self.build.get_descendants(include_self=False).count(), 12) + + # Move one of the child builds + build = Build.objects.get(reference='BO-1239') + self.assertEqual(build.parent.reference, 'BO-1234') + self.assertEqual(build.level, 1) + self.assertEqual(build.get_children().count(), 3) + for bo in build.get_children(): + self.assertEqual(bo.level, 2) + + parent = Build.objects.get(reference='BO-1235') + build.parent = parent + build.save() + + build = Build.objects.get(reference='BO-1239') + self.assertEqual(build.parent.reference, 'BO-1235') + self.assertEqual(build.level, 2) + self.assertEqual(build.get_children().count(), 3) + for bo in build.get_children(): + self.assertEqual(bo.level, 3) diff --git a/src/backend/InvenTree/order/test_api.py b/src/backend/InvenTree/order/test_api.py index 58fe2ff17f..42f702da19 100644 --- a/src/backend/InvenTree/order/test_api.py +++ b/src/backend/InvenTree/order/test_api.py @@ -1919,6 +1919,11 @@ class SalesOrderDownloadTest(OrderTest): class SalesOrderAllocateTest(OrderTest): """Unit tests for allocating stock items against a SalesOrder.""" + @classmethod + def setUpTestData(cls): + """Init routine for this unit test class.""" + super().setUpTestData() + def setUp(self): """Init routines for this unit testing class.""" super().setUp() @@ -2008,7 +2013,10 @@ class SalesOrderAllocateTest(OrderTest): data = {'items': [], 'shipment': self.shipment.pk} for line in self.order.lines.all(): - stock_item = line.part.stock_items.last() + for stock_item in line.part.stock_items.all(): + # Find a non-serialized stock item to allocate + if not stock_item.serialized: + break # Fully-allocate each line data['items'].append({ @@ -2040,11 +2048,22 @@ class SalesOrderAllocateTest(OrderTest): for line in filter(check_template, self.order.lines.all()): stock_item = None + stock_item = None + # Allocate a matching variant parts = Part.objects.filter(salable=True).filter(variant_of=line.part.pk) for part in parts: stock_item = part.stock_items.last() - break + + for item in part.stock_items.all(): + if item.serialized: + continue + + stock_item = item + break + + if stock_item is not None: + break # Fully-allocate each line data['items'].append({ diff --git a/src/backend/InvenTree/part/fixtures/category.yaml b/src/backend/InvenTree/part/fixtures/category.yaml index ab25df536c..b3d220c3e2 100644 --- a/src/backend/InvenTree/part/fixtures/category.yaml +++ b/src/backend/InvenTree/part/fixtures/category.yaml @@ -7,8 +7,8 @@ description: Electronic components parent: null default_location: 1 - level: 0 tree_id: 1 + level: 0 lft: 1 rght: 12 @@ -19,10 +19,10 @@ description: Resistors parent: 1 default_location: null - level: 1 tree_id: 1 - lft: 2 - rght: 3 + level: 1 + lft: 10 + rght: 11 - model: part.partcategory pk: 3 @@ -33,8 +33,8 @@ default_location: null level: 1 tree_id: 1 - lft: 4 - rght: 5 + lft: 2 + rght: 3 - model: part.partcategory pk: 4 @@ -43,10 +43,10 @@ description: Integrated Circuits parent: 1 default_location: null - level: 1 tree_id: 1 - lft: 6 - rght: 11 + level: 1 + lft: 4 + rght: 9 - model: part.partcategory pk: 5 @@ -55,10 +55,10 @@ description: Microcontrollers parent: 4 default_location: null - level: 2 tree_id: 1 - lft: 7 - rght: 8 + level: 2 + lft: 5 + rght: 6 - model: part.partcategory pk: 6 @@ -67,10 +67,10 @@ description: Communication interfaces parent: 4 default_location: null - level: 2 tree_id: 1 - lft: 9 - rght: 10 + level: 2 + lft: 7 + rght: 8 - model: part.partcategory pk: 7 @@ -78,8 +78,8 @@ name: Mechanical description: Mechanical components default_location: null - level: 0 tree_id: 2 + level: 0 lft: 1 rght: 4 @@ -90,7 +90,7 @@ description: Screws, bolts, etc parent: 7 default_location: 5 - level: 1 tree_id: 2 + level: 1 lft: 2 rght: 3 diff --git a/src/backend/InvenTree/part/fixtures/part.yaml b/src/backend/InvenTree/part/fixtures/part.yaml index e242aee615..58023cf7ad 100644 --- a/src/backend/InvenTree/part/fixtures/part.yaml +++ b/src/backend/InvenTree/part/fixtures/part.yaml @@ -8,12 +8,12 @@ category: 8 link: http://www.acme.com/parts/m2x4lphs creation_date: '2018-01-01' - tree_id: 0 purchaseable: True testable: False + tree_id: 5 level: 0 - lft: 0 - rght: 0 + lft: 1 + rght: 2 - model: part.part pk: 2 @@ -22,10 +22,10 @@ description: 'M3x12 socket head cap screw' category: 8 creation_date: '2019-02-02' - tree_id: 0 + tree_id: 6 level: 0 - lft: 0 - rght: 0 + lft: 1 + rght: 2 # Create some resistors @@ -36,11 +36,10 @@ description: '2.2kOhm resistor in 0805 package' category: 2 creation_date: '2020-03-03' - tree_id: 0 + tree_id: 8 level: 0 - lft: 0 - rght: 0 - + lft: 1 + rght: 2 - model: part.part pk: 4 @@ -50,10 +49,10 @@ category: 2 creation_date: '2021-04-04' default_location: 2 # Home/Bathroom - tree_id: 0 + tree_id: 9 level: 0 - lft: 0 - rght: 0 + lft: 1 + rght: 2 # Create some capacitors - model: part.part @@ -64,10 +63,10 @@ purchaseable: true category: 3 creation_date: '2022-05-05' - tree_id: 0 + tree_id: 3 level: 0 - lft: 0 - rght: 0 + lft: 1 + rght: 2 - model: part.part pk: 25 @@ -80,11 +79,11 @@ assembly: true trackable: true testable: true - tree_id: 0 - level: 0 - lft: 0 - rght: 0 default_expiry: 10 + tree_id: 10 + level: 0 + lft: 1 + rght: 2 - model: part.part pk: 50 @@ -94,10 +93,10 @@ category: null salable: true creation_date: '2024-07-07' - tree_id: 0 + tree_id: 7 level: 0 - lft: 0 - rght: 0 + lft: 1 + rght: 2 # A part that can be made from other parts - model: part.part @@ -115,10 +114,10 @@ testable: True IPN: BOB revision: A2 - tree_id: 0 + tree_id: 2 level: 0 - lft: 0 - rght: 0 + lft: 1 + rght: 2 - model: part.part pk: 101 @@ -128,10 +127,10 @@ salable: true creation_date: '2026-09-09' active: True - tree_id: 0 + tree_id: 1 level: 0 - lft: 0 - rght: 0 + lft: 1 + rght: 2 # A 'template' part - model: part.part @@ -145,10 +144,10 @@ creation_date: '2027-10-10' salable: true category: 7 - tree_id: 1 + tree_id: 4 level: 0 - lft: 0 - rght: 0 + lft: 1 + rght: 10 - model: part.part pk: 10001 @@ -160,10 +159,10 @@ testable: true creation_date: '2028-11-11' category: 7 - tree_id: 1 - level: 0 - lft: 0 - rght: 0 + tree_id: 4 + level: 1 + lft: 2 + rght: 3 - model: part.part pk: 10002 @@ -177,10 +176,10 @@ salable: true creation_date: '2029-12-12' category: 7 - tree_id: 1 - level: 0 - lft: 0 - rght: 0 + tree_id: 4 + level: 1 + lft: 8 + rght: 9 - model: part.part pk: 10003 @@ -193,10 +192,10 @@ trackable: false testable: true creation_date: '2030-01-01' - tree_id: 1 - level: 0 - lft: 0 - rght: 0 + tree_id: 4 + level: 1 + lft: 4 + rght: 7 - model: part.part pk: 10004 @@ -209,10 +208,10 @@ creation_date: '2031-02-02' trackable: true testable: true - tree_id: 1 - level: 0 - lft: 0 - rght: 0 + tree_id: 4 + level: 2 + lft: 5 + rght: 6 - model: part.partrelated pk: 1 diff --git a/src/backend/InvenTree/part/models.py b/src/backend/InvenTree/part/models.py index 9b3501167c..166db4a1c3 100644 --- a/src/backend/InvenTree/part/models.py +++ b/src/backend/InvenTree/part/models.py @@ -30,9 +30,8 @@ from django_cleanup import cleanup from djmoney.contrib.exchange.exceptions import MissingRate from djmoney.contrib.exchange.models import convert_money from djmoney.money import Money -from mptt.exceptions import InvalidMove from mptt.managers import TreeManager -from mptt.models import MPTTModel, TreeForeignKey +from mptt.models import TreeForeignKey from stdimage.models import StdImageField from taggit.managers import TaggableManager @@ -70,7 +69,12 @@ from stock import models as StockModels logger = structlog.get_logger('inventree') -class PartCategory(InvenTree.models.InvenTreeTree): +class PartCategory( + InvenTree.models.PluginValidationMixin, + InvenTree.models.MetadataMixin, + InvenTree.models.PathStringMixin, + InvenTree.models.InvenTreeTree, +): """PartCategory provides hierarchical organization of Part objects. Attributes: @@ -401,13 +405,13 @@ class PartReportContext(report.mixins.BaseReportContext): @cleanup.ignore class Part( + InvenTree.models.PluginValidationMixin, InvenTree.models.InvenTreeAttachmentMixin, InvenTree.models.InvenTreeBarcodeMixin, InvenTree.models.InvenTreeNotesMixin, report.mixins.InvenTreeReportMixin, InvenTree.models.MetadataMixin, - InvenTree.models.PluginValidationMixin, - MPTTModel, + InvenTree.models.InvenTreeTree, ): """The Part object represents an abstract part, the 'concept' of an actual entity. @@ -447,6 +451,8 @@ class Part( last_stocktake: Date at which last stocktake was performed for this Part """ + NODE_PARENT_KEY = 'variant_of' + objects = PartManager() tags = TaggableManager(blank=True) @@ -550,10 +556,7 @@ class Part( self.full_clean() - try: - super().save(*args, **kwargs) - except InvalidMove: - raise ValidationError({'variant_of': _('Invalid choice for parent part')}) + super().save(*args, **kwargs) if _new: # Only run if the check was not run previously (due to not existing in the database) diff --git a/src/backend/InvenTree/part/test_api.py b/src/backend/InvenTree/part/test_api.py index 8739510973..5f70139675 100644 --- a/src/backend/InvenTree/part/test_api.py +++ b/src/backend/InvenTree/part/test_api.py @@ -112,8 +112,8 @@ class PartCategoryAPITest(InvenTreeAPITestCase): url = reverse('api-part-category-list') # star categories manually for tests as it is not possible with fixures - # because the current user is no fixure itself and throws an invalid - # foreign key constrain + # because the current user is not fixured itself and throws an invalid + # foreign key constraint for pk in [3, 4]: PartCategory.objects.get(pk=pk).set_starred(self.user, True) @@ -537,8 +537,6 @@ class PartCategoryAPITest(InvenTreeAPITestCase): parent=loc, ) - PartCategory.objects.rebuild() - with self.assertNumQueriesLessThan(15): response = self.get(reverse('api-part-category-tree'), expected_code=200) @@ -588,7 +586,6 @@ class PartCategoryAPITest(InvenTreeAPITestCase): sub4 = PartCategory.objects.create(name='sub4', parent=sub3) sub5 = PartCategory.objects.create(name='sub5', parent=sub2) Part.objects.create(name='test', category=sub4) - PartCategory.objects.rebuild() # This query will trigger an internal server error if annotation results are not limited to 1 url = reverse('api-part-list') @@ -1057,9 +1054,6 @@ class PartAPITest(PartAPITestBase): Uses the 'chair template' part (pk=10000) """ - # Rebuild the MPTT structure before running these tests - Part.objects.rebuild() - url = reverse('api-part-list') response = self.get(url, {'variant_of': 10000}, expected_code=200) @@ -1105,7 +1099,6 @@ class PartAPITest(PartAPITestBase): def test_variant_stock(self): """Unit tests for the 'variant_stock' annotation, which provides a stock count for *variant* parts.""" # Ensure the MPTT structure is in a known state before running tests - Part.objects.rebuild() # Initially, there are no "chairs" in stock, # so each 'chair' template should report variant_stock=0 @@ -2021,9 +2014,6 @@ class PartAPIAggregationTest(InvenTreeAPITestCase): """Create test data as part of setup routine.""" super().setUpTestData() - # Ensure the part "variant" tree is correctly structured - Part.objects.rebuild() - # Add a new part cls.part = Part.objects.create( name='Banana', @@ -2379,9 +2369,6 @@ class BomItemTest(InvenTreeAPITestCase): """Set up the test case.""" super().setUp() - # Rebuild part tree so BOM items validate correctly - Part.objects.rebuild() - def test_bom_list(self): """Tests for the BomItem list endpoint.""" # How many BOM items currently exist in the database? @@ -2569,8 +2556,6 @@ class BomItemTest(InvenTreeAPITestCase): variant.save() - Part.objects.rebuild() - # Create some stock items for this new part for _ in range(ii): StockItem.objects.create(part=variant, location=loc, quantity=100) @@ -2700,8 +2685,6 @@ class BomItemTest(InvenTreeAPITestCase): def test_bom_variant_stock(self): """Test for 'available_variant_stock' annotation.""" - Part.objects.rebuild() - # BOM item we are interested in bom_item = BomItem.objects.get(pk=1) @@ -3080,11 +3063,6 @@ class PartMetadataAPITest(InvenTreeAPITestCase): roles = ['part.change', 'part_category.change'] - def setUp(self): - """Setup unit tets.""" - super().setUp() - Part.objects.rebuild() - def metatester(self, apikey, model): """Generic tester.""" modeldata = model.objects.first() diff --git a/src/backend/InvenTree/part/test_bom_item.py b/src/backend/InvenTree/part/test_bom_item.py index f978f90638..722daf61da 100644 --- a/src/backend/InvenTree/part/test_bom_item.py +++ b/src/backend/InvenTree/part/test_bom_item.py @@ -29,8 +29,6 @@ class BomItemTest(TestCase): """Create initial data.""" super().setUp() - Part.objects.rebuild() - self.bob = Part.objects.get(id=100) self.orphan = Part.objects.get(name='Orphan') self.r1 = Part.objects.get(name='R_2K2_0805') diff --git a/src/backend/InvenTree/part/test_category.py b/src/backend/InvenTree/part/test_category.py index 198c5a32a6..d9f45071e1 100644 --- a/src/backend/InvenTree/part/test_category.py +++ b/src/backend/InvenTree/part/test_category.py @@ -23,8 +23,8 @@ class CategoryTest(TestCase): super().setUpTestData() cls.electronics = PartCategory.objects.get(name='Electronics') - cls.mechanical = PartCategory.objects.get(name='Mechanical') cls.resistors = PartCategory.objects.get(name='Resistors') + cls.mechanical = PartCategory.objects.get(name='Mechanical') cls.capacitors = PartCategory.objects.get(name='Capacitors') cls.fasteners = PartCategory.objects.get(name='Fasteners') cls.ic = PartCategory.objects.get(name='IC') @@ -66,6 +66,7 @@ class CategoryTest(TestCase): def test_path_string(self): """Test that the category path string works correctly.""" # Note that due to data migrations, these fields need to be saved first + self.resistors.save() self.transceivers.save() @@ -88,6 +89,9 @@ class CategoryTest(TestCase): # Move to a new parent location subcat.parent = self.resistors subcat.save() + + # subcat.refresh_from_db() + self.assertEqual(subcat.pathstring, 'Electronics/Resistors/Subcategory') self.assertEqual(len(subcat.path), 3) @@ -217,6 +221,45 @@ class CategoryTest(TestCase): w = Part.objects.get(name='Widget') self.assertIsNone(w.get_default_location()) + def test_root_delete(self): + """Test that deleting a root category works correctly.""" + # Clear out the existing categories + # Note: Cannot call bulk delete here, as it will not trigger MPTT updates + for p in PartCategory.objects.all(): + p.delete() + + # Create a new root category + root = PartCategory.objects.create(name='Root Category', description='Root') + + # Create a child category + for i in range(10): + PartCategory.objects.create( + name=f'Child Category {i}', description='Child', parent=root + ) + + root.refresh_from_db() + + self.assertEqual(root.get_descendants(include_self=False).count(), 10) + self.assertEqual(PartCategory.objects.count(), 11) + + # There is only a single tree_id value + tree_ids = PartCategory.objects.values_list('tree_id', flat=True).distinct() + tree_ids = set(tree_ids) + self.assertEqual(len(tree_ids), 1) + + # Delete the root category + root.delete() + + # All child categories are now "root" categories + for cat in PartCategory.objects.all(): + self.assertIsNone(cat.parent) + self.assertEqual(cat.level, 0) + + # 10 unique tree_id values should now exist + tree_ids = PartCategory.objects.values_list('tree_id', flat=True).distinct() + tree_ids = set(tree_ids) + self.assertEqual(len(tree_ids), 10) + def test_category_tree(self): """Unit tests for the part category tree structure (MPTT). @@ -226,8 +269,6 @@ class CategoryTest(TestCase): # Clear out any existing parts Part.objects.all().delete() - PartCategory.objects.rebuild() - # First, create a structured tree of part categories A = PartCategory.objects.create(name='A', description='Top level category') @@ -309,9 +350,8 @@ class CategoryTest(TestCase): for loc in [B1, B2, C31, C32, C33]: # These should now all be "top level" categories loc.refresh_from_db() - - self.assertEqual(loc.level, 0) self.assertEqual(loc.parent, None) + self.assertEqual(loc.level, 0) # Pathstring should be the same as the name self.assertEqual(loc.pathstring, loc.name) @@ -319,6 +359,7 @@ class CategoryTest(TestCase): # Test pathstring for direct children for child in loc.get_children(): self.assertEqual(child.pathstring, f'{loc.name}/{child.name}') + self.assertEqual(child.level, 1) # Check descendants for B1 descendants = B1.get_descendants() diff --git a/src/backend/InvenTree/part/test_part.py b/src/backend/InvenTree/part/test_part.py index 7931926945..a6f2f5bddb 100644 --- a/src/backend/InvenTree/part/test_part.py +++ b/src/backend/InvenTree/part/test_part.py @@ -160,11 +160,8 @@ class PartTest(TestCase): cls.r1 = Part.objects.get(name='R_2K2_0805') cls.r2 = Part.objects.get(name='R_4K7_0603') - cls.c1 = Part.objects.get(name='C_22N_0805') - Part.objects.rebuild() - def test_barcode_mixin(self): """Test the barcode mixin functionality.""" self.assertEqual(Part.barcode_model_type(), 'part') @@ -173,18 +170,6 @@ class PartTest(TestCase): barcode = p.format_barcode() self.assertEqual(barcode, '{"part": 1}') - def test_tree(self): - """Test that the part variant tree is working properly.""" - chair = Part.objects.get(pk=10000) - self.assertEqual(chair.get_children().count(), 3) - self.assertEqual(chair.get_descendant_count(), 4) - - green = Part.objects.get(pk=10004) - self.assertEqual(green.get_ancestors().count(), 2) - self.assertEqual(green.get_root(), chair) - self.assertEqual(green.get_family().count(), 3) - self.assertEqual(Part.objects.filter(tree_id=chair.tree_id).count(), 5) - def test_str(self): """Test string representation of a Part.""" p = Part.objects.get(pk=100) @@ -417,7 +402,6 @@ class PartTest(TestCase): ) with self.assertRaises(ValidationError) as exc: - print('rev a:', rev_a.revision_of, part.revision_of) rev_a.revision_of = part rev_a.save() @@ -466,6 +450,180 @@ class PartTest(TestCase): self.assertEqual(part.revisions.count(), 2) +class VariantTreeTest(TestCase): + """Unit test for the Part variant tree structure.""" + + fixtures = ['category', 'part', 'location'] + + @classmethod + def setUpTestData(cls): + """Rebuild Part tree before running tests.""" + super().setUpTestData() + + def test_tree(self): + """Test tree structure for fixtured data.""" + chair = Part.objects.get(pk=10000) + self.assertEqual(chair.get_children().count(), 3) + self.assertEqual(chair.get_descendant_count(), 4) + + green = Part.objects.get(pk=10004) + self.assertEqual(green.get_ancestors().count(), 2) + self.assertEqual(green.get_root(), chair) + self.assertEqual(green.get_family().count(), 3) + self.assertEqual(Part.objects.filter(tree_id=chair.tree_id).count(), 5) + + def test_part_creation(self): + """Test that parts are created with the correct tree structure.""" + part_1 = Part.objects.create(name='Part 1', description='Part 1 description') + + part_2 = Part.objects.create(name='Part 2', description='Part 2 description') + + # Check that both parts have been created with unique tree IDs + self.assertNotEqual(part_1.tree_id, part_2.tree_id) + + for p in [part_1, part_2]: + self.assertEqual(p.level, 0) + self.assertEqual(p.lft, 1) + self.assertEqual(p.rght, 2) + self.assertIsNone(p.variant_of) + + self.assertEqual(Part.objects.filter(tree_id=p.tree_id).count(), 1) + + def test_complex_tree(self): + """Test a complex part template/variant tree.""" + template = Part.objects.create( + name='Top Level Template', + description='A top-level template part', + is_template=True, + ) + + # Create some variant parts + for x in ['A', 'B', 'C']: + variant = Part.objects.create( + name=f'Variant {x}', + description=f'Variant part {x}', + variant_of=template, + is_template=True, + ) + + for ii in range(1, 4): + Part.objects.create( + name=f'Sub-Variant {x}-{ii}', + description=f'Sub-variant part {x}-{ii}', + variant_of=variant, + ) + + template.refresh_from_db() + + self.assertEqual(template.get_children().count(), 3) + self.assertEqual(template.get_descendants(include_self=False).count(), 12) + + for variant in template.get_children(): + self.assertEqual(variant.variant_of, template) + self.assertEqual(variant.get_ancestors().count(), 1) + self.assertEqual(variant.get_descendants(include_self=False).count(), 3) + + for child in variant.get_children(): + self.assertEqual(child.variant_of, variant) + self.assertEqual(child.get_ancestors().count(), 2) + self.assertEqual(child.get_descendants(include_self=False).count(), 0) + + # Let's graft one variant onto another + variant_a = Part.objects.get(name='Variant A') + variant_b = Part.objects.get(name='Variant B') + variant_c = Part.objects.get(name='Variant C') + + variant_a.variant_of = variant_b + variant_a.save() + + template.refresh_from_db() + self.assertEqual(template.get_children().count(), 2) + + variant_a.refresh_from_db() + variant_b.refresh_from_db() + + self.assertEqual(variant_a.get_ancestors().count(), 2) + self.assertEqual(variant_a.variant_of, variant_b) + self.assertEqual(variant_b.get_children().count(), 4) + + for child in variant_a.get_children(): + self.assertEqual(child.variant_of, variant_a) + self.assertEqual(child.tree_id, template.tree_id) + self.assertEqual(child.get_ancestors().count(), 3) + self.assertEqual(child.level, 3) + self.assertGreater(child.lft, variant_a.lft) + self.assertGreater(child.lft, template.lft) + self.assertLess(child.rght, variant_a.rght) + self.assertLess(child.rght, template.rght) + self.assertLess(child.lft, child.rght) + + # Let's graft one variant to its own tree + variant_c.variant_of = None + variant_c.save() + + template.refresh_from_db() + variant_a.refresh_from_db() + variant_b.refresh_from_db() + variant_c.refresh_from_db() + + # Check total descendent count + self.assertEqual(template.get_descendant_count(), 8) + self.assertEqual(variant_a.get_descendant_count(), 3) + self.assertEqual(variant_b.get_descendant_count(), 7) + self.assertEqual(variant_c.get_descendant_count(), 3) + + # Check tree ID values + self.assertEqual(template.tree_id, variant_a.tree_id) + self.assertEqual(template.tree_id, variant_b.tree_id) + self.assertNotEqual(template.tree_id, variant_c.tree_id) + + for child in variant_a.get_children(): + # template -> variant_b -> variant_b -> child + self.assertEqual(child.tree_id, template.tree_id) + self.assertEqual(child.get_ancestors().count(), 3) + self.assertLess(child.lft, child.rght) + + for child in variant_b.get_children(): + # template -> variant_b -> child + self.assertEqual(child.tree_id, template.tree_id) + self.assertEqual(child.get_ancestors().count(), 2) + self.assertLess(child.lft, child.rght) + + for child in variant_c.get_children(): + # variant_c -> child + self.assertEqual(child.tree_id, variant_c.tree_id) + self.assertEqual(child.get_ancestors().count(), 1) + self.assertLess(child.lft, child.rght) + + # Next, let's delete an entire variant - ensure that sub-variants are moved up + b_childs = variant_b.get_children() + + with self.assertRaises(ValidationError): + variant_b.delete() + + # Mark as inactive to allow deletion + variant_b.active = False + variant_b.save() + variant_b.delete() + + template.refresh_from_db() + variant_a.refresh_from_db() + + # Top-level template should have now 4 direct children: + # - 3x children grafted from variant_a + # - variant_a - previously child of variant a + self.assertEqual(template.get_children().count(), 4) + + self.assertEqual(variant_a.get_children().count(), 3) + self.assertEqual(variant_a.variant_of, template) + + for child in b_childs: + child.refresh_from_db() + self.assertEqual(child.variant_of, template) + self.assertEqual(child.get_ancestors().count(), 1) + self.assertEqual(child.level, 2) + + class TestTemplateTest(TestCase): """Unit test for the TestTemplate class.""" diff --git a/src/backend/InvenTree/plugin/samples/integration/test_validation_sample.py b/src/backend/InvenTree/plugin/samples/integration/test_validation_sample.py index 010692cb8b..27e7ab9ce4 100644 --- a/src/backend/InvenTree/plugin/samples/integration/test_validation_sample.py +++ b/src/backend/InvenTree/plugin/samples/integration/test_validation_sample.py @@ -16,10 +16,13 @@ class SampleValidatorPluginTest(InvenTreeAPITestCase, InvenTreeTestCase): def setUp(self): """Set up the test environment.""" + super().setUp() + cat = part.models.PartCategory.objects.first() self.part = part.models.Part.objects.create( name='TestPart', category=cat, description='A test part', component=True ) + self.assembly = part.models.Part.objects.create( name='TestAssembly', category=cat, @@ -27,10 +30,6 @@ class SampleValidatorPluginTest(InvenTreeAPITestCase, InvenTreeTestCase): component=False, assembly=True, ) - self.bom_item = part.models.BomItem.objects.create( - part=self.assembly, sub_part=self.part, quantity=1 - ) - super().setUp() def get_plugin(self): """Return the SampleValidatorPlugin instance.""" @@ -43,6 +42,16 @@ class SampleValidatorPluginTest(InvenTreeAPITestCase, InvenTreeTestCase): def test_validate_model_instance(self): """Test the validate_model_instance function.""" # First, ensure that the plugin is disabled + + # Create a BomItem to run tests on + # We need to refresh the part + self.part.refresh_from_db() + self.assembly.refresh_from_db() + + self.bom_item = part.models.BomItem.objects.create( + part=self.assembly, sub_part=self.part, quantity=1 + ) + self.enable_plugin(False) plg = self.get_plugin() diff --git a/src/backend/InvenTree/stock/fixtures/location.yaml b/src/backend/InvenTree/stock/fixtures/location.yaml index 9269fe91b4..f4383f2ad3 100644 --- a/src/backend/InvenTree/stock/fixtures/location.yaml +++ b/src/backend/InvenTree/stock/fixtures/location.yaml @@ -27,8 +27,8 @@ name: 'Dining Room' description: 'A table lives here' parent: 1 - level: 0 tree_id: 1 + level: 1 lft: 4 rght: 5 @@ -49,8 +49,8 @@ name: 'Drawer_1' description: 'In my desk' parent: 4 - level: 0 tree_id: 2 + level: 1 lft: 2 rght: 3 @@ -60,8 +60,8 @@ name: 'Drawer_2' description: 'Also in my desk' parent: 4 - level: 0 tree_id: 2 + level: 1 lft: 4 rght: 5 @@ -71,7 +71,7 @@ name: 'Drawer_3' description: 'Again, in my desk' parent: 4 - level: 0 tree_id: 2 + level: 1 lft: 6 rght: 7 diff --git a/src/backend/InvenTree/stock/fixtures/stock.yaml b/src/backend/InvenTree/stock/fixtures/stock.yaml index 99ca9cacea..364762fa22 100644 --- a/src/backend/InvenTree/stock/fixtures/stock.yaml +++ b/src/backend/InvenTree/stock/fixtures/stock.yaml @@ -8,12 +8,12 @@ location: 3 batch: 'B123' quantity: 4000 - level: 0 - tree_id: 0 - lft: 0 - rght: 0 purchase_price: 123 purchase_price_currency: AUD + tree_id: 20 + level: 0 + lft: 1 + rght: 2 # 5,000 screws in the bathroom - model: stock.stockitem @@ -22,10 +22,10 @@ part: 1 location: 2 quantity: 5000 + tree_id: 21 level: 0 - tree_id: 0 - lft: 0 - rght: 0 + lft: 1 + rght: 2 # Capacitor C_22N_0805 in 'Office' - model: stock.stockitem @@ -34,10 +34,10 @@ part: 5 location: 4 quantity: 666 + tree_id: 16 level: 0 - tree_id: 0 - lft: 0 - rght: 0 + lft: 1 + rght: 2 # 1234 2K2 resistors in 'Drawer_1' - model: stock.stockitem @@ -46,10 +46,10 @@ part: 3 location: 5 quantity: 1234 + tree_id: 22 level: 0 - tree_id: 0 - lft: 0 - rght: 0 + lft: 1 + rght: 2 # Some widgets in drawer 3 - model: stock.stockitem @@ -60,10 +60,10 @@ location: 7 quantity: 10 delete_on_deplete: False + tree_id: 28 level: 0 - tree_id: 0 - lft: 0 - rght: 0 + lft: 1 + rght: 2 - model: stock.stockitem pk: 101 @@ -72,10 +72,10 @@ batch: "B2345" location: 7 quantity: 5 + tree_id: 27 level: 0 - tree_id: 0 - lft: 0 - rght: 0 + lft: 1 + rght: 2 - model: stock.stockitem pk: 102 @@ -84,10 +84,10 @@ batch: 'BCDE' location: 7 quantity: 0 + tree_id: 26 level: 0 - tree_id: 0 - lft: 0 - rght: 0 + lft: 1 + rght: 2 - model: stock.stockitem pk: 105 @@ -97,10 +97,10 @@ quantity: 1 serial: 1000 serial_int: 1000 + tree_id: 29 level: 0 - tree_id: 0 - lft: 0 - rght: 0 + lft: 1 + rght: 2 # Stock items for template / variant parts - model: stock.stockitem @@ -110,10 +110,10 @@ location: 7 quantity: 5 batch: "BBAAA" + tree_id: 2 level: 0 - tree_id: 0 - lft: 0 - rght: 0 + lft: 1 + rght: 2 - model: stock.stockitem pk: 501 @@ -123,10 +123,10 @@ quantity: 1 serial: 1 serial_int: 1 + tree_id: 3 level: 0 - tree_id: 0 - lft: 0 - rght: 0 + lft: 1 + rght: 2 - model: stock.stockitem pk: 502 @@ -136,10 +136,10 @@ quantity: 1 serial: 2 serial_int: 2 + tree_id: 4 level: 0 - tree_id: 0 - lft: 0 - rght: 0 + lft: 1 + rght: 2 - model: stock.stockitem pk: 503 @@ -149,10 +149,10 @@ quantity: 1 serial: 3 serial_int: 3 + tree_id: 5 level: 0 - tree_id: 0 - lft: 0 - rght: 0 + lft: 1 + rght: 2 - model: stock.stockitem pk: 504 @@ -162,10 +162,10 @@ quantity: 1 serial: 4 serial_int: 4 + tree_id: 6 level: 0 - tree_id: 0 - lft: 0 - rght: 0 + lft: 1 + rght: 2 - model: stock.stockitem pk: 505 @@ -175,10 +175,10 @@ quantity: 1 serial: 5 serial_int: 5 + tree_id: 1 level: 0 - tree_id: 0 - lft: 0 - rght: 0 + lft: 1 + rght: 2 - model: stock.stockitem pk: 510 @@ -188,10 +188,10 @@ quantity: 1 serial: 10 serial_int: 10 + tree_id: 24 level: 0 - tree_id: 0 - lft: 0 - rght: 0 + lft: 1 + rght: 2 - model: stock.stockitem pk: 511 @@ -201,10 +201,10 @@ quantity: 1 serial: 11 serial_int: 11 + tree_id: 23 level: 0 - tree_id: 0 - lft: 0 - rght: 0 + lft: 1 + rght: 2 - model: stock.stockitem pk: 512 @@ -214,10 +214,10 @@ quantity: 1 serial: 12 serial_int: 12 + tree_id: 25 level: 0 - tree_id: 0 - lft: 0 - rght: 0 + lft: 1 + rght: 2 - model: stock.stockitem pk: 520 @@ -227,12 +227,12 @@ quantity: 1 serial: 20 serial_int: 20 - level: 0 - tree_id: 0 - lft: 0 - rght: 0 expiry_date: "1990-10-10" barcode_hash: 9e5ae7fc20568ed4814c10967bba8b65 + tree_id: 18 + level: 0 + lft: 1 + rght: 2 - model: stock.stockitem pk: 521 @@ -242,12 +242,12 @@ quantity: 1 serial: 21 serial_int: 21 - level: 0 - tree_id: 0 - lft: 0 - rght: 0 status: 60 barcode_hash: 1be0dfa925825c5c6c79301449e50c2d + tree_id: 17 + level: 0 + lft: 1 + rght: 2 - model: stock.stockitem pk: 522 @@ -257,12 +257,12 @@ quantity: 1 serial: 22 serial_int: 22 - level: 0 - tree_id: 0 - lft: 0 - rght: 0 expiry_date: "1990-10-10" status: 70 + tree_id: 19 + level: 0 + lft: 1 + rght: 2 # Multiple stock items for "Bob" (PK 100) - model: stock.stockitem @@ -271,10 +271,10 @@ part: 100 location: 1 quantity: 10 + tree_id: 9 level: 0 - tree_id: 0 - lft: 0 - rght: 0 + lft: 1 + rght: 2 - model: stock.stockitem pk: 1001 @@ -282,10 +282,10 @@ part: 100 location: 1 quantity: 11 + tree_id: 14 level: 0 - tree_id: 0 - lft: 0 - rght: 0 + lft: 1 + rght: 2 - model: stock.stockitem pk: 1002 @@ -293,10 +293,10 @@ part: 100 location: 1 quantity: 12 + tree_id: 8 level: 0 - tree_id: 0 - lft: 0 - rght: 0 + lft: 1 + rght: 2 - model: stock.stockitem pk: 1003 @@ -304,10 +304,10 @@ part: 100 location: 1 quantity: 13 + tree_id: 15 level: 0 - tree_id: 0 - lft: 0 - rght: 0 + lft: 1 + rght: 2 - model: stock.stockitem pk: 1004 @@ -315,10 +315,10 @@ part: 100 location: 1 quantity: 14 + tree_id: 7 level: 0 - tree_id: 0 - lft: 0 - rght: 0 + lft: 1 + rght: 2 - model: stock.stockitem pk: 1005 @@ -326,10 +326,10 @@ part: 100 location: 1 quantity: 15 + tree_id: 13 level: 0 - tree_id: 0 - lft: 0 - rght: 0 + lft: 1 + rght: 2 - model: stock.stockitem pk: 1006 @@ -337,10 +337,10 @@ part: 100 location: 1 quantity: 16 + tree_id: 12 level: 0 - tree_id: 0 - lft: 0 - rght: 0 + lft: 1 + rght: 2 - model: stock.stockitem pk: 1007 @@ -348,10 +348,10 @@ part: 100 location: 7 quantity: 17 + tree_id: 11 level: 0 - tree_id: 0 - lft: 0 - rght: 0 + lft: 1 + rght: 2 - model: stock.stockitem pk: 1008 @@ -359,7 +359,7 @@ part: 100 location: 7 quantity: 18 + tree_id: 10 level: 0 - tree_id: 0 - lft: 0 - rght: 0 + lft: 1 + rght: 2 diff --git a/src/backend/InvenTree/stock/models.py b/src/backend/InvenTree/stock/models.py index 7e6f95cb9a..80493d5610 100644 --- a/src/backend/InvenTree/stock/models.py +++ b/src/backend/InvenTree/stock/models.py @@ -14,7 +14,7 @@ from django.core.validators import MinValueValidator from django.db import models, transaction from django.db.models import Q, QuerySet, Sum from django.db.models.functions import Coalesce -from django.db.models.signals import post_delete, post_save, pre_delete +from django.db.models.signals import post_delete, post_save from django.db.utils import IntegrityError, OperationalError from django.dispatch import receiver from django.urls import reverse @@ -23,7 +23,7 @@ from django.utils.translation import gettext_lazy as _ import structlog from djmoney.contrib.exchange.models import convert_money from mptt.managers import TreeManager -from mptt.models import MPTTModel, TreeForeignKey +from mptt.models import TreeForeignKey from taggit.managers import TaggableManager import build.models @@ -135,8 +135,11 @@ class StockLocationReportContext(report.mixins.BaseReportContext): class StockLocation( + InvenTree.models.PluginValidationMixin, InvenTree.models.InvenTreeBarcodeMixin, report.mixins.InvenTreeReportMixin, + InvenTree.models.PathStringMixin, + InvenTree.models.MetadataMixin, InvenTree.models.InvenTreeTree, ): """Organization tree for StockItem objects. @@ -409,15 +412,15 @@ class StockItemReportContext(report.mixins.BaseReportContext): class StockItem( + InvenTree.models.PluginValidationMixin, InvenTree.models.InvenTreeAttachmentMixin, InvenTree.models.InvenTreeBarcodeMixin, InvenTree.models.InvenTreeNotesMixin, StatusCodeMixin, report.mixins.InvenTreeReportMixin, - InvenTree.models.MetadataMixin, - InvenTree.models.PluginValidationMixin, common.models.MetaMixin, - MPTTModel, + InvenTree.models.MetadataMixin, + InvenTree.models.InvenTreeTree, ): """A StockItem object represents a quantity of physical instances of a part. @@ -453,6 +456,11 @@ class StockItem( verbose_name = _('Stock Item') + class MPTTMeta: + """MPTT metaclass options.""" + + order_insertion_by = ['part'] + @staticmethod def get_api_url(): """Return API url.""" @@ -600,13 +608,19 @@ class StockItem( raise ValidationError({'part': _('Part must be specified')}) part = data['part'] - tree_id = kwargs.pop('tree_id', 0) - data['parent'] = kwargs.pop('parent', None) or data.get('parent') - data['tree_id'] = tree_id - data['level'] = kwargs.pop('level', 0) - data['rght'] = kwargs.pop('rght', 0) - data['lft'] = kwargs.pop('lft', 0) + parent = kwargs.pop('parent', None) or data.get('parent') + tree_id = kwargs.pop('tree_id', StockItem.getNextTreeID()) + + if parent: + # Override with parent's tree_id if provided + tree_id = parent.tree_id + + # Pre-calculate MPTT fields + data['parent'] = parent if parent else None + data['level'] = parent.level + 1 if parent else 0 + data['lft'] = 0 if parent else 1 + data['rght'] = 0 if parent else 2 # Force single quantity for each item data['quantity'] = 1 @@ -615,6 +629,13 @@ class StockItem( data['serial'] = serial data['serial_int'] = StockItem.convert_serial_to_int(serial) + data['tree_id'] = tree_id + + if not parent: + # No parent, this is a top-level item, so increment the tree_id + # This is because each new item is a "top-level" node in the StockItem tree + tree_id += 1 + # Construct a new StockItem from the provided dict items.append(StockItem(**data)) @@ -622,9 +643,12 @@ class StockItem( StockItem.objects.bulk_create(items) # We will need to rebuild the stock item tree manually, due to the bulk_create operation - InvenTree.tasks.offload_task( - stock.tasks.rebuild_stock_item_tree, tree_id=tree_id, group='stock' - ) + if parent and parent.tree_id: + # Rebuild the tree structure for this StockItem tree + logger.info( + 'Rebuilding StockItem tree structure for tree_id: %s', parent.tree_id + ) + stock.tasks.rebuild_stock_item_tree(parent.tree_id) # Return the newly created StockItem objects return StockItem.objects.filter(part=part, serial__in=serials) @@ -1748,8 +1772,8 @@ class StockItem( self, quantity: int, serials: list[str], - user: User, - notes: str = '', + user: Optional[User] = None, + notes: Optional[str] = '', location: Optional[StockLocation] = None, ): """Split this stock item into unique serial numbers. @@ -2085,10 +2109,17 @@ class StockItem( self.save() # Rebuild stock trees as required + rebuild_result = True for tree_id in tree_ids: - InvenTree.tasks.offload_task( - stock.tasks.rebuild_stock_item_tree, tree_id=tree_id, group='stock' + if not stock.tasks.rebuild_stock_item_tree(tree_id, rebuild_on_fail=False): + rebuild_result = False + + if not rebuild_result: + # If the rebuild failed, offload the task to a background worker + logger.warning( + 'Failed to rebuild stock item tree during merge_stock_items operation, offloading task.' ) + InvenTree.tasks.offload_task(stock.tasks.rebuild_stock_items, group='stock') @transaction.atomic def splitStock(self, quantity, location=None, user=None, **kwargs): @@ -2150,7 +2181,7 @@ class StockItem( # Update the new stock item to ensure the tree structure is observed new_stock.parent = self - new_stock.level = self.level + 1 + new_stock.tree_id = None # Move to the new location if specified, otherwise use current location if location: @@ -2192,9 +2223,7 @@ class StockItem( ) # Rebuild the tree for this parent item - InvenTree.tasks.offload_task( - stock.tasks.rebuild_stock_item_tree, tree_id=self.tree_id, group='stock' - ) + stock.tasks.rebuild_stock_item_tree(self.tree_id) # Attempt to reload the new item from the database try: @@ -2648,19 +2677,6 @@ class StockItem( return status['passed'] >= status['total'] -@receiver(pre_delete, sender=StockItem, dispatch_uid='stock_item_pre_delete_log') -def before_delete_stock_item(sender, instance, using, **kwargs): - """Receives pre_delete signal from StockItem object. - - Before a StockItem is deleted, ensure that each child object is updated, - to point to the new parent item. - """ - # Update each StockItem parent field - for child in instance.children.all(): - child.parent = instance.parent - child.save() - - @receiver(post_delete, sender=StockItem, dispatch_uid='stock_item_post_delete_log') def after_delete_stock_item(sender, instance: StockItem, **kwargs): """Function to be executed after a StockItem object is deleted.""" diff --git a/src/backend/InvenTree/stock/serializers.py b/src/backend/InvenTree/stock/serializers.py index d48c9371f8..55c14640e2 100644 --- a/src/backend/InvenTree/stock/serializers.py +++ b/src/backend/InvenTree/stock/serializers.py @@ -788,7 +788,7 @@ class SerializeStockItemSerializer(serializers.Serializer): item.serializeStock( data['quantity'], serials, - user, + user=user, notes=data.get('notes', ''), location=data['destination'], ) diff --git a/src/backend/InvenTree/stock/tasks.py b/src/backend/InvenTree/stock/tasks.py index 70ec32a1bc..e3b4f5e4ab 100644 --- a/src/backend/InvenTree/stock/tasks.py +++ b/src/backend/InvenTree/stock/tasks.py @@ -7,23 +7,61 @@ tracer = trace.get_tracer(__name__) logger = structlog.get_logger('inventree') -@tracer.start_as_current_span('rebuild_stock_item_tree') -def rebuild_stock_item_tree(tree_id=None): - """Rebuild the stock tree structure. +@tracer.start_as_current_span('rebuild_stock_items') +def rebuild_stock_items(): + """Rebuild the entire StockItem tree structure. - The StockItem tree uses the MPTT library to manage the tree structure. + This may be necessary if the tree structure has become corrupted or inconsistent. """ from InvenTree.exceptions import log_error + from InvenTree.sentry import report_exception + from stock.models import StockItem + + logger.info('Rebuilding StockItem tree structure') + + try: + StockItem.objects.rebuild() + except Exception as e: + # This is a critical error, explicitly report to sentry + report_exception(e) + + log_error('rebuild_stock_items') + logger.exception('Failed to rebuild StockItem tree: %s', e) + + +def rebuild_stock_item_tree(tree_id: int, rebuild_on_fail: bool = True) -> bool: + """Rebuild the stock tree structure. + + Arguments: + tree_id (int): The ID of the StockItem tree to rebuild. + rebuild_on_fail (bool): If True, will attempt to rebuild the entire StockItem tree if the partial rebuild fails. + + Returns: + bool: True if the partial tree rebuild was successful, False otherwise. + + - If the rebuild fails, schedule a rebuild of the entire StockItem tree. + """ + from InvenTree.exceptions import log_error + from InvenTree.sentry import report_exception + from InvenTree.tasks import offload_task from stock.models import StockItem if tree_id: try: StockItem.objects.partial_rebuild(tree_id) - except Exception: + logger.info('Rebuilt StockItem tree for tree_id: %s', tree_id) + return True + except Exception as e: + # This is a critical error, explicitly report to sentry + report_exception(e) + log_error('rebuild_stock_item_tree') - logger.warning('Failed to rebuild StockItem tree') + logger.warning('Failed to rebuild StockItem tree for tree_id: %s', tree_id) # If the partial rebuild fails, rebuild the entire tree - StockItem.objects.rebuild() + if rebuild_on_fail: + offload_task(rebuild_stock_items, group='stock') + return False else: # No tree_id provided, so rebuild the entire tree StockItem.objects.rebuild() + return True diff --git a/src/backend/InvenTree/stock/test_api.py b/src/backend/InvenTree/stock/test_api.py index a200c199a4..6693f3c57e 100644 --- a/src/backend/InvenTree/stock/test_api.py +++ b/src/backend/InvenTree/stock/test_api.py @@ -73,11 +73,11 @@ class StockLocationTest(StockAPITestCase): ({}, 8, 'no parameters'), ({'parent': 1, 'cascade': False}, 2, 'Filter by parent, no cascading'), ({'parent': 1, 'cascade': True}, 2, 'Filter by parent, cascading'), - ({'cascade': True, 'depth': 0}, 7, 'Cascade with no parent, depth=0'), + ({'cascade': True, 'depth': 0}, 3, 'Cascade with no parent, depth=0'), ({'cascade': False, 'depth': 10}, 3, 'Cascade with no parent, depth=10'), ( {'parent': 1, 'cascade': False, 'depth': 0}, - 1, + 0, 'Dont cascade with depth=0 and parent', ), ( @@ -450,8 +450,6 @@ class StockLocationTest(StockAPITestCase): name=f'Location {idx}', description=f'Test location {idx}', parent=loc ) - StockLocation.objects.rebuild() - with self.assertNumQueriesLessThan(15): response = self.get(reverse('api-location-tree'), expected_code=200) @@ -596,13 +594,13 @@ class StockItemListTest(StockAPITestCase): def test_filter_by_part(self): """Filter StockItem by Part reference.""" + # 4 stock items associated with part 25 response = self.get_stock(part=25) + self.assertEqual(len(response), 4) - self.assertEqual(len(response), 17) - + # 3 stock items associated with part 10004 response = self.get_stock(part=10004) - - self.assertEqual(len(response), 12) + self.assertEqual(len(response), 3) def test_filter_by_ipn(self): """Filter StockItem by IPN reference.""" @@ -884,9 +882,15 @@ class StockItemListTest(StockAPITestCase): # Part name should match self.assertEqual(row['Part.Name'], item.part.name) + parts = Part.objects.get(pk=25).get_descendants(include_self=True) + self.assertEqual(parts.count(), 1) + + items = StockItem.objects.filter(part__in=parts) + self.assertEqual(items.count(), 4) + # Export stock items with a specific part with self.export_data(self.list_url, {'part': 25}) as data_file: - self.process_csv(data_file, required_rows=17) + self.process_csv(data_file, required_rows=items.count()) def test_filter_by_allocated(self): """Test that we can filter by "allocated" status. @@ -1034,9 +1038,9 @@ class StockItemListTest(StockAPITestCase): # With full data response = self.post(url, {'part': 1, 'quantity': 1}) - self.assertEqual(response.data['serial_number'], '1001') + self.assertEqual(response.data['serial_number'], '1') response = self.post(url, {'part': 1, 'quantity': 3}) - self.assertEqual(response.data['serial_number'], '1001,1002,1003') + self.assertEqual(response.data['serial_number'], '1,2,3') # Wrong quantities response = self.post(url, {'part': 1, 'quantity': 'abc'}, expected_code=400) diff --git a/src/backend/InvenTree/stock/tests.py b/src/backend/InvenTree/stock/tests.py index 6961638c3c..f71270d821 100644 --- a/src/backend/InvenTree/stock/tests.py +++ b/src/backend/InvenTree/stock/tests.py @@ -50,89 +50,10 @@ class StockTestBase(InvenTreeTestCase): cls.drawer2 = StockLocation.objects.get(name='Drawer_2') cls.drawer3 = StockLocation.objects.get(name='Drawer_3') - # Ensure the MPTT objects are correctly rebuild - Part.objects.rebuild() - StockItem.objects.rebuild() - class StockTest(StockTestBase): """Tests to ensure that the stock location tree functions correctly.""" - def test_pathstring(self): - """Check that pathstring updates occur as expected.""" - a = StockLocation.objects.create(name='A') - b = StockLocation.objects.create(name='B', parent=a) - c = StockLocation.objects.create(name='C', parent=b) - d = StockLocation.objects.create(name='D', parent=c) - - def refresh(): - a.refresh_from_db() - b.refresh_from_db() - c.refresh_from_db() - d.refresh_from_db() - - # Initial checks - self.assertEqual(a.pathstring, 'A') - self.assertEqual(b.pathstring, 'A/B') - self.assertEqual(c.pathstring, 'A/B/C') - self.assertEqual(d.pathstring, 'A/B/C/D') - - c.name = 'Cc' - c.save() - - refresh() - self.assertEqual(a.pathstring, 'A') - self.assertEqual(b.pathstring, 'A/B') - self.assertEqual(c.pathstring, 'A/B/Cc') - self.assertEqual(d.pathstring, 'A/B/Cc/D') - - b.name = 'Bb' - b.save() - - refresh() - self.assertEqual(a.pathstring, 'A') - self.assertEqual(b.pathstring, 'A/Bb') - self.assertEqual(c.pathstring, 'A/Bb/Cc') - self.assertEqual(d.pathstring, 'A/Bb/Cc/D') - - a.name = 'Aa' - a.save() - - refresh() - self.assertEqual(a.pathstring, 'Aa') - self.assertEqual(b.pathstring, 'Aa/Bb') - self.assertEqual(c.pathstring, 'Aa/Bb/Cc') - self.assertEqual(d.pathstring, 'Aa/Bb/Cc/D') - - d.name = 'Dd' - d.save() - - refresh() - self.assertEqual(a.pathstring, 'Aa') - self.assertEqual(b.pathstring, 'Aa/Bb') - self.assertEqual(c.pathstring, 'Aa/Bb/Cc') - self.assertEqual(d.pathstring, 'Aa/Bb/Cc/Dd') - - # Test a really long name - # (it will be clipped to < 250 characters) - a.name = 'A' * 100 - a.save() - b.name = 'B' * 100 - b.save() - c.name = 'C' * 100 - c.save() - d.name = 'D' * 100 - d.save() - - refresh() - self.assertEqual(len(a.pathstring), 100) - self.assertEqual(len(b.pathstring), 201) - self.assertEqual(len(c.pathstring), 249) - self.assertEqual(len(d.pathstring), 249) - - self.assertTrue(d.pathstring.startswith('AAAAAAAA')) - self.assertTrue(d.pathstring.endswith('DDDDDDDD')) - def test_link(self): """Test the link URL field validation.""" item = StockItem.objects.get(pk=1) @@ -768,132 +689,6 @@ class StockTest(StockTestBase): # Serialize the remainder of the stock item.serializeStock(2, [99, 100], self.user) - def test_location_tree(self): - """Unit tests for stock location tree structure (MPTT). - - Ensure that the MPTT structure is rebuilt correctly, - and the current ancestor tree is observed. - - Ref: https://github.com/inventree/InvenTree/issues/2636 - Ref: https://github.com/inventree/InvenTree/issues/2733 - """ - # First, we will create a stock location structure - - A = StockLocation.objects.create(name='A', description='Top level location') - B1 = StockLocation.objects.create(name='B1', parent=A) - B2 = StockLocation.objects.create(name='B2', parent=A) - B3 = StockLocation.objects.create(name='B3', parent=A) - C11 = StockLocation.objects.create(name='C11', parent=B1) - C12 = StockLocation.objects.create(name='C12', parent=B1) - C21 = StockLocation.objects.create(name='C21', parent=B2) - C22 = StockLocation.objects.create(name='C22', parent=B2) - C31 = StockLocation.objects.create(name='C31', parent=B3) - C32 = StockLocation.objects.create(name='C32', parent=B3) - - # Check that the tree_id is correct for each sublocation - for loc in [B1, B2, B3, C11, C12, C21, C22, C31, C32]: - self.assertEqual(loc.tree_id, A.tree_id) - - # Check that the tree levels are correct for each node in the tree - - self.assertEqual(A.level, 0) - self.assertEqual(A.get_ancestors().count(), 0) - - for loc in [B1, B2, B3]: - self.assertEqual(loc.parent, A) - self.assertEqual(loc.level, 1) - self.assertEqual(loc.get_ancestors().count(), 1) - - for loc in [C11, C12]: - self.assertEqual(loc.parent, B1) - self.assertEqual(loc.level, 2) - self.assertEqual(loc.get_ancestors().count(), 2) - - for loc in [C21, C22]: - self.assertEqual(loc.parent, B2) - self.assertEqual(loc.level, 2) - self.assertEqual(loc.get_ancestors().count(), 2) - - for loc in [C31, C32]: - self.assertEqual(loc.parent, B3) - self.assertEqual(loc.level, 2) - self.assertEqual(loc.get_ancestors().count(), 2) - - # Spot-check for C32 - ancestors = C32.get_ancestors(include_self=True) - - self.assertEqual(ancestors[0], A) - self.assertEqual(ancestors[1], B3) - self.assertEqual(ancestors[2], C32) - - # At this point, we are confident that the tree is correctly structured. - - # Let's delete node B3 from the tree. We expect that: - # - C31 should move directly under A - # - C32 should move directly under A - - # Add some stock items to B3 - for _ in range(10): - StockItem.objects.create( - part=Part.objects.get(pk=1), quantity=10, location=B3 - ) - - self.assertEqual(StockItem.objects.filter(location=B3).count(), 10) - self.assertEqual(StockItem.objects.filter(location=A).count(), 0) - - B3.delete() - - A.refresh_from_db() - C31.refresh_from_db() - C32.refresh_from_db() - - # Stock items have been moved to A - self.assertEqual(StockItem.objects.filter(location=A).count(), 10) - - # Parent should be A - self.assertEqual(C31.parent, A) - self.assertEqual(C32.parent, A) - - self.assertEqual(C31.tree_id, A.tree_id) - self.assertEqual(C31.level, 1) - - self.assertEqual(C32.tree_id, A.tree_id) - self.assertEqual(C32.level, 1) - - # Ancestor tree should be just A - ancestors = C31.get_ancestors() - self.assertEqual(ancestors.count(), 1) - self.assertEqual(ancestors[0], A) - - ancestors = C32.get_ancestors() - self.assertEqual(ancestors.count(), 1) - self.assertEqual(ancestors[0], A) - - # Delete A - A.delete() - - # Stock items have been moved to top-level location - self.assertEqual(StockItem.objects.filter(location=None).count(), 10) - - for loc in [B1, B2, C11, C12, C21, C22]: - loc.refresh_from_db() - - self.assertEqual(B1.parent, None) - self.assertEqual(B2.parent, None) - - self.assertEqual(C11.parent, B1) - self.assertEqual(C12.parent, B1) - self.assertEqual(C11.get_ancestors().count(), 1) - self.assertEqual(C12.get_ancestors().count(), 1) - - self.assertEqual(C21.parent, B2) - self.assertEqual(C22.parent, B2) - - ancestors = C21.get_ancestors() - - self.assertEqual(C21.get_ancestors().count(), 1) - self.assertEqual(C22.get_ancestors().count(), 1) - def test_metadata(self): """Unit tests for the metadata field.""" for model in [StockItem, StockLocation]: @@ -1092,7 +887,9 @@ class VariantTest(StockTestBase): item.save() # Attempt to create the same serial number but for a variant (should fail!) + # Reset the primary key and tree_id values item.pk = None + item.tree_id = None item.part = Part.objects.get(pk=10004) with self.assertRaises(ValidationError): @@ -1102,13 +899,216 @@ class VariantTest(StockTestBase): item.save() +class StockLocationTreeTest(StockTestBase): + """Unit test for the StockLocation tree structure.""" + + def test_pathstring(self): + """Check that pathstring updates occur as expected.""" + a = StockLocation.objects.create(name='A') + b = StockLocation.objects.create(name='B', parent=a) + c = StockLocation.objects.create(name='C', parent=b) + d = StockLocation.objects.create(name='D', parent=c) + + def refresh(): + a.refresh_from_db() + b.refresh_from_db() + c.refresh_from_db() + d.refresh_from_db() + + # Initial checks + self.assertEqual(a.pathstring, 'A') + self.assertEqual(b.pathstring, 'A/B') + self.assertEqual(c.pathstring, 'A/B/C') + self.assertEqual(d.pathstring, 'A/B/C/D') + + c.name = 'Cc' + c.save() + + refresh() + self.assertEqual(a.pathstring, 'A') + self.assertEqual(b.pathstring, 'A/B') + self.assertEqual(c.pathstring, 'A/B/Cc') + self.assertEqual(d.pathstring, 'A/B/Cc/D') + + b.name = 'Bb' + b.save() + + refresh() + self.assertEqual(a.pathstring, 'A') + self.assertEqual(b.pathstring, 'A/Bb') + self.assertEqual(c.pathstring, 'A/Bb/Cc') + self.assertEqual(d.pathstring, 'A/Bb/Cc/D') + + a.name = 'Aa' + a.save() + + refresh() + self.assertEqual(a.pathstring, 'Aa') + self.assertEqual(b.pathstring, 'Aa/Bb') + self.assertEqual(c.pathstring, 'Aa/Bb/Cc') + self.assertEqual(d.pathstring, 'Aa/Bb/Cc/D') + + d.name = 'Dd' + d.save() + + refresh() + self.assertEqual(a.pathstring, 'Aa') + self.assertEqual(b.pathstring, 'Aa/Bb') + self.assertEqual(c.pathstring, 'Aa/Bb/Cc') + self.assertEqual(d.pathstring, 'Aa/Bb/Cc/Dd') + + # Test a really long name + # (it will be clipped to < 250 characters) + a.name = 'A' * 100 + a.save() + b.name = 'B' * 100 + b.save() + c.name = 'C' * 100 + c.save() + d.name = 'D' * 100 + d.save() + + refresh() + self.assertEqual(len(a.pathstring), 100) + self.assertEqual(len(b.pathstring), 201) + self.assertEqual(len(c.pathstring), 249) + self.assertEqual(len(d.pathstring), 249) + + self.assertTrue(d.pathstring.startswith('AAAAAAAA')) + self.assertTrue(d.pathstring.endswith('DDDDDDDD')) + + def test_location_tree(self): + """Unit tests for stock location tree structure (MPTT). + + Ensure that the MPTT structure is rebuilt correctly, + and the current ancestor tree is observed. + + Ref: https://github.com/inventree/InvenTree/issues/2636 + Ref: https://github.com/inventree/InvenTree/issues/2733 + """ + # First, we will create a stock location structure + + A = StockLocation.objects.create(name='A', description='Top level location') + B1 = StockLocation.objects.create(name='B1', parent=A) + B2 = StockLocation.objects.create(name='B2', parent=A) + B3 = StockLocation.objects.create(name='B3', parent=A) + C11 = StockLocation.objects.create(name='C11', parent=B1) + C12 = StockLocation.objects.create(name='C12', parent=B1) + C21 = StockLocation.objects.create(name='C21', parent=B2) + C22 = StockLocation.objects.create(name='C22', parent=B2) + C31 = StockLocation.objects.create(name='C31', parent=B3) + C32 = StockLocation.objects.create(name='C32', parent=B3) + + # Check that the tree_id is correct for each sublocation + for loc in [B1, B2, B3, C11, C12, C21, C22, C31, C32]: + self.assertEqual(loc.tree_id, A.tree_id) + + # Check that the tree levels are correct for each node in the tree + + self.assertEqual(A.level, 0) + self.assertEqual(A.get_ancestors().count(), 0) + + for loc in [B1, B2, B3]: + self.assertEqual(loc.parent, A) + self.assertEqual(loc.level, 1) + self.assertEqual(loc.get_ancestors().count(), 1) + + for loc in [C11, C12]: + self.assertEqual(loc.parent, B1) + self.assertEqual(loc.level, 2) + self.assertEqual(loc.get_ancestors().count(), 2) + + for loc in [C21, C22]: + self.assertEqual(loc.parent, B2) + self.assertEqual(loc.level, 2) + self.assertEqual(loc.get_ancestors().count(), 2) + + for loc in [C31, C32]: + self.assertEqual(loc.parent, B3) + self.assertEqual(loc.level, 2) + self.assertEqual(loc.get_ancestors().count(), 2) + + # Spot-check for C32 + ancestors = C32.get_ancestors(include_self=True) + + self.assertEqual(ancestors[0], A) + self.assertEqual(ancestors[1], B3) + self.assertEqual(ancestors[2], C32) + + # At this point, we are confident that the tree is correctly structured. + + # Let's delete node B3 from the tree. We expect that: + # - C31 should move directly under A + # - C32 should move directly under A + + # Add some stock items to B3 + for _ in range(10): + StockItem.objects.create( + part=Part.objects.get(pk=1), quantity=10, location=B3 + ) + + self.assertEqual(StockItem.objects.filter(location=B3).count(), 10) + self.assertEqual(StockItem.objects.filter(location=A).count(), 0) + + B3.delete() + + A.refresh_from_db() + C31.refresh_from_db() + C32.refresh_from_db() + + # Stock items have been moved to A + self.assertEqual(StockItem.objects.filter(location=A).count(), 10) + + # Parent should be A + self.assertEqual(C31.parent, A) + self.assertEqual(C32.parent, A) + + self.assertEqual(C31.tree_id, A.tree_id) + self.assertEqual(C31.level, 1) + + self.assertEqual(C32.tree_id, A.tree_id) + self.assertEqual(C32.level, 1) + + # Ancestor tree should be just A + ancestors = C31.get_ancestors() + self.assertEqual(ancestors.count(), 1) + self.assertEqual(ancestors[0], A) + + ancestors = C32.get_ancestors() + self.assertEqual(ancestors.count(), 1) + self.assertEqual(ancestors[0], A) + + # Delete A + A.delete() + + # Stock items have been moved to top-level location + self.assertEqual(StockItem.objects.filter(location=None).count(), 10) + + for loc in [B1, B2, C11, C12, C21, C22]: + loc.refresh_from_db() + + self.assertEqual(B1.parent, None) + self.assertEqual(B2.parent, None) + + self.assertEqual(C11.parent, B1) + self.assertEqual(C12.parent, B1) + self.assertEqual(C11.get_ancestors().count(), 1) + self.assertEqual(C12.get_ancestors().count(), 1) + + self.assertEqual(C21.parent, B2) + self.assertEqual(C22.parent, B2) + + ancestors = C21.get_ancestors() + + self.assertEqual(C21.get_ancestors().count(), 1) + self.assertEqual(C22.get_ancestors().count(), 1) + + class StockTreeTest(StockTestBase): """Unit test for StockItem tree structure.""" def test_stock_split(self): """Test that stock splitting works correctly.""" - StockItem.objects.rebuild() - part = Part.objects.create(name='My part', description='My part description') location = StockLocation.objects.create(name='Test Location') @@ -1158,6 +1158,127 @@ class StockTreeTest(StockTestBase): item.refresh_from_db() self.assertEqual(item.get_descendants(include_self=True).count(), n + 30) + def test_tree_rebuild(self): + """Test that tree rebuild works correctly.""" + part = Part.objects.create(name='My part', description='My part description') + location = StockLocation.objects.create(name='Test Location') + + N = StockItem.objects.count() + + # Create an initial stock item + item = StockItem.objects.create(part=part, quantity=1000, location=location) + + # Split out ten child items + for _idx in range(10): + item.splitStock(10) + + item.refresh_from_db() + + self.assertEqual(StockItem.objects.count(), N + 11) + self.assertEqual(item.get_children().count(), 10) + self.assertEqual(item.get_descendants(include_self=True).count(), 11) + + # Split the first child item + child = item.get_children().first() + + self.assertEqual(child.parent, item) + self.assertEqual(child.tree_id, item.tree_id) + self.assertEqual(child.level, 1) + + # Split out three grandchildren + for _ in range(3): + child.splitStock(2) + + item.refresh_from_db() + child.refresh_from_db() + + self.assertEqual(child.get_descendants(include_self=True).count(), 4) + self.assertEqual(child.get_children().count(), 3) + + # Check tree structure for grandchildren + grandchildren = child.get_children() + + for gc in grandchildren: + self.assertEqual(gc.parent, child) + self.assertEqual(gc.parent.parent, item) + self.assertEqual(gc.tree_id, item.tree_id) + self.assertEqual(gc.level, 2) + self.assertGreater(gc.lft, child.lft) + self.assertLess(gc.rght, child.rght) + + self.assertEqual(item.get_children().count(), 10) + self.assertEqual(item.get_descendants(include_self=True).count(), 14) + + # Now, delete the child node + # We expect that the grandchildren will be re-parented to the parent node + child.delete() + + for gc in grandchildren: + gc.refresh_from_db() + + # Check that the grandchildren have been re-parented to the top-level + self.assertEqual(gc.parent, item) + self.assertEqual(gc.tree_id, item.tree_id) + self.assertEqual(gc.level, 1) + self.assertGreater(gc.lft, item.lft) + self.assertLess(gc.rght, item.rght) + + item.refresh_from_db() + + self.assertEqual(item.get_children().count(), 12) + self.assertEqual(item.get_descendants(include_self=True).count(), 13) + + def test_serialize(self): + """Test that StockItem serialization maintains tree structure.""" + part = Part.objects.create( + name='My part', description='My part description', trackable=True + ) + + N = StockItem.objects.count() + + # Create an initial stock item + item_1 = StockItem.objects.create(part=part, quantity=1000) + item_2 = item_1.splitStock(750) + + item_1.refresh_from_db() + + self.assertEqual(StockItem.objects.count(), N + 2) + self.assertEqual(item_1.get_children().count(), 1) + self.assertEqual(item_2.parent, item_1) + + # Serialize the secondary item + serials = [str(i) for i in range(20)] + items = item_2.serializeStock(20, serials) + + self.assertEqual(len(items), 20) + self.assertEqual(StockItem.objects.count(), N + 22) + + item_1.refresh_from_db() + item_2.refresh_from_db() + + self.assertEqual(item_1.get_children().count(), 1) + self.assertEqual(item_2.get_children().count(), 20) + + for child in items: + self.assertEqual(child.tree_id, item_2.tree_id) + self.assertEqual(child.level, 2) + self.assertEqual(child.parent, item_2) + self.assertGreater(child.lft, item_2.lft) + self.assertLess(child.rght, item_2.rght) + + # Delete item_2 : we expect that all children will be re-parented to item_1 + item_2.delete() + + for child in items: + child.refresh_from_db() + + # Check that the children have been re-parented to item_1 + self.assertEqual(child.parent, item_1) + self.assertEqual(child.tree_id, item_1.tree_id) + self.assertEqual(child.level, 1) + self.assertGreater(child.lft, item_1.lft) + self.assertLess(child.rght, item_1.rght) + class TestResultTest(StockTestBase): """Tests for the StockItemTestResult model.""" @@ -1245,11 +1366,12 @@ class TestResultTest(StockTestBase): from plugin.registry import registry - StockItem.objects.rebuild() - item = StockItem.objects.get(pk=522) + # Let's duplicate this item item.pk = None + item.parent = None + item.tree_id = None item.serial = None item.quantity = 50