Refresh AI model setup (#13171)
Instead of initializing model at start time we do it at run time to be able to swap model provider more easily. Also introduce a third driver for openai-compatible providers, which among other allows for local models with Ollama
This commit is contained in:
@ -1041,7 +1041,8 @@ export enum MessageChannelVisibility {
|
||||
export enum ModelProvider {
|
||||
ANTHROPIC = 'ANTHROPIC',
|
||||
NONE = 'NONE',
|
||||
OPENAI = 'OPENAI'
|
||||
OPENAI = 'OPENAI',
|
||||
OPENAI_COMPATIBLE = 'OPENAI_COMPATIBLE'
|
||||
}
|
||||
|
||||
export type Mutation = {
|
||||
|
||||
@ -998,7 +998,8 @@ export enum MessageChannelVisibility {
|
||||
export enum ModelProvider {
|
||||
ANTHROPIC = 'ANTHROPIC',
|
||||
NONE = 'NONE',
|
||||
OPENAI = 'OPENAI'
|
||||
OPENAI = 'OPENAI',
|
||||
OPENAI_COMPATIBLE = 'OPENAI_COMPATIBLE'
|
||||
}
|
||||
|
||||
export type Mutation = {
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import { useSnackBar } from '@/ui/feedback/snack-bar-manager/hooks/useSnackBar';
|
||||
import { useRecoilComponentStateV2 } from '@/ui/utilities/state/component-state/hooks/useRecoilComponentStateV2';
|
||||
import { useState } from 'react';
|
||||
import { useRecoilState } from 'recoil';
|
||||
@ -22,6 +23,7 @@ interface OptimisticMessage extends AgentChatMessage {
|
||||
|
||||
export const useAgentChat = (agentId: string) => {
|
||||
const apolloClient = useApolloClient();
|
||||
const { enqueueErrorSnackBar } = useSnackBar();
|
||||
|
||||
const [agentChatMessages, setAgentChatMessages] = useRecoilComponentStateV2(
|
||||
agentChatMessagesComponentState,
|
||||
@ -112,6 +114,11 @@ export const useAgentChat = (agentId: string) => {
|
||||
}));
|
||||
scrollToBottom();
|
||||
},
|
||||
onError: (message: string) => {
|
||||
enqueueErrorSnackBar({
|
||||
message,
|
||||
});
|
||||
},
|
||||
});
|
||||
},
|
||||
},
|
||||
|
||||
@ -1,12 +1,13 @@
|
||||
export type AgentStreamingEvent = {
|
||||
type: 'text-delta' | 'tool-call';
|
||||
type: 'text-delta' | 'tool-call' | 'error';
|
||||
message: string;
|
||||
};
|
||||
|
||||
export type AgentStreamingParserCallbacks = {
|
||||
onTextDelta?: (message: string) => void;
|
||||
onToolCall?: (message: string) => void;
|
||||
onError?: (error: Error, rawLine: string) => void;
|
||||
onError?: (message: string) => void;
|
||||
onParseError?: (error: Error, rawLine: string) => void;
|
||||
};
|
||||
|
||||
export const parseAgentStreamingChunk = (
|
||||
@ -27,6 +28,9 @@ export const parseAgentStreamingChunk = (
|
||||
case 'tool-call':
|
||||
callbacks.onToolCall?.(event.message);
|
||||
break;
|
||||
case 'error':
|
||||
callbacks.onError?.(event.message);
|
||||
break;
|
||||
}
|
||||
} catch (error) {
|
||||
// eslint-disable-next-line no-console
|
||||
@ -36,7 +40,7 @@ export const parseAgentStreamingChunk = (
|
||||
error instanceof Error
|
||||
? error
|
||||
: new Error(`Unknown parsing error: ${String(error)}`);
|
||||
callbacks.onError?.(errorMessage, line);
|
||||
callbacks.onParseError?.(errorMessage, line);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1 +0,0 @@
|
||||
export const AI_DRIVER = Symbol('AI_DRIVER');
|
||||
@ -1,17 +0,0 @@
|
||||
import { AiDriver } from 'src/engine/core-modules/ai/interfaces/ai.interface';
|
||||
|
||||
import { TwentyConfigService } from 'src/engine/core-modules/twenty-config/twenty-config.service';
|
||||
|
||||
export const aiModuleFactory = (twentyConfigService: TwentyConfigService) => {
|
||||
const driver = twentyConfigService.get('AI_DRIVER');
|
||||
|
||||
switch (driver) {
|
||||
case AiDriver.OPENAI: {
|
||||
return { type: AiDriver.OPENAI };
|
||||
}
|
||||
default: {
|
||||
// Default to OpenAI driver if no driver is specified
|
||||
return { type: AiDriver.OPENAI };
|
||||
}
|
||||
}
|
||||
};
|
||||
@ -1,65 +1,46 @@
|
||||
import { DynamicModule, Global, Provider } from '@nestjs/common';
|
||||
import { Global, Module } from '@nestjs/common';
|
||||
import { TypeOrmModule } from '@nestjs/typeorm';
|
||||
|
||||
import {
|
||||
AiDriver,
|
||||
AiModuleAsyncOptions,
|
||||
} from 'src/engine/core-modules/ai/interfaces/ai.interface';
|
||||
|
||||
import { AI_DRIVER } from 'src/engine/core-modules/ai/ai.constants';
|
||||
import { AiService } from 'src/engine/core-modules/ai/services/ai.service';
|
||||
import { AiController } from 'src/engine/core-modules/ai/controllers/ai.controller';
|
||||
import { OpenAIDriver } from 'src/engine/core-modules/ai/drivers/openai.driver';
|
||||
import { AIBillingService } from 'src/engine/core-modules/ai/services/ai-billing.service';
|
||||
import { FeatureFlagModule } from 'src/engine/core-modules/feature-flag/feature-flag.module';
|
||||
import { McpController } from 'src/engine/core-modules/ai/controllers/mcp.controller';
|
||||
import { AuthModule } from 'src/engine/core-modules/auth/auth.module';
|
||||
import { ToolService } from 'src/engine/core-modules/ai/services/tool.service';
|
||||
import { AIBillingService } from 'src/engine/core-modules/ai/services/ai-billing.service';
|
||||
import { AiModelRegistryService } from 'src/engine/core-modules/ai/services/ai-model-registry.service';
|
||||
import { AiService } from 'src/engine/core-modules/ai/services/ai.service';
|
||||
import { McpService } from 'src/engine/core-modules/ai/services/mcp.service';
|
||||
import { ToolService } from 'src/engine/core-modules/ai/services/tool.service';
|
||||
import { TokenModule } from 'src/engine/core-modules/auth/token/token.module';
|
||||
import { FeatureFlagModule } from 'src/engine/core-modules/feature-flag/feature-flag.module';
|
||||
import { ObjectMetadataModule } from 'src/engine/metadata-modules/object-metadata/object-metadata.module';
|
||||
import { RoleEntity } from 'src/engine/metadata-modules/role/role.entity';
|
||||
import { UserRoleModule } from 'src/engine/metadata-modules/user-role/user-role.module';
|
||||
import { WorkspacePermissionsCacheModule } from 'src/engine/metadata-modules/workspace-permissions-cache/workspace-permissions-cache.module';
|
||||
import { WorkspaceCacheStorageModule } from 'src/engine/workspace-cache-storage/workspace-cache-storage.module';
|
||||
import { UserRoleModule } from 'src/engine/metadata-modules/user-role/user-role.module';
|
||||
import { RoleEntity } from 'src/engine/metadata-modules/role/role.entity';
|
||||
|
||||
@Global()
|
||||
export class AiModule {
|
||||
static forRoot(options: AiModuleAsyncOptions): DynamicModule {
|
||||
const provider: Provider = {
|
||||
provide: AI_DRIVER,
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
useFactory: (...args: any[]) => {
|
||||
const config = options.useFactory(...args);
|
||||
|
||||
switch (config?.type) {
|
||||
case AiDriver.OPENAI: {
|
||||
return new OpenAIDriver();
|
||||
}
|
||||
}
|
||||
},
|
||||
inject: options.inject || [],
|
||||
};
|
||||
|
||||
return {
|
||||
module: AiModule,
|
||||
imports: [
|
||||
TypeOrmModule.forFeature([RoleEntity], 'core'),
|
||||
FeatureFlagModule,
|
||||
ObjectMetadataModule,
|
||||
WorkspacePermissionsCacheModule,
|
||||
WorkspaceCacheStorageModule,
|
||||
UserRoleModule,
|
||||
AuthModule,
|
||||
],
|
||||
controllers: [AiController, McpController],
|
||||
providers: [
|
||||
AiService,
|
||||
ToolService,
|
||||
AIBillingService,
|
||||
McpService,
|
||||
provider,
|
||||
],
|
||||
exports: [AiService, AIBillingService, ToolService, McpService],
|
||||
};
|
||||
}
|
||||
}
|
||||
@Module({
|
||||
imports: [
|
||||
TypeOrmModule.forFeature([RoleEntity], 'core'),
|
||||
TokenModule,
|
||||
FeatureFlagModule,
|
||||
ObjectMetadataModule,
|
||||
WorkspacePermissionsCacheModule,
|
||||
WorkspaceCacheStorageModule,
|
||||
UserRoleModule,
|
||||
],
|
||||
controllers: [AiController, McpController],
|
||||
providers: [
|
||||
AiService,
|
||||
AiModelRegistryService,
|
||||
ToolService,
|
||||
AIBillingService,
|
||||
McpService,
|
||||
],
|
||||
exports: [
|
||||
AiService,
|
||||
AiModelRegistryService,
|
||||
AIBillingService,
|
||||
ToolService,
|
||||
McpService,
|
||||
],
|
||||
})
|
||||
export class AiModule {}
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
import { getAIModelsWithAuto } from 'src/engine/core-modules/ai/utils/get-ai-models-with-auto.util';
|
||||
import { getDefaultModelConfig } from 'src/engine/core-modules/ai/utils/get-default-model-config.util';
|
||||
import { Test, TestingModule } from '@nestjs/testing';
|
||||
|
||||
import { AI_MODELS, DEFAULT_MODEL_ID, ModelProvider } from './ai-models.const';
|
||||
import { AiModelRegistryService } from 'src/engine/core-modules/ai/services/ai-model-registry.service';
|
||||
import { TwentyConfigService } from 'src/engine/core-modules/twenty-config/twenty-config.service';
|
||||
|
||||
import { AI_MODELS, ModelProvider } from './ai-models.const';
|
||||
|
||||
describe('AI_MODELS', () => {
|
||||
it('should contain all expected models', () => {
|
||||
@ -15,41 +17,120 @@ describe('AI_MODELS', () => {
|
||||
'claude-3-5-haiku-20241022',
|
||||
]);
|
||||
});
|
||||
|
||||
it('should have the default model as the first model', () => {
|
||||
const DEFAULT_MODEL = AI_MODELS.find(
|
||||
(model) => model.modelId === DEFAULT_MODEL_ID,
|
||||
);
|
||||
|
||||
expect(DEFAULT_MODEL).toBeDefined();
|
||||
expect(DEFAULT_MODEL?.modelId).toBe(DEFAULT_MODEL_ID);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getAIModelsWithAuto', () => {
|
||||
it('should return AI_MODELS with auto model prepended', () => {
|
||||
const ORIGINAL_MODELS = AI_MODELS;
|
||||
const MODELS_WITH_AUTO = getAIModelsWithAuto();
|
||||
describe('AiModelRegistryService', () => {
|
||||
let SERVICE: AiModelRegistryService;
|
||||
let MOCK_CONFIG_SERVICE: jest.Mocked<TwentyConfigService>;
|
||||
|
||||
expect(MODELS_WITH_AUTO).toHaveLength(ORIGINAL_MODELS.length + 1);
|
||||
expect(MODELS_WITH_AUTO[0].modelId).toBe('auto');
|
||||
expect(MODELS_WITH_AUTO[0].label).toBe('Auto');
|
||||
expect(MODELS_WITH_AUTO[0].provider).toBe(ModelProvider.NONE);
|
||||
beforeEach(async () => {
|
||||
MOCK_CONFIG_SERVICE = {
|
||||
get: jest.fn(),
|
||||
} as any;
|
||||
|
||||
// Check that the rest of the models are the same
|
||||
expect(MODELS_WITH_AUTO.slice(1)).toEqual(ORIGINAL_MODELS);
|
||||
const MODULE: TestingModule = await Test.createTestingModule({
|
||||
providers: [
|
||||
AiModelRegistryService,
|
||||
{
|
||||
provide: TwentyConfigService,
|
||||
useValue: MOCK_CONFIG_SERVICE,
|
||||
},
|
||||
],
|
||||
}).compile();
|
||||
|
||||
SERVICE = MODULE.get<AiModelRegistryService>(AiModelRegistryService);
|
||||
});
|
||||
|
||||
it('should have auto model with default model costs', () => {
|
||||
const MODELS_WITH_AUTO = getAIModelsWithAuto();
|
||||
const AUTO_MODEL = MODELS_WITH_AUTO[0];
|
||||
const DEFAULT_MODEL = getDefaultModelConfig();
|
||||
it('should return effective model config for auto', () => {
|
||||
MOCK_CONFIG_SERVICE.get.mockReturnValue('gpt-4o');
|
||||
|
||||
expect(AUTO_MODEL.inputCostPer1kTokensInCents).toBe(
|
||||
DEFAULT_MODEL.inputCostPer1kTokensInCents,
|
||||
expect(() => SERVICE.getEffectiveModelConfig('auto')).toThrow(
|
||||
'No AI models are available. Please configure at least one provider.',
|
||||
);
|
||||
expect(AUTO_MODEL.outputCostPer1kTokensInCents).toBe(
|
||||
DEFAULT_MODEL.outputCostPer1kTokensInCents,
|
||||
});
|
||||
|
||||
it('should return effective model config for auto when models are available', () => {
|
||||
MOCK_CONFIG_SERVICE.get.mockReturnValue('gpt-4o');
|
||||
|
||||
jest.spyOn(SERVICE, 'getAvailableModels').mockReturnValue([
|
||||
{
|
||||
modelId: 'gpt-4o',
|
||||
provider: ModelProvider.OPENAI,
|
||||
model: {} as any,
|
||||
},
|
||||
]);
|
||||
|
||||
jest.spyOn(SERVICE, 'getModel').mockReturnValue({
|
||||
modelId: 'gpt-4o',
|
||||
provider: ModelProvider.OPENAI,
|
||||
model: {} as any,
|
||||
});
|
||||
|
||||
const RESULT = SERVICE.getEffectiveModelConfig('auto');
|
||||
|
||||
expect(RESULT).toBeDefined();
|
||||
expect(RESULT.modelId).toBe('gpt-4o');
|
||||
expect(RESULT.provider).toBe(ModelProvider.OPENAI);
|
||||
});
|
||||
|
||||
it('should return effective model config for auto with custom model', () => {
|
||||
MOCK_CONFIG_SERVICE.get.mockReturnValue('mistral');
|
||||
|
||||
jest.spyOn(SERVICE, 'getAvailableModels').mockReturnValue([
|
||||
{
|
||||
modelId: 'mistral',
|
||||
provider: ModelProvider.OPENAI_COMPATIBLE,
|
||||
model: {} as any,
|
||||
},
|
||||
]);
|
||||
|
||||
jest.spyOn(SERVICE, 'getModel').mockReturnValue({
|
||||
modelId: 'mistral',
|
||||
provider: ModelProvider.OPENAI_COMPATIBLE,
|
||||
model: {} as any,
|
||||
});
|
||||
|
||||
const RESULT = SERVICE.getEffectiveModelConfig('auto');
|
||||
|
||||
expect(RESULT).toBeDefined();
|
||||
expect(RESULT.modelId).toBe('mistral');
|
||||
expect(RESULT.provider).toBe(ModelProvider.OPENAI_COMPATIBLE);
|
||||
expect(RESULT.label).toBe('mistral');
|
||||
expect(RESULT.inputCostPer1kTokensInCents).toBe(0);
|
||||
expect(RESULT.outputCostPer1kTokensInCents).toBe(0);
|
||||
});
|
||||
|
||||
it('should return effective model config for specific model', () => {
|
||||
const RESULT = SERVICE.getEffectiveModelConfig('gpt-4o-mini');
|
||||
|
||||
expect(RESULT).toBeDefined();
|
||||
expect(RESULT.modelId).toBe('gpt-4o-mini');
|
||||
expect(RESULT.provider).toBe(ModelProvider.OPENAI);
|
||||
});
|
||||
|
||||
it('should return effective model config for custom model', () => {
|
||||
// Mock that the custom model exists in registry
|
||||
jest.spyOn(SERVICE, 'getModel').mockReturnValue({
|
||||
modelId: 'mistral',
|
||||
provider: ModelProvider.OPENAI_COMPATIBLE,
|
||||
model: {} as any,
|
||||
});
|
||||
|
||||
const RESULT = SERVICE.getEffectiveModelConfig('mistral');
|
||||
|
||||
expect(RESULT).toBeDefined();
|
||||
expect(RESULT.modelId).toBe('mistral');
|
||||
expect(RESULT.provider).toBe(ModelProvider.OPENAI_COMPATIBLE);
|
||||
expect(RESULT.label).toBe('mistral');
|
||||
expect(RESULT.inputCostPer1kTokensInCents).toBe(0);
|
||||
expect(RESULT.outputCostPer1kTokensInCents).toBe(0);
|
||||
});
|
||||
|
||||
it('should throw error for non-existent model', () => {
|
||||
jest.spyOn(SERVICE, 'getModel').mockReturnValue(undefined);
|
||||
|
||||
expect(() => SERVICE.getEffectiveModelConfig('non-existent-model')).toThrow(
|
||||
'Model with ID non-existent-model not found',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
@ -2,6 +2,7 @@ export enum ModelProvider {
|
||||
NONE = 'none',
|
||||
OPENAI = 'openai',
|
||||
ANTHROPIC = 'anthropic',
|
||||
OPENAI_COMPATIBLE = 'open_ai_compatible',
|
||||
}
|
||||
|
||||
export type ModelId =
|
||||
@ -11,9 +12,8 @@ export type ModelId =
|
||||
| 'gpt-4-turbo'
|
||||
| 'claude-opus-4-20250514'
|
||||
| 'claude-sonnet-4-20250514'
|
||||
| 'claude-3-5-haiku-20241022';
|
||||
|
||||
export const DEFAULT_MODEL_ID: ModelId = 'gpt-4o';
|
||||
| 'claude-3-5-haiku-20241022'
|
||||
| string; // Allow custom model names
|
||||
|
||||
export interface AIModelConfig {
|
||||
modelId: ModelId;
|
||||
|
||||
@ -1,8 +0,0 @@
|
||||
import { CoreMessage, StreamTextResult } from 'ai';
|
||||
|
||||
export interface AiDriver {
|
||||
streamText(
|
||||
messages: CoreMessage[],
|
||||
options?: { temperature?: number; maxTokens?: number },
|
||||
): StreamTextResult<Record<string, never>, undefined>;
|
||||
}
|
||||
@ -1,18 +0,0 @@
|
||||
import { openai } from '@ai-sdk/openai';
|
||||
import { CoreMessage, StreamTextResult, streamText } from 'ai';
|
||||
|
||||
import { AiDriver } from 'src/engine/core-modules/ai/drivers/interfaces/ai-driver.interface';
|
||||
|
||||
export class OpenAIDriver implements AiDriver {
|
||||
streamText(
|
||||
messages: CoreMessage[],
|
||||
options?: { temperature?: number; maxTokens?: number },
|
||||
): StreamTextResult<Record<string, never>, undefined> {
|
||||
return streamText({
|
||||
model: openai('gpt-4o-mini'),
|
||||
messages,
|
||||
temperature: options?.temperature,
|
||||
maxTokens: options?.maxTokens,
|
||||
});
|
||||
}
|
||||
}
|
||||
@ -1,14 +0,0 @@
|
||||
import { InjectionToken, OptionalFactoryDependency } from '@nestjs/common';
|
||||
|
||||
export enum AiDriver {
|
||||
OPENAI = 'openai',
|
||||
}
|
||||
|
||||
export interface AiModuleOptions {
|
||||
type: AiDriver;
|
||||
}
|
||||
|
||||
export type AiModuleAsyncOptions = {
|
||||
inject?: (InjectionToken | OptionalFactoryDependency)[];
|
||||
useFactory: (...args: unknown[]) => AiModuleOptions | undefined;
|
||||
};
|
||||
@ -5,6 +5,7 @@ import { BillingMeterEventName } from 'src/engine/core-modules/billing/enums/bil
|
||||
import { WorkspaceEventEmitter } from 'src/engine/workspace-event-emitter/workspace-event-emitter';
|
||||
|
||||
import { AIBillingService } from './ai-billing.service';
|
||||
import { AiModelRegistryService } from './ai-model-registry.service';
|
||||
|
||||
describe('AIBillingService', () => {
|
||||
let service: AIBillingService;
|
||||
@ -21,6 +22,16 @@ describe('AIBillingService', () => {
|
||||
emitCustomBatchEvent: jest.fn(),
|
||||
};
|
||||
|
||||
const mockAiModelRegistryMethods = {
|
||||
getEffectiveModelConfig: jest.fn().mockReturnValue({
|
||||
modelId: 'gpt-4o',
|
||||
label: 'GPT-4o',
|
||||
provider: 'openai',
|
||||
inputCostPer1kTokensInCents: 0.25,
|
||||
outputCostPer1kTokensInCents: 1.0,
|
||||
}),
|
||||
};
|
||||
|
||||
const module: TestingModule = await Test.createTestingModule({
|
||||
providers: [
|
||||
AIBillingService,
|
||||
@ -28,6 +39,10 @@ describe('AIBillingService', () => {
|
||||
provide: WorkspaceEventEmitter,
|
||||
useValue: mockEventEmitterMethods,
|
||||
},
|
||||
{
|
||||
provide: AiModelRegistryService,
|
||||
useValue: mockAiModelRegistryMethods,
|
||||
},
|
||||
],
|
||||
}).compile();
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@ import { Injectable, Logger } from '@nestjs/common';
|
||||
|
||||
import { ModelId } from 'src/engine/core-modules/ai/constants/ai-models.const';
|
||||
import { DOLLAR_TO_CREDIT_MULTIPLIER } from 'src/engine/core-modules/ai/constants/dollar-to-credit-multiplier';
|
||||
import { getAIModelById } from 'src/engine/core-modules/ai/utils/get-ai-model-by-id.util';
|
||||
import { AiModelRegistryService } from 'src/engine/core-modules/ai/services/ai-model-registry.service';
|
||||
import { BILLING_FEATURE_USED } from 'src/engine/core-modules/billing/constants/billing-feature-used.constant';
|
||||
import { BillingMeterEventName } from 'src/engine/core-modules/billing/enums/billing-meter-event-names';
|
||||
import { BillingUsageEvent } from 'src/engine/core-modules/billing/types/billing-usage-event.type';
|
||||
@ -18,10 +18,13 @@ export interface TokenUsage {
|
||||
export class AIBillingService {
|
||||
private readonly logger = new Logger(AIBillingService.name);
|
||||
|
||||
constructor(private readonly workspaceEventEmitter: WorkspaceEventEmitter) {}
|
||||
constructor(
|
||||
private readonly workspaceEventEmitter: WorkspaceEventEmitter,
|
||||
private readonly aiModelRegistryService: AiModelRegistryService,
|
||||
) {}
|
||||
|
||||
async calculateCost(modelId: ModelId, usage: TokenUsage): Promise<number> {
|
||||
const model = getAIModelById(modelId);
|
||||
const model = this.aiModelRegistryService.getEffectiveModelConfig(modelId);
|
||||
|
||||
if (!model) {
|
||||
throw new Error(`AI model with id ${modelId} not found`);
|
||||
|
||||
@ -0,0 +1,185 @@
|
||||
import { Injectable } from '@nestjs/common';
|
||||
|
||||
import { anthropic } from '@ai-sdk/anthropic';
|
||||
import { createOpenAI, openai } from '@ai-sdk/openai';
|
||||
import { LanguageModel } from 'ai';
|
||||
|
||||
import {
|
||||
AI_MODELS,
|
||||
AIModelConfig,
|
||||
ModelProvider,
|
||||
} from 'src/engine/core-modules/ai/constants/ai-models.const';
|
||||
import { TwentyConfigService } from 'src/engine/core-modules/twenty-config/twenty-config.service';
|
||||
|
||||
export interface RegisteredAIModel {
|
||||
modelId: string;
|
||||
provider: ModelProvider;
|
||||
model: LanguageModel;
|
||||
}
|
||||
|
||||
@Injectable()
|
||||
export class AiModelRegistryService {
|
||||
private modelRegistry: Map<string, RegisteredAIModel> = new Map();
|
||||
|
||||
constructor(private twentyConfigService: TwentyConfigService) {
|
||||
this.buildModelRegistry();
|
||||
}
|
||||
|
||||
private buildModelRegistry(): void {
|
||||
this.modelRegistry.clear();
|
||||
|
||||
const openaiApiKey = this.twentyConfigService.get('OPENAI_API_KEY');
|
||||
|
||||
if (openaiApiKey) {
|
||||
this.registerOpenAIModels();
|
||||
}
|
||||
|
||||
const anthropicApiKey = this.twentyConfigService.get('ANTHROPIC_API_KEY');
|
||||
|
||||
if (anthropicApiKey) {
|
||||
this.registerAnthropicModels();
|
||||
}
|
||||
|
||||
const openaiCompatibleBaseUrl = this.twentyConfigService.get(
|
||||
'OPENAI_COMPATIBLE_BASE_URL',
|
||||
);
|
||||
const openaiCompatibleModelNames = this.twentyConfigService.get(
|
||||
'OPENAI_COMPATIBLE_MODEL_NAMES',
|
||||
);
|
||||
|
||||
if (openaiCompatibleBaseUrl && openaiCompatibleModelNames) {
|
||||
this.registerOpenAICompatibleModels(
|
||||
openaiCompatibleBaseUrl,
|
||||
openaiCompatibleModelNames,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
private registerOpenAIModels(): void {
|
||||
const openaiModels = AI_MODELS.filter(
|
||||
(model) => model.provider === ModelProvider.OPENAI,
|
||||
);
|
||||
|
||||
openaiModels.forEach((modelConfig) => {
|
||||
this.modelRegistry.set(modelConfig.modelId, {
|
||||
modelId: modelConfig.modelId,
|
||||
provider: ModelProvider.OPENAI,
|
||||
model: openai(modelConfig.modelId),
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
private registerAnthropicModels(): void {
|
||||
const anthropicModels = AI_MODELS.filter(
|
||||
(model) => model.provider === ModelProvider.ANTHROPIC,
|
||||
);
|
||||
|
||||
anthropicModels.forEach((modelConfig) => {
|
||||
this.modelRegistry.set(modelConfig.modelId, {
|
||||
modelId: modelConfig.modelId,
|
||||
provider: ModelProvider.ANTHROPIC,
|
||||
model: anthropic(modelConfig.modelId),
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
private registerOpenAICompatibleModels(
|
||||
baseUrl: string,
|
||||
modelNamesString: string,
|
||||
): void {
|
||||
const apiKey = this.twentyConfigService.get('OPENAI_COMPATIBLE_API_KEY');
|
||||
const provider = createOpenAI({
|
||||
baseURL: baseUrl,
|
||||
apiKey: apiKey,
|
||||
});
|
||||
|
||||
const modelNames = modelNamesString
|
||||
.split(',')
|
||||
.map((name) => name.trim())
|
||||
.filter((name) => name.length > 0);
|
||||
|
||||
modelNames.forEach((modelId) => {
|
||||
this.modelRegistry.set(modelId, {
|
||||
modelId,
|
||||
provider: ModelProvider.OPENAI_COMPATIBLE,
|
||||
model: provider(modelId),
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
getModel(modelId: string): RegisteredAIModel | undefined {
|
||||
return this.modelRegistry.get(modelId);
|
||||
}
|
||||
|
||||
getAvailableModels(): RegisteredAIModel[] {
|
||||
return Array.from(this.modelRegistry.values());
|
||||
}
|
||||
|
||||
getDefaultModel(): RegisteredAIModel | undefined {
|
||||
const defaultModelId = this.twentyConfigService.get('DEFAULT_MODEL_ID');
|
||||
let model = this.getModel(defaultModelId);
|
||||
|
||||
if (!model) {
|
||||
const availableModels = this.getAvailableModels();
|
||||
|
||||
model = availableModels[0];
|
||||
}
|
||||
|
||||
return model;
|
||||
}
|
||||
|
||||
getEffectiveModelConfig(modelId: string): AIModelConfig {
|
||||
if (modelId === 'auto') {
|
||||
const defaultModel = this.getDefaultModel();
|
||||
|
||||
if (!defaultModel) {
|
||||
throw new Error(
|
||||
'No AI models are available. Please configure at least one provider.',
|
||||
);
|
||||
}
|
||||
|
||||
const modelConfig = AI_MODELS.find(
|
||||
(model) => model.modelId === defaultModel.modelId,
|
||||
);
|
||||
|
||||
if (modelConfig) {
|
||||
return modelConfig;
|
||||
}
|
||||
|
||||
return this.createDefaultConfigForCustomModel(defaultModel);
|
||||
}
|
||||
|
||||
const predefinedModel = AI_MODELS.find(
|
||||
(model) => model.modelId === modelId,
|
||||
);
|
||||
|
||||
if (predefinedModel) {
|
||||
return predefinedModel;
|
||||
}
|
||||
|
||||
const registeredModel = this.getModel(modelId);
|
||||
|
||||
if (registeredModel) {
|
||||
return this.createDefaultConfigForCustomModel(registeredModel);
|
||||
}
|
||||
|
||||
throw new Error(`Model with ID ${modelId} not found`);
|
||||
}
|
||||
|
||||
private createDefaultConfigForCustomModel(
|
||||
registeredModel: RegisteredAIModel,
|
||||
): AIModelConfig {
|
||||
return {
|
||||
modelId: registeredModel.modelId,
|
||||
label: registeredModel.modelId,
|
||||
provider: registeredModel.provider,
|
||||
inputCostPer1kTokensInCents: 0,
|
||||
outputCostPer1kTokensInCents: 0,
|
||||
};
|
||||
}
|
||||
|
||||
// Force refresh the registry (useful if config changes)
|
||||
refreshRegistry(): void {
|
||||
this.buildModelRegistry();
|
||||
}
|
||||
}
|
||||
@ -1,19 +1,47 @@
|
||||
import { Inject, Injectable } from '@nestjs/common';
|
||||
import { Injectable } from '@nestjs/common';
|
||||
|
||||
import { CoreMessage, StreamTextResult } from 'ai';
|
||||
import { CoreMessage, StreamTextResult, streamText } from 'ai';
|
||||
|
||||
import { AiDriver } from 'src/engine/core-modules/ai/drivers/interfaces/ai-driver.interface';
|
||||
|
||||
import { AI_DRIVER } from 'src/engine/core-modules/ai/ai.constants';
|
||||
import { AiModelRegistryService } from 'src/engine/core-modules/ai/services/ai-model-registry.service';
|
||||
|
||||
@Injectable()
|
||||
export class AiService {
|
||||
constructor(@Inject(AI_DRIVER) private driver: AiDriver) {}
|
||||
constructor(private aiModelRegistryService: AiModelRegistryService) {}
|
||||
|
||||
streamText(
|
||||
messages: CoreMessage[],
|
||||
options?: { temperature?: number; maxTokens?: number },
|
||||
options?: {
|
||||
temperature?: number;
|
||||
maxTokens?: number;
|
||||
modelId?: string; // Optional model override
|
||||
},
|
||||
): StreamTextResult<Record<string, never>, undefined> {
|
||||
return this.driver.streamText(messages, options);
|
||||
const modelId = options?.modelId;
|
||||
const registeredModel = modelId
|
||||
? this.aiModelRegistryService.getModel(modelId)
|
||||
: this.aiModelRegistryService.getDefaultModel();
|
||||
|
||||
if (!registeredModel) {
|
||||
throw new Error(
|
||||
modelId
|
||||
? `Model "${modelId}" is not available. Please check your configuration.`
|
||||
: 'No AI models are available. Please configure at least one provider.',
|
||||
);
|
||||
}
|
||||
|
||||
return streamText({
|
||||
model: registeredModel.model,
|
||||
messages,
|
||||
temperature: options?.temperature,
|
||||
maxTokens: options?.maxTokens,
|
||||
});
|
||||
}
|
||||
|
||||
getAvailableModels() {
|
||||
return this.aiModelRegistryService.getAvailableModels();
|
||||
}
|
||||
|
||||
getDefaultModel() {
|
||||
return this.aiModelRegistryService.getDefaultModel();
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,7 +0,0 @@
|
||||
import { AIModelConfig } from 'src/engine/core-modules/ai/constants/ai-models.const';
|
||||
|
||||
import { getEffectiveModelConfig } from './get-effective-model-config.util';
|
||||
|
||||
export const getAIModelById = (modelId: string): AIModelConfig => {
|
||||
return getEffectiveModelConfig(modelId);
|
||||
};
|
||||
@ -1,35 +0,0 @@
|
||||
import {
|
||||
AI_MODELS,
|
||||
ModelProvider,
|
||||
} from 'src/engine/core-modules/ai/constants/ai-models.const';
|
||||
|
||||
import { getAIModelsWithAuto } from './get-ai-models-with-auto.util';
|
||||
import { getDefaultModelConfig } from './get-default-model-config.util';
|
||||
|
||||
describe('getAIModelsWithAuto', () => {
|
||||
it('should return AI_MODELS with auto model prepended', () => {
|
||||
const ORIGINAL_MODELS = AI_MODELS;
|
||||
const MODELS_WITH_AUTO = getAIModelsWithAuto();
|
||||
|
||||
expect(MODELS_WITH_AUTO).toHaveLength(ORIGINAL_MODELS.length + 1);
|
||||
expect(MODELS_WITH_AUTO[0].modelId).toBe('auto');
|
||||
expect(MODELS_WITH_AUTO[0].label).toBe('Auto');
|
||||
expect(MODELS_WITH_AUTO[0].provider).toBe(ModelProvider.NONE);
|
||||
|
||||
// Check that the rest of the models are the same
|
||||
expect(MODELS_WITH_AUTO.slice(1)).toEqual(ORIGINAL_MODELS);
|
||||
});
|
||||
|
||||
it('should have auto model with default model costs', () => {
|
||||
const MODELS_WITH_AUTO = getAIModelsWithAuto();
|
||||
const AUTO_MODEL = MODELS_WITH_AUTO[0];
|
||||
const DEFAULT_MODEL = getDefaultModelConfig();
|
||||
|
||||
expect(AUTO_MODEL.inputCostPer1kTokensInCents).toBe(
|
||||
DEFAULT_MODEL.inputCostPer1kTokensInCents,
|
||||
);
|
||||
expect(AUTO_MODEL.outputCostPer1kTokensInCents).toBe(
|
||||
DEFAULT_MODEL.outputCostPer1kTokensInCents,
|
||||
);
|
||||
});
|
||||
});
|
||||
@ -1,22 +0,0 @@
|
||||
import {
|
||||
AI_MODELS,
|
||||
AIModelConfig,
|
||||
ModelProvider,
|
||||
} from 'src/engine/core-modules/ai/constants/ai-models.const';
|
||||
|
||||
import { getDefaultModelConfig } from './get-default-model-config.util';
|
||||
|
||||
export const getAIModelsWithAuto = (): AIModelConfig[] => {
|
||||
return [
|
||||
{
|
||||
modelId: 'auto',
|
||||
label: 'Auto',
|
||||
provider: ModelProvider.NONE,
|
||||
inputCostPer1kTokensInCents:
|
||||
getDefaultModelConfig().inputCostPer1kTokensInCents,
|
||||
outputCostPer1kTokensInCents:
|
||||
getDefaultModelConfig().outputCostPer1kTokensInCents,
|
||||
},
|
||||
...AI_MODELS,
|
||||
];
|
||||
};
|
||||
@ -1,29 +0,0 @@
|
||||
import {
|
||||
AI_MODELS,
|
||||
DEFAULT_MODEL_ID,
|
||||
ModelProvider,
|
||||
} from 'src/engine/core-modules/ai/constants/ai-models.const';
|
||||
|
||||
import { getDefaultModelConfig } from './get-default-model-config.util';
|
||||
|
||||
describe('getDefaultModelConfig', () => {
|
||||
it('should return the configuration for the default model', () => {
|
||||
const result = getDefaultModelConfig();
|
||||
|
||||
expect(result).toBeDefined();
|
||||
expect(result.modelId).toBe(DEFAULT_MODEL_ID);
|
||||
expect(result.provider).toBe(ModelProvider.OPENAI);
|
||||
});
|
||||
|
||||
it('should throw an error if default model is not found', () => {
|
||||
const originalFind = AI_MODELS.find;
|
||||
|
||||
AI_MODELS.find = jest.fn().mockReturnValue(undefined);
|
||||
|
||||
expect(() => getDefaultModelConfig()).toThrow(
|
||||
`Default model with ID ${DEFAULT_MODEL_ID} not found`,
|
||||
);
|
||||
|
||||
AI_MODELS.find = originalFind;
|
||||
});
|
||||
});
|
||||
@ -1,17 +0,0 @@
|
||||
import {
|
||||
AI_MODELS,
|
||||
AIModelConfig,
|
||||
DEFAULT_MODEL_ID,
|
||||
} from 'src/engine/core-modules/ai/constants/ai-models.const';
|
||||
|
||||
export const getDefaultModelConfig = (): AIModelConfig => {
|
||||
const defaultModel = AI_MODELS.find(
|
||||
(model) => model.modelId === DEFAULT_MODEL_ID,
|
||||
);
|
||||
|
||||
if (!defaultModel) {
|
||||
throw new Error(`Default model with ID ${DEFAULT_MODEL_ID} not found`);
|
||||
}
|
||||
|
||||
return defaultModel;
|
||||
};
|
||||
@ -1,30 +0,0 @@
|
||||
import {
|
||||
DEFAULT_MODEL_ID,
|
||||
ModelProvider,
|
||||
} from 'src/engine/core-modules/ai/constants/ai-models.const';
|
||||
|
||||
import { getEffectiveModelConfig } from './get-effective-model-config.util';
|
||||
|
||||
describe('getEffectiveModelConfig', () => {
|
||||
it('should return default model config when modelId is "auto"', () => {
|
||||
const result = getEffectiveModelConfig('auto');
|
||||
|
||||
expect(result).toBeDefined();
|
||||
expect(result.modelId).toBe(DEFAULT_MODEL_ID);
|
||||
expect(result.provider).toBe(ModelProvider.OPENAI);
|
||||
});
|
||||
|
||||
it('should return the correct model config for a specific model', () => {
|
||||
const result = getEffectiveModelConfig('gpt-4o');
|
||||
|
||||
expect(result).toBeDefined();
|
||||
expect(result.modelId).toBe('gpt-4o');
|
||||
expect(result.provider).toBe(ModelProvider.OPENAI);
|
||||
});
|
||||
|
||||
it('should throw an error for non-existent model', () => {
|
||||
expect(() => getEffectiveModelConfig('non-existent-model' as any)).toThrow(
|
||||
`Model with ID non-existent-model not found`,
|
||||
);
|
||||
});
|
||||
});
|
||||
@ -1,20 +0,0 @@
|
||||
import {
|
||||
AI_MODELS,
|
||||
AIModelConfig,
|
||||
} from 'src/engine/core-modules/ai/constants/ai-models.const';
|
||||
|
||||
import { getDefaultModelConfig } from './get-default-model-config.util';
|
||||
|
||||
export const getEffectiveModelConfig = (modelId: string): AIModelConfig => {
|
||||
if (modelId === 'auto') {
|
||||
return getDefaultModelConfig();
|
||||
}
|
||||
|
||||
const model = AI_MODELS.find((model) => model.modelId === modelId);
|
||||
|
||||
if (!model) {
|
||||
throw new Error(`Model with ID ${modelId} not found`);
|
||||
}
|
||||
|
||||
return model;
|
||||
};
|
||||
@ -10,7 +10,6 @@ import { AuthException } from 'src/engine/core-modules/auth/auth.exception';
|
||||
import { JwtAuthStrategy } from 'src/engine/core-modules/auth/strategies/jwt.auth.strategy';
|
||||
import { EmailService } from 'src/engine/core-modules/email/email.service';
|
||||
import { JwtWrapperService } from 'src/engine/core-modules/jwt/services/jwt-wrapper.service';
|
||||
import { SSOService } from 'src/engine/core-modules/sso/services/sso.service';
|
||||
import { TwentyConfigService } from 'src/engine/core-modules/twenty-config/twenty-config.service';
|
||||
import { UserWorkspace } from 'src/engine/core-modules/user-workspace/user-workspace.entity';
|
||||
import { User } from 'src/engine/core-modules/user/user.entity';
|
||||
@ -74,10 +73,6 @@ describe('AccessTokenService', () => {
|
||||
provide: EmailService,
|
||||
useValue: {},
|
||||
},
|
||||
{
|
||||
provide: SSOService,
|
||||
useValue: {},
|
||||
},
|
||||
{
|
||||
provide: TwentyORMGlobalManager,
|
||||
useValue: {
|
||||
|
||||
@ -12,7 +12,6 @@ import { RefreshTokenService } from 'src/engine/core-modules/auth/token/services
|
||||
import { RenewTokenService } from 'src/engine/core-modules/auth/token/services/renew-token.service';
|
||||
import { WorkspaceAgnosticTokenService } from 'src/engine/core-modules/auth/token/services/workspace-agnostic-token.service';
|
||||
import { JwtModule } from 'src/engine/core-modules/jwt/jwt.module';
|
||||
import { WorkspaceSSOModule } from 'src/engine/core-modules/sso/sso.module';
|
||||
import { UserWorkspace } from 'src/engine/core-modules/user-workspace/user-workspace.entity';
|
||||
import { User } from 'src/engine/core-modules/user/user.entity';
|
||||
import { Workspace } from 'src/engine/core-modules/workspace/workspace.entity';
|
||||
@ -27,7 +26,6 @@ import { DataSourceModule } from 'src/engine/metadata-modules/data-source/data-s
|
||||
),
|
||||
TypeORMModule,
|
||||
DataSourceModule,
|
||||
WorkspaceSSOModule,
|
||||
],
|
||||
providers: [
|
||||
RenewTokenService,
|
||||
|
||||
@ -3,19 +3,13 @@ import { Test, TestingModule } from '@nestjs/testing';
|
||||
import { NodeEnvironment } from 'src/engine/core-modules/twenty-config/interfaces/node-environment.interface';
|
||||
import { SupportDriver } from 'src/engine/core-modules/twenty-config/interfaces/support.interface';
|
||||
|
||||
import { AiModelRegistryService } from 'src/engine/core-modules/ai/services/ai-model-registry.service';
|
||||
import { CaptchaDriverType } from 'src/engine/core-modules/captcha/interfaces';
|
||||
import { ClientConfigService } from 'src/engine/core-modules/client-config/services/client-config.service';
|
||||
import { DomainManagerService } from 'src/engine/core-modules/domain-manager/services/domain-manager.service';
|
||||
import { PUBLIC_FEATURE_FLAGS } from 'src/engine/core-modules/feature-flag/constants/public-feature-flag.const';
|
||||
import { TwentyConfigService } from 'src/engine/core-modules/twenty-config/twenty-config.service';
|
||||
|
||||
jest.mock(
|
||||
'src/engine/core-modules/ai/utils/get-ai-models-with-auto.util',
|
||||
() => ({
|
||||
getAIModelsWithAuto: jest.fn(() => []),
|
||||
}),
|
||||
);
|
||||
|
||||
describe('ClientConfigService', () => {
|
||||
let service: ClientConfigService;
|
||||
let twentyConfigService: TwentyConfigService;
|
||||
@ -37,6 +31,12 @@ describe('ClientConfigService', () => {
|
||||
getFrontUrl: jest.fn(),
|
||||
},
|
||||
},
|
||||
{
|
||||
provide: AiModelRegistryService,
|
||||
useValue: {
|
||||
getAvailableModels: jest.fn().mockReturnValue([]),
|
||||
},
|
||||
},
|
||||
],
|
||||
}).compile();
|
||||
|
||||
|
||||
@ -3,9 +3,12 @@ import { Injectable } from '@nestjs/common';
|
||||
import { NodeEnvironment } from 'src/engine/core-modules/twenty-config/interfaces/node-environment.interface';
|
||||
import { SupportDriver } from 'src/engine/core-modules/twenty-config/interfaces/support.interface';
|
||||
|
||||
import { ModelProvider } from 'src/engine/core-modules/ai/constants/ai-models.const';
|
||||
import {
|
||||
AI_MODELS,
|
||||
ModelProvider,
|
||||
} from 'src/engine/core-modules/ai/constants/ai-models.const';
|
||||
import { AiModelRegistryService } from 'src/engine/core-modules/ai/services/ai-model-registry.service';
|
||||
import { convertCentsToBillingCredits } from 'src/engine/core-modules/ai/utils/convert-cents-to-billing-credits.util';
|
||||
import { getAIModelsWithAuto } from 'src/engine/core-modules/ai/utils/get-ai-models-with-auto.util';
|
||||
import {
|
||||
ClientAIModelConfig,
|
||||
ClientConfig,
|
||||
@ -19,41 +22,49 @@ export class ClientConfigService {
|
||||
constructor(
|
||||
private twentyConfigService: TwentyConfigService,
|
||||
private domainManagerService: DomainManagerService,
|
||||
private aiModelRegistryService: AiModelRegistryService,
|
||||
) {}
|
||||
|
||||
async getClientConfig(): Promise<ClientConfig> {
|
||||
const captchaProvider = this.twentyConfigService.get('CAPTCHA_DRIVER');
|
||||
const supportDriver = this.twentyConfigService.get('SUPPORT_DRIVER');
|
||||
const openaiApiKey = this.twentyConfigService.get('OPENAI_API_KEY');
|
||||
const anthropicApiKey = this.twentyConfigService.get('ANTHROPIC_API_KEY');
|
||||
|
||||
const aiModels = getAIModelsWithAuto().reduce<ClientAIModelConfig[]>(
|
||||
(acc, model) => {
|
||||
const isAvailable =
|
||||
(model.provider === ModelProvider.OPENAI && openaiApiKey) ||
|
||||
(model.provider === ModelProvider.ANTHROPIC && anthropicApiKey);
|
||||
const availableModels = this.aiModelRegistryService.getAvailableModels();
|
||||
|
||||
if (!isAvailable) {
|
||||
return acc;
|
||||
}
|
||||
const aiModels: ClientAIModelConfig[] = availableModels.map(
|
||||
(registeredModel) => {
|
||||
const builtInModel = AI_MODELS.find(
|
||||
(m) => m.modelId === registeredModel.modelId,
|
||||
);
|
||||
|
||||
acc.push({
|
||||
modelId: model.modelId,
|
||||
label: model.label,
|
||||
provider: model.provider,
|
||||
inputCostPer1kTokensInCredits: convertCentsToBillingCredits(
|
||||
model.inputCostPer1kTokensInCents,
|
||||
),
|
||||
outputCostPer1kTokensInCredits: convertCentsToBillingCredits(
|
||||
model.outputCostPer1kTokensInCents,
|
||||
),
|
||||
});
|
||||
|
||||
return acc;
|
||||
return {
|
||||
modelId: registeredModel.modelId,
|
||||
label: builtInModel?.label || registeredModel.modelId,
|
||||
provider: registeredModel.provider,
|
||||
inputCostPer1kTokensInCredits: builtInModel
|
||||
? convertCentsToBillingCredits(
|
||||
builtInModel.inputCostPer1kTokensInCents,
|
||||
)
|
||||
: 0,
|
||||
outputCostPer1kTokensInCredits: builtInModel
|
||||
? convertCentsToBillingCredits(
|
||||
builtInModel.outputCostPer1kTokensInCents,
|
||||
)
|
||||
: 0,
|
||||
};
|
||||
},
|
||||
[],
|
||||
);
|
||||
|
||||
if (aiModels.length > 0) {
|
||||
aiModels.unshift({
|
||||
modelId: 'auto',
|
||||
label: 'Auto',
|
||||
provider: ModelProvider.NONE,
|
||||
inputCostPer1kTokensInCredits: 0,
|
||||
outputCostPer1kTokensInCredits: 0,
|
||||
});
|
||||
}
|
||||
|
||||
const clientConfig: ClientConfig = {
|
||||
billing: {
|
||||
isBillingEnabled: this.twentyConfigService.get('IS_BILLING_ENABLED'),
|
||||
|
||||
@ -6,7 +6,6 @@ import { WorkspaceQueryRunnerModule } from 'src/engine/api/graphql/workspace-que
|
||||
import { ActorModule } from 'src/engine/core-modules/actor/actor.module';
|
||||
import { AdminPanelModule } from 'src/engine/core-modules/admin-panel/admin-panel.module';
|
||||
import { AiModule } from 'src/engine/core-modules/ai/ai.module';
|
||||
import { aiModuleFactory } from 'src/engine/core-modules/ai/ai.module-factory';
|
||||
import { ApiKeyModule } from 'src/engine/core-modules/api-key/api-key.module';
|
||||
import { AppTokenModule } from 'src/engine/core-modules/app-token/app-token.module';
|
||||
import { ApprovedAccessDomainModule } from 'src/engine/core-modules/approved-access-domain/approved-access-domain.module';
|
||||
@ -109,10 +108,7 @@ import { FileModule } from './file/file.module';
|
||||
wildcard: true,
|
||||
}),
|
||||
CacheStorageModule,
|
||||
AiModule.forRoot({
|
||||
useFactory: aiModuleFactory,
|
||||
inject: [TwentyConfigService],
|
||||
}),
|
||||
AiModule,
|
||||
ServerlessModule.forRootAsync({
|
||||
useFactory: serverlessModuleFactory,
|
||||
inject: [TwentyConfigService, FileStorageService],
|
||||
|
||||
@ -11,7 +11,6 @@ import {
|
||||
} from 'class-validator';
|
||||
import { isDefined } from 'twenty-shared/utils';
|
||||
|
||||
import { AiDriver } from 'src/engine/core-modules/ai/interfaces/ai.interface';
|
||||
import { AwsRegion } from 'src/engine/core-modules/twenty-config/interfaces/aws-region.interface';
|
||||
import { NodeEnvironment } from 'src/engine/core-modules/twenty-config/interfaces/node-environment.interface';
|
||||
import { SupportDriver } from 'src/engine/core-modules/twenty-config/interfaces/support.interface';
|
||||
@ -975,13 +974,12 @@ export class ConfigVariables {
|
||||
|
||||
@ConfigVariablesMetadata({
|
||||
group: ConfigVariablesGroup.LLM,
|
||||
description: 'Driver for the AI chat model',
|
||||
type: ConfigVariableType.ENUM,
|
||||
options: Object.values(AiDriver),
|
||||
isEnvOnly: true,
|
||||
description:
|
||||
'Default model ID for AI operations (can be any available model)',
|
||||
type: ConfigVariableType.STRING,
|
||||
})
|
||||
@CastToUpperSnakeCase()
|
||||
AI_DRIVER: AiDriver;
|
||||
@IsOptional()
|
||||
DEFAULT_MODEL_ID = 'gpt-4o';
|
||||
|
||||
@ConfigVariablesMetadata({
|
||||
group: ConfigVariablesGroup.LLM,
|
||||
@ -989,6 +987,7 @@ export class ConfigVariables {
|
||||
description: 'API key for OpenAI integration',
|
||||
type: ConfigVariableType.STRING,
|
||||
})
|
||||
@IsOptional()
|
||||
OPENAI_API_KEY: string;
|
||||
|
||||
@ConfigVariablesMetadata({
|
||||
@ -997,8 +996,37 @@ export class ConfigVariables {
|
||||
description: 'API key for Anthropic integration',
|
||||
type: ConfigVariableType.STRING,
|
||||
})
|
||||
@IsOptional()
|
||||
ANTHROPIC_API_KEY: string;
|
||||
|
||||
@ConfigVariablesMetadata({
|
||||
group: ConfigVariablesGroup.LLM,
|
||||
description: 'Base URL for OpenAI-compatible LLM provider (e.g., Ollama)',
|
||||
type: ConfigVariableType.STRING,
|
||||
})
|
||||
@IsOptional()
|
||||
@IsUrl({ require_tld: false, require_protocol: true })
|
||||
OPENAI_COMPATIBLE_BASE_URL: string;
|
||||
|
||||
@ConfigVariablesMetadata({
|
||||
group: ConfigVariablesGroup.LLM,
|
||||
description:
|
||||
'Model names for OpenAI-compatible LLM provider (comma-separated, e.g., "llama3.1, mistral, codellama")',
|
||||
type: ConfigVariableType.STRING,
|
||||
})
|
||||
@IsOptional()
|
||||
OPENAI_COMPATIBLE_MODEL_NAMES: string;
|
||||
|
||||
@ConfigVariablesMetadata({
|
||||
group: ConfigVariablesGroup.LLM,
|
||||
isSensitive: true,
|
||||
description:
|
||||
'API key for OpenAI-compatible LLM provider (optional for providers like Ollama)',
|
||||
type: ConfigVariableType.STRING,
|
||||
})
|
||||
@IsOptional()
|
||||
OPENAI_COMPATIBLE_API_KEY: string;
|
||||
|
||||
@ConfigVariablesMetadata({
|
||||
group: ConfigVariablesGroup.ServerConfig,
|
||||
description: 'Enable or disable multi-workspace support',
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import { Injectable } from '@nestjs/common';
|
||||
import { Injectable, Logger } from '@nestjs/common';
|
||||
import { InjectRepository } from '@nestjs/typeorm';
|
||||
|
||||
import { createAnthropic } from '@ai-sdk/anthropic';
|
||||
@ -10,7 +10,7 @@ import {
|
||||
ModelId,
|
||||
ModelProvider,
|
||||
} from 'src/engine/core-modules/ai/constants/ai-models.const';
|
||||
import { getEffectiveModelConfig } from 'src/engine/core-modules/ai/utils/get-effective-model-config.util';
|
||||
import { AiModelRegistryService } from 'src/engine/core-modules/ai/services/ai-model-registry.service';
|
||||
import { TwentyConfigService } from 'src/engine/core-modules/twenty-config/twenty-config.service';
|
||||
import {
|
||||
AgentChatMessageEntity,
|
||||
@ -22,7 +22,6 @@ import { AGENT_SYSTEM_PROMPTS } from 'src/engine/metadata-modules/agent/constant
|
||||
import { convertOutputSchemaToZod } from 'src/engine/metadata-modules/agent/utils/convert-output-schema-to-zod';
|
||||
import { OutputSchema } from 'src/modules/workflow/workflow-builder/workflow-schema/types/output-schema.type';
|
||||
import { resolveInput } from 'src/modules/workflow/workflow-executor/utils/variable-resolver.util';
|
||||
import { getAIModelById } from 'src/engine/core-modules/ai/utils/get-ai-model-by-id.util';
|
||||
|
||||
import { AgentEntity } from './agent.entity';
|
||||
import { AgentException, AgentExceptionCode } from './agent.exception';
|
||||
@ -41,21 +40,27 @@ export interface AgentExecutionResult {
|
||||
|
||||
@Injectable()
|
||||
export class AgentExecutionService {
|
||||
private readonly logger = new Logger(AgentExecutionService.name);
|
||||
|
||||
constructor(
|
||||
private readonly twentyConfigService: TwentyConfigService,
|
||||
private readonly agentToolService: AgentToolService,
|
||||
private readonly aiModelRegistryService: AiModelRegistryService,
|
||||
@InjectRepository(AgentEntity, 'core')
|
||||
private readonly agentRepository: Repository<AgentEntity>,
|
||||
) {}
|
||||
|
||||
getModel = (modelId: ModelId, provider: ModelProvider) => {
|
||||
switch (provider) {
|
||||
case ModelProvider.NONE: {
|
||||
case ModelProvider.OPENAI_COMPATIBLE: {
|
||||
const OpenAIProvider = createOpenAI({
|
||||
apiKey: this.twentyConfigService.get('OPENAI_API_KEY'),
|
||||
baseURL: this.twentyConfigService.get('OPENAI_COMPATIBLE_BASE_URL'),
|
||||
apiKey: this.twentyConfigService.get('OPENAI_COMPATIBLE_API_KEY'),
|
||||
});
|
||||
|
||||
return OpenAIProvider(getEffectiveModelConfig(modelId).modelId);
|
||||
return OpenAIProvider(
|
||||
this.aiModelRegistryService.getEffectiveModelConfig(modelId).modelId,
|
||||
);
|
||||
}
|
||||
case ModelProvider.OPENAI: {
|
||||
const OpenAIProvider = createOpenAI({
|
||||
@ -83,9 +88,6 @@ export class AgentExecutionService {
|
||||
let apiKey: string | undefined;
|
||||
|
||||
switch (provider) {
|
||||
case ModelProvider.NONE:
|
||||
apiKey = this.twentyConfigService.get('OPENAI_API_KEY');
|
||||
break;
|
||||
case ModelProvider.OPENAI:
|
||||
apiKey = this.twentyConfigService.get('OPENAI_API_KEY');
|
||||
break;
|
||||
@ -93,14 +95,11 @@ export class AgentExecutionService {
|
||||
apiKey = this.twentyConfigService.get('ANTHROPIC_API_KEY');
|
||||
break;
|
||||
default:
|
||||
throw new AgentException(
|
||||
`Unsupported provider: ${provider}`,
|
||||
AgentExceptionCode.AGENT_EXECUTION_FAILED,
|
||||
);
|
||||
return;
|
||||
}
|
||||
if (!apiKey) {
|
||||
throw new AgentException(
|
||||
`${provider === ModelProvider.NONE ? 'OPENAI' : provider.toUpperCase()} API key not configured`,
|
||||
`${provider.toUpperCase()} API key not configured`,
|
||||
AgentExceptionCode.API_KEY_NOT_CONFIGURED,
|
||||
);
|
||||
}
|
||||
@ -117,31 +116,55 @@ export class AgentExecutionService {
|
||||
prompt?: string;
|
||||
messages?: CoreMessage[];
|
||||
}) {
|
||||
const aiModel = getAIModelById(agent.modelId);
|
||||
|
||||
if (!aiModel) {
|
||||
throw new AgentException(
|
||||
`AI model with id ${agent.modelId} not found`,
|
||||
AgentExceptionCode.AGENT_EXECUTION_FAILED,
|
||||
try {
|
||||
this.logger.log(
|
||||
`Preparing AI request config for agent ${agent.id} with model ${agent.modelId}`,
|
||||
);
|
||||
|
||||
const aiModel = this.aiModelRegistryService.getEffectiveModelConfig(
|
||||
agent.modelId,
|
||||
);
|
||||
|
||||
if (!aiModel) {
|
||||
const error = `AI model with id ${agent.modelId} not found`;
|
||||
|
||||
this.logger.error(error);
|
||||
throw new AgentException(
|
||||
error,
|
||||
AgentExceptionCode.AGENT_EXECUTION_FAILED,
|
||||
);
|
||||
}
|
||||
|
||||
this.logger.log(
|
||||
`Resolved model: ${aiModel.modelId} (provider: ${aiModel.provider})`,
|
||||
);
|
||||
|
||||
const provider = aiModel.provider;
|
||||
|
||||
await this.validateApiKey(provider);
|
||||
|
||||
const tools = await this.agentToolService.generateToolsForAgent(
|
||||
agent.id,
|
||||
agent.workspaceId,
|
||||
);
|
||||
|
||||
this.logger.log(`Generated ${Object.keys(tools).length} tools for agent`);
|
||||
|
||||
return {
|
||||
system,
|
||||
tools,
|
||||
model: this.getModel(aiModel.modelId, aiModel.provider),
|
||||
...(messages && { messages }),
|
||||
...(prompt && { prompt }),
|
||||
maxSteps: AGENT_CONFIG.MAX_STEPS,
|
||||
};
|
||||
} catch (error) {
|
||||
this.logger.error(
|
||||
`Failed to prepare AI request config for agent ${agent.id}:`,
|
||||
error instanceof Error ? error.stack : error,
|
||||
);
|
||||
throw error;
|
||||
}
|
||||
const provider = aiModel.provider;
|
||||
|
||||
await this.validateApiKey(provider);
|
||||
|
||||
const tools = await this.agentToolService.generateToolsForAgent(
|
||||
agent.id,
|
||||
agent.workspaceId,
|
||||
);
|
||||
|
||||
return {
|
||||
system,
|
||||
tools,
|
||||
model: this.getModel(agent.modelId, aiModel.provider),
|
||||
...(messages && { messages }),
|
||||
...(prompt && { prompt }),
|
||||
maxSteps: AGENT_CONFIG.MAX_STEPS,
|
||||
};
|
||||
}
|
||||
|
||||
async streamChatResponse({
|
||||
@ -173,6 +196,10 @@ export class AgentExecutionService {
|
||||
messages: llmMessages,
|
||||
});
|
||||
|
||||
this.logger.log(
|
||||
`Sending request to AI model with ${llmMessages.length} messages`,
|
||||
);
|
||||
|
||||
return streamText(aiRequestConfig);
|
||||
}
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import { Injectable } from '@nestjs/common';
|
||||
import { Injectable, Logger } from '@nestjs/common';
|
||||
import { InjectRepository } from '@nestjs/typeorm';
|
||||
|
||||
import { Response } from 'express';
|
||||
@ -28,6 +28,8 @@ export type StreamAgentChatResult = {
|
||||
|
||||
@Injectable()
|
||||
export class AgentStreamingService {
|
||||
private readonly logger = new Logger(AgentStreamingService.name);
|
||||
|
||||
constructor(
|
||||
@InjectRepository(AgentChatThreadEntity, 'core')
|
||||
private readonly threadRepository: Repository<AgentChatThreadEntity>,
|
||||
@ -89,7 +91,11 @@ export class AgentStreamingService {
|
||||
message: chunk.args?.toolDescription,
|
||||
});
|
||||
break;
|
||||
case 'error':
|
||||
this.logger.error(`Stream error: ${JSON.stringify(chunk)}`);
|
||||
break;
|
||||
default:
|
||||
this.logger.log(`Unknown chunk type: ${chunk.type}`);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -105,6 +111,10 @@ export class AgentStreamingService {
|
||||
const errorMessage =
|
||||
error instanceof Error ? error.message : 'Unknown error occurred';
|
||||
|
||||
if (error instanceof AgentException) {
|
||||
this.logger.error(`Agent Exception Code: ${error.code}`);
|
||||
}
|
||||
|
||||
if (!res.headersSent) {
|
||||
this.setupStreamingHeaders(res);
|
||||
}
|
||||
|
||||
@ -1,8 +1,6 @@
|
||||
import { Module } from '@nestjs/common';
|
||||
import { TypeOrmModule } from '@nestjs/typeorm';
|
||||
|
||||
import { AiDriver } from 'src/engine/core-modules/ai/interfaces/ai.interface';
|
||||
|
||||
import { AiModule } from 'src/engine/core-modules/ai/ai.module';
|
||||
import { AgentEntity } from 'src/engine/metadata-modules/agent/agent.entity';
|
||||
import { AgentModule } from 'src/engine/metadata-modules/agent/agent.module';
|
||||
@ -13,9 +11,7 @@ import { AiAgentWorkflowAction } from './ai-agent.workflow-action';
|
||||
@Module({
|
||||
imports: [
|
||||
AgentModule,
|
||||
AiModule.forRoot({
|
||||
useFactory: () => ({ type: AiDriver.OPENAI }),
|
||||
}),
|
||||
AiModule,
|
||||
TypeOrmModule.forFeature([AgentEntity], 'core'),
|
||||
],
|
||||
providers: [ScopedWorkspaceContextFactory, AiAgentWorkflowAction],
|
||||
|
||||
Reference in New Issue
Block a user