Create custom StreamingRestLink for streaming API requests using @stream directive (#13114)

Co-authored-by: Félix Malfait <felix.malfait@gmail.com>
Co-authored-by: Jean-Baptiste Ronssin <65334819+jbronssin@users.noreply.github.com>
Co-authored-by: Paul Rastoin <45004772+prastoin@users.noreply.github.com>
Co-authored-by: Naifer <161821705+omarNaifer12@users.noreply.github.com>
Co-authored-by: Lucas Bordeau <bordeau.lucas@gmail.com>
Co-authored-by: Dmitry Moiseenko <36731450+cxdima@users.noreply.github.com>
Co-authored-by: prastoin <paul@twenty.com>
Co-authored-by: Etienne <45695613+etiennejouan@users.noreply.github.com>
Co-authored-by: Guillim <guillim@users.noreply.github.com>
Co-authored-by: Charles Bochet <charles@twenty.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions <github-actions@twenty.com>
Co-authored-by: Niklas Korz <niklas@niklaskorz.de>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Co-authored-by: Raphaël Bosi <71827178+bosiraphael@users.noreply.github.com>
Co-authored-by: Charles Bochet <charlesBochet@users.noreply.github.com>
This commit is contained in:
Abdul Rahman
2025-07-09 02:29:04 +05:30
committed by GitHub
parent 39f6f3c4bb
commit 6057bdd389
6 changed files with 496 additions and 110 deletions

View File

@ -28,6 +28,7 @@ import { isUndefinedOrNull } from '~/utils/isUndefinedOrNull';
import { ApolloManager } from '../types/apolloManager.interface';
import { getTokenPair } from '../utils/getTokenPair';
import { loggerLink } from '../utils/loggerLink';
import { StreamingRestLink } from '../utils/streamingRestLink';
const logger = loggerLink(() => 'Twenty');
@ -42,6 +43,8 @@ export interface Options<TCacheShape> extends ApolloClientOptions<TCacheShape> {
isDebugMode?: boolean;
}
const REST_API_BASE_URL = `${REACT_APP_SERVER_BASE_URL}/rest`;
export class ApolloFactory<TCacheShape> implements ApolloManager<TCacheShape> {
private client: ApolloClient<TCacheShape>;
private currentWorkspaceMember: CurrentWorkspaceMember | null = null;
@ -69,8 +72,12 @@ export class ApolloFactory<TCacheShape> implements ApolloManager<TCacheShape> {
uri,
});
const streamingRestLink = new StreamingRestLink({
uri: REST_API_BASE_URL,
});
const restLink = new RestLink({
uri: `${REACT_APP_SERVER_BASE_URL}/rest`,
uri: REST_API_BASE_URL,
});
const authLink = setContext(async (_, { headers }) => {
@ -228,6 +235,7 @@ export class ApolloFactory<TCacheShape> implements ApolloManager<TCacheShape> {
...(extraLinks || []),
isDebugMode ? logger : null,
retryLink,
streamingRestLink,
restLink,
httpLink,
].filter(isDefined),

View File

@ -0,0 +1,231 @@
import { gql } from '@apollo/client';
import { Operation } from '@apollo/client/core';
import { StreamingRestLink } from '../streamingRestLink';
global.fetch = jest.fn();
describe('StreamingRestLink', () => {
let streamingLink: StreamingRestLink;
let mockForward: jest.MockedFunction<(operation: Operation) => any>;
beforeEach(() => {
streamingLink = new StreamingRestLink({
uri: 'https://api.example.com',
});
mockForward = jest.fn();
(global.fetch as jest.Mock).mockClear();
});
describe('request', () => {
it('should forward operations without @stream directive', () => {
const operation = {
query: gql`
query Test {
test
}
`,
variables: {},
getContext: () => ({}),
} as Operation;
const result = streamingLink.request(operation, mockForward);
expect(mockForward).toHaveBeenCalledWith(operation);
expect(result).toBe(mockForward(operation));
});
it('should handle operations with @stream directive', async () => {
const operation = {
query: gql`
query StreamTest($threadId: String!) {
streamChatResponse(threadId: $threadId)
@stream(
path: "/agent-chat/stream/{args.threadId}"
method: "POST"
bodyKey: "requestBody"
)
}
`,
variables: { threadId: '123', requestBody: { threadId: '123' } },
getContext: () => ({ onChunk: jest.fn() }),
operationName: 'StreamTest',
extensions: {},
setContext: jest.fn(),
} as Operation;
const mockResponse = {
ok: true,
body: {
getReader: () => ({
read: jest.fn().mockResolvedValue({ done: true }),
releaseLock: jest.fn(),
}),
},
};
(global.fetch as jest.Mock).mockResolvedValue(mockResponse);
const observable = streamingLink.request(operation, mockForward);
const observer = {
next: jest.fn(),
error: jest.fn(),
complete: jest.fn(),
};
observable.subscribe(observer);
expect(mockForward).not.toHaveBeenCalled();
expect(global.fetch).toHaveBeenCalledWith(
'https://api.example.com/agent-chat/stream/123',
expect.objectContaining({
method: 'POST',
headers: expect.objectContaining({
'Content-Type': 'application/json',
Accept: 'text/event-stream',
}),
body: JSON.stringify({ threadId: '123' }),
}),
);
});
it('should handle network errors', async () => {
const operation = {
query: gql`
query StreamTest {
test @stream(path: "/stream", method: "GET")
}
`,
variables: {},
getContext: () => ({}),
} as Operation;
(global.fetch as jest.Mock).mockRejectedValue(new Error('Network error'));
const observable = streamingLink.request(operation, mockForward);
const observer = {
next: jest.fn(),
error: jest.fn(),
complete: jest.fn(),
};
observable.subscribe(observer);
await new Promise((resolve) => setTimeout(resolve, 0));
expect(observer.error).toHaveBeenCalledWith(new Error('Network error'));
});
it('should handle non-ok responses', async () => {
const operation = {
query: gql`
query StreamTest {
test @stream(path: "/stream", method: "GET")
}
`,
variables: {},
getContext: () => ({}),
} as Operation;
const mockResponse = { ok: false, status: 404 };
(global.fetch as jest.Mock).mockResolvedValue(mockResponse);
const observable = streamingLink.request(operation, mockForward);
const observer = {
next: jest.fn(),
error: jest.fn(),
complete: jest.fn(),
};
observable.subscribe(observer);
await new Promise((resolve) => setTimeout(resolve, 0));
expect(observer.error).toHaveBeenCalledWith(
new Error('HTTP error! status: 404'),
);
});
});
describe('extractStreamDirective', () => {
it('should extract directive arguments correctly', () => {
const operation = {
query: gql`
query Test($threadId: String!) {
streamChatResponse(threadId: $threadId)
@stream(
path: "/agent-chat/stream/{args.threadId}"
method: "POST"
bodyKey: "requestBody"
type: "StreamChatResponse"
)
}
`,
variables: {},
getContext: () => ({}),
} as Operation;
const directive = (streamingLink as any).extractStreamDirective(
operation,
);
expect(directive).toEqual({
path: '/agent-chat/stream/{args.threadId}',
method: 'POST',
bodyKey: 'requestBody',
type: 'StreamChatResponse',
});
});
it('should return null for operations without @stream directive', () => {
const operation = {
query: gql`
query Test {
test
}
`,
variables: {},
getContext: () => ({}),
} as Operation;
const directive = (streamingLink as any).extractStreamDirective(
operation,
);
expect(directive).toBeNull();
});
});
describe('buildUrl', () => {
it('should build URL with variable substitution', () => {
const operation = {
variables: { threadId: '123' },
query: gql`
query Test {
test
}
`,
operationName: 'Test',
extensions: {},
setContext: jest.fn(),
getContext: () => ({}),
} as unknown as Operation;
const directive = {
path: '/agent-chat/stream/{args.threadId}',
};
const url = (streamingLink as any).buildUrl({
streamDirective: directive,
operation,
});
expect(url).toBe('https://api.example.com/agent-chat/stream/123');
});
it('should use uri from context if provided', () => {
const url = (streamingLink as any).buildUrl({
uri: 'https://custom.example.com/api',
});
expect(url).toBe('https://custom.example.com/api');
});
});
});

View File

@ -0,0 +1,226 @@
import { ApolloLink, Observable, Operation } from '@apollo/client/core';
import { FetchResult } from '@apollo/client/link/core';
import { ArgumentNode, DirectiveNode } from 'graphql';
import { isDefined } from 'twenty-shared/utils';
type StreamingRestLinkOptions = {
uri: string;
headers?: Record<string, string>;
credentials?: RequestCredentials;
};
type StreamDirective = {
type?: string;
path?: string;
method?: string;
bodyKey?: string;
headers?: Record<string, string>;
};
export class StreamingRestLink extends ApolloLink {
private readonly baseUri: string;
private readonly defaultHeaders: Record<string, string>;
constructor(options: StreamingRestLinkOptions) {
super();
this.baseUri = options.uri;
this.defaultHeaders = options.headers || {};
}
public request(
operation: Operation,
forward: (operation: Operation) => Observable<FetchResult>,
): Observable<FetchResult> {
const streamDirective = this.extractStreamDirective(operation);
if (!streamDirective) {
return forward(operation);
}
const { uri, onChunk, headers } = operation.getContext();
return new Observable((observer) => {
const controller = new AbortController();
const url = this.buildUrl({
uri,
streamDirective,
operation,
});
const requestConfig = this.buildRequestConfig({
operation,
streamDirective,
headers,
signal: controller.signal,
});
fetch(url, requestConfig)
.then(async (response) => {
if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`);
}
if (!response.body) {
throw new Error('Response body is null');
}
const reader = response.body.getReader();
const decoder = new TextDecoder();
let accumulatedData = '';
let isStreaming = true;
while (isStreaming) {
const { done, value } = await reader.read();
if (done) {
observer.complete();
isStreaming = false;
continue;
}
const decodedChunk = decoder.decode(value, { stream: true });
accumulatedData += decodedChunk;
if (isDefined(onChunk) && typeof onChunk === 'function') {
onChunk(accumulatedData);
}
try {
const parsedData = JSON.parse(decodedChunk);
observer.next({ data: parsedData });
} catch {
observer.next({
data: { streamingData: decodedChunk },
});
}
}
})
.catch((error) => {
observer.error(error);
});
return () => controller.abort();
});
}
private extractStreamDirective(operation: Operation): StreamDirective | null {
try {
const definition = operation.query.definitions[0];
if (!definition || definition.kind !== 'OperationDefinition') {
return null;
}
if (
!definition.selectionSet ||
!definition.selectionSet.selections ||
definition.selectionSet.selections.length === 0
) {
return null;
}
const selection = definition.selectionSet.selections[0];
if (!selection || !isDefined(selection.directives)) {
return null;
}
const streamDirective = selection.directives.find(
(d: DirectiveNode) => d.name.value === 'stream',
);
if (!isDefined(streamDirective)) {
return null;
}
const args = streamDirective.arguments || [];
const directive: StreamDirective = {};
args.forEach((arg: ArgumentNode) => {
if (arg.value.kind === 'StringValue') {
const value = arg.value.value;
switch (arg.name.value) {
case 'path':
directive.path = value;
break;
case 'type':
directive.type = value;
break;
case 'method':
directive.method = value;
break;
case 'bodyKey':
directive.bodyKey = value;
break;
}
}
});
return directive;
} catch (error) {
return null;
}
}
private buildUrl({
uri,
streamDirective,
operation,
}: {
uri?: string;
streamDirective?: StreamDirective | null;
operation?: Operation;
}): string {
if (isDefined(uri)) {
return uri.startsWith('http') ? uri : `${this.baseUri}${uri}`;
}
if (isDefined(streamDirective?.path)) {
let path = streamDirective.path;
if (isDefined(operation?.variables)) {
Object.entries(operation.variables).forEach(([key, value]) => {
path = path.replace(`{args.${key}}`, String(value));
});
}
return `${this.baseUri}${path}`;
}
throw new Error('No valid URL found');
}
private buildRequestConfig({
operation,
streamDirective,
headers,
signal,
}: {
operation: Operation;
streamDirective?: StreamDirective | null;
headers?: Record<string, string>;
signal?: AbortSignal;
}): RequestInit {
const method = streamDirective?.method || 'POST';
let body: string | undefined;
if (isDefined(streamDirective?.bodyKey) && isDefined(operation.variables)) {
body = JSON.stringify(operation.variables[streamDirective.bodyKey]);
} else {
body = JSON.stringify(operation.variables);
}
return {
method,
headers: {
'Content-Type': 'application/json',
Accept: 'text/event-stream',
...this.defaultHeaders,
...headers,
...streamDirective?.headers,
},
body,
signal,
};
}
}

View File

@ -30,3 +30,15 @@ export const GET_AGENT_CHAT_MESSAGES = gql`
}
}
`;
export const STREAM_CHAT_QUERY = gql`
query StreamChatResponse($requestBody: JSON!) {
streamChatResponse(requestBody: $requestBody)
@stream(
type: "StreamChatResponse"
path: "/agent-chat/stream"
method: "POST"
bodyKey: "requestBody"
)
}
`;

View File

@ -1,105 +0,0 @@
import { getTokenPair } from '@/apollo/utils/getTokenPair';
import { renewToken } from '@/auth/services/AuthService';
import { AppPath } from '@/types/AppPath';
import { isDefined } from 'twenty-shared/utils';
import { REACT_APP_SERVER_BASE_URL } from '~/config';
import { cookieStorage } from '~/utils/cookie-storage';
const handleTokenRenewal = async () => {
const tokenPair = getTokenPair();
if (!isDefined(tokenPair?.refreshToken?.token)) {
throw new Error('No refresh token available');
}
const newTokens = await renewToken(
`${REACT_APP_SERVER_BASE_URL}/graphql`,
tokenPair,
);
if (!isDefined(newTokens)) {
throw new Error('Token renewal failed');
}
cookieStorage.setItem('tokenPair', JSON.stringify(newTokens));
return newTokens;
};
const createStreamRequest = (
threadId: string,
userMessage: string,
accessToken: string,
) => ({
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${accessToken}`,
},
body: JSON.stringify({
threadId,
userMessage,
}),
});
const handleStreamResponse = async (
response: Response,
onChunk: (chunk: string) => void,
): Promise<string> => {
const reader = response.body?.getReader();
const decoder = new TextDecoder();
let accumulated = '';
if (isDefined(reader)) {
let done = false;
while (!done) {
const { value, done: isDone } = await reader.read();
done = isDone;
if (done) break;
const chunk = decoder.decode(value, { stream: true });
accumulated += chunk;
onChunk(accumulated);
}
}
return accumulated;
};
export const streamChatResponse = async (
threadId: string,
userMessage: string,
onChunk: (chunk: string) => void,
) => {
const tokenPair = getTokenPair();
if (!isDefined(tokenPair?.accessToken?.token)) {
throw new Error('No access token available');
}
const accessToken = tokenPair.accessToken.token;
const response = await fetch(
`${REACT_APP_SERVER_BASE_URL}/rest/agent-chat/stream`,
createStreamRequest(threadId, userMessage, accessToken),
);
if (response.ok) {
return handleStreamResponse(response, onChunk);
}
if (response.status === 401) {
try {
const newTokens = await handleTokenRenewal();
const retryResponse = await fetch(
`${REACT_APP_SERVER_BASE_URL}/rest/agent-chat/stream`,
createStreamRequest(threadId, userMessage, newTokens.accessToken.token),
);
if (retryResponse.ok) {
return handleStreamResponse(retryResponse, onChunk);
}
} catch (renewalError) {
window.location.href = AppPath.SignInUp;
}
throw new Error('Authentication failed');
}
};

View File

@ -6,10 +6,11 @@ import { Key } from 'ts-key-enum';
import { useHotkeysOnFocusedElement } from '@/ui/utilities/hotkey/hooks/useHotkeysOnFocusedElement';
import { useScrollWrapperElement } from '@/ui/utilities/scroll/hooks/useScrollWrapperElement';
import { STREAM_CHAT_QUERY } from '@/workflow/workflow-steps/workflow-actions/ai-agent-action/api/agent-chat-apollo.api';
import { AgentChatMessageRole } from '@/workflow/workflow-steps/workflow-actions/ai-agent-action/constants/agent-chat-message-role';
import { useApolloClient } from '@apollo/client';
import { isDefined } from 'twenty-shared/utils';
import { v4 } from 'uuid';
import { streamChatResponse } from '../api/streamChatResponse';
import { agentChatInputState } from '../states/agentChatInputState';
import { agentChatMessagesComponentState } from '../states/agentChatMessagesComponentState';
import { agentStreamingMessageState } from '../states/agentStreamingMessageState';
@ -21,6 +22,8 @@ interface OptimisticMessage extends AgentChatMessage {
}
export const useAgentChat = (agentId: string) => {
const apolloClient = useApolloClient();
const [agentChatMessages, setAgentChatMessages] = useRecoilComponentStateV2(
agentChatMessagesComponentState,
agentId,
@ -92,9 +95,20 @@ export const useAgentChat = (agentId: string) => {
setIsStreaming(true);
await streamChatResponse(currentThreadId, content, (chunk) => {
setAgentStreamingMessage(chunk);
scrollToBottom();
await apolloClient.query({
query: STREAM_CHAT_QUERY,
variables: {
requestBody: {
threadId: currentThreadId,
userMessage: content,
},
},
context: {
onChunk: (chunk: string) => {
setAgentStreamingMessage(chunk);
scrollToBottom();
},
},
});
setIsStreaming(false);