2
0
mirror of https://github.com/inventree/InvenTree.git synced 2026-04-02 09:31:02 +00:00

Refactoring for report helper functions (#11579)

* Refactoring for media file report helper functions

* Updated unit tests

* Improved error handling

* Generic path return for asset

* Fix return type annotations

* Fix existing test

* Tweaked unit test

* Collect static files in CI

* Run static step for all DB tests

* Update action.yaml

* Fix for action.yaml

* Updated unit tests
This commit is contained in:
Oliver
2026-03-21 17:38:41 +11:00
committed by GitHub
parent 5adf33d354
commit 6d8606bbe4
8 changed files with 224 additions and 58 deletions

View File

@@ -15,6 +15,10 @@ inputs:
required: false required: false
description: 'Install the InvenTree requirements?' description: 'Install the InvenTree requirements?'
default: 'false' default: 'false'
static:
required: false
description: 'Should the static files be built?'
default: 'false'
dev-install: dev-install:
required: false required: false
description: 'Install the InvenTree development requirements?' description: 'Install the InvenTree development requirements?'
@@ -103,3 +107,7 @@ runs:
if: ${{ inputs.update == 'true' }} if: ${{ inputs.update == 'true' }}
shell: bash shell: bash
run: invoke update --skip-backup --skip-static run: invoke update --skip-backup --skip-static
- name: Collect static files
if: ${{ inputs.static == 'true' }}
shell: bash
run: invoke static --skip-plugins

View File

@@ -341,6 +341,7 @@ jobs:
apt-dependency: gettext poppler-utils apt-dependency: gettext poppler-utils
dev-install: true dev-install: true
update: true update: true
static: true
npm: true npm: true
- name: Download Python Code For `${WRAPPER_NAME}` - name: Download Python Code For `${WRAPPER_NAME}`
run: git clone --depth 1 https://github.com/inventree/${WRAPPER_NAME} ./${WRAPPER_NAME} run: git clone --depth 1 https://github.com/inventree/${WRAPPER_NAME} ./${WRAPPER_NAME}
@@ -398,6 +399,7 @@ jobs:
apt-dependency: gettext poppler-utils apt-dependency: gettext poppler-utils
dev-install: true dev-install: true
update: true update: true
static: true
- name: Data Export Test - name: Data Export Test
uses: ./.github/actions/migration uses: ./.github/actions/migration
- name: Test Translations - name: Test Translations
@@ -500,6 +502,7 @@ jobs:
pip-dependency: psycopg django-redis>=5.0.0 pip-dependency: psycopg django-redis>=5.0.0
dev-install: true dev-install: true
update: true update: true
static: true
- name: Run Tests - name: Run Tests
run: invoke dev.test --check --translations run: invoke dev.test --check --translations
- name: Data Export Test - name: Data Export Test
@@ -548,6 +551,7 @@ jobs:
pip-dependency: mysqlclient pip-dependency: mysqlclient
dev-install: true dev-install: true
update: true update: true
static: true
- name: Run Tests - name: Run Tests
run: invoke dev.test --check --translations run: invoke dev.test --check --translations
- name: Data Export Test - name: Data Export Test

View File

@@ -229,8 +229,15 @@ def getStaticUrl(filename):
return os.path.join(STATIC_URL, str(filename)) return os.path.join(STATIC_URL, str(filename))
def TestIfImage(img): def TestIfImage(img) -> bool:
"""Test if an image file is indeed an image.""" """Test if an image file is indeed an image.
Arguments:
img: A file-like object
Returns:
True if the file is a valid image, False otherwise
"""
try: try:
Image.open(img).verify() Image.open(img).verify()
return True return True

View File

@@ -74,7 +74,6 @@ class PartImageTestMixin:
{'image': img_file}, {'image': img_file},
expected_code=200, expected_code=200,
) )
print(response.data)
image_name = response.data['image'] image_name = response.data['image']
self.assertTrue(image_name.startswith('/media/part_images/part_image')) self.assertTrue(image_name.startswith('/media/part_images/part_image'))
return image_name return image_name
@@ -1838,7 +1837,7 @@ class PartDetailTests(PartImageTestMixin, PartAPITestBase):
# Part should not have an image! # Part should not have an image!
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
print(p.image.file) _x = p.image.file
# Try to upload a non-image file # Try to upload a non-image file
test_path = get_testfolder_dir() / 'dummy_image' test_path = get_testfolder_dir() / 'dummy_image'

View File

@@ -5,12 +5,16 @@ import logging
import os import os
from datetime import date, datetime from datetime import date, datetime
from decimal import Decimal, InvalidOperation from decimal import Decimal, InvalidOperation
from io import BytesIO
from pathlib import Path
from typing import Any, Optional from typing import Any, Optional
from django import template from django import template
from django.apps.registry import apps from django.apps.registry import apps
from django.conf import settings from django.conf import settings
from django.core.exceptions import ValidationError from django.contrib.staticfiles.storage import staticfiles_storage
from django.core.exceptions import SuspiciousFileOperation, ValidationError
from django.core.files.storage import default_storage
from django.db.models import Model from django.db.models import Model
from django.db.models.query import QuerySet from django.db.models.query import QuerySet
from django.utils.safestring import SafeString, mark_safe from django.utils.safestring import SafeString, mark_safe
@@ -145,6 +149,111 @@ def getkey(container: dict, key: str, backup_value: Optional[Any] = None) -> Any
return container.get(key, backup_value) return container.get(key, backup_value)
def media_file_exists(path: Path | str) -> bool:
"""Check if a media file exists at the specified path.
Arguments:
path: The path to the media file, relative to the media storage root
Returns:
True if the file exists, False otherwise
"""
if not path:
return False
try:
return default_storage.exists(str(path))
except SuspiciousFileOperation:
# Prevent path traversal attacks
raise ValidationError(_('Invalid media file path') + f": '{path}'")
def static_file_exists(path: Path | str) -> bool:
"""Check if a static file exists at the specified path.
Arguments:
path: The path to the static file, relative to the static storage root
Returns:
True if the file exists, False otherwise
"""
if not path:
return False
try:
return staticfiles_storage.exists(str(path))
except SuspiciousFileOperation:
# Prevent path traversal attacks
raise ValidationError(_('Invalid static file path') + f": '{path}'")
def get_static_file_contents(
path: Path | str, raise_error: bool = True
) -> bytes | None:
"""Return the contents of a static file.
Arguments:
path: The path to the static file, relative to the static storage root
raise_error: If True, raise an error if the file cannot be found (default = True)
Returns:
The contents of the static file, or None if the file cannot be found
"""
if not path:
if raise_error:
raise ValueError('No media file specified')
else:
return None
if not staticfiles_storage.exists(path):
if raise_error:
raise FileNotFoundError(f'Static file does not exist: {path!s}')
else:
return None
with staticfiles_storage.open(str(path)) as f:
file_data = f.read()
return file_data
def get_media_file_contents(path: Path | str, raise_error: bool = True) -> bytes | None:
"""Return the fully qualified file path to an uploaded media file.
Arguments:
path: The path to the media file, relative to the media storage root
raise_error: If True, raise an error if the file cannot be found (default = True)
Returns:
The contents of the media file, or None if the file cannot be found
Raises:
FileNotFoundError: If the requested media file cannot be loaded
PermissionError: If the requested media file is outside of the media root
ValidationError: If the provided path is invalid
Notes:
- The resulting path is resolved against the media root directory
"""
if not path:
if raise_error:
raise ValueError('No media file specified')
else:
return None
if not media_file_exists(path):
if raise_error:
raise FileNotFoundError(f'Media file does not exist: {path!s}')
else:
return None
# Load the file - and return the contents
with default_storage.open(str(path)) as f:
file_data = f.read()
return file_data
@register.simple_tag() @register.simple_tag()
def asset(filename): def asset(filename):
"""Return fully-qualified path for an upload report asset file. """Return fully-qualified path for an upload report asset file.
@@ -159,18 +268,21 @@ def asset(filename):
# Prepend an empty string to enforce 'stringiness' # Prepend an empty string to enforce 'stringiness'
filename = '' + filename filename = '' + filename
# If in debug mode, return URL to the image, not a local file # Remove any leading slash characters from the filename, to prevent path traversal attacks
debug_mode = get_global_setting('REPORT_DEBUG_MODE', cache=False) filename = str(filename).lstrip('/\\')
# Test if the file actually exists full_path = Path('report', 'assets', filename)
full_path = settings.MEDIA_ROOT.joinpath('report', 'assets', filename).resolve()
if not full_path.exists() or not full_path.is_file(): if not media_file_exists(full_path):
raise FileNotFoundError(_('Asset file does not exist') + f": '{filename}'") raise FileNotFoundError(_('Asset file not found') + f": '{filename}'")
if debug_mode: # In debug mode, return a web URL to the asset file (rather than a local file path)
return os.path.join(settings.MEDIA_URL, 'report', 'assets', filename) if get_global_setting('REPORT_DEBUG_MODE', cache=False):
return f'file://{full_path}' return str(Path(settings.MEDIA_URL, 'report', 'assets', filename))
storage_path = default_storage.path(str(full_path))
return f'file://{storage_path}'
@register.simple_tag() @register.simple_tag()
@@ -187,7 +299,7 @@ def uploaded_image(
"""Return raw image data from an 'uploaded' image. """Return raw image data from an 'uploaded' image.
Arguments: Arguments:
filename: The filename of the image relative to the MEDIA_ROOT directory filename: The filename of the image relative to the media root directory
replace_missing: Optionally return a placeholder image if the provided filename does not exist (default = True) replace_missing: Optionally return a placeholder image if the provided filename does not exist (default = True)
replacement_file: The filename of the placeholder image (default = 'blank_image.png') replacement_file: The filename of the placeholder image (default = 'blank_image.png')
validate: Optionally validate that the file is a valid image file validate: Optionally validate that the file is a valid image file
@@ -205,38 +317,43 @@ def uploaded_image(
# Prepend an empty string to enforce 'stringiness' # Prepend an empty string to enforce 'stringiness'
filename = '' + filename filename = '' + filename
# Strip out any leading slash characters from the filename, to prevent path traversal attacks
filename = str(filename).lstrip('/\\')
# If in debug mode, return URL to the image, not a local file # If in debug mode, return URL to the image, not a local file
debug_mode = get_global_setting('REPORT_DEBUG_MODE', cache=False) debug_mode = get_global_setting('REPORT_DEBUG_MODE', cache=False)
# Check if the file exists # Load image data - this will check if the file exists
if not filename: exists = bool(filename) and media_file_exists(filename)
exists = False
else:
try:
full_path = settings.MEDIA_ROOT.joinpath(filename).resolve()
exists = full_path.exists() and full_path.is_file()
except Exception: # pragma: no cover
exists = False # pragma: no cover
if exists and validate and not InvenTree.helpers.TestIfImage(full_path):
logger.warning("File '%s' is not a valid image", filename)
exists = False
if not exists and not replace_missing: if not exists and not replace_missing:
raise FileNotFoundError(_('Image file not found') + f": '{filename}'") raise FileNotFoundError(_('Image file not found') + f": '{filename}'")
if exists:
img_data = get_media_file_contents(filename, raise_error=False)
# Check if the image data is valid
if (
img_data
and validate
and not InvenTree.helpers.TestIfImage(BytesIO(img_data))
):
logger.warning("File '%s' is not a valid image", filename)
img_data = None
exists = False
else:
# Load the backup image from the static files directory
replacement_file_path = Path('img', replacement_file)
img_data = get_static_file_contents(replacement_file_path)
if debug_mode: if debug_mode:
# In debug mode, return a web path (rather than an encoded image blob) # In debug mode, return a web path (rather than an encoded image blob)
if exists: if exists:
return os.path.join(settings.MEDIA_URL, filename) return os.path.join(settings.MEDIA_URL, filename)
return os.path.join(settings.STATIC_URL, 'img', replacement_file) return os.path.join(settings.STATIC_URL, 'img', replacement_file)
elif not exists: if img_data:
full_path = settings.STATIC_ROOT.joinpath('img', replacement_file).resolve() img = Image.open(BytesIO(img_data))
# Load the image, check that it is valid
if full_path.exists() and full_path.is_file():
img = Image.open(full_path)
else: else:
# A placeholder image showing that the image is missing # A placeholder image showing that the image is missing
img = Image.new('RGB', (64, 64), color='red') img = Image.new('RGB', (64, 64), color='red')
@@ -288,22 +405,15 @@ def encode_svg_image(filename: str) -> str:
# Prepend an empty string to enforce 'stringiness' # Prepend an empty string to enforce 'stringiness'
filename = '' + filename filename = '' + filename
# Check if the file exists # Remove any leading slash characters from the filename, to prevent path traversal attacks
filename = str(filename).lstrip('/\\')
if not filename: if not filename:
exists = False raise FileNotFoundError(_('No image file specified'))
else:
try:
full_path = settings.MEDIA_ROOT.joinpath(filename).resolve()
exists = full_path.exists() and full_path.is_file()
except Exception:
exists = False
if not exists: # Read out the file contents
raise FileNotFoundError(_('Image file not found') + f": '{filename}'") # Note: This will check if the file exists, and raise an error if it does not
data = get_media_file_contents(filename)
# Read the file data
with open(full_path, 'rb') as f:
data = f.read()
# Return the base64-encoded data # Return the base64-encoded data
return 'data:image/svg+xml;charset=utf-8;base64,' + base64.b64encode(data).decode( return 'data:image/svg+xml;charset=utf-8;base64,' + base64.b64encode(data).decode(
@@ -323,8 +433,15 @@ def part_image(part: Part, preview: bool = False, thumbnail: bool = False, **kwa
Raises: Raises:
TypeError: If provided part is not a Part instance TypeError: If provided part is not a Part instance
""" """
if type(part) is not Part: if not part or not isinstance(part, Part):
raise TypeError(_('part_image tag requires a Part instance')) raise TypeError(_('part_image tag requires a Part instance'))
image_filename = InvenTree.helpers.image2name(part.image, preview, thumbnail)
if kwargs.get('check_exists'):
if not media_file_exists(image_filename):
raise FileNotFoundError(_('Image file not found') + f": '{image_filename}'")
return uploaded_image( return uploaded_image(
InvenTree.helpers.image2name(part.image, preview, thumbnail), **kwargs InvenTree.helpers.image2name(part.image, preview, thumbnail), **kwargs
) )

View File

@@ -14,7 +14,6 @@ from djmoney.money import Money
from PIL import Image from PIL import Image
from common.models import InvenTreeSetting, Parameter, ParameterTemplate from common.models import InvenTreeSetting, Parameter, ParameterTemplate
from InvenTree.config import get_testfolder_dir
from InvenTree.unit_test import InvenTreeTestCase from InvenTree.unit_test import InvenTreeTestCase
from part.models import Part # TODO fix import: PartParameter, PartParameterTemplate from part.models import Part # TODO fix import: PartParameter, PartParameterTemplate
from part.test_api import PartImageTestMixin from part.test_api import PartImageTestMixin
@@ -81,7 +80,27 @@ class ReportTagTest(PartImageTestMixin, InvenTreeTestCase):
self.debug_mode(False) self.debug_mode(False)
asset = report_tags.asset('test.txt') asset = report_tags.asset('test.txt')
self.assertEqual(asset, f'file://{asset_dir}/test.txt')
# Test for attempted path traversal
with self.assertRaises(ValidationError):
report_tags.asset('../../../report/assets/test.txt')
def test_file_access(self):
"""Tests for media and static file access."""
for fn in [None, '', '@@@@@@', 'fake_file.txt']:
self.assertFalse(report_tags.media_file_exists(fn))
self.assertFalse(report_tags.static_file_exists(fn))
with self.assertRaises(FileNotFoundError):
report_tags.get_media_file_contents('dummy_file.txt')
with self.assertRaises(ValueError):
report_tags.get_static_file_contents(None)
# Try again, without throwing an error
self.assertIsNone(
report_tags.get_media_file_contents('dummy_file.txt', raise_error=False)
)
def test_uploaded_image(self): def test_uploaded_image(self):
"""Tests for retrieving uploaded images.""" """Tests for retrieving uploaded images."""
@@ -148,6 +167,10 @@ class ReportTagTest(PartImageTestMixin, InvenTreeTestCase):
) )
self.assertTrue(img.startswith('data:image/png;charset=utf-8;base64,')) self.assertTrue(img.startswith('data:image/png;charset=utf-8;base64,'))
# Attempted path traversal
with self.assertRaises(ValidationError):
report_tags.uploaded_image('../../../part/images/test.jpg')
def test_part_image(self): def test_part_image(self):
"""Unit tests for the 'part_image' tag.""" """Unit tests for the 'part_image' tag."""
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
@@ -157,8 +180,10 @@ class ReportTagTest(PartImageTestMixin, InvenTreeTestCase):
self.create_test_image() self.create_test_image()
obj.refresh_from_db() obj.refresh_from_db()
report_tags.part_image(obj, preview=True) r = report_tags.part_image(obj, preview=True)
report_tags.part_image(obj, thumbnail=True) self.assertIn('data:image/png;charset=utf-8;base64,', r)
r = report_tags.part_image(obj, thumbnail=True)
self.assertIn('data:image/png;charset=utf-8;base64,', r)
def test_company_image(self): def test_company_image(self):
"""Unit tests for the 'company_image' tag.""" """Unit tests for the 'company_image' tag."""
@@ -392,12 +417,16 @@ class ReportTagTest(PartImageTestMixin, InvenTreeTestCase):
def test_encode_svg_image(self): def test_encode_svg_image(self):
"""Test the encode_svg_image template tag.""" """Test the encode_svg_image template tag."""
# Generate smallest possible SVG for testing # Generate smallest possible SVG for testing
svg_path = get_testfolder_dir() / 'part_image_123abc.png' # Store it in the media directory
img_path = 'part_image_123abc.png'
svg_path = settings.MEDIA_ROOT / img_path
with open(svg_path, 'w', encoding='utf8') as f: with open(svg_path, 'w', encoding='utf8') as f:
f.write('<svg xmlns="http://www.w3.org/2000/svg>') f.write('<svg xmlns="http://www.w3.org/2000/svg>')
# Test with a valid SVG file # Test with a valid SVG file
svg = report_tags.encode_svg_image(svg_path) svg = report_tags.encode_svg_image(img_path)
self.assertTrue(svg.startswith('data:image/svg+xml;charset=utf-8;base64,')) self.assertTrue(svg.startswith('data:image/svg+xml;charset=utf-8;base64,'))
self.assertIn('svg', svg) self.assertIn('svg', svg)
self.assertEqual( self.assertEqual(

View File

@@ -571,7 +571,9 @@ class TestReportTest(PrintTestMixins, ReportTest):
def test_mdl_salesorder(self): def test_mdl_salesorder(self):
"""Test the SalesOrder model.""" """Test the SalesOrder model."""
self.run_print_test(SalesOrder, 'salesorder', label=False) for enabled in [True, False]:
set_global_setting('REPORT_DEBUG_MODE', enabled)
self.run_print_test(SalesOrder, 'salesorder', label=False)
class AdminTest(AdminTestCase): class AdminTest(AdminTestCase):

View File

@@ -112,9 +112,9 @@ export const PdfPreviewComponent: PreviewAreaComponent = forwardRef(
style={{ style={{
display: 'flex', display: 'flex',
justifyContent: 'center', justifyContent: 'center',
alignItems: 'center',
height: '100%', height: '100%',
width: '100%' width: '100%',
paddingTop: '50px'
}} }}
> >
<Trans>Preview not available, click "Reload Preview".</Trans> <Trans>Preview not available, click "Reload Preview".</Trans>