diff --git a/InvenTree/part/migrations/0112_auto_20230525_1606.py b/InvenTree/part/migrations/0112_auto_20230525_1606.py index 43e7d2abf5..c1aebf3866 100644 --- a/InvenTree/part/migrations/0112_auto_20230525_1606.py +++ b/InvenTree/part/migrations/0112_auto_20230525_1606.py @@ -57,7 +57,6 @@ class AddFieldOrSkip(migrations.AddField): try: super().database_forwards(app_label, schema_editor, from_state, to_state) - print(f'Added field {self.name} to model {self.model_name}') except Exception as exc: pass diff --git a/InvenTree/stock/models.py b/InvenTree/stock/models.py index 18e7fe4f6b..ec6251cb79 100644 --- a/InvenTree/stock/models.py +++ b/InvenTree/stock/models.py @@ -2,6 +2,7 @@ from __future__ import annotations +import logging import os from datetime import datetime, timedelta from decimal import Decimal, InvalidOperation @@ -49,6 +50,8 @@ from part import models as PartModels from plugin.events import trigger_event from users.models import Owner +logger = logging.getLogger('inventree') + class StockLocationType(MetadataMixin, models.Model): """A type of stock location like Warehouse, room, shelf, drawer. @@ -1706,9 +1709,12 @@ class StockItem( # Nullify the PK so a new record is created new_stock = StockItem.objects.get(pk=self.pk) new_stock.pk = None - new_stock.parent = self new_stock.quantity = quantity + # Update the new stock item to ensure the tree structure is observed + new_stock.parent = self + new_stock.level = self.level + 1 + # Move to the new location if specified, otherwise use current location if location: new_stock.location = location @@ -1748,6 +1754,19 @@ class StockItem( stockitem=new_stock, ) + # Rebuild the tree for this parent item + try: + StockItem.objects.partial_rebuild(tree_id=self.tree_id) + except Exception: + logger.warning('Rebuilding entire StockItem tree') + StockItem.objects.rebuild() + + # Attempt to reload the new item from the database + try: + new_stock.refresh_from_db() + except Exception: + pass + # Return a copy of the "new" stock item return new_stock diff --git a/InvenTree/stock/tests.py b/InvenTree/stock/tests.py index aa965a350b..6cc5e0e6bc 100644 --- a/InvenTree/stock/tests.py +++ b/InvenTree/stock/tests.py @@ -502,9 +502,18 @@ class StockTest(StockTestBase): ait = it.allocateToCustomer( customer, quantity=an, order=order, user=None, notes='Allocated some stock' ) + + self.assertEqual(ait.quantity, an) + self.assertTrue(ait.parent, it) + + # There should be only quantity 10x remaining + it.refresh_from_db() + self.assertEqual(it.quantity, 10) + ait.return_from_customer(it.location, None, notes='Stock removed from customer') # When returned stock is returned to its original (parent) location, check that the parent has correct quantity + it.refresh_from_db() self.assertEqual(it.quantity, n) ait = it.allocateToCustomer( @@ -987,6 +996,63 @@ class VariantTest(StockTestBase): item.save() +class StockTreeTest(StockTestBase): + """Unit test for StockItem tree structure.""" + + def test_stock_split(self): + """Test that stock splitting works correctly.""" + StockItem.objects.rebuild() + + part = Part.objects.create(name='My part', description='My part description') + location = StockLocation.objects.create(name='Test Location') + + # Create an initial stock item + item = StockItem.objects.create(part=part, quantity=1000, location=location) + + # Test that the initial MPTT values are correct + self.assertEqual(item.level, 0) + self.assertEqual(item.lft, 1) + self.assertEqual(item.rght, 2) + + children = [] + + self.assertEqual(item.get_descendants(include_self=False).count(), 0) + self.assertEqual(item.get_descendants(include_self=True).count(), 1) + + # Create child items by splitting stock + for idx in range(10): + child = item.splitStock(50, None, None) + children.append(child) + + # Check that the child item has been correctly created + self.assertEqual(child.parent.pk, item.pk) + self.assertEqual(child.tree_id, item.tree_id) + self.assertEqual(child.level, 1) + + item.refresh_from_db() + self.assertEqual(item.get_children().count(), idx + 1) + self.assertEqual(item.get_descendants(include_self=True).count(), idx + 2) + + item.refresh_from_db() + n = item.get_descendants(include_self=True).count() + + for child in children: + # Create multiple sub-childs + for _idx in range(3): + sub_child = child.splitStock(10, None, None) + self.assertEqual(sub_child.parent.pk, child.pk) + self.assertEqual(sub_child.tree_id, child.tree_id) + self.assertEqual(sub_child.level, 2) + + self.assertEqual(sub_child.get_ancestors(include_self=True).count(), 3) + + child.refresh_from_db() + self.assertEqual(child.get_descendants(include_self=True).count(), 4) + + item.refresh_from_db() + self.assertEqual(item.get_descendants(include_self=True).count(), n + 30) + + class TestResultTest(StockTestBase): """Tests for the StockItemTestResult model.""" @@ -1050,6 +1116,9 @@ class TestResultTest(StockTestBase): def test_duplicate_item_tests(self): """Test duplicate item behaviour.""" # Create an example stock item by copying one from the database (because we are lazy) + + StockItem.objects.rebuild() + item = StockItem.objects.get(pk=522) item.pk = None