feat: text to image impl (#6437)

fix CLOUD-18
fix CLOUD-28
fix CLOUD-29
This commit is contained in:
darkskygit
2024-04-10 12:13:39 +00:00
parent 7c38a54f81
commit 9f349a2300
19 changed files with 601 additions and 99 deletions

View File

@@ -15,6 +15,7 @@ const {
R2_SECRET_ACCESS_KEY,
CAPTCHA_TURNSTILE_SECRET,
COPILOT_OPENAI_API_KEY,
COPILOT_FAL_API_KEY,
MAILER_SENDER,
MAILER_USER,
MAILER_PASSWORD,
@@ -101,6 +102,7 @@ const createHelmCommand = ({ isDryRun }) => {
`--set-string graphql.app.captcha.turnstile.secret="${CAPTCHA_TURNSTILE_SECRET}"`,
`--set graphql.app.copilot.enabled=true`,
`--set-string graphql.app.copilot.openai.key="${COPILOT_OPENAI_API_KEY}"`,
`--set-string graphql.app.copilot.fal.key="${COPILOT_FAL_API_KEY}"`,
`--set graphql.app.objectStorage.r2.enabled=true`,
`--set-string graphql.app.objectStorage.r2.accountId="${R2_ACCOUNT_ID}"`,
`--set-string graphql.app.objectStorage.r2.accessKeyId="${R2_ACCESS_KEY_ID}"`,

View File

@@ -6,4 +6,5 @@ metadata:
type: Opaque
data:
openaiSecret: {{ .Values.app.copilot.openai.key | b64enc }}
falSecret: {{ .Values.app.copilot.fal.key | b64enc }}
{{- end }}

View File

@@ -154,6 +154,11 @@ spec:
secretKeyRef:
name: "{{ .Values.app.copilot.secretName }}"
key: openaiSecret
- name: COPILOT_FAL_API_KEY
valueFrom:
secretKeyRef:
name: "{{ .Values.app.copilot.secretName }}"
key: falSecret
{{ end }}
{{ if .Values.app.oauth.google.enabled }}
- name: OAUTH_GOOGLE_ENABLED

View File

@@ -135,6 +135,7 @@ jobs:
R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }}
CAPTCHA_TURNSTILE_SECRET: ${{ secrets.CAPTCHA_TURNSTILE_SECRET }}
COPILOT_OPENAI_API_KEY: ${{ secrets.COPILOT_OPENAI_API_KEY }}
COPILOT_FAL_API_KEY: ${{ secrets.COPILOT_FAL_API_KEY }}
MAILER_SENDER: ${{ secrets.OAUTH_EMAIL_SENDER }}
MAILER_USER: ${{ secrets.OAUTH_EMAIL_LOGIN }}
MAILER_PASSWORD: ${{ secrets.OAUTH_EMAIL_PASSWORD }}

View File

@@ -26,6 +26,7 @@ CREATE TABLE "ai_prompts_messages" (
"idx" INTEGER NOT NULL,
"role" "AiPromptRole" NOT NULL,
"content" TEXT NOT NULL,
"attachments" JSON,
"params" JSON,
"created_at" TIMESTAMPTZ(6) NOT NULL DEFAULT CURRENT_TIMESTAMP
);
@@ -47,6 +48,8 @@ CREATE TABLE "ai_sessions_messages" (
"session_id" VARCHAR(36) NOT NULL,
"role" "AiPromptRole" NOT NULL,
"content" TEXT NOT NULL,
"attachments" JSON,
"params" JSON,
"created_at" TIMESTAMPTZ(6) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updated_at" TIMESTAMPTZ(6) NOT NULL,

View File

@@ -430,15 +430,16 @@ enum AiPromptRole {
}
model AiPromptMessage {
promptId Int @map("prompt_id") @db.Integer
promptId Int @map("prompt_id") @db.Integer
// if a group of prompts contains multiple sentences, idx specifies the order of each sentence
idx Int @db.Integer
idx Int @db.Integer
// system/assistant/user
role AiPromptRole
role AiPromptRole
// prompt content
content String @db.Text
params Json? @db.Json
createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(6)
content String @db.Text
attachments Json? @db.Json
params Json? @db.Json
createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(6)
prompt AiPrompt @relation(fields: [promptId], references: [id], onDelete: Cascade)
@@ -462,12 +463,14 @@ model AiPrompt {
}
model AiSessionMessage {
id String @id @default(uuid()) @db.VarChar(36)
sessionId String @map("session_id") @db.VarChar(36)
role AiPromptRole
content String @db.Text
createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(6)
updatedAt DateTime @updatedAt @map("updated_at") @db.Timestamptz(6)
id String @id @default(uuid()) @db.VarChar(36)
sessionId String @map("session_id") @db.VarChar(36)
role AiPromptRole
content String @db.Text
attachments Json? @db.Json
params Json? @db.Json
createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(6)
updatedAt DateTime @updatedAt @map("updated_at") @db.Timestamptz(6)
session AiSession @relation(fields: [sessionId], references: [id], onDelete: Cascade)

View File

@@ -20,6 +20,7 @@ AFFiNE.ENV_MAP = {
THROTTLE_TTL: ['rateLimiter.ttl', 'int'],
THROTTLE_LIMIT: ['rateLimiter.limit', 'int'],
COPILOT_OPENAI_API_KEY: 'plugins.copilot.openai.apiKey',
COPILOT_FAL_API_KEY: 'plugins.copilot.fal.apiKey',
REDIS_SERVER_HOST: 'plugins.redis.host',
REDIS_SERVER_PORT: ['plugins.redis.port', 'int'],
REDIS_SERVER_USER: 'plugins.redis.username',

View File

@@ -31,10 +31,22 @@ export const prompts: Prompt[] = [
model: 'gpt-4-vision-preview',
messages: [],
},
{
name: 'debug:action:dalle3',
action: 'image',
model: 'dall-e-3',
messages: [],
},
{
name: 'debug:action:fal-sd15',
action: 'image',
model: '110602490-lcm-sd15-i2i',
messages: [],
},
{
name: 'Summary',
action: 'text',
model: 'gpt-3.5-turbo',
model: 'gpt-4-turbo-preview',
messages: [
{
role: 'assistant',
@@ -46,7 +58,7 @@ export const prompts: Prompt[] = [
{
name: 'Summary the webpage',
action: 'text',
model: 'gpt-3.5-turbo',
model: 'gpt-4-turbo-preview',
messages: [
{
role: 'assistant',
@@ -58,7 +70,7 @@ export const prompts: Prompt[] = [
{
name: 'Explain this image',
action: 'text',
model: 'gpt-3.5-turbo',
model: 'gpt-4-vision-preview',
messages: [
{
role: 'assistant',
@@ -70,7 +82,7 @@ export const prompts: Prompt[] = [
{
name: 'Explain this code',
action: 'text',
model: 'gpt-3.5-turbo',
model: 'gpt-4-turbo-preview',
messages: [
{
role: 'assistant',
@@ -82,7 +94,7 @@ export const prompts: Prompt[] = [
{
name: 'Translate to',
action: 'text',
model: 'gpt-3.5-turbo',
model: 'gpt-4-turbo-preview',
messages: [
{
role: 'assistant',
@@ -108,7 +120,7 @@ export const prompts: Prompt[] = [
{
name: 'Write an article about this',
action: 'text',
model: 'gpt-3.5-turbo',
model: 'gpt-4-turbo-preview',
messages: [
{
role: 'assistant',
@@ -119,7 +131,7 @@ export const prompts: Prompt[] = [
{
name: 'Write a twitter about this',
action: 'text',
model: 'gpt-3.5-turbo',
model: 'gpt-4-turbo-preview',
messages: [
{
role: 'assistant',
@@ -130,7 +142,7 @@ export const prompts: Prompt[] = [
{
name: 'Write a poem about this',
action: 'text',
model: 'gpt-3.5-turbo',
model: 'gpt-4-turbo-preview',
messages: [
{
role: 'assistant',
@@ -141,7 +153,7 @@ export const prompts: Prompt[] = [
{
name: 'Write a blog post about this',
action: 'text',
model: 'gpt-3.5-turbo',
model: 'gpt-4-turbo-preview',
messages: [
{
role: 'assistant',
@@ -152,7 +164,7 @@ export const prompts: Prompt[] = [
{
name: 'Change tone to',
action: 'text',
model: 'gpt-3.5-turbo',
model: 'gpt-4-turbo-preview',
messages: [
{
role: 'assistant',
@@ -165,7 +177,7 @@ export const prompts: Prompt[] = [
{
name: 'Brainstorm ideas about this',
action: 'text',
model: 'gpt-3.5-turbo',
model: 'gpt-4-turbo-preview',
messages: [
{
role: 'assistant',
@@ -177,7 +189,7 @@ export const prompts: Prompt[] = [
{
name: 'Improve writing for it',
action: 'text',
model: 'gpt-3.5-turbo',
model: 'gpt-4-turbo-preview',
messages: [
{
role: 'assistant',
@@ -189,7 +201,7 @@ export const prompts: Prompt[] = [
{
name: 'Improve grammar for it',
action: 'text',
model: 'gpt-3.5-turbo',
model: 'gpt-4-turbo-preview',
messages: [
{
role: 'assistant',
@@ -201,7 +213,7 @@ export const prompts: Prompt[] = [
{
name: 'Fix spelling for it',
action: 'text',
model: 'gpt-3.5-turbo',
model: 'gpt-4-turbo-preview',
messages: [
{
role: 'assistant',
@@ -227,7 +239,7 @@ export const prompts: Prompt[] = [
{
name: 'Find action items from it',
action: 'todo-list',
model: 'gpt-3.5-turbo',
model: 'gpt-4-turbo-preview',
messages: [
{
role: 'assistant',
@@ -239,7 +251,7 @@ export const prompts: Prompt[] = [
{
name: 'Check code error',
action: 'text',
model: 'gpt-3.5-turbo',
model: 'gpt-4-turbo-preview',
messages: [
{
role: 'assistant',
@@ -251,7 +263,7 @@ export const prompts: Prompt[] = [
{
name: 'Create a presentation',
action: 'text',
model: 'gpt-3.5-turbo',
model: 'gpt-4-turbo-preview',
messages: [
{
role: 'assistant',
@@ -263,7 +275,7 @@ export const prompts: Prompt[] = [
{
name: 'Create headings',
action: 'text',
model: 'gpt-3.5-turbo',
model: 'gpt-4-turbo-preview',
messages: [
{
role: 'assistant',

View File

@@ -23,12 +23,13 @@ import {
import { Public } from '../../core/auth';
import { CurrentUser } from '../../core/auth/current-user';
import { CopilotProviderService } from './providers';
import { ChatSessionService } from './session';
import { ChatSession, ChatSessionService } from './session';
import { CopilotCapability } from './types';
export interface ChatEvent {
data: string;
type: 'attachment' | 'message';
id?: string;
data: string;
}
@Controller('/api/copilot')
@@ -38,13 +39,54 @@ export class CopilotController {
private readonly provider: CopilotProviderService
) {}
private async hasAttachment(sessionId: string, messageId?: string) {
const session = await this.chatSession.get(sessionId);
if (!session) {
throw new BadRequestException('Session not found');
}
if (messageId) {
const message = await session.getMessageById(messageId);
if (Array.isArray(message.attachments) && message.attachments.length) {
return true;
}
}
return false;
}
private async appendSessionMessage(
sessionId: string,
message?: string,
messageId?: string
): Promise<ChatSession> {
const session = await this.chatSession.get(sessionId);
if (!session) {
throw new BadRequestException('Session not found');
}
if (messageId) {
await session.pushByMessageId(messageId);
} else {
if (!message || !message.trim()) {
throw new BadRequestException('Message is empty');
}
session.push({
role: 'user',
content: decodeURIComponent(message),
createdAt: new Date(),
});
}
return session;
}
@Public()
@Get('/chat/:sessionId')
async chat(
@CurrentUser() user: CurrentUser,
@Req() req: Request,
@Param('sessionId') sessionId: string,
@Query('message') content: string,
@Query('message') message: string | undefined,
@Query('messageId') messageId: string | undefined,
@Query() params: Record<string, string | string[]>
): Promise<string> {
const provider = this.provider.getProviderByCapability(
@@ -53,21 +95,16 @@ export class CopilotController {
if (!provider) {
throw new InternalServerErrorException('No provider available');
}
const session = await this.chatSession.get(sessionId);
if (!session) {
throw new BadRequestException('Session not found');
}
if (!content || !content.trim()) {
throw new BadRequestException('Message is empty');
}
session.push({
role: 'user',
content: decodeURIComponent(content),
createdAt: new Date(),
});
const session = await this.appendSessionMessage(
sessionId,
message,
messageId
);
try {
delete params.message;
delete params.messageId;
const content = await provider.generateText(
session.finish(params),
session.model,
@@ -98,7 +135,8 @@ export class CopilotController {
@CurrentUser() user: CurrentUser,
@Req() req: Request,
@Param('sessionId') sessionId: string,
@Query('message') content: string,
@Query('message') message: string | undefined,
@Query('messageId') messageId: string | undefined,
@Query() params: Record<string, string>
): Promise<Observable<ChatEvent>> {
const provider = this.provider.getProviderByCapability(
@@ -107,20 +145,15 @@ export class CopilotController {
if (!provider) {
throw new InternalServerErrorException('No provider available');
}
const session = await this.chatSession.get(sessionId);
if (!session) {
throw new BadRequestException('Session not found');
}
if (!content || !content.trim()) {
throw new BadRequestException('Message is empty');
}
session.push({
role: 'user',
content: decodeURIComponent(content),
createdAt: new Date(),
});
const session = await this.appendSessionMessage(
sessionId,
message,
messageId
);
delete params.message;
delete params.messageId;
return from(
provider.generateTextStream(session.finish(params), session.model, {
signal: req.signal,
@@ -130,7 +163,9 @@ export class CopilotController {
connect(shared$ =>
merge(
// actual chat event stream
shared$.pipe(map(data => ({ id: sessionId, data }))),
shared$.pipe(
map(data => ({ type: 'message' as const, id: sessionId, data }))
),
// save the generated text to the session
shared$.pipe(
toArray(),
@@ -148,4 +183,66 @@ export class CopilotController {
)
);
}
@Public()
@Sse('/chat/:sessionId/images')
async chatImagesStream(
@CurrentUser() user: CurrentUser | undefined,
@Req() req: Request,
@Param('sessionId') sessionId: string,
@Query('message') message: string | undefined,
@Query('messageId') messageId: string | undefined,
@Query() params: Record<string, string>
): Promise<Observable<ChatEvent>> {
const provider = this.provider.getProviderByCapability(
(await this.hasAttachment(sessionId, messageId))
? CopilotCapability.ImageToImage
: CopilotCapability.TextToImage
);
if (!provider) {
throw new InternalServerErrorException('No provider available');
}
const session = await this.appendSessionMessage(
sessionId,
message,
messageId
);
delete params.message;
delete params.messageId;
return from(
provider.generateImagesStream(session.finish(params), session.model, {
signal: req.signal,
user: user?.id,
})
).pipe(
connect(shared$ =>
merge(
// actual chat event stream
shared$.pipe(
map(attachment => ({
type: 'attachment' as const,
id: sessionId,
data: attachment,
}))
),
// save the generated text to the session
shared$.pipe(
toArray(),
concatMap(attachments => {
session.push({
role: 'assistant',
content: '',
attachments: attachments,
createdAt: new Date(),
});
return from(session.save());
}),
switchMap(() => EMPTY)
)
)
)
);
}
}

View File

@@ -3,16 +3,19 @@ import { QuotaService } from '../../core/quota';
import { PermissionService } from '../../core/workspaces/permission';
import { Plugin } from '../registry';
import { CopilotController } from './controller';
import { ChatMessageCache } from './message';
import { PromptService } from './prompt';
import {
assertProvidersConfigs,
CopilotProviderService,
FalProvider,
OpenAIProvider,
registerCopilotProvider,
} from './providers';
import { CopilotResolver, UserCopilotResolver } from './resolver';
import { ChatSessionService } from './session';
registerCopilotProvider(FalProvider);
registerCopilotProvider(OpenAIProvider);
@Plugin({
@@ -22,6 +25,7 @@ registerCopilotProvider(OpenAIProvider);
QuotaService,
ChatSessionService,
CopilotResolver,
ChatMessageCache,
UserCopilotResolver,
PromptService,
CopilotProviderService,

View File

@@ -0,0 +1,35 @@
import { randomUUID } from 'node:crypto';
import { Injectable, Logger } from '@nestjs/common';
import { SessionCache } from '../../fundamentals';
import { SubmittedMessage, SubmittedMessageSchema } from './types';
const CHAT_MESSAGE_KEY = 'chat-message';
const CHAT_MESSAGE_TTL = 3600 * 1 * 1000; // 1 hours
@Injectable()
export class ChatMessageCache {
private readonly logger = new Logger(ChatMessageCache.name);
constructor(private readonly cache: SessionCache) {}
async get(id: string): Promise<SubmittedMessage | undefined> {
return await this.cache.get(`${CHAT_MESSAGE_KEY}:${id}`);
}
async set(message: SubmittedMessage): Promise<string | undefined> {
try {
const parsed = SubmittedMessageSchema.safeParse(message);
if (parsed.success) {
const id = randomUUID();
await this.cache.set(`${CHAT_MESSAGE_KEY}:${id}`, parsed.data, {
ttl: CHAT_MESSAGE_TTL,
});
return id;
}
} catch (e: any) {
this.logger.error(`Failed to get chat message from cache: ${e.message}`);
}
return undefined;
}
}

View File

@@ -0,0 +1,92 @@
import assert from 'node:assert';
import {
CopilotCapability,
CopilotImageToImageProvider,
CopilotProviderType,
PromptMessage,
} from '../types';
export type FalConfig = {
apiKey: string;
};
export type FalResponse = {
images: Array<{ url: string }>;
};
export class FalProvider implements CopilotImageToImageProvider {
static readonly type = CopilotProviderType.FAL;
static readonly capabilities = [CopilotCapability.ImageToImage];
readonly availableModels = [
// image to image
// https://blog.fal.ai/building-applications-with-real-time-stable-diffusion-apis/
'110602490-lcm-sd15-i2i',
];
constructor(private readonly config: FalConfig) {
assert(FalProvider.assetsConfig(config));
}
static assetsConfig(config: FalConfig) {
return !!config.apiKey;
}
getCapabilities(): CopilotCapability[] {
return FalProvider.capabilities;
}
// ====== image to image ======
async generateImages(
messages: PromptMessage[],
model: string = this.availableModels[0],
options: {
signal?: AbortSignal;
user?: string;
} = {}
): Promise<Array<string>> {
const { content, attachments } = messages.pop() || {};
if (!this.availableModels.includes(model)) {
throw new Error(`Invalid model: ${model}`);
}
if (!content) {
throw new Error('Prompt is required');
}
if (!Array.isArray(attachments) || !attachments.length) {
throw new Error('Attachments is required');
}
const data = (await fetch(`https://${model}.gateway.alpha.fal.ai/`, {
method: 'POST',
headers: {
Authorization: `key ${this.config.apiKey}`,
'Content-Type': 'application/json',
},
body: JSON.stringify({
image_url: attachments[0],
prompt: content,
sync_mode: true,
seed: 42,
enable_safety_checks: false,
}),
signal: options.signal,
}).then(res => res.json())) as FalResponse;
return data.images.map(image => image.url);
}
async *generateImagesStream(
messages: PromptMessage[],
model: string = this.availableModels[0],
options: {
signal?: AbortSignal;
user?: string;
} = {}
): AsyncIterable<string> {
const ret = await this.generateImages(messages, model, options);
for (const url of ret) {
yield url;
}
}
}

View File

@@ -134,4 +134,5 @@ export class CopilotProviderService {
}
}
export { FalProvider } from './fal';
export { OpenAIProvider } from './openai';

View File

@@ -5,22 +5,31 @@ import { ClientOptions, OpenAI } from 'openai';
import {
ChatMessageRole,
CopilotCapability,
CopilotImageToTextProvider,
CopilotProviderType,
CopilotTextToEmbeddingProvider,
CopilotTextToImageProvider,
CopilotTextToTextProvider,
PromptMessage,
} from '../types';
const DEFAULT_DIMENSIONS = 256;
const SIMPLE_IMAGE_URL_REGEX = /^(https?:\/\/|data:image\/)/;
export class OpenAIProvider
implements CopilotTextToTextProvider, CopilotTextToEmbeddingProvider
implements
CopilotTextToTextProvider,
CopilotTextToEmbeddingProvider,
CopilotTextToImageProvider,
CopilotImageToTextProvider
{
static readonly type = CopilotProviderType.OpenAI;
static readonly capabilities = [
CopilotCapability.TextToText,
CopilotCapability.TextToEmbedding,
CopilotCapability.TextToImage,
CopilotCapability.ImageToText,
];
readonly availableModels = [
@@ -35,6 +44,8 @@ export class OpenAIProvider
// moderation
'text-moderation-latest',
'text-moderation-stable',
// text to image
'dall-e-3',
];
private readonly instance: OpenAI;
@@ -52,12 +63,29 @@ export class OpenAIProvider
return OpenAIProvider.capabilities;
}
private chatToGPTMessage(messages: PromptMessage[]) {
private chatToGPTMessage(
messages: PromptMessage[]
): OpenAI.Chat.Completions.ChatCompletionMessageParam[] {
// filter redundant fields
return messages.map(message => ({
role: message.role,
content: message.content,
}));
return messages.map(({ role, content, attachments }) => {
if (Array.isArray(attachments)) {
const contents = [
{ type: 'text', text: content },
...attachments
.filter(url => SIMPLE_IMAGE_URL_REGEX.test(url))
.map(url => ({
type: 'image_url',
image_url: { url, detail: 'low' },
})),
];
return {
role,
content: contents,
} as OpenAI.Chat.Completions.ChatCompletionMessageParam;
} else {
return { role, content };
}
});
}
private checkParams({
@@ -194,4 +222,44 @@ export class OpenAIProvider
});
return result.data.map(e => e.embedding);
}
// ====== text to image ======
async generateImages(
messages: PromptMessage[],
model: string = 'dall-e-3',
options: {
signal?: AbortSignal;
user?: string;
} = {}
): Promise<Array<string>> {
const { content: prompt } = messages.pop() || {};
if (!prompt) {
throw new Error('Prompt is required');
}
const result = await this.instance.images.generate(
{
prompt,
model,
response_format: 'url',
user: options.user,
},
{ signal: options.signal }
);
return result.data.map(image => image.url).filter((v): v is string => !!v);
}
async *generateImagesStream(
messages: PromptMessage[],
model: string = 'dall-e-3',
options: {
signal?: AbortSignal;
user?: string;
} = {}
): AsyncIterable<string> {
const ret = await this.generateImages(messages, model, options);
for (const url of ret) {
yield url;
}
}
}

View File

@@ -1,3 +1,4 @@
import { Logger } from '@nestjs/common';
import {
Args,
Field,
@@ -12,7 +13,7 @@ import {
} from '@nestjs/graphql';
import { SafeIntResolver } from 'graphql-scalars';
import { CurrentUser, Public } from '../../core/auth';
import { CurrentUser } from '../../core/auth';
import { QuotaService } from '../../core/quota';
import { UserType } from '../../core/user';
import { PermissionService } from '../../core/workspaces/permission';
@@ -21,11 +22,19 @@ import {
PaymentRequiredException,
TooManyRequestsException,
} from '../../fundamentals';
import { ChatSessionService, ListHistoriesOptions } from './session';
import { AvailableModels, type ChatHistory, type ChatMessage } from './types';
import { ChatSessionService } from './session';
import {
AvailableModels,
type ChatHistory,
type ChatMessage,
type ListHistoriesOptions,
SubmittedMessage,
} from './types';
registerEnumType(AvailableModels, { name: 'CopilotModel' });
const COPILOT_LOCKER = 'copilot';
// ================== Input Types ==================
@InputType()
@@ -48,6 +57,21 @@ class CreateChatSessionInput {
promptName!: string;
}
@InputType()
class CreateChatMessageInput implements Omit<SubmittedMessage, 'params'> {
@Field(() => String)
sessionId!: string;
@Field(() => String)
content!: string;
@Field(() => [String], { nullable: true })
attachments!: string[] | undefined;
@Field(() => String, { nullable: true })
params!: string | undefined;
}
@InputType()
class QueryChatHistoriesInput implements Partial<ListHistoriesOptions> {
@Field(() => Boolean, { nullable: true })
@@ -118,6 +142,8 @@ export class CopilotType {
@Resolver(() => CopilotType)
export class CopilotResolver {
private readonly logger = new Logger(CopilotResolver.name);
constructor(
private readonly permissions: PermissionService,
private readonly quota: QuotaService,
@@ -208,7 +234,6 @@ export class CopilotResolver {
);
}
@Public()
@Mutation(() => String, {
description: 'Create a chat session',
})
@@ -222,7 +247,7 @@ export class CopilotResolver {
options.docId,
user.id
);
const lockFlag = `session:${user.id}:${options.workspaceId}`;
const lockFlag = `${COPILOT_LOCKER}:session:${user.id}:${options.workspaceId}`;
await using lock = await this.mutex.lock(lockFlag);
if (!lock) {
return new TooManyRequestsException('Server is busy');
@@ -241,6 +266,32 @@ export class CopilotResolver {
});
return session;
}
@Mutation(() => String, {
description: 'Create a chat message',
})
async createCopilotMessage(
@CurrentUser() user: CurrentUser,
@Args({ name: 'options', type: () => CreateChatMessageInput })
options: CreateChatMessageInput
) {
const lockFlag = `${COPILOT_LOCKER}:message:${user?.id}:${options.sessionId}`;
await using lock = await this.mutex.lock(lockFlag);
if (!lock) {
return new TooManyRequestsException('Server is busy');
}
try {
const { params, ...rest } = options;
const record: SubmittedMessage['params'] = {};
new URLSearchParams(params).forEach((value, key) => {
record[key] = value;
});
return await this.chatSession.createMessage({ ...rest, params: record });
} catch (e: any) {
this.logger.error(`Failed to create chat message: ${e.message}`);
throw new Error('Failed to create chat message');
}
}
}
@Resolver(() => UserType)

View File

@@ -3,43 +3,26 @@ import { randomUUID } from 'node:crypto';
import { Injectable, Logger } from '@nestjs/common';
import { PrismaClient } from '@prisma/client';
import { ChatMessageCache } from './message';
import { ChatPrompt, PromptService } from './prompt';
import {
AvailableModel,
ChatHistory,
ChatMessage,
ChatMessageSchema,
ChatSessionOptions,
ChatSessionState,
getTokenEncoder,
ListHistoriesOptions,
PromptMessage,
PromptMessageSchema,
PromptParams,
SubmittedMessage,
} from './types';
export interface ChatSessionOptions {
userId: string;
workspaceId: string;
docId: string;
promptName: string;
}
export interface ChatSessionState
extends Omit<ChatSessionOptions, 'promptName'> {
// connect ids
sessionId: string;
// states
prompt: ChatPrompt;
messages: ChatMessage[];
}
export type ListHistoriesOptions = {
action: boolean | undefined;
limit: number | undefined;
skip: number | undefined;
sessionId: string | undefined;
};
export class ChatSession implements AsyncDisposable {
constructor(
private readonly messageCache: ChatMessageCache,
private readonly state: ChatSessionState,
private readonly dispose?: (state: ChatSessionState) => Promise<void>,
private readonly maxTokenSize = 3840
@@ -60,6 +43,29 @@ export class ChatSession implements AsyncDisposable {
this.state.messages.push(message);
}
async getMessageById(messageId: string) {
const message = await this.messageCache.get(messageId);
if (!message || message.sessionId !== this.state.sessionId) {
throw new Error(`Message not found: ${messageId}`);
}
return message;
}
async pushByMessageId(messageId: string) {
const message = await this.messageCache.get(messageId);
if (!message || message.sessionId !== this.state.sessionId) {
throw new Error(`Message not found: ${messageId}`);
}
this.push({
role: 'user',
content: message.content,
attachments: message.attachments,
params: message.params,
createdAt: new Date(),
});
}
pop() {
this.state.messages.pop();
}
@@ -109,6 +115,7 @@ export class ChatSessionService {
constructor(
private readonly db: PrismaClient,
private readonly messageCache: ChatMessageCache,
private readonly prompt: PromptService
) {}
@@ -326,6 +333,10 @@ export class ChatSessionService {
});
}
async createMessage(message: SubmittedMessage): Promise<string | undefined> {
return await this.messageCache.set(message);
}
/**
* usage:
* ``` typescript
@@ -342,7 +353,7 @@ export class ChatSessionService {
async get(sessionId: string): Promise<ChatSession | null> {
const state = await this.getSession(sessionId);
if (state) {
return new ChatSession(state, async state => {
return new ChatSession(this.messageCache, state, async state => {
await this.setSession(state);
});
}

View File

@@ -8,10 +8,12 @@ import {
} from 'tiktoken';
import { z } from 'zod';
import type { ChatPrompt } from './prompt';
export interface CopilotConfig {
openai: OpenAIClientOptions;
fal: {
secret: string;
apiKey: string;
};
}
@@ -27,6 +29,8 @@ export enum AvailableModels {
// moderation
TextModerationLatest = 'text-moderation-latest',
TextModerationStable = 'text-moderation-stable',
// text to image
DallE3 = 'dall-e-3',
}
export type AvailableModel = keyof typeof AvailableModels;
@@ -53,8 +57,7 @@ export const ChatMessageRole = Object.values(AiPromptRole) as [
'user',
];
export const PromptMessageSchema = z.object({
role: z.enum(ChatMessageRole),
const PureMessageSchema = z.object({
content: z.string(),
attachments: z.array(z.string()).optional(),
params: z
@@ -63,6 +66,10 @@ export const PromptMessageSchema = z.object({
.nullable(),
});
export const PromptMessageSchema = PureMessageSchema.extend({
role: z.enum(ChatMessageRole),
}).strict();
export type PromptMessage = z.infer<typeof PromptMessageSchema>;
export type PromptParams = NonNullable<PromptMessage['params']>;
@@ -73,6 +80,12 @@ export const ChatMessageSchema = PromptMessageSchema.extend({
export type ChatMessage = z.infer<typeof ChatMessageSchema>;
export const SubmittedMessageSchema = PureMessageSchema.extend({
sessionId: z.string(),
}).strict();
export type SubmittedMessage = z.infer<typeof SubmittedMessageSchema>;
export const ChatHistorySchema = z
.object({
sessionId: z.string(),
@@ -84,6 +97,32 @@ export const ChatHistorySchema = z
export type ChatHistory = z.infer<typeof ChatHistorySchema>;
// ======== Chat Session ========
export interface ChatSessionOptions {
// connect ids
userId: string;
workspaceId: string;
docId: string;
promptName: string;
}
export interface ChatSessionState
extends Omit<ChatSessionOptions, 'promptName'> {
// connect ids
sessionId: string;
// states
prompt: ChatPrompt;
messages: ChatMessage[];
}
export type ListHistoriesOptions = {
action: boolean | undefined;
limit: number | undefined;
skip: number | undefined;
sessionId: string | undefined;
};
// ======== Provider Interface ========
export enum CopilotProviderType {
@@ -96,6 +135,7 @@ export enum CopilotCapability {
TextToEmbedding = 'text-to-embedding',
TextToImage = 'text-to-image',
ImageToImage = 'image-to-image',
ImageToText = 'image-to-text',
}
export interface CopilotProvider {
@@ -137,13 +177,71 @@ export interface CopilotTextToEmbeddingProvider extends CopilotProvider {
): Promise<number[][]>;
}
export interface CopilotTextToImageProvider extends CopilotProvider {}
export interface CopilotTextToImageProvider extends CopilotProvider {
generateImages(
messages: PromptMessage[],
model: string,
options: {
signal?: AbortSignal;
user?: string;
}
): Promise<Array<string>>;
generateImagesStream(
messages: PromptMessage[],
model?: string,
options?: {
signal?: AbortSignal;
user?: string;
}
): AsyncIterable<string>;
}
export interface CopilotImageToImageProvider extends CopilotProvider {}
export interface CopilotImageToTextProvider extends CopilotProvider {
generateText(
messages: PromptMessage[],
model: string,
options: {
temperature?: number;
maxTokens?: number;
signal?: AbortSignal;
user?: string;
}
): Promise<string>;
generateTextStream(
messages: PromptMessage[],
model: string,
options: {
temperature?: number;
maxTokens?: number;
signal?: AbortSignal;
user?: string;
}
): AsyncIterable<string>;
}
export interface CopilotImageToImageProvider extends CopilotProvider {
generateImages(
messages: PromptMessage[],
model: string,
options: {
signal?: AbortSignal;
user?: string;
}
): Promise<Array<string>>;
generateImagesStream(
messages: PromptMessage[],
model?: string,
options?: {
signal?: AbortSignal;
user?: string;
}
): AsyncIterable<string>;
}
export type CapabilityToCopilotProvider = {
[CopilotCapability.TextToText]: CopilotTextToTextProvider;
[CopilotCapability.TextToEmbedding]: CopilotTextToEmbeddingProvider;
[CopilotCapability.TextToImage]: CopilotTextToImageProvider;
[CopilotCapability.ImageToText]: CopilotImageToTextProvider;
[CopilotCapability.ImageToImage]: CopilotImageToImageProvider;
};

View File

@@ -37,6 +37,13 @@ type CopilotQuota {
used: SafeInt!
}
input CreateChatMessageInput {
attachments: [String!]
content: String!
params: String
sessionId: String!
}
input CreateChatSessionInput {
"""An mark identifying which view to use to display the session"""
action: String
@@ -167,6 +174,9 @@ type Mutation {
"""Create a subscription checkout link of stripe"""
createCheckoutSession(input: CreateCheckoutSessionInput!): String!
"""Create a chat message"""
createCopilotMessage(options: CreateChatMessageInput!): String!
"""Create a chat session"""
createCopilotSession(options: CreateChatSessionInput!): String!

View File

@@ -34,6 +34,13 @@ export interface Scalars {
Upload: { input: File; output: File };
}
export interface CreateChatMessageInput {
attachments: InputMaybe<Array<Scalars['String']['input']>>;
content: Scalars['String']['input'];
params: InputMaybe<Scalars['String']['input']>;
sessionId: Scalars['String']['input'];
}
export interface CreateChatSessionInput {
/** An mark identifying which view to use to display the session */
action: InputMaybe<Scalars['String']['input']>;