2
0
mirror of https://github.com/inventree/InvenTree.git synced 2025-05-05 14:58:50 +00:00

788 lines
20 KiB
Python

"""
Provides helper functions used throughout the InvenTree project
"""
import io
import json
import os.path
import re
from decimal import Decimal, InvalidOperation
from wsgiref.util import FileWrapper
from django.contrib.auth.models import Permission
from django.core.exceptions import FieldError, ValidationError
from django.http import StreamingHttpResponse
from django.test import TestCase
from django.utils.translation import gettext_lazy as _
from djmoney.money import Money
from PIL import Image
import InvenTree.version
from common.models import InvenTreeSetting
from common.settings import currency_code_default
from .api_tester import UserMixin
from .settings import MEDIA_URL, STATIC_URL
def getSetting(key, backup_value=None):
"""
Shortcut for reading a setting value from the database
"""
return InvenTreeSetting.get_setting(key, backup_value=backup_value)
def generateTestKey(test_name):
"""
Generate a test 'key' for a given test name.
This must not have illegal chars as it will be used for dict lookup in a template.
Tests must be named such that they will have unique keys.
"""
key = test_name.strip().lower()
key = key.replace(" ", "")
# Remove any characters that cannot be used to represent a variable
key = re.sub(r'[^a-zA-Z0-9]', '', key)
return key
def getMediaUrl(filename):
"""
Return the qualified access path for the given file,
under the media directory.
"""
return os.path.join(MEDIA_URL, str(filename))
def getStaticUrl(filename):
"""
Return the qualified access path for the given file,
under the static media directory.
"""
return os.path.join(STATIC_URL, str(filename))
def construct_absolute_url(*arg):
"""
Construct (or attempt to construct) an absolute URL from a relative URL.
This is useful when (for example) sending an email to a user with a link
to something in the InvenTree web framework.
This requires the BASE_URL configuration option to be set!
"""
base = str(InvenTreeSetting.get_setting('INVENTREE_BASE_URL'))
url = '/'.join(arg)
if not base:
return url
# Strip trailing slash from base url
if base.endswith('/'):
base = base[:-1]
if url.startswith('/'):
url = url[1:]
url = f"{base}/{url}"
return url
def getBlankImage():
"""
Return the qualified path for the 'blank image' placeholder.
"""
return getStaticUrl("img/blank_image.png")
def getBlankThumbnail():
"""
Return the qualified path for the 'blank image' thumbnail placeholder.
"""
return getStaticUrl("img/blank_image.thumbnail.png")
def TestIfImage(img):
""" Test if an image file is indeed an image """
try:
Image.open(img).verify()
return True
except:
return False
def TestIfImageURL(url):
""" Test if an image URL (or filename) looks like a valid image format.
Simply tests the extension against a set of allowed values
"""
return os.path.splitext(os.path.basename(url))[-1].lower() in [
'.jpg', '.jpeg',
'.png', '.bmp',
'.tif', '.tiff',
'.webp', '.gif',
]
def str2bool(text, test=True):
""" Test if a string 'looks' like a boolean value.
Args:
text: Input text
test (default = True): Set which boolean value to look for
Returns:
True if the text looks like the selected boolean value
"""
if test:
return str(text).lower() in ['1', 'y', 'yes', 't', 'true', 'ok', 'on', ]
else:
return str(text).lower() in ['0', 'n', 'no', 'none', 'f', 'false', 'off', ]
def is_bool(text):
"""
Determine if a string value 'looks' like a boolean.
"""
if str2bool(text, True):
return True
elif str2bool(text, False):
return True
else:
return False
def isNull(text):
"""
Test if a string 'looks' like a null value.
This is useful for querying the API against a null key.
Args:
text: Input text
Returns:
True if the text looks like a null value
"""
return str(text).strip().lower() in ['top', 'null', 'none', 'empty', 'false', '-1', '']
def normalize(d):
"""
Normalize a decimal number, and remove exponential formatting.
"""
if type(d) is not Decimal:
d = Decimal(d)
d = d.normalize()
# Ref: https://docs.python.org/3/library/decimal.html
return d.quantize(Decimal(1)) if d == d.to_integral() else d.normalize()
def increment(n):
"""
Attempt to increment an integer (or a string that looks like an integer!)
e.g.
001 -> 002
2 -> 3
AB01 -> AB02
QQQ -> QQQ
"""
value = str(n).strip()
# Ignore empty strings
if not value:
return value
pattern = r"(.*?)(\d+)?$"
result = re.search(pattern, value)
# No match!
if result is None:
return value
groups = result.groups()
# If we cannot match the regex, then simply return the provided value
if len(groups) != 2:
return value
prefix, number = groups
# No number extracted? Simply return the prefix (without incrementing!)
if not number:
return prefix
# Record the width of the number
width = len(number)
try:
number = int(number) + 1
number = str(number)
except ValueError:
pass
number = number.zfill(width)
return prefix + number
def decimal2string(d):
"""
Format a Decimal number as a string,
stripping out any trailing zeroes or decimal points.
Essentially make it look like a whole number if it is one.
Args:
d: A python Decimal object
Returns:
A string representation of the input number
"""
if type(d) is Decimal:
d = normalize(d)
try:
# Ensure that the provided string can actually be converted to a float
float(d)
except ValueError:
# Not a number
return str(d)
s = str(d)
# Return entire number if there is no decimal place
if '.' not in s:
return s
return s.rstrip("0").rstrip(".")
def decimal2money(d, currency=None):
"""
Format a Decimal number as Money
Args:
d: A python Decimal object
currency: Currency of the input amount, defaults to default currency in settings
Returns:
A Money object from the input(s)
"""
if not currency:
currency = currency_code_default()
return Money(d, currency)
def WrapWithQuotes(text, quote='"'):
""" Wrap the supplied text with quotes
Args:
text: Input text to wrap
quote: Quote character to use for wrapping (default = "")
Returns:
Supplied text wrapped in quote char
"""
if not text.startswith(quote):
text = quote + text
if not text.endswith(quote):
text = text + quote
return text
def MakeBarcode(object_name, object_pk, object_data=None, **kwargs):
""" Generate a string for a barcode. Adds some global InvenTree parameters.
Args:
object_type: string describing the object type e.g. 'StockItem'
object_id: ID (Primary Key) of the object in the database
object_url: url for JSON API detail view of the object
data: Python dict object containing extra datawhich will be rendered to string (must only contain stringable values)
Returns:
json string of the supplied data plus some other data
"""
if object_data is None:
object_data = {}
url = kwargs.get('url', False)
brief = kwargs.get('brief', True)
data = {}
if url:
request = object_data.get('request', None)
item_url = object_data.get('item_url', None)
absolute_url = None
if request and item_url:
absolute_url = request.build_absolute_uri(item_url)
# Return URL (No JSON)
return absolute_url
if item_url:
# Return URL (No JSON)
return item_url
elif brief:
data[object_name] = object_pk
else:
data['tool'] = 'InvenTree'
data['version'] = InvenTree.version.inventreeVersion()
data['instance'] = InvenTree.version.inventreeInstanceName()
# Ensure PK is included
object_data['id'] = object_pk
data[object_name] = object_data
return json.dumps(data, sort_keys=True)
def GetExportFormats():
""" Return a list of allowable file formats for exporting data """
return [
'csv',
'tsv',
'xls',
'xlsx',
'json',
'yaml',
]
def DownloadFile(data, filename, content_type='application/text', inline=False):
"""
Create a dynamic file for the user to download.
Args:
data: Raw file data (string or bytes)
filename: Filename for the file download
content_type: Content type for the download
inline: Download "inline" or as attachment? (Default = attachment)
Return:
A StreamingHttpResponse object wrapping the supplied data
"""
filename = WrapWithQuotes(filename)
if type(data) == str:
wrapper = FileWrapper(io.StringIO(data))
else:
wrapper = FileWrapper(io.BytesIO(data))
response = StreamingHttpResponse(wrapper, content_type=content_type)
response['Content-Length'] = len(data)
disposition = "inline" if inline else "attachment"
response['Content-Disposition'] = f'{disposition}; filename={filename}'
return response
def extract_serial_numbers(serials, expected_quantity, next_number: int):
"""
Attempt to extract serial numbers from an input string:
Requirements:
- Serial numbers can be either strings, or integers
- Serial numbers can be split by whitespace / newline / commma chars
- Serial numbers can be supplied as an inclusive range using hyphen char e.g. 10-20
- Serial numbers can be defined as ~ for getting the next available serial number
- Serial numbers can be supplied as <start>+ for getting all expecteded numbers starting from <start>
- Serial numbers can be supplied as <start>+<length> for getting <length> numbers starting from <start>
Args:
serials: input string with patterns
expected_quantity: The number of (unique) serial numbers we expect
next_number(int): the next possible serial number
"""
serials = serials.strip()
# fill in the next serial number into the serial
while '~' in serials:
serials = serials.replace('~', str(next_number), 1)
next_number += 1
# Split input string by whitespace or comma (,) characters
groups = re.split(r"[\s,]+", serials)
numbers = []
errors = []
# Helper function to check for duplicated numbers
def add_sn(sn):
# Attempt integer conversion first, so numerical strings are never stored
try:
sn = int(sn)
except ValueError:
pass
if sn in numbers:
errors.append(_('Duplicate serial: {sn}').format(sn=sn))
else:
numbers.append(sn)
try:
expected_quantity = int(expected_quantity)
except ValueError:
raise ValidationError([_("Invalid quantity provided")])
if len(serials) == 0:
raise ValidationError([_("Empty serial number string")])
# If the user has supplied the correct number of serials, don't process them for groups
# just add them so any duplicates (or future validations) are checked
if len(groups) == expected_quantity:
for group in groups:
add_sn(group)
if len(errors) > 0:
raise ValidationError(errors)
return numbers
for group in groups:
group = group.strip()
# Hyphen indicates a range of numbers
if '-' in group:
items = group.split('-')
if len(items) == 2 and all([i.isnumeric() for i in items]):
a = items[0].strip()
b = items[1].strip()
try:
a = int(a)
b = int(b)
if a < b:
for n in range(a, b + 1):
add_sn(n)
else:
errors.append(_("Invalid group range: {g}").format(g=group))
except ValueError:
errors.append(_("Invalid group: {g}").format(g=group))
continue
else:
# More than 2 hyphens or non-numeric group so add without interpolating
add_sn(group)
# plus signals either
# 1: 'start+': expected number of serials, starting at start
# 2: 'start+number': number of serials, starting at start
elif '+' in group:
items = group.split('+')
# case 1, 2
if len(items) == 2:
start = int(items[0])
# case 2
if bool(items[1]):
end = start + int(items[1]) + 1
# case 1
else:
end = start + (expected_quantity - len(numbers))
for n in range(start, end):
add_sn(n)
# no case
else:
errors.append(_("Invalid group sequence: {g}").format(g=group))
# At this point, we assume that the "group" is just a single serial value
elif group:
add_sn(group)
# No valid input group detected
else:
raise ValidationError(_(f"Invalid/no group {group}"))
if len(errors) > 0:
raise ValidationError(errors)
if len(numbers) == 0:
raise ValidationError([_("No serial numbers found")])
# The number of extracted serial numbers must match the expected quantity
if expected_quantity != len(numbers):
raise ValidationError([_("Number of unique serial numbers ({s}) must match quantity ({q})").format(s=len(numbers), q=expected_quantity)])
return numbers
def validateFilterString(value, model=None):
"""
Validate that a provided filter string looks like a list of comma-separated key=value pairs
These should nominally match to a valid database filter based on the model being filtered.
e.g. "category=6, IPN=12"
e.g. "part__name=widget"
The ReportTemplate class uses the filter string to work out which items a given report applies to.
For example, an acceptance test report template might only apply to stock items with a given IPN,
so the string could be set to:
filters = "IPN = ACME0001"
Returns a map of key:value pairs
"""
# Empty results map
results = {}
value = str(value).strip()
if not value or len(value) == 0:
return results
groups = value.split(',')
for group in groups:
group = group.strip()
pair = group.split('=')
if len(pair) != 2:
raise ValidationError(
"Invalid group: {g}".format(g=group)
)
k, v = pair
k = k.strip()
v = v.strip()
if not k or not v:
raise ValidationError(
"Invalid group: {g}".format(g=group)
)
results[k] = v
# If a model is provided, verify that the provided filters can be used against it
if model is not None:
try:
model.objects.filter(**results)
except FieldError as e:
raise ValidationError(
str(e),
)
return results
def addUserPermission(user, permission):
"""
Shortcut function for adding a certain permission to a user.
"""
perm = Permission.objects.get(codename=permission)
user.user_permissions.add(perm)
def addUserPermissions(user, permissions):
"""
Shortcut function for adding multiple permissions to a user.
"""
for permission in permissions:
addUserPermission(user, permission)
def getMigrationFileNames(app):
"""
Return a list of all migration filenames for provided app
"""
local_dir = os.path.dirname(os.path.abspath(__file__))
migration_dir = os.path.join(local_dir, '..', app, 'migrations')
files = os.listdir(migration_dir)
# Regex pattern for migration files
pattern = r"^[\d]+_.*\.py$"
migration_files = []
for f in files:
if re.match(pattern, f):
migration_files.append(f)
return migration_files
def getOldestMigrationFile(app, exclude_extension=True, ignore_initial=True):
"""
Return the filename associated with the oldest migration
"""
oldest_num = -1
oldest_file = None
for f in getMigrationFileNames(app):
if ignore_initial and f.startswith('0001_initial'):
continue
num = int(f.split('_')[0])
if oldest_file is None or num < oldest_num:
oldest_num = num
oldest_file = f
if exclude_extension:
oldest_file = oldest_file.replace('.py', '')
return oldest_file
def getNewestMigrationFile(app, exclude_extension=True):
"""
Return the filename associated with the newest migration
"""
newest_file = None
newest_num = -1
for f in getMigrationFileNames(app):
num = int(f.split('_')[0])
if newest_file is None or num > newest_num:
newest_num = num
newest_file = f
if exclude_extension:
newest_file = newest_file.replace('.py', '')
return newest_file
def clean_decimal(number):
""" Clean-up decimal value """
# Check if empty
if number is None or number == '' or number == 0:
return Decimal(0)
# Convert to string and remove spaces
number = str(number).replace(' ', '')
# Guess what type of decimal and thousands separators are used
count_comma = number.count(',')
count_point = number.count('.')
if count_comma == 1:
# Comma is used as decimal separator
if count_point > 0:
# Points are used as thousands separators: remove them
number = number.replace('.', '')
# Replace decimal separator with point
number = number.replace(',', '.')
elif count_point == 1:
# Point is used as decimal separator
if count_comma > 0:
# Commas are used as thousands separators: remove them
number = number.replace(',', '')
# Convert to Decimal type
try:
clean_number = Decimal(number)
except InvalidOperation:
# Number cannot be converted to Decimal (eg. a string containing letters)
return Decimal(0)
return clean_number.quantize(Decimal(1)) if clean_number == clean_number.to_integral() else clean_number.normalize()
def get_objectreference(obj, type_ref: str = 'content_type', object_ref: str = 'object_id'):
"""lookup method for the GenericForeignKey fields
Attributes:
- obj: object that will be resolved
- type_ref: field name for the contenttype field in the model
- object_ref: field name for the object id in the model
Example implementation in the serializer:
```
target = serializers.SerializerMethodField()
def get_target(self, obj):
return get_objectreference(obj, 'target_content_type', 'target_object_id')
```
The method name must always be the name of the field prefixed by 'get_'
"""
model_cls = getattr(obj, type_ref)
obj_id = getattr(obj, object_ref)
# check if references are set -> return nothing if not
if model_cls is None or obj_id is None:
return None
# resolve referenced data into objects
model_cls = model_cls.model_class()
item = model_cls.objects.get(id=obj_id)
url_fnc = getattr(item, 'get_absolute_url', None)
# create output
ret = {}
if url_fnc:
ret['link'] = url_fnc()
return {
'name': str(item),
'model': str(model_cls._meta.verbose_name),
**ret
}
def inheritors(cls):
"""
Return all classes that are subclasses from the supplied cls
"""
subcls = set()
work = [cls]
while work:
parent = work.pop()
for child in parent.__subclasses__():
if child not in subcls:
subcls.add(child)
work.append(child)
return subcls
class InvenTreeTestCase(UserMixin, TestCase):
pass