From 5a9c19492b07266ebf647f089c3917b9dbd153a8 Mon Sep 17 00:00:00 2001 From: Oliver Walters Date: Mon, 27 May 2019 22:44:13 +1000 Subject: [PATCH] Design an aggregation filter for stock items - If 'aggregate=1' is sent to the stock API, aggregate the returned stock items by part and location - Suprisingly this actually works right out of the gate --- InvenTree/stock/api.py | 58 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 57 insertions(+), 1 deletion(-) diff --git a/InvenTree/stock/api.py b/InvenTree/stock/api.py index ec998dde41..525352f0e6 100644 --- a/InvenTree/stock/api.py +++ b/InvenTree/stock/api.py @@ -8,6 +8,8 @@ from django_filters import NumberFilter from django.conf.urls import url, include from django.urls import reverse +from django.db.models import Sum, Count + from .models import StockLocation, StockItem from .models import StockItemTracking @@ -241,6 +243,7 @@ class StockList(generics.ListCreateAPIView): - POST: Create a new StockItem Additional query parameters are available: + - aggregate: If 'true' then stock items are aggregated by Part and Location - location: Filter stock by location - category: Filter by parts belonging to a certain category - supplier: Filter by supplier @@ -257,6 +260,52 @@ class StockList(generics.ListCreateAPIView): kwargs['context'] = self.get_serializer_context() return self.serializer_class(*args, **kwargs) + def list(self, request, *args, **kwargs): + + queryset = self.filter_queryset(self.get_queryset()) + + if str2bool(self.request.GET.get('aggregate', None)): + # Aggregate stock by part type + queryset = queryset.values( + 'part', + 'part__name', + 'part__image', + 'location', + 'location__name').annotate( + stock=Sum('quantity'), + items=Count('part')) + + for result in queryset: + # If there is only 1 stock item (which will be a lot of the time), + # Add that data to the dict + if result['items'] == 1: + items = StockItem.objects.filter( + part=result['part'], + location=result['location'] + ) + + if items.count() == 1: + result.pop('items') + + item = items[0] + + # Add in some extra information specific to this StockItem + result['pk'] = item.id + result['serial'] = item.serial + result['batch'] = item.batch + result['notes'] = item.notes + + return Response(queryset) + + page = self.paginate_queryset(queryset) + if page is not None: + serializer = self.get_serializer(page, many=True) + return self.get_paginated_response(serializer.data) + + serializer = self.get_serializer(queryset, many=True) + return Response(serializer.data) + + def get_queryset(self): """ If the query includes a particular location, @@ -264,7 +313,7 @@ class StockList(generics.ListCreateAPIView): """ # Start with all objects - stock_list = StockItem.objects.all() + stock_list = StockItem.objects.filter(customer=None, belongs_to=None) # Does the client wish to filter by the Part ID? part_id = self.request.query_params.get('part', None) @@ -313,6 +362,13 @@ class StockList(generics.ListCreateAPIView): # Pre-fetch related objects for better response time stock_list = self.get_serializer_class().setup_eager_loading(stock_list) + # Also ensure that we pre-fecth all the related items + stock_list = stock_list.prefetch_related( + 'part', + 'part__category', + 'location' + ) + return stock_list serializer_class = StockItemSerializer