Refresh AI model setup (#13171)

Instead of initializing model at start time we do it at run time to be
able to swap model provider more easily.

Also introduce a third driver for openai-compatible providers, which
among other allows for local models with Ollama
This commit is contained in:
Félix Malfait
2025-07-11 13:09:54 +02:00
committed by GitHub
parent fd13bb0258
commit 0a93468b95
32 changed files with 566 additions and 417 deletions

View File

@ -1041,7 +1041,8 @@ export enum MessageChannelVisibility {
export enum ModelProvider {
ANTHROPIC = 'ANTHROPIC',
NONE = 'NONE',
OPENAI = 'OPENAI'
OPENAI = 'OPENAI',
OPENAI_COMPATIBLE = 'OPENAI_COMPATIBLE'
}
export type Mutation = {

View File

@ -998,7 +998,8 @@ export enum MessageChannelVisibility {
export enum ModelProvider {
ANTHROPIC = 'ANTHROPIC',
NONE = 'NONE',
OPENAI = 'OPENAI'
OPENAI = 'OPENAI',
OPENAI_COMPATIBLE = 'OPENAI_COMPATIBLE'
}
export type Mutation = {

View File

@ -1,3 +1,4 @@
import { useSnackBar } from '@/ui/feedback/snack-bar-manager/hooks/useSnackBar';
import { useRecoilComponentStateV2 } from '@/ui/utilities/state/component-state/hooks/useRecoilComponentStateV2';
import { useState } from 'react';
import { useRecoilState } from 'recoil';
@ -22,6 +23,7 @@ interface OptimisticMessage extends AgentChatMessage {
export const useAgentChat = (agentId: string) => {
const apolloClient = useApolloClient();
const { enqueueErrorSnackBar } = useSnackBar();
const [agentChatMessages, setAgentChatMessages] = useRecoilComponentStateV2(
agentChatMessagesComponentState,
@ -112,6 +114,11 @@ export const useAgentChat = (agentId: string) => {
}));
scrollToBottom();
},
onError: (message: string) => {
enqueueErrorSnackBar({
message,
});
},
});
},
},

View File

@ -1,12 +1,13 @@
export type AgentStreamingEvent = {
type: 'text-delta' | 'tool-call';
type: 'text-delta' | 'tool-call' | 'error';
message: string;
};
export type AgentStreamingParserCallbacks = {
onTextDelta?: (message: string) => void;
onToolCall?: (message: string) => void;
onError?: (error: Error, rawLine: string) => void;
onError?: (message: string) => void;
onParseError?: (error: Error, rawLine: string) => void;
};
export const parseAgentStreamingChunk = (
@ -27,6 +28,9 @@ export const parseAgentStreamingChunk = (
case 'tool-call':
callbacks.onToolCall?.(event.message);
break;
case 'error':
callbacks.onError?.(event.message);
break;
}
} catch (error) {
// eslint-disable-next-line no-console
@ -36,7 +40,7 @@ export const parseAgentStreamingChunk = (
error instanceof Error
? error
: new Error(`Unknown parsing error: ${String(error)}`);
callbacks.onError?.(errorMessage, line);
callbacks.onParseError?.(errorMessage, line);
}
}
}

View File

@ -1 +0,0 @@
export const AI_DRIVER = Symbol('AI_DRIVER');

View File

@ -1,17 +0,0 @@
import { AiDriver } from 'src/engine/core-modules/ai/interfaces/ai.interface';
import { TwentyConfigService } from 'src/engine/core-modules/twenty-config/twenty-config.service';
export const aiModuleFactory = (twentyConfigService: TwentyConfigService) => {
const driver = twentyConfigService.get('AI_DRIVER');
switch (driver) {
case AiDriver.OPENAI: {
return { type: AiDriver.OPENAI };
}
default: {
// Default to OpenAI driver if no driver is specified
return { type: AiDriver.OPENAI };
}
}
};

View File

@ -1,65 +1,46 @@
import { DynamicModule, Global, Provider } from '@nestjs/common';
import { Global, Module } from '@nestjs/common';
import { TypeOrmModule } from '@nestjs/typeorm';
import {
AiDriver,
AiModuleAsyncOptions,
} from 'src/engine/core-modules/ai/interfaces/ai.interface';
import { AI_DRIVER } from 'src/engine/core-modules/ai/ai.constants';
import { AiService } from 'src/engine/core-modules/ai/services/ai.service';
import { AiController } from 'src/engine/core-modules/ai/controllers/ai.controller';
import { OpenAIDriver } from 'src/engine/core-modules/ai/drivers/openai.driver';
import { AIBillingService } from 'src/engine/core-modules/ai/services/ai-billing.service';
import { FeatureFlagModule } from 'src/engine/core-modules/feature-flag/feature-flag.module';
import { McpController } from 'src/engine/core-modules/ai/controllers/mcp.controller';
import { AuthModule } from 'src/engine/core-modules/auth/auth.module';
import { ToolService } from 'src/engine/core-modules/ai/services/tool.service';
import { AIBillingService } from 'src/engine/core-modules/ai/services/ai-billing.service';
import { AiModelRegistryService } from 'src/engine/core-modules/ai/services/ai-model-registry.service';
import { AiService } from 'src/engine/core-modules/ai/services/ai.service';
import { McpService } from 'src/engine/core-modules/ai/services/mcp.service';
import { ToolService } from 'src/engine/core-modules/ai/services/tool.service';
import { TokenModule } from 'src/engine/core-modules/auth/token/token.module';
import { FeatureFlagModule } from 'src/engine/core-modules/feature-flag/feature-flag.module';
import { ObjectMetadataModule } from 'src/engine/metadata-modules/object-metadata/object-metadata.module';
import { RoleEntity } from 'src/engine/metadata-modules/role/role.entity';
import { UserRoleModule } from 'src/engine/metadata-modules/user-role/user-role.module';
import { WorkspacePermissionsCacheModule } from 'src/engine/metadata-modules/workspace-permissions-cache/workspace-permissions-cache.module';
import { WorkspaceCacheStorageModule } from 'src/engine/workspace-cache-storage/workspace-cache-storage.module';
import { UserRoleModule } from 'src/engine/metadata-modules/user-role/user-role.module';
import { RoleEntity } from 'src/engine/metadata-modules/role/role.entity';
@Global()
export class AiModule {
static forRoot(options: AiModuleAsyncOptions): DynamicModule {
const provider: Provider = {
provide: AI_DRIVER,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
useFactory: (...args: any[]) => {
const config = options.useFactory(...args);
switch (config?.type) {
case AiDriver.OPENAI: {
return new OpenAIDriver();
}
}
},
inject: options.inject || [],
};
return {
module: AiModule,
imports: [
TypeOrmModule.forFeature([RoleEntity], 'core'),
FeatureFlagModule,
ObjectMetadataModule,
WorkspacePermissionsCacheModule,
WorkspaceCacheStorageModule,
UserRoleModule,
AuthModule,
],
controllers: [AiController, McpController],
providers: [
AiService,
ToolService,
AIBillingService,
McpService,
provider,
],
exports: [AiService, AIBillingService, ToolService, McpService],
};
}
}
@Module({
imports: [
TypeOrmModule.forFeature([RoleEntity], 'core'),
TokenModule,
FeatureFlagModule,
ObjectMetadataModule,
WorkspacePermissionsCacheModule,
WorkspaceCacheStorageModule,
UserRoleModule,
],
controllers: [AiController, McpController],
providers: [
AiService,
AiModelRegistryService,
ToolService,
AIBillingService,
McpService,
],
exports: [
AiService,
AiModelRegistryService,
AIBillingService,
ToolService,
McpService,
],
})
export class AiModule {}

View File

@ -1,7 +1,9 @@
import { getAIModelsWithAuto } from 'src/engine/core-modules/ai/utils/get-ai-models-with-auto.util';
import { getDefaultModelConfig } from 'src/engine/core-modules/ai/utils/get-default-model-config.util';
import { Test, TestingModule } from '@nestjs/testing';
import { AI_MODELS, DEFAULT_MODEL_ID, ModelProvider } from './ai-models.const';
import { AiModelRegistryService } from 'src/engine/core-modules/ai/services/ai-model-registry.service';
import { TwentyConfigService } from 'src/engine/core-modules/twenty-config/twenty-config.service';
import { AI_MODELS, ModelProvider } from './ai-models.const';
describe('AI_MODELS', () => {
it('should contain all expected models', () => {
@ -15,41 +17,120 @@ describe('AI_MODELS', () => {
'claude-3-5-haiku-20241022',
]);
});
it('should have the default model as the first model', () => {
const DEFAULT_MODEL = AI_MODELS.find(
(model) => model.modelId === DEFAULT_MODEL_ID,
);
expect(DEFAULT_MODEL).toBeDefined();
expect(DEFAULT_MODEL?.modelId).toBe(DEFAULT_MODEL_ID);
});
});
describe('getAIModelsWithAuto', () => {
it('should return AI_MODELS with auto model prepended', () => {
const ORIGINAL_MODELS = AI_MODELS;
const MODELS_WITH_AUTO = getAIModelsWithAuto();
describe('AiModelRegistryService', () => {
let SERVICE: AiModelRegistryService;
let MOCK_CONFIG_SERVICE: jest.Mocked<TwentyConfigService>;
expect(MODELS_WITH_AUTO).toHaveLength(ORIGINAL_MODELS.length + 1);
expect(MODELS_WITH_AUTO[0].modelId).toBe('auto');
expect(MODELS_WITH_AUTO[0].label).toBe('Auto');
expect(MODELS_WITH_AUTO[0].provider).toBe(ModelProvider.NONE);
beforeEach(async () => {
MOCK_CONFIG_SERVICE = {
get: jest.fn(),
} as any;
// Check that the rest of the models are the same
expect(MODELS_WITH_AUTO.slice(1)).toEqual(ORIGINAL_MODELS);
const MODULE: TestingModule = await Test.createTestingModule({
providers: [
AiModelRegistryService,
{
provide: TwentyConfigService,
useValue: MOCK_CONFIG_SERVICE,
},
],
}).compile();
SERVICE = MODULE.get<AiModelRegistryService>(AiModelRegistryService);
});
it('should have auto model with default model costs', () => {
const MODELS_WITH_AUTO = getAIModelsWithAuto();
const AUTO_MODEL = MODELS_WITH_AUTO[0];
const DEFAULT_MODEL = getDefaultModelConfig();
it('should return effective model config for auto', () => {
MOCK_CONFIG_SERVICE.get.mockReturnValue('gpt-4o');
expect(AUTO_MODEL.inputCostPer1kTokensInCents).toBe(
DEFAULT_MODEL.inputCostPer1kTokensInCents,
expect(() => SERVICE.getEffectiveModelConfig('auto')).toThrow(
'No AI models are available. Please configure at least one provider.',
);
expect(AUTO_MODEL.outputCostPer1kTokensInCents).toBe(
DEFAULT_MODEL.outputCostPer1kTokensInCents,
});
it('should return effective model config for auto when models are available', () => {
MOCK_CONFIG_SERVICE.get.mockReturnValue('gpt-4o');
jest.spyOn(SERVICE, 'getAvailableModels').mockReturnValue([
{
modelId: 'gpt-4o',
provider: ModelProvider.OPENAI,
model: {} as any,
},
]);
jest.spyOn(SERVICE, 'getModel').mockReturnValue({
modelId: 'gpt-4o',
provider: ModelProvider.OPENAI,
model: {} as any,
});
const RESULT = SERVICE.getEffectiveModelConfig('auto');
expect(RESULT).toBeDefined();
expect(RESULT.modelId).toBe('gpt-4o');
expect(RESULT.provider).toBe(ModelProvider.OPENAI);
});
it('should return effective model config for auto with custom model', () => {
MOCK_CONFIG_SERVICE.get.mockReturnValue('mistral');
jest.spyOn(SERVICE, 'getAvailableModels').mockReturnValue([
{
modelId: 'mistral',
provider: ModelProvider.OPENAI_COMPATIBLE,
model: {} as any,
},
]);
jest.spyOn(SERVICE, 'getModel').mockReturnValue({
modelId: 'mistral',
provider: ModelProvider.OPENAI_COMPATIBLE,
model: {} as any,
});
const RESULT = SERVICE.getEffectiveModelConfig('auto');
expect(RESULT).toBeDefined();
expect(RESULT.modelId).toBe('mistral');
expect(RESULT.provider).toBe(ModelProvider.OPENAI_COMPATIBLE);
expect(RESULT.label).toBe('mistral');
expect(RESULT.inputCostPer1kTokensInCents).toBe(0);
expect(RESULT.outputCostPer1kTokensInCents).toBe(0);
});
it('should return effective model config for specific model', () => {
const RESULT = SERVICE.getEffectiveModelConfig('gpt-4o-mini');
expect(RESULT).toBeDefined();
expect(RESULT.modelId).toBe('gpt-4o-mini');
expect(RESULT.provider).toBe(ModelProvider.OPENAI);
});
it('should return effective model config for custom model', () => {
// Mock that the custom model exists in registry
jest.spyOn(SERVICE, 'getModel').mockReturnValue({
modelId: 'mistral',
provider: ModelProvider.OPENAI_COMPATIBLE,
model: {} as any,
});
const RESULT = SERVICE.getEffectiveModelConfig('mistral');
expect(RESULT).toBeDefined();
expect(RESULT.modelId).toBe('mistral');
expect(RESULT.provider).toBe(ModelProvider.OPENAI_COMPATIBLE);
expect(RESULT.label).toBe('mistral');
expect(RESULT.inputCostPer1kTokensInCents).toBe(0);
expect(RESULT.outputCostPer1kTokensInCents).toBe(0);
});
it('should throw error for non-existent model', () => {
jest.spyOn(SERVICE, 'getModel').mockReturnValue(undefined);
expect(() => SERVICE.getEffectiveModelConfig('non-existent-model')).toThrow(
'Model with ID non-existent-model not found',
);
});
});

View File

@ -2,6 +2,7 @@ export enum ModelProvider {
NONE = 'none',
OPENAI = 'openai',
ANTHROPIC = 'anthropic',
OPENAI_COMPATIBLE = 'open_ai_compatible',
}
export type ModelId =
@ -11,9 +12,8 @@ export type ModelId =
| 'gpt-4-turbo'
| 'claude-opus-4-20250514'
| 'claude-sonnet-4-20250514'
| 'claude-3-5-haiku-20241022';
export const DEFAULT_MODEL_ID: ModelId = 'gpt-4o';
| 'claude-3-5-haiku-20241022'
| string; // Allow custom model names
export interface AIModelConfig {
modelId: ModelId;

View File

@ -1,8 +0,0 @@
import { CoreMessage, StreamTextResult } from 'ai';
export interface AiDriver {
streamText(
messages: CoreMessage[],
options?: { temperature?: number; maxTokens?: number },
): StreamTextResult<Record<string, never>, undefined>;
}

View File

@ -1,18 +0,0 @@
import { openai } from '@ai-sdk/openai';
import { CoreMessage, StreamTextResult, streamText } from 'ai';
import { AiDriver } from 'src/engine/core-modules/ai/drivers/interfaces/ai-driver.interface';
export class OpenAIDriver implements AiDriver {
streamText(
messages: CoreMessage[],
options?: { temperature?: number; maxTokens?: number },
): StreamTextResult<Record<string, never>, undefined> {
return streamText({
model: openai('gpt-4o-mini'),
messages,
temperature: options?.temperature,
maxTokens: options?.maxTokens,
});
}
}

View File

@ -1,14 +0,0 @@
import { InjectionToken, OptionalFactoryDependency } from '@nestjs/common';
export enum AiDriver {
OPENAI = 'openai',
}
export interface AiModuleOptions {
type: AiDriver;
}
export type AiModuleAsyncOptions = {
inject?: (InjectionToken | OptionalFactoryDependency)[];
useFactory: (...args: unknown[]) => AiModuleOptions | undefined;
};

View File

@ -5,6 +5,7 @@ import { BillingMeterEventName } from 'src/engine/core-modules/billing/enums/bil
import { WorkspaceEventEmitter } from 'src/engine/workspace-event-emitter/workspace-event-emitter';
import { AIBillingService } from './ai-billing.service';
import { AiModelRegistryService } from './ai-model-registry.service';
describe('AIBillingService', () => {
let service: AIBillingService;
@ -21,6 +22,16 @@ describe('AIBillingService', () => {
emitCustomBatchEvent: jest.fn(),
};
const mockAiModelRegistryMethods = {
getEffectiveModelConfig: jest.fn().mockReturnValue({
modelId: 'gpt-4o',
label: 'GPT-4o',
provider: 'openai',
inputCostPer1kTokensInCents: 0.25,
outputCostPer1kTokensInCents: 1.0,
}),
};
const module: TestingModule = await Test.createTestingModule({
providers: [
AIBillingService,
@ -28,6 +39,10 @@ describe('AIBillingService', () => {
provide: WorkspaceEventEmitter,
useValue: mockEventEmitterMethods,
},
{
provide: AiModelRegistryService,
useValue: mockAiModelRegistryMethods,
},
],
}).compile();

View File

@ -2,7 +2,7 @@ import { Injectable, Logger } from '@nestjs/common';
import { ModelId } from 'src/engine/core-modules/ai/constants/ai-models.const';
import { DOLLAR_TO_CREDIT_MULTIPLIER } from 'src/engine/core-modules/ai/constants/dollar-to-credit-multiplier';
import { getAIModelById } from 'src/engine/core-modules/ai/utils/get-ai-model-by-id.util';
import { AiModelRegistryService } from 'src/engine/core-modules/ai/services/ai-model-registry.service';
import { BILLING_FEATURE_USED } from 'src/engine/core-modules/billing/constants/billing-feature-used.constant';
import { BillingMeterEventName } from 'src/engine/core-modules/billing/enums/billing-meter-event-names';
import { BillingUsageEvent } from 'src/engine/core-modules/billing/types/billing-usage-event.type';
@ -18,10 +18,13 @@ export interface TokenUsage {
export class AIBillingService {
private readonly logger = new Logger(AIBillingService.name);
constructor(private readonly workspaceEventEmitter: WorkspaceEventEmitter) {}
constructor(
private readonly workspaceEventEmitter: WorkspaceEventEmitter,
private readonly aiModelRegistryService: AiModelRegistryService,
) {}
async calculateCost(modelId: ModelId, usage: TokenUsage): Promise<number> {
const model = getAIModelById(modelId);
const model = this.aiModelRegistryService.getEffectiveModelConfig(modelId);
if (!model) {
throw new Error(`AI model with id ${modelId} not found`);

View File

@ -0,0 +1,185 @@
import { Injectable } from '@nestjs/common';
import { anthropic } from '@ai-sdk/anthropic';
import { createOpenAI, openai } from '@ai-sdk/openai';
import { LanguageModel } from 'ai';
import {
AI_MODELS,
AIModelConfig,
ModelProvider,
} from 'src/engine/core-modules/ai/constants/ai-models.const';
import { TwentyConfigService } from 'src/engine/core-modules/twenty-config/twenty-config.service';
export interface RegisteredAIModel {
modelId: string;
provider: ModelProvider;
model: LanguageModel;
}
@Injectable()
export class AiModelRegistryService {
private modelRegistry: Map<string, RegisteredAIModel> = new Map();
constructor(private twentyConfigService: TwentyConfigService) {
this.buildModelRegistry();
}
private buildModelRegistry(): void {
this.modelRegistry.clear();
const openaiApiKey = this.twentyConfigService.get('OPENAI_API_KEY');
if (openaiApiKey) {
this.registerOpenAIModels();
}
const anthropicApiKey = this.twentyConfigService.get('ANTHROPIC_API_KEY');
if (anthropicApiKey) {
this.registerAnthropicModels();
}
const openaiCompatibleBaseUrl = this.twentyConfigService.get(
'OPENAI_COMPATIBLE_BASE_URL',
);
const openaiCompatibleModelNames = this.twentyConfigService.get(
'OPENAI_COMPATIBLE_MODEL_NAMES',
);
if (openaiCompatibleBaseUrl && openaiCompatibleModelNames) {
this.registerOpenAICompatibleModels(
openaiCompatibleBaseUrl,
openaiCompatibleModelNames,
);
}
}
private registerOpenAIModels(): void {
const openaiModels = AI_MODELS.filter(
(model) => model.provider === ModelProvider.OPENAI,
);
openaiModels.forEach((modelConfig) => {
this.modelRegistry.set(modelConfig.modelId, {
modelId: modelConfig.modelId,
provider: ModelProvider.OPENAI,
model: openai(modelConfig.modelId),
});
});
}
private registerAnthropicModels(): void {
const anthropicModels = AI_MODELS.filter(
(model) => model.provider === ModelProvider.ANTHROPIC,
);
anthropicModels.forEach((modelConfig) => {
this.modelRegistry.set(modelConfig.modelId, {
modelId: modelConfig.modelId,
provider: ModelProvider.ANTHROPIC,
model: anthropic(modelConfig.modelId),
});
});
}
private registerOpenAICompatibleModels(
baseUrl: string,
modelNamesString: string,
): void {
const apiKey = this.twentyConfigService.get('OPENAI_COMPATIBLE_API_KEY');
const provider = createOpenAI({
baseURL: baseUrl,
apiKey: apiKey,
});
const modelNames = modelNamesString
.split(',')
.map((name) => name.trim())
.filter((name) => name.length > 0);
modelNames.forEach((modelId) => {
this.modelRegistry.set(modelId, {
modelId,
provider: ModelProvider.OPENAI_COMPATIBLE,
model: provider(modelId),
});
});
}
getModel(modelId: string): RegisteredAIModel | undefined {
return this.modelRegistry.get(modelId);
}
getAvailableModels(): RegisteredAIModel[] {
return Array.from(this.modelRegistry.values());
}
getDefaultModel(): RegisteredAIModel | undefined {
const defaultModelId = this.twentyConfigService.get('DEFAULT_MODEL_ID');
let model = this.getModel(defaultModelId);
if (!model) {
const availableModels = this.getAvailableModels();
model = availableModels[0];
}
return model;
}
getEffectiveModelConfig(modelId: string): AIModelConfig {
if (modelId === 'auto') {
const defaultModel = this.getDefaultModel();
if (!defaultModel) {
throw new Error(
'No AI models are available. Please configure at least one provider.',
);
}
const modelConfig = AI_MODELS.find(
(model) => model.modelId === defaultModel.modelId,
);
if (modelConfig) {
return modelConfig;
}
return this.createDefaultConfigForCustomModel(defaultModel);
}
const predefinedModel = AI_MODELS.find(
(model) => model.modelId === modelId,
);
if (predefinedModel) {
return predefinedModel;
}
const registeredModel = this.getModel(modelId);
if (registeredModel) {
return this.createDefaultConfigForCustomModel(registeredModel);
}
throw new Error(`Model with ID ${modelId} not found`);
}
private createDefaultConfigForCustomModel(
registeredModel: RegisteredAIModel,
): AIModelConfig {
return {
modelId: registeredModel.modelId,
label: registeredModel.modelId,
provider: registeredModel.provider,
inputCostPer1kTokensInCents: 0,
outputCostPer1kTokensInCents: 0,
};
}
// Force refresh the registry (useful if config changes)
refreshRegistry(): void {
this.buildModelRegistry();
}
}

View File

@ -1,19 +1,47 @@
import { Inject, Injectable } from '@nestjs/common';
import { Injectable } from '@nestjs/common';
import { CoreMessage, StreamTextResult } from 'ai';
import { CoreMessage, StreamTextResult, streamText } from 'ai';
import { AiDriver } from 'src/engine/core-modules/ai/drivers/interfaces/ai-driver.interface';
import { AI_DRIVER } from 'src/engine/core-modules/ai/ai.constants';
import { AiModelRegistryService } from 'src/engine/core-modules/ai/services/ai-model-registry.service';
@Injectable()
export class AiService {
constructor(@Inject(AI_DRIVER) private driver: AiDriver) {}
constructor(private aiModelRegistryService: AiModelRegistryService) {}
streamText(
messages: CoreMessage[],
options?: { temperature?: number; maxTokens?: number },
options?: {
temperature?: number;
maxTokens?: number;
modelId?: string; // Optional model override
},
): StreamTextResult<Record<string, never>, undefined> {
return this.driver.streamText(messages, options);
const modelId = options?.modelId;
const registeredModel = modelId
? this.aiModelRegistryService.getModel(modelId)
: this.aiModelRegistryService.getDefaultModel();
if (!registeredModel) {
throw new Error(
modelId
? `Model "${modelId}" is not available. Please check your configuration.`
: 'No AI models are available. Please configure at least one provider.',
);
}
return streamText({
model: registeredModel.model,
messages,
temperature: options?.temperature,
maxTokens: options?.maxTokens,
});
}
getAvailableModels() {
return this.aiModelRegistryService.getAvailableModels();
}
getDefaultModel() {
return this.aiModelRegistryService.getDefaultModel();
}
}

View File

@ -1,7 +0,0 @@
import { AIModelConfig } from 'src/engine/core-modules/ai/constants/ai-models.const';
import { getEffectiveModelConfig } from './get-effective-model-config.util';
export const getAIModelById = (modelId: string): AIModelConfig => {
return getEffectiveModelConfig(modelId);
};

View File

@ -1,35 +0,0 @@
import {
AI_MODELS,
ModelProvider,
} from 'src/engine/core-modules/ai/constants/ai-models.const';
import { getAIModelsWithAuto } from './get-ai-models-with-auto.util';
import { getDefaultModelConfig } from './get-default-model-config.util';
describe('getAIModelsWithAuto', () => {
it('should return AI_MODELS with auto model prepended', () => {
const ORIGINAL_MODELS = AI_MODELS;
const MODELS_WITH_AUTO = getAIModelsWithAuto();
expect(MODELS_WITH_AUTO).toHaveLength(ORIGINAL_MODELS.length + 1);
expect(MODELS_WITH_AUTO[0].modelId).toBe('auto');
expect(MODELS_WITH_AUTO[0].label).toBe('Auto');
expect(MODELS_WITH_AUTO[0].provider).toBe(ModelProvider.NONE);
// Check that the rest of the models are the same
expect(MODELS_WITH_AUTO.slice(1)).toEqual(ORIGINAL_MODELS);
});
it('should have auto model with default model costs', () => {
const MODELS_WITH_AUTO = getAIModelsWithAuto();
const AUTO_MODEL = MODELS_WITH_AUTO[0];
const DEFAULT_MODEL = getDefaultModelConfig();
expect(AUTO_MODEL.inputCostPer1kTokensInCents).toBe(
DEFAULT_MODEL.inputCostPer1kTokensInCents,
);
expect(AUTO_MODEL.outputCostPer1kTokensInCents).toBe(
DEFAULT_MODEL.outputCostPer1kTokensInCents,
);
});
});

View File

@ -1,22 +0,0 @@
import {
AI_MODELS,
AIModelConfig,
ModelProvider,
} from 'src/engine/core-modules/ai/constants/ai-models.const';
import { getDefaultModelConfig } from './get-default-model-config.util';
export const getAIModelsWithAuto = (): AIModelConfig[] => {
return [
{
modelId: 'auto',
label: 'Auto',
provider: ModelProvider.NONE,
inputCostPer1kTokensInCents:
getDefaultModelConfig().inputCostPer1kTokensInCents,
outputCostPer1kTokensInCents:
getDefaultModelConfig().outputCostPer1kTokensInCents,
},
...AI_MODELS,
];
};

View File

@ -1,29 +0,0 @@
import {
AI_MODELS,
DEFAULT_MODEL_ID,
ModelProvider,
} from 'src/engine/core-modules/ai/constants/ai-models.const';
import { getDefaultModelConfig } from './get-default-model-config.util';
describe('getDefaultModelConfig', () => {
it('should return the configuration for the default model', () => {
const result = getDefaultModelConfig();
expect(result).toBeDefined();
expect(result.modelId).toBe(DEFAULT_MODEL_ID);
expect(result.provider).toBe(ModelProvider.OPENAI);
});
it('should throw an error if default model is not found', () => {
const originalFind = AI_MODELS.find;
AI_MODELS.find = jest.fn().mockReturnValue(undefined);
expect(() => getDefaultModelConfig()).toThrow(
`Default model with ID ${DEFAULT_MODEL_ID} not found`,
);
AI_MODELS.find = originalFind;
});
});

View File

@ -1,17 +0,0 @@
import {
AI_MODELS,
AIModelConfig,
DEFAULT_MODEL_ID,
} from 'src/engine/core-modules/ai/constants/ai-models.const';
export const getDefaultModelConfig = (): AIModelConfig => {
const defaultModel = AI_MODELS.find(
(model) => model.modelId === DEFAULT_MODEL_ID,
);
if (!defaultModel) {
throw new Error(`Default model with ID ${DEFAULT_MODEL_ID} not found`);
}
return defaultModel;
};

View File

@ -1,30 +0,0 @@
import {
DEFAULT_MODEL_ID,
ModelProvider,
} from 'src/engine/core-modules/ai/constants/ai-models.const';
import { getEffectiveModelConfig } from './get-effective-model-config.util';
describe('getEffectiveModelConfig', () => {
it('should return default model config when modelId is "auto"', () => {
const result = getEffectiveModelConfig('auto');
expect(result).toBeDefined();
expect(result.modelId).toBe(DEFAULT_MODEL_ID);
expect(result.provider).toBe(ModelProvider.OPENAI);
});
it('should return the correct model config for a specific model', () => {
const result = getEffectiveModelConfig('gpt-4o');
expect(result).toBeDefined();
expect(result.modelId).toBe('gpt-4o');
expect(result.provider).toBe(ModelProvider.OPENAI);
});
it('should throw an error for non-existent model', () => {
expect(() => getEffectiveModelConfig('non-existent-model' as any)).toThrow(
`Model with ID non-existent-model not found`,
);
});
});

View File

@ -1,20 +0,0 @@
import {
AI_MODELS,
AIModelConfig,
} from 'src/engine/core-modules/ai/constants/ai-models.const';
import { getDefaultModelConfig } from './get-default-model-config.util';
export const getEffectiveModelConfig = (modelId: string): AIModelConfig => {
if (modelId === 'auto') {
return getDefaultModelConfig();
}
const model = AI_MODELS.find((model) => model.modelId === modelId);
if (!model) {
throw new Error(`Model with ID ${modelId} not found`);
}
return model;
};

View File

@ -10,7 +10,6 @@ import { AuthException } from 'src/engine/core-modules/auth/auth.exception';
import { JwtAuthStrategy } from 'src/engine/core-modules/auth/strategies/jwt.auth.strategy';
import { EmailService } from 'src/engine/core-modules/email/email.service';
import { JwtWrapperService } from 'src/engine/core-modules/jwt/services/jwt-wrapper.service';
import { SSOService } from 'src/engine/core-modules/sso/services/sso.service';
import { TwentyConfigService } from 'src/engine/core-modules/twenty-config/twenty-config.service';
import { UserWorkspace } from 'src/engine/core-modules/user-workspace/user-workspace.entity';
import { User } from 'src/engine/core-modules/user/user.entity';
@ -74,10 +73,6 @@ describe('AccessTokenService', () => {
provide: EmailService,
useValue: {},
},
{
provide: SSOService,
useValue: {},
},
{
provide: TwentyORMGlobalManager,
useValue: {

View File

@ -12,7 +12,6 @@ import { RefreshTokenService } from 'src/engine/core-modules/auth/token/services
import { RenewTokenService } from 'src/engine/core-modules/auth/token/services/renew-token.service';
import { WorkspaceAgnosticTokenService } from 'src/engine/core-modules/auth/token/services/workspace-agnostic-token.service';
import { JwtModule } from 'src/engine/core-modules/jwt/jwt.module';
import { WorkspaceSSOModule } from 'src/engine/core-modules/sso/sso.module';
import { UserWorkspace } from 'src/engine/core-modules/user-workspace/user-workspace.entity';
import { User } from 'src/engine/core-modules/user/user.entity';
import { Workspace } from 'src/engine/core-modules/workspace/workspace.entity';
@ -27,7 +26,6 @@ import { DataSourceModule } from 'src/engine/metadata-modules/data-source/data-s
),
TypeORMModule,
DataSourceModule,
WorkspaceSSOModule,
],
providers: [
RenewTokenService,

View File

@ -3,19 +3,13 @@ import { Test, TestingModule } from '@nestjs/testing';
import { NodeEnvironment } from 'src/engine/core-modules/twenty-config/interfaces/node-environment.interface';
import { SupportDriver } from 'src/engine/core-modules/twenty-config/interfaces/support.interface';
import { AiModelRegistryService } from 'src/engine/core-modules/ai/services/ai-model-registry.service';
import { CaptchaDriverType } from 'src/engine/core-modules/captcha/interfaces';
import { ClientConfigService } from 'src/engine/core-modules/client-config/services/client-config.service';
import { DomainManagerService } from 'src/engine/core-modules/domain-manager/services/domain-manager.service';
import { PUBLIC_FEATURE_FLAGS } from 'src/engine/core-modules/feature-flag/constants/public-feature-flag.const';
import { TwentyConfigService } from 'src/engine/core-modules/twenty-config/twenty-config.service';
jest.mock(
'src/engine/core-modules/ai/utils/get-ai-models-with-auto.util',
() => ({
getAIModelsWithAuto: jest.fn(() => []),
}),
);
describe('ClientConfigService', () => {
let service: ClientConfigService;
let twentyConfigService: TwentyConfigService;
@ -37,6 +31,12 @@ describe('ClientConfigService', () => {
getFrontUrl: jest.fn(),
},
},
{
provide: AiModelRegistryService,
useValue: {
getAvailableModels: jest.fn().mockReturnValue([]),
},
},
],
}).compile();

View File

@ -3,9 +3,12 @@ import { Injectable } from '@nestjs/common';
import { NodeEnvironment } from 'src/engine/core-modules/twenty-config/interfaces/node-environment.interface';
import { SupportDriver } from 'src/engine/core-modules/twenty-config/interfaces/support.interface';
import { ModelProvider } from 'src/engine/core-modules/ai/constants/ai-models.const';
import {
AI_MODELS,
ModelProvider,
} from 'src/engine/core-modules/ai/constants/ai-models.const';
import { AiModelRegistryService } from 'src/engine/core-modules/ai/services/ai-model-registry.service';
import { convertCentsToBillingCredits } from 'src/engine/core-modules/ai/utils/convert-cents-to-billing-credits.util';
import { getAIModelsWithAuto } from 'src/engine/core-modules/ai/utils/get-ai-models-with-auto.util';
import {
ClientAIModelConfig,
ClientConfig,
@ -19,41 +22,49 @@ export class ClientConfigService {
constructor(
private twentyConfigService: TwentyConfigService,
private domainManagerService: DomainManagerService,
private aiModelRegistryService: AiModelRegistryService,
) {}
async getClientConfig(): Promise<ClientConfig> {
const captchaProvider = this.twentyConfigService.get('CAPTCHA_DRIVER');
const supportDriver = this.twentyConfigService.get('SUPPORT_DRIVER');
const openaiApiKey = this.twentyConfigService.get('OPENAI_API_KEY');
const anthropicApiKey = this.twentyConfigService.get('ANTHROPIC_API_KEY');
const aiModels = getAIModelsWithAuto().reduce<ClientAIModelConfig[]>(
(acc, model) => {
const isAvailable =
(model.provider === ModelProvider.OPENAI && openaiApiKey) ||
(model.provider === ModelProvider.ANTHROPIC && anthropicApiKey);
const availableModels = this.aiModelRegistryService.getAvailableModels();
if (!isAvailable) {
return acc;
}
const aiModels: ClientAIModelConfig[] = availableModels.map(
(registeredModel) => {
const builtInModel = AI_MODELS.find(
(m) => m.modelId === registeredModel.modelId,
);
acc.push({
modelId: model.modelId,
label: model.label,
provider: model.provider,
inputCostPer1kTokensInCredits: convertCentsToBillingCredits(
model.inputCostPer1kTokensInCents,
),
outputCostPer1kTokensInCredits: convertCentsToBillingCredits(
model.outputCostPer1kTokensInCents,
),
});
return acc;
return {
modelId: registeredModel.modelId,
label: builtInModel?.label || registeredModel.modelId,
provider: registeredModel.provider,
inputCostPer1kTokensInCredits: builtInModel
? convertCentsToBillingCredits(
builtInModel.inputCostPer1kTokensInCents,
)
: 0,
outputCostPer1kTokensInCredits: builtInModel
? convertCentsToBillingCredits(
builtInModel.outputCostPer1kTokensInCents,
)
: 0,
};
},
[],
);
if (aiModels.length > 0) {
aiModels.unshift({
modelId: 'auto',
label: 'Auto',
provider: ModelProvider.NONE,
inputCostPer1kTokensInCredits: 0,
outputCostPer1kTokensInCredits: 0,
});
}
const clientConfig: ClientConfig = {
billing: {
isBillingEnabled: this.twentyConfigService.get('IS_BILLING_ENABLED'),

View File

@ -6,7 +6,6 @@ import { WorkspaceQueryRunnerModule } from 'src/engine/api/graphql/workspace-que
import { ActorModule } from 'src/engine/core-modules/actor/actor.module';
import { AdminPanelModule } from 'src/engine/core-modules/admin-panel/admin-panel.module';
import { AiModule } from 'src/engine/core-modules/ai/ai.module';
import { aiModuleFactory } from 'src/engine/core-modules/ai/ai.module-factory';
import { ApiKeyModule } from 'src/engine/core-modules/api-key/api-key.module';
import { AppTokenModule } from 'src/engine/core-modules/app-token/app-token.module';
import { ApprovedAccessDomainModule } from 'src/engine/core-modules/approved-access-domain/approved-access-domain.module';
@ -109,10 +108,7 @@ import { FileModule } from './file/file.module';
wildcard: true,
}),
CacheStorageModule,
AiModule.forRoot({
useFactory: aiModuleFactory,
inject: [TwentyConfigService],
}),
AiModule,
ServerlessModule.forRootAsync({
useFactory: serverlessModuleFactory,
inject: [TwentyConfigService, FileStorageService],

View File

@ -11,7 +11,6 @@ import {
} from 'class-validator';
import { isDefined } from 'twenty-shared/utils';
import { AiDriver } from 'src/engine/core-modules/ai/interfaces/ai.interface';
import { AwsRegion } from 'src/engine/core-modules/twenty-config/interfaces/aws-region.interface';
import { NodeEnvironment } from 'src/engine/core-modules/twenty-config/interfaces/node-environment.interface';
import { SupportDriver } from 'src/engine/core-modules/twenty-config/interfaces/support.interface';
@ -975,13 +974,12 @@ export class ConfigVariables {
@ConfigVariablesMetadata({
group: ConfigVariablesGroup.LLM,
description: 'Driver for the AI chat model',
type: ConfigVariableType.ENUM,
options: Object.values(AiDriver),
isEnvOnly: true,
description:
'Default model ID for AI operations (can be any available model)',
type: ConfigVariableType.STRING,
})
@CastToUpperSnakeCase()
AI_DRIVER: AiDriver;
@IsOptional()
DEFAULT_MODEL_ID = 'gpt-4o';
@ConfigVariablesMetadata({
group: ConfigVariablesGroup.LLM,
@ -989,6 +987,7 @@ export class ConfigVariables {
description: 'API key for OpenAI integration',
type: ConfigVariableType.STRING,
})
@IsOptional()
OPENAI_API_KEY: string;
@ConfigVariablesMetadata({
@ -997,8 +996,37 @@ export class ConfigVariables {
description: 'API key for Anthropic integration',
type: ConfigVariableType.STRING,
})
@IsOptional()
ANTHROPIC_API_KEY: string;
@ConfigVariablesMetadata({
group: ConfigVariablesGroup.LLM,
description: 'Base URL for OpenAI-compatible LLM provider (e.g., Ollama)',
type: ConfigVariableType.STRING,
})
@IsOptional()
@IsUrl({ require_tld: false, require_protocol: true })
OPENAI_COMPATIBLE_BASE_URL: string;
@ConfigVariablesMetadata({
group: ConfigVariablesGroup.LLM,
description:
'Model names for OpenAI-compatible LLM provider (comma-separated, e.g., "llama3.1, mistral, codellama")',
type: ConfigVariableType.STRING,
})
@IsOptional()
OPENAI_COMPATIBLE_MODEL_NAMES: string;
@ConfigVariablesMetadata({
group: ConfigVariablesGroup.LLM,
isSensitive: true,
description:
'API key for OpenAI-compatible LLM provider (optional for providers like Ollama)',
type: ConfigVariableType.STRING,
})
@IsOptional()
OPENAI_COMPATIBLE_API_KEY: string;
@ConfigVariablesMetadata({
group: ConfigVariablesGroup.ServerConfig,
description: 'Enable or disable multi-workspace support',

View File

@ -1,4 +1,4 @@
import { Injectable } from '@nestjs/common';
import { Injectable, Logger } from '@nestjs/common';
import { InjectRepository } from '@nestjs/typeorm';
import { createAnthropic } from '@ai-sdk/anthropic';
@ -10,7 +10,7 @@ import {
ModelId,
ModelProvider,
} from 'src/engine/core-modules/ai/constants/ai-models.const';
import { getEffectiveModelConfig } from 'src/engine/core-modules/ai/utils/get-effective-model-config.util';
import { AiModelRegistryService } from 'src/engine/core-modules/ai/services/ai-model-registry.service';
import { TwentyConfigService } from 'src/engine/core-modules/twenty-config/twenty-config.service';
import {
AgentChatMessageEntity,
@ -22,7 +22,6 @@ import { AGENT_SYSTEM_PROMPTS } from 'src/engine/metadata-modules/agent/constant
import { convertOutputSchemaToZod } from 'src/engine/metadata-modules/agent/utils/convert-output-schema-to-zod';
import { OutputSchema } from 'src/modules/workflow/workflow-builder/workflow-schema/types/output-schema.type';
import { resolveInput } from 'src/modules/workflow/workflow-executor/utils/variable-resolver.util';
import { getAIModelById } from 'src/engine/core-modules/ai/utils/get-ai-model-by-id.util';
import { AgentEntity } from './agent.entity';
import { AgentException, AgentExceptionCode } from './agent.exception';
@ -41,21 +40,27 @@ export interface AgentExecutionResult {
@Injectable()
export class AgentExecutionService {
private readonly logger = new Logger(AgentExecutionService.name);
constructor(
private readonly twentyConfigService: TwentyConfigService,
private readonly agentToolService: AgentToolService,
private readonly aiModelRegistryService: AiModelRegistryService,
@InjectRepository(AgentEntity, 'core')
private readonly agentRepository: Repository<AgentEntity>,
) {}
getModel = (modelId: ModelId, provider: ModelProvider) => {
switch (provider) {
case ModelProvider.NONE: {
case ModelProvider.OPENAI_COMPATIBLE: {
const OpenAIProvider = createOpenAI({
apiKey: this.twentyConfigService.get('OPENAI_API_KEY'),
baseURL: this.twentyConfigService.get('OPENAI_COMPATIBLE_BASE_URL'),
apiKey: this.twentyConfigService.get('OPENAI_COMPATIBLE_API_KEY'),
});
return OpenAIProvider(getEffectiveModelConfig(modelId).modelId);
return OpenAIProvider(
this.aiModelRegistryService.getEffectiveModelConfig(modelId).modelId,
);
}
case ModelProvider.OPENAI: {
const OpenAIProvider = createOpenAI({
@ -83,9 +88,6 @@ export class AgentExecutionService {
let apiKey: string | undefined;
switch (provider) {
case ModelProvider.NONE:
apiKey = this.twentyConfigService.get('OPENAI_API_KEY');
break;
case ModelProvider.OPENAI:
apiKey = this.twentyConfigService.get('OPENAI_API_KEY');
break;
@ -93,14 +95,11 @@ export class AgentExecutionService {
apiKey = this.twentyConfigService.get('ANTHROPIC_API_KEY');
break;
default:
throw new AgentException(
`Unsupported provider: ${provider}`,
AgentExceptionCode.AGENT_EXECUTION_FAILED,
);
return;
}
if (!apiKey) {
throw new AgentException(
`${provider === ModelProvider.NONE ? 'OPENAI' : provider.toUpperCase()} API key not configured`,
`${provider.toUpperCase()} API key not configured`,
AgentExceptionCode.API_KEY_NOT_CONFIGURED,
);
}
@ -117,31 +116,55 @@ export class AgentExecutionService {
prompt?: string;
messages?: CoreMessage[];
}) {
const aiModel = getAIModelById(agent.modelId);
if (!aiModel) {
throw new AgentException(
`AI model with id ${agent.modelId} not found`,
AgentExceptionCode.AGENT_EXECUTION_FAILED,
try {
this.logger.log(
`Preparing AI request config for agent ${agent.id} with model ${agent.modelId}`,
);
const aiModel = this.aiModelRegistryService.getEffectiveModelConfig(
agent.modelId,
);
if (!aiModel) {
const error = `AI model with id ${agent.modelId} not found`;
this.logger.error(error);
throw new AgentException(
error,
AgentExceptionCode.AGENT_EXECUTION_FAILED,
);
}
this.logger.log(
`Resolved model: ${aiModel.modelId} (provider: ${aiModel.provider})`,
);
const provider = aiModel.provider;
await this.validateApiKey(provider);
const tools = await this.agentToolService.generateToolsForAgent(
agent.id,
agent.workspaceId,
);
this.logger.log(`Generated ${Object.keys(tools).length} tools for agent`);
return {
system,
tools,
model: this.getModel(aiModel.modelId, aiModel.provider),
...(messages && { messages }),
...(prompt && { prompt }),
maxSteps: AGENT_CONFIG.MAX_STEPS,
};
} catch (error) {
this.logger.error(
`Failed to prepare AI request config for agent ${agent.id}:`,
error instanceof Error ? error.stack : error,
);
throw error;
}
const provider = aiModel.provider;
await this.validateApiKey(provider);
const tools = await this.agentToolService.generateToolsForAgent(
agent.id,
agent.workspaceId,
);
return {
system,
tools,
model: this.getModel(agent.modelId, aiModel.provider),
...(messages && { messages }),
...(prompt && { prompt }),
maxSteps: AGENT_CONFIG.MAX_STEPS,
};
}
async streamChatResponse({
@ -173,6 +196,10 @@ export class AgentExecutionService {
messages: llmMessages,
});
this.logger.log(
`Sending request to AI model with ${llmMessages.length} messages`,
);
return streamText(aiRequestConfig);
}

View File

@ -1,4 +1,4 @@
import { Injectable } from '@nestjs/common';
import { Injectable, Logger } from '@nestjs/common';
import { InjectRepository } from '@nestjs/typeorm';
import { Response } from 'express';
@ -28,6 +28,8 @@ export type StreamAgentChatResult = {
@Injectable()
export class AgentStreamingService {
private readonly logger = new Logger(AgentStreamingService.name);
constructor(
@InjectRepository(AgentChatThreadEntity, 'core')
private readonly threadRepository: Repository<AgentChatThreadEntity>,
@ -89,7 +91,11 @@ export class AgentStreamingService {
message: chunk.args?.toolDescription,
});
break;
case 'error':
this.logger.error(`Stream error: ${JSON.stringify(chunk)}`);
break;
default:
this.logger.log(`Unknown chunk type: ${chunk.type}`);
break;
}
}
@ -105,6 +111,10 @@ export class AgentStreamingService {
const errorMessage =
error instanceof Error ? error.message : 'Unknown error occurred';
if (error instanceof AgentException) {
this.logger.error(`Agent Exception Code: ${error.code}`);
}
if (!res.headersSent) {
this.setupStreamingHeaders(res);
}

View File

@ -1,8 +1,6 @@
import { Module } from '@nestjs/common';
import { TypeOrmModule } from '@nestjs/typeorm';
import { AiDriver } from 'src/engine/core-modules/ai/interfaces/ai.interface';
import { AiModule } from 'src/engine/core-modules/ai/ai.module';
import { AgentEntity } from 'src/engine/metadata-modules/agent/agent.entity';
import { AgentModule } from 'src/engine/metadata-modules/agent/agent.module';
@ -13,9 +11,7 @@ import { AiAgentWorkflowAction } from './ai-agent.workflow-action';
@Module({
imports: [
AgentModule,
AiModule.forRoot({
useFactory: () => ({ type: AiDriver.OPENAI }),
}),
AiModule,
TypeOrmModule.forFeature([AgentEntity], 'core'),
],
providers: [ScopedWorkspaceContextFactory, AiAgentWorkflowAction],