Add file support to agent chat (#13187)
https://github.com/user-attachments/assets/911d5d8d-cc2e-4c18-9f93-2663d84ff9ef --------- Co-authored-by: Raphaël Bosi <71827178+bosiraphael@users.noreply.github.com> Co-authored-by: neo773 <62795688+neo773@users.noreply.github.com> Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Co-authored-by: Félix Malfait <felix.malfait@gmail.com> Co-authored-by: Félix Malfait <felix@twenty.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions <github-actions@twenty.com> Co-authored-by: MD Readul Islam <99027968+readul-islam@users.noreply.github.com> Co-authored-by: readul-islam <developer.readul@gamil.com> Co-authored-by: Thomas des Francs <tdesfrancs@gmail.com> Co-authored-by: Guillim <guillim@users.noreply.github.com> Co-authored-by: Lucas Bordeau <bordeau.lucas@gmail.com>
This commit is contained in:
@ -5,10 +5,12 @@ import {
|
||||
Index,
|
||||
JoinColumn,
|
||||
ManyToOne,
|
||||
OneToMany,
|
||||
PrimaryGeneratedColumn,
|
||||
Relation,
|
||||
} from 'typeorm';
|
||||
|
||||
import { FileEntity } from 'src/engine/core-modules/file/entities/file.entity';
|
||||
import { AgentChatThreadEntity } from 'src/engine/metadata-modules/agent/agent-chat-thread.entity';
|
||||
|
||||
export enum AgentChatMessageRole {
|
||||
@ -37,6 +39,9 @@ export class AgentChatMessageEntity {
|
||||
@Column('text')
|
||||
content: string;
|
||||
|
||||
@OneToMany(() => FileEntity, (file) => file.message)
|
||||
files: Relation<FileEntity[]>;
|
||||
|
||||
@CreateDateColumn()
|
||||
createdAt: Date;
|
||||
}
|
||||
|
||||
@ -58,7 +58,7 @@ export class AgentChatController {
|
||||
@Post('stream')
|
||||
async streamAgentChat(
|
||||
@Body()
|
||||
body: { threadId: string; userMessage: string },
|
||||
body: { threadId: string; userMessage: string; fileIds?: string[] },
|
||||
@AuthUserWorkspaceId() userWorkspaceId: string,
|
||||
@Res() res: Response,
|
||||
) {
|
||||
@ -67,6 +67,7 @@ export class AgentChatController {
|
||||
threadId: body.threadId,
|
||||
userMessage: body.userMessage,
|
||||
userWorkspaceId,
|
||||
fileIds: body.fileIds || [],
|
||||
res,
|
||||
});
|
||||
} catch (error) {
|
||||
|
||||
@ -3,6 +3,7 @@ import { InjectRepository } from '@nestjs/typeorm';
|
||||
|
||||
import { Repository } from 'typeorm';
|
||||
|
||||
import { FileEntity } from 'src/engine/core-modules/file/entities/file.entity';
|
||||
import {
|
||||
AgentChatMessageEntity,
|
||||
AgentChatMessageRole,
|
||||
@ -20,6 +21,8 @@ export class AgentChatService {
|
||||
private readonly threadRepository: Repository<AgentChatThreadEntity>,
|
||||
@InjectRepository(AgentChatMessageEntity, 'core')
|
||||
private readonly messageRepository: Repository<AgentChatMessageEntity>,
|
||||
@InjectRepository(FileEntity, 'core')
|
||||
private readonly fileRepository: Repository<FileEntity>,
|
||||
) {}
|
||||
|
||||
async createThread(agentId: string, userWorkspaceId: string) {
|
||||
@ -45,10 +48,12 @@ export class AgentChatService {
|
||||
threadId,
|
||||
role,
|
||||
content,
|
||||
fileIds,
|
||||
}: {
|
||||
threadId: string;
|
||||
role: AgentChatMessageRole;
|
||||
content: string;
|
||||
fileIds?: string[];
|
||||
}) {
|
||||
const message = this.messageRepository.create({
|
||||
threadId,
|
||||
@ -56,7 +61,17 @@ export class AgentChatService {
|
||||
content,
|
||||
});
|
||||
|
||||
return this.messageRepository.save(message);
|
||||
const savedMessage = await this.messageRepository.save(message);
|
||||
|
||||
if (fileIds && fileIds.length > 0) {
|
||||
for (const fileId of fileIds) {
|
||||
await this.fileRepository.update(fileId, {
|
||||
messageId: savedMessage.id,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return savedMessage;
|
||||
}
|
||||
|
||||
async getMessagesForThread(threadId: string, userWorkspaceId: string) {
|
||||
@ -77,6 +92,7 @@ export class AgentChatService {
|
||||
return this.messageRepository.find({
|
||||
where: { threadId },
|
||||
order: { createdAt: 'ASC' },
|
||||
relations: ['files'],
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,16 +1,30 @@
|
||||
import { Injectable, Logger } from '@nestjs/common';
|
||||
import { InjectRepository } from '@nestjs/typeorm';
|
||||
|
||||
import { Readable } from 'stream';
|
||||
|
||||
import { createAnthropic } from '@ai-sdk/anthropic';
|
||||
import { createOpenAI } from '@ai-sdk/openai';
|
||||
import { CoreMessage, generateObject, generateText, streamText } from 'ai';
|
||||
import { Repository } from 'typeorm';
|
||||
import {
|
||||
CoreMessage,
|
||||
CoreUserMessage,
|
||||
FilePart,
|
||||
generateObject,
|
||||
generateText,
|
||||
ImagePart,
|
||||
streamText,
|
||||
TextPart,
|
||||
} from 'ai';
|
||||
import { In, Repository } from 'typeorm';
|
||||
|
||||
import {
|
||||
ModelId,
|
||||
ModelProvider,
|
||||
} from 'src/engine/core-modules/ai/constants/ai-models.const';
|
||||
import { AiModelRegistryService } from 'src/engine/core-modules/ai/services/ai-model-registry.service';
|
||||
import { FileEntity } from 'src/engine/core-modules/file/entities/file.entity';
|
||||
import { FileService } from 'src/engine/core-modules/file/services/file.service';
|
||||
import { extractFolderPathAndFilename } from 'src/engine/core-modules/file/utils/extract-folderpath-and-filename.utils';
|
||||
import { TwentyConfigService } from 'src/engine/core-modules/twenty-config/twenty-config.service';
|
||||
import {
|
||||
AgentChatMessageEntity,
|
||||
@ -22,6 +36,7 @@ import { AGENT_SYSTEM_PROMPTS } from 'src/engine/metadata-modules/agent/constant
|
||||
import { convertOutputSchemaToZod } from 'src/engine/metadata-modules/agent/utils/convert-output-schema-to-zod';
|
||||
import { OutputSchema } from 'src/modules/workflow/workflow-builder/workflow-schema/types/output-schema.type';
|
||||
import { resolveInput } from 'src/modules/workflow/workflow-executor/utils/variable-resolver.util';
|
||||
import { streamToBuffer } from 'src/utils/stream-to-buffer';
|
||||
|
||||
import { AgentEntity } from './agent.entity';
|
||||
import { AgentException, AgentExceptionCode } from './agent.exception';
|
||||
@ -45,9 +60,12 @@ export class AgentExecutionService {
|
||||
constructor(
|
||||
private readonly twentyConfigService: TwentyConfigService,
|
||||
private readonly agentToolService: AgentToolService,
|
||||
private readonly fileService: FileService,
|
||||
private readonly aiModelRegistryService: AiModelRegistryService,
|
||||
@InjectRepository(AgentEntity, 'core')
|
||||
private readonly agentRepository: Repository<AgentEntity>,
|
||||
@InjectRepository(FileEntity, 'core')
|
||||
private readonly fileRepository: Repository<FileEntity>,
|
||||
) {}
|
||||
|
||||
getModel = (modelId: ModelId, provider: ModelProvider) => {
|
||||
@ -58,9 +76,7 @@ export class AgentExecutionService {
|
||||
apiKey: this.twentyConfigService.get('OPENAI_COMPATIBLE_API_KEY'),
|
||||
});
|
||||
|
||||
return OpenAIProvider(
|
||||
this.aiModelRegistryService.getEffectiveModelConfig(modelId).modelId,
|
||||
);
|
||||
return OpenAIProvider(modelId);
|
||||
}
|
||||
case ModelProvider.OPENAI: {
|
||||
const OpenAIProvider = createOpenAI({
|
||||
@ -167,14 +183,73 @@ export class AgentExecutionService {
|
||||
}
|
||||
}
|
||||
|
||||
private async buildUserMessageWithFiles(
|
||||
userMessage: string,
|
||||
fileIds?: string[],
|
||||
): Promise<CoreUserMessage> {
|
||||
if (!fileIds || fileIds.length === 0) {
|
||||
return { role: AgentChatMessageRole.USER, content: userMessage };
|
||||
}
|
||||
|
||||
const files = await this.fileRepository.find({
|
||||
where: {
|
||||
id: In(fileIds),
|
||||
},
|
||||
});
|
||||
|
||||
const textPart: TextPart = {
|
||||
type: 'text',
|
||||
text: userMessage,
|
||||
};
|
||||
|
||||
const fileParts = await Promise.all(
|
||||
files.map((file) => this.createFilePart(file)),
|
||||
);
|
||||
|
||||
return {
|
||||
role: AgentChatMessageRole.USER,
|
||||
content: [textPart, ...fileParts],
|
||||
};
|
||||
}
|
||||
|
||||
private async createFilePart(
|
||||
file: FileEntity,
|
||||
): Promise<ImagePart | FilePart> {
|
||||
const { folderPath, filename } = extractFolderPathAndFilename(
|
||||
file.fullPath,
|
||||
);
|
||||
const fileStream = await this.fileService.getFileStream(
|
||||
folderPath,
|
||||
filename,
|
||||
file.workspaceId,
|
||||
);
|
||||
const fileBuffer = await streamToBuffer(fileStream as Readable);
|
||||
|
||||
if (file.type.startsWith('image')) {
|
||||
return {
|
||||
type: 'image',
|
||||
image: fileBuffer,
|
||||
mimeType: file.type,
|
||||
};
|
||||
} else {
|
||||
return {
|
||||
type: 'file',
|
||||
data: fileBuffer,
|
||||
mimeType: file.type,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
async streamChatResponse({
|
||||
agentId,
|
||||
userMessage,
|
||||
messages,
|
||||
fileIds,
|
||||
}: {
|
||||
agentId: string;
|
||||
userMessage: string;
|
||||
messages: AgentChatMessageEntity[];
|
||||
fileIds?: string[];
|
||||
}) {
|
||||
const agent = await this.agentRepository.findOneOrFail({
|
||||
where: { id: agentId },
|
||||
@ -185,10 +260,12 @@ export class AgentExecutionService {
|
||||
content,
|
||||
}));
|
||||
|
||||
llmMessages.push({
|
||||
role: AgentChatMessageRole.USER,
|
||||
content: userMessage,
|
||||
});
|
||||
const userMessageWithFiles = await this.buildUserMessageWithFiles(
|
||||
userMessage,
|
||||
fileIds,
|
||||
);
|
||||
|
||||
llmMessages.push(userMessageWithFiles);
|
||||
|
||||
const aiRequestConfig = await this.prepareAIRequestConfig({
|
||||
system: `${AGENT_SYSTEM_PROMPTS.AGENT_CHAT}\n\n${agent.prompt}`,
|
||||
|
||||
@ -17,6 +17,7 @@ export type StreamAgentChatOptions = {
|
||||
threadId: string;
|
||||
userMessage: string;
|
||||
userWorkspaceId: string;
|
||||
fileIds?: string[];
|
||||
res: Response;
|
||||
};
|
||||
|
||||
@ -41,6 +42,7 @@ export class AgentStreamingService {
|
||||
threadId,
|
||||
userMessage,
|
||||
userWorkspaceId,
|
||||
fileIds = [],
|
||||
res,
|
||||
}: StreamAgentChatOptions) {
|
||||
try {
|
||||
@ -59,12 +61,6 @@ export class AgentStreamingService {
|
||||
);
|
||||
}
|
||||
|
||||
await this.agentChatService.addMessage({
|
||||
threadId,
|
||||
role: AgentChatMessageRole.USER,
|
||||
content: userMessage,
|
||||
});
|
||||
|
||||
this.setupStreamingHeaders(res);
|
||||
|
||||
const { fullStream } =
|
||||
@ -72,6 +68,7 @@ export class AgentStreamingService {
|
||||
agentId: thread.agent.id,
|
||||
userMessage,
|
||||
messages: thread.messages,
|
||||
fileIds,
|
||||
});
|
||||
|
||||
let aiResponse = '';
|
||||
@ -92,6 +89,20 @@ export class AgentStreamingService {
|
||||
});
|
||||
break;
|
||||
case 'error':
|
||||
{
|
||||
const errorMessage =
|
||||
chunk.error &&
|
||||
typeof chunk.error === 'object' &&
|
||||
'message' in chunk.error
|
||||
? chunk.error.message
|
||||
: 'Something went wrong. Please try again.';
|
||||
|
||||
this.sendStreamEvent(res, {
|
||||
type: 'error',
|
||||
message: errorMessage as string,
|
||||
});
|
||||
res.end();
|
||||
}
|
||||
this.logger.error(`Stream error: ${JSON.stringify(chunk)}`);
|
||||
break;
|
||||
default:
|
||||
@ -100,6 +111,19 @@ export class AgentStreamingService {
|
||||
}
|
||||
}
|
||||
|
||||
if (!aiResponse) {
|
||||
res.end();
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
await this.agentChatService.addMessage({
|
||||
threadId,
|
||||
role: AgentChatMessageRole.USER,
|
||||
content: userMessage,
|
||||
fileIds,
|
||||
});
|
||||
|
||||
await this.agentChatService.addMessage({
|
||||
threadId,
|
||||
role: AgentChatMessageRole.ASSISTANT,
|
||||
|
||||
@ -5,6 +5,9 @@ import { AiModule } from 'src/engine/core-modules/ai/ai.module';
|
||||
import { AuditModule } from 'src/engine/core-modules/audit/audit.module';
|
||||
import { TokenModule } from 'src/engine/core-modules/auth/token/token.module';
|
||||
import { FeatureFlagModule } from 'src/engine/core-modules/feature-flag/feature-flag.module';
|
||||
import { FileEntity } from 'src/engine/core-modules/file/entities/file.entity';
|
||||
import { FileUploadModule } from 'src/engine/core-modules/file/file-upload/file-upload.module';
|
||||
import { FileModule } from 'src/engine/core-modules/file/file.module';
|
||||
import { ThrottlerModule } from 'src/engine/core-modules/throttler/throttler.module';
|
||||
import { UserWorkspace } from 'src/engine/core-modules/user-workspace/user-workspace.entity';
|
||||
import { AgentChatController } from 'src/engine/metadata-modules/agent/agent-chat.controller';
|
||||
@ -33,6 +36,7 @@ import { AgentService } from './agent.service';
|
||||
RoleTargetsEntity,
|
||||
AgentChatMessageEntity,
|
||||
AgentChatThreadEntity,
|
||||
FileEntity,
|
||||
UserWorkspace,
|
||||
],
|
||||
'core',
|
||||
@ -41,6 +45,8 @@ import { AgentService } from './agent.service';
|
||||
ThrottlerModule,
|
||||
AuditModule,
|
||||
FeatureFlagModule,
|
||||
FileUploadModule,
|
||||
FileModule,
|
||||
ObjectMetadataModule,
|
||||
WorkspacePermissionsCacheModule,
|
||||
WorkspaceCacheStorageModule,
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
import { Field, ID, ObjectType } from '@nestjs/graphql';
|
||||
|
||||
import { FileDTO } from 'src/engine/core-modules/file/dtos/file.dto';
|
||||
|
||||
@ObjectType('AgentChatMessage')
|
||||
export class AgentChatMessageDTO {
|
||||
@Field(() => ID)
|
||||
@ -14,6 +16,9 @@ export class AgentChatMessageDTO {
|
||||
@Field()
|
||||
content: string;
|
||||
|
||||
@Field(() => [FileDTO], { nullable: true })
|
||||
files?: FileDTO[];
|
||||
|
||||
@Field()
|
||||
createdAt: Date;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user