refactor(infra): refactor copilot client (#8813)

This commit is contained in:
EYHN
2024-11-27 06:44:47 +00:00
parent 6b4a1aa917
commit 6e25243868
8 changed files with 149 additions and 85 deletions

View File

@@ -142,10 +142,6 @@ export class AIProvider {
...options: Parameters<BlockSuitePresets.AIActions[T]>
) => ReturnType<BlockSuitePresets.AIActions[T]>
): void {
if (this.actions[id]) {
console.warn(`AI action ${id} is already provided`);
}
// @ts-expect-error TODO: maybe fix this
this.actions[id] = (
...args: Parameters<BlockSuitePresets.AIActions[T]>

View File

@@ -7,10 +7,10 @@ import {
getCopilotHistoriesQuery,
getCopilotHistoryIdsQuery,
getCopilotSessionsQuery,
gqlFetcherFactory,
GraphQLError,
type GraphQLQuery,
type QueryOptions,
type QueryResponse,
type RequestOptions,
UserFriendlyError,
} from '@affine/graphql';
@@ -21,26 +21,6 @@ import {
} from '@blocksuite/affine/blocks';
import { getCurrentStore } from '@toeverything/infra';
/**
* @deprecated will be removed soon
*/
export function getBaseUrl(): string {
if (BUILD_CONFIG.isElectron || BUILD_CONFIG.isIOS || BUILD_CONFIG.isAndroid) {
return BUILD_CONFIG.serverUrlPrefix;
}
if (typeof window === 'undefined') {
// is nodejs
return '';
}
const { protocol, hostname, port } = window.location;
return `${protocol}//${hostname}${port ? `:${port}` : ''}`;
}
/**
* @deprecated will be removed soon
*/
const defaultFetcher = gqlFetcherFactory(getBaseUrl() + '/graphql');
type OptionsField<T extends GraphQLQuery> =
RequestOptions<T>['variables'] extends { options: infer U } ? U : never;
@@ -76,23 +56,22 @@ export function handleError(src: any) {
return err;
}
const fetcher = async <Query extends GraphQLQuery>(
options: QueryOptions<Query>
) => {
try {
return await defaultFetcher<Query>(options);
} catch (err) {
throw handleError(err);
}
};
export class CopilotClient {
readonly backendUrl = getBaseUrl();
constructor(
readonly gql: <Query extends GraphQLQuery>(
options: QueryOptions<Query>
) => Promise<QueryResponse<Query>>,
readonly fetcher: (input: string, init?: RequestInit) => Promise<Response>,
readonly eventSource: (
url: string,
eventSourceInitDict?: EventSourceInit
) => EventSource
) {}
async createSession(
options: OptionsField<typeof createCopilotSessionMutation>
) {
const res = await fetcher({
const res = await this.gql({
query: createCopilotSessionMutation,
variables: {
options,
@@ -102,7 +81,7 @@ export class CopilotClient {
}
async forkSession(options: OptionsField<typeof forkCopilotSessionMutation>) {
const res = await fetcher({
const res = await this.gql({
query: forkCopilotSessionMutation,
variables: {
options,
@@ -114,7 +93,7 @@ export class CopilotClient {
async createMessage(
options: OptionsField<typeof createCopilotMessageMutation>
) {
const res = await fetcher({
const res = await this.gql({
query: createCopilotMessageMutation,
variables: {
options,
@@ -124,7 +103,7 @@ export class CopilotClient {
}
async getSessions(workspaceId: string) {
const res = await fetcher({
const res = await this.gql({
query: getCopilotSessionsQuery,
variables: {
workspaceId,
@@ -140,7 +119,7 @@ export class CopilotClient {
typeof getCopilotHistoriesQuery
>['variables']['options']
) {
const res = await fetcher({
const res = await this.gql({
query: getCopilotHistoriesQuery,
variables: {
workspaceId,
@@ -159,7 +138,7 @@ export class CopilotClient {
typeof getCopilotHistoriesQuery
>['variables']['options']
) {
const res = await fetcher({
const res = await this.gql({
query: getCopilotHistoryIdsQuery,
variables: {
workspaceId,
@@ -176,7 +155,7 @@ export class CopilotClient {
docId: string;
sessionIds: string[];
}) {
const res = await fetcher({
const res = await this.gql({
query: cleanupCopilotSessionMutation,
variables: {
input,
@@ -194,11 +173,11 @@ export class CopilotClient {
messageId?: string;
signal?: AbortSignal;
}) {
const url = new URL(`${this.backendUrl}/api/copilot/chat/${sessionId}`);
let url = `/api/copilot/chat/${sessionId}`;
if (messageId) {
url.searchParams.set('messageId', messageId);
url += `?messageId=${encodeURIComponent(messageId)}`;
}
const response = await fetch(url.toString(), { signal });
const response = await this.fetcher(url.toString(), { signal });
return response.text();
}
@@ -213,11 +192,11 @@ export class CopilotClient {
},
endpoint = 'stream'
) {
const url = new URL(
`${this.backendUrl}/api/copilot/chat/${sessionId}/${endpoint}`
);
if (messageId) url.searchParams.set('messageId', messageId);
return new EventSource(url.toString());
let url = `/api/copilot/chat/${sessionId}/${endpoint}`;
if (messageId) {
url += `?messageId=${encodeURIComponent(messageId)}`;
}
return this.eventSource(url);
}
// Text or image to images
@@ -227,15 +206,18 @@ export class CopilotClient {
seed?: string,
endpoint = 'images'
) {
const url = new URL(
`${this.backendUrl}/api/copilot/chat/${sessionId}/${endpoint}`
);
if (messageId) {
url.searchParams.set('messageId', messageId);
let url = `/api/copilot/chat/${sessionId}/${endpoint}`;
if (messageId || seed) {
url += '?';
url += new URLSearchParams(
Object.fromEntries(
Object.entries({ messageId, seed }).filter(
([_, v]) => v !== undefined
)
) as Record<string, string>
).toString();
}
if (seed) {
url.searchParams.set('seed', seed);
}
return new EventSource(url);
return this.eventSource(url);
}
}

View File

@@ -3,15 +3,14 @@ import type { ForkChatSessionInput } from '@affine/graphql';
import { assertExists } from '@blocksuite/affine/global/utils';
import { partition } from 'lodash-es';
import { CopilotClient } from './copilot-client';
import type { CopilotClient } from './copilot-client';
import { delay, toTextStream } from './event-source';
import type { PromptKey } from './prompt';
const TIMEOUT = 50000;
const client = new CopilotClient();
export type TextToTextOptions = {
client: CopilotClient;
docId: string;
workspaceId: string;
promptName?: PromptKey;
@@ -33,9 +32,11 @@ export type ToImageOptions = TextToTextOptions & {
};
export function createChatSession({
client,
workspaceId,
docId,
}: {
client: CopilotClient;
workspaceId: string;
docId: string;
}) {
@@ -46,7 +47,10 @@ export function createChatSession({
});
}
export function forkCopilotSession(forkChatSessionInput: ForkChatSessionInput) {
export function forkCopilotSession(
client: CopilotClient,
forkChatSessionInput: ForkChatSessionInput
) {
return client.forkSession(forkChatSessionInput);
}
@@ -83,6 +87,7 @@ async function resizeImage(blob: Blob | File): Promise<Blob | null> {
}
async function createSessionMessage({
client,
docId,
workspaceId,
promptName,
@@ -140,6 +145,7 @@ async function createSessionMessage({
}
export function textToText({
client,
docId,
workspaceId,
promptName,
@@ -169,6 +175,7 @@ export function textToText({
_messageId = undefined;
} else {
const message = await createSessionMessage({
client,
docId,
workspaceId,
promptName,
@@ -242,6 +249,7 @@ export function textToText({
_messageId = undefined;
} else {
const message = await createSessionMessage({
client,
docId,
workspaceId,
promptName,
@@ -268,10 +276,6 @@ export function textToText({
}
}
export const listHistories = client.getHistories;
export const listHistoryIds = client.getHistoryIds;
// Only one image is currently being processed
export function toImage({
docId,
@@ -286,6 +290,7 @@ export function toImage({
timeout = TIMEOUT,
retry = false,
workflow = false,
client,
}: ToImageOptions) {
let _sessionId: string;
let _messageId: string | undefined;
@@ -305,6 +310,7 @@ export function toImage({
content,
attachments,
params,
client,
});
_sessionId = sessionId;
_messageId = messageId;
@@ -334,10 +340,12 @@ export function cleanupSessions({
workspaceId,
docId,
sessionIds,
client,
}: {
workspaceId: string;
docId: string;
sessionIds: string[];
client: CopilotClient;
}) {
return client.cleanupSessions({ workspaceId, docId, sessionIds });
}

View File

@@ -10,13 +10,12 @@ import { assertExists } from '@blocksuite/affine/global/utils';
import { getCurrentStore } from '@toeverything/infra';
import { z } from 'zod';
import { getBaseUrl } from './copilot-client';
import type { CopilotClient } from './copilot-client';
import type { PromptKey } from './prompt';
import {
cleanupSessions,
createChatSession,
forkCopilotSession,
listHistories,
textToText,
toImage,
} from './request';
@@ -39,11 +38,11 @@ const processTypeToPromptName = new Map(
})
);
export function setupAIProvider() {
// a single workspace should have only a single chat session
// user-id:workspace-id:doc-id -> chat session id
const chatSessions = new Map<string, Promise<string>>();
// a single workspace should have only a single chat session
// user-id:workspace-id:doc-id -> chat session id
const chatSessions = new Map<string, Promise<string>>();
export function setupAIProvider(client: CopilotClient) {
async function getChatSessionId(workspaceId: string, docId: string) {
const userId = (await AIProvider.userInfo)?.id;
@@ -56,6 +55,7 @@ export function setupAIProvider() {
chatSessions.set(
storeKey,
createChatSession({
client,
workspaceId,
docId,
})
@@ -78,6 +78,7 @@ export function setupAIProvider() {
options.sessionId ?? getChatSessionId(options.workspaceId, options.docId);
return textToText({
...options,
client,
content: options.input,
sessionId,
});
@@ -86,6 +87,7 @@ export function setupAIProvider() {
AIProvider.provide('summary', options => {
return textToText({
...options,
client,
content: options.input,
promptName: 'Summary',
});
@@ -94,6 +96,7 @@ export function setupAIProvider() {
AIProvider.provide('translate', options => {
return textToText({
...options,
client,
promptName: 'Translate to',
content: options.input,
params: {
@@ -105,6 +108,7 @@ export function setupAIProvider() {
AIProvider.provide('changeTone', options => {
return textToText({
...options,
client,
params: {
tone: options.tone.toLowerCase(),
},
@@ -116,6 +120,7 @@ export function setupAIProvider() {
AIProvider.provide('improveWriting', options => {
return textToText({
...options,
client,
content: options.input,
promptName: 'Improve writing for it',
});
@@ -124,6 +129,7 @@ export function setupAIProvider() {
AIProvider.provide('improveGrammar', options => {
return textToText({
...options,
client,
content: options.input,
promptName: 'Improve grammar for it',
});
@@ -132,6 +138,7 @@ export function setupAIProvider() {
AIProvider.provide('fixSpelling', options => {
return textToText({
...options,
client,
content: options.input,
promptName: 'Fix spelling for it',
});
@@ -140,6 +147,7 @@ export function setupAIProvider() {
AIProvider.provide('createHeadings', options => {
return textToText({
...options,
client,
content: options.input,
promptName: 'Create headings',
});
@@ -148,6 +156,7 @@ export function setupAIProvider() {
AIProvider.provide('makeLonger', options => {
return textToText({
...options,
client,
content: options.input,
promptName: 'Make it longer',
});
@@ -156,6 +165,7 @@ export function setupAIProvider() {
AIProvider.provide('makeShorter', options => {
return textToText({
...options,
client,
content: options.input,
promptName: 'Make it shorter',
});
@@ -164,6 +174,7 @@ export function setupAIProvider() {
AIProvider.provide('checkCodeErrors', options => {
return textToText({
...options,
client,
content: options.input,
promptName: 'Check code error',
});
@@ -172,6 +183,7 @@ export function setupAIProvider() {
AIProvider.provide('explainCode', options => {
return textToText({
...options,
client,
content: options.input,
promptName: 'Explain this code',
});
@@ -180,6 +192,7 @@ export function setupAIProvider() {
AIProvider.provide('writeArticle', options => {
return textToText({
...options,
client,
content: options.input,
promptName: 'Write an article about this',
});
@@ -188,6 +201,7 @@ export function setupAIProvider() {
AIProvider.provide('writeTwitterPost', options => {
return textToText({
...options,
client,
content: options.input,
promptName: 'Write a twitter about this',
});
@@ -196,6 +210,7 @@ export function setupAIProvider() {
AIProvider.provide('writePoem', options => {
return textToText({
...options,
client,
content: options.input,
promptName: 'Write a poem about this',
});
@@ -204,6 +219,7 @@ export function setupAIProvider() {
AIProvider.provide('writeOutline', options => {
return textToText({
...options,
client,
content: options.input,
promptName: 'Write outline',
});
@@ -212,6 +228,7 @@ export function setupAIProvider() {
AIProvider.provide('writeBlogPost', options => {
return textToText({
...options,
client,
content: options.input,
promptName: 'Write a blog post about this',
});
@@ -220,6 +237,7 @@ export function setupAIProvider() {
AIProvider.provide('brainstorm', options => {
return textToText({
...options,
client,
content: options.input,
promptName: 'Brainstorm ideas about this',
});
@@ -228,6 +246,7 @@ export function setupAIProvider() {
AIProvider.provide('findActions', options => {
return textToText({
...options,
client,
content: options.input,
promptName: 'Find action items from it',
});
@@ -236,6 +255,7 @@ export function setupAIProvider() {
AIProvider.provide('brainstormMindmap', options => {
return textToText({
...options,
client,
content: options.input,
promptName: 'workflow:brainstorm',
workflow: true,
@@ -246,6 +266,7 @@ export function setupAIProvider() {
assertExists(options.input, 'expandMindmap action requires input');
return textToText({
...options,
client,
params: {
mindmap: options.mindmap,
node: options.input,
@@ -258,6 +279,7 @@ export function setupAIProvider() {
AIProvider.provide('explain', options => {
return textToText({
...options,
client,
content: options.input,
promptName: 'Explain this',
});
@@ -266,6 +288,7 @@ export function setupAIProvider() {
AIProvider.provide('explainImage', options => {
return textToText({
...options,
client,
content: options.input,
promptName: 'Explain this image',
});
@@ -288,6 +311,7 @@ Could you make a new website based on these notes and send back just the html fi
return textToText({
...options,
client,
content,
promptName,
});
@@ -332,6 +356,7 @@ Could you make a new website based on these notes and send back just the html fi
};
return textToText({
...options,
client,
content: options.input,
promptName: 'workflow:presentation',
workflow: true,
@@ -348,6 +373,7 @@ Could you make a new website based on these notes and send back just the html fi
}
return toImage({
...options,
client,
promptName,
});
});
@@ -357,6 +383,7 @@ Could you make a new website based on these notes and send back just the html fi
const promptName = filterStyleToPromptName.get(options.style as string);
return toImage({
...options,
client,
timeout: 120000,
promptName: promptName as PromptKey,
workflow: !!promptName?.startsWith('workflow:'),
@@ -370,6 +397,7 @@ Could you make a new website based on these notes and send back just the html fi
) as PromptKey;
return toImage({
...options,
client,
timeout: 120000,
promptName,
});
@@ -378,6 +406,7 @@ Could you make a new website based on these notes and send back just the html fi
AIProvider.provide('generateCaption', options => {
return textToText({
...options,
client,
content: options.input,
promptName: 'Generate a caption',
});
@@ -386,6 +415,7 @@ Could you make a new website based on these notes and send back just the html fi
AIProvider.provide('continueWriting', options => {
return textToText({
...options,
client,
content: options.input,
promptName: 'Continue writing',
});
@@ -399,7 +429,7 @@ Could you make a new website based on these notes and send back just the html fi
): Promise<BlockSuitePresets.AIHistory[]> => {
// @ts-expect-error - 'action' is missing in server impl
return (
(await listHistories(workspaceId, docId, {
(await client.getHistories(workspaceId, docId, {
action: true,
})) ?? []
);
@@ -412,14 +442,14 @@ Could you make a new website based on these notes and send back just the html fi
>['variables']['options']
): Promise<BlockSuitePresets.AIHistory[]> => {
// @ts-expect-error - 'action' is missing in server impl
return (await listHistories(workspaceId, docId, options)) ?? [];
return (await client.getHistories(workspaceId, docId, options)) ?? [];
},
cleanup: async (
workspaceId: string,
docId: string,
sessionIds: string[]
) => {
await cleanupSessions({ workspaceId, docId, sessionIds });
await cleanupSessions({ workspaceId, docId, sessionIds, client });
},
ids: async (
workspaceId: string,
@@ -429,21 +459,23 @@ Could you make a new website based on these notes and send back just the html fi
>['variables']['options']
): Promise<BlockSuitePresets.AIHistoryIds[]> => {
// @ts-expect-error - 'role' is missing type in server impl
return await listHistories(workspaceId, docId, options);
return await client.getHistoryIds(workspaceId, docId, options);
},
});
AIProvider.provide('photoEngine', {
async searchImages(options): Promise<string[]> {
const url = new URL(getBaseUrl() + '/api/copilot/unsplash/photos');
url.searchParams.set('query', options.query);
let url = '/api/copilot/unsplash/photos';
if (options.query) {
url += `?query=${encodeURIComponent(options.query)}`;
}
const result: {
results?: {
urls: {
regular: string;
};
}[];
} = await fetch(url.toString()).then(res => res.json());
} = await client.fetcher(url.toString()).then(res => res.json());
if (!result.results) return [];
return result.results.map(r => {
const url = new URL(r.urls.regular);
@@ -460,10 +492,10 @@ Could you make a new website based on these notes and send back just the html fi
AIProvider.provide('onboarding', toggleGeneralAIOnboarding);
AIProvider.provide('forkChat', options => {
return forkCopilotSession(options);
return forkCopilotSession(client, options);
});
AIProvider.slots.requestLogin.on(() => {
const disposeRequestLoginHandler = AIProvider.slots.requestLogin.on(() => {
getCurrentStore().set(authAtom, s => ({
...s,
openModal: true,
@@ -471,4 +503,8 @@ Could you make a new website based on these notes and send back just the html fi
});
setupTracker();
return () => {
disposeRequestLoginHandler.dispose();
};
}

View File

@@ -1,13 +1,11 @@
import { registerBlocksuitePresetsCustomComponents } from '@affine/core/blocksuite/presets/effects';
import { effects as bsEffects } from '@blocksuite/affine/effects';
import { setupAIProvider } from './ai/setup-provider';
import { effects as edgelessEffects } from './specs/edgeless';
import { effects as patchEffects } from './specs/preview';
bsEffects();
patchEffects();
setupAIProvider();
edgelessEffects();
registerBlocksuitePresetsCustomComponents();

View File

@@ -8,6 +8,11 @@ import { SyncAwareness } from '@affine/core/components/affine/awareness';
import { useRegisterFindInPageCommands } from '@affine/core/components/hooks/affine/use-register-find-in-page-commands';
import { useRegisterWorkspaceCommands } from '@affine/core/components/hooks/use-register-workspace-commands';
import { OverCapacityNotification } from '@affine/core/components/over-capacity';
import {
EventSourceService,
FetchService,
GraphQLService,
} from '@affine/core/modules/cloud';
import { GlobalDialogService } from '@affine/core/modules/dialogs';
import { EditorSettingService } from '@affine/core/modules/editor-setting';
import { useRegisterNavigationCommands } from '@affine/core/modules/navigation/view/use-register-navigation-commands';
@@ -38,6 +43,9 @@ import {
} from 'rxjs';
import { Map as YMap } from 'yjs';
import { CopilotClient } from '../blocksuite/block-suite-editor/ai/copilot-client';
import { setupAIProvider } from '../blocksuite/block-suite-editor/ai/setup-provider';
/**
* @deprecated just for legacy code, will be removed in the future
*/
@@ -129,6 +137,23 @@ export const WorkspaceSideEffects = () => {
};
}, [globalDialogService]);
const graphqlService = useService(GraphQLService);
const eventSourceService = useService(EventSourceService);
const fetchService = useService(FetchService);
useEffect(() => {
const dispose = setupAIProvider(
new CopilotClient(
graphqlService.gql,
fetchService.fetch,
eventSourceService.eventSource
)
);
return () => {
dispose();
};
}, [eventSourceService, fetchService, graphqlService]);
useRegisterWorkspaceCommands();
useRegisterNavigationCommands();
useRegisterFindInPageCommands();

View File

@@ -13,6 +13,7 @@ export { WebSocketAuthProvider } from './provider/websocket-auth';
export { AccountChanged, AuthService } from './services/auth';
export { CaptchaService } from './services/captcha';
export { DefaultServerService } from './services/default-server';
export { EventSourceService } from './services/eventsource';
export { FetchService } from './services/fetch';
export { GraphQLService } from './services/graphql';
export { InvoicesService } from './services/invoices';
@@ -53,6 +54,7 @@ import { AuthService } from './services/auth';
import { CaptchaService } from './services/captcha';
import { CloudDocMetaService } from './services/cloud-doc-meta';
import { DefaultServerService } from './services/default-server';
import { EventSourceService } from './services/eventsource';
import { FetchService } from './services/fetch';
import { GraphQLService } from './services/graphql';
import { InvoicesService } from './services/invoices';
@@ -84,6 +86,7 @@ export function configureCloudModule(framework: Framework) {
.scope(ServerScope)
.service(ServerService, [ServerScope])
.service(FetchService, [RawFetchProvider, ServerService])
.service(EventSourceService, [ServerService])
.service(GraphQLService, [FetchService])
.service(
WebSocketService,

View File

@@ -0,0 +1,16 @@
import { Service } from '@toeverything/infra';
import type { ServerService } from './server';
export class EventSourceService extends Service {
constructor(private readonly serverService: ServerService) {
super();
}
eventSource = (url: string, eventSourceInitDict?: EventSourceInit) => {
return new EventSource(
new URL(url, this.serverService.server.baseUrl),
eventSourceInitDict
);
};
}