From 55c8b73b0ae09c534eaafd351b6b775b87f70690 Mon Sep 17 00:00:00 2001 From: Oliver Date: Sat, 29 Oct 2022 14:18:19 +1100 Subject: [PATCH] Floating point API bug (#3877) * Add unit tests for internalpricebreak - Exposes an existing bug * Ensure that rounding-decimal and prices are rounded correctly - Force remove trailing digits / reduce precision --- InvenTree/InvenTree/fields.py | 33 +++++++++++----- InvenTree/InvenTree/serializers.py | 5 ++- InvenTree/part/test_api.py | 63 ++++++++++++++++++++++++++++++ 3 files changed, 90 insertions(+), 11 deletions(-) diff --git a/InvenTree/InvenTree/fields.py b/InvenTree/InvenTree/fields.py index d754737cfa..a5ec354bdb 100644 --- a/InvenTree/InvenTree/fields.py +++ b/InvenTree/InvenTree/fields.py @@ -91,6 +91,18 @@ class InvenTreeModelMoneyField(ModelMoneyField): kwargs['form_class'] = InvenTreeMoneyField return super().formfield(**kwargs) + def to_python(self, value): + """Convert value to python type.""" + value = super().to_python(value) + return round_decimal(value, self.decimal_places) + + def prepare_value(self, value): + """Override the 'prepare_value' method, to remove trailing zeros when displaying. + + Why? It looks nice! + """ + return round_decimal(value, self.decimal_places, normalize=True) + class InvenTreeMoneyField(MoneyField): """Custom MoneyField for clean migrations while using dynamic currency settings.""" @@ -126,11 +138,16 @@ class DatePickerFormField(forms.DateField): ) -def round_decimal(value, places): +def round_decimal(value, places, normalize=False): """Round value to the specified number of places.""" - if value is not None: - # see https://docs.python.org/2/library/decimal.html#decimal.Decimal.quantize for options - return value.quantize(Decimal(10) ** -places) + + if type(value) in [Decimal, float]: + value = round(value, places) + + if normalize: + # Remove any trailing zeroes + value = InvenTree.helpers.normalize(value) + return value @@ -140,18 +157,14 @@ class RoundingDecimalFormField(forms.DecimalField): def to_python(self, value): """Convert value to python type.""" value = super().to_python(value) - value = round_decimal(value, self.decimal_places) - return value + return round_decimal(value, self.decimal_places) def prepare_value(self, value): """Override the 'prepare_value' method, to remove trailing zeros when displaying. Why? It looks nice! """ - if type(value) == Decimal: - return InvenTree.helpers.normalize(value) - else: - return value + return round_decimal(value, self.decimal_places, normalize=True) class RoundingDecimalField(models.DecimalField): diff --git a/InvenTree/InvenTree/serializers.py b/InvenTree/InvenTree/serializers.py index 385d2f4837..c97da05284 100644 --- a/InvenTree/InvenTree/serializers.py +++ b/InvenTree/InvenTree/serializers.py @@ -34,13 +34,14 @@ class InvenTreeMoneySerializer(MoneyField): def __init__(self, *args, **kwargs): """Overrite default values.""" kwargs["max_digits"] = kwargs.get("max_digits", 19) - kwargs["decimal_places"] = kwargs.get("decimal_places", 4) + self.decimal_places = kwargs["decimal_places"] = kwargs.get("decimal_places", 4) kwargs["required"] = kwargs.get("required", False) super().__init__(*args, **kwargs) def get_value(self, data): """Test that the returned amount is a valid Decimal.""" + amount = super(DecimalField, self).get_value(data) # Convert an empty string to None @@ -49,7 +50,9 @@ class InvenTreeMoneySerializer(MoneyField): try: if amount is not None and amount is not empty: + # Convert to a Decimal instance, and round to maximum allowed decimal places amount = Decimal(amount) + amount = round(amount, self.decimal_places) except Exception: raise ValidationError({ self.field_name: [_("Must be a valid number")], diff --git a/InvenTree/part/test_api.py b/InvenTree/part/test_api.py index 9555ed5092..635e4305ae 100644 --- a/InvenTree/part/test_api.py +++ b/InvenTree/part/test_api.py @@ -1,5 +1,6 @@ """Unit tests for the various part API endpoints""" +from decimal import Decimal from random import randint from django.urls import reverse @@ -2430,3 +2431,65 @@ class PartAttachmentTest(InvenTreeAPITestCase): self.assertEqual(data['part'], 1) self.assertEqual(data['link'], link) self.assertEqual(data['comment'], 'Hello world') + + +class PartInternalPriceBreakTest(InvenTreeAPITestCase): + """Unit tests for the PartInternalPrice API endpoints""" + + fixtures = [ + 'category', + 'part', + 'params', + 'location', + 'bom', + 'company', + 'test_templates', + 'manufacturer_part', + 'supplier_part', + 'order', + 'stock', + ] + + roles = [ + 'part.change', + 'part.add', + 'part.delete', + 'part_category.change', + 'part_category.add', + 'part_category.delete', + ] + + def test_create_price_breaks(self): + """Test we can create price breaks at various quantities""" + + url = reverse('api-part-internal-price-list') + + breaks = [ + (1.0, 101), + (1.1, 92.555555555), + (1.5, 90.999999999), + (1.756, 89), + (2, 86), + (25, 80) + ] + + for q, p in breaks: + data = self.post( + url, + { + 'part': 1, + 'quantity': q, + 'price': p, + }, + expected_code=201 + ).data + + self.assertEqual(data['part'], 1) + self.assertEqual( + round(Decimal(data['quantity']), 4), + round(Decimal(q), 4) + ) + self.assertEqual( + round(Decimal(data['price']), 4), + round(Decimal(p), 4) + )