mirror of
				https://github.com/inventree/InvenTree.git
				synced 2025-11-04 15:15:42 +00:00 
			
		
		
		
	Floating point API bug (#3877)
* Add unit tests for internalpricebreak - Exposes an existing bug * Ensure that rounding-decimal and prices are rounded correctly - Force remove trailing digits / reduce precision
This commit is contained in:
		@@ -91,6 +91,18 @@ class InvenTreeModelMoneyField(ModelMoneyField):
 | 
			
		||||
        kwargs['form_class'] = InvenTreeMoneyField
 | 
			
		||||
        return super().formfield(**kwargs)
 | 
			
		||||
 | 
			
		||||
    def to_python(self, value):
 | 
			
		||||
        """Convert value to python type."""
 | 
			
		||||
        value = super().to_python(value)
 | 
			
		||||
        return round_decimal(value, self.decimal_places)
 | 
			
		||||
 | 
			
		||||
    def prepare_value(self, value):
 | 
			
		||||
        """Override the 'prepare_value' method, to remove trailing zeros when displaying.
 | 
			
		||||
 | 
			
		||||
        Why? It looks nice!
 | 
			
		||||
        """
 | 
			
		||||
        return round_decimal(value, self.decimal_places, normalize=True)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class InvenTreeMoneyField(MoneyField):
 | 
			
		||||
    """Custom MoneyField for clean migrations while using dynamic currency settings."""
 | 
			
		||||
@@ -126,11 +138,16 @@ class DatePickerFormField(forms.DateField):
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def round_decimal(value, places):
 | 
			
		||||
def round_decimal(value, places, normalize=False):
 | 
			
		||||
    """Round value to the specified number of places."""
 | 
			
		||||
    if value is not None:
 | 
			
		||||
        # see https://docs.python.org/2/library/decimal.html#decimal.Decimal.quantize for options
 | 
			
		||||
        return value.quantize(Decimal(10) ** -places)
 | 
			
		||||
 | 
			
		||||
    if type(value) in [Decimal, float]:
 | 
			
		||||
        value = round(value, places)
 | 
			
		||||
 | 
			
		||||
        if normalize:
 | 
			
		||||
            # Remove any trailing zeroes
 | 
			
		||||
            value = InvenTree.helpers.normalize(value)
 | 
			
		||||
 | 
			
		||||
    return value
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@@ -140,18 +157,14 @@ class RoundingDecimalFormField(forms.DecimalField):
 | 
			
		||||
    def to_python(self, value):
 | 
			
		||||
        """Convert value to python type."""
 | 
			
		||||
        value = super().to_python(value)
 | 
			
		||||
        value = round_decimal(value, self.decimal_places)
 | 
			
		||||
        return value
 | 
			
		||||
        return round_decimal(value, self.decimal_places)
 | 
			
		||||
 | 
			
		||||
    def prepare_value(self, value):
 | 
			
		||||
        """Override the 'prepare_value' method, to remove trailing zeros when displaying.
 | 
			
		||||
 | 
			
		||||
        Why? It looks nice!
 | 
			
		||||
        """
 | 
			
		||||
        if type(value) == Decimal:
 | 
			
		||||
            return InvenTree.helpers.normalize(value)
 | 
			
		||||
        else:
 | 
			
		||||
            return value
 | 
			
		||||
        return round_decimal(value, self.decimal_places, normalize=True)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RoundingDecimalField(models.DecimalField):
 | 
			
		||||
 
 | 
			
		||||
@@ -34,13 +34,14 @@ class InvenTreeMoneySerializer(MoneyField):
 | 
			
		||||
    def __init__(self, *args, **kwargs):
 | 
			
		||||
        """Overrite default values."""
 | 
			
		||||
        kwargs["max_digits"] = kwargs.get("max_digits", 19)
 | 
			
		||||
        kwargs["decimal_places"] = kwargs.get("decimal_places", 4)
 | 
			
		||||
        self.decimal_places = kwargs["decimal_places"] = kwargs.get("decimal_places", 4)
 | 
			
		||||
        kwargs["required"] = kwargs.get("required", False)
 | 
			
		||||
 | 
			
		||||
        super().__init__(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
    def get_value(self, data):
 | 
			
		||||
        """Test that the returned amount is a valid Decimal."""
 | 
			
		||||
 | 
			
		||||
        amount = super(DecimalField, self).get_value(data)
 | 
			
		||||
 | 
			
		||||
        # Convert an empty string to None
 | 
			
		||||
@@ -49,7 +50,9 @@ class InvenTreeMoneySerializer(MoneyField):
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            if amount is not None and amount is not empty:
 | 
			
		||||
                # Convert to a Decimal instance, and round to maximum allowed decimal places
 | 
			
		||||
                amount = Decimal(amount)
 | 
			
		||||
                amount = round(amount, self.decimal_places)
 | 
			
		||||
        except Exception:
 | 
			
		||||
            raise ValidationError({
 | 
			
		||||
                self.field_name: [_("Must be a valid number")],
 | 
			
		||||
 
 | 
			
		||||
@@ -1,5 +1,6 @@
 | 
			
		||||
"""Unit tests for the various part API endpoints"""
 | 
			
		||||
 | 
			
		||||
from decimal import Decimal
 | 
			
		||||
from random import randint
 | 
			
		||||
 | 
			
		||||
from django.urls import reverse
 | 
			
		||||
@@ -2430,3 +2431,65 @@ class PartAttachmentTest(InvenTreeAPITestCase):
 | 
			
		||||
        self.assertEqual(data['part'], 1)
 | 
			
		||||
        self.assertEqual(data['link'], link)
 | 
			
		||||
        self.assertEqual(data['comment'], 'Hello world')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PartInternalPriceBreakTest(InvenTreeAPITestCase):
 | 
			
		||||
    """Unit tests for the PartInternalPrice API endpoints"""
 | 
			
		||||
 | 
			
		||||
    fixtures = [
 | 
			
		||||
        'category',
 | 
			
		||||
        'part',
 | 
			
		||||
        'params',
 | 
			
		||||
        'location',
 | 
			
		||||
        'bom',
 | 
			
		||||
        'company',
 | 
			
		||||
        'test_templates',
 | 
			
		||||
        'manufacturer_part',
 | 
			
		||||
        'supplier_part',
 | 
			
		||||
        'order',
 | 
			
		||||
        'stock',
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
    roles = [
 | 
			
		||||
        'part.change',
 | 
			
		||||
        'part.add',
 | 
			
		||||
        'part.delete',
 | 
			
		||||
        'part_category.change',
 | 
			
		||||
        'part_category.add',
 | 
			
		||||
        'part_category.delete',
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
    def test_create_price_breaks(self):
 | 
			
		||||
        """Test we can create price breaks at various quantities"""
 | 
			
		||||
 | 
			
		||||
        url = reverse('api-part-internal-price-list')
 | 
			
		||||
 | 
			
		||||
        breaks = [
 | 
			
		||||
            (1.0, 101),
 | 
			
		||||
            (1.1, 92.555555555),
 | 
			
		||||
            (1.5, 90.999999999),
 | 
			
		||||
            (1.756, 89),
 | 
			
		||||
            (2, 86),
 | 
			
		||||
            (25, 80)
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
        for q, p in breaks:
 | 
			
		||||
            data = self.post(
 | 
			
		||||
                url,
 | 
			
		||||
                {
 | 
			
		||||
                    'part': 1,
 | 
			
		||||
                    'quantity': q,
 | 
			
		||||
                    'price': p,
 | 
			
		||||
                },
 | 
			
		||||
                expected_code=201
 | 
			
		||||
            ).data
 | 
			
		||||
 | 
			
		||||
            self.assertEqual(data['part'], 1)
 | 
			
		||||
            self.assertEqual(
 | 
			
		||||
                round(Decimal(data['quantity']), 4),
 | 
			
		||||
                round(Decimal(q), 4)
 | 
			
		||||
            )
 | 
			
		||||
            self.assertEqual(
 | 
			
		||||
                round(Decimal(data['price']), 4),
 | 
			
		||||
                round(Decimal(p), 4)
 | 
			
		||||
            )
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user