diff --git a/InvenTree/part/test_category.py b/InvenTree/part/test_category.py index 53030d402a..6eb76fa845 100644 --- a/InvenTree/part/test_category.py +++ b/InvenTree/part/test_category.py @@ -172,3 +172,122 @@ class CategoryTest(TestCase): # And one part should have no default location at all w = Part.objects.get(name='Widget') self.assertIsNone(w.get_default_location()) + + def test_category_tree(self): + """ + Unit tests for the part category tree structure (MPTT) + Ensure that the MPTT structure is rebuilt correctly, + and the correct ancestor tree is observed. + """ + + # Clear out any existing parts + Part.objects.all().delete() + + # First, create a structured tree of part categories + A = PartCategory.objects.create( + name='A', + description='Top level category', + ) + + B1 = PartCategory.objects.create(name='B1', parent=A) + B2 = PartCategory.objects.create(name='B2', parent=A) + B3 = PartCategory.objects.create(name='B3', parent=A) + + C11 = PartCategory.objects.create(name='C11', parent=B1) + C12 = PartCategory.objects.create(name='C12', parent=B1) + C13 = PartCategory.objects.create(name='C13', parent=B1) + + C21 = PartCategory.objects.create(name='C21', parent=B2) + C22 = PartCategory.objects.create(name='C22', parent=B2) + C23 = PartCategory.objects.create(name='C23', parent=B2) + + C31 = PartCategory.objects.create(name='C31', parent=B3) + C32 = PartCategory.objects.create(name='C32', parent=B3) + C33 = PartCategory.objects.create(name='C33', parent=B3) + + # Check that the tree_id value is correct + for cat in [B1, B2, B3, C11, C22, C33]: + self.assertEqual(cat.tree_id, A.tree_id) + self.assertEqual(cat.level, cat.parent.level + 1) + self.assertEqual(cat.get_ancestors().count(), cat.level) + + # Spot check for C31 + ancestors = C31.get_ancestors(include_self=True) + + self.assertEqual(ancestors.count(), 3) + self.assertEqual(ancestors[0], A) + self.assertEqual(ancestors[1], B3) + self.assertEqual(ancestors[2], C31) + + # At this point, we are confident that the tree is correctly structured + + # Add some parts to category B3 + + for i in range(10): + Part.objects.create( + name=f'Part {i}', + description='A test part', + category=B3, + ) + + self.assertEqual(Part.objects.filter(category=B3).count(), 10) + self.assertEqual(Part.objects.filter(category=A).count(), 0) + + # Delete category B3 + B3.delete() + + # Child parts have been moved to category A + self.assertEqual(Part.objects.filter(category=A).count(), 10) + + for cat in [C31, C32, C33]: + # These categories should now be directly under A + cat.refresh_from_db() + + self.assertEqual(cat.parent, A) + self.assertEqual(cat.level, 1) + self.assertEqual(cat.get_ancestors().count(), 1) + self.assertEqual(cat.get_ancestors()[0], A) + + # Now, delete category A + A.delete() + + # Parts have now been moved to the top-level category + self.assertEqual(Part.objects.filter(category=None).count(), 10) + + for loc in [B1, B2, C31, C32, C33]: + # These should now all be "top level" categories + loc.refresh_from_db() + + self.assertEqual(loc.level, 0) + self.assertEqual(loc.parent, None) + + # Check descendants for B1 + descendants = B1.get_descendants() + self.assertEqual(descendants.count(), 3) + + for loc in [C11, C12, C13]: + self.assertTrue(loc in descendants) + + # Check category C1x, should be B1 -> C1x + for loc in [C11, C12, C13]: + loc.refresh_from_db() + + self.assertEqual(loc.level, 1) + self.assertEqual(loc.parent, B1) + ancestors = loc.get_ancestors(include_self=True) + + self.assertEqual(ancestors.count(), 2) + self.assertEqual(ancestors[0], B1) + self.assertEqual(ancestors[1], loc) + + # Check category C2x, should be B2 -> C2x + for loc in [C21, C22, C23]: + loc.refresh_from_db() + + self.assertEqual(loc.level, 1) + self.assertEqual(loc.parent, B2) + ancestors = loc.get_ancestors(include_self=True) + + self.assertEqual(ancestors.count(), 2) + self.assertEqual(ancestors[0], B2) + self.assertEqual(ancestors[1], loc) diff --git a/InvenTree/stock/tests.py b/InvenTree/stock/tests.py index 20dc562f54..50f77a593b 100644 --- a/InvenTree/stock/tests.py +++ b/InvenTree/stock/tests.py @@ -688,7 +688,6 @@ class StockTest(TestCase): self.assertEqual(C22.parent, B2) ancestors = C21.get_ancestors() - print("C21 ancestors:", ancestors) self.assertEqual(C21.get_ancestors().count(), 1) self.assertEqual(C22.get_ancestors().count(), 1)