diff --git a/src/backend/InvenTree/InvenTree/serializers.py b/src/backend/InvenTree/InvenTree/serializers.py index d84ac452d0..14c24de598 100644 --- a/src/backend/InvenTree/InvenTree/serializers.py +++ b/src/backend/InvenTree/InvenTree/serializers.py @@ -85,6 +85,9 @@ class FilterableSerializerMixin: _was_filtered = False no_filters = False + """If True, do not raise an exception if no filterable fields are found.""" + filter_on_query = True + """If True, also look for filter parameters in the request query parameters.""" def __init__(self, *args, **kwargs): """Initialization routine for the serializer. This gathers and applies filters through kwargs.""" @@ -113,12 +116,24 @@ class FilterableSerializerMixin: if getattr(a, 'is_filterable', None) } + # Gather query parameters from the request context + query_params = {} + if context := kwargs.get('context', {}): + query_params = dict(getattr(context.get('request', {}), 'query_params', {})) + # Remove filter args from kwargs to avoid issues with super().__init__ poped_kwargs = {} # store popped kwargs as a arg might be reused for multiple fields tgs_vals: dict[str, bool] = {} for k, v in self.filter_targets.items(): pop_ref = v['filter_name'] or k val = kwargs.pop(pop_ref, poped_kwargs.get(pop_ref)) + + # Optionally also look in query parameters + if val is None and self.filter_on_query: + val = query_params.pop(pop_ref, None) + if isinstance(val, list) and len(val) == 1: + val = val[0] + if val: # Save popped value for reuse poped_kwargs[pop_ref] = val tgs_vals[k] = ( diff --git a/src/backend/InvenTree/InvenTree/test_serializers.py b/src/backend/InvenTree/InvenTree/test_serializers.py index 259eadc764..d913abbf5b 100644 --- a/src/backend/InvenTree/InvenTree/test_serializers.py +++ b/src/backend/InvenTree/InvenTree/test_serializers.py @@ -77,3 +77,17 @@ class FilteredSerializers(InvenTreeAPITestCase): self.assertNotContains(response, 'field_b') self.assertContains(response, 'field_c') self.assertContains(response, 'field_d') + + # Request with filter for field_b + response = self.client.get(url, {'field_b': True}) + self.assertContains(response, 'field_a') + self.assertContains(response, 'field_b') + self.assertContains(response, 'field_c') + self.assertContains(response, 'field_d') + + # Disable field_c using custom filter name + response = self.client.get(url, {'crazy_name': 'false'}) + self.assertContains(response, 'field_a') + self.assertNotContains(response, 'field_b') + self.assertNotContains(response, 'field_c') + self.assertNotContains(response, 'field_d')