Refresh AI model setup (#13171)
Instead of initializing model at start time we do it at run time to be able to swap model provider more easily. Also introduce a third driver for openai-compatible providers, which among other allows for local models with Ollama
This commit is contained in:
@ -1,4 +1,4 @@
|
||||
import { Injectable } from '@nestjs/common';
|
||||
import { Injectable, Logger } from '@nestjs/common';
|
||||
import { InjectRepository } from '@nestjs/typeorm';
|
||||
|
||||
import { createAnthropic } from '@ai-sdk/anthropic';
|
||||
@ -10,7 +10,7 @@ import {
|
||||
ModelId,
|
||||
ModelProvider,
|
||||
} from 'src/engine/core-modules/ai/constants/ai-models.const';
|
||||
import { getEffectiveModelConfig } from 'src/engine/core-modules/ai/utils/get-effective-model-config.util';
|
||||
import { AiModelRegistryService } from 'src/engine/core-modules/ai/services/ai-model-registry.service';
|
||||
import { TwentyConfigService } from 'src/engine/core-modules/twenty-config/twenty-config.service';
|
||||
import {
|
||||
AgentChatMessageEntity,
|
||||
@ -22,7 +22,6 @@ import { AGENT_SYSTEM_PROMPTS } from 'src/engine/metadata-modules/agent/constant
|
||||
import { convertOutputSchemaToZod } from 'src/engine/metadata-modules/agent/utils/convert-output-schema-to-zod';
|
||||
import { OutputSchema } from 'src/modules/workflow/workflow-builder/workflow-schema/types/output-schema.type';
|
||||
import { resolveInput } from 'src/modules/workflow/workflow-executor/utils/variable-resolver.util';
|
||||
import { getAIModelById } from 'src/engine/core-modules/ai/utils/get-ai-model-by-id.util';
|
||||
|
||||
import { AgentEntity } from './agent.entity';
|
||||
import { AgentException, AgentExceptionCode } from './agent.exception';
|
||||
@ -41,21 +40,27 @@ export interface AgentExecutionResult {
|
||||
|
||||
@Injectable()
|
||||
export class AgentExecutionService {
|
||||
private readonly logger = new Logger(AgentExecutionService.name);
|
||||
|
||||
constructor(
|
||||
private readonly twentyConfigService: TwentyConfigService,
|
||||
private readonly agentToolService: AgentToolService,
|
||||
private readonly aiModelRegistryService: AiModelRegistryService,
|
||||
@InjectRepository(AgentEntity, 'core')
|
||||
private readonly agentRepository: Repository<AgentEntity>,
|
||||
) {}
|
||||
|
||||
getModel = (modelId: ModelId, provider: ModelProvider) => {
|
||||
switch (provider) {
|
||||
case ModelProvider.NONE: {
|
||||
case ModelProvider.OPENAI_COMPATIBLE: {
|
||||
const OpenAIProvider = createOpenAI({
|
||||
apiKey: this.twentyConfigService.get('OPENAI_API_KEY'),
|
||||
baseURL: this.twentyConfigService.get('OPENAI_COMPATIBLE_BASE_URL'),
|
||||
apiKey: this.twentyConfigService.get('OPENAI_COMPATIBLE_API_KEY'),
|
||||
});
|
||||
|
||||
return OpenAIProvider(getEffectiveModelConfig(modelId).modelId);
|
||||
return OpenAIProvider(
|
||||
this.aiModelRegistryService.getEffectiveModelConfig(modelId).modelId,
|
||||
);
|
||||
}
|
||||
case ModelProvider.OPENAI: {
|
||||
const OpenAIProvider = createOpenAI({
|
||||
@ -83,9 +88,6 @@ export class AgentExecutionService {
|
||||
let apiKey: string | undefined;
|
||||
|
||||
switch (provider) {
|
||||
case ModelProvider.NONE:
|
||||
apiKey = this.twentyConfigService.get('OPENAI_API_KEY');
|
||||
break;
|
||||
case ModelProvider.OPENAI:
|
||||
apiKey = this.twentyConfigService.get('OPENAI_API_KEY');
|
||||
break;
|
||||
@ -93,14 +95,11 @@ export class AgentExecutionService {
|
||||
apiKey = this.twentyConfigService.get('ANTHROPIC_API_KEY');
|
||||
break;
|
||||
default:
|
||||
throw new AgentException(
|
||||
`Unsupported provider: ${provider}`,
|
||||
AgentExceptionCode.AGENT_EXECUTION_FAILED,
|
||||
);
|
||||
return;
|
||||
}
|
||||
if (!apiKey) {
|
||||
throw new AgentException(
|
||||
`${provider === ModelProvider.NONE ? 'OPENAI' : provider.toUpperCase()} API key not configured`,
|
||||
`${provider.toUpperCase()} API key not configured`,
|
||||
AgentExceptionCode.API_KEY_NOT_CONFIGURED,
|
||||
);
|
||||
}
|
||||
@ -117,31 +116,55 @@ export class AgentExecutionService {
|
||||
prompt?: string;
|
||||
messages?: CoreMessage[];
|
||||
}) {
|
||||
const aiModel = getAIModelById(agent.modelId);
|
||||
|
||||
if (!aiModel) {
|
||||
throw new AgentException(
|
||||
`AI model with id ${agent.modelId} not found`,
|
||||
AgentExceptionCode.AGENT_EXECUTION_FAILED,
|
||||
try {
|
||||
this.logger.log(
|
||||
`Preparing AI request config for agent ${agent.id} with model ${agent.modelId}`,
|
||||
);
|
||||
|
||||
const aiModel = this.aiModelRegistryService.getEffectiveModelConfig(
|
||||
agent.modelId,
|
||||
);
|
||||
|
||||
if (!aiModel) {
|
||||
const error = `AI model with id ${agent.modelId} not found`;
|
||||
|
||||
this.logger.error(error);
|
||||
throw new AgentException(
|
||||
error,
|
||||
AgentExceptionCode.AGENT_EXECUTION_FAILED,
|
||||
);
|
||||
}
|
||||
|
||||
this.logger.log(
|
||||
`Resolved model: ${aiModel.modelId} (provider: ${aiModel.provider})`,
|
||||
);
|
||||
|
||||
const provider = aiModel.provider;
|
||||
|
||||
await this.validateApiKey(provider);
|
||||
|
||||
const tools = await this.agentToolService.generateToolsForAgent(
|
||||
agent.id,
|
||||
agent.workspaceId,
|
||||
);
|
||||
|
||||
this.logger.log(`Generated ${Object.keys(tools).length} tools for agent`);
|
||||
|
||||
return {
|
||||
system,
|
||||
tools,
|
||||
model: this.getModel(aiModel.modelId, aiModel.provider),
|
||||
...(messages && { messages }),
|
||||
...(prompt && { prompt }),
|
||||
maxSteps: AGENT_CONFIG.MAX_STEPS,
|
||||
};
|
||||
} catch (error) {
|
||||
this.logger.error(
|
||||
`Failed to prepare AI request config for agent ${agent.id}:`,
|
||||
error instanceof Error ? error.stack : error,
|
||||
);
|
||||
throw error;
|
||||
}
|
||||
const provider = aiModel.provider;
|
||||
|
||||
await this.validateApiKey(provider);
|
||||
|
||||
const tools = await this.agentToolService.generateToolsForAgent(
|
||||
agent.id,
|
||||
agent.workspaceId,
|
||||
);
|
||||
|
||||
return {
|
||||
system,
|
||||
tools,
|
||||
model: this.getModel(agent.modelId, aiModel.provider),
|
||||
...(messages && { messages }),
|
||||
...(prompt && { prompt }),
|
||||
maxSteps: AGENT_CONFIG.MAX_STEPS,
|
||||
};
|
||||
}
|
||||
|
||||
async streamChatResponse({
|
||||
@ -173,6 +196,10 @@ export class AgentExecutionService {
|
||||
messages: llmMessages,
|
||||
});
|
||||
|
||||
this.logger.log(
|
||||
`Sending request to AI model with ${llmMessages.length} messages`,
|
||||
);
|
||||
|
||||
return streamText(aiRequestConfig);
|
||||
}
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import { Injectable } from '@nestjs/common';
|
||||
import { Injectable, Logger } from '@nestjs/common';
|
||||
import { InjectRepository } from '@nestjs/typeorm';
|
||||
|
||||
import { Response } from 'express';
|
||||
@ -28,6 +28,8 @@ export type StreamAgentChatResult = {
|
||||
|
||||
@Injectable()
|
||||
export class AgentStreamingService {
|
||||
private readonly logger = new Logger(AgentStreamingService.name);
|
||||
|
||||
constructor(
|
||||
@InjectRepository(AgentChatThreadEntity, 'core')
|
||||
private readonly threadRepository: Repository<AgentChatThreadEntity>,
|
||||
@ -89,7 +91,11 @@ export class AgentStreamingService {
|
||||
message: chunk.args?.toolDescription,
|
||||
});
|
||||
break;
|
||||
case 'error':
|
||||
this.logger.error(`Stream error: ${JSON.stringify(chunk)}`);
|
||||
break;
|
||||
default:
|
||||
this.logger.log(`Unknown chunk type: ${chunk.type}`);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -105,6 +111,10 @@ export class AgentStreamingService {
|
||||
const errorMessage =
|
||||
error instanceof Error ? error.message : 'Unknown error occurred';
|
||||
|
||||
if (error instanceof AgentException) {
|
||||
this.logger.error(`Agent Exception Code: ${error.code}`);
|
||||
}
|
||||
|
||||
if (!res.headersSent) {
|
||||
this.setupStreamingHeaders(res);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user