mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-12 12:28:42 +00:00
feat: text to image impl (#6437)
fix CLOUD-18 fix CLOUD-28 fix CLOUD-29
This commit is contained in:
@@ -23,12 +23,13 @@ import {
|
||||
import { Public } from '../../core/auth';
|
||||
import { CurrentUser } from '../../core/auth/current-user';
|
||||
import { CopilotProviderService } from './providers';
|
||||
import { ChatSessionService } from './session';
|
||||
import { ChatSession, ChatSessionService } from './session';
|
||||
import { CopilotCapability } from './types';
|
||||
|
||||
export interface ChatEvent {
|
||||
data: string;
|
||||
type: 'attachment' | 'message';
|
||||
id?: string;
|
||||
data: string;
|
||||
}
|
||||
|
||||
@Controller('/api/copilot')
|
||||
@@ -38,13 +39,54 @@ export class CopilotController {
|
||||
private readonly provider: CopilotProviderService
|
||||
) {}
|
||||
|
||||
private async hasAttachment(sessionId: string, messageId?: string) {
|
||||
const session = await this.chatSession.get(sessionId);
|
||||
if (!session) {
|
||||
throw new BadRequestException('Session not found');
|
||||
}
|
||||
|
||||
if (messageId) {
|
||||
const message = await session.getMessageById(messageId);
|
||||
if (Array.isArray(message.attachments) && message.attachments.length) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private async appendSessionMessage(
|
||||
sessionId: string,
|
||||
message?: string,
|
||||
messageId?: string
|
||||
): Promise<ChatSession> {
|
||||
const session = await this.chatSession.get(sessionId);
|
||||
if (!session) {
|
||||
throw new BadRequestException('Session not found');
|
||||
}
|
||||
|
||||
if (messageId) {
|
||||
await session.pushByMessageId(messageId);
|
||||
} else {
|
||||
if (!message || !message.trim()) {
|
||||
throw new BadRequestException('Message is empty');
|
||||
}
|
||||
session.push({
|
||||
role: 'user',
|
||||
content: decodeURIComponent(message),
|
||||
createdAt: new Date(),
|
||||
});
|
||||
}
|
||||
return session;
|
||||
}
|
||||
|
||||
@Public()
|
||||
@Get('/chat/:sessionId')
|
||||
async chat(
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@Req() req: Request,
|
||||
@Param('sessionId') sessionId: string,
|
||||
@Query('message') content: string,
|
||||
@Query('message') message: string | undefined,
|
||||
@Query('messageId') messageId: string | undefined,
|
||||
@Query() params: Record<string, string | string[]>
|
||||
): Promise<string> {
|
||||
const provider = this.provider.getProviderByCapability(
|
||||
@@ -53,21 +95,16 @@ export class CopilotController {
|
||||
if (!provider) {
|
||||
throw new InternalServerErrorException('No provider available');
|
||||
}
|
||||
const session = await this.chatSession.get(sessionId);
|
||||
if (!session) {
|
||||
throw new BadRequestException('Session not found');
|
||||
}
|
||||
if (!content || !content.trim()) {
|
||||
throw new BadRequestException('Message is empty');
|
||||
}
|
||||
session.push({
|
||||
role: 'user',
|
||||
content: decodeURIComponent(content),
|
||||
createdAt: new Date(),
|
||||
});
|
||||
|
||||
const session = await this.appendSessionMessage(
|
||||
sessionId,
|
||||
message,
|
||||
messageId
|
||||
);
|
||||
|
||||
try {
|
||||
delete params.message;
|
||||
delete params.messageId;
|
||||
const content = await provider.generateText(
|
||||
session.finish(params),
|
||||
session.model,
|
||||
@@ -98,7 +135,8 @@ export class CopilotController {
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@Req() req: Request,
|
||||
@Param('sessionId') sessionId: string,
|
||||
@Query('message') content: string,
|
||||
@Query('message') message: string | undefined,
|
||||
@Query('messageId') messageId: string | undefined,
|
||||
@Query() params: Record<string, string>
|
||||
): Promise<Observable<ChatEvent>> {
|
||||
const provider = this.provider.getProviderByCapability(
|
||||
@@ -107,20 +145,15 @@ export class CopilotController {
|
||||
if (!provider) {
|
||||
throw new InternalServerErrorException('No provider available');
|
||||
}
|
||||
const session = await this.chatSession.get(sessionId);
|
||||
if (!session) {
|
||||
throw new BadRequestException('Session not found');
|
||||
}
|
||||
if (!content || !content.trim()) {
|
||||
throw new BadRequestException('Message is empty');
|
||||
}
|
||||
session.push({
|
||||
role: 'user',
|
||||
content: decodeURIComponent(content),
|
||||
createdAt: new Date(),
|
||||
});
|
||||
|
||||
const session = await this.appendSessionMessage(
|
||||
sessionId,
|
||||
message,
|
||||
messageId
|
||||
);
|
||||
|
||||
delete params.message;
|
||||
delete params.messageId;
|
||||
return from(
|
||||
provider.generateTextStream(session.finish(params), session.model, {
|
||||
signal: req.signal,
|
||||
@@ -130,7 +163,9 @@ export class CopilotController {
|
||||
connect(shared$ =>
|
||||
merge(
|
||||
// actual chat event stream
|
||||
shared$.pipe(map(data => ({ id: sessionId, data }))),
|
||||
shared$.pipe(
|
||||
map(data => ({ type: 'message' as const, id: sessionId, data }))
|
||||
),
|
||||
// save the generated text to the session
|
||||
shared$.pipe(
|
||||
toArray(),
|
||||
@@ -148,4 +183,66 @@ export class CopilotController {
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@Public()
|
||||
@Sse('/chat/:sessionId/images')
|
||||
async chatImagesStream(
|
||||
@CurrentUser() user: CurrentUser | undefined,
|
||||
@Req() req: Request,
|
||||
@Param('sessionId') sessionId: string,
|
||||
@Query('message') message: string | undefined,
|
||||
@Query('messageId') messageId: string | undefined,
|
||||
@Query() params: Record<string, string>
|
||||
): Promise<Observable<ChatEvent>> {
|
||||
const provider = this.provider.getProviderByCapability(
|
||||
(await this.hasAttachment(sessionId, messageId))
|
||||
? CopilotCapability.ImageToImage
|
||||
: CopilotCapability.TextToImage
|
||||
);
|
||||
if (!provider) {
|
||||
throw new InternalServerErrorException('No provider available');
|
||||
}
|
||||
|
||||
const session = await this.appendSessionMessage(
|
||||
sessionId,
|
||||
message,
|
||||
messageId
|
||||
);
|
||||
|
||||
delete params.message;
|
||||
delete params.messageId;
|
||||
return from(
|
||||
provider.generateImagesStream(session.finish(params), session.model, {
|
||||
signal: req.signal,
|
||||
user: user?.id,
|
||||
})
|
||||
).pipe(
|
||||
connect(shared$ =>
|
||||
merge(
|
||||
// actual chat event stream
|
||||
shared$.pipe(
|
||||
map(attachment => ({
|
||||
type: 'attachment' as const,
|
||||
id: sessionId,
|
||||
data: attachment,
|
||||
}))
|
||||
),
|
||||
// save the generated text to the session
|
||||
shared$.pipe(
|
||||
toArray(),
|
||||
concatMap(attachments => {
|
||||
session.push({
|
||||
role: 'assistant',
|
||||
content: '',
|
||||
attachments: attachments,
|
||||
createdAt: new Date(),
|
||||
});
|
||||
return from(session.save());
|
||||
}),
|
||||
switchMap(() => EMPTY)
|
||||
)
|
||||
)
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,16 +3,19 @@ import { QuotaService } from '../../core/quota';
|
||||
import { PermissionService } from '../../core/workspaces/permission';
|
||||
import { Plugin } from '../registry';
|
||||
import { CopilotController } from './controller';
|
||||
import { ChatMessageCache } from './message';
|
||||
import { PromptService } from './prompt';
|
||||
import {
|
||||
assertProvidersConfigs,
|
||||
CopilotProviderService,
|
||||
FalProvider,
|
||||
OpenAIProvider,
|
||||
registerCopilotProvider,
|
||||
} from './providers';
|
||||
import { CopilotResolver, UserCopilotResolver } from './resolver';
|
||||
import { ChatSessionService } from './session';
|
||||
|
||||
registerCopilotProvider(FalProvider);
|
||||
registerCopilotProvider(OpenAIProvider);
|
||||
|
||||
@Plugin({
|
||||
@@ -22,6 +25,7 @@ registerCopilotProvider(OpenAIProvider);
|
||||
QuotaService,
|
||||
ChatSessionService,
|
||||
CopilotResolver,
|
||||
ChatMessageCache,
|
||||
UserCopilotResolver,
|
||||
PromptService,
|
||||
CopilotProviderService,
|
||||
|
||||
35
packages/backend/server/src/plugins/copilot/message.ts
Normal file
35
packages/backend/server/src/plugins/copilot/message.ts
Normal file
@@ -0,0 +1,35 @@
|
||||
import { randomUUID } from 'node:crypto';
|
||||
|
||||
import { Injectable, Logger } from '@nestjs/common';
|
||||
|
||||
import { SessionCache } from '../../fundamentals';
|
||||
import { SubmittedMessage, SubmittedMessageSchema } from './types';
|
||||
|
||||
const CHAT_MESSAGE_KEY = 'chat-message';
|
||||
const CHAT_MESSAGE_TTL = 3600 * 1 * 1000; // 1 hours
|
||||
|
||||
@Injectable()
|
||||
export class ChatMessageCache {
|
||||
private readonly logger = new Logger(ChatMessageCache.name);
|
||||
constructor(private readonly cache: SessionCache) {}
|
||||
|
||||
async get(id: string): Promise<SubmittedMessage | undefined> {
|
||||
return await this.cache.get(`${CHAT_MESSAGE_KEY}:${id}`);
|
||||
}
|
||||
|
||||
async set(message: SubmittedMessage): Promise<string | undefined> {
|
||||
try {
|
||||
const parsed = SubmittedMessageSchema.safeParse(message);
|
||||
if (parsed.success) {
|
||||
const id = randomUUID();
|
||||
await this.cache.set(`${CHAT_MESSAGE_KEY}:${id}`, parsed.data, {
|
||||
ttl: CHAT_MESSAGE_TTL,
|
||||
});
|
||||
return id;
|
||||
}
|
||||
} catch (e: any) {
|
||||
this.logger.error(`Failed to get chat message from cache: ${e.message}`);
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
92
packages/backend/server/src/plugins/copilot/providers/fal.ts
Normal file
92
packages/backend/server/src/plugins/copilot/providers/fal.ts
Normal file
@@ -0,0 +1,92 @@
|
||||
import assert from 'node:assert';
|
||||
|
||||
import {
|
||||
CopilotCapability,
|
||||
CopilotImageToImageProvider,
|
||||
CopilotProviderType,
|
||||
PromptMessage,
|
||||
} from '../types';
|
||||
|
||||
export type FalConfig = {
|
||||
apiKey: string;
|
||||
};
|
||||
|
||||
export type FalResponse = {
|
||||
images: Array<{ url: string }>;
|
||||
};
|
||||
|
||||
export class FalProvider implements CopilotImageToImageProvider {
|
||||
static readonly type = CopilotProviderType.FAL;
|
||||
static readonly capabilities = [CopilotCapability.ImageToImage];
|
||||
|
||||
readonly availableModels = [
|
||||
// image to image
|
||||
// https://blog.fal.ai/building-applications-with-real-time-stable-diffusion-apis/
|
||||
'110602490-lcm-sd15-i2i',
|
||||
];
|
||||
|
||||
constructor(private readonly config: FalConfig) {
|
||||
assert(FalProvider.assetsConfig(config));
|
||||
}
|
||||
|
||||
static assetsConfig(config: FalConfig) {
|
||||
return !!config.apiKey;
|
||||
}
|
||||
|
||||
getCapabilities(): CopilotCapability[] {
|
||||
return FalProvider.capabilities;
|
||||
}
|
||||
|
||||
// ====== image to image ======
|
||||
async generateImages(
|
||||
messages: PromptMessage[],
|
||||
model: string = this.availableModels[0],
|
||||
options: {
|
||||
signal?: AbortSignal;
|
||||
user?: string;
|
||||
} = {}
|
||||
): Promise<Array<string>> {
|
||||
const { content, attachments } = messages.pop() || {};
|
||||
if (!this.availableModels.includes(model)) {
|
||||
throw new Error(`Invalid model: ${model}`);
|
||||
}
|
||||
if (!content) {
|
||||
throw new Error('Prompt is required');
|
||||
}
|
||||
if (!Array.isArray(attachments) || !attachments.length) {
|
||||
throw new Error('Attachments is required');
|
||||
}
|
||||
|
||||
const data = (await fetch(`https://${model}.gateway.alpha.fal.ai/`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
Authorization: `key ${this.config.apiKey}`,
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
image_url: attachments[0],
|
||||
prompt: content,
|
||||
sync_mode: true,
|
||||
seed: 42,
|
||||
enable_safety_checks: false,
|
||||
}),
|
||||
signal: options.signal,
|
||||
}).then(res => res.json())) as FalResponse;
|
||||
|
||||
return data.images.map(image => image.url);
|
||||
}
|
||||
|
||||
async *generateImagesStream(
|
||||
messages: PromptMessage[],
|
||||
model: string = this.availableModels[0],
|
||||
options: {
|
||||
signal?: AbortSignal;
|
||||
user?: string;
|
||||
} = {}
|
||||
): AsyncIterable<string> {
|
||||
const ret = await this.generateImages(messages, model, options);
|
||||
for (const url of ret) {
|
||||
yield url;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -134,4 +134,5 @@ export class CopilotProviderService {
|
||||
}
|
||||
}
|
||||
|
||||
export { FalProvider } from './fal';
|
||||
export { OpenAIProvider } from './openai';
|
||||
|
||||
@@ -5,22 +5,31 @@ import { ClientOptions, OpenAI } from 'openai';
|
||||
import {
|
||||
ChatMessageRole,
|
||||
CopilotCapability,
|
||||
CopilotImageToTextProvider,
|
||||
CopilotProviderType,
|
||||
CopilotTextToEmbeddingProvider,
|
||||
CopilotTextToImageProvider,
|
||||
CopilotTextToTextProvider,
|
||||
PromptMessage,
|
||||
} from '../types';
|
||||
|
||||
const DEFAULT_DIMENSIONS = 256;
|
||||
|
||||
const SIMPLE_IMAGE_URL_REGEX = /^(https?:\/\/|data:image\/)/;
|
||||
|
||||
export class OpenAIProvider
|
||||
implements CopilotTextToTextProvider, CopilotTextToEmbeddingProvider
|
||||
implements
|
||||
CopilotTextToTextProvider,
|
||||
CopilotTextToEmbeddingProvider,
|
||||
CopilotTextToImageProvider,
|
||||
CopilotImageToTextProvider
|
||||
{
|
||||
static readonly type = CopilotProviderType.OpenAI;
|
||||
static readonly capabilities = [
|
||||
CopilotCapability.TextToText,
|
||||
CopilotCapability.TextToEmbedding,
|
||||
CopilotCapability.TextToImage,
|
||||
CopilotCapability.ImageToText,
|
||||
];
|
||||
|
||||
readonly availableModels = [
|
||||
@@ -35,6 +44,8 @@ export class OpenAIProvider
|
||||
// moderation
|
||||
'text-moderation-latest',
|
||||
'text-moderation-stable',
|
||||
// text to image
|
||||
'dall-e-3',
|
||||
];
|
||||
|
||||
private readonly instance: OpenAI;
|
||||
@@ -52,12 +63,29 @@ export class OpenAIProvider
|
||||
return OpenAIProvider.capabilities;
|
||||
}
|
||||
|
||||
private chatToGPTMessage(messages: PromptMessage[]) {
|
||||
private chatToGPTMessage(
|
||||
messages: PromptMessage[]
|
||||
): OpenAI.Chat.Completions.ChatCompletionMessageParam[] {
|
||||
// filter redundant fields
|
||||
return messages.map(message => ({
|
||||
role: message.role,
|
||||
content: message.content,
|
||||
}));
|
||||
return messages.map(({ role, content, attachments }) => {
|
||||
if (Array.isArray(attachments)) {
|
||||
const contents = [
|
||||
{ type: 'text', text: content },
|
||||
...attachments
|
||||
.filter(url => SIMPLE_IMAGE_URL_REGEX.test(url))
|
||||
.map(url => ({
|
||||
type: 'image_url',
|
||||
image_url: { url, detail: 'low' },
|
||||
})),
|
||||
];
|
||||
return {
|
||||
role,
|
||||
content: contents,
|
||||
} as OpenAI.Chat.Completions.ChatCompletionMessageParam;
|
||||
} else {
|
||||
return { role, content };
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
private checkParams({
|
||||
@@ -194,4 +222,44 @@ export class OpenAIProvider
|
||||
});
|
||||
return result.data.map(e => e.embedding);
|
||||
}
|
||||
|
||||
// ====== text to image ======
|
||||
async generateImages(
|
||||
messages: PromptMessage[],
|
||||
model: string = 'dall-e-3',
|
||||
options: {
|
||||
signal?: AbortSignal;
|
||||
user?: string;
|
||||
} = {}
|
||||
): Promise<Array<string>> {
|
||||
const { content: prompt } = messages.pop() || {};
|
||||
if (!prompt) {
|
||||
throw new Error('Prompt is required');
|
||||
}
|
||||
const result = await this.instance.images.generate(
|
||||
{
|
||||
prompt,
|
||||
model,
|
||||
response_format: 'url',
|
||||
user: options.user,
|
||||
},
|
||||
{ signal: options.signal }
|
||||
);
|
||||
|
||||
return result.data.map(image => image.url).filter((v): v is string => !!v);
|
||||
}
|
||||
|
||||
async *generateImagesStream(
|
||||
messages: PromptMessage[],
|
||||
model: string = 'dall-e-3',
|
||||
options: {
|
||||
signal?: AbortSignal;
|
||||
user?: string;
|
||||
} = {}
|
||||
): AsyncIterable<string> {
|
||||
const ret = await this.generateImages(messages, model, options);
|
||||
for (const url of ret) {
|
||||
yield url;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import { Logger } from '@nestjs/common';
|
||||
import {
|
||||
Args,
|
||||
Field,
|
||||
@@ -12,7 +13,7 @@ import {
|
||||
} from '@nestjs/graphql';
|
||||
import { SafeIntResolver } from 'graphql-scalars';
|
||||
|
||||
import { CurrentUser, Public } from '../../core/auth';
|
||||
import { CurrentUser } from '../../core/auth';
|
||||
import { QuotaService } from '../../core/quota';
|
||||
import { UserType } from '../../core/user';
|
||||
import { PermissionService } from '../../core/workspaces/permission';
|
||||
@@ -21,11 +22,19 @@ import {
|
||||
PaymentRequiredException,
|
||||
TooManyRequestsException,
|
||||
} from '../../fundamentals';
|
||||
import { ChatSessionService, ListHistoriesOptions } from './session';
|
||||
import { AvailableModels, type ChatHistory, type ChatMessage } from './types';
|
||||
import { ChatSessionService } from './session';
|
||||
import {
|
||||
AvailableModels,
|
||||
type ChatHistory,
|
||||
type ChatMessage,
|
||||
type ListHistoriesOptions,
|
||||
SubmittedMessage,
|
||||
} from './types';
|
||||
|
||||
registerEnumType(AvailableModels, { name: 'CopilotModel' });
|
||||
|
||||
const COPILOT_LOCKER = 'copilot';
|
||||
|
||||
// ================== Input Types ==================
|
||||
|
||||
@InputType()
|
||||
@@ -48,6 +57,21 @@ class CreateChatSessionInput {
|
||||
promptName!: string;
|
||||
}
|
||||
|
||||
@InputType()
|
||||
class CreateChatMessageInput implements Omit<SubmittedMessage, 'params'> {
|
||||
@Field(() => String)
|
||||
sessionId!: string;
|
||||
|
||||
@Field(() => String)
|
||||
content!: string;
|
||||
|
||||
@Field(() => [String], { nullable: true })
|
||||
attachments!: string[] | undefined;
|
||||
|
||||
@Field(() => String, { nullable: true })
|
||||
params!: string | undefined;
|
||||
}
|
||||
|
||||
@InputType()
|
||||
class QueryChatHistoriesInput implements Partial<ListHistoriesOptions> {
|
||||
@Field(() => Boolean, { nullable: true })
|
||||
@@ -118,6 +142,8 @@ export class CopilotType {
|
||||
|
||||
@Resolver(() => CopilotType)
|
||||
export class CopilotResolver {
|
||||
private readonly logger = new Logger(CopilotResolver.name);
|
||||
|
||||
constructor(
|
||||
private readonly permissions: PermissionService,
|
||||
private readonly quota: QuotaService,
|
||||
@@ -208,7 +234,6 @@ export class CopilotResolver {
|
||||
);
|
||||
}
|
||||
|
||||
@Public()
|
||||
@Mutation(() => String, {
|
||||
description: 'Create a chat session',
|
||||
})
|
||||
@@ -222,7 +247,7 @@ export class CopilotResolver {
|
||||
options.docId,
|
||||
user.id
|
||||
);
|
||||
const lockFlag = `session:${user.id}:${options.workspaceId}`;
|
||||
const lockFlag = `${COPILOT_LOCKER}:session:${user.id}:${options.workspaceId}`;
|
||||
await using lock = await this.mutex.lock(lockFlag);
|
||||
if (!lock) {
|
||||
return new TooManyRequestsException('Server is busy');
|
||||
@@ -241,6 +266,32 @@ export class CopilotResolver {
|
||||
});
|
||||
return session;
|
||||
}
|
||||
|
||||
@Mutation(() => String, {
|
||||
description: 'Create a chat message',
|
||||
})
|
||||
async createCopilotMessage(
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@Args({ name: 'options', type: () => CreateChatMessageInput })
|
||||
options: CreateChatMessageInput
|
||||
) {
|
||||
const lockFlag = `${COPILOT_LOCKER}:message:${user?.id}:${options.sessionId}`;
|
||||
await using lock = await this.mutex.lock(lockFlag);
|
||||
if (!lock) {
|
||||
return new TooManyRequestsException('Server is busy');
|
||||
}
|
||||
try {
|
||||
const { params, ...rest } = options;
|
||||
const record: SubmittedMessage['params'] = {};
|
||||
new URLSearchParams(params).forEach((value, key) => {
|
||||
record[key] = value;
|
||||
});
|
||||
return await this.chatSession.createMessage({ ...rest, params: record });
|
||||
} catch (e: any) {
|
||||
this.logger.error(`Failed to create chat message: ${e.message}`);
|
||||
throw new Error('Failed to create chat message');
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Resolver(() => UserType)
|
||||
|
||||
@@ -3,43 +3,26 @@ import { randomUUID } from 'node:crypto';
|
||||
import { Injectable, Logger } from '@nestjs/common';
|
||||
import { PrismaClient } from '@prisma/client';
|
||||
|
||||
import { ChatMessageCache } from './message';
|
||||
import { ChatPrompt, PromptService } from './prompt';
|
||||
import {
|
||||
AvailableModel,
|
||||
ChatHistory,
|
||||
ChatMessage,
|
||||
ChatMessageSchema,
|
||||
ChatSessionOptions,
|
||||
ChatSessionState,
|
||||
getTokenEncoder,
|
||||
ListHistoriesOptions,
|
||||
PromptMessage,
|
||||
PromptMessageSchema,
|
||||
PromptParams,
|
||||
SubmittedMessage,
|
||||
} from './types';
|
||||
|
||||
export interface ChatSessionOptions {
|
||||
userId: string;
|
||||
workspaceId: string;
|
||||
docId: string;
|
||||
promptName: string;
|
||||
}
|
||||
|
||||
export interface ChatSessionState
|
||||
extends Omit<ChatSessionOptions, 'promptName'> {
|
||||
// connect ids
|
||||
sessionId: string;
|
||||
// states
|
||||
prompt: ChatPrompt;
|
||||
messages: ChatMessage[];
|
||||
}
|
||||
|
||||
export type ListHistoriesOptions = {
|
||||
action: boolean | undefined;
|
||||
limit: number | undefined;
|
||||
skip: number | undefined;
|
||||
sessionId: string | undefined;
|
||||
};
|
||||
|
||||
export class ChatSession implements AsyncDisposable {
|
||||
constructor(
|
||||
private readonly messageCache: ChatMessageCache,
|
||||
private readonly state: ChatSessionState,
|
||||
private readonly dispose?: (state: ChatSessionState) => Promise<void>,
|
||||
private readonly maxTokenSize = 3840
|
||||
@@ -60,6 +43,29 @@ export class ChatSession implements AsyncDisposable {
|
||||
this.state.messages.push(message);
|
||||
}
|
||||
|
||||
async getMessageById(messageId: string) {
|
||||
const message = await this.messageCache.get(messageId);
|
||||
if (!message || message.sessionId !== this.state.sessionId) {
|
||||
throw new Error(`Message not found: ${messageId}`);
|
||||
}
|
||||
return message;
|
||||
}
|
||||
|
||||
async pushByMessageId(messageId: string) {
|
||||
const message = await this.messageCache.get(messageId);
|
||||
if (!message || message.sessionId !== this.state.sessionId) {
|
||||
throw new Error(`Message not found: ${messageId}`);
|
||||
}
|
||||
|
||||
this.push({
|
||||
role: 'user',
|
||||
content: message.content,
|
||||
attachments: message.attachments,
|
||||
params: message.params,
|
||||
createdAt: new Date(),
|
||||
});
|
||||
}
|
||||
|
||||
pop() {
|
||||
this.state.messages.pop();
|
||||
}
|
||||
@@ -109,6 +115,7 @@ export class ChatSessionService {
|
||||
|
||||
constructor(
|
||||
private readonly db: PrismaClient,
|
||||
private readonly messageCache: ChatMessageCache,
|
||||
private readonly prompt: PromptService
|
||||
) {}
|
||||
|
||||
@@ -326,6 +333,10 @@ export class ChatSessionService {
|
||||
});
|
||||
}
|
||||
|
||||
async createMessage(message: SubmittedMessage): Promise<string | undefined> {
|
||||
return await this.messageCache.set(message);
|
||||
}
|
||||
|
||||
/**
|
||||
* usage:
|
||||
* ``` typescript
|
||||
@@ -342,7 +353,7 @@ export class ChatSessionService {
|
||||
async get(sessionId: string): Promise<ChatSession | null> {
|
||||
const state = await this.getSession(sessionId);
|
||||
if (state) {
|
||||
return new ChatSession(state, async state => {
|
||||
return new ChatSession(this.messageCache, state, async state => {
|
||||
await this.setSession(state);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -8,10 +8,12 @@ import {
|
||||
} from 'tiktoken';
|
||||
import { z } from 'zod';
|
||||
|
||||
import type { ChatPrompt } from './prompt';
|
||||
|
||||
export interface CopilotConfig {
|
||||
openai: OpenAIClientOptions;
|
||||
fal: {
|
||||
secret: string;
|
||||
apiKey: string;
|
||||
};
|
||||
}
|
||||
|
||||
@@ -27,6 +29,8 @@ export enum AvailableModels {
|
||||
// moderation
|
||||
TextModerationLatest = 'text-moderation-latest',
|
||||
TextModerationStable = 'text-moderation-stable',
|
||||
// text to image
|
||||
DallE3 = 'dall-e-3',
|
||||
}
|
||||
|
||||
export type AvailableModel = keyof typeof AvailableModels;
|
||||
@@ -53,8 +57,7 @@ export const ChatMessageRole = Object.values(AiPromptRole) as [
|
||||
'user',
|
||||
];
|
||||
|
||||
export const PromptMessageSchema = z.object({
|
||||
role: z.enum(ChatMessageRole),
|
||||
const PureMessageSchema = z.object({
|
||||
content: z.string(),
|
||||
attachments: z.array(z.string()).optional(),
|
||||
params: z
|
||||
@@ -63,6 +66,10 @@ export const PromptMessageSchema = z.object({
|
||||
.nullable(),
|
||||
});
|
||||
|
||||
export const PromptMessageSchema = PureMessageSchema.extend({
|
||||
role: z.enum(ChatMessageRole),
|
||||
}).strict();
|
||||
|
||||
export type PromptMessage = z.infer<typeof PromptMessageSchema>;
|
||||
|
||||
export type PromptParams = NonNullable<PromptMessage['params']>;
|
||||
@@ -73,6 +80,12 @@ export const ChatMessageSchema = PromptMessageSchema.extend({
|
||||
|
||||
export type ChatMessage = z.infer<typeof ChatMessageSchema>;
|
||||
|
||||
export const SubmittedMessageSchema = PureMessageSchema.extend({
|
||||
sessionId: z.string(),
|
||||
}).strict();
|
||||
|
||||
export type SubmittedMessage = z.infer<typeof SubmittedMessageSchema>;
|
||||
|
||||
export const ChatHistorySchema = z
|
||||
.object({
|
||||
sessionId: z.string(),
|
||||
@@ -84,6 +97,32 @@ export const ChatHistorySchema = z
|
||||
|
||||
export type ChatHistory = z.infer<typeof ChatHistorySchema>;
|
||||
|
||||
// ======== Chat Session ========
|
||||
|
||||
export interface ChatSessionOptions {
|
||||
// connect ids
|
||||
userId: string;
|
||||
workspaceId: string;
|
||||
docId: string;
|
||||
promptName: string;
|
||||
}
|
||||
|
||||
export interface ChatSessionState
|
||||
extends Omit<ChatSessionOptions, 'promptName'> {
|
||||
// connect ids
|
||||
sessionId: string;
|
||||
// states
|
||||
prompt: ChatPrompt;
|
||||
messages: ChatMessage[];
|
||||
}
|
||||
|
||||
export type ListHistoriesOptions = {
|
||||
action: boolean | undefined;
|
||||
limit: number | undefined;
|
||||
skip: number | undefined;
|
||||
sessionId: string | undefined;
|
||||
};
|
||||
|
||||
// ======== Provider Interface ========
|
||||
|
||||
export enum CopilotProviderType {
|
||||
@@ -96,6 +135,7 @@ export enum CopilotCapability {
|
||||
TextToEmbedding = 'text-to-embedding',
|
||||
TextToImage = 'text-to-image',
|
||||
ImageToImage = 'image-to-image',
|
||||
ImageToText = 'image-to-text',
|
||||
}
|
||||
|
||||
export interface CopilotProvider {
|
||||
@@ -137,13 +177,71 @@ export interface CopilotTextToEmbeddingProvider extends CopilotProvider {
|
||||
): Promise<number[][]>;
|
||||
}
|
||||
|
||||
export interface CopilotTextToImageProvider extends CopilotProvider {}
|
||||
export interface CopilotTextToImageProvider extends CopilotProvider {
|
||||
generateImages(
|
||||
messages: PromptMessage[],
|
||||
model: string,
|
||||
options: {
|
||||
signal?: AbortSignal;
|
||||
user?: string;
|
||||
}
|
||||
): Promise<Array<string>>;
|
||||
generateImagesStream(
|
||||
messages: PromptMessage[],
|
||||
model?: string,
|
||||
options?: {
|
||||
signal?: AbortSignal;
|
||||
user?: string;
|
||||
}
|
||||
): AsyncIterable<string>;
|
||||
}
|
||||
|
||||
export interface CopilotImageToImageProvider extends CopilotProvider {}
|
||||
export interface CopilotImageToTextProvider extends CopilotProvider {
|
||||
generateText(
|
||||
messages: PromptMessage[],
|
||||
model: string,
|
||||
options: {
|
||||
temperature?: number;
|
||||
maxTokens?: number;
|
||||
signal?: AbortSignal;
|
||||
user?: string;
|
||||
}
|
||||
): Promise<string>;
|
||||
generateTextStream(
|
||||
messages: PromptMessage[],
|
||||
model: string,
|
||||
options: {
|
||||
temperature?: number;
|
||||
maxTokens?: number;
|
||||
signal?: AbortSignal;
|
||||
user?: string;
|
||||
}
|
||||
): AsyncIterable<string>;
|
||||
}
|
||||
|
||||
export interface CopilotImageToImageProvider extends CopilotProvider {
|
||||
generateImages(
|
||||
messages: PromptMessage[],
|
||||
model: string,
|
||||
options: {
|
||||
signal?: AbortSignal;
|
||||
user?: string;
|
||||
}
|
||||
): Promise<Array<string>>;
|
||||
generateImagesStream(
|
||||
messages: PromptMessage[],
|
||||
model?: string,
|
||||
options?: {
|
||||
signal?: AbortSignal;
|
||||
user?: string;
|
||||
}
|
||||
): AsyncIterable<string>;
|
||||
}
|
||||
|
||||
export type CapabilityToCopilotProvider = {
|
||||
[CopilotCapability.TextToText]: CopilotTextToTextProvider;
|
||||
[CopilotCapability.TextToEmbedding]: CopilotTextToEmbeddingProvider;
|
||||
[CopilotCapability.TextToImage]: CopilotTextToImageProvider;
|
||||
[CopilotCapability.ImageToText]: CopilotImageToTextProvider;
|
||||
[CopilotCapability.ImageToImage]: CopilotImageToImageProvider;
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user