2
0
mirror of https://github.com/inventree/InvenTree.git synced 2025-10-24 18:07:38 +00:00

Floating point API bug (#3877)

* Add unit tests for internalpricebreak

- Exposes an existing bug

* Ensure that rounding-decimal and prices are rounded correctly

- Force remove trailing digits / reduce precision
This commit is contained in:
Oliver
2022-10-29 14:18:19 +11:00
committed by GitHub
parent 5263ccdca3
commit 55c8b73b0a
3 changed files with 90 additions and 11 deletions

View File

@@ -91,6 +91,18 @@ class InvenTreeModelMoneyField(ModelMoneyField):
kwargs['form_class'] = InvenTreeMoneyField kwargs['form_class'] = InvenTreeMoneyField
return super().formfield(**kwargs) return super().formfield(**kwargs)
def to_python(self, value):
"""Convert value to python type."""
value = super().to_python(value)
return round_decimal(value, self.decimal_places)
def prepare_value(self, value):
"""Override the 'prepare_value' method, to remove trailing zeros when displaying.
Why? It looks nice!
"""
return round_decimal(value, self.decimal_places, normalize=True)
class InvenTreeMoneyField(MoneyField): class InvenTreeMoneyField(MoneyField):
"""Custom MoneyField for clean migrations while using dynamic currency settings.""" """Custom MoneyField for clean migrations while using dynamic currency settings."""
@@ -126,11 +138,16 @@ class DatePickerFormField(forms.DateField):
) )
def round_decimal(value, places): def round_decimal(value, places, normalize=False):
"""Round value to the specified number of places.""" """Round value to the specified number of places."""
if value is not None:
# see https://docs.python.org/2/library/decimal.html#decimal.Decimal.quantize for options if type(value) in [Decimal, float]:
return value.quantize(Decimal(10) ** -places) value = round(value, places)
if normalize:
# Remove any trailing zeroes
value = InvenTree.helpers.normalize(value)
return value return value
@@ -140,18 +157,14 @@ class RoundingDecimalFormField(forms.DecimalField):
def to_python(self, value): def to_python(self, value):
"""Convert value to python type.""" """Convert value to python type."""
value = super().to_python(value) value = super().to_python(value)
value = round_decimal(value, self.decimal_places) return round_decimal(value, self.decimal_places)
return value
def prepare_value(self, value): def prepare_value(self, value):
"""Override the 'prepare_value' method, to remove trailing zeros when displaying. """Override the 'prepare_value' method, to remove trailing zeros when displaying.
Why? It looks nice! Why? It looks nice!
""" """
if type(value) == Decimal: return round_decimal(value, self.decimal_places, normalize=True)
return InvenTree.helpers.normalize(value)
else:
return value
class RoundingDecimalField(models.DecimalField): class RoundingDecimalField(models.DecimalField):

View File

@@ -34,13 +34,14 @@ class InvenTreeMoneySerializer(MoneyField):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
"""Overrite default values.""" """Overrite default values."""
kwargs["max_digits"] = kwargs.get("max_digits", 19) kwargs["max_digits"] = kwargs.get("max_digits", 19)
kwargs["decimal_places"] = kwargs.get("decimal_places", 4) self.decimal_places = kwargs["decimal_places"] = kwargs.get("decimal_places", 4)
kwargs["required"] = kwargs.get("required", False) kwargs["required"] = kwargs.get("required", False)
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def get_value(self, data): def get_value(self, data):
"""Test that the returned amount is a valid Decimal.""" """Test that the returned amount is a valid Decimal."""
amount = super(DecimalField, self).get_value(data) amount = super(DecimalField, self).get_value(data)
# Convert an empty string to None # Convert an empty string to None
@@ -49,7 +50,9 @@ class InvenTreeMoneySerializer(MoneyField):
try: try:
if amount is not None and amount is not empty: if amount is not None and amount is not empty:
# Convert to a Decimal instance, and round to maximum allowed decimal places
amount = Decimal(amount) amount = Decimal(amount)
amount = round(amount, self.decimal_places)
except Exception: except Exception:
raise ValidationError({ raise ValidationError({
self.field_name: [_("Must be a valid number")], self.field_name: [_("Must be a valid number")],

View File

@@ -1,5 +1,6 @@
"""Unit tests for the various part API endpoints""" """Unit tests for the various part API endpoints"""
from decimal import Decimal
from random import randint from random import randint
from django.urls import reverse from django.urls import reverse
@@ -2430,3 +2431,65 @@ class PartAttachmentTest(InvenTreeAPITestCase):
self.assertEqual(data['part'], 1) self.assertEqual(data['part'], 1)
self.assertEqual(data['link'], link) self.assertEqual(data['link'], link)
self.assertEqual(data['comment'], 'Hello world') self.assertEqual(data['comment'], 'Hello world')
class PartInternalPriceBreakTest(InvenTreeAPITestCase):
"""Unit tests for the PartInternalPrice API endpoints"""
fixtures = [
'category',
'part',
'params',
'location',
'bom',
'company',
'test_templates',
'manufacturer_part',
'supplier_part',
'order',
'stock',
]
roles = [
'part.change',
'part.add',
'part.delete',
'part_category.change',
'part_category.add',
'part_category.delete',
]
def test_create_price_breaks(self):
"""Test we can create price breaks at various quantities"""
url = reverse('api-part-internal-price-list')
breaks = [
(1.0, 101),
(1.1, 92.555555555),
(1.5, 90.999999999),
(1.756, 89),
(2, 86),
(25, 80)
]
for q, p in breaks:
data = self.post(
url,
{
'part': 1,
'quantity': q,
'price': p,
},
expected_code=201
).data
self.assertEqual(data['part'], 1)
self.assertEqual(
round(Decimal(data['quantity']), 4),
round(Decimal(q), 4)
)
self.assertEqual(
round(Decimal(data['price']), 4),
round(Decimal(p), 4)
)