feat: text to image impl (#6437)

fix CLOUD-18
fix CLOUD-28
fix CLOUD-29
This commit is contained in:
darkskygit
2024-04-10 12:13:39 +00:00
parent 7c38a54f81
commit 9f349a2300
19 changed files with 601 additions and 99 deletions

View File

@@ -23,12 +23,13 @@ import {
import { Public } from '../../core/auth';
import { CurrentUser } from '../../core/auth/current-user';
import { CopilotProviderService } from './providers';
import { ChatSessionService } from './session';
import { ChatSession, ChatSessionService } from './session';
import { CopilotCapability } from './types';
export interface ChatEvent {
data: string;
type: 'attachment' | 'message';
id?: string;
data: string;
}
@Controller('/api/copilot')
@@ -38,13 +39,54 @@ export class CopilotController {
private readonly provider: CopilotProviderService
) {}
private async hasAttachment(sessionId: string, messageId?: string) {
const session = await this.chatSession.get(sessionId);
if (!session) {
throw new BadRequestException('Session not found');
}
if (messageId) {
const message = await session.getMessageById(messageId);
if (Array.isArray(message.attachments) && message.attachments.length) {
return true;
}
}
return false;
}
private async appendSessionMessage(
sessionId: string,
message?: string,
messageId?: string
): Promise<ChatSession> {
const session = await this.chatSession.get(sessionId);
if (!session) {
throw new BadRequestException('Session not found');
}
if (messageId) {
await session.pushByMessageId(messageId);
} else {
if (!message || !message.trim()) {
throw new BadRequestException('Message is empty');
}
session.push({
role: 'user',
content: decodeURIComponent(message),
createdAt: new Date(),
});
}
return session;
}
@Public()
@Get('/chat/:sessionId')
async chat(
@CurrentUser() user: CurrentUser,
@Req() req: Request,
@Param('sessionId') sessionId: string,
@Query('message') content: string,
@Query('message') message: string | undefined,
@Query('messageId') messageId: string | undefined,
@Query() params: Record<string, string | string[]>
): Promise<string> {
const provider = this.provider.getProviderByCapability(
@@ -53,21 +95,16 @@ export class CopilotController {
if (!provider) {
throw new InternalServerErrorException('No provider available');
}
const session = await this.chatSession.get(sessionId);
if (!session) {
throw new BadRequestException('Session not found');
}
if (!content || !content.trim()) {
throw new BadRequestException('Message is empty');
}
session.push({
role: 'user',
content: decodeURIComponent(content),
createdAt: new Date(),
});
const session = await this.appendSessionMessage(
sessionId,
message,
messageId
);
try {
delete params.message;
delete params.messageId;
const content = await provider.generateText(
session.finish(params),
session.model,
@@ -98,7 +135,8 @@ export class CopilotController {
@CurrentUser() user: CurrentUser,
@Req() req: Request,
@Param('sessionId') sessionId: string,
@Query('message') content: string,
@Query('message') message: string | undefined,
@Query('messageId') messageId: string | undefined,
@Query() params: Record<string, string>
): Promise<Observable<ChatEvent>> {
const provider = this.provider.getProviderByCapability(
@@ -107,20 +145,15 @@ export class CopilotController {
if (!provider) {
throw new InternalServerErrorException('No provider available');
}
const session = await this.chatSession.get(sessionId);
if (!session) {
throw new BadRequestException('Session not found');
}
if (!content || !content.trim()) {
throw new BadRequestException('Message is empty');
}
session.push({
role: 'user',
content: decodeURIComponent(content),
createdAt: new Date(),
});
const session = await this.appendSessionMessage(
sessionId,
message,
messageId
);
delete params.message;
delete params.messageId;
return from(
provider.generateTextStream(session.finish(params), session.model, {
signal: req.signal,
@@ -130,7 +163,9 @@ export class CopilotController {
connect(shared$ =>
merge(
// actual chat event stream
shared$.pipe(map(data => ({ id: sessionId, data }))),
shared$.pipe(
map(data => ({ type: 'message' as const, id: sessionId, data }))
),
// save the generated text to the session
shared$.pipe(
toArray(),
@@ -148,4 +183,66 @@ export class CopilotController {
)
);
}
@Public()
@Sse('/chat/:sessionId/images')
async chatImagesStream(
@CurrentUser() user: CurrentUser | undefined,
@Req() req: Request,
@Param('sessionId') sessionId: string,
@Query('message') message: string | undefined,
@Query('messageId') messageId: string | undefined,
@Query() params: Record<string, string>
): Promise<Observable<ChatEvent>> {
const provider = this.provider.getProviderByCapability(
(await this.hasAttachment(sessionId, messageId))
? CopilotCapability.ImageToImage
: CopilotCapability.TextToImage
);
if (!provider) {
throw new InternalServerErrorException('No provider available');
}
const session = await this.appendSessionMessage(
sessionId,
message,
messageId
);
delete params.message;
delete params.messageId;
return from(
provider.generateImagesStream(session.finish(params), session.model, {
signal: req.signal,
user: user?.id,
})
).pipe(
connect(shared$ =>
merge(
// actual chat event stream
shared$.pipe(
map(attachment => ({
type: 'attachment' as const,
id: sessionId,
data: attachment,
}))
),
// save the generated text to the session
shared$.pipe(
toArray(),
concatMap(attachments => {
session.push({
role: 'assistant',
content: '',
attachments: attachments,
createdAt: new Date(),
});
return from(session.save());
}),
switchMap(() => EMPTY)
)
)
)
);
}
}

View File

@@ -3,16 +3,19 @@ import { QuotaService } from '../../core/quota';
import { PermissionService } from '../../core/workspaces/permission';
import { Plugin } from '../registry';
import { CopilotController } from './controller';
import { ChatMessageCache } from './message';
import { PromptService } from './prompt';
import {
assertProvidersConfigs,
CopilotProviderService,
FalProvider,
OpenAIProvider,
registerCopilotProvider,
} from './providers';
import { CopilotResolver, UserCopilotResolver } from './resolver';
import { ChatSessionService } from './session';
registerCopilotProvider(FalProvider);
registerCopilotProvider(OpenAIProvider);
@Plugin({
@@ -22,6 +25,7 @@ registerCopilotProvider(OpenAIProvider);
QuotaService,
ChatSessionService,
CopilotResolver,
ChatMessageCache,
UserCopilotResolver,
PromptService,
CopilotProviderService,

View File

@@ -0,0 +1,35 @@
import { randomUUID } from 'node:crypto';
import { Injectable, Logger } from '@nestjs/common';
import { SessionCache } from '../../fundamentals';
import { SubmittedMessage, SubmittedMessageSchema } from './types';
const CHAT_MESSAGE_KEY = 'chat-message';
const CHAT_MESSAGE_TTL = 3600 * 1 * 1000; // 1 hours
@Injectable()
export class ChatMessageCache {
private readonly logger = new Logger(ChatMessageCache.name);
constructor(private readonly cache: SessionCache) {}
async get(id: string): Promise<SubmittedMessage | undefined> {
return await this.cache.get(`${CHAT_MESSAGE_KEY}:${id}`);
}
async set(message: SubmittedMessage): Promise<string | undefined> {
try {
const parsed = SubmittedMessageSchema.safeParse(message);
if (parsed.success) {
const id = randomUUID();
await this.cache.set(`${CHAT_MESSAGE_KEY}:${id}`, parsed.data, {
ttl: CHAT_MESSAGE_TTL,
});
return id;
}
} catch (e: any) {
this.logger.error(`Failed to get chat message from cache: ${e.message}`);
}
return undefined;
}
}

View File

@@ -0,0 +1,92 @@
import assert from 'node:assert';
import {
CopilotCapability,
CopilotImageToImageProvider,
CopilotProviderType,
PromptMessage,
} from '../types';
export type FalConfig = {
apiKey: string;
};
export type FalResponse = {
images: Array<{ url: string }>;
};
export class FalProvider implements CopilotImageToImageProvider {
static readonly type = CopilotProviderType.FAL;
static readonly capabilities = [CopilotCapability.ImageToImage];
readonly availableModels = [
// image to image
// https://blog.fal.ai/building-applications-with-real-time-stable-diffusion-apis/
'110602490-lcm-sd15-i2i',
];
constructor(private readonly config: FalConfig) {
assert(FalProvider.assetsConfig(config));
}
static assetsConfig(config: FalConfig) {
return !!config.apiKey;
}
getCapabilities(): CopilotCapability[] {
return FalProvider.capabilities;
}
// ====== image to image ======
async generateImages(
messages: PromptMessage[],
model: string = this.availableModels[0],
options: {
signal?: AbortSignal;
user?: string;
} = {}
): Promise<Array<string>> {
const { content, attachments } = messages.pop() || {};
if (!this.availableModels.includes(model)) {
throw new Error(`Invalid model: ${model}`);
}
if (!content) {
throw new Error('Prompt is required');
}
if (!Array.isArray(attachments) || !attachments.length) {
throw new Error('Attachments is required');
}
const data = (await fetch(`https://${model}.gateway.alpha.fal.ai/`, {
method: 'POST',
headers: {
Authorization: `key ${this.config.apiKey}`,
'Content-Type': 'application/json',
},
body: JSON.stringify({
image_url: attachments[0],
prompt: content,
sync_mode: true,
seed: 42,
enable_safety_checks: false,
}),
signal: options.signal,
}).then(res => res.json())) as FalResponse;
return data.images.map(image => image.url);
}
async *generateImagesStream(
messages: PromptMessage[],
model: string = this.availableModels[0],
options: {
signal?: AbortSignal;
user?: string;
} = {}
): AsyncIterable<string> {
const ret = await this.generateImages(messages, model, options);
for (const url of ret) {
yield url;
}
}
}

View File

@@ -134,4 +134,5 @@ export class CopilotProviderService {
}
}
export { FalProvider } from './fal';
export { OpenAIProvider } from './openai';

View File

@@ -5,22 +5,31 @@ import { ClientOptions, OpenAI } from 'openai';
import {
ChatMessageRole,
CopilotCapability,
CopilotImageToTextProvider,
CopilotProviderType,
CopilotTextToEmbeddingProvider,
CopilotTextToImageProvider,
CopilotTextToTextProvider,
PromptMessage,
} from '../types';
const DEFAULT_DIMENSIONS = 256;
const SIMPLE_IMAGE_URL_REGEX = /^(https?:\/\/|data:image\/)/;
export class OpenAIProvider
implements CopilotTextToTextProvider, CopilotTextToEmbeddingProvider
implements
CopilotTextToTextProvider,
CopilotTextToEmbeddingProvider,
CopilotTextToImageProvider,
CopilotImageToTextProvider
{
static readonly type = CopilotProviderType.OpenAI;
static readonly capabilities = [
CopilotCapability.TextToText,
CopilotCapability.TextToEmbedding,
CopilotCapability.TextToImage,
CopilotCapability.ImageToText,
];
readonly availableModels = [
@@ -35,6 +44,8 @@ export class OpenAIProvider
// moderation
'text-moderation-latest',
'text-moderation-stable',
// text to image
'dall-e-3',
];
private readonly instance: OpenAI;
@@ -52,12 +63,29 @@ export class OpenAIProvider
return OpenAIProvider.capabilities;
}
private chatToGPTMessage(messages: PromptMessage[]) {
private chatToGPTMessage(
messages: PromptMessage[]
): OpenAI.Chat.Completions.ChatCompletionMessageParam[] {
// filter redundant fields
return messages.map(message => ({
role: message.role,
content: message.content,
}));
return messages.map(({ role, content, attachments }) => {
if (Array.isArray(attachments)) {
const contents = [
{ type: 'text', text: content },
...attachments
.filter(url => SIMPLE_IMAGE_URL_REGEX.test(url))
.map(url => ({
type: 'image_url',
image_url: { url, detail: 'low' },
})),
];
return {
role,
content: contents,
} as OpenAI.Chat.Completions.ChatCompletionMessageParam;
} else {
return { role, content };
}
});
}
private checkParams({
@@ -194,4 +222,44 @@ export class OpenAIProvider
});
return result.data.map(e => e.embedding);
}
// ====== text to image ======
async generateImages(
messages: PromptMessage[],
model: string = 'dall-e-3',
options: {
signal?: AbortSignal;
user?: string;
} = {}
): Promise<Array<string>> {
const { content: prompt } = messages.pop() || {};
if (!prompt) {
throw new Error('Prompt is required');
}
const result = await this.instance.images.generate(
{
prompt,
model,
response_format: 'url',
user: options.user,
},
{ signal: options.signal }
);
return result.data.map(image => image.url).filter((v): v is string => !!v);
}
async *generateImagesStream(
messages: PromptMessage[],
model: string = 'dall-e-3',
options: {
signal?: AbortSignal;
user?: string;
} = {}
): AsyncIterable<string> {
const ret = await this.generateImages(messages, model, options);
for (const url of ret) {
yield url;
}
}
}

View File

@@ -1,3 +1,4 @@
import { Logger } from '@nestjs/common';
import {
Args,
Field,
@@ -12,7 +13,7 @@ import {
} from '@nestjs/graphql';
import { SafeIntResolver } from 'graphql-scalars';
import { CurrentUser, Public } from '../../core/auth';
import { CurrentUser } from '../../core/auth';
import { QuotaService } from '../../core/quota';
import { UserType } from '../../core/user';
import { PermissionService } from '../../core/workspaces/permission';
@@ -21,11 +22,19 @@ import {
PaymentRequiredException,
TooManyRequestsException,
} from '../../fundamentals';
import { ChatSessionService, ListHistoriesOptions } from './session';
import { AvailableModels, type ChatHistory, type ChatMessage } from './types';
import { ChatSessionService } from './session';
import {
AvailableModels,
type ChatHistory,
type ChatMessage,
type ListHistoriesOptions,
SubmittedMessage,
} from './types';
registerEnumType(AvailableModels, { name: 'CopilotModel' });
const COPILOT_LOCKER = 'copilot';
// ================== Input Types ==================
@InputType()
@@ -48,6 +57,21 @@ class CreateChatSessionInput {
promptName!: string;
}
@InputType()
class CreateChatMessageInput implements Omit<SubmittedMessage, 'params'> {
@Field(() => String)
sessionId!: string;
@Field(() => String)
content!: string;
@Field(() => [String], { nullable: true })
attachments!: string[] | undefined;
@Field(() => String, { nullable: true })
params!: string | undefined;
}
@InputType()
class QueryChatHistoriesInput implements Partial<ListHistoriesOptions> {
@Field(() => Boolean, { nullable: true })
@@ -118,6 +142,8 @@ export class CopilotType {
@Resolver(() => CopilotType)
export class CopilotResolver {
private readonly logger = new Logger(CopilotResolver.name);
constructor(
private readonly permissions: PermissionService,
private readonly quota: QuotaService,
@@ -208,7 +234,6 @@ export class CopilotResolver {
);
}
@Public()
@Mutation(() => String, {
description: 'Create a chat session',
})
@@ -222,7 +247,7 @@ export class CopilotResolver {
options.docId,
user.id
);
const lockFlag = `session:${user.id}:${options.workspaceId}`;
const lockFlag = `${COPILOT_LOCKER}:session:${user.id}:${options.workspaceId}`;
await using lock = await this.mutex.lock(lockFlag);
if (!lock) {
return new TooManyRequestsException('Server is busy');
@@ -241,6 +266,32 @@ export class CopilotResolver {
});
return session;
}
@Mutation(() => String, {
description: 'Create a chat message',
})
async createCopilotMessage(
@CurrentUser() user: CurrentUser,
@Args({ name: 'options', type: () => CreateChatMessageInput })
options: CreateChatMessageInput
) {
const lockFlag = `${COPILOT_LOCKER}:message:${user?.id}:${options.sessionId}`;
await using lock = await this.mutex.lock(lockFlag);
if (!lock) {
return new TooManyRequestsException('Server is busy');
}
try {
const { params, ...rest } = options;
const record: SubmittedMessage['params'] = {};
new URLSearchParams(params).forEach((value, key) => {
record[key] = value;
});
return await this.chatSession.createMessage({ ...rest, params: record });
} catch (e: any) {
this.logger.error(`Failed to create chat message: ${e.message}`);
throw new Error('Failed to create chat message');
}
}
}
@Resolver(() => UserType)

View File

@@ -3,43 +3,26 @@ import { randomUUID } from 'node:crypto';
import { Injectable, Logger } from '@nestjs/common';
import { PrismaClient } from '@prisma/client';
import { ChatMessageCache } from './message';
import { ChatPrompt, PromptService } from './prompt';
import {
AvailableModel,
ChatHistory,
ChatMessage,
ChatMessageSchema,
ChatSessionOptions,
ChatSessionState,
getTokenEncoder,
ListHistoriesOptions,
PromptMessage,
PromptMessageSchema,
PromptParams,
SubmittedMessage,
} from './types';
export interface ChatSessionOptions {
userId: string;
workspaceId: string;
docId: string;
promptName: string;
}
export interface ChatSessionState
extends Omit<ChatSessionOptions, 'promptName'> {
// connect ids
sessionId: string;
// states
prompt: ChatPrompt;
messages: ChatMessage[];
}
export type ListHistoriesOptions = {
action: boolean | undefined;
limit: number | undefined;
skip: number | undefined;
sessionId: string | undefined;
};
export class ChatSession implements AsyncDisposable {
constructor(
private readonly messageCache: ChatMessageCache,
private readonly state: ChatSessionState,
private readonly dispose?: (state: ChatSessionState) => Promise<void>,
private readonly maxTokenSize = 3840
@@ -60,6 +43,29 @@ export class ChatSession implements AsyncDisposable {
this.state.messages.push(message);
}
async getMessageById(messageId: string) {
const message = await this.messageCache.get(messageId);
if (!message || message.sessionId !== this.state.sessionId) {
throw new Error(`Message not found: ${messageId}`);
}
return message;
}
async pushByMessageId(messageId: string) {
const message = await this.messageCache.get(messageId);
if (!message || message.sessionId !== this.state.sessionId) {
throw new Error(`Message not found: ${messageId}`);
}
this.push({
role: 'user',
content: message.content,
attachments: message.attachments,
params: message.params,
createdAt: new Date(),
});
}
pop() {
this.state.messages.pop();
}
@@ -109,6 +115,7 @@ export class ChatSessionService {
constructor(
private readonly db: PrismaClient,
private readonly messageCache: ChatMessageCache,
private readonly prompt: PromptService
) {}
@@ -326,6 +333,10 @@ export class ChatSessionService {
});
}
async createMessage(message: SubmittedMessage): Promise<string | undefined> {
return await this.messageCache.set(message);
}
/**
* usage:
* ``` typescript
@@ -342,7 +353,7 @@ export class ChatSessionService {
async get(sessionId: string): Promise<ChatSession | null> {
const state = await this.getSession(sessionId);
if (state) {
return new ChatSession(state, async state => {
return new ChatSession(this.messageCache, state, async state => {
await this.setSession(state);
});
}

View File

@@ -8,10 +8,12 @@ import {
} from 'tiktoken';
import { z } from 'zod';
import type { ChatPrompt } from './prompt';
export interface CopilotConfig {
openai: OpenAIClientOptions;
fal: {
secret: string;
apiKey: string;
};
}
@@ -27,6 +29,8 @@ export enum AvailableModels {
// moderation
TextModerationLatest = 'text-moderation-latest',
TextModerationStable = 'text-moderation-stable',
// text to image
DallE3 = 'dall-e-3',
}
export type AvailableModel = keyof typeof AvailableModels;
@@ -53,8 +57,7 @@ export const ChatMessageRole = Object.values(AiPromptRole) as [
'user',
];
export const PromptMessageSchema = z.object({
role: z.enum(ChatMessageRole),
const PureMessageSchema = z.object({
content: z.string(),
attachments: z.array(z.string()).optional(),
params: z
@@ -63,6 +66,10 @@ export const PromptMessageSchema = z.object({
.nullable(),
});
export const PromptMessageSchema = PureMessageSchema.extend({
role: z.enum(ChatMessageRole),
}).strict();
export type PromptMessage = z.infer<typeof PromptMessageSchema>;
export type PromptParams = NonNullable<PromptMessage['params']>;
@@ -73,6 +80,12 @@ export const ChatMessageSchema = PromptMessageSchema.extend({
export type ChatMessage = z.infer<typeof ChatMessageSchema>;
export const SubmittedMessageSchema = PureMessageSchema.extend({
sessionId: z.string(),
}).strict();
export type SubmittedMessage = z.infer<typeof SubmittedMessageSchema>;
export const ChatHistorySchema = z
.object({
sessionId: z.string(),
@@ -84,6 +97,32 @@ export const ChatHistorySchema = z
export type ChatHistory = z.infer<typeof ChatHistorySchema>;
// ======== Chat Session ========
export interface ChatSessionOptions {
// connect ids
userId: string;
workspaceId: string;
docId: string;
promptName: string;
}
export interface ChatSessionState
extends Omit<ChatSessionOptions, 'promptName'> {
// connect ids
sessionId: string;
// states
prompt: ChatPrompt;
messages: ChatMessage[];
}
export type ListHistoriesOptions = {
action: boolean | undefined;
limit: number | undefined;
skip: number | undefined;
sessionId: string | undefined;
};
// ======== Provider Interface ========
export enum CopilotProviderType {
@@ -96,6 +135,7 @@ export enum CopilotCapability {
TextToEmbedding = 'text-to-embedding',
TextToImage = 'text-to-image',
ImageToImage = 'image-to-image',
ImageToText = 'image-to-text',
}
export interface CopilotProvider {
@@ -137,13 +177,71 @@ export interface CopilotTextToEmbeddingProvider extends CopilotProvider {
): Promise<number[][]>;
}
export interface CopilotTextToImageProvider extends CopilotProvider {}
export interface CopilotTextToImageProvider extends CopilotProvider {
generateImages(
messages: PromptMessage[],
model: string,
options: {
signal?: AbortSignal;
user?: string;
}
): Promise<Array<string>>;
generateImagesStream(
messages: PromptMessage[],
model?: string,
options?: {
signal?: AbortSignal;
user?: string;
}
): AsyncIterable<string>;
}
export interface CopilotImageToImageProvider extends CopilotProvider {}
export interface CopilotImageToTextProvider extends CopilotProvider {
generateText(
messages: PromptMessage[],
model: string,
options: {
temperature?: number;
maxTokens?: number;
signal?: AbortSignal;
user?: string;
}
): Promise<string>;
generateTextStream(
messages: PromptMessage[],
model: string,
options: {
temperature?: number;
maxTokens?: number;
signal?: AbortSignal;
user?: string;
}
): AsyncIterable<string>;
}
export interface CopilotImageToImageProvider extends CopilotProvider {
generateImages(
messages: PromptMessage[],
model: string,
options: {
signal?: AbortSignal;
user?: string;
}
): Promise<Array<string>>;
generateImagesStream(
messages: PromptMessage[],
model?: string,
options?: {
signal?: AbortSignal;
user?: string;
}
): AsyncIterable<string>;
}
export type CapabilityToCopilotProvider = {
[CopilotCapability.TextToText]: CopilotTextToTextProvider;
[CopilotCapability.TextToEmbedding]: CopilotTextToEmbeddingProvider;
[CopilotCapability.TextToImage]: CopilotTextToImageProvider;
[CopilotCapability.ImageToText]: CopilotImageToTextProvider;
[CopilotCapability.ImageToImage]: CopilotImageToImageProvider;
};