test: copilot unit & e2e test (#6649)

fix CLOUD-31
This commit is contained in:
darkskygit
2024-04-26 09:43:35 +00:00
parent f015a11181
commit 850bbee629
12 changed files with 1145 additions and 134 deletions

View File

@@ -45,6 +45,7 @@ if (env.R2_OBJECT_STORAGE_ACCOUNT_ID) {
AFFiNE.plugins.use('copilot', {
openai: {},
fal: {},
});
AFFiNE.plugins.use('redis');
AFFiNE.plugins.use('payment', {

View File

@@ -42,6 +42,11 @@ export interface ChatEvent {
data: string;
}
type CheckResult = {
model: string | undefined;
hasAttachment?: boolean;
};
@Controller('/api/copilot')
export class CopilotController {
private readonly logger = new Logger(CopilotController.name);
@@ -53,17 +58,26 @@ export class CopilotController {
private readonly storage: CopilotStorage
) {}
private async hasAttachment(sessionId: string, messageId: string) {
private async checkRequest(
userId: string,
sessionId: string,
messageId?: string
): Promise<CheckResult> {
await this.chatSession.checkQuota(userId);
const session = await this.chatSession.get(sessionId);
if (!session) {
if (!session || session.config.userId !== userId) {
throw new BadRequestException('Session not found');
}
const message = await session.getMessageById(messageId);
if (Array.isArray(message.attachments) && message.attachments.length) {
return true;
const ret: CheckResult = { model: session.model };
if (messageId) {
const message = await session.getMessageById(messageId);
ret.hasAttachment =
Array.isArray(message.attachments) && !!message.attachments.length;
}
return false;
return ret;
}
private async appendSessionMessage(
@@ -107,9 +121,7 @@ export class CopilotController {
@Query('messageId') messageId: string,
@Query() params: Record<string, string | string[]>
): Promise<string> {
await this.chatSession.checkQuota(user.id);
const model = await this.chatSession.get(sessionId).then(s => s?.model);
const { model } = await this.checkRequest(user.id, sessionId);
const provider = this.provider.getProviderByCapability(
CopilotCapability.TextToText,
model
@@ -155,60 +167,58 @@ export class CopilotController {
@Query() params: Record<string, string>
): Promise<Observable<ChatEvent>> {
try {
await this.chatSession.checkQuota(user.id);
const { model } = await this.checkRequest(user.id, sessionId);
const provider = this.provider.getProviderByCapability(
CopilotCapability.TextToText,
model
);
if (!provider) {
throw new InternalServerErrorException('No provider available');
}
const session = await this.appendSessionMessage(sessionId, messageId);
delete params.messageId;
return from(
provider.generateTextStream(session.finish(params), session.model, {
signal: this.getSignal(req),
user: user.id,
})
).pipe(
connect(shared$ =>
merge(
// actual chat event stream
shared$.pipe(
map(data => ({ type: 'message' as const, id: messageId, data }))
),
// save the generated text to the session
shared$.pipe(
toArray(),
concatMap(values => {
session.push({
role: 'assistant',
content: values.join(''),
createdAt: new Date(),
});
return from(session.save());
}),
switchMap(() => EMPTY)
)
)
),
catchError(err =>
of({
type: 'error' as const,
data: this.handleError(err),
})
)
);
} catch (err) {
return of({
type: 'error' as const,
data: this.handleError(err),
});
}
const model = await this.chatSession.get(sessionId).then(s => s?.model);
const provider = this.provider.getProviderByCapability(
CopilotCapability.TextToText,
model
);
if (!provider) {
throw new InternalServerErrorException('No provider available');
}
const session = await this.appendSessionMessage(sessionId, messageId);
delete params.messageId;
return from(
provider.generateTextStream(session.finish(params), session.model, {
signal: this.getSignal(req),
user: user.id,
})
).pipe(
connect(shared$ =>
merge(
// actual chat event stream
shared$.pipe(
map(data => ({ type: 'message' as const, id: sessionId, data }))
),
// save the generated text to the session
shared$.pipe(
toArray(),
concatMap(values => {
session.push({
role: 'assistant',
content: values.join(''),
createdAt: new Date(),
});
return from(session.save());
}),
switchMap(() => EMPTY)
)
)
),
catchError(err =>
of({
type: 'error' as const,
data: this.handleError(err),
})
)
);
}
@Sse('/chat/:sessionId/images')
@@ -220,75 +230,76 @@ export class CopilotController {
@Query() params: Record<string, string>
): Promise<Observable<ChatEvent>> {
try {
await this.chatSession.checkQuota(user.id);
const { model, hasAttachment } = await this.checkRequest(
user.id,
sessionId,
messageId
);
const provider = this.provider.getProviderByCapability(
hasAttachment
? CopilotCapability.ImageToImage
: CopilotCapability.TextToImage,
model
);
if (!provider) {
throw new InternalServerErrorException('No provider available');
}
const session = await this.appendSessionMessage(sessionId, messageId);
delete params.messageId;
const handleRemoteLink = this.storage.handleRemoteLink.bind(
this.storage,
user.id,
sessionId
);
return from(
provider.generateImagesStream(session.finish(params), session.model, {
signal: this.getSignal(req),
user: user.id,
})
).pipe(
mergeMap(handleRemoteLink),
connect(shared$ =>
merge(
// actual chat event stream
shared$.pipe(
map(attachment => ({
type: 'attachment' as const,
id: messageId,
data: attachment,
}))
),
// save the generated text to the session
shared$.pipe(
toArray(),
concatMap(attachments => {
session.push({
role: 'assistant',
content: '',
attachments: attachments,
createdAt: new Date(),
});
return from(session.save());
}),
switchMap(() => EMPTY)
)
)
),
catchError(err =>
of({
type: 'error' as const,
data: this.handleError(err),
})
)
);
} catch (err) {
return of({
type: 'error' as const,
data: this.handleError(err),
});
}
const hasAttachment = await this.hasAttachment(sessionId, messageId);
const model = await this.chatSession.get(sessionId).then(s => s?.model);
const provider = this.provider.getProviderByCapability(
hasAttachment
? CopilotCapability.ImageToImage
: CopilotCapability.TextToImage,
model
);
if (!provider) {
throw new InternalServerErrorException('No provider available');
}
const session = await this.appendSessionMessage(sessionId, messageId);
delete params.messageId;
const handleRemoteLink = this.storage.handleRemoteLink.bind(
this.storage,
user.id,
sessionId
);
return from(
provider.generateImagesStream(session.finish(params), session.model, {
signal: this.getSignal(req),
user: user.id,
})
).pipe(
mergeMap(handleRemoteLink),
connect(shared$ =>
merge(
// actual chat event stream
shared$.pipe(
map(attachment => ({
type: 'attachment' as const,
id: sessionId,
data: attachment,
}))
),
// save the generated text to the session
shared$.pipe(
toArray(),
concatMap(attachments => {
session.push({
role: 'assistant',
content: '',
attachments: attachments,
createdAt: new Date(),
});
return from(session.save());
}),
switchMap(() => EMPTY)
)
)
),
catchError(err =>
of({
type: 'error' as const,
data: this.handleError(err),
})
)
);
}
@Get('/unsplash/photos')

View File

@@ -193,11 +193,12 @@ export class PromptService {
return null;
}
async set(name: string, messages: PromptMessage[]) {
async set(name: string, model: string, messages: PromptMessage[]) {
return await this.db.aiPrompt
.create({
data: {
name,
model,
messages: {
create: messages.map((m, idx) => ({
idx,

View File

@@ -41,6 +41,10 @@ export class FalProvider
return !!config.apiKey;
}
get type(): CopilotProviderType {
return FalProvider.type;
}
getCapabilities(): CopilotCapability[] {
return FalProvider.capabilities;
}

View File

@@ -13,7 +13,7 @@ import {
PromptMessage,
} from '../types';
const DEFAULT_DIMENSIONS = 256;
export const DEFAULT_DIMENSIONS = 256;
const SIMPLE_IMAGE_URL_REGEX = /^(https?:\/\/|data:image\/)/;
@@ -59,6 +59,10 @@ export class OpenAIProvider
return !!config.apiKey;
}
get type(): CopilotProviderType {
return OpenAIProvider.type;
}
getCapabilities(): CopilotCapability[] {
return OpenAIProvider.capabilities;
}
@@ -67,7 +71,7 @@ export class OpenAIProvider
return this.availableModels.includes(model);
}
private chatToGPTMessage(
protected chatToGPTMessage(
messages: PromptMessage[]
): OpenAI.Chat.Completions.ChatCompletionMessageParam[] {
// filter redundant fields
@@ -92,7 +96,7 @@ export class OpenAIProvider
});
}
private checkParams({
protected checkParams({
messages,
embeddings,
model,

View File

@@ -278,7 +278,9 @@ export class CopilotResolver {
return new TooManyRequestsException('Server is busy');
}
const session = await this.chatSession.get(options.sessionId);
if (!session) return new BadRequestException('Session not found');
if (!session || session.config.userId !== user.id) {
return new BadRequestException('Session not found');
}
if (options.blobs) {
options.attachments = options.attachments || [];

View File

@@ -81,7 +81,7 @@ export class ChatSession implements AsyncDisposable {
}
pop() {
this.state.messages.pop();
return this.state.messages.pop();
}
private takeMessages(): ChatMessage[] {
@@ -115,7 +115,7 @@ export class ChatSession implements AsyncDisposable {
Object.keys(params).length ? params : messages[0]?.params || {},
this.config.sessionId
),
...messages.filter(m => m.content || m.attachments?.length),
...messages.filter(m => m.content?.trim() || m.attachments?.length),
];
}

View File

@@ -15,6 +15,7 @@ export interface CopilotConfig {
openai: OpenAIClientOptions;
fal: FalConfig;
unsplashKey: string;
test: never;
}
export enum AvailableModels {
@@ -130,6 +131,8 @@ export type ListHistoriesOptions = {
export enum CopilotProviderType {
FAL = 'fal',
OpenAI = 'openai',
// only for test
Test = 'test',
}
export enum CopilotCapability {
@@ -141,6 +144,7 @@ export enum CopilotCapability {
}
export interface CopilotProvider {
readonly type: CopilotProviderType;
getCapabilities(): CopilotCapability[];
isModelAvailable(model: string): boolean;
}

View File

@@ -0,0 +1,382 @@
/// <reference types="../src/global.d.ts" />
import { randomUUID } from 'node:crypto';
import { INestApplication } from '@nestjs/common';
import type { TestFn } from 'ava';
import ava from 'ava';
import Sinon from 'sinon';
import { AuthService } from '../src/core/auth';
import { WorkspaceModule } from '../src/core/workspaces';
import { ConfigModule } from '../src/fundamentals/config';
import { CopilotModule } from '../src/plugins/copilot';
import { PromptService } from '../src/plugins/copilot/prompt';
import {
CopilotProviderService,
registerCopilotProvider,
} from '../src/plugins/copilot/providers';
import { CopilotStorage } from '../src/plugins/copilot/storage';
import {
acceptInviteById,
createTestingApp,
createWorkspace,
inviteUser,
signUp,
} from './utils';
import {
chatWithImages,
chatWithText,
chatWithTextStream,
createCopilotMessage,
createCopilotSession,
getHistories,
MockCopilotTestProvider,
textToEventStream,
} from './utils/copilot';
const test = ava as TestFn<{
auth: AuthService;
app: INestApplication;
prompt: PromptService;
provider: CopilotProviderService;
storage: CopilotStorage;
}>;
test.beforeEach(async t => {
const { app } = await createTestingApp({
imports: [
ConfigModule.forRoot({
plugins: {
copilot: {
openai: {
apiKey: '1',
},
fal: {
apiKey: '1',
},
},
},
}),
WorkspaceModule,
CopilotModule,
],
});
const auth = app.get(AuthService);
const prompt = app.get(PromptService);
const storage = app.get(CopilotStorage);
t.context.app = app;
t.context.auth = auth;
t.context.prompt = prompt;
t.context.storage = storage;
});
let token: string;
const promptName = 'prompt';
test.beforeEach(async t => {
const { app, prompt } = t.context;
const user = await signUp(app, 'test', 'darksky@affine.pro', '123456');
token = user.token.token;
registerCopilotProvider(MockCopilotTestProvider);
await prompt.set(promptName, 'test', [
{ role: 'system', content: 'hello {{word}}' },
]);
});
test.afterEach.always(async t => {
await t.context.app.close();
});
// ==================== session ====================
test('should create session correctly', async t => {
const { app } = t.context;
const assertCreateSession = async (
workspaceId: string,
error: string,
asserter = async (x: any) => {
t.truthy(await x, error);
}
) => {
await asserter(
createCopilotSession(app, token, workspaceId, randomUUID(), promptName)
);
};
{
const { id } = await createWorkspace(app, token);
await assertCreateSession(
id,
'should be able to create session with cloud workspace that user can access'
);
}
{
await assertCreateSession(
randomUUID(),
'should be able to create session with local workspace'
);
}
{
const {
token: { token },
} = await signUp(app, 'test', 'test@affine.pro', '123456');
const { id } = await createWorkspace(app, token);
await assertCreateSession(id, '', async x => {
await t.throwsAsync(
x,
{ instanceOf: Error },
'should not able to create session with cloud workspace that user cannot access'
);
});
const inviteId = await inviteUser(
app,
token,
id,
'darksky@affine.pro',
'Admin'
);
await acceptInviteById(app, id, inviteId, false);
await assertCreateSession(
id,
'should able to create session after user have permission'
);
}
});
test('should be able to use test provider', async t => {
const { app } = t.context;
const { id } = await createWorkspace(app, token);
t.truthy(
await createCopilotSession(app, token, id, randomUUID(), promptName),
'failed to create session'
);
});
// ==================== message ====================
test('should create message correctly', async t => {
const { app } = t.context;
{
const { id } = await createWorkspace(app, token);
const sessionId = await createCopilotSession(
app,
token,
id,
randomUUID(),
promptName
);
const messageId = await createCopilotMessage(app, token, sessionId);
t.truthy(messageId, 'should be able to create message with valid session');
}
{
await t.throwsAsync(
createCopilotMessage(app, token, randomUUID()),
{ instanceOf: Error },
'should not able to create message with invalid session'
);
}
});
// ==================== chat ====================
test('should be able to chat with api', async t => {
const { app, storage } = t.context;
Sinon.stub(storage, 'handleRemoteLink').resolvesArg(2);
const { id } = await createWorkspace(app, token);
const sessionId = await createCopilotSession(
app,
token,
id,
randomUUID(),
promptName
);
const messageId = await createCopilotMessage(app, token, sessionId);
const ret = await chatWithText(app, token, sessionId, messageId);
t.is(ret, 'generate text to text', 'should be able to chat with text');
const ret2 = await chatWithTextStream(app, token, sessionId, messageId);
t.is(
ret2,
textToEventStream('generate text to text stream', messageId),
'should be able to chat with text stream'
);
const ret3 = await chatWithImages(app, token, sessionId, messageId);
t.is(
ret3,
textToEventStream(
['https://example.com/image.jpg'],
messageId,
'attachment'
),
'should be able to chat with images'
);
Sinon.restore();
});
test('should reject message from different session', async t => {
const { app } = t.context;
const { id } = await createWorkspace(app, token);
const sessionId = await createCopilotSession(
app,
token,
id,
randomUUID(),
promptName
);
const anotherSessionId = await createCopilotSession(
app,
token,
id,
randomUUID(),
promptName
);
const anotherMessageId = await createCopilotMessage(
app,
token,
anotherSessionId
);
await t.throwsAsync(
chatWithText(app, token, sessionId, anotherMessageId),
{ instanceOf: Error },
'should reject message from different session'
);
});
test('should reject request from different user', async t => {
const { app } = t.context;
const { id } = await createWorkspace(app, token);
const sessionId = await createCopilotSession(
app,
token,
id,
randomUUID(),
promptName
);
// should reject message from different user
{
const { token } = await signUp(app, 'a1', 'a1@affine.pro', '123456');
await t.throwsAsync(
createCopilotMessage(app, token.token, sessionId),
{ instanceOf: Error },
'should reject message from different user'
);
}
// should reject chat from different user
{
const messageId = await createCopilotMessage(app, token, sessionId);
{
const { token } = await signUp(app, 'a2', 'a2@affine.pro', '123456');
await t.throwsAsync(
chatWithText(app, token.token, sessionId, messageId),
{ instanceOf: Error },
'should reject chat from different user'
);
}
}
});
// ==================== history ====================
test('should be able to list history', async t => {
const { app } = t.context;
const { id: workspaceId } = await createWorkspace(app, token);
const sessionId = await createCopilotSession(
app,
token,
workspaceId,
randomUUID(),
promptName
);
const messageId = await createCopilotMessage(app, token, sessionId);
await chatWithText(app, token, sessionId, messageId);
const histories = await getHistories(app, token, { workspaceId });
t.deepEqual(
histories.map(h => h.messages.map(m => m.content)),
[['generate text to text']],
'should be able to list history'
);
});
test('should reject request that user have not permission', async t => {
const { app } = t.context;
const {
token: { token: anotherToken },
} = await signUp(app, 'a1', 'a1@affine.pro', '123456');
const { id: workspaceId } = await createWorkspace(app, anotherToken);
// should reject request that user have not permission
{
await t.throwsAsync(
getHistories(app, token, { workspaceId }),
{ instanceOf: Error },
'should reject request that user have not permission'
);
}
// should able to list history after user have permission
{
const inviteId = await inviteUser(
app,
anotherToken,
workspaceId,
'darksky@affine.pro',
'Admin'
);
await acceptInviteById(app, workspaceId, inviteId, false);
t.deepEqual(
await getHistories(app, token, { workspaceId }),
[],
'should able to list history after user have permission'
);
}
{
const sessionId = await createCopilotSession(
app,
anotherToken,
workspaceId,
randomUUID(),
promptName
);
const messageId = await createCopilotMessage(app, anotherToken, sessionId);
await chatWithText(app, anotherToken, sessionId, messageId);
const histories = await getHistories(app, anotherToken, { workspaceId });
t.deepEqual(
histories.map(h => h.messages.map(m => m.content)),
[['generate text to text']],
'should able to list history'
);
t.deepEqual(
await getHistories(app, token, { workspaceId }),
[],
'should not list history created by another user'
);
}
});

View File

@@ -5,17 +5,28 @@ import type { TestFn } from 'ava';
import ava from 'ava';
import { AuthService } from '../src/core/auth';
import { QuotaManagementService, QuotaModule } from '../src/core/quota';
import { QuotaModule } from '../src/core/quota';
import { ConfigModule } from '../src/fundamentals/config';
import { CopilotModule } from '../src/plugins/copilot';
import { PromptService } from '../src/plugins/copilot/prompt';
import {
CopilotProviderService,
registerCopilotProvider,
} from '../src/plugins/copilot/providers';
import { ChatSessionService } from '../src/plugins/copilot/session';
import {
CopilotCapability,
CopilotProviderType,
} from '../src/plugins/copilot/types';
import { createTestingModule } from './utils';
import { MockCopilotTestProvider } from './utils/copilot';
const test = ava as TestFn<{
auth: AuthService;
quotaManager: QuotaManagementService;
module: TestingModule;
prompt: PromptService;
provider: CopilotProviderService;
session: ChatSessionService;
}>;
test.beforeEach(async t => {
@@ -27,6 +38,9 @@ test.beforeEach(async t => {
openai: {
apiKey: '1',
},
fal: {
apiKey: '1',
},
},
},
}),
@@ -35,26 +49,37 @@ test.beforeEach(async t => {
],
});
const quotaManager = module.get(QuotaManagementService);
const auth = module.get(AuthService);
const prompt = module.get(PromptService);
const provider = module.get(CopilotProviderService);
const session = module.get(ChatSessionService);
t.context.module = module;
t.context.quotaManager = quotaManager;
t.context.auth = auth;
t.context.prompt = prompt;
t.context.provider = provider;
t.context.session = session;
});
test.afterEach.always(async t => {
await t.context.module.close();
});
let userId: string;
test.beforeEach(async t => {
const { auth } = t.context;
const user = await auth.signUp('test', 'darksky@affine.pro', '123456');
userId = user.id;
});
// ==================== prompt ====================
test('should be able to manage prompt', async t => {
const { prompt } = t.context;
t.is((await prompt.list()).length, 0, 'should have no prompt');
await prompt.set('test', [
await prompt.set('test', 'test', [
{ role: 'system', content: 'hello' },
{ role: 'user', content: 'hello' },
]);
@@ -91,7 +116,7 @@ test('should be able to render prompt', async t => {
content: 'hello world',
};
await prompt.set('test', [msg]);
await prompt.set('test', 'test', [msg]);
const testPrompt = await prompt.get('test');
t.assert(testPrompt, 'should have prompt');
t.is(
@@ -126,7 +151,7 @@ test('should be able to render listed prompt', async t => {
links: ['https://affine.pro', 'https://github.com/toeverything/affine'],
};
await prompt.set('test', [msg]);
await prompt.set('test', 'test', [msg]);
const testPrompt = await prompt.get('test');
t.is(
@@ -135,3 +160,265 @@ test('should be able to render listed prompt', async t => {
'should render the prompt'
);
});
// ==================== session ====================
test('should be able to manage chat session', async t => {
const { prompt, session } = t.context;
await prompt.set('prompt', 'model', [
{ role: 'system', content: 'hello {{word}}' },
]);
const sessionId = await session.create({
docId: 'test',
workspaceId: 'test',
userId,
promptName: 'prompt',
});
t.truthy(sessionId, 'should create session');
const s = (await session.get(sessionId))!;
t.is(s.config.sessionId, sessionId, 'should get session');
t.is(s.config.promptName, 'prompt', 'should have prompt name');
t.is(s.model, 'model', 'should have model');
const params = { word: 'world' };
s.push({ role: 'user', content: 'hello', createdAt: new Date() });
// @ts-expect-error
const finalMessages = s.finish(params).map(({ createdAt: _, ...m }) => m);
t.deepEqual(
finalMessages,
[
{ content: 'hello world', params, role: 'system' },
{ content: 'hello', role: 'user' },
],
'should generate the final message'
);
await s.save();
const s1 = (await session.get(sessionId))!;
t.deepEqual(
// @ts-expect-error
s1.finish(params).map(({ createdAt: _, ...m }) => m),
finalMessages,
'should same as before message'
);
t.deepEqual(
// @ts-expect-error
s1.finish({}).map(({ createdAt: _, ...m }) => m),
[
{ content: 'hello ', params: {}, role: 'system' },
{ content: 'hello', role: 'user' },
],
'should generate different message with another params'
);
});
test('should be able to process message id', async t => {
const { prompt, session } = t.context;
await prompt.set('prompt', 'model', [
{ role: 'system', content: 'hello {{word}}' },
]);
const sessionId = await session.create({
docId: 'test',
workspaceId: 'test',
userId,
promptName: 'prompt',
});
const s = (await session.get(sessionId))!;
const textMessage = (await session.createMessage({
sessionId,
content: 'hello',
}))!;
const anotherSessionMessage = (await session.createMessage({
sessionId: 'another-session-id',
}))!;
await t.notThrowsAsync(
s.pushByMessageId(textMessage),
'should push by message id'
);
await t.throwsAsync(
s.pushByMessageId(anotherSessionMessage),
{
instanceOf: Error,
},
'should throw error if push by another session message id'
);
await t.throwsAsync(
s.pushByMessageId('invalid'),
{ instanceOf: Error },
'should throw error if push by invalid message id'
);
});
test('should be able to generate with message id', async t => {
const { prompt, session } = t.context;
await prompt.set('prompt', 'model', [
{ role: 'system', content: 'hello {{word}}' },
]);
// text message
{
const sessionId = await session.create({
docId: 'test',
workspaceId: 'test',
userId,
promptName: 'prompt',
});
const s = (await session.get(sessionId))!;
const message = (await session.createMessage({
sessionId,
content: 'hello',
}))!;
await s.pushByMessageId(message);
const finalMessages = s
.finish({ word: 'world' })
.map(({ content }) => content);
t.deepEqual(finalMessages, ['hello world', 'hello']);
}
// attachment message
{
const sessionId = await session.create({
docId: 'test',
workspaceId: 'test',
userId,
promptName: 'prompt',
});
const s = (await session.get(sessionId))!;
const message = (await session.createMessage({
sessionId,
attachments: ['https://affine.pro/example.jpg'],
}))!;
await s.pushByMessageId(message);
const finalMessages = s
.finish({ word: 'world' })
.map(({ attachments }) => attachments);
t.deepEqual(finalMessages, [
// system prompt
undefined,
// user prompt
['https://affine.pro/example.jpg'],
]);
}
// empty message
{
const sessionId = await session.create({
docId: 'test',
workspaceId: 'test',
userId,
promptName: 'prompt',
});
const s = (await session.get(sessionId))!;
const message = (await session.createMessage({
sessionId,
}))!;
await s.pushByMessageId(message);
const finalMessages = s
.finish({ word: 'world' })
.map(({ content }) => content);
// empty message should be filtered
t.deepEqual(finalMessages, ['hello world']);
}
});
// ==================== provider ====================
test('should be able to get provider', async t => {
const { provider } = t.context;
{
const p = provider.getProviderByCapability(CopilotCapability.TextToText);
t.is(
p?.type.toString(),
'openai',
'should get provider support text-to-text'
);
}
{
const p = provider.getProviderByCapability(
CopilotCapability.TextToEmbedding
);
t.is(
p?.type.toString(),
'openai',
'should get provider support text-to-embedding'
);
}
{
const p = provider.getProviderByCapability(CopilotCapability.TextToImage);
t.is(
p?.type.toString(),
'fal',
'should get provider support text-to-image'
);
}
{
const p = provider.getProviderByCapability(CopilotCapability.ImageToImage);
t.is(
p?.type.toString(),
'fal',
'should get provider support image-to-image'
);
}
{
const p = provider.getProviderByCapability(CopilotCapability.ImageToText);
t.is(
p?.type.toString(),
'openai',
'should get provider support image-to-text'
);
}
// text-to-image use fal by default, but this case can use
// model dall-e-3 to select openai provider
{
const p = provider.getProviderByCapability(
CopilotCapability.TextToImage,
'dall-e-3'
);
t.is(
p?.type.toString(),
'openai',
'should get provider support text-to-image and model'
);
}
});
test('should be able to register test provider', async t => {
const { provider } = t.context;
registerCopilotProvider(MockCopilotTestProvider);
const assertProvider = (cap: CopilotCapability) => {
const p = provider.getProviderByCapability(cap, 'test');
t.is(
p?.type,
CopilotProviderType.Test,
`should get test provider with ${cap}`
);
};
assertProvider(CopilotCapability.TextToText);
assertProvider(CopilotCapability.TextToEmbedding);
assertProvider(CopilotCapability.TextToImage);
assertProvider(CopilotCapability.ImageToImage);
assertProvider(CopilotCapability.ImageToText);
});

View File

@@ -0,0 +1,305 @@
import { randomBytes } from 'node:crypto';
import { INestApplication } from '@nestjs/common';
import request from 'supertest';
import {
DEFAULT_DIMENSIONS,
OpenAIProvider,
} from '../../src/plugins/copilot/providers/openai';
import {
CopilotCapability,
CopilotImageToImageProvider,
CopilotImageToTextProvider,
CopilotProviderType,
CopilotTextToEmbeddingProvider,
CopilotTextToImageProvider,
CopilotTextToTextProvider,
PromptMessage,
} from '../../src/plugins/copilot/types';
import { gql } from './common';
import { handleGraphQLError } from './utils';
export class MockCopilotTestProvider
extends OpenAIProvider
implements
CopilotTextToTextProvider,
CopilotTextToEmbeddingProvider,
CopilotTextToImageProvider,
CopilotImageToImageProvider,
CopilotImageToTextProvider
{
override readonly availableModels = ['test'];
static override readonly capabilities = [
CopilotCapability.TextToText,
CopilotCapability.TextToEmbedding,
CopilotCapability.TextToImage,
CopilotCapability.ImageToImage,
CopilotCapability.ImageToText,
];
override get type(): CopilotProviderType {
return CopilotProviderType.Test;
}
override getCapabilities(): CopilotCapability[] {
return MockCopilotTestProvider.capabilities;
}
override isModelAvailable(model: string): boolean {
return this.availableModels.includes(model);
}
// ====== text to text ======
override async generateText(
messages: PromptMessage[],
model: string = 'test',
_options: {
temperature?: number;
maxTokens?: number;
signal?: AbortSignal;
user?: string;
} = {}
): Promise<string> {
this.checkParams({ messages, model });
return 'generate text to text';
}
override async *generateTextStream(
messages: PromptMessage[],
model: string = 'gpt-3.5-turbo',
options: {
temperature?: number;
maxTokens?: number;
signal?: AbortSignal;
user?: string;
} = {}
): AsyncIterable<string> {
this.checkParams({ messages, model });
const result = 'generate text to text stream';
for await (const message of result) {
yield message;
if (options.signal?.aborted) {
break;
}
}
}
// ====== text to embedding ======
override async generateEmbedding(
messages: string | string[],
model: string,
options: {
dimensions: number;
signal?: AbortSignal;
user?: string;
} = { dimensions: DEFAULT_DIMENSIONS }
): Promise<number[][]> {
messages = Array.isArray(messages) ? messages : [messages];
this.checkParams({ embeddings: messages, model });
return [Array.from(randomBytes(options.dimensions)).map(v => v % 128)];
}
// ====== text to image ======
override async generateImages(
messages: PromptMessage[],
_model: string = 'test',
_options: {
signal?: AbortSignal;
user?: string;
} = {}
): Promise<Array<string>> {
const { content: prompt } = messages.pop() || {};
if (!prompt) {
throw new Error('Prompt is required');
}
return ['https://example.com/image.jpg'];
}
override async *generateImagesStream(
messages: PromptMessage[],
model: string = 'dall-e-3',
options: {
signal?: AbortSignal;
user?: string;
} = {}
): AsyncIterable<string> {
const ret = await this.generateImages(messages, model, options);
for (const url of ret) {
yield url;
}
}
}
export async function createCopilotSession(
app: INestApplication,
userToken: string,
workspaceId: string,
docId: string,
promptName: string
): Promise<string> {
const res = await request(app.getHttpServer())
.post(gql)
.auth(userToken, { type: 'bearer' })
.set({ 'x-request-id': 'test', 'x-operation-name': 'test' })
.send({
query: `
mutation createCopilotSession($options: CreateChatSessionInput!) {
createCopilotSession(options: $options)
}
`,
variables: { options: { workspaceId, docId, promptName } },
})
.expect(200);
handleGraphQLError(res);
return res.body.data.createCopilotSession;
}
export async function createCopilotMessage(
app: INestApplication,
userToken: string,
sessionId: string,
content?: string,
attachments?: string[],
blobs?: ArrayBuffer[],
params?: Record<string, string>
): Promise<string> {
const res = await request(app.getHttpServer())
.post(gql)
.auth(userToken, { type: 'bearer' })
.set({ 'x-request-id': 'test', 'x-operation-name': 'test' })
.send({
query: `
mutation createCopilotMessage($options: CreateChatMessageInput!) {
createCopilotMessage(options: $options)
}
`,
variables: {
options: { sessionId, content, attachments, blobs, params },
},
})
.expect(200);
handleGraphQLError(res);
return res.body.data.createCopilotMessage;
}
export async function chatWithText(
app: INestApplication,
userToken: string,
sessionId: string,
messageId: string,
prefix = ''
): Promise<string> {
const res = await request(app.getHttpServer())
.get(`/api/copilot/chat/${sessionId}${prefix}?messageId=${messageId}`)
.auth(userToken, { type: 'bearer' })
.expect(200);
return res.text;
}
export async function chatWithTextStream(
app: INestApplication,
userToken: string,
sessionId: string,
messageId: string
) {
return chatWithText(app, userToken, sessionId, messageId, '/stream');
}
export async function chatWithImages(
app: INestApplication,
userToken: string,
sessionId: string,
messageId: string
) {
return chatWithText(app, userToken, sessionId, messageId, '/images');
}
export function textToEventStream(
content: string | string[],
id: string,
event = 'message'
): string {
return (
Array.from(content)
.map(x => `\nevent: ${event}\nid: ${id}\ndata: ${x}`)
.join('\n') + '\n\n'
);
}
type ChatMessage = {
role: string;
content: string;
attachments: string[] | null;
createdAt: string;
};
type History = {
sessionId: string;
tokens: number;
action: string | null;
createdAt: string;
messages: ChatMessage[];
};
export async function getHistories(
app: INestApplication,
userToken: string,
variables: {
workspaceId: string;
docId?: string;
options?: {
sessionId?: string;
action?: boolean;
limit?: number;
skip?: number;
};
}
): Promise<History[]> {
const res = await request(app.getHttpServer())
.post(gql)
.auth(userToken, { type: 'bearer' })
.set({ 'x-request-id': 'test', 'x-operation-name': 'test' })
.send({
query: `
query getCopilotHistories(
$workspaceId: String!
$docId: String
$options: QueryChatHistoriesInput
) {
currentUser {
copilot(workspaceId: $workspaceId) {
histories(docId: $docId, options: $options) {
sessionId
tokens
action
createdAt
messages {
role
content
attachments
createdAt
}
}
}
}
}
`,
variables,
})
.expect(200);
handleGraphQLError(res);
return res.body.data.currentUser?.copilot?.histories || [];
}

View File

@@ -5,6 +5,7 @@ import { Test, TestingModuleBuilder } from '@nestjs/testing';
import { PrismaClient } from '@prisma/client';
import cookieParser from 'cookie-parser';
import graphqlUploadExpress from 'graphql-upload/graphqlUploadExpress.mjs';
import type { Response } from 'supertest';
import { AppModule, FunctionalityModules } from '../../src/app.module';
import { AuthGuard, AuthModule } from '../../src/core/auth';
@@ -136,3 +137,12 @@ export async function createTestingApp(moduleDef: TestingModuleMeatdata = {}) {
app,
};
}
export function handleGraphQLError(resp: Response) {
const { errors } = resp.body;
if (errors) {
const cause = errors[0];
const stacktrace = cause.extensions?.stacktrace;
throw new Error(stacktrace ? stacktrace.join('\n') : cause.message, cause);
}
}