2
0
mirror of https://github.com/inventree/InvenTree.git synced 2025-04-29 12:06:44 +00:00
InvenTree/InvenTree/InvenTree/serializers.py

644 lines
18 KiB
Python

"""
Serializers used in various InvenTree apps
"""
# -*- coding: utf-8 -*-
from __future__ import unicode_literals
import os
import tablib
from decimal import Decimal
from collections import OrderedDict
from django.conf import settings
from django.contrib.auth.models import User
from django.core.exceptions import ValidationError as DjangoValidationError
from django.utils.translation import gettext_lazy as _
from django.db import models
from djmoney.contrib.django_rest_framework.fields import MoneyField
from djmoney.money import Money
from djmoney.utils import MONEY_CLASSES, get_currency_field_name
from rest_framework import serializers
from rest_framework.utils import model_meta
from rest_framework.fields import empty
from rest_framework.exceptions import ValidationError
from rest_framework.serializers import DecimalField
from .models import extract_int
class InvenTreeMoneySerializer(MoneyField):
"""
Custom serializer for 'MoneyField',
which ensures that passed values are numerically valid
Ref: https://github.com/django-money/django-money/blob/master/djmoney/contrib/django_rest_framework/fields.py
"""
def __init__(self, *args, **kwargs):
kwargs["max_digits"] = kwargs.get("max_digits", 19)
kwargs["decimal_places"] = kwargs.get("decimal_places", 4)
super().__init__(*args, **kwargs)
def get_value(self, data):
"""
Test that the returned amount is a valid Decimal
"""
amount = super(DecimalField, self).get_value(data)
# Convert an empty string to None
if len(str(amount).strip()) == 0:
amount = None
try:
if amount is not None and amount is not empty:
amount = Decimal(amount)
except:
raise ValidationError({
self.field_name: [_("Must be a valid number")],
})
currency = data.get(get_currency_field_name(self.field_name), self.default_currency)
if currency and amount is not None and not isinstance(amount, MONEY_CLASSES) and amount is not empty:
return Money(amount, currency)
return amount
class UserSerializer(serializers.ModelSerializer):
""" Serializer for User - provides all fields """
class Meta:
model = User
fields = 'all'
class UserSerializerBrief(serializers.ModelSerializer):
""" Serializer for User - provides limited information """
class Meta:
model = User
fields = [
'pk',
'username',
]
class InvenTreeModelSerializer(serializers.ModelSerializer):
"""
Inherits the standard Django ModelSerializer class,
but also ensures that the underlying model class data are checked on validation.
"""
def __init__(self, instance=None, data=empty, **kwargs):
"""
Custom __init__ routine to ensure that *default* values (as specified in the ORM)
are used by the DRF serializers, *if* the values are not provided by the user.
"""
# If instance is None, we are creating a new instance
if instance is None and data is not empty:
if data is None:
data = OrderedDict()
else:
new_data = OrderedDict()
new_data.update(data)
data = new_data
# Add missing fields which have default values
ModelClass = self.Meta.model
fields = model_meta.get_field_info(ModelClass)
for field_name, field in fields.fields.items():
"""
Update the field IF (and ONLY IF):
- The field has a specified default value
- The field does not already have a value set
"""
if field.has_default() and field_name not in data:
value = field.default
# Account for callable functions
if callable(value):
try:
value = value()
except:
continue
data[field_name] = value
super().__init__(instance, data, **kwargs)
def get_initial(self):
"""
Construct initial data for the serializer.
Use the 'default' values specified by the django model definition
"""
initials = super().get_initial().copy()
# Are we creating a new instance?
if self.instance is None:
ModelClass = self.Meta.model
fields = model_meta.get_field_info(ModelClass)
for field_name, field in fields.fields.items():
if field.has_default() and field_name not in initials:
value = field.default
# Account for callable functions
if callable(value):
try:
value = value()
except:
continue
initials[field_name] = value
return initials
def save(self, **kwargs):
"""
Catch any django ValidationError thrown at the moment save() is called,
and re-throw as a DRF ValidationError
"""
try:
super().save(**kwargs)
except (ValidationError, DjangoValidationError) as exc:
raise ValidationError(detail=serializers.as_serializer_error(exc))
return self.instance
def update(self, instance, validated_data):
"""
Catch any django ValidationError, and re-throw as a DRF ValidationError
"""
try:
instance = super().update(instance, validated_data)
except (ValidationError, DjangoValidationError) as exc:
raise ValidationError(detail=serializers.as_serializer_error(exc))
return instance
def run_validation(self, data=empty):
"""
Perform serializer validation.
In addition to running validators on the serializer fields,
this class ensures that the underlying model is also validated.
"""
# Run any native validation checks first (may raise a ValidationError)
data = super().run_validation(data)
# Now ensure the underlying model is correct
if not hasattr(self, 'instance') or self.instance is None:
# No instance exists (we are creating a new one)
instance = self.Meta.model(**data)
else:
# Instance already exists (we are updating!)
instance = self.instance
# Update instance fields
for attr, value in data.items():
try:
setattr(instance, attr, value)
except (ValidationError, DjangoValidationError) as exc:
raise ValidationError(detail=serializers.as_serializer_error(exc))
# Run a 'full_clean' on the model.
# Note that by default, DRF does *not* perform full model validation!
try:
instance.full_clean()
except (ValidationError, DjangoValidationError) as exc:
data = exc.message_dict
# Change '__all__' key (django style) to 'non_field_errors' (DRF style)
if '__all__' in data:
data['non_field_errors'] = data['__all__']
del data['__all__']
raise ValidationError(data)
return data
class ReferenceIndexingSerializerMixin():
"""
This serializer mixin ensures the the reference is not to big / small
for the BigIntegerField
"""
def validate_reference(self, value):
if extract_int(value) > models.BigIntegerField.MAX_BIGINT:
raise serializers.ValidationError('reference is to to big')
return value
class InvenTreeAttachmentSerializerField(serializers.FileField):
"""
Override the DRF native FileField serializer,
to remove the leading server path.
For example, the FileField might supply something like:
http://127.0.0.1:8000/media/foo/bar.jpg
Whereas we wish to return:
/media/foo/bar.jpg
Why? You can't handle the why!
Actually, if the server process is serving the data at 127.0.0.1,
but a proxy service (e.g. nginx) is then providing DNS lookup to the outside world,
then an attachment which prefixes the "address" of the internal server
will not be accessible from the outside world.
"""
def to_representation(self, value):
if not value:
return None
return os.path.join(str(settings.MEDIA_URL), str(value))
class InvenTreeAttachmentSerializer(InvenTreeModelSerializer):
"""
Special case of an InvenTreeModelSerializer, which handles an "attachment" model.
The only real addition here is that we support "renaming" of the attachment file.
"""
attachment = InvenTreeAttachmentSerializerField(
required=False,
allow_null=False,
)
# The 'filename' field must be present in the serializer
filename = serializers.CharField(
label=_('Filename'),
required=False,
source='basename',
allow_blank=False,
)
class InvenTreeImageSerializerField(serializers.ImageField):
"""
Custom image serializer.
On upload, validate that the file is a valid image file
"""
def to_representation(self, value):
if not value:
return None
return os.path.join(str(settings.MEDIA_URL), str(value))
class InvenTreeDecimalField(serializers.FloatField):
"""
Custom serializer for decimal fields. Solves the following issues:
- The normal DRF DecimalField renders values with trailing zeros
- Using a FloatField can result in rounding issues: https://code.djangoproject.com/ticket/30290
"""
def to_internal_value(self, data):
# Convert the value to a string, and then a decimal
try:
return Decimal(str(data))
except:
raise serializers.ValidationError(_("Invalid value"))
class DataFileUploadSerializer(serializers.Serializer):
"""
Generic serializer for uploading a data file, and extracting a dataset.
- Validates uploaded file
- Extracts column names
- Extracts data rows
"""
# Implementing class should register a target model (database model) to be used for import
TARGET_MODEL = None
class Meta:
fields = [
'data_file',
]
data_file = serializers.FileField(
label=_("Data File"),
help_text=_("Select data file for upload"),
required=True,
allow_empty_file=False,
)
def validate_data_file(self, data_file):
"""
Perform validation checks on the uploaded data file.
"""
self.filename = data_file.name
name, ext = os.path.splitext(data_file.name)
# Remove the leading . from the extension
ext = ext[1:]
accepted_file_types = [
'xls', 'xlsx',
'csv', 'tsv',
'xml',
]
if ext not in accepted_file_types:
raise serializers.ValidationError(_("Unsupported file type"))
# Impose a 50MB limit on uploaded BOM files
max_upload_file_size = 50 * 1024 * 1024
if data_file.size > max_upload_file_size:
raise serializers.ValidationError(_("File is too large"))
# Read file data into memory (bytes object)
try:
data = data_file.read()
except Exception as e:
raise serializers.ValidationError(str(e))
if ext in ['csv', 'tsv', 'xml']:
try:
data = data.decode()
except Exception as e:
raise serializers.ValidationError(str(e))
# Convert to a tablib dataset (we expect headers)
try:
self.dataset = tablib.Dataset().load(data, ext, headers=True)
except Exception as e:
raise serializers.ValidationError(str(e))
if len(self.dataset.headers) == 0:
raise serializers.ValidationError(_("No columns found in file"))
if len(self.dataset) == 0:
raise serializers.ValidationError(_("No data rows found in file"))
return data_file
def match_column(self, column_name, field_names, exact=False):
"""
Attempt to match a column name (from the file) to a field (defined in the model)
Order of matching is:
- Direct match
- Case insensitive match
- Fuzzy match
"""
if not column_name:
return None
column_name = str(column_name).strip()
column_name_lower = column_name.lower()
if column_name in field_names:
return column_name
for field_name in field_names:
if field_name.lower() == column_name_lower:
return field_name
if exact:
# Finished available 'exact' matches
return None
# TODO: Fuzzy pattern matching for column names
# No matches found
return None
def extract_data(self):
"""
Returns dataset extracted from the file
"""
# Provide a dict of available import fields for the model
model_fields = {}
# Keep track of columns we have already extracted
matched_columns = set()
if self.TARGET_MODEL:
try:
model_fields = self.TARGET_MODEL.get_import_fields()
except:
pass
# Extract a list of valid model field names
model_field_names = [key for key in model_fields.keys()]
# Provide a dict of available columns from the dataset
file_columns = {}
for header in self.dataset.headers:
column = {}
# Attempt to "match" file columns to model fields
match = self.match_column(header, model_field_names, exact=True)
if match is not None and match not in matched_columns:
matched_columns.add(match)
column['value'] = match
else:
column['value'] = None
file_columns[header] = column
return {
'file_fields': file_columns,
'model_fields': model_fields,
'rows': [row.values() for row in self.dataset.dict],
'filename': self.filename,
}
def save(self):
...
class DataFileExtractSerializer(serializers.Serializer):
"""
Generic serializer for extracting data from an imported dataset.
- User provides an array of matched headers
- User provides an array of raw data rows
"""
# Implementing class should register a target model (database model) to be used for import
TARGET_MODEL = None
class Meta:
fields = [
'columns',
'rows',
]
# Mapping of columns
columns = serializers.ListField(
child=serializers.CharField(
allow_blank=True,
),
)
rows = serializers.ListField(
child=serializers.ListField(
child=serializers.CharField(
allow_blank=True,
allow_null=True,
),
)
)
def validate(self, data):
data = super().validate(data)
self.columns = data.get('columns', [])
self.rows = data.get('rows', [])
if len(self.rows) == 0:
raise serializers.ValidationError(_("No data rows provided"))
if len(self.columns) == 0:
raise serializers.ValidationError(_("No data columns supplied"))
self.validate_extracted_columns()
return data
@property
def data(self):
if self.TARGET_MODEL:
try:
model_fields = self.TARGET_MODEL.get_import_fields()
except:
model_fields = {}
rows = []
for row in self.rows:
"""
Optionally pre-process each row, before sending back to the client
"""
processed_row = self.process_row(self.row_to_dict(row))
if processed_row:
rows.append({
"original": row,
"data": processed_row,
})
return {
'fields': model_fields,
'columns': self.columns,
'rows': rows,
}
def process_row(self, row):
"""
Process a 'row' of data, which is a mapped column:value dict
Returns either a mapped column:value dict, or None.
If the function returns None, the column is ignored!
"""
# Default implementation simply returns the original row data
return row
def row_to_dict(self, row):
"""
Convert a "row" to a named data dict
"""
row_dict = {
'errors': {},
}
for idx, value in enumerate(row):
if idx < len(self.columns):
col = self.columns[idx]
if col:
row_dict[col] = value
return row_dict
def validate_extracted_columns(self):
"""
Perform custom validation of header mapping.
"""
if self.TARGET_MODEL:
try:
model_fields = self.TARGET_MODEL.get_import_fields()
except:
model_fields = {}
cols_seen = set()
for name, field in model_fields.items():
required = field.get('required', False)
# Check for missing required columns
if required:
if name not in self.columns:
raise serializers.ValidationError(_(f"Missing required column: '{name}'"))
for col in self.columns:
if not col:
continue
# Check for duplicated columns
if col in cols_seen:
raise serializers.ValidationError(_(f"Duplicate column: '{col}'"))
cols_seen.add(col)
def save(self):
"""
No "save" action for this serializer
"""
...