test: copilot unit & e2e test (#6649)

fix CLOUD-31
This commit is contained in:
darkskygit
2024-04-26 09:43:35 +00:00
parent f015a11181
commit 850bbee629
12 changed files with 1145 additions and 134 deletions

View File

@@ -45,6 +45,7 @@ if (env.R2_OBJECT_STORAGE_ACCOUNT_ID) {
AFFiNE.plugins.use('copilot', {
openai: {},
fal: {},
});
AFFiNE.plugins.use('redis');
AFFiNE.plugins.use('payment', {

View File

@@ -42,6 +42,11 @@ export interface ChatEvent {
data: string;
}
type CheckResult = {
model: string | undefined;
hasAttachment?: boolean;
};
@Controller('/api/copilot')
export class CopilotController {
private readonly logger = new Logger(CopilotController.name);
@@ -53,17 +58,26 @@ export class CopilotController {
private readonly storage: CopilotStorage
) {}
private async hasAttachment(sessionId: string, messageId: string) {
private async checkRequest(
userId: string,
sessionId: string,
messageId?: string
): Promise<CheckResult> {
await this.chatSession.checkQuota(userId);
const session = await this.chatSession.get(sessionId);
if (!session) {
if (!session || session.config.userId !== userId) {
throw new BadRequestException('Session not found');
}
const message = await session.getMessageById(messageId);
if (Array.isArray(message.attachments) && message.attachments.length) {
return true;
const ret: CheckResult = { model: session.model };
if (messageId) {
const message = await session.getMessageById(messageId);
ret.hasAttachment =
Array.isArray(message.attachments) && !!message.attachments.length;
}
return false;
return ret;
}
private async appendSessionMessage(
@@ -107,9 +121,7 @@ export class CopilotController {
@Query('messageId') messageId: string,
@Query() params: Record<string, string | string[]>
): Promise<string> {
await this.chatSession.checkQuota(user.id);
const model = await this.chatSession.get(sessionId).then(s => s?.model);
const { model } = await this.checkRequest(user.id, sessionId);
const provider = this.provider.getProviderByCapability(
CopilotCapability.TextToText,
model
@@ -155,60 +167,58 @@ export class CopilotController {
@Query() params: Record<string, string>
): Promise<Observable<ChatEvent>> {
try {
await this.chatSession.checkQuota(user.id);
const { model } = await this.checkRequest(user.id, sessionId);
const provider = this.provider.getProviderByCapability(
CopilotCapability.TextToText,
model
);
if (!provider) {
throw new InternalServerErrorException('No provider available');
}
const session = await this.appendSessionMessage(sessionId, messageId);
delete params.messageId;
return from(
provider.generateTextStream(session.finish(params), session.model, {
signal: this.getSignal(req),
user: user.id,
})
).pipe(
connect(shared$ =>
merge(
// actual chat event stream
shared$.pipe(
map(data => ({ type: 'message' as const, id: messageId, data }))
),
// save the generated text to the session
shared$.pipe(
toArray(),
concatMap(values => {
session.push({
role: 'assistant',
content: values.join(''),
createdAt: new Date(),
});
return from(session.save());
}),
switchMap(() => EMPTY)
)
)
),
catchError(err =>
of({
type: 'error' as const,
data: this.handleError(err),
})
)
);
} catch (err) {
return of({
type: 'error' as const,
data: this.handleError(err),
});
}
const model = await this.chatSession.get(sessionId).then(s => s?.model);
const provider = this.provider.getProviderByCapability(
CopilotCapability.TextToText,
model
);
if (!provider) {
throw new InternalServerErrorException('No provider available');
}
const session = await this.appendSessionMessage(sessionId, messageId);
delete params.messageId;
return from(
provider.generateTextStream(session.finish(params), session.model, {
signal: this.getSignal(req),
user: user.id,
})
).pipe(
connect(shared$ =>
merge(
// actual chat event stream
shared$.pipe(
map(data => ({ type: 'message' as const, id: sessionId, data }))
),
// save the generated text to the session
shared$.pipe(
toArray(),
concatMap(values => {
session.push({
role: 'assistant',
content: values.join(''),
createdAt: new Date(),
});
return from(session.save());
}),
switchMap(() => EMPTY)
)
)
),
catchError(err =>
of({
type: 'error' as const,
data: this.handleError(err),
})
)
);
}
@Sse('/chat/:sessionId/images')
@@ -220,75 +230,76 @@ export class CopilotController {
@Query() params: Record<string, string>
): Promise<Observable<ChatEvent>> {
try {
await this.chatSession.checkQuota(user.id);
const { model, hasAttachment } = await this.checkRequest(
user.id,
sessionId,
messageId
);
const provider = this.provider.getProviderByCapability(
hasAttachment
? CopilotCapability.ImageToImage
: CopilotCapability.TextToImage,
model
);
if (!provider) {
throw new InternalServerErrorException('No provider available');
}
const session = await this.appendSessionMessage(sessionId, messageId);
delete params.messageId;
const handleRemoteLink = this.storage.handleRemoteLink.bind(
this.storage,
user.id,
sessionId
);
return from(
provider.generateImagesStream(session.finish(params), session.model, {
signal: this.getSignal(req),
user: user.id,
})
).pipe(
mergeMap(handleRemoteLink),
connect(shared$ =>
merge(
// actual chat event stream
shared$.pipe(
map(attachment => ({
type: 'attachment' as const,
id: messageId,
data: attachment,
}))
),
// save the generated text to the session
shared$.pipe(
toArray(),
concatMap(attachments => {
session.push({
role: 'assistant',
content: '',
attachments: attachments,
createdAt: new Date(),
});
return from(session.save());
}),
switchMap(() => EMPTY)
)
)
),
catchError(err =>
of({
type: 'error' as const,
data: this.handleError(err),
})
)
);
} catch (err) {
return of({
type: 'error' as const,
data: this.handleError(err),
});
}
const hasAttachment = await this.hasAttachment(sessionId, messageId);
const model = await this.chatSession.get(sessionId).then(s => s?.model);
const provider = this.provider.getProviderByCapability(
hasAttachment
? CopilotCapability.ImageToImage
: CopilotCapability.TextToImage,
model
);
if (!provider) {
throw new InternalServerErrorException('No provider available');
}
const session = await this.appendSessionMessage(sessionId, messageId);
delete params.messageId;
const handleRemoteLink = this.storage.handleRemoteLink.bind(
this.storage,
user.id,
sessionId
);
return from(
provider.generateImagesStream(session.finish(params), session.model, {
signal: this.getSignal(req),
user: user.id,
})
).pipe(
mergeMap(handleRemoteLink),
connect(shared$ =>
merge(
// actual chat event stream
shared$.pipe(
map(attachment => ({
type: 'attachment' as const,
id: sessionId,
data: attachment,
}))
),
// save the generated text to the session
shared$.pipe(
toArray(),
concatMap(attachments => {
session.push({
role: 'assistant',
content: '',
attachments: attachments,
createdAt: new Date(),
});
return from(session.save());
}),
switchMap(() => EMPTY)
)
)
),
catchError(err =>
of({
type: 'error' as const,
data: this.handleError(err),
})
)
);
}
@Get('/unsplash/photos')

View File

@@ -193,11 +193,12 @@ export class PromptService {
return null;
}
async set(name: string, messages: PromptMessage[]) {
async set(name: string, model: string, messages: PromptMessage[]) {
return await this.db.aiPrompt
.create({
data: {
name,
model,
messages: {
create: messages.map((m, idx) => ({
idx,

View File

@@ -41,6 +41,10 @@ export class FalProvider
return !!config.apiKey;
}
get type(): CopilotProviderType {
return FalProvider.type;
}
getCapabilities(): CopilotCapability[] {
return FalProvider.capabilities;
}

View File

@@ -13,7 +13,7 @@ import {
PromptMessage,
} from '../types';
const DEFAULT_DIMENSIONS = 256;
export const DEFAULT_DIMENSIONS = 256;
const SIMPLE_IMAGE_URL_REGEX = /^(https?:\/\/|data:image\/)/;
@@ -59,6 +59,10 @@ export class OpenAIProvider
return !!config.apiKey;
}
get type(): CopilotProviderType {
return OpenAIProvider.type;
}
getCapabilities(): CopilotCapability[] {
return OpenAIProvider.capabilities;
}
@@ -67,7 +71,7 @@ export class OpenAIProvider
return this.availableModels.includes(model);
}
private chatToGPTMessage(
protected chatToGPTMessage(
messages: PromptMessage[]
): OpenAI.Chat.Completions.ChatCompletionMessageParam[] {
// filter redundant fields
@@ -92,7 +96,7 @@ export class OpenAIProvider
});
}
private checkParams({
protected checkParams({
messages,
embeddings,
model,

View File

@@ -278,7 +278,9 @@ export class CopilotResolver {
return new TooManyRequestsException('Server is busy');
}
const session = await this.chatSession.get(options.sessionId);
if (!session) return new BadRequestException('Session not found');
if (!session || session.config.userId !== user.id) {
return new BadRequestException('Session not found');
}
if (options.blobs) {
options.attachments = options.attachments || [];

View File

@@ -81,7 +81,7 @@ export class ChatSession implements AsyncDisposable {
}
pop() {
this.state.messages.pop();
return this.state.messages.pop();
}
private takeMessages(): ChatMessage[] {
@@ -115,7 +115,7 @@ export class ChatSession implements AsyncDisposable {
Object.keys(params).length ? params : messages[0]?.params || {},
this.config.sessionId
),
...messages.filter(m => m.content || m.attachments?.length),
...messages.filter(m => m.content?.trim() || m.attachments?.length),
];
}

View File

@@ -15,6 +15,7 @@ export interface CopilotConfig {
openai: OpenAIClientOptions;
fal: FalConfig;
unsplashKey: string;
test: never;
}
export enum AvailableModels {
@@ -130,6 +131,8 @@ export type ListHistoriesOptions = {
export enum CopilotProviderType {
FAL = 'fal',
OpenAI = 'openai',
// only for test
Test = 'test',
}
export enum CopilotCapability {
@@ -141,6 +144,7 @@ export enum CopilotCapability {
}
export interface CopilotProvider {
readonly type: CopilotProviderType;
getCapabilities(): CopilotCapability[];
isModelAvailable(model: string): boolean;
}