2
0
mirror of https://github.com/inventree/InvenTree.git synced 2025-04-29 12:06:44 +00:00
InvenTree/InvenTree/InvenTree/serializers.py
Lukas 99d122baa9
Refactor model helpers into own file (#4927)
* Refactor model helpers into own file to allow helper import when apps not loaded yet

* Import helper functions at module level

* Added missing imports where vscode couldnt help because its no explicit import
2023-05-31 09:18:42 +10:00

737 lines
22 KiB
Python

"""Serializers used in various InvenTree apps."""
import os
from collections import OrderedDict
from decimal import Decimal
from django.conf import settings
from django.contrib.auth.models import User
from django.core.exceptions import ValidationError as DjangoValidationError
from django.db import models
from django.utils.translation import gettext_lazy as _
import tablib
from djmoney.contrib.django_rest_framework.fields import MoneyField
from djmoney.money import Money
from djmoney.utils import MONEY_CLASSES, get_currency_field_name
from rest_framework import serializers
from rest_framework.exceptions import ValidationError
from rest_framework.fields import empty
from rest_framework.serializers import DecimalField
from rest_framework.utils import model_meta
from taggit.serializers import TaggitSerializer
from common.models import InvenTreeSetting
from common.settings import currency_code_default, currency_code_mappings
from InvenTree.fields import InvenTreeRestURLField, InvenTreeURLField
from InvenTree.helpers_model import download_image_from_url
class InvenTreeMoneySerializer(MoneyField):
"""Custom serializer for 'MoneyField', which ensures that passed values are numerically valid.
Ref: https://github.com/django-money/django-money/blob/master/djmoney/contrib/django_rest_framework/fields.py
"""
def __init__(self, *args, **kwargs):
"""Overrite default values."""
kwargs["max_digits"] = kwargs.get("max_digits", 19)
self.decimal_places = kwargs["decimal_places"] = kwargs.get("decimal_places", 6)
kwargs["required"] = kwargs.get("required", False)
super().__init__(*args, **kwargs)
def get_value(self, data):
"""Test that the returned amount is a valid Decimal."""
amount = super(DecimalField, self).get_value(data)
# Convert an empty string to None
if len(str(amount).strip()) == 0:
amount = None
try:
if amount is not None and amount is not empty:
# Convert to a Decimal instance, and round to maximum allowed decimal places
amount = Decimal(amount)
amount = round(amount, self.decimal_places)
except Exception:
raise ValidationError({
self.field_name: [_("Must be a valid number")],
})
currency = data.get(get_currency_field_name(self.field_name), self.default_currency)
if currency and amount is not None and not isinstance(amount, MONEY_CLASSES) and amount is not empty:
return Money(amount, currency)
return amount
class InvenTreeCurrencySerializer(serializers.ChoiceField):
"""Custom serializers for selecting currency option"""
def __init__(self, *args, **kwargs):
"""Initialize the currency serializer"""
choices = currency_code_mappings()
allow_blank = kwargs.get('allow_blank', False) or kwargs.get('allow_null', False)
if allow_blank:
choices = [('', '---------')] + choices
kwargs['choices'] = choices
if 'default' not in kwargs and 'required' not in kwargs:
kwargs['default'] = '' if allow_blank else currency_code_default
if 'label' not in kwargs:
kwargs['label'] = _('Currency')
if 'help_text' not in kwargs:
kwargs['help_text'] = _('Select currency from available options')
super().__init__(*args, **kwargs)
class InvenTreeModelSerializer(serializers.ModelSerializer):
"""Inherits the standard Django ModelSerializer class, but also ensures that the underlying model class data are checked on validation."""
# Switch out URLField mapping
serializer_field_mapping = {
**serializers.ModelSerializer.serializer_field_mapping,
models.URLField: InvenTreeRestURLField,
InvenTreeURLField: InvenTreeRestURLField,
}
def __init__(self, instance=None, data=empty, **kwargs):
"""Custom __init__ routine to ensure that *default* values (as specified in the ORM) are used by the DRF serializers, *if* the values are not provided by the user."""
# If instance is None, we are creating a new instance
if instance is None and data is not empty:
if data is None:
data = OrderedDict()
else:
new_data = OrderedDict()
new_data.update(data)
data = new_data
# Add missing fields which have default values
ModelClass = self.Meta.model
fields = model_meta.get_field_info(ModelClass)
for field_name, field in fields.fields.items():
"""
Update the field IF (and ONLY IF):
- The field has a specified default value
- The field does not already have a value set
"""
if field.has_default() and field_name not in data:
value = field.default
# Account for callable functions
if callable(value):
try:
value = value()
except Exception:
continue
data[field_name] = value
super().__init__(instance, data, **kwargs)
def get_initial(self):
"""Construct initial data for the serializer.
Use the 'default' values specified by the django model definition
"""
initials = super().get_initial().copy()
# Are we creating a new instance?
if self.instance is None:
ModelClass = self.Meta.model
fields = model_meta.get_field_info(ModelClass)
for field_name, field in fields.fields.items():
if field.has_default() and field_name not in initials:
value = field.default
# Account for callable functions
if callable(value):
try:
value = value()
except Exception:
continue
initials[field_name] = value
return initials
def skip_create_fields(self):
"""Return a list of 'fields' which should be skipped for model creation.
This is used to 'bypass' a shortcoming of the DRF framework,
which does not allow us to have writeable serializer fields which do not exist on the model.
Default implementation returns an empty list
"""
return []
def save(self, **kwargs):
"""Catch any django ValidationError thrown at the moment `save` is called, and re-throw as a DRF ValidationError."""
try:
super().save(**kwargs)
except (ValidationError, DjangoValidationError) as exc:
raise ValidationError(detail=serializers.as_serializer_error(exc))
return self.instance
def create(self, validated_data):
"""Custom create method which supports field adjustment"""
initial_data = validated_data.copy()
# Remove any fields which do not exist on the model
for field in self.skip_create_fields():
initial_data.pop(field, None)
return super().create(initial_data)
def update(self, instance, validated_data):
"""Catch any django ValidationError, and re-throw as a DRF ValidationError."""
try:
instance = super().update(instance, validated_data)
except (ValidationError, DjangoValidationError) as exc:
raise ValidationError(detail=serializers.as_serializer_error(exc))
return instance
def run_validation(self, data=empty):
"""Perform serializer validation.
In addition to running validators on the serializer fields,
this class ensures that the underlying model is also validated.
"""
# Run any native validation checks first (may raise a ValidationError)
data = super().run_validation(data)
if not hasattr(self, 'instance') or self.instance is None:
# No instance exists (we are creating a new one)
initial_data = data.copy()
for field in self.skip_create_fields():
# Remove any fields we do not wish to provide to the model
initial_data.pop(field, None)
# Create a (RAM only) instance for extra testing
instance = self.Meta.model(**initial_data)
else:
# Instance already exists (we are updating!)
instance = self.instance
# Update instance fields
for attr, value in data.items():
try:
setattr(instance, attr, value)
except (ValidationError, DjangoValidationError) as exc:
raise ValidationError(detail=serializers.as_serializer_error(exc))
# Run a 'full_clean' on the model.
# Note that by default, DRF does *not* perform full model validation!
try:
instance.full_clean()
except (ValidationError, DjangoValidationError) as exc:
data = exc.message_dict
# Change '__all__' key (django style) to 'non_field_errors' (DRF style)
if '__all__' in data:
data['non_field_errors'] = data['__all__']
del data['__all__']
raise ValidationError(data)
return data
class InvenTreeTaggitSerializer(TaggitSerializer):
"""Updated from https://github.com/glemmaPaul/django-taggit-serializer."""
def update(self, instance, validated_data):
"""Overriden update method to readd the tagmanager."""
to_be_tagged, validated_data = self._pop_tags(validated_data)
tag_object = super().update(instance, validated_data)
for key in to_be_tagged.keys():
# readd the tagmanager
new_tagobject = tag_object.__class__.objects.get(id=tag_object.id)
setattr(tag_object, key, getattr(new_tagobject, key))
return self._save_tags(tag_object, to_be_tagged)
class InvenTreeTagModelSerializer(InvenTreeTaggitSerializer, InvenTreeModelSerializer):
"""Combination of InvenTreeTaggitSerializer and InvenTreeModelSerializer."""
pass
class UserSerializer(InvenTreeModelSerializer):
"""Serializer for a User."""
class Meta:
"""Metaclass defines serializer fields."""
model = User
fields = [
'pk',
'username',
'first_name',
'last_name',
'email'
]
class InvenTreeAttachmentSerializerField(serializers.FileField):
"""Override the DRF native FileField serializer, to remove the leading server path.
For example, the FileField might supply something like:
http://127.0.0.1:8000/media/foo/bar.jpg
Whereas we wish to return:
/media/foo/bar.jpg
If the server process is serving the data at 127.0.0.1,
but a proxy service (e.g. nginx) is then providing DNS lookup to the outside world,
then an attachment which prefixes the "address" of the internal server
will not be accessible from the outside world.
"""
def to_representation(self, value):
"""To json-serializable type."""
if not value:
return None
return os.path.join(str(settings.MEDIA_URL), str(value))
class InvenTreeAttachmentSerializer(InvenTreeModelSerializer):
"""Special case of an InvenTreeModelSerializer, which handles an "attachment" model.
The only real addition here is that we support "renaming" of the attachment file.
"""
@staticmethod
def attachment_fields(extra_fields=None):
"""Default set of fields for an attachment serializer"""
fields = [
'pk',
'attachment',
'filename',
'link',
'comment',
'upload_date',
'user',
'user_detail',
]
if extra_fields:
fields += extra_fields
return fields
user_detail = UserSerializer(source='user', read_only=True, many=False)
attachment = InvenTreeAttachmentSerializerField(
required=False,
allow_null=False,
)
# The 'filename' field must be present in the serializer
filename = serializers.CharField(
label=_('Filename'),
required=False,
source='basename',
allow_blank=False,
)
upload_date = serializers.DateField(read_only=True)
class InvenTreeImageSerializerField(serializers.ImageField):
"""Custom image serializer.
On upload, validate that the file is a valid image file
"""
def to_representation(self, value):
"""To json-serializable type."""
if not value:
return None
return os.path.join(str(settings.MEDIA_URL), str(value))
class InvenTreeDecimalField(serializers.FloatField):
"""Custom serializer for decimal fields.
Solves the following issues:
- The normal DRF DecimalField renders values with trailing zeros
- Using a FloatField can result in rounding issues: https://code.djangoproject.com/ticket/30290
"""
def to_internal_value(self, data):
"""Convert to python type."""
# Convert the value to a string, and then a decimal
try:
return Decimal(str(data))
except Exception:
raise serializers.ValidationError(_("Invalid value"))
class DataFileUploadSerializer(serializers.Serializer):
"""Generic serializer for uploading a data file, and extracting a dataset.
- Validates uploaded file
- Extracts column names
- Extracts data rows
"""
# Implementing class should register a target model (database model) to be used for import
TARGET_MODEL = None
class Meta:
"""Metaclass options."""
fields = [
'data_file',
]
data_file = serializers.FileField(
label=_("Data File"),
help_text=_("Select data file for upload"),
required=True,
allow_empty_file=False,
)
def validate_data_file(self, data_file):
"""Perform validation checks on the uploaded data file."""
self.filename = data_file.name
name, ext = os.path.splitext(data_file.name)
# Remove the leading . from the extension
ext = ext[1:]
accepted_file_types = [
'xls', 'xlsx',
'csv', 'tsv',
'xml',
]
if ext not in accepted_file_types:
raise serializers.ValidationError(_("Unsupported file type"))
# Impose a 50MB limit on uploaded BOM files
max_upload_file_size = 50 * 1024 * 1024
if data_file.size > max_upload_file_size:
raise serializers.ValidationError(_("File is too large"))
# Read file data into memory (bytes object)
try:
data = data_file.read()
except Exception as e:
raise serializers.ValidationError(str(e))
if ext in ['csv', 'tsv', 'xml']:
try:
data = data.decode()
except Exception as e:
raise serializers.ValidationError(str(e))
# Convert to a tablib dataset (we expect headers)
try:
self.dataset = tablib.Dataset().load(data, ext, headers=True)
except Exception as e:
raise serializers.ValidationError(str(e))
if len(self.dataset.headers) == 0:
raise serializers.ValidationError(_("No columns found in file"))
if len(self.dataset) == 0:
raise serializers.ValidationError(_("No data rows found in file"))
return data_file
def match_column(self, column_name, field_names, exact=False):
"""Attempt to match a column name (from the file) to a field (defined in the model).
Order of matching is:
- Direct match
- Case insensitive match
- Fuzzy match
"""
if not column_name:
return None
column_name = str(column_name).strip()
column_name_lower = column_name.lower()
if column_name in field_names:
return column_name
for field_name in field_names:
if field_name.lower() == column_name_lower:
return field_name
if exact:
# Finished available 'exact' matches
return None
# TODO: Fuzzy pattern matching for column names
# No matches found
return None
def extract_data(self):
"""Returns dataset extracted from the file."""
# Provide a dict of available import fields for the model
model_fields = {}
# Keep track of columns we have already extracted
matched_columns = set()
if self.TARGET_MODEL:
try:
model_fields = self.TARGET_MODEL.get_import_fields()
except Exception:
pass
# Extract a list of valid model field names
model_field_names = list(model_fields.keys())
# Provide a dict of available columns from the dataset
file_columns = {}
for header in self.dataset.headers:
column = {}
# Attempt to "match" file columns to model fields
match = self.match_column(header, model_field_names, exact=True)
if match is not None and match not in matched_columns:
matched_columns.add(match)
column['value'] = match
else:
column['value'] = None
file_columns[header] = column
return {
'file_fields': file_columns,
'model_fields': model_fields,
'rows': [row.values() for row in self.dataset.dict],
'filename': self.filename,
}
def save(self):
"""Empty overwrite for save."""
...
class DataFileExtractSerializer(serializers.Serializer):
"""Generic serializer for extracting data from an imported dataset.
- User provides an array of matched headers
- User provides an array of raw data rows
"""
# Implementing class should register a target model (database model) to be used for import
TARGET_MODEL = None
class Meta:
"""Metaclass options."""
fields = [
'columns',
'rows',
]
# Mapping of columns
columns = serializers.ListField(
child=serializers.CharField(
allow_blank=True,
),
)
rows = serializers.ListField(
child=serializers.ListField(
child=serializers.CharField(
allow_blank=True,
allow_null=True,
),
)
)
def validate(self, data):
"""Clean data."""
data = super().validate(data)
self.columns = data.get('columns', [])
self.rows = data.get('rows', [])
if len(self.rows) == 0:
raise serializers.ValidationError(_("No data rows provided"))
if len(self.columns) == 0:
raise serializers.ValidationError(_("No data columns supplied"))
self.validate_extracted_columns()
return data
@property
def data(self):
"""Returns current data."""
if self.TARGET_MODEL:
try:
model_fields = self.TARGET_MODEL.get_import_fields()
except Exception:
model_fields = {}
rows = []
for row in self.rows:
"""Optionally pre-process each row, before sending back to the client."""
processed_row = self.process_row(self.row_to_dict(row))
if processed_row:
rows.append({
"original": row,
"data": processed_row,
})
return {
'fields': model_fields,
'columns': self.columns,
'rows': rows,
}
def process_row(self, row):
"""Process a 'row' of data, which is a mapped column:value dict.
Returns either a mapped column:value dict, or None.
If the function returns None, the column is ignored!
"""
# Default implementation simply returns the original row data
return row
def row_to_dict(self, row):
"""Convert a "row" to a named data dict."""
row_dict = {
'errors': {},
}
for idx, value in enumerate(row):
if idx < len(self.columns):
col = self.columns[idx]
if col:
row_dict[col] = value
return row_dict
def validate_extracted_columns(self):
"""Perform custom validation of header mapping."""
if self.TARGET_MODEL:
try:
model_fields = self.TARGET_MODEL.get_import_fields()
except Exception:
model_fields = {}
cols_seen = set()
for name, field in model_fields.items():
required = field.get('required', False)
# Check for missing required columns
if required:
if name not in self.columns:
raise serializers.ValidationError(_(f"Missing required column: '{name}'"))
for col in self.columns:
if not col:
continue
# Check for duplicated columns
if col in cols_seen:
raise serializers.ValidationError(_(f"Duplicate column: '{col}'"))
cols_seen.add(col)
def save(self):
"""No "save" action for this serializer."""
pass
class RemoteImageMixin(metaclass=serializers.SerializerMetaclass):
"""Mixin class which allows downloading an 'image' from a remote URL.
Adds the optional, write-only `remote_image` field to the serializer
"""
def skip_create_fields(self):
"""Ensure the 'remote_image' field is skipped when creating a new instance"""
return [
'remote_image',
]
remote_image = serializers.URLField(
required=False,
allow_blank=False,
write_only=True,
label=_("URL"),
help_text=_("URL of remote image file"),
)
def validate_remote_image(self, url):
"""Perform custom validation for the remote image URL.
- Attempt to download the image and store it against this object instance
- Catches and re-throws any errors
"""
if not url:
return
if not InvenTreeSetting.get_setting('INVENTREE_DOWNLOAD_FROM_URL'):
raise ValidationError(_("Downloading images from remote URL is not enabled"))
try:
self.remote_image_file = download_image_from_url(url)
except Exception as exc:
self.remote_image_file = None
raise ValidationError(str(exc))
return url