"""Unit tests for the PartCategory model"""

from django.core.exceptions import ValidationError
from django.test import TestCase

from .models import Part, PartCategory, PartParameter, PartParameterTemplate


class CategoryTest(TestCase):
    """Tests to ensure that the relational category tree functions correctly.

    Loads the following test fixtures:
    - category.yaml
    """
    fixtures = [
        'category',
        'part',
        'location',
        'params',
    ]

    def setUp(self):
        """Extract some interesting categories for time-saving"""
        self.electronics = PartCategory.objects.get(name='Electronics')
        self.mechanical = PartCategory.objects.get(name='Mechanical')
        self.resistors = PartCategory.objects.get(name='Resistors')
        self.capacitors = PartCategory.objects.get(name='Capacitors')
        self.fasteners = PartCategory.objects.get(name='Fasteners')
        self.ic = PartCategory.objects.get(name='IC')
        self.transceivers = PartCategory.objects.get(name='Transceivers')

    def test_parents(self):
        """Test that the parent fields are properly set, based on the test fixtures."""
        self.assertEqual(self.resistors.parent, self.electronics)
        self.assertEqual(self.capacitors.parent, self.electronics)
        self.assertEqual(self.electronics.parent, None)

        self.assertEqual(self.fasteners.parent, self.mechanical)

    def test_children_count(self):
        """Test that categories have the correct number of children."""
        self.assertTrue(self.electronics.has_children)
        self.assertTrue(self.mechanical.has_children)

        self.assertEqual(len(self.electronics.children.all()), 3)
        self.assertEqual(len(self.mechanical.children.all()), 1)

    def test_unique_childs(self):
        """Test the 'unique_children' functionality."""
        childs = [item.pk for item in self.electronics.getUniqueChildren()]

        self.assertIn(self.transceivers.id, childs)
        self.assertIn(self.ic.id, childs)

        self.assertNotIn(self.fasteners.id, childs)

    def test_unique_parents(self):
        """Test the 'unique_parents' functionality."""
        parents = [item.pk for item in self.transceivers.getUniqueParents()]

        self.assertIn(self.electronics.id, parents)
        self.assertIn(self.ic.id, parents)
        self.assertNotIn(self.fasteners.id, parents)

    def test_path_string(self):
        """Test that the category path string works correctly."""
        self.assertEqual(str(self.resistors), 'Electronics/Resistors - Resistors')
        self.assertEqual(str(self.transceivers.pathstring), 'Electronics/IC/Transceivers')

    def test_url(self):
        """Test that the PartCategory URL works."""
        self.assertEqual(self.capacitors.get_absolute_url(), '/part/category/3/')

    def test_part_count(self):
        """Test that the Category part count works."""

        self.assertEqual(self.fasteners.partcount(), 2)
        self.assertEqual(self.capacitors.partcount(), 1)

        self.assertEqual(self.electronics.partcount(), 3)

        self.assertEqual(self.mechanical.partcount(), 9)
        self.assertEqual(self.mechanical.partcount(active=True), 8)
        self.assertEqual(self.mechanical.partcount(False), 7)

        self.assertEqual(self.electronics.item_count, self.electronics.partcount())

    def test_parameters(self):
        """Test that the Category parameters are correctly fetched."""
        # Check number of SQL queries to iterate other parameters
        with self.assertNumQueries(7):
            # Prefetch: 3 queries (parts, parameters and parameters_template)
            fasteners = self.fasteners.prefetch_parts_parameters()
            # Iterate through all parts and parameters
            for fastener in fasteners:
                self.assertIsInstance(fastener, Part)
                for parameter in fastener.parameters.all():
                    self.assertIsInstance(parameter, PartParameter)
                    self.assertIsInstance(parameter.template, PartParameterTemplate)

            # Test number of unique parameters
            self.assertEqual(len(self.fasteners.get_unique_parameters(prefetch=fasteners)), 1)
            # Test number of parameters found for each part
            parts_parameters = self.fasteners.get_parts_parameters(prefetch=fasteners)
            part_infos = ['pk', 'name', 'description']
            for part_parameter in parts_parameters:
                # Remove part informations
                for item in part_infos:
                    part_parameter.pop(item)
                self.assertEqual(len(part_parameter), 1)

    def test_invalid_name(self):
        """Test that an illegal character is prohibited in a category name"""
        cat = PartCategory(name='test/with/illegal/chars', description='Test category', parent=None)

        with self.assertRaises(ValidationError) as err:
            cat.full_clean()
            cat.save()  # pragma: no cover

        self.assertIn('Illegal character in name', str(err.exception.error_dict.get('name')))

        cat.name = 'good name'
        cat.save()

    def test_delete(self):
        """Test that category deletion moves the children properly."""
        # Delete the 'IC' category and 'Transceiver' should move to be under 'Electronics'
        self.assertEqual(self.transceivers.parent, self.ic)
        self.assertEqual(self.ic.parent, self.electronics)

        self.ic.delete()

        # Get the data again
        transceivers = PartCategory.objects.get(name='Transceivers')
        self.assertEqual(transceivers.parent, self.electronics)

        # Now delete the 'fasteners' category - the parts should move to 'mechanical'
        self.fasteners.delete()

        fasteners = Part.objects.filter(description__contains='screw')

        for f in fasteners:
            self.assertEqual(f.category, self.mechanical)

    def test_default_locations(self):
        """Test traversal for default locations."""
        self.assertEqual(str(self.fasteners.default_location), 'Office/Drawer_1 - In my desk')

        # Any part under electronics should default to 'Home'
        r1 = Part.objects.get(name='R_2K2_0805')
        self.assertIsNone(r1.default_location)
        self.assertEqual(r1.get_default_location().name, 'Home')

        # But one part has a default_location set
        r2 = Part.objects.get(name='R_4K7_0603')
        self.assertEqual(r2.get_default_location().name, 'Bathroom')

        # And one part should have no default location at all
        w = Part.objects.get(name='Widget')
        self.assertIsNone(w.get_default_location())

    def test_category_tree(self):
        """Unit tests for the part category tree structure (MPTT)

        Ensure that the MPTT structure is rebuilt correctly,
        and the correct ancestor tree is observed.
        """
        # Clear out any existing parts
        Part.objects.all().delete()

        # First, create a structured tree of part categories
        A = PartCategory.objects.create(
            name='A',
            description='Top level category',
        )

        B1 = PartCategory.objects.create(name='B1', parent=A)
        B2 = PartCategory.objects.create(name='B2', parent=A)
        B3 = PartCategory.objects.create(name='B3', parent=A)

        C11 = PartCategory.objects.create(name='C11', parent=B1)
        C12 = PartCategory.objects.create(name='C12', parent=B1)
        C13 = PartCategory.objects.create(name='C13', parent=B1)

        C21 = PartCategory.objects.create(name='C21', parent=B2)
        C22 = PartCategory.objects.create(name='C22', parent=B2)
        C23 = PartCategory.objects.create(name='C23', parent=B2)

        C31 = PartCategory.objects.create(name='C31', parent=B3)
        C32 = PartCategory.objects.create(name='C32', parent=B3)
        C33 = PartCategory.objects.create(name='C33', parent=B3)

        # Check that the tree_id value is correct
        for cat in [B1, B2, B3, C11, C22, C33]:
            self.assertEqual(cat.tree_id, A.tree_id)
            self.assertEqual(cat.level, cat.parent.level + 1)
            self.assertEqual(cat.get_ancestors().count(), cat.level)

        # Spot check for C31
        ancestors = C31.get_ancestors(include_self=True)

        self.assertEqual(ancestors.count(), 3)
        self.assertEqual(ancestors[0], A)
        self.assertEqual(ancestors[1], B3)
        self.assertEqual(ancestors[2], C31)

        # At this point, we are confident that the tree is correctly structured

        # Add some parts to category B3

        for i in range(10):
            Part.objects.create(
                name=f'Part {i}',
                description='A test part',
                category=B3,
            )

        self.assertEqual(Part.objects.filter(category=B3).count(), 10)
        self.assertEqual(Part.objects.filter(category=A).count(), 0)

        # Delete category B3
        B3.delete()

        # Child parts have been moved to category A
        self.assertEqual(Part.objects.filter(category=A).count(), 10)

        for cat in [C31, C32, C33]:
            # These categories should now be directly under A
            cat.refresh_from_db()

            self.assertEqual(cat.parent, A)
            self.assertEqual(cat.level, 1)
            self.assertEqual(cat.get_ancestors().count(), 1)
            self.assertEqual(cat.get_ancestors()[0], A)

        # Now, delete category A
        A.delete()

        # Parts have now been moved to the top-level category
        self.assertEqual(Part.objects.filter(category=None).count(), 10)

        for loc in [B1, B2, C31, C32, C33]:
            # These should now all be "top level" categories
            loc.refresh_from_db()

            self.assertEqual(loc.level, 0)
            self.assertEqual(loc.parent, None)

        # Check descendants for B1
        descendants = B1.get_descendants()
        self.assertEqual(descendants.count(), 3)

        for loc in [C11, C12, C13]:
            self.assertTrue(loc in descendants)

        # Check category C1x, should be B1 -> C1x
        for loc in [C11, C12, C13]:
            loc.refresh_from_db()

            self.assertEqual(loc.level, 1)
            self.assertEqual(loc.parent, B1)
            ancestors = loc.get_ancestors(include_self=True)

            self.assertEqual(ancestors.count(), 2)
            self.assertEqual(ancestors[0], B1)
            self.assertEqual(ancestors[1], loc)

        # Check category C2x, should be B2 -> C2x
        for loc in [C21, C22, C23]:
            loc.refresh_from_db()

            self.assertEqual(loc.level, 1)
            self.assertEqual(loc.parent, B2)
            ancestors = loc.get_ancestors(include_self=True)

            self.assertEqual(ancestors.count(), 2)
            self.assertEqual(ancestors[0], B2)
            self.assertEqual(ancestors[1], loc)