From 47cbf3071d82d09f235344b51c5246ad8788987d Mon Sep 17 00:00:00 2001
From: Oliver Walters <oliver.henry.walters@gmail.com>
Date: Thu, 12 Nov 2020 21:36:32 +1100
Subject: [PATCH] Add option to add a single-quantity price-break when creating
 a new SupplierPart object

- Add unit testing!
---
 InvenTree/company/forms.py        |   1 -
 InvenTree/company/models.py       |  19 +++++
 InvenTree/company/test_views.py   | 115 ++++++++++++++++++++++++++++--
 InvenTree/company/views.py        |  32 +++++++++
 InvenTree/part/fixtures/part.yaml |   1 +
 5 files changed, 162 insertions(+), 6 deletions(-)

diff --git a/InvenTree/company/forms.py b/InvenTree/company/forms.py
index 8ee434050f..0ad95c3e8c 100644
--- a/InvenTree/company/forms.py
+++ b/InvenTree/company/forms.py
@@ -14,7 +14,6 @@ import django.forms
 import djmoney.settings
 from djmoney.forms.fields import MoneyField
 
-from common.models import InvenTreeSetting
 import common.settings
 
 from .models import Company
diff --git a/InvenTree/company/models.py b/InvenTree/company/models.py
index 241c5fe1f7..81718a9acd 100644
--- a/InvenTree/company/models.py
+++ b/InvenTree/company/models.py
@@ -380,6 +380,25 @@ class SupplierPart(models.Model):
     def unit_pricing(self):
         return self.get_price(1)
 
+    def add_price_break(self, quantity, price):
+        """
+        Create a new price break for this part
+
+        args:
+            quantity - Numerical quantity
+            price - Must be a Money object
+        """
+
+        # Check if a price break at that quantity already exists...
+        if self.price_breaks.filter(quantity=quantity, part=self.pk).exists():
+            return
+
+        SupplierPriceBreak.objects.create(
+            part=self,
+            quantity=quantity,
+            price=price
+        )
+
     def get_price(self, quantity, moq=True, multiples=True, currency=None):
         """ Calculate the supplier price based on quantity price breaks.
 
diff --git a/InvenTree/company/test_views.py b/InvenTree/company/test_views.py
index d895c18957..a68a740d33 100644
--- a/InvenTree/company/test_views.py
+++ b/InvenTree/company/test_views.py
@@ -3,6 +3,8 @@
 # -*- coding: utf-8 -*-
 from __future__ import unicode_literals
 
+import json
+
 from django.test import TestCase
 from django.urls import reverse
 from django.contrib.auth import get_user_model
@@ -11,7 +13,7 @@ from django.contrib.auth.models import Group
 from .models import SupplierPart
 
 
-class CompanyViewTest(TestCase):
+class CompanyViewTestBase(TestCase):
 
     fixtures = [
         'category',
@@ -47,14 +49,105 @@ class CompanyViewTest(TestCase):
 
         self.client.login(username='username', password='password')
 
-    def test_company_index(self):
-        """ Test the company index """
 
-        response = self.client.get(reverse('company-index'))
+class SupplierPartViewTests(CompanyViewTestBase):
+    """
+    Tests for the SupplierPart views.
+    """
+
+    def post(self, data, valid=None):
+        """
+        POST against this form and return the response (as a JSON object)
+        """
+        url = reverse('supplier-part-create')
+
+        response = self.client.post(url, data, HTTP_X_REQUESTED_WITH='XMLHttpRequest')
+
         self.assertEqual(response.status_code, 200)
 
+        json_data = json.loads(response.content)
+
+        # If a particular status code is required
+        if valid is not None:
+            if valid:
+                self.assertEqual(json_data['form_valid'], True)
+            else:
+                self.assertEqual(json_data['form_valid'], False)
+
+        form_errors = json.loads(json_data['form_errors'])
+
+        return json_data, form_errors
+
+    def test_supplier_part_create(self):
+        """
+        Test the SupplierPartCreate view.
+        
+        This view allows some additional functionality,
+        specifically it allows the user to create a single-quantity price break
+        automatically, when saving the new SupplierPart model.
+        """
+
+        url = reverse('supplier-part-create')
+
+        # First check that we can GET the form
+        response = self.client.get(url, HTTP_X_REQUESTED_WITH='XMLHttpRequest')
+        self.assertEqual(response.status_code, 200)
+
+        # How many supplier parts are already in the database?
+        n = SupplierPart.objects.all().count()
+
+        data = {
+            'part': 1,
+            'supplier': 1,
+        }
+
+        # SKU is required! (form should fail)
+        (response, errors) = self.post(data, valid=False)
+
+        self.assertIsNotNone(errors.get('SKU', None))
+
+        data['SKU'] = 'TEST-ME-123'
+
+        (response, errors) = self.post(data, valid=True)
+
+        # Check that the SupplierPart was created!
+        self.assertEqual(n + 1, SupplierPart.objects.all().count())
+
+        # Check that it was created *without* a price-break
+        supplier_part = SupplierPart.objects.get(pk=response['pk'])
+
+        self.assertEqual(supplier_part.price_breaks.count(), 0)
+
+        # Duplicate SKU is prohibited
+        (response, errors) = self.post(data, valid=False)
+
+        self.assertIsNotNone(errors.get('__all__', None))
+
+        # Add with a different SKU, *and* a single-quantity price
+        data['SKU'] = 'TEST-ME-1234'
+        data['single_pricing_0'] = '123.4'
+        data['single_pricing_1'] = 'CAD'
+
+        (response, errors) = self.post(data, valid=True)
+
+        pk = response.get('pk')
+
+        # Check that *another* SupplierPart was created
+        self.assertEqual(n + 2, SupplierPart.objects.all().count())
+
+        supplier_part = SupplierPart.objects.get(pk=pk)
+
+        # Check that a price-break has been created!
+        self.assertEqual(supplier_part.price_breaks.count(), 1)
+
+        price_break = supplier_part.price_breaks.first()
+
+        self.assertEqual(price_break.quantity, 1)
+
     def test_supplier_part_delete(self):
-        """ Test the SupplierPartDelete view """
+        """
+        Test the SupplierPartDelete view
+        """
 
         url = reverse('supplier-part-delete')
 
@@ -80,3 +173,15 @@ class CompanyViewTest(TestCase):
         self.assertEqual(response.status_code, 200)
 
         self.assertEqual(n - 2, SupplierPart.objects.count())
+
+
+class CompanyViewTest(CompanyViewTestBase):
+    """
+    Tests for various 'Company' views
+    """
+
+    def test_company_index(self):
+        """ Test the company index """
+
+        response = self.client.get(reverse('company-index'))
+        self.assertEqual(response.status_code, 200)
diff --git a/InvenTree/company/views.py b/InvenTree/company/views.py
index e82107ec14..2f734a7cc5 100644
--- a/InvenTree/company/views.py
+++ b/InvenTree/company/views.py
@@ -271,6 +271,14 @@ class SupplierPartEdit(AjaxUpdateView):
     ajax_form_title = _('Edit Supplier Part')
     role_required = 'purchase_order.change'
 
+    def get_form(self):
+        form = super().get_form()
+
+        # Hide the single-pricing field (only for creating a new SupplierPart!)
+        form.fields['single_pricing'].widget = HiddenInput()
+
+        return form
+
 
 class SupplierPartCreate(AjaxCreateView):
     """ Create view for making new SupplierPart """
@@ -282,6 +290,30 @@ class SupplierPartCreate(AjaxCreateView):
     context_object_name = 'part'
     role_required = 'purchase_order.add'
 
+    def validate(self, part, form):
+
+        single_pricing = form.cleaned_data.get('single_pricing', None)
+
+        if single_pricing:
+            # TODO - What validation steps can be performed on the single_pricing field?
+            pass
+
+    def save(self, form):
+        """
+        If single_pricing is defined, add a price break for quantity=1
+        """
+
+        # Save the supplier part object
+        supplier_part = super().save(form)
+
+        single_pricing = form.cleaned_data.get('single_pricing', None)
+
+        if single_pricing:
+
+            supplier_part.add_price_break(1, single_pricing)
+
+        return supplier_part
+
     def get_form(self):
         """ Create Form instance to create a new SupplierPart object.
         Hide some fields if they are not appropriate in context
diff --git a/InvenTree/part/fixtures/part.yaml b/InvenTree/part/fixtures/part.yaml
index 9883edfcd3..f6d9d246db 100644
--- a/InvenTree/part/fixtures/part.yaml
+++ b/InvenTree/part/fixtures/part.yaml
@@ -8,6 +8,7 @@
     category: 8
     link: www.acme.com/parts/m2x4lphs
     tree_id: 0
+    purchaseable: True
     level: 0
     lft: 0
     rght: 0