mirror of
https://github.com/inventree/InvenTree.git
synced 2025-04-27 19:16:44 +00:00
- Prevent shadow overwrite of default_values dict - Remove dead code (cherry picked from commit 7049e84ac3bb07bb476a4307bb29d5edaee40dd3) Co-authored-by: Oliver <oliver.henry.walters@gmail.com>
This commit is contained in:
parent
407ccb7bd2
commit
3b6b41976f
@ -111,17 +111,13 @@ class DataImportSession(models.Model):
|
||||
)
|
||||
|
||||
@property
|
||||
def field_mapping(self):
|
||||
def field_mapping(self) -> dict:
|
||||
"""Construct a dict of field mappings for this import session.
|
||||
|
||||
Returns: A dict of field: column mappings
|
||||
Returns:
|
||||
A dict of field -> column mappings
|
||||
"""
|
||||
mapping = {}
|
||||
|
||||
for i in self.column_mappings.all():
|
||||
mapping[i.field] = i.column
|
||||
|
||||
return mapping
|
||||
return {mapping.field: mapping.column for mapping in self.column_mappings.all()}
|
||||
|
||||
@property
|
||||
def model_class(self):
|
||||
@ -138,7 +134,7 @@ class DataImportSession(models.Model):
|
||||
|
||||
return supported_models().get(self.model_type, None)
|
||||
|
||||
def extract_columns(self):
|
||||
def extract_columns(self) -> None:
|
||||
"""Run initial column extraction and mapping.
|
||||
|
||||
This method is called when the import session is first created.
|
||||
@ -204,7 +200,7 @@ class DataImportSession(models.Model):
|
||||
self.status = DataImportStatusCode.MAPPING.value
|
||||
self.save()
|
||||
|
||||
def accept_mapping(self):
|
||||
def accept_mapping(self) -> None:
|
||||
"""Accept current mapping configuration.
|
||||
|
||||
- Validate that the current column mapping is correct
|
||||
@ -243,7 +239,7 @@ class DataImportSession(models.Model):
|
||||
# No errors, so trigger the data import process
|
||||
self.trigger_data_import()
|
||||
|
||||
def trigger_data_import(self):
|
||||
def trigger_data_import(self) -> None:
|
||||
"""Trigger the data import process for this session.
|
||||
|
||||
Offloads the task to the background worker process.
|
||||
@ -256,7 +252,7 @@ class DataImportSession(models.Model):
|
||||
|
||||
offload_task(importer.tasks.import_data, self.pk)
|
||||
|
||||
def import_data(self):
|
||||
def import_data(self) -> None:
|
||||
"""Perform the data import process for this session."""
|
||||
# Clear any existing data rows
|
||||
self.rows.all().delete()
|
||||
@ -316,12 +312,12 @@ class DataImportSession(models.Model):
|
||||
return True
|
||||
|
||||
@property
|
||||
def row_count(self):
|
||||
def row_count(self) -> int:
|
||||
"""Return the number of rows in the import session."""
|
||||
return self.rows.count()
|
||||
|
||||
@property
|
||||
def completed_row_count(self):
|
||||
def completed_row_count(self) -> int:
|
||||
"""Return the number of completed rows for this session."""
|
||||
return self.rows.filter(complete=True).count()
|
||||
|
||||
@ -349,7 +345,7 @@ class DataImportSession(models.Model):
|
||||
self._available_fields = fields
|
||||
return fields
|
||||
|
||||
def required_fields(self):
|
||||
def required_fields(self) -> dict:
|
||||
"""Returns information on which fields are *required* for import."""
|
||||
fields = self.available_fields()
|
||||
|
||||
@ -591,7 +587,7 @@ class DataImportRow(models.Model):
|
||||
value = value or None
|
||||
|
||||
# Use the default value, if provided
|
||||
if value in [None, ''] and field in default_values:
|
||||
if value is None and field in default_values:
|
||||
value = default_values[field]
|
||||
|
||||
data[field] = value
|
||||
@ -607,7 +603,9 @@ class DataImportRow(models.Model):
|
||||
- If available, we use the "default" values provided by the import session
|
||||
- If available, we use the "override" values provided by the import session
|
||||
"""
|
||||
data = self.default_values
|
||||
data = {}
|
||||
|
||||
data.update(self.default_values)
|
||||
|
||||
if self.data:
|
||||
data.update(self.data)
|
||||
|
@ -81,23 +81,6 @@ def extract_column_names(data_file) -> list:
|
||||
return headers
|
||||
|
||||
|
||||
def extract_rows(data_file) -> list:
|
||||
"""Extract rows from the data file.
|
||||
|
||||
Each returned row is a dictionary of column_name: value pairs.
|
||||
"""
|
||||
data = load_data_file(data_file)
|
||||
|
||||
headers = data.headers
|
||||
|
||||
rows = []
|
||||
|
||||
for row in data:
|
||||
rows.append(dict(zip(headers, row)))
|
||||
|
||||
return rows
|
||||
|
||||
|
||||
def get_field_label(field) -> str:
|
||||
"""Return the label for a field in a serializer class.
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user