mirror of
https://github.com/inventree/InvenTree.git
synced 2025-04-29 12:06:44 +00:00
Floating point API bug (#3877)
* Add unit tests for internalpricebreak - Exposes an existing bug * Ensure that rounding-decimal and prices are rounded correctly - Force remove trailing digits / reduce precision
This commit is contained in:
parent
5263ccdca3
commit
55c8b73b0a
@ -91,6 +91,18 @@ class InvenTreeModelMoneyField(ModelMoneyField):
|
|||||||
kwargs['form_class'] = InvenTreeMoneyField
|
kwargs['form_class'] = InvenTreeMoneyField
|
||||||
return super().formfield(**kwargs)
|
return super().formfield(**kwargs)
|
||||||
|
|
||||||
|
def to_python(self, value):
|
||||||
|
"""Convert value to python type."""
|
||||||
|
value = super().to_python(value)
|
||||||
|
return round_decimal(value, self.decimal_places)
|
||||||
|
|
||||||
|
def prepare_value(self, value):
|
||||||
|
"""Override the 'prepare_value' method, to remove trailing zeros when displaying.
|
||||||
|
|
||||||
|
Why? It looks nice!
|
||||||
|
"""
|
||||||
|
return round_decimal(value, self.decimal_places, normalize=True)
|
||||||
|
|
||||||
|
|
||||||
class InvenTreeMoneyField(MoneyField):
|
class InvenTreeMoneyField(MoneyField):
|
||||||
"""Custom MoneyField for clean migrations while using dynamic currency settings."""
|
"""Custom MoneyField for clean migrations while using dynamic currency settings."""
|
||||||
@ -126,11 +138,16 @@ class DatePickerFormField(forms.DateField):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def round_decimal(value, places):
|
def round_decimal(value, places, normalize=False):
|
||||||
"""Round value to the specified number of places."""
|
"""Round value to the specified number of places."""
|
||||||
if value is not None:
|
|
||||||
# see https://docs.python.org/2/library/decimal.html#decimal.Decimal.quantize for options
|
if type(value) in [Decimal, float]:
|
||||||
return value.quantize(Decimal(10) ** -places)
|
value = round(value, places)
|
||||||
|
|
||||||
|
if normalize:
|
||||||
|
# Remove any trailing zeroes
|
||||||
|
value = InvenTree.helpers.normalize(value)
|
||||||
|
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
@ -140,18 +157,14 @@ class RoundingDecimalFormField(forms.DecimalField):
|
|||||||
def to_python(self, value):
|
def to_python(self, value):
|
||||||
"""Convert value to python type."""
|
"""Convert value to python type."""
|
||||||
value = super().to_python(value)
|
value = super().to_python(value)
|
||||||
value = round_decimal(value, self.decimal_places)
|
return round_decimal(value, self.decimal_places)
|
||||||
return value
|
|
||||||
|
|
||||||
def prepare_value(self, value):
|
def prepare_value(self, value):
|
||||||
"""Override the 'prepare_value' method, to remove trailing zeros when displaying.
|
"""Override the 'prepare_value' method, to remove trailing zeros when displaying.
|
||||||
|
|
||||||
Why? It looks nice!
|
Why? It looks nice!
|
||||||
"""
|
"""
|
||||||
if type(value) == Decimal:
|
return round_decimal(value, self.decimal_places, normalize=True)
|
||||||
return InvenTree.helpers.normalize(value)
|
|
||||||
else:
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
class RoundingDecimalField(models.DecimalField):
|
class RoundingDecimalField(models.DecimalField):
|
||||||
|
@ -34,13 +34,14 @@ class InvenTreeMoneySerializer(MoneyField):
|
|||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
"""Overrite default values."""
|
"""Overrite default values."""
|
||||||
kwargs["max_digits"] = kwargs.get("max_digits", 19)
|
kwargs["max_digits"] = kwargs.get("max_digits", 19)
|
||||||
kwargs["decimal_places"] = kwargs.get("decimal_places", 4)
|
self.decimal_places = kwargs["decimal_places"] = kwargs.get("decimal_places", 4)
|
||||||
kwargs["required"] = kwargs.get("required", False)
|
kwargs["required"] = kwargs.get("required", False)
|
||||||
|
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
def get_value(self, data):
|
def get_value(self, data):
|
||||||
"""Test that the returned amount is a valid Decimal."""
|
"""Test that the returned amount is a valid Decimal."""
|
||||||
|
|
||||||
amount = super(DecimalField, self).get_value(data)
|
amount = super(DecimalField, self).get_value(data)
|
||||||
|
|
||||||
# Convert an empty string to None
|
# Convert an empty string to None
|
||||||
@ -49,7 +50,9 @@ class InvenTreeMoneySerializer(MoneyField):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
if amount is not None and amount is not empty:
|
if amount is not None and amount is not empty:
|
||||||
|
# Convert to a Decimal instance, and round to maximum allowed decimal places
|
||||||
amount = Decimal(amount)
|
amount = Decimal(amount)
|
||||||
|
amount = round(amount, self.decimal_places)
|
||||||
except Exception:
|
except Exception:
|
||||||
raise ValidationError({
|
raise ValidationError({
|
||||||
self.field_name: [_("Must be a valid number")],
|
self.field_name: [_("Must be a valid number")],
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
"""Unit tests for the various part API endpoints"""
|
"""Unit tests for the various part API endpoints"""
|
||||||
|
|
||||||
|
from decimal import Decimal
|
||||||
from random import randint
|
from random import randint
|
||||||
|
|
||||||
from django.urls import reverse
|
from django.urls import reverse
|
||||||
@ -2430,3 +2431,65 @@ class PartAttachmentTest(InvenTreeAPITestCase):
|
|||||||
self.assertEqual(data['part'], 1)
|
self.assertEqual(data['part'], 1)
|
||||||
self.assertEqual(data['link'], link)
|
self.assertEqual(data['link'], link)
|
||||||
self.assertEqual(data['comment'], 'Hello world')
|
self.assertEqual(data['comment'], 'Hello world')
|
||||||
|
|
||||||
|
|
||||||
|
class PartInternalPriceBreakTest(InvenTreeAPITestCase):
|
||||||
|
"""Unit tests for the PartInternalPrice API endpoints"""
|
||||||
|
|
||||||
|
fixtures = [
|
||||||
|
'category',
|
||||||
|
'part',
|
||||||
|
'params',
|
||||||
|
'location',
|
||||||
|
'bom',
|
||||||
|
'company',
|
||||||
|
'test_templates',
|
||||||
|
'manufacturer_part',
|
||||||
|
'supplier_part',
|
||||||
|
'order',
|
||||||
|
'stock',
|
||||||
|
]
|
||||||
|
|
||||||
|
roles = [
|
||||||
|
'part.change',
|
||||||
|
'part.add',
|
||||||
|
'part.delete',
|
||||||
|
'part_category.change',
|
||||||
|
'part_category.add',
|
||||||
|
'part_category.delete',
|
||||||
|
]
|
||||||
|
|
||||||
|
def test_create_price_breaks(self):
|
||||||
|
"""Test we can create price breaks at various quantities"""
|
||||||
|
|
||||||
|
url = reverse('api-part-internal-price-list')
|
||||||
|
|
||||||
|
breaks = [
|
||||||
|
(1.0, 101),
|
||||||
|
(1.1, 92.555555555),
|
||||||
|
(1.5, 90.999999999),
|
||||||
|
(1.756, 89),
|
||||||
|
(2, 86),
|
||||||
|
(25, 80)
|
||||||
|
]
|
||||||
|
|
||||||
|
for q, p in breaks:
|
||||||
|
data = self.post(
|
||||||
|
url,
|
||||||
|
{
|
||||||
|
'part': 1,
|
||||||
|
'quantity': q,
|
||||||
|
'price': p,
|
||||||
|
},
|
||||||
|
expected_code=201
|
||||||
|
).data
|
||||||
|
|
||||||
|
self.assertEqual(data['part'], 1)
|
||||||
|
self.assertEqual(
|
||||||
|
round(Decimal(data['quantity']), 4),
|
||||||
|
round(Decimal(q), 4)
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
round(Decimal(data['price']), 4),
|
||||||
|
round(Decimal(p), 4)
|
||||||
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user