2
0
mirror of https://github.com/inventree/InvenTree.git synced 2025-04-28 19:46:46 +00:00

Refactor login state management (#7158)

* Refactor login state management

- Previously relied only on presence of cookie
- Cookie may not actually be *valid*
- Inspect actual login state by looking at userState values
- Ensures better sequencing of global state API requests
- Login state is now correctly preseed across browsers

* Ignore errors for user/me/ API endpoint in playwright test

* Do not request notifications unless logged in

* Prevent duplicate licenses

* Update src/frontend/src/views/DesktopAppView.tsx

Co-authored-by: Matthias Mair <code@mjmair.com>

* Simplify checkLoginState

* Fix bug in return types

* Update playwright tests

* linting

* Remove error msg

* Use token auth for API calls

- Will (hopefully) allow us to bypass csrfmiddle request handling?

* Refetch token if not available

* Use cache for DISPLAY_FULL_NAMES setting

* Update src/frontend/tests/baseFixtures.ts

Co-authored-by: Matthias Mair <code@mjmair.com>

* PUI test updates

* Tweak doLogout function

* Revert change to baseFixtures.ts

* Cleanup

* Fix highlighted property

* Test cleanup

---------

Co-authored-by: Matthias Mair <code@mjmair.com>
This commit is contained in:
Oliver 2024-05-07 23:11:38 +10:00 committed by GitHub
parent 6c944c73dd
commit 289af4e924
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 225 additions and 109 deletions

View File

@ -73,8 +73,24 @@ class LicenseView(APIView):
logger.exception("Exception while reading license file '%s': %s", path, e) logger.exception("Exception while reading license file '%s': %s", path, e)
return [] return []
# Ensure consistent string between backend and frontend licenses output = []
return [{key.lower(): value for key, value in entry.items()} for entry in data] names = set()
# Ensure we do not have any duplicate 'name' values in the list
for entry in data:
name = None
for key in entry.keys():
if key.lower() == 'name':
name = entry[key]
break
if name is None or name in names:
continue
names.add(name)
output.append({key.lower(): value for key, value in entry.items()})
return output
@extend_schema(responses={200: OpenApiResponse(response=LicenseViewSerializer)}) @extend_schema(responses={200: OpenApiResponse(response=LicenseViewSerializer)})
def get(self, request, *args, **kwargs): def get(self, request, *args, **kwargs):

View File

@ -70,7 +70,8 @@ class AuthRequiredMiddleware(object):
# API requests are handled by the DRF library # API requests are handled by the DRF library
if request.path_info.startswith('/api/'): if request.path_info.startswith('/api/'):
return self.get_response(request) response = self.get_response(request)
return response
# Is the function exempt from auth requirements? # Is the function exempt from auth requirements?
path_func = resolve(request.path).func path_func = resolve(request.path).func

View File

@ -34,7 +34,7 @@ logger = logging.getLogger('inventree')
# string representation of a user # string representation of a user
def user_model_str(self): def user_model_str(self):
"""Function to override the default Django User __str__.""" """Function to override the default Django User __str__."""
if common_models.InvenTreeSetting.get_setting('DISPLAY_FULL_NAMES'): if common_models.InvenTreeSetting.get_setting('DISPLAY_FULL_NAMES', cache=True):
if self.first_name or self.last_name: if self.first_name or self.last_name:
return f'{self.first_name} {self.last_name}' return f'{self.first_name} {self.last_name}'
return self.username return self.username
@ -831,7 +831,9 @@ class Owner(models.Model):
"""Defines the owner string representation.""" """Defines the owner string representation."""
if ( if (
self.owner_type.name == 'user' self.owner_type.name == 'user'
and common_models.InvenTreeSetting.get_setting('DISPLAY_FULL_NAMES') and common_models.InvenTreeSetting.get_setting(
'DISPLAY_FULL_NAMES', cache=True
)
): ):
display_name = self.owner.get_full_name() display_name = self.owner.get_full_name()
else: else:
@ -842,7 +844,9 @@ class Owner(models.Model):
"""Return the 'name' of this owner.""" """Return the 'name' of this owner."""
if ( if (
self.owner_type.name == 'user' self.owner_type.name == 'user'
and common_models.InvenTreeSetting.get_setting('DISPLAY_FULL_NAMES') and common_models.InvenTreeSetting.get_setting(
'DISPLAY_FULL_NAMES', cache=True
)
): ):
return self.owner.get_full_name() or str(self.owner) return self.owner.get_full_name() or str(self.owner)
return str(self.owner) return str(self.owner)

View File

@ -5,6 +5,7 @@ export default defineConfig({
fullyParallel: true, fullyParallel: true,
timeout: 60000, timeout: 60000,
forbidOnly: !!process.env.CI, forbidOnly: !!process.env.CI,
timeout: 5 * 60 * 1000,
retries: process.env.CI ? 1 : 0, retries: process.env.CI ? 1 : 0,
workers: process.env.CI ? 2 : undefined, workers: process.env.CI ? 2 : undefined,
reporter: process.env.CI ? [['html', { open: 'never' }], ['github']] : 'list', reporter: process.env.CI ? [['html', { open: 'never' }], ['github']] : 'list',

View File

@ -2,6 +2,7 @@ import { QueryClient } from '@tanstack/react-query';
import axios from 'axios'; import axios from 'axios';
import { useLocalState } from './states/LocalState'; import { useLocalState } from './states/LocalState';
import { useUserState } from './states/UserState';
// Global API instance // Global API instance
export const api = axios.create({}); export const api = axios.create({});
@ -11,6 +12,7 @@ export const api = axios.create({});
*/ */
export function setApiDefaults() { export function setApiDefaults() {
const host = useLocalState.getState().host; const host = useLocalState.getState().host;
const token = useUserState.getState().token;
api.defaults.baseURL = host; api.defaults.baseURL = host;
api.defaults.timeout = 2500; api.defaults.timeout = 2500;
@ -19,6 +21,12 @@ export function setApiDefaults() {
api.defaults.withXSRFToken = true; api.defaults.withXSRFToken = true;
api.defaults.xsrfCookieName = 'csrftoken'; api.defaults.xsrfCookieName = 'csrftoken';
api.defaults.xsrfHeaderName = 'X-CSRFToken'; api.defaults.xsrfHeaderName = 'X-CSRFToken';
if (token) {
api.defaults.headers['Authorization'] = `Token ${token}`;
} else {
delete api.defaults.headers['Authorization'];
}
} }
export const queryClient = new QueryClient(); export const queryClient = new QueryClient();

View File

@ -17,9 +17,10 @@ import { useLocation, useNavigate } from 'react-router-dom';
import { api } from '../../App'; import { api } from '../../App';
import { ApiEndpoints } from '../../enums/ApiEndpoints'; import { ApiEndpoints } from '../../enums/ApiEndpoints';
import { doBasicLogin, doSimpleLogin, isLoggedIn } from '../../functions/auth'; import { doBasicLogin, doSimpleLogin } from '../../functions/auth';
import { showLoginNotification } from '../../functions/notifications'; import { showLoginNotification } from '../../functions/notifications';
import { apiUrl, useServerApiState } from '../../states/ApiState'; import { apiUrl, useServerApiState } from '../../states/ApiState';
import { useUserState } from '../../states/UserState';
import { SsoButton } from '../buttons/SSOButton'; import { SsoButton } from '../buttons/SSOButton';
export function AuthenticationForm() { export function AuthenticationForm() {
@ -31,6 +32,7 @@ export function AuthenticationForm() {
const [auth_settings] = useServerApiState((state) => [state.auth_settings]); const [auth_settings] = useServerApiState((state) => [state.auth_settings]);
const navigate = useNavigate(); const navigate = useNavigate();
const location = useLocation(); const location = useLocation();
const { isLoggedIn } = useUserState();
const [isLoggingIn, setIsLoggingIn] = useState<boolean>(false); const [isLoggingIn, setIsLoggingIn] = useState<boolean>(false);

View File

@ -45,17 +45,17 @@ function ConditionalDocTooltip({
export function MenuLinks({ export function MenuLinks({
links, links,
highlighted highlighted = false
}: { }: {
links: MenuLinkItem[]; links: MenuLinkItem[];
highlighted?: boolean; highlighted?: boolean;
}) { }) {
const { classes } = InvenTreeStyle(); const { classes } = InvenTreeStyle();
highlighted = highlighted || false;
const filteredLinks = links.filter( const filteredLinks = links.filter(
(item) => !highlighted || item.highlight === true (item) => !highlighted || item.highlight === true
); );
return ( return (
<SimpleGrid cols={2} spacing={0}> <SimpleGrid cols={2} spacing={0}>
{filteredLinks.map((item) => ( {filteredLinks.map((item) => (

View File

@ -11,6 +11,7 @@ import { ApiEndpoints } from '../../enums/ApiEndpoints';
import { InvenTreeStyle } from '../../globalStyle'; import { InvenTreeStyle } from '../../globalStyle';
import { apiUrl } from '../../states/ApiState'; import { apiUrl } from '../../states/ApiState';
import { useLocalState } from '../../states/LocalState'; import { useLocalState } from '../../states/LocalState';
import { useUserState } from '../../states/UserState';
import { ScanButton } from '../buttons/ScanButton'; import { ScanButton } from '../buttons/ScanButton';
import { SpotlightButton } from '../buttons/SpotlightButton'; import { SpotlightButton } from '../buttons/SpotlightButton';
import { MainMenu } from './MainMenu'; import { MainMenu } from './MainMenu';
@ -37,11 +38,14 @@ export function Header() {
{ open: openNotificationDrawer, close: closeNotificationDrawer } { open: openNotificationDrawer, close: closeNotificationDrawer }
] = useDisclosure(false); ] = useDisclosure(false);
const { isLoggedIn } = useUserState();
const [notificationCount, setNotificationCount] = useState<number>(0); const [notificationCount, setNotificationCount] = useState<number>(0);
// Fetch number of notifications for the current user // Fetch number of notifications for the current user
const notifications = useQuery({ const notifications = useQuery({
queryKey: ['notification-count'], queryKey: ['notification-count'],
enabled: isLoggedIn(),
queryFn: async () => { queryFn: async () => {
try { try {
const params = { const params = {

View File

@ -6,13 +6,14 @@ import { useEffect, useState } from 'react';
import { Navigate, Outlet, useLocation, useNavigate } from 'react-router-dom'; import { Navigate, Outlet, useLocation, useNavigate } from 'react-router-dom';
import { getActions } from '../../defaults/actions'; import { getActions } from '../../defaults/actions';
import { isLoggedIn } from '../../functions/auth';
import { InvenTreeStyle } from '../../globalStyle'; import { InvenTreeStyle } from '../../globalStyle';
import { useUserState } from '../../states/UserState';
import { Footer } from './Footer'; import { Footer } from './Footer';
import { Header } from './Header'; import { Header } from './Header';
export const ProtectedRoute = ({ children }: { children: JSX.Element }) => { export const ProtectedRoute = ({ children }: { children: JSX.Element }) => {
const location = useLocation(); const location = useLocation();
const { isLoggedIn } = useUserState();
if (!isLoggedIn()) { if (!isLoggedIn()) {
return ( return (

View File

@ -51,7 +51,7 @@ export function PartCategoryTree({
) )
.catch((error) => { .catch((error) => {
console.error('Error fetching part category tree:', error); console.error('Error fetching part category tree:', error);
return error; return [];
}), }),
refetchOnMount: true refetchOnMount: true
}); });

View File

@ -43,7 +43,7 @@ export function StockLocationTree({
) )
.catch((error) => { .catch((error) => {
console.error('Error fetching stock location tree:', error); console.error('Error fetching stock location tree:', error);
return error; return [];
}), }),
refetchOnMount: true refetchOnMount: true
}); });

View File

@ -104,7 +104,7 @@ export function LanguageContext({ children }: { children: JSX.Element }) {
}) })
/* istanbul ignore next */ /* istanbul ignore next */
.catch((err) => { .catch((err) => {
console.error('Failed loading translations', err); console.error('ERR: Failed loading translations', err);
if (isMounted.current) setLoadedState('error'); if (isMounted.current) setLoadedState('error');
}); });

View File

@ -6,6 +6,7 @@ import { api, setApiDefaults } from '../App';
import { ApiEndpoints } from '../enums/ApiEndpoints'; import { ApiEndpoints } from '../enums/ApiEndpoints';
import { apiUrl } from '../states/ApiState'; import { apiUrl } from '../states/ApiState';
import { useLocalState } from '../states/LocalState'; import { useLocalState } from '../states/LocalState';
import { useUserState } from '../states/UserState';
import { fetchGlobalStates } from '../states/states'; import { fetchGlobalStates } from '../states/states';
import { showLoginNotification } from './notifications'; import { showLoginNotification } from './notifications';
@ -16,7 +17,8 @@ import { showLoginNotification } from './notifications';
*/ */
export const doBasicLogin = async (username: string, password: string) => { export const doBasicLogin = async (username: string, password: string) => {
const { host } = useLocalState.getState(); const { host } = useLocalState.getState();
// const apiState = useServerApiState.getState(); const { clearUserState, setToken, fetchUserState, isLoggedIn } =
useUserState.getState();
if (username.length == 0 || password.length == 0) { if (username.length == 0 || password.length == 0) {
return; return;
@ -26,6 +28,8 @@ export const doBasicLogin = async (username: string, password: string) => {
const login_url = apiUrl(ApiEndpoints.user_login); const login_url = apiUrl(ApiEndpoints.user_login);
let result: boolean = false;
// Attempt login with // Attempt login with
await api await api
.post( .post(
@ -39,18 +43,21 @@ export const doBasicLogin = async (username: string, password: string) => {
} }
) )
.then((response) => { .then((response) => {
switch (response.status) { if (response.status == 200) {
case 200: if (response.data.key) {
fetchGlobalStates(); setToken(response.data.key);
break; result = true;
default: }
clearCsrfCookie();
break;
} }
}) })
.catch(() => { .catch(() => {});
clearCsrfCookie();
}); if (result) {
await fetchUserState();
await fetchGlobalStates();
} else {
clearUserState();
}
}; };
/** /**
@ -59,16 +66,21 @@ export const doBasicLogin = async (username: string, password: string) => {
* @arg deleteToken: If true, delete the token from the server * @arg deleteToken: If true, delete the token from the server
*/ */
export const doLogout = async (navigate: any) => { export const doLogout = async (navigate: any) => {
const { clearUserState, isLoggedIn } = useUserState.getState();
// Logout from the server session // Logout from the server session
await api.post(apiUrl(ApiEndpoints.user_logout)).finally(() => { if (isLoggedIn() || !!getCsrfCookie()) {
clearCsrfCookie(); await api.post(apiUrl(ApiEndpoints.user_logout)).catch(() => {});
navigate('/login');
showLoginNotification({ showLoginNotification({
title: t`Logged Out`, title: t`Logged Out`,
message: t`Successfully logged out` message: t`Successfully logged out`
}); });
}); }
clearUserState();
clearCsrfCookie();
navigate('/login');
}; };
export const doSimpleLogin = async (email: string) => { export const doSimpleLogin = async (email: string) => {
@ -122,17 +134,19 @@ export function handleReset(navigate: any, values: { email: string }) {
* - An existing API token is stored in the session * - An existing API token is stored in the session
* - An existing CSRF cookie is stored in the browser * - An existing CSRF cookie is stored in the browser
*/ */
export function checkLoginState( export const checkLoginState = async (
navigate: any, navigate: any,
redirect?: string, redirect?: string,
no_redirect?: boolean no_redirect?: boolean
) { ) => {
setApiDefaults(); setApiDefaults();
if (redirect == '/') { if (redirect == '/') {
redirect = '/home'; redirect = '/home';
} }
const { isLoggedIn, fetchUserState } = useUserState.getState();
// Callback function when login is successful // Callback function when login is successful
const loginSuccess = () => { const loginSuccess = () => {
showLoginNotification({ showLoginNotification({
@ -140,6 +154,8 @@ export function checkLoginState(
message: t`Successfully logged in` message: t`Successfully logged in`
}); });
fetchGlobalStates();
navigate(redirect ?? '/home'); navigate(redirect ?? '/home');
}; };
@ -150,24 +166,22 @@ export function checkLoginState(
} }
}; };
// Check the 'user_me' endpoint to see if the user is logged in
if (isLoggedIn()) { if (isLoggedIn()) {
api // Already logged in
.get(apiUrl(ApiEndpoints.user_me)) loginSuccess();
.then((response) => { return;
if (response.status == 200) { }
// Not yet logged in, but we might have a valid session cookie
// Attempt to login
await fetchUserState();
if (isLoggedIn()) {
loginSuccess(); loginSuccess();
} else { } else {
loginFailure(); loginFailure();
} }
}) };
.catch(() => {
loginFailure();
});
} else {
loginFailure();
}
}
/* /*
* Return the value of the CSRF cookie, if available * Return the value of the CSRF cookie, if available
@ -181,10 +195,6 @@ export function getCsrfCookie() {
return cookieValue; return cookieValue;
} }
export function isLoggedIn() {
return !!getCsrfCookie();
}
/* /*
* Clear out the CSRF and session cookies (force session logout) * Clear out the CSRF and session cookies (force session logout)
*/ */

View File

@ -17,11 +17,16 @@ export function AccountDetailPanel() {
const form = useForm({ initialValues: user }); const form = useForm({ initialValues: user });
const [editing, setEditing] = useToggle([false, true] as const); const [editing, setEditing] = useToggle([false, true] as const);
function SaveData(values: any) { function SaveData(values: any) {
api.put(apiUrl(ApiEndpoints.user_me), values).then((res) => { api
.put(apiUrl(ApiEndpoints.user_me), values)
.then((res) => {
if (res.status === 200) { if (res.status === 200) {
setEditing(); setEditing();
fetchUserState(); fetchUserState();
} }
})
.catch(() => {
console.error('ERR: Error saving user data');
}); });
} }

View File

@ -5,9 +5,9 @@ import { create, createStore } from 'zustand';
import { api } from '../App'; import { api } from '../App';
import { ApiEndpoints } from '../enums/ApiEndpoints'; import { ApiEndpoints } from '../enums/ApiEndpoints';
import { isLoggedIn } from '../functions/auth';
import { isTrue } from '../functions/conversion'; import { isTrue } from '../functions/conversion';
import { PathParams, apiUrl } from './ApiState'; import { PathParams, apiUrl } from './ApiState';
import { useUserState } from './UserState';
import { Setting, SettingsLookup } from './states'; import { Setting, SettingsLookup } from './states';
export interface SettingsStateProps { export interface SettingsStateProps {
@ -29,6 +29,8 @@ export const useGlobalSettingsState = create<SettingsStateProps>(
lookup: {}, lookup: {},
endpoint: ApiEndpoints.settings_global_list, endpoint: ApiEndpoints.settings_global_list,
fetchSettings: async () => { fetchSettings: async () => {
const { isLoggedIn } = useUserState.getState();
if (!isLoggedIn()) { if (!isLoggedIn()) {
return; return;
} }
@ -63,6 +65,8 @@ export const useUserSettingsState = create<SettingsStateProps>((set, get) => ({
lookup: {}, lookup: {},
endpoint: ApiEndpoints.settings_user_list, endpoint: ApiEndpoints.settings_user_list,
fetchSettings: async () => { fetchSettings: async () => {
const { isLoggedIn } = useUserState.getState();
if (!isLoggedIn()) { if (!isLoggedIn()) {
return; return;
} }

View File

@ -6,8 +6,8 @@ import { StatusCodeListInterface } from '../components/render/StatusRenderer';
import { statusCodeList } from '../defaults/backendMappings'; import { statusCodeList } from '../defaults/backendMappings';
import { ApiEndpoints } from '../enums/ApiEndpoints'; import { ApiEndpoints } from '../enums/ApiEndpoints';
import { ModelType } from '../enums/ModelType'; import { ModelType } from '../enums/ModelType';
import { isLoggedIn } from '../functions/auth';
import { apiUrl } from './ApiState'; import { apiUrl } from './ApiState';
import { useUserState } from './UserState';
type StatusLookup = Record<ModelType | string, StatusCodeListInterface>; type StatusLookup = Record<ModelType | string, StatusCodeListInterface>;
@ -23,6 +23,8 @@ export const useGlobalStatusState = create<ServerStateProps>()(
status: undefined, status: undefined,
setStatus: (newStatus: StatusLookup) => set({ status: newStatus }), setStatus: (newStatus: StatusLookup) => set({ status: newStatus }),
fetchStatus: async () => { fetchStatus: async () => {
const { isLoggedIn } = useUserState.getState();
// Fetch status data for rendering labels // Fetch status data for rendering labels
if (!isLoggedIn()) { if (!isLoggedIn()) {
return; return;

View File

@ -1,22 +1,28 @@
import { create } from 'zustand'; import { create } from 'zustand';
import { api } from '../App'; import { api, setApiDefaults } from '../App';
import { ApiEndpoints } from '../enums/ApiEndpoints'; import { ApiEndpoints } from '../enums/ApiEndpoints';
import { UserPermissions, UserRoles } from '../enums/Roles'; import { UserPermissions, UserRoles } from '../enums/Roles';
import { isLoggedIn } from '../functions/auth'; import { clearCsrfCookie } from '../functions/auth';
import { apiUrl } from './ApiState'; import { apiUrl } from './ApiState';
import { UserProps } from './states'; import { UserProps } from './states';
interface UserStateProps { interface UserStateProps {
user: UserProps | undefined; user: UserProps | undefined;
token: string | undefined;
username: () => string; username: () => string;
setUser: (newUser: UserProps) => void; setUser: (newUser: UserProps) => void;
setToken: (newToken: string) => void;
clearToken: () => void;
fetchUserToken: () => void;
fetchUserState: () => void; fetchUserState: () => void;
clearUserState: () => void;
checkUserRole: (role: UserRoles, permission: UserPermissions) => boolean; checkUserRole: (role: UserRoles, permission: UserPermissions) => boolean;
hasDeleteRole: (role: UserRoles) => boolean; hasDeleteRole: (role: UserRoles) => boolean;
hasChangeRole: (role: UserRoles) => boolean; hasChangeRole: (role: UserRoles) => boolean;
hasAddRole: (role: UserRoles) => boolean; hasAddRole: (role: UserRoles) => boolean;
hasViewRole: (role: UserRoles) => boolean; hasViewRole: (role: UserRoles) => boolean;
isLoggedIn: () => boolean;
isStaff: () => boolean; isStaff: () => boolean;
isSuperuser: () => boolean; isSuperuser: () => boolean;
} }
@ -26,6 +32,15 @@ interface UserStateProps {
*/ */
export const useUserState = create<UserStateProps>((set, get) => ({ export const useUserState = create<UserStateProps>((set, get) => ({
user: undefined, user: undefined,
token: undefined,
setToken: (newToken: string) => {
set({ token: newToken });
setApiDefaults();
},
clearToken: () => {
set({ token: undefined });
setApiDefaults();
},
username: () => { username: () => {
const user: UserProps = get().user as UserProps; const user: UserProps = get().user as UserProps;
@ -36,9 +51,29 @@ export const useUserState = create<UserStateProps>((set, get) => ({
} }
}, },
setUser: (newUser: UserProps) => set({ user: newUser }), setUser: (newUser: UserProps) => set({ user: newUser }),
clearUserState: () => {
set({ user: undefined });
set({ token: undefined });
clearCsrfCookie();
setApiDefaults();
},
fetchUserToken: async () => {
await api
.get(apiUrl(ApiEndpoints.user_token))
.then((response) => {
if (response.status == 200 && response.data.token) {
get().setToken(response.data.token);
} else {
get().clearToken();
}
})
.catch(() => {
get().clearToken();
});
},
fetchUserState: async () => { fetchUserState: async () => {
if (!isLoggedIn()) { if (!get().token) {
return; await get().fetchUserToken();
} }
// Fetch user data // Fetch user data
@ -47,6 +82,7 @@ export const useUserState = create<UserStateProps>((set, get) => ({
timeout: 2000 timeout: 2000
}) })
.then((response) => { .then((response) => {
if (response.status == 200) {
const user: UserProps = { const user: UserProps = {
pk: response.data.pk, pk: response.data.pk,
first_name: response.data?.first_name ?? '', first_name: response.data?.first_name ?? '',
@ -55,15 +91,23 @@ export const useUserState = create<UserStateProps>((set, get) => ({
username: response.data.username username: response.data.username
}; };
set({ user: user }); set({ user: user });
} else {
get().clearUserState();
}
}) })
.catch((error) => { .catch(() => {
console.error('ERR: Error fetching user data'); get().clearUserState();
}); });
if (!get().isLoggedIn()) {
return;
}
// Fetch role data // Fetch role data
await api await api
.get(apiUrl(ApiEndpoints.user_roles)) .get(apiUrl(ApiEndpoints.user_roles))
.then((response) => { .then((response) => {
if (response.status == 200) {
const user: UserProps = get().user as UserProps; const user: UserProps = get().user as UserProps;
// Update user with role data // Update user with role data
@ -73,9 +117,13 @@ export const useUserState = create<UserStateProps>((set, get) => ({
user.is_superuser = response.data?.is_superuser ?? false; user.is_superuser = response.data?.is_superuser ?? false;
set({ user: user }); set({ user: user });
} }
} else {
get().clearUserState();
}
}) })
.catch((_error) => { .catch((_error) => {
console.error('ERR: Error fetching user roles'); console.error('ERR: Error fetching user roles');
get().clearUserState();
}); });
}, },
checkUserRole: (role: UserRoles, permission: UserPermissions) => { checkUserRole: (role: UserRoles, permission: UserPermissions) => {
@ -93,6 +141,13 @@ export const useUserState = create<UserStateProps>((set, get) => ({
return user?.roles[role]?.includes(permission) ?? false; return user?.roles[role]?.includes(permission) ?? false;
}, },
isLoggedIn: () => {
if (!get().token) {
return false;
}
const user: UserProps = get().user as UserProps;
return !!user && !!user.pk;
},
isStaff: () => { isStaff: () => {
const user: UserProps = get().user as UserProps; const user: UserProps = get().user as UserProps;
return user?.is_staff ?? false; return user?.is_staff ?? false;

View File

@ -1,5 +1,4 @@
import { setApiDefaults } from '../App'; import { setApiDefaults } from '../App';
import { isLoggedIn } from '../functions/auth';
import { useServerApiState } from './ApiState'; import { useServerApiState } from './ApiState';
import { useGlobalSettingsState, useUserSettingsState } from './SettingsState'; import { useGlobalSettingsState, useUserSettingsState } from './SettingsState';
import { useGlobalStatusState } from './StatusState'; import { useGlobalStatusState } from './StatusState';
@ -126,6 +125,8 @@ export type SettingsLookup = {
* Necessary on login, or if locale is changed. * Necessary on login, or if locale is changed.
*/ */
export function fetchGlobalStates() { export function fetchGlobalStates() {
const { isLoggedIn } = useUserState.getState();
if (!isLoggedIn()) { if (!isLoggedIn()) {
return; return;
} }

View File

@ -1,46 +1,22 @@
import { QueryClientProvider } from '@tanstack/react-query'; import { QueryClientProvider } from '@tanstack/react-query';
import { useEffect, useState } from 'react'; import { useEffect } from 'react';
import { BrowserRouter } from 'react-router-dom'; import { BrowserRouter } from 'react-router-dom';
import { queryClient } from '../App'; import { queryClient } from '../App';
import { BaseContext } from '../contexts/BaseContext'; import { BaseContext } from '../contexts/BaseContext';
import { defaultHostList } from '../defaults/defaultHostList'; import { defaultHostList } from '../defaults/defaultHostList';
import { isLoggedIn } from '../functions/auth';
import { base_url } from '../main'; import { base_url } from '../main';
import { routes } from '../router'; import { routes } from '../router';
import { useLocalState } from '../states/LocalState'; import { useLocalState } from '../states/LocalState';
import {
useGlobalSettingsState,
useUserSettingsState
} from '../states/SettingsState';
import { useUserState } from '../states/UserState';
export default function DesktopAppView() { export default function DesktopAppView() {
const [hostList] = useLocalState((state) => [state.hostList]); const [hostList] = useLocalState((state) => [state.hostList]);
const [fetchUserState] = useUserState((state) => [state.fetchUserState]);
const [fetchGlobalSettings] = useGlobalSettingsState((state) => [
state.fetchSettings
]);
const [fetchUserSettings] = useUserSettingsState((state) => [
state.fetchSettings
]);
// Server Session
const [fetchedServerSession, setFetchedServerSession] = useState(false);
useEffect(() => { useEffect(() => {
if (Object.keys(hostList).length === 0) { if (Object.keys(hostList).length === 0) {
useLocalState.setState({ hostList: defaultHostList }); useLocalState.setState({ hostList: defaultHostList });
} }
}, [hostList]);
if (isLoggedIn() && !fetchedServerSession) {
setFetchedServerSession(true);
fetchUserState();
fetchGlobalSettings();
fetchUserSettings();
}
}, [fetchedServerSession]);
return ( return (
<BaseContext> <BaseContext>

View File

@ -59,6 +59,8 @@ export const test = baseTest.extend({
if ( if (
msg.type() === 'error' && msg.type() === 'error' &&
!msg.text().startsWith('ERR: ') && !msg.text().startsWith('ERR: ') &&
url != 'http://localhost:8000/api/user/me/' &&
url != 'http://localhost:8000/api/user/token/' &&
url != 'http://localhost:8000/api/barcode/' && url != 'http://localhost:8000/api/barcode/' &&
url != 'http://localhost:8000/api/news/?search=&offset=0&limit=25' && url != 'http://localhost:8000/api/news/?search=&offset=0&limit=25' &&
url != 'https://docs.inventree.org/en/versions.json' && url != 'https://docs.inventree.org/en/versions.json' &&

View File

@ -9,7 +9,6 @@ export const doLogin = async (page, username?: string, password?: string) => {
password = password ?? user.password; password = password ?? user.password;
await page.goto(logoutUrl); await page.goto(logoutUrl);
await page.goto(loginUrl);
await expect(page).toHaveTitle(RegExp('^InvenTree.*$')); await expect(page).toHaveTitle(RegExp('^InvenTree.*$'));
await page.waitForURL('**/platform/login'); await page.waitForURL('**/platform/login');
await page.getByLabel('username').fill(username); await page.getByLabel('username').fill(username);

View File

@ -1,5 +1,5 @@
import { expect, test } from './baseFixtures.js'; import { expect, test } from './baseFixtures.js';
import { baseUrl, user } from './defaults.js'; import { baseUrl, loginUrl, user } from './defaults.js';
import { doLogin, doQuickLogin } from './login.js'; import { doLogin, doQuickLogin } from './login.js';
test('PUI - Basic Login Test', async ({ page }) => { test('PUI - Basic Login Test', async ({ page }) => {
@ -17,6 +17,22 @@ test('PUI - Basic Login Test', async ({ page }) => {
await page await page
.getByRole('heading', { name: `Welcome to your Dashboard, ${user.name}` }) .getByRole('heading', { name: `Welcome to your Dashboard, ${user.name}` })
.click(); .click();
// Check that the username is provided
await page.getByText(user.username);
await expect(page).toHaveTitle(RegExp('^InvenTree'));
// Go to the dashboard
await page.goto(baseUrl);
await page.waitForURL('**/platform');
// Logout (via menu)
await page.getByRole('button', { name: 'Ally Access' }).click();
await page.getByRole('menuitem', { name: 'Logout' }).click();
await page.waitForURL('**/platform/login');
await page.getByLabel('username');
}); });
test('PUI - Quick Login Test', async ({ page }) => { test('PUI - Quick Login Test', async ({ page }) => {
@ -34,4 +50,8 @@ test('PUI - Quick Login Test', async ({ page }) => {
await page await page
.getByRole('heading', { name: `Welcome to your Dashboard, ${user.name}` }) .getByRole('heading', { name: `Welcome to your Dashboard, ${user.name}` })
.click(); .click();
// Logout (via URL)
await page.goto(`${baseUrl}/logout/`);
await page.waitForURL('**/platform/login');
}); });

View File

@ -71,9 +71,10 @@ test('PUI - Parts - Supplier Parts', async ({ page }) => {
test('PUI - Sales', async ({ page }) => { test('PUI - Sales', async ({ page }) => {
await doQuickLogin(page); await doQuickLogin(page);
await page.goto(`${baseUrl}/sales/`); await page.goto(`${baseUrl}/sales/index/`);
await page.waitForURL('**/platform/sales/**'); await page.waitForURL('**/platform/sales/**');
await page.getByRole('tab', { name: 'Sales Orders' }).click();
await page.waitForURL('**/platform/sales/index/salesorders'); await page.waitForURL('**/platform/sales/index/salesorders');
await page.getByRole('tab', { name: 'Return Orders' }).click(); await page.getByRole('tab', { name: 'Return Orders' }).click();

View File

@ -5,8 +5,12 @@ import { doQuickLogin } from './login.js';
test('PUI - Stock', async ({ page }) => { test('PUI - Stock', async ({ page }) => {
await doQuickLogin(page); await doQuickLogin(page);
await page.goto(`${baseUrl}/stock`); await page.goto(`${baseUrl}/stock/location/index/`);
await page.waitForURL('**/platform/stock/location/**');
await page.getByRole('tab', { name: 'Location Details' }).click();
await page.waitForURL('**/platform/stock/location/index/details'); await page.waitForURL('**/platform/stock/location/index/details');
await page.getByRole('tab', { name: 'Stock Items' }).click(); await page.getByRole('tab', { name: 'Stock Items' }).click();
await page.getByRole('cell', { name: '1551ABK' }).click(); await page.getByRole('cell', { name: '1551ABK' }).click();
await page.getByRole('tab', { name: 'Stock', exact: true }).click(); await page.getByRole('tab', { name: 'Stock', exact: true }).click();