mirror of
https://github.com/inventree/InvenTree.git
synced 2025-12-25 13:43:30 +00:00
Data export fix (#11055)
* Only look at query_params for top-level serializers - Nested serializers should *not* look at query params * Prevent all fields when exporting data * Add unit test for large dataset export * Fix code * Pass through via context rather than primary kwarg * Fix for file download * Ensure request is passed through to the serializer * ensure query params are passed through when exporting data * Fix code comment * Fix for unit test helper func * Increase max export time
This commit is contained in:
@@ -239,10 +239,17 @@ class OutputOptionsMixin:
|
|||||||
|
|
||||||
def get_serializer(self, *args, **kwargs):
|
def get_serializer(self, *args, **kwargs):
|
||||||
"""Return serializer instance with output options applied."""
|
"""Return serializer instance with output options applied."""
|
||||||
if self.output_options and hasattr(self, 'request'):
|
request = getattr(self, 'request', None)
|
||||||
|
|
||||||
|
if self.output_options and request:
|
||||||
params = self.request.query_params
|
params = self.request.query_params
|
||||||
kwargs.update(self.output_options.format_params(params))
|
kwargs.update(self.output_options.format_params(params))
|
||||||
|
|
||||||
|
# Ensure the request is included in the serializer context
|
||||||
|
context = kwargs.get('context', {})
|
||||||
|
context['request'] = request
|
||||||
|
kwargs['context'] = context
|
||||||
|
|
||||||
serializer = super().get_serializer(*args, **kwargs)
|
serializer = super().get_serializer(*args, **kwargs)
|
||||||
|
|
||||||
# Check if the serializer actually can be filtered - makes not much sense to use this mixin without that prerequisite
|
# Check if the serializer actually can be filtered - makes not much sense to use this mixin without that prerequisite
|
||||||
|
|||||||
@@ -163,10 +163,26 @@ class FilterableSerializerMixin:
|
|||||||
|
|
||||||
def gather_filters(self, kwargs) -> None:
|
def gather_filters(self, kwargs) -> None:
|
||||||
"""Gather filterable fields through introspection."""
|
"""Gather filterable fields through introspection."""
|
||||||
|
context = kwargs.get('context', {})
|
||||||
|
top_level_serializer = context.get('top_level_serializer', None)
|
||||||
|
request = context.get('request', None) or getattr(self, 'request', None)
|
||||||
|
|
||||||
|
# Gather query parameters from the request context
|
||||||
|
query_params = dict(getattr(request, 'query_params', {})) if request else {}
|
||||||
|
|
||||||
|
is_top_level = (
|
||||||
|
top_level_serializer is None
|
||||||
|
or top_level_serializer == self.__class__.__name__
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update the context to ensure that the top_level_serializer flag is removed for nested serializers
|
||||||
|
if top_level_serializer is None:
|
||||||
|
context['top_level_serializer'] = self.__class__.__name__
|
||||||
|
kwargs['context'] = context
|
||||||
|
|
||||||
# Fast exit if this has already been done or would not have any effect
|
# Fast exit if this has already been done or would not have any effect
|
||||||
if getattr(self, '_was_filtered', False) or not hasattr(self, 'fields'):
|
if getattr(self, '_was_filtered', False) or not hasattr(self, 'fields'):
|
||||||
return
|
return
|
||||||
self._was_filtered = True
|
|
||||||
|
|
||||||
# Actually gather the filterable fields
|
# Actually gather the filterable fields
|
||||||
# Also see `enable_filter` where` is_filterable and is_filterable_vals are set
|
# Also see `enable_filter` where` is_filterable and is_filterable_vals are set
|
||||||
@@ -176,21 +192,22 @@ class FilterableSerializerMixin:
|
|||||||
if getattr(a, 'is_filterable', None)
|
if getattr(a, 'is_filterable', None)
|
||||||
}
|
}
|
||||||
|
|
||||||
# Gather query parameters from the request context
|
|
||||||
query_params = {}
|
|
||||||
if context := kwargs.get('context', {}):
|
|
||||||
query_params = dict(getattr(context.get('request', {}), 'query_params', {}))
|
|
||||||
|
|
||||||
# Remove filter args from kwargs to avoid issues with super().__init__
|
# Remove filter args from kwargs to avoid issues with super().__init__
|
||||||
popped_kwargs = {} # store popped kwargs as a arg might be reused for multiple fields
|
popped_kwargs = {} # store popped kwargs as a arg might be reused for multiple fields
|
||||||
tgs_vals: dict[str, bool] = {}
|
tgs_vals: dict[str, bool] = {}
|
||||||
for k, v in self.filter_targets.items():
|
for k, v in self.filter_targets.items():
|
||||||
pop_ref = v['filter_name'] or k
|
pop_ref = v['filter_name'] or k
|
||||||
val = kwargs.pop(pop_ref, popped_kwargs.get(pop_ref))
|
val = kwargs.pop(pop_ref, popped_kwargs.get(pop_ref))
|
||||||
|
|
||||||
# Optionally also look in query parameters
|
# Optionally also look in query parameters
|
||||||
if val is None and self.filter_on_query and v.get('filter_by_query', True):
|
# Note that we only do this for a top-level serializer, to avoid issues with nested serializers
|
||||||
|
if (
|
||||||
|
is_top_level
|
||||||
|
and val is None
|
||||||
|
and self.filter_on_query
|
||||||
|
and v.get('filter_by_query', True)
|
||||||
|
):
|
||||||
val = query_params.pop(pop_ref, None)
|
val = query_params.pop(pop_ref, None)
|
||||||
|
|
||||||
if isinstance(val, list) and len(val) == 1:
|
if isinstance(val, list) and len(val) == 1:
|
||||||
val = val[0]
|
val = val[0]
|
||||||
|
|
||||||
@@ -199,7 +216,9 @@ class FilterableSerializerMixin:
|
|||||||
tgs_vals[k] = (
|
tgs_vals[k] = (
|
||||||
str2bool(val) if isinstance(val, (str, int, float)) else val
|
str2bool(val) if isinstance(val, (str, int, float)) else val
|
||||||
) # Support for various filtering style for backwards compatibility
|
) # Support for various filtering style for backwards compatibility
|
||||||
|
|
||||||
self.filter_target_values = tgs_vals
|
self.filter_target_values = tgs_vals
|
||||||
|
self._was_filtered = True
|
||||||
|
|
||||||
# Ensure this mixin is not broadly applied as it is expensive on scale (total CI time increased by 21% when running all coverage tests)
|
# Ensure this mixin is not broadly applied as it is expensive on scale (total CI time increased by 21% when running all coverage tests)
|
||||||
if len(self.filter_targets) == 0 and not self.no_filters:
|
if len(self.filter_targets) == 0 and not self.no_filters:
|
||||||
@@ -216,14 +235,12 @@ class FilterableSerializerMixin:
|
|||||||
):
|
):
|
||||||
return
|
return
|
||||||
|
|
||||||
# Skip filtering when exporting data - leave all fields intact
|
is_exporting = getattr(self, '_exporting_data', False)
|
||||||
if getattr(self, '_exporting_data', False):
|
|
||||||
return
|
|
||||||
|
|
||||||
# Skip filtering for a write requests - all fields should be present for data creation
|
# Skip filtering for a write requests - all fields should be present for data creation
|
||||||
if request := self.context.get('request', None):
|
if request := self.context.get('request', None):
|
||||||
if method := getattr(request, 'method', None):
|
if method := getattr(request, 'method', None):
|
||||||
if str(method).lower() in ['post', 'put', 'patch']:
|
if str(method).lower() in ['post', 'put', 'patch'] and not is_exporting:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Throw out fields which are not requested (either by default or explicitly)
|
# Throw out fields which are not requested (either by default or explicitly)
|
||||||
|
|||||||
@@ -157,7 +157,7 @@ class FilteredSerializers(InvenTreeAPITestCase):
|
|||||||
_ = BadSerializer()
|
_ = BadSerializer()
|
||||||
self.assertTrue(True) # Dummy assertion to ensure we reach here
|
self.assertTrue(True) # Dummy assertion to ensure we reach here
|
||||||
|
|
||||||
def test_failiure_OutputOptionsMixin(self):
|
def test_failure_OutputOptionsMixin(self):
|
||||||
"""Test failure case for OutputOptionsMixin."""
|
"""Test failure case for OutputOptionsMixin."""
|
||||||
|
|
||||||
class BadSerializer(InvenTree.serializers.InvenTreeModelSerializer):
|
class BadSerializer(InvenTree.serializers.InvenTreeModelSerializer):
|
||||||
|
|||||||
@@ -655,7 +655,9 @@ class InvenTreeAPITestCase(
|
|||||||
# Append URL params
|
# Append URL params
|
||||||
url += '?' + '&'.join([f'{key}={value}' for key, value in params.items()])
|
url += '?' + '&'.join([f'{key}={value}' for key, value in params.items()])
|
||||||
|
|
||||||
response = self.client.get(url, data=None, format='json')
|
response = self.get(
|
||||||
|
url, data=None, format='json', expected_code=expected_code, **kwargs
|
||||||
|
)
|
||||||
self.check_response(url, response, expected_code=expected_code)
|
self.check_response(url, response, expected_code=expected_code)
|
||||||
|
|
||||||
# Check that the response is of the correct type
|
# Check that the response is of the correct type
|
||||||
|
|||||||
@@ -337,6 +337,12 @@ class DataExportViewMixin:
|
|||||||
# Update the output instance with the total number of items to export
|
# Update the output instance with the total number of items to export
|
||||||
output.total = queryset.count()
|
output.total = queryset.count()
|
||||||
output.save()
|
output.save()
|
||||||
|
request = context.get('request', None)
|
||||||
|
|
||||||
|
if request:
|
||||||
|
query_params = getattr(request, 'query_params', {})
|
||||||
|
context.update(**query_params)
|
||||||
|
context['request'] = request
|
||||||
|
|
||||||
data = None
|
data = None
|
||||||
serializer = serializer_class(context=context, exporting=True)
|
serializer = serializer_class(context=context, exporting=True)
|
||||||
@@ -363,7 +369,12 @@ class DataExportViewMixin:
|
|||||||
# The returned data *must* be a list of dict objects
|
# The returned data *must* be a list of dict objects
|
||||||
try:
|
try:
|
||||||
data = export_plugin.export_data(
|
data = export_plugin.export_data(
|
||||||
queryset, serializer_class, headers, export_context, output
|
queryset,
|
||||||
|
serializer_class,
|
||||||
|
headers,
|
||||||
|
export_context,
|
||||||
|
output,
|
||||||
|
serializer_context=context,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -90,6 +90,7 @@ class DataExportMixin:
|
|||||||
headers: OrderedDict,
|
headers: OrderedDict,
|
||||||
context: dict,
|
context: dict,
|
||||||
output: DataOutput,
|
output: DataOutput,
|
||||||
|
serializer_context: Optional[dict] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> list:
|
) -> list:
|
||||||
"""Export data from the queryset.
|
"""Export data from the queryset.
|
||||||
@@ -100,6 +101,7 @@ class DataExportMixin:
|
|||||||
Arguments:
|
Arguments:
|
||||||
queryset: The queryset to export
|
queryset: The queryset to export
|
||||||
serializer_class: The serializer class to use for exporting the data
|
serializer_class: The serializer class to use for exporting the data
|
||||||
|
serializer_context: Optional context for the serializer
|
||||||
headers: The headers for the export
|
headers: The headers for the export
|
||||||
context: Any custom context for the export (provided by the plugin serializer)
|
context: Any custom context for the export (provided by the plugin serializer)
|
||||||
output: The DataOutput object for the export
|
output: The DataOutput object for the export
|
||||||
@@ -107,7 +109,9 @@ class DataExportMixin:
|
|||||||
Returns: The exported data (a list of dict objects)
|
Returns: The exported data (a list of dict objects)
|
||||||
"""
|
"""
|
||||||
# The default implementation simply serializes the queryset
|
# The default implementation simply serializes the queryset
|
||||||
return serializer_class(queryset, many=True, exporting=True).data
|
return serializer_class(
|
||||||
|
queryset, many=True, exporting=True, context=serializer_context or {}
|
||||||
|
).data
|
||||||
|
|
||||||
def get_export_options_serializer(self, **kwargs) -> serializers.Serializer | None:
|
def get_export_options_serializer(self, **kwargs) -> serializers.Serializer | None:
|
||||||
"""Return a serializer class with dynamic export options for this plugin.
|
"""Return a serializer class with dynamic export options for this plugin.
|
||||||
|
|||||||
@@ -867,9 +867,13 @@ class StockItemListTest(StockAPITestCase):
|
|||||||
|
|
||||||
excluded_headers = ['metadata']
|
excluded_headers = ['metadata']
|
||||||
|
|
||||||
filters = {}
|
filters = {
|
||||||
|
'part_detail': True,
|
||||||
|
'location_detail': True,
|
||||||
|
'supplier_part_detail': True,
|
||||||
|
}
|
||||||
|
|
||||||
with self.export_data(self.list_url, filters) as data_file:
|
with self.export_data(self.list_url, params=filters) as data_file:
|
||||||
self.process_csv(
|
self.process_csv(
|
||||||
data_file,
|
data_file,
|
||||||
required_cols=required_headers,
|
required_cols=required_headers,
|
||||||
@@ -881,7 +885,7 @@ class StockItemListTest(StockAPITestCase):
|
|||||||
filters['location'] = 1
|
filters['location'] = 1
|
||||||
filters['cascade'] = True
|
filters['cascade'] = True
|
||||||
|
|
||||||
with self.export_data(self.list_url, filters) as data_file:
|
with self.export_data(self.list_url, params=filters) as data_file:
|
||||||
data = self.process_csv(data_file, required_rows=9)
|
data = self.process_csv(data_file, required_rows=9)
|
||||||
|
|
||||||
for row in data:
|
for row in data:
|
||||||
@@ -909,6 +913,51 @@ class StockItemListTest(StockAPITestCase):
|
|||||||
with self.export_data(self.list_url, {'part': 25}) as data_file:
|
with self.export_data(self.list_url, {'part': 25}) as data_file:
|
||||||
self.process_csv(data_file, required_rows=items.count())
|
self.process_csv(data_file, required_rows=items.count())
|
||||||
|
|
||||||
|
def test_large_export(self):
|
||||||
|
"""Test export of very large dataset.
|
||||||
|
|
||||||
|
- Ensure that the time taken to export a large dataset is reasonable.
|
||||||
|
- Ensure that the number of DB queries is reasonable.
|
||||||
|
"""
|
||||||
|
# Create a large number of stock items
|
||||||
|
locations = list(StockLocation.objects.all())
|
||||||
|
parts = list(Part.objects.filter(virtual=False))
|
||||||
|
|
||||||
|
idx = 0
|
||||||
|
|
||||||
|
N_LOCATIONS = len(locations)
|
||||||
|
N_PARTS = len(parts)
|
||||||
|
|
||||||
|
stock_items = []
|
||||||
|
|
||||||
|
while idx < 2500:
|
||||||
|
part = parts[idx % N_PARTS]
|
||||||
|
location = locations[idx % N_LOCATIONS]
|
||||||
|
|
||||||
|
item = StockItem(
|
||||||
|
part=part,
|
||||||
|
location=location,
|
||||||
|
quantity=10,
|
||||||
|
level=0,
|
||||||
|
tree_id=0,
|
||||||
|
lft=0,
|
||||||
|
rght=0,
|
||||||
|
)
|
||||||
|
stock_items.append(item)
|
||||||
|
idx += 1
|
||||||
|
|
||||||
|
StockItem.objects.bulk_create(stock_items)
|
||||||
|
|
||||||
|
self.assertGreaterEqual(StockItem.objects.count(), 2500)
|
||||||
|
|
||||||
|
# Note: While the export is quick on pgsql, it is still quite slow on sqlite3
|
||||||
|
with self.export_data(
|
||||||
|
self.list_url, max_query_count=50, max_query_time=7.5
|
||||||
|
) as data_file:
|
||||||
|
data = self.process_csv(data_file)
|
||||||
|
|
||||||
|
self.assertGreaterEqual(len(data), 2500)
|
||||||
|
|
||||||
def test_filter_by_allocated(self):
|
def test_filter_by_allocated(self):
|
||||||
"""Test that we can filter by "allocated" status.
|
"""Test that we can filter by "allocated" status.
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user