diff --git a/packages/twenty-front/src/generated-metadata/graphql.ts b/packages/twenty-front/src/generated-metadata/graphql.ts index 4187b8906..a0a2e2b97 100644 --- a/packages/twenty-front/src/generated-metadata/graphql.ts +++ b/packages/twenty-front/src/generated-metadata/graphql.ts @@ -612,6 +612,7 @@ export type FullName = { export type GetAuthorizationUrlInput = { identityProviderId: Scalars['String']['input']; + workspaceInviteHash?: InputMaybe; }; export type GetAuthorizationUrlOutput = { diff --git a/packages/twenty-front/src/generated/graphql.tsx b/packages/twenty-front/src/generated/graphql.tsx index cf4d5acbf..60947d6e7 100644 --- a/packages/twenty-front/src/generated/graphql.tsx +++ b/packages/twenty-front/src/generated/graphql.tsx @@ -600,6 +600,7 @@ export type FullName = { export type GetAuthorizationUrlInput = { identityProviderId: Scalars['String']; + workspaceInviteHash?: InputMaybe; }; export type GetAuthorizationUrlOutput = { diff --git a/packages/twenty-front/src/modules/app/hooks/useCreateAppRouter.tsx b/packages/twenty-front/src/modules/app/hooks/useCreateAppRouter.tsx index d7f11bac8..74ac80d76 100644 --- a/packages/twenty-front/src/modules/app/hooks/useCreateAppRouter.tsx +++ b/packages/twenty-front/src/modules/app/hooks/useCreateAppRouter.tsx @@ -13,7 +13,6 @@ import { Route, } from 'react-router-dom'; import { Authorize } from '~/pages/auth/Authorize'; -import { Invite } from '~/pages/auth/Invite'; import { PasswordReset } from '~/pages/auth/PasswordReset'; import { SignInUp } from '~/pages/auth/SignInUp'; import { NotFound } from '~/pages/not-found/NotFound'; @@ -43,7 +42,7 @@ export const useCreateAppRouter = ( } /> } /> } /> - } /> + } /> } /> } /> } /> diff --git a/packages/twenty-front/src/modules/auth/sign-in-up/hooks/useSSO.ts b/packages/twenty-front/src/modules/auth/sign-in-up/hooks/useSSO.ts index 3b3a3405c..abe014412 100644 --- a/packages/twenty-front/src/modules/auth/sign-in-up/hooks/useSSO.ts +++ b/packages/twenty-front/src/modules/auth/sign-in-up/hooks/useSSO.ts @@ -5,11 +5,13 @@ import { useRedirect } from '@/domain-manager/hooks/useRedirect'; import { SnackBarVariant } from '@/ui/feedback/snack-bar-manager/components/SnackBar'; import { useSnackBar } from '@/ui/feedback/snack-bar-manager/hooks/useSnackBar'; import { useApolloClient } from '@apollo/client'; +import { useParams } from 'react-router-dom'; export const useSSO = () => { const apolloClient = useApolloClient(); - const { enqueueSnackBar } = useSnackBar(); + const workspaceInviteHash = useParams().workspaceInviteHash; + const { enqueueSnackBar } = useSnackBar(); const { redirect } = useRedirect(); const redirectToSSOLoginPage = async (identityProviderId: string) => { @@ -18,7 +20,7 @@ export const useSSO = () => { authorizationUrlForSSOResult = await apolloClient.mutate({ mutation: GET_AUTHORIZATION_URL, variables: { - input: { identityProviderId }, + input: { identityProviderId, workspaceInviteHash }, }, }); } catch (error: any) { diff --git a/packages/twenty-front/src/modules/auth/sign-in-up/hooks/useWorkspaceFromInviteHash.ts b/packages/twenty-front/src/modules/auth/sign-in-up/hooks/useWorkspaceFromInviteHash.ts index 209618eff..dbc5b011a 100644 --- a/packages/twenty-front/src/modules/auth/sign-in-up/hooks/useWorkspaceFromInviteHash.ts +++ b/packages/twenty-front/src/modules/auth/sign-in-up/hooks/useWorkspaceFromInviteHash.ts @@ -24,6 +24,7 @@ export const useWorkspaceFromInviteHash = () => { ); const { data: workspaceFromInviteHash, loading } = useGetWorkspaceFromInviteHashQuery({ + skip: !workspaceInviteHash, variables: { inviteHash: workspaceInviteHash || '' }, onError: (error) => { enqueueSnackBar(error.message, { diff --git a/packages/twenty-front/src/pages/auth/Invite.tsx b/packages/twenty-front/src/pages/auth/Invite.tsx deleted file mode 100644 index 86a449b96..000000000 --- a/packages/twenty-front/src/pages/auth/Invite.tsx +++ /dev/null @@ -1,27 +0,0 @@ -import { Logo } from '@/auth/components/Logo'; -import { Title } from '@/auth/components/Title'; -import { SignInUpWorkspaceScopeForm } from '@/auth/sign-in-up/components/SignInUpWorkspaceScopeForm'; -import { useWorkspaceFromInviteHash } from '@/auth/sign-in-up/hooks/useWorkspaceFromInviteHash'; -import { useMemo } from 'react'; -import { AnimatedEaseIn } from 'twenty-ui'; - -import { SignInUpWorkspaceScopeFormEffect } from '@/auth/sign-in-up/components/SignInUpWorkspaceScopeFormEffect'; - -export const Invite = () => { - const { workspace: workspaceFromInviteHash } = useWorkspaceFromInviteHash(); - - const title = useMemo(() => { - return `Join ${workspaceFromInviteHash?.displayName ?? ''} team`; - }, [workspaceFromInviteHash?.displayName]); - - return ( - <> - - - - {title} - - - - ); -}; diff --git a/packages/twenty-front/src/pages/auth/SignInUp.tsx b/packages/twenty-front/src/pages/auth/SignInUp.tsx index 41cf12971..82b8302fd 100644 --- a/packages/twenty-front/src/pages/auth/SignInUp.tsx +++ b/packages/twenty-front/src/pages/auth/SignInUp.tsx @@ -23,29 +23,25 @@ import { AnimatedEaseIn } from 'twenty-ui'; import { useSearchParams } from 'react-router-dom'; import { PublicWorkspaceDataOutput } from '~/generated-metadata/graphql'; +import { useWorkspaceFromInviteHash } from '@/auth/sign-in-up/hooks/useWorkspaceFromInviteHash'; const StandardContent = ({ workspacePublicData, signInUpForm, signInUpStep, + title, }: { workspacePublicData: PublicWorkspaceDataOutput | null; signInUpForm: JSX.Element | null; signInUpStep: SignInUpStep; + title: string; }) => { return ( <> - - Welcome to{' '} - {!isDefined(workspacePublicData?.displayName) - ? DEFAULT_WORKSPACE_NAME - : workspacePublicData?.displayName === '' - ? 'Your Workspace' - : workspacePublicData?.displayName} - + {title} {signInUpForm} {signInUpStep !== SignInUpStep.Password && } @@ -61,9 +57,29 @@ export const SignInUp = () => { const workspacePublicData = useRecoilValue(workspacePublicDataState); const { loading } = useGetPublicWorkspaceDataBySubdomain(); const isMultiWorkspaceEnabled = useRecoilValue(isMultiWorkspaceEnabledState); + const { workspaceInviteHash, workspace: workspaceFromInviteHash } = + useWorkspaceFromInviteHash(); const [searchParams] = useSearchParams(); + const title = useMemo(() => { + if (isDefined(workspaceInviteHash)) { + return `Join ${workspaceFromInviteHash?.displayName ?? ''} team`; + } + + return `Welcome to ${ + !isDefined(workspacePublicData?.displayName) + ? DEFAULT_WORKSPACE_NAME + : workspacePublicData?.displayName === '' + ? 'Your Workspace' + : workspacePublicData?.displayName + }`; + }, [ + workspaceFromInviteHash?.displayName, + workspaceInviteHash, + workspacePublicData?.displayName, + ]); + const signInUpForm = useMemo(() => { if (loading) return null; @@ -110,6 +126,7 @@ export const SignInUp = () => { workspacePublicData={workspacePublicData} signInUpForm={signInUpForm} signInUpStep={signInUpStep} + title={title} /> ); }; diff --git a/packages/twenty-front/src/pages/auth/__stories__/Invite.stories.tsx b/packages/twenty-front/src/pages/auth/__stories__/SignInUpWithInvite.stories.tsx similarity index 95% rename from packages/twenty-front/src/pages/auth/__stories__/Invite.stories.tsx rename to packages/twenty-front/src/pages/auth/__stories__/SignInUpWithInvite.stories.tsx index 85665e8be..3a510c881 100644 --- a/packages/twenty-front/src/pages/auth/__stories__/Invite.stories.tsx +++ b/packages/twenty-front/src/pages/auth/__stories__/SignInUpWithInvite.stories.tsx @@ -12,11 +12,11 @@ import { import { graphqlMocks } from '~/testing/graphqlMocks'; import { AppPath } from '@/types/AppPath'; -import { Invite } from '../Invite'; +import { SignInUp } from '../SignInUp'; const meta: Meta = { title: 'Pages/Auth/Invite', - component: Invite, + component: SignInUp, decorators: [PageDecorator], args: { routePath: AppPath.Invite, @@ -67,7 +67,7 @@ const meta: Meta = { export default meta; -export type Story = StoryObj; +export type Story = StoryObj; export const Default: Story = { play: async ({ canvasElement }) => { diff --git a/packages/twenty-server/src/engine/core-modules/auth/controllers/sso-auth.controller.ts b/packages/twenty-server/src/engine/core-modules/auth/controllers/sso-auth.controller.ts index 002fc8594..b791b21d5 100644 --- a/packages/twenty-server/src/engine/core-modules/auth/controllers/sso-auth.controller.ts +++ b/packages/twenty-server/src/engine/core-modules/auth/controllers/sso-auth.controller.ts @@ -36,6 +36,8 @@ import { EnvironmentService } from 'src/engine/core-modules/environment/environm import { GuardRedirectService } from 'src/engine/core-modules/guard-redirect/services/guard-redirect.service'; import { SAMLRequest } from 'src/engine/core-modules/auth/strategies/saml.auth.strategy'; import { OIDCRequest } from 'src/engine/core-modules/auth/strategies/oidc.auth.strategy'; +import { Workspace } from 'src/engine/core-modules/workspace/workspace.entity'; +import { workspaceValidator } from 'src/engine/core-modules/workspace/workspace.validate'; @Controller('auth') export class SSOAuthController { @@ -107,9 +109,10 @@ export class SSOAuthController { private async authCallback(req: OIDCRequest | SAMLRequest, res: Response) { const workspaceIdentityProvider = - await this.findWorkspaceIdentityProviderByIdentityProviderId( - req.user.identityProviderId, - ); + await this.workspaceSSOIdentityProviderRepository.findOne({ + where: { id: req.user.identityProviderId }, + relations: ['workspace'], + }); try { if (!workspaceIdentityProvider) { @@ -126,15 +129,30 @@ export class SSOAuthController { ); } - const { loginToken, identityProvider } = await this.generateLoginToken( + const currentWorkspace = await this.authService.findWorkspaceForSignInUp({ + workspaceId: workspaceIdentityProvider.workspaceId, + workspaceInviteHash: req.user.workspaceInviteHash, + email: req.user.email, + authProvider: 'sso', + }); + + workspaceValidator.assertIsDefinedOrThrow( + currentWorkspace, + new AuthException( + 'Workspace not found', + AuthExceptionCode.OAUTH_ACCESS_DENIED, + ), + ); + + const { loginToken } = await this.generateLoginToken( req.user, - workspaceIdentityProvider, + currentWorkspace, ); return res.redirect( this.authService.computeRedirectURI({ loginToken: loginToken.token, - subdomain: identityProvider.workspace.subdomain, + subdomain: currentWorkspace.subdomain, }), ); } catch (err) { @@ -149,33 +167,16 @@ export class SSOAuthController { } } - private async findWorkspaceIdentityProviderByIdentityProviderId( - identityProviderId: string, - ) { - return await this.workspaceSSOIdentityProviderRepository.findOne({ - where: { id: identityProviderId }, - relations: ['workspace'], - }); - } - private async generateLoginToken( - payload: { email: string }, - identityProvider: WorkspaceSSOIdentityProvider, + payload: { email: string; workspaceInviteHash?: string }, + currentWorkspace: Workspace, ) { - if (!identityProvider) { - throw new AuthException( - 'Identity provider not found', - AuthExceptionCode.INVALID_DATA, - ); - } - - const invitation = - payload.email && identityProvider.workspace - ? await this.authService.findInvitationForSignInUp({ - currentWorkspace: identityProvider.workspace, - email: payload.email, - }) - : undefined; + const invitation = payload.email + ? await this.authService.findInvitationForSignInUp({ + currentWorkspace, + email: payload.email, + }) + : undefined; const existingUser = await this.userRepository.findOne({ where: { @@ -191,12 +192,13 @@ export class SSOAuthController { await this.authService.checkAccessForSignIn({ userData, invitation, - workspace: identityProvider.workspace, + workspaceInviteHash: payload.workspaceInviteHash, + workspace: currentWorkspace, }); const { workspace, user } = await this.authService.signInUp({ userData, - workspace: identityProvider.workspace, + workspace: currentWorkspace, invitation, authParams: { provider: 'sso', @@ -204,7 +206,7 @@ export class SSOAuthController { }); return { - identityProvider, + workspace, loginToken: await this.loginTokenService.generateLoginToken( user.email, workspace.id, diff --git a/packages/twenty-server/src/engine/core-modules/auth/strategies/oidc.auth.strategy.ts b/packages/twenty-server/src/engine/core-modules/auth/strategies/oidc.auth.strategy.ts index c57623496..42f37fd32 100644 --- a/packages/twenty-server/src/engine/core-modules/auth/strategies/oidc.auth.strategy.ts +++ b/packages/twenty-server/src/engine/core-modules/auth/strategies/oidc.auth.strategy.ts @@ -3,7 +3,6 @@ import { Injectable } from '@nestjs/common'; import { PassportStrategy } from '@nestjs/passport'; -import { isEmail } from 'class-validator'; import { Request } from 'express'; import { Strategy, StrategyOptions, TokenSet } from 'openid-client'; @@ -21,6 +20,7 @@ export type OIDCRequest = Omit< email: string; firstName?: string | null; lastName?: string | null; + workspaceInviteHash?: string; }; }; @@ -50,12 +50,17 @@ export class OIDCAuthStrategy extends PassportStrategy( ...options, state: JSON.stringify({ identityProviderId: req.params.identityProviderId, + ...(req.query.forceSubdomainUrl ? { forceSubdomainUrl: true } : {}), + ...(req.query.workspaceInviteHash + ? { workspaceInviteHash: req.query.workspaceInviteHash } + : {}), }), }); } private extractState(req: Request): { identityProviderId: string; + workspaceInviteHash?: string; } { try { const state = JSON.parse( @@ -70,6 +75,7 @@ export class OIDCAuthStrategy extends PassportStrategy( return { identityProviderId: state.identityProviderId, + workspaceInviteHash: state.workspaceInviteHash, }; } catch (err) { throw new AuthException('Invalid state', AuthExceptionCode.INVALID_INPUT); @@ -86,12 +92,20 @@ export class OIDCAuthStrategy extends PassportStrategy( const userinfo = await this.client.userinfo(tokenset); - if (!userinfo.email || !isEmail(userinfo.email)) { - return done(new Error('Invalid email')); + const email = userinfo.email ?? userinfo.upn; + + if (!email || typeof email !== 'string') { + return done( + new AuthException( + 'Email not found in identity provider payload', + AuthExceptionCode.INVALID_DATA, + ), + ); } done(null, { - email: userinfo.email, + email, + workspaceInviteHash: state.workspaceInviteHash, identityProviderId: state.identityProviderId, ...(userinfo.given_name ? { firstName: userinfo.given_name } : {}), ...(userinfo.family_name ? { lastName: userinfo.family_name } : {}), diff --git a/packages/twenty-server/src/engine/core-modules/auth/strategies/saml.auth.strategy.ts b/packages/twenty-server/src/engine/core-modules/auth/strategies/saml.auth.strategy.ts index c3c9e5e7e..4dff257d8 100644 --- a/packages/twenty-server/src/engine/core-modules/auth/strategies/saml.auth.strategy.ts +++ b/packages/twenty-server/src/engine/core-modules/auth/strategies/saml.auth.strategy.ts @@ -26,6 +26,7 @@ export type SAMLRequest = Omit< > & { user: { identityProviderId: string; + workspaceInviteHash?: string; email: string; }; }; @@ -78,6 +79,9 @@ export class SamlAuthStrategy extends PassportStrategy( additionalParams: { RelayState: JSON.stringify({ identityProviderId: req.params.identityProviderId, + ...(req.query.workspaceInviteHash + ? { workspaceInviteHash: req.query.workspaceInviteHash } + : {}), }), }, }); @@ -85,6 +89,7 @@ export class SamlAuthStrategy extends PassportStrategy( private extractState(req: Request): { identityProviderId: string; + workspaceInviteHash?: string; } { try { if ('RelayState' in req.body && typeof req.body.RelayState === 'string') { @@ -92,6 +97,7 @@ export class SamlAuthStrategy extends PassportStrategy( return { identityProviderId: RelayState.identityProviderId, + workspaceInviteHash: RelayState.workspaceInviteHash, }; } @@ -114,11 +120,7 @@ export class SamlAuthStrategy extends PassportStrategy( } const state = this.extractState(request); - const result: Pick = { - user: { ...state, email }, - }; - - done(null, result); + done(null, { ...state, email }); } catch (err) { done(err); } diff --git a/packages/twenty-server/src/engine/core-modules/sso/dtos/get-authorization-url.input.ts b/packages/twenty-server/src/engine/core-modules/sso/dtos/get-authorization-url.input.ts index e0adc9645..e627a4a7e 100644 --- a/packages/twenty-server/src/engine/core-modules/sso/dtos/get-authorization-url.input.ts +++ b/packages/twenty-server/src/engine/core-modules/sso/dtos/get-authorization-url.input.ts @@ -2,11 +2,16 @@ import { Field, InputType } from '@nestjs/graphql'; -import { IsString } from 'class-validator'; +import { IsOptional, IsString } from 'class-validator'; @InputType() export class GetAuthorizationUrlInput { @Field(() => String) @IsString() identityProviderId: string; + + @Field(() => String, { nullable: true }) + @IsString() + @IsOptional() + workspaceInviteHash?: string; } diff --git a/packages/twenty-server/src/engine/core-modules/sso/services/sso.service.ts b/packages/twenty-server/src/engine/core-modules/sso/services/sso.service.ts index 4d3185210..881824581 100644 --- a/packages/twenty-server/src/engine/core-modules/sso/services/sso.service.ts +++ b/packages/twenty-server/src/engine/core-modules/sso/services/sso.service.ts @@ -152,11 +152,18 @@ export class SSOService { buildIssuerURL( identityProvider: Pick, + searchParams?: Record, ) { const authorizationUrl = new URL(this.environmentService.get('SERVER_URL')); authorizationUrl.pathname = `/auth/${identityProvider.type.toLowerCase()}/login/${identityProvider.id}`; + if (searchParams) { + Object.entries(searchParams).forEach(([key, value]) => { + authorizationUrl.searchParams.append(key, value.toString()); + }); + } + return authorizationUrl.toString(); } @@ -191,7 +198,10 @@ export class SSOService { }); } - async getAuthorizationUrl(identityProviderId: string) { + async getAuthorizationUrl( + identityProviderId: string, + searchParams: Record, + ) { const identityProvider = (await this.workspaceSSOIdentityProviderRepository.findOne({ where: { @@ -208,7 +218,7 @@ export class SSOService { return { id: identityProvider.id, - authorizationURL: this.buildIssuerURL(identityProvider), + authorizationURL: this.buildIssuerURL(identityProvider, searchParams), type: identityProvider.type, }; } diff --git a/packages/twenty-server/src/engine/core-modules/sso/sso.resolver.ts b/packages/twenty-server/src/engine/core-modules/sso/sso.resolver.ts index bdf5b0889..d79a5130b 100644 --- a/packages/twenty-server/src/engine/core-modules/sso/sso.resolver.ts +++ b/packages/twenty-server/src/engine/core-modules/sso/sso.resolver.ts @@ -3,6 +3,8 @@ import { UseGuards } from '@nestjs/common'; import { Args, Mutation, Query, Resolver } from '@nestjs/graphql'; +import omit from 'lodash.omit'; + import { EnterpriseFeaturesEnabledGuard } from 'src/engine/core-modules/auth/guards/enterprise-features-enabled.guard'; import { DeleteSsoInput } from 'src/engine/core-modules/sso/dtos/delete-sso.input'; import { DeleteSsoOutput } from 'src/engine/core-modules/sso/dtos/delete-sso.output'; @@ -47,10 +49,11 @@ export class SSOResolver { } @Mutation(() => GetAuthorizationUrlOutput) - async getAuthorizationUrl( - @Args('input') { identityProviderId }: GetAuthorizationUrlInput, - ) { - return this.sSOService.getAuthorizationUrl(identityProviderId); + async getAuthorizationUrl(@Args('input') params: GetAuthorizationUrlInput) { + return await this.sSOService.getAuthorizationUrl( + params.identityProviderId, + omit(params, ['identityProviderId']), + ); } @UseGuards(WorkspaceAuthGuard, EnterpriseFeaturesEnabledGuard)