diff --git a/src/backend/InvenTree/order/api.py b/src/backend/InvenTree/order/api.py index fc4bb4763c..52520808dd 100644 --- a/src/backend/InvenTree/order/api.py +++ b/src/backend/InvenTree/order/api.py @@ -75,6 +75,24 @@ class GeneralExtraLineList(DataExportViewMixin): filterset_fields = ['order'] +class OrderCreateMixin: + """Mixin class which handles order creation via API.""" + + def create(self, request, *args, **kwargs): + """Save user information on order creation.""" + serializer = self.get_serializer(data=self.clean_data(request.data)) + serializer.is_valid(raise_exception=True) + + item = serializer.save() + item.created_by = request.user + item.save() + + headers = self.get_success_headers(serializer.data) + return Response( + serializer.data, status=status.HTTP_201_CREATED, headers=headers + ) + + class OrderFilter(rest_filters.FilterSet): """Base class for custom API filters for the OrderList endpoint.""" @@ -266,7 +284,9 @@ class PurchaseOrderMixin: return queryset -class PurchaseOrderList(PurchaseOrderMixin, DataExportViewMixin, ListCreateAPI): +class PurchaseOrderList( + PurchaseOrderMixin, OrderCreateMixin, DataExportViewMixin, ListCreateAPI +): """API endpoint for accessing a list of PurchaseOrder objects. - GET: Return list of PurchaseOrder objects (with filters) @@ -728,7 +748,9 @@ class SalesOrderMixin: return queryset -class SalesOrderList(SalesOrderMixin, DataExportViewMixin, ListCreateAPI): +class SalesOrderList( + SalesOrderMixin, OrderCreateMixin, DataExportViewMixin, ListCreateAPI +): """API endpoint for accessing a list of SalesOrder objects. - GET: Return list of SalesOrder objects (with filters) @@ -737,20 +759,6 @@ class SalesOrderList(SalesOrderMixin, DataExportViewMixin, ListCreateAPI): filterset_class = SalesOrderFilter - def create(self, request, *args, **kwargs): - """Save user information on create.""" - serializer = self.get_serializer(data=self.clean_data(request.data)) - serializer.is_valid(raise_exception=True) - - item = serializer.save() - item.created_by = request.user - item.save() - - headers = self.get_success_headers(serializer.data) - return Response( - serializer.data, status=status.HTTP_201_CREATED, headers=headers - ) - def filter_queryset(self, queryset): """Perform custom filtering operations on the SalesOrder queryset.""" queryset = super().filter_queryset(queryset) @@ -1345,25 +1353,13 @@ class ReturnOrderMixin: return queryset -class ReturnOrderList(ReturnOrderMixin, DataExportViewMixin, ListCreateAPI): +class ReturnOrderList( + ReturnOrderMixin, OrderCreateMixin, DataExportViewMixin, ListCreateAPI +): """API endpoint for accessing a list of ReturnOrder objects.""" filterset_class = ReturnOrderFilter - def create(self, request, *args, **kwargs): - """Save user information on create.""" - serializer = self.get_serializer(data=self.clean_data(request.data)) - serializer.is_valid(raise_exception=True) - - item = serializer.save() - item.created_by = request.user - item.save() - - headers = self.get_success_headers(serializer.data) - return Response( - serializer.data, status=status.HTTP_201_CREATED, headers=headers - ) - filter_backends = SEARCH_ORDER_FILTER_ALIAS ordering_field_aliases = { diff --git a/src/backend/InvenTree/order/test_api.py b/src/backend/InvenTree/order/test_api.py index 2ace5cecb0..97cdc0d3ee 100644 --- a/src/backend/InvenTree/order/test_api.py +++ b/src/backend/InvenTree/order/test_api.py @@ -255,7 +255,10 @@ class PurchaseOrderTest(OrderTest): order = models.PurchaseOrder.objects.get(pk=response.data['pk']) - self.assertEqual(order.reference, 'PO-92233720368547758089999999999999999') + # Check that the created_by field is set correctly + self.assertEqual(order.created_by.username, 'testuser') + + self.assertEqual(order.reference, huge_number) self.assertEqual(order.reference_int, 0x7FFFFFFF) def test_po_reference_wildcard_default(self): @@ -1407,6 +1410,11 @@ class SalesOrderTest(OrderTest): # Grab the PK for the newly created SalesOrder pk = response.data['pk'] + # Basic checks against the newly created SalesOrder + so = models.SalesOrder.objects.get(pk=pk) + self.assertEqual(so.reference, 'SO-12345') + self.assertEqual(so.created_by.username, 'testuser') + # Try to create a SO with identical reference (should fail) response = self.post( url,