mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-14 13:25:12 +00:00
@@ -45,6 +45,7 @@ if (env.R2_OBJECT_STORAGE_ACCOUNT_ID) {
|
||||
|
||||
AFFiNE.plugins.use('copilot', {
|
||||
openai: {},
|
||||
fal: {},
|
||||
});
|
||||
AFFiNE.plugins.use('redis');
|
||||
AFFiNE.plugins.use('payment', {
|
||||
|
||||
@@ -42,6 +42,11 @@ export interface ChatEvent {
|
||||
data: string;
|
||||
}
|
||||
|
||||
type CheckResult = {
|
||||
model: string | undefined;
|
||||
hasAttachment?: boolean;
|
||||
};
|
||||
|
||||
@Controller('/api/copilot')
|
||||
export class CopilotController {
|
||||
private readonly logger = new Logger(CopilotController.name);
|
||||
@@ -53,17 +58,26 @@ export class CopilotController {
|
||||
private readonly storage: CopilotStorage
|
||||
) {}
|
||||
|
||||
private async hasAttachment(sessionId: string, messageId: string) {
|
||||
private async checkRequest(
|
||||
userId: string,
|
||||
sessionId: string,
|
||||
messageId?: string
|
||||
): Promise<CheckResult> {
|
||||
await this.chatSession.checkQuota(userId);
|
||||
const session = await this.chatSession.get(sessionId);
|
||||
if (!session) {
|
||||
if (!session || session.config.userId !== userId) {
|
||||
throw new BadRequestException('Session not found');
|
||||
}
|
||||
|
||||
const message = await session.getMessageById(messageId);
|
||||
if (Array.isArray(message.attachments) && message.attachments.length) {
|
||||
return true;
|
||||
const ret: CheckResult = { model: session.model };
|
||||
|
||||
if (messageId) {
|
||||
const message = await session.getMessageById(messageId);
|
||||
ret.hasAttachment =
|
||||
Array.isArray(message.attachments) && !!message.attachments.length;
|
||||
}
|
||||
return false;
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
private async appendSessionMessage(
|
||||
@@ -107,9 +121,7 @@ export class CopilotController {
|
||||
@Query('messageId') messageId: string,
|
||||
@Query() params: Record<string, string | string[]>
|
||||
): Promise<string> {
|
||||
await this.chatSession.checkQuota(user.id);
|
||||
|
||||
const model = await this.chatSession.get(sessionId).then(s => s?.model);
|
||||
const { model } = await this.checkRequest(user.id, sessionId);
|
||||
const provider = this.provider.getProviderByCapability(
|
||||
CopilotCapability.TextToText,
|
||||
model
|
||||
@@ -155,60 +167,58 @@ export class CopilotController {
|
||||
@Query() params: Record<string, string>
|
||||
): Promise<Observable<ChatEvent>> {
|
||||
try {
|
||||
await this.chatSession.checkQuota(user.id);
|
||||
const { model } = await this.checkRequest(user.id, sessionId);
|
||||
const provider = this.provider.getProviderByCapability(
|
||||
CopilotCapability.TextToText,
|
||||
model
|
||||
);
|
||||
if (!provider) {
|
||||
throw new InternalServerErrorException('No provider available');
|
||||
}
|
||||
|
||||
const session = await this.appendSessionMessage(sessionId, messageId);
|
||||
delete params.messageId;
|
||||
|
||||
return from(
|
||||
provider.generateTextStream(session.finish(params), session.model, {
|
||||
signal: this.getSignal(req),
|
||||
user: user.id,
|
||||
})
|
||||
).pipe(
|
||||
connect(shared$ =>
|
||||
merge(
|
||||
// actual chat event stream
|
||||
shared$.pipe(
|
||||
map(data => ({ type: 'message' as const, id: messageId, data }))
|
||||
),
|
||||
// save the generated text to the session
|
||||
shared$.pipe(
|
||||
toArray(),
|
||||
concatMap(values => {
|
||||
session.push({
|
||||
role: 'assistant',
|
||||
content: values.join(''),
|
||||
createdAt: new Date(),
|
||||
});
|
||||
return from(session.save());
|
||||
}),
|
||||
switchMap(() => EMPTY)
|
||||
)
|
||||
)
|
||||
),
|
||||
catchError(err =>
|
||||
of({
|
||||
type: 'error' as const,
|
||||
data: this.handleError(err),
|
||||
})
|
||||
)
|
||||
);
|
||||
} catch (err) {
|
||||
return of({
|
||||
type: 'error' as const,
|
||||
data: this.handleError(err),
|
||||
});
|
||||
}
|
||||
|
||||
const model = await this.chatSession.get(sessionId).then(s => s?.model);
|
||||
const provider = this.provider.getProviderByCapability(
|
||||
CopilotCapability.TextToText,
|
||||
model
|
||||
);
|
||||
if (!provider) {
|
||||
throw new InternalServerErrorException('No provider available');
|
||||
}
|
||||
|
||||
const session = await this.appendSessionMessage(sessionId, messageId);
|
||||
delete params.messageId;
|
||||
|
||||
return from(
|
||||
provider.generateTextStream(session.finish(params), session.model, {
|
||||
signal: this.getSignal(req),
|
||||
user: user.id,
|
||||
})
|
||||
).pipe(
|
||||
connect(shared$ =>
|
||||
merge(
|
||||
// actual chat event stream
|
||||
shared$.pipe(
|
||||
map(data => ({ type: 'message' as const, id: sessionId, data }))
|
||||
),
|
||||
// save the generated text to the session
|
||||
shared$.pipe(
|
||||
toArray(),
|
||||
concatMap(values => {
|
||||
session.push({
|
||||
role: 'assistant',
|
||||
content: values.join(''),
|
||||
createdAt: new Date(),
|
||||
});
|
||||
return from(session.save());
|
||||
}),
|
||||
switchMap(() => EMPTY)
|
||||
)
|
||||
)
|
||||
),
|
||||
catchError(err =>
|
||||
of({
|
||||
type: 'error' as const,
|
||||
data: this.handleError(err),
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@Sse('/chat/:sessionId/images')
|
||||
@@ -220,75 +230,76 @@ export class CopilotController {
|
||||
@Query() params: Record<string, string>
|
||||
): Promise<Observable<ChatEvent>> {
|
||||
try {
|
||||
await this.chatSession.checkQuota(user.id);
|
||||
const { model, hasAttachment } = await this.checkRequest(
|
||||
user.id,
|
||||
sessionId,
|
||||
messageId
|
||||
);
|
||||
const provider = this.provider.getProviderByCapability(
|
||||
hasAttachment
|
||||
? CopilotCapability.ImageToImage
|
||||
: CopilotCapability.TextToImage,
|
||||
model
|
||||
);
|
||||
if (!provider) {
|
||||
throw new InternalServerErrorException('No provider available');
|
||||
}
|
||||
|
||||
const session = await this.appendSessionMessage(sessionId, messageId);
|
||||
delete params.messageId;
|
||||
|
||||
const handleRemoteLink = this.storage.handleRemoteLink.bind(
|
||||
this.storage,
|
||||
user.id,
|
||||
sessionId
|
||||
);
|
||||
|
||||
return from(
|
||||
provider.generateImagesStream(session.finish(params), session.model, {
|
||||
signal: this.getSignal(req),
|
||||
user: user.id,
|
||||
})
|
||||
).pipe(
|
||||
mergeMap(handleRemoteLink),
|
||||
connect(shared$ =>
|
||||
merge(
|
||||
// actual chat event stream
|
||||
shared$.pipe(
|
||||
map(attachment => ({
|
||||
type: 'attachment' as const,
|
||||
id: messageId,
|
||||
data: attachment,
|
||||
}))
|
||||
),
|
||||
// save the generated text to the session
|
||||
shared$.pipe(
|
||||
toArray(),
|
||||
concatMap(attachments => {
|
||||
session.push({
|
||||
role: 'assistant',
|
||||
content: '',
|
||||
attachments: attachments,
|
||||
createdAt: new Date(),
|
||||
});
|
||||
return from(session.save());
|
||||
}),
|
||||
switchMap(() => EMPTY)
|
||||
)
|
||||
)
|
||||
),
|
||||
catchError(err =>
|
||||
of({
|
||||
type: 'error' as const,
|
||||
data: this.handleError(err),
|
||||
})
|
||||
)
|
||||
);
|
||||
} catch (err) {
|
||||
return of({
|
||||
type: 'error' as const,
|
||||
data: this.handleError(err),
|
||||
});
|
||||
}
|
||||
|
||||
const hasAttachment = await this.hasAttachment(sessionId, messageId);
|
||||
const model = await this.chatSession.get(sessionId).then(s => s?.model);
|
||||
const provider = this.provider.getProviderByCapability(
|
||||
hasAttachment
|
||||
? CopilotCapability.ImageToImage
|
||||
: CopilotCapability.TextToImage,
|
||||
model
|
||||
);
|
||||
if (!provider) {
|
||||
throw new InternalServerErrorException('No provider available');
|
||||
}
|
||||
|
||||
const session = await this.appendSessionMessage(sessionId, messageId);
|
||||
delete params.messageId;
|
||||
|
||||
const handleRemoteLink = this.storage.handleRemoteLink.bind(
|
||||
this.storage,
|
||||
user.id,
|
||||
sessionId
|
||||
);
|
||||
|
||||
return from(
|
||||
provider.generateImagesStream(session.finish(params), session.model, {
|
||||
signal: this.getSignal(req),
|
||||
user: user.id,
|
||||
})
|
||||
).pipe(
|
||||
mergeMap(handleRemoteLink),
|
||||
connect(shared$ =>
|
||||
merge(
|
||||
// actual chat event stream
|
||||
shared$.pipe(
|
||||
map(attachment => ({
|
||||
type: 'attachment' as const,
|
||||
id: sessionId,
|
||||
data: attachment,
|
||||
}))
|
||||
),
|
||||
// save the generated text to the session
|
||||
shared$.pipe(
|
||||
toArray(),
|
||||
concatMap(attachments => {
|
||||
session.push({
|
||||
role: 'assistant',
|
||||
content: '',
|
||||
attachments: attachments,
|
||||
createdAt: new Date(),
|
||||
});
|
||||
return from(session.save());
|
||||
}),
|
||||
switchMap(() => EMPTY)
|
||||
)
|
||||
)
|
||||
),
|
||||
catchError(err =>
|
||||
of({
|
||||
type: 'error' as const,
|
||||
data: this.handleError(err),
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@Get('/unsplash/photos')
|
||||
|
||||
@@ -193,11 +193,12 @@ export class PromptService {
|
||||
return null;
|
||||
}
|
||||
|
||||
async set(name: string, messages: PromptMessage[]) {
|
||||
async set(name: string, model: string, messages: PromptMessage[]) {
|
||||
return await this.db.aiPrompt
|
||||
.create({
|
||||
data: {
|
||||
name,
|
||||
model,
|
||||
messages: {
|
||||
create: messages.map((m, idx) => ({
|
||||
idx,
|
||||
|
||||
@@ -41,6 +41,10 @@ export class FalProvider
|
||||
return !!config.apiKey;
|
||||
}
|
||||
|
||||
get type(): CopilotProviderType {
|
||||
return FalProvider.type;
|
||||
}
|
||||
|
||||
getCapabilities(): CopilotCapability[] {
|
||||
return FalProvider.capabilities;
|
||||
}
|
||||
|
||||
@@ -13,7 +13,7 @@ import {
|
||||
PromptMessage,
|
||||
} from '../types';
|
||||
|
||||
const DEFAULT_DIMENSIONS = 256;
|
||||
export const DEFAULT_DIMENSIONS = 256;
|
||||
|
||||
const SIMPLE_IMAGE_URL_REGEX = /^(https?:\/\/|data:image\/)/;
|
||||
|
||||
@@ -59,6 +59,10 @@ export class OpenAIProvider
|
||||
return !!config.apiKey;
|
||||
}
|
||||
|
||||
get type(): CopilotProviderType {
|
||||
return OpenAIProvider.type;
|
||||
}
|
||||
|
||||
getCapabilities(): CopilotCapability[] {
|
||||
return OpenAIProvider.capabilities;
|
||||
}
|
||||
@@ -67,7 +71,7 @@ export class OpenAIProvider
|
||||
return this.availableModels.includes(model);
|
||||
}
|
||||
|
||||
private chatToGPTMessage(
|
||||
protected chatToGPTMessage(
|
||||
messages: PromptMessage[]
|
||||
): OpenAI.Chat.Completions.ChatCompletionMessageParam[] {
|
||||
// filter redundant fields
|
||||
@@ -92,7 +96,7 @@ export class OpenAIProvider
|
||||
});
|
||||
}
|
||||
|
||||
private checkParams({
|
||||
protected checkParams({
|
||||
messages,
|
||||
embeddings,
|
||||
model,
|
||||
|
||||
@@ -278,7 +278,9 @@ export class CopilotResolver {
|
||||
return new TooManyRequestsException('Server is busy');
|
||||
}
|
||||
const session = await this.chatSession.get(options.sessionId);
|
||||
if (!session) return new BadRequestException('Session not found');
|
||||
if (!session || session.config.userId !== user.id) {
|
||||
return new BadRequestException('Session not found');
|
||||
}
|
||||
|
||||
if (options.blobs) {
|
||||
options.attachments = options.attachments || [];
|
||||
|
||||
@@ -81,7 +81,7 @@ export class ChatSession implements AsyncDisposable {
|
||||
}
|
||||
|
||||
pop() {
|
||||
this.state.messages.pop();
|
||||
return this.state.messages.pop();
|
||||
}
|
||||
|
||||
private takeMessages(): ChatMessage[] {
|
||||
@@ -115,7 +115,7 @@ export class ChatSession implements AsyncDisposable {
|
||||
Object.keys(params).length ? params : messages[0]?.params || {},
|
||||
this.config.sessionId
|
||||
),
|
||||
...messages.filter(m => m.content || m.attachments?.length),
|
||||
...messages.filter(m => m.content?.trim() || m.attachments?.length),
|
||||
];
|
||||
}
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ export interface CopilotConfig {
|
||||
openai: OpenAIClientOptions;
|
||||
fal: FalConfig;
|
||||
unsplashKey: string;
|
||||
test: never;
|
||||
}
|
||||
|
||||
export enum AvailableModels {
|
||||
@@ -130,6 +131,8 @@ export type ListHistoriesOptions = {
|
||||
export enum CopilotProviderType {
|
||||
FAL = 'fal',
|
||||
OpenAI = 'openai',
|
||||
// only for test
|
||||
Test = 'test',
|
||||
}
|
||||
|
||||
export enum CopilotCapability {
|
||||
@@ -141,6 +144,7 @@ export enum CopilotCapability {
|
||||
}
|
||||
|
||||
export interface CopilotProvider {
|
||||
readonly type: CopilotProviderType;
|
||||
getCapabilities(): CopilotCapability[];
|
||||
isModelAvailable(model: string): boolean;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user