2
0
mirror of https://github.com/inventree/InvenTree.git synced 2025-06-17 12:35:46 +00:00

Merge remote-tracking branch 'inventree/master'

This commit is contained in:
Oliver Walters
2022-10-29 22:10:21 +11:00
4 changed files with 98 additions and 11 deletions

View File

@ -91,6 +91,18 @@ class InvenTreeModelMoneyField(ModelMoneyField):
kwargs['form_class'] = InvenTreeMoneyField kwargs['form_class'] = InvenTreeMoneyField
return super().formfield(**kwargs) return super().formfield(**kwargs)
def to_python(self, value):
"""Convert value to python type."""
value = super().to_python(value)
return round_decimal(value, self.decimal_places)
def prepare_value(self, value):
"""Override the 'prepare_value' method, to remove trailing zeros when displaying.
Why? It looks nice!
"""
return round_decimal(value, self.decimal_places, normalize=True)
class InvenTreeMoneyField(MoneyField): class InvenTreeMoneyField(MoneyField):
"""Custom MoneyField for clean migrations while using dynamic currency settings.""" """Custom MoneyField for clean migrations while using dynamic currency settings."""
@ -126,11 +138,16 @@ class DatePickerFormField(forms.DateField):
) )
def round_decimal(value, places): def round_decimal(value, places, normalize=False):
"""Round value to the specified number of places.""" """Round value to the specified number of places."""
if value is not None:
# see https://docs.python.org/2/library/decimal.html#decimal.Decimal.quantize for options if type(value) in [Decimal, float]:
return value.quantize(Decimal(10) ** -places) value = round(value, places)
if normalize:
# Remove any trailing zeroes
value = InvenTree.helpers.normalize(value)
return value return value
@ -140,18 +157,14 @@ class RoundingDecimalFormField(forms.DecimalField):
def to_python(self, value): def to_python(self, value):
"""Convert value to python type.""" """Convert value to python type."""
value = super().to_python(value) value = super().to_python(value)
value = round_decimal(value, self.decimal_places) return round_decimal(value, self.decimal_places)
return value
def prepare_value(self, value): def prepare_value(self, value):
"""Override the 'prepare_value' method, to remove trailing zeros when displaying. """Override the 'prepare_value' method, to remove trailing zeros when displaying.
Why? It looks nice! Why? It looks nice!
""" """
if type(value) == Decimal: return round_decimal(value, self.decimal_places, normalize=True)
return InvenTree.helpers.normalize(value)
else:
return value
class RoundingDecimalField(models.DecimalField): class RoundingDecimalField(models.DecimalField):

View File

@ -34,13 +34,14 @@ class InvenTreeMoneySerializer(MoneyField):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
"""Overrite default values.""" """Overrite default values."""
kwargs["max_digits"] = kwargs.get("max_digits", 19) kwargs["max_digits"] = kwargs.get("max_digits", 19)
kwargs["decimal_places"] = kwargs.get("decimal_places", 4) self.decimal_places = kwargs["decimal_places"] = kwargs.get("decimal_places", 4)
kwargs["required"] = kwargs.get("required", False) kwargs["required"] = kwargs.get("required", False)
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def get_value(self, data): def get_value(self, data):
"""Test that the returned amount is a valid Decimal.""" """Test that the returned amount is a valid Decimal."""
amount = super(DecimalField, self).get_value(data) amount = super(DecimalField, self).get_value(data)
# Convert an empty string to None # Convert an empty string to None
@ -49,7 +50,9 @@ class InvenTreeMoneySerializer(MoneyField):
try: try:
if amount is not None and amount is not empty: if amount is not None and amount is not empty:
# Convert to a Decimal instance, and round to maximum allowed decimal places
amount = Decimal(amount) amount = Decimal(amount)
amount = round(amount, self.decimal_places)
except Exception: except Exception:
raise ValidationError({ raise ValidationError({
self.field_name: [_("Must be a valid number")], self.field_name: [_("Must be a valid number")],

View File

@ -217,6 +217,7 @@ class ManufacturerTest(InvenTreeAPITestCase):
'part': 1, 'part': 1,
'manufacturer': 7, 'manufacturer': 7,
'MPN': 'PART_NUMBER', 'MPN': 'PART_NUMBER',
'link': 'https://www.axel-larsson.se/Exego.aspx?p_id=341&ArtNr=0804020E',
}, },
expected_code=201 expected_code=201
) )
@ -229,17 +230,24 @@ class ManufacturerTest(InvenTreeAPITestCase):
'supplier': 1, 'supplier': 1,
'SKU': 'SKU_TEST', 'SKU': 'SKU_TEST',
'manufacturer_part': pk, 'manufacturer_part': pk,
'link': 'https://www.axel-larsson.se/Exego.aspx?p_id=341&ArtNr=0804020E',
} }
response = self.client.post(url, data, format='json') response = self.client.post(url, data, format='json')
self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertEqual(response.status_code, status.HTTP_201_CREATED)
# Check link is not modified
self.assertEqual(response.data['link'], 'https://www.axel-larsson.se/Exego.aspx?p_id=341&ArtNr=0804020E')
# Check manufacturer part # Check manufacturer part
manufacturer_part_id = int(response.data['manufacturer_part_detail']['pk']) manufacturer_part_id = int(response.data['manufacturer_part_detail']['pk'])
url = reverse('api-manufacturer-part-detail', kwargs={'pk': manufacturer_part_id}) url = reverse('api-manufacturer-part-detail', kwargs={'pk': manufacturer_part_id})
response = self.get(url) response = self.get(url)
self.assertEqual(response.data['MPN'], 'PART_NUMBER') self.assertEqual(response.data['MPN'], 'PART_NUMBER')
# Check link is not modified
self.assertEqual(response.data['link'], 'https://www.axel-larsson.se/Exego.aspx?p_id=341&ArtNr=0804020E')
class SupplierPartTest(InvenTreeAPITestCase): class SupplierPartTest(InvenTreeAPITestCase):
"""Unit tests for the SupplierPart API endpoints""" """Unit tests for the SupplierPart API endpoints"""

View File

@ -1,5 +1,6 @@
"""Unit tests for the various part API endpoints""" """Unit tests for the various part API endpoints"""
from decimal import Decimal
from random import randint from random import randint
from django.urls import reverse from django.urls import reverse
@ -2430,3 +2431,65 @@ class PartAttachmentTest(InvenTreeAPITestCase):
self.assertEqual(data['part'], 1) self.assertEqual(data['part'], 1)
self.assertEqual(data['link'], link) self.assertEqual(data['link'], link)
self.assertEqual(data['comment'], 'Hello world') self.assertEqual(data['comment'], 'Hello world')
class PartInternalPriceBreakTest(InvenTreeAPITestCase):
"""Unit tests for the PartInternalPrice API endpoints"""
fixtures = [
'category',
'part',
'params',
'location',
'bom',
'company',
'test_templates',
'manufacturer_part',
'supplier_part',
'order',
'stock',
]
roles = [
'part.change',
'part.add',
'part.delete',
'part_category.change',
'part_category.add',
'part_category.delete',
]
def test_create_price_breaks(self):
"""Test we can create price breaks at various quantities"""
url = reverse('api-part-internal-price-list')
breaks = [
(1.0, 101),
(1.1, 92.555555555),
(1.5, 90.999999999),
(1.756, 89),
(2, 86),
(25, 80)
]
for q, p in breaks:
data = self.post(
url,
{
'part': 1,
'quantity': q,
'price': p,
},
expected_code=201
).data
self.assertEqual(data['part'], 1)
self.assertEqual(
round(Decimal(data['quantity']), 4),
round(Decimal(q), 4)
)
self.assertEqual(
round(Decimal(data['price']), 4),
round(Decimal(p), 4)
)