From 0a93468b9593f88daf367048cb8d7ec0a28ef498 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20Malfait?= Date: Fri, 11 Jul 2025 13:09:54 +0200 Subject: [PATCH] Refresh AI model setup (#13171) Instead of initializing model at start time we do it at run time to be able to swap model provider more easily. Also introduce a third driver for openai-compatible providers, which among other allows for local models with Ollama --- .../src/generated-metadata/graphql.ts | 3 +- .../twenty-front/src/generated/graphql.ts | 3 +- .../ai-agent-action/hooks/useAgentChat.ts | 7 + .../utils/parseAgentStreamingChunk.ts | 10 +- .../engine/core-modules/ai/ai.constants.ts | 1 - .../core-modules/ai/ai.module-factory.ts | 17 -- .../src/engine/core-modules/ai/ai.module.ts | 91 ++++----- .../ai/constants/ai-models.const.spec.ts | 141 ++++++++++--- .../ai/constants/ai-models.const.ts | 6 +- .../drivers/interfaces/ai-driver.interface.ts | 8 - .../core-modules/ai/drivers/openai.driver.ts | 18 -- .../ai/interfaces/ai.interface.ts | 14 -- .../ai/services/ai-billing.service.spec.ts | 15 ++ .../ai/services/ai-billing.service.ts | 9 +- .../ai/services/ai-model-registry.service.ts | 185 ++++++++++++++++++ .../core-modules/ai/services/ai.service.ts | 44 ++++- .../ai/utils/get-ai-model-by-id.util.ts | 7 - .../get-ai-models-with-auto.util.spec.ts | 35 ---- .../ai/utils/get-ai-models-with-auto.util.ts | 22 --- .../get-default-model-config.util.spec.ts | 29 --- .../ai/utils/get-default-model-config.util.ts | 17 -- .../get-effective-model-config.util.spec.ts | 30 --- .../utils/get-effective-model-config.util.ts | 20 -- .../services/access-token.service.spec.ts | 5 - .../core-modules/auth/token/token.module.ts | 2 - .../services/client-config.service.spec.ts | 14 +- .../services/client-config.service.ts | 63 +++--- .../engine/core-modules/core-engine.module.ts | 6 +- .../twenty-config/config-variables.ts | 42 +++- .../agent/agent-execution.service.ts | 101 ++++++---- .../agent/agent-streaming.service.ts | 12 +- .../ai-agent/ai-agent-action.module.ts | 6 +- 32 files changed, 566 insertions(+), 417 deletions(-) delete mode 100644 packages/twenty-server/src/engine/core-modules/ai/ai.constants.ts delete mode 100644 packages/twenty-server/src/engine/core-modules/ai/ai.module-factory.ts delete mode 100644 packages/twenty-server/src/engine/core-modules/ai/drivers/interfaces/ai-driver.interface.ts delete mode 100644 packages/twenty-server/src/engine/core-modules/ai/drivers/openai.driver.ts delete mode 100644 packages/twenty-server/src/engine/core-modules/ai/interfaces/ai.interface.ts create mode 100644 packages/twenty-server/src/engine/core-modules/ai/services/ai-model-registry.service.ts delete mode 100644 packages/twenty-server/src/engine/core-modules/ai/utils/get-ai-model-by-id.util.ts delete mode 100644 packages/twenty-server/src/engine/core-modules/ai/utils/get-ai-models-with-auto.util.spec.ts delete mode 100644 packages/twenty-server/src/engine/core-modules/ai/utils/get-ai-models-with-auto.util.ts delete mode 100644 packages/twenty-server/src/engine/core-modules/ai/utils/get-default-model-config.util.spec.ts delete mode 100644 packages/twenty-server/src/engine/core-modules/ai/utils/get-default-model-config.util.ts delete mode 100644 packages/twenty-server/src/engine/core-modules/ai/utils/get-effective-model-config.util.spec.ts delete mode 100644 packages/twenty-server/src/engine/core-modules/ai/utils/get-effective-model-config.util.ts diff --git a/packages/twenty-front/src/generated-metadata/graphql.ts b/packages/twenty-front/src/generated-metadata/graphql.ts index 299e3f8f5..3c76b4a93 100644 --- a/packages/twenty-front/src/generated-metadata/graphql.ts +++ b/packages/twenty-front/src/generated-metadata/graphql.ts @@ -1041,7 +1041,8 @@ export enum MessageChannelVisibility { export enum ModelProvider { ANTHROPIC = 'ANTHROPIC', NONE = 'NONE', - OPENAI = 'OPENAI' + OPENAI = 'OPENAI', + OPENAI_COMPATIBLE = 'OPENAI_COMPATIBLE' } export type Mutation = { diff --git a/packages/twenty-front/src/generated/graphql.ts b/packages/twenty-front/src/generated/graphql.ts index 8bcf6aa33..33cae512d 100644 --- a/packages/twenty-front/src/generated/graphql.ts +++ b/packages/twenty-front/src/generated/graphql.ts @@ -998,7 +998,8 @@ export enum MessageChannelVisibility { export enum ModelProvider { ANTHROPIC = 'ANTHROPIC', NONE = 'NONE', - OPENAI = 'OPENAI' + OPENAI = 'OPENAI', + OPENAI_COMPATIBLE = 'OPENAI_COMPATIBLE' } export type Mutation = { diff --git a/packages/twenty-front/src/modules/workflow/workflow-steps/workflow-actions/ai-agent-action/hooks/useAgentChat.ts b/packages/twenty-front/src/modules/workflow/workflow-steps/workflow-actions/ai-agent-action/hooks/useAgentChat.ts index 11b48fd80..d692befeb 100644 --- a/packages/twenty-front/src/modules/workflow/workflow-steps/workflow-actions/ai-agent-action/hooks/useAgentChat.ts +++ b/packages/twenty-front/src/modules/workflow/workflow-steps/workflow-actions/ai-agent-action/hooks/useAgentChat.ts @@ -1,3 +1,4 @@ +import { useSnackBar } from '@/ui/feedback/snack-bar-manager/hooks/useSnackBar'; import { useRecoilComponentStateV2 } from '@/ui/utilities/state/component-state/hooks/useRecoilComponentStateV2'; import { useState } from 'react'; import { useRecoilState } from 'recoil'; @@ -22,6 +23,7 @@ interface OptimisticMessage extends AgentChatMessage { export const useAgentChat = (agentId: string) => { const apolloClient = useApolloClient(); + const { enqueueErrorSnackBar } = useSnackBar(); const [agentChatMessages, setAgentChatMessages] = useRecoilComponentStateV2( agentChatMessagesComponentState, @@ -112,6 +114,11 @@ export const useAgentChat = (agentId: string) => { })); scrollToBottom(); }, + onError: (message: string) => { + enqueueErrorSnackBar({ + message, + }); + }, }); }, }, diff --git a/packages/twenty-front/src/modules/workflow/workflow-steps/workflow-actions/ai-agent-action/utils/parseAgentStreamingChunk.ts b/packages/twenty-front/src/modules/workflow/workflow-steps/workflow-actions/ai-agent-action/utils/parseAgentStreamingChunk.ts index ebad33061..a0cb4006c 100644 --- a/packages/twenty-front/src/modules/workflow/workflow-steps/workflow-actions/ai-agent-action/utils/parseAgentStreamingChunk.ts +++ b/packages/twenty-front/src/modules/workflow/workflow-steps/workflow-actions/ai-agent-action/utils/parseAgentStreamingChunk.ts @@ -1,12 +1,13 @@ export type AgentStreamingEvent = { - type: 'text-delta' | 'tool-call'; + type: 'text-delta' | 'tool-call' | 'error'; message: string; }; export type AgentStreamingParserCallbacks = { onTextDelta?: (message: string) => void; onToolCall?: (message: string) => void; - onError?: (error: Error, rawLine: string) => void; + onError?: (message: string) => void; + onParseError?: (error: Error, rawLine: string) => void; }; export const parseAgentStreamingChunk = ( @@ -27,6 +28,9 @@ export const parseAgentStreamingChunk = ( case 'tool-call': callbacks.onToolCall?.(event.message); break; + case 'error': + callbacks.onError?.(event.message); + break; } } catch (error) { // eslint-disable-next-line no-console @@ -36,7 +40,7 @@ export const parseAgentStreamingChunk = ( error instanceof Error ? error : new Error(`Unknown parsing error: ${String(error)}`); - callbacks.onError?.(errorMessage, line); + callbacks.onParseError?.(errorMessage, line); } } } diff --git a/packages/twenty-server/src/engine/core-modules/ai/ai.constants.ts b/packages/twenty-server/src/engine/core-modules/ai/ai.constants.ts deleted file mode 100644 index 8f94b9e9c..000000000 --- a/packages/twenty-server/src/engine/core-modules/ai/ai.constants.ts +++ /dev/null @@ -1 +0,0 @@ -export const AI_DRIVER = Symbol('AI_DRIVER'); diff --git a/packages/twenty-server/src/engine/core-modules/ai/ai.module-factory.ts b/packages/twenty-server/src/engine/core-modules/ai/ai.module-factory.ts deleted file mode 100644 index 4ae585674..000000000 --- a/packages/twenty-server/src/engine/core-modules/ai/ai.module-factory.ts +++ /dev/null @@ -1,17 +0,0 @@ -import { AiDriver } from 'src/engine/core-modules/ai/interfaces/ai.interface'; - -import { TwentyConfigService } from 'src/engine/core-modules/twenty-config/twenty-config.service'; - -export const aiModuleFactory = (twentyConfigService: TwentyConfigService) => { - const driver = twentyConfigService.get('AI_DRIVER'); - - switch (driver) { - case AiDriver.OPENAI: { - return { type: AiDriver.OPENAI }; - } - default: { - // Default to OpenAI driver if no driver is specified - return { type: AiDriver.OPENAI }; - } - } -}; diff --git a/packages/twenty-server/src/engine/core-modules/ai/ai.module.ts b/packages/twenty-server/src/engine/core-modules/ai/ai.module.ts index 980969bb4..2cd845d5a 100644 --- a/packages/twenty-server/src/engine/core-modules/ai/ai.module.ts +++ b/packages/twenty-server/src/engine/core-modules/ai/ai.module.ts @@ -1,65 +1,46 @@ -import { DynamicModule, Global, Provider } from '@nestjs/common'; +import { Global, Module } from '@nestjs/common'; import { TypeOrmModule } from '@nestjs/typeorm'; -import { - AiDriver, - AiModuleAsyncOptions, -} from 'src/engine/core-modules/ai/interfaces/ai.interface'; - -import { AI_DRIVER } from 'src/engine/core-modules/ai/ai.constants'; -import { AiService } from 'src/engine/core-modules/ai/services/ai.service'; import { AiController } from 'src/engine/core-modules/ai/controllers/ai.controller'; -import { OpenAIDriver } from 'src/engine/core-modules/ai/drivers/openai.driver'; -import { AIBillingService } from 'src/engine/core-modules/ai/services/ai-billing.service'; -import { FeatureFlagModule } from 'src/engine/core-modules/feature-flag/feature-flag.module'; import { McpController } from 'src/engine/core-modules/ai/controllers/mcp.controller'; -import { AuthModule } from 'src/engine/core-modules/auth/auth.module'; -import { ToolService } from 'src/engine/core-modules/ai/services/tool.service'; +import { AIBillingService } from 'src/engine/core-modules/ai/services/ai-billing.service'; +import { AiModelRegistryService } from 'src/engine/core-modules/ai/services/ai-model-registry.service'; +import { AiService } from 'src/engine/core-modules/ai/services/ai.service'; import { McpService } from 'src/engine/core-modules/ai/services/mcp.service'; +import { ToolService } from 'src/engine/core-modules/ai/services/tool.service'; +import { TokenModule } from 'src/engine/core-modules/auth/token/token.module'; +import { FeatureFlagModule } from 'src/engine/core-modules/feature-flag/feature-flag.module'; import { ObjectMetadataModule } from 'src/engine/metadata-modules/object-metadata/object-metadata.module'; +import { RoleEntity } from 'src/engine/metadata-modules/role/role.entity'; +import { UserRoleModule } from 'src/engine/metadata-modules/user-role/user-role.module'; import { WorkspacePermissionsCacheModule } from 'src/engine/metadata-modules/workspace-permissions-cache/workspace-permissions-cache.module'; import { WorkspaceCacheStorageModule } from 'src/engine/workspace-cache-storage/workspace-cache-storage.module'; -import { UserRoleModule } from 'src/engine/metadata-modules/user-role/user-role.module'; -import { RoleEntity } from 'src/engine/metadata-modules/role/role.entity'; @Global() -export class AiModule { - static forRoot(options: AiModuleAsyncOptions): DynamicModule { - const provider: Provider = { - provide: AI_DRIVER, - // eslint-disable-next-line @typescript-eslint/no-explicit-any - useFactory: (...args: any[]) => { - const config = options.useFactory(...args); - - switch (config?.type) { - case AiDriver.OPENAI: { - return new OpenAIDriver(); - } - } - }, - inject: options.inject || [], - }; - - return { - module: AiModule, - imports: [ - TypeOrmModule.forFeature([RoleEntity], 'core'), - FeatureFlagModule, - ObjectMetadataModule, - WorkspacePermissionsCacheModule, - WorkspaceCacheStorageModule, - UserRoleModule, - AuthModule, - ], - controllers: [AiController, McpController], - providers: [ - AiService, - ToolService, - AIBillingService, - McpService, - provider, - ], - exports: [AiService, AIBillingService, ToolService, McpService], - }; - } -} +@Module({ + imports: [ + TypeOrmModule.forFeature([RoleEntity], 'core'), + TokenModule, + FeatureFlagModule, + ObjectMetadataModule, + WorkspacePermissionsCacheModule, + WorkspaceCacheStorageModule, + UserRoleModule, + ], + controllers: [AiController, McpController], + providers: [ + AiService, + AiModelRegistryService, + ToolService, + AIBillingService, + McpService, + ], + exports: [ + AiService, + AiModelRegistryService, + AIBillingService, + ToolService, + McpService, + ], +}) +export class AiModule {} diff --git a/packages/twenty-server/src/engine/core-modules/ai/constants/ai-models.const.spec.ts b/packages/twenty-server/src/engine/core-modules/ai/constants/ai-models.const.spec.ts index 088b894e1..0d6e920d3 100644 --- a/packages/twenty-server/src/engine/core-modules/ai/constants/ai-models.const.spec.ts +++ b/packages/twenty-server/src/engine/core-modules/ai/constants/ai-models.const.spec.ts @@ -1,7 +1,9 @@ -import { getAIModelsWithAuto } from 'src/engine/core-modules/ai/utils/get-ai-models-with-auto.util'; -import { getDefaultModelConfig } from 'src/engine/core-modules/ai/utils/get-default-model-config.util'; +import { Test, TestingModule } from '@nestjs/testing'; -import { AI_MODELS, DEFAULT_MODEL_ID, ModelProvider } from './ai-models.const'; +import { AiModelRegistryService } from 'src/engine/core-modules/ai/services/ai-model-registry.service'; +import { TwentyConfigService } from 'src/engine/core-modules/twenty-config/twenty-config.service'; + +import { AI_MODELS, ModelProvider } from './ai-models.const'; describe('AI_MODELS', () => { it('should contain all expected models', () => { @@ -15,41 +17,120 @@ describe('AI_MODELS', () => { 'claude-3-5-haiku-20241022', ]); }); - - it('should have the default model as the first model', () => { - const DEFAULT_MODEL = AI_MODELS.find( - (model) => model.modelId === DEFAULT_MODEL_ID, - ); - - expect(DEFAULT_MODEL).toBeDefined(); - expect(DEFAULT_MODEL?.modelId).toBe(DEFAULT_MODEL_ID); - }); }); -describe('getAIModelsWithAuto', () => { - it('should return AI_MODELS with auto model prepended', () => { - const ORIGINAL_MODELS = AI_MODELS; - const MODELS_WITH_AUTO = getAIModelsWithAuto(); +describe('AiModelRegistryService', () => { + let SERVICE: AiModelRegistryService; + let MOCK_CONFIG_SERVICE: jest.Mocked; - expect(MODELS_WITH_AUTO).toHaveLength(ORIGINAL_MODELS.length + 1); - expect(MODELS_WITH_AUTO[0].modelId).toBe('auto'); - expect(MODELS_WITH_AUTO[0].label).toBe('Auto'); - expect(MODELS_WITH_AUTO[0].provider).toBe(ModelProvider.NONE); + beforeEach(async () => { + MOCK_CONFIG_SERVICE = { + get: jest.fn(), + } as any; - // Check that the rest of the models are the same - expect(MODELS_WITH_AUTO.slice(1)).toEqual(ORIGINAL_MODELS); + const MODULE: TestingModule = await Test.createTestingModule({ + providers: [ + AiModelRegistryService, + { + provide: TwentyConfigService, + useValue: MOCK_CONFIG_SERVICE, + }, + ], + }).compile(); + + SERVICE = MODULE.get(AiModelRegistryService); }); - it('should have auto model with default model costs', () => { - const MODELS_WITH_AUTO = getAIModelsWithAuto(); - const AUTO_MODEL = MODELS_WITH_AUTO[0]; - const DEFAULT_MODEL = getDefaultModelConfig(); + it('should return effective model config for auto', () => { + MOCK_CONFIG_SERVICE.get.mockReturnValue('gpt-4o'); - expect(AUTO_MODEL.inputCostPer1kTokensInCents).toBe( - DEFAULT_MODEL.inputCostPer1kTokensInCents, + expect(() => SERVICE.getEffectiveModelConfig('auto')).toThrow( + 'No AI models are available. Please configure at least one provider.', ); - expect(AUTO_MODEL.outputCostPer1kTokensInCents).toBe( - DEFAULT_MODEL.outputCostPer1kTokensInCents, + }); + + it('should return effective model config for auto when models are available', () => { + MOCK_CONFIG_SERVICE.get.mockReturnValue('gpt-4o'); + + jest.spyOn(SERVICE, 'getAvailableModels').mockReturnValue([ + { + modelId: 'gpt-4o', + provider: ModelProvider.OPENAI, + model: {} as any, + }, + ]); + + jest.spyOn(SERVICE, 'getModel').mockReturnValue({ + modelId: 'gpt-4o', + provider: ModelProvider.OPENAI, + model: {} as any, + }); + + const RESULT = SERVICE.getEffectiveModelConfig('auto'); + + expect(RESULT).toBeDefined(); + expect(RESULT.modelId).toBe('gpt-4o'); + expect(RESULT.provider).toBe(ModelProvider.OPENAI); + }); + + it('should return effective model config for auto with custom model', () => { + MOCK_CONFIG_SERVICE.get.mockReturnValue('mistral'); + + jest.spyOn(SERVICE, 'getAvailableModels').mockReturnValue([ + { + modelId: 'mistral', + provider: ModelProvider.OPENAI_COMPATIBLE, + model: {} as any, + }, + ]); + + jest.spyOn(SERVICE, 'getModel').mockReturnValue({ + modelId: 'mistral', + provider: ModelProvider.OPENAI_COMPATIBLE, + model: {} as any, + }); + + const RESULT = SERVICE.getEffectiveModelConfig('auto'); + + expect(RESULT).toBeDefined(); + expect(RESULT.modelId).toBe('mistral'); + expect(RESULT.provider).toBe(ModelProvider.OPENAI_COMPATIBLE); + expect(RESULT.label).toBe('mistral'); + expect(RESULT.inputCostPer1kTokensInCents).toBe(0); + expect(RESULT.outputCostPer1kTokensInCents).toBe(0); + }); + + it('should return effective model config for specific model', () => { + const RESULT = SERVICE.getEffectiveModelConfig('gpt-4o-mini'); + + expect(RESULT).toBeDefined(); + expect(RESULT.modelId).toBe('gpt-4o-mini'); + expect(RESULT.provider).toBe(ModelProvider.OPENAI); + }); + + it('should return effective model config for custom model', () => { + // Mock that the custom model exists in registry + jest.spyOn(SERVICE, 'getModel').mockReturnValue({ + modelId: 'mistral', + provider: ModelProvider.OPENAI_COMPATIBLE, + model: {} as any, + }); + + const RESULT = SERVICE.getEffectiveModelConfig('mistral'); + + expect(RESULT).toBeDefined(); + expect(RESULT.modelId).toBe('mistral'); + expect(RESULT.provider).toBe(ModelProvider.OPENAI_COMPATIBLE); + expect(RESULT.label).toBe('mistral'); + expect(RESULT.inputCostPer1kTokensInCents).toBe(0); + expect(RESULT.outputCostPer1kTokensInCents).toBe(0); + }); + + it('should throw error for non-existent model', () => { + jest.spyOn(SERVICE, 'getModel').mockReturnValue(undefined); + + expect(() => SERVICE.getEffectiveModelConfig('non-existent-model')).toThrow( + 'Model with ID non-existent-model not found', ); }); }); diff --git a/packages/twenty-server/src/engine/core-modules/ai/constants/ai-models.const.ts b/packages/twenty-server/src/engine/core-modules/ai/constants/ai-models.const.ts index 1cc29639f..60d47ecd9 100644 --- a/packages/twenty-server/src/engine/core-modules/ai/constants/ai-models.const.ts +++ b/packages/twenty-server/src/engine/core-modules/ai/constants/ai-models.const.ts @@ -2,6 +2,7 @@ export enum ModelProvider { NONE = 'none', OPENAI = 'openai', ANTHROPIC = 'anthropic', + OPENAI_COMPATIBLE = 'open_ai_compatible', } export type ModelId = @@ -11,9 +12,8 @@ export type ModelId = | 'gpt-4-turbo' | 'claude-opus-4-20250514' | 'claude-sonnet-4-20250514' - | 'claude-3-5-haiku-20241022'; - -export const DEFAULT_MODEL_ID: ModelId = 'gpt-4o'; + | 'claude-3-5-haiku-20241022' + | string; // Allow custom model names export interface AIModelConfig { modelId: ModelId; diff --git a/packages/twenty-server/src/engine/core-modules/ai/drivers/interfaces/ai-driver.interface.ts b/packages/twenty-server/src/engine/core-modules/ai/drivers/interfaces/ai-driver.interface.ts deleted file mode 100644 index c3d578a9a..000000000 --- a/packages/twenty-server/src/engine/core-modules/ai/drivers/interfaces/ai-driver.interface.ts +++ /dev/null @@ -1,8 +0,0 @@ -import { CoreMessage, StreamTextResult } from 'ai'; - -export interface AiDriver { - streamText( - messages: CoreMessage[], - options?: { temperature?: number; maxTokens?: number }, - ): StreamTextResult, undefined>; -} diff --git a/packages/twenty-server/src/engine/core-modules/ai/drivers/openai.driver.ts b/packages/twenty-server/src/engine/core-modules/ai/drivers/openai.driver.ts deleted file mode 100644 index fe0da7e12..000000000 --- a/packages/twenty-server/src/engine/core-modules/ai/drivers/openai.driver.ts +++ /dev/null @@ -1,18 +0,0 @@ -import { openai } from '@ai-sdk/openai'; -import { CoreMessage, StreamTextResult, streamText } from 'ai'; - -import { AiDriver } from 'src/engine/core-modules/ai/drivers/interfaces/ai-driver.interface'; - -export class OpenAIDriver implements AiDriver { - streamText( - messages: CoreMessage[], - options?: { temperature?: number; maxTokens?: number }, - ): StreamTextResult, undefined> { - return streamText({ - model: openai('gpt-4o-mini'), - messages, - temperature: options?.temperature, - maxTokens: options?.maxTokens, - }); - } -} diff --git a/packages/twenty-server/src/engine/core-modules/ai/interfaces/ai.interface.ts b/packages/twenty-server/src/engine/core-modules/ai/interfaces/ai.interface.ts deleted file mode 100644 index 060628180..000000000 --- a/packages/twenty-server/src/engine/core-modules/ai/interfaces/ai.interface.ts +++ /dev/null @@ -1,14 +0,0 @@ -import { InjectionToken, OptionalFactoryDependency } from '@nestjs/common'; - -export enum AiDriver { - OPENAI = 'openai', -} - -export interface AiModuleOptions { - type: AiDriver; -} - -export type AiModuleAsyncOptions = { - inject?: (InjectionToken | OptionalFactoryDependency)[]; - useFactory: (...args: unknown[]) => AiModuleOptions | undefined; -}; diff --git a/packages/twenty-server/src/engine/core-modules/ai/services/ai-billing.service.spec.ts b/packages/twenty-server/src/engine/core-modules/ai/services/ai-billing.service.spec.ts index eeed3b492..ebf868902 100644 --- a/packages/twenty-server/src/engine/core-modules/ai/services/ai-billing.service.spec.ts +++ b/packages/twenty-server/src/engine/core-modules/ai/services/ai-billing.service.spec.ts @@ -5,6 +5,7 @@ import { BillingMeterEventName } from 'src/engine/core-modules/billing/enums/bil import { WorkspaceEventEmitter } from 'src/engine/workspace-event-emitter/workspace-event-emitter'; import { AIBillingService } from './ai-billing.service'; +import { AiModelRegistryService } from './ai-model-registry.service'; describe('AIBillingService', () => { let service: AIBillingService; @@ -21,6 +22,16 @@ describe('AIBillingService', () => { emitCustomBatchEvent: jest.fn(), }; + const mockAiModelRegistryMethods = { + getEffectiveModelConfig: jest.fn().mockReturnValue({ + modelId: 'gpt-4o', + label: 'GPT-4o', + provider: 'openai', + inputCostPer1kTokensInCents: 0.25, + outputCostPer1kTokensInCents: 1.0, + }), + }; + const module: TestingModule = await Test.createTestingModule({ providers: [ AIBillingService, @@ -28,6 +39,10 @@ describe('AIBillingService', () => { provide: WorkspaceEventEmitter, useValue: mockEventEmitterMethods, }, + { + provide: AiModelRegistryService, + useValue: mockAiModelRegistryMethods, + }, ], }).compile(); diff --git a/packages/twenty-server/src/engine/core-modules/ai/services/ai-billing.service.ts b/packages/twenty-server/src/engine/core-modules/ai/services/ai-billing.service.ts index 6aff7756e..0372a3c29 100644 --- a/packages/twenty-server/src/engine/core-modules/ai/services/ai-billing.service.ts +++ b/packages/twenty-server/src/engine/core-modules/ai/services/ai-billing.service.ts @@ -2,7 +2,7 @@ import { Injectable, Logger } from '@nestjs/common'; import { ModelId } from 'src/engine/core-modules/ai/constants/ai-models.const'; import { DOLLAR_TO_CREDIT_MULTIPLIER } from 'src/engine/core-modules/ai/constants/dollar-to-credit-multiplier'; -import { getAIModelById } from 'src/engine/core-modules/ai/utils/get-ai-model-by-id.util'; +import { AiModelRegistryService } from 'src/engine/core-modules/ai/services/ai-model-registry.service'; import { BILLING_FEATURE_USED } from 'src/engine/core-modules/billing/constants/billing-feature-used.constant'; import { BillingMeterEventName } from 'src/engine/core-modules/billing/enums/billing-meter-event-names'; import { BillingUsageEvent } from 'src/engine/core-modules/billing/types/billing-usage-event.type'; @@ -18,10 +18,13 @@ export interface TokenUsage { export class AIBillingService { private readonly logger = new Logger(AIBillingService.name); - constructor(private readonly workspaceEventEmitter: WorkspaceEventEmitter) {} + constructor( + private readonly workspaceEventEmitter: WorkspaceEventEmitter, + private readonly aiModelRegistryService: AiModelRegistryService, + ) {} async calculateCost(modelId: ModelId, usage: TokenUsage): Promise { - const model = getAIModelById(modelId); + const model = this.aiModelRegistryService.getEffectiveModelConfig(modelId); if (!model) { throw new Error(`AI model with id ${modelId} not found`); diff --git a/packages/twenty-server/src/engine/core-modules/ai/services/ai-model-registry.service.ts b/packages/twenty-server/src/engine/core-modules/ai/services/ai-model-registry.service.ts new file mode 100644 index 000000000..dc5e7a464 --- /dev/null +++ b/packages/twenty-server/src/engine/core-modules/ai/services/ai-model-registry.service.ts @@ -0,0 +1,185 @@ +import { Injectable } from '@nestjs/common'; + +import { anthropic } from '@ai-sdk/anthropic'; +import { createOpenAI, openai } from '@ai-sdk/openai'; +import { LanguageModel } from 'ai'; + +import { + AI_MODELS, + AIModelConfig, + ModelProvider, +} from 'src/engine/core-modules/ai/constants/ai-models.const'; +import { TwentyConfigService } from 'src/engine/core-modules/twenty-config/twenty-config.service'; + +export interface RegisteredAIModel { + modelId: string; + provider: ModelProvider; + model: LanguageModel; +} + +@Injectable() +export class AiModelRegistryService { + private modelRegistry: Map = new Map(); + + constructor(private twentyConfigService: TwentyConfigService) { + this.buildModelRegistry(); + } + + private buildModelRegistry(): void { + this.modelRegistry.clear(); + + const openaiApiKey = this.twentyConfigService.get('OPENAI_API_KEY'); + + if (openaiApiKey) { + this.registerOpenAIModels(); + } + + const anthropicApiKey = this.twentyConfigService.get('ANTHROPIC_API_KEY'); + + if (anthropicApiKey) { + this.registerAnthropicModels(); + } + + const openaiCompatibleBaseUrl = this.twentyConfigService.get( + 'OPENAI_COMPATIBLE_BASE_URL', + ); + const openaiCompatibleModelNames = this.twentyConfigService.get( + 'OPENAI_COMPATIBLE_MODEL_NAMES', + ); + + if (openaiCompatibleBaseUrl && openaiCompatibleModelNames) { + this.registerOpenAICompatibleModels( + openaiCompatibleBaseUrl, + openaiCompatibleModelNames, + ); + } + } + + private registerOpenAIModels(): void { + const openaiModels = AI_MODELS.filter( + (model) => model.provider === ModelProvider.OPENAI, + ); + + openaiModels.forEach((modelConfig) => { + this.modelRegistry.set(modelConfig.modelId, { + modelId: modelConfig.modelId, + provider: ModelProvider.OPENAI, + model: openai(modelConfig.modelId), + }); + }); + } + + private registerAnthropicModels(): void { + const anthropicModels = AI_MODELS.filter( + (model) => model.provider === ModelProvider.ANTHROPIC, + ); + + anthropicModels.forEach((modelConfig) => { + this.modelRegistry.set(modelConfig.modelId, { + modelId: modelConfig.modelId, + provider: ModelProvider.ANTHROPIC, + model: anthropic(modelConfig.modelId), + }); + }); + } + + private registerOpenAICompatibleModels( + baseUrl: string, + modelNamesString: string, + ): void { + const apiKey = this.twentyConfigService.get('OPENAI_COMPATIBLE_API_KEY'); + const provider = createOpenAI({ + baseURL: baseUrl, + apiKey: apiKey, + }); + + const modelNames = modelNamesString + .split(',') + .map((name) => name.trim()) + .filter((name) => name.length > 0); + + modelNames.forEach((modelId) => { + this.modelRegistry.set(modelId, { + modelId, + provider: ModelProvider.OPENAI_COMPATIBLE, + model: provider(modelId), + }); + }); + } + + getModel(modelId: string): RegisteredAIModel | undefined { + return this.modelRegistry.get(modelId); + } + + getAvailableModels(): RegisteredAIModel[] { + return Array.from(this.modelRegistry.values()); + } + + getDefaultModel(): RegisteredAIModel | undefined { + const defaultModelId = this.twentyConfigService.get('DEFAULT_MODEL_ID'); + let model = this.getModel(defaultModelId); + + if (!model) { + const availableModels = this.getAvailableModels(); + + model = availableModels[0]; + } + + return model; + } + + getEffectiveModelConfig(modelId: string): AIModelConfig { + if (modelId === 'auto') { + const defaultModel = this.getDefaultModel(); + + if (!defaultModel) { + throw new Error( + 'No AI models are available. Please configure at least one provider.', + ); + } + + const modelConfig = AI_MODELS.find( + (model) => model.modelId === defaultModel.modelId, + ); + + if (modelConfig) { + return modelConfig; + } + + return this.createDefaultConfigForCustomModel(defaultModel); + } + + const predefinedModel = AI_MODELS.find( + (model) => model.modelId === modelId, + ); + + if (predefinedModel) { + return predefinedModel; + } + + const registeredModel = this.getModel(modelId); + + if (registeredModel) { + return this.createDefaultConfigForCustomModel(registeredModel); + } + + throw new Error(`Model with ID ${modelId} not found`); + } + + private createDefaultConfigForCustomModel( + registeredModel: RegisteredAIModel, + ): AIModelConfig { + return { + modelId: registeredModel.modelId, + label: registeredModel.modelId, + provider: registeredModel.provider, + inputCostPer1kTokensInCents: 0, + outputCostPer1kTokensInCents: 0, + }; + } + + // Force refresh the registry (useful if config changes) + refreshRegistry(): void { + this.buildModelRegistry(); + } +} diff --git a/packages/twenty-server/src/engine/core-modules/ai/services/ai.service.ts b/packages/twenty-server/src/engine/core-modules/ai/services/ai.service.ts index 771b10d01..7c3316e71 100644 --- a/packages/twenty-server/src/engine/core-modules/ai/services/ai.service.ts +++ b/packages/twenty-server/src/engine/core-modules/ai/services/ai.service.ts @@ -1,19 +1,47 @@ -import { Inject, Injectable } from '@nestjs/common'; +import { Injectable } from '@nestjs/common'; -import { CoreMessage, StreamTextResult } from 'ai'; +import { CoreMessage, StreamTextResult, streamText } from 'ai'; -import { AiDriver } from 'src/engine/core-modules/ai/drivers/interfaces/ai-driver.interface'; - -import { AI_DRIVER } from 'src/engine/core-modules/ai/ai.constants'; +import { AiModelRegistryService } from 'src/engine/core-modules/ai/services/ai-model-registry.service'; @Injectable() export class AiService { - constructor(@Inject(AI_DRIVER) private driver: AiDriver) {} + constructor(private aiModelRegistryService: AiModelRegistryService) {} streamText( messages: CoreMessage[], - options?: { temperature?: number; maxTokens?: number }, + options?: { + temperature?: number; + maxTokens?: number; + modelId?: string; // Optional model override + }, ): StreamTextResult, undefined> { - return this.driver.streamText(messages, options); + const modelId = options?.modelId; + const registeredModel = modelId + ? this.aiModelRegistryService.getModel(modelId) + : this.aiModelRegistryService.getDefaultModel(); + + if (!registeredModel) { + throw new Error( + modelId + ? `Model "${modelId}" is not available. Please check your configuration.` + : 'No AI models are available. Please configure at least one provider.', + ); + } + + return streamText({ + model: registeredModel.model, + messages, + temperature: options?.temperature, + maxTokens: options?.maxTokens, + }); + } + + getAvailableModels() { + return this.aiModelRegistryService.getAvailableModels(); + } + + getDefaultModel() { + return this.aiModelRegistryService.getDefaultModel(); } } diff --git a/packages/twenty-server/src/engine/core-modules/ai/utils/get-ai-model-by-id.util.ts b/packages/twenty-server/src/engine/core-modules/ai/utils/get-ai-model-by-id.util.ts deleted file mode 100644 index aeb22df97..000000000 --- a/packages/twenty-server/src/engine/core-modules/ai/utils/get-ai-model-by-id.util.ts +++ /dev/null @@ -1,7 +0,0 @@ -import { AIModelConfig } from 'src/engine/core-modules/ai/constants/ai-models.const'; - -import { getEffectiveModelConfig } from './get-effective-model-config.util'; - -export const getAIModelById = (modelId: string): AIModelConfig => { - return getEffectiveModelConfig(modelId); -}; diff --git a/packages/twenty-server/src/engine/core-modules/ai/utils/get-ai-models-with-auto.util.spec.ts b/packages/twenty-server/src/engine/core-modules/ai/utils/get-ai-models-with-auto.util.spec.ts deleted file mode 100644 index 26d961642..000000000 --- a/packages/twenty-server/src/engine/core-modules/ai/utils/get-ai-models-with-auto.util.spec.ts +++ /dev/null @@ -1,35 +0,0 @@ -import { - AI_MODELS, - ModelProvider, -} from 'src/engine/core-modules/ai/constants/ai-models.const'; - -import { getAIModelsWithAuto } from './get-ai-models-with-auto.util'; -import { getDefaultModelConfig } from './get-default-model-config.util'; - -describe('getAIModelsWithAuto', () => { - it('should return AI_MODELS with auto model prepended', () => { - const ORIGINAL_MODELS = AI_MODELS; - const MODELS_WITH_AUTO = getAIModelsWithAuto(); - - expect(MODELS_WITH_AUTO).toHaveLength(ORIGINAL_MODELS.length + 1); - expect(MODELS_WITH_AUTO[0].modelId).toBe('auto'); - expect(MODELS_WITH_AUTO[0].label).toBe('Auto'); - expect(MODELS_WITH_AUTO[0].provider).toBe(ModelProvider.NONE); - - // Check that the rest of the models are the same - expect(MODELS_WITH_AUTO.slice(1)).toEqual(ORIGINAL_MODELS); - }); - - it('should have auto model with default model costs', () => { - const MODELS_WITH_AUTO = getAIModelsWithAuto(); - const AUTO_MODEL = MODELS_WITH_AUTO[0]; - const DEFAULT_MODEL = getDefaultModelConfig(); - - expect(AUTO_MODEL.inputCostPer1kTokensInCents).toBe( - DEFAULT_MODEL.inputCostPer1kTokensInCents, - ); - expect(AUTO_MODEL.outputCostPer1kTokensInCents).toBe( - DEFAULT_MODEL.outputCostPer1kTokensInCents, - ); - }); -}); diff --git a/packages/twenty-server/src/engine/core-modules/ai/utils/get-ai-models-with-auto.util.ts b/packages/twenty-server/src/engine/core-modules/ai/utils/get-ai-models-with-auto.util.ts deleted file mode 100644 index 330c74f60..000000000 --- a/packages/twenty-server/src/engine/core-modules/ai/utils/get-ai-models-with-auto.util.ts +++ /dev/null @@ -1,22 +0,0 @@ -import { - AI_MODELS, - AIModelConfig, - ModelProvider, -} from 'src/engine/core-modules/ai/constants/ai-models.const'; - -import { getDefaultModelConfig } from './get-default-model-config.util'; - -export const getAIModelsWithAuto = (): AIModelConfig[] => { - return [ - { - modelId: 'auto', - label: 'Auto', - provider: ModelProvider.NONE, - inputCostPer1kTokensInCents: - getDefaultModelConfig().inputCostPer1kTokensInCents, - outputCostPer1kTokensInCents: - getDefaultModelConfig().outputCostPer1kTokensInCents, - }, - ...AI_MODELS, - ]; -}; diff --git a/packages/twenty-server/src/engine/core-modules/ai/utils/get-default-model-config.util.spec.ts b/packages/twenty-server/src/engine/core-modules/ai/utils/get-default-model-config.util.spec.ts deleted file mode 100644 index 563726ce7..000000000 --- a/packages/twenty-server/src/engine/core-modules/ai/utils/get-default-model-config.util.spec.ts +++ /dev/null @@ -1,29 +0,0 @@ -import { - AI_MODELS, - DEFAULT_MODEL_ID, - ModelProvider, -} from 'src/engine/core-modules/ai/constants/ai-models.const'; - -import { getDefaultModelConfig } from './get-default-model-config.util'; - -describe('getDefaultModelConfig', () => { - it('should return the configuration for the default model', () => { - const result = getDefaultModelConfig(); - - expect(result).toBeDefined(); - expect(result.modelId).toBe(DEFAULT_MODEL_ID); - expect(result.provider).toBe(ModelProvider.OPENAI); - }); - - it('should throw an error if default model is not found', () => { - const originalFind = AI_MODELS.find; - - AI_MODELS.find = jest.fn().mockReturnValue(undefined); - - expect(() => getDefaultModelConfig()).toThrow( - `Default model with ID ${DEFAULT_MODEL_ID} not found`, - ); - - AI_MODELS.find = originalFind; - }); -}); diff --git a/packages/twenty-server/src/engine/core-modules/ai/utils/get-default-model-config.util.ts b/packages/twenty-server/src/engine/core-modules/ai/utils/get-default-model-config.util.ts deleted file mode 100644 index a38a26f6e..000000000 --- a/packages/twenty-server/src/engine/core-modules/ai/utils/get-default-model-config.util.ts +++ /dev/null @@ -1,17 +0,0 @@ -import { - AI_MODELS, - AIModelConfig, - DEFAULT_MODEL_ID, -} from 'src/engine/core-modules/ai/constants/ai-models.const'; - -export const getDefaultModelConfig = (): AIModelConfig => { - const defaultModel = AI_MODELS.find( - (model) => model.modelId === DEFAULT_MODEL_ID, - ); - - if (!defaultModel) { - throw new Error(`Default model with ID ${DEFAULT_MODEL_ID} not found`); - } - - return defaultModel; -}; diff --git a/packages/twenty-server/src/engine/core-modules/ai/utils/get-effective-model-config.util.spec.ts b/packages/twenty-server/src/engine/core-modules/ai/utils/get-effective-model-config.util.spec.ts deleted file mode 100644 index 1936179a9..000000000 --- a/packages/twenty-server/src/engine/core-modules/ai/utils/get-effective-model-config.util.spec.ts +++ /dev/null @@ -1,30 +0,0 @@ -import { - DEFAULT_MODEL_ID, - ModelProvider, -} from 'src/engine/core-modules/ai/constants/ai-models.const'; - -import { getEffectiveModelConfig } from './get-effective-model-config.util'; - -describe('getEffectiveModelConfig', () => { - it('should return default model config when modelId is "auto"', () => { - const result = getEffectiveModelConfig('auto'); - - expect(result).toBeDefined(); - expect(result.modelId).toBe(DEFAULT_MODEL_ID); - expect(result.provider).toBe(ModelProvider.OPENAI); - }); - - it('should return the correct model config for a specific model', () => { - const result = getEffectiveModelConfig('gpt-4o'); - - expect(result).toBeDefined(); - expect(result.modelId).toBe('gpt-4o'); - expect(result.provider).toBe(ModelProvider.OPENAI); - }); - - it('should throw an error for non-existent model', () => { - expect(() => getEffectiveModelConfig('non-existent-model' as any)).toThrow( - `Model with ID non-existent-model not found`, - ); - }); -}); diff --git a/packages/twenty-server/src/engine/core-modules/ai/utils/get-effective-model-config.util.ts b/packages/twenty-server/src/engine/core-modules/ai/utils/get-effective-model-config.util.ts deleted file mode 100644 index a59f3ff65..000000000 --- a/packages/twenty-server/src/engine/core-modules/ai/utils/get-effective-model-config.util.ts +++ /dev/null @@ -1,20 +0,0 @@ -import { - AI_MODELS, - AIModelConfig, -} from 'src/engine/core-modules/ai/constants/ai-models.const'; - -import { getDefaultModelConfig } from './get-default-model-config.util'; - -export const getEffectiveModelConfig = (modelId: string): AIModelConfig => { - if (modelId === 'auto') { - return getDefaultModelConfig(); - } - - const model = AI_MODELS.find((model) => model.modelId === modelId); - - if (!model) { - throw new Error(`Model with ID ${modelId} not found`); - } - - return model; -}; diff --git a/packages/twenty-server/src/engine/core-modules/auth/token/services/access-token.service.spec.ts b/packages/twenty-server/src/engine/core-modules/auth/token/services/access-token.service.spec.ts index 0494de098..dcaaea1e4 100644 --- a/packages/twenty-server/src/engine/core-modules/auth/token/services/access-token.service.spec.ts +++ b/packages/twenty-server/src/engine/core-modules/auth/token/services/access-token.service.spec.ts @@ -10,7 +10,6 @@ import { AuthException } from 'src/engine/core-modules/auth/auth.exception'; import { JwtAuthStrategy } from 'src/engine/core-modules/auth/strategies/jwt.auth.strategy'; import { EmailService } from 'src/engine/core-modules/email/email.service'; import { JwtWrapperService } from 'src/engine/core-modules/jwt/services/jwt-wrapper.service'; -import { SSOService } from 'src/engine/core-modules/sso/services/sso.service'; import { TwentyConfigService } from 'src/engine/core-modules/twenty-config/twenty-config.service'; import { UserWorkspace } from 'src/engine/core-modules/user-workspace/user-workspace.entity'; import { User } from 'src/engine/core-modules/user/user.entity'; @@ -74,10 +73,6 @@ describe('AccessTokenService', () => { provide: EmailService, useValue: {}, }, - { - provide: SSOService, - useValue: {}, - }, { provide: TwentyORMGlobalManager, useValue: { diff --git a/packages/twenty-server/src/engine/core-modules/auth/token/token.module.ts b/packages/twenty-server/src/engine/core-modules/auth/token/token.module.ts index fe610a5f0..23eea1d4a 100644 --- a/packages/twenty-server/src/engine/core-modules/auth/token/token.module.ts +++ b/packages/twenty-server/src/engine/core-modules/auth/token/token.module.ts @@ -12,7 +12,6 @@ import { RefreshTokenService } from 'src/engine/core-modules/auth/token/services import { RenewTokenService } from 'src/engine/core-modules/auth/token/services/renew-token.service'; import { WorkspaceAgnosticTokenService } from 'src/engine/core-modules/auth/token/services/workspace-agnostic-token.service'; import { JwtModule } from 'src/engine/core-modules/jwt/jwt.module'; -import { WorkspaceSSOModule } from 'src/engine/core-modules/sso/sso.module'; import { UserWorkspace } from 'src/engine/core-modules/user-workspace/user-workspace.entity'; import { User } from 'src/engine/core-modules/user/user.entity'; import { Workspace } from 'src/engine/core-modules/workspace/workspace.entity'; @@ -27,7 +26,6 @@ import { DataSourceModule } from 'src/engine/metadata-modules/data-source/data-s ), TypeORMModule, DataSourceModule, - WorkspaceSSOModule, ], providers: [ RenewTokenService, diff --git a/packages/twenty-server/src/engine/core-modules/client-config/services/client-config.service.spec.ts b/packages/twenty-server/src/engine/core-modules/client-config/services/client-config.service.spec.ts index 7246b1a13..3e24331c5 100644 --- a/packages/twenty-server/src/engine/core-modules/client-config/services/client-config.service.spec.ts +++ b/packages/twenty-server/src/engine/core-modules/client-config/services/client-config.service.spec.ts @@ -3,19 +3,13 @@ import { Test, TestingModule } from '@nestjs/testing'; import { NodeEnvironment } from 'src/engine/core-modules/twenty-config/interfaces/node-environment.interface'; import { SupportDriver } from 'src/engine/core-modules/twenty-config/interfaces/support.interface'; +import { AiModelRegistryService } from 'src/engine/core-modules/ai/services/ai-model-registry.service'; import { CaptchaDriverType } from 'src/engine/core-modules/captcha/interfaces'; import { ClientConfigService } from 'src/engine/core-modules/client-config/services/client-config.service'; import { DomainManagerService } from 'src/engine/core-modules/domain-manager/services/domain-manager.service'; import { PUBLIC_FEATURE_FLAGS } from 'src/engine/core-modules/feature-flag/constants/public-feature-flag.const'; import { TwentyConfigService } from 'src/engine/core-modules/twenty-config/twenty-config.service'; -jest.mock( - 'src/engine/core-modules/ai/utils/get-ai-models-with-auto.util', - () => ({ - getAIModelsWithAuto: jest.fn(() => []), - }), -); - describe('ClientConfigService', () => { let service: ClientConfigService; let twentyConfigService: TwentyConfigService; @@ -37,6 +31,12 @@ describe('ClientConfigService', () => { getFrontUrl: jest.fn(), }, }, + { + provide: AiModelRegistryService, + useValue: { + getAvailableModels: jest.fn().mockReturnValue([]), + }, + }, ], }).compile(); diff --git a/packages/twenty-server/src/engine/core-modules/client-config/services/client-config.service.ts b/packages/twenty-server/src/engine/core-modules/client-config/services/client-config.service.ts index 3e7cb44b4..a895c7880 100644 --- a/packages/twenty-server/src/engine/core-modules/client-config/services/client-config.service.ts +++ b/packages/twenty-server/src/engine/core-modules/client-config/services/client-config.service.ts @@ -3,9 +3,12 @@ import { Injectable } from '@nestjs/common'; import { NodeEnvironment } from 'src/engine/core-modules/twenty-config/interfaces/node-environment.interface'; import { SupportDriver } from 'src/engine/core-modules/twenty-config/interfaces/support.interface'; -import { ModelProvider } from 'src/engine/core-modules/ai/constants/ai-models.const'; +import { + AI_MODELS, + ModelProvider, +} from 'src/engine/core-modules/ai/constants/ai-models.const'; +import { AiModelRegistryService } from 'src/engine/core-modules/ai/services/ai-model-registry.service'; import { convertCentsToBillingCredits } from 'src/engine/core-modules/ai/utils/convert-cents-to-billing-credits.util'; -import { getAIModelsWithAuto } from 'src/engine/core-modules/ai/utils/get-ai-models-with-auto.util'; import { ClientAIModelConfig, ClientConfig, @@ -19,41 +22,49 @@ export class ClientConfigService { constructor( private twentyConfigService: TwentyConfigService, private domainManagerService: DomainManagerService, + private aiModelRegistryService: AiModelRegistryService, ) {} async getClientConfig(): Promise { const captchaProvider = this.twentyConfigService.get('CAPTCHA_DRIVER'); const supportDriver = this.twentyConfigService.get('SUPPORT_DRIVER'); - const openaiApiKey = this.twentyConfigService.get('OPENAI_API_KEY'); - const anthropicApiKey = this.twentyConfigService.get('ANTHROPIC_API_KEY'); - const aiModels = getAIModelsWithAuto().reduce( - (acc, model) => { - const isAvailable = - (model.provider === ModelProvider.OPENAI && openaiApiKey) || - (model.provider === ModelProvider.ANTHROPIC && anthropicApiKey); + const availableModels = this.aiModelRegistryService.getAvailableModels(); - if (!isAvailable) { - return acc; - } + const aiModels: ClientAIModelConfig[] = availableModels.map( + (registeredModel) => { + const builtInModel = AI_MODELS.find( + (m) => m.modelId === registeredModel.modelId, + ); - acc.push({ - modelId: model.modelId, - label: model.label, - provider: model.provider, - inputCostPer1kTokensInCredits: convertCentsToBillingCredits( - model.inputCostPer1kTokensInCents, - ), - outputCostPer1kTokensInCredits: convertCentsToBillingCredits( - model.outputCostPer1kTokensInCents, - ), - }); - - return acc; + return { + modelId: registeredModel.modelId, + label: builtInModel?.label || registeredModel.modelId, + provider: registeredModel.provider, + inputCostPer1kTokensInCredits: builtInModel + ? convertCentsToBillingCredits( + builtInModel.inputCostPer1kTokensInCents, + ) + : 0, + outputCostPer1kTokensInCredits: builtInModel + ? convertCentsToBillingCredits( + builtInModel.outputCostPer1kTokensInCents, + ) + : 0, + }; }, - [], ); + if (aiModels.length > 0) { + aiModels.unshift({ + modelId: 'auto', + label: 'Auto', + provider: ModelProvider.NONE, + inputCostPer1kTokensInCredits: 0, + outputCostPer1kTokensInCredits: 0, + }); + } + const clientConfig: ClientConfig = { billing: { isBillingEnabled: this.twentyConfigService.get('IS_BILLING_ENABLED'), diff --git a/packages/twenty-server/src/engine/core-modules/core-engine.module.ts b/packages/twenty-server/src/engine/core-modules/core-engine.module.ts index bc1409a68..5a07991c2 100644 --- a/packages/twenty-server/src/engine/core-modules/core-engine.module.ts +++ b/packages/twenty-server/src/engine/core-modules/core-engine.module.ts @@ -6,7 +6,6 @@ import { WorkspaceQueryRunnerModule } from 'src/engine/api/graphql/workspace-que import { ActorModule } from 'src/engine/core-modules/actor/actor.module'; import { AdminPanelModule } from 'src/engine/core-modules/admin-panel/admin-panel.module'; import { AiModule } from 'src/engine/core-modules/ai/ai.module'; -import { aiModuleFactory } from 'src/engine/core-modules/ai/ai.module-factory'; import { ApiKeyModule } from 'src/engine/core-modules/api-key/api-key.module'; import { AppTokenModule } from 'src/engine/core-modules/app-token/app-token.module'; import { ApprovedAccessDomainModule } from 'src/engine/core-modules/approved-access-domain/approved-access-domain.module'; @@ -109,10 +108,7 @@ import { FileModule } from './file/file.module'; wildcard: true, }), CacheStorageModule, - AiModule.forRoot({ - useFactory: aiModuleFactory, - inject: [TwentyConfigService], - }), + AiModule, ServerlessModule.forRootAsync({ useFactory: serverlessModuleFactory, inject: [TwentyConfigService, FileStorageService], diff --git a/packages/twenty-server/src/engine/core-modules/twenty-config/config-variables.ts b/packages/twenty-server/src/engine/core-modules/twenty-config/config-variables.ts index 3027f2129..238733b4d 100644 --- a/packages/twenty-server/src/engine/core-modules/twenty-config/config-variables.ts +++ b/packages/twenty-server/src/engine/core-modules/twenty-config/config-variables.ts @@ -11,7 +11,6 @@ import { } from 'class-validator'; import { isDefined } from 'twenty-shared/utils'; -import { AiDriver } from 'src/engine/core-modules/ai/interfaces/ai.interface'; import { AwsRegion } from 'src/engine/core-modules/twenty-config/interfaces/aws-region.interface'; import { NodeEnvironment } from 'src/engine/core-modules/twenty-config/interfaces/node-environment.interface'; import { SupportDriver } from 'src/engine/core-modules/twenty-config/interfaces/support.interface'; @@ -975,13 +974,12 @@ export class ConfigVariables { @ConfigVariablesMetadata({ group: ConfigVariablesGroup.LLM, - description: 'Driver for the AI chat model', - type: ConfigVariableType.ENUM, - options: Object.values(AiDriver), - isEnvOnly: true, + description: + 'Default model ID for AI operations (can be any available model)', + type: ConfigVariableType.STRING, }) - @CastToUpperSnakeCase() - AI_DRIVER: AiDriver; + @IsOptional() + DEFAULT_MODEL_ID = 'gpt-4o'; @ConfigVariablesMetadata({ group: ConfigVariablesGroup.LLM, @@ -989,6 +987,7 @@ export class ConfigVariables { description: 'API key for OpenAI integration', type: ConfigVariableType.STRING, }) + @IsOptional() OPENAI_API_KEY: string; @ConfigVariablesMetadata({ @@ -997,8 +996,37 @@ export class ConfigVariables { description: 'API key for Anthropic integration', type: ConfigVariableType.STRING, }) + @IsOptional() ANTHROPIC_API_KEY: string; + @ConfigVariablesMetadata({ + group: ConfigVariablesGroup.LLM, + description: 'Base URL for OpenAI-compatible LLM provider (e.g., Ollama)', + type: ConfigVariableType.STRING, + }) + @IsOptional() + @IsUrl({ require_tld: false, require_protocol: true }) + OPENAI_COMPATIBLE_BASE_URL: string; + + @ConfigVariablesMetadata({ + group: ConfigVariablesGroup.LLM, + description: + 'Model names for OpenAI-compatible LLM provider (comma-separated, e.g., "llama3.1, mistral, codellama")', + type: ConfigVariableType.STRING, + }) + @IsOptional() + OPENAI_COMPATIBLE_MODEL_NAMES: string; + + @ConfigVariablesMetadata({ + group: ConfigVariablesGroup.LLM, + isSensitive: true, + description: + 'API key for OpenAI-compatible LLM provider (optional for providers like Ollama)', + type: ConfigVariableType.STRING, + }) + @IsOptional() + OPENAI_COMPATIBLE_API_KEY: string; + @ConfigVariablesMetadata({ group: ConfigVariablesGroup.ServerConfig, description: 'Enable or disable multi-workspace support', diff --git a/packages/twenty-server/src/engine/metadata-modules/agent/agent-execution.service.ts b/packages/twenty-server/src/engine/metadata-modules/agent/agent-execution.service.ts index 058ff333b..425f1804a 100644 --- a/packages/twenty-server/src/engine/metadata-modules/agent/agent-execution.service.ts +++ b/packages/twenty-server/src/engine/metadata-modules/agent/agent-execution.service.ts @@ -1,4 +1,4 @@ -import { Injectable } from '@nestjs/common'; +import { Injectable, Logger } from '@nestjs/common'; import { InjectRepository } from '@nestjs/typeorm'; import { createAnthropic } from '@ai-sdk/anthropic'; @@ -10,7 +10,7 @@ import { ModelId, ModelProvider, } from 'src/engine/core-modules/ai/constants/ai-models.const'; -import { getEffectiveModelConfig } from 'src/engine/core-modules/ai/utils/get-effective-model-config.util'; +import { AiModelRegistryService } from 'src/engine/core-modules/ai/services/ai-model-registry.service'; import { TwentyConfigService } from 'src/engine/core-modules/twenty-config/twenty-config.service'; import { AgentChatMessageEntity, @@ -22,7 +22,6 @@ import { AGENT_SYSTEM_PROMPTS } from 'src/engine/metadata-modules/agent/constant import { convertOutputSchemaToZod } from 'src/engine/metadata-modules/agent/utils/convert-output-schema-to-zod'; import { OutputSchema } from 'src/modules/workflow/workflow-builder/workflow-schema/types/output-schema.type'; import { resolveInput } from 'src/modules/workflow/workflow-executor/utils/variable-resolver.util'; -import { getAIModelById } from 'src/engine/core-modules/ai/utils/get-ai-model-by-id.util'; import { AgentEntity } from './agent.entity'; import { AgentException, AgentExceptionCode } from './agent.exception'; @@ -41,21 +40,27 @@ export interface AgentExecutionResult { @Injectable() export class AgentExecutionService { + private readonly logger = new Logger(AgentExecutionService.name); + constructor( private readonly twentyConfigService: TwentyConfigService, private readonly agentToolService: AgentToolService, + private readonly aiModelRegistryService: AiModelRegistryService, @InjectRepository(AgentEntity, 'core') private readonly agentRepository: Repository, ) {} getModel = (modelId: ModelId, provider: ModelProvider) => { switch (provider) { - case ModelProvider.NONE: { + case ModelProvider.OPENAI_COMPATIBLE: { const OpenAIProvider = createOpenAI({ - apiKey: this.twentyConfigService.get('OPENAI_API_KEY'), + baseURL: this.twentyConfigService.get('OPENAI_COMPATIBLE_BASE_URL'), + apiKey: this.twentyConfigService.get('OPENAI_COMPATIBLE_API_KEY'), }); - return OpenAIProvider(getEffectiveModelConfig(modelId).modelId); + return OpenAIProvider( + this.aiModelRegistryService.getEffectiveModelConfig(modelId).modelId, + ); } case ModelProvider.OPENAI: { const OpenAIProvider = createOpenAI({ @@ -83,9 +88,6 @@ export class AgentExecutionService { let apiKey: string | undefined; switch (provider) { - case ModelProvider.NONE: - apiKey = this.twentyConfigService.get('OPENAI_API_KEY'); - break; case ModelProvider.OPENAI: apiKey = this.twentyConfigService.get('OPENAI_API_KEY'); break; @@ -93,14 +95,11 @@ export class AgentExecutionService { apiKey = this.twentyConfigService.get('ANTHROPIC_API_KEY'); break; default: - throw new AgentException( - `Unsupported provider: ${provider}`, - AgentExceptionCode.AGENT_EXECUTION_FAILED, - ); + return; } if (!apiKey) { throw new AgentException( - `${provider === ModelProvider.NONE ? 'OPENAI' : provider.toUpperCase()} API key not configured`, + `${provider.toUpperCase()} API key not configured`, AgentExceptionCode.API_KEY_NOT_CONFIGURED, ); } @@ -117,31 +116,55 @@ export class AgentExecutionService { prompt?: string; messages?: CoreMessage[]; }) { - const aiModel = getAIModelById(agent.modelId); - - if (!aiModel) { - throw new AgentException( - `AI model with id ${agent.modelId} not found`, - AgentExceptionCode.AGENT_EXECUTION_FAILED, + try { + this.logger.log( + `Preparing AI request config for agent ${agent.id} with model ${agent.modelId}`, ); + + const aiModel = this.aiModelRegistryService.getEffectiveModelConfig( + agent.modelId, + ); + + if (!aiModel) { + const error = `AI model with id ${agent.modelId} not found`; + + this.logger.error(error); + throw new AgentException( + error, + AgentExceptionCode.AGENT_EXECUTION_FAILED, + ); + } + + this.logger.log( + `Resolved model: ${aiModel.modelId} (provider: ${aiModel.provider})`, + ); + + const provider = aiModel.provider; + + await this.validateApiKey(provider); + + const tools = await this.agentToolService.generateToolsForAgent( + agent.id, + agent.workspaceId, + ); + + this.logger.log(`Generated ${Object.keys(tools).length} tools for agent`); + + return { + system, + tools, + model: this.getModel(aiModel.modelId, aiModel.provider), + ...(messages && { messages }), + ...(prompt && { prompt }), + maxSteps: AGENT_CONFIG.MAX_STEPS, + }; + } catch (error) { + this.logger.error( + `Failed to prepare AI request config for agent ${agent.id}:`, + error instanceof Error ? error.stack : error, + ); + throw error; } - const provider = aiModel.provider; - - await this.validateApiKey(provider); - - const tools = await this.agentToolService.generateToolsForAgent( - agent.id, - agent.workspaceId, - ); - - return { - system, - tools, - model: this.getModel(agent.modelId, aiModel.provider), - ...(messages && { messages }), - ...(prompt && { prompt }), - maxSteps: AGENT_CONFIG.MAX_STEPS, - }; } async streamChatResponse({ @@ -173,6 +196,10 @@ export class AgentExecutionService { messages: llmMessages, }); + this.logger.log( + `Sending request to AI model with ${llmMessages.length} messages`, + ); + return streamText(aiRequestConfig); } diff --git a/packages/twenty-server/src/engine/metadata-modules/agent/agent-streaming.service.ts b/packages/twenty-server/src/engine/metadata-modules/agent/agent-streaming.service.ts index 5358c12ff..917fb64d5 100644 --- a/packages/twenty-server/src/engine/metadata-modules/agent/agent-streaming.service.ts +++ b/packages/twenty-server/src/engine/metadata-modules/agent/agent-streaming.service.ts @@ -1,4 +1,4 @@ -import { Injectable } from '@nestjs/common'; +import { Injectable, Logger } from '@nestjs/common'; import { InjectRepository } from '@nestjs/typeorm'; import { Response } from 'express'; @@ -28,6 +28,8 @@ export type StreamAgentChatResult = { @Injectable() export class AgentStreamingService { + private readonly logger = new Logger(AgentStreamingService.name); + constructor( @InjectRepository(AgentChatThreadEntity, 'core') private readonly threadRepository: Repository, @@ -89,7 +91,11 @@ export class AgentStreamingService { message: chunk.args?.toolDescription, }); break; + case 'error': + this.logger.error(`Stream error: ${JSON.stringify(chunk)}`); + break; default: + this.logger.log(`Unknown chunk type: ${chunk.type}`); break; } } @@ -105,6 +111,10 @@ export class AgentStreamingService { const errorMessage = error instanceof Error ? error.message : 'Unknown error occurred'; + if (error instanceof AgentException) { + this.logger.error(`Agent Exception Code: ${error.code}`); + } + if (!res.headersSent) { this.setupStreamingHeaders(res); } diff --git a/packages/twenty-server/src/modules/workflow/workflow-executor/workflow-actions/ai-agent/ai-agent-action.module.ts b/packages/twenty-server/src/modules/workflow/workflow-executor/workflow-actions/ai-agent/ai-agent-action.module.ts index 849148d2c..f7e570a4c 100644 --- a/packages/twenty-server/src/modules/workflow/workflow-executor/workflow-actions/ai-agent/ai-agent-action.module.ts +++ b/packages/twenty-server/src/modules/workflow/workflow-executor/workflow-actions/ai-agent/ai-agent-action.module.ts @@ -1,8 +1,6 @@ import { Module } from '@nestjs/common'; import { TypeOrmModule } from '@nestjs/typeorm'; -import { AiDriver } from 'src/engine/core-modules/ai/interfaces/ai.interface'; - import { AiModule } from 'src/engine/core-modules/ai/ai.module'; import { AgentEntity } from 'src/engine/metadata-modules/agent/agent.entity'; import { AgentModule } from 'src/engine/metadata-modules/agent/agent.module'; @@ -13,9 +11,7 @@ import { AiAgentWorkflowAction } from './ai-agent.workflow-action'; @Module({ imports: [ AgentModule, - AiModule.forRoot({ - useFactory: () => ({ type: AiDriver.OPENAI }), - }), + AiModule, TypeOrmModule.forFeature([AgentEntity], 'core'), ], providers: [ScopedWorkspaceContextFactory, AiAgentWorkflowAction],