2
0
mirror of https://github.com/inventree/InvenTree.git synced 2025-10-31 13:15:43 +00:00

Supplier Mixin (#9761)

* commit initial draft for supplier import

* complete import wizard

* allow importing only mp and sp

* improved sample supplier plugin

* add docs

* add tests

* bump api version

* fix schema docu

* fix issues from code review

* commit unstaged changes

* fix test

* refactor part parameter bulk creation

* try to fix test

* fix tests

* fix test for mysql

* fix test

* support multiple suppliers by a single plugin

* hide import button if there is no supplier import plugin

* make form submitable via enter

* add pui test

* try to prevent race condition

* refactor api calls in pui tests

* try to fix tests again?

* fix tests

* trigger: ci

* update changelog

* fix api_version

* fix style

* Update CHANGELOG.md

Co-authored-by: Matthias Mair <code@mjmair.com>

* add user docs

---------

Co-authored-by: Matthias Mair <code@mjmair.com>
This commit is contained in:
Lukas Wolf
2025-10-17 22:13:03 +02:00
committed by GitHub
parent d534f67c62
commit de270a5fe7
41 changed files with 2298 additions and 119 deletions

View File

@@ -1,5 +1,6 @@
"""Main JSON interface views."""
import collections
import json
from pathlib import Path
@@ -488,16 +489,46 @@ class BulkCreateMixin:
if isinstance(data, list):
created_items = []
errors = []
has_errors = False
# If data is a list, we assume it is a bulk create request
if len(data) == 0:
raise ValidationError({'non_field_errors': _('No data provided')})
for item in data:
serializer = self.get_serializer(data=item)
serializer.is_valid(raise_exception=True)
self.perform_create(serializer)
created_items.append(serializer.data)
# validate unique together fields
if unique_create_fields := getattr(self, 'unique_create_fields', None):
existing = collections.defaultdict(list)
for idx, item in enumerate(data):
key = tuple(item[v] for v in unique_create_fields)
existing[key].append(idx)
unique_errors = [[] for _ in range(len(data))]
has_unique_errors = False
for item in existing.values():
if len(item) > 1:
has_unique_errors = True
error = {}
for field_name in unique_create_fields:
error[field_name] = [_('This field must be unique.')]
for idx in item:
unique_errors[idx] = error
if has_unique_errors:
raise ValidationError(unique_errors)
with transaction.atomic():
for item in data:
serializer = self.get_serializer(data=item)
if serializer.is_valid():
self.perform_create(serializer)
created_items.append(serializer.data)
errors.append([])
else:
errors.append(serializer.errors)
has_errors = True
if has_errors:
raise ValidationError(errors)
return Response(created_items, status=201)

View File

@@ -1,12 +1,16 @@
"""InvenTree API version information."""
# InvenTree API version
INVENTREE_API_VERSION = 409
INVENTREE_API_VERSION = 410
"""Increment this API version number whenever there is a significant change to the API that any clients need to know about."""
INVENTREE_API_TEXT = """
v410 -> 2025-06-12 : https://github.com/inventree/InvenTree/pull/9761
- Add supplier search and import API endpoints
- Add part parameter bulk create API endpoint
v409 -> 2025-10-17 : https://github.com/inventree/InvenTree/pull/10601
- Adds ability to filter StockList API by manufacturer part ID

View File

@@ -17,6 +17,7 @@ from rest_framework.response import Response
import part.filters
from data_exporter.mixins import DataExportViewMixin
from InvenTree.api import (
BulkCreateMixin,
BulkDeleteMixin,
BulkUpdateMixin,
ListCreateDestroyAPIView,
@@ -1416,7 +1417,11 @@ class PartParameterFilter(FilterSet):
class PartParameterList(
PartParameterAPIMixin, OutputOptionsMixin, DataExportViewMixin, ListCreateAPI
BulkCreateMixin,
PartParameterAPIMixin,
OutputOptionsMixin,
DataExportViewMixin,
ListCreateAPI,
):
"""API endpoint for accessing a list of PartParameter objects.
@@ -1444,6 +1449,8 @@ class PartParameterList(
'template__units',
]
unique_create_fields = ['part', 'template']
class PartParameterDetail(
PartParameterAPIMixin, OutputOptionsMixin, RetrieveUpdateDestroyAPI

View File

@@ -364,6 +364,36 @@ class PartParameterTest(InvenTreeAPITestCase):
self.assertEqual(len(response.data), 8)
def test_bulk_create_params(self):
"""Test that we can bulk create parameters via the API."""
url = reverse('api-part-parameter-list')
part4 = Part.objects.get(pk=4)
data = [
{'part': 4, 'template': 1, 'data': 70},
{'part': 4, 'template': 2, 'data': 80},
{'part': 4, 'template': 1, 'data': 80},
]
# test that having non unique part/template combinations fails
res = self.post(url, data, expected_code=400)
self.assertEqual(len(res.data), 3)
self.assertEqual(len(res.data[1]), 0)
for err in [res.data[0], res.data[2]]:
self.assertEqual(len(err), 2)
self.assertEqual(str(err['part'][0]), 'This field must be unique.')
self.assertEqual(str(err['template'][0]), 'This field must be unique.')
self.assertEqual(PartParameter.objects.filter(part=part4).count(), 0)
# Now, create a valid set of parameters
data = [
{'part': 4, 'template': 1, 'data': 70},
{'part': 4, 'template': 2, 'data': 80},
]
res = self.post(url, data, expected_code=201)
self.assertEqual(len(res.data), 2)
self.assertEqual(PartParameter.objects.filter(part=part4).count(), 2)
def test_param_detail(self):
"""Tests for the PartParameter detail endpoint."""
url = reverse('api-part-parameter-detail', kwargs={'pk': 5})

View File

@@ -31,6 +31,7 @@ from InvenTree.mixins import (
from plugin.base.action.api import ActionPluginView
from plugin.base.barcodes.api import barcode_api_urls
from plugin.base.locate.api import LocatePluginView
from plugin.base.supplier.api import supplier_api_urls
from plugin.base.ui.api import ui_plugins_api_urls
from plugin.models import PluginConfig, PluginSetting, PluginUserSetting
from plugin.plugin import InvenTreePlugin
@@ -601,4 +602,5 @@ plugin_api_urls = [
path('', PluginList.as_view(), name='api-plugin-list'),
]),
),
path('supplier/', include(supplier_api_urls)),
]

View File

@@ -0,0 +1,246 @@
"""API views for supplier plugins in InvenTree."""
from typing import TYPE_CHECKING
from django.db import transaction
from django.urls import path
from drf_spectacular.utils import OpenApiParameter, extend_schema
from rest_framework import status
from rest_framework.exceptions import NotFound
from rest_framework.response import Response
from rest_framework.views import APIView
from InvenTree import permissions
from part.models import PartCategoryParameterTemplate
from plugin import registry
from plugin.plugin import PluginMixinEnum
from .serializers import (
ImportRequestSerializer,
ImportResultSerializer,
SearchResultSerializer,
SupplierListSerializer,
)
if TYPE_CHECKING:
from plugin.base.supplier.mixins import SupplierMixin
else: # pragma: no cover
class SupplierMixin:
"""Dummy class for type checking."""
def get_supplier_plugin(plugin_slug: str, supplier_slug: str) -> SupplierMixin:
"""Return the supplier plugin for the given plugin and supplier slugs."""
supplier_plugin = None
for plugin in registry.with_mixin(PluginMixinEnum.SUPPLIER):
if plugin.slug == plugin_slug:
supplier_plugin = plugin
break
if not supplier_plugin:
raise NotFound(detail=f"Plugin '{plugin_slug}' not found")
if not any(s.slug == supplier_slug for s in supplier_plugin.get_suppliers()):
raise NotFound(
detail=f"Supplier '{supplier_slug}' not found for plugin '{plugin_slug}'"
)
return supplier_plugin
class ListSupplier(APIView):
"""List all available supplier plugins.
- GET: List supplier plugins
"""
role_required = 'part.add'
permission_classes = [
permissions.IsAuthenticatedOrReadScope,
permissions.RolePermission,
]
serializer_class = SupplierListSerializer
@extend_schema(responses={200: SupplierListSerializer(many=True)})
def get(self, request):
"""List all available supplier plugins."""
suppliers = []
for plugin in registry.with_mixin(PluginMixinEnum.SUPPLIER):
suppliers.extend([
{
'plugin_slug': plugin.slug,
'supplier_slug': supplier.slug,
'supplier_name': supplier.name,
}
for supplier in plugin.get_suppliers()
])
return Response(suppliers)
class SearchPart(APIView):
"""Search parts by supplier.
- GET: Start part search
"""
role_required = 'part.add'
permission_classes = [
permissions.IsAuthenticatedOrReadScope,
permissions.RolePermission,
]
serializer_class = SearchResultSerializer
@extend_schema(
parameters=[
OpenApiParameter(name='plugin', description='Plugin slug', required=True),
OpenApiParameter(
name='supplier', description='Supplier slug', required=True
),
OpenApiParameter(name='term', description='Search term', required=True),
],
responses={200: SearchResultSerializer(many=True)},
)
def get(self, request):
"""Search parts by supplier."""
plugin_slug = request.query_params.get('plugin', '')
supplier_slug = request.query_params.get('supplier', '')
term = request.query_params.get('term', '')
supplier_plugin = get_supplier_plugin(plugin_slug, supplier_slug)
try:
results = supplier_plugin.get_search_results(supplier_slug, term)
except Exception as e:
return Response(
{'error': str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR
)
response = SearchResultSerializer(results, many=True).data
return Response(response)
class ImportPart(APIView):
"""Import a part by supplier.
- POST: Attempt to import part by sku
"""
role_required = 'part.add'
permission_classes = [
permissions.IsAuthenticatedOrReadScope,
permissions.RolePermission,
]
serializer_class = ImportResultSerializer
@extend_schema(
request=ImportRequestSerializer, responses={200: ImportResultSerializer}
)
def post(self, request):
"""Import a part by supplier."""
serializer = ImportRequestSerializer(data=request.data)
if not serializer.is_valid():
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
# Extract validated data
plugin_slug = serializer.validated_data.get('plugin', '')
supplier_slug = serializer.validated_data.get('supplier', '')
part_import_id = serializer.validated_data.get('part_import_id', '')
category = serializer.validated_data.get('category_id', None)
part = serializer.validated_data.get('part_id', None)
supplier_plugin = get_supplier_plugin(plugin_slug, supplier_slug)
# Validate part/category
if not part and not category:
return Response(
{
'detail': "'category_id' is not provided, but required if no part_id is provided"
},
status=status.HTTP_400_BAD_REQUEST,
)
from plugin.base.supplier.mixins import supplier
# Import part data
try:
import_data = supplier_plugin.get_import_data(supplier_slug, part_import_id)
with transaction.atomic():
# create part if it does not exist
if not part:
part = supplier_plugin.import_part(
import_data, category=category, creation_user=request.user
)
# create manufacturer part
manufacturer_part = supplier_plugin.import_manufacturer_part(
import_data, part=part
)
# create supplier part
supplier_part = supplier_plugin.import_supplier_part(
import_data, part=part, manufacturer_part=manufacturer_part
)
# set default supplier if not set
if not part.default_supplier:
part.default_supplier = supplier_part
part.save()
# get pricing
pricing = supplier_plugin.get_pricing_data(import_data)
# get parameters
parameters = supplier_plugin.get_parameters(import_data)
except supplier.PartNotFoundError:
return Response(
{'detail': f"Part with id: '{part_import_id}' not found"},
status=status.HTTP_404_NOT_FOUND,
)
except Exception as e:
return Response(
{'detail': str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR
)
# add default parameters for category
if category:
categories = category.get_ancestors(include_self=True)
category_parameters = PartCategoryParameterTemplate.objects.filter(
category__in=categories
)
for c in category_parameters:
for p in parameters:
if p.parameter_template == c.parameter_template:
p.on_category = True
p.value = p.value if p.value is not None else c.default_value
break
else:
parameters.append(
supplier.ImportParameter(
name=c.parameter_template.name,
value=c.default_value,
on_category=True,
parameter_template=c.parameter_template,
)
)
parameters.sort(key=lambda x: x.on_category, reverse=True)
response = ImportResultSerializer({
'part_id': part.pk,
'part_detail': part,
'supplier_part_id': supplier_part.pk,
'manufacturer_part_id': manufacturer_part.pk,
'pricing': pricing,
'parameters': parameters,
}).data
return Response(response)
supplier_api_urls = [
path('list/', ListSupplier.as_view(), name='api-supplier-list'),
path('search/', SearchPart.as_view(), name='api-supplier-search'),
path('import/', ImportPart.as_view(), name='api-supplier-import'),
]

View File

@@ -0,0 +1,88 @@
"""Dataclasses for supplier plugins."""
from dataclasses import dataclass
from typing import Optional
import part.models as part_models
@dataclass
class Supplier:
"""Data class to represent a supplier.
Note that one plugin can connect to multiple suppliers this way with e.g. different credentials.
Attributes:
slug (str): A unique identifier for the supplier.
name (str): The human-readable name of the supplier.
"""
slug: str
name: str
@dataclass
class SearchResult:
"""Data class to represent a search result from a supplier.
Attributes:
sku (str): The stock keeping unit identifier for the part.
name (str): The name of the part.
exact (bool): Indicates if the search result is an exact match.
description (Optional[str]): A brief description of the part.
price (Optional[str]): The price of the part as a string.
link (Optional[str]): A URL link to the part on the supplier's website.
image_url (Optional[str]): A URL to an image of the part.
id (Optional[str]): An optional identifier for the part (part_id), defaults to sku if not provided
existing_part (Optional[part_models.Part]): An existing part in the system that matches this search result, if any.
"""
sku: str
name: str
exact: bool
description: Optional[str] = None
price: Optional[str] = None
link: Optional[str] = None
image_url: Optional[str] = None
id: Optional[str] = None
existing_part: Optional[part_models.Part] = None
def __post_init__(self):
"""Post-initialization to set the ID if not provided."""
if not self.id:
self.id = self.sku
@dataclass
class ImportParameter:
"""Data class to represent a parameter for a part during import.
Attributes:
name (str): The name of the parameter.
value (str): The value of the parameter.
on_category (Optional[bool]): Indicates if the parameter is associated with a category. This will be automatically set by InvenTree
parameter_template (Optional[PartParameterTemplate]): The associated parameter template, if any.
"""
name: str
value: str
on_category: Optional[bool] = False
parameter_template: Optional[part_models.PartParameterTemplate] = None
def __post_init__(self):
"""Post-initialization to fetch the parameter template if not provided."""
if not self.parameter_template:
try:
self.parameter_template = part_models.PartParameterTemplate.objects.get(
name__iexact=self.name
)
except part_models.PartParameterTemplate.DoesNotExist:
pass
class PartNotFoundError(Exception):
"""Exception raised when a part is not found during import."""
class PartImportError(Exception):
"""Exception raised when an error occurs during part import."""

View File

@@ -0,0 +1,177 @@
"""Plugin mixin class for Supplier Integration."""
import io
from typing import Any, Generic, Optional, TypeVar
import django.contrib.auth.models
from django.core.exceptions import ValidationError
from django.core.files.base import ContentFile
import company.models
import part.models as part_models
from InvenTree.helpers_model import download_image_from_url
from plugin import PluginMixinEnum
from plugin.base.supplier import helpers as supplier
from plugin.mixins import SettingsMixin
PartData = TypeVar('PartData')
class SupplierMixin(SettingsMixin, Generic[PartData]):
"""Mixin which provides integration to specific suppliers."""
class MixinMeta:
"""Meta options for this mixin."""
MIXIN_NAME = 'Supplier'
def __init__(self):
"""Register mixin."""
super().__init__()
self.add_mixin(PluginMixinEnum.SUPPLIER, True, __class__)
self.SETTINGS['SUPPLIER'] = {
'name': 'Supplier',
'description': 'The Supplier which this plugin integrates with.',
'model': 'company.company',
'model_filters': {'is_supplier': True},
'required': True,
}
@property
def supplier_company(self):
"""Return the supplier company object."""
pk = self.get_setting('SUPPLIER', cache=True)
if not pk:
raise supplier.PartImportError('Supplier setting is missing.')
return company.models.Company.objects.get(pk=pk)
# --- Methods to be overridden by plugins ---
def get_suppliers(self) -> list[supplier.Supplier]:
"""Return a list of available suppliers."""
raise NotImplementedError('This method needs to be overridden.')
def get_search_results(
self, supplier_slug: str, term: str
) -> list[supplier.SearchResult]:
"""Return a list of search results for the given search term."""
raise NotImplementedError('This method needs to be overridden.')
def get_import_data(self, supplier_slug: str, part_id: str) -> PartData:
"""Return the import data for the given part ID."""
raise NotImplementedError('This method needs to be overridden.')
def get_pricing_data(self, data: PartData) -> dict[int, tuple[float, str]]:
"""Return a dictionary of pricing data for the given part data."""
raise NotImplementedError('This method needs to be overridden.')
def get_parameters(self, data: PartData) -> list[supplier.ImportParameter]:
"""Return a list of parameters for the given part data."""
raise NotImplementedError('This method needs to be overridden.')
def import_part(
self,
data: PartData,
*,
category: Optional[part_models.PartCategory],
creation_user: Optional[django.contrib.auth.models.User],
) -> part_models.Part:
"""Import a part using the provided data.
This may include:
- Creating a new part
- Add an image to the part
- if this part has several variants, (create) a template part and assign it to the part
- create related parts
- add attachments to the part
"""
raise NotImplementedError('This method needs to be overridden.')
def import_manufacturer_part(
self, data: PartData, *, part: part_models.Part
) -> company.models.ManufacturerPart:
"""Import a manufacturer part using the provided data.
This may include:
- Creating a new manufacturer
- Creating a new manufacturer part
- Assigning the part to the manufacturer part
- Setting the default supplier for the part
- Adding parameters to the manufacturer part
- Adding attachments to the manufacturer part
"""
raise NotImplementedError('This method needs to be overridden.')
def import_supplier_part(
self,
data: PartData,
*,
part: part_models.Part,
manufacturer_part: company.models.ManufacturerPart,
) -> part_models.SupplierPart:
"""Import a supplier part using the provided data.
This may include:
- Creating a new supplier part
- Creating supplier price breaks
"""
raise NotImplementedError('This method needs to be overridden.')
# --- Helper methods for importing parts ---
def download_image(self, img_url: str):
"""Download an image from the given URL and return it as a ContentFile."""
img_r = download_image_from_url(img_url)
fmt = img_r.format or 'PNG'
buffer = io.BytesIO()
img_r.save(buffer, format=fmt)
return ContentFile(buffer.getvalue()), fmt
def get_template_part(
self, other_variants: list[part_models.Part], template_kwargs: dict[str, Any]
) -> part_models.Part:
"""Helper function to handle variant parts.
This helper function identifies all roots for the provided 'other_variants' list
- for no root => root part will be created using the 'template_kwargs'
- for one root
- root is a template => return it
- root is no template, create a new template like if there is no root
and assign it to only root that was found and return it
- for multiple roots => error raised
"""
root_set = {v.get_root() for v in other_variants}
# check how much roots for the variant parts exist to identify the parent_part
parent_part = None # part that should be used as parent_part
root_part = None # part that was discovered as root part in root_set
if len(root_set) == 1:
root_part = next(iter(root_set))
if root_part.is_template:
parent_part = root_part
if len(root_set) == 0 or (root_part and not root_part.is_template):
parent_part = part_models.Part.objects.create(**template_kwargs)
if not parent_part:
raise supplier.PartImportError(
f'A few variant parts from the supplier are already imported, but have different InvenTree variant root parts, try to merge them to the same root variant template part (parts: {", ".join(str(p.pk) for p in other_variants)}).'
)
# assign parent_part to root_part if root_part has no variant of already
if root_part and not root_part.is_template and not root_part.variant_of:
root_part.variant_of = parent_part # type: ignore
root_part.save()
return parent_part
def create_related_parts(
self, part: part_models.Part, related_parts: list[part_models.Part]
):
"""Create relationships between the given part and related parts."""
for p in related_parts:
try:
part_models.PartRelated.objects.create(part_1=part, part_2=p)
except ValidationError:
pass # pass, duplicate relationship detected

View File

@@ -0,0 +1,115 @@
"""Serializer definitions for the supplier plugin base."""
from typing import Any, Optional
from rest_framework import serializers
import part.models as part_models
from part.serializers import PartSerializer
class SupplierListSerializer(serializers.Serializer):
"""Serializer for a supplier plugin."""
plugin_slug = serializers.CharField()
supplier_slug = serializers.CharField()
supplier_name = serializers.CharField()
class SearchResultSerializer(serializers.Serializer):
"""Serializer for a search result."""
class Meta:
"""Meta options for the SearchResultSerializer."""
fields = [
'id',
'sku',
'name',
'exact',
'description',
'price',
'link',
'image_url',
'existing_part_id',
]
read_only_fields = fields
id = serializers.CharField()
sku = serializers.CharField()
name = serializers.CharField()
exact = serializers.BooleanField()
description = serializers.CharField()
price = serializers.CharField()
link = serializers.CharField()
image_url = serializers.CharField()
existing_part_id = serializers.SerializerMethodField()
def get_existing_part_id(self, value) -> Optional[int]:
"""Return the ID of the existing part if available."""
return getattr(value.existing_part, 'pk', None)
class ImportParameterSerializer(serializers.Serializer):
"""Serializer for a ImportParameter."""
class Meta:
"""Meta options for the ImportParameterSerializer."""
fields = ['name', 'value', 'parameter_template', 'on_category']
name = serializers.CharField()
value = serializers.CharField()
parameter_template = serializers.SerializerMethodField()
on_category = serializers.BooleanField()
def get_parameter_template(self, value) -> Optional[int]:
"""Return the ID of the parameter template if available."""
return getattr(value.parameter_template, 'pk', None)
class ImportRequestSerializer(serializers.Serializer):
"""Serializer for the import request."""
plugin = serializers.CharField(required=True)
supplier = serializers.CharField(required=True)
part_import_id = serializers.CharField(required=True)
category_id = serializers.PrimaryKeyRelatedField(
queryset=part_models.PartCategory.objects.all(),
many=False,
required=False,
allow_null=True,
)
part_id = serializers.PrimaryKeyRelatedField(
queryset=part_models.Part.objects.all(),
many=False,
required=False,
allow_null=True,
)
class ImportResultSerializer(serializers.Serializer):
"""Serializer for the import result."""
class Meta:
"""Meta options for the ImportResultSerializer."""
fields = [
'part_id',
'part_detail',
'manufacturer_part_id',
'supplier_part_id',
'pricing',
'parameters',
]
part_id = serializers.IntegerField()
part_detail = PartSerializer()
manufacturer_part_id = serializers.IntegerField()
supplier_part_id = serializers.IntegerField()
pricing = serializers.SerializerMethodField()
parameters = ImportParameterSerializer(many=True)
def get_pricing(self, value: Any) -> list[tuple[float, str]]:
"""Return the pricing data as a dictionary."""
return value['pricing']

View File

@@ -20,6 +20,8 @@ from plugin.base.integration.ValidationMixin import ValidationMixin
from plugin.base.label.mixins import LabelPrintingMixin
from plugin.base.locate.mixins import LocateMixin
from plugin.base.mail.mixins import MailMixin
from plugin.base.supplier import helpers as supplier
from plugin.base.supplier.mixins import SupplierMixin
from plugin.base.ui.mixins import UserInterfaceMixin
__all__ = [
@@ -41,8 +43,10 @@ __all__ = [
'ScheduleMixin',
'SettingsMixin',
'SupplierBarcodeMixin',
'SupplierMixin',
'TransitionMixin',
'UrlsMixin',
'UserInterfaceMixin',
'ValidationMixin',
'supplier',
]

View File

@@ -75,6 +75,7 @@ class PluginMixinEnum(StringEnum):
SCHEDULE = 'schedule'
SETTINGS = 'settings'
SETTINGS_CONTENT = 'settingscontent'
SUPPLIER = 'supplier'
STATE_TRANSITION = 'statetransition'
SUPPLIER_BARCODE = 'supplier-barcode'
URLS = 'urls'

View File

@@ -0,0 +1,182 @@
"""Sample supplier plugin."""
from company.models import Company, ManufacturerPart, SupplierPart, SupplierPriceBreak
from part.models import Part
from plugin.mixins import SupplierMixin, supplier
from plugin.plugin import InvenTreePlugin
class SampleSupplierPlugin(SupplierMixin, InvenTreePlugin):
"""Example plugin to integrate with a dummy supplier."""
NAME = 'SampleSupplierPlugin'
SLUG = 'samplesupplier'
TITLE = 'My sample supplier plugin'
VERSION = '0.0.1'
def __init__(self):
"""Initialize the sample supplier plugin."""
super().__init__()
self.sample_data = []
for material in ['Steel', 'Aluminium', 'Brass']:
for size in ['M1', 'M2', 'M3', 'M4', 'M5']:
for length in range(5, 30, 5):
self.sample_data.append({
'material': material,
'thread': size,
'length': length,
'sku': f'BOLT-{material}-{size}-{length}',
'name': f'Bolt {size}x{length}mm {material}',
'description': f'This is a sample part description demonstration purposes for the {size}x{length} {material} bolt.',
'price': {
1: [1.0, 'EUR'],
10: [0.9, 'EUR'],
100: [0.8, 'EUR'],
5000: [0.5, 'EUR'],
},
'link': f'https://example.com/sample-part-{size}-{length}-{material}',
'image_url': r'https://github.com/inventree/demo-dataset/blob/main/media/part_images/flat-head.png?raw=true',
'brand': 'Bolt Manufacturer',
})
def get_suppliers(self) -> list[supplier.Supplier]:
"""Return a list of available suppliers."""
return [supplier.Supplier(slug='sample-fasteners', name='Sample Fasteners')]
def get_search_results(
self, supplier_slug: str, term: str
) -> list[supplier.SearchResult]:
"""Return a list of search results based on the search term."""
return [
supplier.SearchResult(
sku=p['sku'],
name=p['name'],
description=p['description'],
exact=p['sku'] == term,
price=f'{p["price"][1][0]:.2f}',
link=p['link'],
image_url=p['image_url'],
existing_part=getattr(
SupplierPart.objects.filter(SKU=p['sku']).first(), 'part', None
),
)
for p in self.sample_data
if all(t.lower() in p['name'].lower() for t in term.split())
]
def get_import_data(self, supplier_slug: str, part_id: str):
"""Return import data for a specific part ID."""
for p in self.sample_data:
if p['sku'] == part_id:
p = p.copy()
p['variants'] = [
x['sku']
for x in self.sample_data
if x['thread'] == p['thread'] and x['length'] == p['length']
]
return p
raise supplier.PartNotFoundError()
def get_pricing_data(self, data) -> dict[int, tuple[float, str]]:
"""Return pricing data for the given part data."""
return data['price']
def get_parameters(self, data) -> list[supplier.ImportParameter]:
"""Return a list of parameters for the given part data."""
return [
supplier.ImportParameter(name='Thread', value=data['thread'][1:]),
supplier.ImportParameter(name='Length', value=f'{data["length"]}mm'),
supplier.ImportParameter(name='Material', value=data['material']),
supplier.ImportParameter(name='Head', value='Flat Head'),
]
def import_part(self, data, **kwargs) -> Part:
"""Import a part based on the provided data."""
part, created = Part.objects.get_or_create(
name__iexact=data['sku'],
purchaseable=True,
defaults={
'name': data['sku'],
'description': data['description'],
'link': data['link'],
**kwargs,
},
)
# If the part was created, set additional fields
if created:
if data['image_url']:
file, fmt = self.download_image(data['image_url'])
filename = f'part_{part.pk}_image.{fmt.lower()}'
part.image.save(filename, file)
# link other variants if they exist in our inventree database
if len(data['variants']):
# search for other parts that may already have a template part associated
variant_parts = [
x.part
for x in SupplierPart.objects.filter(SKU__in=data['variants'])
]
parent_part = self.get_template_part(
variant_parts,
{
# we cannot extract a real name for the root part, but we can try to guess a unique name
'name': data['sku'].replace(data['material'] + '-', ''),
'description': data['name'].replace(' ' + data['material'], ''),
'link': data['link'],
'image': part.image.name,
'is_template': True,
**kwargs,
},
)
# after the template part was created, we need to refresh the part from the db because its tree id may have changed
# which results in an error if saved directly
part.refresh_from_db()
part.variant_of = parent_part # type: ignore
part.save()
return part
def import_manufacturer_part(self, data, **kwargs) -> ManufacturerPart:
"""Import a manufacturer part based on the provided data."""
mft, _ = Company.objects.get_or_create(
name__iexact=data['brand'],
defaults={
'is_manufacturer': True,
'is_supplier': False,
'name': data['brand'],
},
)
mft_part, created = ManufacturerPart.objects.get_or_create(
MPN=f'MAN-{data["sku"]}', manufacturer=mft, **kwargs
)
if created:
# Attachments, notes, parameters and more can be added here
pass
return mft_part
def import_supplier_part(self, data, **kwargs) -> SupplierPart:
"""Import a supplier part based on the provided data."""
spp, _ = SupplierPart.objects.get_or_create(
SKU=data['sku'],
supplier=self.supplier_company,
**kwargs,
defaults={'link': data['link']},
)
SupplierPriceBreak.objects.filter(part=spp).delete()
SupplierPriceBreak.objects.bulk_create([
SupplierPriceBreak(
part=spp, quantity=quantity, price=price, price_currency=currency
)
for quantity, (price, currency) in data['price'].items()
])
return spp

View File

@@ -0,0 +1,211 @@
"""Unit tests for locate_sample sample plugins."""
from django.urls import reverse
from company.models import ManufacturerPart, SupplierPart
from InvenTree.unit_test import InvenTreeAPITestCase
from part.models import (
Part,
PartCategory,
PartCategoryParameterTemplate,
PartParameterTemplate,
)
from plugin import registry
class SampleSupplierTest(InvenTreeAPITestCase):
"""Tests for SampleSupplierPlugin."""
fixtures = ['location', 'category', 'part', 'stock', 'company']
roles = ['part.add']
def test_list(self):
"""Check the list api."""
# Test APIs
url = reverse('api-supplier-list')
# No plugin
res = self.get(url, expected_code=200)
self.assertEqual(len(res.data), 0)
# Activate plugin
config = registry.get_plugin('samplesupplier', active=None).plugin_config()
config.active = True
config.save()
# One active plugin
res = self.get(url, expected_code=200)
self.assertEqual(len(res.data), 1)
self.assertEqual(res.data[0]['plugin_slug'], 'samplesupplier')
self.assertEqual(res.data[0]['supplier_slug'], 'sample-fasteners')
self.assertEqual(res.data[0]['supplier_name'], 'Sample Fasteners')
def test_search(self):
"""Check the search api."""
# Activate plugin
config = registry.get_plugin('samplesupplier', active=None).plugin_config()
config.active = True
config.save()
# Test APIs
url = reverse('api-supplier-search')
# No plugin
self.get(
url,
{'plugin': 'non-existent-plugin', 'supplier': 'sample-fasteners'},
expected_code=404,
)
# No supplier
self.get(
url,
{'plugin': 'samplesupplier', 'supplier': 'non-existent-supplier'},
expected_code=404,
)
# valid supplier
res = self.get(
url,
{'plugin': 'samplesupplier', 'supplier': 'sample-fasteners', 'term': 'M5'},
expected_code=200,
)
self.assertEqual(len(res.data), 15)
self.assertEqual(res.data[0]['sku'], 'BOLT-Steel-M5-5')
def test_import_part(self):
"""Test importing a part by supplier."""
# Activate plugin
plugin = registry.get_plugin('samplesupplier', active=None)
config = plugin.plugin_config()
config.active = True
config.save()
# Test APIs
url = reverse('api-supplier-import')
# No plugin
self.post(
url,
{
'plugin': 'non-existent-plugin',
'supplier': 'sample-fasteners',
'part_import_id': 'BOLT-Steel-M5-5',
},
expected_code=404,
)
# No supplier
self.post(
url,
{
'plugin': 'samplesupplier',
'supplier': 'non-existent-supplier',
'part_import_id': 'BOLT-Steel-M5-5',
},
expected_code=404,
)
# valid supplier, no part or category provided
self.post(
url,
{
'plugin': 'samplesupplier',
'supplier': 'sample-fasteners',
'part_import_id': 'BOLT-Steel-M5-5',
},
expected_code=400,
)
# valid supplier, but no supplier company set
self.post(
url,
{
'plugin': 'samplesupplier',
'supplier': 'sample-fasteners',
'part_import_id': 'BOLT-Steel-M5-5',
'category_id': 1,
},
expected_code=500,
)
# Set the supplier company now
plugin.set_setting('SUPPLIER', 1)
# valid supplier, valid part import
category = PartCategory.objects.get(pk=1)
p_len = PartParameterTemplate(name='Length', units='mm')
p_test = PartParameterTemplate(name='Test Parameter')
p_len.save()
p_test.save()
PartCategoryParameterTemplate.objects.bulk_create([
PartCategoryParameterTemplate(category=category, parameter_template=p_len),
PartCategoryParameterTemplate(
category=category, parameter_template=p_test, default_value='Test Value'
),
])
res = self.post(
url,
{
'plugin': 'samplesupplier',
'supplier': 'sample-fasteners',
'part_import_id': 'BOLT-Steel-M5-5',
'category_id': 1,
},
expected_code=200,
)
part = Part.objects.get(name='BOLT-Steel-M5-5')
self.assertIsNotNone(part)
self.assertEqual(part.pk, res.data['part_id'])
self.assertIsNotNone(SupplierPart.objects.get(pk=res.data['supplier_part_id']))
self.assertIsNotNone(
ManufacturerPart.objects.get(pk=res.data['manufacturer_part_id'])
)
self.assertSetEqual(
{x['name'] for x in res.data['parameters']},
{'Thread', 'Length', 'Material', 'Head', 'Test Parameter'},
)
for p in res.data['parameters']:
if p['name'] == 'Length':
self.assertEqual(p['value'], '5mm')
self.assertEqual(p['parameter_template'], p_len.pk)
self.assertTrue(p['on_category'])
elif p['name'] == 'Test Parameter':
self.assertEqual(p['value'], 'Test Value')
self.assertEqual(p['parameter_template'], p_test.pk)
self.assertTrue(p['on_category'])
# valid supplier, import only manufacturer and supplier part
part2 = Part.objects.create(name='Test Part', purchaseable=True)
res = self.post(
url,
{
'plugin': 'samplesupplier',
'supplier': 'sample-fasteners',
'part_import_id': 'BOLT-Steel-M5-10',
'part_id': part2.pk,
},
expected_code=200,
)
self.assertEqual(part2.pk, res.data['part_id'])
sp = SupplierPart.objects.get(pk=res.data['supplier_part_id'])
mp = ManufacturerPart.objects.get(pk=res.data['manufacturer_part_id'])
self.assertIsNotNone(sp)
self.assertIsNotNone(mp)
self.assertEqual(sp.part.pk, part2.pk)
self.assertEqual(mp.part.pk, part2.pk)
# PartNotFoundError
self.post(
url,
{
'plugin': 'samplesupplier',
'supplier': 'sample-fasteners',
'part_import_id': 'non-existent-part',
'category_id': 1,
},
expected_code=404,
)