diff --git a/src/backend/InvenTree/part/api.py b/src/backend/InvenTree/part/api.py index dfe32cd689..55bf695eb0 100644 --- a/src/backend/InvenTree/part/api.py +++ b/src/backend/InvenTree/part/api.py @@ -73,20 +73,6 @@ class CategoryMixin: queryset = part_serializers.CategorySerializer.annotate_queryset(queryset) return queryset - def get_serializer_context(self): - """Add extra context to the serializer for the CategoryDetail endpoint.""" - ctx = super().get_serializer_context() - - try: - ctx['starred_categories'] = [ - star.category for star in self.request.user.starred_categories.all() - ] - except AttributeError: - # Error is thrown if the view does not have an associated request - ctx['starred_categories'] = [] - - return ctx - class CategoryFilter(FilterSet): """Custom filterset class for the PartCategoryList endpoint.""" @@ -266,14 +252,13 @@ class CategoryDetail(CategoryMixin, OutputOptionsMixin, CustomRetrieveUpdateDest """Perform 'update' function and mark this part as 'starred' (or not).""" # Clean up input data data = self.clean_data(request.data) + response = super().update(request, *args, **kwargs) if 'starred' in data: starred = str2bool(data.get('starred', False)) self.get_object().set_starred(request.user, starred, include_parents=False) - response = super().update(request, *args, **kwargs) - return response def destroy(self, request, *args, **kwargs): @@ -1027,27 +1012,8 @@ class PartMixin(SerializerContextMixin): # Indicate that we can create a new Part via this endpoint kwargs['create'] = self.is_create - # Pass a list of "starred" parts to the current user to the serializer - # We do this to reduce the number of database queries required! - if ( - self.starred_parts is None - and self.request is not None - and hasattr(self.request.user, 'starred_parts') - ): - self.starred_parts = [ - star.part for star in self.request.user.starred_parts.all() - ] - kwargs['starred_parts'] = self.starred_parts - return super().get_serializer(*args, **kwargs) - def get_serializer_context(self): - """Extend serializer context data.""" - context = super().get_serializer_context() - context['request'] = self.request - - return context - class PartOutputOptions(OutputConfiguration): """Output options for Part endpoints.""" @@ -1132,6 +1098,7 @@ class PartDetail(PartMixin, OutputOptionsMixin, RetrieveUpdateDestroyAPI): """ # Clean input data data = self.clean_data(request.data) + response = super().update(request, *args, **kwargs) if 'starred' in data: starred = str2bool(data.get('starred', False)) @@ -1140,8 +1107,6 @@ class PartDetail(PartMixin, OutputOptionsMixin, RetrieveUpdateDestroyAPI): request.user, starred, include_variants=False, include_categories=False ) - response = super().update(request, *args, **kwargs) - return response diff --git a/src/backend/InvenTree/part/serializers.py b/src/backend/InvenTree/part/serializers.py index a0a00933eb..c1b6814eb1 100644 --- a/src/backend/InvenTree/part/serializers.py +++ b/src/backend/InvenTree/part/serializers.py @@ -129,7 +129,16 @@ class CategorySerializer( def get_starred(self, category) -> bool: """Return True if the category is directly "starred" by the current user.""" - return category in self.context.get('starred_categories', []) + if not self.request or not self.request.user: + return False + + # Cache the "starred_categories" list for the current user + if not hasattr(self, 'starred_categories'): + self.starred_categories = [ + star.category.pk for star in self.request.user.starred_categories.all() + ] + + return category.pk in self.starred_categories path = enable_filter( FilterableListField( @@ -638,7 +647,6 @@ class PartSerializer( - Allows us to optionally pass extra fields based on the query. """ - self.starred_parts = kwargs.pop('starred_parts', []) create = kwargs.pop('create', False) super().__init__(*args, **kwargs) @@ -754,7 +762,16 @@ class PartSerializer( def get_starred(self, part) -> bool: """Return "true" if the part is starred by the current user.""" - return part in self.starred_parts + if not self.request or not self.request.user: + return False + + # Cache the "starred_parts" list for the current user + if not hasattr(self, 'starred_parts'): + self.starred_parts = [ + star.part.pk for star in self.request.user.starred_parts.all() + ] + + return part.pk in self.starred_parts # Extra detail for the category category_detail = enable_filter(