From 9f349a2300092c61902fcb2aec4349da5aad1504 Mon Sep 17 00:00:00 2001 From: darkskygit Date: Wed, 10 Apr 2024 12:13:39 +0000 Subject: [PATCH] feat: text to image impl (#6437) fix CLOUD-18 fix CLOUD-28 fix CLOUD-29 --- .github/actions/deploy/deploy.mjs | 2 + .../graphql/templates/copilot-secret.yaml | 1 + .../charts/graphql/templates/deployment.yaml | 5 + .github/workflows/deploy.yml | 1 + .../migration.sql | 3 + packages/backend/server/schema.prisma | 27 +-- .../backend/server/src/config/affine.env.ts | 1 + .../src/data/migrations/utils/prompts.ts | 48 ++++-- .../server/src/plugins/copilot/controller.ts | 155 ++++++++++++++---- .../server/src/plugins/copilot/index.ts | 4 + .../server/src/plugins/copilot/message.ts | 35 ++++ .../src/plugins/copilot/providers/fal.ts | 92 +++++++++++ .../src/plugins/copilot/providers/index.ts | 1 + .../src/plugins/copilot/providers/openai.ts | 80 ++++++++- .../server/src/plugins/copilot/resolver.ts | 61 ++++++- .../server/src/plugins/copilot/session.ts | 59 ++++--- .../server/src/plugins/copilot/types.ts | 108 +++++++++++- packages/backend/server/src/schema.gql | 10 ++ packages/frontend/graphql/src/schema.ts | 7 + 19 files changed, 601 insertions(+), 99 deletions(-) create mode 100644 packages/backend/server/src/plugins/copilot/message.ts create mode 100644 packages/backend/server/src/plugins/copilot/providers/fal.ts diff --git a/.github/actions/deploy/deploy.mjs b/.github/actions/deploy/deploy.mjs index 33b5707c8c..e1583cf540 100644 --- a/.github/actions/deploy/deploy.mjs +++ b/.github/actions/deploy/deploy.mjs @@ -15,6 +15,7 @@ const { R2_SECRET_ACCESS_KEY, CAPTCHA_TURNSTILE_SECRET, COPILOT_OPENAI_API_KEY, + COPILOT_FAL_API_KEY, MAILER_SENDER, MAILER_USER, MAILER_PASSWORD, @@ -101,6 +102,7 @@ const createHelmCommand = ({ isDryRun }) => { `--set-string graphql.app.captcha.turnstile.secret="${CAPTCHA_TURNSTILE_SECRET}"`, `--set graphql.app.copilot.enabled=true`, `--set-string graphql.app.copilot.openai.key="${COPILOT_OPENAI_API_KEY}"`, + `--set-string graphql.app.copilot.fal.key="${COPILOT_FAL_API_KEY}"`, `--set graphql.app.objectStorage.r2.enabled=true`, `--set-string graphql.app.objectStorage.r2.accountId="${R2_ACCOUNT_ID}"`, `--set-string graphql.app.objectStorage.r2.accessKeyId="${R2_ACCESS_KEY_ID}"`, diff --git a/.github/helm/affine/charts/graphql/templates/copilot-secret.yaml b/.github/helm/affine/charts/graphql/templates/copilot-secret.yaml index 277b1ff965..26858e63dc 100644 --- a/.github/helm/affine/charts/graphql/templates/copilot-secret.yaml +++ b/.github/helm/affine/charts/graphql/templates/copilot-secret.yaml @@ -6,4 +6,5 @@ metadata: type: Opaque data: openaiSecret: {{ .Values.app.copilot.openai.key | b64enc }} + falSecret: {{ .Values.app.copilot.fal.key | b64enc }} {{- end }} diff --git a/.github/helm/affine/charts/graphql/templates/deployment.yaml b/.github/helm/affine/charts/graphql/templates/deployment.yaml index faa9b02fb6..56f575206d 100644 --- a/.github/helm/affine/charts/graphql/templates/deployment.yaml +++ b/.github/helm/affine/charts/graphql/templates/deployment.yaml @@ -154,6 +154,11 @@ spec: secretKeyRef: name: "{{ .Values.app.copilot.secretName }}" key: openaiSecret + - name: COPILOT_FAL_API_KEY + valueFrom: + secretKeyRef: + name: "{{ .Values.app.copilot.secretName }}" + key: falSecret {{ end }} {{ if .Values.app.oauth.google.enabled }} - name: OAUTH_GOOGLE_ENABLED diff --git a/.github/workflows/deploy.yml b/.github/workflows/deploy.yml index bc3253da6a..a66ec6e7c5 100644 --- a/.github/workflows/deploy.yml +++ b/.github/workflows/deploy.yml @@ -135,6 +135,7 @@ jobs: R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} CAPTCHA_TURNSTILE_SECRET: ${{ secrets.CAPTCHA_TURNSTILE_SECRET }} COPILOT_OPENAI_API_KEY: ${{ secrets.COPILOT_OPENAI_API_KEY }} + COPILOT_FAL_API_KEY: ${{ secrets.COPILOT_FAL_API_KEY }} MAILER_SENDER: ${{ secrets.OAUTH_EMAIL_SENDER }} MAILER_USER: ${{ secrets.OAUTH_EMAIL_LOGIN }} MAILER_PASSWORD: ${{ secrets.OAUTH_EMAIL_PASSWORD }} diff --git a/packages/backend/server/migrations/20240402100608_ai_prompt_session_metadata/migration.sql b/packages/backend/server/migrations/20240402100608_ai_prompt_session_metadata/migration.sql index 837d9601ea..1c41993c5c 100644 --- a/packages/backend/server/migrations/20240402100608_ai_prompt_session_metadata/migration.sql +++ b/packages/backend/server/migrations/20240402100608_ai_prompt_session_metadata/migration.sql @@ -26,6 +26,7 @@ CREATE TABLE "ai_prompts_messages" ( "idx" INTEGER NOT NULL, "role" "AiPromptRole" NOT NULL, "content" TEXT NOT NULL, + "attachments" JSON, "params" JSON, "created_at" TIMESTAMPTZ(6) NOT NULL DEFAULT CURRENT_TIMESTAMP ); @@ -47,6 +48,8 @@ CREATE TABLE "ai_sessions_messages" ( "session_id" VARCHAR(36) NOT NULL, "role" "AiPromptRole" NOT NULL, "content" TEXT NOT NULL, + "attachments" JSON, + "params" JSON, "created_at" TIMESTAMPTZ(6) NOT NULL DEFAULT CURRENT_TIMESTAMP, "updated_at" TIMESTAMPTZ(6) NOT NULL, diff --git a/packages/backend/server/schema.prisma b/packages/backend/server/schema.prisma index f9f5ae0696..920268a1c1 100644 --- a/packages/backend/server/schema.prisma +++ b/packages/backend/server/schema.prisma @@ -430,15 +430,16 @@ enum AiPromptRole { } model AiPromptMessage { - promptId Int @map("prompt_id") @db.Integer + promptId Int @map("prompt_id") @db.Integer // if a group of prompts contains multiple sentences, idx specifies the order of each sentence - idx Int @db.Integer + idx Int @db.Integer // system/assistant/user - role AiPromptRole + role AiPromptRole // prompt content - content String @db.Text - params Json? @db.Json - createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(6) + content String @db.Text + attachments Json? @db.Json + params Json? @db.Json + createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(6) prompt AiPrompt @relation(fields: [promptId], references: [id], onDelete: Cascade) @@ -462,12 +463,14 @@ model AiPrompt { } model AiSessionMessage { - id String @id @default(uuid()) @db.VarChar(36) - sessionId String @map("session_id") @db.VarChar(36) - role AiPromptRole - content String @db.Text - createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(6) - updatedAt DateTime @updatedAt @map("updated_at") @db.Timestamptz(6) + id String @id @default(uuid()) @db.VarChar(36) + sessionId String @map("session_id") @db.VarChar(36) + role AiPromptRole + content String @db.Text + attachments Json? @db.Json + params Json? @db.Json + createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(6) + updatedAt DateTime @updatedAt @map("updated_at") @db.Timestamptz(6) session AiSession @relation(fields: [sessionId], references: [id], onDelete: Cascade) diff --git a/packages/backend/server/src/config/affine.env.ts b/packages/backend/server/src/config/affine.env.ts index 148cb6585e..c393f6b89d 100644 --- a/packages/backend/server/src/config/affine.env.ts +++ b/packages/backend/server/src/config/affine.env.ts @@ -20,6 +20,7 @@ AFFiNE.ENV_MAP = { THROTTLE_TTL: ['rateLimiter.ttl', 'int'], THROTTLE_LIMIT: ['rateLimiter.limit', 'int'], COPILOT_OPENAI_API_KEY: 'plugins.copilot.openai.apiKey', + COPILOT_FAL_API_KEY: 'plugins.copilot.fal.apiKey', REDIS_SERVER_HOST: 'plugins.redis.host', REDIS_SERVER_PORT: ['plugins.redis.port', 'int'], REDIS_SERVER_USER: 'plugins.redis.username', diff --git a/packages/backend/server/src/data/migrations/utils/prompts.ts b/packages/backend/server/src/data/migrations/utils/prompts.ts index 6423658a0d..6818b0728c 100644 --- a/packages/backend/server/src/data/migrations/utils/prompts.ts +++ b/packages/backend/server/src/data/migrations/utils/prompts.ts @@ -31,10 +31,22 @@ export const prompts: Prompt[] = [ model: 'gpt-4-vision-preview', messages: [], }, + { + name: 'debug:action:dalle3', + action: 'image', + model: 'dall-e-3', + messages: [], + }, + { + name: 'debug:action:fal-sd15', + action: 'image', + model: '110602490-lcm-sd15-i2i', + messages: [], + }, { name: 'Summary', action: 'text', - model: 'gpt-3.5-turbo', + model: 'gpt-4-turbo-preview', messages: [ { role: 'assistant', @@ -46,7 +58,7 @@ export const prompts: Prompt[] = [ { name: 'Summary the webpage', action: 'text', - model: 'gpt-3.5-turbo', + model: 'gpt-4-turbo-preview', messages: [ { role: 'assistant', @@ -58,7 +70,7 @@ export const prompts: Prompt[] = [ { name: 'Explain this image', action: 'text', - model: 'gpt-3.5-turbo', + model: 'gpt-4-vision-preview', messages: [ { role: 'assistant', @@ -70,7 +82,7 @@ export const prompts: Prompt[] = [ { name: 'Explain this code', action: 'text', - model: 'gpt-3.5-turbo', + model: 'gpt-4-turbo-preview', messages: [ { role: 'assistant', @@ -82,7 +94,7 @@ export const prompts: Prompt[] = [ { name: 'Translate to', action: 'text', - model: 'gpt-3.5-turbo', + model: 'gpt-4-turbo-preview', messages: [ { role: 'assistant', @@ -108,7 +120,7 @@ export const prompts: Prompt[] = [ { name: 'Write an article about this', action: 'text', - model: 'gpt-3.5-turbo', + model: 'gpt-4-turbo-preview', messages: [ { role: 'assistant', @@ -119,7 +131,7 @@ export const prompts: Prompt[] = [ { name: 'Write a twitter about this', action: 'text', - model: 'gpt-3.5-turbo', + model: 'gpt-4-turbo-preview', messages: [ { role: 'assistant', @@ -130,7 +142,7 @@ export const prompts: Prompt[] = [ { name: 'Write a poem about this', action: 'text', - model: 'gpt-3.5-turbo', + model: 'gpt-4-turbo-preview', messages: [ { role: 'assistant', @@ -141,7 +153,7 @@ export const prompts: Prompt[] = [ { name: 'Write a blog post about this', action: 'text', - model: 'gpt-3.5-turbo', + model: 'gpt-4-turbo-preview', messages: [ { role: 'assistant', @@ -152,7 +164,7 @@ export const prompts: Prompt[] = [ { name: 'Change tone to', action: 'text', - model: 'gpt-3.5-turbo', + model: 'gpt-4-turbo-preview', messages: [ { role: 'assistant', @@ -165,7 +177,7 @@ export const prompts: Prompt[] = [ { name: 'Brainstorm ideas about this', action: 'text', - model: 'gpt-3.5-turbo', + model: 'gpt-4-turbo-preview', messages: [ { role: 'assistant', @@ -177,7 +189,7 @@ export const prompts: Prompt[] = [ { name: 'Improve writing for it', action: 'text', - model: 'gpt-3.5-turbo', + model: 'gpt-4-turbo-preview', messages: [ { role: 'assistant', @@ -189,7 +201,7 @@ export const prompts: Prompt[] = [ { name: 'Improve grammar for it', action: 'text', - model: 'gpt-3.5-turbo', + model: 'gpt-4-turbo-preview', messages: [ { role: 'assistant', @@ -201,7 +213,7 @@ export const prompts: Prompt[] = [ { name: 'Fix spelling for it', action: 'text', - model: 'gpt-3.5-turbo', + model: 'gpt-4-turbo-preview', messages: [ { role: 'assistant', @@ -227,7 +239,7 @@ export const prompts: Prompt[] = [ { name: 'Find action items from it', action: 'todo-list', - model: 'gpt-3.5-turbo', + model: 'gpt-4-turbo-preview', messages: [ { role: 'assistant', @@ -239,7 +251,7 @@ export const prompts: Prompt[] = [ { name: 'Check code error', action: 'text', - model: 'gpt-3.5-turbo', + model: 'gpt-4-turbo-preview', messages: [ { role: 'assistant', @@ -251,7 +263,7 @@ export const prompts: Prompt[] = [ { name: 'Create a presentation', action: 'text', - model: 'gpt-3.5-turbo', + model: 'gpt-4-turbo-preview', messages: [ { role: 'assistant', @@ -263,7 +275,7 @@ export const prompts: Prompt[] = [ { name: 'Create headings', action: 'text', - model: 'gpt-3.5-turbo', + model: 'gpt-4-turbo-preview', messages: [ { role: 'assistant', diff --git a/packages/backend/server/src/plugins/copilot/controller.ts b/packages/backend/server/src/plugins/copilot/controller.ts index 58678efd1a..bc527ebd79 100644 --- a/packages/backend/server/src/plugins/copilot/controller.ts +++ b/packages/backend/server/src/plugins/copilot/controller.ts @@ -23,12 +23,13 @@ import { import { Public } from '../../core/auth'; import { CurrentUser } from '../../core/auth/current-user'; import { CopilotProviderService } from './providers'; -import { ChatSessionService } from './session'; +import { ChatSession, ChatSessionService } from './session'; import { CopilotCapability } from './types'; export interface ChatEvent { - data: string; + type: 'attachment' | 'message'; id?: string; + data: string; } @Controller('/api/copilot') @@ -38,13 +39,54 @@ export class CopilotController { private readonly provider: CopilotProviderService ) {} + private async hasAttachment(sessionId: string, messageId?: string) { + const session = await this.chatSession.get(sessionId); + if (!session) { + throw new BadRequestException('Session not found'); + } + + if (messageId) { + const message = await session.getMessageById(messageId); + if (Array.isArray(message.attachments) && message.attachments.length) { + return true; + } + } + return false; + } + + private async appendSessionMessage( + sessionId: string, + message?: string, + messageId?: string + ): Promise { + const session = await this.chatSession.get(sessionId); + if (!session) { + throw new BadRequestException('Session not found'); + } + + if (messageId) { + await session.pushByMessageId(messageId); + } else { + if (!message || !message.trim()) { + throw new BadRequestException('Message is empty'); + } + session.push({ + role: 'user', + content: decodeURIComponent(message), + createdAt: new Date(), + }); + } + return session; + } + @Public() @Get('/chat/:sessionId') async chat( @CurrentUser() user: CurrentUser, @Req() req: Request, @Param('sessionId') sessionId: string, - @Query('message') content: string, + @Query('message') message: string | undefined, + @Query('messageId') messageId: string | undefined, @Query() params: Record ): Promise { const provider = this.provider.getProviderByCapability( @@ -53,21 +95,16 @@ export class CopilotController { if (!provider) { throw new InternalServerErrorException('No provider available'); } - const session = await this.chatSession.get(sessionId); - if (!session) { - throw new BadRequestException('Session not found'); - } - if (!content || !content.trim()) { - throw new BadRequestException('Message is empty'); - } - session.push({ - role: 'user', - content: decodeURIComponent(content), - createdAt: new Date(), - }); + + const session = await this.appendSessionMessage( + sessionId, + message, + messageId + ); try { delete params.message; + delete params.messageId; const content = await provider.generateText( session.finish(params), session.model, @@ -98,7 +135,8 @@ export class CopilotController { @CurrentUser() user: CurrentUser, @Req() req: Request, @Param('sessionId') sessionId: string, - @Query('message') content: string, + @Query('message') message: string | undefined, + @Query('messageId') messageId: string | undefined, @Query() params: Record ): Promise> { const provider = this.provider.getProviderByCapability( @@ -107,20 +145,15 @@ export class CopilotController { if (!provider) { throw new InternalServerErrorException('No provider available'); } - const session = await this.chatSession.get(sessionId); - if (!session) { - throw new BadRequestException('Session not found'); - } - if (!content || !content.trim()) { - throw new BadRequestException('Message is empty'); - } - session.push({ - role: 'user', - content: decodeURIComponent(content), - createdAt: new Date(), - }); + + const session = await this.appendSessionMessage( + sessionId, + message, + messageId + ); delete params.message; + delete params.messageId; return from( provider.generateTextStream(session.finish(params), session.model, { signal: req.signal, @@ -130,7 +163,9 @@ export class CopilotController { connect(shared$ => merge( // actual chat event stream - shared$.pipe(map(data => ({ id: sessionId, data }))), + shared$.pipe( + map(data => ({ type: 'message' as const, id: sessionId, data })) + ), // save the generated text to the session shared$.pipe( toArray(), @@ -148,4 +183,66 @@ export class CopilotController { ) ); } + + @Public() + @Sse('/chat/:sessionId/images') + async chatImagesStream( + @CurrentUser() user: CurrentUser | undefined, + @Req() req: Request, + @Param('sessionId') sessionId: string, + @Query('message') message: string | undefined, + @Query('messageId') messageId: string | undefined, + @Query() params: Record + ): Promise> { + const provider = this.provider.getProviderByCapability( + (await this.hasAttachment(sessionId, messageId)) + ? CopilotCapability.ImageToImage + : CopilotCapability.TextToImage + ); + if (!provider) { + throw new InternalServerErrorException('No provider available'); + } + + const session = await this.appendSessionMessage( + sessionId, + message, + messageId + ); + + delete params.message; + delete params.messageId; + return from( + provider.generateImagesStream(session.finish(params), session.model, { + signal: req.signal, + user: user?.id, + }) + ).pipe( + connect(shared$ => + merge( + // actual chat event stream + shared$.pipe( + map(attachment => ({ + type: 'attachment' as const, + id: sessionId, + data: attachment, + })) + ), + // save the generated text to the session + shared$.pipe( + toArray(), + concatMap(attachments => { + session.push({ + role: 'assistant', + content: '', + attachments: attachments, + createdAt: new Date(), + }); + return from(session.save()); + }), + switchMap(() => EMPTY) + ) + ) + ) + ); + } } diff --git a/packages/backend/server/src/plugins/copilot/index.ts b/packages/backend/server/src/plugins/copilot/index.ts index d3f7185f93..370e17cec5 100644 --- a/packages/backend/server/src/plugins/copilot/index.ts +++ b/packages/backend/server/src/plugins/copilot/index.ts @@ -3,16 +3,19 @@ import { QuotaService } from '../../core/quota'; import { PermissionService } from '../../core/workspaces/permission'; import { Plugin } from '../registry'; import { CopilotController } from './controller'; +import { ChatMessageCache } from './message'; import { PromptService } from './prompt'; import { assertProvidersConfigs, CopilotProviderService, + FalProvider, OpenAIProvider, registerCopilotProvider, } from './providers'; import { CopilotResolver, UserCopilotResolver } from './resolver'; import { ChatSessionService } from './session'; +registerCopilotProvider(FalProvider); registerCopilotProvider(OpenAIProvider); @Plugin({ @@ -22,6 +25,7 @@ registerCopilotProvider(OpenAIProvider); QuotaService, ChatSessionService, CopilotResolver, + ChatMessageCache, UserCopilotResolver, PromptService, CopilotProviderService, diff --git a/packages/backend/server/src/plugins/copilot/message.ts b/packages/backend/server/src/plugins/copilot/message.ts new file mode 100644 index 0000000000..2810143eb8 --- /dev/null +++ b/packages/backend/server/src/plugins/copilot/message.ts @@ -0,0 +1,35 @@ +import { randomUUID } from 'node:crypto'; + +import { Injectable, Logger } from '@nestjs/common'; + +import { SessionCache } from '../../fundamentals'; +import { SubmittedMessage, SubmittedMessageSchema } from './types'; + +const CHAT_MESSAGE_KEY = 'chat-message'; +const CHAT_MESSAGE_TTL = 3600 * 1 * 1000; // 1 hours + +@Injectable() +export class ChatMessageCache { + private readonly logger = new Logger(ChatMessageCache.name); + constructor(private readonly cache: SessionCache) {} + + async get(id: string): Promise { + return await this.cache.get(`${CHAT_MESSAGE_KEY}:${id}`); + } + + async set(message: SubmittedMessage): Promise { + try { + const parsed = SubmittedMessageSchema.safeParse(message); + if (parsed.success) { + const id = randomUUID(); + await this.cache.set(`${CHAT_MESSAGE_KEY}:${id}`, parsed.data, { + ttl: CHAT_MESSAGE_TTL, + }); + return id; + } + } catch (e: any) { + this.logger.error(`Failed to get chat message from cache: ${e.message}`); + } + return undefined; + } +} diff --git a/packages/backend/server/src/plugins/copilot/providers/fal.ts b/packages/backend/server/src/plugins/copilot/providers/fal.ts new file mode 100644 index 0000000000..addb8d8b14 --- /dev/null +++ b/packages/backend/server/src/plugins/copilot/providers/fal.ts @@ -0,0 +1,92 @@ +import assert from 'node:assert'; + +import { + CopilotCapability, + CopilotImageToImageProvider, + CopilotProviderType, + PromptMessage, +} from '../types'; + +export type FalConfig = { + apiKey: string; +}; + +export type FalResponse = { + images: Array<{ url: string }>; +}; + +export class FalProvider implements CopilotImageToImageProvider { + static readonly type = CopilotProviderType.FAL; + static readonly capabilities = [CopilotCapability.ImageToImage]; + + readonly availableModels = [ + // image to image + // https://blog.fal.ai/building-applications-with-real-time-stable-diffusion-apis/ + '110602490-lcm-sd15-i2i', + ]; + + constructor(private readonly config: FalConfig) { + assert(FalProvider.assetsConfig(config)); + } + + static assetsConfig(config: FalConfig) { + return !!config.apiKey; + } + + getCapabilities(): CopilotCapability[] { + return FalProvider.capabilities; + } + + // ====== image to image ====== + async generateImages( + messages: PromptMessage[], + model: string = this.availableModels[0], + options: { + signal?: AbortSignal; + user?: string; + } = {} + ): Promise> { + const { content, attachments } = messages.pop() || {}; + if (!this.availableModels.includes(model)) { + throw new Error(`Invalid model: ${model}`); + } + if (!content) { + throw new Error('Prompt is required'); + } + if (!Array.isArray(attachments) || !attachments.length) { + throw new Error('Attachments is required'); + } + + const data = (await fetch(`https://${model}.gateway.alpha.fal.ai/`, { + method: 'POST', + headers: { + Authorization: `key ${this.config.apiKey}`, + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ + image_url: attachments[0], + prompt: content, + sync_mode: true, + seed: 42, + enable_safety_checks: false, + }), + signal: options.signal, + }).then(res => res.json())) as FalResponse; + + return data.images.map(image => image.url); + } + + async *generateImagesStream( + messages: PromptMessage[], + model: string = this.availableModels[0], + options: { + signal?: AbortSignal; + user?: string; + } = {} + ): AsyncIterable { + const ret = await this.generateImages(messages, model, options); + for (const url of ret) { + yield url; + } + } +} diff --git a/packages/backend/server/src/plugins/copilot/providers/index.ts b/packages/backend/server/src/plugins/copilot/providers/index.ts index 52164d2a3d..0baeb5d2b1 100644 --- a/packages/backend/server/src/plugins/copilot/providers/index.ts +++ b/packages/backend/server/src/plugins/copilot/providers/index.ts @@ -134,4 +134,5 @@ export class CopilotProviderService { } } +export { FalProvider } from './fal'; export { OpenAIProvider } from './openai'; diff --git a/packages/backend/server/src/plugins/copilot/providers/openai.ts b/packages/backend/server/src/plugins/copilot/providers/openai.ts index af85794466..2084d6a5cb 100644 --- a/packages/backend/server/src/plugins/copilot/providers/openai.ts +++ b/packages/backend/server/src/plugins/copilot/providers/openai.ts @@ -5,22 +5,31 @@ import { ClientOptions, OpenAI } from 'openai'; import { ChatMessageRole, CopilotCapability, + CopilotImageToTextProvider, CopilotProviderType, CopilotTextToEmbeddingProvider, + CopilotTextToImageProvider, CopilotTextToTextProvider, PromptMessage, } from '../types'; const DEFAULT_DIMENSIONS = 256; +const SIMPLE_IMAGE_URL_REGEX = /^(https?:\/\/|data:image\/)/; + export class OpenAIProvider - implements CopilotTextToTextProvider, CopilotTextToEmbeddingProvider + implements + CopilotTextToTextProvider, + CopilotTextToEmbeddingProvider, + CopilotTextToImageProvider, + CopilotImageToTextProvider { static readonly type = CopilotProviderType.OpenAI; static readonly capabilities = [ CopilotCapability.TextToText, CopilotCapability.TextToEmbedding, CopilotCapability.TextToImage, + CopilotCapability.ImageToText, ]; readonly availableModels = [ @@ -35,6 +44,8 @@ export class OpenAIProvider // moderation 'text-moderation-latest', 'text-moderation-stable', + // text to image + 'dall-e-3', ]; private readonly instance: OpenAI; @@ -52,12 +63,29 @@ export class OpenAIProvider return OpenAIProvider.capabilities; } - private chatToGPTMessage(messages: PromptMessage[]) { + private chatToGPTMessage( + messages: PromptMessage[] + ): OpenAI.Chat.Completions.ChatCompletionMessageParam[] { // filter redundant fields - return messages.map(message => ({ - role: message.role, - content: message.content, - })); + return messages.map(({ role, content, attachments }) => { + if (Array.isArray(attachments)) { + const contents = [ + { type: 'text', text: content }, + ...attachments + .filter(url => SIMPLE_IMAGE_URL_REGEX.test(url)) + .map(url => ({ + type: 'image_url', + image_url: { url, detail: 'low' }, + })), + ]; + return { + role, + content: contents, + } as OpenAI.Chat.Completions.ChatCompletionMessageParam; + } else { + return { role, content }; + } + }); } private checkParams({ @@ -194,4 +222,44 @@ export class OpenAIProvider }); return result.data.map(e => e.embedding); } + + // ====== text to image ====== + async generateImages( + messages: PromptMessage[], + model: string = 'dall-e-3', + options: { + signal?: AbortSignal; + user?: string; + } = {} + ): Promise> { + const { content: prompt } = messages.pop() || {}; + if (!prompt) { + throw new Error('Prompt is required'); + } + const result = await this.instance.images.generate( + { + prompt, + model, + response_format: 'url', + user: options.user, + }, + { signal: options.signal } + ); + + return result.data.map(image => image.url).filter((v): v is string => !!v); + } + + async *generateImagesStream( + messages: PromptMessage[], + model: string = 'dall-e-3', + options: { + signal?: AbortSignal; + user?: string; + } = {} + ): AsyncIterable { + const ret = await this.generateImages(messages, model, options); + for (const url of ret) { + yield url; + } + } } diff --git a/packages/backend/server/src/plugins/copilot/resolver.ts b/packages/backend/server/src/plugins/copilot/resolver.ts index 4126def0c1..ae0a544485 100644 --- a/packages/backend/server/src/plugins/copilot/resolver.ts +++ b/packages/backend/server/src/plugins/copilot/resolver.ts @@ -1,3 +1,4 @@ +import { Logger } from '@nestjs/common'; import { Args, Field, @@ -12,7 +13,7 @@ import { } from '@nestjs/graphql'; import { SafeIntResolver } from 'graphql-scalars'; -import { CurrentUser, Public } from '../../core/auth'; +import { CurrentUser } from '../../core/auth'; import { QuotaService } from '../../core/quota'; import { UserType } from '../../core/user'; import { PermissionService } from '../../core/workspaces/permission'; @@ -21,11 +22,19 @@ import { PaymentRequiredException, TooManyRequestsException, } from '../../fundamentals'; -import { ChatSessionService, ListHistoriesOptions } from './session'; -import { AvailableModels, type ChatHistory, type ChatMessage } from './types'; +import { ChatSessionService } from './session'; +import { + AvailableModels, + type ChatHistory, + type ChatMessage, + type ListHistoriesOptions, + SubmittedMessage, +} from './types'; registerEnumType(AvailableModels, { name: 'CopilotModel' }); +const COPILOT_LOCKER = 'copilot'; + // ================== Input Types ================== @InputType() @@ -48,6 +57,21 @@ class CreateChatSessionInput { promptName!: string; } +@InputType() +class CreateChatMessageInput implements Omit { + @Field(() => String) + sessionId!: string; + + @Field(() => String) + content!: string; + + @Field(() => [String], { nullable: true }) + attachments!: string[] | undefined; + + @Field(() => String, { nullable: true }) + params!: string | undefined; +} + @InputType() class QueryChatHistoriesInput implements Partial { @Field(() => Boolean, { nullable: true }) @@ -118,6 +142,8 @@ export class CopilotType { @Resolver(() => CopilotType) export class CopilotResolver { + private readonly logger = new Logger(CopilotResolver.name); + constructor( private readonly permissions: PermissionService, private readonly quota: QuotaService, @@ -208,7 +234,6 @@ export class CopilotResolver { ); } - @Public() @Mutation(() => String, { description: 'Create a chat session', }) @@ -222,7 +247,7 @@ export class CopilotResolver { options.docId, user.id ); - const lockFlag = `session:${user.id}:${options.workspaceId}`; + const lockFlag = `${COPILOT_LOCKER}:session:${user.id}:${options.workspaceId}`; await using lock = await this.mutex.lock(lockFlag); if (!lock) { return new TooManyRequestsException('Server is busy'); @@ -241,6 +266,32 @@ export class CopilotResolver { }); return session; } + + @Mutation(() => String, { + description: 'Create a chat message', + }) + async createCopilotMessage( + @CurrentUser() user: CurrentUser, + @Args({ name: 'options', type: () => CreateChatMessageInput }) + options: CreateChatMessageInput + ) { + const lockFlag = `${COPILOT_LOCKER}:message:${user?.id}:${options.sessionId}`; + await using lock = await this.mutex.lock(lockFlag); + if (!lock) { + return new TooManyRequestsException('Server is busy'); + } + try { + const { params, ...rest } = options; + const record: SubmittedMessage['params'] = {}; + new URLSearchParams(params).forEach((value, key) => { + record[key] = value; + }); + return await this.chatSession.createMessage({ ...rest, params: record }); + } catch (e: any) { + this.logger.error(`Failed to create chat message: ${e.message}`); + throw new Error('Failed to create chat message'); + } + } } @Resolver(() => UserType) diff --git a/packages/backend/server/src/plugins/copilot/session.ts b/packages/backend/server/src/plugins/copilot/session.ts index 6cf1656496..6fca6d688e 100644 --- a/packages/backend/server/src/plugins/copilot/session.ts +++ b/packages/backend/server/src/plugins/copilot/session.ts @@ -3,43 +3,26 @@ import { randomUUID } from 'node:crypto'; import { Injectable, Logger } from '@nestjs/common'; import { PrismaClient } from '@prisma/client'; +import { ChatMessageCache } from './message'; import { ChatPrompt, PromptService } from './prompt'; import { AvailableModel, ChatHistory, ChatMessage, ChatMessageSchema, + ChatSessionOptions, + ChatSessionState, getTokenEncoder, + ListHistoriesOptions, PromptMessage, PromptMessageSchema, PromptParams, + SubmittedMessage, } from './types'; -export interface ChatSessionOptions { - userId: string; - workspaceId: string; - docId: string; - promptName: string; -} - -export interface ChatSessionState - extends Omit { - // connect ids - sessionId: string; - // states - prompt: ChatPrompt; - messages: ChatMessage[]; -} - -export type ListHistoriesOptions = { - action: boolean | undefined; - limit: number | undefined; - skip: number | undefined; - sessionId: string | undefined; -}; - export class ChatSession implements AsyncDisposable { constructor( + private readonly messageCache: ChatMessageCache, private readonly state: ChatSessionState, private readonly dispose?: (state: ChatSessionState) => Promise, private readonly maxTokenSize = 3840 @@ -60,6 +43,29 @@ export class ChatSession implements AsyncDisposable { this.state.messages.push(message); } + async getMessageById(messageId: string) { + const message = await this.messageCache.get(messageId); + if (!message || message.sessionId !== this.state.sessionId) { + throw new Error(`Message not found: ${messageId}`); + } + return message; + } + + async pushByMessageId(messageId: string) { + const message = await this.messageCache.get(messageId); + if (!message || message.sessionId !== this.state.sessionId) { + throw new Error(`Message not found: ${messageId}`); + } + + this.push({ + role: 'user', + content: message.content, + attachments: message.attachments, + params: message.params, + createdAt: new Date(), + }); + } + pop() { this.state.messages.pop(); } @@ -109,6 +115,7 @@ export class ChatSessionService { constructor( private readonly db: PrismaClient, + private readonly messageCache: ChatMessageCache, private readonly prompt: PromptService ) {} @@ -326,6 +333,10 @@ export class ChatSessionService { }); } + async createMessage(message: SubmittedMessage): Promise { + return await this.messageCache.set(message); + } + /** * usage: * ``` typescript @@ -342,7 +353,7 @@ export class ChatSessionService { async get(sessionId: string): Promise { const state = await this.getSession(sessionId); if (state) { - return new ChatSession(state, async state => { + return new ChatSession(this.messageCache, state, async state => { await this.setSession(state); }); } diff --git a/packages/backend/server/src/plugins/copilot/types.ts b/packages/backend/server/src/plugins/copilot/types.ts index 86a73a86df..59870d0888 100644 --- a/packages/backend/server/src/plugins/copilot/types.ts +++ b/packages/backend/server/src/plugins/copilot/types.ts @@ -8,10 +8,12 @@ import { } from 'tiktoken'; import { z } from 'zod'; +import type { ChatPrompt } from './prompt'; + export interface CopilotConfig { openai: OpenAIClientOptions; fal: { - secret: string; + apiKey: string; }; } @@ -27,6 +29,8 @@ export enum AvailableModels { // moderation TextModerationLatest = 'text-moderation-latest', TextModerationStable = 'text-moderation-stable', + // text to image + DallE3 = 'dall-e-3', } export type AvailableModel = keyof typeof AvailableModels; @@ -53,8 +57,7 @@ export const ChatMessageRole = Object.values(AiPromptRole) as [ 'user', ]; -export const PromptMessageSchema = z.object({ - role: z.enum(ChatMessageRole), +const PureMessageSchema = z.object({ content: z.string(), attachments: z.array(z.string()).optional(), params: z @@ -63,6 +66,10 @@ export const PromptMessageSchema = z.object({ .nullable(), }); +export const PromptMessageSchema = PureMessageSchema.extend({ + role: z.enum(ChatMessageRole), +}).strict(); + export type PromptMessage = z.infer; export type PromptParams = NonNullable; @@ -73,6 +80,12 @@ export const ChatMessageSchema = PromptMessageSchema.extend({ export type ChatMessage = z.infer; +export const SubmittedMessageSchema = PureMessageSchema.extend({ + sessionId: z.string(), +}).strict(); + +export type SubmittedMessage = z.infer; + export const ChatHistorySchema = z .object({ sessionId: z.string(), @@ -84,6 +97,32 @@ export const ChatHistorySchema = z export type ChatHistory = z.infer; +// ======== Chat Session ======== + +export interface ChatSessionOptions { + // connect ids + userId: string; + workspaceId: string; + docId: string; + promptName: string; +} + +export interface ChatSessionState + extends Omit { + // connect ids + sessionId: string; + // states + prompt: ChatPrompt; + messages: ChatMessage[]; +} + +export type ListHistoriesOptions = { + action: boolean | undefined; + limit: number | undefined; + skip: number | undefined; + sessionId: string | undefined; +}; + // ======== Provider Interface ======== export enum CopilotProviderType { @@ -96,6 +135,7 @@ export enum CopilotCapability { TextToEmbedding = 'text-to-embedding', TextToImage = 'text-to-image', ImageToImage = 'image-to-image', + ImageToText = 'image-to-text', } export interface CopilotProvider { @@ -137,13 +177,71 @@ export interface CopilotTextToEmbeddingProvider extends CopilotProvider { ): Promise; } -export interface CopilotTextToImageProvider extends CopilotProvider {} +export interface CopilotTextToImageProvider extends CopilotProvider { + generateImages( + messages: PromptMessage[], + model: string, + options: { + signal?: AbortSignal; + user?: string; + } + ): Promise>; + generateImagesStream( + messages: PromptMessage[], + model?: string, + options?: { + signal?: AbortSignal; + user?: string; + } + ): AsyncIterable; +} -export interface CopilotImageToImageProvider extends CopilotProvider {} +export interface CopilotImageToTextProvider extends CopilotProvider { + generateText( + messages: PromptMessage[], + model: string, + options: { + temperature?: number; + maxTokens?: number; + signal?: AbortSignal; + user?: string; + } + ): Promise; + generateTextStream( + messages: PromptMessage[], + model: string, + options: { + temperature?: number; + maxTokens?: number; + signal?: AbortSignal; + user?: string; + } + ): AsyncIterable; +} + +export interface CopilotImageToImageProvider extends CopilotProvider { + generateImages( + messages: PromptMessage[], + model: string, + options: { + signal?: AbortSignal; + user?: string; + } + ): Promise>; + generateImagesStream( + messages: PromptMessage[], + model?: string, + options?: { + signal?: AbortSignal; + user?: string; + } + ): AsyncIterable; +} export type CapabilityToCopilotProvider = { [CopilotCapability.TextToText]: CopilotTextToTextProvider; [CopilotCapability.TextToEmbedding]: CopilotTextToEmbeddingProvider; [CopilotCapability.TextToImage]: CopilotTextToImageProvider; + [CopilotCapability.ImageToText]: CopilotImageToTextProvider; [CopilotCapability.ImageToImage]: CopilotImageToImageProvider; }; diff --git a/packages/backend/server/src/schema.gql b/packages/backend/server/src/schema.gql index e46289db84..ba1e50db18 100644 --- a/packages/backend/server/src/schema.gql +++ b/packages/backend/server/src/schema.gql @@ -37,6 +37,13 @@ type CopilotQuota { used: SafeInt! } +input CreateChatMessageInput { + attachments: [String!] + content: String! + params: String + sessionId: String! +} + input CreateChatSessionInput { """An mark identifying which view to use to display the session""" action: String @@ -167,6 +174,9 @@ type Mutation { """Create a subscription checkout link of stripe""" createCheckoutSession(input: CreateCheckoutSessionInput!): String! + """Create a chat message""" + createCopilotMessage(options: CreateChatMessageInput!): String! + """Create a chat session""" createCopilotSession(options: CreateChatSessionInput!): String! diff --git a/packages/frontend/graphql/src/schema.ts b/packages/frontend/graphql/src/schema.ts index 9c1bd6f58c..035d5dd09c 100644 --- a/packages/frontend/graphql/src/schema.ts +++ b/packages/frontend/graphql/src/schema.ts @@ -34,6 +34,13 @@ export interface Scalars { Upload: { input: File; output: File }; } +export interface CreateChatMessageInput { + attachments: InputMaybe>; + content: Scalars['String']['input']; + params: InputMaybe; + sessionId: Scalars['String']['input']; +} + export interface CreateChatSessionInput { /** An mark identifying which view to use to display the session */ action: InputMaybe;