From 8e289a3208f76e74ea7215f041d9d361cfe6e326 Mon Sep 17 00:00:00 2001 From: Oliver Date: Sat, 21 Mar 2026 23:47:11 +1100 Subject: [PATCH] [API] Category star fix (#11588) * [API] Bug fix for PartStar and PartCategoryStar - Logic refactor and fixes * Add playwright tests * Remove debug statements * Revert API string changes --- src/backend/InvenTree/part/api.py | 39 ++--------------------- src/backend/InvenTree/part/serializers.py | 23 +++++++++++-- src/frontend/tests/pages/pui_part.spec.ts | 32 +++++++++++++++++++ 3 files changed, 54 insertions(+), 40 deletions(-) diff --git a/src/backend/InvenTree/part/api.py b/src/backend/InvenTree/part/api.py index 38bb79834b..dbf8cf815a 100644 --- a/src/backend/InvenTree/part/api.py +++ b/src/backend/InvenTree/part/api.py @@ -74,20 +74,6 @@ class CategoryMixin: queryset = part_serializers.CategorySerializer.annotate_queryset(queryset) return queryset - def get_serializer_context(self): - """Add extra context to the serializer for the CategoryDetail endpoint.""" - ctx = super().get_serializer_context() - - try: - ctx['starred_categories'] = [ - star.category for star in self.request.user.starred_categories.all() - ] - except AttributeError: - # Error is thrown if the view does not have an associated request - ctx['starred_categories'] = [] - - return ctx - class CategoryFilter(FilterSet): """Custom filterset class for the PartCategoryList endpoint.""" @@ -267,14 +253,13 @@ class CategoryDetail(CategoryMixin, OutputOptionsMixin, CustomRetrieveUpdateDest """Perform 'update' function and mark this part as 'starred' (or not).""" # Clean up input data data = self.clean_data(request.data) + response = super().update(request, *args, **kwargs) if 'starred' in data: starred = str2bool(data.get('starred', False)) self.get_object().set_starred(request.user, starred, include_parents=False) - response = super().update(request, *args, **kwargs) - return response def destroy(self, request, *args, **kwargs): @@ -1036,27 +1021,8 @@ class PartMixin(SerializerContextMixin): # Indicate that we can create a new Part via this endpoint kwargs['create'] = self.is_create - # Pass a list of "starred" parts to the current user to the serializer - # We do this to reduce the number of database queries required! - if ( - self.starred_parts is None - and self.request is not None - and hasattr(self.request.user, 'starred_parts') - ): - self.starred_parts = [ - star.part for star in self.request.user.starred_parts.all() - ] - kwargs['starred_parts'] = self.starred_parts - return super().get_serializer(*args, **kwargs) - def get_serializer_context(self): - """Extend serializer context data.""" - context = super().get_serializer_context() - context['request'] = self.request - - return context - class PartOutputOptions(OutputConfiguration): """Output options for Part endpoints.""" @@ -1141,6 +1107,7 @@ class PartDetail(PartMixin, OutputOptionsMixin, RetrieveUpdateDestroyAPI): """ # Clean input data data = self.clean_data(request.data) + response = super().update(request, *args, **kwargs) if 'starred' in data: starred = str2bool(data.get('starred', False)) @@ -1149,8 +1116,6 @@ class PartDetail(PartMixin, OutputOptionsMixin, RetrieveUpdateDestroyAPI): request.user, starred, include_variants=False, include_categories=False ) - response = super().update(request, *args, **kwargs) - return response diff --git a/src/backend/InvenTree/part/serializers.py b/src/backend/InvenTree/part/serializers.py index 8102ac4641..8788ca8624 100644 --- a/src/backend/InvenTree/part/serializers.py +++ b/src/backend/InvenTree/part/serializers.py @@ -130,7 +130,16 @@ class CategorySerializer( def get_starred(self, category) -> bool: """Return True if the category is directly "starred" by the current user.""" - return category in self.context.get('starred_categories', []) + if not self.request or not self.request.user: + return False + + # Cache the "starred_categories" list for the current user + if not hasattr(self, 'starred_categories'): + self.starred_categories = [ + star.category.pk for star in self.request.user.starred_categories.all() + ] + + return category.pk in self.starred_categories path = enable_filter( FilterableListField( @@ -638,7 +647,6 @@ class PartSerializer( - Allows us to optionally pass extra fields based on the query. """ - self.starred_parts = kwargs.pop('starred_parts', []) create = kwargs.pop('create', False) super().__init__(*args, **kwargs) @@ -754,7 +762,16 @@ class PartSerializer( def get_starred(self, part) -> bool: """Return "true" if the part is starred by the current user.""" - return part in self.starred_parts + if not self.request or not self.request.user: + return False + + # Cache the "starred_parts" list for the current user + if not hasattr(self, 'starred_parts'): + self.starred_parts = [ + star.part.pk for star in self.request.user.starred_parts.all() + ] + + return part.pk in self.starred_parts # Extra detail for the category category_detail = enable_filter( diff --git a/src/frontend/tests/pages/pui_part.spec.ts b/src/frontend/tests/pages/pui_part.spec.ts index c8554bba5b..3277980488 100644 --- a/src/frontend/tests/pages/pui_part.spec.ts +++ b/src/frontend/tests/pages/pui_part.spec.ts @@ -97,6 +97,38 @@ test('Parts - Image Selection', async ({ browser }) => { await page.getByText('The image has been removed successfully').waitFor(); }); +// Test subscription logic for parts and categories +test('Parts - Subscriptions', async ({ browser }) => { + const page = await doCachedLogin(browser, { url: 'part/category/3/parts' }); + + // Click to subscribe to this category + await page + .getByRole('button', { name: 'action-button-subscribe-to-' }) + .click(); + await page.getByText('Subscription added').waitFor(); + + // Click to unsubscribe from this category + await page + .getByRole('button', { name: 'action-button-unsubscribe-' }) + .click(); + await page.getByText('Subscription removed').waitFor(); + + // Navigate through to a part detail page + await page.getByRole('cell', { name: 'Thumbnail M3x10 FHS-PLA' }).click(); + + // Click to subscribe to this part + await page + .getByRole('button', { name: 'action-button-subscribe-to-' }) + .click(); + await page.getByText('Subscription added').waitFor(); + + // Click to unsubscribe from this part + await page + .getByRole('button', { name: 'action-button-unsubscribe-' }) + .click(); + await page.getByText('Subscription removed').waitFor(); +}); + test('Parts - Manufacturer Parts', async ({ browser }) => { const page = await doCachedLogin(browser, { url: 'part/84/' });