Create custom StreamingRestLink for streaming API requests using @stream directive (#13114)

Co-authored-by: Félix Malfait <felix.malfait@gmail.com>
Co-authored-by: Jean-Baptiste Ronssin <65334819+jbronssin@users.noreply.github.com>
Co-authored-by: Paul Rastoin <45004772+prastoin@users.noreply.github.com>
Co-authored-by: Naifer <161821705+omarNaifer12@users.noreply.github.com>
Co-authored-by: Lucas Bordeau <bordeau.lucas@gmail.com>
Co-authored-by: Dmitry Moiseenko <36731450+cxdima@users.noreply.github.com>
Co-authored-by: prastoin <paul@twenty.com>
Co-authored-by: Etienne <45695613+etiennejouan@users.noreply.github.com>
Co-authored-by: Guillim <guillim@users.noreply.github.com>
Co-authored-by: Charles Bochet <charles@twenty.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions <github-actions@twenty.com>
Co-authored-by: Niklas Korz <niklas@niklaskorz.de>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Co-authored-by: Raphaël Bosi <71827178+bosiraphael@users.noreply.github.com>
Co-authored-by: Charles Bochet <charlesBochet@users.noreply.github.com>
This commit is contained in:
Abdul Rahman
2025-07-09 02:29:04 +05:30
committed by GitHub
parent 39f6f3c4bb
commit 6057bdd389
6 changed files with 496 additions and 110 deletions

View File

@ -28,6 +28,7 @@ import { isUndefinedOrNull } from '~/utils/isUndefinedOrNull';
import { ApolloManager } from '../types/apolloManager.interface';
import { getTokenPair } from '../utils/getTokenPair';
import { loggerLink } from '../utils/loggerLink';
import { StreamingRestLink } from '../utils/streamingRestLink';
const logger = loggerLink(() => 'Twenty');
@ -42,6 +43,8 @@ export interface Options<TCacheShape> extends ApolloClientOptions<TCacheShape> {
isDebugMode?: boolean;
}
const REST_API_BASE_URL = `${REACT_APP_SERVER_BASE_URL}/rest`;
export class ApolloFactory<TCacheShape> implements ApolloManager<TCacheShape> {
private client: ApolloClient<TCacheShape>;
private currentWorkspaceMember: CurrentWorkspaceMember | null = null;
@ -69,8 +72,12 @@ export class ApolloFactory<TCacheShape> implements ApolloManager<TCacheShape> {
uri,
});
const streamingRestLink = new StreamingRestLink({
uri: REST_API_BASE_URL,
});
const restLink = new RestLink({
uri: `${REACT_APP_SERVER_BASE_URL}/rest`,
uri: REST_API_BASE_URL,
});
const authLink = setContext(async (_, { headers }) => {
@ -228,6 +235,7 @@ export class ApolloFactory<TCacheShape> implements ApolloManager<TCacheShape> {
...(extraLinks || []),
isDebugMode ? logger : null,
retryLink,
streamingRestLink,
restLink,
httpLink,
].filter(isDefined),

View File

@ -0,0 +1,231 @@
import { gql } from '@apollo/client';
import { Operation } from '@apollo/client/core';
import { StreamingRestLink } from '../streamingRestLink';
global.fetch = jest.fn();
describe('StreamingRestLink', () => {
let streamingLink: StreamingRestLink;
let mockForward: jest.MockedFunction<(operation: Operation) => any>;
beforeEach(() => {
streamingLink = new StreamingRestLink({
uri: 'https://api.example.com',
});
mockForward = jest.fn();
(global.fetch as jest.Mock).mockClear();
});
describe('request', () => {
it('should forward operations without @stream directive', () => {
const operation = {
query: gql`
query Test {
test
}
`,
variables: {},
getContext: () => ({}),
} as Operation;
const result = streamingLink.request(operation, mockForward);
expect(mockForward).toHaveBeenCalledWith(operation);
expect(result).toBe(mockForward(operation));
});
it('should handle operations with @stream directive', async () => {
const operation = {
query: gql`
query StreamTest($threadId: String!) {
streamChatResponse(threadId: $threadId)
@stream(
path: "/agent-chat/stream/{args.threadId}"
method: "POST"
bodyKey: "requestBody"
)
}
`,
variables: { threadId: '123', requestBody: { threadId: '123' } },
getContext: () => ({ onChunk: jest.fn() }),
operationName: 'StreamTest',
extensions: {},
setContext: jest.fn(),
} as Operation;
const mockResponse = {
ok: true,
body: {
getReader: () => ({
read: jest.fn().mockResolvedValue({ done: true }),
releaseLock: jest.fn(),
}),
},
};
(global.fetch as jest.Mock).mockResolvedValue(mockResponse);
const observable = streamingLink.request(operation, mockForward);
const observer = {
next: jest.fn(),
error: jest.fn(),
complete: jest.fn(),
};
observable.subscribe(observer);
expect(mockForward).not.toHaveBeenCalled();
expect(global.fetch).toHaveBeenCalledWith(
'https://api.example.com/agent-chat/stream/123',
expect.objectContaining({
method: 'POST',
headers: expect.objectContaining({
'Content-Type': 'application/json',
Accept: 'text/event-stream',
}),
body: JSON.stringify({ threadId: '123' }),
}),
);
});
it('should handle network errors', async () => {
const operation = {
query: gql`
query StreamTest {
test @stream(path: "/stream", method: "GET")
}
`,
variables: {},
getContext: () => ({}),
} as Operation;
(global.fetch as jest.Mock).mockRejectedValue(new Error('Network error'));
const observable = streamingLink.request(operation, mockForward);
const observer = {
next: jest.fn(),
error: jest.fn(),
complete: jest.fn(),
};
observable.subscribe(observer);
await new Promise((resolve) => setTimeout(resolve, 0));
expect(observer.error).toHaveBeenCalledWith(new Error('Network error'));
});
it('should handle non-ok responses', async () => {
const operation = {
query: gql`
query StreamTest {
test @stream(path: "/stream", method: "GET")
}
`,
variables: {},
getContext: () => ({}),
} as Operation;
const mockResponse = { ok: false, status: 404 };
(global.fetch as jest.Mock).mockResolvedValue(mockResponse);
const observable = streamingLink.request(operation, mockForward);
const observer = {
next: jest.fn(),
error: jest.fn(),
complete: jest.fn(),
};
observable.subscribe(observer);
await new Promise((resolve) => setTimeout(resolve, 0));
expect(observer.error).toHaveBeenCalledWith(
new Error('HTTP error! status: 404'),
);
});
});
describe('extractStreamDirective', () => {
it('should extract directive arguments correctly', () => {
const operation = {
query: gql`
query Test($threadId: String!) {
streamChatResponse(threadId: $threadId)
@stream(
path: "/agent-chat/stream/{args.threadId}"
method: "POST"
bodyKey: "requestBody"
type: "StreamChatResponse"
)
}
`,
variables: {},
getContext: () => ({}),
} as Operation;
const directive = (streamingLink as any).extractStreamDirective(
operation,
);
expect(directive).toEqual({
path: '/agent-chat/stream/{args.threadId}',
method: 'POST',
bodyKey: 'requestBody',
type: 'StreamChatResponse',
});
});
it('should return null for operations without @stream directive', () => {
const operation = {
query: gql`
query Test {
test
}
`,
variables: {},
getContext: () => ({}),
} as Operation;
const directive = (streamingLink as any).extractStreamDirective(
operation,
);
expect(directive).toBeNull();
});
});
describe('buildUrl', () => {
it('should build URL with variable substitution', () => {
const operation = {
variables: { threadId: '123' },
query: gql`
query Test {
test
}
`,
operationName: 'Test',
extensions: {},
setContext: jest.fn(),
getContext: () => ({}),
} as unknown as Operation;
const directive = {
path: '/agent-chat/stream/{args.threadId}',
};
const url = (streamingLink as any).buildUrl({
streamDirective: directive,
operation,
});
expect(url).toBe('https://api.example.com/agent-chat/stream/123');
});
it('should use uri from context if provided', () => {
const url = (streamingLink as any).buildUrl({
uri: 'https://custom.example.com/api',
});
expect(url).toBe('https://custom.example.com/api');
});
});
});

View File

@ -0,0 +1,226 @@
import { ApolloLink, Observable, Operation } from '@apollo/client/core';
import { FetchResult } from '@apollo/client/link/core';
import { ArgumentNode, DirectiveNode } from 'graphql';
import { isDefined } from 'twenty-shared/utils';
type StreamingRestLinkOptions = {
uri: string;
headers?: Record<string, string>;
credentials?: RequestCredentials;
};
type StreamDirective = {
type?: string;
path?: string;
method?: string;
bodyKey?: string;
headers?: Record<string, string>;
};
export class StreamingRestLink extends ApolloLink {
private readonly baseUri: string;
private readonly defaultHeaders: Record<string, string>;
constructor(options: StreamingRestLinkOptions) {
super();
this.baseUri = options.uri;
this.defaultHeaders = options.headers || {};
}
public request(
operation: Operation,
forward: (operation: Operation) => Observable<FetchResult>,
): Observable<FetchResult> {
const streamDirective = this.extractStreamDirective(operation);
if (!streamDirective) {
return forward(operation);
}
const { uri, onChunk, headers } = operation.getContext();
return new Observable((observer) => {
const controller = new AbortController();
const url = this.buildUrl({
uri,
streamDirective,
operation,
});
const requestConfig = this.buildRequestConfig({
operation,
streamDirective,
headers,
signal: controller.signal,
});
fetch(url, requestConfig)
.then(async (response) => {
if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`);
}
if (!response.body) {
throw new Error('Response body is null');
}
const reader = response.body.getReader();
const decoder = new TextDecoder();
let accumulatedData = '';
let isStreaming = true;
while (isStreaming) {
const { done, value } = await reader.read();
if (done) {
observer.complete();
isStreaming = false;
continue;
}
const decodedChunk = decoder.decode(value, { stream: true });
accumulatedData += decodedChunk;
if (isDefined(onChunk) && typeof onChunk === 'function') {
onChunk(accumulatedData);
}
try {
const parsedData = JSON.parse(decodedChunk);
observer.next({ data: parsedData });
} catch {
observer.next({
data: { streamingData: decodedChunk },
});
}
}
})
.catch((error) => {
observer.error(error);
});
return () => controller.abort();
});
}
private extractStreamDirective(operation: Operation): StreamDirective | null {
try {
const definition = operation.query.definitions[0];
if (!definition || definition.kind !== 'OperationDefinition') {
return null;
}
if (
!definition.selectionSet ||
!definition.selectionSet.selections ||
definition.selectionSet.selections.length === 0
) {
return null;
}
const selection = definition.selectionSet.selections[0];
if (!selection || !isDefined(selection.directives)) {
return null;
}
const streamDirective = selection.directives.find(
(d: DirectiveNode) => d.name.value === 'stream',
);
if (!isDefined(streamDirective)) {
return null;
}
const args = streamDirective.arguments || [];
const directive: StreamDirective = {};
args.forEach((arg: ArgumentNode) => {
if (arg.value.kind === 'StringValue') {
const value = arg.value.value;
switch (arg.name.value) {
case 'path':
directive.path = value;
break;
case 'type':
directive.type = value;
break;
case 'method':
directive.method = value;
break;
case 'bodyKey':
directive.bodyKey = value;
break;
}
}
});
return directive;
} catch (error) {
return null;
}
}
private buildUrl({
uri,
streamDirective,
operation,
}: {
uri?: string;
streamDirective?: StreamDirective | null;
operation?: Operation;
}): string {
if (isDefined(uri)) {
return uri.startsWith('http') ? uri : `${this.baseUri}${uri}`;
}
if (isDefined(streamDirective?.path)) {
let path = streamDirective.path;
if (isDefined(operation?.variables)) {
Object.entries(operation.variables).forEach(([key, value]) => {
path = path.replace(`{args.${key}}`, String(value));
});
}
return `${this.baseUri}${path}`;
}
throw new Error('No valid URL found');
}
private buildRequestConfig({
operation,
streamDirective,
headers,
signal,
}: {
operation: Operation;
streamDirective?: StreamDirective | null;
headers?: Record<string, string>;
signal?: AbortSignal;
}): RequestInit {
const method = streamDirective?.method || 'POST';
let body: string | undefined;
if (isDefined(streamDirective?.bodyKey) && isDefined(operation.variables)) {
body = JSON.stringify(operation.variables[streamDirective.bodyKey]);
} else {
body = JSON.stringify(operation.variables);
}
return {
method,
headers: {
'Content-Type': 'application/json',
Accept: 'text/event-stream',
...this.defaultHeaders,
...headers,
...streamDirective?.headers,
},
body,
signal,
};
}
}