feat(sso): fix saml + allow to use public invite with sso + fix invite page with multiple sso provider (#9963)

- Fix SAML issue
- Fix the wrong state on the Invite page when multiple SSO provider
exists
- Allow to signup with SSO and public invite link
- For OIDC, use the property upn to guess email for Microsoft and enable
oidc with a specific context in azure
- Improve error in OIDC flow when email not found
This commit is contained in:
Antoine Moreaux
2025-02-03 18:48:25 +01:00
committed by GitHub
parent 253a3eb83f
commit 47487f5d1c
14 changed files with 122 additions and 92 deletions

View File

@ -612,6 +612,7 @@ export type FullName = {
export type GetAuthorizationUrlInput = { export type GetAuthorizationUrlInput = {
identityProviderId: Scalars['String']['input']; identityProviderId: Scalars['String']['input'];
workspaceInviteHash?: InputMaybe<Scalars['String']['input']>;
}; };
export type GetAuthorizationUrlOutput = { export type GetAuthorizationUrlOutput = {

View File

@ -600,6 +600,7 @@ export type FullName = {
export type GetAuthorizationUrlInput = { export type GetAuthorizationUrlInput = {
identityProviderId: Scalars['String']; identityProviderId: Scalars['String'];
workspaceInviteHash?: InputMaybe<Scalars['String']>;
}; };
export type GetAuthorizationUrlOutput = { export type GetAuthorizationUrlOutput = {

View File

@ -13,7 +13,6 @@ import {
Route, Route,
} from 'react-router-dom'; } from 'react-router-dom';
import { Authorize } from '~/pages/auth/Authorize'; import { Authorize } from '~/pages/auth/Authorize';
import { Invite } from '~/pages/auth/Invite';
import { PasswordReset } from '~/pages/auth/PasswordReset'; import { PasswordReset } from '~/pages/auth/PasswordReset';
import { SignInUp } from '~/pages/auth/SignInUp'; import { SignInUp } from '~/pages/auth/SignInUp';
import { NotFound } from '~/pages/not-found/NotFound'; import { NotFound } from '~/pages/not-found/NotFound';
@ -43,7 +42,7 @@ export const useCreateAppRouter = (
<Route path={AppPath.Verify} element={<VerifyEffect />} /> <Route path={AppPath.Verify} element={<VerifyEffect />} />
<Route path={AppPath.VerifyEmail} element={<VerifyEmailEffect />} /> <Route path={AppPath.VerifyEmail} element={<VerifyEmailEffect />} />
<Route path={AppPath.SignInUp} element={<SignInUp />} /> <Route path={AppPath.SignInUp} element={<SignInUp />} />
<Route path={AppPath.Invite} element={<Invite />} /> <Route path={AppPath.Invite} element={<SignInUp />} />
<Route path={AppPath.ResetPassword} element={<PasswordReset />} /> <Route path={AppPath.ResetPassword} element={<PasswordReset />} />
<Route path={AppPath.CreateWorkspace} element={<CreateWorkspace />} /> <Route path={AppPath.CreateWorkspace} element={<CreateWorkspace />} />
<Route path={AppPath.CreateProfile} element={<CreateProfile />} /> <Route path={AppPath.CreateProfile} element={<CreateProfile />} />

View File

@ -5,11 +5,13 @@ import { useRedirect } from '@/domain-manager/hooks/useRedirect';
import { SnackBarVariant } from '@/ui/feedback/snack-bar-manager/components/SnackBar'; import { SnackBarVariant } from '@/ui/feedback/snack-bar-manager/components/SnackBar';
import { useSnackBar } from '@/ui/feedback/snack-bar-manager/hooks/useSnackBar'; import { useSnackBar } from '@/ui/feedback/snack-bar-manager/hooks/useSnackBar';
import { useApolloClient } from '@apollo/client'; import { useApolloClient } from '@apollo/client';
import { useParams } from 'react-router-dom';
export const useSSO = () => { export const useSSO = () => {
const apolloClient = useApolloClient(); const apolloClient = useApolloClient();
const { enqueueSnackBar } = useSnackBar(); const workspaceInviteHash = useParams().workspaceInviteHash;
const { enqueueSnackBar } = useSnackBar();
const { redirect } = useRedirect(); const { redirect } = useRedirect();
const redirectToSSOLoginPage = async (identityProviderId: string) => { const redirectToSSOLoginPage = async (identityProviderId: string) => {
@ -18,7 +20,7 @@ export const useSSO = () => {
authorizationUrlForSSOResult = await apolloClient.mutate({ authorizationUrlForSSOResult = await apolloClient.mutate({
mutation: GET_AUTHORIZATION_URL, mutation: GET_AUTHORIZATION_URL,
variables: { variables: {
input: { identityProviderId }, input: { identityProviderId, workspaceInviteHash },
}, },
}); });
} catch (error: any) { } catch (error: any) {

View File

@ -24,6 +24,7 @@ export const useWorkspaceFromInviteHash = () => {
); );
const { data: workspaceFromInviteHash, loading } = const { data: workspaceFromInviteHash, loading } =
useGetWorkspaceFromInviteHashQuery({ useGetWorkspaceFromInviteHashQuery({
skip: !workspaceInviteHash,
variables: { inviteHash: workspaceInviteHash || '' }, variables: { inviteHash: workspaceInviteHash || '' },
onError: (error) => { onError: (error) => {
enqueueSnackBar(error.message, { enqueueSnackBar(error.message, {

View File

@ -1,27 +0,0 @@
import { Logo } from '@/auth/components/Logo';
import { Title } from '@/auth/components/Title';
import { SignInUpWorkspaceScopeForm } from '@/auth/sign-in-up/components/SignInUpWorkspaceScopeForm';
import { useWorkspaceFromInviteHash } from '@/auth/sign-in-up/hooks/useWorkspaceFromInviteHash';
import { useMemo } from 'react';
import { AnimatedEaseIn } from 'twenty-ui';
import { SignInUpWorkspaceScopeFormEffect } from '@/auth/sign-in-up/components/SignInUpWorkspaceScopeFormEffect';
export const Invite = () => {
const { workspace: workspaceFromInviteHash } = useWorkspaceFromInviteHash();
const title = useMemo(() => {
return `Join ${workspaceFromInviteHash?.displayName ?? ''} team`;
}, [workspaceFromInviteHash?.displayName]);
return (
<>
<AnimatedEaseIn>
<Logo secondaryLogo={workspaceFromInviteHash?.logo} />
</AnimatedEaseIn>
<Title animate>{title}</Title>
<SignInUpWorkspaceScopeFormEffect />
<SignInUpWorkspaceScopeForm />
</>
);
};

View File

@ -23,29 +23,25 @@ import { AnimatedEaseIn } from 'twenty-ui';
import { useSearchParams } from 'react-router-dom'; import { useSearchParams } from 'react-router-dom';
import { PublicWorkspaceDataOutput } from '~/generated-metadata/graphql'; import { PublicWorkspaceDataOutput } from '~/generated-metadata/graphql';
import { useWorkspaceFromInviteHash } from '@/auth/sign-in-up/hooks/useWorkspaceFromInviteHash';
const StandardContent = ({ const StandardContent = ({
workspacePublicData, workspacePublicData,
signInUpForm, signInUpForm,
signInUpStep, signInUpStep,
title,
}: { }: {
workspacePublicData: PublicWorkspaceDataOutput | null; workspacePublicData: PublicWorkspaceDataOutput | null;
signInUpForm: JSX.Element | null; signInUpForm: JSX.Element | null;
signInUpStep: SignInUpStep; signInUpStep: SignInUpStep;
title: string;
}) => { }) => {
return ( return (
<> <>
<AnimatedEaseIn> <AnimatedEaseIn>
<Logo secondaryLogo={workspacePublicData?.logo} /> <Logo secondaryLogo={workspacePublicData?.logo} />
</AnimatedEaseIn> </AnimatedEaseIn>
<Title animate> <Title animate>{title}</Title>
Welcome to{' '}
{!isDefined(workspacePublicData?.displayName)
? DEFAULT_WORKSPACE_NAME
: workspacePublicData?.displayName === ''
? 'Your Workspace'
: workspacePublicData?.displayName}
</Title>
{signInUpForm} {signInUpForm}
{signInUpStep !== SignInUpStep.Password && <FooterNote />} {signInUpStep !== SignInUpStep.Password && <FooterNote />}
</> </>
@ -61,9 +57,29 @@ export const SignInUp = () => {
const workspacePublicData = useRecoilValue(workspacePublicDataState); const workspacePublicData = useRecoilValue(workspacePublicDataState);
const { loading } = useGetPublicWorkspaceDataBySubdomain(); const { loading } = useGetPublicWorkspaceDataBySubdomain();
const isMultiWorkspaceEnabled = useRecoilValue(isMultiWorkspaceEnabledState); const isMultiWorkspaceEnabled = useRecoilValue(isMultiWorkspaceEnabledState);
const { workspaceInviteHash, workspace: workspaceFromInviteHash } =
useWorkspaceFromInviteHash();
const [searchParams] = useSearchParams(); const [searchParams] = useSearchParams();
const title = useMemo(() => {
if (isDefined(workspaceInviteHash)) {
return `Join ${workspaceFromInviteHash?.displayName ?? ''} team`;
}
return `Welcome to ${
!isDefined(workspacePublicData?.displayName)
? DEFAULT_WORKSPACE_NAME
: workspacePublicData?.displayName === ''
? 'Your Workspace'
: workspacePublicData?.displayName
}`;
}, [
workspaceFromInviteHash?.displayName,
workspaceInviteHash,
workspacePublicData?.displayName,
]);
const signInUpForm = useMemo(() => { const signInUpForm = useMemo(() => {
if (loading) return null; if (loading) return null;
@ -110,6 +126,7 @@ export const SignInUp = () => {
workspacePublicData={workspacePublicData} workspacePublicData={workspacePublicData}
signInUpForm={signInUpForm} signInUpForm={signInUpForm}
signInUpStep={signInUpStep} signInUpStep={signInUpStep}
title={title}
/> />
); );
}; };

View File

@ -12,11 +12,11 @@ import {
import { graphqlMocks } from '~/testing/graphqlMocks'; import { graphqlMocks } from '~/testing/graphqlMocks';
import { AppPath } from '@/types/AppPath'; import { AppPath } from '@/types/AppPath';
import { Invite } from '../Invite'; import { SignInUp } from '../SignInUp';
const meta: Meta<PageDecoratorArgs> = { const meta: Meta<PageDecoratorArgs> = {
title: 'Pages/Auth/Invite', title: 'Pages/Auth/Invite',
component: Invite, component: SignInUp,
decorators: [PageDecorator], decorators: [PageDecorator],
args: { args: {
routePath: AppPath.Invite, routePath: AppPath.Invite,
@ -67,7 +67,7 @@ const meta: Meta<PageDecoratorArgs> = {
export default meta; export default meta;
export type Story = StoryObj<typeof Invite>; export type Story = StoryObj<typeof SignInUp>;
export const Default: Story = { export const Default: Story = {
play: async ({ canvasElement }) => { play: async ({ canvasElement }) => {

View File

@ -36,6 +36,8 @@ import { EnvironmentService } from 'src/engine/core-modules/environment/environm
import { GuardRedirectService } from 'src/engine/core-modules/guard-redirect/services/guard-redirect.service'; import { GuardRedirectService } from 'src/engine/core-modules/guard-redirect/services/guard-redirect.service';
import { SAMLRequest } from 'src/engine/core-modules/auth/strategies/saml.auth.strategy'; import { SAMLRequest } from 'src/engine/core-modules/auth/strategies/saml.auth.strategy';
import { OIDCRequest } from 'src/engine/core-modules/auth/strategies/oidc.auth.strategy'; import { OIDCRequest } from 'src/engine/core-modules/auth/strategies/oidc.auth.strategy';
import { Workspace } from 'src/engine/core-modules/workspace/workspace.entity';
import { workspaceValidator } from 'src/engine/core-modules/workspace/workspace.validate';
@Controller('auth') @Controller('auth')
export class SSOAuthController { export class SSOAuthController {
@ -107,9 +109,10 @@ export class SSOAuthController {
private async authCallback(req: OIDCRequest | SAMLRequest, res: Response) { private async authCallback(req: OIDCRequest | SAMLRequest, res: Response) {
const workspaceIdentityProvider = const workspaceIdentityProvider =
await this.findWorkspaceIdentityProviderByIdentityProviderId( await this.workspaceSSOIdentityProviderRepository.findOne({
req.user.identityProviderId, where: { id: req.user.identityProviderId },
); relations: ['workspace'],
});
try { try {
if (!workspaceIdentityProvider) { if (!workspaceIdentityProvider) {
@ -126,15 +129,30 @@ export class SSOAuthController {
); );
} }
const { loginToken, identityProvider } = await this.generateLoginToken( const currentWorkspace = await this.authService.findWorkspaceForSignInUp({
workspaceId: workspaceIdentityProvider.workspaceId,
workspaceInviteHash: req.user.workspaceInviteHash,
email: req.user.email,
authProvider: 'sso',
});
workspaceValidator.assertIsDefinedOrThrow(
currentWorkspace,
new AuthException(
'Workspace not found',
AuthExceptionCode.OAUTH_ACCESS_DENIED,
),
);
const { loginToken } = await this.generateLoginToken(
req.user, req.user,
workspaceIdentityProvider, currentWorkspace,
); );
return res.redirect( return res.redirect(
this.authService.computeRedirectURI({ this.authService.computeRedirectURI({
loginToken: loginToken.token, loginToken: loginToken.token,
subdomain: identityProvider.workspace.subdomain, subdomain: currentWorkspace.subdomain,
}), }),
); );
} catch (err) { } catch (err) {
@ -149,33 +167,16 @@ export class SSOAuthController {
} }
} }
private async findWorkspaceIdentityProviderByIdentityProviderId(
identityProviderId: string,
) {
return await this.workspaceSSOIdentityProviderRepository.findOne({
where: { id: identityProviderId },
relations: ['workspace'],
});
}
private async generateLoginToken( private async generateLoginToken(
payload: { email: string }, payload: { email: string; workspaceInviteHash?: string },
identityProvider: WorkspaceSSOIdentityProvider, currentWorkspace: Workspace,
) { ) {
if (!identityProvider) { const invitation = payload.email
throw new AuthException( ? await this.authService.findInvitationForSignInUp({
'Identity provider not found', currentWorkspace,
AuthExceptionCode.INVALID_DATA, email: payload.email,
); })
} : undefined;
const invitation =
payload.email && identityProvider.workspace
? await this.authService.findInvitationForSignInUp({
currentWorkspace: identityProvider.workspace,
email: payload.email,
})
: undefined;
const existingUser = await this.userRepository.findOne({ const existingUser = await this.userRepository.findOne({
where: { where: {
@ -191,12 +192,13 @@ export class SSOAuthController {
await this.authService.checkAccessForSignIn({ await this.authService.checkAccessForSignIn({
userData, userData,
invitation, invitation,
workspace: identityProvider.workspace, workspaceInviteHash: payload.workspaceInviteHash,
workspace: currentWorkspace,
}); });
const { workspace, user } = await this.authService.signInUp({ const { workspace, user } = await this.authService.signInUp({
userData, userData,
workspace: identityProvider.workspace, workspace: currentWorkspace,
invitation, invitation,
authParams: { authParams: {
provider: 'sso', provider: 'sso',
@ -204,7 +206,7 @@ export class SSOAuthController {
}); });
return { return {
identityProvider, workspace,
loginToken: await this.loginTokenService.generateLoginToken( loginToken: await this.loginTokenService.generateLoginToken(
user.email, user.email,
workspace.id, workspace.id,

View File

@ -3,7 +3,6 @@
import { Injectable } from '@nestjs/common'; import { Injectable } from '@nestjs/common';
import { PassportStrategy } from '@nestjs/passport'; import { PassportStrategy } from '@nestjs/passport';
import { isEmail } from 'class-validator';
import { Request } from 'express'; import { Request } from 'express';
import { Strategy, StrategyOptions, TokenSet } from 'openid-client'; import { Strategy, StrategyOptions, TokenSet } from 'openid-client';
@ -21,6 +20,7 @@ export type OIDCRequest = Omit<
email: string; email: string;
firstName?: string | null; firstName?: string | null;
lastName?: string | null; lastName?: string | null;
workspaceInviteHash?: string;
}; };
}; };
@ -50,12 +50,17 @@ export class OIDCAuthStrategy extends PassportStrategy(
...options, ...options,
state: JSON.stringify({ state: JSON.stringify({
identityProviderId: req.params.identityProviderId, identityProviderId: req.params.identityProviderId,
...(req.query.forceSubdomainUrl ? { forceSubdomainUrl: true } : {}),
...(req.query.workspaceInviteHash
? { workspaceInviteHash: req.query.workspaceInviteHash }
: {}),
}), }),
}); });
} }
private extractState(req: Request): { private extractState(req: Request): {
identityProviderId: string; identityProviderId: string;
workspaceInviteHash?: string;
} { } {
try { try {
const state = JSON.parse( const state = JSON.parse(
@ -70,6 +75,7 @@ export class OIDCAuthStrategy extends PassportStrategy(
return { return {
identityProviderId: state.identityProviderId, identityProviderId: state.identityProviderId,
workspaceInviteHash: state.workspaceInviteHash,
}; };
} catch (err) { } catch (err) {
throw new AuthException('Invalid state', AuthExceptionCode.INVALID_INPUT); throw new AuthException('Invalid state', AuthExceptionCode.INVALID_INPUT);
@ -86,12 +92,20 @@ export class OIDCAuthStrategy extends PassportStrategy(
const userinfo = await this.client.userinfo(tokenset); const userinfo = await this.client.userinfo(tokenset);
if (!userinfo.email || !isEmail(userinfo.email)) { const email = userinfo.email ?? userinfo.upn;
return done(new Error('Invalid email'));
if (!email || typeof email !== 'string') {
return done(
new AuthException(
'Email not found in identity provider payload',
AuthExceptionCode.INVALID_DATA,
),
);
} }
done(null, { done(null, {
email: userinfo.email, email,
workspaceInviteHash: state.workspaceInviteHash,
identityProviderId: state.identityProviderId, identityProviderId: state.identityProviderId,
...(userinfo.given_name ? { firstName: userinfo.given_name } : {}), ...(userinfo.given_name ? { firstName: userinfo.given_name } : {}),
...(userinfo.family_name ? { lastName: userinfo.family_name } : {}), ...(userinfo.family_name ? { lastName: userinfo.family_name } : {}),

View File

@ -26,6 +26,7 @@ export type SAMLRequest = Omit<
> & { > & {
user: { user: {
identityProviderId: string; identityProviderId: string;
workspaceInviteHash?: string;
email: string; email: string;
}; };
}; };
@ -78,6 +79,9 @@ export class SamlAuthStrategy extends PassportStrategy(
additionalParams: { additionalParams: {
RelayState: JSON.stringify({ RelayState: JSON.stringify({
identityProviderId: req.params.identityProviderId, identityProviderId: req.params.identityProviderId,
...(req.query.workspaceInviteHash
? { workspaceInviteHash: req.query.workspaceInviteHash }
: {}),
}), }),
}, },
}); });
@ -85,6 +89,7 @@ export class SamlAuthStrategy extends PassportStrategy(
private extractState(req: Request): { private extractState(req: Request): {
identityProviderId: string; identityProviderId: string;
workspaceInviteHash?: string;
} { } {
try { try {
if ('RelayState' in req.body && typeof req.body.RelayState === 'string') { if ('RelayState' in req.body && typeof req.body.RelayState === 'string') {
@ -92,6 +97,7 @@ export class SamlAuthStrategy extends PassportStrategy(
return { return {
identityProviderId: RelayState.identityProviderId, identityProviderId: RelayState.identityProviderId,
workspaceInviteHash: RelayState.workspaceInviteHash,
}; };
} }
@ -114,11 +120,7 @@ export class SamlAuthStrategy extends PassportStrategy(
} }
const state = this.extractState(request); const state = this.extractState(request);
const result: Pick<SAMLRequest, 'user'> = { done(null, { ...state, email });
user: { ...state, email },
};
done(null, result);
} catch (err) { } catch (err) {
done(err); done(err);
} }

View File

@ -2,11 +2,16 @@
import { Field, InputType } from '@nestjs/graphql'; import { Field, InputType } from '@nestjs/graphql';
import { IsString } from 'class-validator'; import { IsOptional, IsString } from 'class-validator';
@InputType() @InputType()
export class GetAuthorizationUrlInput { export class GetAuthorizationUrlInput {
@Field(() => String) @Field(() => String)
@IsString() @IsString()
identityProviderId: string; identityProviderId: string;
@Field(() => String, { nullable: true })
@IsString()
@IsOptional()
workspaceInviteHash?: string;
} }

View File

@ -152,11 +152,18 @@ export class SSOService {
buildIssuerURL( buildIssuerURL(
identityProvider: Pick<WorkspaceSSOIdentityProvider, 'id' | 'type'>, identityProvider: Pick<WorkspaceSSOIdentityProvider, 'id' | 'type'>,
searchParams?: Record<string, string | boolean>,
) { ) {
const authorizationUrl = new URL(this.environmentService.get('SERVER_URL')); const authorizationUrl = new URL(this.environmentService.get('SERVER_URL'));
authorizationUrl.pathname = `/auth/${identityProvider.type.toLowerCase()}/login/${identityProvider.id}`; authorizationUrl.pathname = `/auth/${identityProvider.type.toLowerCase()}/login/${identityProvider.id}`;
if (searchParams) {
Object.entries(searchParams).forEach(([key, value]) => {
authorizationUrl.searchParams.append(key, value.toString());
});
}
return authorizationUrl.toString(); return authorizationUrl.toString();
} }
@ -191,7 +198,10 @@ export class SSOService {
}); });
} }
async getAuthorizationUrl(identityProviderId: string) { async getAuthorizationUrl(
identityProviderId: string,
searchParams: Record<string, string | boolean>,
) {
const identityProvider = const identityProvider =
(await this.workspaceSSOIdentityProviderRepository.findOne({ (await this.workspaceSSOIdentityProviderRepository.findOne({
where: { where: {
@ -208,7 +218,7 @@ export class SSOService {
return { return {
id: identityProvider.id, id: identityProvider.id,
authorizationURL: this.buildIssuerURL(identityProvider), authorizationURL: this.buildIssuerURL(identityProvider, searchParams),
type: identityProvider.type, type: identityProvider.type,
}; };
} }

View File

@ -3,6 +3,8 @@
import { UseGuards } from '@nestjs/common'; import { UseGuards } from '@nestjs/common';
import { Args, Mutation, Query, Resolver } from '@nestjs/graphql'; import { Args, Mutation, Query, Resolver } from '@nestjs/graphql';
import omit from 'lodash.omit';
import { EnterpriseFeaturesEnabledGuard } from 'src/engine/core-modules/auth/guards/enterprise-features-enabled.guard'; import { EnterpriseFeaturesEnabledGuard } from 'src/engine/core-modules/auth/guards/enterprise-features-enabled.guard';
import { DeleteSsoInput } from 'src/engine/core-modules/sso/dtos/delete-sso.input'; import { DeleteSsoInput } from 'src/engine/core-modules/sso/dtos/delete-sso.input';
import { DeleteSsoOutput } from 'src/engine/core-modules/sso/dtos/delete-sso.output'; import { DeleteSsoOutput } from 'src/engine/core-modules/sso/dtos/delete-sso.output';
@ -47,10 +49,11 @@ export class SSOResolver {
} }
@Mutation(() => GetAuthorizationUrlOutput) @Mutation(() => GetAuthorizationUrlOutput)
async getAuthorizationUrl( async getAuthorizationUrl(@Args('input') params: GetAuthorizationUrlInput) {
@Args('input') { identityProviderId }: GetAuthorizationUrlInput, return await this.sSOService.getAuthorizationUrl(
) { params.identityProviderId,
return this.sSOService.getAuthorizationUrl(identityProviderId); omit(params, ['identityProviderId']),
);
} }
@UseGuards(WorkspaceAuthGuard, EnterpriseFeaturesEnabledGuard) @UseGuards(WorkspaceAuthGuard, EnterpriseFeaturesEnabledGuard)