feat: text to image impl (#6437)

fix CLOUD-18
fix CLOUD-28
fix CLOUD-29
This commit is contained in:
darkskygit
2024-04-10 12:13:39 +00:00
parent 7c38a54f81
commit 9f349a2300
19 changed files with 601 additions and 99 deletions

View File

@@ -23,12 +23,13 @@ import {
import { Public } from '../../core/auth';
import { CurrentUser } from '../../core/auth/current-user';
import { CopilotProviderService } from './providers';
import { ChatSessionService } from './session';
import { ChatSession, ChatSessionService } from './session';
import { CopilotCapability } from './types';
export interface ChatEvent {
data: string;
type: 'attachment' | 'message';
id?: string;
data: string;
}
@Controller('/api/copilot')
@@ -38,13 +39,54 @@ export class CopilotController {
private readonly provider: CopilotProviderService
) {}
private async hasAttachment(sessionId: string, messageId?: string) {
const session = await this.chatSession.get(sessionId);
if (!session) {
throw new BadRequestException('Session not found');
}
if (messageId) {
const message = await session.getMessageById(messageId);
if (Array.isArray(message.attachments) && message.attachments.length) {
return true;
}
}
return false;
}
private async appendSessionMessage(
sessionId: string,
message?: string,
messageId?: string
): Promise<ChatSession> {
const session = await this.chatSession.get(sessionId);
if (!session) {
throw new BadRequestException('Session not found');
}
if (messageId) {
await session.pushByMessageId(messageId);
} else {
if (!message || !message.trim()) {
throw new BadRequestException('Message is empty');
}
session.push({
role: 'user',
content: decodeURIComponent(message),
createdAt: new Date(),
});
}
return session;
}
@Public()
@Get('/chat/:sessionId')
async chat(
@CurrentUser() user: CurrentUser,
@Req() req: Request,
@Param('sessionId') sessionId: string,
@Query('message') content: string,
@Query('message') message: string | undefined,
@Query('messageId') messageId: string | undefined,
@Query() params: Record<string, string | string[]>
): Promise<string> {
const provider = this.provider.getProviderByCapability(
@@ -53,21 +95,16 @@ export class CopilotController {
if (!provider) {
throw new InternalServerErrorException('No provider available');
}
const session = await this.chatSession.get(sessionId);
if (!session) {
throw new BadRequestException('Session not found');
}
if (!content || !content.trim()) {
throw new BadRequestException('Message is empty');
}
session.push({
role: 'user',
content: decodeURIComponent(content),
createdAt: new Date(),
});
const session = await this.appendSessionMessage(
sessionId,
message,
messageId
);
try {
delete params.message;
delete params.messageId;
const content = await provider.generateText(
session.finish(params),
session.model,
@@ -98,7 +135,8 @@ export class CopilotController {
@CurrentUser() user: CurrentUser,
@Req() req: Request,
@Param('sessionId') sessionId: string,
@Query('message') content: string,
@Query('message') message: string | undefined,
@Query('messageId') messageId: string | undefined,
@Query() params: Record<string, string>
): Promise<Observable<ChatEvent>> {
const provider = this.provider.getProviderByCapability(
@@ -107,20 +145,15 @@ export class CopilotController {
if (!provider) {
throw new InternalServerErrorException('No provider available');
}
const session = await this.chatSession.get(sessionId);
if (!session) {
throw new BadRequestException('Session not found');
}
if (!content || !content.trim()) {
throw new BadRequestException('Message is empty');
}
session.push({
role: 'user',
content: decodeURIComponent(content),
createdAt: new Date(),
});
const session = await this.appendSessionMessage(
sessionId,
message,
messageId
);
delete params.message;
delete params.messageId;
return from(
provider.generateTextStream(session.finish(params), session.model, {
signal: req.signal,
@@ -130,7 +163,9 @@ export class CopilotController {
connect(shared$ =>
merge(
// actual chat event stream
shared$.pipe(map(data => ({ id: sessionId, data }))),
shared$.pipe(
map(data => ({ type: 'message' as const, id: sessionId, data }))
),
// save the generated text to the session
shared$.pipe(
toArray(),
@@ -148,4 +183,66 @@ export class CopilotController {
)
);
}
@Public()
@Sse('/chat/:sessionId/images')
async chatImagesStream(
@CurrentUser() user: CurrentUser | undefined,
@Req() req: Request,
@Param('sessionId') sessionId: string,
@Query('message') message: string | undefined,
@Query('messageId') messageId: string | undefined,
@Query() params: Record<string, string>
): Promise<Observable<ChatEvent>> {
const provider = this.provider.getProviderByCapability(
(await this.hasAttachment(sessionId, messageId))
? CopilotCapability.ImageToImage
: CopilotCapability.TextToImage
);
if (!provider) {
throw new InternalServerErrorException('No provider available');
}
const session = await this.appendSessionMessage(
sessionId,
message,
messageId
);
delete params.message;
delete params.messageId;
return from(
provider.generateImagesStream(session.finish(params), session.model, {
signal: req.signal,
user: user?.id,
})
).pipe(
connect(shared$ =>
merge(
// actual chat event stream
shared$.pipe(
map(attachment => ({
type: 'attachment' as const,
id: sessionId,
data: attachment,
}))
),
// save the generated text to the session
shared$.pipe(
toArray(),
concatMap(attachments => {
session.push({
role: 'assistant',
content: '',
attachments: attachments,
createdAt: new Date(),
});
return from(session.save());
}),
switchMap(() => EMPTY)
)
)
)
);
}
}