refacto(auth): improve type + remove complexity (#9949)

- Improve type
- Remove unnecessary code
- Fix the issue that prevents the usage of invitations when a user signs
in with social media.
- Add Microsoft icon for sso list page
This commit is contained in:
Antoine Moreaux
2025-01-31 15:39:45 +01:00
committed by GitHub
parent 591301f7ce
commit b801307d92
11 changed files with 120 additions and 125 deletions

View File

@ -1,6 +1,11 @@
/* @license Enterprise */ /* @license Enterprise */
import { IconComponent, IconGoogle, IconKey } from 'twenty-ui'; import {
IconComponent,
IconGoogle,
IconKey,
IconMicrosoftOutlook,
} from 'twenty-ui';
export const guessSSOIdentityProviderIconByUrl = ( export const guessSSOIdentityProviderIconByUrl = (
url: string, url: string,
@ -9,5 +14,9 @@ export const guessSSOIdentityProviderIconByUrl = (
return IconGoogle; return IconGoogle;
} }
if (url.includes('microsoft')) {
return IconMicrosoftOutlook;
}
return IconKey; return IconKey;
}; };

View File

@ -51,7 +51,6 @@ export class GoogleAuthController {
email, email,
picture, picture,
workspaceInviteHash, workspaceInviteHash,
workspacePersonalInviteToken,
workspaceId, workspaceId,
billingCheckoutSessionState, billingCheckoutSessionState,
} = req.user; } = req.user;
@ -65,7 +64,7 @@ export class GoogleAuthController {
try { try {
const invitation = const invitation =
currentWorkspace && workspacePersonalInviteToken && email currentWorkspace && email
? await this.authService.findInvitationForSignInUp({ ? await this.authService.findInvitationForSignInUp({
currentWorkspace, currentWorkspace,
email, email,

View File

@ -52,7 +52,6 @@ export class MicrosoftAuthController {
email, email,
picture, picture,
workspaceInviteHash, workspaceInviteHash,
workspacePersonalInviteToken,
workspaceId, workspaceId,
billingCheckoutSessionState, billingCheckoutSessionState,
} = req.user; } = req.user;
@ -66,7 +65,7 @@ export class MicrosoftAuthController {
try { try {
const invitation = const invitation =
currentWorkspace && workspacePersonalInviteToken && email currentWorkspace && email
? await this.authService.findInvitationForSignInUp({ ? await this.authService.findInvitationForSignInUp({
currentWorkspace, currentWorkspace,
email, email,

View File

@ -34,6 +34,8 @@ import { User } from 'src/engine/core-modules/user/user.entity';
import { AuthOAuthExceptionFilter } from 'src/engine/core-modules/auth/filters/auth-oauth-exception.filter'; import { AuthOAuthExceptionFilter } from 'src/engine/core-modules/auth/filters/auth-oauth-exception.filter';
import { EnvironmentService } from 'src/engine/core-modules/environment/environment.service'; import { EnvironmentService } from 'src/engine/core-modules/environment/environment.service';
import { GuardRedirectService } from 'src/engine/core-modules/guard-redirect/services/guard-redirect.service'; import { GuardRedirectService } from 'src/engine/core-modules/guard-redirect/services/guard-redirect.service';
import { SAMLRequest } from 'src/engine/core-modules/auth/strategies/saml.auth.strategy';
import { OIDCRequest } from 'src/engine/core-modules/auth/strategies/oidc.auth.strategy';
@Controller('auth') @Controller('auth')
export class SSOAuthController { export class SSOAuthController {
@ -85,14 +87,14 @@ export class SSOAuthController {
@Get('oidc/callback') @Get('oidc/callback')
@UseGuards(EnterpriseFeaturesEnabledGuard, OIDCAuthGuard) @UseGuards(EnterpriseFeaturesEnabledGuard, OIDCAuthGuard)
@UseFilters(AuthOAuthExceptionFilter) @UseFilters(AuthOAuthExceptionFilter)
async oidcAuthCallback(@Req() req: any, @Res() res: Response) { async oidcAuthCallback(@Req() req: OIDCRequest, @Res() res: Response) {
return await this.authCallback(req, res); return await this.authCallback(req, res);
} }
@Post('saml/callback/:identityProviderId') @Post('saml/callback/:identityProviderId')
@UseGuards(EnterpriseFeaturesEnabledGuard, SAMLAuthGuard) @UseGuards(EnterpriseFeaturesEnabledGuard, SAMLAuthGuard)
@UseFilters(AuthOAuthExceptionFilter) @UseFilters(AuthOAuthExceptionFilter)
async samlAuthCallback(@Req() req: any, @Res() res: Response) { async samlAuthCallback(@Req() req: SAMLRequest, @Res() res: Response) {
try { try {
return await this.authCallback(req, res); return await this.authCallback(req, res);
} catch (err) { } catch (err) {
@ -103,10 +105,10 @@ export class SSOAuthController {
} }
} }
private async authCallback({ user }: any, res: Response) { private async authCallback(req: OIDCRequest | SAMLRequest, res: Response) {
const workspaceIdentityProvider = const workspaceIdentityProvider =
await this.findWorkspaceIdentityProviderByIdentityProviderId( await this.findWorkspaceIdentityProviderByIdentityProviderId(
user.identityProviderId, req.user.identityProviderId,
); );
try { try {
@ -117,7 +119,7 @@ export class SSOAuthController {
); );
} }
if (!user.user.email) { if (!req.user.email) {
throw new AuthException( throw new AuthException(
'Email not found from identity provider.', 'Email not found from identity provider.',
AuthExceptionCode.OAUTH_ACCESS_DENIED, AuthExceptionCode.OAUTH_ACCESS_DENIED,
@ -125,7 +127,7 @@ export class SSOAuthController {
} }
const { loginToken, identityProvider } = await this.generateLoginToken( const { loginToken, identityProvider } = await this.generateLoginToken(
user.user, req.user,
workspaceIdentityProvider, workspaceIdentityProvider,
); );
@ -157,7 +159,7 @@ export class SSOAuthController {
} }
private async generateLoginToken( private async generateLoginToken(
payload: { email: string } & Record<string, string>, payload: { email: string },
identityProvider: WorkspaceSSOIdentityProvider, identityProvider: WorkspaceSSOIdentityProvider,
) { ) {
if (!identityProvider) { if (!identityProvider) {

View File

@ -3,6 +3,7 @@ import { AuthGuard } from '@nestjs/passport';
import { InjectRepository } from '@nestjs/typeorm'; import { InjectRepository } from '@nestjs/typeorm';
import { Repository } from 'typeorm'; import { Repository } from 'typeorm';
import { Request } from 'express';
import { import {
AuthException, AuthException,
@ -26,7 +27,7 @@ export class GoogleOauthGuard extends AuthGuard('google') {
} }
async canActivate(context: ExecutionContext) { async canActivate(context: ExecutionContext) {
const request = context.switchToHttp().getRequest(); const request = context.switchToHttp().getRequest<Request>();
let workspace: Workspace | null = null; let workspace: Workspace | null = null;
try { try {
@ -40,9 +41,6 @@ export class GoogleOauthGuard extends AuthGuard('google') {
}); });
} }
const workspaceInviteHash = request.query.inviteHash;
const workspacePersonalInviteToken = request.query.inviteToken;
if (request.query.error === 'access_denied') { if (request.query.error === 'access_denied') {
throw new AuthException( throw new AuthException(
'Google OAuth access denied', 'Google OAuth access denied',
@ -50,26 +48,6 @@ export class GoogleOauthGuard extends AuthGuard('google') {
); );
} }
if (workspaceInviteHash && typeof workspaceInviteHash === 'string') {
request.params.workspaceInviteHash = workspaceInviteHash;
}
if (
workspacePersonalInviteToken &&
typeof workspacePersonalInviteToken === 'string'
) {
request.params.workspacePersonalInviteToken =
workspacePersonalInviteToken;
}
if (
request.query.billingCheckoutSessionState &&
typeof request.query.billingCheckoutSessionState === 'string'
) {
request.params.billingCheckoutSessionState =
request.query.billingCheckoutSessionState;
}
return (await super.canActivate(context)) as boolean; return (await super.canActivate(context)) as boolean;
} catch (err) { } catch (err) {
this.guardRedirectService.dispatchErrorFromGuard( this.guardRedirectService.dispatchErrorFromGuard(

View File

@ -36,29 +36,6 @@ export class MicrosoftOAuthGuard extends AuthGuard('microsoft') {
}); });
} }
const workspaceInviteHash = request.query.inviteHash;
const workspacePersonalInviteToken = request.query.inviteToken;
if (workspaceInviteHash && typeof workspaceInviteHash === 'string') {
request.params.workspaceInviteHash = workspaceInviteHash;
}
if (
workspacePersonalInviteToken &&
typeof workspacePersonalInviteToken === 'string'
) {
request.params.workspacePersonalInviteToken =
workspacePersonalInviteToken;
}
if (
request.query.billingCheckoutSessionState &&
typeof request.query.billingCheckoutSessionState === 'string'
) {
request.params.billingCheckoutSessionState =
request.query.billingCheckoutSessionState;
}
return (await super.canActivate(context)) as boolean; return (await super.canActivate(context)) as boolean;
} catch (err) { } catch (err) {
this.guardRedirectService.dispatchErrorFromGuard( this.guardRedirectService.dispatchErrorFromGuard(

View File

@ -34,24 +34,14 @@ export class GoogleStrategy extends PassportStrategy(Strategy, 'google') {
}); });
} }
authenticate(req: any, options: any) { authenticate(req: Request, options: any) {
options = { options = {
...options, ...options,
state: JSON.stringify({ state: JSON.stringify({
workspaceInviteHash: req.params.workspaceInviteHash, workspaceInviteHash: req.query.workspaceInviteHash,
workspaceId: req.params.workspaceId, workspaceId: req.params.workspaceId,
...(req.params.billingCheckoutSessionState billingCheckoutSessionState: req.query.billingCheckoutSessionState,
? { workspacePersonalInviteToken: req.query.workspacePersonalInviteToken,
billingCheckoutSessionState:
req.params.billingCheckoutSessionState,
}
: {}),
...(req.params.workspacePersonalInviteToken
? {
workspacePersonalInviteToken:
req.params.workspacePersonalInviteToken,
}
: {}),
}), }),
}; };

View File

@ -38,24 +38,14 @@ export class MicrosoftStrategy extends PassportStrategy(Strategy, 'microsoft') {
}); });
} }
authenticate(req: any, options: any) { authenticate(req: Request, options: any) {
options = { options = {
...options, ...options,
state: JSON.stringify({ state: JSON.stringify({
workspaceInviteHash: req.params.workspaceInviteHash, workspaceInviteHash: req.query.workspaceInviteHash,
workspaceId: req.params.workspaceId, workspaceId: req.params.workspaceId,
...(req.params.billingCheckoutSessionState billingCheckoutSessionState: req.query.billingCheckoutSessionState,
? { workspacePersonalInviteToken: req.query.workspacePersonalInviteToken,
billingCheckoutSessionState:
req.params.billingCheckoutSessionState,
}
: {}),
...(req.params.workspacePersonalInviteToken
? {
workspacePersonalInviteToken:
req.params.workspacePersonalInviteToken,
}
: {}),
}), }),
}; };

View File

@ -3,11 +3,26 @@
import { Injectable } from '@nestjs/common'; import { Injectable } from '@nestjs/common';
import { PassportStrategy } from '@nestjs/passport'; import { PassportStrategy } from '@nestjs/passport';
import { isEmail } from 'class-validator';
import { Request } from 'express';
import { Strategy, StrategyOptions, TokenSet } from 'openid-client';
import { import {
Strategy, AuthException,
StrategyOptions, AuthExceptionCode,
StrategyVerifyCallbackReq, } from 'src/engine/core-modules/auth/auth.exception';
} from 'openid-client';
export type OIDCRequest = Omit<
Request,
'user' | 'workspace' | 'workspaceMetadataVersion'
> & {
user: {
identityProviderId: string;
email: string;
firstName?: string | null;
lastName?: string | null;
};
};
@Injectable() @Injectable()
export class OIDCAuthStrategy extends PassportStrategy( export class OIDCAuthStrategy extends PassportStrategy(
@ -30,7 +45,7 @@ export class OIDCAuthStrategy extends PassportStrategy(
}); });
} }
async authenticate(req: any, options: any) { async authenticate(req: Request, options: any) {
return super.authenticate(req, { return super.authenticate(req, {
...options, ...options,
state: JSON.stringify({ state: JSON.stringify({
@ -39,37 +54,50 @@ export class OIDCAuthStrategy extends PassportStrategy(
}); });
} }
validate: StrategyVerifyCallbackReq<{ private extractState(req: Request): {
identityProviderId: string; identityProviderId: string;
user: { } {
email?: string;
firstName?: string | null;
lastName?: string | null;
};
}> = async (req, tokenset, done) => {
try { try {
const state = JSON.parse( const state = JSON.parse(
'query' in req && req.query.state && typeof req.query.state === 'string'
req.query &&
typeof req.query === 'object' &&
'state' in req.query &&
req.query.state &&
typeof req.query.state === 'string'
? req.query.state ? req.query.state
: '{}', : '{}',
); );
if (!state.identityProviderId) {
throw new Error();
}
return {
identityProviderId: state.identityProviderId,
};
} catch (err) {
throw new AuthException('Invalid state', AuthExceptionCode.INVALID_INPUT);
}
}
async validate(
req: Request,
tokenset: TokenSet,
done: (err: any, user?: OIDCRequest['user']) => void,
) {
try {
const state = this.extractState(req);
const userinfo = await this.client.userinfo(tokenset); const userinfo = await this.client.userinfo(tokenset);
const user = { if (!userinfo.email || !isEmail(userinfo.email)) {
return done(new Error('Invalid email'));
}
done(null, {
email: userinfo.email, email: userinfo.email,
identityProviderId: state.identityProviderId,
...(userinfo.given_name ? { firstName: userinfo.given_name } : {}), ...(userinfo.given_name ? { firstName: userinfo.given_name } : {}),
...(userinfo.family_name ? { lastName: userinfo.family_name } : {}), ...(userinfo.family_name ? { lastName: userinfo.family_name } : {}),
}; });
done(null, { user, identityProviderId: state.identityProviderId });
} catch (err) { } catch (err) {
done(err); done(err);
} }
}; }
} }

View File

@ -15,6 +15,20 @@ import { isEmail } from 'class-validator';
import { Request } from 'express'; import { Request } from 'express';
import { SSOService } from 'src/engine/core-modules/sso/services/sso.service'; import { SSOService } from 'src/engine/core-modules/sso/services/sso.service';
import {
AuthException,
AuthExceptionCode,
} from 'src/engine/core-modules/auth/auth.exception';
export type SAMLRequest = Omit<
Request,
'user' | 'workspace' | 'workspaceMetadataVersion'
> & {
user: {
identityProviderId: string;
email: string;
};
};
@Injectable() @Injectable()
export class SamlAuthStrategy extends PassportStrategy( export class SamlAuthStrategy extends PassportStrategy(
@ -69,6 +83,24 @@ export class SamlAuthStrategy extends PassportStrategy(
}); });
} }
private extractState(req: Request): {
identityProviderId: string;
} {
try {
if ('RelayState' in req.body && typeof req.body.RelayState === 'string') {
const RelayState = JSON.parse(req.body.RelayState);
return {
identityProviderId: RelayState.identityProviderId,
};
}
throw new Error();
} catch (err) {
throw new AuthException('Invalid state', AuthExceptionCode.INVALID_INPUT);
}
}
validate: VerifyWithRequest = async (request, profile, done) => { validate: VerifyWithRequest = async (request, profile, done) => {
try { try {
if (!profile) { if (!profile) {
@ -80,20 +112,11 @@ export class SamlAuthStrategy extends PassportStrategy(
if (!isEmail(email)) { if (!isEmail(email)) {
return done(new Error('Invalid email')); return done(new Error('Invalid email'));
} }
const state = this.extractState(request);
const result: { const result: Pick<SAMLRequest, 'user'> = {
user: Record<string, string>; user: { ...state, email },
identityProviderId?: string; };
} = { user: { email } };
if (
'RelayState' in request.body &&
typeof request.body.RelayState === 'string'
) {
const RelayState = JSON.parse(request.body.RelayState);
result.identityProviderId = RelayState.identityProviderId;
}
done(null, result); done(null, result);
} catch (err) { } catch (err) {

View File

@ -1,6 +1,6 @@
import { useTheme } from '@emotion/react'; import { useTheme } from '@emotion/react';
import IconMicrosoftOutlookRaw from '../assets/microsoft-outlook.svg?react'; import IconMicrosoftRaw from '../assets/microsoft.svg?react';
interface IconMicrosoftOutlookProps { interface IconMicrosoftOutlookProps {
size?: number; size?: number;
@ -10,5 +10,5 @@ export const IconMicrosoftOutlook = (props: IconMicrosoftOutlookProps) => {
const theme = useTheme(); const theme = useTheme();
const size = props.size ?? theme.icon.size.lg; const size = props.size ?? theme.icon.size.lg;
return <IconMicrosoftOutlookRaw height={size} width={size} />; return <IconMicrosoftRaw height={size} width={size} />;
}; };