diff --git a/src/backend/InvenTree/InvenTree/models.py b/src/backend/InvenTree/InvenTree/models.py index 5301141881..c8e8eb4ecf 100644 --- a/src/backend/InvenTree/InvenTree/models.py +++ b/src/backend/InvenTree/InvenTree/models.py @@ -451,7 +451,17 @@ class ReferenceIndexingMixin(models.Model): reference_int = models.BigIntegerField(default=0) -class InvenTreeModel(PluginValidationMixin, models.Model): +class ContentTypeMixin: + """Mixin class which supports retrieval of the ContentType for a model instance.""" + + def get_content_type(self): + """Return the ContentType object associated with this model.""" + from django.contrib.contenttypes.models import ContentType + + return ContentType.objects.get_for_model(self.__class__) + + +class InvenTreeModel(ContentTypeMixin, PluginValidationMixin, models.Model): """Base class for InvenTree models, which provides some common functionality. Includes the following mixins by default: @@ -658,7 +668,7 @@ class InvenTreeAttachmentMixin(InvenTreePermissionCheckMixin): Attachment.objects.create(**kwargs) -class InvenTreeTree(MPTTModel): +class InvenTreeTree(ContentTypeMixin, MPTTModel): """Provides an abstracted self-referencing tree model, based on the MPTTModel class. Our implementation provides the following key improvements: diff --git a/src/backend/InvenTree/part/test_param.py b/src/backend/InvenTree/part/test_param.py index 4e3bae3a70..c1a536bc25 100644 --- a/src/backend/InvenTree/part/test_param.py +++ b/src/backend/InvenTree/part/test_param.py @@ -369,13 +369,21 @@ class PartParameterTest(InvenTreeAPITestCase): # test that having non unique part/template combinations fails res = self.post(url, data, expected_code=400) + self.assertEqual(len(res.data), 3) self.assertEqual(len(res.data[1]), 0) for err in [res.data[0], res.data[2]]: - self.assertEqual(len(err), 2) + self.assertEqual(len(err), 3) self.assertEqual(str(err['model_id'][0]), 'This field must be unique.') + self.assertEqual(str(err['model_type'][0]), 'This field must be unique.') self.assertEqual(str(err['template'][0]), 'This field must be unique.') - self.assertEqual(Parameter.objects.filter(content_object=part4).count(), 0) + + self.assertEqual( + Parameter.objects.filter( + model_type=part4.get_content_type(), model_id=part4.pk + ).count(), + 0, + ) # Now, create a valid set of parameters data = [ @@ -384,7 +392,13 @@ class PartParameterTest(InvenTreeAPITestCase): ] res = self.post(url, data, expected_code=201) self.assertEqual(len(res.data), 2) - self.assertEqual(Parameter.objects.filter(content_object=part4).count(), 2) + + self.assertEqual( + Parameter.objects.filter( + model_type=part4.get_content_type(), model_id=part4.pk + ).count(), + 2, + ) def test_param_detail(self): """Tests for the Parameter detail endpoint."""