Text-to-SQL proof of concept (#5788)
Added: - An "Ask AI" command to the command menu. - A simple GraphQL resolver that converts the user's question into a relevant SQL query using an LLM, runs the query, and returns the result. <img width="428" alt="Screenshot 2024-06-09 at 20 53 09" src="https://github.com/twentyhq/twenty/assets/171685816/57127f37-d4a6-498d-b253-733ffa0d209f"> No security concerns have been addressed, this is only a proof-of-concept and not intended to be enabled in production. All changes are behind a feature flag called `IS_ASK_AI_ENABLED`. --------- Co-authored-by: Félix Malfait <felix.malfait@gmail.com>
This commit is contained in:
@ -17,6 +17,8 @@ import {
|
||||
|
||||
import { EmailDriver } from 'src/engine/integrations/email/interfaces/email.interface';
|
||||
import { NodeEnvironment } from 'src/engine/integrations/environment/interfaces/node-environment.interface';
|
||||
import { LLMChatModelDriver } from 'src/engine/integrations/llm-chat-model/interfaces/llm-chat-model.interface';
|
||||
import { LLMTracingDriver } from 'src/engine/integrations/llm-tracing/interfaces/llm-tracing.interface';
|
||||
|
||||
import { assert } from 'src/utils/assert';
|
||||
import { CastToStringArray } from 'src/engine/integrations/environment/decorators/cast-to-string-array.decorator';
|
||||
@ -369,6 +371,16 @@ export class EnvironmentVariables {
|
||||
|
||||
OPENROUTER_API_KEY: string;
|
||||
|
||||
LLM_CHAT_MODEL_DRIVER: LLMChatModelDriver = LLMChatModelDriver.OpenAI;
|
||||
|
||||
OPENAI_API_KEY: string;
|
||||
|
||||
LANGFUSE_SECRET_KEY: string;
|
||||
|
||||
LANGFUSE_PUBLIC_KEY: string;
|
||||
|
||||
LLM_TRACING_DRIVER: LLMTracingDriver = LLMTracingDriver.Console;
|
||||
|
||||
@CastToPositiveNumber()
|
||||
API_RATE_LIMITING_TTL = 100;
|
||||
|
||||
|
||||
@ -12,6 +12,10 @@ import { emailModuleFactory } from 'src/engine/integrations/email/email.module-f
|
||||
import { CacheStorageModule } from 'src/engine/integrations/cache-storage/cache-storage.module';
|
||||
import { CaptchaModule } from 'src/engine/integrations/captcha/captcha.module';
|
||||
import { captchaModuleFactory } from 'src/engine/integrations/captcha/captcha.module-factory';
|
||||
import { LLMChatModelModule } from 'src/engine/integrations/llm-chat-model/llm-chat-model.module';
|
||||
import { llmChatModelModuleFactory } from 'src/engine/integrations/llm-chat-model/llm-chat-model.module-factory';
|
||||
import { LLMTracingModule } from 'src/engine/integrations/llm-tracing/llm-tracing.module';
|
||||
import { llmTracingModuleFactory } from 'src/engine/integrations/llm-tracing/llm-tracing.module-factory';
|
||||
|
||||
import { EnvironmentModule } from './environment/environment.module';
|
||||
import { EnvironmentService } from './environment/environment.service';
|
||||
@ -50,6 +54,14 @@ import { MessageQueueModule } from './message-queue/message-queue.module';
|
||||
wildcard: true,
|
||||
}),
|
||||
CacheStorageModule,
|
||||
LLMChatModelModule.forRoot({
|
||||
useFactory: llmChatModelModuleFactory,
|
||||
inject: [EnvironmentService],
|
||||
}),
|
||||
LLMTracingModule.forRoot({
|
||||
useFactory: llmTracingModuleFactory,
|
||||
inject: [EnvironmentService],
|
||||
}),
|
||||
],
|
||||
exports: [],
|
||||
providers: [],
|
||||
|
||||
@ -0,0 +1,5 @@
|
||||
import { BaseChatModel } from '@langchain/core/language_models/chat_models';
|
||||
|
||||
export interface LLMChatModelDriver {
|
||||
getJSONChatModel(): BaseChatModel;
|
||||
}
|
||||
@ -0,0 +1,22 @@
|
||||
import { BaseChatModel } from '@langchain/core/language_models/chat_models';
|
||||
import { ChatOpenAI } from '@langchain/openai';
|
||||
|
||||
import { LLMChatModelDriver } from 'src/engine/integrations/llm-chat-model/drivers/interfaces/llm-prompt-template-driver.interface';
|
||||
|
||||
export class OpenAIDriver implements LLMChatModelDriver {
|
||||
private chatModel: BaseChatModel;
|
||||
|
||||
constructor() {
|
||||
this.chatModel = new ChatOpenAI({
|
||||
model: 'gpt-4o',
|
||||
}).bind({
|
||||
response_format: {
|
||||
type: 'json_object',
|
||||
},
|
||||
}) as unknown as BaseChatModel;
|
||||
}
|
||||
|
||||
getJSONChatModel() {
|
||||
return this.chatModel;
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,14 @@
|
||||
import { ModuleMetadata, FactoryProvider } from '@nestjs/common';
|
||||
|
||||
export enum LLMChatModelDriver {
|
||||
OpenAI = 'openai',
|
||||
}
|
||||
|
||||
export interface LLMChatModelModuleOptions {
|
||||
type: LLMChatModelDriver;
|
||||
}
|
||||
|
||||
export type LLMChatModelModuleAsyncOptions = {
|
||||
useFactory: (...args: any[]) => LLMChatModelModuleOptions;
|
||||
} & Pick<ModuleMetadata, 'imports'> &
|
||||
Pick<FactoryProvider, 'inject'>;
|
||||
@ -0,0 +1 @@
|
||||
export const LLM_CHAT_MODEL_DRIVER = Symbol('LLM_CHAT_MODEL_DRIVER');
|
||||
@ -0,0 +1,19 @@
|
||||
import { LLMChatModelDriver } from 'src/engine/integrations/llm-chat-model/interfaces/llm-chat-model.interface';
|
||||
|
||||
import { EnvironmentService } from 'src/engine/integrations/environment/environment.service';
|
||||
|
||||
export const llmChatModelModuleFactory = (
|
||||
environmentService: EnvironmentService,
|
||||
) => {
|
||||
const driver = environmentService.get('LLM_CHAT_MODEL_DRIVER');
|
||||
|
||||
switch (driver) {
|
||||
case LLMChatModelDriver.OpenAI: {
|
||||
return { type: LLMChatModelDriver.OpenAI };
|
||||
}
|
||||
default:
|
||||
throw new Error(
|
||||
`Invalid LLM chat model driver (${driver}), check your .env file`,
|
||||
);
|
||||
}
|
||||
};
|
||||
@ -0,0 +1,35 @@
|
||||
import { DynamicModule, Global } from '@nestjs/common';
|
||||
|
||||
import {
|
||||
LLMChatModelDriver,
|
||||
LLMChatModelModuleAsyncOptions,
|
||||
} from 'src/engine/integrations/llm-chat-model/interfaces/llm-chat-model.interface';
|
||||
|
||||
import { LLM_CHAT_MODEL_DRIVER } from 'src/engine/integrations/llm-chat-model/llm-chat-model.constants';
|
||||
import { OpenAIDriver } from 'src/engine/integrations/llm-chat-model/drivers/openai.driver';
|
||||
import { LLMChatModelService } from 'src/engine/integrations/llm-chat-model/llm-chat-model.service';
|
||||
|
||||
@Global()
|
||||
export class LLMChatModelModule {
|
||||
static forRoot(options: LLMChatModelModuleAsyncOptions): DynamicModule {
|
||||
const provider = {
|
||||
provide: LLM_CHAT_MODEL_DRIVER,
|
||||
useFactory: (...args: any[]) => {
|
||||
const config = options.useFactory(...args);
|
||||
|
||||
switch (config.type) {
|
||||
case LLMChatModelDriver.OpenAI: {
|
||||
return new OpenAIDriver();
|
||||
}
|
||||
}
|
||||
},
|
||||
inject: options.inject || [],
|
||||
};
|
||||
|
||||
return {
|
||||
module: LLMChatModelModule,
|
||||
providers: [LLMChatModelService, provider],
|
||||
exports: [LLMChatModelService],
|
||||
};
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,16 @@
|
||||
import { Injectable, Inject } from '@nestjs/common';
|
||||
|
||||
import { LLMChatModelDriver } from 'src/engine/integrations/llm-chat-model/drivers/interfaces/llm-prompt-template-driver.interface';
|
||||
|
||||
import { LLM_CHAT_MODEL_DRIVER } from 'src/engine/integrations/llm-chat-model/llm-chat-model.constants';
|
||||
|
||||
@Injectable()
|
||||
export class LLMChatModelService {
|
||||
constructor(
|
||||
@Inject(LLM_CHAT_MODEL_DRIVER) private driver: LLMChatModelDriver,
|
||||
) {}
|
||||
|
||||
getJSONChatModel() {
|
||||
return this.driver.getJSONChatModel();
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,25 @@
|
||||
import { BaseCallbackHandler } from '@langchain/core/callbacks/base';
|
||||
import { ConsoleCallbackHandler } from '@langchain/core/tracers/console';
|
||||
import { Run } from '@langchain/core/tracers/base';
|
||||
|
||||
import { LLMTracingDriver } from 'src/engine/integrations/llm-tracing/drivers/interfaces/llm-tracing-driver.interface';
|
||||
|
||||
class WithMetadataConsoleCallbackHandler extends ConsoleCallbackHandler {
|
||||
private metadata: Record<string, unknown>;
|
||||
|
||||
constructor(metadata: Record<string, unknown>) {
|
||||
super();
|
||||
this.metadata = metadata;
|
||||
}
|
||||
|
||||
onChainStart(run: Run) {
|
||||
console.log(`Chain metadata: ${JSON.stringify(this.metadata)}`);
|
||||
super.onChainStart(run);
|
||||
}
|
||||
}
|
||||
|
||||
export class ConsoleDriver implements LLMTracingDriver {
|
||||
getCallbackHandler(metadata: Record<string, unknown>): BaseCallbackHandler {
|
||||
return new WithMetadataConsoleCallbackHandler(metadata);
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,5 @@
|
||||
import { BaseCallbackHandler } from '@langchain/core/callbacks/base';
|
||||
|
||||
export interface LLMTracingDriver {
|
||||
getCallbackHandler(metadata: Record<string, unknown>): BaseCallbackHandler;
|
||||
}
|
||||
@ -0,0 +1,26 @@
|
||||
import { BaseCallbackHandler } from '@langchain/core/callbacks/base';
|
||||
import CallbackHandler from 'langfuse-langchain';
|
||||
|
||||
import { LLMTracingDriver } from 'src/engine/integrations/llm-tracing/drivers/interfaces/llm-tracing-driver.interface';
|
||||
|
||||
export interface LangfuseDriverOptions {
|
||||
secretKey: string;
|
||||
publicKey: string;
|
||||
}
|
||||
|
||||
export class LangfuseDriver implements LLMTracingDriver {
|
||||
private options: LangfuseDriverOptions;
|
||||
|
||||
constructor(options: LangfuseDriverOptions) {
|
||||
this.options = options;
|
||||
}
|
||||
|
||||
getCallbackHandler(metadata: Record<string, unknown>): BaseCallbackHandler {
|
||||
return new CallbackHandler({
|
||||
secretKey: this.options.secretKey,
|
||||
publicKey: this.options.publicKey,
|
||||
baseUrl: 'https://cloud.langfuse.com',
|
||||
metadata: metadata,
|
||||
});
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,26 @@
|
||||
import { ModuleMetadata, FactoryProvider } from '@nestjs/common';
|
||||
|
||||
import { LangfuseDriverOptions } from 'src/engine/integrations/llm-tracing/drivers/langfuse.driver';
|
||||
|
||||
export enum LLMTracingDriver {
|
||||
Langfuse = 'langfuse',
|
||||
Console = 'console',
|
||||
}
|
||||
|
||||
export interface LangfuseDriverFactoryOptions {
|
||||
type: LLMTracingDriver.Langfuse;
|
||||
options: LangfuseDriverOptions;
|
||||
}
|
||||
|
||||
export interface ConsoleDriverFactoryOptions {
|
||||
type: LLMTracingDriver.Console;
|
||||
}
|
||||
|
||||
export type LLMTracingModuleOptions =
|
||||
| LangfuseDriverFactoryOptions
|
||||
| ConsoleDriverFactoryOptions;
|
||||
|
||||
export type LLMTracingModuleAsyncOptions = {
|
||||
useFactory: (...args: any[]) => LLMTracingModuleOptions;
|
||||
} & Pick<ModuleMetadata, 'imports'> &
|
||||
Pick<FactoryProvider, 'inject'>;
|
||||
@ -0,0 +1 @@
|
||||
export const LLM_TRACING_DRIVER = Symbol('LLM_TRACING_DRIVER');
|
||||
@ -0,0 +1,34 @@
|
||||
import { LLMTracingDriver } from 'src/engine/integrations/llm-tracing/interfaces/llm-tracing.interface';
|
||||
|
||||
import { EnvironmentService } from 'src/engine/integrations/environment/environment.service';
|
||||
|
||||
export const llmTracingModuleFactory = (
|
||||
environmentService: EnvironmentService,
|
||||
) => {
|
||||
const driver = environmentService.get('LLM_TRACING_DRIVER');
|
||||
|
||||
switch (driver) {
|
||||
case LLMTracingDriver.Console: {
|
||||
return { type: LLMTracingDriver.Console as const };
|
||||
}
|
||||
case LLMTracingDriver.Langfuse: {
|
||||
const secretKey = environmentService.get('LANGFUSE_SECRET_KEY');
|
||||
const publicKey = environmentService.get('LANGFUSE_PUBLIC_KEY');
|
||||
|
||||
if (!(secretKey && publicKey)) {
|
||||
throw new Error(
|
||||
`${driver} LLM tracing driver requires LANGFUSE_SECRET_KEY and LANGFUSE_PUBLIC_KEY to be defined, check your .env file`,
|
||||
);
|
||||
}
|
||||
|
||||
return {
|
||||
type: LLMTracingDriver.Langfuse as const,
|
||||
options: { secretKey, publicKey },
|
||||
};
|
||||
}
|
||||
default:
|
||||
throw new Error(
|
||||
`Invalid LLM tracing driver (${driver}), check your .env file`,
|
||||
);
|
||||
}
|
||||
};
|
||||
@ -0,0 +1,39 @@
|
||||
import { Global, DynamicModule } from '@nestjs/common';
|
||||
|
||||
import {
|
||||
LLMTracingModuleAsyncOptions,
|
||||
LLMTracingDriver,
|
||||
} from 'src/engine/integrations/llm-tracing/interfaces/llm-tracing.interface';
|
||||
|
||||
import { LangfuseDriver } from 'src/engine/integrations/llm-tracing/drivers/langfuse.driver';
|
||||
import { ConsoleDriver } from 'src/engine/integrations/llm-tracing/drivers/console.driver';
|
||||
import { LLMTracingService } from 'src/engine/integrations/llm-tracing/llm-tracing.service';
|
||||
import { LLM_TRACING_DRIVER } from 'src/engine/integrations/llm-tracing/llm-tracing.constants';
|
||||
|
||||
@Global()
|
||||
export class LLMTracingModule {
|
||||
static forRoot(options: LLMTracingModuleAsyncOptions): DynamicModule {
|
||||
const provider = {
|
||||
provide: LLM_TRACING_DRIVER,
|
||||
useFactory: (...args: any[]) => {
|
||||
const config = options.useFactory(...args);
|
||||
|
||||
switch (config.type) {
|
||||
case LLMTracingDriver.Langfuse: {
|
||||
return new LangfuseDriver(config.options);
|
||||
}
|
||||
case LLMTracingDriver.Console: {
|
||||
return new ConsoleDriver();
|
||||
}
|
||||
}
|
||||
},
|
||||
inject: options.inject || [],
|
||||
};
|
||||
|
||||
return {
|
||||
module: LLMTracingModule,
|
||||
providers: [LLMTracingService, provider],
|
||||
exports: [LLMTracingService],
|
||||
};
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,16 @@
|
||||
import { Injectable, Inject } from '@nestjs/common';
|
||||
|
||||
import { BaseCallbackHandler } from '@langchain/core/callbacks/base';
|
||||
|
||||
import { LLMTracingDriver } from 'src/engine/integrations/llm-tracing/drivers/interfaces/llm-tracing-driver.interface';
|
||||
|
||||
import { LLM_TRACING_DRIVER } from 'src/engine/integrations/llm-tracing/llm-tracing.constants';
|
||||
|
||||
@Injectable()
|
||||
export class LLMTracingService {
|
||||
constructor(@Inject(LLM_TRACING_DRIVER) private driver: LLMTracingDriver) {}
|
||||
|
||||
getCallbackHandler(metadata: Record<string, unknown>): BaseCallbackHandler {
|
||||
return this.driver.getCallbackHandler(metadata);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user