mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-12 20:38:52 +00:00
feat: text to image impl (#6437)
fix CLOUD-18 fix CLOUD-28 fix CLOUD-29
This commit is contained in:
@@ -23,12 +23,13 @@ import {
|
||||
import { Public } from '../../core/auth';
|
||||
import { CurrentUser } from '../../core/auth/current-user';
|
||||
import { CopilotProviderService } from './providers';
|
||||
import { ChatSessionService } from './session';
|
||||
import { ChatSession, ChatSessionService } from './session';
|
||||
import { CopilotCapability } from './types';
|
||||
|
||||
export interface ChatEvent {
|
||||
data: string;
|
||||
type: 'attachment' | 'message';
|
||||
id?: string;
|
||||
data: string;
|
||||
}
|
||||
|
||||
@Controller('/api/copilot')
|
||||
@@ -38,13 +39,54 @@ export class CopilotController {
|
||||
private readonly provider: CopilotProviderService
|
||||
) {}
|
||||
|
||||
private async hasAttachment(sessionId: string, messageId?: string) {
|
||||
const session = await this.chatSession.get(sessionId);
|
||||
if (!session) {
|
||||
throw new BadRequestException('Session not found');
|
||||
}
|
||||
|
||||
if (messageId) {
|
||||
const message = await session.getMessageById(messageId);
|
||||
if (Array.isArray(message.attachments) && message.attachments.length) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private async appendSessionMessage(
|
||||
sessionId: string,
|
||||
message?: string,
|
||||
messageId?: string
|
||||
): Promise<ChatSession> {
|
||||
const session = await this.chatSession.get(sessionId);
|
||||
if (!session) {
|
||||
throw new BadRequestException('Session not found');
|
||||
}
|
||||
|
||||
if (messageId) {
|
||||
await session.pushByMessageId(messageId);
|
||||
} else {
|
||||
if (!message || !message.trim()) {
|
||||
throw new BadRequestException('Message is empty');
|
||||
}
|
||||
session.push({
|
||||
role: 'user',
|
||||
content: decodeURIComponent(message),
|
||||
createdAt: new Date(),
|
||||
});
|
||||
}
|
||||
return session;
|
||||
}
|
||||
|
||||
@Public()
|
||||
@Get('/chat/:sessionId')
|
||||
async chat(
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@Req() req: Request,
|
||||
@Param('sessionId') sessionId: string,
|
||||
@Query('message') content: string,
|
||||
@Query('message') message: string | undefined,
|
||||
@Query('messageId') messageId: string | undefined,
|
||||
@Query() params: Record<string, string | string[]>
|
||||
): Promise<string> {
|
||||
const provider = this.provider.getProviderByCapability(
|
||||
@@ -53,21 +95,16 @@ export class CopilotController {
|
||||
if (!provider) {
|
||||
throw new InternalServerErrorException('No provider available');
|
||||
}
|
||||
const session = await this.chatSession.get(sessionId);
|
||||
if (!session) {
|
||||
throw new BadRequestException('Session not found');
|
||||
}
|
||||
if (!content || !content.trim()) {
|
||||
throw new BadRequestException('Message is empty');
|
||||
}
|
||||
session.push({
|
||||
role: 'user',
|
||||
content: decodeURIComponent(content),
|
||||
createdAt: new Date(),
|
||||
});
|
||||
|
||||
const session = await this.appendSessionMessage(
|
||||
sessionId,
|
||||
message,
|
||||
messageId
|
||||
);
|
||||
|
||||
try {
|
||||
delete params.message;
|
||||
delete params.messageId;
|
||||
const content = await provider.generateText(
|
||||
session.finish(params),
|
||||
session.model,
|
||||
@@ -98,7 +135,8 @@ export class CopilotController {
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@Req() req: Request,
|
||||
@Param('sessionId') sessionId: string,
|
||||
@Query('message') content: string,
|
||||
@Query('message') message: string | undefined,
|
||||
@Query('messageId') messageId: string | undefined,
|
||||
@Query() params: Record<string, string>
|
||||
): Promise<Observable<ChatEvent>> {
|
||||
const provider = this.provider.getProviderByCapability(
|
||||
@@ -107,20 +145,15 @@ export class CopilotController {
|
||||
if (!provider) {
|
||||
throw new InternalServerErrorException('No provider available');
|
||||
}
|
||||
const session = await this.chatSession.get(sessionId);
|
||||
if (!session) {
|
||||
throw new BadRequestException('Session not found');
|
||||
}
|
||||
if (!content || !content.trim()) {
|
||||
throw new BadRequestException('Message is empty');
|
||||
}
|
||||
session.push({
|
||||
role: 'user',
|
||||
content: decodeURIComponent(content),
|
||||
createdAt: new Date(),
|
||||
});
|
||||
|
||||
const session = await this.appendSessionMessage(
|
||||
sessionId,
|
||||
message,
|
||||
messageId
|
||||
);
|
||||
|
||||
delete params.message;
|
||||
delete params.messageId;
|
||||
return from(
|
||||
provider.generateTextStream(session.finish(params), session.model, {
|
||||
signal: req.signal,
|
||||
@@ -130,7 +163,9 @@ export class CopilotController {
|
||||
connect(shared$ =>
|
||||
merge(
|
||||
// actual chat event stream
|
||||
shared$.pipe(map(data => ({ id: sessionId, data }))),
|
||||
shared$.pipe(
|
||||
map(data => ({ type: 'message' as const, id: sessionId, data }))
|
||||
),
|
||||
// save the generated text to the session
|
||||
shared$.pipe(
|
||||
toArray(),
|
||||
@@ -148,4 +183,66 @@ export class CopilotController {
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@Public()
|
||||
@Sse('/chat/:sessionId/images')
|
||||
async chatImagesStream(
|
||||
@CurrentUser() user: CurrentUser | undefined,
|
||||
@Req() req: Request,
|
||||
@Param('sessionId') sessionId: string,
|
||||
@Query('message') message: string | undefined,
|
||||
@Query('messageId') messageId: string | undefined,
|
||||
@Query() params: Record<string, string>
|
||||
): Promise<Observable<ChatEvent>> {
|
||||
const provider = this.provider.getProviderByCapability(
|
||||
(await this.hasAttachment(sessionId, messageId))
|
||||
? CopilotCapability.ImageToImage
|
||||
: CopilotCapability.TextToImage
|
||||
);
|
||||
if (!provider) {
|
||||
throw new InternalServerErrorException('No provider available');
|
||||
}
|
||||
|
||||
const session = await this.appendSessionMessage(
|
||||
sessionId,
|
||||
message,
|
||||
messageId
|
||||
);
|
||||
|
||||
delete params.message;
|
||||
delete params.messageId;
|
||||
return from(
|
||||
provider.generateImagesStream(session.finish(params), session.model, {
|
||||
signal: req.signal,
|
||||
user: user?.id,
|
||||
})
|
||||
).pipe(
|
||||
connect(shared$ =>
|
||||
merge(
|
||||
// actual chat event stream
|
||||
shared$.pipe(
|
||||
map(attachment => ({
|
||||
type: 'attachment' as const,
|
||||
id: sessionId,
|
||||
data: attachment,
|
||||
}))
|
||||
),
|
||||
// save the generated text to the session
|
||||
shared$.pipe(
|
||||
toArray(),
|
||||
concatMap(attachments => {
|
||||
session.push({
|
||||
role: 'assistant',
|
||||
content: '',
|
||||
attachments: attachments,
|
||||
createdAt: new Date(),
|
||||
});
|
||||
return from(session.save());
|
||||
}),
|
||||
switchMap(() => EMPTY)
|
||||
)
|
||||
)
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user