diff --git a/packages/twenty-front/src/modules/apollo/services/apollo.factory.ts b/packages/twenty-front/src/modules/apollo/services/apollo.factory.ts index 2a76e445a..9266ef5ff 100644 --- a/packages/twenty-front/src/modules/apollo/services/apollo.factory.ts +++ b/packages/twenty-front/src/modules/apollo/services/apollo.factory.ts @@ -28,6 +28,7 @@ import { isUndefinedOrNull } from '~/utils/isUndefinedOrNull'; import { ApolloManager } from '../types/apolloManager.interface'; import { getTokenPair } from '../utils/getTokenPair'; import { loggerLink } from '../utils/loggerLink'; +import { StreamingRestLink } from '../utils/streamingRestLink'; const logger = loggerLink(() => 'Twenty'); @@ -42,6 +43,8 @@ export interface Options extends ApolloClientOptions { isDebugMode?: boolean; } +const REST_API_BASE_URL = `${REACT_APP_SERVER_BASE_URL}/rest`; + export class ApolloFactory implements ApolloManager { private client: ApolloClient; private currentWorkspaceMember: CurrentWorkspaceMember | null = null; @@ -69,8 +72,12 @@ export class ApolloFactory implements ApolloManager { uri, }); + const streamingRestLink = new StreamingRestLink({ + uri: REST_API_BASE_URL, + }); + const restLink = new RestLink({ - uri: `${REACT_APP_SERVER_BASE_URL}/rest`, + uri: REST_API_BASE_URL, }); const authLink = setContext(async (_, { headers }) => { @@ -228,6 +235,7 @@ export class ApolloFactory implements ApolloManager { ...(extraLinks || []), isDebugMode ? logger : null, retryLink, + streamingRestLink, restLink, httpLink, ].filter(isDefined), diff --git a/packages/twenty-front/src/modules/apollo/utils/__tests__/streamingRestLink.test.ts b/packages/twenty-front/src/modules/apollo/utils/__tests__/streamingRestLink.test.ts new file mode 100644 index 000000000..f509d16bd --- /dev/null +++ b/packages/twenty-front/src/modules/apollo/utils/__tests__/streamingRestLink.test.ts @@ -0,0 +1,231 @@ +import { gql } from '@apollo/client'; +import { Operation } from '@apollo/client/core'; +import { StreamingRestLink } from '../streamingRestLink'; + +global.fetch = jest.fn(); +describe('StreamingRestLink', () => { + let streamingLink: StreamingRestLink; + let mockForward: jest.MockedFunction<(operation: Operation) => any>; + + beforeEach(() => { + streamingLink = new StreamingRestLink({ + uri: 'https://api.example.com', + }); + mockForward = jest.fn(); + (global.fetch as jest.Mock).mockClear(); + }); + + describe('request', () => { + it('should forward operations without @stream directive', () => { + const operation = { + query: gql` + query Test { + test + } + `, + variables: {}, + getContext: () => ({}), + } as Operation; + + const result = streamingLink.request(operation, mockForward); + + expect(mockForward).toHaveBeenCalledWith(operation); + expect(result).toBe(mockForward(operation)); + }); + + it('should handle operations with @stream directive', async () => { + const operation = { + query: gql` + query StreamTest($threadId: String!) { + streamChatResponse(threadId: $threadId) + @stream( + path: "/agent-chat/stream/{args.threadId}" + method: "POST" + bodyKey: "requestBody" + ) + } + `, + variables: { threadId: '123', requestBody: { threadId: '123' } }, + getContext: () => ({ onChunk: jest.fn() }), + operationName: 'StreamTest', + extensions: {}, + setContext: jest.fn(), + } as Operation; + + const mockResponse = { + ok: true, + body: { + getReader: () => ({ + read: jest.fn().mockResolvedValue({ done: true }), + releaseLock: jest.fn(), + }), + }, + }; + (global.fetch as jest.Mock).mockResolvedValue(mockResponse); + + const observable = streamingLink.request(operation, mockForward); + const observer = { + next: jest.fn(), + error: jest.fn(), + complete: jest.fn(), + }; + + observable.subscribe(observer); + + expect(mockForward).not.toHaveBeenCalled(); + expect(global.fetch).toHaveBeenCalledWith( + 'https://api.example.com/agent-chat/stream/123', + expect.objectContaining({ + method: 'POST', + headers: expect.objectContaining({ + 'Content-Type': 'application/json', + Accept: 'text/event-stream', + }), + body: JSON.stringify({ threadId: '123' }), + }), + ); + }); + + it('should handle network errors', async () => { + const operation = { + query: gql` + query StreamTest { + test @stream(path: "/stream", method: "GET") + } + `, + variables: {}, + getContext: () => ({}), + } as Operation; + + (global.fetch as jest.Mock).mockRejectedValue(new Error('Network error')); + + const observable = streamingLink.request(operation, mockForward); + const observer = { + next: jest.fn(), + error: jest.fn(), + complete: jest.fn(), + }; + + observable.subscribe(observer); + + await new Promise((resolve) => setTimeout(resolve, 0)); + + expect(observer.error).toHaveBeenCalledWith(new Error('Network error')); + }); + + it('should handle non-ok responses', async () => { + const operation = { + query: gql` + query StreamTest { + test @stream(path: "/stream", method: "GET") + } + `, + variables: {}, + getContext: () => ({}), + } as Operation; + + const mockResponse = { ok: false, status: 404 }; + (global.fetch as jest.Mock).mockResolvedValue(mockResponse); + + const observable = streamingLink.request(operation, mockForward); + const observer = { + next: jest.fn(), + error: jest.fn(), + complete: jest.fn(), + }; + + observable.subscribe(observer); + + await new Promise((resolve) => setTimeout(resolve, 0)); + + expect(observer.error).toHaveBeenCalledWith( + new Error('HTTP error! status: 404'), + ); + }); + }); + + describe('extractStreamDirective', () => { + it('should extract directive arguments correctly', () => { + const operation = { + query: gql` + query Test($threadId: String!) { + streamChatResponse(threadId: $threadId) + @stream( + path: "/agent-chat/stream/{args.threadId}" + method: "POST" + bodyKey: "requestBody" + type: "StreamChatResponse" + ) + } + `, + variables: {}, + getContext: () => ({}), + } as Operation; + + const directive = (streamingLink as any).extractStreamDirective( + operation, + ); + + expect(directive).toEqual({ + path: '/agent-chat/stream/{args.threadId}', + method: 'POST', + bodyKey: 'requestBody', + type: 'StreamChatResponse', + }); + }); + + it('should return null for operations without @stream directive', () => { + const operation = { + query: gql` + query Test { + test + } + `, + variables: {}, + getContext: () => ({}), + } as Operation; + + const directive = (streamingLink as any).extractStreamDirective( + operation, + ); + + expect(directive).toBeNull(); + }); + }); + + describe('buildUrl', () => { + it('should build URL with variable substitution', () => { + const operation = { + variables: { threadId: '123' }, + query: gql` + query Test { + test + } + `, + operationName: 'Test', + extensions: {}, + setContext: jest.fn(), + getContext: () => ({}), + } as unknown as Operation; + + const directive = { + path: '/agent-chat/stream/{args.threadId}', + }; + + const url = (streamingLink as any).buildUrl({ + streamDirective: directive, + operation, + }); + + expect(url).toBe('https://api.example.com/agent-chat/stream/123'); + }); + + it('should use uri from context if provided', () => { + const url = (streamingLink as any).buildUrl({ + uri: 'https://custom.example.com/api', + }); + + expect(url).toBe('https://custom.example.com/api'); + }); + }); +}); diff --git a/packages/twenty-front/src/modules/apollo/utils/streamingRestLink.ts b/packages/twenty-front/src/modules/apollo/utils/streamingRestLink.ts new file mode 100644 index 000000000..1fdb541df --- /dev/null +++ b/packages/twenty-front/src/modules/apollo/utils/streamingRestLink.ts @@ -0,0 +1,226 @@ +import { ApolloLink, Observable, Operation } from '@apollo/client/core'; +import { FetchResult } from '@apollo/client/link/core'; +import { ArgumentNode, DirectiveNode } from 'graphql'; +import { isDefined } from 'twenty-shared/utils'; + +type StreamingRestLinkOptions = { + uri: string; + headers?: Record; + credentials?: RequestCredentials; +}; + +type StreamDirective = { + type?: string; + path?: string; + method?: string; + bodyKey?: string; + headers?: Record; +}; + +export class StreamingRestLink extends ApolloLink { + private readonly baseUri: string; + private readonly defaultHeaders: Record; + + constructor(options: StreamingRestLinkOptions) { + super(); + this.baseUri = options.uri; + this.defaultHeaders = options.headers || {}; + } + + public request( + operation: Operation, + forward: (operation: Operation) => Observable, + ): Observable { + const streamDirective = this.extractStreamDirective(operation); + + if (!streamDirective) { + return forward(operation); + } + + const { uri, onChunk, headers } = operation.getContext(); + + return new Observable((observer) => { + const controller = new AbortController(); + const url = this.buildUrl({ + uri, + streamDirective, + operation, + }); + + const requestConfig = this.buildRequestConfig({ + operation, + streamDirective, + headers, + signal: controller.signal, + }); + + fetch(url, requestConfig) + .then(async (response) => { + if (!response.ok) { + throw new Error(`HTTP error! status: ${response.status}`); + } + + if (!response.body) { + throw new Error('Response body is null'); + } + + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + let accumulatedData = ''; + + let isStreaming = true; + while (isStreaming) { + const { done, value } = await reader.read(); + + if (done) { + observer.complete(); + isStreaming = false; + continue; + } + + const decodedChunk = decoder.decode(value, { stream: true }); + accumulatedData += decodedChunk; + + if (isDefined(onChunk) && typeof onChunk === 'function') { + onChunk(accumulatedData); + } + + try { + const parsedData = JSON.parse(decodedChunk); + observer.next({ data: parsedData }); + } catch { + observer.next({ + data: { streamingData: decodedChunk }, + }); + } + } + }) + .catch((error) => { + observer.error(error); + }); + + return () => controller.abort(); + }); + } + + private extractStreamDirective(operation: Operation): StreamDirective | null { + try { + const definition = operation.query.definitions[0]; + + if (!definition || definition.kind !== 'OperationDefinition') { + return null; + } + + if ( + !definition.selectionSet || + !definition.selectionSet.selections || + definition.selectionSet.selections.length === 0 + ) { + return null; + } + + const selection = definition.selectionSet.selections[0]; + + if (!selection || !isDefined(selection.directives)) { + return null; + } + + const streamDirective = selection.directives.find( + (d: DirectiveNode) => d.name.value === 'stream', + ); + + if (!isDefined(streamDirective)) { + return null; + } + + const args = streamDirective.arguments || []; + const directive: StreamDirective = {}; + + args.forEach((arg: ArgumentNode) => { + if (arg.value.kind === 'StringValue') { + const value = arg.value.value; + switch (arg.name.value) { + case 'path': + directive.path = value; + break; + case 'type': + directive.type = value; + break; + case 'method': + directive.method = value; + break; + case 'bodyKey': + directive.bodyKey = value; + break; + } + } + }); + + return directive; + } catch (error) { + return null; + } + } + + private buildUrl({ + uri, + streamDirective, + operation, + }: { + uri?: string; + streamDirective?: StreamDirective | null; + operation?: Operation; + }): string { + if (isDefined(uri)) { + return uri.startsWith('http') ? uri : `${this.baseUri}${uri}`; + } + + if (isDefined(streamDirective?.path)) { + let path = streamDirective.path; + + if (isDefined(operation?.variables)) { + Object.entries(operation.variables).forEach(([key, value]) => { + path = path.replace(`{args.${key}}`, String(value)); + }); + } + + return `${this.baseUri}${path}`; + } + + throw new Error('No valid URL found'); + } + + private buildRequestConfig({ + operation, + streamDirective, + headers, + signal, + }: { + operation: Operation; + streamDirective?: StreamDirective | null; + headers?: Record; + signal?: AbortSignal; + }): RequestInit { + const method = streamDirective?.method || 'POST'; + let body: string | undefined; + + if (isDefined(streamDirective?.bodyKey) && isDefined(operation.variables)) { + body = JSON.stringify(operation.variables[streamDirective.bodyKey]); + } else { + body = JSON.stringify(operation.variables); + } + + return { + method, + headers: { + 'Content-Type': 'application/json', + Accept: 'text/event-stream', + ...this.defaultHeaders, + ...headers, + ...streamDirective?.headers, + }, + body, + signal, + }; + } +} diff --git a/packages/twenty-front/src/modules/workflow/workflow-steps/workflow-actions/ai-agent-action/api/agent-chat-apollo.api.ts b/packages/twenty-front/src/modules/workflow/workflow-steps/workflow-actions/ai-agent-action/api/agent-chat-apollo.api.ts index a322c8bb1..8967488cb 100644 --- a/packages/twenty-front/src/modules/workflow/workflow-steps/workflow-actions/ai-agent-action/api/agent-chat-apollo.api.ts +++ b/packages/twenty-front/src/modules/workflow/workflow-steps/workflow-actions/ai-agent-action/api/agent-chat-apollo.api.ts @@ -30,3 +30,15 @@ export const GET_AGENT_CHAT_MESSAGES = gql` } } `; + +export const STREAM_CHAT_QUERY = gql` + query StreamChatResponse($requestBody: JSON!) { + streamChatResponse(requestBody: $requestBody) + @stream( + type: "StreamChatResponse" + path: "/agent-chat/stream" + method: "POST" + bodyKey: "requestBody" + ) + } +`; diff --git a/packages/twenty-front/src/modules/workflow/workflow-steps/workflow-actions/ai-agent-action/api/streamChatResponse.ts b/packages/twenty-front/src/modules/workflow/workflow-steps/workflow-actions/ai-agent-action/api/streamChatResponse.ts deleted file mode 100644 index cda2afae3..000000000 --- a/packages/twenty-front/src/modules/workflow/workflow-steps/workflow-actions/ai-agent-action/api/streamChatResponse.ts +++ /dev/null @@ -1,105 +0,0 @@ -import { getTokenPair } from '@/apollo/utils/getTokenPair'; -import { renewToken } from '@/auth/services/AuthService'; -import { AppPath } from '@/types/AppPath'; -import { isDefined } from 'twenty-shared/utils'; -import { REACT_APP_SERVER_BASE_URL } from '~/config'; -import { cookieStorage } from '~/utils/cookie-storage'; - -const handleTokenRenewal = async () => { - const tokenPair = getTokenPair(); - if (!isDefined(tokenPair?.refreshToken?.token)) { - throw new Error('No refresh token available'); - } - - const newTokens = await renewToken( - `${REACT_APP_SERVER_BASE_URL}/graphql`, - tokenPair, - ); - - if (!isDefined(newTokens)) { - throw new Error('Token renewal failed'); - } - - cookieStorage.setItem('tokenPair', JSON.stringify(newTokens)); - return newTokens; -}; - -const createStreamRequest = ( - threadId: string, - userMessage: string, - accessToken: string, -) => ({ - method: 'POST', - headers: { - 'Content-Type': 'application/json', - Authorization: `Bearer ${accessToken}`, - }, - body: JSON.stringify({ - threadId, - userMessage, - }), -}); - -const handleStreamResponse = async ( - response: Response, - onChunk: (chunk: string) => void, -): Promise => { - const reader = response.body?.getReader(); - const decoder = new TextDecoder(); - let accumulated = ''; - - if (isDefined(reader)) { - let done = false; - while (!done) { - const { value, done: isDone } = await reader.read(); - done = isDone; - if (done) break; - - const chunk = decoder.decode(value, { stream: true }); - accumulated += chunk; - onChunk(accumulated); - } - } - - return accumulated; -}; - -export const streamChatResponse = async ( - threadId: string, - userMessage: string, - onChunk: (chunk: string) => void, -) => { - const tokenPair = getTokenPair(); - - if (!isDefined(tokenPair?.accessToken?.token)) { - throw new Error('No access token available'); - } - - const accessToken = tokenPair.accessToken.token; - - const response = await fetch( - `${REACT_APP_SERVER_BASE_URL}/rest/agent-chat/stream`, - createStreamRequest(threadId, userMessage, accessToken), - ); - - if (response.ok) { - return handleStreamResponse(response, onChunk); - } - - if (response.status === 401) { - try { - const newTokens = await handleTokenRenewal(); - const retryResponse = await fetch( - `${REACT_APP_SERVER_BASE_URL}/rest/agent-chat/stream`, - createStreamRequest(threadId, userMessage, newTokens.accessToken.token), - ); - - if (retryResponse.ok) { - return handleStreamResponse(retryResponse, onChunk); - } - } catch (renewalError) { - window.location.href = AppPath.SignInUp; - } - throw new Error('Authentication failed'); - } -}; diff --git a/packages/twenty-front/src/modules/workflow/workflow-steps/workflow-actions/ai-agent-action/hooks/useAgentChat.ts b/packages/twenty-front/src/modules/workflow/workflow-steps/workflow-actions/ai-agent-action/hooks/useAgentChat.ts index 8a14bc208..26da08ff8 100644 --- a/packages/twenty-front/src/modules/workflow/workflow-steps/workflow-actions/ai-agent-action/hooks/useAgentChat.ts +++ b/packages/twenty-front/src/modules/workflow/workflow-steps/workflow-actions/ai-agent-action/hooks/useAgentChat.ts @@ -6,10 +6,11 @@ import { Key } from 'ts-key-enum'; import { useHotkeysOnFocusedElement } from '@/ui/utilities/hotkey/hooks/useHotkeysOnFocusedElement'; import { useScrollWrapperElement } from '@/ui/utilities/scroll/hooks/useScrollWrapperElement'; +import { STREAM_CHAT_QUERY } from '@/workflow/workflow-steps/workflow-actions/ai-agent-action/api/agent-chat-apollo.api'; import { AgentChatMessageRole } from '@/workflow/workflow-steps/workflow-actions/ai-agent-action/constants/agent-chat-message-role'; +import { useApolloClient } from '@apollo/client'; import { isDefined } from 'twenty-shared/utils'; import { v4 } from 'uuid'; -import { streamChatResponse } from '../api/streamChatResponse'; import { agentChatInputState } from '../states/agentChatInputState'; import { agentChatMessagesComponentState } from '../states/agentChatMessagesComponentState'; import { agentStreamingMessageState } from '../states/agentStreamingMessageState'; @@ -21,6 +22,8 @@ interface OptimisticMessage extends AgentChatMessage { } export const useAgentChat = (agentId: string) => { + const apolloClient = useApolloClient(); + const [agentChatMessages, setAgentChatMessages] = useRecoilComponentStateV2( agentChatMessagesComponentState, agentId, @@ -92,9 +95,20 @@ export const useAgentChat = (agentId: string) => { setIsStreaming(true); - await streamChatResponse(currentThreadId, content, (chunk) => { - setAgentStreamingMessage(chunk); - scrollToBottom(); + await apolloClient.query({ + query: STREAM_CHAT_QUERY, + variables: { + requestBody: { + threadId: currentThreadId, + userMessage: content, + }, + }, + context: { + onChunk: (chunk: string) => { + setAgentStreamingMessage(chunk); + scrollToBottom(); + }, + }, }); setIsStreaming(false);