Begin refactoring AI module (#12464)
Cleaning up to prepare for a few tests
This commit is contained in:
14
package.json
14
package.json
@ -7,9 +7,9 @@
|
||||
"@aws-sdk/client-s3": "^3.363.0",
|
||||
"@aws-sdk/client-sts": "^3.744.0",
|
||||
"@aws-sdk/credential-providers": "^3.363.0",
|
||||
"@blocknote/mantine": "^0.22.0",
|
||||
"@blocknote/react": "^0.22.0",
|
||||
"@blocknote/server-util": "0.17.1",
|
||||
"@blocknote/mantine": "^0.31.1",
|
||||
"@blocknote/react": "^0.31.1",
|
||||
"@blocknote/server-util": "^0.17.1",
|
||||
"@codesandbox/sandpack-react": "^2.13.5",
|
||||
"@dagrejs/dagre": "^1.1.2",
|
||||
"@emotion/react": "^11.11.1",
|
||||
@ -45,9 +45,9 @@
|
||||
"@ptc-org/nestjs-query-typeorm": "4.2.1-alpha.2",
|
||||
"@react-email/components": "0.0.35",
|
||||
"@react-email/render": "0.0.17",
|
||||
"@sentry/node": "^8",
|
||||
"@sentry/profiling-node": "^8",
|
||||
"@sentry/react": "^8",
|
||||
"@sentry/node": "^9.26.0",
|
||||
"@sentry/profiling-node": "^9.26.0",
|
||||
"@sentry/react": "^9.26.0",
|
||||
"@sniptt/guards": "^0.2.0",
|
||||
"@stoplight/elements": "^8.0.5",
|
||||
"@swc/jest": "^0.2.29",
|
||||
@ -221,7 +221,7 @@
|
||||
"@nx/vite": "18.3.3",
|
||||
"@nx/web": "18.3.3",
|
||||
"@playwright/test": "^1.46.0",
|
||||
"@sentry/types": "^7.109.0",
|
||||
"@sentry/types": "^8",
|
||||
"@storybook/addon-actions": "^7.6.3",
|
||||
"@storybook/addon-coverage": "^1.0.0",
|
||||
"@storybook/addon-essentials": "^7.6.7",
|
||||
|
||||
@ -29,8 +29,9 @@
|
||||
"workerDirectory": "public"
|
||||
},
|
||||
"dependencies": {
|
||||
"@blocknote/xl-docx-exporter": "^0.22.0",
|
||||
"@blocknote/xl-pdf-exporter": "^0.22.0",
|
||||
"@blocknote/xl-ai": "^0.31.1",
|
||||
"@blocknote/xl-docx-exporter": "^0.31.1",
|
||||
"@blocknote/xl-pdf-exporter": "^0.31.1",
|
||||
"@cyntler/react-doc-viewer": "^1.17.0",
|
||||
"@lingui/core": "^5.1.2",
|
||||
"@lingui/detect-locale": "^5.2.0",
|
||||
|
||||
@ -653,13 +653,13 @@ export type FeatureFlagDto = {
|
||||
|
||||
export enum FeatureFlagKey {
|
||||
IS_AIRTABLE_INTEGRATION_ENABLED = 'IS_AIRTABLE_INTEGRATION_ENABLED',
|
||||
IS_COPILOT_ENABLED = 'IS_COPILOT_ENABLED',
|
||||
IS_JSON_FILTER_ENABLED = 'IS_JSON_FILTER_ENABLED',
|
||||
IS_PERMISSIONS_V2_ENABLED = 'IS_PERMISSIONS_V2_ENABLED',
|
||||
IS_POSTGRESQL_INTEGRATION_ENABLED = 'IS_POSTGRESQL_INTEGRATION_ENABLED',
|
||||
IS_STRIPE_INTEGRATION_ENABLED = 'IS_STRIPE_INTEGRATION_ENABLED',
|
||||
IS_UNIQUE_INDEXES_ENABLED = 'IS_UNIQUE_INDEXES_ENABLED',
|
||||
IS_WORKFLOW_ENABLED = 'IS_WORKFLOW_ENABLED'
|
||||
IS_WORKFLOW_ENABLED = 'IS_WORKFLOW_ENABLED',
|
||||
IS_AI_ENABLED = 'IS_AI_ENABLED'
|
||||
}
|
||||
|
||||
export type Field = {
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import { gql } from '@apollo/client';
|
||||
import * as Apollo from '@apollo/client';
|
||||
import { gql } from '@apollo/client';
|
||||
export type Maybe<T> = T | null;
|
||||
export type InputMaybe<T> = Maybe<T>;
|
||||
export type Exact<T extends { [key: string]: unknown }> = { [K in keyof T]: T[K] };
|
||||
@ -584,13 +584,13 @@ export type FeatureFlagDto = {
|
||||
|
||||
export enum FeatureFlagKey {
|
||||
IS_AIRTABLE_INTEGRATION_ENABLED = 'IS_AIRTABLE_INTEGRATION_ENABLED',
|
||||
IS_COPILOT_ENABLED = 'IS_COPILOT_ENABLED',
|
||||
IS_JSON_FILTER_ENABLED = 'IS_JSON_FILTER_ENABLED',
|
||||
IS_PERMISSIONS_V2_ENABLED = 'IS_PERMISSIONS_V2_ENABLED',
|
||||
IS_POSTGRESQL_INTEGRATION_ENABLED = 'IS_POSTGRESQL_INTEGRATION_ENABLED',
|
||||
IS_STRIPE_INTEGRATION_ENABLED = 'IS_STRIPE_INTEGRATION_ENABLED',
|
||||
IS_UNIQUE_INDEXES_ENABLED = 'IS_UNIQUE_INDEXES_ENABLED',
|
||||
IS_WORKFLOW_ENABLED = 'IS_WORKFLOW_ENABLED'
|
||||
IS_WORKFLOW_ENABLED = 'IS_WORKFLOW_ENABLED',
|
||||
IS_AI_ENABLED = 'IS_AI_ENABLED'
|
||||
}
|
||||
|
||||
export type Field = {
|
||||
|
||||
@ -4,7 +4,7 @@ import React, { useEffect, useState } from 'react';
|
||||
import rehypeStringify from 'rehype-stringify';
|
||||
import remarkParse from 'remark-parse';
|
||||
import remarkRehype from 'remark-rehype';
|
||||
import { unified } from 'unified';
|
||||
import { PluggableList, unified } from 'unified';
|
||||
import { visit } from 'unist-util-visit';
|
||||
|
||||
import { SettingsPageContainer } from '@/settings/components/SettingsPageContainer';
|
||||
@ -89,9 +89,7 @@ export const Releases = () => {
|
||||
for (const release of json) {
|
||||
release.html = String(
|
||||
await unified()
|
||||
.use(remarkParse)
|
||||
.use(remarkRehype)
|
||||
.use(rehypeStringify)
|
||||
.use([remarkParse, remarkRehype, rehypeStringify] as PluggableList)
|
||||
.use(() => (tree: any) => {
|
||||
visit(tree, (node) => {
|
||||
if (node.tagName === 'h1' || node.tagName === 'h2') {
|
||||
|
||||
@ -15,11 +15,10 @@
|
||||
"typeorm": "../../node_modules/typeorm/.bin/typeorm"
|
||||
},
|
||||
"dependencies": {
|
||||
"@ai-sdk/openai": "^1.3.22",
|
||||
"@clickhouse/client": "^1.11.0",
|
||||
"@esbuild-plugins/node-modules-polyfill": "^0.2.2",
|
||||
"@graphql-yoga/nestjs": "patch:@graphql-yoga/nestjs@2.1.0#./patches/@graphql-yoga+nestjs+2.1.0.patch",
|
||||
"@langchain/mistralai": "^0.0.24",
|
||||
"@langchain/openai": "^0.1.3",
|
||||
"@lingui/core": "^5.1.2",
|
||||
"@monaco-editor/react": "^4.6.0",
|
||||
"@nestjs/cache-manager": "^2.2.1",
|
||||
@ -28,11 +27,14 @@
|
||||
"@nestjs/schedule": "^3.0.0",
|
||||
"@node-saml/passport-saml": "^5.0.0",
|
||||
"@opentelemetry/api": "^1.9.0",
|
||||
"@opentelemetry/auto-instrumentations-node": "^0.60.0",
|
||||
"@opentelemetry/exporter-metrics-otlp-http": "^0.200.0",
|
||||
"@opentelemetry/sdk-metrics": "^2.0.0",
|
||||
"@opentelemetry/sdk-node": "^0.202.0",
|
||||
"@ptc-org/nestjs-query-graphql": "patch:@ptc-org/nestjs-query-graphql@4.2.0#./patches/@ptc-org+nestjs-query-graphql+4.2.0.patch",
|
||||
"@revertdotdev/revert-react": "^0.0.21",
|
||||
"@sentry/nestjs": "^8.30.0",
|
||||
"@sentry/nestjs": "^8.55.0",
|
||||
"ai": "^4.3.16",
|
||||
"cache-manager": "^5.4.0",
|
||||
"cache-manager-redis-yet": "^4.1.2",
|
||||
"class-validator": "patch:class-validator@0.14.0#./patches/class-validator+0.14.0.patch",
|
||||
@ -43,8 +45,6 @@
|
||||
"handlebars": "^4.7.8",
|
||||
"jsdom": "~22.1.0",
|
||||
"jwt-decode": "^4.0.0",
|
||||
"langchain": "^0.2.6",
|
||||
"langfuse-langchain": "^3.11.2",
|
||||
"lodash.differencewith": "^4.5.0",
|
||||
"lodash.merge": "^4.6.2",
|
||||
"lodash.omitby": "^4.6.0",
|
||||
|
||||
@ -0,0 +1 @@
|
||||
export const AI_DRIVER = Symbol('AI_DRIVER');
|
||||
@ -0,0 +1,17 @@
|
||||
import { AiDriver } from 'src/engine/core-modules/ai/interfaces/ai.interface';
|
||||
|
||||
import { TwentyConfigService } from 'src/engine/core-modules/twenty-config/twenty-config.service';
|
||||
|
||||
export const aiModuleFactory = (twentyConfigService: TwentyConfigService) => {
|
||||
const driver = twentyConfigService.get('AI_DRIVER');
|
||||
|
||||
switch (driver) {
|
||||
case AiDriver.OPENAI: {
|
||||
return { type: AiDriver.OPENAI };
|
||||
}
|
||||
default: {
|
||||
// Default to OpenAI driver if no driver is specified
|
||||
return { type: AiDriver.OPENAI };
|
||||
}
|
||||
}
|
||||
};
|
||||
@ -0,0 +1,40 @@
|
||||
import { DynamicModule, Global, Provider } from '@nestjs/common';
|
||||
|
||||
import {
|
||||
AiDriver,
|
||||
AiModuleAsyncOptions,
|
||||
} from 'src/engine/core-modules/ai/interfaces/ai.interface';
|
||||
|
||||
import { AI_DRIVER } from 'src/engine/core-modules/ai/ai.constants';
|
||||
import { AiService } from 'src/engine/core-modules/ai/ai.service';
|
||||
import { AiController } from 'src/engine/core-modules/ai/controllers/ai.controller';
|
||||
import { OpenAIDriver } from 'src/engine/core-modules/ai/drivers/openai.driver';
|
||||
import { FeatureFlagModule } from 'src/engine/core-modules/feature-flag/feature-flag.module';
|
||||
|
||||
@Global()
|
||||
export class AiModule {
|
||||
static forRoot(options: AiModuleAsyncOptions): DynamicModule {
|
||||
const provider: Provider = {
|
||||
provide: AI_DRIVER,
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
useFactory: (...args: any[]) => {
|
||||
const config = options.useFactory(...args);
|
||||
|
||||
switch (config?.type) {
|
||||
case AiDriver.OPENAI: {
|
||||
return new OpenAIDriver();
|
||||
}
|
||||
}
|
||||
},
|
||||
inject: options.inject || [],
|
||||
};
|
||||
|
||||
return {
|
||||
module: AiModule,
|
||||
imports: [FeatureFlagModule],
|
||||
controllers: [AiController],
|
||||
providers: [AiService, provider],
|
||||
exports: [AiService],
|
||||
};
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,19 @@
|
||||
import { Inject, Injectable } from '@nestjs/common';
|
||||
|
||||
import { CoreMessage, StreamTextResult } from 'ai';
|
||||
|
||||
import { AiDriver } from 'src/engine/core-modules/ai/drivers/interfaces/ai-driver.interface';
|
||||
|
||||
import { AI_DRIVER } from 'src/engine/core-modules/ai/ai.constants';
|
||||
|
||||
@Injectable()
|
||||
export class AiService {
|
||||
constructor(@Inject(AI_DRIVER) private driver: AiDriver) {}
|
||||
|
||||
streamText(
|
||||
messages: CoreMessage[],
|
||||
options?: { temperature?: number; maxTokens?: number },
|
||||
): StreamTextResult<Record<string, never>, undefined> {
|
||||
return this.driver.streamText(messages, options);
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,123 @@
|
||||
import { Test, TestingModule } from '@nestjs/testing';
|
||||
|
||||
import { AiService } from 'src/engine/core-modules/ai/ai.service';
|
||||
import { FeatureFlagService } from 'src/engine/core-modules/feature-flag/services/feature-flag.service';
|
||||
|
||||
import { AiController } from './ai.controller';
|
||||
|
||||
describe('AiController', () => {
|
||||
let controller: AiController;
|
||||
let aiService: jest.Mocked<AiService>;
|
||||
let featureFlagService: jest.Mocked<FeatureFlagService>;
|
||||
|
||||
beforeEach(async () => {
|
||||
const mockAiService = {
|
||||
streamText: jest.fn(),
|
||||
};
|
||||
|
||||
const mockFeatureFlagService = {
|
||||
isFeatureEnabled: jest.fn().mockResolvedValue(true),
|
||||
};
|
||||
|
||||
const module: TestingModule = await Test.createTestingModule({
|
||||
controllers: [AiController],
|
||||
providers: [
|
||||
{
|
||||
provide: AiService,
|
||||
useValue: mockAiService,
|
||||
},
|
||||
{
|
||||
provide: FeatureFlagService,
|
||||
useValue: mockFeatureFlagService,
|
||||
},
|
||||
],
|
||||
}).compile();
|
||||
|
||||
controller = module.get<AiController>(AiController);
|
||||
aiService = module.get(AiService);
|
||||
featureFlagService = module.get(FeatureFlagService);
|
||||
});
|
||||
|
||||
it('should be defined', () => {
|
||||
expect(controller).toBeDefined();
|
||||
});
|
||||
|
||||
describe('chat', () => {
|
||||
const mockWorkspace = { id: 'workspace-1' } as any;
|
||||
|
||||
it('should handle valid chat request', async () => {
|
||||
const mockRequest = {
|
||||
messages: [{ role: 'user' as const, content: 'Hello' }],
|
||||
temperature: 0.7,
|
||||
maxTokens: 100,
|
||||
};
|
||||
|
||||
const mockRes = {
|
||||
setHeader: jest.fn(),
|
||||
write: jest.fn(),
|
||||
end: jest.fn(),
|
||||
} as any;
|
||||
|
||||
const mockStreamTextResult = {
|
||||
pipeDataStreamToResponse: jest.fn(),
|
||||
};
|
||||
|
||||
aiService.streamText.mockReturnValue(mockStreamTextResult as any);
|
||||
|
||||
await controller.chat(mockRequest, mockWorkspace, mockRes);
|
||||
|
||||
expect(featureFlagService.isFeatureEnabled).toHaveBeenCalled();
|
||||
expect(aiService.streamText).toHaveBeenCalledWith(mockRequest.messages, {
|
||||
temperature: 0.7,
|
||||
maxTokens: 100,
|
||||
});
|
||||
expect(
|
||||
mockStreamTextResult.pipeDataStreamToResponse,
|
||||
).toHaveBeenCalledWith(mockRes);
|
||||
});
|
||||
|
||||
it('should throw error for empty messages', async () => {
|
||||
const mockRequest = {
|
||||
messages: [],
|
||||
};
|
||||
|
||||
const mockRes = {} as any;
|
||||
|
||||
await expect(
|
||||
controller.chat(mockRequest, mockWorkspace, mockRes),
|
||||
).rejects.toThrow('Messages array is required and cannot be empty');
|
||||
});
|
||||
|
||||
it('should handle service errors', async () => {
|
||||
const mockRequest = {
|
||||
messages: [{ role: 'user' as const, content: 'Hello' }],
|
||||
};
|
||||
|
||||
const mockRes = {} as any;
|
||||
|
||||
aiService.streamText.mockImplementation(() => {
|
||||
throw new Error('Service error');
|
||||
});
|
||||
|
||||
await expect(
|
||||
controller.chat(mockRequest, mockWorkspace, mockRes),
|
||||
).rejects.toThrow(
|
||||
'An error occurred while processing your request: Service error',
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw error when AI feature is disabled', async () => {
|
||||
featureFlagService.isFeatureEnabled.mockResolvedValue(false);
|
||||
|
||||
const mockRequest = {
|
||||
messages: [{ role: 'user' as const, content: 'Hello' }],
|
||||
};
|
||||
|
||||
const mockRes = {} as any;
|
||||
|
||||
await expect(
|
||||
controller.chat(mockRequest, mockWorkspace, mockRes),
|
||||
).rejects.toThrow('AI feature is not enabled for this workspace');
|
||||
});
|
||||
});
|
||||
});
|
||||
@ -0,0 +1,79 @@
|
||||
import {
|
||||
Body,
|
||||
Controller,
|
||||
HttpException,
|
||||
HttpStatus,
|
||||
Post,
|
||||
Res,
|
||||
UseGuards,
|
||||
} from '@nestjs/common';
|
||||
|
||||
import { CoreMessage } from 'ai';
|
||||
import { Response } from 'express';
|
||||
|
||||
import { AiService } from 'src/engine/core-modules/ai/ai.service';
|
||||
import { FeatureFlagKey } from 'src/engine/core-modules/feature-flag/enums/feature-flag-key.enum';
|
||||
import { FeatureFlagService } from 'src/engine/core-modules/feature-flag/services/feature-flag.service';
|
||||
import { Workspace } from 'src/engine/core-modules/workspace/workspace.entity';
|
||||
import { AuthWorkspace } from 'src/engine/decorators/auth/auth-workspace.decorator';
|
||||
import { WorkspaceAuthGuard } from 'src/engine/guards/workspace-auth.guard';
|
||||
|
||||
export interface ChatRequest {
|
||||
messages: CoreMessage[];
|
||||
temperature?: number;
|
||||
maxTokens?: number;
|
||||
}
|
||||
|
||||
@Controller('chat')
|
||||
@UseGuards(WorkspaceAuthGuard)
|
||||
export class AiController {
|
||||
constructor(
|
||||
private readonly aiService: AiService,
|
||||
private readonly featureFlagService: FeatureFlagService,
|
||||
) {}
|
||||
|
||||
@Post()
|
||||
async chat(
|
||||
@Body() request: ChatRequest,
|
||||
@AuthWorkspace() workspace: Workspace,
|
||||
@Res() res: Response,
|
||||
) {
|
||||
const isAiEnabled = await this.featureFlagService.isFeatureEnabled(
|
||||
FeatureFlagKey.IS_AI_ENABLED,
|
||||
workspace.id,
|
||||
);
|
||||
|
||||
if (!isAiEnabled) {
|
||||
throw new HttpException(
|
||||
'AI feature is not enabled for this workspace',
|
||||
HttpStatus.FORBIDDEN,
|
||||
);
|
||||
}
|
||||
|
||||
const { messages, temperature, maxTokens } = request;
|
||||
|
||||
if (!messages || messages.length === 0) {
|
||||
throw new HttpException(
|
||||
'Messages array is required and cannot be empty',
|
||||
HttpStatus.BAD_REQUEST,
|
||||
);
|
||||
}
|
||||
|
||||
try {
|
||||
const result = this.aiService.streamText(messages, {
|
||||
temperature,
|
||||
maxTokens,
|
||||
});
|
||||
|
||||
result.pipeDataStreamToResponse(res);
|
||||
} catch (error) {
|
||||
const errorMessage =
|
||||
error instanceof Error ? error.message : 'Unknown error occurred';
|
||||
|
||||
throw new HttpException(
|
||||
`An error occurred while processing your request: ${errorMessage}`,
|
||||
HttpStatus.INTERNAL_SERVER_ERROR,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,8 @@
|
||||
import { CoreMessage, StreamTextResult } from 'ai';
|
||||
|
||||
export interface AiDriver {
|
||||
streamText(
|
||||
messages: CoreMessage[],
|
||||
options?: { temperature?: number; maxTokens?: number },
|
||||
): StreamTextResult<Record<string, never>, undefined>;
|
||||
}
|
||||
@ -0,0 +1,18 @@
|
||||
import { openai } from '@ai-sdk/openai';
|
||||
import { CoreMessage, StreamTextResult, streamText } from 'ai';
|
||||
|
||||
import { AiDriver } from 'src/engine/core-modules/ai/drivers/interfaces/ai-driver.interface';
|
||||
|
||||
export class OpenAIDriver implements AiDriver {
|
||||
streamText(
|
||||
messages: CoreMessage[],
|
||||
options?: { temperature?: number; maxTokens?: number },
|
||||
): StreamTextResult<Record<string, never>, undefined> {
|
||||
return streamText({
|
||||
model: openai('gpt-4o-mini'),
|
||||
messages,
|
||||
temperature: options?.temperature,
|
||||
maxTokens: options?.maxTokens,
|
||||
});
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,14 @@
|
||||
import { InjectionToken, OptionalFactoryDependency } from '@nestjs/common';
|
||||
|
||||
export enum AiDriver {
|
||||
OPENAI = 'openai',
|
||||
}
|
||||
|
||||
export interface AiModuleOptions {
|
||||
type: AiDriver;
|
||||
}
|
||||
|
||||
export type AiModuleAsyncOptions = {
|
||||
inject?: (InjectionToken | OptionalFactoryDependency)[];
|
||||
useFactory: (...args: unknown[]) => AiModuleOptions | undefined;
|
||||
};
|
||||
@ -5,6 +5,8 @@ import { EventEmitterModule } from '@nestjs/event-emitter';
|
||||
import { WorkspaceQueryRunnerModule } from 'src/engine/api/graphql/workspace-query-runner/workspace-query-runner.module';
|
||||
import { ActorModule } from 'src/engine/core-modules/actor/actor.module';
|
||||
import { AdminPanelModule } from 'src/engine/core-modules/admin-panel/admin-panel.module';
|
||||
import { AiModule } from 'src/engine/core-modules/ai/ai.module';
|
||||
import { aiModuleFactory } from 'src/engine/core-modules/ai/ai.module-factory';
|
||||
import { AppTokenModule } from 'src/engine/core-modules/app-token/app-token.module';
|
||||
import { ApprovedAccessDomainModule } from 'src/engine/core-modules/approved-access-domain/approved-access-domain.module';
|
||||
import { AuthModule } from 'src/engine/core-modules/auth/auth.module';
|
||||
@ -21,10 +23,6 @@ import { FileStorageModule } from 'src/engine/core-modules/file-storage/file-sto
|
||||
import { FileStorageService } from 'src/engine/core-modules/file-storage/file-storage.service';
|
||||
import { HealthModule } from 'src/engine/core-modules/health/health.module';
|
||||
import { LabModule } from 'src/engine/core-modules/lab/lab.module';
|
||||
import { LLMChatModelModule } from 'src/engine/core-modules/llm-chat-model/llm-chat-model.module';
|
||||
import { llmChatModelModuleFactory } from 'src/engine/core-modules/llm-chat-model/llm-chat-model.module-factory';
|
||||
import { LLMTracingModule } from 'src/engine/core-modules/llm-tracing/llm-tracing.module';
|
||||
import { llmTracingModuleFactory } from 'src/engine/core-modules/llm-tracing/llm-tracing.module-factory';
|
||||
import { LoggerModule } from 'src/engine/core-modules/logger/logger.module';
|
||||
import { loggerModuleFactory } from 'src/engine/core-modules/logger/logger.module-factory';
|
||||
import { MessageQueueModule } from 'src/engine/core-modules/message-queue/message-queue.module';
|
||||
@ -105,12 +103,8 @@ import { FileModule } from './file/file.module';
|
||||
wildcard: true,
|
||||
}),
|
||||
CacheStorageModule,
|
||||
LLMChatModelModule.forRoot({
|
||||
useFactory: llmChatModelModuleFactory,
|
||||
inject: [TwentyConfigService],
|
||||
}),
|
||||
LLMTracingModule.forRoot({
|
||||
useFactory: llmTracingModuleFactory,
|
||||
AiModule.forRoot({
|
||||
useFactory: aiModuleFactory,
|
||||
inject: [TwentyConfigService],
|
||||
}),
|
||||
ServerlessModule.forRootAsync({
|
||||
|
||||
@ -2,9 +2,9 @@ export enum FeatureFlagKey {
|
||||
IS_AIRTABLE_INTEGRATION_ENABLED = 'IS_AIRTABLE_INTEGRATION_ENABLED',
|
||||
IS_POSTGRESQL_INTEGRATION_ENABLED = 'IS_POSTGRESQL_INTEGRATION_ENABLED',
|
||||
IS_STRIPE_INTEGRATION_ENABLED = 'IS_STRIPE_INTEGRATION_ENABLED',
|
||||
IS_COPILOT_ENABLED = 'IS_COPILOT_ENABLED',
|
||||
IS_WORKFLOW_ENABLED = 'IS_WORKFLOW_ENABLED',
|
||||
IS_UNIQUE_INDEXES_ENABLED = 'IS_UNIQUE_INDEXES_ENABLED',
|
||||
IS_JSON_FILTER_ENABLED = 'IS_JSON_FILTER_ENABLED',
|
||||
IS_PERMISSIONS_V2_ENABLED = 'IS_PERMISSIONS_V2_ENABLED',
|
||||
IS_AI_ENABLED = 'IS_AI_ENABLED',
|
||||
}
|
||||
|
||||
@ -122,12 +122,12 @@ describe('FeatureFlagService', () => {
|
||||
mockWorkspaceFeatureFlagsMapCacheService.getWorkspaceFeatureFlagsMap.mockResolvedValue(
|
||||
{
|
||||
[FeatureFlagKey.IS_WORKFLOW_ENABLED]: true,
|
||||
[FeatureFlagKey.IS_COPILOT_ENABLED]: false,
|
||||
[FeatureFlagKey.IS_AI_ENABLED]: false,
|
||||
},
|
||||
);
|
||||
const mockFeatureFlags = [
|
||||
{ key: FeatureFlagKey.IS_WORKFLOW_ENABLED, value: true },
|
||||
{ key: FeatureFlagKey.IS_COPILOT_ENABLED, value: false },
|
||||
{ key: FeatureFlagKey.IS_AI_ENABLED, value: false },
|
||||
];
|
||||
|
||||
// Act
|
||||
@ -146,7 +146,7 @@ describe('FeatureFlagService', () => {
|
||||
// Prepare
|
||||
const mockFeatureFlags = [
|
||||
{ key: FeatureFlagKey.IS_WORKFLOW_ENABLED, value: true, workspaceId },
|
||||
{ key: FeatureFlagKey.IS_COPILOT_ENABLED, value: false, workspaceId },
|
||||
{ key: FeatureFlagKey.IS_AI_ENABLED, value: false, workspaceId },
|
||||
];
|
||||
|
||||
mockFeatureFlagRepository.find.mockResolvedValue(mockFeatureFlags);
|
||||
@ -157,7 +157,7 @@ describe('FeatureFlagService', () => {
|
||||
// Assert
|
||||
expect(result).toEqual({
|
||||
[FeatureFlagKey.IS_WORKFLOW_ENABLED]: true,
|
||||
[FeatureFlagKey.IS_COPILOT_ENABLED]: false,
|
||||
[FeatureFlagKey.IS_AI_ENABLED]: false,
|
||||
});
|
||||
});
|
||||
});
|
||||
@ -167,7 +167,7 @@ describe('FeatureFlagService', () => {
|
||||
// Prepare
|
||||
const keys = [
|
||||
FeatureFlagKey.IS_WORKFLOW_ENABLED,
|
||||
FeatureFlagKey.IS_COPILOT_ENABLED,
|
||||
FeatureFlagKey.IS_AI_ENABLED,
|
||||
];
|
||||
|
||||
mockFeatureFlagRepository.upsert.mockResolvedValue({});
|
||||
|
||||
@ -1,5 +0,0 @@
|
||||
import { BaseChatModel } from '@langchain/core/language_models/chat_models';
|
||||
|
||||
export interface LLMChatModelDriver {
|
||||
getJSONChatModel(): BaseChatModel;
|
||||
}
|
||||
@ -1,22 +0,0 @@
|
||||
import { BaseChatModel } from '@langchain/core/language_models/chat_models';
|
||||
import { ChatOpenAI } from '@langchain/openai';
|
||||
|
||||
import { LLMChatModelDriver } from 'src/engine/core-modules/llm-chat-model/drivers/interfaces/llm-prompt-template-driver.interface';
|
||||
|
||||
export class OpenAIDriver implements LLMChatModelDriver {
|
||||
private chatModel: BaseChatModel;
|
||||
|
||||
constructor() {
|
||||
this.chatModel = new ChatOpenAI({
|
||||
model: 'gpt-4o',
|
||||
}).bind({
|
||||
response_format: {
|
||||
type: 'json_object',
|
||||
},
|
||||
}) as unknown as BaseChatModel;
|
||||
}
|
||||
|
||||
getJSONChatModel() {
|
||||
return this.chatModel;
|
||||
}
|
||||
}
|
||||
@ -1,15 +0,0 @@
|
||||
import { FactoryProvider, ModuleMetadata } from '@nestjs/common';
|
||||
|
||||
export enum LLMChatModelDriver {
|
||||
OPENAI = 'OPENAI',
|
||||
}
|
||||
|
||||
export interface LLMChatModelModuleOptions {
|
||||
type: LLMChatModelDriver;
|
||||
}
|
||||
|
||||
export type LLMChatModelModuleAsyncOptions = {
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
useFactory: (...args: any[]) => LLMChatModelModuleOptions | undefined;
|
||||
} & Pick<ModuleMetadata, 'imports'> &
|
||||
Pick<FactoryProvider, 'inject'>;
|
||||
@ -1 +0,0 @@
|
||||
export const LLM_CHAT_MODEL_DRIVER = Symbol('LLM_CHAT_MODEL_DRIVER');
|
||||
@ -1,17 +0,0 @@
|
||||
import { LLMChatModelDriver } from 'src/engine/core-modules/llm-chat-model/interfaces/llm-chat-model.interface';
|
||||
|
||||
import { TwentyConfigService } from 'src/engine/core-modules/twenty-config/twenty-config.service';
|
||||
|
||||
export const llmChatModelModuleFactory = (
|
||||
twentyConfigService: TwentyConfigService,
|
||||
) => {
|
||||
const driver = twentyConfigService.get('LLM_CHAT_MODEL_DRIVER');
|
||||
|
||||
switch (driver) {
|
||||
case LLMChatModelDriver.OPENAI: {
|
||||
return { type: LLMChatModelDriver.OPENAI };
|
||||
}
|
||||
default:
|
||||
// `No LLM chat model driver (${driver})`);
|
||||
}
|
||||
};
|
||||
@ -1,36 +0,0 @@
|
||||
import { DynamicModule, Global } from '@nestjs/common';
|
||||
|
||||
import {
|
||||
LLMChatModelDriver,
|
||||
LLMChatModelModuleAsyncOptions,
|
||||
} from 'src/engine/core-modules/llm-chat-model/interfaces/llm-chat-model.interface';
|
||||
|
||||
import { OpenAIDriver } from 'src/engine/core-modules/llm-chat-model/drivers/openai.driver';
|
||||
import { LLM_CHAT_MODEL_DRIVER } from 'src/engine/core-modules/llm-chat-model/llm-chat-model.constants';
|
||||
import { LLMChatModelService } from 'src/engine/core-modules/llm-chat-model/llm-chat-model.service';
|
||||
|
||||
@Global()
|
||||
export class LLMChatModelModule {
|
||||
static forRoot(options: LLMChatModelModuleAsyncOptions): DynamicModule {
|
||||
const provider = {
|
||||
provide: LLM_CHAT_MODEL_DRIVER,
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
useFactory: (...args: any[]) => {
|
||||
const config = options.useFactory(...args);
|
||||
|
||||
switch (config?.type) {
|
||||
case LLMChatModelDriver.OPENAI: {
|
||||
return new OpenAIDriver();
|
||||
}
|
||||
}
|
||||
},
|
||||
inject: options.inject || [],
|
||||
};
|
||||
|
||||
return {
|
||||
module: LLMChatModelModule,
|
||||
providers: [LLMChatModelService, provider],
|
||||
exports: [LLMChatModelService],
|
||||
};
|
||||
}
|
||||
}
|
||||
@ -1,16 +0,0 @@
|
||||
import { Injectable, Inject } from '@nestjs/common';
|
||||
|
||||
import { LLMChatModelDriver } from 'src/engine/core-modules/llm-chat-model/drivers/interfaces/llm-prompt-template-driver.interface';
|
||||
|
||||
import { LLM_CHAT_MODEL_DRIVER } from 'src/engine/core-modules/llm-chat-model/llm-chat-model.constants';
|
||||
|
||||
@Injectable()
|
||||
export class LLMChatModelService {
|
||||
constructor(
|
||||
@Inject(LLM_CHAT_MODEL_DRIVER) private driver: LLMChatModelDriver,
|
||||
) {}
|
||||
|
||||
getJSONChatModel() {
|
||||
return this.driver.getJSONChatModel();
|
||||
}
|
||||
}
|
||||
@ -1,26 +0,0 @@
|
||||
/* eslint-disable no-console */
|
||||
import { BaseCallbackHandler } from '@langchain/core/callbacks/base';
|
||||
import { Run } from '@langchain/core/tracers/base';
|
||||
import { ConsoleCallbackHandler } from '@langchain/core/tracers/console';
|
||||
|
||||
import { LLMTracingDriver } from 'src/engine/core-modules/llm-tracing/drivers/interfaces/llm-tracing-driver.interface';
|
||||
|
||||
class WithMetadataConsoleCallbackHandler extends ConsoleCallbackHandler {
|
||||
private metadata: Record<string, unknown>;
|
||||
|
||||
constructor(metadata: Record<string, unknown>) {
|
||||
super();
|
||||
this.metadata = metadata;
|
||||
}
|
||||
|
||||
onChainStart(run: Run) {
|
||||
console.log(`Chain metadata: ${JSON.stringify(this.metadata)}`);
|
||||
super.onChainStart(run);
|
||||
}
|
||||
}
|
||||
|
||||
export class ConsoleDriver implements LLMTracingDriver {
|
||||
getCallbackHandler(metadata: Record<string, unknown>): BaseCallbackHandler {
|
||||
return new WithMetadataConsoleCallbackHandler(metadata);
|
||||
}
|
||||
}
|
||||
@ -1,5 +0,0 @@
|
||||
import { BaseCallbackHandler } from '@langchain/core/callbacks/base';
|
||||
|
||||
export interface LLMTracingDriver {
|
||||
getCallbackHandler(metadata: Record<string, unknown>): BaseCallbackHandler;
|
||||
}
|
||||
@ -1,26 +0,0 @@
|
||||
import { BaseCallbackHandler } from '@langchain/core/callbacks/base';
|
||||
import CallbackHandler from 'langfuse-langchain';
|
||||
|
||||
import { LLMTracingDriver } from 'src/engine/core-modules/llm-tracing/drivers/interfaces/llm-tracing-driver.interface';
|
||||
|
||||
export interface LangfuseDriverOptions {
|
||||
secretKey: string;
|
||||
publicKey: string;
|
||||
}
|
||||
|
||||
export class LangfuseDriver implements LLMTracingDriver {
|
||||
private options: LangfuseDriverOptions;
|
||||
|
||||
constructor(options: LangfuseDriverOptions) {
|
||||
this.options = options;
|
||||
}
|
||||
|
||||
getCallbackHandler(metadata: Record<string, unknown>): BaseCallbackHandler {
|
||||
return new CallbackHandler({
|
||||
secretKey: this.options.secretKey,
|
||||
publicKey: this.options.publicKey,
|
||||
baseUrl: 'https://cloud.langfuse.com',
|
||||
metadata: metadata,
|
||||
});
|
||||
}
|
||||
}
|
||||
@ -1,27 +0,0 @@
|
||||
import { FactoryProvider, ModuleMetadata } from '@nestjs/common';
|
||||
|
||||
import { LangfuseDriverOptions } from 'src/engine/core-modules/llm-tracing/drivers/langfuse.driver';
|
||||
|
||||
export enum LLMTracingDriver {
|
||||
LANGFUSE = 'LANGFUSE',
|
||||
CONSOLE = 'CONSOLE',
|
||||
}
|
||||
|
||||
export interface LangfuseDriverFactoryOptions {
|
||||
type: LLMTracingDriver.LANGFUSE;
|
||||
options: LangfuseDriverOptions;
|
||||
}
|
||||
|
||||
export interface ConsoleDriverFactoryOptions {
|
||||
type: LLMTracingDriver.CONSOLE;
|
||||
}
|
||||
|
||||
export type LLMTracingModuleOptions =
|
||||
| LangfuseDriverFactoryOptions
|
||||
| ConsoleDriverFactoryOptions;
|
||||
|
||||
export type LLMTracingModuleAsyncOptions = {
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
useFactory: (...args: any[]) => LLMTracingModuleOptions;
|
||||
} & Pick<ModuleMetadata, 'imports'> &
|
||||
Pick<FactoryProvider, 'inject'>;
|
||||
@ -1 +0,0 @@
|
||||
export const LLM_TRACING_DRIVER = Symbol('LLM_TRACING_DRIVER');
|
||||
@ -1,34 +0,0 @@
|
||||
import { LLMTracingDriver } from 'src/engine/core-modules/llm-tracing/interfaces/llm-tracing.interface';
|
||||
|
||||
import { TwentyConfigService } from 'src/engine/core-modules/twenty-config/twenty-config.service';
|
||||
|
||||
export const llmTracingModuleFactory = (
|
||||
twentyConfigService: TwentyConfigService,
|
||||
) => {
|
||||
const driver = twentyConfigService.get('LLM_TRACING_DRIVER');
|
||||
|
||||
switch (driver) {
|
||||
case LLMTracingDriver.CONSOLE: {
|
||||
return { type: LLMTracingDriver.CONSOLE as const };
|
||||
}
|
||||
case LLMTracingDriver.LANGFUSE: {
|
||||
const secretKey = twentyConfigService.get('LANGFUSE_SECRET_KEY');
|
||||
const publicKey = twentyConfigService.get('LANGFUSE_PUBLIC_KEY');
|
||||
|
||||
if (!(secretKey && publicKey)) {
|
||||
throw new Error(
|
||||
`${driver} LLM tracing driver requires LANGFUSE_SECRET_KEY and LANGFUSE_PUBLIC_KEY to be defined, check your .env file`,
|
||||
);
|
||||
}
|
||||
|
||||
return {
|
||||
type: LLMTracingDriver.LANGFUSE as const,
|
||||
options: { secretKey, publicKey },
|
||||
};
|
||||
}
|
||||
default:
|
||||
throw new Error(
|
||||
`Invalid LLM tracing driver (${driver}), check your .env file`,
|
||||
);
|
||||
}
|
||||
};
|
||||
@ -1,40 +0,0 @@
|
||||
import { DynamicModule, Global } from '@nestjs/common';
|
||||
|
||||
import {
|
||||
LLMTracingDriver,
|
||||
LLMTracingModuleAsyncOptions,
|
||||
} from 'src/engine/core-modules/llm-tracing/interfaces/llm-tracing.interface';
|
||||
|
||||
import { ConsoleDriver } from 'src/engine/core-modules/llm-tracing/drivers/console.driver';
|
||||
import { LangfuseDriver } from 'src/engine/core-modules/llm-tracing/drivers/langfuse.driver';
|
||||
import { LLM_TRACING_DRIVER } from 'src/engine/core-modules/llm-tracing/llm-tracing.constants';
|
||||
import { LLMTracingService } from 'src/engine/core-modules/llm-tracing/llm-tracing.service';
|
||||
|
||||
@Global()
|
||||
export class LLMTracingModule {
|
||||
static forRoot(options: LLMTracingModuleAsyncOptions): DynamicModule {
|
||||
const provider = {
|
||||
provide: LLM_TRACING_DRIVER,
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
useFactory: (...args: any[]) => {
|
||||
const config = options.useFactory(...args);
|
||||
|
||||
switch (config.type) {
|
||||
case LLMTracingDriver.LANGFUSE: {
|
||||
return new LangfuseDriver(config.options);
|
||||
}
|
||||
case LLMTracingDriver.CONSOLE: {
|
||||
return new ConsoleDriver();
|
||||
}
|
||||
}
|
||||
},
|
||||
inject: options.inject || [],
|
||||
};
|
||||
|
||||
return {
|
||||
module: LLMTracingModule,
|
||||
providers: [LLMTracingService, provider],
|
||||
exports: [LLMTracingService],
|
||||
};
|
||||
}
|
||||
}
|
||||
@ -1,16 +0,0 @@
|
||||
import { Injectable, Inject } from '@nestjs/common';
|
||||
|
||||
import { BaseCallbackHandler } from '@langchain/core/callbacks/base';
|
||||
|
||||
import { LLMTracingDriver } from 'src/engine/core-modules/llm-tracing/drivers/interfaces/llm-tracing-driver.interface';
|
||||
|
||||
import { LLM_TRACING_DRIVER } from 'src/engine/core-modules/llm-tracing/llm-tracing.constants';
|
||||
|
||||
@Injectable()
|
||||
export class LLMTracingService {
|
||||
constructor(@Inject(LLM_TRACING_DRIVER) private driver: LLMTracingDriver) {}
|
||||
|
||||
getCallbackHandler(metadata: Record<string, unknown>): BaseCallbackHandler {
|
||||
return this.driver.getCallbackHandler(metadata);
|
||||
}
|
||||
}
|
||||
@ -11,8 +11,7 @@ import {
|
||||
} from 'class-validator';
|
||||
import { isDefined } from 'twenty-shared/utils';
|
||||
|
||||
import { LLMChatModelDriver } from 'src/engine/core-modules/llm-chat-model/interfaces/llm-chat-model.interface';
|
||||
import { LLMTracingDriver } from 'src/engine/core-modules/llm-tracing/interfaces/llm-tracing.interface';
|
||||
import { AiDriver } from 'src/engine/core-modules/ai/interfaces/ai.interface';
|
||||
import { AwsRegion } from 'src/engine/core-modules/twenty-config/interfaces/aws-region.interface';
|
||||
import { NodeEnvironment } from 'src/engine/core-modules/twenty-config/interfaces/node-environment.interface';
|
||||
import { SupportDriver } from 'src/engine/core-modules/twenty-config/interfaces/support.interface';
|
||||
@ -952,13 +951,13 @@ export class ConfigVariables {
|
||||
|
||||
@ConfigVariablesMetadata({
|
||||
group: ConfigVariablesGroup.LLM,
|
||||
description: 'Driver for the LLM chat model',
|
||||
description: 'Driver for the AI chat model',
|
||||
type: ConfigVariableType.ENUM,
|
||||
options: Object.values(LLMChatModelDriver),
|
||||
options: Object.values(AiDriver),
|
||||
isEnvOnly: true,
|
||||
})
|
||||
@CastToUpperSnakeCase()
|
||||
LLM_CHAT_MODEL_DRIVER: LLMChatModelDriver;
|
||||
AI_DRIVER: AiDriver;
|
||||
|
||||
@ConfigVariablesMetadata({
|
||||
group: ConfigVariablesGroup.LLM,
|
||||
@ -968,31 +967,6 @@ export class ConfigVariables {
|
||||
})
|
||||
OPENAI_API_KEY: string;
|
||||
|
||||
@ConfigVariablesMetadata({
|
||||
group: ConfigVariablesGroup.LLM,
|
||||
isSensitive: true,
|
||||
description: 'Secret key for Langfuse integration',
|
||||
type: ConfigVariableType.STRING,
|
||||
})
|
||||
LANGFUSE_SECRET_KEY: string;
|
||||
|
||||
@ConfigVariablesMetadata({
|
||||
group: ConfigVariablesGroup.LLM,
|
||||
description: 'Public key for Langfuse integration',
|
||||
type: ConfigVariableType.STRING,
|
||||
})
|
||||
LANGFUSE_PUBLIC_KEY: string;
|
||||
|
||||
@ConfigVariablesMetadata({
|
||||
group: ConfigVariablesGroup.LLM,
|
||||
description: 'Driver for LLM tracing',
|
||||
type: ConfigVariableType.ENUM,
|
||||
options: Object.values(LLMTracingDriver),
|
||||
isEnvOnly: true,
|
||||
})
|
||||
@CastToUpperSnakeCase()
|
||||
LLM_TRACING_DRIVER: LLMTracingDriver = LLMTracingDriver.CONSOLE;
|
||||
|
||||
@ConfigVariablesMetadata({
|
||||
group: ConfigVariablesGroup.ServerConfig,
|
||||
description: 'Enable or disable multi-workspace support',
|
||||
|
||||
@ -39,6 +39,7 @@ if (process.env.EXCEPTION_HANDLER_DRIVER === ExceptionHandlerDriver.SENTRY) {
|
||||
Sentry.expressIntegration(),
|
||||
Sentry.graphqlIntegration(),
|
||||
Sentry.postgresIntegration(),
|
||||
Sentry.vercelAIIntegration(),
|
||||
nodeProfilingIntegration(),
|
||||
],
|
||||
tracesSampleRate: 0.1,
|
||||
|
||||
Reference in New Issue
Block a user