2
0
mirror of https://github.com/inventree/InvenTree.git synced 2025-12-25 13:43:30 +00:00

Data export fix (#11055)

* Only look at query_params for top-level serializers

- Nested serializers should *not* look at query params

* Prevent all fields when exporting data

* Add unit test for large dataset export

* Fix code

* Pass through via context rather than primary kwarg

* Fix for file download

* Ensure request is passed through to the serializer

* ensure query params are passed through when exporting data

* Fix code comment

* Fix for unit test helper func

* Increase max export time
This commit is contained in:
Oliver
2025-12-23 08:46:41 +11:00
committed by GitHub
parent 947a3e26a0
commit 9d2ac521ef
7 changed files with 110 additions and 20 deletions

View File

@@ -239,10 +239,17 @@ class OutputOptionsMixin:
def get_serializer(self, *args, **kwargs): def get_serializer(self, *args, **kwargs):
"""Return serializer instance with output options applied.""" """Return serializer instance with output options applied."""
if self.output_options and hasattr(self, 'request'): request = getattr(self, 'request', None)
if self.output_options and request:
params = self.request.query_params params = self.request.query_params
kwargs.update(self.output_options.format_params(params)) kwargs.update(self.output_options.format_params(params))
# Ensure the request is included in the serializer context
context = kwargs.get('context', {})
context['request'] = request
kwargs['context'] = context
serializer = super().get_serializer(*args, **kwargs) serializer = super().get_serializer(*args, **kwargs)
# Check if the serializer actually can be filtered - makes not much sense to use this mixin without that prerequisite # Check if the serializer actually can be filtered - makes not much sense to use this mixin without that prerequisite

View File

@@ -163,10 +163,26 @@ class FilterableSerializerMixin:
def gather_filters(self, kwargs) -> None: def gather_filters(self, kwargs) -> None:
"""Gather filterable fields through introspection.""" """Gather filterable fields through introspection."""
context = kwargs.get('context', {})
top_level_serializer = context.get('top_level_serializer', None)
request = context.get('request', None) or getattr(self, 'request', None)
# Gather query parameters from the request context
query_params = dict(getattr(request, 'query_params', {})) if request else {}
is_top_level = (
top_level_serializer is None
or top_level_serializer == self.__class__.__name__
)
# Update the context to ensure that the top_level_serializer flag is removed for nested serializers
if top_level_serializer is None:
context['top_level_serializer'] = self.__class__.__name__
kwargs['context'] = context
# Fast exit if this has already been done or would not have any effect # Fast exit if this has already been done or would not have any effect
if getattr(self, '_was_filtered', False) or not hasattr(self, 'fields'): if getattr(self, '_was_filtered', False) or not hasattr(self, 'fields'):
return return
self._was_filtered = True
# Actually gather the filterable fields # Actually gather the filterable fields
# Also see `enable_filter` where` is_filterable and is_filterable_vals are set # Also see `enable_filter` where` is_filterable and is_filterable_vals are set
@@ -176,21 +192,22 @@ class FilterableSerializerMixin:
if getattr(a, 'is_filterable', None) if getattr(a, 'is_filterable', None)
} }
# Gather query parameters from the request context
query_params = {}
if context := kwargs.get('context', {}):
query_params = dict(getattr(context.get('request', {}), 'query_params', {}))
# Remove filter args from kwargs to avoid issues with super().__init__ # Remove filter args from kwargs to avoid issues with super().__init__
popped_kwargs = {} # store popped kwargs as a arg might be reused for multiple fields popped_kwargs = {} # store popped kwargs as a arg might be reused for multiple fields
tgs_vals: dict[str, bool] = {} tgs_vals: dict[str, bool] = {}
for k, v in self.filter_targets.items(): for k, v in self.filter_targets.items():
pop_ref = v['filter_name'] or k pop_ref = v['filter_name'] or k
val = kwargs.pop(pop_ref, popped_kwargs.get(pop_ref)) val = kwargs.pop(pop_ref, popped_kwargs.get(pop_ref))
# Optionally also look in query parameters # Optionally also look in query parameters
if val is None and self.filter_on_query and v.get('filter_by_query', True): # Note that we only do this for a top-level serializer, to avoid issues with nested serializers
if (
is_top_level
and val is None
and self.filter_on_query
and v.get('filter_by_query', True)
):
val = query_params.pop(pop_ref, None) val = query_params.pop(pop_ref, None)
if isinstance(val, list) and len(val) == 1: if isinstance(val, list) and len(val) == 1:
val = val[0] val = val[0]
@@ -199,7 +216,9 @@ class FilterableSerializerMixin:
tgs_vals[k] = ( tgs_vals[k] = (
str2bool(val) if isinstance(val, (str, int, float)) else val str2bool(val) if isinstance(val, (str, int, float)) else val
) # Support for various filtering style for backwards compatibility ) # Support for various filtering style for backwards compatibility
self.filter_target_values = tgs_vals self.filter_target_values = tgs_vals
self._was_filtered = True
# Ensure this mixin is not broadly applied as it is expensive on scale (total CI time increased by 21% when running all coverage tests) # Ensure this mixin is not broadly applied as it is expensive on scale (total CI time increased by 21% when running all coverage tests)
if len(self.filter_targets) == 0 and not self.no_filters: if len(self.filter_targets) == 0 and not self.no_filters:
@@ -216,14 +235,12 @@ class FilterableSerializerMixin:
): ):
return return
# Skip filtering when exporting data - leave all fields intact is_exporting = getattr(self, '_exporting_data', False)
if getattr(self, '_exporting_data', False):
return
# Skip filtering for a write requests - all fields should be present for data creation # Skip filtering for a write requests - all fields should be present for data creation
if request := self.context.get('request', None): if request := self.context.get('request', None):
if method := getattr(request, 'method', None): if method := getattr(request, 'method', None):
if str(method).lower() in ['post', 'put', 'patch']: if str(method).lower() in ['post', 'put', 'patch'] and not is_exporting:
return return
# Throw out fields which are not requested (either by default or explicitly) # Throw out fields which are not requested (either by default or explicitly)

View File

@@ -157,7 +157,7 @@ class FilteredSerializers(InvenTreeAPITestCase):
_ = BadSerializer() _ = BadSerializer()
self.assertTrue(True) # Dummy assertion to ensure we reach here self.assertTrue(True) # Dummy assertion to ensure we reach here
def test_failiure_OutputOptionsMixin(self): def test_failure_OutputOptionsMixin(self):
"""Test failure case for OutputOptionsMixin.""" """Test failure case for OutputOptionsMixin."""
class BadSerializer(InvenTree.serializers.InvenTreeModelSerializer): class BadSerializer(InvenTree.serializers.InvenTreeModelSerializer):

View File

@@ -655,7 +655,9 @@ class InvenTreeAPITestCase(
# Append URL params # Append URL params
url += '?' + '&'.join([f'{key}={value}' for key, value in params.items()]) url += '?' + '&'.join([f'{key}={value}' for key, value in params.items()])
response = self.client.get(url, data=None, format='json') response = self.get(
url, data=None, format='json', expected_code=expected_code, **kwargs
)
self.check_response(url, response, expected_code=expected_code) self.check_response(url, response, expected_code=expected_code)
# Check that the response is of the correct type # Check that the response is of the correct type

View File

@@ -337,6 +337,12 @@ class DataExportViewMixin:
# Update the output instance with the total number of items to export # Update the output instance with the total number of items to export
output.total = queryset.count() output.total = queryset.count()
output.save() output.save()
request = context.get('request', None)
if request:
query_params = getattr(request, 'query_params', {})
context.update(**query_params)
context['request'] = request
data = None data = None
serializer = serializer_class(context=context, exporting=True) serializer = serializer_class(context=context, exporting=True)
@@ -363,7 +369,12 @@ class DataExportViewMixin:
# The returned data *must* be a list of dict objects # The returned data *must* be a list of dict objects
try: try:
data = export_plugin.export_data( data = export_plugin.export_data(
queryset, serializer_class, headers, export_context, output queryset,
serializer_class,
headers,
export_context,
output,
serializer_context=context,
) )
except Exception as e: except Exception as e:

View File

@@ -90,6 +90,7 @@ class DataExportMixin:
headers: OrderedDict, headers: OrderedDict,
context: dict, context: dict,
output: DataOutput, output: DataOutput,
serializer_context: Optional[dict] = None,
**kwargs, **kwargs,
) -> list: ) -> list:
"""Export data from the queryset. """Export data from the queryset.
@@ -100,6 +101,7 @@ class DataExportMixin:
Arguments: Arguments:
queryset: The queryset to export queryset: The queryset to export
serializer_class: The serializer class to use for exporting the data serializer_class: The serializer class to use for exporting the data
serializer_context: Optional context for the serializer
headers: The headers for the export headers: The headers for the export
context: Any custom context for the export (provided by the plugin serializer) context: Any custom context for the export (provided by the plugin serializer)
output: The DataOutput object for the export output: The DataOutput object for the export
@@ -107,7 +109,9 @@ class DataExportMixin:
Returns: The exported data (a list of dict objects) Returns: The exported data (a list of dict objects)
""" """
# The default implementation simply serializes the queryset # The default implementation simply serializes the queryset
return serializer_class(queryset, many=True, exporting=True).data return serializer_class(
queryset, many=True, exporting=True, context=serializer_context or {}
).data
def get_export_options_serializer(self, **kwargs) -> serializers.Serializer | None: def get_export_options_serializer(self, **kwargs) -> serializers.Serializer | None:
"""Return a serializer class with dynamic export options for this plugin. """Return a serializer class with dynamic export options for this plugin.

View File

@@ -867,9 +867,13 @@ class StockItemListTest(StockAPITestCase):
excluded_headers = ['metadata'] excluded_headers = ['metadata']
filters = {} filters = {
'part_detail': True,
'location_detail': True,
'supplier_part_detail': True,
}
with self.export_data(self.list_url, filters) as data_file: with self.export_data(self.list_url, params=filters) as data_file:
self.process_csv( self.process_csv(
data_file, data_file,
required_cols=required_headers, required_cols=required_headers,
@@ -881,7 +885,7 @@ class StockItemListTest(StockAPITestCase):
filters['location'] = 1 filters['location'] = 1
filters['cascade'] = True filters['cascade'] = True
with self.export_data(self.list_url, filters) as data_file: with self.export_data(self.list_url, params=filters) as data_file:
data = self.process_csv(data_file, required_rows=9) data = self.process_csv(data_file, required_rows=9)
for row in data: for row in data:
@@ -909,6 +913,51 @@ class StockItemListTest(StockAPITestCase):
with self.export_data(self.list_url, {'part': 25}) as data_file: with self.export_data(self.list_url, {'part': 25}) as data_file:
self.process_csv(data_file, required_rows=items.count()) self.process_csv(data_file, required_rows=items.count())
def test_large_export(self):
"""Test export of very large dataset.
- Ensure that the time taken to export a large dataset is reasonable.
- Ensure that the number of DB queries is reasonable.
"""
# Create a large number of stock items
locations = list(StockLocation.objects.all())
parts = list(Part.objects.filter(virtual=False))
idx = 0
N_LOCATIONS = len(locations)
N_PARTS = len(parts)
stock_items = []
while idx < 2500:
part = parts[idx % N_PARTS]
location = locations[idx % N_LOCATIONS]
item = StockItem(
part=part,
location=location,
quantity=10,
level=0,
tree_id=0,
lft=0,
rght=0,
)
stock_items.append(item)
idx += 1
StockItem.objects.bulk_create(stock_items)
self.assertGreaterEqual(StockItem.objects.count(), 2500)
# Note: While the export is quick on pgsql, it is still quite slow on sqlite3
with self.export_data(
self.list_url, max_query_count=50, max_query_time=7.5
) as data_file:
data = self.process_csv(data_file)
self.assertGreaterEqual(len(data), 2500)
def test_filter_by_allocated(self): def test_filter_by_allocated(self):
"""Test that we can filter by "allocated" status. """Test that we can filter by "allocated" status.