Text-to-SQL proof of concept (#5788)
Added: - An "Ask AI" command to the command menu. - A simple GraphQL resolver that converts the user's question into a relevant SQL query using an LLM, runs the query, and returns the result. <img width="428" alt="Screenshot 2024-06-09 at 20 53 09" src="https://github.com/twentyhq/twenty/assets/171685816/57127f37-d4a6-498d-b253-733ffa0d209f"> No security concerns have been addressed, this is only a proof-of-concept and not intended to be enabled in production. All changes are behind a feature flag called `IS_ASK_AI_ENABLED`. --------- Co-authored-by: Félix Malfait <felix.malfait@gmail.com>
This commit is contained in:
@ -20,6 +20,13 @@ export type Scalars = {
|
||||
Upload: any;
|
||||
};
|
||||
|
||||
export type AisqlQueryResult = {
|
||||
__typename?: 'AISQLQueryResult';
|
||||
queryFailedErrorMessage?: Maybe<Scalars['String']>;
|
||||
sqlQuery: Scalars['String'];
|
||||
sqlQueryResult?: Maybe<Scalars['String']>;
|
||||
};
|
||||
|
||||
export type ActivateWorkspaceInput = {
|
||||
displayName?: InputMaybe<Scalars['String']>;
|
||||
};
|
||||
@ -526,6 +533,7 @@ export type Query = {
|
||||
currentUser: User;
|
||||
currentWorkspace: Workspace;
|
||||
findWorkspaceFromInviteHash: Workspace;
|
||||
getAISQLQuery: AisqlQueryResult;
|
||||
getPostgresCredentials?: Maybe<PostgresCredentials>;
|
||||
getProductPrices: ProductPricesEntity;
|
||||
getTimelineCalendarEventsFromCompanyId: TimelineCalendarEventsWithTotal;
|
||||
@ -559,6 +567,11 @@ export type QueryFindWorkspaceFromInviteHashArgs = {
|
||||
};
|
||||
|
||||
|
||||
export type QueryGetAisqlQueryArgs = {
|
||||
text: Scalars['String'];
|
||||
};
|
||||
|
||||
|
||||
export type QueryGetProductPricesArgs = {
|
||||
product: Scalars['String'];
|
||||
};
|
||||
@ -1264,6 +1277,13 @@ export type SkipSyncEmailOnboardingStepMutationVariables = Exact<{ [key: string]
|
||||
|
||||
export type SkipSyncEmailOnboardingStepMutation = { __typename?: 'Mutation', skipSyncEmailOnboardingStep: { __typename?: 'OnboardingStepSuccess', success: boolean } };
|
||||
|
||||
export type GetAisqlQueryQueryVariables = Exact<{
|
||||
text: Scalars['String'];
|
||||
}>;
|
||||
|
||||
|
||||
export type GetAisqlQueryQuery = { __typename?: 'Query', getAISQLQuery: { __typename?: 'AISQLQueryResult', sqlQuery: string, sqlQueryResult?: string | null, queryFailedErrorMessage?: string | null } };
|
||||
|
||||
export type UserQueryFragmentFragment = { __typename?: 'User', id: any, firstName: string, lastName: string, email: string, canImpersonate: boolean, supportUserHash?: string | null, onboardingStatus?: OnboardingStatus | null, workspaceMember?: { __typename?: 'WorkspaceMember', id: any, colorScheme: string, avatarUrl?: string | null, locale: string, name: { __typename?: 'FullName', firstName: string, lastName: string } } | null, defaultWorkspace: { __typename?: 'Workspace', id: any, displayName?: string | null, logo?: string | null, domainName?: string | null, inviteHash?: string | null, allowImpersonation: boolean, activationStatus: string, currentCacheVersion?: string | null, workspaceMembersCount?: number | null, featureFlags?: Array<{ __typename?: 'FeatureFlag', id: any, key: string, value: boolean, workspaceId: string }> | null, currentBillingSubscription?: { __typename?: 'BillingSubscription', id: any, status: SubscriptionStatus, interval?: SubscriptionInterval | null } | null }, workspaces: Array<{ __typename?: 'UserWorkspace', workspace?: { __typename?: 'Workspace', id: any, logo?: string | null, displayName?: string | null, domainName?: string | null } | null }> };
|
||||
|
||||
export type DeleteUserAccountMutationVariables = Exact<{ [key: string]: never; }>;
|
||||
@ -2449,6 +2469,43 @@ export function useSkipSyncEmailOnboardingStepMutation(baseOptions?: Apollo.Muta
|
||||
export type SkipSyncEmailOnboardingStepMutationHookResult = ReturnType<typeof useSkipSyncEmailOnboardingStepMutation>;
|
||||
export type SkipSyncEmailOnboardingStepMutationResult = Apollo.MutationResult<SkipSyncEmailOnboardingStepMutation>;
|
||||
export type SkipSyncEmailOnboardingStepMutationOptions = Apollo.BaseMutationOptions<SkipSyncEmailOnboardingStepMutation, SkipSyncEmailOnboardingStepMutationVariables>;
|
||||
export const GetAisqlQueryDocument = gql`
|
||||
query GetAISQLQuery($text: String!) {
|
||||
getAISQLQuery(text: $text) {
|
||||
sqlQuery
|
||||
sqlQueryResult
|
||||
queryFailedErrorMessage
|
||||
}
|
||||
}
|
||||
`;
|
||||
|
||||
/**
|
||||
* __useGetAisqlQueryQuery__
|
||||
*
|
||||
* To run a query within a React component, call `useGetAisqlQueryQuery` and pass it any options that fit your needs.
|
||||
* When your component renders, `useGetAisqlQueryQuery` returns an object from Apollo Client that contains loading, error, and data properties
|
||||
* you can use to render your UI.
|
||||
*
|
||||
* @param baseOptions options that will be passed into the query, supported options are listed on: https://www.apollographql.com/docs/react/api/react-hooks/#options;
|
||||
*
|
||||
* @example
|
||||
* const { data, loading, error } = useGetAisqlQueryQuery({
|
||||
* variables: {
|
||||
* text: // value for 'text'
|
||||
* },
|
||||
* });
|
||||
*/
|
||||
export function useGetAisqlQueryQuery(baseOptions: Apollo.QueryHookOptions<GetAisqlQueryQuery, GetAisqlQueryQueryVariables>) {
|
||||
const options = {...defaultOptions, ...baseOptions}
|
||||
return Apollo.useQuery<GetAisqlQueryQuery, GetAisqlQueryQueryVariables>(GetAisqlQueryDocument, options);
|
||||
}
|
||||
export function useGetAisqlQueryLazyQuery(baseOptions?: Apollo.LazyQueryHookOptions<GetAisqlQueryQuery, GetAisqlQueryQueryVariables>) {
|
||||
const options = {...defaultOptions, ...baseOptions}
|
||||
return Apollo.useLazyQuery<GetAisqlQueryQuery, GetAisqlQueryQueryVariables>(GetAisqlQueryDocument, options);
|
||||
}
|
||||
export type GetAisqlQueryQueryHookResult = ReturnType<typeof useGetAisqlQueryQuery>;
|
||||
export type GetAisqlQueryLazyQueryHookResult = ReturnType<typeof useGetAisqlQueryLazyQuery>;
|
||||
export type GetAisqlQueryQueryResult = Apollo.QueryResult<GetAisqlQueryQuery, GetAisqlQueryQueryVariables>;
|
||||
export const DeleteUserAccountDocument = gql`
|
||||
mutation DeleteUserAccount {
|
||||
deleteUser {
|
||||
|
||||
@ -0,0 +1,54 @@
|
||||
import styled from '@emotion/styled';
|
||||
import { useSetRecoilState } from 'recoil';
|
||||
|
||||
import { copilotQueryState } from '@/activities/copilot/right-drawer/states/copilotQueryState';
|
||||
import {
|
||||
AutosizeTextInput,
|
||||
AutosizeTextInputVariant,
|
||||
} from '@/ui/input/components/AutosizeTextInput';
|
||||
|
||||
const StyledContainer = styled.div`
|
||||
box-sizing: border-box;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
height: 100%;
|
||||
justify-content: flex-start;
|
||||
overflow-y: auto;
|
||||
position: relative;
|
||||
`;
|
||||
|
||||
const StyledChatArea = styled.div`
|
||||
flex: 1;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
overflow-y: scroll;
|
||||
padding: ${({ theme }) => theme.spacing(6)};
|
||||
padding-bottom: 0px;
|
||||
`;
|
||||
|
||||
const StyledNewMessageArea = styled.div`
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
padding: ${({ theme }) => theme.spacing(6)};
|
||||
padding-top: 0px;
|
||||
`;
|
||||
|
||||
export const RightDrawerAIChat = () => {
|
||||
const setCopilotQuery = useSetRecoilState(copilotQueryState);
|
||||
|
||||
return (
|
||||
<StyledContainer>
|
||||
<StyledChatArea>{/* TODO */}</StyledChatArea>
|
||||
<StyledNewMessageArea>
|
||||
<AutosizeTextInput
|
||||
autoFocus
|
||||
placeholder="Ask anything"
|
||||
variant={AutosizeTextInputVariant.Icon}
|
||||
onValidate={(text) => {
|
||||
setCopilotQuery(text);
|
||||
}}
|
||||
/>
|
||||
</StyledNewMessageArea>
|
||||
</StyledContainer>
|
||||
);
|
||||
};
|
||||
@ -0,0 +1,14 @@
|
||||
import { useRightDrawer } from '@/ui/layout/right-drawer/hooks/useRightDrawer';
|
||||
import { RightDrawerHotkeyScope } from '@/ui/layout/right-drawer/types/RightDrawerHotkeyScope';
|
||||
import { RightDrawerPages } from '@/ui/layout/right-drawer/types/RightDrawerPages';
|
||||
import { useSetHotkeyScope } from '@/ui/utilities/hotkey/hooks/useSetHotkeyScope';
|
||||
|
||||
export const useOpenCopilotRightDrawer = () => {
|
||||
const { openRightDrawer } = useRightDrawer();
|
||||
const setHotkeyScope = useSetHotkeyScope();
|
||||
|
||||
return () => {
|
||||
setHotkeyScope(RightDrawerHotkeyScope.RightDrawer, { goto: false });
|
||||
openRightDrawer(RightDrawerPages.Copilot);
|
||||
};
|
||||
};
|
||||
@ -0,0 +1,6 @@
|
||||
import { createState } from 'twenty-ui';
|
||||
|
||||
export const copilotQueryState = createState({
|
||||
key: 'activities/copilot-query',
|
||||
defaultValue: '',
|
||||
});
|
||||
@ -1,10 +1,12 @@
|
||||
import { useMemo, useRef } from 'react';
|
||||
import styled from '@emotion/styled';
|
||||
import { isNonEmptyString } from '@sniptt/guards';
|
||||
import { useRecoilState, useRecoilValue } from 'recoil';
|
||||
import { useRecoilState, useRecoilValue, useSetRecoilState } from 'recoil';
|
||||
import { Key } from 'ts-key-enum';
|
||||
import { Avatar, IconNotes } from 'twenty-ui';
|
||||
import { Avatar, IconNotes, IconSparkles } from 'twenty-ui';
|
||||
|
||||
import { useOpenCopilotRightDrawer } from '@/activities/copilot/right-drawer/hooks/useOpenCopilotRightDrawer';
|
||||
import { copilotQueryState } from '@/activities/copilot/right-drawer/states/copilotQueryState';
|
||||
import { useOpenActivityRightDrawer } from '@/activities/hooks/useOpenActivityRightDrawer';
|
||||
import { Activity } from '@/activities/types/Activity';
|
||||
import { commandMenuSearchState } from '@/command-menu/states/commandMenuSearchState';
|
||||
@ -21,6 +23,7 @@ import { AppHotkeyScope } from '@/ui/utilities/hotkey/types/AppHotkeyScope';
|
||||
import { useListenClickOutside } from '@/ui/utilities/pointer-event/hooks/useListenClickOutside';
|
||||
import { useIsMobile } from '@/ui/utilities/responsive/hooks/useIsMobile';
|
||||
import { ScrollWrapper } from '@/ui/utilities/scroll/components/ScrollWrapper';
|
||||
import { useIsFeatureEnabled } from '@/workspace/hooks/useIsFeatureEnabled';
|
||||
import { getLogoUrlFromDomainName } from '~/utils';
|
||||
import { generateILikeFiltersForCompositeFields } from '~/utils/array/generateILikeFiltersForCompositeFields';
|
||||
import { isDefined } from '~/utils/isDefined';
|
||||
@ -248,8 +251,27 @@ export const CommandMenu = () => {
|
||||
callback: closeCommandMenu,
|
||||
});
|
||||
|
||||
const selectableItemIds = matchingCreateCommand
|
||||
const isCopilotEnabled = useIsFeatureEnabled('IS_COPILOT_ENABLED');
|
||||
const setCopilotQuery = useSetRecoilState(copilotQueryState);
|
||||
const openCopilotRightDrawer = useOpenCopilotRightDrawer();
|
||||
|
||||
const copilotCommand: Command = {
|
||||
id: 'copilot',
|
||||
to: '', // TODO
|
||||
Icon: IconSparkles,
|
||||
label: 'Open Copilot',
|
||||
type: CommandType.Navigate,
|
||||
onCommandClick: () => {
|
||||
setCopilotQuery(commandMenuSearch);
|
||||
openCopilotRightDrawer();
|
||||
},
|
||||
};
|
||||
|
||||
const copilotCommands: Command[] = isCopilotEnabled ? [copilotCommand] : [];
|
||||
|
||||
const selectableItemIds = copilotCommands
|
||||
.map((cmd) => cmd.id)
|
||||
.concat(matchingCreateCommand.map((cmd) => cmd.id))
|
||||
.concat(matchingNavigateCommand.map((cmd) => cmd.id))
|
||||
.concat(people.map((person) => person.id))
|
||||
.concat(companies.map((company) => company.id))
|
||||
@ -275,6 +297,7 @@ export const CommandMenu = () => {
|
||||
hotkeyScope={AppHotkeyScope.CommandMenu}
|
||||
onEnter={(itemId) => {
|
||||
const command = [
|
||||
...copilotCommands,
|
||||
...commandMenuCommands,
|
||||
...otherCommands,
|
||||
].find((cmd) => cmd.id === itemId);
|
||||
@ -292,6 +315,22 @@ export const CommandMenu = () => {
|
||||
!activities.length && (
|
||||
<StyledEmpty>No results found</StyledEmpty>
|
||||
)}
|
||||
{isCopilotEnabled && (
|
||||
<CommandGroup heading="Copilot">
|
||||
<SelectableItem itemId={copilotCommand.id}>
|
||||
<CommandMenuItem
|
||||
id={copilotCommand.id}
|
||||
Icon={copilotCommand.Icon}
|
||||
label={`${copilotCommand.label} ${
|
||||
commandMenuSearch.length > 2
|
||||
? `"${commandMenuSearch}"`
|
||||
: ''
|
||||
}`}
|
||||
onClick={copilotCommand.onCommandClick}
|
||||
/>
|
||||
</SelectableItem>
|
||||
</CommandGroup>
|
||||
)}
|
||||
<CommandGroup heading="Create">
|
||||
{matchingCreateCommand.map((cmd) => (
|
||||
<SelectableItem itemId={cmd.id} key={cmd.id}>
|
||||
|
||||
@ -0,0 +1,11 @@
|
||||
import { gql } from '@apollo/client';
|
||||
|
||||
export const getCopilot = gql`
|
||||
query GetAISQLQuery($text: String!) {
|
||||
getAISQLQuery(text: $text) {
|
||||
sqlQuery
|
||||
sqlQueryResult
|
||||
queryFailedErrorMessage
|
||||
}
|
||||
}
|
||||
`;
|
||||
@ -30,6 +30,8 @@ type AutosizeTextInputProps = {
|
||||
value?: string;
|
||||
className?: string;
|
||||
onBlur?: () => void;
|
||||
autoFocus?: boolean;
|
||||
disabled?: boolean;
|
||||
};
|
||||
|
||||
const StyledContainer = styled.div`
|
||||
@ -123,6 +125,8 @@ export const AutosizeTextInput = ({
|
||||
value = '',
|
||||
className,
|
||||
onBlur,
|
||||
autoFocus,
|
||||
disabled,
|
||||
}: AutosizeTextInputProps) => {
|
||||
const [isFocused, setIsFocused] = useState(false);
|
||||
const [isHidden, setIsHidden] = useState(
|
||||
@ -212,7 +216,9 @@ export const AutosizeTextInput = ({
|
||||
{!isHidden && (
|
||||
<StyledTextArea
|
||||
ref={textInputRef}
|
||||
autoFocus={variant === AutosizeTextInputVariant.Button}
|
||||
autoFocus={
|
||||
autoFocus || variant === AutosizeTextInputVariant.Button
|
||||
}
|
||||
placeholder={placeholder ?? 'Write a comment'}
|
||||
maxRows={MAX_ROWS}
|
||||
minRows={computedMinRows}
|
||||
@ -221,6 +227,7 @@ export const AutosizeTextInput = ({
|
||||
onFocus={handleFocus}
|
||||
onBlur={handleBlur}
|
||||
variant={variant}
|
||||
disabled={disabled}
|
||||
/>
|
||||
)}
|
||||
{variant === AutosizeTextInputVariant.Icon && (
|
||||
|
||||
@ -2,6 +2,7 @@ import styled from '@emotion/styled';
|
||||
import { useRecoilState, useRecoilValue } from 'recoil';
|
||||
|
||||
import { RightDrawerCalendarEvent } from '@/activities/calendar/right-drawer/components/RightDrawerCalendarEvent';
|
||||
import { RightDrawerAIChat } from '@/activities/copilot/right-drawer/components/RightDrawerAIChat';
|
||||
import { RightDrawerEmailThread } from '@/activities/emails/right-drawer/components/RightDrawerEmailThread';
|
||||
import { RightDrawerCreateActivity } from '@/activities/right-drawer/components/create/RightDrawerCreateActivity';
|
||||
import { RightDrawerEditActivity } from '@/activities/right-drawer/components/edit/RightDrawerEditActivity';
|
||||
@ -50,6 +51,10 @@ const RIGHT_DRAWER_PAGES_CONFIG = {
|
||||
page: <RightDrawerRecord />,
|
||||
topBar: <RightDrawerTopBar page={RightDrawerPages.ViewRecord} />,
|
||||
},
|
||||
[RightDrawerPages.Copilot]: {
|
||||
page: <RightDrawerAIChat />,
|
||||
topBar: <RightDrawerTopBar page={RightDrawerPages.Copilot} />,
|
||||
},
|
||||
};
|
||||
|
||||
export const RightDrawerRouter = () => {
|
||||
|
||||
@ -6,4 +6,5 @@ export const RIGHT_DRAWER_PAGE_ICONS = {
|
||||
[RightDrawerPages.ViewEmailThread]: 'IconMail',
|
||||
[RightDrawerPages.ViewCalendarEvent]: 'IconCalendarEvent',
|
||||
[RightDrawerPages.ViewRecord]: 'Icon123',
|
||||
[RightDrawerPages.Copilot]: 'IconSparkles',
|
||||
};
|
||||
|
||||
@ -6,4 +6,5 @@ export const RIGHT_DRAWER_PAGE_TITLES = {
|
||||
[RightDrawerPages.ViewEmailThread]: 'Email Thread',
|
||||
[RightDrawerPages.ViewCalendarEvent]: 'Calendar Event',
|
||||
[RightDrawerPages.ViewRecord]: 'Record Editor',
|
||||
[RightDrawerPages.Copilot]: 'Copilot',
|
||||
};
|
||||
|
||||
@ -4,4 +4,5 @@ export enum RightDrawerPages {
|
||||
ViewEmailThread = 'view-email-thread',
|
||||
ViewCalendarEvent = 'view-calendar-event',
|
||||
ViewRecord = 'view-record',
|
||||
Copilot = 'copilot',
|
||||
}
|
||||
|
||||
@ -11,7 +11,7 @@ type TabProps = {
|
||||
className?: string;
|
||||
onClick?: () => void;
|
||||
disabled?: boolean;
|
||||
hasBetaPill?: boolean;
|
||||
pill?: string;
|
||||
};
|
||||
|
||||
const StyledTab = styled.div<{ active?: boolean; disabled?: boolean }>`
|
||||
@ -59,7 +59,7 @@ export const Tab = ({
|
||||
onClick,
|
||||
className,
|
||||
disabled,
|
||||
hasBetaPill,
|
||||
pill,
|
||||
}: TabProps) => {
|
||||
const theme = useTheme();
|
||||
return (
|
||||
@ -73,7 +73,7 @@ export const Tab = ({
|
||||
<StyledHover>
|
||||
{Icon && <Icon size={theme.icon.size.md} />}
|
||||
{title}
|
||||
{hasBetaPill && <Pill label="Beta" />}
|
||||
{pill && <Pill label={pill} />}
|
||||
</StyledHover>
|
||||
</StyledTab>
|
||||
);
|
||||
|
||||
@ -15,7 +15,7 @@ type SingleTabProps = {
|
||||
id: string;
|
||||
hide?: boolean;
|
||||
disabled?: boolean;
|
||||
hasBetaPill?: boolean;
|
||||
pill?: string;
|
||||
};
|
||||
|
||||
type TabListProps = {
|
||||
@ -62,7 +62,7 @@ export const TabList = ({ tabs, tabListId, loading }: TabListProps) => {
|
||||
setActiveTabId(tab.id);
|
||||
}}
|
||||
disabled={tab.disabled ?? loading}
|
||||
hasBetaPill={tab.hasBetaPill}
|
||||
pill={tab.pill}
|
||||
/>
|
||||
))}
|
||||
</StyledContainer>
|
||||
|
||||
@ -4,4 +4,5 @@ export type FeatureFlagKey =
|
||||
| 'IS_EVENT_OBJECT_ENABLED'
|
||||
| 'IS_AIRTABLE_INTEGRATION_ENABLED'
|
||||
| 'IS_POSTGRESQL_INTEGRATION_ENABLED'
|
||||
| 'IS_STRIPE_INTEGRATION_ENABLED';
|
||||
| 'IS_STRIPE_INTEGRATION_ENABLED'
|
||||
| 'IS_COPILOT_ENABLED';
|
||||
|
||||
@ -72,4 +72,4 @@ SIGN_IN_PREFILLED=true
|
||||
# API_RATE_LIMITING_LIMIT=
|
||||
# MUTATION_MAXIMUM_AFFECTED_RECORDS=100
|
||||
# CHROME_EXTENSION_ID=bggmipldbceihilonnbpgoeclgbkblkp
|
||||
# PG_SSL_ALLOW_SELF_SIGNED=true
|
||||
# PG_SSL_ALLOW_SELF_SIGNED=true
|
||||
@ -15,6 +15,8 @@
|
||||
},
|
||||
"dependencies": {
|
||||
"@graphql-yoga/nestjs": "patch:@graphql-yoga/nestjs@2.1.0#./patches/@graphql-yoga-nestjs-npm-2.1.0-cb509e6047.patch",
|
||||
"@langchain/mistralai": "^0.0.24",
|
||||
"@langchain/openai": "^0.1.3",
|
||||
"@nestjs/cache-manager": "^2.2.1",
|
||||
"@nestjs/devtools-integration": "^0.1.6",
|
||||
"@nestjs/graphql": "patch:@nestjs/graphql@12.1.1#./patches/@nestjs+graphql+12.1.1.patch",
|
||||
@ -25,13 +27,16 @@
|
||||
"graphql-middleware": "^6.1.35",
|
||||
"jsdom": "~22.1.0",
|
||||
"jwt-decode": "^4.0.0",
|
||||
"langchain": "^0.2.6",
|
||||
"langfuse-langchain": "^3.11.2",
|
||||
"lodash.differencewith": "^4.5.0",
|
||||
"lodash.omitby": "^4.6.0",
|
||||
"lodash.uniq": "^4.5.0",
|
||||
"lodash.uniqby": "^4.7.0",
|
||||
"passport": "^0.7.0",
|
||||
"psl": "^1.9.0",
|
||||
"tsconfig-paths": "^4.2.0"
|
||||
"tsconfig-paths": "^4.2.0",
|
||||
"zod-to-json-schema": "^3.23.1"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@nestjs/cli": "10.3.0",
|
||||
|
||||
@ -7,6 +7,7 @@ import {
|
||||
import { EventEmitter2 } from '@nestjs/event-emitter';
|
||||
|
||||
import isEmpty from 'lodash.isempty';
|
||||
import { DataSource } from 'typeorm';
|
||||
|
||||
import { IConnection } from 'src/engine/api/graphql/workspace-query-runner/interfaces/connection.interface';
|
||||
import {
|
||||
@ -620,15 +621,12 @@ export class WorkspaceQueryRunnerService {
|
||||
return sanitizedRecord;
|
||||
}
|
||||
|
||||
async execute(
|
||||
query: string,
|
||||
async executeSQL(
|
||||
workspaceDataSource: DataSource,
|
||||
workspaceId: string,
|
||||
): Promise<PGGraphQLResult | undefined> {
|
||||
const workspaceDataSource =
|
||||
await this.workspaceDataSourceService.connectToWorkspaceDataSource(
|
||||
workspaceId,
|
||||
);
|
||||
|
||||
sqlQuery: string,
|
||||
parameters?: any[],
|
||||
) {
|
||||
try {
|
||||
return await workspaceDataSource?.transaction(
|
||||
async (transactionManager) => {
|
||||
@ -638,10 +636,7 @@ export class WorkspaceQueryRunnerService {
|
||||
)};
|
||||
`);
|
||||
|
||||
const results = transactionManager.query<PGGraphQLResult>(
|
||||
`SELECT graphql.resolve($1);`,
|
||||
[query],
|
||||
);
|
||||
const results = transactionManager.query(sqlQuery, parameters);
|
||||
|
||||
return results;
|
||||
},
|
||||
@ -655,6 +650,23 @@ export class WorkspaceQueryRunnerService {
|
||||
}
|
||||
}
|
||||
|
||||
async execute(
|
||||
query: string,
|
||||
workspaceId: string,
|
||||
): Promise<PGGraphQLResult | undefined> {
|
||||
const workspaceDataSource =
|
||||
await this.workspaceDataSourceService.connectToWorkspaceDataSource(
|
||||
workspaceId,
|
||||
);
|
||||
|
||||
return this.executeSQL(
|
||||
workspaceDataSource,
|
||||
workspaceId,
|
||||
`SELECT graphql.resolve($1);`,
|
||||
[query],
|
||||
);
|
||||
}
|
||||
|
||||
private async parseResult<Result>(
|
||||
graphqlResult: PGGraphQLResult | undefined,
|
||||
objectMetadataItem: ObjectMetadataInterface,
|
||||
|
||||
@ -0,0 +1,30 @@
|
||||
import { Module } from '@nestjs/common';
|
||||
import { TypeOrmModule } from '@nestjs/typeorm';
|
||||
|
||||
import { WorkspaceDataSourceModule } from 'src/engine/workspace-datasource/workspace-datasource.module';
|
||||
import { UserModule } from 'src/engine/core-modules/user/user.module';
|
||||
import { AISQLQueryResolver } from 'src/engine/core-modules/ai-sql-query/ai-sql-query.resolver';
|
||||
import { AISQLQueryService } from 'src/engine/core-modules/ai-sql-query/ai-sql-query.service';
|
||||
import { FeatureFlagEntity } from 'src/engine/core-modules/feature-flag/feature-flag.entity';
|
||||
import { WorkspaceQueryRunnerModule } from 'src/engine/api/graphql/workspace-query-runner/workspace-query-runner.module';
|
||||
import { LLMChatModelModule } from 'src/engine/integrations/llm-chat-model/llm-chat-model.module';
|
||||
import { EnvironmentModule } from 'src/engine/integrations/environment/environment.module';
|
||||
import { LLMTracingModule } from 'src/engine/integrations/llm-tracing/llm-tracing.module';
|
||||
import { ObjectMetadataModule } from 'src/engine/metadata-modules/object-metadata/object-metadata.module';
|
||||
import { WorkspaceSyncMetadataModule } from 'src/engine/workspace-manager/workspace-sync-metadata/workspace-sync-metadata.module';
|
||||
@Module({
|
||||
imports: [
|
||||
WorkspaceDataSourceModule,
|
||||
WorkspaceQueryRunnerModule,
|
||||
UserModule,
|
||||
TypeOrmModule.forFeature([FeatureFlagEntity], 'core'),
|
||||
LLMChatModelModule,
|
||||
LLMTracingModule,
|
||||
EnvironmentModule,
|
||||
ObjectMetadataModule,
|
||||
WorkspaceSyncMetadataModule,
|
||||
],
|
||||
exports: [],
|
||||
providers: [AISQLQueryResolver, AISQLQueryService],
|
||||
})
|
||||
export class AISQLQueryModule {}
|
||||
@ -0,0 +1,14 @@
|
||||
import { PromptTemplate } from '@langchain/core/prompts';
|
||||
|
||||
export const sqlGenerationPromptTemplate = PromptTemplate.fromTemplate<{
|
||||
llmOutputJsonSchema: string;
|
||||
sqlCreateTableStatements: string;
|
||||
userQuestion: string;
|
||||
}>(`Always respond following this JSON Schema: {llmOutputJsonSchema}
|
||||
|
||||
Based on the table schema below, write a PostgreSQL query that would answer the user's question. All column names must be enclosed in double quotes.
|
||||
|
||||
{sqlCreateTableStatements}
|
||||
|
||||
Question: {userQuestion}
|
||||
SQL Query:`);
|
||||
@ -0,0 +1,64 @@
|
||||
import { Args, Query, Resolver, ArgsType, Field } from '@nestjs/graphql';
|
||||
import { ForbiddenException, UseGuards } from '@nestjs/common';
|
||||
import { InjectRepository } from '@nestjs/typeorm';
|
||||
|
||||
import { Repository } from 'typeorm';
|
||||
|
||||
import { User } from 'src/engine/core-modules/user/user.entity';
|
||||
import { JwtAuthGuard } from 'src/engine/guards/jwt.auth.guard';
|
||||
import { Workspace } from 'src/engine/core-modules/workspace/workspace.entity';
|
||||
import { AuthWorkspace } from 'src/engine/decorators/auth/auth-workspace.decorator';
|
||||
import {
|
||||
FeatureFlagEntity,
|
||||
FeatureFlagKeys,
|
||||
} from 'src/engine/core-modules/feature-flag/feature-flag.entity';
|
||||
import { AuthUser } from 'src/engine/decorators/auth/auth-user.decorator';
|
||||
import { AISQLQueryResult } from 'src/engine/core-modules/ai-sql-query/dtos/ai-sql-query-result.dto';
|
||||
import { AISQLQueryService } from 'src/engine/core-modules/ai-sql-query/ai-sql-query.service';
|
||||
|
||||
@ArgsType()
|
||||
class GetAISQLQueryArgs {
|
||||
@Field(() => String)
|
||||
text: string;
|
||||
}
|
||||
|
||||
@UseGuards(JwtAuthGuard)
|
||||
@Resolver(() => AISQLQueryResult)
|
||||
export class AISQLQueryResolver {
|
||||
constructor(
|
||||
private readonly aiSqlQueryService: AISQLQueryService,
|
||||
@InjectRepository(FeatureFlagEntity, 'core')
|
||||
private readonly featureFlagRepository: Repository<FeatureFlagEntity>,
|
||||
) {}
|
||||
|
||||
@Query(() => AISQLQueryResult)
|
||||
async getAISQLQuery(
|
||||
@AuthWorkspace() { id: workspaceId }: Workspace,
|
||||
@AuthUser() user: User,
|
||||
@Args() { text }: GetAISQLQueryArgs,
|
||||
) {
|
||||
const isCopilotEnabledFeatureFlag =
|
||||
await this.featureFlagRepository.findOneBy({
|
||||
workspaceId,
|
||||
key: FeatureFlagKeys.IsCopilotEnabled,
|
||||
value: true,
|
||||
});
|
||||
|
||||
if (!isCopilotEnabledFeatureFlag?.value) {
|
||||
throw new ForbiddenException(
|
||||
`${FeatureFlagKeys.IsCopilotEnabled} feature flag is disabled`,
|
||||
);
|
||||
}
|
||||
|
||||
const traceMetadata = {
|
||||
userId: user.id,
|
||||
userEmail: user.email,
|
||||
};
|
||||
|
||||
return this.aiSqlQueryService.generateAndExecute(
|
||||
workspaceId,
|
||||
text,
|
||||
traceMetadata,
|
||||
);
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,253 @@
|
||||
import { Injectable, Logger } from '@nestjs/common';
|
||||
|
||||
import { RunnableSequence } from '@langchain/core/runnables';
|
||||
import { StructuredOutputParser } from '@langchain/core/output_parsers';
|
||||
import { DataSource, QueryFailedError } from 'typeorm';
|
||||
import { z } from 'zod';
|
||||
import { zodToJsonSchema } from 'zod-to-json-schema';
|
||||
import { PostgresConnectionOptions } from 'typeorm/driver/postgres/PostgresConnectionOptions';
|
||||
import groupBy from 'lodash.groupby';
|
||||
|
||||
import { PartialFieldMetadata } from 'src/engine/workspace-manager/workspace-sync-metadata/interfaces/partial-field-metadata.interface';
|
||||
|
||||
import { WorkspaceDataSourceService } from 'src/engine/workspace-datasource/workspace-datasource.service';
|
||||
import { WorkspaceQueryRunnerService } from 'src/engine/api/graphql/workspace-query-runner/workspace-query-runner.service';
|
||||
import { LLMChatModelService } from 'src/engine/integrations/llm-chat-model/llm-chat-model.service';
|
||||
import { LLMTracingService } from 'src/engine/integrations/llm-tracing/llm-tracing.service';
|
||||
import { ObjectMetadataEntity } from 'src/engine/metadata-modules/object-metadata/object-metadata.entity';
|
||||
import { DEFAULT_LABEL_IDENTIFIER_FIELD_NAME } from 'src/engine/metadata-modules/object-metadata/object-metadata.constants';
|
||||
import { StandardObjectFactory } from 'src/engine/workspace-manager/workspace-sync-metadata/factories/standard-object.factory';
|
||||
import { standardObjectMetadataDefinitions } from 'src/engine/workspace-manager/workspace-sync-metadata/standard-objects';
|
||||
import { AISQLQueryResult } from 'src/engine/core-modules/ai-sql-query/dtos/ai-sql-query-result.dto';
|
||||
import { sqlGenerationPromptTemplate } from 'src/engine/core-modules/ai-sql-query/ai-sql-query.prompt-templates';
|
||||
|
||||
@Injectable()
|
||||
export class AISQLQueryService {
|
||||
private readonly logger = new Logger(AISQLQueryService.name);
|
||||
constructor(
|
||||
private readonly workspaceDataSourceService: WorkspaceDataSourceService,
|
||||
private readonly workspaceQueryRunnerService: WorkspaceQueryRunnerService,
|
||||
private readonly llmChatModelService: LLMChatModelService,
|
||||
private readonly llmTracingService: LLMTracingService,
|
||||
private readonly standardObjectFactory: StandardObjectFactory,
|
||||
) {}
|
||||
|
||||
private getLabelIdentifierName(
|
||||
objectMetadata: ObjectMetadataEntity,
|
||||
dataSourceId,
|
||||
workspaceId,
|
||||
workspaceFeatureFlagsMap,
|
||||
): string | undefined {
|
||||
const customObjectLabelIdentifierFieldMetadata = objectMetadata.fields.find(
|
||||
(fieldMetadata) =>
|
||||
fieldMetadata.id === objectMetadata.labelIdentifierFieldMetadataId,
|
||||
);
|
||||
|
||||
const standardObjectMetadataCollection = this.standardObjectFactory.create(
|
||||
standardObjectMetadataDefinitions,
|
||||
{ workspaceId, dataSourceId },
|
||||
workspaceFeatureFlagsMap,
|
||||
);
|
||||
|
||||
const standardObjectLabelIdentifierFieldMetadata =
|
||||
standardObjectMetadataCollection
|
||||
.find(
|
||||
(standardObjectMetadata) =>
|
||||
standardObjectMetadata.nameSingular === objectMetadata.nameSingular,
|
||||
)
|
||||
?.fields.find(
|
||||
(field: PartialFieldMetadata) =>
|
||||
field.name === DEFAULT_LABEL_IDENTIFIER_FIELD_NAME,
|
||||
) as PartialFieldMetadata;
|
||||
|
||||
const labelIdentifierFieldMetadata =
|
||||
customObjectLabelIdentifierFieldMetadata ??
|
||||
standardObjectLabelIdentifierFieldMetadata;
|
||||
|
||||
return (
|
||||
labelIdentifierFieldMetadata?.name ?? DEFAULT_LABEL_IDENTIFIER_FIELD_NAME
|
||||
);
|
||||
}
|
||||
|
||||
private async getColInfosByTableName(dataSource: DataSource) {
|
||||
const { schema } = dataSource.options as PostgresConnectionOptions;
|
||||
|
||||
// From LangChain sql_utils.ts
|
||||
const sqlQuery = `SELECT
|
||||
t.table_name,
|
||||
c.*
|
||||
FROM
|
||||
information_schema.tables t
|
||||
JOIN information_schema.columns c
|
||||
ON t.table_name = c.table_name
|
||||
WHERE
|
||||
t.table_schema = '${schema}'
|
||||
AND c.table_schema = '${schema}'
|
||||
ORDER BY
|
||||
t.table_name,
|
||||
c.ordinal_position;`;
|
||||
const colInfos = await dataSource.query<
|
||||
{
|
||||
table_name: string;
|
||||
column_name: string;
|
||||
data_type: string | undefined;
|
||||
is_nullable: 'YES' | 'NO';
|
||||
}[]
|
||||
>(sqlQuery);
|
||||
|
||||
return groupBy(colInfos, (colInfo) => colInfo.table_name);
|
||||
}
|
||||
|
||||
private getCreateTableStatement(tableName: string, colInfos: any[]) {
|
||||
return `${`CREATE TABLE ${tableName} (\n`} ${colInfos
|
||||
.map(
|
||||
(colInfo) =>
|
||||
`${colInfo.column_name} ${colInfo.data_type} ${
|
||||
colInfo.is_nullable === 'YES' ? '' : 'NOT NULL'
|
||||
}`,
|
||||
)
|
||||
.join(', ')});`;
|
||||
}
|
||||
|
||||
private getRelationDescriptions() {
|
||||
// TODO - Construct sentences like the following:
|
||||
// investorId: a foreign key referencing the person table, indicating the investor who owns this portfolio company.
|
||||
return '';
|
||||
}
|
||||
|
||||
private getTableDescription(tableName: string, colInfos: any[]) {
|
||||
return [
|
||||
this.getCreateTableStatement(tableName, colInfos),
|
||||
this.getRelationDescriptions(),
|
||||
].join('\n');
|
||||
}
|
||||
|
||||
private async getWorkspaceSchemaDescription(
|
||||
dataSource: DataSource,
|
||||
): Promise<string> {
|
||||
const colInfoByTableName = await this.getColInfosByTableName(dataSource);
|
||||
|
||||
return Object.entries(colInfoByTableName)
|
||||
.map(([tableName, colInfos]) =>
|
||||
this.getTableDescription(tableName, colInfos),
|
||||
)
|
||||
.join('\n\n');
|
||||
}
|
||||
|
||||
private async generateWithDataSource(
|
||||
workspaceId: string,
|
||||
workspaceDataSource: DataSource,
|
||||
userQuestion: string,
|
||||
traceMetadata: Record<string, string> = {},
|
||||
) {
|
||||
const workspaceSchemaName =
|
||||
this.workspaceDataSourceService.getSchemaName(workspaceId);
|
||||
|
||||
workspaceDataSource.setOptions({
|
||||
schema: workspaceSchemaName,
|
||||
});
|
||||
|
||||
const workspaceSchemaDescription =
|
||||
await this.getWorkspaceSchemaDescription(workspaceDataSource);
|
||||
|
||||
const llmOutputSchema = z.object({
|
||||
sqlQuery: z.string(),
|
||||
});
|
||||
|
||||
const llmOutputJsonSchema = JSON.stringify(
|
||||
zodToJsonSchema(llmOutputSchema),
|
||||
);
|
||||
|
||||
const structuredOutputParser =
|
||||
StructuredOutputParser.fromZodSchema(llmOutputSchema);
|
||||
|
||||
const sqlQueryGeneratorChain = RunnableSequence.from([
|
||||
sqlGenerationPromptTemplate,
|
||||
this.llmChatModelService.getJSONChatModel(),
|
||||
structuredOutputParser,
|
||||
]);
|
||||
|
||||
const metadata = {
|
||||
workspaceId,
|
||||
...traceMetadata,
|
||||
};
|
||||
const tracingCallbackHandler =
|
||||
this.llmTracingService.getCallbackHandler(metadata);
|
||||
|
||||
const { sqlQuery } = await sqlQueryGeneratorChain.invoke(
|
||||
{
|
||||
llmOutputJsonSchema,
|
||||
sqlCreateTableStatements: workspaceSchemaDescription,
|
||||
userQuestion,
|
||||
},
|
||||
{
|
||||
callbacks: [tracingCallbackHandler],
|
||||
},
|
||||
);
|
||||
|
||||
return sqlQuery;
|
||||
}
|
||||
|
||||
async generate(
|
||||
workspaceId: string,
|
||||
userQuestion: string,
|
||||
traceMetadata: Record<string, string> = {},
|
||||
) {
|
||||
const workspaceDataSource =
|
||||
await this.workspaceDataSourceService.connectToWorkspaceDataSource(
|
||||
workspaceId,
|
||||
);
|
||||
|
||||
return this.generateWithDataSource(
|
||||
workspaceId,
|
||||
workspaceDataSource,
|
||||
userQuestion,
|
||||
traceMetadata,
|
||||
);
|
||||
}
|
||||
|
||||
async generateAndExecute(
|
||||
workspaceId: string,
|
||||
userQuestion: string,
|
||||
traceMetadata: Record<string, string> = {},
|
||||
): Promise<AISQLQueryResult> {
|
||||
const workspaceDataSource =
|
||||
await this.workspaceDataSourceService.connectToWorkspaceDataSource(
|
||||
workspaceId,
|
||||
);
|
||||
|
||||
const sqlQuery = await this.generateWithDataSource(
|
||||
workspaceId,
|
||||
workspaceDataSource,
|
||||
userQuestion,
|
||||
traceMetadata,
|
||||
);
|
||||
|
||||
try {
|
||||
const sqlQueryResult: Record<string, any>[] =
|
||||
await this.workspaceQueryRunnerService.executeSQL(
|
||||
workspaceDataSource,
|
||||
workspaceId,
|
||||
sqlQuery,
|
||||
);
|
||||
|
||||
return {
|
||||
sqlQuery,
|
||||
sqlQueryResult: JSON.stringify(sqlQueryResult),
|
||||
};
|
||||
} catch (error) {
|
||||
if (error instanceof QueryFailedError) {
|
||||
return {
|
||||
sqlQuery,
|
||||
queryFailedErrorMessage: error.message,
|
||||
};
|
||||
}
|
||||
|
||||
this.logger.error(error.message, error.stack);
|
||||
|
||||
return {
|
||||
sqlQuery,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,17 @@
|
||||
import { Field, ObjectType } from '@nestjs/graphql';
|
||||
|
||||
import { IsOptional } from 'class-validator';
|
||||
|
||||
@ObjectType('AISQLQueryResult')
|
||||
export class AISQLQueryResult {
|
||||
@Field(() => String)
|
||||
sqlQuery: string;
|
||||
|
||||
@Field(() => String, { nullable: true })
|
||||
@IsOptional()
|
||||
sqlQueryResult?: string;
|
||||
|
||||
@Field(() => String, { nullable: true })
|
||||
@IsOptional()
|
||||
queryFailedErrorMessage?: string;
|
||||
}
|
||||
@ -10,6 +10,7 @@ import { TimelineMessagingModule } from 'src/engine/core-modules/messaging/timel
|
||||
import { TimelineCalendarEventModule } from 'src/engine/core-modules/calendar/timeline-calendar-event.module';
|
||||
import { BillingModule } from 'src/engine/core-modules/billing/billing.module';
|
||||
import { HealthModule } from 'src/engine/core-modules/health/health.module';
|
||||
import { AISQLQueryModule } from 'src/engine/core-modules/ai-sql-query/ai-sql-query.module';
|
||||
import { PostgresCredentialsModule } from 'src/engine/core-modules/postgres-credentials/postgres-credentials.module';
|
||||
|
||||
import { AnalyticsModule } from './analytics/analytics.module';
|
||||
@ -31,6 +32,7 @@ import { ClientConfigModule } from './client-config/client-config.module';
|
||||
TimelineCalendarEventModule,
|
||||
UserModule,
|
||||
WorkspaceModule,
|
||||
AISQLQueryModule,
|
||||
PostgresCredentialsModule,
|
||||
],
|
||||
exports: [
|
||||
|
||||
@ -22,6 +22,7 @@ export enum FeatureFlagKeys {
|
||||
IsPostgreSQLIntegrationEnabled = 'IS_POSTGRESQL_INTEGRATION_ENABLED',
|
||||
IsStripeIntegrationEnabled = 'IS_STRIPE_INTEGRATION_ENABLED',
|
||||
IsContactCreationForSentAndReceivedEmailsEnabled = 'IS_CONTACT_CREATION_FOR_SENT_AND_RECEIVED_EMAILS_ENABLED',
|
||||
IsCopilotEnabled = 'IS_COPILOT_ENABLED',
|
||||
IsMessagingAliasFetchingEnabled = 'IS_MESSAGING_ALIAS_FETCHING_ENABLED',
|
||||
IsGoogleCalendarSyncV2Enabled = 'IS_GOOGLE_CALENDAR_SYNC_V2_ENABLED',
|
||||
IsFreeAccessEnabled = 'IS_FREE_ACCESS_ENABLED',
|
||||
|
||||
@ -17,6 +17,8 @@ import {
|
||||
|
||||
import { EmailDriver } from 'src/engine/integrations/email/interfaces/email.interface';
|
||||
import { NodeEnvironment } from 'src/engine/integrations/environment/interfaces/node-environment.interface';
|
||||
import { LLMChatModelDriver } from 'src/engine/integrations/llm-chat-model/interfaces/llm-chat-model.interface';
|
||||
import { LLMTracingDriver } from 'src/engine/integrations/llm-tracing/interfaces/llm-tracing.interface';
|
||||
|
||||
import { assert } from 'src/utils/assert';
|
||||
import { CastToStringArray } from 'src/engine/integrations/environment/decorators/cast-to-string-array.decorator';
|
||||
@ -369,6 +371,16 @@ export class EnvironmentVariables {
|
||||
|
||||
OPENROUTER_API_KEY: string;
|
||||
|
||||
LLM_CHAT_MODEL_DRIVER: LLMChatModelDriver = LLMChatModelDriver.OpenAI;
|
||||
|
||||
OPENAI_API_KEY: string;
|
||||
|
||||
LANGFUSE_SECRET_KEY: string;
|
||||
|
||||
LANGFUSE_PUBLIC_KEY: string;
|
||||
|
||||
LLM_TRACING_DRIVER: LLMTracingDriver = LLMTracingDriver.Console;
|
||||
|
||||
@CastToPositiveNumber()
|
||||
API_RATE_LIMITING_TTL = 100;
|
||||
|
||||
|
||||
@ -12,6 +12,10 @@ import { emailModuleFactory } from 'src/engine/integrations/email/email.module-f
|
||||
import { CacheStorageModule } from 'src/engine/integrations/cache-storage/cache-storage.module';
|
||||
import { CaptchaModule } from 'src/engine/integrations/captcha/captcha.module';
|
||||
import { captchaModuleFactory } from 'src/engine/integrations/captcha/captcha.module-factory';
|
||||
import { LLMChatModelModule } from 'src/engine/integrations/llm-chat-model/llm-chat-model.module';
|
||||
import { llmChatModelModuleFactory } from 'src/engine/integrations/llm-chat-model/llm-chat-model.module-factory';
|
||||
import { LLMTracingModule } from 'src/engine/integrations/llm-tracing/llm-tracing.module';
|
||||
import { llmTracingModuleFactory } from 'src/engine/integrations/llm-tracing/llm-tracing.module-factory';
|
||||
|
||||
import { EnvironmentModule } from './environment/environment.module';
|
||||
import { EnvironmentService } from './environment/environment.service';
|
||||
@ -50,6 +54,14 @@ import { MessageQueueModule } from './message-queue/message-queue.module';
|
||||
wildcard: true,
|
||||
}),
|
||||
CacheStorageModule,
|
||||
LLMChatModelModule.forRoot({
|
||||
useFactory: llmChatModelModuleFactory,
|
||||
inject: [EnvironmentService],
|
||||
}),
|
||||
LLMTracingModule.forRoot({
|
||||
useFactory: llmTracingModuleFactory,
|
||||
inject: [EnvironmentService],
|
||||
}),
|
||||
],
|
||||
exports: [],
|
||||
providers: [],
|
||||
|
||||
@ -0,0 +1,5 @@
|
||||
import { BaseChatModel } from '@langchain/core/language_models/chat_models';
|
||||
|
||||
export interface LLMChatModelDriver {
|
||||
getJSONChatModel(): BaseChatModel;
|
||||
}
|
||||
@ -0,0 +1,22 @@
|
||||
import { BaseChatModel } from '@langchain/core/language_models/chat_models';
|
||||
import { ChatOpenAI } from '@langchain/openai';
|
||||
|
||||
import { LLMChatModelDriver } from 'src/engine/integrations/llm-chat-model/drivers/interfaces/llm-prompt-template-driver.interface';
|
||||
|
||||
export class OpenAIDriver implements LLMChatModelDriver {
|
||||
private chatModel: BaseChatModel;
|
||||
|
||||
constructor() {
|
||||
this.chatModel = new ChatOpenAI({
|
||||
model: 'gpt-4o',
|
||||
}).bind({
|
||||
response_format: {
|
||||
type: 'json_object',
|
||||
},
|
||||
}) as unknown as BaseChatModel;
|
||||
}
|
||||
|
||||
getJSONChatModel() {
|
||||
return this.chatModel;
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,14 @@
|
||||
import { ModuleMetadata, FactoryProvider } from '@nestjs/common';
|
||||
|
||||
export enum LLMChatModelDriver {
|
||||
OpenAI = 'openai',
|
||||
}
|
||||
|
||||
export interface LLMChatModelModuleOptions {
|
||||
type: LLMChatModelDriver;
|
||||
}
|
||||
|
||||
export type LLMChatModelModuleAsyncOptions = {
|
||||
useFactory: (...args: any[]) => LLMChatModelModuleOptions;
|
||||
} & Pick<ModuleMetadata, 'imports'> &
|
||||
Pick<FactoryProvider, 'inject'>;
|
||||
@ -0,0 +1 @@
|
||||
export const LLM_CHAT_MODEL_DRIVER = Symbol('LLM_CHAT_MODEL_DRIVER');
|
||||
@ -0,0 +1,19 @@
|
||||
import { LLMChatModelDriver } from 'src/engine/integrations/llm-chat-model/interfaces/llm-chat-model.interface';
|
||||
|
||||
import { EnvironmentService } from 'src/engine/integrations/environment/environment.service';
|
||||
|
||||
export const llmChatModelModuleFactory = (
|
||||
environmentService: EnvironmentService,
|
||||
) => {
|
||||
const driver = environmentService.get('LLM_CHAT_MODEL_DRIVER');
|
||||
|
||||
switch (driver) {
|
||||
case LLMChatModelDriver.OpenAI: {
|
||||
return { type: LLMChatModelDriver.OpenAI };
|
||||
}
|
||||
default:
|
||||
throw new Error(
|
||||
`Invalid LLM chat model driver (${driver}), check your .env file`,
|
||||
);
|
||||
}
|
||||
};
|
||||
@ -0,0 +1,35 @@
|
||||
import { DynamicModule, Global } from '@nestjs/common';
|
||||
|
||||
import {
|
||||
LLMChatModelDriver,
|
||||
LLMChatModelModuleAsyncOptions,
|
||||
} from 'src/engine/integrations/llm-chat-model/interfaces/llm-chat-model.interface';
|
||||
|
||||
import { LLM_CHAT_MODEL_DRIVER } from 'src/engine/integrations/llm-chat-model/llm-chat-model.constants';
|
||||
import { OpenAIDriver } from 'src/engine/integrations/llm-chat-model/drivers/openai.driver';
|
||||
import { LLMChatModelService } from 'src/engine/integrations/llm-chat-model/llm-chat-model.service';
|
||||
|
||||
@Global()
|
||||
export class LLMChatModelModule {
|
||||
static forRoot(options: LLMChatModelModuleAsyncOptions): DynamicModule {
|
||||
const provider = {
|
||||
provide: LLM_CHAT_MODEL_DRIVER,
|
||||
useFactory: (...args: any[]) => {
|
||||
const config = options.useFactory(...args);
|
||||
|
||||
switch (config.type) {
|
||||
case LLMChatModelDriver.OpenAI: {
|
||||
return new OpenAIDriver();
|
||||
}
|
||||
}
|
||||
},
|
||||
inject: options.inject || [],
|
||||
};
|
||||
|
||||
return {
|
||||
module: LLMChatModelModule,
|
||||
providers: [LLMChatModelService, provider],
|
||||
exports: [LLMChatModelService],
|
||||
};
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,16 @@
|
||||
import { Injectable, Inject } from '@nestjs/common';
|
||||
|
||||
import { LLMChatModelDriver } from 'src/engine/integrations/llm-chat-model/drivers/interfaces/llm-prompt-template-driver.interface';
|
||||
|
||||
import { LLM_CHAT_MODEL_DRIVER } from 'src/engine/integrations/llm-chat-model/llm-chat-model.constants';
|
||||
|
||||
@Injectable()
|
||||
export class LLMChatModelService {
|
||||
constructor(
|
||||
@Inject(LLM_CHAT_MODEL_DRIVER) private driver: LLMChatModelDriver,
|
||||
) {}
|
||||
|
||||
getJSONChatModel() {
|
||||
return this.driver.getJSONChatModel();
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,25 @@
|
||||
import { BaseCallbackHandler } from '@langchain/core/callbacks/base';
|
||||
import { ConsoleCallbackHandler } from '@langchain/core/tracers/console';
|
||||
import { Run } from '@langchain/core/tracers/base';
|
||||
|
||||
import { LLMTracingDriver } from 'src/engine/integrations/llm-tracing/drivers/interfaces/llm-tracing-driver.interface';
|
||||
|
||||
class WithMetadataConsoleCallbackHandler extends ConsoleCallbackHandler {
|
||||
private metadata: Record<string, unknown>;
|
||||
|
||||
constructor(metadata: Record<string, unknown>) {
|
||||
super();
|
||||
this.metadata = metadata;
|
||||
}
|
||||
|
||||
onChainStart(run: Run) {
|
||||
console.log(`Chain metadata: ${JSON.stringify(this.metadata)}`);
|
||||
super.onChainStart(run);
|
||||
}
|
||||
}
|
||||
|
||||
export class ConsoleDriver implements LLMTracingDriver {
|
||||
getCallbackHandler(metadata: Record<string, unknown>): BaseCallbackHandler {
|
||||
return new WithMetadataConsoleCallbackHandler(metadata);
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,5 @@
|
||||
import { BaseCallbackHandler } from '@langchain/core/callbacks/base';
|
||||
|
||||
export interface LLMTracingDriver {
|
||||
getCallbackHandler(metadata: Record<string, unknown>): BaseCallbackHandler;
|
||||
}
|
||||
@ -0,0 +1,26 @@
|
||||
import { BaseCallbackHandler } from '@langchain/core/callbacks/base';
|
||||
import CallbackHandler from 'langfuse-langchain';
|
||||
|
||||
import { LLMTracingDriver } from 'src/engine/integrations/llm-tracing/drivers/interfaces/llm-tracing-driver.interface';
|
||||
|
||||
export interface LangfuseDriverOptions {
|
||||
secretKey: string;
|
||||
publicKey: string;
|
||||
}
|
||||
|
||||
export class LangfuseDriver implements LLMTracingDriver {
|
||||
private options: LangfuseDriverOptions;
|
||||
|
||||
constructor(options: LangfuseDriverOptions) {
|
||||
this.options = options;
|
||||
}
|
||||
|
||||
getCallbackHandler(metadata: Record<string, unknown>): BaseCallbackHandler {
|
||||
return new CallbackHandler({
|
||||
secretKey: this.options.secretKey,
|
||||
publicKey: this.options.publicKey,
|
||||
baseUrl: 'https://cloud.langfuse.com',
|
||||
metadata: metadata,
|
||||
});
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,26 @@
|
||||
import { ModuleMetadata, FactoryProvider } from '@nestjs/common';
|
||||
|
||||
import { LangfuseDriverOptions } from 'src/engine/integrations/llm-tracing/drivers/langfuse.driver';
|
||||
|
||||
export enum LLMTracingDriver {
|
||||
Langfuse = 'langfuse',
|
||||
Console = 'console',
|
||||
}
|
||||
|
||||
export interface LangfuseDriverFactoryOptions {
|
||||
type: LLMTracingDriver.Langfuse;
|
||||
options: LangfuseDriverOptions;
|
||||
}
|
||||
|
||||
export interface ConsoleDriverFactoryOptions {
|
||||
type: LLMTracingDriver.Console;
|
||||
}
|
||||
|
||||
export type LLMTracingModuleOptions =
|
||||
| LangfuseDriverFactoryOptions
|
||||
| ConsoleDriverFactoryOptions;
|
||||
|
||||
export type LLMTracingModuleAsyncOptions = {
|
||||
useFactory: (...args: any[]) => LLMTracingModuleOptions;
|
||||
} & Pick<ModuleMetadata, 'imports'> &
|
||||
Pick<FactoryProvider, 'inject'>;
|
||||
@ -0,0 +1 @@
|
||||
export const LLM_TRACING_DRIVER = Symbol('LLM_TRACING_DRIVER');
|
||||
@ -0,0 +1,34 @@
|
||||
import { LLMTracingDriver } from 'src/engine/integrations/llm-tracing/interfaces/llm-tracing.interface';
|
||||
|
||||
import { EnvironmentService } from 'src/engine/integrations/environment/environment.service';
|
||||
|
||||
export const llmTracingModuleFactory = (
|
||||
environmentService: EnvironmentService,
|
||||
) => {
|
||||
const driver = environmentService.get('LLM_TRACING_DRIVER');
|
||||
|
||||
switch (driver) {
|
||||
case LLMTracingDriver.Console: {
|
||||
return { type: LLMTracingDriver.Console as const };
|
||||
}
|
||||
case LLMTracingDriver.Langfuse: {
|
||||
const secretKey = environmentService.get('LANGFUSE_SECRET_KEY');
|
||||
const publicKey = environmentService.get('LANGFUSE_PUBLIC_KEY');
|
||||
|
||||
if (!(secretKey && publicKey)) {
|
||||
throw new Error(
|
||||
`${driver} LLM tracing driver requires LANGFUSE_SECRET_KEY and LANGFUSE_PUBLIC_KEY to be defined, check your .env file`,
|
||||
);
|
||||
}
|
||||
|
||||
return {
|
||||
type: LLMTracingDriver.Langfuse as const,
|
||||
options: { secretKey, publicKey },
|
||||
};
|
||||
}
|
||||
default:
|
||||
throw new Error(
|
||||
`Invalid LLM tracing driver (${driver}), check your .env file`,
|
||||
);
|
||||
}
|
||||
};
|
||||
@ -0,0 +1,39 @@
|
||||
import { Global, DynamicModule } from '@nestjs/common';
|
||||
|
||||
import {
|
||||
LLMTracingModuleAsyncOptions,
|
||||
LLMTracingDriver,
|
||||
} from 'src/engine/integrations/llm-tracing/interfaces/llm-tracing.interface';
|
||||
|
||||
import { LangfuseDriver } from 'src/engine/integrations/llm-tracing/drivers/langfuse.driver';
|
||||
import { ConsoleDriver } from 'src/engine/integrations/llm-tracing/drivers/console.driver';
|
||||
import { LLMTracingService } from 'src/engine/integrations/llm-tracing/llm-tracing.service';
|
||||
import { LLM_TRACING_DRIVER } from 'src/engine/integrations/llm-tracing/llm-tracing.constants';
|
||||
|
||||
@Global()
|
||||
export class LLMTracingModule {
|
||||
static forRoot(options: LLMTracingModuleAsyncOptions): DynamicModule {
|
||||
const provider = {
|
||||
provide: LLM_TRACING_DRIVER,
|
||||
useFactory: (...args: any[]) => {
|
||||
const config = options.useFactory(...args);
|
||||
|
||||
switch (config.type) {
|
||||
case LLMTracingDriver.Langfuse: {
|
||||
return new LangfuseDriver(config.options);
|
||||
}
|
||||
case LLMTracingDriver.Console: {
|
||||
return new ConsoleDriver();
|
||||
}
|
||||
}
|
||||
},
|
||||
inject: options.inject || [],
|
||||
};
|
||||
|
||||
return {
|
||||
module: LLMTracingModule,
|
||||
providers: [LLMTracingService, provider],
|
||||
exports: [LLMTracingService],
|
||||
};
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,16 @@
|
||||
import { Injectable, Inject } from '@nestjs/common';
|
||||
|
||||
import { BaseCallbackHandler } from '@langchain/core/callbacks/base';
|
||||
|
||||
import { LLMTracingDriver } from 'src/engine/integrations/llm-tracing/drivers/interfaces/llm-tracing-driver.interface';
|
||||
|
||||
import { LLM_TRACING_DRIVER } from 'src/engine/integrations/llm-tracing/llm-tracing.constants';
|
||||
|
||||
@Injectable()
|
||||
export class LLMTracingService {
|
||||
constructor(@Inject(LLM_TRACING_DRIVER) private driver: LLMTracingDriver) {}
|
||||
|
||||
getCallbackHandler(metadata: Record<string, unknown>): BaseCallbackHandler {
|
||||
return this.driver.getCallbackHandler(metadata);
|
||||
}
|
||||
}
|
||||
@ -0,0 +1 @@
|
||||
export const DEFAULT_LABEL_IDENTIFIER_FIELD_NAME = 'name';
|
||||
@ -59,6 +59,7 @@ export class AddStandardIdCommand extends CommandRunner {
|
||||
IS_POSTGRESQL_INTEGRATION_ENABLED: true,
|
||||
IS_STRIPE_INTEGRATION_ENABLED: false,
|
||||
IS_CONTACT_CREATION_FOR_SENT_AND_RECEIVED_EMAILS_ENABLED: true,
|
||||
IS_COPILOT_ENABLED: false,
|
||||
IS_MESSAGING_ALIAS_FETCHING_ENABLED: true,
|
||||
IS_GOOGLE_CALENDAR_SYNC_V2_ENABLED: true,
|
||||
IS_FREE_ACCESS_ENABLED: false,
|
||||
@ -77,6 +78,7 @@ export class AddStandardIdCommand extends CommandRunner {
|
||||
IS_POSTGRESQL_INTEGRATION_ENABLED: true,
|
||||
IS_STRIPE_INTEGRATION_ENABLED: false,
|
||||
IS_CONTACT_CREATION_FOR_SENT_AND_RECEIVED_EMAILS_ENABLED: true,
|
||||
IS_COPILOT_ENABLED: false,
|
||||
IS_MESSAGING_ALIAS_FETCHING_ENABLED: true,
|
||||
IS_GOOGLE_CALENDAR_SYNC_V2_ENABLED: true,
|
||||
IS_FREE_ACCESS_ENABLED: false,
|
||||
|
||||
@ -33,6 +33,7 @@ export {
|
||||
IconCalendarEvent,
|
||||
IconCalendarTime,
|
||||
IconCalendarX,
|
||||
IconChartCandle,
|
||||
IconCheck,
|
||||
IconCheckbox,
|
||||
IconChevronDown,
|
||||
@ -137,10 +138,13 @@ export {
|
||||
IconReload,
|
||||
IconRepeat,
|
||||
IconRocket,
|
||||
IconRotate,
|
||||
IconSearch,
|
||||
IconSend,
|
||||
IconSettings,
|
||||
IconSortDescending,
|
||||
IconSparkles,
|
||||
IconSql,
|
||||
IconSquareRoundedCheck,
|
||||
IconTable,
|
||||
IconTag,
|
||||
|
||||
@ -170,8 +170,13 @@ yarn command:prod cron:calendar:google-calendar-sync
|
||||
### Data enrichment and AI
|
||||
|
||||
<ArticleTable options={[
|
||||
['OPENROUTER_API_KEY', '', "The API key for openrouter.ai, an abstraction layer over models from Mistral, OpenAI and more"]
|
||||
]}></ArticleTable>
|
||||
['OPENROUTER_API_KEY', '', "The API key for openrouter.ai, an abstraction layer over models from Mistral, OpenAI and more"],
|
||||
['OPENAI_API_KEY', 'sk-proj-abcdabcd...', "OpenAI API key"],
|
||||
['LLM_CHAT_MODEL_DRIVER', 'openai', "LLM provider"],
|
||||
['LLM_TRACING_DRIVER', 'langfuse', "Where to output LangChain logs. 'langfuse' or 'console'."],
|
||||
['LANGFUSE_SECRET_KEY', 'sk-lf-abcdabcd-abcd...', "Langfuse secret key"],
|
||||
['LANGFUSE_PUBLIC_KEY', 'pk-lf-abcdabcd-abcd...', "Langfuse public key"],
|
||||
]}></ArticleTable>
|
||||
|
||||
|
||||
### Support Chat
|
||||
|
||||
Reference in New Issue
Block a user