2
0
mirror of https://github.com/inventree/InvenTree.git synced 2026-06-06 00:44:25 +00:00

feat(backend): add typechecking with ty (#9664)

* Add ty for type checking

* fix various typing issues

* fix req

* more fixes

* and more types

* and more typing

* fix imports

* more fixes

* fix types and optional statements

* ensure patch only runs if it is installed

* add type check to qc

* more fixes

* install all reqs

* fix more types

* more fixes

* disable container stuff for now

* move typecheck to seperate job

* try to use putput for path

* use env instead

* fix typo

* add missing install

* remove unclear imports - not sure why this was done

* add kwarg names

* fix introduced issue in url call

* ignore import

* fix broken typing changes

* fix filter import

* reduce change set

* remove api-change

* fix dict

* ignore typing errors

* fix more type issues

* ignore errors

* style fix

* fix type

* bump ty

* fix more

* type fixes

* update ignores

* fix import

* fix defaults

* fix ignore

* fix some issues

* fix type
This commit is contained in:
Matthias Mair
2025-09-17 13:30:02 +02:00
committed by GitHub
parent f057247fc1
commit 21cb488eef
100 changed files with 524 additions and 267 deletions
+16 -6
View File
@@ -17,6 +17,7 @@ import os
import re import re
import sys import sys
from pathlib import Path from pathlib import Path
from typing import Optional
import requests import requests
@@ -183,7 +184,8 @@ def check_version_number(version_string, allow_duplicate=False):
return highest_release return highest_release
if __name__ == '__main__': def main() -> bool:
"""Run the version check."""
parser = argparse.ArgumentParser(description='InvenTree Version Check') parser = argparse.ArgumentParser(description='InvenTree Version Check')
parser.add_argument( parser.add_argument(
'--show-version', '--show-version',
@@ -220,7 +222,7 @@ if __name__ == '__main__':
# Ensure that we are running in GH Actions # Ensure that we are running in GH Actions
if os.environ.get('GITHUB_ACTIONS', '') != 'true': if os.environ.get('GITHUB_ACTIONS', '') != 'true':
print('This script is intended to be run within a GitHub Action!') print('This script is intended to be run within a GitHub Action!')
sys.exit(1) return False
print('Running InvenTree version check...') print('Running InvenTree version check...')
@@ -261,11 +263,11 @@ if __name__ == '__main__':
) )
# Determine which docker tag we are going to use # Determine which docker tag we are going to use
docker_tags = None docker_tags: Optional[list[str]] = None
if GITHUB_REF_TYPE == 'tag': if GITHUB_REF_TYPE == 'tag':
# GITHUB_REF should be of the form /refs/heads/<tag> # GITHUB_REF should be of the form /refs/heads/<tag>
version_tag = GITHUB_REF.split('/')[-1] version_tag: str = GITHUB_REF.split('/')[-1]
print(f"Checking requirements for tagged release - '{version_tag}':") print(f"Checking requirements for tagged release - '{version_tag}':")
if version_tag != inventree_version: if version_tag != inventree_version:
@@ -287,11 +289,11 @@ if __name__ == '__main__':
print('GITHUB_REF_TYPE:', GITHUB_REF_TYPE) print('GITHUB_REF_TYPE:', GITHUB_REF_TYPE)
print('GITHUB_BASE_REF:', GITHUB_BASE_REF) print('GITHUB_BASE_REF:', GITHUB_BASE_REF)
print('GITHUB_REF:', GITHUB_REF) print('GITHUB_REF:', GITHUB_REF)
sys.exit(1) return False
if docker_tags is None: if docker_tags is None:
print('Docker tags could not be determined') print('Docker tags could not be determined')
sys.exit(1) return False
print(f"Version check passed for '{inventree_version}'!") print(f"Version check passed for '{inventree_version}'!")
print(f"Docker tags: '{docker_tags}'") print(f"Docker tags: '{docker_tags}'")
@@ -308,3 +310,11 @@ if __name__ == '__main__':
if GITHUB_REF_TYPE == 'tag' and highest_release: if GITHUB_REF_TYPE == 'tag' and highest_release:
env_file.write('stable_release=true\n') env_file.write('stable_release=true\n')
return True
if __name__ == '__main__':
rslt = main()
if rslt is not True:
print('Version check failed!')
sys.exit(1)
+21
View File
@@ -97,6 +97,27 @@ jobs:
pip install --require-hashes -r contrib/dev_reqs/requirements.txt pip install --require-hashes -r contrib/dev_reqs/requirements.txt
python3 .github/scripts/version_check.py python3 .github/scripts/version_check.py
typecheck:
name: Style [Typecheck]
runs-on: ubuntu-24.04
needs: [paths-filter, pre-commit]
if: needs.paths-filter.outputs.server == 'true' || needs.paths-filter.outputs.requirements == 'true' || needs.paths-filter.outputs.force == 'true'
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # pin@v4.2.2
with:
persist-credentials: false
- name: Environment Setup
id: setup
uses: ./.github/actions/setup
with:
apt-dependency: gettext poppler-utils
dev-install: true
update: true
- name: Check types
run: |
ty check --python ${Python_ROOT_DIR}/bin/python3
mkdocs: mkdocs:
name: Style [Documentation] name: Style [Documentation]
runs-on: ubuntu-24.04 runs-on: ubuntu-24.04
+1 -1
View File
@@ -4,7 +4,7 @@ import json
import os import os
import re import re
from datetime import datetime from datetime import datetime
from distutils.version import StrictVersion from distutils.version import StrictVersion # type: ignore[import]
from pathlib import Path from pathlib import Path
import requests import requests
+1 -1
View File
@@ -126,7 +126,7 @@ def check_link(url) -> bool:
return False return False
def get_build_environment() -> str: def get_build_environment() -> Optional[str]:
"""Returns the branch we are currently building on, based on the environment variables of the various CI platforms.""" """Returns the branch we are currently building on, based on the environment variables of the various CI platforms."""
# Check if we are in ReadTheDocs # Check if we are in ReadTheDocs
if os.environ.get('READTHEDOCS') == 'True': if os.environ.get('READTHEDOCS') == 'True':
+14
View File
@@ -101,6 +101,20 @@ python-version = "3.9.2"
no-strip-extras=true no-strip-extras=true
generate-hashes=true generate-hashes=true
[tool.ty.src]
root = "src/backend/InvenTree"
[tool.ty.rules]
unresolved-reference="ignore" # 21 # see https://github.com/astral-sh/ty/issues/220
unresolved-attribute="ignore" # 505 # need Plugin Mixin typing
call-non-callable="ignore" # 8 ##
invalid-return-type="ignore" # 22 ##
invalid-argument-type="ignore" # 49
possibly-unbound-attribute="ignore" # 25 # https://github.com/astral-sh/ty/issues/164
unknown-argument="ignore" # 3 # need to wait for betterdjango field stubs
invalid-assignment="ignore" # 17 # need to wait for betterdjango field stubs
[tool.coverage.run] [tool.coverage.run]
source = ["src/backend/InvenTree", "InvenTree"] source = ["src/backend/InvenTree", "InvenTree"]
dynamic_context = "test_function" dynamic_context = "test_function"
@@ -48,10 +48,13 @@ class AllUserRequire2FAMiddleware(MiddlewareMixin):
def is_allowed_page(self, request: HttpRequest) -> bool: def is_allowed_page(self, request: HttpRequest) -> bool:
"""Check if the current page can be accessed without mfa.""" """Check if the current page can be accessed without mfa."""
match = request.resolver_match
return ( return (
any(ref in self.app_names for ref in request.resolver_match.app_names) None
or request.resolver_match.url_name in self.allowed_pages if match is None
or request.resolver_match.route == 'favicon.ico' else any(ref in self.app_names for ref in match.app_names)
or match.url_name in self.allowed_pages
or match.route == 'favicon.ico'
) )
def enforce_2fa(self, request): def enforce_2fa(self, request):
+1
View File
@@ -18,6 +18,7 @@ from rest_framework.serializers import ValidationError
from rest_framework.views import APIView from rest_framework.views import APIView
import InvenTree.config import InvenTree.config
import InvenTree.permissions
import InvenTree.version import InvenTree.version
from common.settings import get_global_setting from common.settings import get_global_setting
from InvenTree import helpers from InvenTree import helpers
+3
View File
@@ -131,6 +131,9 @@ class InvenTreeConfig(AppConfig):
tasks = InvenTree.tasks.tasks.task_list tasks = InvenTree.tasks.tasks.task_list
for task in tasks: for task in tasks:
if not task:
continue # pragma: no cover
ref_name = f'{task.func.__module__}.{task.func.__name__}' ref_name = f'{task.func.__module__}.{task.func.__name__}'
if ref_name in existing_tasks: if ref_name in existing_tasks:
+3 -2
View File
@@ -2,6 +2,7 @@
import socket import socket
import threading import threading
from typing import Any
import structlog import structlog
@@ -140,7 +141,7 @@ def delete_session_cache() -> None:
del thread_data.request_cache del thread_data.request_cache
def get_session_cache(key: str) -> any: def get_session_cache(key: str) -> Any:
"""Return a cached value from the session cache.""" """Return a cached value from the session cache."""
# Only return a cached value if the request object is available too # Only return a cached value if the request object is available too
if not hasattr(thread_data, 'request'): if not hasattr(thread_data, 'request'):
@@ -152,7 +153,7 @@ def get_session_cache(key: str) -> any:
return val return val
def set_session_cache(key: str, value: any) -> None: def set_session_cache(key: str, value: Any) -> None:
"""Set a cached value in the session cache.""" """Set a cached value in the session cache."""
# Only set a cached value if the request object is available too # Only set a cached value if the request object is available too
if not hasattr(thread_data, 'request'): if not hasattr(thread_data, 'request'):
+1 -1
View File
@@ -171,7 +171,7 @@ def get_config_file(create=True) -> Path:
return cfg_filename return cfg_filename
def load_config_data(set_cache: bool = False) -> map: def load_config_data(set_cache: bool = False) -> Union[map, None]:
"""Load configuration data from the config file. """Load configuration data from the config file.
Arguments: Arguments:
@@ -191,7 +191,7 @@ def convert_physical_value(value: str, unit: Optional[str] = None, strip_units=T
attempts.append(f'{value}{unit}') attempts.append(f'{value}{unit}')
attempts.append(f'{eng}{unit}') attempts.append(f'{eng}{unit}')
value = None value: Optional[str] = None
# Run through the available "attempts", take the first successful result # Run through the available "attempts", take the first successful result
for attempt in attempts: for attempt in attempts:
@@ -66,7 +66,8 @@ def log_error(
data = error_data data = error_data
else: else:
try: try:
data = '\n'.join(traceback.format_exception(kind, info, data)) formatted_exception = traceback.format_exception(kind, info, data) # type: ignore[no-matching-overload]
data = '\n'.join(formatted_exception)
except AttributeError: except AttributeError:
data = 'No traceback information available' data = 'No traceback information available'
@@ -150,8 +151,10 @@ def exception_handler(exc, context):
if response is not None: if response is not None:
# Convert errors returned under the label '__all__' to 'non_field_errors' # Convert errors returned under the label '__all__' to 'non_field_errors'
if '__all__' in response.data: data = response.data
response.data['non_field_errors'] = response.data['__all__']
del response.data['__all__'] if data and '__all__' in data:
data['non_field_errors'] = data['__all__']
del data['__all__']
return response return response
+7 -6
View File
@@ -6,7 +6,8 @@ from django.conf import settings
from django.utils import timezone from django.utils import timezone
from django.utils.timezone import make_aware from django.utils.timezone import make_aware
from django_filters import rest_framework as rest_filters import django_filters.rest_framework.backends as drf_backend
import django_filters.rest_framework.filters as rest_filters
from rest_framework import filters from rest_framework import filters
import InvenTree.helpers import InvenTree.helpers
@@ -20,7 +21,7 @@ class InvenTreeDateFilter(rest_filters.DateFilter):
if settings.USE_TZ and value is not None: if settings.USE_TZ and value is not None:
tz = timezone.get_current_timezone() tz = timezone.get_current_timezone()
value = datetime(value.year, value.month, value.day) value = datetime(value.year, value.month, value.day)
value = make_aware(value, tz, True) value = make_aware(value, timezone=tz, is_dst=True)
return super().filter(qs, value) return super().filter(qs, value)
@@ -192,17 +193,17 @@ class NumberOrNullFilter(rest_filters.NumberFilter):
SEARCH_ORDER_FILTER = [ SEARCH_ORDER_FILTER = [
rest_filters.DjangoFilterBackend, drf_backend.DjangoFilterBackend,
InvenTreeSearchFilter, InvenTreeSearchFilter,
filters.OrderingFilter, filters.OrderingFilter,
] ]
SEARCH_ORDER_FILTER_ALIAS = [ SEARCH_ORDER_FILTER_ALIAS = [
rest_filters.DjangoFilterBackend, drf_backend.DjangoFilterBackend,
InvenTreeSearchFilter, InvenTreeSearchFilter,
InvenTreeOrderingFilter, InvenTreeOrderingFilter,
] ]
ORDER_FILTER = [rest_filters.DjangoFilterBackend, filters.OrderingFilter] ORDER_FILTER = [drf_backend.DjangoFilterBackend, filters.OrderingFilter]
ORDER_FILTER_ALIAS = [rest_filters.DjangoFilterBackend, InvenTreeOrderingFilter] ORDER_FILTER_ALIAS = [drf_backend.DjangoFilterBackend, InvenTreeOrderingFilter]
+2 -2
View File
@@ -107,7 +107,7 @@ def construct_format_regex(fmt_string: str) -> str:
# Add a named capture group for the format entry # Add a named capture group for the format entry
if name: if name:
# Check if integer values are required # Check if integer values are required
c = '\\d' if _fmt.endswith('d') else '.' c = '\\d' if _fmt and _fmt.endswith('d') else '.'
# Specify width # Specify width
# TODO: Introspect required width # TODO: Introspect required width
@@ -124,7 +124,7 @@ def construct_format_regex(fmt_string: str) -> str:
return pattern return pattern
def validate_string(value: str, fmt_string: str) -> str: def validate_string(value: str, fmt_string: str) -> bool:
"""Validate that the provided string matches the specified format. """Validate that the provided string matches the specified format.
Args: Args:
+11 -11
View File
@@ -9,7 +9,7 @@ import os
import os.path import os.path
import re import re
from decimal import Decimal, InvalidOperation from decimal import Decimal, InvalidOperation
from typing import Optional, TypeVar from typing import Optional, TypeVar, Union
from wsgiref.util import FileWrapper from wsgiref.util import FileWrapper
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
@@ -22,6 +22,8 @@ from django.utils import timezone
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
import bleach import bleach
import bleach.css_sanitizer
import bleach.sanitizer
import structlog import structlog
from bleach import clean from bleach import clean
from djmoney.money import Money from djmoney.money import Money
@@ -124,7 +126,7 @@ def extract_int(
return ref_int return ref_int
def generateTestKey(test_name: str) -> str: def generateTestKey(test_name: Union[str, None]) -> str:
"""Generate a test 'key' for a given test name. This must not have illegal chars as it will be used for dict lookup in a template. """Generate a test 'key' for a given test name. This must not have illegal chars as it will be used for dict lookup in a template.
Tests must be named such that they will have unique keys. Tests must be named such that they will have unique keys.
@@ -366,9 +368,7 @@ def increment(value):
except ValueError: except ValueError:
pass pass
number = number.zfill(width) return prefix + str(number).zfill(width)
return prefix + number
def decimal2string(d): def decimal2string(d):
@@ -966,7 +966,7 @@ def current_time(local=True):
""" """
if settings.USE_TZ: if settings.USE_TZ:
now = timezone.now() now = timezone.now()
now = to_local_time(now, target_tz=server_timezone() if local else 'UTC') now = to_local_time(now, target_tz_str=server_timezone() if local else 'UTC')
return now return now
else: else:
return datetime.datetime.now() return datetime.datetime.now()
@@ -985,12 +985,12 @@ def server_timezone() -> str:
return settings.TIME_ZONE return settings.TIME_ZONE
def to_local_time(time, target_tz: Optional[str] = None): def to_local_time(time, target_tz_str: Optional[str] = None):
"""Convert the provided time object to the local timezone. """Convert the provided time object to the local timezone.
Arguments: Arguments:
time: The time / date to convert time: The time / date to convert
target_tz: The desired timezone (string) - defaults to server time target_tz_str: The desired timezone (string) - defaults to server time
Returns: Returns:
A timezone aware datetime object, with the desired timezone A timezone aware datetime object, with the desired timezone
@@ -1014,11 +1014,11 @@ def to_local_time(time, target_tz: Optional[str] = None):
# Default to UTC if not provided # Default to UTC if not provided
source_tz = ZoneInfo('UTC') source_tz = ZoneInfo('UTC')
if not target_tz: if not target_tz_str:
target_tz = server_timezone() target_tz_str = server_timezone()
try: try:
target_tz = ZoneInfo(str(target_tz)) target_tz = ZoneInfo(str(target_tz_str))
except ZoneInfoNotFoundError: except ZoneInfoNotFoundError:
target_tz = ZoneInfo('UTC') target_tz = ZoneInfo('UTC')
@@ -114,7 +114,7 @@ def send_email(
return True, None return True, None
def get_email_for_user(user) -> str: def get_email_for_user(user) -> Optional[str]:
"""Find an email address for the specified user.""" """Find an email address for the specified user."""
# First check if the user has an associated email address # First check if the user has an associated email address
if user.email: if user.email:
@@ -11,6 +11,7 @@ from django.db.utils import OperationalError, ProgrammingError
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
import requests import requests
import requests.exceptions
import structlog import structlog
from djmoney.contrib.exchange.models import convert_money from djmoney.contrib.exchange.models import convert_money
from djmoney.money import Money from djmoney.money import Money
@@ -328,8 +329,9 @@ def notify_users(
'template': {'subject': content.name.format(**content_context)}, 'template': {'subject': content.name.format(**content_context)},
} }
if content.template: tmp = content.template
context['template']['html'] = content.template.format(**content_context) if tmp:
context['template']['html'] = tmp.format(**content_context)
# Create notification # Create notification
trigger_notification( trigger_notification(
@@ -1,7 +1,7 @@
"""Extended schema generator.""" """Extended schema generator."""
from pathlib import Path from pathlib import Path
from typing import TypeVar from typing import TypeVar, Union
from django.conf import settings from django.conf import settings
@@ -26,7 +26,7 @@ def prep_name(ref):
return f'{dja_ref_prefix}.{ref}' return f'{dja_ref_prefix}.{ref}'
def sub_component_name(name: T) -> T: def sub_component_name(name: T) -> Union[T, str]:
"""Clean up component references.""" """Clean up component references."""
if not isinstance(name, str): if not isinstance(name, str):
return name return name
@@ -2,9 +2,10 @@
import time import time
from django.core.exceptions import ImproperlyConfigured
from django.core.management.base import BaseCommand from django.core.management.base import BaseCommand
from django.db import connection from django.db import connection
from django.db.utils import ImproperlyConfigured, OperationalError from django.db.utils import OperationalError
class Command(BaseCommand): class Command(BaseCommand):
+1 -1
View File
@@ -1172,7 +1172,7 @@ class InvenTreeBarcodeMixin(models.Model):
raise ValueError("Provide either 'barcode_hash' or 'barcode_data'") raise ValueError("Provide either 'barcode_hash' or 'barcode_data'")
# If barcode_hash is not provided, create from supplier barcode_data # If barcode_hash is not provided, create from supplier barcode_data
if barcode_hash is None: if barcode_hash is None and barcode_data is not None:
barcode_hash = InvenTree.helpers.hash_barcode(barcode_data) barcode_hash = InvenTree.helpers.hash_barcode(barcode_data)
# Check for existing item # Check for existing item
+11 -5
View File
@@ -256,7 +256,9 @@ class RolePermissionOrReadOnly(RolePermission):
def get_required_alternate_scopes(self, request, view): def get_required_alternate_scopes(self, request, view):
"""Return the required scopes for the current request.""" """Return the required scopes for the current request."""
scopes = map_scope( scopes = map_scope(
only_read=True, read_name=DEFAULT_STAFF, map_read=permissions.SAFE_METHODS only_read=True,
read_name=DEFAULT_STAFF,
map_read=list(permissions.SAFE_METHODS),
) )
return scopes return scopes
@@ -294,7 +296,7 @@ class IsSuperuserOrReadOnlyOrScope(OASTokenMixin, permissions.IsAdminUser):
return map_scope( return map_scope(
only_read=True, only_read=True,
read_name=DEFAULT_SUPERUSER, read_name=DEFAULT_SUPERUSER,
map_read=permissions.SAFE_METHODS, map_read=list(permissions.SAFE_METHODS),
) )
@@ -319,7 +321,9 @@ class IsStaffOrReadOnlyScope(OASTokenMixin, permissions.IsAuthenticated):
def get_required_alternate_scopes(self, request, view): def get_required_alternate_scopes(self, request, view):
"""Return the required scopes for the current request.""" """Return the required scopes for the current request."""
return map_scope( return map_scope(
only_read=True, read_name=DEFAULT_STAFF, map_read=permissions.SAFE_METHODS only_read=True,
read_name=DEFAULT_STAFF,
map_read=list(permissions.SAFE_METHODS),
) )
@@ -349,7 +353,7 @@ def auth_exempt(view_func):
def wrapped_view(*args, **kwargs): def wrapped_view(*args, **kwargs):
return view_func(*args, **kwargs) return view_func(*args, **kwargs)
wrapped_view.auth_exempt = True wrapped_view.auth_exempt = True # type:ignore[unresolved-attribute]
return wraps(view_func)(wrapped_view) return wraps(view_func)(wrapped_view)
@@ -400,7 +404,9 @@ class GlobalSettingsPermissions(OASTokenMixin, permissions.BasePermission):
def get_required_alternate_scopes(self, request, view): def get_required_alternate_scopes(self, request, view):
"""Return the required scopes for the current request.""" """Return the required scopes for the current request."""
return map_scope( return map_scope(
only_read=True, read_name=DEFAULT_STAFF, map_read=permissions.SAFE_METHODS only_read=True,
read_name=DEFAULT_STAFF,
map_read=list(permissions.SAFE_METHODS),
) )
+2 -2
View File
@@ -188,8 +188,8 @@ ALLOWED_ATTRIBUTES_SVG = [
def sanitize_svg( def sanitize_svg(
file_data, file_data,
strip: bool = True, strip: bool = True,
elements: str = ALLOWED_ELEMENTS_SVG, elements: list[str] = ALLOWED_ELEMENTS_SVG,
attributes: str = ALLOWED_ATTRIBUTES_SVG, attributes: list[str] = ALLOWED_ATTRIBUTES_SVG,
) -> str: ) -> str:
"""Sanitize a SVG file. """Sanitize a SVG file.
@@ -373,16 +373,15 @@ class InvenTreeModelSerializer(serializers.ModelSerializer):
instance.full_clean() instance.full_clean()
except (ValidationError, DjangoValidationError) as exc: except (ValidationError, DjangoValidationError) as exc:
if hasattr(exc, 'message_dict'): if hasattr(exc, 'message_dict'):
data = exc.message_dict data = {**exc.message_dict}
elif hasattr(exc, 'message'): elif hasattr(exc, 'message'):
data = {'non_field_errors': [str(exc.message)]} data = {'non_field_errors': [str(exc.message)]}
else: else:
data = {'non_field_errors': [str(exc)]} data = {'non_field_errors': [str(exc)]}
# Change '__all__' key (django style) to 'non_field_errors' (DRF style) # Change '__all__' key (django style) to 'non_field_errors' (DRF style)
if '__all__' in data: if hasattr(data, '__all__'):
data['non_field_errors'] = data['__all__'] data['non_field_errors'] = data.pop('__all__')
del data['__all__']
raise ValidationError(data) raise ValidationError(data)
+27 -15
View File
@@ -43,6 +43,13 @@ from users.oauth2_scopes import oauth2_scopes
from . import config, locales from . import config, locales
try:
import django_stubs_ext
django_stubs_ext.monkeypatch() # pragma: no cover
except ImportError: # pragma: no cover
pass
checkMinPythonVersion() checkMinPythonVersion()
INVENTREE_BASE_URL = 'https://inventree.org' INVENTREE_BASE_URL = 'https://inventree.org'
@@ -382,22 +389,25 @@ QUERYCOUNT = {
} }
AUTHENTICATION_BACKENDS = CONFIG.get( default_auth_backends = [
'authentication_backends', 'oauth2_provider.backends.OAuth2Backend', # OAuth2 provider
[ 'django.contrib.auth.backends.RemoteUserBackend', # proxy login
'oauth2_provider.backends.OAuth2Backend', # OAuth2 provider 'django.contrib.auth.backends.ModelBackend',
'django.contrib.auth.backends.RemoteUserBackend', # proxy login 'allauth.account.auth_backends.AuthenticationBackend', # SSO login via external providers
'django.contrib.auth.backends.ModelBackend', 'sesame.backends.ModelBackend', # Magic link login django-sesame
'allauth.account.auth_backends.AuthenticationBackend', # SSO login via external providers ]
'sesame.backends.ModelBackend', # Magic link login django-sesame
], AUTHENTICATION_BACKENDS = (
CONFIG.get('authentication_backends', default_auth_backends)
if CONFIG
else default_auth_backends
) )
# LDAP support # LDAP support
LDAP_AUTH = get_boolean_setting('INVENTREE_LDAP_ENABLED', 'ldap.enabled', False) LDAP_AUTH = get_boolean_setting('INVENTREE_LDAP_ENABLED', 'ldap.enabled', False)
if LDAP_AUTH: if LDAP_AUTH:
import django_auth_ldap.config import django_auth_ldap.config # type: ignore[unresolved-import]
import ldap import ldap # type: ignore[unresolved-import]
AUTHENTICATION_BACKENDS.append('django_auth_ldap.backend.LDAPBackend') AUTHENTICATION_BACKENDS.append('django_auth_ldap.backend.LDAPBackend')
@@ -450,7 +460,7 @@ if LDAP_AUTH:
) )
AUTH_LDAP_USER_SEARCH = django_auth_ldap.config.LDAPSearch( AUTH_LDAP_USER_SEARCH = django_auth_ldap.config.LDAPSearch(
get_setting('INVENTREE_LDAP_SEARCH_BASE_DN', 'ldap.search_base_dn'), get_setting('INVENTREE_LDAP_SEARCH_BASE_DN', 'ldap.search_base_dn'),
ldap.SCOPE_SUBTREE, ldap.SCOPE_SUBTREE, # type: ignore[unresolved-attribute]
str( str(
get_setting( get_setting(
'INVENTREE_LDAP_SEARCH_FILTER_STR', 'INVENTREE_LDAP_SEARCH_FILTER_STR',
@@ -486,7 +496,7 @@ if LDAP_AUTH:
) )
AUTH_LDAP_GROUP_SEARCH = django_auth_ldap.config.LDAPSearch( AUTH_LDAP_GROUP_SEARCH = django_auth_ldap.config.LDAPSearch(
get_setting('INVENTREE_LDAP_GROUP_SEARCH', 'ldap.group_search'), get_setting('INVENTREE_LDAP_GROUP_SEARCH', 'ldap.group_search'),
ldap.SCOPE_SUBTREE, ldap.SCOPE_SUBTREE, # type: ignore[unresolved-attribute]
f'(objectClass={AUTH_LDAP_GROUP_OBJECT_CLASS})', f'(objectClass={AUTH_LDAP_GROUP_OBJECT_CLASS})',
) )
AUTH_LDAP_GROUP_TYPE_CLASS = get_setting( AUTH_LDAP_GROUP_TYPE_CLASS = get_setting(
@@ -604,7 +614,7 @@ Configure the database backend based on the user-specified values.
logger.debug('Configuring database backend:') logger.debug('Configuring database backend:')
# Extract database configuration from the config.yaml file # Extract database configuration from the config.yaml file
db_config = CONFIG.get('database', None) db_config = CONFIG.get('database', None) if CONFIG else None
if not db_config: if not db_config:
db_config = {} db_config = {}
@@ -690,7 +700,9 @@ if db_options is None:
# Specific options for postgres backend # Specific options for postgres backend
if 'postgres' in DB_ENGINE: # pragma: no cover if 'postgres' in DB_ENGINE: # pragma: no cover
from django.db.backends.postgresql.psycopg_any import IsolationLevel from django.db.backends.postgresql.psycopg_any import ( # type: ignore[unresolved-import]
IsolationLevel,
)
# Connection timeout # Connection timeout
if 'connect_timeout' not in db_options: if 'connect_timeout' not in db_options:
+2 -1
View File
@@ -50,7 +50,7 @@ def check_provider(provider):
if not app: if not app:
return False return False
if allauth.app_settings.SITES_ENABLED: if allauth.app_settings.SITES_ENABLED: # type: ignore[unresolved-attribute]
# At least one matching site must be specified # At least one matching site must be specified
if not app.sites.exists(): if not app.sites.exists():
logger.error('SocialApp %s has no sites configured', app) logger.error('SocialApp %s has no sites configured', app)
@@ -102,6 +102,7 @@ def ensure_sso_groups(sender, sociallogin: SocialLogin, **kwargs):
# ensure user has groups # ensure user has groups
user = sociallogin.account.user user = sociallogin.account.user
for group_name in group_names: for group_name in group_names:
try: try:
user.groups.get(name=group_name) user.groups.get(name=group_name)
+5 -3
View File
@@ -285,7 +285,7 @@ class ScheduledTask:
QUARTERLY: str = 'Q' QUARTERLY: str = 'Q'
YEARLY: str = 'Y' YEARLY: str = 'Y'
TYPE: tuple[str] = (MINUTES, HOURLY, DAILY, WEEKLY, MONTHLY, QUARTERLY, YEARLY) TYPE: tuple[str] = (MINUTES, HOURLY, DAILY, WEEKLY, MONTHLY, QUARTERLY, YEARLY) # type: ignore[invalid-assignment]
class TaskRegister: class TaskRegister:
@@ -302,7 +302,9 @@ tasks = TaskRegister()
def scheduled_task( def scheduled_task(
interval: str, minutes: Optional[int] = None, tasklist: TaskRegister = None interval: str,
minutes: Optional[int] = None,
tasklist: Optional[TaskRegister] = None,
): ):
"""Register the given task as a scheduled task. """Register the given task as a scheduled task.
@@ -544,7 +546,7 @@ def check_for_updates():
match = re.match(r'^.*(\d+)\.(\d+)\.(\d+).*$', tag) match = re.match(r'^.*(\d+)\.(\d+)\.(\d+).*$', tag)
if len(match.groups()) != 3: # pragma: no cover if not match or len(match.groups()) != 3: # pragma: no cover
logger.warning("Version '%s' did not match expected pattern", tag) logger.warning("Version '%s' did not match expected pattern", tag)
return return
+1 -1
View File
@@ -567,7 +567,7 @@ class GeneralApiTests(InvenTreeAPITestCase):
self.assertIn('License file not found at', str(log.output)) self.assertIn('License file not found at', str(log.output))
with TemporaryDirectory() as tmp: with TemporaryDirectory() as tmp: # type: ignore[no-matching-overload]
sample_file = Path(tmp, 'temp.txt') sample_file = Path(tmp, 'temp.txt')
sample_file.write_text('abc') sample_file.write_text('abc')
+8 -3
View File
@@ -4,7 +4,7 @@ import base64
import logging import logging
from typing import Optional from typing import Optional
from opentelemetry import metrics, trace from opentelemetry import metrics, trace # type: ignore[import]
from opentelemetry.instrumentation.django import DjangoInstrumentor from opentelemetry.instrumentation.django import DjangoInstrumentor
from opentelemetry.instrumentation.redis import RedisInstrumentor from opentelemetry.instrumentation.redis import RedisInstrumentor
from opentelemetry.instrumentation.requests import RequestsInstrumentor from opentelemetry.instrumentation.requests import RequestsInstrumentor
@@ -29,8 +29,8 @@ TRACE_PROV = None
def setup_tracing( def setup_tracing(
endpoint: str, endpoint: Optional[str] = None,
headers: dict, headers: Optional[dict] = None,
resources_input: Optional[dict] = None, resources_input: Optional[dict] = None,
console: bool = False, console: bool = False,
auth: Optional[dict] = None, auth: Optional[dict] = None,
@@ -50,6 +50,11 @@ def setup_tracing(
""" """
if InvenTree.ready.isImportingData() or InvenTree.ready.isRunningMigrations(): if InvenTree.ready.isImportingData() or InvenTree.ready.isRunningMigrations():
return return
if endpoint is None or headers is None:
print(
'Tracing endpoint or headers not specified - skipping tracing setup'
) # pragma: no cover
return # pragma: no cover
# Logger configuration # Logger configuration
logger = logging.getLogger('inventree') logger = logging.getLogger('inventree')
+5 -1
View File
@@ -115,7 +115,7 @@ def getOldestMigrationFile(app, exclude_extension=True, ignore_initial=True):
oldest_num = num oldest_num = num
oldest_file = f oldest_file = f
if exclude_extension: if exclude_extension and oldest_file:
oldest_file = oldest_file.replace('.py', '') oldest_file = oldest_file.replace('.py', '')
return oldest_file return oldest_file
@@ -583,6 +583,10 @@ class InvenTreeAPITestCase(
result = re.search( result = re.search(
r'(attachment|inline); filename=[\'"]([\w\d\-.]+)[\'"]', disposition r'(attachment|inline); filename=[\'"]([\w\d\-.]+)[\'"]', disposition
) )
if not result:
raise ValueError(
'No filename match found in disposition'
) # pragma: no cover
fn = result.groups()[1] fn = result.groups()[1]
@@ -6,6 +6,7 @@ from django.core.exceptions import ValidationError
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
import pint import pint
import pint.errors
from moneyed import CURRENCIES from moneyed import CURRENCIES
import InvenTree.conversion import InvenTree.conversion
+1 -1
View File
@@ -107,7 +107,7 @@ def inventreeVersionTuple(version=None):
match = re.match(r'^.*(\d+)\.(\d+)\.(\d+).*$', str(version)) match = re.match(r'^.*(\d+)\.(\d+)\.(\d+).*$', str(version))
return [int(g) for g in match.groups()] return [int(g) for g in match.groups()] if match else []
def isInvenTreeDevelopmentVersion(): def isInvenTreeDevelopmentVersion():
+10 -6
View File
@@ -2,17 +2,21 @@
from __future__ import annotations from __future__ import annotations
from typing import Optional
from django.contrib.auth.models import User from django.contrib.auth.models import User
from django.db.models import F, Q from django.db.models import F, Q
from django.urls import include, path from django.urls import include, path
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from django_filters import rest_framework as rest_filters from django_filters import rest_framework as rest_filters
from django_filters.rest_framework.filterset import FilterSet
from drf_spectacular.utils import extend_schema, extend_schema_field from drf_spectacular.utils import extend_schema, extend_schema_field
from rest_framework import serializers, status from rest_framework import serializers, status
from rest_framework.exceptions import ValidationError from rest_framework.exceptions import ValidationError
from rest_framework.response import Response from rest_framework.response import Response
import build.models as build_models
import build.serializers import build.serializers
import common.models import common.models
import part.models as part_models import part.models as part_models
@@ -33,7 +37,7 @@ from InvenTree.mixins import CreateAPI, ListCreateAPI, RetrieveUpdateDestroyAPI
from users.models import Owner from users.models import Owner
class BuildFilter(rest_filters.FilterSet): class BuildFilter(FilterSet):
"""Custom filterset for BuildList API endpoint.""" """Custom filterset for BuildList API endpoint."""
class Meta: class Meta:
@@ -431,7 +435,7 @@ class BuildUnallocate(CreateAPI):
return ctx return ctx
class BuildLineFilter(rest_filters.FilterSet): class BuildLineFilter(FilterSet):
"""Custom filterset for the BuildLine API endpoint.""" """Custom filterset for the BuildLine API endpoint."""
class Meta: class Meta:
@@ -605,7 +609,7 @@ class BuildLineList(BuildLineMixin, DataExportViewMixin, ListCreateAPI):
'bom_item__reference', 'bom_item__reference',
] ]
def get_source_build(self) -> Build | None: def get_source_build(self) -> Optional[Build]:
"""Return the target build for the BuildLine queryset.""" """Return the target build for the BuildLine queryset."""
source_build = None source_build = None
@@ -622,7 +626,7 @@ class BuildLineList(BuildLineMixin, DataExportViewMixin, ListCreateAPI):
class BuildLineDetail(BuildLineMixin, RetrieveUpdateDestroyAPI): class BuildLineDetail(BuildLineMixin, RetrieveUpdateDestroyAPI):
"""API endpoint for detail view of a BuildLine object.""" """API endpoint for detail view of a BuildLine object."""
def get_source_build(self) -> Build | None: def get_source_build(self) -> Optional[Build]:
"""Return the target source location for the BuildLine queryset.""" """Return the target source location for the BuildLine queryset."""
return None return None
@@ -783,7 +787,7 @@ class BuildItemDetail(RetrieveUpdateDestroyAPI):
serializer_class = build.serializers.BuildItemSerializer serializer_class = build.serializers.BuildItemSerializer
class BuildItemFilter(rest_filters.FilterSet): class BuildItemFilter(FilterSet):
"""Custom filterset for the BuildItemList API endpoint.""" """Custom filterset for the BuildItemList API endpoint."""
class Meta: class Meta:
@@ -829,7 +833,7 @@ class BuildItemFilter(rest_filters.FilterSet):
return queryset.filter(stock_item__part=part) return queryset.filter(stock_item__part=part)
build = rest_filters.ModelChoiceFilter( build = rest_filters.ModelChoiceFilter(
queryset=build.models.Build.objects.all(), queryset=build_models.Build.objects.all(),
label=_('Build Order'), label=_('Build Order'),
field_name='build_line__build', field_name='build_line__build',
) )
+2 -2
View File
@@ -1064,7 +1064,7 @@ class Build(
lines = lines.exclude(bom_item__consumable=True) lines = lines.exclude(bom_item__consumable=True)
lines = annotate_allocated_quantity(lines) lines = annotate_allocated_quantity(lines)
for build_line in lines: for build_line in lines: # type: ignore[non-iterable]
reduce_by = build_line.allocated - build_line.quantity reduce_by = build_line.allocated - build_line.quantity
if reduce_by <= 0: if reduce_by <= 0:
@@ -1359,7 +1359,7 @@ class Build(
except (ValidationError, serializers.ValidationError) as exc: except (ValidationError, serializers.ValidationError) as exc:
# Catch model errors and re-throw as DRF errors # Catch model errors and re-throw as DRF errors
raise ValidationError( raise ValidationError(
detail=serializers.as_serializer_error(exc) exc.message, detail=serializers.as_serializer_error(exc)
) )
if unallocated_quantity <= 0: if unallocated_quantity <= 0:
@@ -23,6 +23,7 @@ from rest_framework.serializers import ValidationError
import build.tasks import build.tasks
import common.models import common.models
import common.settings
import company.serializers import company.serializers
import InvenTree.helpers import InvenTree.helpers
import InvenTree.tasks import InvenTree.tasks
+17 -2
View File
@@ -1,6 +1,7 @@
"""Unit tests for the BuildOrder API.""" """Unit tests for the BuildOrder API."""
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Optional
from django.urls import reverse from django.urls import reverse
@@ -668,6 +669,11 @@ class BuildAllocationTest(BuildAPITest):
wrong_line = line wrong_line = line
break break
if not wrong_line:
raise self.fail(
'No matching BuildLine found for the given stock item'
) # pragma: no cover
data = self.post( data = self.post(
self.url, self.url,
{ {
@@ -695,6 +701,11 @@ class BuildAllocationTest(BuildAPITest):
right_line = line right_line = line
break break
if not right_line:
raise self.fail(
'No matching BuildLine found for the given stock item'
) # pragma: no cover
self.post( self.post(
self.url, self.url,
{ {
@@ -722,11 +733,15 @@ class BuildAllocationTest(BuildAPITest):
# Find the correct BuildLine # Find the correct BuildLine
si = StockItem.objects.get(pk=2) si = StockItem.objects.get(pk=2)
right_line = None right_line: Optional[BuildLine] = None
for line in self.build.build_lines.all(): for line in self.build.build_lines.all():
if line.bom_item.sub_part.pk == si.part.pk: if line.bom_item.sub_part.pk == si.part.pk:
right_line = line right_line: BuildLine = line
break break
if not right_line:
raise self.fail(
'No matching BuildLine found for the given stock item'
) # pragma: no cover
self.post( self.post(
self.url, self.url,
+4 -2
View File
@@ -1,6 +1,7 @@
"""Provides a JSON API for common components.""" """Provides a JSON API for common components."""
import json import json
import json.decoder
from django.conf import settings from django.conf import settings
from django.contrib.contenttypes.models import ContentType from django.contrib.contenttypes.models import ContentType
@@ -13,8 +14,9 @@ from django.utils.translation import gettext_lazy as _
from django.views.decorators.cache import cache_control from django.views.decorators.cache import cache_control
from django.views.decorators.csrf import csrf_exempt from django.views.decorators.csrf import csrf_exempt
import django_filters.rest_framework.filters as rest_filters
import django_q.models import django_q.models
from django_filters import rest_framework as rest_filters from django_filters.rest_framework.filterset import FilterSet
from django_q.tasks import async_task from django_q.tasks import async_task
from djmoney.contrib.exchange.models import ExchangeBackend, Rate from djmoney.contrib.exchange.models import ExchangeBackend, Rate
from drf_spectacular.utils import OpenApiResponse, extend_schema from drf_spectacular.utils import OpenApiResponse, extend_schema
@@ -676,7 +678,7 @@ class ContentTypeModelDetail(ContentTypeDetail):
return super().get(request, *args, **kwargs) return super().get(request, *args, **kwargs)
class AttachmentFilter(rest_filters.FilterSet): class AttachmentFilter(FilterSet):
"""Filterset for the AttachmentList API endpoint.""" """Filterset for the AttachmentList API endpoint."""
class Meta: class Meta:
+1 -1
View File
@@ -160,7 +160,7 @@ def get_price(
- If MOQ (minimum order quantity) is required, bump quantity - If MOQ (minimum order quantity) is required, bump quantity
- If order multiples are to be observed, then we need to calculate based on that, too - If order multiples are to be observed, then we need to calculate based on that, too
""" """
from common.currency import currency_code_default # from common.currency import currency_code_default
if hasattr(instance, break_name): if hasattr(instance, break_name):
price_breaks = getattr(instance, break_name).all() price_breaks = getattr(instance, break_name).all()
+6 -5
View File
@@ -30,6 +30,7 @@ from django.core.mail import EmailMultiAlternatives, get_connection
from django.core.mail.utils import DNS_NAME from django.core.mail.utils import DNS_NAME
from django.core.validators import MinValueValidator from django.core.validators import MinValueValidator
from django.db import models, transaction from django.db import models, transaction
from django.db.models import enums
from django.db.models.signals import post_delete, post_save from django.db.models.signals import post_delete, post_save
from django.db.utils import IntegrityError, OperationalError, ProgrammingError from django.db.utils import IntegrityError, OperationalError, ProgrammingError
from django.dispatch import receiver from django.dispatch import receiver
@@ -66,7 +67,7 @@ from InvenTree.version import inventree_identifier
logger = structlog.get_logger('inventree') logger = structlog.get_logger('inventree')
class RenderMeta(models.enums.ChoicesMeta): class RenderMeta(enums.ChoicesMeta):
"""Metaclass for rendering choices.""" """Metaclass for rendering choices."""
choice_fnc = None choice_fnc = None
@@ -80,7 +81,7 @@ class RenderMeta(models.enums.ChoicesMeta):
return [] return []
class RenderChoices(models.TextChoices, metaclass=RenderMeta): class RenderChoices(models.TextChoices, metaclass=RenderMeta): # type: ignore
"""Class for creating enumerated string choices for schema rendering.""" """Class for creating enumerated string choices for schema rendering."""
@@ -971,7 +972,7 @@ class BaseInvenTreeSetting(models.Model):
return setting.get('model', None) return setting.get('model', None)
def model_filters(self) -> dict: def model_filters(self) -> Optional[dict]:
"""Return the model filters associated with this setting.""" """Return the model filters associated with this setting."""
setting = self.get_setting_definition( setting = self.get_setting_definition(
self.key, **self.get_filters_for_instance() self.key, **self.get_filters_for_instance()
@@ -1505,8 +1506,8 @@ class WebhookEndpoint(models.Model):
request (optional): Original request object. Defaults to None. request (optional): Original request object. Defaults to None.
""" """
return WebhookMessage.objects.create( return WebhookMessage.objects.create(
host=request.get_host(), host=request.get_host() if request else '',
header=json.dumps(dict(headers.items())), header=json.dumps(dict(headers.items())) if headers else None,
body=payload, body=payload,
endpoint=self, endpoint=self,
) )
@@ -84,9 +84,7 @@ class InvenTreeNotificationBodies:
) )
def trigger_notification( def trigger_notification(obj: Model, category: str = '', obj_ref: str = 'pk', **kwargs):
obj: Model, category: Optional[str] = None, obj_ref: str = 'pk', **kwargs
):
"""Send out a notification. """Send out a notification.
Args: Args:
+8 -5
View File
@@ -19,7 +19,7 @@ from django.test import Client, TestCase
from django.test.utils import override_settings from django.test.utils import override_settings
from django.urls import reverse from django.urls import reverse
import PIL from PIL import Image
import common.validators import common.validators
from common.notifications import trigger_notification from common.notifications import trigger_notification
@@ -200,7 +200,7 @@ class AttachmentTest(InvenTreeAPITestCase):
# Assign 'delete' permission to 'part' model # Assign 'delete' permission to 'part' model
self.assignRole('part.delete') self.assignRole('part.delete')
response = self.delete(url, expected_code=204) self.delete(url, expected_code=204)
class SettingsTest(InvenTreeTestCase): class SettingsTest(InvenTreeTestCase):
@@ -671,9 +671,9 @@ class GlobalSettingsApiTest(InvenTreeAPITestCase):
# Find the associated setting # Find the associated setting
setting = next((s for s in response.data if s['key'] == key), None) setting = next((s for s in response.data if s['key'] == key), None)
assert setting is not None
# Check default value (should be False, not 'False') # Check default value (should be False, not 'False')
self.assertIsNotNone(setting)
self.assertFalse(setting['value']) self.assertFalse(setting['value'])
# Check that we can manually set the value # Check that we can manually set the value
@@ -851,9 +851,9 @@ class UserSettingsApiTest(InvenTreeAPITestCase):
# Find the associated setting # Find the associated setting
setting = next((s for s in response.data if s['key'] == key), None) setting = next((s for s in response.data if s['key'] == key), None)
assert setting is not None
# Check default value (should be 10, not '10') # Check default value (should be 10, not '10')
self.assertIsNotNone(setting)
self.assertEqual(setting['value'], 10) self.assertEqual(setting['value'], 10)
# Check that writing an invalid value returns an error # Check that writing an invalid value returns an error
@@ -1535,7 +1535,7 @@ class NotesImageTest(InvenTreeAPITestCase):
n = NotesImage.objects.count() n = NotesImage.objects.count()
# Construct a simple image file # Construct a simple image file
image = PIL.Image.new('RGB', (100, 100), color='red') image = Image.new('RGB', (100, 100), color='red')
with io.BytesIO() as output: with io.BytesIO() as output:
image.save(output, format='PNG') image.save(output, format='PNG')
@@ -1589,6 +1589,7 @@ class ProjectCodesTest(InvenTreeAPITestCase):
# Get the first project code # Get the first project code
code = ProjectCode.objects.first() code = ProjectCode.objects.first()
assert code is not None and code.pk
# Delete it # Delete it
self.delete( self.delete(
@@ -1686,6 +1687,7 @@ class CustomUnitAPITest(InvenTreeAPITestCase):
def test_edit(self): def test_edit(self):
"""Test edit permissions for CustomUnit model.""" """Test edit permissions for CustomUnit model."""
unit = CustomUnit.objects.first() unit = CustomUnit.objects.first()
assert unit is not None and unit.pk
# Try to edit without permission # Try to edit without permission
self.user.is_staff = False self.user.is_staff = False
@@ -1713,6 +1715,7 @@ class CustomUnitAPITest(InvenTreeAPITestCase):
def test_validation(self): def test_validation(self):
"""Test that validation works as expected.""" """Test that validation works as expected."""
unit = CustomUnit.objects.first() unit = CustomUnit.objects.first()
assert unit is not None and unit.pk
self.user.is_staff = True self.user.is_staff = True
self.user.save() self.user.save()
+6 -5
View File
@@ -4,7 +4,8 @@ from django.db.models import Q
from django.urls import include, path from django.urls import include, path
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from django_filters import rest_framework as rest_filters import django_filters.rest_framework.filters as rest_filters
from django_filters.rest_framework.filterset import FilterSet
import part.models import part.models
from data_exporter.mixins import DataExportViewMixin from data_exporter.mixins import DataExportViewMixin
@@ -127,7 +128,7 @@ class AddressDetail(RetrieveUpdateDestroyAPI):
serializer_class = AddressSerializer serializer_class = AddressSerializer
class ManufacturerPartFilter(rest_filters.FilterSet): class ManufacturerPartFilter(FilterSet):
"""Custom API filters for the ManufacturerPart list endpoint.""" """Custom API filters for the ManufacturerPart list endpoint."""
class Meta: class Meta:
@@ -204,7 +205,7 @@ class ManufacturerPartDetail(RetrieveUpdateDestroyAPI):
serializer_class = ManufacturerPartSerializer serializer_class = ManufacturerPartSerializer
class ManufacturerPartParameterFilter(rest_filters.FilterSet): class ManufacturerPartParameterFilter(FilterSet):
"""Custom filterset for the ManufacturerPartParameterList API endpoint.""" """Custom filterset for the ManufacturerPartParameterList API endpoint."""
class Meta: class Meta:
@@ -259,7 +260,7 @@ class ManufacturerPartParameterDetail(RetrieveUpdateDestroyAPI):
serializer_class = ManufacturerPartParameterSerializer serializer_class = ManufacturerPartParameterSerializer
class SupplierPartFilter(rest_filters.FilterSet): class SupplierPartFilter(FilterSet):
"""API filters for the SupplierPartList endpoint.""" """API filters for the SupplierPartList endpoint."""
class Meta: class Meta:
@@ -418,7 +419,7 @@ class SupplierPartDetail(SupplierPartMixin, RetrieveUpdateDestroyAPI):
""" """
class SupplierPriceBreakFilter(rest_filters.FilterSet): class SupplierPriceBreakFilter(FilterSet):
"""Custom API filters for the SupplierPriceBreak list endpoint.""" """Custom API filters for the SupplierPriceBreak list endpoint."""
class Meta: class Meta:
@@ -51,7 +51,7 @@ def reverse_association(apps, schema_editor): # pragma: no cover
row = cursor.fetchone() row = cursor.fetchone()
if len(row) > 0: if row and len(row) > 0:
try: try:
manufacturer_id = int(row[0]) manufacturer_id = int(row[0])
except (TypeError, ValueError): except (TypeError, ValueError):
@@ -67,12 +67,12 @@ def reverse_association(apps, schema_editor): # pragma: no cover
response = cursor.execute(f"SELECT name from company_company where id={manufacturer_id};") response = cursor.execute(f"SELECT name from company_company where id={manufacturer_id};")
row = cursor.fetchone() row = cursor.fetchone()
if row:
name = row[0]
name = row[0] print(" - Manufacturer name: '{name}'".format(name=name))
print(" - Manufacturer name: '{name}'".format(name=name)) response = cursor.execute("UPDATE part_supplierpart SET manufacturer_name='{name}' WHERE id={ID};".format(name=name, ID=supplier_part_id))
response = cursor.execute("UPDATE part_supplierpart SET manufacturer_name='{name}' WHERE id={ID};".format(name=name, ID=supplier_part_id))
def associate_manufacturers(apps, schema_editor): def associate_manufacturers(apps, schema_editor):
""" """
@@ -106,7 +106,7 @@ def associate_manufacturers(apps, schema_editor):
response = cursor.execute(query) response = cursor.execute(query)
row = cursor.fetchone() row = cursor.fetchone()
if len(row) > 0: if row and len(row) > 0:
return row[0] return row[0]
return '' # pragma: no cover return '' # pragma: no cover
@@ -296,11 +296,11 @@ def associate_manufacturers(apps, schema_editor):
# Double-check if the typed name corresponds to an existing item # Double-check if the typed name corresponds to an existing item
elif response in companies.keys(): elif response in companies.keys():
link_part(part, companies[response]) link_part(part_id, companies[response])
return return
elif response in links.keys(): elif response in links.keys():
link_part(part, links[response]) link_part(part_id, links[response])
return return
# No match, create a new manufacturer # No match, create a new manufacturer
+14 -1
View File
@@ -156,7 +156,10 @@ class CompanyTest(InvenTreeAPITestCase):
def test_company_notes(self): def test_company_notes(self):
"""Test the markdown 'notes' field for the Company model.""" """Test the markdown 'notes' field for the Company model."""
pk = Company.objects.first().pk company = Company.objects.first()
assert company
pk = company.pk
url = reverse('api-company-detail', kwargs={'pk': pk}) url = reverse('api-company-detail', kwargs={'pk': pk})
# Attempt to inject malicious markdown into the "notes" field # Attempt to inject malicious markdown into the "notes" field
@@ -253,6 +256,7 @@ class ContactTest(InvenTreeAPITestCase):
n = Contact.objects.count() n = Contact.objects.count()
company = Company.objects.first() company = Company.objects.first()
assert company
# Without required permissions, creation should fail # Without required permissions, creation should fail
self.post( self.post(
@@ -271,6 +275,8 @@ class ContactTest(InvenTreeAPITestCase):
"""Test that we can edit a Contact via the API.""" """Test that we can edit a Contact via the API."""
# Get the first contact # Get the first contact
contact = Contact.objects.first() contact = Contact.objects.first()
assert contact
# Use this contact in the tests # Use this contact in the tests
url = reverse('api-contact-detail', kwargs={'pk': contact.pk}) url = reverse('api-contact-detail', kwargs={'pk': contact.pk})
@@ -294,6 +300,8 @@ class ContactTest(InvenTreeAPITestCase):
"""Tests that we can delete a Contact via the API.""" """Tests that we can delete a Contact via the API."""
# Get the last contact # Get the last contact
contact = Contact.objects.first() contact = Contact.objects.first()
assert contact
url = reverse('api-contact-detail', kwargs={'pk': contact.pk}) url = reverse('api-contact-detail', kwargs={'pk': contact.pk})
# Delete (without required permissions) # Delete (without required permissions)
@@ -348,6 +356,7 @@ class AddressTest(InvenTreeAPITestCase):
def test_filter_list(self): def test_filter_list(self):
"""Test listing addresses filtered on company.""" """Test listing addresses filtered on company."""
company = Company.objects.first() company = Company.objects.first()
assert company
response = self.get(self.url, {'company': company.pk}, expected_code=200) response = self.get(self.url, {'company': company.pk}, expected_code=200)
@@ -356,6 +365,7 @@ class AddressTest(InvenTreeAPITestCase):
def test_create(self): def test_create(self):
"""Test creating a new address.""" """Test creating a new address."""
company = Company.objects.first() company = Company.objects.first()
assert company
self.post(self.url, {'company': company.pk, 'title': 'HQ'}, expected_code=403) self.post(self.url, {'company': company.pk, 'title': 'HQ'}, expected_code=403)
@@ -366,6 +376,7 @@ class AddressTest(InvenTreeAPITestCase):
def test_get(self): def test_get(self):
"""Test that objects are properly returned from a get.""" """Test that objects are properly returned from a get."""
addr = Address.objects.first() addr = Address.objects.first()
assert addr
url = reverse('api-address-detail', kwargs={'pk': addr.pk}) url = reverse('api-address-detail', kwargs={'pk': addr.pk})
response = self.get(url, expected_code=200) response = self.get(url, expected_code=200)
@@ -386,6 +397,7 @@ class AddressTest(InvenTreeAPITestCase):
def test_edit(self): def test_edit(self):
"""Test editing an Address object.""" """Test editing an Address object."""
addr = Address.objects.first() addr = Address.objects.first()
assert addr
url = reverse('api-address-detail', kwargs={'pk': addr.pk}) url = reverse('api-address-detail', kwargs={'pk': addr.pk})
@@ -403,6 +415,7 @@ class AddressTest(InvenTreeAPITestCase):
def test_delete(self): def test_delete(self):
"""Test deleting an object.""" """Test deleting an object."""
addr = Address.objects.first() addr = Address.objects.first()
assert addr
url = reverse('api-address-detail', kwargs={'pk': addr.pk}) url = reverse('api-address-detail', kwargs={'pk': addr.pk})
+1 -1
View File
@@ -18,7 +18,7 @@ class DataExporterConfig(AppConfig):
def cleanup(self): def cleanup(self):
"""Cleanup any old export files.""" """Cleanup any old export files."""
try: try:
from data_exporter.tasks import cleanup_old_export_outputs from data_exporter.tasks import cleanup_old_export_outputs # type: ignore
cleanup_old_export_outputs() cleanup_old_export_outputs()
except Exception: except Exception:
@@ -1,6 +1,7 @@
"""Mixin classes for the exporter app.""" """Mixin classes for the exporter app."""
from collections import OrderedDict from collections import OrderedDict
from typing import Any
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django.core.files.base import ContentFile from django.core.files.base import ContentFile
@@ -127,7 +128,7 @@ class DataExportSerializerMixin:
""" """
return headers return headers
def get_nested_value(self, row: dict, key: str) -> any: def get_nested_value(self, row: dict, key: str) -> Any:
"""Get a nested value from a dictionary. """Get a nested value from a dictionary.
This method allows for dot notation to access nested fields. This method allows for dot notation to access nested fields.
@@ -6,7 +6,6 @@ from rest_framework import serializers
import InvenTree.exceptions import InvenTree.exceptions
import InvenTree.helpers import InvenTree.helpers
import InvenTree.serializers
from plugin import PluginMixinEnum, registry from plugin import PluginMixinEnum, registry
@@ -53,7 +52,7 @@ class DataExportOptionsSerializer(serializers.Serializer):
try: try:
supports_export = plugin.supports_export( supports_export = plugin.supports_export(
model_class, model_class,
user=request.user, user=request.user if request else None,
serializer_class=serializer_class, serializer_class=serializer_class,
view_class=view_class, view_class=view_class,
) )
@@ -6,6 +6,7 @@ There is a rendered state for each state value. The rendered state is used for d
States can be extended with custom options for each InvenTree instance - those options are stored in the database and need to link back to state values. States can be extended with custom options for each InvenTree instance - those options are stored in the database and need to link back to state values.
""" """
from . import fields
from .states import ColorEnum, StatusCode, StatusCodeMixin from .states import ColorEnum, StatusCode, StatusCodeMixin
from .transition import StateTransitionMixin, TransitionMethod from .transition import StateTransitionMixin, TransitionMethod
@@ -15,4 +16,5 @@ __all__ = [
'StatusCode', 'StatusCode',
'StatusCodeMixin', 'StatusCodeMixin',
'TransitionMethod', 'TransitionMethod',
'fields',
] ]
@@ -4,6 +4,7 @@ import enum
import logging import logging
import re import re
from enum import Enum from enum import Enum
from typing import Optional
logger = logging.getLogger('inventree') logger = logging.getLogger('inventree')
@@ -297,7 +298,7 @@ class StatusCodeMixin:
"""Return the status code for this object.""" """Return the status code for this object."""
return getattr(self, self.STATUS_FIELD) return getattr(self, self.STATUS_FIELD)
def get_custom_status(self) -> int: def get_custom_status(self) -> Optional[int]:
"""Return the custom status code for this object.""" """Return the custom status code for this object."""
return getattr(self, f'{self.STATUS_FIELD}_custom_key', None) return getattr(self, f'{self.STATUS_FIELD}_custom_key', None)
+2 -1
View File
@@ -20,4 +20,5 @@ def status_label(typ: str, key: int, include_custom: bool = False, *args, **kwar
def display_status_label(typ: str, key: int, fallback: int, *args, **kwargs): def display_status_label(typ: str, key: int, fallback: int, *args, **kwargs):
"""Render a status label.""" """Render a status label."""
render_key = int(key) if key else fallback render_key = int(key) if key else fallback
return status_label(typ, render_key, *args, include_custom=True, **kwargs) kwargs['include_custom'] = True
return status_label(typ, render_key, *args, **kwargs)
@@ -1,5 +1,7 @@
"""Classes and functions for plugin controlled object state transitions.""" """Classes and functions for plugin controlled object state transitions."""
from typing import Callable
from django.db.models import Model from django.db.models import Model
import structlog import structlog
@@ -30,7 +32,7 @@ class TransitionMethod:
current_state: int, current_state: int,
target_state: int, target_state: int,
instance: Model, instance: Model,
default_action: callable, default_action: Callable,
**kwargs, **kwargs,
) -> bool: ) -> bool:
"""Perform a state transition. """Perform a state transition.
+2 -4
View File
@@ -303,9 +303,7 @@ class DataImportSession(models.Model):
if not any(row_data.values()): if not any(row_data.values()):
continue continue
row = importer.models.DataImportRow( row = DataImportRow(session=self, row_data=row_data, row_index=idx)
session=self, row_data=row_data, row_index=idx
)
row.extract_data( row.extract_data(
field_mapping=field_mapping, field_mapping=field_mapping,
@@ -317,7 +315,7 @@ class DataImportSession(models.Model):
imported_rows.append(row) imported_rows.append(row)
# Perform database writes as a single operation # Perform database writes as a single operation
importer.models.DataImportRow.objects.bulk_create(imported_rows) DataImportRow.objects.bulk_create(imported_rows)
# Mark the import task as "PROCESSING" # Mark the import task as "PROCESSING"
self.status = DataImportStatusCode.PROCESSING.value self.status = DataImportStatusCode.PROCESSING.value
+4 -1
View File
@@ -1,9 +1,12 @@
"""Data import operational functions.""" """Data import operational functions."""
from typing import Optional
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
import tablib import tablib
import tablib.core
import InvenTree.helpers import InvenTree.helpers
@@ -82,7 +85,7 @@ def extract_column_names(data_file) -> list:
return headers return headers
def get_field_label(field) -> str: def get_field_label(field) -> Optional[str]:
"""Return the label for a field in a serializer class. """Return the label for a field in a serializer class.
Check for labels in the following order of descending priority: Check for labels in the following order of descending priority:
+2 -2
View File
@@ -1,7 +1,7 @@
"""Models for the machine app.""" """Models for the machine app."""
import uuid import uuid
from typing import Literal from typing import Literal, Optional
from django.contrib import admin from django.contrib import admin
from django.db import models from django.db import models
@@ -192,7 +192,7 @@ class MachineSetting(common.models.BaseInvenTreeSetting):
If not provided, we'll look at the machine registry to see what settings this machine driver requires If not provided, we'll look at the machine registry to see what settings this machine driver requires
""" """
if 'settings' not in kwargs: if 'settings' not in kwargs:
machine_config: MachineConfig = kwargs.pop('machine_config', None) machine_config: Optional[MachineConfig] = kwargs.pop('machine_config', None)
if machine_config and machine_config.machine: if machine_config and machine_config.machine:
config_type = kwargs.get('config_type') config_type = kwargs.get('config_type')
if config_type == cls.ConfigType.DRIVER: if config_type == cls.ConfigType.DRIVER:
+3 -1
View File
@@ -423,7 +423,9 @@ class MachineRegistry(
# If the plugin registry has changed, the machine registry hash will change # If the plugin registry has changed, the machine registry hash will change
plugin_registry.update_plugin_hash() plugin_registry.update_plugin_hash()
data.update(plugin_registry.registry_hash.encode()) current_hash = plugin_registry.registry_hash
if current_hash:
data.update(current_hash.encode())
for pk, machine in self.machines.items(): for pk, machine in self.machines.items():
data.update(str(pk).encode()) data.update(str(pk).encode())
+1
View File
@@ -220,6 +220,7 @@ class TestLabelPrinterMachineType(InvenTreeAPITestCase):
parts = Part.objects.all()[:2] parts = Part.objects.all()[:2]
template = LabelTemplate.objects.filter(enabled=True, model_type='part').first() template = LabelTemplate.objects.filter(enabled=True, model_type='part').first()
assert template
url = reverse('api-label-print') url = reverse('api-label-print')
+6 -5
View File
@@ -11,8 +11,9 @@ from django.http.response import JsonResponse
from django.urls import include, path, re_path from django.urls import include, path, re_path
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
import django_filters.rest_framework.filters as rest_filters
import rest_framework.serializers import rest_framework.serializers
from django_filters import rest_framework as rest_filters from django_filters.rest_framework.filterset import FilterSet
from django_ical.views import ICalFeed from django_ical.views import ICalFeed
from drf_spectacular.types import OpenApiTypes from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import extend_schema, extend_schema_field from drf_spectacular.utils import extend_schema, extend_schema_field
@@ -100,7 +101,7 @@ class OrderCreateMixin:
) )
class OrderFilter(rest_filters.FilterSet): class OrderFilter(FilterSet):
"""Base class for custom API filters for the OrderList endpoint.""" """Base class for custom API filters for the OrderList endpoint."""
# Filter against order status # Filter against order status
@@ -258,7 +259,7 @@ class OrderFilter(rest_filters.FilterSet):
return queryset.filter(q1 | q2 | q3 | q4).distinct() return queryset.filter(q1 | q2 | q3 | q4).distinct()
class LineItemFilter(rest_filters.FilterSet): class LineItemFilter(FilterSet):
"""Base class for custom API filters for order line item list(s).""" """Base class for custom API filters for order line item list(s)."""
# Filter by order status # Filter by order status
@@ -1147,7 +1148,7 @@ class SalesOrderAllocate(SalesOrderContextMixin, CreateAPI):
serializer_class = serializers.SalesOrderShipmentAllocationSerializer serializer_class = serializers.SalesOrderShipmentAllocationSerializer
class SalesOrderAllocationFilter(rest_filters.FilterSet): class SalesOrderAllocationFilter(FilterSet):
"""Custom filterset for the SalesOrderAllocationList endpoint.""" """Custom filterset for the SalesOrderAllocationList endpoint."""
class Meta: class Meta:
@@ -1321,7 +1322,7 @@ class SalesOrderAllocationDetail(SalesOrderAllocationMixin, RetrieveUpdateDestro
"""API endpoint for detali view of a SalesOrderAllocation object.""" """API endpoint for detali view of a SalesOrderAllocation object."""
class SalesOrderShipmentFilter(rest_filters.FilterSet): class SalesOrderShipmentFilter(FilterSet):
"""Custom filterset for the SalesOrderShipmentList endpoint.""" """Custom filterset for the SalesOrderShipmentList endpoint."""
class Meta: class Meta:
+3 -2
View File
@@ -1,6 +1,7 @@
"""Background tasks for the 'order' app.""" """Background tasks for the 'order' app."""
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Union
from django.contrib.auth.models import Group, User from django.contrib.auth.models import Group, User
from django.db import transaction from django.db import transaction
@@ -104,7 +105,7 @@ def check_overdue_purchase_orders():
@tracer.start_as_current_span('notify_overdue_sales_order') @tracer.start_as_current_span('notify_overdue_sales_order')
def notify_overdue_sales_order(so: order.models.SalesOrder) -> None: def notify_overdue_sales_order(so: order.models.SalesOrder) -> None:
"""Notify appropriate users that a SalesOrder has just become 'overdue'.""" """Notify appropriate users that a SalesOrder has just become 'overdue'."""
targets: list[User, Group, Owner] = [] targets: list[Union[User, Group, Owner]] = []
if so.created_by: if so.created_by:
targets.append(so.created_by) targets.append(so.created_by)
@@ -171,7 +172,7 @@ def check_overdue_sales_orders():
@tracer.start_as_current_span('notify_overdue_return_order') @tracer.start_as_current_span('notify_overdue_return_order')
def notify_overdue_return_order(ro: order.models.ReturnOrder) -> None: def notify_overdue_return_order(ro: order.models.ReturnOrder) -> None:
"""Notify appropriate users that a ReturnOrder has just become 'overdue'.""" """Notify appropriate users that a ReturnOrder has just become 'overdue'."""
targets: list[User, Group, Owner] = [] targets: list[Union[User, Group, Owner]] = []
if ro.created_by: if ro.created_by:
targets.append(ro.created_by) targets.append(ro.created_by)
+17 -4
View File
@@ -4,6 +4,7 @@ import base64
import io import io
import json import json
from datetime import date, datetime, timedelta from datetime import date, datetime, timedelta
from typing import Optional
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django.db import connection from django.db import connection
@@ -420,7 +421,9 @@ class PurchaseOrderTest(OrderTest):
self.assertIn('Responsible user or group must be specified', str(response.data)) self.assertIn('Responsible user or group must be specified', str(response.data))
data['responsible'] = Owner.objects.first().pk owner = Owner.objects.first()
assert owner
data['responsible'] = owner.pk
response = self.post(url, data, expected_code=201) response = self.post(url, data, expected_code=201)
@@ -1689,6 +1692,7 @@ class SalesOrderTest(OrderTest):
shipment = models.SalesOrderShipment.objects.create( shipment = models.SalesOrderShipment.objects.create(
order=so, reference='SHIP-12345' order=so, reference='SHIP-12345'
) )
assert shipment
# Allocate some stock # Allocate some stock
item = StockItem.objects.create(part=part, quantity=100, location=None) item = StockItem.objects.create(part=part, quantity=100, location=None)
@@ -1825,10 +1829,13 @@ class SalesOrderLineItemTest(OrderTest):
self.assignRole('sales_order.add') self.assignRole('sales_order.add')
# Crete a new SalesOrder via the API # Crete a new SalesOrder via the API
company = Company.objects.filter(is_customer=True).first()
assert company
response = self.post( response = self.post(
reverse('api-so-list'), reverse('api-so-list'),
{ {
'customer': Company.objects.filter(is_customer=True).first().pk, 'customer': company.pk,
'reference': 'SO-12345', 'reference': 'SO-12345',
'description': 'Test Sales Order', 'description': 'Test Sales Order',
}, },
@@ -1878,6 +1885,7 @@ class SalesOrderLineItemTest(OrderTest):
p = Part.objects.get(pk=item) p = Part.objects.get(pk=item)
s = StockItem.objects.create(part=p, quantity=100) s = StockItem.objects.create(part=p, quantity=100)
l = models.SalesOrderLineItem.objects.filter(order=order, part=p).first() l = models.SalesOrderLineItem.objects.filter(order=order, part=p).first()
assert l
# Allocate against the API # Allocate against the API
self.post( self.post(
@@ -2099,12 +2107,14 @@ class SalesOrderAllocateTest(OrderTest):
return line_item.part.is_template return line_item.part.is_template
for line in filter(check_template, self.order.lines.all()): for line in filter(check_template, self.order.lines.all()):
stock_item = None stock_item: Optional[StockItem] = None
stock_item = None stock_item = None
# Allocate a matching variant # Allocate a matching variant
parts = Part.objects.filter(salable=True).filter(variant_of=line.part.pk) parts: list[Part] = Part.objects.filter(salable=True).filter(
variant_of=line.part.pk
)
for part in parts: for part in parts:
stock_item = part.stock_items.last() stock_item = part.stock_items.last()
@@ -2118,6 +2128,9 @@ class SalesOrderAllocateTest(OrderTest):
if stock_item is not None: if stock_item is not None:
break break
if stock_item is None:
raise self.fail('No stock item found for part') # pragma: no cover
# Fully-allocate each line # Fully-allocate each line
data['items'].append({ data['items'].append({
'line_item': line.pk, 'line_item': line.pk,
+15 -14
View File
@@ -6,8 +6,9 @@ from django.db.models import Count, F, Q
from django.urls import include, path from django.urls import include, path
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from django_filters import rest_framework as rest_filters import django_filters.rest_framework.filters as rest_filters
from django_filters.rest_framework import DjangoFilterBackend from django_filters.rest_framework import DjangoFilterBackend
from django_filters.rest_framework.filterset import FilterSet
from drf_spectacular.types import OpenApiTypes from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import extend_schema_field from drf_spectacular.utils import extend_schema_field
from rest_framework import serializers from rest_framework import serializers
@@ -98,7 +99,7 @@ class CategoryMixin:
return ctx return ctx
class CategoryFilter(rest_filters.FilterSet): class CategoryFilter(FilterSet):
"""Custom filterset class for the PartCategoryList endpoint.""" """Custom filterset class for the PartCategoryList endpoint."""
class Meta: class Meta:
@@ -282,11 +283,11 @@ class CategoryDetail(CategoryMixin, CustomRetrieveUpdateDestroyAPI):
return super().destroy( return super().destroy(
request, request,
*args, *args,
**dict( **{
kwargs, **kwargs,
delete_parts=delete_parts, 'delete_parts': delete_parts,
delete_child_categories=delete_child_categories, 'delete_child_categories': delete_child_categories,
), },
) )
@@ -399,7 +400,7 @@ class PartInternalPriceList(DataExportViewMixin, ListCreateAPI):
ordering = 'quantity' ordering = 'quantity'
class PartTestTemplateFilter(rest_filters.FilterSet): class PartTestTemplateFilter(FilterSet):
"""Custom filterset class for the PartTestTemplateList endpoint.""" """Custom filterset class for the PartTestTemplateList endpoint."""
class Meta: class Meta:
@@ -644,7 +645,7 @@ class PartValidateBOM(RetrieveUpdateAPI):
return Response(serializer.data) return Response(serializer.data)
class PartFilter(rest_filters.FilterSet): class PartFilter(FilterSet):
"""Custom filters for the PartList endpoint. """Custom filters for the PartList endpoint.
Uses the django_filters extension framework Uses the django_filters extension framework
@@ -1196,7 +1197,7 @@ class PartDetail(PartMixin, RetrieveUpdateDestroyAPI):
return response return response
class PartRelatedFilter(rest_filters.FilterSet): class PartRelatedFilter(FilterSet):
"""FilterSet for PartRelated objects.""" """FilterSet for PartRelated objects."""
class Meta: class Meta:
@@ -1243,7 +1244,7 @@ class PartRelatedDetail(PartRelatedMixin, RetrieveUpdateDestroyAPI):
"""API endpoint for accessing detail view of a PartRelated object.""" """API endpoint for accessing detail view of a PartRelated object."""
class PartParameterTemplateFilter(rest_filters.FilterSet): class PartParameterTemplateFilter(FilterSet):
"""FilterSet for PartParameterTemplate objects.""" """FilterSet for PartParameterTemplate objects."""
class Meta: class Meta:
@@ -1377,7 +1378,7 @@ class PartParameterAPIMixin:
return super().get_serializer(*args, **kwargs) return super().get_serializer(*args, **kwargs)
class PartParameterFilter(rest_filters.FilterSet): class PartParameterFilter(FilterSet):
"""Custom filters for the PartParameterList API endpoint.""" """Custom filters for the PartParameterList API endpoint."""
class Meta: class Meta:
@@ -1438,7 +1439,7 @@ class PartParameterDetail(PartParameterAPIMixin, RetrieveUpdateDestroyAPI):
"""API endpoint for detail view of a single PartParameter object.""" """API endpoint for detail view of a single PartParameter object."""
class PartStocktakeFilter(rest_filters.FilterSet): class PartStocktakeFilter(FilterSet):
"""Custom filter for the PartStocktakeList endpoint.""" """Custom filter for the PartStocktakeList endpoint."""
class Meta: class Meta:
@@ -1480,7 +1481,7 @@ class PartStocktakeDetail(RetrieveUpdateDestroyAPI):
serializer_class = part_serializers.PartStocktakeSerializer serializer_class = part_serializers.PartStocktakeSerializer
class BomFilter(rest_filters.FilterSet): class BomFilter(FilterSet):
"""Custom filters for the BOM list.""" """Custom filters for the BOM list."""
class Meta: class Meta:
+3 -2
View File
@@ -12,6 +12,7 @@ Useful References:
""" """
from decimal import Decimal from decimal import Decimal
from typing import Optional
from django.db import models from django.db import models
from django.db.models import ( from django.db.models import (
@@ -137,7 +138,7 @@ def annotate_on_order_quantity(reference: str = '') -> QuerySet:
) )
def annotate_total_stock(reference: str = '', filter: Q = None) -> QuerySet: def annotate_total_stock(reference: str = '', filter: Optional[Q] = None) -> QuerySet:
"""Annotate 'total stock' quantity against a queryset. """Annotate 'total stock' quantity against a queryset.
- This function calculates the 'total stock' for a given part - This function calculates the 'total stock' for a given part
@@ -269,7 +270,7 @@ def annotate_sales_order_allocations(reference: str = '', location=None) -> Quer
) )
def variant_stock_query(reference: str = '', filter: Q = None) -> QuerySet: def variant_stock_query(reference: str = '', filter: Optional[Q] = None) -> QuerySet:
"""Create a queryset to retrieve all stock items for variant parts under the specified part. """Create a queryset to retrieve all stock items for variant parts under the specified part.
- Useful for annotating a queryset with aggregated information about variant parts - Useful for annotating a queryset with aggregated information about variant parts
+2 -2
View File
@@ -819,7 +819,7 @@ class Part(
if not check_duplicates: if not check_duplicates:
return return
from part.models import Part # from part.models import Part
from stock.models import StockItem from stock.models import StockItem
if get_global_setting('SERIAL_NUMBER_GLOBALLY_UNIQUE', False): if get_global_setting('SERIAL_NUMBER_GLOBALLY_UNIQUE', False):
@@ -850,7 +850,7 @@ class Part(
def find_conflicting_serial_numbers(self, serials: list) -> list: def find_conflicting_serial_numbers(self, serials: list) -> list:
"""For a provided list of serials, return a list of those which are conflicting.""" """For a provided list of serials, return a list of those which are conflicting."""
from part.models import Part # from part.models import Part
from stock.models import StockItem from stock.models import StockItem
conflicts = [] conflicts = []
+6 -4
View File
@@ -11,7 +11,7 @@ from django.db import connection
from django.test.utils import CaptureQueriesContext from django.test.utils import CaptureQueriesContext
from django.urls import reverse from django.urls import reverse
import PIL from PIL import Image
from rest_framework.test import APIClient from rest_framework.test import APIClient
import build.models import build.models
@@ -65,7 +65,7 @@ class PartImageTestMixin:
fn = get_testfolder_dir() / 'part_image_123abc.png' fn = get_testfolder_dir() / 'part_image_123abc.png'
img = PIL.Image.new('RGB', (128, 128), color='blue') img = Image.new('RGB', (128, 128), color='blue')
img.save(fn) img.save(fn)
with open(fn, 'rb') as img_file: with open(fn, 'rb') as img_file:
@@ -1770,7 +1770,7 @@ class PartDetailTests(PartImageTestMixin, PartAPITestBase):
for fmt in ['jpg', 'j2k', 'png', 'bmp', 'webp']: for fmt in ['jpg', 'j2k', 'png', 'bmp', 'webp']:
fn = f'{test_path}.{fmt}' fn = f'{test_path}.{fmt}'
img = PIL.Image.new('RGB', (128, 128), color='red') img = Image.new('RGB', (128, 128), color='red')
img.save(fn) img.save(fn)
with open(fn, 'rb') as dummy_image: with open(fn, 'rb') as dummy_image:
@@ -1820,7 +1820,7 @@ class PartDetailTests(PartImageTestMixin, PartAPITestBase):
fn = get_testfolder_dir() / 'part_image_123abc.png' fn = get_testfolder_dir() / 'part_image_123abc.png'
img = PIL.Image.new('RGB', (128, 128), color='blue') img = Image.new('RGB', (128, 128), color='blue')
img.save(fn) img.save(fn)
# Upload the image to a part # Upload the image to a part
@@ -2463,6 +2463,7 @@ class BomItemTest(InvenTreeAPITestCase):
# Now, let's validate an item # Now, let's validate an item
bom_item = BomItem.objects.first() bom_item = BomItem.objects.first()
assert bom_item
bom_item.validate_hash() bom_item.validate_hash()
@@ -3109,6 +3110,7 @@ class PartTestTemplateTest(PartAPITestBase):
def test_choices(self): def test_choices(self):
"""Test the 'choices' field for the PartTestTemplate model.""" """Test the 'choices' field for the PartTestTemplate model."""
template = PartTestTemplate.objects.first() template = PartTestTemplate.objects.first()
assert template
url = reverse('api-part-test-template-detail', kwargs={'pk': template.pk}) url = reverse('api-part-test-template-detail', kwargs={'pk': template.pk})
+1 -1
View File
@@ -2,8 +2,8 @@
from django.contrib import admin from django.contrib import admin
import plugin.registry as pl_registry
from plugin import models from plugin import models
from plugin.registry import registry as pl_registry
def plugin_update(queryset, new_status: bool): def plugin_update(queryset, new_status: bool):
+3 -2
View File
@@ -6,8 +6,9 @@ from django.core.exceptions import ValidationError
from django.urls import include, path, re_path from django.urls import include, path, re_path
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from django_filters import rest_framework as rest_filters import django_filters.rest_framework.filters as rest_filters
from django_filters.rest_framework import DjangoFilterBackend from django_filters.rest_framework import DjangoFilterBackend
from django_filters.rest_framework.filterset import FilterSet
from drf_spectacular.utils import extend_schema from drf_spectacular.utils import extend_schema
from rest_framework import status from rest_framework import status
from rest_framework.exceptions import NotFound from rest_framework.exceptions import NotFound
@@ -36,7 +37,7 @@ from plugin.plugin import InvenTreePlugin
from plugin.registry import registry from plugin.registry import registry
class PluginFilter(rest_filters.FilterSet): class PluginFilter(FilterSet):
"""Filter for the PluginConfig model. """Filter for the PluginConfig model.
Provides custom filtering options for the FilterList API endpoint. Provides custom filtering options for the FilterList API endpoint.
@@ -5,7 +5,7 @@ from django.urls import include, path
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
import structlog import structlog
from django_filters import rest_framework as rest_filters from django_filters.rest_framework.filterset import FilterSet
from drf_spectacular.utils import extend_schema, extend_schema_view from drf_spectacular.utils import extend_schema, extend_schema_view
from rest_framework import status from rest_framework import status
from rest_framework.exceptions import PermissionDenied, ValidationError from rest_framework.exceptions import PermissionDenied, ValidationError
@@ -770,7 +770,7 @@ class BarcodeScanResultMixin:
return queryset return queryset
class BarcodeScanResultFilter(rest_filters.FilterSet): class BarcodeScanResultFilter(FilterSet):
"""Custom filterset for the BarcodeScanResult API.""" """Custom filterset for the BarcodeScanResult API."""
class Meta: class Meta:
@@ -2,6 +2,8 @@
from __future__ import annotations from __future__ import annotations
from typing import Optional
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django.db.models import Q from django.db.models import Q
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
@@ -113,7 +115,7 @@ class SupplierBarcodeMixin(BarcodeMixin):
return fields.get(key, backup_value) return fields.get(key, backup_value)
def get_part(self) -> Part | None: def get_part(self) -> Optional[Part]:
"""Extract the Part object from the barcode fields.""" """Extract the Part object from the barcode fields."""
# TODO: Implement this # TODO: Implement this
return None return None
@@ -128,7 +130,7 @@ class SupplierBarcodeMixin(BarcodeMixin):
"""Return the supplier part number from the barcode fields.""" """Return the supplier part number from the barcode fields."""
return self.get_field_value(self.SUPPLIER_PART_NUMBER) return self.get_field_value(self.SUPPLIER_PART_NUMBER)
def get_supplier_part(self) -> SupplierPart | None: def get_supplier_part(self) -> Optional[SupplierPart]:
"""Return the SupplierPart object for the scanned barcode. """Return the SupplierPart object for the scanned barcode.
Returns: Returns:
@@ -172,7 +174,7 @@ class SupplierBarcodeMixin(BarcodeMixin):
"""Return the manufacturer part number from the barcode fields.""" """Return the manufacturer part number from the barcode fields."""
return self.get_field_value(self.MANUFACTURER_PART_NUMBER) return self.get_field_value(self.MANUFACTURER_PART_NUMBER)
def get_manufacturer_part(self) -> ManufacturerPart | None: def get_manufacturer_part(self) -> Optional[ManufacturerPart]:
"""Return the ManufacturerPart object for the scanned barcode. """Return the ManufacturerPart object for the scanned barcode.
Returns: Returns:
@@ -213,7 +215,7 @@ class SupplierBarcodeMixin(BarcodeMixin):
"""Return the supplier order number from the barcode fields.""" """Return the supplier order number from the barcode fields."""
return self.get_field_value(self.SUPPLIER_ORDER_NUMBER) return self.get_field_value(self.SUPPLIER_ORDER_NUMBER)
def get_purchase_order(self) -> PurchaseOrder | None: def get_purchase_order(self) -> Optional[PurchaseOrder]:
"""Extract the PurchaseOrder object from the barcode fields. """Extract the PurchaseOrder object from the barcode fields.
Inspect the customer_order_number and supplier_order_number fields, Inspect the customer_order_number and supplier_order_number fields,
@@ -260,7 +262,7 @@ class SupplierBarcodeMixin(BarcodeMixin):
'extract_barcode_fields must be implemented by each plugin' 'extract_barcode_fields must be implemented by each plugin'
) )
def scan(self, barcode_data: str) -> dict: def scan(self, barcode_data: str) -> Optional[dict]:
"""Perform a generic 'scan' operation on a supplier barcode. """Perform a generic 'scan' operation on a supplier barcode.
The supplier barcode may provide sufficient information to match against The supplier barcode may provide sufficient information to match against
@@ -319,7 +321,7 @@ class SupplierBarcodeMixin(BarcodeMixin):
location=None, location=None,
auto_allocate: bool = True, auto_allocate: bool = True,
**kwargs, **kwargs,
) -> dict | None: ) -> Optional[dict]:
"""Attempt to receive an item against a PurchaseOrder via barcode scanning. """Attempt to receive an item against a PurchaseOrder via barcode scanning.
Arguments: Arguments:
@@ -430,7 +432,7 @@ class SupplierBarcodeMixin(BarcodeMixin):
return response return response
def get_supplier(self, cache: bool = False) -> Company | None: def get_supplier(self, cache: bool = False) -> Optional[Company]:
"""Get the supplier for the SUPPLIER_ID set in the plugin settings. """Get the supplier for the SUPPLIER_ID set in the plugin settings.
If it's not defined, try to guess it and set it if possible. If it's not defined, try to guess it and set it if possible.
@@ -461,9 +463,12 @@ class SupplierBarcodeMixin(BarcodeMixin):
if len(suppliers) != 1: if len(suppliers) != 1:
return _cache_supplier(None) return _cache_supplier(None)
self.set_setting('SUPPLIER_ID', suppliers.first().pk) supplier = suppliers.first()
assert supplier
return _cache_supplier(suppliers.first()) self.set_setting('SUPPLIER_ID', supplier.pk)
return _cache_supplier(supplier)
@classmethod @classmethod
def ecia_field_map(cls): def ecia_field_map(cls):
@@ -163,7 +163,8 @@ class APICallMixin:
url = f'{self.api_url}/{endpoint}' url = f'{self.api_url}/{endpoint}'
# build kwargs for call # build kwargs for call
kwargs.update({'url': url, 'headers': headers}) kwargs.update({'headers': headers})
kwargs.pop('url', None)
if data and json: if data and json:
raise ValueError('You can either pass `data` or `json` to this function.') raise ValueError('You can either pass `data` or `json` to this function.')
@@ -175,7 +176,7 @@ class APICallMixin:
kwargs['data'] = data kwargs['data'] = data
# run command # run command
response = requests.request(method, **kwargs) response = requests.request(method, url=url, **kwargs)
# return # return
if simple_response: if simple_response:
@@ -1,7 +1,7 @@
"""Plugin class for custom data exporting.""" """Plugin class for custom data exporting."""
from collections import OrderedDict from collections import OrderedDict
from typing import Union from typing import Optional, Union
from django.contrib.auth.models import User from django.contrib.auth.models import User
from django.db.models import QuerySet from django.db.models import QuerySet
@@ -36,8 +36,8 @@ class DataExportMixin:
self, self,
model_class: type, model_class: type,
user: User, user: User,
serializer_class: serializers.Serializer = None, serializer_class: Optional[serializers.Serializer] = None,
view_class: views.APIView = None, view_class: Optional[views.APIView] = None,
*args, *args,
**kwargs, **kwargs,
) -> bool: ) -> bool:
@@ -72,7 +72,7 @@ class ValidationMixin:
def validate_model_instance( def validate_model_instance(
self, instance: Model, deltas: Optional[dict] = None self, instance: Model, deltas: Optional[dict] = None
) -> None: ) -> Optional[bool]:
"""Run custom validation on a database model instance. """Run custom validation on a database model instance.
This method is called when a model instance is being validated. This method is called when a model instance is being validated.
@@ -90,7 +90,7 @@ class ValidationMixin:
""" """
return None return None
def validate_part_name(self, name: str, part: part.models.Part) -> None: def validate_part_name(self, name: str, part: part.models.Part) -> Optional[bool]:
"""Perform validation on a proposed Part name. """Perform validation on a proposed Part name.
Arguments: Arguments:
@@ -105,7 +105,7 @@ class ValidationMixin:
""" """
return None return None
def validate_part_ipn(self, ipn: str, part: part.models.Part) -> None: def validate_part_ipn(self, ipn: str, part: part.models.Part) -> Optional[bool]:
"""Perform validation on a proposed Part IPN (internal part number). """Perform validation on a proposed Part IPN (internal part number).
Arguments: Arguments:
@@ -122,7 +122,7 @@ class ValidationMixin:
def validate_batch_code( def validate_batch_code(
self, batch_code: str, item: stock.models.StockItem self, batch_code: str, item: stock.models.StockItem
) -> None: ) -> Optional[bool]:
"""Validate the supplied batch code. """Validate the supplied batch code.
Arguments: Arguments:
@@ -137,7 +137,7 @@ class ValidationMixin:
""" """
return None return None
def generate_batch_code(self, **kwargs) -> str: def generate_batch_code(self, **kwargs) -> Optional[str]:
"""Generate a new batch code. """Generate a new batch code.
This method is called when a new batch code is required. This method is called when a new batch code is required.
@@ -154,8 +154,8 @@ class ValidationMixin:
self, self,
serial: str, serial: str,
part: part.models.Part, part: part.models.Part,
stock_item: stock.models.StockItem = None, stock_item: Optional[stock.models.StockItem] = None,
) -> None: ) -> Optional[bool]:
"""Validate the supplied serial number. """Validate the supplied serial number.
Arguments: Arguments:
@@ -171,7 +171,7 @@ class ValidationMixin:
""" """
return None return None
def convert_serial_to_int(self, serial: str) -> int: def convert_serial_to_int(self, serial: str) -> Optional[int]:
"""Convert a serial number (string) into an integer representation. """Convert a serial number (string) into an integer representation.
This integer value is used for efficient sorting based on serial numbers. This integer value is used for efficient sorting based on serial numbers.
@@ -192,7 +192,7 @@ class ValidationMixin:
""" """
return None return None
def get_latest_serial_number(self, part, **kwargs): def get_latest_serial_number(self, part, **kwargs) -> Optional[str]:
"""Return the 'latest' serial number for a given Part instance. """Return the 'latest' serial number for a given Part instance.
A plugin which implements this method can either return: A plugin which implements this method can either return:
@@ -209,8 +209,8 @@ class ValidationMixin:
return None return None
def increment_serial_number( def increment_serial_number(
self, serial: str, part: part.models.Part = None, **kwargs self, serial: str, part: Optional[part.models.Part] = None, **kwargs
) -> str: ) -> Optional[str]:
"""Return the next sequential serial based on the provided value. """Return the next sequential serial based on the provided value.
A plugin which implements this method can either return: A plugin which implements this method can either return:
@@ -229,7 +229,7 @@ class ValidationMixin:
def validate_part_parameter( def validate_part_parameter(
self, parameter: part.models.PartParameter, data: str self, parameter: part.models.PartParameter, data: str
) -> None: ) -> Optional[bool]:
"""Validate a parameter value. """Validate a parameter value.
Arguments: Arguments:
@@ -311,6 +311,7 @@ class APICallMixinTest(BaseMixinDefinition, TestCase):
self.assertTrue(result) self.assertTrue(result)
self.assertNotIn('error', result) self.assertNotIn('error', result)
assert result is not None
self.assertEqual(result['name'], 'morpheus') self.assertEqual(result['name'], 'morpheus')
# api_call with endpoint with leading slash # api_call with endpoint with leading slash
@@ -113,6 +113,7 @@ class LabelMixinTests(PrintTestMixins, InvenTreeAPITestCase):
parts = Part.objects.all()[:2] parts = Part.objects.all()[:2]
template = LabelTemplate.objects.filter(enabled=True, model_type='part').first() template = LabelTemplate.objects.filter(enabled=True, model_type='part').first()
assert template
self.assertIsNotNone(template) self.assertIsNotNone(template)
self.assertTrue(template.enabled) self.assertTrue(template.enabled)
@@ -227,6 +228,8 @@ class LabelMixinTests(PrintTestMixins, InvenTreeAPITestCase):
# Lookup references # Lookup references
parts = Part.objects.all()[:2] parts = Part.objects.all()[:2]
template = LabelTemplate.objects.filter(enabled=True, model_type='part').first() template = LabelTemplate.objects.filter(enabled=True, model_type='part').first()
assert template
self.do_activate_plugin() self.do_activate_plugin()
plugin = registry.get_plugin(self.plugin_ref) plugin = registry.get_plugin(self.plugin_ref)
@@ -7,4 +7,4 @@ class BrokenFileIntegrationPlugin(InvenTreePlugin):
"""An very broken plugin.""" """An very broken plugin."""
aaa = bb # noqa: F821 aaa = bb # noqa: F821 # type: ignore[unresolved-reference]
+1 -1
View File
@@ -51,7 +51,7 @@ class MixinNotImplementedError(NotImplementedError):
def log_registry_error(error, reference: str = 'general'): def log_registry_error(error, reference: str = 'general'):
"""Log an plugin error.""" """Log an plugin error."""
from plugin import registry from plugin.registry import registry
# make sure the registry is set up # make sure the registry is set up
if reference not in registry.errors: if reference not in registry.errors:
+3 -2
View File
@@ -2,6 +2,7 @@
import inspect import inspect
import warnings import warnings
from typing import Optional
from django.conf import settings from django.conf import settings
from django.contrib import admin from django.contrib import admin
@@ -219,7 +220,7 @@ class PluginConfig(InvenTree.models.MetadataMixin, models.Model):
return pkg_name is not None return pkg_name is not None
@property @property
def admin_source(self) -> str: def admin_source(self) -> Optional[str]:
"""Return the path to the javascript file which renders custom admin content for this plugin. """Return the path to the javascript file which renders custom admin content for this plugin.
- It is required that the file provides a 'renderPluginSettings' function! - It is required that the file provides a 'renderPluginSettings' function!
@@ -239,7 +240,7 @@ class PluginConfig(InvenTree.models.MetadataMixin, models.Model):
return None return None
@property @property
def admin_context(self) -> dict: def admin_context(self) -> Optional[dict]:
"""Return the context data for the admin integration.""" """Return the context data for the admin integration."""
if not self.plugin: if not self.plugin:
return None return None
+6 -5
View File
@@ -3,7 +3,7 @@
import inspect import inspect
import warnings import warnings
from datetime import datetime from datetime import datetime
from distutils.sysconfig import get_python_lib from distutils.sysconfig import get_python_lib # type: ignore[import]
from importlib.metadata import PackageNotFoundError, metadata from importlib.metadata import PackageNotFoundError, metadata
from pathlib import Path from pathlib import Path
from typing import Optional, Union from typing import Optional, Union
@@ -568,8 +568,9 @@ class InvenTreePlugin(VersionMixin, MixinBase, MetaBase):
package = {} package = {}
# process date # process date
if package.get('date'): date = package.get('date')
package['date'] = datetime.fromisoformat(package.get('date')) if date:
package['date'] = datetime.fromisoformat(date)
# set variables # set variables
self.package = package self.package = package
@@ -608,7 +609,7 @@ class InvenTreePlugin(VersionMixin, MixinBase, MetaBase):
return url return url
def get_admin_source(self) -> str: def get_admin_source(self) -> Union[str, None]:
"""Return a path to a JavaScript file which contains custom UI settings. """Return a path to a JavaScript file which contains custom UI settings.
The frontend code expects that this file provides a function named 'renderPluginSettings'. The frontend code expects that this file provides a function named 'renderPluginSettings'.
@@ -618,7 +619,7 @@ class InvenTreePlugin(VersionMixin, MixinBase, MetaBase):
return self.plugin_static_file(self.ADMIN_SOURCE) return self.plugin_static_file(self.ADMIN_SOURCE)
def get_admin_context(self) -> dict: def get_admin_context(self) -> Union[dict, None]:
"""Return a context dictionary for the admin panel settings. """Return a context dictionary for the admin panel settings.
This is an optional method which can be overridden by the plugin. This is an optional method which can be overridden by the plugin.
+11 -8
View File
@@ -121,12 +121,12 @@ class PluginsRegistry:
self.ready = False # Marks if the registry is ready to be used self.ready = False # Marks if the registry is ready to be used
# Keep an internal hash of the plugin registry state # Keep an internal hash of the plugin registry state
self.registry_hash = None self.registry_hash: Optional[str] = None
self.plugin_modules: list[InvenTreePlugin] = [] # Holds all discovered plugins self.plugin_modules: list[InvenTreePlugin] = [] # Holds all discovered plugins
self.mixin_modules: dict[str, Any] = {} # Holds all discovered mixins self.mixin_modules: dict[str, Any] = {} # Holds all discovered mixins
self.errors = {} # Holds errors discovered during loading self.errors: dict[str, list[Any]] = {} # Holds errors discovered during loading
self.loading_lock = Lock() # Lock to prevent multiple loading at the same time self.loading_lock = Lock() # Lock to prevent multiple loading at the same time
@@ -289,7 +289,7 @@ class PluginsRegistry:
@registry_entrypoint(default_value=[]) @registry_entrypoint(default_value=[])
def with_mixin( def with_mixin(
self, mixin: str, active: bool = True, builtin: Optional[bool] = None self, mixin: str, active: Optional[bool] = True, builtin: Optional[bool] = None
) -> list[InvenTreePlugin]: ) -> list[InvenTreePlugin]:
"""Returns reference to all plugins that have a specified mixin enabled. """Returns reference to all plugins that have a specified mixin enabled.
@@ -764,9 +764,9 @@ class PluginsRegistry:
f"Plugin '{p}' is not compatible with the current InvenTree version {v}" f"Plugin '{p}' is not compatible with the current InvenTree version {v}"
) )
if v := plg_i.MIN_VERSION: if v := plg_i.MIN_VERSION:
_msg += _(f'Plugin requires at least version {v}') _msg += _(f'Plugin requires at least version {v}') # type: ignore[unsupported-operator]
if v := plg_i.MAX_VERSION: if v := plg_i.MAX_VERSION:
_msg += _(f'Plugin requires at most version {v}') _msg += _(f'Plugin requires at most version {v}') # type: ignore[unsupported-operator]
# Log to error stack # Log to error stack
log_registry_error(_msg, reference=f'{p}:init_plugin') log_registry_error(_msg, reference=f'{p}:init_plugin')
else: else:
@@ -809,7 +809,7 @@ class PluginsRegistry:
logger.exception( logger.exception(
'[PLUGIN] Encountered an error with %s:\n%s', '[PLUGIN] Encountered an error with %s:\n%s',
error.path, getattr(error, 'path', None),
str(error), str(error),
) )
@@ -1084,11 +1084,14 @@ def _load_source(modname, filename):
# loader = importlib.machinery.SourceFileLoader(modname, filename) # loader = importlib.machinery.SourceFileLoader(modname, filename)
spec = importlib.util.spec_from_file_location(modname, filename) # , loader=loader) spec = importlib.util.spec_from_file_location(modname, filename) # , loader=loader)
if spec is None:
raise ImportError(f"Cannot find module '{modname}'") # pragma: no cover
module = importlib.util.module_from_spec(spec) module = importlib.util.module_from_spec(spec)
sys.modules[module.__name__] = module sys.modules[module.__name__] = module
if spec.loader: loader = spec.loader
spec.loader.exec_module(module) if loader is not None:
loader.exec_module(module)
return module return module
+6 -3
View File
@@ -7,6 +7,7 @@ import tempfile
import textwrap import textwrap
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Optional
from unittest import mock from unittest import mock
from unittest.mock import patch from unittest.mock import patch
@@ -204,7 +205,9 @@ class InvenTreePluginTests(TestCase):
self.assertFalse(self.plugin_version.check_version([0, 1, 4])) self.assertFalse(self.plugin_version.check_version([0, 1, 4]))
plug = registry.plugins_full.get('sampleversion') plug = registry.plugins_full.get('sampleversion')
self.assertEqual(plug.is_active(), False) self.assertIsNotNone(plug)
if plug:
self.assertEqual(plug.is_active(), False)
class RegistryTests(TestQueryMixin, PluginRegistryMixin, TestCase): class RegistryTests(TestQueryMixin, PluginRegistryMixin, TestCase):
@@ -251,7 +254,7 @@ class RegistryTests(TestQueryMixin, PluginRegistryMixin, TestCase):
def test_folder_loading(self): def test_folder_loading(self):
"""Test that plugins in folders outside of BASE_DIR get loaded.""" """Test that plugins in folders outside of BASE_DIR get loaded."""
# Run in temporary directory -> always a new random name # Run in temporary directory -> always a new random name
with tempfile.TemporaryDirectory() as tmp: with tempfile.TemporaryDirectory() as tmp: # type: ignore[no-matching-overload]
# Fill directory with sample data # Fill directory with sample data
new_dir = Path(tmp).joinpath('mock') new_dir = Path(tmp).joinpath('mock')
shutil.copytree(self.mockDir(), new_dir) shutil.copytree(self.mockDir(), new_dir)
@@ -400,7 +403,7 @@ class RegistryTests(TestQueryMixin, PluginRegistryMixin, TestCase):
def create_plugin_file( def create_plugin_file(
version: str, enabled: bool = True, reload: bool = True version: str, enabled: bool = True, reload: bool = True
) -> str: ) -> Optional[str]:
"""Create a plugin file with the given version. """Create a plugin file with the given version.
Arguments: Arguments:
+3 -2
View File
@@ -6,8 +6,9 @@ from django.utils.decorators import method_decorator
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from django.views.decorators.cache import never_cache from django.views.decorators.cache import never_cache
from django_filters import rest_framework as rest_filters import django_filters.rest_framework.filters as rest_filters
from django_filters.rest_framework import DjangoFilterBackend from django_filters.rest_framework import DjangoFilterBackend
from django_filters.rest_framework.filterset import FilterSet
from rest_framework.generics import GenericAPIView from rest_framework.generics import GenericAPIView
from rest_framework.response import Response from rest_framework.response import Response
@@ -31,7 +32,7 @@ class TemplatePermissionMixin:
permission_classes = [InvenTree.permissions.IsStaffOrReadOnlyScope] permission_classes = [InvenTree.permissions.IsStaffOrReadOnlyScope]
class ReportFilterBase(rest_filters.FilterSet): class ReportFilterBase(FilterSet):
"""Base filter class for label and report templates.""" """Base filter class for label and report templates."""
enabled = rest_filters.BooleanFilter() enabled = rest_filters.BooleanFilter()
+1 -1
View File
@@ -67,7 +67,7 @@ class ReportConfig(AppConfig):
def cleanup(self): def cleanup(self):
"""Cleanup old label and report outputs.""" """Cleanup old label and report outputs."""
try: try:
from report.tasks import cleanup_old_report_outputs from report.tasks import cleanup_old_report_outputs # type: ignore[import]
cleanup_old_report_outputs() cleanup_old_report_outputs()
except Exception: except Exception:
+6 -3
View File
@@ -238,7 +238,7 @@ class ReportTemplateBase(MetadataMixin, InvenTree.models.InvenTreeModel):
), ),
) )
def generate_filename(self, context, **kwargs): def generate_filename(self, context, **kwargs) -> str:
"""Generate a filename for this report.""" """Generate a filename for this report."""
template_string = Template(self.filename_pattern) template_string = Template(self.filename_pattern)
@@ -491,7 +491,7 @@ class ReportTemplate(TemplateUploadMixin, ReportTemplateBase):
debug_mode = get_global_setting('REPORT_DEBUG_MODE', False) debug_mode = get_global_setting('REPORT_DEBUG_MODE', False)
# Start with a default report name # Start with a default report name
report_name = None report_name: Optional[str] = None
# If a DataOutput object is not provided, create a new one # If a DataOutput object is not provided, create a new one
if not output: if not output:
@@ -608,6 +608,9 @@ class ReportTemplate(TemplateUploadMixin, ReportTemplateBase):
'path': request.path if request else None, 'path': request.path if request else None,
}) })
if not report_name:
report_name = '' # pragma: no cover
if not report_name.endswith('.pdf'): if not report_name.endswith('.pdf'):
report_name += '.pdf' report_name += '.pdf'
@@ -695,7 +698,7 @@ class LabelTemplate(TemplateUploadMixin, ReportTemplateBase):
def get_context(self, instance, request=None, **kwargs): def get_context(self, instance, request=None, **kwargs):
"""Supply context data to the label template for rendering.""" """Supply context data to the label template for rendering."""
base_context = super().get_context(instance, request, **kwargs) base_context = super().get_context(instance, request, **kwargs)
label_context: LabelContextExtension = { label_context: LabelContextExtension = { # type: ignore[invalid-assignment]
'width': self.width, 'width': self.width,
'height': self.height, 'height': self.height,
'page_style': None, 'page_style': None,
@@ -4,6 +4,7 @@ from django import template
from django.utils.safestring import mark_safe from django.utils.safestring import mark_safe
import barcode as python_barcode import barcode as python_barcode
import barcode.writer as python_barcode_writer
import qrcode.constants as ECL import qrcode.constants as ECL
from PIL import Image, ImageColor from PIL import Image, ImageColor
from qrcode.main import QRCode from qrcode.main import QRCode
@@ -122,7 +123,7 @@ def barcode(data: str, barcode_class='code128', **kwargs) -> str:
data = str(data).zfill(constructor.digits) data = str(data).zfill(constructor.digits)
writer = python_barcode.writer.ImageWriter writer = python_barcode_writer.ImageWriter
barcode_image = constructor(data, writer=writer()) barcode_image = constructor(data, writer=writer())
@@ -148,7 +149,7 @@ def datamatrix(data: str, **kwargs) -> str:
Returns: Returns:
image (str): base64 encoded image data image (str): base64 encoded image data
""" """
from ppf.datamatrix import DataMatrix from ppf.datamatrix.datamatrix import DataMatrix
data = str(data).strip() data = str(data).strip()
@@ -50,7 +50,7 @@ def filter_queryset(queryset: QuerySet, **kwargs) -> QuerySet:
@register.simple_tag() @register.simple_tag()
def filter_db_model(model_name: str, **kwargs) -> QuerySet: def filter_db_model(model_name: str, **kwargs) -> Optional[QuerySet]:
"""Filter a database model based on the provided keyword arguments. """Filter a database model based on the provided keyword arguments.
Arguments: Arguments:
@@ -102,7 +102,7 @@ def getindex(container: list, index: int) -> Any:
@register.simple_tag() @register.simple_tag()
def getkey(container: dict, key: str, backup_value: Optional[any] = None) -> Any: def getkey(container: dict, key: str, backup_value: Optional[Any] = None) -> Any:
"""Perform key lookup in the provided dict object. """Perform key lookup in the provided dict object.
This function is provided to get around template rendering limitations. This function is provided to get around template rendering limitations.
@@ -301,14 +301,13 @@ def part_image(part: Part, preview: bool = False, thumbnail: bool = False, **kwa
if type(part) is not Part: if type(part) is not Part:
raise TypeError(_('part_image tag requires a Part instance')) raise TypeError(_('part_image tag requires a Part instance'))
if not part.image: part_img = part.image
if not part_img:
img = None img = None
elif preview: elif preview:
img = None if not hasattr(part.image, 'preview') else part.image.preview.name img = None if not hasattr(part.image, 'preview') else part_img.preview.name
elif thumbnail: elif thumbnail:
img = ( img = None if not hasattr(part.image, 'thumbnail') else part_img.thumbnail.name
None if not hasattr(part.image, 'thumbnail') else part.image.thumbnail.name
)
else: else:
img = part.image.name img = part.image.name
@@ -316,7 +315,7 @@ def part_image(part: Part, preview: bool = False, thumbnail: bool = False, **kwa
@register.simple_tag() @register.simple_tag()
def part_parameter(part: Part, parameter_name: str) -> str: def part_parameter(part: Part, parameter_name: str) -> Optional[str]:
"""Return a PartParameter object for the given part and parameter name. """Return a PartParameter object for the given part and parameter name.
Arguments: Arguments:
@@ -348,12 +347,15 @@ def company_image(
if type(company) is not Company: if type(company) is not Company:
raise TypeError(_('company_image tag requires a Company instance')) raise TypeError(_('company_image tag requires a Company instance'))
if preview: cmp_img = company.image
img = company.image.preview.name if not cmp_img:
img = None
elif preview:
img = cmp_img.preview.name
elif thumbnail: elif thumbnail:
img = company.image.thumbnail.name img = cmp_img.thumbnail.name
else: else:
img = company.image.name img = cmp_img.name
return uploaded_image(img, **kwargs) return uploaded_image(img, **kwargs)
+4
View File
@@ -84,6 +84,7 @@ class ReportTest(InvenTreeAPITestCase):
# Filter by items # Filter by items
part_pk = Part.objects.first().pk part_pk = Part.objects.first().pk
report = ReportTemplate.objects.filter(model_type='part').first() report = ReportTemplate.objects.filter(model_type='part').first()
assert report
try: try:
response = self.get( response = self.get(
@@ -236,6 +237,7 @@ class ReportTest(InvenTreeAPITestCase):
url = reverse('api-report-template-list') url = reverse('api-report-template-list')
template = ReportTemplate.objects.first() template = ReportTemplate.objects.first()
assert template
detail_url = reverse('api-report-template-detail', kwargs={'pk': template.pk}) detail_url = reverse('api-report-template-detail', kwargs={'pk': template.pk})
@@ -415,6 +417,7 @@ class PrintTestMixins:
qs = qs.objects.all() qs = qs.objects.all()
template = mdl.objects.filter(enabled=True, model_type=model_type).first() template = mdl.objects.filter(enabled=True, model_type=model_type).first()
assert template
plugin = registry.get_plugin(self.plugin_ref) plugin = registry.get_plugin(self.plugin_ref)
# Single page printing # Single page printing
@@ -475,6 +478,7 @@ class TestReportTest(PrintTestMixins, ReportTest):
template = ReportTemplate.objects.filter( template = ReportTemplate.objects.filter(
enabled=True, model_type='stockitem' enabled=True, model_type='stockitem'
).first() ).first()
assert template
self.assertIsNotNone(template) self.assertIsNotNone(template)
+10 -9
View File
@@ -9,7 +9,8 @@ from django.db.models import F, Q
from django.urls import include, path from django.urls import include, path
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from django_filters import rest_framework as rest_filters import django_filters.rest_framework.filters as rest_filters
from django_filters.rest_framework.filterset import FilterSet
from drf_spectacular.types import OpenApiTypes from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import extend_schema, extend_schema_field from drf_spectacular.utils import extend_schema, extend_schema_field
from rest_framework import status from rest_framework import status
@@ -257,7 +258,7 @@ class StockMerge(CreateAPI):
return ctx return ctx
class StockLocationFilter(rest_filters.FilterSet): class StockLocationFilter(FilterSet):
"""Base class for custom API filters for the StockLocation endpoint.""" """Base class for custom API filters for the StockLocation endpoint."""
class Meta: class Meta:
@@ -425,11 +426,11 @@ class StockLocationDetail(StockLocationMixin, CustomRetrieveUpdateDestroyAPI):
return super().destroy( return super().destroy(
request, request,
*args, *args,
**dict( **{
kwargs, **kwargs,
delete_sub_locations=delete_sub_locations, 'delete_sub_locations': delete_sub_locations,
delete_stock_items=delete_stock_items, 'delete_stock_items': delete_stock_items,
), },
) )
@@ -505,7 +506,7 @@ class StockLocationTypeDetail(RetrieveUpdateDestroyAPI):
return queryset return queryset
class StockFilter(rest_filters.FilterSet): class StockFilter(FilterSet):
"""FilterSet for StockItem LIST API.""" """FilterSet for StockItem LIST API."""
class Meta: class Meta:
@@ -1339,7 +1340,7 @@ class StockItemTestResultDetail(StockItemTestResultMixin, RetrieveUpdateDestroyA
"""Detail endpoint for StockItemTestResult.""" """Detail endpoint for StockItemTestResult."""
class StockItemTestResultFilter(rest_filters.FilterSet): class StockItemTestResultFilter(FilterSet):
"""API filter for the StockItemTestResult list.""" """API filter for the StockItemTestResult list."""
class Meta: class Meta:
+3 -1
View File
@@ -1,12 +1,14 @@
"""Custom query filters for the Stock models.""" """Custom query filters for the Stock models."""
from typing import Optional
from django.db.models import F, Func, IntegerField, OuterRef, Q, Subquery from django.db.models import F, Func, IntegerField, OuterRef, Q, Subquery
from django.db.models.functions import Coalesce from django.db.models.functions import Coalesce
import stock.models import stock.models
def annotate_location_items(filter: Q = None): def annotate_location_items(filter: Optional[Q] = None):
"""Construct a queryset annotation which returns the number of stock items in a particular location. """Construct a queryset annotation which returns the number of stock items in a particular location.
- Includes items in subcategories also - Includes items in subcategories also
+2 -1
View File
@@ -1,6 +1,7 @@
"""Generator functions for the stock app.""" """Generator functions for the stock app."""
from inspect import signature from inspect import signature
from typing import Optional
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
@@ -78,7 +79,7 @@ def generate_batch_code(**kwargs):
return Template(batch_template).render(context) return Template(batch_template).render(context)
def generate_serial_number(part=None, quantity=1, **kwargs) -> str: def generate_serial_number(part=None, quantity=1, **kwargs) -> Optional[str]:
"""Generate a default 'serial number' for a new StockItem.""" """Generate a default 'serial number' for a new StockItem."""
quantity = quantity or 1 quantity = quantity or 1
@@ -51,10 +51,10 @@ def update_history(apps, schema_editor):
q = entry.quantity q = entry.quantity
if idx == 0 or not q == quantity: if idx == 0 or q != quantity:
try: try:
deltas['quantity']: float(q) deltas['quantity']= float(q)
updated = True updated = True
except Exception: except Exception:
print(f"WARNING: Error converting quantity '{q}'") print(f"WARNING: Error converting quantity '{q}'")
+2 -2
View File
@@ -667,7 +667,7 @@ class StockItem(
return items return items
@staticmethod @staticmethod
def convert_serial_to_int(serial: str) -> int: def convert_serial_to_int(serial: str) -> Optional[int]:
"""Convert the provided serial number to an integer value. """Convert the provided serial number to an integer value.
This function hooks into the plugin system to allow for custom serial number conversion. This function hooks into the plugin system to allow for custom serial number conversion.
@@ -1784,7 +1784,7 @@ class StockItem(
self, self,
entry_type: int, entry_type: int,
user: User, user: User,
deltas: dict | None = None, deltas: Optional[dict] = None,
notes: str = '', notes: str = '',
commit: bool = True, commit: bool = True,
**kwargs, **kwargs,
@@ -20,6 +20,7 @@ import build.models
import company.models import company.models
import company.serializers as company_serializers import company.serializers as company_serializers
import InvenTree.helpers import InvenTree.helpers
import InvenTree.ready
import InvenTree.serializers import InvenTree.serializers
import order.models import order.models
import part.filters as part_filters import part.filters as part_filters
+1
View File
@@ -2079,6 +2079,7 @@ class StockTestResultTest(StockAPITestCase):
url = reverse('api-stock-test-result-list') url = reverse('api-stock-test-result-list')
test_template = PartTestTemplate.objects.first() test_template = PartTestTemplate.objects.first()
assert test_template
test_template.choices = 'AA, BB, CC' test_template.choices = 'AA, BB, CC'
test_template.save() test_template.save()
+3 -1
View File
@@ -47,7 +47,9 @@ User.add_to_class('__str__', user_model_str) # Overriding User.__str__
if settings.LDAP_AUTH: if settings.LDAP_AUTH:
from django_auth_ldap.backend import populate_user from django_auth_ldap.backend import ( # type: ignore[unresolved-import]
populate_user,
)
@receiver(populate_user) @receiver(populate_user)
def create_email_address(user, **kwargs): def create_email_address(user, **kwargs):
+2 -2
View File
@@ -31,8 +31,8 @@ def get_model_permission_string(model: models.Model, permission: str) -> str:
Returns: Returns:
str: The permission string (e.g. 'part.view_part') str: The permission string (e.g. 'part.view_part')
""" """
model, app = split_model(model) _model, _app = split_model(model)
return f'{app}.{permission}_{model}' return f'{_app}.{permission}_{_model}'
def split_permission(app: str, perm: str) -> tuple[str, str]: def split_permission(app: str, perm: str) -> tuple[str, str]:
+1 -1
View File
@@ -281,7 +281,7 @@ class GroupSerializer(InvenTreeModelSerializer):
class ExtendedUserSerializer(UserSerializer): class ExtendedUserSerializer(UserSerializer):
"""Serializer for a User with a bit more info.""" """Serializer for a User with a bit more info."""
from users.serializers import GroupSerializer # from users.serializers import GroupSerializer
class Meta(UserSerializer.Meta): class Meta(UserSerializer.Meta):
"""Metaclass defines serializer fields.""" """Metaclass defines serializer fields."""
+3 -1
View File
@@ -1,5 +1,7 @@
"""Background tasks for the users app.""" """Background tasks for the users app."""
from typing import Any
from django.contrib.auth.models import Group, Permission from django.contrib.auth.models import Group, Permission
from django.contrib.contenttypes.models import ContentType from django.contrib.contenttypes.models import ContentType
@@ -99,7 +101,7 @@ def update_group_roles(group: Group, debug: bool = False) -> None:
permissions_to_delete.add(permission_string) permissions_to_delete.add(permission_string)
# Pre-fetch all the RuleSet objects # Pre-fetch all the RuleSet objects
rulesets = { rulesets: dict[Any, RuleSet] = {
r.name: r for r in RuleSet.objects.filter(group=group).prefetch_related('group') r.name: r for r in RuleSet.objects.filter(group=group).prefetch_related('group')
} }
+4
View File
@@ -131,6 +131,8 @@ class UserAPITests(InvenTreeAPITestCase):
def test_user_detail(self): def test_user_detail(self):
"""Test the UserDetail API endpoint.""" """Test the UserDetail API endpoint."""
user = User.objects.first() user = User.objects.first()
assert user
url = reverse('api-user-detail', kwargs={'pk': user.pk}) url = reverse('api-user-detail', kwargs={'pk': user.pk})
user.is_staff = False user.is_staff = False
@@ -274,6 +276,7 @@ class UserTokenTests(InvenTreeAPITestCase):
# If we re-generate a token, the value changes # If we re-generate a token, the value changes
token = ApiToken.objects.filter(name='cat').first() token = ApiToken.objects.filter(name='cat').first()
assert token
# Request the token with the same name # Request the token with the same name
data = self.get(url, data={'name': 'cat'}, expected_code=200).data data = self.get(url, data={'name': 'cat'}, expected_code=200).data
@@ -331,6 +334,7 @@ class UserTokenTests(InvenTreeAPITestCase):
# Grab the token, and update # Grab the token, and update
token = ApiToken.objects.first() token = ApiToken.objects.first()
assert token
self.assertEqual(token.key, token_key) self.assertEqual(token.key, token_key)
self.assertIsNotNone(token.last_seen) self.assertIsNotNone(token.last_seen)
@@ -1,6 +1,7 @@
"""Template tag to render SPA imports.""" """Template tag to render SPA imports."""
import json import json
import json.decoder
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
+3
View File
@@ -9,3 +9,6 @@ pip-tools # Compile pip requirements
pre-commit # Git pre-commit pre-commit # Git pre-commit
setuptools # Standard dependency setuptools # Standard dependency
pdfminer.six # PDF validation pdfminer.six # PDF validation
ty # type checking
django-types # typing
django-stubs # typing
+46
View File
@@ -6,6 +6,7 @@ asgiref==3.9.1 \
# via # via
# -c src/backend/requirements.txt # -c src/backend/requirements.txt
# django # django
# django-stubs
build==1.3.0 \ build==1.3.0 \
--hash=sha256:698edd0ea270bde950f53aed21f3a0135672206f3911e0176261a31e0e07b397 \ --hash=sha256:698edd0ea270bde950f53aed21f3a0135672206f3911e0176261a31e0e07b397 \
--hash=sha256:7145f0b5061ba90a1500d60bd1b13ca0a8a4cebdd0cc16ed8adf1c0e739f43b4 --hash=sha256:7145f0b5061ba90a1500d60bd1b13ca0a8a4cebdd0cc16ed8adf1c0e739f43b4
@@ -326,16 +327,30 @@ django==4.2.24 \
# via # via
# -c src/backend/requirements.txt # -c src/backend/requirements.txt
# django-slowtests # django-slowtests
# django-stubs
# django-stubs-ext
django-querycount==0.8.3 \ django-querycount==0.8.3 \
--hash=sha256:0782484e8a1bd29498fa0195a67106e47cdcc98fafe80cebb1991964077cb694 --hash=sha256:0782484e8a1bd29498fa0195a67106e47cdcc98fafe80cebb1991964077cb694
# via -r src/backend/requirements-dev.in # via -r src/backend/requirements-dev.in
django-slowtests==1.1.1 \ django-slowtests==1.1.1 \
--hash=sha256:3c6936d420c9df444ac03625b41d97de043c662bbde61fbcd33e4cd407d0c247 --hash=sha256:3c6936d420c9df444ac03625b41d97de043c662bbde61fbcd33e4cd407d0c247
# via -r src/backend/requirements-dev.in # via -r src/backend/requirements-dev.in
django-stubs==5.1.3 \
--hash=sha256:716758ced158b439213062e52de6df3cff7c586f9f9ad7ab59210efbea5dfe78 \
--hash=sha256:8c230bc5bebee6da282ba8a27ad1503c84a0c4cd2f46e63d149e76d2a63e639a
# via -r src/backend/requirements-dev.in
django-stubs-ext==5.1.3 \
--hash=sha256:3e60f82337f0d40a362f349bf15539144b96e4ceb4dbd0239be1cd71f6a74ad0 \
--hash=sha256:64561fbc53e963cc1eed2c8eb27e18b8e48dcb90771205180fe29fc8a59e55fd
# via django-stubs
django-test-migrations==1.4.0 \ django-test-migrations==1.4.0 \
--hash=sha256:294dff98f6d43d020d4046b971bac5339e7c71458a35e9ad6450c388fe16ed6b \ --hash=sha256:294dff98f6d43d020d4046b971bac5339e7c71458a35e9ad6450c388fe16ed6b \
--hash=sha256:f0c9c92864ed27d0c9a582e92056637e91227f54bd868a50cb9a1726668c563e --hash=sha256:f0c9c92864ed27d0c9a582e92056637e91227f54bd868a50cb9a1726668c563e
# via -r src/backend/requirements-dev.in # via -r src/backend/requirements-dev.in
django-types==0.20.0 \
--hash=sha256:4e55d2c56155e3d69d75def9eb1d95a891303f2ac19fccf6fe8056da4293fae7 \
--hash=sha256:a0b5c2c9a1e591684bb21a93b64e50ca6cb2d3eab48f49faff1eac706bd3a9c7
# via -r src/backend/requirements-dev.in
filelock==3.18.0 \ filelock==3.18.0 \
--hash=sha256:adbc88eabb99d2fec8c9c1b229b171f18afa655400173ddc653d5d01501fb9f2 \ --hash=sha256:adbc88eabb99d2fec8c9c1b229b171f18afa655400173ddc653d5d01501fb9f2 \
--hash=sha256:c401f4f8377c4464e6db25fff06205fd89bdd83b65eb0488ed1b160f780e21de --hash=sha256:c401f4f8377c4464e6db25fff06205fd89bdd83b65eb0488ed1b160f780e21de
@@ -505,13 +520,44 @@ tomli==2.2.1 \
# -c src/backend/requirements.txt # -c src/backend/requirements.txt
# build # build
# coverage # coverage
# django-stubs
# pip-tools # pip-tools
ty==0.0.1a20 \
--hash=sha256:0b481f26513f38543df514189fb16744690bcba8d23afee95a01927d93b46e36 \
--hash=sha256:3c2ace3a22fab4bd79f84c74e3dab26e798bfba7006bea4008d6321c1bd6efc6 \
--hash=sha256:3ff75cd4c744d09914e8c9db8d99e02f82c9379ad56b0a3fc4c5c9c923cfa84e \
--hash=sha256:726d0738be4459ac7ffae312ba96c5f486d6cbc082723f322555d7cba9397871 \
--hash=sha256:7abbe3c02218c12228b1d7c5f98c57240029cc3bcb15b6997b707c19be3908c1 \
--hash=sha256:83a7ee12465841619b5eb3ca962ffc7d576bb1c1ac812638681aee241acbfbbe \
--hash=sha256:8a138fa4f74e6ed34e9fd14652d132409700c7ff57682c2fed656109ebfba42f \
--hash=sha256:8eff8871d6b88d150e2a67beba2c57048f20c090c219f38ed02eebaada04c124 \
--hash=sha256:933b65a152f277aa0e23ba9027e5df2c2cc09e18293e87f2a918658634db5f15 \
--hash=sha256:b4124ab75e0e6f09fe7bc9df4a77ee43c5e0ef7e61b0c149d7c089d971437cbd \
--hash=sha256:b8c4336987a6a781d4392a9fd7b3a39edb7e4f3dd4f860e03f46c932b52aefa2 \
--hash=sha256:cad12c857ea4b97bf61e02f6796e13061ccca5e41f054cbd657862d80aa43bae \
--hash=sha256:d8ac1c5a14cda5fad1a8b53959d9a5d979fe16ce1cc2785ea8676fed143ac85f \
--hash=sha256:e26437772be7f7808868701f2bf9e14e706a6ec4c7d02dbd377ff94d7ba60c11 \
--hash=sha256:f153b65c7fcb6b8b59547ddb6353761b3e8d8bb6f0edd15e3e3ac14405949f7a \
--hash=sha256:f41e77ff118da3385915e13c3f366b3a2f823461de54abd2e0ca72b170ba0f19 \
--hash=sha256:f73a7aca1f0d38af4d6999b375eb00553f3bfcba102ae976756cc142e14f3450 \
--hash=sha256:fff51c75ee3f7cc6d7722f2f15789ef8ffe6fd2af70e7269ac785763c906688e
# via -r src/backend/requirements-dev.in
types-psycopg2==2.9.21.20250915 \
--hash=sha256:bfeb8f54c32490e7b5edc46215ab4163693192bc90407b4a023822de9239f5c8 \
--hash=sha256:eefe5ccdc693fc086146e84c9ba437bb278efe1ef330b299a0cb71169dc6c55f
# via django-types
types-pyyaml==6.0.12.20250915 \
--hash=sha256:0f8b54a528c303f0e6f7165687dd33fafa81c807fcac23f632b63aa624ced1d3 \
--hash=sha256:e7d4d9e064e89a3b3cae120b4990cd370874d2bf12fa5f46c97018dd5d3c9ab6
# via django-stubs
typing-extensions==4.14.1 \ typing-extensions==4.14.1 \
--hash=sha256:38b39f4aeeab64884ce9f74c94263ef78f3c22467c8724005483154c26648d36 \ --hash=sha256:38b39f4aeeab64884ce9f74c94263ef78f3c22467c8724005483154c26648d36 \
--hash=sha256:d1e1e3b58374dc93031d6eda2420a48ea44a36c2b4766a4fdeb3710755731d76 --hash=sha256:d1e1e3b58374dc93031d6eda2420a48ea44a36c2b4766a4fdeb3710755731d76
# via # via
# -c src/backend/requirements.txt # -c src/backend/requirements.txt
# asgiref # asgiref
# django-stubs
# django-stubs-ext
# django-test-migrations # django-test-migrations
virtualenv==20.33.1 \ virtualenv==20.33.1 \
--hash=sha256:07c19bc66c11acab6a5958b815cbcee30891cd1c2ccf53785a28651a0d8d8a67 \ --hash=sha256:07c19bc66c11acab6a5958b815cbcee30891cd1c2ccf53785a28651a0d8d8a67 \
+13 -7
View File
@@ -77,7 +77,7 @@ def is_pkg_installer_by_path():
def get_installer(content: Optional[dict] = None): def get_installer(content: Optional[dict] = None):
"""Get the installer for the current environment or a content dict.""" """Get the installer for the current environment or a content dict."""
if content is None: if content is None:
content = os.environ content = dict(os.environ)
return content.get('INVENTREE_PKG_INSTALLER', None) return content.get('INVENTREE_PKG_INSTALLER', None)
@@ -461,7 +461,9 @@ def check_file_existence(filename: Path, overwrite: bool = False):
@state_logger('TASK01') @state_logger('TASK01')
def plugins(c, uv=False): def plugins(c, uv=False):
"""Installs all plugins as specified in 'plugins.txt'.""" """Installs all plugins as specified in 'plugins.txt'."""
from src.backend.InvenTree.InvenTree.config import get_plugin_file from src.backend.InvenTree.InvenTree.config import ( # type: ignore[import]
get_plugin_file,
)
plugin_file = get_plugin_file() plugin_file = get_plugin_file()
@@ -573,7 +575,9 @@ def rebuild_models(c):
@task @task
def rebuild_thumbnails(c): def rebuild_thumbnails(c):
"""Rebuild missing image thumbnails.""" """Rebuild missing image thumbnails."""
from src.backend.InvenTree.InvenTree.config import get_media_dir from src.backend.InvenTree.InvenTree.config import ( # type: ignore[import]
get_media_dir,
)
info(f'Rebuilding image thumbnails in {get_media_dir()}') info(f'Rebuilding image thumbnails in {get_media_dir()}')
manage(c, 'rebuild_thumbnails', pty=True) manage(c, 'rebuild_thumbnails', pty=True)
@@ -1165,7 +1169,7 @@ def test_translations(c):
info('Fill in dummy translations...') info('Fill in dummy translations...')
file_path = pathlib.Path(settings.LOCALE_PATHS[0], 'xx', 'LC_MESSAGES', 'django.po') file_path = pathlib.Path(settings.LOCALE_PATHS[0], 'xx', 'LC_MESSAGES', 'django.po')
new_file_path = str(file_path) + '_new' new_file_path = Path(str(file_path) + '_new')
# compile regex # compile regex
reg = re.compile( reg = re.compile(
@@ -1303,7 +1307,9 @@ def setup_test(
path='inventree-demo-dataset', path='inventree-demo-dataset',
): ):
"""Setup a testing environment.""" """Setup a testing environment."""
from src.backend.InvenTree.InvenTree.config import get_media_dir from src.backend.InvenTree.InvenTree.config import ( # type: ignore[import]
get_media_dir,
)
if not ignore_update: if not ignore_update:
update(c) update(c)
@@ -1453,8 +1459,8 @@ def export_definitions(c, basedir: str = ''):
@task(default=True) @task(default=True)
def version(c): def version(c):
"""Show the current version of InvenTree.""" """Show the current version of InvenTree."""
import src.backend.InvenTree.InvenTree.version as InvenTreeVersion import src.backend.InvenTree.InvenTree.version as InvenTreeVersion # type: ignore[import]
from src.backend.InvenTree.InvenTree.config import ( from src.backend.InvenTree.InvenTree.config import ( # type: ignore[import]
get_backup_dir, get_backup_dir,
get_config_file, get_config_file,
get_media_dir, get_media_dir,