Text-to-SQL proof of concept (#5788)

Added:
- An "Ask AI" command to the command menu.
- A simple GraphQL resolver that converts the user's question into a
relevant SQL query using an LLM, runs the query, and returns the result.

<img width="428" alt="Screenshot 2024-06-09 at 20 53 09"
src="https://github.com/twentyhq/twenty/assets/171685816/57127f37-d4a6-498d-b253-733ffa0d209f">

No security concerns have been addressed, this is only a
proof-of-concept and not intended to be enabled in production.

All changes are behind a feature flag called `IS_ASK_AI_ENABLED`.

---------

Co-authored-by: Félix Malfait <felix.malfait@gmail.com>
This commit is contained in:
ad-elias
2024-07-04 08:57:26 +02:00
committed by GitHub
parent 25fce27fe3
commit 4c642a0bb8
46 changed files with 1463 additions and 40 deletions

View File

@ -7,6 +7,7 @@ import {
import { EventEmitter2 } from '@nestjs/event-emitter';
import isEmpty from 'lodash.isempty';
import { DataSource } from 'typeorm';
import { IConnection } from 'src/engine/api/graphql/workspace-query-runner/interfaces/connection.interface';
import {
@ -620,15 +621,12 @@ export class WorkspaceQueryRunnerService {
return sanitizedRecord;
}
async execute(
query: string,
async executeSQL(
workspaceDataSource: DataSource,
workspaceId: string,
): Promise<PGGraphQLResult | undefined> {
const workspaceDataSource =
await this.workspaceDataSourceService.connectToWorkspaceDataSource(
workspaceId,
);
sqlQuery: string,
parameters?: any[],
) {
try {
return await workspaceDataSource?.transaction(
async (transactionManager) => {
@ -638,10 +636,7 @@ export class WorkspaceQueryRunnerService {
)};
`);
const results = transactionManager.query<PGGraphQLResult>(
`SELECT graphql.resolve($1);`,
[query],
);
const results = transactionManager.query(sqlQuery, parameters);
return results;
},
@ -655,6 +650,23 @@ export class WorkspaceQueryRunnerService {
}
}
async execute(
query: string,
workspaceId: string,
): Promise<PGGraphQLResult | undefined> {
const workspaceDataSource =
await this.workspaceDataSourceService.connectToWorkspaceDataSource(
workspaceId,
);
return this.executeSQL(
workspaceDataSource,
workspaceId,
`SELECT graphql.resolve($1);`,
[query],
);
}
private async parseResult<Result>(
graphqlResult: PGGraphQLResult | undefined,
objectMetadataItem: ObjectMetadataInterface,

View File

@ -0,0 +1,30 @@
import { Module } from '@nestjs/common';
import { TypeOrmModule } from '@nestjs/typeorm';
import { WorkspaceDataSourceModule } from 'src/engine/workspace-datasource/workspace-datasource.module';
import { UserModule } from 'src/engine/core-modules/user/user.module';
import { AISQLQueryResolver } from 'src/engine/core-modules/ai-sql-query/ai-sql-query.resolver';
import { AISQLQueryService } from 'src/engine/core-modules/ai-sql-query/ai-sql-query.service';
import { FeatureFlagEntity } from 'src/engine/core-modules/feature-flag/feature-flag.entity';
import { WorkspaceQueryRunnerModule } from 'src/engine/api/graphql/workspace-query-runner/workspace-query-runner.module';
import { LLMChatModelModule } from 'src/engine/integrations/llm-chat-model/llm-chat-model.module';
import { EnvironmentModule } from 'src/engine/integrations/environment/environment.module';
import { LLMTracingModule } from 'src/engine/integrations/llm-tracing/llm-tracing.module';
import { ObjectMetadataModule } from 'src/engine/metadata-modules/object-metadata/object-metadata.module';
import { WorkspaceSyncMetadataModule } from 'src/engine/workspace-manager/workspace-sync-metadata/workspace-sync-metadata.module';
@Module({
imports: [
WorkspaceDataSourceModule,
WorkspaceQueryRunnerModule,
UserModule,
TypeOrmModule.forFeature([FeatureFlagEntity], 'core'),
LLMChatModelModule,
LLMTracingModule,
EnvironmentModule,
ObjectMetadataModule,
WorkspaceSyncMetadataModule,
],
exports: [],
providers: [AISQLQueryResolver, AISQLQueryService],
})
export class AISQLQueryModule {}

View File

@ -0,0 +1,14 @@
import { PromptTemplate } from '@langchain/core/prompts';
export const sqlGenerationPromptTemplate = PromptTemplate.fromTemplate<{
llmOutputJsonSchema: string;
sqlCreateTableStatements: string;
userQuestion: string;
}>(`Always respond following this JSON Schema: {llmOutputJsonSchema}
Based on the table schema below, write a PostgreSQL query that would answer the user's question. All column names must be enclosed in double quotes.
{sqlCreateTableStatements}
Question: {userQuestion}
SQL Query:`);

View File

@ -0,0 +1,64 @@
import { Args, Query, Resolver, ArgsType, Field } from '@nestjs/graphql';
import { ForbiddenException, UseGuards } from '@nestjs/common';
import { InjectRepository } from '@nestjs/typeorm';
import { Repository } from 'typeorm';
import { User } from 'src/engine/core-modules/user/user.entity';
import { JwtAuthGuard } from 'src/engine/guards/jwt.auth.guard';
import { Workspace } from 'src/engine/core-modules/workspace/workspace.entity';
import { AuthWorkspace } from 'src/engine/decorators/auth/auth-workspace.decorator';
import {
FeatureFlagEntity,
FeatureFlagKeys,
} from 'src/engine/core-modules/feature-flag/feature-flag.entity';
import { AuthUser } from 'src/engine/decorators/auth/auth-user.decorator';
import { AISQLQueryResult } from 'src/engine/core-modules/ai-sql-query/dtos/ai-sql-query-result.dto';
import { AISQLQueryService } from 'src/engine/core-modules/ai-sql-query/ai-sql-query.service';
@ArgsType()
class GetAISQLQueryArgs {
@Field(() => String)
text: string;
}
@UseGuards(JwtAuthGuard)
@Resolver(() => AISQLQueryResult)
export class AISQLQueryResolver {
constructor(
private readonly aiSqlQueryService: AISQLQueryService,
@InjectRepository(FeatureFlagEntity, 'core')
private readonly featureFlagRepository: Repository<FeatureFlagEntity>,
) {}
@Query(() => AISQLQueryResult)
async getAISQLQuery(
@AuthWorkspace() { id: workspaceId }: Workspace,
@AuthUser() user: User,
@Args() { text }: GetAISQLQueryArgs,
) {
const isCopilotEnabledFeatureFlag =
await this.featureFlagRepository.findOneBy({
workspaceId,
key: FeatureFlagKeys.IsCopilotEnabled,
value: true,
});
if (!isCopilotEnabledFeatureFlag?.value) {
throw new ForbiddenException(
`${FeatureFlagKeys.IsCopilotEnabled} feature flag is disabled`,
);
}
const traceMetadata = {
userId: user.id,
userEmail: user.email,
};
return this.aiSqlQueryService.generateAndExecute(
workspaceId,
text,
traceMetadata,
);
}
}

View File

@ -0,0 +1,253 @@
import { Injectable, Logger } from '@nestjs/common';
import { RunnableSequence } from '@langchain/core/runnables';
import { StructuredOutputParser } from '@langchain/core/output_parsers';
import { DataSource, QueryFailedError } from 'typeorm';
import { z } from 'zod';
import { zodToJsonSchema } from 'zod-to-json-schema';
import { PostgresConnectionOptions } from 'typeorm/driver/postgres/PostgresConnectionOptions';
import groupBy from 'lodash.groupby';
import { PartialFieldMetadata } from 'src/engine/workspace-manager/workspace-sync-metadata/interfaces/partial-field-metadata.interface';
import { WorkspaceDataSourceService } from 'src/engine/workspace-datasource/workspace-datasource.service';
import { WorkspaceQueryRunnerService } from 'src/engine/api/graphql/workspace-query-runner/workspace-query-runner.service';
import { LLMChatModelService } from 'src/engine/integrations/llm-chat-model/llm-chat-model.service';
import { LLMTracingService } from 'src/engine/integrations/llm-tracing/llm-tracing.service';
import { ObjectMetadataEntity } from 'src/engine/metadata-modules/object-metadata/object-metadata.entity';
import { DEFAULT_LABEL_IDENTIFIER_FIELD_NAME } from 'src/engine/metadata-modules/object-metadata/object-metadata.constants';
import { StandardObjectFactory } from 'src/engine/workspace-manager/workspace-sync-metadata/factories/standard-object.factory';
import { standardObjectMetadataDefinitions } from 'src/engine/workspace-manager/workspace-sync-metadata/standard-objects';
import { AISQLQueryResult } from 'src/engine/core-modules/ai-sql-query/dtos/ai-sql-query-result.dto';
import { sqlGenerationPromptTemplate } from 'src/engine/core-modules/ai-sql-query/ai-sql-query.prompt-templates';
@Injectable()
export class AISQLQueryService {
private readonly logger = new Logger(AISQLQueryService.name);
constructor(
private readonly workspaceDataSourceService: WorkspaceDataSourceService,
private readonly workspaceQueryRunnerService: WorkspaceQueryRunnerService,
private readonly llmChatModelService: LLMChatModelService,
private readonly llmTracingService: LLMTracingService,
private readonly standardObjectFactory: StandardObjectFactory,
) {}
private getLabelIdentifierName(
objectMetadata: ObjectMetadataEntity,
dataSourceId,
workspaceId,
workspaceFeatureFlagsMap,
): string | undefined {
const customObjectLabelIdentifierFieldMetadata = objectMetadata.fields.find(
(fieldMetadata) =>
fieldMetadata.id === objectMetadata.labelIdentifierFieldMetadataId,
);
const standardObjectMetadataCollection = this.standardObjectFactory.create(
standardObjectMetadataDefinitions,
{ workspaceId, dataSourceId },
workspaceFeatureFlagsMap,
);
const standardObjectLabelIdentifierFieldMetadata =
standardObjectMetadataCollection
.find(
(standardObjectMetadata) =>
standardObjectMetadata.nameSingular === objectMetadata.nameSingular,
)
?.fields.find(
(field: PartialFieldMetadata) =>
field.name === DEFAULT_LABEL_IDENTIFIER_FIELD_NAME,
) as PartialFieldMetadata;
const labelIdentifierFieldMetadata =
customObjectLabelIdentifierFieldMetadata ??
standardObjectLabelIdentifierFieldMetadata;
return (
labelIdentifierFieldMetadata?.name ?? DEFAULT_LABEL_IDENTIFIER_FIELD_NAME
);
}
private async getColInfosByTableName(dataSource: DataSource) {
const { schema } = dataSource.options as PostgresConnectionOptions;
// From LangChain sql_utils.ts
const sqlQuery = `SELECT
t.table_name,
c.*
FROM
information_schema.tables t
JOIN information_schema.columns c
ON t.table_name = c.table_name
WHERE
t.table_schema = '${schema}'
AND c.table_schema = '${schema}'
ORDER BY
t.table_name,
c.ordinal_position;`;
const colInfos = await dataSource.query<
{
table_name: string;
column_name: string;
data_type: string | undefined;
is_nullable: 'YES' | 'NO';
}[]
>(sqlQuery);
return groupBy(colInfos, (colInfo) => colInfo.table_name);
}
private getCreateTableStatement(tableName: string, colInfos: any[]) {
return `${`CREATE TABLE ${tableName} (\n`} ${colInfos
.map(
(colInfo) =>
`${colInfo.column_name} ${colInfo.data_type} ${
colInfo.is_nullable === 'YES' ? '' : 'NOT NULL'
}`,
)
.join(', ')});`;
}
private getRelationDescriptions() {
// TODO - Construct sentences like the following:
// investorId: a foreign key referencing the person table, indicating the investor who owns this portfolio company.
return '';
}
private getTableDescription(tableName: string, colInfos: any[]) {
return [
this.getCreateTableStatement(tableName, colInfos),
this.getRelationDescriptions(),
].join('\n');
}
private async getWorkspaceSchemaDescription(
dataSource: DataSource,
): Promise<string> {
const colInfoByTableName = await this.getColInfosByTableName(dataSource);
return Object.entries(colInfoByTableName)
.map(([tableName, colInfos]) =>
this.getTableDescription(tableName, colInfos),
)
.join('\n\n');
}
private async generateWithDataSource(
workspaceId: string,
workspaceDataSource: DataSource,
userQuestion: string,
traceMetadata: Record<string, string> = {},
) {
const workspaceSchemaName =
this.workspaceDataSourceService.getSchemaName(workspaceId);
workspaceDataSource.setOptions({
schema: workspaceSchemaName,
});
const workspaceSchemaDescription =
await this.getWorkspaceSchemaDescription(workspaceDataSource);
const llmOutputSchema = z.object({
sqlQuery: z.string(),
});
const llmOutputJsonSchema = JSON.stringify(
zodToJsonSchema(llmOutputSchema),
);
const structuredOutputParser =
StructuredOutputParser.fromZodSchema(llmOutputSchema);
const sqlQueryGeneratorChain = RunnableSequence.from([
sqlGenerationPromptTemplate,
this.llmChatModelService.getJSONChatModel(),
structuredOutputParser,
]);
const metadata = {
workspaceId,
...traceMetadata,
};
const tracingCallbackHandler =
this.llmTracingService.getCallbackHandler(metadata);
const { sqlQuery } = await sqlQueryGeneratorChain.invoke(
{
llmOutputJsonSchema,
sqlCreateTableStatements: workspaceSchemaDescription,
userQuestion,
},
{
callbacks: [tracingCallbackHandler],
},
);
return sqlQuery;
}
async generate(
workspaceId: string,
userQuestion: string,
traceMetadata: Record<string, string> = {},
) {
const workspaceDataSource =
await this.workspaceDataSourceService.connectToWorkspaceDataSource(
workspaceId,
);
return this.generateWithDataSource(
workspaceId,
workspaceDataSource,
userQuestion,
traceMetadata,
);
}
async generateAndExecute(
workspaceId: string,
userQuestion: string,
traceMetadata: Record<string, string> = {},
): Promise<AISQLQueryResult> {
const workspaceDataSource =
await this.workspaceDataSourceService.connectToWorkspaceDataSource(
workspaceId,
);
const sqlQuery = await this.generateWithDataSource(
workspaceId,
workspaceDataSource,
userQuestion,
traceMetadata,
);
try {
const sqlQueryResult: Record<string, any>[] =
await this.workspaceQueryRunnerService.executeSQL(
workspaceDataSource,
workspaceId,
sqlQuery,
);
return {
sqlQuery,
sqlQueryResult: JSON.stringify(sqlQueryResult),
};
} catch (error) {
if (error instanceof QueryFailedError) {
return {
sqlQuery,
queryFailedErrorMessage: error.message,
};
}
this.logger.error(error.message, error.stack);
return {
sqlQuery,
};
}
}
}

View File

@ -0,0 +1,17 @@
import { Field, ObjectType } from '@nestjs/graphql';
import { IsOptional } from 'class-validator';
@ObjectType('AISQLQueryResult')
export class AISQLQueryResult {
@Field(() => String)
sqlQuery: string;
@Field(() => String, { nullable: true })
@IsOptional()
sqlQueryResult?: string;
@Field(() => String, { nullable: true })
@IsOptional()
queryFailedErrorMessage?: string;
}

View File

@ -10,6 +10,7 @@ import { TimelineMessagingModule } from 'src/engine/core-modules/messaging/timel
import { TimelineCalendarEventModule } from 'src/engine/core-modules/calendar/timeline-calendar-event.module';
import { BillingModule } from 'src/engine/core-modules/billing/billing.module';
import { HealthModule } from 'src/engine/core-modules/health/health.module';
import { AISQLQueryModule } from 'src/engine/core-modules/ai-sql-query/ai-sql-query.module';
import { PostgresCredentialsModule } from 'src/engine/core-modules/postgres-credentials/postgres-credentials.module';
import { AnalyticsModule } from './analytics/analytics.module';
@ -31,6 +32,7 @@ import { ClientConfigModule } from './client-config/client-config.module';
TimelineCalendarEventModule,
UserModule,
WorkspaceModule,
AISQLQueryModule,
PostgresCredentialsModule,
],
exports: [

View File

@ -22,6 +22,7 @@ export enum FeatureFlagKeys {
IsPostgreSQLIntegrationEnabled = 'IS_POSTGRESQL_INTEGRATION_ENABLED',
IsStripeIntegrationEnabled = 'IS_STRIPE_INTEGRATION_ENABLED',
IsContactCreationForSentAndReceivedEmailsEnabled = 'IS_CONTACT_CREATION_FOR_SENT_AND_RECEIVED_EMAILS_ENABLED',
IsCopilotEnabled = 'IS_COPILOT_ENABLED',
IsMessagingAliasFetchingEnabled = 'IS_MESSAGING_ALIAS_FETCHING_ENABLED',
IsGoogleCalendarSyncV2Enabled = 'IS_GOOGLE_CALENDAR_SYNC_V2_ENABLED',
IsFreeAccessEnabled = 'IS_FREE_ACCESS_ENABLED',

View File

@ -17,6 +17,8 @@ import {
import { EmailDriver } from 'src/engine/integrations/email/interfaces/email.interface';
import { NodeEnvironment } from 'src/engine/integrations/environment/interfaces/node-environment.interface';
import { LLMChatModelDriver } from 'src/engine/integrations/llm-chat-model/interfaces/llm-chat-model.interface';
import { LLMTracingDriver } from 'src/engine/integrations/llm-tracing/interfaces/llm-tracing.interface';
import { assert } from 'src/utils/assert';
import { CastToStringArray } from 'src/engine/integrations/environment/decorators/cast-to-string-array.decorator';
@ -369,6 +371,16 @@ export class EnvironmentVariables {
OPENROUTER_API_KEY: string;
LLM_CHAT_MODEL_DRIVER: LLMChatModelDriver = LLMChatModelDriver.OpenAI;
OPENAI_API_KEY: string;
LANGFUSE_SECRET_KEY: string;
LANGFUSE_PUBLIC_KEY: string;
LLM_TRACING_DRIVER: LLMTracingDriver = LLMTracingDriver.Console;
@CastToPositiveNumber()
API_RATE_LIMITING_TTL = 100;

View File

@ -12,6 +12,10 @@ import { emailModuleFactory } from 'src/engine/integrations/email/email.module-f
import { CacheStorageModule } from 'src/engine/integrations/cache-storage/cache-storage.module';
import { CaptchaModule } from 'src/engine/integrations/captcha/captcha.module';
import { captchaModuleFactory } from 'src/engine/integrations/captcha/captcha.module-factory';
import { LLMChatModelModule } from 'src/engine/integrations/llm-chat-model/llm-chat-model.module';
import { llmChatModelModuleFactory } from 'src/engine/integrations/llm-chat-model/llm-chat-model.module-factory';
import { LLMTracingModule } from 'src/engine/integrations/llm-tracing/llm-tracing.module';
import { llmTracingModuleFactory } from 'src/engine/integrations/llm-tracing/llm-tracing.module-factory';
import { EnvironmentModule } from './environment/environment.module';
import { EnvironmentService } from './environment/environment.service';
@ -50,6 +54,14 @@ import { MessageQueueModule } from './message-queue/message-queue.module';
wildcard: true,
}),
CacheStorageModule,
LLMChatModelModule.forRoot({
useFactory: llmChatModelModuleFactory,
inject: [EnvironmentService],
}),
LLMTracingModule.forRoot({
useFactory: llmTracingModuleFactory,
inject: [EnvironmentService],
}),
],
exports: [],
providers: [],

View File

@ -0,0 +1,5 @@
import { BaseChatModel } from '@langchain/core/language_models/chat_models';
export interface LLMChatModelDriver {
getJSONChatModel(): BaseChatModel;
}

View File

@ -0,0 +1,22 @@
import { BaseChatModel } from '@langchain/core/language_models/chat_models';
import { ChatOpenAI } from '@langchain/openai';
import { LLMChatModelDriver } from 'src/engine/integrations/llm-chat-model/drivers/interfaces/llm-prompt-template-driver.interface';
export class OpenAIDriver implements LLMChatModelDriver {
private chatModel: BaseChatModel;
constructor() {
this.chatModel = new ChatOpenAI({
model: 'gpt-4o',
}).bind({
response_format: {
type: 'json_object',
},
}) as unknown as BaseChatModel;
}
getJSONChatModel() {
return this.chatModel;
}
}

View File

@ -0,0 +1,14 @@
import { ModuleMetadata, FactoryProvider } from '@nestjs/common';
export enum LLMChatModelDriver {
OpenAI = 'openai',
}
export interface LLMChatModelModuleOptions {
type: LLMChatModelDriver;
}
export type LLMChatModelModuleAsyncOptions = {
useFactory: (...args: any[]) => LLMChatModelModuleOptions;
} & Pick<ModuleMetadata, 'imports'> &
Pick<FactoryProvider, 'inject'>;

View File

@ -0,0 +1 @@
export const LLM_CHAT_MODEL_DRIVER = Symbol('LLM_CHAT_MODEL_DRIVER');

View File

@ -0,0 +1,19 @@
import { LLMChatModelDriver } from 'src/engine/integrations/llm-chat-model/interfaces/llm-chat-model.interface';
import { EnvironmentService } from 'src/engine/integrations/environment/environment.service';
export const llmChatModelModuleFactory = (
environmentService: EnvironmentService,
) => {
const driver = environmentService.get('LLM_CHAT_MODEL_DRIVER');
switch (driver) {
case LLMChatModelDriver.OpenAI: {
return { type: LLMChatModelDriver.OpenAI };
}
default:
throw new Error(
`Invalid LLM chat model driver (${driver}), check your .env file`,
);
}
};

View File

@ -0,0 +1,35 @@
import { DynamicModule, Global } from '@nestjs/common';
import {
LLMChatModelDriver,
LLMChatModelModuleAsyncOptions,
} from 'src/engine/integrations/llm-chat-model/interfaces/llm-chat-model.interface';
import { LLM_CHAT_MODEL_DRIVER } from 'src/engine/integrations/llm-chat-model/llm-chat-model.constants';
import { OpenAIDriver } from 'src/engine/integrations/llm-chat-model/drivers/openai.driver';
import { LLMChatModelService } from 'src/engine/integrations/llm-chat-model/llm-chat-model.service';
@Global()
export class LLMChatModelModule {
static forRoot(options: LLMChatModelModuleAsyncOptions): DynamicModule {
const provider = {
provide: LLM_CHAT_MODEL_DRIVER,
useFactory: (...args: any[]) => {
const config = options.useFactory(...args);
switch (config.type) {
case LLMChatModelDriver.OpenAI: {
return new OpenAIDriver();
}
}
},
inject: options.inject || [],
};
return {
module: LLMChatModelModule,
providers: [LLMChatModelService, provider],
exports: [LLMChatModelService],
};
}
}

View File

@ -0,0 +1,16 @@
import { Injectable, Inject } from '@nestjs/common';
import { LLMChatModelDriver } from 'src/engine/integrations/llm-chat-model/drivers/interfaces/llm-prompt-template-driver.interface';
import { LLM_CHAT_MODEL_DRIVER } from 'src/engine/integrations/llm-chat-model/llm-chat-model.constants';
@Injectable()
export class LLMChatModelService {
constructor(
@Inject(LLM_CHAT_MODEL_DRIVER) private driver: LLMChatModelDriver,
) {}
getJSONChatModel() {
return this.driver.getJSONChatModel();
}
}

View File

@ -0,0 +1,25 @@
import { BaseCallbackHandler } from '@langchain/core/callbacks/base';
import { ConsoleCallbackHandler } from '@langchain/core/tracers/console';
import { Run } from '@langchain/core/tracers/base';
import { LLMTracingDriver } from 'src/engine/integrations/llm-tracing/drivers/interfaces/llm-tracing-driver.interface';
class WithMetadataConsoleCallbackHandler extends ConsoleCallbackHandler {
private metadata: Record<string, unknown>;
constructor(metadata: Record<string, unknown>) {
super();
this.metadata = metadata;
}
onChainStart(run: Run) {
console.log(`Chain metadata: ${JSON.stringify(this.metadata)}`);
super.onChainStart(run);
}
}
export class ConsoleDriver implements LLMTracingDriver {
getCallbackHandler(metadata: Record<string, unknown>): BaseCallbackHandler {
return new WithMetadataConsoleCallbackHandler(metadata);
}
}

View File

@ -0,0 +1,5 @@
import { BaseCallbackHandler } from '@langchain/core/callbacks/base';
export interface LLMTracingDriver {
getCallbackHandler(metadata: Record<string, unknown>): BaseCallbackHandler;
}

View File

@ -0,0 +1,26 @@
import { BaseCallbackHandler } from '@langchain/core/callbacks/base';
import CallbackHandler from 'langfuse-langchain';
import { LLMTracingDriver } from 'src/engine/integrations/llm-tracing/drivers/interfaces/llm-tracing-driver.interface';
export interface LangfuseDriverOptions {
secretKey: string;
publicKey: string;
}
export class LangfuseDriver implements LLMTracingDriver {
private options: LangfuseDriverOptions;
constructor(options: LangfuseDriverOptions) {
this.options = options;
}
getCallbackHandler(metadata: Record<string, unknown>): BaseCallbackHandler {
return new CallbackHandler({
secretKey: this.options.secretKey,
publicKey: this.options.publicKey,
baseUrl: 'https://cloud.langfuse.com',
metadata: metadata,
});
}
}

View File

@ -0,0 +1,26 @@
import { ModuleMetadata, FactoryProvider } from '@nestjs/common';
import { LangfuseDriverOptions } from 'src/engine/integrations/llm-tracing/drivers/langfuse.driver';
export enum LLMTracingDriver {
Langfuse = 'langfuse',
Console = 'console',
}
export interface LangfuseDriverFactoryOptions {
type: LLMTracingDriver.Langfuse;
options: LangfuseDriverOptions;
}
export interface ConsoleDriverFactoryOptions {
type: LLMTracingDriver.Console;
}
export type LLMTracingModuleOptions =
| LangfuseDriverFactoryOptions
| ConsoleDriverFactoryOptions;
export type LLMTracingModuleAsyncOptions = {
useFactory: (...args: any[]) => LLMTracingModuleOptions;
} & Pick<ModuleMetadata, 'imports'> &
Pick<FactoryProvider, 'inject'>;

View File

@ -0,0 +1 @@
export const LLM_TRACING_DRIVER = Symbol('LLM_TRACING_DRIVER');

View File

@ -0,0 +1,34 @@
import { LLMTracingDriver } from 'src/engine/integrations/llm-tracing/interfaces/llm-tracing.interface';
import { EnvironmentService } from 'src/engine/integrations/environment/environment.service';
export const llmTracingModuleFactory = (
environmentService: EnvironmentService,
) => {
const driver = environmentService.get('LLM_TRACING_DRIVER');
switch (driver) {
case LLMTracingDriver.Console: {
return { type: LLMTracingDriver.Console as const };
}
case LLMTracingDriver.Langfuse: {
const secretKey = environmentService.get('LANGFUSE_SECRET_KEY');
const publicKey = environmentService.get('LANGFUSE_PUBLIC_KEY');
if (!(secretKey && publicKey)) {
throw new Error(
`${driver} LLM tracing driver requires LANGFUSE_SECRET_KEY and LANGFUSE_PUBLIC_KEY to be defined, check your .env file`,
);
}
return {
type: LLMTracingDriver.Langfuse as const,
options: { secretKey, publicKey },
};
}
default:
throw new Error(
`Invalid LLM tracing driver (${driver}), check your .env file`,
);
}
};

View File

@ -0,0 +1,39 @@
import { Global, DynamicModule } from '@nestjs/common';
import {
LLMTracingModuleAsyncOptions,
LLMTracingDriver,
} from 'src/engine/integrations/llm-tracing/interfaces/llm-tracing.interface';
import { LangfuseDriver } from 'src/engine/integrations/llm-tracing/drivers/langfuse.driver';
import { ConsoleDriver } from 'src/engine/integrations/llm-tracing/drivers/console.driver';
import { LLMTracingService } from 'src/engine/integrations/llm-tracing/llm-tracing.service';
import { LLM_TRACING_DRIVER } from 'src/engine/integrations/llm-tracing/llm-tracing.constants';
@Global()
export class LLMTracingModule {
static forRoot(options: LLMTracingModuleAsyncOptions): DynamicModule {
const provider = {
provide: LLM_TRACING_DRIVER,
useFactory: (...args: any[]) => {
const config = options.useFactory(...args);
switch (config.type) {
case LLMTracingDriver.Langfuse: {
return new LangfuseDriver(config.options);
}
case LLMTracingDriver.Console: {
return new ConsoleDriver();
}
}
},
inject: options.inject || [],
};
return {
module: LLMTracingModule,
providers: [LLMTracingService, provider],
exports: [LLMTracingService],
};
}
}

View File

@ -0,0 +1,16 @@
import { Injectable, Inject } from '@nestjs/common';
import { BaseCallbackHandler } from '@langchain/core/callbacks/base';
import { LLMTracingDriver } from 'src/engine/integrations/llm-tracing/drivers/interfaces/llm-tracing-driver.interface';
import { LLM_TRACING_DRIVER } from 'src/engine/integrations/llm-tracing/llm-tracing.constants';
@Injectable()
export class LLMTracingService {
constructor(@Inject(LLM_TRACING_DRIVER) private driver: LLMTracingDriver) {}
getCallbackHandler(metadata: Record<string, unknown>): BaseCallbackHandler {
return this.driver.getCallbackHandler(metadata);
}
}

View File

@ -0,0 +1 @@
export const DEFAULT_LABEL_IDENTIFIER_FIELD_NAME = 'name';

View File

@ -59,6 +59,7 @@ export class AddStandardIdCommand extends CommandRunner {
IS_POSTGRESQL_INTEGRATION_ENABLED: true,
IS_STRIPE_INTEGRATION_ENABLED: false,
IS_CONTACT_CREATION_FOR_SENT_AND_RECEIVED_EMAILS_ENABLED: true,
IS_COPILOT_ENABLED: false,
IS_MESSAGING_ALIAS_FETCHING_ENABLED: true,
IS_GOOGLE_CALENDAR_SYNC_V2_ENABLED: true,
IS_FREE_ACCESS_ENABLED: false,
@ -77,6 +78,7 @@ export class AddStandardIdCommand extends CommandRunner {
IS_POSTGRESQL_INTEGRATION_ENABLED: true,
IS_STRIPE_INTEGRATION_ENABLED: false,
IS_CONTACT_CREATION_FOR_SENT_AND_RECEIVED_EMAILS_ENABLED: true,
IS_COPILOT_ENABLED: false,
IS_MESSAGING_ALIAS_FETCHING_ENABLED: true,
IS_GOOGLE_CALENDAR_SYNC_V2_ENABLED: true,
IS_FREE_ACCESS_ENABLED: false,