From 7049e84ac3bb07bb476a4307bb29d5edaee40dd3 Mon Sep 17 00:00:00 2001 From: Oliver Date: Wed, 12 Feb 2025 07:24:24 +1100 Subject: [PATCH] Fix for data import (#9060) - Prevent shadow overwrite of default_values dict - Remove dead code --- src/backend/InvenTree/importer/models.py | 32 +++++++++----------- src/backend/InvenTree/importer/operations.py | 17 ----------- 2 files changed, 15 insertions(+), 34 deletions(-) diff --git a/src/backend/InvenTree/importer/models.py b/src/backend/InvenTree/importer/models.py index a23394b88a..5191242cf3 100644 --- a/src/backend/InvenTree/importer/models.py +++ b/src/backend/InvenTree/importer/models.py @@ -111,17 +111,13 @@ class DataImportSession(models.Model): ) @property - def field_mapping(self): + def field_mapping(self) -> dict: """Construct a dict of field mappings for this import session. - Returns: A dict of field: column mappings + Returns: + A dict of field -> column mappings """ - mapping = {} - - for i in self.column_mappings.all(): - mapping[i.field] = i.column - - return mapping + return {mapping.field: mapping.column for mapping in self.column_mappings.all()} @property def model_class(self): @@ -138,7 +134,7 @@ class DataImportSession(models.Model): return supported_models().get(self.model_type, None) - def extract_columns(self): + def extract_columns(self) -> None: """Run initial column extraction and mapping. This method is called when the import session is first created. @@ -211,7 +207,7 @@ class DataImportSession(models.Model): self.status = DataImportStatusCode.MAPPING.value self.save() - def accept_mapping(self): + def accept_mapping(self) -> None: """Accept current mapping configuration. - Validate that the current column mapping is correct @@ -250,7 +246,7 @@ class DataImportSession(models.Model): # No errors, so trigger the data import process self.trigger_data_import() - def trigger_data_import(self): + def trigger_data_import(self) -> None: """Trigger the data import process for this session. Offloads the task to the background worker process. @@ -263,7 +259,7 @@ class DataImportSession(models.Model): offload_task(importer.tasks.import_data, self.pk) - def import_data(self): + def import_data(self) -> None: """Perform the data import process for this session.""" # Clear any existing data rows self.rows.all().delete() @@ -323,12 +319,12 @@ class DataImportSession(models.Model): return True @property - def row_count(self): + def row_count(self) -> int: """Return the number of rows in the import session.""" return self.rows.count() @property - def completed_row_count(self): + def completed_row_count(self) -> int: """Return the number of completed rows for this session.""" return self.rows.filter(complete=True).count() @@ -356,7 +352,7 @@ class DataImportSession(models.Model): self._available_fields = fields return fields - def required_fields(self): + def required_fields(self) -> dict: """Returns information on which fields are *required* for import.""" fields = self.available_fields() @@ -598,7 +594,7 @@ class DataImportRow(models.Model): value = value or None # Use the default value, if provided - if value in [None, ''] and field in default_values: + if value is None and field in default_values: value = default_values[field] data[field] = value @@ -614,7 +610,9 @@ class DataImportRow(models.Model): - If available, we use the "default" values provided by the import session - If available, we use the "override" values provided by the import session """ - data = self.default_values + data = {} + + data.update(self.default_values) if self.data: data.update(self.data) diff --git a/src/backend/InvenTree/importer/operations.py b/src/backend/InvenTree/importer/operations.py index 5e5551c284..a15713b840 100644 --- a/src/backend/InvenTree/importer/operations.py +++ b/src/backend/InvenTree/importer/operations.py @@ -83,23 +83,6 @@ def extract_column_names(data_file) -> list: return headers -def extract_rows(data_file) -> list: - """Extract rows from the data file. - - Each returned row is a dictionary of column_name: value pairs. - """ - data = load_data_file(data_file) - - headers = data.headers - - rows = [] - - for row in data: - rows.append(dict(zip(headers, row))) - - return rows - - def get_field_label(field) -> str: """Return the label for a field in a serializer class.