refacto(auth): improve type + remove complexity (#9949)

- Improve type
- Remove unnecessary code
- Fix the issue that prevents the usage of invitations when a user signs
in with social media.
- Add Microsoft icon for sso list page
This commit is contained in:
Antoine Moreaux
2025-01-31 15:39:45 +01:00
committed by GitHub
parent 591301f7ce
commit b801307d92
11 changed files with 120 additions and 125 deletions

View File

@ -1,6 +1,11 @@
/* @license Enterprise */
import { IconComponent, IconGoogle, IconKey } from 'twenty-ui';
import {
IconComponent,
IconGoogle,
IconKey,
IconMicrosoftOutlook,
} from 'twenty-ui';
export const guessSSOIdentityProviderIconByUrl = (
url: string,
@ -9,5 +14,9 @@ export const guessSSOIdentityProviderIconByUrl = (
return IconGoogle;
}
if (url.includes('microsoft')) {
return IconMicrosoftOutlook;
}
return IconKey;
};

View File

@ -51,7 +51,6 @@ export class GoogleAuthController {
email,
picture,
workspaceInviteHash,
workspacePersonalInviteToken,
workspaceId,
billingCheckoutSessionState,
} = req.user;
@ -65,7 +64,7 @@ export class GoogleAuthController {
try {
const invitation =
currentWorkspace && workspacePersonalInviteToken && email
currentWorkspace && email
? await this.authService.findInvitationForSignInUp({
currentWorkspace,
email,

View File

@ -52,7 +52,6 @@ export class MicrosoftAuthController {
email,
picture,
workspaceInviteHash,
workspacePersonalInviteToken,
workspaceId,
billingCheckoutSessionState,
} = req.user;
@ -66,7 +65,7 @@ export class MicrosoftAuthController {
try {
const invitation =
currentWorkspace && workspacePersonalInviteToken && email
currentWorkspace && email
? await this.authService.findInvitationForSignInUp({
currentWorkspace,
email,

View File

@ -34,6 +34,8 @@ import { User } from 'src/engine/core-modules/user/user.entity';
import { AuthOAuthExceptionFilter } from 'src/engine/core-modules/auth/filters/auth-oauth-exception.filter';
import { EnvironmentService } from 'src/engine/core-modules/environment/environment.service';
import { GuardRedirectService } from 'src/engine/core-modules/guard-redirect/services/guard-redirect.service';
import { SAMLRequest } from 'src/engine/core-modules/auth/strategies/saml.auth.strategy';
import { OIDCRequest } from 'src/engine/core-modules/auth/strategies/oidc.auth.strategy';
@Controller('auth')
export class SSOAuthController {
@ -85,14 +87,14 @@ export class SSOAuthController {
@Get('oidc/callback')
@UseGuards(EnterpriseFeaturesEnabledGuard, OIDCAuthGuard)
@UseFilters(AuthOAuthExceptionFilter)
async oidcAuthCallback(@Req() req: any, @Res() res: Response) {
async oidcAuthCallback(@Req() req: OIDCRequest, @Res() res: Response) {
return await this.authCallback(req, res);
}
@Post('saml/callback/:identityProviderId')
@UseGuards(EnterpriseFeaturesEnabledGuard, SAMLAuthGuard)
@UseFilters(AuthOAuthExceptionFilter)
async samlAuthCallback(@Req() req: any, @Res() res: Response) {
async samlAuthCallback(@Req() req: SAMLRequest, @Res() res: Response) {
try {
return await this.authCallback(req, res);
} catch (err) {
@ -103,10 +105,10 @@ export class SSOAuthController {
}
}
private async authCallback({ user }: any, res: Response) {
private async authCallback(req: OIDCRequest | SAMLRequest, res: Response) {
const workspaceIdentityProvider =
await this.findWorkspaceIdentityProviderByIdentityProviderId(
user.identityProviderId,
req.user.identityProviderId,
);
try {
@ -117,7 +119,7 @@ export class SSOAuthController {
);
}
if (!user.user.email) {
if (!req.user.email) {
throw new AuthException(
'Email not found from identity provider.',
AuthExceptionCode.OAUTH_ACCESS_DENIED,
@ -125,7 +127,7 @@ export class SSOAuthController {
}
const { loginToken, identityProvider } = await this.generateLoginToken(
user.user,
req.user,
workspaceIdentityProvider,
);
@ -157,7 +159,7 @@ export class SSOAuthController {
}
private async generateLoginToken(
payload: { email: string } & Record<string, string>,
payload: { email: string },
identityProvider: WorkspaceSSOIdentityProvider,
) {
if (!identityProvider) {

View File

@ -3,6 +3,7 @@ import { AuthGuard } from '@nestjs/passport';
import { InjectRepository } from '@nestjs/typeorm';
import { Repository } from 'typeorm';
import { Request } from 'express';
import {
AuthException,
@ -26,7 +27,7 @@ export class GoogleOauthGuard extends AuthGuard('google') {
}
async canActivate(context: ExecutionContext) {
const request = context.switchToHttp().getRequest();
const request = context.switchToHttp().getRequest<Request>();
let workspace: Workspace | null = null;
try {
@ -40,9 +41,6 @@ export class GoogleOauthGuard extends AuthGuard('google') {
});
}
const workspaceInviteHash = request.query.inviteHash;
const workspacePersonalInviteToken = request.query.inviteToken;
if (request.query.error === 'access_denied') {
throw new AuthException(
'Google OAuth access denied',
@ -50,26 +48,6 @@ export class GoogleOauthGuard extends AuthGuard('google') {
);
}
if (workspaceInviteHash && typeof workspaceInviteHash === 'string') {
request.params.workspaceInviteHash = workspaceInviteHash;
}
if (
workspacePersonalInviteToken &&
typeof workspacePersonalInviteToken === 'string'
) {
request.params.workspacePersonalInviteToken =
workspacePersonalInviteToken;
}
if (
request.query.billingCheckoutSessionState &&
typeof request.query.billingCheckoutSessionState === 'string'
) {
request.params.billingCheckoutSessionState =
request.query.billingCheckoutSessionState;
}
return (await super.canActivate(context)) as boolean;
} catch (err) {
this.guardRedirectService.dispatchErrorFromGuard(

View File

@ -36,29 +36,6 @@ export class MicrosoftOAuthGuard extends AuthGuard('microsoft') {
});
}
const workspaceInviteHash = request.query.inviteHash;
const workspacePersonalInviteToken = request.query.inviteToken;
if (workspaceInviteHash && typeof workspaceInviteHash === 'string') {
request.params.workspaceInviteHash = workspaceInviteHash;
}
if (
workspacePersonalInviteToken &&
typeof workspacePersonalInviteToken === 'string'
) {
request.params.workspacePersonalInviteToken =
workspacePersonalInviteToken;
}
if (
request.query.billingCheckoutSessionState &&
typeof request.query.billingCheckoutSessionState === 'string'
) {
request.params.billingCheckoutSessionState =
request.query.billingCheckoutSessionState;
}
return (await super.canActivate(context)) as boolean;
} catch (err) {
this.guardRedirectService.dispatchErrorFromGuard(

View File

@ -34,24 +34,14 @@ export class GoogleStrategy extends PassportStrategy(Strategy, 'google') {
});
}
authenticate(req: any, options: any) {
authenticate(req: Request, options: any) {
options = {
...options,
state: JSON.stringify({
workspaceInviteHash: req.params.workspaceInviteHash,
workspaceInviteHash: req.query.workspaceInviteHash,
workspaceId: req.params.workspaceId,
...(req.params.billingCheckoutSessionState
? {
billingCheckoutSessionState:
req.params.billingCheckoutSessionState,
}
: {}),
...(req.params.workspacePersonalInviteToken
? {
workspacePersonalInviteToken:
req.params.workspacePersonalInviteToken,
}
: {}),
billingCheckoutSessionState: req.query.billingCheckoutSessionState,
workspacePersonalInviteToken: req.query.workspacePersonalInviteToken,
}),
};

View File

@ -38,24 +38,14 @@ export class MicrosoftStrategy extends PassportStrategy(Strategy, 'microsoft') {
});
}
authenticate(req: any, options: any) {
authenticate(req: Request, options: any) {
options = {
...options,
state: JSON.stringify({
workspaceInviteHash: req.params.workspaceInviteHash,
workspaceInviteHash: req.query.workspaceInviteHash,
workspaceId: req.params.workspaceId,
...(req.params.billingCheckoutSessionState
? {
billingCheckoutSessionState:
req.params.billingCheckoutSessionState,
}
: {}),
...(req.params.workspacePersonalInviteToken
? {
workspacePersonalInviteToken:
req.params.workspacePersonalInviteToken,
}
: {}),
billingCheckoutSessionState: req.query.billingCheckoutSessionState,
workspacePersonalInviteToken: req.query.workspacePersonalInviteToken,
}),
};

View File

@ -3,11 +3,26 @@
import { Injectable } from '@nestjs/common';
import { PassportStrategy } from '@nestjs/passport';
import { isEmail } from 'class-validator';
import { Request } from 'express';
import { Strategy, StrategyOptions, TokenSet } from 'openid-client';
import {
Strategy,
StrategyOptions,
StrategyVerifyCallbackReq,
} from 'openid-client';
AuthException,
AuthExceptionCode,
} from 'src/engine/core-modules/auth/auth.exception';
export type OIDCRequest = Omit<
Request,
'user' | 'workspace' | 'workspaceMetadataVersion'
> & {
user: {
identityProviderId: string;
email: string;
firstName?: string | null;
lastName?: string | null;
};
};
@Injectable()
export class OIDCAuthStrategy extends PassportStrategy(
@ -30,7 +45,7 @@ export class OIDCAuthStrategy extends PassportStrategy(
});
}
async authenticate(req: any, options: any) {
async authenticate(req: Request, options: any) {
return super.authenticate(req, {
...options,
state: JSON.stringify({
@ -39,37 +54,50 @@ export class OIDCAuthStrategy extends PassportStrategy(
});
}
validate: StrategyVerifyCallbackReq<{
private extractState(req: Request): {
identityProviderId: string;
user: {
email?: string;
firstName?: string | null;
lastName?: string | null;
};
}> = async (req, tokenset, done) => {
} {
try {
const state = JSON.parse(
'query' in req &&
req.query &&
typeof req.query === 'object' &&
'state' in req.query &&
req.query.state &&
typeof req.query.state === 'string'
req.query.state && typeof req.query.state === 'string'
? req.query.state
: '{}',
);
if (!state.identityProviderId) {
throw new Error();
}
return {
identityProviderId: state.identityProviderId,
};
} catch (err) {
throw new AuthException('Invalid state', AuthExceptionCode.INVALID_INPUT);
}
}
async validate(
req: Request,
tokenset: TokenSet,
done: (err: any, user?: OIDCRequest['user']) => void,
) {
try {
const state = this.extractState(req);
const userinfo = await this.client.userinfo(tokenset);
const user = {
if (!userinfo.email || !isEmail(userinfo.email)) {
return done(new Error('Invalid email'));
}
done(null, {
email: userinfo.email,
identityProviderId: state.identityProviderId,
...(userinfo.given_name ? { firstName: userinfo.given_name } : {}),
...(userinfo.family_name ? { lastName: userinfo.family_name } : {}),
};
done(null, { user, identityProviderId: state.identityProviderId });
});
} catch (err) {
done(err);
}
};
}
}

View File

@ -15,6 +15,20 @@ import { isEmail } from 'class-validator';
import { Request } from 'express';
import { SSOService } from 'src/engine/core-modules/sso/services/sso.service';
import {
AuthException,
AuthExceptionCode,
} from 'src/engine/core-modules/auth/auth.exception';
export type SAMLRequest = Omit<
Request,
'user' | 'workspace' | 'workspaceMetadataVersion'
> & {
user: {
identityProviderId: string;
email: string;
};
};
@Injectable()
export class SamlAuthStrategy extends PassportStrategy(
@ -69,6 +83,24 @@ export class SamlAuthStrategy extends PassportStrategy(
});
}
private extractState(req: Request): {
identityProviderId: string;
} {
try {
if ('RelayState' in req.body && typeof req.body.RelayState === 'string') {
const RelayState = JSON.parse(req.body.RelayState);
return {
identityProviderId: RelayState.identityProviderId,
};
}
throw new Error();
} catch (err) {
throw new AuthException('Invalid state', AuthExceptionCode.INVALID_INPUT);
}
}
validate: VerifyWithRequest = async (request, profile, done) => {
try {
if (!profile) {
@ -80,20 +112,11 @@ export class SamlAuthStrategy extends PassportStrategy(
if (!isEmail(email)) {
return done(new Error('Invalid email'));
}
const state = this.extractState(request);
const result: {
user: Record<string, string>;
identityProviderId?: string;
} = { user: { email } };
if (
'RelayState' in request.body &&
typeof request.body.RelayState === 'string'
) {
const RelayState = JSON.parse(request.body.RelayState);
result.identityProviderId = RelayState.identityProviderId;
}
const result: Pick<SAMLRequest, 'user'> = {
user: { ...state, email },
};
done(null, result);
} catch (err) {

View File

@ -1,6 +1,6 @@
import { useTheme } from '@emotion/react';
import IconMicrosoftOutlookRaw from '../assets/microsoft-outlook.svg?react';
import IconMicrosoftRaw from '../assets/microsoft.svg?react';
interface IconMicrosoftOutlookProps {
size?: number;
@ -10,5 +10,5 @@ export const IconMicrosoftOutlook = (props: IconMicrosoftOutlookProps) => {
const theme = useTheme();
const size = props.size ?? theme.icon.size.lg;
return <IconMicrosoftOutlookRaw height={size} width={size} />;
return <IconMicrosoftRaw height={size} width={size} />;
};