2
0
mirror of https://github.com/inventree/InvenTree.git synced 2025-06-18 13:05:42 +00:00

Add bleach (#41) (#3204)

* use shims for API view inheritation

* Add mixin for input sanitation

* fix clean operation to fix all string values

* Also clean up dicts
this is to future-proof this function

* Update docstirng

* proof custom methods against XSS through authenticated users
This commit is contained in:
Matthias Mair
2022-06-16 02:01:53 +02:00
committed by GitHub
parent f8a2760955
commit e83995b4f5
12 changed files with 310 additions and 178 deletions

View File

@ -14,7 +14,7 @@ from django_filters.rest_framework import DjangoFilterBackend
from djmoney.contrib.exchange.exceptions import MissingRate
from djmoney.contrib.exchange.models import convert_money
from djmoney.money import Money
from rest_framework import filters, generics, serializers, status
from rest_framework import filters, serializers, status
from rest_framework.exceptions import ValidationError
from rest_framework.response import Response
@ -25,6 +25,9 @@ from company.models import Company, ManufacturerPart, SupplierPart
from InvenTree.api import (APIDownloadMixin, AttachmentMixin,
ListCreateDestroyAPIView)
from InvenTree.helpers import DownloadFile, increment, isNull, str2bool
from InvenTree.mixins import (CreateAPI, ListAPI, ListCreateAPI, RetrieveAPI,
RetrieveUpdateAPI, RetrieveUpdateDestroyAPI,
UpdateAPI)
from InvenTree.status_codes import (BuildStatus, PurchaseOrderStatus,
SalesOrderStatus)
from part.admin import PartResource
@ -39,7 +42,7 @@ from .models import (BomItem, BomItemSubstitute, Part, PartAttachment,
PartTestTemplate)
class CategoryList(generics.ListCreateAPIView):
class CategoryList(ListCreateAPI):
"""API endpoint for accessing a list of PartCategory objects.
- GET: Return a list of PartCategory objects
@ -155,7 +158,7 @@ class CategoryList(generics.ListCreateAPIView):
]
class CategoryDetail(generics.RetrieveUpdateDestroyAPIView):
class CategoryDetail(RetrieveUpdateDestroyAPI):
"""API endpoint for detail view of a single PartCategory object."""
serializer_class = part_serializers.CategorySerializer
@ -175,8 +178,11 @@ class CategoryDetail(generics.RetrieveUpdateDestroyAPIView):
def update(self, request, *args, **kwargs):
"""Perform 'update' function and mark this part as 'starred' (or not)"""
if 'starred' in request.data:
starred = str2bool(request.data.get('starred', False))
# Clean up input data
data = self.clean_data(request.data)
if 'starred' in data:
starred = str2bool(data.get('starred', False))
self.get_object().set_starred(request.user, starred)
@ -185,7 +191,7 @@ class CategoryDetail(generics.RetrieveUpdateDestroyAPIView):
return response
class CategoryMetadata(generics.RetrieveUpdateAPIView):
class CategoryMetadata(RetrieveUpdateAPI):
"""API endpoint for viewing / updating PartCategory metadata."""
def get_serializer(self, *args, **kwargs):
@ -195,7 +201,7 @@ class CategoryMetadata(generics.RetrieveUpdateAPIView):
queryset = PartCategory.objects.all()
class CategoryParameterList(generics.ListCreateAPIView):
class CategoryParameterList(ListCreateAPI):
"""API endpoint for accessing a list of PartCategoryParameterTemplate objects.
- GET: Return a list of PartCategoryParameterTemplate objects
@ -236,14 +242,14 @@ class CategoryParameterList(generics.ListCreateAPIView):
return queryset
class CategoryParameterDetail(generics.RetrieveUpdateDestroyAPIView):
class CategoryParameterDetail(RetrieveUpdateDestroyAPI):
"""Detail endpoint fro the PartCategoryParameterTemplate model"""
queryset = PartCategoryParameterTemplate.objects.all()
serializer_class = part_serializers.CategoryParameterTemplateSerializer
class CategoryTree(generics.ListAPIView):
class CategoryTree(ListAPI):
"""API endpoint for accessing a list of PartCategory objects ready for rendering a tree."""
queryset = PartCategory.objects.all()
@ -258,14 +264,14 @@ class CategoryTree(generics.ListAPIView):
ordering = ['level', 'name']
class PartSalePriceDetail(generics.RetrieveUpdateDestroyAPIView):
class PartSalePriceDetail(RetrieveUpdateDestroyAPI):
"""Detail endpoint for PartSellPriceBreak model."""
queryset = PartSellPriceBreak.objects.all()
serializer_class = part_serializers.PartSalePriceSerializer
class PartSalePriceList(generics.ListCreateAPIView):
class PartSalePriceList(ListCreateAPI):
"""API endpoint for list view of PartSalePriceBreak model."""
queryset = PartSellPriceBreak.objects.all()
@ -280,14 +286,14 @@ class PartSalePriceList(generics.ListCreateAPIView):
]
class PartInternalPriceDetail(generics.RetrieveUpdateDestroyAPIView):
class PartInternalPriceDetail(RetrieveUpdateDestroyAPI):
"""Detail endpoint for PartInternalPriceBreak model."""
queryset = PartInternalPriceBreak.objects.all()
serializer_class = part_serializers.PartInternalPriceSerializer
class PartInternalPriceList(generics.ListCreateAPIView):
class PartInternalPriceList(ListCreateAPI):
"""API endpoint for list view of PartInternalPriceBreak model."""
queryset = PartInternalPriceBreak.objects.all()
@ -318,21 +324,21 @@ class PartAttachmentList(AttachmentMixin, ListCreateDestroyAPIView):
]
class PartAttachmentDetail(AttachmentMixin, generics.RetrieveUpdateDestroyAPIView):
class PartAttachmentDetail(AttachmentMixin, RetrieveUpdateDestroyAPI):
"""Detail endpoint for PartAttachment model."""
queryset = PartAttachment.objects.all()
serializer_class = part_serializers.PartAttachmentSerializer
class PartTestTemplateDetail(generics.RetrieveUpdateDestroyAPIView):
class PartTestTemplateDetail(RetrieveUpdateDestroyAPI):
"""Detail endpoint for PartTestTemplate model."""
queryset = PartTestTemplate.objects.all()
serializer_class = part_serializers.PartTestTemplateSerializer
class PartTestTemplateList(generics.ListCreateAPIView):
class PartTestTemplateList(ListCreateAPI):
"""API endpoint for listing (and creating) a PartTestTemplate."""
queryset = PartTestTemplate.objects.all()
@ -372,7 +378,7 @@ class PartTestTemplateList(generics.ListCreateAPIView):
]
class PartThumbs(generics.ListAPIView):
class PartThumbs(ListAPI):
"""API endpoint for retrieving information on available Part thumbnails."""
queryset = Part.objects.all()
@ -415,7 +421,7 @@ class PartThumbs(generics.ListAPIView):
]
class PartThumbsUpdate(generics.RetrieveUpdateAPIView):
class PartThumbsUpdate(RetrieveUpdateAPI):
"""API endpoint for updating Part thumbnails."""
queryset = Part.objects.all()
@ -426,7 +432,7 @@ class PartThumbsUpdate(generics.RetrieveUpdateAPIView):
]
class PartScheduling(generics.RetrieveAPIView):
class PartScheduling(RetrieveAPI):
"""API endpoint for delivering "scheduling" information about a given part via the API.
Returns a chronologically ordered list about future "scheduled" events,
@ -560,7 +566,7 @@ class PartScheduling(generics.RetrieveAPIView):
return Response(schedule)
class PartMetadata(generics.RetrieveUpdateAPIView):
class PartMetadata(RetrieveUpdateAPI):
"""API endpoint for viewing / updating Part metadata."""
def get_serializer(self, *args, **kwargs):
@ -570,7 +576,7 @@ class PartMetadata(generics.RetrieveUpdateAPIView):
queryset = Part.objects.all()
class PartSerialNumberDetail(generics.RetrieveAPIView):
class PartSerialNumberDetail(RetrieveAPI):
"""API endpoint for returning extra serial number information about a particular part."""
queryset = Part.objects.all()
@ -595,7 +601,7 @@ class PartSerialNumberDetail(generics.RetrieveAPIView):
return Response(data)
class PartCopyBOM(generics.CreateAPIView):
class PartCopyBOM(CreateAPI):
"""API endpoint for duplicating a BOM."""
queryset = Part.objects.all()
@ -613,7 +619,7 @@ class PartCopyBOM(generics.CreateAPIView):
return ctx
class PartValidateBOM(generics.RetrieveUpdateAPIView):
class PartValidateBOM(RetrieveUpdateAPI):
"""API endpoint for 'validating' the BOM for a given Part."""
class BOMValidateSerializer(serializers.ModelSerializer):
@ -654,7 +660,10 @@ class PartValidateBOM(generics.RetrieveUpdateAPIView):
partial = kwargs.pop('partial', False)
serializer = self.get_serializer(part, data=request.data, partial=partial)
# Clean up input data before using it
data = self.clean_data(request.data)
serializer = self.get_serializer(part, data=data, partial=partial)
serializer.is_valid(raise_exception=True)
part.validate_bom(request.user)
@ -664,7 +673,7 @@ class PartValidateBOM(generics.RetrieveUpdateAPIView):
})
class PartDetail(generics.RetrieveUpdateDestroyAPIView):
class PartDetail(RetrieveUpdateDestroyAPI):
"""API endpoint for detail view of a single Part object."""
queryset = Part.objects.all()
@ -721,8 +730,11 @@ class PartDetail(generics.RetrieveUpdateDestroyAPIView):
- If the 'starred' field is provided, update the 'starred' status against current user
"""
if 'starred' in request.data:
starred = str2bool(request.data.get('starred', False))
# Clean input data
data = self.clean_data(request.data)
if 'starred' in data:
starred = str2bool(data.get('starred', False))
self.get_object().set_starred(request.user, starred)
@ -874,7 +886,7 @@ class PartFilter(rest_filters.FilterSet):
virtual = rest_filters.BooleanFilter()
class PartList(APIDownloadMixin, generics.ListCreateAPIView):
class PartList(APIDownloadMixin, ListCreateAPI):
"""API endpoint for accessing a list of Part objects.
- GET: Return list of objects
@ -1003,7 +1015,10 @@ class PartList(APIDownloadMixin, generics.ListCreateAPIView):
"""
# TODO: Unit tests for this function!
serializer = self.get_serializer(data=request.data)
# Clean up input data
data = self.clean_data(request.data)
serializer = self.get_serializer(data=data)
serializer.is_valid(raise_exception=True)
part = serializer.save()
@ -1011,23 +1026,23 @@ class PartList(APIDownloadMixin, generics.ListCreateAPIView):
# Optionally copy templates from category or parent category
copy_templates = {
'main': str2bool(request.data.get('copy_category_templates', False)),
'parent': str2bool(request.data.get('copy_parent_templates', False))
'main': str2bool(data.get('copy_category_templates', False)),
'parent': str2bool(data.get('copy_parent_templates', False))
}
part.save(**{'add_category_templates': copy_templates})
# Optionally copy data from another part (e.g. when duplicating)
copy_from = request.data.get('copy_from', None)
copy_from = data.get('copy_from', None)
if copy_from is not None:
try:
original = Part.objects.get(pk=copy_from)
copy_bom = str2bool(request.data.get('copy_bom', False))
copy_parameters = str2bool(request.data.get('copy_parameters', False))
copy_image = str2bool(request.data.get('copy_image', True))
copy_bom = str2bool(data.get('copy_bom', False))
copy_parameters = str2bool(data.get('copy_parameters', False))
copy_image = str2bool(data.get('copy_image', True))
# Copy image?
if copy_image:
@ -1046,12 +1061,12 @@ class PartList(APIDownloadMixin, generics.ListCreateAPIView):
pass
# Optionally create initial stock item
initial_stock = str2bool(request.data.get('initial_stock', False))
initial_stock = str2bool(data.get('initial_stock', False))
if initial_stock:
try:
initial_stock_quantity = Decimal(request.data.get('initial_stock_quantity', ''))
initial_stock_quantity = Decimal(data.get('initial_stock_quantity', ''))
if initial_stock_quantity <= 0:
raise ValidationError({
@ -1062,7 +1077,7 @@ class PartList(APIDownloadMixin, generics.ListCreateAPIView):
'initial_stock_quantity': [_('Must be a valid quantity')],
})
initial_stock_location = request.data.get('initial_stock_location', None)
initial_stock_location = data.get('initial_stock_location', None)
try:
initial_stock_location = StockLocation.objects.get(pk=initial_stock_location)
@ -1086,20 +1101,20 @@ class PartList(APIDownloadMixin, generics.ListCreateAPIView):
stock_item.save(user=request.user)
# Optionally add manufacturer / supplier data to the part
if part.purchaseable and str2bool(request.data.get('add_supplier_info', False)):
if part.purchaseable and str2bool(data.get('add_supplier_info', False)):
try:
manufacturer = Company.objects.get(pk=request.data.get('manufacturer', None))
manufacturer = Company.objects.get(pk=data.get('manufacturer', None))
except Exception:
manufacturer = None
try:
supplier = Company.objects.get(pk=request.data.get('supplier', None))
supplier = Company.objects.get(pk=data.get('supplier', None))
except Exception:
supplier = None
mpn = str(request.data.get('MPN', '')).strip()
sku = str(request.data.get('SKU', '')).strip()
mpn = str(data.get('MPN', '')).strip()
sku = str(data.get('SKU', '')).strip()
# Construct a manufacturer part
if manufacturer or mpn:
@ -1347,7 +1362,7 @@ class PartList(APIDownloadMixin, generics.ListCreateAPIView):
]
class PartRelatedList(generics.ListCreateAPIView):
class PartRelatedList(ListCreateAPI):
"""API endpoint for accessing a list of PartRelated objects."""
queryset = PartRelated.objects.all()
@ -1374,14 +1389,14 @@ class PartRelatedList(generics.ListCreateAPIView):
return queryset
class PartRelatedDetail(generics.RetrieveUpdateDestroyAPIView):
class PartRelatedDetail(RetrieveUpdateDestroyAPI):
"""API endpoint for accessing detail view of a PartRelated object."""
queryset = PartRelated.objects.all()
serializer_class = part_serializers.PartRelationSerializer
class PartParameterTemplateList(generics.ListCreateAPIView):
class PartParameterTemplateList(ListCreateAPI):
"""API endpoint for accessing a list of PartParameterTemplate objects.
- GET: Return list of PartParameterTemplate objects
@ -1441,14 +1456,14 @@ class PartParameterTemplateList(generics.ListCreateAPIView):
return queryset
class PartParameterTemplateDetail(generics.RetrieveUpdateDestroyAPIView):
class PartParameterTemplateDetail(RetrieveUpdateDestroyAPI):
"""API endpoint for accessing the detail view for a PartParameterTemplate object"""
queryset = PartParameterTemplate.objects.all()
serializer_class = part_serializers.PartParameterTemplateSerializer
class PartParameterList(generics.ListCreateAPIView):
class PartParameterList(ListCreateAPI):
"""API endpoint for accessing a list of PartParameter objects.
- GET: Return list of PartParameter objects
@ -1468,7 +1483,7 @@ class PartParameterList(generics.ListCreateAPIView):
]
class PartParameterDetail(generics.RetrieveUpdateDestroyAPIView):
class PartParameterDetail(RetrieveUpdateDestroyAPI):
"""API endpoint for detail view of a single PartParameter object."""
queryset = PartParameter.objects.all()
@ -1747,7 +1762,7 @@ class BomList(ListCreateDestroyAPIView):
]
class BomImportUpload(generics.CreateAPIView):
class BomImportUpload(CreateAPI):
"""API endpoint for uploading a complete Bill of Materials.
It is assumed that the BOM has been extracted from a file using the BomExtract endpoint.
@ -1758,7 +1773,10 @@ class BomImportUpload(generics.CreateAPIView):
def create(self, request, *args, **kwargs):
"""Custom create function to return the extracted data."""
serializer = self.get_serializer(data=request.data)
# Clean up input data
data = self.clean_data(request.data)
serializer = self.get_serializer(data=data)
serializer.is_valid(raise_exception=True)
self.perform_create(serializer)
headers = self.get_success_headers(serializer.data)
@ -1768,21 +1786,21 @@ class BomImportUpload(generics.CreateAPIView):
return Response(data, status=status.HTTP_201_CREATED, headers=headers)
class BomImportExtract(generics.CreateAPIView):
class BomImportExtract(CreateAPI):
"""API endpoint for extracting BOM data from a BOM file."""
queryset = Part.objects.none()
serializer_class = part_serializers.BomImportExtractSerializer
class BomImportSubmit(generics.CreateAPIView):
class BomImportSubmit(CreateAPI):
"""API endpoint for submitting BOM data from a BOM file."""
queryset = BomItem.objects.none()
serializer_class = part_serializers.BomImportSubmitSerializer
class BomDetail(generics.RetrieveUpdateDestroyAPIView):
class BomDetail(RetrieveUpdateDestroyAPI):
"""API endpoint for detail view of a single BomItem object."""
queryset = BomItem.objects.all()
@ -1798,7 +1816,7 @@ class BomDetail(generics.RetrieveUpdateDestroyAPIView):
return queryset
class BomItemValidate(generics.UpdateAPIView):
class BomItemValidate(UpdateAPI):
"""API endpoint for validating a BomItem."""
class BomItemValidationSerializer(serializers.Serializer):
@ -1812,11 +1830,13 @@ class BomItemValidate(generics.UpdateAPIView):
"""Perform update request."""
partial = kwargs.pop('partial', False)
valid = request.data.get('valid', False)
# Clean up input data
data = self.clean_data(request.data)
valid = data.get('valid', False)
instance = self.get_object()
serializer = self.get_serializer(instance, data=request.data, partial=partial)
serializer = self.get_serializer(instance, data=data, partial=partial)
serializer.is_valid(raise_exception=True)
if type(instance) == BomItem:
@ -1825,7 +1845,7 @@ class BomItemValidate(generics.UpdateAPIView):
return Response(serializer.data)
class BomItemSubstituteList(generics.ListCreateAPIView):
class BomItemSubstituteList(ListCreateAPI):
"""API endpoint for accessing a list of BomItemSubstitute objects."""
serializer_class = part_serializers.BomItemSubstituteSerializer
@ -1843,7 +1863,7 @@ class BomItemSubstituteList(generics.ListCreateAPIView):
]
class BomItemSubstituteDetail(generics.RetrieveUpdateDestroyAPIView):
class BomItemSubstituteDetail(RetrieveUpdateDestroyAPI):
"""API endpoint for detail view of a single BomItemSubstitute object."""
queryset = BomItemSubstitute.objects.all()