2
0
mirror of https://github.com/inventree/InvenTree.git synced 2026-05-13 21:17:33 +00:00

feat(backend): add typechecking with ty (#9664)

* Add ty for type checking

* fix various typing issues

* fix req

* more fixes

* and more types

* and more typing

* fix imports

* more fixes

* fix types and optional statements

* ensure patch only runs if it is installed

* add type check to qc

* more fixes

* install all reqs

* fix more types

* more fixes

* disable container stuff for now

* move typecheck to seperate job

* try to use putput for path

* use env instead

* fix typo

* add missing install

* remove unclear imports - not sure why this was done

* add kwarg names

* fix introduced issue in url call

* ignore import

* fix broken typing changes

* fix filter import

* reduce change set

* remove api-change

* fix dict

* ignore typing errors

* fix more type issues

* ignore errors

* style fix

* fix type

* bump ty

* fix more

* type fixes

* update ignores

* fix import

* fix defaults

* fix ignore

* fix some issues

* fix type
This commit is contained in:
Matthias Mair
2025-09-17 13:30:02 +02:00
committed by GitHub
parent f057247fc1
commit 21cb488eef
100 changed files with 524 additions and 267 deletions
+16 -6
View File
@@ -17,6 +17,7 @@ import os
import re
import sys
from pathlib import Path
from typing import Optional
import requests
@@ -183,7 +184,8 @@ def check_version_number(version_string, allow_duplicate=False):
return highest_release
if __name__ == '__main__':
def main() -> bool:
"""Run the version check."""
parser = argparse.ArgumentParser(description='InvenTree Version Check')
parser.add_argument(
'--show-version',
@@ -220,7 +222,7 @@ if __name__ == '__main__':
# Ensure that we are running in GH Actions
if os.environ.get('GITHUB_ACTIONS', '') != 'true':
print('This script is intended to be run within a GitHub Action!')
sys.exit(1)
return False
print('Running InvenTree version check...')
@@ -261,11 +263,11 @@ if __name__ == '__main__':
)
# Determine which docker tag we are going to use
docker_tags = None
docker_tags: Optional[list[str]] = None
if GITHUB_REF_TYPE == 'tag':
# GITHUB_REF should be of the form /refs/heads/<tag>
version_tag = GITHUB_REF.split('/')[-1]
version_tag: str = GITHUB_REF.split('/')[-1]
print(f"Checking requirements for tagged release - '{version_tag}':")
if version_tag != inventree_version:
@@ -287,11 +289,11 @@ if __name__ == '__main__':
print('GITHUB_REF_TYPE:', GITHUB_REF_TYPE)
print('GITHUB_BASE_REF:', GITHUB_BASE_REF)
print('GITHUB_REF:', GITHUB_REF)
sys.exit(1)
return False
if docker_tags is None:
print('Docker tags could not be determined')
sys.exit(1)
return False
print(f"Version check passed for '{inventree_version}'!")
print(f"Docker tags: '{docker_tags}'")
@@ -308,3 +310,11 @@ if __name__ == '__main__':
if GITHUB_REF_TYPE == 'tag' and highest_release:
env_file.write('stable_release=true\n')
return True
if __name__ == '__main__':
rslt = main()
if rslt is not True:
print('Version check failed!')
sys.exit(1)
+21
View File
@@ -97,6 +97,27 @@ jobs:
pip install --require-hashes -r contrib/dev_reqs/requirements.txt
python3 .github/scripts/version_check.py
typecheck:
name: Style [Typecheck]
runs-on: ubuntu-24.04
needs: [paths-filter, pre-commit]
if: needs.paths-filter.outputs.server == 'true' || needs.paths-filter.outputs.requirements == 'true' || needs.paths-filter.outputs.force == 'true'
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # pin@v4.2.2
with:
persist-credentials: false
- name: Environment Setup
id: setup
uses: ./.github/actions/setup
with:
apt-dependency: gettext poppler-utils
dev-install: true
update: true
- name: Check types
run: |
ty check --python ${Python_ROOT_DIR}/bin/python3
mkdocs:
name: Style [Documentation]
runs-on: ubuntu-24.04
+1 -1
View File
@@ -4,7 +4,7 @@ import json
import os
import re
from datetime import datetime
from distutils.version import StrictVersion
from distutils.version import StrictVersion # type: ignore[import]
from pathlib import Path
import requests
+1 -1
View File
@@ -126,7 +126,7 @@ def check_link(url) -> bool:
return False
def get_build_environment() -> str:
def get_build_environment() -> Optional[str]:
"""Returns the branch we are currently building on, based on the environment variables of the various CI platforms."""
# Check if we are in ReadTheDocs
if os.environ.get('READTHEDOCS') == 'True':
+14
View File
@@ -101,6 +101,20 @@ python-version = "3.9.2"
no-strip-extras=true
generate-hashes=true
[tool.ty.src]
root = "src/backend/InvenTree"
[tool.ty.rules]
unresolved-reference="ignore" # 21 # see https://github.com/astral-sh/ty/issues/220
unresolved-attribute="ignore" # 505 # need Plugin Mixin typing
call-non-callable="ignore" # 8 ##
invalid-return-type="ignore" # 22 ##
invalid-argument-type="ignore" # 49
possibly-unbound-attribute="ignore" # 25 # https://github.com/astral-sh/ty/issues/164
unknown-argument="ignore" # 3 # need to wait for betterdjango field stubs
invalid-assignment="ignore" # 17 # need to wait for betterdjango field stubs
[tool.coverage.run]
source = ["src/backend/InvenTree", "InvenTree"]
dynamic_context = "test_function"
@@ -48,10 +48,13 @@ class AllUserRequire2FAMiddleware(MiddlewareMixin):
def is_allowed_page(self, request: HttpRequest) -> bool:
"""Check if the current page can be accessed without mfa."""
match = request.resolver_match
return (
any(ref in self.app_names for ref in request.resolver_match.app_names)
or request.resolver_match.url_name in self.allowed_pages
or request.resolver_match.route == 'favicon.ico'
None
if match is None
else any(ref in self.app_names for ref in match.app_names)
or match.url_name in self.allowed_pages
or match.route == 'favicon.ico'
)
def enforce_2fa(self, request):
+1
View File
@@ -18,6 +18,7 @@ from rest_framework.serializers import ValidationError
from rest_framework.views import APIView
import InvenTree.config
import InvenTree.permissions
import InvenTree.version
from common.settings import get_global_setting
from InvenTree import helpers
+3
View File
@@ -131,6 +131,9 @@ class InvenTreeConfig(AppConfig):
tasks = InvenTree.tasks.tasks.task_list
for task in tasks:
if not task:
continue # pragma: no cover
ref_name = f'{task.func.__module__}.{task.func.__name__}'
if ref_name in existing_tasks:
+3 -2
View File
@@ -2,6 +2,7 @@
import socket
import threading
from typing import Any
import structlog
@@ -140,7 +141,7 @@ def delete_session_cache() -> None:
del thread_data.request_cache
def get_session_cache(key: str) -> any:
def get_session_cache(key: str) -> Any:
"""Return a cached value from the session cache."""
# Only return a cached value if the request object is available too
if not hasattr(thread_data, 'request'):
@@ -152,7 +153,7 @@ def get_session_cache(key: str) -> any:
return val
def set_session_cache(key: str, value: any) -> None:
def set_session_cache(key: str, value: Any) -> None:
"""Set a cached value in the session cache."""
# Only set a cached value if the request object is available too
if not hasattr(thread_data, 'request'):
+1 -1
View File
@@ -171,7 +171,7 @@ def get_config_file(create=True) -> Path:
return cfg_filename
def load_config_data(set_cache: bool = False) -> map:
def load_config_data(set_cache: bool = False) -> Union[map, None]:
"""Load configuration data from the config file.
Arguments:
@@ -191,7 +191,7 @@ def convert_physical_value(value: str, unit: Optional[str] = None, strip_units=T
attempts.append(f'{value}{unit}')
attempts.append(f'{eng}{unit}')
value = None
value: Optional[str] = None
# Run through the available "attempts", take the first successful result
for attempt in attempts:
@@ -66,7 +66,8 @@ def log_error(
data = error_data
else:
try:
data = '\n'.join(traceback.format_exception(kind, info, data))
formatted_exception = traceback.format_exception(kind, info, data) # type: ignore[no-matching-overload]
data = '\n'.join(formatted_exception)
except AttributeError:
data = 'No traceback information available'
@@ -150,8 +151,10 @@ def exception_handler(exc, context):
if response is not None:
# Convert errors returned under the label '__all__' to 'non_field_errors'
if '__all__' in response.data:
response.data['non_field_errors'] = response.data['__all__']
del response.data['__all__']
data = response.data
if data and '__all__' in data:
data['non_field_errors'] = data['__all__']
del data['__all__']
return response
+7 -6
View File
@@ -6,7 +6,8 @@ from django.conf import settings
from django.utils import timezone
from django.utils.timezone import make_aware
from django_filters import rest_framework as rest_filters
import django_filters.rest_framework.backends as drf_backend
import django_filters.rest_framework.filters as rest_filters
from rest_framework import filters
import InvenTree.helpers
@@ -20,7 +21,7 @@ class InvenTreeDateFilter(rest_filters.DateFilter):
if settings.USE_TZ and value is not None:
tz = timezone.get_current_timezone()
value = datetime(value.year, value.month, value.day)
value = make_aware(value, tz, True)
value = make_aware(value, timezone=tz, is_dst=True)
return super().filter(qs, value)
@@ -192,17 +193,17 @@ class NumberOrNullFilter(rest_filters.NumberFilter):
SEARCH_ORDER_FILTER = [
rest_filters.DjangoFilterBackend,
drf_backend.DjangoFilterBackend,
InvenTreeSearchFilter,
filters.OrderingFilter,
]
SEARCH_ORDER_FILTER_ALIAS = [
rest_filters.DjangoFilterBackend,
drf_backend.DjangoFilterBackend,
InvenTreeSearchFilter,
InvenTreeOrderingFilter,
]
ORDER_FILTER = [rest_filters.DjangoFilterBackend, filters.OrderingFilter]
ORDER_FILTER = [drf_backend.DjangoFilterBackend, filters.OrderingFilter]
ORDER_FILTER_ALIAS = [rest_filters.DjangoFilterBackend, InvenTreeOrderingFilter]
ORDER_FILTER_ALIAS = [drf_backend.DjangoFilterBackend, InvenTreeOrderingFilter]
+2 -2
View File
@@ -107,7 +107,7 @@ def construct_format_regex(fmt_string: str) -> str:
# Add a named capture group for the format entry
if name:
# Check if integer values are required
c = '\\d' if _fmt.endswith('d') else '.'
c = '\\d' if _fmt and _fmt.endswith('d') else '.'
# Specify width
# TODO: Introspect required width
@@ -124,7 +124,7 @@ def construct_format_regex(fmt_string: str) -> str:
return pattern
def validate_string(value: str, fmt_string: str) -> str:
def validate_string(value: str, fmt_string: str) -> bool:
"""Validate that the provided string matches the specified format.
Args:
+11 -11
View File
@@ -9,7 +9,7 @@ import os
import os.path
import re
from decimal import Decimal, InvalidOperation
from typing import Optional, TypeVar
from typing import Optional, TypeVar, Union
from wsgiref.util import FileWrapper
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
@@ -22,6 +22,8 @@ from django.utils import timezone
from django.utils.translation import gettext_lazy as _
import bleach
import bleach.css_sanitizer
import bleach.sanitizer
import structlog
from bleach import clean
from djmoney.money import Money
@@ -124,7 +126,7 @@ def extract_int(
return ref_int
def generateTestKey(test_name: str) -> str:
def generateTestKey(test_name: Union[str, None]) -> str:
"""Generate a test 'key' for a given test name. This must not have illegal chars as it will be used for dict lookup in a template.
Tests must be named such that they will have unique keys.
@@ -366,9 +368,7 @@ def increment(value):
except ValueError:
pass
number = number.zfill(width)
return prefix + number
return prefix + str(number).zfill(width)
def decimal2string(d):
@@ -966,7 +966,7 @@ def current_time(local=True):
"""
if settings.USE_TZ:
now = timezone.now()
now = to_local_time(now, target_tz=server_timezone() if local else 'UTC')
now = to_local_time(now, target_tz_str=server_timezone() if local else 'UTC')
return now
else:
return datetime.datetime.now()
@@ -985,12 +985,12 @@ def server_timezone() -> str:
return settings.TIME_ZONE
def to_local_time(time, target_tz: Optional[str] = None):
def to_local_time(time, target_tz_str: Optional[str] = None):
"""Convert the provided time object to the local timezone.
Arguments:
time: The time / date to convert
target_tz: The desired timezone (string) - defaults to server time
target_tz_str: The desired timezone (string) - defaults to server time
Returns:
A timezone aware datetime object, with the desired timezone
@@ -1014,11 +1014,11 @@ def to_local_time(time, target_tz: Optional[str] = None):
# Default to UTC if not provided
source_tz = ZoneInfo('UTC')
if not target_tz:
target_tz = server_timezone()
if not target_tz_str:
target_tz_str = server_timezone()
try:
target_tz = ZoneInfo(str(target_tz))
target_tz = ZoneInfo(str(target_tz_str))
except ZoneInfoNotFoundError:
target_tz = ZoneInfo('UTC')
@@ -114,7 +114,7 @@ def send_email(
return True, None
def get_email_for_user(user) -> str:
def get_email_for_user(user) -> Optional[str]:
"""Find an email address for the specified user."""
# First check if the user has an associated email address
if user.email:
@@ -11,6 +11,7 @@ from django.db.utils import OperationalError, ProgrammingError
from django.utils.translation import gettext_lazy as _
import requests
import requests.exceptions
import structlog
from djmoney.contrib.exchange.models import convert_money
from djmoney.money import Money
@@ -328,8 +329,9 @@ def notify_users(
'template': {'subject': content.name.format(**content_context)},
}
if content.template:
context['template']['html'] = content.template.format(**content_context)
tmp = content.template
if tmp:
context['template']['html'] = tmp.format(**content_context)
# Create notification
trigger_notification(
@@ -1,7 +1,7 @@
"""Extended schema generator."""
from pathlib import Path
from typing import TypeVar
from typing import TypeVar, Union
from django.conf import settings
@@ -26,7 +26,7 @@ def prep_name(ref):
return f'{dja_ref_prefix}.{ref}'
def sub_component_name(name: T) -> T:
def sub_component_name(name: T) -> Union[T, str]:
"""Clean up component references."""
if not isinstance(name, str):
return name
@@ -2,9 +2,10 @@
import time
from django.core.exceptions import ImproperlyConfigured
from django.core.management.base import BaseCommand
from django.db import connection
from django.db.utils import ImproperlyConfigured, OperationalError
from django.db.utils import OperationalError
class Command(BaseCommand):
+1 -1
View File
@@ -1172,7 +1172,7 @@ class InvenTreeBarcodeMixin(models.Model):
raise ValueError("Provide either 'barcode_hash' or 'barcode_data'")
# If barcode_hash is not provided, create from supplier barcode_data
if barcode_hash is None:
if barcode_hash is None and barcode_data is not None:
barcode_hash = InvenTree.helpers.hash_barcode(barcode_data)
# Check for existing item
+11 -5
View File
@@ -256,7 +256,9 @@ class RolePermissionOrReadOnly(RolePermission):
def get_required_alternate_scopes(self, request, view):
"""Return the required scopes for the current request."""
scopes = map_scope(
only_read=True, read_name=DEFAULT_STAFF, map_read=permissions.SAFE_METHODS
only_read=True,
read_name=DEFAULT_STAFF,
map_read=list(permissions.SAFE_METHODS),
)
return scopes
@@ -294,7 +296,7 @@ class IsSuperuserOrReadOnlyOrScope(OASTokenMixin, permissions.IsAdminUser):
return map_scope(
only_read=True,
read_name=DEFAULT_SUPERUSER,
map_read=permissions.SAFE_METHODS,
map_read=list(permissions.SAFE_METHODS),
)
@@ -319,7 +321,9 @@ class IsStaffOrReadOnlyScope(OASTokenMixin, permissions.IsAuthenticated):
def get_required_alternate_scopes(self, request, view):
"""Return the required scopes for the current request."""
return map_scope(
only_read=True, read_name=DEFAULT_STAFF, map_read=permissions.SAFE_METHODS
only_read=True,
read_name=DEFAULT_STAFF,
map_read=list(permissions.SAFE_METHODS),
)
@@ -349,7 +353,7 @@ def auth_exempt(view_func):
def wrapped_view(*args, **kwargs):
return view_func(*args, **kwargs)
wrapped_view.auth_exempt = True
wrapped_view.auth_exempt = True # type:ignore[unresolved-attribute]
return wraps(view_func)(wrapped_view)
@@ -400,7 +404,9 @@ class GlobalSettingsPermissions(OASTokenMixin, permissions.BasePermission):
def get_required_alternate_scopes(self, request, view):
"""Return the required scopes for the current request."""
return map_scope(
only_read=True, read_name=DEFAULT_STAFF, map_read=permissions.SAFE_METHODS
only_read=True,
read_name=DEFAULT_STAFF,
map_read=list(permissions.SAFE_METHODS),
)
+2 -2
View File
@@ -188,8 +188,8 @@ ALLOWED_ATTRIBUTES_SVG = [
def sanitize_svg(
file_data,
strip: bool = True,
elements: str = ALLOWED_ELEMENTS_SVG,
attributes: str = ALLOWED_ATTRIBUTES_SVG,
elements: list[str] = ALLOWED_ELEMENTS_SVG,
attributes: list[str] = ALLOWED_ATTRIBUTES_SVG,
) -> str:
"""Sanitize a SVG file.
@@ -373,16 +373,15 @@ class InvenTreeModelSerializer(serializers.ModelSerializer):
instance.full_clean()
except (ValidationError, DjangoValidationError) as exc:
if hasattr(exc, 'message_dict'):
data = exc.message_dict
data = {**exc.message_dict}
elif hasattr(exc, 'message'):
data = {'non_field_errors': [str(exc.message)]}
else:
data = {'non_field_errors': [str(exc)]}
# Change '__all__' key (django style) to 'non_field_errors' (DRF style)
if '__all__' in data:
data['non_field_errors'] = data['__all__']
del data['__all__']
if hasattr(data, '__all__'):
data['non_field_errors'] = data.pop('__all__')
raise ValidationError(data)
+27 -15
View File
@@ -43,6 +43,13 @@ from users.oauth2_scopes import oauth2_scopes
from . import config, locales
try:
import django_stubs_ext
django_stubs_ext.monkeypatch() # pragma: no cover
except ImportError: # pragma: no cover
pass
checkMinPythonVersion()
INVENTREE_BASE_URL = 'https://inventree.org'
@@ -382,22 +389,25 @@ QUERYCOUNT = {
}
AUTHENTICATION_BACKENDS = CONFIG.get(
'authentication_backends',
[
'oauth2_provider.backends.OAuth2Backend', # OAuth2 provider
'django.contrib.auth.backends.RemoteUserBackend', # proxy login
'django.contrib.auth.backends.ModelBackend',
'allauth.account.auth_backends.AuthenticationBackend', # SSO login via external providers
'sesame.backends.ModelBackend', # Magic link login django-sesame
],
default_auth_backends = [
'oauth2_provider.backends.OAuth2Backend', # OAuth2 provider
'django.contrib.auth.backends.RemoteUserBackend', # proxy login
'django.contrib.auth.backends.ModelBackend',
'allauth.account.auth_backends.AuthenticationBackend', # SSO login via external providers
'sesame.backends.ModelBackend', # Magic link login django-sesame
]
AUTHENTICATION_BACKENDS = (
CONFIG.get('authentication_backends', default_auth_backends)
if CONFIG
else default_auth_backends
)
# LDAP support
LDAP_AUTH = get_boolean_setting('INVENTREE_LDAP_ENABLED', 'ldap.enabled', False)
if LDAP_AUTH:
import django_auth_ldap.config
import ldap
import django_auth_ldap.config # type: ignore[unresolved-import]
import ldap # type: ignore[unresolved-import]
AUTHENTICATION_BACKENDS.append('django_auth_ldap.backend.LDAPBackend')
@@ -450,7 +460,7 @@ if LDAP_AUTH:
)
AUTH_LDAP_USER_SEARCH = django_auth_ldap.config.LDAPSearch(
get_setting('INVENTREE_LDAP_SEARCH_BASE_DN', 'ldap.search_base_dn'),
ldap.SCOPE_SUBTREE,
ldap.SCOPE_SUBTREE, # type: ignore[unresolved-attribute]
str(
get_setting(
'INVENTREE_LDAP_SEARCH_FILTER_STR',
@@ -486,7 +496,7 @@ if LDAP_AUTH:
)
AUTH_LDAP_GROUP_SEARCH = django_auth_ldap.config.LDAPSearch(
get_setting('INVENTREE_LDAP_GROUP_SEARCH', 'ldap.group_search'),
ldap.SCOPE_SUBTREE,
ldap.SCOPE_SUBTREE, # type: ignore[unresolved-attribute]
f'(objectClass={AUTH_LDAP_GROUP_OBJECT_CLASS})',
)
AUTH_LDAP_GROUP_TYPE_CLASS = get_setting(
@@ -604,7 +614,7 @@ Configure the database backend based on the user-specified values.
logger.debug('Configuring database backend:')
# Extract database configuration from the config.yaml file
db_config = CONFIG.get('database', None)
db_config = CONFIG.get('database', None) if CONFIG else None
if not db_config:
db_config = {}
@@ -690,7 +700,9 @@ if db_options is None:
# Specific options for postgres backend
if 'postgres' in DB_ENGINE: # pragma: no cover
from django.db.backends.postgresql.psycopg_any import IsolationLevel
from django.db.backends.postgresql.psycopg_any import ( # type: ignore[unresolved-import]
IsolationLevel,
)
# Connection timeout
if 'connect_timeout' not in db_options:
+2 -1
View File
@@ -50,7 +50,7 @@ def check_provider(provider):
if not app:
return False
if allauth.app_settings.SITES_ENABLED:
if allauth.app_settings.SITES_ENABLED: # type: ignore[unresolved-attribute]
# At least one matching site must be specified
if not app.sites.exists():
logger.error('SocialApp %s has no sites configured', app)
@@ -102,6 +102,7 @@ def ensure_sso_groups(sender, sociallogin: SocialLogin, **kwargs):
# ensure user has groups
user = sociallogin.account.user
for group_name in group_names:
try:
user.groups.get(name=group_name)
+5 -3
View File
@@ -285,7 +285,7 @@ class ScheduledTask:
QUARTERLY: str = 'Q'
YEARLY: str = 'Y'
TYPE: tuple[str] = (MINUTES, HOURLY, DAILY, WEEKLY, MONTHLY, QUARTERLY, YEARLY)
TYPE: tuple[str] = (MINUTES, HOURLY, DAILY, WEEKLY, MONTHLY, QUARTERLY, YEARLY) # type: ignore[invalid-assignment]
class TaskRegister:
@@ -302,7 +302,9 @@ tasks = TaskRegister()
def scheduled_task(
interval: str, minutes: Optional[int] = None, tasklist: TaskRegister = None
interval: str,
minutes: Optional[int] = None,
tasklist: Optional[TaskRegister] = None,
):
"""Register the given task as a scheduled task.
@@ -544,7 +546,7 @@ def check_for_updates():
match = re.match(r'^.*(\d+)\.(\d+)\.(\d+).*$', tag)
if len(match.groups()) != 3: # pragma: no cover
if not match or len(match.groups()) != 3: # pragma: no cover
logger.warning("Version '%s' did not match expected pattern", tag)
return
+1 -1
View File
@@ -567,7 +567,7 @@ class GeneralApiTests(InvenTreeAPITestCase):
self.assertIn('License file not found at', str(log.output))
with TemporaryDirectory() as tmp:
with TemporaryDirectory() as tmp: # type: ignore[no-matching-overload]
sample_file = Path(tmp, 'temp.txt')
sample_file.write_text('abc')
+8 -3
View File
@@ -4,7 +4,7 @@ import base64
import logging
from typing import Optional
from opentelemetry import metrics, trace
from opentelemetry import metrics, trace # type: ignore[import]
from opentelemetry.instrumentation.django import DjangoInstrumentor
from opentelemetry.instrumentation.redis import RedisInstrumentor
from opentelemetry.instrumentation.requests import RequestsInstrumentor
@@ -29,8 +29,8 @@ TRACE_PROV = None
def setup_tracing(
endpoint: str,
headers: dict,
endpoint: Optional[str] = None,
headers: Optional[dict] = None,
resources_input: Optional[dict] = None,
console: bool = False,
auth: Optional[dict] = None,
@@ -50,6 +50,11 @@ def setup_tracing(
"""
if InvenTree.ready.isImportingData() or InvenTree.ready.isRunningMigrations():
return
if endpoint is None or headers is None:
print(
'Tracing endpoint or headers not specified - skipping tracing setup'
) # pragma: no cover
return # pragma: no cover
# Logger configuration
logger = logging.getLogger('inventree')
+5 -1
View File
@@ -115,7 +115,7 @@ def getOldestMigrationFile(app, exclude_extension=True, ignore_initial=True):
oldest_num = num
oldest_file = f
if exclude_extension:
if exclude_extension and oldest_file:
oldest_file = oldest_file.replace('.py', '')
return oldest_file
@@ -583,6 +583,10 @@ class InvenTreeAPITestCase(
result = re.search(
r'(attachment|inline); filename=[\'"]([\w\d\-.]+)[\'"]', disposition
)
if not result:
raise ValueError(
'No filename match found in disposition'
) # pragma: no cover
fn = result.groups()[1]
@@ -6,6 +6,7 @@ from django.core.exceptions import ValidationError
from django.utils.translation import gettext_lazy as _
import pint
import pint.errors
from moneyed import CURRENCIES
import InvenTree.conversion
+1 -1
View File
@@ -107,7 +107,7 @@ def inventreeVersionTuple(version=None):
match = re.match(r'^.*(\d+)\.(\d+)\.(\d+).*$', str(version))
return [int(g) for g in match.groups()]
return [int(g) for g in match.groups()] if match else []
def isInvenTreeDevelopmentVersion():
+10 -6
View File
@@ -2,17 +2,21 @@
from __future__ import annotations
from typing import Optional
from django.contrib.auth.models import User
from django.db.models import F, Q
from django.urls import include, path
from django.utils.translation import gettext_lazy as _
from django_filters import rest_framework as rest_filters
from django_filters.rest_framework.filterset import FilterSet
from drf_spectacular.utils import extend_schema, extend_schema_field
from rest_framework import serializers, status
from rest_framework.exceptions import ValidationError
from rest_framework.response import Response
import build.models as build_models
import build.serializers
import common.models
import part.models as part_models
@@ -33,7 +37,7 @@ from InvenTree.mixins import CreateAPI, ListCreateAPI, RetrieveUpdateDestroyAPI
from users.models import Owner
class BuildFilter(rest_filters.FilterSet):
class BuildFilter(FilterSet):
"""Custom filterset for BuildList API endpoint."""
class Meta:
@@ -431,7 +435,7 @@ class BuildUnallocate(CreateAPI):
return ctx
class BuildLineFilter(rest_filters.FilterSet):
class BuildLineFilter(FilterSet):
"""Custom filterset for the BuildLine API endpoint."""
class Meta:
@@ -605,7 +609,7 @@ class BuildLineList(BuildLineMixin, DataExportViewMixin, ListCreateAPI):
'bom_item__reference',
]
def get_source_build(self) -> Build | None:
def get_source_build(self) -> Optional[Build]:
"""Return the target build for the BuildLine queryset."""
source_build = None
@@ -622,7 +626,7 @@ class BuildLineList(BuildLineMixin, DataExportViewMixin, ListCreateAPI):
class BuildLineDetail(BuildLineMixin, RetrieveUpdateDestroyAPI):
"""API endpoint for detail view of a BuildLine object."""
def get_source_build(self) -> Build | None:
def get_source_build(self) -> Optional[Build]:
"""Return the target source location for the BuildLine queryset."""
return None
@@ -783,7 +787,7 @@ class BuildItemDetail(RetrieveUpdateDestroyAPI):
serializer_class = build.serializers.BuildItemSerializer
class BuildItemFilter(rest_filters.FilterSet):
class BuildItemFilter(FilterSet):
"""Custom filterset for the BuildItemList API endpoint."""
class Meta:
@@ -829,7 +833,7 @@ class BuildItemFilter(rest_filters.FilterSet):
return queryset.filter(stock_item__part=part)
build = rest_filters.ModelChoiceFilter(
queryset=build.models.Build.objects.all(),
queryset=build_models.Build.objects.all(),
label=_('Build Order'),
field_name='build_line__build',
)
+2 -2
View File
@@ -1064,7 +1064,7 @@ class Build(
lines = lines.exclude(bom_item__consumable=True)
lines = annotate_allocated_quantity(lines)
for build_line in lines:
for build_line in lines: # type: ignore[non-iterable]
reduce_by = build_line.allocated - build_line.quantity
if reduce_by <= 0:
@@ -1359,7 +1359,7 @@ class Build(
except (ValidationError, serializers.ValidationError) as exc:
# Catch model errors and re-throw as DRF errors
raise ValidationError(
detail=serializers.as_serializer_error(exc)
exc.message, detail=serializers.as_serializer_error(exc)
)
if unallocated_quantity <= 0:
@@ -23,6 +23,7 @@ from rest_framework.serializers import ValidationError
import build.tasks
import common.models
import common.settings
import company.serializers
import InvenTree.helpers
import InvenTree.tasks
+17 -2
View File
@@ -1,6 +1,7 @@
"""Unit tests for the BuildOrder API."""
from datetime import datetime, timedelta
from typing import Optional
from django.urls import reverse
@@ -668,6 +669,11 @@ class BuildAllocationTest(BuildAPITest):
wrong_line = line
break
if not wrong_line:
raise self.fail(
'No matching BuildLine found for the given stock item'
) # pragma: no cover
data = self.post(
self.url,
{
@@ -695,6 +701,11 @@ class BuildAllocationTest(BuildAPITest):
right_line = line
break
if not right_line:
raise self.fail(
'No matching BuildLine found for the given stock item'
) # pragma: no cover
self.post(
self.url,
{
@@ -722,11 +733,15 @@ class BuildAllocationTest(BuildAPITest):
# Find the correct BuildLine
si = StockItem.objects.get(pk=2)
right_line = None
right_line: Optional[BuildLine] = None
for line in self.build.build_lines.all():
if line.bom_item.sub_part.pk == si.part.pk:
right_line = line
right_line: BuildLine = line
break
if not right_line:
raise self.fail(
'No matching BuildLine found for the given stock item'
) # pragma: no cover
self.post(
self.url,
+4 -2
View File
@@ -1,6 +1,7 @@
"""Provides a JSON API for common components."""
import json
import json.decoder
from django.conf import settings
from django.contrib.contenttypes.models import ContentType
@@ -13,8 +14,9 @@ from django.utils.translation import gettext_lazy as _
from django.views.decorators.cache import cache_control
from django.views.decorators.csrf import csrf_exempt
import django_filters.rest_framework.filters as rest_filters
import django_q.models
from django_filters import rest_framework as rest_filters
from django_filters.rest_framework.filterset import FilterSet
from django_q.tasks import async_task
from djmoney.contrib.exchange.models import ExchangeBackend, Rate
from drf_spectacular.utils import OpenApiResponse, extend_schema
@@ -676,7 +678,7 @@ class ContentTypeModelDetail(ContentTypeDetail):
return super().get(request, *args, **kwargs)
class AttachmentFilter(rest_filters.FilterSet):
class AttachmentFilter(FilterSet):
"""Filterset for the AttachmentList API endpoint."""
class Meta:
+1 -1
View File
@@ -160,7 +160,7 @@ def get_price(
- If MOQ (minimum order quantity) is required, bump quantity
- If order multiples are to be observed, then we need to calculate based on that, too
"""
from common.currency import currency_code_default
# from common.currency import currency_code_default
if hasattr(instance, break_name):
price_breaks = getattr(instance, break_name).all()
+6 -5
View File
@@ -30,6 +30,7 @@ from django.core.mail import EmailMultiAlternatives, get_connection
from django.core.mail.utils import DNS_NAME
from django.core.validators import MinValueValidator
from django.db import models, transaction
from django.db.models import enums
from django.db.models.signals import post_delete, post_save
from django.db.utils import IntegrityError, OperationalError, ProgrammingError
from django.dispatch import receiver
@@ -66,7 +67,7 @@ from InvenTree.version import inventree_identifier
logger = structlog.get_logger('inventree')
class RenderMeta(models.enums.ChoicesMeta):
class RenderMeta(enums.ChoicesMeta):
"""Metaclass for rendering choices."""
choice_fnc = None
@@ -80,7 +81,7 @@ class RenderMeta(models.enums.ChoicesMeta):
return []
class RenderChoices(models.TextChoices, metaclass=RenderMeta):
class RenderChoices(models.TextChoices, metaclass=RenderMeta): # type: ignore
"""Class for creating enumerated string choices for schema rendering."""
@@ -971,7 +972,7 @@ class BaseInvenTreeSetting(models.Model):
return setting.get('model', None)
def model_filters(self) -> dict:
def model_filters(self) -> Optional[dict]:
"""Return the model filters associated with this setting."""
setting = self.get_setting_definition(
self.key, **self.get_filters_for_instance()
@@ -1505,8 +1506,8 @@ class WebhookEndpoint(models.Model):
request (optional): Original request object. Defaults to None.
"""
return WebhookMessage.objects.create(
host=request.get_host(),
header=json.dumps(dict(headers.items())),
host=request.get_host() if request else '',
header=json.dumps(dict(headers.items())) if headers else None,
body=payload,
endpoint=self,
)
@@ -84,9 +84,7 @@ class InvenTreeNotificationBodies:
)
def trigger_notification(
obj: Model, category: Optional[str] = None, obj_ref: str = 'pk', **kwargs
):
def trigger_notification(obj: Model, category: str = '', obj_ref: str = 'pk', **kwargs):
"""Send out a notification.
Args:
+8 -5
View File
@@ -19,7 +19,7 @@ from django.test import Client, TestCase
from django.test.utils import override_settings
from django.urls import reverse
import PIL
from PIL import Image
import common.validators
from common.notifications import trigger_notification
@@ -200,7 +200,7 @@ class AttachmentTest(InvenTreeAPITestCase):
# Assign 'delete' permission to 'part' model
self.assignRole('part.delete')
response = self.delete(url, expected_code=204)
self.delete(url, expected_code=204)
class SettingsTest(InvenTreeTestCase):
@@ -671,9 +671,9 @@ class GlobalSettingsApiTest(InvenTreeAPITestCase):
# Find the associated setting
setting = next((s for s in response.data if s['key'] == key), None)
assert setting is not None
# Check default value (should be False, not 'False')
self.assertIsNotNone(setting)
self.assertFalse(setting['value'])
# Check that we can manually set the value
@@ -851,9 +851,9 @@ class UserSettingsApiTest(InvenTreeAPITestCase):
# Find the associated setting
setting = next((s for s in response.data if s['key'] == key), None)
assert setting is not None
# Check default value (should be 10, not '10')
self.assertIsNotNone(setting)
self.assertEqual(setting['value'], 10)
# Check that writing an invalid value returns an error
@@ -1535,7 +1535,7 @@ class NotesImageTest(InvenTreeAPITestCase):
n = NotesImage.objects.count()
# Construct a simple image file
image = PIL.Image.new('RGB', (100, 100), color='red')
image = Image.new('RGB', (100, 100), color='red')
with io.BytesIO() as output:
image.save(output, format='PNG')
@@ -1589,6 +1589,7 @@ class ProjectCodesTest(InvenTreeAPITestCase):
# Get the first project code
code = ProjectCode.objects.first()
assert code is not None and code.pk
# Delete it
self.delete(
@@ -1686,6 +1687,7 @@ class CustomUnitAPITest(InvenTreeAPITestCase):
def test_edit(self):
"""Test edit permissions for CustomUnit model."""
unit = CustomUnit.objects.first()
assert unit is not None and unit.pk
# Try to edit without permission
self.user.is_staff = False
@@ -1713,6 +1715,7 @@ class CustomUnitAPITest(InvenTreeAPITestCase):
def test_validation(self):
"""Test that validation works as expected."""
unit = CustomUnit.objects.first()
assert unit is not None and unit.pk
self.user.is_staff = True
self.user.save()
+6 -5
View File
@@ -4,7 +4,8 @@ from django.db.models import Q
from django.urls import include, path
from django.utils.translation import gettext_lazy as _
from django_filters import rest_framework as rest_filters
import django_filters.rest_framework.filters as rest_filters
from django_filters.rest_framework.filterset import FilterSet
import part.models
from data_exporter.mixins import DataExportViewMixin
@@ -127,7 +128,7 @@ class AddressDetail(RetrieveUpdateDestroyAPI):
serializer_class = AddressSerializer
class ManufacturerPartFilter(rest_filters.FilterSet):
class ManufacturerPartFilter(FilterSet):
"""Custom API filters for the ManufacturerPart list endpoint."""
class Meta:
@@ -204,7 +205,7 @@ class ManufacturerPartDetail(RetrieveUpdateDestroyAPI):
serializer_class = ManufacturerPartSerializer
class ManufacturerPartParameterFilter(rest_filters.FilterSet):
class ManufacturerPartParameterFilter(FilterSet):
"""Custom filterset for the ManufacturerPartParameterList API endpoint."""
class Meta:
@@ -259,7 +260,7 @@ class ManufacturerPartParameterDetail(RetrieveUpdateDestroyAPI):
serializer_class = ManufacturerPartParameterSerializer
class SupplierPartFilter(rest_filters.FilterSet):
class SupplierPartFilter(FilterSet):
"""API filters for the SupplierPartList endpoint."""
class Meta:
@@ -418,7 +419,7 @@ class SupplierPartDetail(SupplierPartMixin, RetrieveUpdateDestroyAPI):
"""
class SupplierPriceBreakFilter(rest_filters.FilterSet):
class SupplierPriceBreakFilter(FilterSet):
"""Custom API filters for the SupplierPriceBreak list endpoint."""
class Meta:
@@ -51,7 +51,7 @@ def reverse_association(apps, schema_editor): # pragma: no cover
row = cursor.fetchone()
if len(row) > 0:
if row and len(row) > 0:
try:
manufacturer_id = int(row[0])
except (TypeError, ValueError):
@@ -67,12 +67,12 @@ def reverse_association(apps, schema_editor): # pragma: no cover
response = cursor.execute(f"SELECT name from company_company where id={manufacturer_id};")
row = cursor.fetchone()
if row:
name = row[0]
name = row[0]
print(" - Manufacturer name: '{name}'".format(name=name))
print(" - Manufacturer name: '{name}'".format(name=name))
response = cursor.execute("UPDATE part_supplierpart SET manufacturer_name='{name}' WHERE id={ID};".format(name=name, ID=supplier_part_id))
response = cursor.execute("UPDATE part_supplierpart SET manufacturer_name='{name}' WHERE id={ID};".format(name=name, ID=supplier_part_id))
def associate_manufacturers(apps, schema_editor):
"""
@@ -106,7 +106,7 @@ def associate_manufacturers(apps, schema_editor):
response = cursor.execute(query)
row = cursor.fetchone()
if len(row) > 0:
if row and len(row) > 0:
return row[0]
return '' # pragma: no cover
@@ -296,11 +296,11 @@ def associate_manufacturers(apps, schema_editor):
# Double-check if the typed name corresponds to an existing item
elif response in companies.keys():
link_part(part, companies[response])
link_part(part_id, companies[response])
return
elif response in links.keys():
link_part(part, links[response])
link_part(part_id, links[response])
return
# No match, create a new manufacturer
+14 -1
View File
@@ -156,7 +156,10 @@ class CompanyTest(InvenTreeAPITestCase):
def test_company_notes(self):
"""Test the markdown 'notes' field for the Company model."""
pk = Company.objects.first().pk
company = Company.objects.first()
assert company
pk = company.pk
url = reverse('api-company-detail', kwargs={'pk': pk})
# Attempt to inject malicious markdown into the "notes" field
@@ -253,6 +256,7 @@ class ContactTest(InvenTreeAPITestCase):
n = Contact.objects.count()
company = Company.objects.first()
assert company
# Without required permissions, creation should fail
self.post(
@@ -271,6 +275,8 @@ class ContactTest(InvenTreeAPITestCase):
"""Test that we can edit a Contact via the API."""
# Get the first contact
contact = Contact.objects.first()
assert contact
# Use this contact in the tests
url = reverse('api-contact-detail', kwargs={'pk': contact.pk})
@@ -294,6 +300,8 @@ class ContactTest(InvenTreeAPITestCase):
"""Tests that we can delete a Contact via the API."""
# Get the last contact
contact = Contact.objects.first()
assert contact
url = reverse('api-contact-detail', kwargs={'pk': contact.pk})
# Delete (without required permissions)
@@ -348,6 +356,7 @@ class AddressTest(InvenTreeAPITestCase):
def test_filter_list(self):
"""Test listing addresses filtered on company."""
company = Company.objects.first()
assert company
response = self.get(self.url, {'company': company.pk}, expected_code=200)
@@ -356,6 +365,7 @@ class AddressTest(InvenTreeAPITestCase):
def test_create(self):
"""Test creating a new address."""
company = Company.objects.first()
assert company
self.post(self.url, {'company': company.pk, 'title': 'HQ'}, expected_code=403)
@@ -366,6 +376,7 @@ class AddressTest(InvenTreeAPITestCase):
def test_get(self):
"""Test that objects are properly returned from a get."""
addr = Address.objects.first()
assert addr
url = reverse('api-address-detail', kwargs={'pk': addr.pk})
response = self.get(url, expected_code=200)
@@ -386,6 +397,7 @@ class AddressTest(InvenTreeAPITestCase):
def test_edit(self):
"""Test editing an Address object."""
addr = Address.objects.first()
assert addr
url = reverse('api-address-detail', kwargs={'pk': addr.pk})
@@ -403,6 +415,7 @@ class AddressTest(InvenTreeAPITestCase):
def test_delete(self):
"""Test deleting an object."""
addr = Address.objects.first()
assert addr
url = reverse('api-address-detail', kwargs={'pk': addr.pk})
+1 -1
View File
@@ -18,7 +18,7 @@ class DataExporterConfig(AppConfig):
def cleanup(self):
"""Cleanup any old export files."""
try:
from data_exporter.tasks import cleanup_old_export_outputs
from data_exporter.tasks import cleanup_old_export_outputs # type: ignore
cleanup_old_export_outputs()
except Exception:
@@ -1,6 +1,7 @@
"""Mixin classes for the exporter app."""
from collections import OrderedDict
from typing import Any
from django.core.exceptions import ValidationError
from django.core.files.base import ContentFile
@@ -127,7 +128,7 @@ class DataExportSerializerMixin:
"""
return headers
def get_nested_value(self, row: dict, key: str) -> any:
def get_nested_value(self, row: dict, key: str) -> Any:
"""Get a nested value from a dictionary.
This method allows for dot notation to access nested fields.
@@ -6,7 +6,6 @@ from rest_framework import serializers
import InvenTree.exceptions
import InvenTree.helpers
import InvenTree.serializers
from plugin import PluginMixinEnum, registry
@@ -53,7 +52,7 @@ class DataExportOptionsSerializer(serializers.Serializer):
try:
supports_export = plugin.supports_export(
model_class,
user=request.user,
user=request.user if request else None,
serializer_class=serializer_class,
view_class=view_class,
)
@@ -6,6 +6,7 @@ There is a rendered state for each state value. The rendered state is used for d
States can be extended with custom options for each InvenTree instance - those options are stored in the database and need to link back to state values.
"""
from . import fields
from .states import ColorEnum, StatusCode, StatusCodeMixin
from .transition import StateTransitionMixin, TransitionMethod
@@ -15,4 +16,5 @@ __all__ = [
'StatusCode',
'StatusCodeMixin',
'TransitionMethod',
'fields',
]
@@ -4,6 +4,7 @@ import enum
import logging
import re
from enum import Enum
from typing import Optional
logger = logging.getLogger('inventree')
@@ -297,7 +298,7 @@ class StatusCodeMixin:
"""Return the status code for this object."""
return getattr(self, self.STATUS_FIELD)
def get_custom_status(self) -> int:
def get_custom_status(self) -> Optional[int]:
"""Return the custom status code for this object."""
return getattr(self, f'{self.STATUS_FIELD}_custom_key', None)
+2 -1
View File
@@ -20,4 +20,5 @@ def status_label(typ: str, key: int, include_custom: bool = False, *args, **kwar
def display_status_label(typ: str, key: int, fallback: int, *args, **kwargs):
"""Render a status label."""
render_key = int(key) if key else fallback
return status_label(typ, render_key, *args, include_custom=True, **kwargs)
kwargs['include_custom'] = True
return status_label(typ, render_key, *args, **kwargs)
@@ -1,5 +1,7 @@
"""Classes and functions for plugin controlled object state transitions."""
from typing import Callable
from django.db.models import Model
import structlog
@@ -30,7 +32,7 @@ class TransitionMethod:
current_state: int,
target_state: int,
instance: Model,
default_action: callable,
default_action: Callable,
**kwargs,
) -> bool:
"""Perform a state transition.
+2 -4
View File
@@ -303,9 +303,7 @@ class DataImportSession(models.Model):
if not any(row_data.values()):
continue
row = importer.models.DataImportRow(
session=self, row_data=row_data, row_index=idx
)
row = DataImportRow(session=self, row_data=row_data, row_index=idx)
row.extract_data(
field_mapping=field_mapping,
@@ -317,7 +315,7 @@ class DataImportSession(models.Model):
imported_rows.append(row)
# Perform database writes as a single operation
importer.models.DataImportRow.objects.bulk_create(imported_rows)
DataImportRow.objects.bulk_create(imported_rows)
# Mark the import task as "PROCESSING"
self.status = DataImportStatusCode.PROCESSING.value
+4 -1
View File
@@ -1,9 +1,12 @@
"""Data import operational functions."""
from typing import Optional
from django.core.exceptions import ValidationError
from django.utils.translation import gettext_lazy as _
import tablib
import tablib.core
import InvenTree.helpers
@@ -82,7 +85,7 @@ def extract_column_names(data_file) -> list:
return headers
def get_field_label(field) -> str:
def get_field_label(field) -> Optional[str]:
"""Return the label for a field in a serializer class.
Check for labels in the following order of descending priority:
+2 -2
View File
@@ -1,7 +1,7 @@
"""Models for the machine app."""
import uuid
from typing import Literal
from typing import Literal, Optional
from django.contrib import admin
from django.db import models
@@ -192,7 +192,7 @@ class MachineSetting(common.models.BaseInvenTreeSetting):
If not provided, we'll look at the machine registry to see what settings this machine driver requires
"""
if 'settings' not in kwargs:
machine_config: MachineConfig = kwargs.pop('machine_config', None)
machine_config: Optional[MachineConfig] = kwargs.pop('machine_config', None)
if machine_config and machine_config.machine:
config_type = kwargs.get('config_type')
if config_type == cls.ConfigType.DRIVER:
+3 -1
View File
@@ -423,7 +423,9 @@ class MachineRegistry(
# If the plugin registry has changed, the machine registry hash will change
plugin_registry.update_plugin_hash()
data.update(plugin_registry.registry_hash.encode())
current_hash = plugin_registry.registry_hash
if current_hash:
data.update(current_hash.encode())
for pk, machine in self.machines.items():
data.update(str(pk).encode())
+1
View File
@@ -220,6 +220,7 @@ class TestLabelPrinterMachineType(InvenTreeAPITestCase):
parts = Part.objects.all()[:2]
template = LabelTemplate.objects.filter(enabled=True, model_type='part').first()
assert template
url = reverse('api-label-print')
+6 -5
View File
@@ -11,8 +11,9 @@ from django.http.response import JsonResponse
from django.urls import include, path, re_path
from django.utils.translation import gettext_lazy as _
import django_filters.rest_framework.filters as rest_filters
import rest_framework.serializers
from django_filters import rest_framework as rest_filters
from django_filters.rest_framework.filterset import FilterSet
from django_ical.views import ICalFeed
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import extend_schema, extend_schema_field
@@ -100,7 +101,7 @@ class OrderCreateMixin:
)
class OrderFilter(rest_filters.FilterSet):
class OrderFilter(FilterSet):
"""Base class for custom API filters for the OrderList endpoint."""
# Filter against order status
@@ -258,7 +259,7 @@ class OrderFilter(rest_filters.FilterSet):
return queryset.filter(q1 | q2 | q3 | q4).distinct()
class LineItemFilter(rest_filters.FilterSet):
class LineItemFilter(FilterSet):
"""Base class for custom API filters for order line item list(s)."""
# Filter by order status
@@ -1147,7 +1148,7 @@ class SalesOrderAllocate(SalesOrderContextMixin, CreateAPI):
serializer_class = serializers.SalesOrderShipmentAllocationSerializer
class SalesOrderAllocationFilter(rest_filters.FilterSet):
class SalesOrderAllocationFilter(FilterSet):
"""Custom filterset for the SalesOrderAllocationList endpoint."""
class Meta:
@@ -1321,7 +1322,7 @@ class SalesOrderAllocationDetail(SalesOrderAllocationMixin, RetrieveUpdateDestro
"""API endpoint for detali view of a SalesOrderAllocation object."""
class SalesOrderShipmentFilter(rest_filters.FilterSet):
class SalesOrderShipmentFilter(FilterSet):
"""Custom filterset for the SalesOrderShipmentList endpoint."""
class Meta:
+3 -2
View File
@@ -1,6 +1,7 @@
"""Background tasks for the 'order' app."""
from datetime import datetime, timedelta
from typing import Union
from django.contrib.auth.models import Group, User
from django.db import transaction
@@ -104,7 +105,7 @@ def check_overdue_purchase_orders():
@tracer.start_as_current_span('notify_overdue_sales_order')
def notify_overdue_sales_order(so: order.models.SalesOrder) -> None:
"""Notify appropriate users that a SalesOrder has just become 'overdue'."""
targets: list[User, Group, Owner] = []
targets: list[Union[User, Group, Owner]] = []
if so.created_by:
targets.append(so.created_by)
@@ -171,7 +172,7 @@ def check_overdue_sales_orders():
@tracer.start_as_current_span('notify_overdue_return_order')
def notify_overdue_return_order(ro: order.models.ReturnOrder) -> None:
"""Notify appropriate users that a ReturnOrder has just become 'overdue'."""
targets: list[User, Group, Owner] = []
targets: list[Union[User, Group, Owner]] = []
if ro.created_by:
targets.append(ro.created_by)
+17 -4
View File
@@ -4,6 +4,7 @@ import base64
import io
import json
from datetime import date, datetime, timedelta
from typing import Optional
from django.core.exceptions import ValidationError
from django.db import connection
@@ -420,7 +421,9 @@ class PurchaseOrderTest(OrderTest):
self.assertIn('Responsible user or group must be specified', str(response.data))
data['responsible'] = Owner.objects.first().pk
owner = Owner.objects.first()
assert owner
data['responsible'] = owner.pk
response = self.post(url, data, expected_code=201)
@@ -1689,6 +1692,7 @@ class SalesOrderTest(OrderTest):
shipment = models.SalesOrderShipment.objects.create(
order=so, reference='SHIP-12345'
)
assert shipment
# Allocate some stock
item = StockItem.objects.create(part=part, quantity=100, location=None)
@@ -1825,10 +1829,13 @@ class SalesOrderLineItemTest(OrderTest):
self.assignRole('sales_order.add')
# Crete a new SalesOrder via the API
company = Company.objects.filter(is_customer=True).first()
assert company
response = self.post(
reverse('api-so-list'),
{
'customer': Company.objects.filter(is_customer=True).first().pk,
'customer': company.pk,
'reference': 'SO-12345',
'description': 'Test Sales Order',
},
@@ -1878,6 +1885,7 @@ class SalesOrderLineItemTest(OrderTest):
p = Part.objects.get(pk=item)
s = StockItem.objects.create(part=p, quantity=100)
l = models.SalesOrderLineItem.objects.filter(order=order, part=p).first()
assert l
# Allocate against the API
self.post(
@@ -2099,12 +2107,14 @@ class SalesOrderAllocateTest(OrderTest):
return line_item.part.is_template
for line in filter(check_template, self.order.lines.all()):
stock_item = None
stock_item: Optional[StockItem] = None
stock_item = None
# Allocate a matching variant
parts = Part.objects.filter(salable=True).filter(variant_of=line.part.pk)
parts: list[Part] = Part.objects.filter(salable=True).filter(
variant_of=line.part.pk
)
for part in parts:
stock_item = part.stock_items.last()
@@ -2118,6 +2128,9 @@ class SalesOrderAllocateTest(OrderTest):
if stock_item is not None:
break
if stock_item is None:
raise self.fail('No stock item found for part') # pragma: no cover
# Fully-allocate each line
data['items'].append({
'line_item': line.pk,
+15 -14
View File
@@ -6,8 +6,9 @@ from django.db.models import Count, F, Q
from django.urls import include, path
from django.utils.translation import gettext_lazy as _
from django_filters import rest_framework as rest_filters
import django_filters.rest_framework.filters as rest_filters
from django_filters.rest_framework import DjangoFilterBackend
from django_filters.rest_framework.filterset import FilterSet
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import extend_schema_field
from rest_framework import serializers
@@ -98,7 +99,7 @@ class CategoryMixin:
return ctx
class CategoryFilter(rest_filters.FilterSet):
class CategoryFilter(FilterSet):
"""Custom filterset class for the PartCategoryList endpoint."""
class Meta:
@@ -282,11 +283,11 @@ class CategoryDetail(CategoryMixin, CustomRetrieveUpdateDestroyAPI):
return super().destroy(
request,
*args,
**dict(
kwargs,
delete_parts=delete_parts,
delete_child_categories=delete_child_categories,
),
**{
**kwargs,
'delete_parts': delete_parts,
'delete_child_categories': delete_child_categories,
},
)
@@ -399,7 +400,7 @@ class PartInternalPriceList(DataExportViewMixin, ListCreateAPI):
ordering = 'quantity'
class PartTestTemplateFilter(rest_filters.FilterSet):
class PartTestTemplateFilter(FilterSet):
"""Custom filterset class for the PartTestTemplateList endpoint."""
class Meta:
@@ -644,7 +645,7 @@ class PartValidateBOM(RetrieveUpdateAPI):
return Response(serializer.data)
class PartFilter(rest_filters.FilterSet):
class PartFilter(FilterSet):
"""Custom filters for the PartList endpoint.
Uses the django_filters extension framework
@@ -1196,7 +1197,7 @@ class PartDetail(PartMixin, RetrieveUpdateDestroyAPI):
return response
class PartRelatedFilter(rest_filters.FilterSet):
class PartRelatedFilter(FilterSet):
"""FilterSet for PartRelated objects."""
class Meta:
@@ -1243,7 +1244,7 @@ class PartRelatedDetail(PartRelatedMixin, RetrieveUpdateDestroyAPI):
"""API endpoint for accessing detail view of a PartRelated object."""
class PartParameterTemplateFilter(rest_filters.FilterSet):
class PartParameterTemplateFilter(FilterSet):
"""FilterSet for PartParameterTemplate objects."""
class Meta:
@@ -1377,7 +1378,7 @@ class PartParameterAPIMixin:
return super().get_serializer(*args, **kwargs)
class PartParameterFilter(rest_filters.FilterSet):
class PartParameterFilter(FilterSet):
"""Custom filters for the PartParameterList API endpoint."""
class Meta:
@@ -1438,7 +1439,7 @@ class PartParameterDetail(PartParameterAPIMixin, RetrieveUpdateDestroyAPI):
"""API endpoint for detail view of a single PartParameter object."""
class PartStocktakeFilter(rest_filters.FilterSet):
class PartStocktakeFilter(FilterSet):
"""Custom filter for the PartStocktakeList endpoint."""
class Meta:
@@ -1480,7 +1481,7 @@ class PartStocktakeDetail(RetrieveUpdateDestroyAPI):
serializer_class = part_serializers.PartStocktakeSerializer
class BomFilter(rest_filters.FilterSet):
class BomFilter(FilterSet):
"""Custom filters for the BOM list."""
class Meta:
+3 -2
View File
@@ -12,6 +12,7 @@ Useful References:
"""
from decimal import Decimal
from typing import Optional
from django.db import models
from django.db.models import (
@@ -137,7 +138,7 @@ def annotate_on_order_quantity(reference: str = '') -> QuerySet:
)
def annotate_total_stock(reference: str = '', filter: Q = None) -> QuerySet:
def annotate_total_stock(reference: str = '', filter: Optional[Q] = None) -> QuerySet:
"""Annotate 'total stock' quantity against a queryset.
- This function calculates the 'total stock' for a given part
@@ -269,7 +270,7 @@ def annotate_sales_order_allocations(reference: str = '', location=None) -> Quer
)
def variant_stock_query(reference: str = '', filter: Q = None) -> QuerySet:
def variant_stock_query(reference: str = '', filter: Optional[Q] = None) -> QuerySet:
"""Create a queryset to retrieve all stock items for variant parts under the specified part.
- Useful for annotating a queryset with aggregated information about variant parts
+2 -2
View File
@@ -819,7 +819,7 @@ class Part(
if not check_duplicates:
return
from part.models import Part
# from part.models import Part
from stock.models import StockItem
if get_global_setting('SERIAL_NUMBER_GLOBALLY_UNIQUE', False):
@@ -850,7 +850,7 @@ class Part(
def find_conflicting_serial_numbers(self, serials: list) -> list:
"""For a provided list of serials, return a list of those which are conflicting."""
from part.models import Part
# from part.models import Part
from stock.models import StockItem
conflicts = []
+6 -4
View File
@@ -11,7 +11,7 @@ from django.db import connection
from django.test.utils import CaptureQueriesContext
from django.urls import reverse
import PIL
from PIL import Image
from rest_framework.test import APIClient
import build.models
@@ -65,7 +65,7 @@ class PartImageTestMixin:
fn = get_testfolder_dir() / 'part_image_123abc.png'
img = PIL.Image.new('RGB', (128, 128), color='blue')
img = Image.new('RGB', (128, 128), color='blue')
img.save(fn)
with open(fn, 'rb') as img_file:
@@ -1770,7 +1770,7 @@ class PartDetailTests(PartImageTestMixin, PartAPITestBase):
for fmt in ['jpg', 'j2k', 'png', 'bmp', 'webp']:
fn = f'{test_path}.{fmt}'
img = PIL.Image.new('RGB', (128, 128), color='red')
img = Image.new('RGB', (128, 128), color='red')
img.save(fn)
with open(fn, 'rb') as dummy_image:
@@ -1820,7 +1820,7 @@ class PartDetailTests(PartImageTestMixin, PartAPITestBase):
fn = get_testfolder_dir() / 'part_image_123abc.png'
img = PIL.Image.new('RGB', (128, 128), color='blue')
img = Image.new('RGB', (128, 128), color='blue')
img.save(fn)
# Upload the image to a part
@@ -2463,6 +2463,7 @@ class BomItemTest(InvenTreeAPITestCase):
# Now, let's validate an item
bom_item = BomItem.objects.first()
assert bom_item
bom_item.validate_hash()
@@ -3109,6 +3110,7 @@ class PartTestTemplateTest(PartAPITestBase):
def test_choices(self):
"""Test the 'choices' field for the PartTestTemplate model."""
template = PartTestTemplate.objects.first()
assert template
url = reverse('api-part-test-template-detail', kwargs={'pk': template.pk})
+1 -1
View File
@@ -2,8 +2,8 @@
from django.contrib import admin
import plugin.registry as pl_registry
from plugin import models
from plugin.registry import registry as pl_registry
def plugin_update(queryset, new_status: bool):
+3 -2
View File
@@ -6,8 +6,9 @@ from django.core.exceptions import ValidationError
from django.urls import include, path, re_path
from django.utils.translation import gettext_lazy as _
from django_filters import rest_framework as rest_filters
import django_filters.rest_framework.filters as rest_filters
from django_filters.rest_framework import DjangoFilterBackend
from django_filters.rest_framework.filterset import FilterSet
from drf_spectacular.utils import extend_schema
from rest_framework import status
from rest_framework.exceptions import NotFound
@@ -36,7 +37,7 @@ from plugin.plugin import InvenTreePlugin
from plugin.registry import registry
class PluginFilter(rest_filters.FilterSet):
class PluginFilter(FilterSet):
"""Filter for the PluginConfig model.
Provides custom filtering options for the FilterList API endpoint.
@@ -5,7 +5,7 @@ from django.urls import include, path
from django.utils.translation import gettext_lazy as _
import structlog
from django_filters import rest_framework as rest_filters
from django_filters.rest_framework.filterset import FilterSet
from drf_spectacular.utils import extend_schema, extend_schema_view
from rest_framework import status
from rest_framework.exceptions import PermissionDenied, ValidationError
@@ -770,7 +770,7 @@ class BarcodeScanResultMixin:
return queryset
class BarcodeScanResultFilter(rest_filters.FilterSet):
class BarcodeScanResultFilter(FilterSet):
"""Custom filterset for the BarcodeScanResult API."""
class Meta:
@@ -2,6 +2,8 @@
from __future__ import annotations
from typing import Optional
from django.core.exceptions import ValidationError
from django.db.models import Q
from django.utils.translation import gettext_lazy as _
@@ -113,7 +115,7 @@ class SupplierBarcodeMixin(BarcodeMixin):
return fields.get(key, backup_value)
def get_part(self) -> Part | None:
def get_part(self) -> Optional[Part]:
"""Extract the Part object from the barcode fields."""
# TODO: Implement this
return None
@@ -128,7 +130,7 @@ class SupplierBarcodeMixin(BarcodeMixin):
"""Return the supplier part number from the barcode fields."""
return self.get_field_value(self.SUPPLIER_PART_NUMBER)
def get_supplier_part(self) -> SupplierPart | None:
def get_supplier_part(self) -> Optional[SupplierPart]:
"""Return the SupplierPart object for the scanned barcode.
Returns:
@@ -172,7 +174,7 @@ class SupplierBarcodeMixin(BarcodeMixin):
"""Return the manufacturer part number from the barcode fields."""
return self.get_field_value(self.MANUFACTURER_PART_NUMBER)
def get_manufacturer_part(self) -> ManufacturerPart | None:
def get_manufacturer_part(self) -> Optional[ManufacturerPart]:
"""Return the ManufacturerPart object for the scanned barcode.
Returns:
@@ -213,7 +215,7 @@ class SupplierBarcodeMixin(BarcodeMixin):
"""Return the supplier order number from the barcode fields."""
return self.get_field_value(self.SUPPLIER_ORDER_NUMBER)
def get_purchase_order(self) -> PurchaseOrder | None:
def get_purchase_order(self) -> Optional[PurchaseOrder]:
"""Extract the PurchaseOrder object from the barcode fields.
Inspect the customer_order_number and supplier_order_number fields,
@@ -260,7 +262,7 @@ class SupplierBarcodeMixin(BarcodeMixin):
'extract_barcode_fields must be implemented by each plugin'
)
def scan(self, barcode_data: str) -> dict:
def scan(self, barcode_data: str) -> Optional[dict]:
"""Perform a generic 'scan' operation on a supplier barcode.
The supplier barcode may provide sufficient information to match against
@@ -319,7 +321,7 @@ class SupplierBarcodeMixin(BarcodeMixin):
location=None,
auto_allocate: bool = True,
**kwargs,
) -> dict | None:
) -> Optional[dict]:
"""Attempt to receive an item against a PurchaseOrder via barcode scanning.
Arguments:
@@ -430,7 +432,7 @@ class SupplierBarcodeMixin(BarcodeMixin):
return response
def get_supplier(self, cache: bool = False) -> Company | None:
def get_supplier(self, cache: bool = False) -> Optional[Company]:
"""Get the supplier for the SUPPLIER_ID set in the plugin settings.
If it's not defined, try to guess it and set it if possible.
@@ -461,9 +463,12 @@ class SupplierBarcodeMixin(BarcodeMixin):
if len(suppliers) != 1:
return _cache_supplier(None)
self.set_setting('SUPPLIER_ID', suppliers.first().pk)
supplier = suppliers.first()
assert supplier
return _cache_supplier(suppliers.first())
self.set_setting('SUPPLIER_ID', supplier.pk)
return _cache_supplier(supplier)
@classmethod
def ecia_field_map(cls):
@@ -163,7 +163,8 @@ class APICallMixin:
url = f'{self.api_url}/{endpoint}'
# build kwargs for call
kwargs.update({'url': url, 'headers': headers})
kwargs.update({'headers': headers})
kwargs.pop('url', None)
if data and json:
raise ValueError('You can either pass `data` or `json` to this function.')
@@ -175,7 +176,7 @@ class APICallMixin:
kwargs['data'] = data
# run command
response = requests.request(method, **kwargs)
response = requests.request(method, url=url, **kwargs)
# return
if simple_response:
@@ -1,7 +1,7 @@
"""Plugin class for custom data exporting."""
from collections import OrderedDict
from typing import Union
from typing import Optional, Union
from django.contrib.auth.models import User
from django.db.models import QuerySet
@@ -36,8 +36,8 @@ class DataExportMixin:
self,
model_class: type,
user: User,
serializer_class: serializers.Serializer = None,
view_class: views.APIView = None,
serializer_class: Optional[serializers.Serializer] = None,
view_class: Optional[views.APIView] = None,
*args,
**kwargs,
) -> bool:
@@ -72,7 +72,7 @@ class ValidationMixin:
def validate_model_instance(
self, instance: Model, deltas: Optional[dict] = None
) -> None:
) -> Optional[bool]:
"""Run custom validation on a database model instance.
This method is called when a model instance is being validated.
@@ -90,7 +90,7 @@ class ValidationMixin:
"""
return None
def validate_part_name(self, name: str, part: part.models.Part) -> None:
def validate_part_name(self, name: str, part: part.models.Part) -> Optional[bool]:
"""Perform validation on a proposed Part name.
Arguments:
@@ -105,7 +105,7 @@ class ValidationMixin:
"""
return None
def validate_part_ipn(self, ipn: str, part: part.models.Part) -> None:
def validate_part_ipn(self, ipn: str, part: part.models.Part) -> Optional[bool]:
"""Perform validation on a proposed Part IPN (internal part number).
Arguments:
@@ -122,7 +122,7 @@ class ValidationMixin:
def validate_batch_code(
self, batch_code: str, item: stock.models.StockItem
) -> None:
) -> Optional[bool]:
"""Validate the supplied batch code.
Arguments:
@@ -137,7 +137,7 @@ class ValidationMixin:
"""
return None
def generate_batch_code(self, **kwargs) -> str:
def generate_batch_code(self, **kwargs) -> Optional[str]:
"""Generate a new batch code.
This method is called when a new batch code is required.
@@ -154,8 +154,8 @@ class ValidationMixin:
self,
serial: str,
part: part.models.Part,
stock_item: stock.models.StockItem = None,
) -> None:
stock_item: Optional[stock.models.StockItem] = None,
) -> Optional[bool]:
"""Validate the supplied serial number.
Arguments:
@@ -171,7 +171,7 @@ class ValidationMixin:
"""
return None
def convert_serial_to_int(self, serial: str) -> int:
def convert_serial_to_int(self, serial: str) -> Optional[int]:
"""Convert a serial number (string) into an integer representation.
This integer value is used for efficient sorting based on serial numbers.
@@ -192,7 +192,7 @@ class ValidationMixin:
"""
return None
def get_latest_serial_number(self, part, **kwargs):
def get_latest_serial_number(self, part, **kwargs) -> Optional[str]:
"""Return the 'latest' serial number for a given Part instance.
A plugin which implements this method can either return:
@@ -209,8 +209,8 @@ class ValidationMixin:
return None
def increment_serial_number(
self, serial: str, part: part.models.Part = None, **kwargs
) -> str:
self, serial: str, part: Optional[part.models.Part] = None, **kwargs
) -> Optional[str]:
"""Return the next sequential serial based on the provided value.
A plugin which implements this method can either return:
@@ -229,7 +229,7 @@ class ValidationMixin:
def validate_part_parameter(
self, parameter: part.models.PartParameter, data: str
) -> None:
) -> Optional[bool]:
"""Validate a parameter value.
Arguments:
@@ -311,6 +311,7 @@ class APICallMixinTest(BaseMixinDefinition, TestCase):
self.assertTrue(result)
self.assertNotIn('error', result)
assert result is not None
self.assertEqual(result['name'], 'morpheus')
# api_call with endpoint with leading slash
@@ -113,6 +113,7 @@ class LabelMixinTests(PrintTestMixins, InvenTreeAPITestCase):
parts = Part.objects.all()[:2]
template = LabelTemplate.objects.filter(enabled=True, model_type='part').first()
assert template
self.assertIsNotNone(template)
self.assertTrue(template.enabled)
@@ -227,6 +228,8 @@ class LabelMixinTests(PrintTestMixins, InvenTreeAPITestCase):
# Lookup references
parts = Part.objects.all()[:2]
template = LabelTemplate.objects.filter(enabled=True, model_type='part').first()
assert template
self.do_activate_plugin()
plugin = registry.get_plugin(self.plugin_ref)
@@ -7,4 +7,4 @@ class BrokenFileIntegrationPlugin(InvenTreePlugin):
"""An very broken plugin."""
aaa = bb # noqa: F821
aaa = bb # noqa: F821 # type: ignore[unresolved-reference]
+1 -1
View File
@@ -51,7 +51,7 @@ class MixinNotImplementedError(NotImplementedError):
def log_registry_error(error, reference: str = 'general'):
"""Log an plugin error."""
from plugin import registry
from plugin.registry import registry
# make sure the registry is set up
if reference not in registry.errors:
+3 -2
View File
@@ -2,6 +2,7 @@
import inspect
import warnings
from typing import Optional
from django.conf import settings
from django.contrib import admin
@@ -219,7 +220,7 @@ class PluginConfig(InvenTree.models.MetadataMixin, models.Model):
return pkg_name is not None
@property
def admin_source(self) -> str:
def admin_source(self) -> Optional[str]:
"""Return the path to the javascript file which renders custom admin content for this plugin.
- It is required that the file provides a 'renderPluginSettings' function!
@@ -239,7 +240,7 @@ class PluginConfig(InvenTree.models.MetadataMixin, models.Model):
return None
@property
def admin_context(self) -> dict:
def admin_context(self) -> Optional[dict]:
"""Return the context data for the admin integration."""
if not self.plugin:
return None
+6 -5
View File
@@ -3,7 +3,7 @@
import inspect
import warnings
from datetime import datetime
from distutils.sysconfig import get_python_lib
from distutils.sysconfig import get_python_lib # type: ignore[import]
from importlib.metadata import PackageNotFoundError, metadata
from pathlib import Path
from typing import Optional, Union
@@ -568,8 +568,9 @@ class InvenTreePlugin(VersionMixin, MixinBase, MetaBase):
package = {}
# process date
if package.get('date'):
package['date'] = datetime.fromisoformat(package.get('date'))
date = package.get('date')
if date:
package['date'] = datetime.fromisoformat(date)
# set variables
self.package = package
@@ -608,7 +609,7 @@ class InvenTreePlugin(VersionMixin, MixinBase, MetaBase):
return url
def get_admin_source(self) -> str:
def get_admin_source(self) -> Union[str, None]:
"""Return a path to a JavaScript file which contains custom UI settings.
The frontend code expects that this file provides a function named 'renderPluginSettings'.
@@ -618,7 +619,7 @@ class InvenTreePlugin(VersionMixin, MixinBase, MetaBase):
return self.plugin_static_file(self.ADMIN_SOURCE)
def get_admin_context(self) -> dict:
def get_admin_context(self) -> Union[dict, None]:
"""Return a context dictionary for the admin panel settings.
This is an optional method which can be overridden by the plugin.
+11 -8
View File
@@ -121,12 +121,12 @@ class PluginsRegistry:
self.ready = False # Marks if the registry is ready to be used
# Keep an internal hash of the plugin registry state
self.registry_hash = None
self.registry_hash: Optional[str] = None
self.plugin_modules: list[InvenTreePlugin] = [] # Holds all discovered plugins
self.mixin_modules: dict[str, Any] = {} # Holds all discovered mixins
self.errors = {} # Holds errors discovered during loading
self.errors: dict[str, list[Any]] = {} # Holds errors discovered during loading
self.loading_lock = Lock() # Lock to prevent multiple loading at the same time
@@ -289,7 +289,7 @@ class PluginsRegistry:
@registry_entrypoint(default_value=[])
def with_mixin(
self, mixin: str, active: bool = True, builtin: Optional[bool] = None
self, mixin: str, active: Optional[bool] = True, builtin: Optional[bool] = None
) -> list[InvenTreePlugin]:
"""Returns reference to all plugins that have a specified mixin enabled.
@@ -764,9 +764,9 @@ class PluginsRegistry:
f"Plugin '{p}' is not compatible with the current InvenTree version {v}"
)
if v := plg_i.MIN_VERSION:
_msg += _(f'Plugin requires at least version {v}')
_msg += _(f'Plugin requires at least version {v}') # type: ignore[unsupported-operator]
if v := plg_i.MAX_VERSION:
_msg += _(f'Plugin requires at most version {v}')
_msg += _(f'Plugin requires at most version {v}') # type: ignore[unsupported-operator]
# Log to error stack
log_registry_error(_msg, reference=f'{p}:init_plugin')
else:
@@ -809,7 +809,7 @@ class PluginsRegistry:
logger.exception(
'[PLUGIN] Encountered an error with %s:\n%s',
error.path,
getattr(error, 'path', None),
str(error),
)
@@ -1084,11 +1084,14 @@ def _load_source(modname, filename):
# loader = importlib.machinery.SourceFileLoader(modname, filename)
spec = importlib.util.spec_from_file_location(modname, filename) # , loader=loader)
if spec is None:
raise ImportError(f"Cannot find module '{modname}'") # pragma: no cover
module = importlib.util.module_from_spec(spec)
sys.modules[module.__name__] = module
if spec.loader:
spec.loader.exec_module(module)
loader = spec.loader
if loader is not None:
loader.exec_module(module)
return module
+6 -3
View File
@@ -7,6 +7,7 @@ import tempfile
import textwrap
from datetime import datetime
from pathlib import Path
from typing import Optional
from unittest import mock
from unittest.mock import patch
@@ -204,7 +205,9 @@ class InvenTreePluginTests(TestCase):
self.assertFalse(self.plugin_version.check_version([0, 1, 4]))
plug = registry.plugins_full.get('sampleversion')
self.assertEqual(plug.is_active(), False)
self.assertIsNotNone(plug)
if plug:
self.assertEqual(plug.is_active(), False)
class RegistryTests(TestQueryMixin, PluginRegistryMixin, TestCase):
@@ -251,7 +254,7 @@ class RegistryTests(TestQueryMixin, PluginRegistryMixin, TestCase):
def test_folder_loading(self):
"""Test that plugins in folders outside of BASE_DIR get loaded."""
# Run in temporary directory -> always a new random name
with tempfile.TemporaryDirectory() as tmp:
with tempfile.TemporaryDirectory() as tmp: # type: ignore[no-matching-overload]
# Fill directory with sample data
new_dir = Path(tmp).joinpath('mock')
shutil.copytree(self.mockDir(), new_dir)
@@ -400,7 +403,7 @@ class RegistryTests(TestQueryMixin, PluginRegistryMixin, TestCase):
def create_plugin_file(
version: str, enabled: bool = True, reload: bool = True
) -> str:
) -> Optional[str]:
"""Create a plugin file with the given version.
Arguments:
+3 -2
View File
@@ -6,8 +6,9 @@ from django.utils.decorators import method_decorator
from django.utils.translation import gettext_lazy as _
from django.views.decorators.cache import never_cache
from django_filters import rest_framework as rest_filters
import django_filters.rest_framework.filters as rest_filters
from django_filters.rest_framework import DjangoFilterBackend
from django_filters.rest_framework.filterset import FilterSet
from rest_framework.generics import GenericAPIView
from rest_framework.response import Response
@@ -31,7 +32,7 @@ class TemplatePermissionMixin:
permission_classes = [InvenTree.permissions.IsStaffOrReadOnlyScope]
class ReportFilterBase(rest_filters.FilterSet):
class ReportFilterBase(FilterSet):
"""Base filter class for label and report templates."""
enabled = rest_filters.BooleanFilter()
+1 -1
View File
@@ -67,7 +67,7 @@ class ReportConfig(AppConfig):
def cleanup(self):
"""Cleanup old label and report outputs."""
try:
from report.tasks import cleanup_old_report_outputs
from report.tasks import cleanup_old_report_outputs # type: ignore[import]
cleanup_old_report_outputs()
except Exception:
+6 -3
View File
@@ -238,7 +238,7 @@ class ReportTemplateBase(MetadataMixin, InvenTree.models.InvenTreeModel):
),
)
def generate_filename(self, context, **kwargs):
def generate_filename(self, context, **kwargs) -> str:
"""Generate a filename for this report."""
template_string = Template(self.filename_pattern)
@@ -491,7 +491,7 @@ class ReportTemplate(TemplateUploadMixin, ReportTemplateBase):
debug_mode = get_global_setting('REPORT_DEBUG_MODE', False)
# Start with a default report name
report_name = None
report_name: Optional[str] = None
# If a DataOutput object is not provided, create a new one
if not output:
@@ -608,6 +608,9 @@ class ReportTemplate(TemplateUploadMixin, ReportTemplateBase):
'path': request.path if request else None,
})
if not report_name:
report_name = '' # pragma: no cover
if not report_name.endswith('.pdf'):
report_name += '.pdf'
@@ -695,7 +698,7 @@ class LabelTemplate(TemplateUploadMixin, ReportTemplateBase):
def get_context(self, instance, request=None, **kwargs):
"""Supply context data to the label template for rendering."""
base_context = super().get_context(instance, request, **kwargs)
label_context: LabelContextExtension = {
label_context: LabelContextExtension = { # type: ignore[invalid-assignment]
'width': self.width,
'height': self.height,
'page_style': None,
@@ -4,6 +4,7 @@ from django import template
from django.utils.safestring import mark_safe
import barcode as python_barcode
import barcode.writer as python_barcode_writer
import qrcode.constants as ECL
from PIL import Image, ImageColor
from qrcode.main import QRCode
@@ -122,7 +123,7 @@ def barcode(data: str, barcode_class='code128', **kwargs) -> str:
data = str(data).zfill(constructor.digits)
writer = python_barcode.writer.ImageWriter
writer = python_barcode_writer.ImageWriter
barcode_image = constructor(data, writer=writer())
@@ -148,7 +149,7 @@ def datamatrix(data: str, **kwargs) -> str:
Returns:
image (str): base64 encoded image data
"""
from ppf.datamatrix import DataMatrix
from ppf.datamatrix.datamatrix import DataMatrix
data = str(data).strip()
@@ -50,7 +50,7 @@ def filter_queryset(queryset: QuerySet, **kwargs) -> QuerySet:
@register.simple_tag()
def filter_db_model(model_name: str, **kwargs) -> QuerySet:
def filter_db_model(model_name: str, **kwargs) -> Optional[QuerySet]:
"""Filter a database model based on the provided keyword arguments.
Arguments:
@@ -102,7 +102,7 @@ def getindex(container: list, index: int) -> Any:
@register.simple_tag()
def getkey(container: dict, key: str, backup_value: Optional[any] = None) -> Any:
def getkey(container: dict, key: str, backup_value: Optional[Any] = None) -> Any:
"""Perform key lookup in the provided dict object.
This function is provided to get around template rendering limitations.
@@ -301,14 +301,13 @@ def part_image(part: Part, preview: bool = False, thumbnail: bool = False, **kwa
if type(part) is not Part:
raise TypeError(_('part_image tag requires a Part instance'))
if not part.image:
part_img = part.image
if not part_img:
img = None
elif preview:
img = None if not hasattr(part.image, 'preview') else part.image.preview.name
img = None if not hasattr(part.image, 'preview') else part_img.preview.name
elif thumbnail:
img = (
None if not hasattr(part.image, 'thumbnail') else part.image.thumbnail.name
)
img = None if not hasattr(part.image, 'thumbnail') else part_img.thumbnail.name
else:
img = part.image.name
@@ -316,7 +315,7 @@ def part_image(part: Part, preview: bool = False, thumbnail: bool = False, **kwa
@register.simple_tag()
def part_parameter(part: Part, parameter_name: str) -> str:
def part_parameter(part: Part, parameter_name: str) -> Optional[str]:
"""Return a PartParameter object for the given part and parameter name.
Arguments:
@@ -348,12 +347,15 @@ def company_image(
if type(company) is not Company:
raise TypeError(_('company_image tag requires a Company instance'))
if preview:
img = company.image.preview.name
cmp_img = company.image
if not cmp_img:
img = None
elif preview:
img = cmp_img.preview.name
elif thumbnail:
img = company.image.thumbnail.name
img = cmp_img.thumbnail.name
else:
img = company.image.name
img = cmp_img.name
return uploaded_image(img, **kwargs)
+4
View File
@@ -84,6 +84,7 @@ class ReportTest(InvenTreeAPITestCase):
# Filter by items
part_pk = Part.objects.first().pk
report = ReportTemplate.objects.filter(model_type='part').first()
assert report
try:
response = self.get(
@@ -236,6 +237,7 @@ class ReportTest(InvenTreeAPITestCase):
url = reverse('api-report-template-list')
template = ReportTemplate.objects.first()
assert template
detail_url = reverse('api-report-template-detail', kwargs={'pk': template.pk})
@@ -415,6 +417,7 @@ class PrintTestMixins:
qs = qs.objects.all()
template = mdl.objects.filter(enabled=True, model_type=model_type).first()
assert template
plugin = registry.get_plugin(self.plugin_ref)
# Single page printing
@@ -475,6 +478,7 @@ class TestReportTest(PrintTestMixins, ReportTest):
template = ReportTemplate.objects.filter(
enabled=True, model_type='stockitem'
).first()
assert template
self.assertIsNotNone(template)
+10 -9
View File
@@ -9,7 +9,8 @@ from django.db.models import F, Q
from django.urls import include, path
from django.utils.translation import gettext_lazy as _
from django_filters import rest_framework as rest_filters
import django_filters.rest_framework.filters as rest_filters
from django_filters.rest_framework.filterset import FilterSet
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import extend_schema, extend_schema_field
from rest_framework import status
@@ -257,7 +258,7 @@ class StockMerge(CreateAPI):
return ctx
class StockLocationFilter(rest_filters.FilterSet):
class StockLocationFilter(FilterSet):
"""Base class for custom API filters for the StockLocation endpoint."""
class Meta:
@@ -425,11 +426,11 @@ class StockLocationDetail(StockLocationMixin, CustomRetrieveUpdateDestroyAPI):
return super().destroy(
request,
*args,
**dict(
kwargs,
delete_sub_locations=delete_sub_locations,
delete_stock_items=delete_stock_items,
),
**{
**kwargs,
'delete_sub_locations': delete_sub_locations,
'delete_stock_items': delete_stock_items,
},
)
@@ -505,7 +506,7 @@ class StockLocationTypeDetail(RetrieveUpdateDestroyAPI):
return queryset
class StockFilter(rest_filters.FilterSet):
class StockFilter(FilterSet):
"""FilterSet for StockItem LIST API."""
class Meta:
@@ -1339,7 +1340,7 @@ class StockItemTestResultDetail(StockItemTestResultMixin, RetrieveUpdateDestroyA
"""Detail endpoint for StockItemTestResult."""
class StockItemTestResultFilter(rest_filters.FilterSet):
class StockItemTestResultFilter(FilterSet):
"""API filter for the StockItemTestResult list."""
class Meta:
+3 -1
View File
@@ -1,12 +1,14 @@
"""Custom query filters for the Stock models."""
from typing import Optional
from django.db.models import F, Func, IntegerField, OuterRef, Q, Subquery
from django.db.models.functions import Coalesce
import stock.models
def annotate_location_items(filter: Q = None):
def annotate_location_items(filter: Optional[Q] = None):
"""Construct a queryset annotation which returns the number of stock items in a particular location.
- Includes items in subcategories also
+2 -1
View File
@@ -1,6 +1,7 @@
"""Generator functions for the stock app."""
from inspect import signature
from typing import Optional
from django.core.exceptions import ValidationError
@@ -78,7 +79,7 @@ def generate_batch_code(**kwargs):
return Template(batch_template).render(context)
def generate_serial_number(part=None, quantity=1, **kwargs) -> str:
def generate_serial_number(part=None, quantity=1, **kwargs) -> Optional[str]:
"""Generate a default 'serial number' for a new StockItem."""
quantity = quantity or 1
@@ -51,10 +51,10 @@ def update_history(apps, schema_editor):
q = entry.quantity
if idx == 0 or not q == quantity:
if idx == 0 or q != quantity:
try:
deltas['quantity']: float(q)
deltas['quantity']= float(q)
updated = True
except Exception:
print(f"WARNING: Error converting quantity '{q}'")
+2 -2
View File
@@ -667,7 +667,7 @@ class StockItem(
return items
@staticmethod
def convert_serial_to_int(serial: str) -> int:
def convert_serial_to_int(serial: str) -> Optional[int]:
"""Convert the provided serial number to an integer value.
This function hooks into the plugin system to allow for custom serial number conversion.
@@ -1784,7 +1784,7 @@ class StockItem(
self,
entry_type: int,
user: User,
deltas: dict | None = None,
deltas: Optional[dict] = None,
notes: str = '',
commit: bool = True,
**kwargs,
@@ -20,6 +20,7 @@ import build.models
import company.models
import company.serializers as company_serializers
import InvenTree.helpers
import InvenTree.ready
import InvenTree.serializers
import order.models
import part.filters as part_filters
+1
View File
@@ -2079,6 +2079,7 @@ class StockTestResultTest(StockAPITestCase):
url = reverse('api-stock-test-result-list')
test_template = PartTestTemplate.objects.first()
assert test_template
test_template.choices = 'AA, BB, CC'
test_template.save()
+3 -1
View File
@@ -47,7 +47,9 @@ User.add_to_class('__str__', user_model_str) # Overriding User.__str__
if settings.LDAP_AUTH:
from django_auth_ldap.backend import populate_user
from django_auth_ldap.backend import ( # type: ignore[unresolved-import]
populate_user,
)
@receiver(populate_user)
def create_email_address(user, **kwargs):
+2 -2
View File
@@ -31,8 +31,8 @@ def get_model_permission_string(model: models.Model, permission: str) -> str:
Returns:
str: The permission string (e.g. 'part.view_part')
"""
model, app = split_model(model)
return f'{app}.{permission}_{model}'
_model, _app = split_model(model)
return f'{_app}.{permission}_{_model}'
def split_permission(app: str, perm: str) -> tuple[str, str]:
+1 -1
View File
@@ -281,7 +281,7 @@ class GroupSerializer(InvenTreeModelSerializer):
class ExtendedUserSerializer(UserSerializer):
"""Serializer for a User with a bit more info."""
from users.serializers import GroupSerializer
# from users.serializers import GroupSerializer
class Meta(UserSerializer.Meta):
"""Metaclass defines serializer fields."""
+3 -1
View File
@@ -1,5 +1,7 @@
"""Background tasks for the users app."""
from typing import Any
from django.contrib.auth.models import Group, Permission
from django.contrib.contenttypes.models import ContentType
@@ -99,7 +101,7 @@ def update_group_roles(group: Group, debug: bool = False) -> None:
permissions_to_delete.add(permission_string)
# Pre-fetch all the RuleSet objects
rulesets = {
rulesets: dict[Any, RuleSet] = {
r.name: r for r in RuleSet.objects.filter(group=group).prefetch_related('group')
}
+4
View File
@@ -131,6 +131,8 @@ class UserAPITests(InvenTreeAPITestCase):
def test_user_detail(self):
"""Test the UserDetail API endpoint."""
user = User.objects.first()
assert user
url = reverse('api-user-detail', kwargs={'pk': user.pk})
user.is_staff = False
@@ -274,6 +276,7 @@ class UserTokenTests(InvenTreeAPITestCase):
# If we re-generate a token, the value changes
token = ApiToken.objects.filter(name='cat').first()
assert token
# Request the token with the same name
data = self.get(url, data={'name': 'cat'}, expected_code=200).data
@@ -331,6 +334,7 @@ class UserTokenTests(InvenTreeAPITestCase):
# Grab the token, and update
token = ApiToken.objects.first()
assert token
self.assertEqual(token.key, token_key)
self.assertIsNotNone(token.last_seen)
@@ -1,6 +1,7 @@
"""Template tag to render SPA imports."""
import json
import json.decoder
from pathlib import Path
from typing import Union
+3
View File
@@ -9,3 +9,6 @@ pip-tools # Compile pip requirements
pre-commit # Git pre-commit
setuptools # Standard dependency
pdfminer.six # PDF validation
ty # type checking
django-types # typing
django-stubs # typing
+46
View File
@@ -6,6 +6,7 @@ asgiref==3.9.1 \
# via
# -c src/backend/requirements.txt
# django
# django-stubs
build==1.3.0 \
--hash=sha256:698edd0ea270bde950f53aed21f3a0135672206f3911e0176261a31e0e07b397 \
--hash=sha256:7145f0b5061ba90a1500d60bd1b13ca0a8a4cebdd0cc16ed8adf1c0e739f43b4
@@ -326,16 +327,30 @@ django==4.2.24 \
# via
# -c src/backend/requirements.txt
# django-slowtests
# django-stubs
# django-stubs-ext
django-querycount==0.8.3 \
--hash=sha256:0782484e8a1bd29498fa0195a67106e47cdcc98fafe80cebb1991964077cb694
# via -r src/backend/requirements-dev.in
django-slowtests==1.1.1 \
--hash=sha256:3c6936d420c9df444ac03625b41d97de043c662bbde61fbcd33e4cd407d0c247
# via -r src/backend/requirements-dev.in
django-stubs==5.1.3 \
--hash=sha256:716758ced158b439213062e52de6df3cff7c586f9f9ad7ab59210efbea5dfe78 \
--hash=sha256:8c230bc5bebee6da282ba8a27ad1503c84a0c4cd2f46e63d149e76d2a63e639a
# via -r src/backend/requirements-dev.in
django-stubs-ext==5.1.3 \
--hash=sha256:3e60f82337f0d40a362f349bf15539144b96e4ceb4dbd0239be1cd71f6a74ad0 \
--hash=sha256:64561fbc53e963cc1eed2c8eb27e18b8e48dcb90771205180fe29fc8a59e55fd
# via django-stubs
django-test-migrations==1.4.0 \
--hash=sha256:294dff98f6d43d020d4046b971bac5339e7c71458a35e9ad6450c388fe16ed6b \
--hash=sha256:f0c9c92864ed27d0c9a582e92056637e91227f54bd868a50cb9a1726668c563e
# via -r src/backend/requirements-dev.in
django-types==0.20.0 \
--hash=sha256:4e55d2c56155e3d69d75def9eb1d95a891303f2ac19fccf6fe8056da4293fae7 \
--hash=sha256:a0b5c2c9a1e591684bb21a93b64e50ca6cb2d3eab48f49faff1eac706bd3a9c7
# via -r src/backend/requirements-dev.in
filelock==3.18.0 \
--hash=sha256:adbc88eabb99d2fec8c9c1b229b171f18afa655400173ddc653d5d01501fb9f2 \
--hash=sha256:c401f4f8377c4464e6db25fff06205fd89bdd83b65eb0488ed1b160f780e21de
@@ -505,13 +520,44 @@ tomli==2.2.1 \
# -c src/backend/requirements.txt
# build
# coverage
# django-stubs
# pip-tools
ty==0.0.1a20 \
--hash=sha256:0b481f26513f38543df514189fb16744690bcba8d23afee95a01927d93b46e36 \
--hash=sha256:3c2ace3a22fab4bd79f84c74e3dab26e798bfba7006bea4008d6321c1bd6efc6 \
--hash=sha256:3ff75cd4c744d09914e8c9db8d99e02f82c9379ad56b0a3fc4c5c9c923cfa84e \
--hash=sha256:726d0738be4459ac7ffae312ba96c5f486d6cbc082723f322555d7cba9397871 \
--hash=sha256:7abbe3c02218c12228b1d7c5f98c57240029cc3bcb15b6997b707c19be3908c1 \
--hash=sha256:83a7ee12465841619b5eb3ca962ffc7d576bb1c1ac812638681aee241acbfbbe \
--hash=sha256:8a138fa4f74e6ed34e9fd14652d132409700c7ff57682c2fed656109ebfba42f \
--hash=sha256:8eff8871d6b88d150e2a67beba2c57048f20c090c219f38ed02eebaada04c124 \
--hash=sha256:933b65a152f277aa0e23ba9027e5df2c2cc09e18293e87f2a918658634db5f15 \
--hash=sha256:b4124ab75e0e6f09fe7bc9df4a77ee43c5e0ef7e61b0c149d7c089d971437cbd \
--hash=sha256:b8c4336987a6a781d4392a9fd7b3a39edb7e4f3dd4f860e03f46c932b52aefa2 \
--hash=sha256:cad12c857ea4b97bf61e02f6796e13061ccca5e41f054cbd657862d80aa43bae \
--hash=sha256:d8ac1c5a14cda5fad1a8b53959d9a5d979fe16ce1cc2785ea8676fed143ac85f \
--hash=sha256:e26437772be7f7808868701f2bf9e14e706a6ec4c7d02dbd377ff94d7ba60c11 \
--hash=sha256:f153b65c7fcb6b8b59547ddb6353761b3e8d8bb6f0edd15e3e3ac14405949f7a \
--hash=sha256:f41e77ff118da3385915e13c3f366b3a2f823461de54abd2e0ca72b170ba0f19 \
--hash=sha256:f73a7aca1f0d38af4d6999b375eb00553f3bfcba102ae976756cc142e14f3450 \
--hash=sha256:fff51c75ee3f7cc6d7722f2f15789ef8ffe6fd2af70e7269ac785763c906688e
# via -r src/backend/requirements-dev.in
types-psycopg2==2.9.21.20250915 \
--hash=sha256:bfeb8f54c32490e7b5edc46215ab4163693192bc90407b4a023822de9239f5c8 \
--hash=sha256:eefe5ccdc693fc086146e84c9ba437bb278efe1ef330b299a0cb71169dc6c55f
# via django-types
types-pyyaml==6.0.12.20250915 \
--hash=sha256:0f8b54a528c303f0e6f7165687dd33fafa81c807fcac23f632b63aa624ced1d3 \
--hash=sha256:e7d4d9e064e89a3b3cae120b4990cd370874d2bf12fa5f46c97018dd5d3c9ab6
# via django-stubs
typing-extensions==4.14.1 \
--hash=sha256:38b39f4aeeab64884ce9f74c94263ef78f3c22467c8724005483154c26648d36 \
--hash=sha256:d1e1e3b58374dc93031d6eda2420a48ea44a36c2b4766a4fdeb3710755731d76
# via
# -c src/backend/requirements.txt
# asgiref
# django-stubs
# django-stubs-ext
# django-test-migrations
virtualenv==20.33.1 \
--hash=sha256:07c19bc66c11acab6a5958b815cbcee30891cd1c2ccf53785a28651a0d8d8a67 \
+13 -7
View File
@@ -77,7 +77,7 @@ def is_pkg_installer_by_path():
def get_installer(content: Optional[dict] = None):
"""Get the installer for the current environment or a content dict."""
if content is None:
content = os.environ
content = dict(os.environ)
return content.get('INVENTREE_PKG_INSTALLER', None)
@@ -461,7 +461,9 @@ def check_file_existence(filename: Path, overwrite: bool = False):
@state_logger('TASK01')
def plugins(c, uv=False):
"""Installs all plugins as specified in 'plugins.txt'."""
from src.backend.InvenTree.InvenTree.config import get_plugin_file
from src.backend.InvenTree.InvenTree.config import ( # type: ignore[import]
get_plugin_file,
)
plugin_file = get_plugin_file()
@@ -573,7 +575,9 @@ def rebuild_models(c):
@task
def rebuild_thumbnails(c):
"""Rebuild missing image thumbnails."""
from src.backend.InvenTree.InvenTree.config import get_media_dir
from src.backend.InvenTree.InvenTree.config import ( # type: ignore[import]
get_media_dir,
)
info(f'Rebuilding image thumbnails in {get_media_dir()}')
manage(c, 'rebuild_thumbnails', pty=True)
@@ -1165,7 +1169,7 @@ def test_translations(c):
info('Fill in dummy translations...')
file_path = pathlib.Path(settings.LOCALE_PATHS[0], 'xx', 'LC_MESSAGES', 'django.po')
new_file_path = str(file_path) + '_new'
new_file_path = Path(str(file_path) + '_new')
# compile regex
reg = re.compile(
@@ -1303,7 +1307,9 @@ def setup_test(
path='inventree-demo-dataset',
):
"""Setup a testing environment."""
from src.backend.InvenTree.InvenTree.config import get_media_dir
from src.backend.InvenTree.InvenTree.config import ( # type: ignore[import]
get_media_dir,
)
if not ignore_update:
update(c)
@@ -1453,8 +1459,8 @@ def export_definitions(c, basedir: str = ''):
@task(default=True)
def version(c):
"""Show the current version of InvenTree."""
import src.backend.InvenTree.InvenTree.version as InvenTreeVersion
from src.backend.InvenTree.InvenTree.config import (
import src.backend.InvenTree.InvenTree.version as InvenTreeVersion # type: ignore[import]
from src.backend.InvenTree.InvenTree.config import ( # type: ignore[import]
get_backup_dir,
get_config_file,
get_media_dir,