2
0
mirror of https://github.com/inventree/InvenTree.git synced 2026-03-30 08:01:07 +00:00

[API] Category star fix (#11588)

* [API] Bug fix for PartStar and PartCategoryStar

- Logic refactor and fixes

* Add playwright tests

* Remove debug statements

* Revert API string changes
This commit is contained in:
Oliver
2026-03-21 23:47:11 +11:00
committed by GitHub
parent cf619b4184
commit 8e289a3208
3 changed files with 54 additions and 40 deletions

View File

@@ -74,20 +74,6 @@ class CategoryMixin:
queryset = part_serializers.CategorySerializer.annotate_queryset(queryset)
return queryset
def get_serializer_context(self):
"""Add extra context to the serializer for the CategoryDetail endpoint."""
ctx = super().get_serializer_context()
try:
ctx['starred_categories'] = [
star.category for star in self.request.user.starred_categories.all()
]
except AttributeError:
# Error is thrown if the view does not have an associated request
ctx['starred_categories'] = []
return ctx
class CategoryFilter(FilterSet):
"""Custom filterset class for the PartCategoryList endpoint."""
@@ -267,14 +253,13 @@ class CategoryDetail(CategoryMixin, OutputOptionsMixin, CustomRetrieveUpdateDest
"""Perform 'update' function and mark this part as 'starred' (or not)."""
# Clean up input data
data = self.clean_data(request.data)
response = super().update(request, *args, **kwargs)
if 'starred' in data:
starred = str2bool(data.get('starred', False))
self.get_object().set_starred(request.user, starred, include_parents=False)
response = super().update(request, *args, **kwargs)
return response
def destroy(self, request, *args, **kwargs):
@@ -1036,27 +1021,8 @@ class PartMixin(SerializerContextMixin):
# Indicate that we can create a new Part via this endpoint
kwargs['create'] = self.is_create
# Pass a list of "starred" parts to the current user to the serializer
# We do this to reduce the number of database queries required!
if (
self.starred_parts is None
and self.request is not None
and hasattr(self.request.user, 'starred_parts')
):
self.starred_parts = [
star.part for star in self.request.user.starred_parts.all()
]
kwargs['starred_parts'] = self.starred_parts
return super().get_serializer(*args, **kwargs)
def get_serializer_context(self):
"""Extend serializer context data."""
context = super().get_serializer_context()
context['request'] = self.request
return context
class PartOutputOptions(OutputConfiguration):
"""Output options for Part endpoints."""
@@ -1141,6 +1107,7 @@ class PartDetail(PartMixin, OutputOptionsMixin, RetrieveUpdateDestroyAPI):
"""
# Clean input data
data = self.clean_data(request.data)
response = super().update(request, *args, **kwargs)
if 'starred' in data:
starred = str2bool(data.get('starred', False))
@@ -1149,8 +1116,6 @@ class PartDetail(PartMixin, OutputOptionsMixin, RetrieveUpdateDestroyAPI):
request.user, starred, include_variants=False, include_categories=False
)
response = super().update(request, *args, **kwargs)
return response

View File

@@ -130,7 +130,16 @@ class CategorySerializer(
def get_starred(self, category) -> bool:
"""Return True if the category is directly "starred" by the current user."""
return category in self.context.get('starred_categories', [])
if not self.request or not self.request.user:
return False
# Cache the "starred_categories" list for the current user
if not hasattr(self, 'starred_categories'):
self.starred_categories = [
star.category.pk for star in self.request.user.starred_categories.all()
]
return category.pk in self.starred_categories
path = enable_filter(
FilterableListField(
@@ -638,7 +647,6 @@ class PartSerializer(
- Allows us to optionally pass extra fields based on the query.
"""
self.starred_parts = kwargs.pop('starred_parts', [])
create = kwargs.pop('create', False)
super().__init__(*args, **kwargs)
@@ -754,7 +762,16 @@ class PartSerializer(
def get_starred(self, part) -> bool:
"""Return "true" if the part is starred by the current user."""
return part in self.starred_parts
if not self.request or not self.request.user:
return False
# Cache the "starred_parts" list for the current user
if not hasattr(self, 'starred_parts'):
self.starred_parts = [
star.part.pk for star in self.request.user.starred_parts.all()
]
return part.pk in self.starred_parts
# Extra detail for the category
category_detail = enable_filter(

View File

@@ -97,6 +97,38 @@ test('Parts - Image Selection', async ({ browser }) => {
await page.getByText('The image has been removed successfully').waitFor();
});
// Test subscription logic for parts and categories
test('Parts - Subscriptions', async ({ browser }) => {
const page = await doCachedLogin(browser, { url: 'part/category/3/parts' });
// Click to subscribe to this category
await page
.getByRole('button', { name: 'action-button-subscribe-to-' })
.click();
await page.getByText('Subscription added').waitFor();
// Click to unsubscribe from this category
await page
.getByRole('button', { name: 'action-button-unsubscribe-' })
.click();
await page.getByText('Subscription removed').waitFor();
// Navigate through to a part detail page
await page.getByRole('cell', { name: 'Thumbnail M3x10 FHS-PLA' }).click();
// Click to subscribe to this part
await page
.getByRole('button', { name: 'action-button-subscribe-to-' })
.click();
await page.getByText('Subscription added').waitFor();
// Click to unsubscribe from this part
await page
.getByRole('button', { name: 'action-button-unsubscribe-' })
.click();
await page.getByText('Subscription removed').waitFor();
});
test('Parts - Manufacturer Parts', async ({ browser }) => {
const page = await doCachedLogin(browser, { url: 'part/84/' });