mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-11 20:08:37 +00:00
@@ -45,6 +45,7 @@ if (env.R2_OBJECT_STORAGE_ACCOUNT_ID) {
|
||||
|
||||
AFFiNE.plugins.use('copilot', {
|
||||
openai: {},
|
||||
fal: {},
|
||||
});
|
||||
AFFiNE.plugins.use('redis');
|
||||
AFFiNE.plugins.use('payment', {
|
||||
|
||||
@@ -42,6 +42,11 @@ export interface ChatEvent {
|
||||
data: string;
|
||||
}
|
||||
|
||||
type CheckResult = {
|
||||
model: string | undefined;
|
||||
hasAttachment?: boolean;
|
||||
};
|
||||
|
||||
@Controller('/api/copilot')
|
||||
export class CopilotController {
|
||||
private readonly logger = new Logger(CopilotController.name);
|
||||
@@ -53,17 +58,26 @@ export class CopilotController {
|
||||
private readonly storage: CopilotStorage
|
||||
) {}
|
||||
|
||||
private async hasAttachment(sessionId: string, messageId: string) {
|
||||
private async checkRequest(
|
||||
userId: string,
|
||||
sessionId: string,
|
||||
messageId?: string
|
||||
): Promise<CheckResult> {
|
||||
await this.chatSession.checkQuota(userId);
|
||||
const session = await this.chatSession.get(sessionId);
|
||||
if (!session) {
|
||||
if (!session || session.config.userId !== userId) {
|
||||
throw new BadRequestException('Session not found');
|
||||
}
|
||||
|
||||
const message = await session.getMessageById(messageId);
|
||||
if (Array.isArray(message.attachments) && message.attachments.length) {
|
||||
return true;
|
||||
const ret: CheckResult = { model: session.model };
|
||||
|
||||
if (messageId) {
|
||||
const message = await session.getMessageById(messageId);
|
||||
ret.hasAttachment =
|
||||
Array.isArray(message.attachments) && !!message.attachments.length;
|
||||
}
|
||||
return false;
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
private async appendSessionMessage(
|
||||
@@ -107,9 +121,7 @@ export class CopilotController {
|
||||
@Query('messageId') messageId: string,
|
||||
@Query() params: Record<string, string | string[]>
|
||||
): Promise<string> {
|
||||
await this.chatSession.checkQuota(user.id);
|
||||
|
||||
const model = await this.chatSession.get(sessionId).then(s => s?.model);
|
||||
const { model } = await this.checkRequest(user.id, sessionId);
|
||||
const provider = this.provider.getProviderByCapability(
|
||||
CopilotCapability.TextToText,
|
||||
model
|
||||
@@ -155,60 +167,58 @@ export class CopilotController {
|
||||
@Query() params: Record<string, string>
|
||||
): Promise<Observable<ChatEvent>> {
|
||||
try {
|
||||
await this.chatSession.checkQuota(user.id);
|
||||
const { model } = await this.checkRequest(user.id, sessionId);
|
||||
const provider = this.provider.getProviderByCapability(
|
||||
CopilotCapability.TextToText,
|
||||
model
|
||||
);
|
||||
if (!provider) {
|
||||
throw new InternalServerErrorException('No provider available');
|
||||
}
|
||||
|
||||
const session = await this.appendSessionMessage(sessionId, messageId);
|
||||
delete params.messageId;
|
||||
|
||||
return from(
|
||||
provider.generateTextStream(session.finish(params), session.model, {
|
||||
signal: this.getSignal(req),
|
||||
user: user.id,
|
||||
})
|
||||
).pipe(
|
||||
connect(shared$ =>
|
||||
merge(
|
||||
// actual chat event stream
|
||||
shared$.pipe(
|
||||
map(data => ({ type: 'message' as const, id: messageId, data }))
|
||||
),
|
||||
// save the generated text to the session
|
||||
shared$.pipe(
|
||||
toArray(),
|
||||
concatMap(values => {
|
||||
session.push({
|
||||
role: 'assistant',
|
||||
content: values.join(''),
|
||||
createdAt: new Date(),
|
||||
});
|
||||
return from(session.save());
|
||||
}),
|
||||
switchMap(() => EMPTY)
|
||||
)
|
||||
)
|
||||
),
|
||||
catchError(err =>
|
||||
of({
|
||||
type: 'error' as const,
|
||||
data: this.handleError(err),
|
||||
})
|
||||
)
|
||||
);
|
||||
} catch (err) {
|
||||
return of({
|
||||
type: 'error' as const,
|
||||
data: this.handleError(err),
|
||||
});
|
||||
}
|
||||
|
||||
const model = await this.chatSession.get(sessionId).then(s => s?.model);
|
||||
const provider = this.provider.getProviderByCapability(
|
||||
CopilotCapability.TextToText,
|
||||
model
|
||||
);
|
||||
if (!provider) {
|
||||
throw new InternalServerErrorException('No provider available');
|
||||
}
|
||||
|
||||
const session = await this.appendSessionMessage(sessionId, messageId);
|
||||
delete params.messageId;
|
||||
|
||||
return from(
|
||||
provider.generateTextStream(session.finish(params), session.model, {
|
||||
signal: this.getSignal(req),
|
||||
user: user.id,
|
||||
})
|
||||
).pipe(
|
||||
connect(shared$ =>
|
||||
merge(
|
||||
// actual chat event stream
|
||||
shared$.pipe(
|
||||
map(data => ({ type: 'message' as const, id: sessionId, data }))
|
||||
),
|
||||
// save the generated text to the session
|
||||
shared$.pipe(
|
||||
toArray(),
|
||||
concatMap(values => {
|
||||
session.push({
|
||||
role: 'assistant',
|
||||
content: values.join(''),
|
||||
createdAt: new Date(),
|
||||
});
|
||||
return from(session.save());
|
||||
}),
|
||||
switchMap(() => EMPTY)
|
||||
)
|
||||
)
|
||||
),
|
||||
catchError(err =>
|
||||
of({
|
||||
type: 'error' as const,
|
||||
data: this.handleError(err),
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@Sse('/chat/:sessionId/images')
|
||||
@@ -220,75 +230,76 @@ export class CopilotController {
|
||||
@Query() params: Record<string, string>
|
||||
): Promise<Observable<ChatEvent>> {
|
||||
try {
|
||||
await this.chatSession.checkQuota(user.id);
|
||||
const { model, hasAttachment } = await this.checkRequest(
|
||||
user.id,
|
||||
sessionId,
|
||||
messageId
|
||||
);
|
||||
const provider = this.provider.getProviderByCapability(
|
||||
hasAttachment
|
||||
? CopilotCapability.ImageToImage
|
||||
: CopilotCapability.TextToImage,
|
||||
model
|
||||
);
|
||||
if (!provider) {
|
||||
throw new InternalServerErrorException('No provider available');
|
||||
}
|
||||
|
||||
const session = await this.appendSessionMessage(sessionId, messageId);
|
||||
delete params.messageId;
|
||||
|
||||
const handleRemoteLink = this.storage.handleRemoteLink.bind(
|
||||
this.storage,
|
||||
user.id,
|
||||
sessionId
|
||||
);
|
||||
|
||||
return from(
|
||||
provider.generateImagesStream(session.finish(params), session.model, {
|
||||
signal: this.getSignal(req),
|
||||
user: user.id,
|
||||
})
|
||||
).pipe(
|
||||
mergeMap(handleRemoteLink),
|
||||
connect(shared$ =>
|
||||
merge(
|
||||
// actual chat event stream
|
||||
shared$.pipe(
|
||||
map(attachment => ({
|
||||
type: 'attachment' as const,
|
||||
id: messageId,
|
||||
data: attachment,
|
||||
}))
|
||||
),
|
||||
// save the generated text to the session
|
||||
shared$.pipe(
|
||||
toArray(),
|
||||
concatMap(attachments => {
|
||||
session.push({
|
||||
role: 'assistant',
|
||||
content: '',
|
||||
attachments: attachments,
|
||||
createdAt: new Date(),
|
||||
});
|
||||
return from(session.save());
|
||||
}),
|
||||
switchMap(() => EMPTY)
|
||||
)
|
||||
)
|
||||
),
|
||||
catchError(err =>
|
||||
of({
|
||||
type: 'error' as const,
|
||||
data: this.handleError(err),
|
||||
})
|
||||
)
|
||||
);
|
||||
} catch (err) {
|
||||
return of({
|
||||
type: 'error' as const,
|
||||
data: this.handleError(err),
|
||||
});
|
||||
}
|
||||
|
||||
const hasAttachment = await this.hasAttachment(sessionId, messageId);
|
||||
const model = await this.chatSession.get(sessionId).then(s => s?.model);
|
||||
const provider = this.provider.getProviderByCapability(
|
||||
hasAttachment
|
||||
? CopilotCapability.ImageToImage
|
||||
: CopilotCapability.TextToImage,
|
||||
model
|
||||
);
|
||||
if (!provider) {
|
||||
throw new InternalServerErrorException('No provider available');
|
||||
}
|
||||
|
||||
const session = await this.appendSessionMessage(sessionId, messageId);
|
||||
delete params.messageId;
|
||||
|
||||
const handleRemoteLink = this.storage.handleRemoteLink.bind(
|
||||
this.storage,
|
||||
user.id,
|
||||
sessionId
|
||||
);
|
||||
|
||||
return from(
|
||||
provider.generateImagesStream(session.finish(params), session.model, {
|
||||
signal: this.getSignal(req),
|
||||
user: user.id,
|
||||
})
|
||||
).pipe(
|
||||
mergeMap(handleRemoteLink),
|
||||
connect(shared$ =>
|
||||
merge(
|
||||
// actual chat event stream
|
||||
shared$.pipe(
|
||||
map(attachment => ({
|
||||
type: 'attachment' as const,
|
||||
id: sessionId,
|
||||
data: attachment,
|
||||
}))
|
||||
),
|
||||
// save the generated text to the session
|
||||
shared$.pipe(
|
||||
toArray(),
|
||||
concatMap(attachments => {
|
||||
session.push({
|
||||
role: 'assistant',
|
||||
content: '',
|
||||
attachments: attachments,
|
||||
createdAt: new Date(),
|
||||
});
|
||||
return from(session.save());
|
||||
}),
|
||||
switchMap(() => EMPTY)
|
||||
)
|
||||
)
|
||||
),
|
||||
catchError(err =>
|
||||
of({
|
||||
type: 'error' as const,
|
||||
data: this.handleError(err),
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@Get('/unsplash/photos')
|
||||
|
||||
@@ -193,11 +193,12 @@ export class PromptService {
|
||||
return null;
|
||||
}
|
||||
|
||||
async set(name: string, messages: PromptMessage[]) {
|
||||
async set(name: string, model: string, messages: PromptMessage[]) {
|
||||
return await this.db.aiPrompt
|
||||
.create({
|
||||
data: {
|
||||
name,
|
||||
model,
|
||||
messages: {
|
||||
create: messages.map((m, idx) => ({
|
||||
idx,
|
||||
|
||||
@@ -41,6 +41,10 @@ export class FalProvider
|
||||
return !!config.apiKey;
|
||||
}
|
||||
|
||||
get type(): CopilotProviderType {
|
||||
return FalProvider.type;
|
||||
}
|
||||
|
||||
getCapabilities(): CopilotCapability[] {
|
||||
return FalProvider.capabilities;
|
||||
}
|
||||
|
||||
@@ -13,7 +13,7 @@ import {
|
||||
PromptMessage,
|
||||
} from '../types';
|
||||
|
||||
const DEFAULT_DIMENSIONS = 256;
|
||||
export const DEFAULT_DIMENSIONS = 256;
|
||||
|
||||
const SIMPLE_IMAGE_URL_REGEX = /^(https?:\/\/|data:image\/)/;
|
||||
|
||||
@@ -59,6 +59,10 @@ export class OpenAIProvider
|
||||
return !!config.apiKey;
|
||||
}
|
||||
|
||||
get type(): CopilotProviderType {
|
||||
return OpenAIProvider.type;
|
||||
}
|
||||
|
||||
getCapabilities(): CopilotCapability[] {
|
||||
return OpenAIProvider.capabilities;
|
||||
}
|
||||
@@ -67,7 +71,7 @@ export class OpenAIProvider
|
||||
return this.availableModels.includes(model);
|
||||
}
|
||||
|
||||
private chatToGPTMessage(
|
||||
protected chatToGPTMessage(
|
||||
messages: PromptMessage[]
|
||||
): OpenAI.Chat.Completions.ChatCompletionMessageParam[] {
|
||||
// filter redundant fields
|
||||
@@ -92,7 +96,7 @@ export class OpenAIProvider
|
||||
});
|
||||
}
|
||||
|
||||
private checkParams({
|
||||
protected checkParams({
|
||||
messages,
|
||||
embeddings,
|
||||
model,
|
||||
|
||||
@@ -278,7 +278,9 @@ export class CopilotResolver {
|
||||
return new TooManyRequestsException('Server is busy');
|
||||
}
|
||||
const session = await this.chatSession.get(options.sessionId);
|
||||
if (!session) return new BadRequestException('Session not found');
|
||||
if (!session || session.config.userId !== user.id) {
|
||||
return new BadRequestException('Session not found');
|
||||
}
|
||||
|
||||
if (options.blobs) {
|
||||
options.attachments = options.attachments || [];
|
||||
|
||||
@@ -81,7 +81,7 @@ export class ChatSession implements AsyncDisposable {
|
||||
}
|
||||
|
||||
pop() {
|
||||
this.state.messages.pop();
|
||||
return this.state.messages.pop();
|
||||
}
|
||||
|
||||
private takeMessages(): ChatMessage[] {
|
||||
@@ -115,7 +115,7 @@ export class ChatSession implements AsyncDisposable {
|
||||
Object.keys(params).length ? params : messages[0]?.params || {},
|
||||
this.config.sessionId
|
||||
),
|
||||
...messages.filter(m => m.content || m.attachments?.length),
|
||||
...messages.filter(m => m.content?.trim() || m.attachments?.length),
|
||||
];
|
||||
}
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ export interface CopilotConfig {
|
||||
openai: OpenAIClientOptions;
|
||||
fal: FalConfig;
|
||||
unsplashKey: string;
|
||||
test: never;
|
||||
}
|
||||
|
||||
export enum AvailableModels {
|
||||
@@ -130,6 +131,8 @@ export type ListHistoriesOptions = {
|
||||
export enum CopilotProviderType {
|
||||
FAL = 'fal',
|
||||
OpenAI = 'openai',
|
||||
// only for test
|
||||
Test = 'test',
|
||||
}
|
||||
|
||||
export enum CopilotCapability {
|
||||
@@ -141,6 +144,7 @@ export enum CopilotCapability {
|
||||
}
|
||||
|
||||
export interface CopilotProvider {
|
||||
readonly type: CopilotProviderType;
|
||||
getCapabilities(): CopilotCapability[];
|
||||
isModelAvailable(model: string): boolean;
|
||||
}
|
||||
|
||||
382
packages/backend/server/tests/copilot.e2e.ts
Normal file
382
packages/backend/server/tests/copilot.e2e.ts
Normal file
@@ -0,0 +1,382 @@
|
||||
/// <reference types="../src/global.d.ts" />
|
||||
|
||||
import { randomUUID } from 'node:crypto';
|
||||
|
||||
import { INestApplication } from '@nestjs/common';
|
||||
import type { TestFn } from 'ava';
|
||||
import ava from 'ava';
|
||||
import Sinon from 'sinon';
|
||||
|
||||
import { AuthService } from '../src/core/auth';
|
||||
import { WorkspaceModule } from '../src/core/workspaces';
|
||||
import { ConfigModule } from '../src/fundamentals/config';
|
||||
import { CopilotModule } from '../src/plugins/copilot';
|
||||
import { PromptService } from '../src/plugins/copilot/prompt';
|
||||
import {
|
||||
CopilotProviderService,
|
||||
registerCopilotProvider,
|
||||
} from '../src/plugins/copilot/providers';
|
||||
import { CopilotStorage } from '../src/plugins/copilot/storage';
|
||||
import {
|
||||
acceptInviteById,
|
||||
createTestingApp,
|
||||
createWorkspace,
|
||||
inviteUser,
|
||||
signUp,
|
||||
} from './utils';
|
||||
import {
|
||||
chatWithImages,
|
||||
chatWithText,
|
||||
chatWithTextStream,
|
||||
createCopilotMessage,
|
||||
createCopilotSession,
|
||||
getHistories,
|
||||
MockCopilotTestProvider,
|
||||
textToEventStream,
|
||||
} from './utils/copilot';
|
||||
|
||||
const test = ava as TestFn<{
|
||||
auth: AuthService;
|
||||
app: INestApplication;
|
||||
prompt: PromptService;
|
||||
provider: CopilotProviderService;
|
||||
storage: CopilotStorage;
|
||||
}>;
|
||||
|
||||
test.beforeEach(async t => {
|
||||
const { app } = await createTestingApp({
|
||||
imports: [
|
||||
ConfigModule.forRoot({
|
||||
plugins: {
|
||||
copilot: {
|
||||
openai: {
|
||||
apiKey: '1',
|
||||
},
|
||||
fal: {
|
||||
apiKey: '1',
|
||||
},
|
||||
},
|
||||
},
|
||||
}),
|
||||
WorkspaceModule,
|
||||
CopilotModule,
|
||||
],
|
||||
});
|
||||
|
||||
const auth = app.get(AuthService);
|
||||
const prompt = app.get(PromptService);
|
||||
const storage = app.get(CopilotStorage);
|
||||
|
||||
t.context.app = app;
|
||||
t.context.auth = auth;
|
||||
t.context.prompt = prompt;
|
||||
t.context.storage = storage;
|
||||
});
|
||||
|
||||
let token: string;
|
||||
const promptName = 'prompt';
|
||||
test.beforeEach(async t => {
|
||||
const { app, prompt } = t.context;
|
||||
const user = await signUp(app, 'test', 'darksky@affine.pro', '123456');
|
||||
token = user.token.token;
|
||||
|
||||
registerCopilotProvider(MockCopilotTestProvider);
|
||||
|
||||
await prompt.set(promptName, 'test', [
|
||||
{ role: 'system', content: 'hello {{word}}' },
|
||||
]);
|
||||
});
|
||||
|
||||
test.afterEach.always(async t => {
|
||||
await t.context.app.close();
|
||||
});
|
||||
|
||||
// ==================== session ====================
|
||||
|
||||
test('should create session correctly', async t => {
|
||||
const { app } = t.context;
|
||||
|
||||
const assertCreateSession = async (
|
||||
workspaceId: string,
|
||||
error: string,
|
||||
asserter = async (x: any) => {
|
||||
t.truthy(await x, error);
|
||||
}
|
||||
) => {
|
||||
await asserter(
|
||||
createCopilotSession(app, token, workspaceId, randomUUID(), promptName)
|
||||
);
|
||||
};
|
||||
|
||||
{
|
||||
const { id } = await createWorkspace(app, token);
|
||||
await assertCreateSession(
|
||||
id,
|
||||
'should be able to create session with cloud workspace that user can access'
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
await assertCreateSession(
|
||||
randomUUID(),
|
||||
'should be able to create session with local workspace'
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
const {
|
||||
token: { token },
|
||||
} = await signUp(app, 'test', 'test@affine.pro', '123456');
|
||||
const { id } = await createWorkspace(app, token);
|
||||
await assertCreateSession(id, '', async x => {
|
||||
await t.throwsAsync(
|
||||
x,
|
||||
{ instanceOf: Error },
|
||||
'should not able to create session with cloud workspace that user cannot access'
|
||||
);
|
||||
});
|
||||
|
||||
const inviteId = await inviteUser(
|
||||
app,
|
||||
token,
|
||||
id,
|
||||
'darksky@affine.pro',
|
||||
'Admin'
|
||||
);
|
||||
await acceptInviteById(app, id, inviteId, false);
|
||||
await assertCreateSession(
|
||||
id,
|
||||
'should able to create session after user have permission'
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
test('should be able to use test provider', async t => {
|
||||
const { app } = t.context;
|
||||
|
||||
const { id } = await createWorkspace(app, token);
|
||||
t.truthy(
|
||||
await createCopilotSession(app, token, id, randomUUID(), promptName),
|
||||
'failed to create session'
|
||||
);
|
||||
});
|
||||
|
||||
// ==================== message ====================
|
||||
|
||||
test('should create message correctly', async t => {
|
||||
const { app } = t.context;
|
||||
|
||||
{
|
||||
const { id } = await createWorkspace(app, token);
|
||||
const sessionId = await createCopilotSession(
|
||||
app,
|
||||
token,
|
||||
id,
|
||||
randomUUID(),
|
||||
promptName
|
||||
);
|
||||
const messageId = await createCopilotMessage(app, token, sessionId);
|
||||
t.truthy(messageId, 'should be able to create message with valid session');
|
||||
}
|
||||
|
||||
{
|
||||
await t.throwsAsync(
|
||||
createCopilotMessage(app, token, randomUUID()),
|
||||
{ instanceOf: Error },
|
||||
'should not able to create message with invalid session'
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
// ==================== chat ====================
|
||||
|
||||
test('should be able to chat with api', async t => {
|
||||
const { app, storage } = t.context;
|
||||
|
||||
Sinon.stub(storage, 'handleRemoteLink').resolvesArg(2);
|
||||
|
||||
const { id } = await createWorkspace(app, token);
|
||||
const sessionId = await createCopilotSession(
|
||||
app,
|
||||
token,
|
||||
id,
|
||||
randomUUID(),
|
||||
promptName
|
||||
);
|
||||
const messageId = await createCopilotMessage(app, token, sessionId);
|
||||
const ret = await chatWithText(app, token, sessionId, messageId);
|
||||
t.is(ret, 'generate text to text', 'should be able to chat with text');
|
||||
|
||||
const ret2 = await chatWithTextStream(app, token, sessionId, messageId);
|
||||
t.is(
|
||||
ret2,
|
||||
textToEventStream('generate text to text stream', messageId),
|
||||
'should be able to chat with text stream'
|
||||
);
|
||||
|
||||
const ret3 = await chatWithImages(app, token, sessionId, messageId);
|
||||
t.is(
|
||||
ret3,
|
||||
textToEventStream(
|
||||
['https://example.com/image.jpg'],
|
||||
messageId,
|
||||
'attachment'
|
||||
),
|
||||
'should be able to chat with images'
|
||||
);
|
||||
|
||||
Sinon.restore();
|
||||
});
|
||||
|
||||
test('should reject message from different session', async t => {
|
||||
const { app } = t.context;
|
||||
|
||||
const { id } = await createWorkspace(app, token);
|
||||
const sessionId = await createCopilotSession(
|
||||
app,
|
||||
token,
|
||||
id,
|
||||
randomUUID(),
|
||||
promptName
|
||||
);
|
||||
const anotherSessionId = await createCopilotSession(
|
||||
app,
|
||||
token,
|
||||
id,
|
||||
randomUUID(),
|
||||
promptName
|
||||
);
|
||||
const anotherMessageId = await createCopilotMessage(
|
||||
app,
|
||||
token,
|
||||
anotherSessionId
|
||||
);
|
||||
await t.throwsAsync(
|
||||
chatWithText(app, token, sessionId, anotherMessageId),
|
||||
{ instanceOf: Error },
|
||||
'should reject message from different session'
|
||||
);
|
||||
});
|
||||
|
||||
test('should reject request from different user', async t => {
|
||||
const { app } = t.context;
|
||||
|
||||
const { id } = await createWorkspace(app, token);
|
||||
const sessionId = await createCopilotSession(
|
||||
app,
|
||||
token,
|
||||
id,
|
||||
randomUUID(),
|
||||
promptName
|
||||
);
|
||||
|
||||
// should reject message from different user
|
||||
{
|
||||
const { token } = await signUp(app, 'a1', 'a1@affine.pro', '123456');
|
||||
await t.throwsAsync(
|
||||
createCopilotMessage(app, token.token, sessionId),
|
||||
{ instanceOf: Error },
|
||||
'should reject message from different user'
|
||||
);
|
||||
}
|
||||
|
||||
// should reject chat from different user
|
||||
{
|
||||
const messageId = await createCopilotMessage(app, token, sessionId);
|
||||
{
|
||||
const { token } = await signUp(app, 'a2', 'a2@affine.pro', '123456');
|
||||
await t.throwsAsync(
|
||||
chatWithText(app, token.token, sessionId, messageId),
|
||||
{ instanceOf: Error },
|
||||
'should reject chat from different user'
|
||||
);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// ==================== history ====================
|
||||
|
||||
test('should be able to list history', async t => {
|
||||
const { app } = t.context;
|
||||
|
||||
const { id: workspaceId } = await createWorkspace(app, token);
|
||||
const sessionId = await createCopilotSession(
|
||||
app,
|
||||
token,
|
||||
workspaceId,
|
||||
randomUUID(),
|
||||
promptName
|
||||
);
|
||||
|
||||
const messageId = await createCopilotMessage(app, token, sessionId);
|
||||
await chatWithText(app, token, sessionId, messageId);
|
||||
|
||||
const histories = await getHistories(app, token, { workspaceId });
|
||||
t.deepEqual(
|
||||
histories.map(h => h.messages.map(m => m.content)),
|
||||
[['generate text to text']],
|
||||
'should be able to list history'
|
||||
);
|
||||
});
|
||||
|
||||
test('should reject request that user have not permission', async t => {
|
||||
const { app } = t.context;
|
||||
|
||||
const {
|
||||
token: { token: anotherToken },
|
||||
} = await signUp(app, 'a1', 'a1@affine.pro', '123456');
|
||||
const { id: workspaceId } = await createWorkspace(app, anotherToken);
|
||||
|
||||
// should reject request that user have not permission
|
||||
{
|
||||
await t.throwsAsync(
|
||||
getHistories(app, token, { workspaceId }),
|
||||
{ instanceOf: Error },
|
||||
'should reject request that user have not permission'
|
||||
);
|
||||
}
|
||||
|
||||
// should able to list history after user have permission
|
||||
{
|
||||
const inviteId = await inviteUser(
|
||||
app,
|
||||
anotherToken,
|
||||
workspaceId,
|
||||
'darksky@affine.pro',
|
||||
'Admin'
|
||||
);
|
||||
await acceptInviteById(app, workspaceId, inviteId, false);
|
||||
|
||||
t.deepEqual(
|
||||
await getHistories(app, token, { workspaceId }),
|
||||
[],
|
||||
'should able to list history after user have permission'
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
const sessionId = await createCopilotSession(
|
||||
app,
|
||||
anotherToken,
|
||||
workspaceId,
|
||||
randomUUID(),
|
||||
promptName
|
||||
);
|
||||
|
||||
const messageId = await createCopilotMessage(app, anotherToken, sessionId);
|
||||
await chatWithText(app, anotherToken, sessionId, messageId);
|
||||
|
||||
const histories = await getHistories(app, anotherToken, { workspaceId });
|
||||
t.deepEqual(
|
||||
histories.map(h => h.messages.map(m => m.content)),
|
||||
[['generate text to text']],
|
||||
'should able to list history'
|
||||
);
|
||||
|
||||
t.deepEqual(
|
||||
await getHistories(app, token, { workspaceId }),
|
||||
[],
|
||||
'should not list history created by another user'
|
||||
);
|
||||
}
|
||||
});
|
||||
@@ -5,17 +5,28 @@ import type { TestFn } from 'ava';
|
||||
import ava from 'ava';
|
||||
|
||||
import { AuthService } from '../src/core/auth';
|
||||
import { QuotaManagementService, QuotaModule } from '../src/core/quota';
|
||||
import { QuotaModule } from '../src/core/quota';
|
||||
import { ConfigModule } from '../src/fundamentals/config';
|
||||
import { CopilotModule } from '../src/plugins/copilot';
|
||||
import { PromptService } from '../src/plugins/copilot/prompt';
|
||||
import {
|
||||
CopilotProviderService,
|
||||
registerCopilotProvider,
|
||||
} from '../src/plugins/copilot/providers';
|
||||
import { ChatSessionService } from '../src/plugins/copilot/session';
|
||||
import {
|
||||
CopilotCapability,
|
||||
CopilotProviderType,
|
||||
} from '../src/plugins/copilot/types';
|
||||
import { createTestingModule } from './utils';
|
||||
import { MockCopilotTestProvider } from './utils/copilot';
|
||||
|
||||
const test = ava as TestFn<{
|
||||
auth: AuthService;
|
||||
quotaManager: QuotaManagementService;
|
||||
module: TestingModule;
|
||||
prompt: PromptService;
|
||||
provider: CopilotProviderService;
|
||||
session: ChatSessionService;
|
||||
}>;
|
||||
|
||||
test.beforeEach(async t => {
|
||||
@@ -27,6 +38,9 @@ test.beforeEach(async t => {
|
||||
openai: {
|
||||
apiKey: '1',
|
||||
},
|
||||
fal: {
|
||||
apiKey: '1',
|
||||
},
|
||||
},
|
||||
},
|
||||
}),
|
||||
@@ -35,26 +49,37 @@ test.beforeEach(async t => {
|
||||
],
|
||||
});
|
||||
|
||||
const quotaManager = module.get(QuotaManagementService);
|
||||
const auth = module.get(AuthService);
|
||||
const prompt = module.get(PromptService);
|
||||
const provider = module.get(CopilotProviderService);
|
||||
const session = module.get(ChatSessionService);
|
||||
|
||||
t.context.module = module;
|
||||
t.context.quotaManager = quotaManager;
|
||||
t.context.auth = auth;
|
||||
t.context.prompt = prompt;
|
||||
t.context.provider = provider;
|
||||
t.context.session = session;
|
||||
});
|
||||
|
||||
test.afterEach.always(async t => {
|
||||
await t.context.module.close();
|
||||
});
|
||||
|
||||
let userId: string;
|
||||
test.beforeEach(async t => {
|
||||
const { auth } = t.context;
|
||||
const user = await auth.signUp('test', 'darksky@affine.pro', '123456');
|
||||
userId = user.id;
|
||||
});
|
||||
|
||||
// ==================== prompt ====================
|
||||
|
||||
test('should be able to manage prompt', async t => {
|
||||
const { prompt } = t.context;
|
||||
|
||||
t.is((await prompt.list()).length, 0, 'should have no prompt');
|
||||
|
||||
await prompt.set('test', [
|
||||
await prompt.set('test', 'test', [
|
||||
{ role: 'system', content: 'hello' },
|
||||
{ role: 'user', content: 'hello' },
|
||||
]);
|
||||
@@ -91,7 +116,7 @@ test('should be able to render prompt', async t => {
|
||||
content: 'hello world',
|
||||
};
|
||||
|
||||
await prompt.set('test', [msg]);
|
||||
await prompt.set('test', 'test', [msg]);
|
||||
const testPrompt = await prompt.get('test');
|
||||
t.assert(testPrompt, 'should have prompt');
|
||||
t.is(
|
||||
@@ -126,7 +151,7 @@ test('should be able to render listed prompt', async t => {
|
||||
links: ['https://affine.pro', 'https://github.com/toeverything/affine'],
|
||||
};
|
||||
|
||||
await prompt.set('test', [msg]);
|
||||
await prompt.set('test', 'test', [msg]);
|
||||
const testPrompt = await prompt.get('test');
|
||||
|
||||
t.is(
|
||||
@@ -135,3 +160,265 @@ test('should be able to render listed prompt', async t => {
|
||||
'should render the prompt'
|
||||
);
|
||||
});
|
||||
|
||||
// ==================== session ====================
|
||||
|
||||
test('should be able to manage chat session', async t => {
|
||||
const { prompt, session } = t.context;
|
||||
|
||||
await prompt.set('prompt', 'model', [
|
||||
{ role: 'system', content: 'hello {{word}}' },
|
||||
]);
|
||||
|
||||
const sessionId = await session.create({
|
||||
docId: 'test',
|
||||
workspaceId: 'test',
|
||||
userId,
|
||||
promptName: 'prompt',
|
||||
});
|
||||
t.truthy(sessionId, 'should create session');
|
||||
|
||||
const s = (await session.get(sessionId))!;
|
||||
t.is(s.config.sessionId, sessionId, 'should get session');
|
||||
t.is(s.config.promptName, 'prompt', 'should have prompt name');
|
||||
t.is(s.model, 'model', 'should have model');
|
||||
|
||||
const params = { word: 'world' };
|
||||
|
||||
s.push({ role: 'user', content: 'hello', createdAt: new Date() });
|
||||
// @ts-expect-error
|
||||
const finalMessages = s.finish(params).map(({ createdAt: _, ...m }) => m);
|
||||
t.deepEqual(
|
||||
finalMessages,
|
||||
[
|
||||
{ content: 'hello world', params, role: 'system' },
|
||||
{ content: 'hello', role: 'user' },
|
||||
],
|
||||
'should generate the final message'
|
||||
);
|
||||
await s.save();
|
||||
|
||||
const s1 = (await session.get(sessionId))!;
|
||||
t.deepEqual(
|
||||
// @ts-expect-error
|
||||
s1.finish(params).map(({ createdAt: _, ...m }) => m),
|
||||
finalMessages,
|
||||
'should same as before message'
|
||||
);
|
||||
t.deepEqual(
|
||||
// @ts-expect-error
|
||||
s1.finish({}).map(({ createdAt: _, ...m }) => m),
|
||||
[
|
||||
{ content: 'hello ', params: {}, role: 'system' },
|
||||
{ content: 'hello', role: 'user' },
|
||||
],
|
||||
'should generate different message with another params'
|
||||
);
|
||||
});
|
||||
|
||||
test('should be able to process message id', async t => {
|
||||
const { prompt, session } = t.context;
|
||||
|
||||
await prompt.set('prompt', 'model', [
|
||||
{ role: 'system', content: 'hello {{word}}' },
|
||||
]);
|
||||
|
||||
const sessionId = await session.create({
|
||||
docId: 'test',
|
||||
workspaceId: 'test',
|
||||
userId,
|
||||
promptName: 'prompt',
|
||||
});
|
||||
const s = (await session.get(sessionId))!;
|
||||
|
||||
const textMessage = (await session.createMessage({
|
||||
sessionId,
|
||||
content: 'hello',
|
||||
}))!;
|
||||
const anotherSessionMessage = (await session.createMessage({
|
||||
sessionId: 'another-session-id',
|
||||
}))!;
|
||||
|
||||
await t.notThrowsAsync(
|
||||
s.pushByMessageId(textMessage),
|
||||
'should push by message id'
|
||||
);
|
||||
await t.throwsAsync(
|
||||
s.pushByMessageId(anotherSessionMessage),
|
||||
{
|
||||
instanceOf: Error,
|
||||
},
|
||||
'should throw error if push by another session message id'
|
||||
);
|
||||
await t.throwsAsync(
|
||||
s.pushByMessageId('invalid'),
|
||||
{ instanceOf: Error },
|
||||
'should throw error if push by invalid message id'
|
||||
);
|
||||
});
|
||||
|
||||
test('should be able to generate with message id', async t => {
|
||||
const { prompt, session } = t.context;
|
||||
|
||||
await prompt.set('prompt', 'model', [
|
||||
{ role: 'system', content: 'hello {{word}}' },
|
||||
]);
|
||||
|
||||
// text message
|
||||
{
|
||||
const sessionId = await session.create({
|
||||
docId: 'test',
|
||||
workspaceId: 'test',
|
||||
userId,
|
||||
promptName: 'prompt',
|
||||
});
|
||||
const s = (await session.get(sessionId))!;
|
||||
|
||||
const message = (await session.createMessage({
|
||||
sessionId,
|
||||
content: 'hello',
|
||||
}))!;
|
||||
|
||||
await s.pushByMessageId(message);
|
||||
const finalMessages = s
|
||||
.finish({ word: 'world' })
|
||||
.map(({ content }) => content);
|
||||
t.deepEqual(finalMessages, ['hello world', 'hello']);
|
||||
}
|
||||
|
||||
// attachment message
|
||||
{
|
||||
const sessionId = await session.create({
|
||||
docId: 'test',
|
||||
workspaceId: 'test',
|
||||
userId,
|
||||
promptName: 'prompt',
|
||||
});
|
||||
const s = (await session.get(sessionId))!;
|
||||
|
||||
const message = (await session.createMessage({
|
||||
sessionId,
|
||||
attachments: ['https://affine.pro/example.jpg'],
|
||||
}))!;
|
||||
|
||||
await s.pushByMessageId(message);
|
||||
const finalMessages = s
|
||||
.finish({ word: 'world' })
|
||||
.map(({ attachments }) => attachments);
|
||||
t.deepEqual(finalMessages, [
|
||||
// system prompt
|
||||
undefined,
|
||||
// user prompt
|
||||
['https://affine.pro/example.jpg'],
|
||||
]);
|
||||
}
|
||||
|
||||
// empty message
|
||||
{
|
||||
const sessionId = await session.create({
|
||||
docId: 'test',
|
||||
workspaceId: 'test',
|
||||
userId,
|
||||
promptName: 'prompt',
|
||||
});
|
||||
const s = (await session.get(sessionId))!;
|
||||
|
||||
const message = (await session.createMessage({
|
||||
sessionId,
|
||||
}))!;
|
||||
|
||||
await s.pushByMessageId(message);
|
||||
const finalMessages = s
|
||||
.finish({ word: 'world' })
|
||||
.map(({ content }) => content);
|
||||
// empty message should be filtered
|
||||
t.deepEqual(finalMessages, ['hello world']);
|
||||
}
|
||||
});
|
||||
|
||||
// ==================== provider ====================
|
||||
|
||||
test('should be able to get provider', async t => {
|
||||
const { provider } = t.context;
|
||||
|
||||
{
|
||||
const p = provider.getProviderByCapability(CopilotCapability.TextToText);
|
||||
t.is(
|
||||
p?.type.toString(),
|
||||
'openai',
|
||||
'should get provider support text-to-text'
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
const p = provider.getProviderByCapability(
|
||||
CopilotCapability.TextToEmbedding
|
||||
);
|
||||
t.is(
|
||||
p?.type.toString(),
|
||||
'openai',
|
||||
'should get provider support text-to-embedding'
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
const p = provider.getProviderByCapability(CopilotCapability.TextToImage);
|
||||
t.is(
|
||||
p?.type.toString(),
|
||||
'fal',
|
||||
'should get provider support text-to-image'
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
const p = provider.getProviderByCapability(CopilotCapability.ImageToImage);
|
||||
t.is(
|
||||
p?.type.toString(),
|
||||
'fal',
|
||||
'should get provider support image-to-image'
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
const p = provider.getProviderByCapability(CopilotCapability.ImageToText);
|
||||
t.is(
|
||||
p?.type.toString(),
|
||||
'openai',
|
||||
'should get provider support image-to-text'
|
||||
);
|
||||
}
|
||||
|
||||
// text-to-image use fal by default, but this case can use
|
||||
// model dall-e-3 to select openai provider
|
||||
{
|
||||
const p = provider.getProviderByCapability(
|
||||
CopilotCapability.TextToImage,
|
||||
'dall-e-3'
|
||||
);
|
||||
t.is(
|
||||
p?.type.toString(),
|
||||
'openai',
|
||||
'should get provider support text-to-image and model'
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
test('should be able to register test provider', async t => {
|
||||
const { provider } = t.context;
|
||||
registerCopilotProvider(MockCopilotTestProvider);
|
||||
|
||||
const assertProvider = (cap: CopilotCapability) => {
|
||||
const p = provider.getProviderByCapability(cap, 'test');
|
||||
t.is(
|
||||
p?.type,
|
||||
CopilotProviderType.Test,
|
||||
`should get test provider with ${cap}`
|
||||
);
|
||||
};
|
||||
|
||||
assertProvider(CopilotCapability.TextToText);
|
||||
assertProvider(CopilotCapability.TextToEmbedding);
|
||||
assertProvider(CopilotCapability.TextToImage);
|
||||
assertProvider(CopilotCapability.ImageToImage);
|
||||
assertProvider(CopilotCapability.ImageToText);
|
||||
});
|
||||
|
||||
305
packages/backend/server/tests/utils/copilot.ts
Normal file
305
packages/backend/server/tests/utils/copilot.ts
Normal file
@@ -0,0 +1,305 @@
|
||||
import { randomBytes } from 'node:crypto';
|
||||
|
||||
import { INestApplication } from '@nestjs/common';
|
||||
import request from 'supertest';
|
||||
|
||||
import {
|
||||
DEFAULT_DIMENSIONS,
|
||||
OpenAIProvider,
|
||||
} from '../../src/plugins/copilot/providers/openai';
|
||||
import {
|
||||
CopilotCapability,
|
||||
CopilotImageToImageProvider,
|
||||
CopilotImageToTextProvider,
|
||||
CopilotProviderType,
|
||||
CopilotTextToEmbeddingProvider,
|
||||
CopilotTextToImageProvider,
|
||||
CopilotTextToTextProvider,
|
||||
PromptMessage,
|
||||
} from '../../src/plugins/copilot/types';
|
||||
import { gql } from './common';
|
||||
import { handleGraphQLError } from './utils';
|
||||
|
||||
export class MockCopilotTestProvider
|
||||
extends OpenAIProvider
|
||||
implements
|
||||
CopilotTextToTextProvider,
|
||||
CopilotTextToEmbeddingProvider,
|
||||
CopilotTextToImageProvider,
|
||||
CopilotImageToImageProvider,
|
||||
CopilotImageToTextProvider
|
||||
{
|
||||
override readonly availableModels = ['test'];
|
||||
static override readonly capabilities = [
|
||||
CopilotCapability.TextToText,
|
||||
CopilotCapability.TextToEmbedding,
|
||||
CopilotCapability.TextToImage,
|
||||
CopilotCapability.ImageToImage,
|
||||
CopilotCapability.ImageToText,
|
||||
];
|
||||
|
||||
override get type(): CopilotProviderType {
|
||||
return CopilotProviderType.Test;
|
||||
}
|
||||
|
||||
override getCapabilities(): CopilotCapability[] {
|
||||
return MockCopilotTestProvider.capabilities;
|
||||
}
|
||||
|
||||
override isModelAvailable(model: string): boolean {
|
||||
return this.availableModels.includes(model);
|
||||
}
|
||||
|
||||
// ====== text to text ======
|
||||
|
||||
override async generateText(
|
||||
messages: PromptMessage[],
|
||||
model: string = 'test',
|
||||
_options: {
|
||||
temperature?: number;
|
||||
maxTokens?: number;
|
||||
signal?: AbortSignal;
|
||||
user?: string;
|
||||
} = {}
|
||||
): Promise<string> {
|
||||
this.checkParams({ messages, model });
|
||||
return 'generate text to text';
|
||||
}
|
||||
|
||||
override async *generateTextStream(
|
||||
messages: PromptMessage[],
|
||||
model: string = 'gpt-3.5-turbo',
|
||||
options: {
|
||||
temperature?: number;
|
||||
maxTokens?: number;
|
||||
signal?: AbortSignal;
|
||||
user?: string;
|
||||
} = {}
|
||||
): AsyncIterable<string> {
|
||||
this.checkParams({ messages, model });
|
||||
|
||||
const result = 'generate text to text stream';
|
||||
for await (const message of result) {
|
||||
yield message;
|
||||
if (options.signal?.aborted) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ====== text to embedding ======
|
||||
|
||||
override async generateEmbedding(
|
||||
messages: string | string[],
|
||||
model: string,
|
||||
options: {
|
||||
dimensions: number;
|
||||
signal?: AbortSignal;
|
||||
user?: string;
|
||||
} = { dimensions: DEFAULT_DIMENSIONS }
|
||||
): Promise<number[][]> {
|
||||
messages = Array.isArray(messages) ? messages : [messages];
|
||||
this.checkParams({ embeddings: messages, model });
|
||||
|
||||
return [Array.from(randomBytes(options.dimensions)).map(v => v % 128)];
|
||||
}
|
||||
|
||||
// ====== text to image ======
|
||||
override async generateImages(
|
||||
messages: PromptMessage[],
|
||||
_model: string = 'test',
|
||||
_options: {
|
||||
signal?: AbortSignal;
|
||||
user?: string;
|
||||
} = {}
|
||||
): Promise<Array<string>> {
|
||||
const { content: prompt } = messages.pop() || {};
|
||||
if (!prompt) {
|
||||
throw new Error('Prompt is required');
|
||||
}
|
||||
|
||||
return ['https://example.com/image.jpg'];
|
||||
}
|
||||
|
||||
override async *generateImagesStream(
|
||||
messages: PromptMessage[],
|
||||
model: string = 'dall-e-3',
|
||||
options: {
|
||||
signal?: AbortSignal;
|
||||
user?: string;
|
||||
} = {}
|
||||
): AsyncIterable<string> {
|
||||
const ret = await this.generateImages(messages, model, options);
|
||||
for (const url of ret) {
|
||||
yield url;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export async function createCopilotSession(
|
||||
app: INestApplication,
|
||||
userToken: string,
|
||||
workspaceId: string,
|
||||
docId: string,
|
||||
promptName: string
|
||||
): Promise<string> {
|
||||
const res = await request(app.getHttpServer())
|
||||
.post(gql)
|
||||
.auth(userToken, { type: 'bearer' })
|
||||
.set({ 'x-request-id': 'test', 'x-operation-name': 'test' })
|
||||
.send({
|
||||
query: `
|
||||
mutation createCopilotSession($options: CreateChatSessionInput!) {
|
||||
createCopilotSession(options: $options)
|
||||
}
|
||||
`,
|
||||
variables: { options: { workspaceId, docId, promptName } },
|
||||
})
|
||||
.expect(200);
|
||||
|
||||
handleGraphQLError(res);
|
||||
|
||||
return res.body.data.createCopilotSession;
|
||||
}
|
||||
|
||||
export async function createCopilotMessage(
|
||||
app: INestApplication,
|
||||
userToken: string,
|
||||
sessionId: string,
|
||||
content?: string,
|
||||
attachments?: string[],
|
||||
blobs?: ArrayBuffer[],
|
||||
params?: Record<string, string>
|
||||
): Promise<string> {
|
||||
const res = await request(app.getHttpServer())
|
||||
.post(gql)
|
||||
.auth(userToken, { type: 'bearer' })
|
||||
.set({ 'x-request-id': 'test', 'x-operation-name': 'test' })
|
||||
.send({
|
||||
query: `
|
||||
mutation createCopilotMessage($options: CreateChatMessageInput!) {
|
||||
createCopilotMessage(options: $options)
|
||||
}
|
||||
`,
|
||||
variables: {
|
||||
options: { sessionId, content, attachments, blobs, params },
|
||||
},
|
||||
})
|
||||
.expect(200);
|
||||
|
||||
handleGraphQLError(res);
|
||||
|
||||
return res.body.data.createCopilotMessage;
|
||||
}
|
||||
|
||||
export async function chatWithText(
|
||||
app: INestApplication,
|
||||
userToken: string,
|
||||
sessionId: string,
|
||||
messageId: string,
|
||||
prefix = ''
|
||||
): Promise<string> {
|
||||
const res = await request(app.getHttpServer())
|
||||
.get(`/api/copilot/chat/${sessionId}${prefix}?messageId=${messageId}`)
|
||||
.auth(userToken, { type: 'bearer' })
|
||||
.expect(200);
|
||||
|
||||
return res.text;
|
||||
}
|
||||
|
||||
export async function chatWithTextStream(
|
||||
app: INestApplication,
|
||||
userToken: string,
|
||||
sessionId: string,
|
||||
messageId: string
|
||||
) {
|
||||
return chatWithText(app, userToken, sessionId, messageId, '/stream');
|
||||
}
|
||||
|
||||
export async function chatWithImages(
|
||||
app: INestApplication,
|
||||
userToken: string,
|
||||
sessionId: string,
|
||||
messageId: string
|
||||
) {
|
||||
return chatWithText(app, userToken, sessionId, messageId, '/images');
|
||||
}
|
||||
|
||||
export function textToEventStream(
|
||||
content: string | string[],
|
||||
id: string,
|
||||
event = 'message'
|
||||
): string {
|
||||
return (
|
||||
Array.from(content)
|
||||
.map(x => `\nevent: ${event}\nid: ${id}\ndata: ${x}`)
|
||||
.join('\n') + '\n\n'
|
||||
);
|
||||
}
|
||||
|
||||
type ChatMessage = {
|
||||
role: string;
|
||||
content: string;
|
||||
attachments: string[] | null;
|
||||
createdAt: string;
|
||||
};
|
||||
|
||||
type History = {
|
||||
sessionId: string;
|
||||
tokens: number;
|
||||
action: string | null;
|
||||
createdAt: string;
|
||||
messages: ChatMessage[];
|
||||
};
|
||||
|
||||
export async function getHistories(
|
||||
app: INestApplication,
|
||||
userToken: string,
|
||||
variables: {
|
||||
workspaceId: string;
|
||||
docId?: string;
|
||||
options?: {
|
||||
sessionId?: string;
|
||||
action?: boolean;
|
||||
limit?: number;
|
||||
skip?: number;
|
||||
};
|
||||
}
|
||||
): Promise<History[]> {
|
||||
const res = await request(app.getHttpServer())
|
||||
.post(gql)
|
||||
.auth(userToken, { type: 'bearer' })
|
||||
.set({ 'x-request-id': 'test', 'x-operation-name': 'test' })
|
||||
.send({
|
||||
query: `
|
||||
query getCopilotHistories(
|
||||
$workspaceId: String!
|
||||
$docId: String
|
||||
$options: QueryChatHistoriesInput
|
||||
) {
|
||||
currentUser {
|
||||
copilot(workspaceId: $workspaceId) {
|
||||
histories(docId: $docId, options: $options) {
|
||||
sessionId
|
||||
tokens
|
||||
action
|
||||
createdAt
|
||||
messages {
|
||||
role
|
||||
content
|
||||
attachments
|
||||
createdAt
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
`,
|
||||
variables,
|
||||
})
|
||||
.expect(200);
|
||||
|
||||
handleGraphQLError(res);
|
||||
|
||||
return res.body.data.currentUser?.copilot?.histories || [];
|
||||
}
|
||||
@@ -5,6 +5,7 @@ import { Test, TestingModuleBuilder } from '@nestjs/testing';
|
||||
import { PrismaClient } from '@prisma/client';
|
||||
import cookieParser from 'cookie-parser';
|
||||
import graphqlUploadExpress from 'graphql-upload/graphqlUploadExpress.mjs';
|
||||
import type { Response } from 'supertest';
|
||||
|
||||
import { AppModule, FunctionalityModules } from '../../src/app.module';
|
||||
import { AuthGuard, AuthModule } from '../../src/core/auth';
|
||||
@@ -136,3 +137,12 @@ export async function createTestingApp(moduleDef: TestingModuleMeatdata = {}) {
|
||||
app,
|
||||
};
|
||||
}
|
||||
|
||||
export function handleGraphQLError(resp: Response) {
|
||||
const { errors } = resp.body;
|
||||
if (errors) {
|
||||
const cause = errors[0];
|
||||
const stacktrace = cause.extensions?.stacktrace;
|
||||
throw new Error(stacktrace ? stacktrace.join('\n') : cause.message, cause);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user