diff --git a/packages/backend/server/migrations/20240506051856_add_user_and_features_index/migration.sql b/packages/backend/server/migrations/20240506051856_add_user_and_features_index/migration.sql new file mode 100644 index 0000000000..2e9b63c446 --- /dev/null +++ b/packages/backend/server/migrations/20240506051856_add_user_and_features_index/migration.sql @@ -0,0 +1,5 @@ +-- CreateIndex +CREATE INDEX "user_features_user_id_idx" ON "user_features"("user_id"); + +-- CreateIndex +CREATE INDEX "users_email_idx" ON "users"("email"); diff --git a/packages/backend/server/schema.prisma b/packages/backend/server/schema.prisma index 920268a1c1..f429304306 100644 --- a/packages/backend/server/schema.prisma +++ b/packages/backend/server/schema.prisma @@ -32,6 +32,7 @@ model User { sessions UserSession[] aiSessions AiSession[] + @@index([email]) @@map("users") } @@ -195,6 +196,7 @@ model UserFeatures { feature Features @relation(fields: [featureId], references: [id], onDelete: Cascade) user User @relation(fields: [userId], references: [id], onDelete: Cascade) + @@index([userId]) @@map("user_features") } diff --git a/packages/backend/server/src/core/features/management.ts b/packages/backend/server/src/core/features/management.ts index b0c23bd2f3..3b34c279d3 100644 --- a/packages/backend/server/src/core/features/management.ts +++ b/packages/backend/server/src/core/features/management.ts @@ -138,19 +138,11 @@ export class FeatureManagementService { async addWorkspaceFeatures( workspaceId: string, feature: FeatureType, - version?: number, reason?: string ) { - const latestVersions = await this.feature.getFeaturesVersion(); - // use latest version if not specified - const latestVersion = version || latestVersions[feature]; - if (!Number.isInteger(latestVersion)) { - throw new Error(`Version of feature ${feature} not found`); - } return this.feature.addWorkspaceFeature( workspaceId, feature, - latestVersion, reason || 'add feature by api' ); } diff --git a/packages/backend/server/src/core/features/service.ts b/packages/backend/server/src/core/features/service.ts index 4cb4a62da6..4c27a22d9c 100644 --- a/packages/backend/server/src/core/features/service.ts +++ b/packages/backend/server/src/core/features/service.ts @@ -8,33 +8,6 @@ import { FeatureKind, FeatureType } from './types'; @Injectable() export class FeatureService { constructor(private readonly prisma: PrismaClient) {} - - async getFeaturesVersion() { - const features = await this.prisma.features.findMany({ - where: { - type: FeatureKind.Feature, - }, - select: { - feature: true, - version: true, - }, - }); - return features.reduce( - (acc, feature) => { - // only keep the latest version - if (acc[feature.feature]) { - if (acc[feature.feature] < feature.version) { - acc[feature.feature] = feature.version; - } - } else { - acc[feature.feature] = feature.version; - } - return acc; - }, - {} as Record - ); - } - async getFeature( feature: F ): Promise | undefined> { @@ -80,14 +53,15 @@ export class FeatureService { if (latestFlag) { return latestFlag.id; } else { - const latestVersion = await tx.features - .aggregate({ - where: { feature }, - _max: { version: true }, + const featureId = await tx.features + .findFirst({ + where: { feature, type: FeatureKind.Feature }, + orderBy: { version: 'desc' }, + select: { id: true }, }) - .then(r => r._max.version); + .then(r => r?.id); - if (!latestVersion) { + if (!featureId) { throw new Error(`Feature ${feature} not found`); } @@ -97,20 +71,8 @@ export class FeatureService { reason, expiredAt, activated: true, - user: { - connect: { - id: userId, - }, - }, - feature: { - connect: { - feature_version: { - feature, - version: latestVersion, - }, - type: FeatureKind.Feature, - }, - }, + userId, + featureId, }, }) .then(r => r.id); @@ -144,10 +106,8 @@ export class FeatureService { async getUserFeatures(userId: string) { const features = await this.prisma.userFeatures.findMany({ where: { - user: { id: userId }, - feature: { - type: FeatureKind.Feature, - }, + userId, + feature: { type: FeatureKind.Feature }, }, select: { activated: true, @@ -171,7 +131,7 @@ export class FeatureService { async getActivatedUserFeatures(userId: string) { const features = await this.prisma.userFeatures.findMany({ where: { - user: { id: userId }, + userId, feature: { type: FeatureKind.Feature }, activated: true, OR: [{ expiredAt: null }, { expiredAt: { gt: new Date() } }], @@ -242,7 +202,6 @@ export class FeatureService { async addWorkspaceFeature( workspaceId: string, feature: FeatureType, - version: number, reason: string, expiredAt?: Date | string ) { @@ -263,26 +222,27 @@ export class FeatureService { if (latestFlag) { return latestFlag.id; } else { + // use latest version of feature + const featureId = await tx.features + .findFirst({ + where: { feature, type: FeatureKind.Feature }, + select: { id: true }, + orderBy: { version: 'desc' }, + }) + .then(r => r?.id); + + if (!featureId) { + throw new Error(`Feature ${feature} not found`); + } + return tx.workspaceFeatures .create({ data: { reason, expiredAt, activated: true, - workspace: { - connect: { - id: workspaceId, - }, - }, - feature: { - connect: { - feature_version: { - feature, - version, - }, - type: FeatureKind.Feature, - }, - }, + workspaceId, + featureId, }, }) .then(r => r.id); diff --git a/packages/backend/server/src/core/quota/service.ts b/packages/backend/server/src/core/quota/service.ts index b19a2729d1..7ad6464dd4 100644 --- a/packages/backend/server/src/core/quota/service.ts +++ b/packages/backend/server/src/core/quota/service.ts @@ -19,9 +19,7 @@ export class QuotaService { async getUserQuota(userId: string) { const quota = await this.prisma.userFeatures.findFirst({ where: { - user: { - id: userId, - }, + userId, feature: { type: FeatureKind.Quota, }, @@ -48,9 +46,7 @@ export class QuotaService { async getUserQuotas(userId: string) { const quotas = await this.prisma.userFeatures.findMany({ where: { - user: { - id: userId, - }, + userId, feature: { type: FeatureKind.Quota, }, @@ -96,14 +92,17 @@ export class QuotaService { return; } - const latestPlanVersion = await tx.features.aggregate({ - where: { - feature: quota, - }, - _max: { - version: true, - }, - }); + const featureId = await tx.features + .findFirst({ + where: { feature: quota, type: FeatureKind.Quota }, + select: { id: true }, + orderBy: { version: 'desc' }, + }) + .then(f => f?.id); + + if (!featureId) { + throw new Error(`Quota ${quota} not found`); + } // we will deactivate all exists quota for this user await tx.userFeatures.updateMany({ @@ -121,20 +120,8 @@ export class QuotaService { await tx.userFeatures.create({ data: { - user: { - connect: { - id: userId, - }, - }, - feature: { - connect: { - feature_version: { - feature: quota, - version: latestPlanVersion._max.version || 1, - }, - type: FeatureKind.Quota, - }, - }, + userId, + featureId, reason: reason ?? 'switch quota', activated: true, expiredAt, diff --git a/packages/backend/server/src/core/workspaces/management.ts b/packages/backend/server/src/core/workspaces/management.ts index e28932f970..942dc62df8 100644 --- a/packages/backend/server/src/core/workspaces/management.ts +++ b/packages/backend/server/src/core/workspaces/management.ts @@ -81,7 +81,6 @@ export class WorkspaceManagementResolver { .addWorkspaceFeatures( workspaceId, feature, - undefined, 'add by experimental feature api' ) .then(id => id > 0); diff --git a/packages/backend/server/src/core/workspaces/resolvers/workspace.ts b/packages/backend/server/src/core/workspaces/resolvers/workspace.ts index a16b2d9e68..9bf0bdbba3 100644 --- a/packages/backend/server/src/core/workspaces/resolvers/workspace.ts +++ b/packages/backend/server/src/core/workspaces/resolvers/workspace.ts @@ -218,11 +218,7 @@ export class WorkspaceResolver { permissions: { create: { type: Permission.Owner, - user: { - connect: { - id: user.id, - }, - }, + userId: user.id, accepted: true, }, }, diff --git a/packages/backend/server/src/data/migrations/utils/user-features.ts b/packages/backend/server/src/data/migrations/utils/user-features.ts index 35510e5547..fdc7c9130d 100644 --- a/packages/backend/server/src/data/migrations/utils/user-features.ts +++ b/packages/backend/server/src/data/migrations/utils/user-features.ts @@ -46,6 +46,16 @@ export async function upsertLatestFeatureVersion( export async function migrateNewFeatureTable(prisma: PrismaClient) { const waitingList = await prisma.newFeaturesWaitingList.findMany(); + const latestEarlyAccessFeatureId = await prisma.features + .findFirst({ + where: { feature: FeatureType.EarlyAccess, type: FeatureKind.Feature }, + select: { id: true }, + orderBy: { version: 'desc' }, + }) + .then(r => r?.id); + if (!latestEarlyAccessFeatureId) { + throw new Error('Feature EarlyAccess not found'); + } for (const oldUser of waitingList) { const user = await prisma.user.findFirst({ where: { @@ -85,20 +95,8 @@ export async function migrateNewFeatureTable(prisma: PrismaClient) { data: { reason: 'Early access user', activated: true, - user: { - connect: { - id: user.id, - }, - }, - feature: { - connect: { - feature_version: { - feature: FeatureType.EarlyAccess, - version: 1, - }, - type: FeatureKind.Feature, - }, - }, + userId: user.id, + featureId: latestEarlyAccessFeatureId, }, }) .then(r => r.id); diff --git a/packages/backend/server/src/plugins/copilot/session.ts b/packages/backend/server/src/plugins/copilot/session.ts index 92c1388bc5..d313e31a34 100644 --- a/packages/backend/server/src/plugins/copilot/session.ts +++ b/packages/backend/server/src/plugins/copilot/session.ts @@ -23,6 +23,7 @@ import { } from './types'; export class ChatSession implements AsyncDisposable { + private stashMessageCount = 0; constructor( private readonly messageCache: ChatMessageCache, private readonly state: ChatSessionState, @@ -46,6 +47,11 @@ export class ChatSession implements AsyncDisposable { return { sessionId, userId, workspaceId, docId, promptName }; } + get stashMessages() { + if (!this.stashMessageCount) return []; + return this.state.messages.slice(-this.stashMessageCount); + } + push(message: ChatMessage) { if ( this.state.prompt.action && @@ -55,6 +61,7 @@ export class ChatSession implements AsyncDisposable { throw new Error('Action has been taken, no more messages allowed'); } this.state.messages.push(message); + this.stashMessageCount += 1; } async getMessageById(messageId: string) { @@ -141,7 +148,12 @@ export class ChatSession implements AsyncDisposable { } async save() { - await this.dispose?.(this.state); + await this.dispose?.({ + ...this.state, + // only provide new messages + messages: this.stashMessages, + }); + this.stashMessageCount = 0; } async [Symbol.asyncDispose]() { @@ -181,36 +193,40 @@ export class ChatSessionService { if (id) sessionId = id; } - const messages = state.messages.map(m => ({ - ...m, - attachments: m.attachments || undefined, - params: m.params || undefined, - })); + const haveSession = await tx.aiSession + .count({ + where: { + id: sessionId, + userId: state.userId, + }, + }) + .then(c => c > 0); - await tx.aiSession.upsert({ - where: { - id: sessionId, - userId: state.userId, - }, - update: { - messages: { - // skip delete old messages if no new messages - deleteMany: messages.length ? {} : undefined, - create: messages, + if (haveSession) { + // message will only exists when setSession call by session.save + if (state.messages.length) { + await tx.aiSessionMessage.createMany({ + data: state.messages.map(m => ({ + ...m, + attachments: m.attachments || undefined, + params: m.params || undefined, + sessionId, + })), + }); + } + } else { + await tx.aiSession.create({ + data: { + id: sessionId, + workspaceId: state.workspaceId, + docId: state.docId, + // connect + userId: state.userId, + promptName: state.prompt.name, }, - }, - create: { - id: sessionId, - workspaceId: state.workspaceId, - docId: state.docId, - messages: { - create: messages, - }, - // connect - user: { connect: { id: state.userId } }, - prompt: { connect: { name: state.prompt.name } }, - }, - }); + }); + } + return sessionId; }); } diff --git a/packages/backend/server/tests/copilot.spec.ts b/packages/backend/server/tests/copilot.spec.ts index 40f4bae9ec..f9976e8b17 100644 --- a/packages/backend/server/tests/copilot.spec.ts +++ b/packages/backend/server/tests/copilot.spec.ts @@ -336,6 +336,32 @@ test('should be able to generate with message id', async t => { } }); +test('should save message correctly', async t => { + const { prompt, session } = t.context; + + await prompt.set('prompt', 'model', [ + { role: 'system', content: 'hello {{word}}' }, + ]); + + const sessionId = await session.create({ + docId: 'test', + workspaceId: 'test', + userId, + promptName: 'prompt', + }); + const s = (await session.get(sessionId))!; + + const message = (await session.createMessage({ + sessionId, + content: 'hello', + }))!; + + await s.pushByMessageId(message); + t.is(s.stashMessages.length, 1, 'should get stash messages'); + await s.save(); + t.is(s.stashMessages.length, 0, 'should empty stash messages after save'); +}); + // ==================== provider ==================== test('should be able to get provider', async t => { diff --git a/packages/backend/server/tests/feature.spec.ts b/packages/backend/server/tests/feature.spec.ts index 52be2609fc..d8e19bc32b 100644 --- a/packages/backend/server/tests/feature.spec.ts +++ b/packages/backend/server/tests/feature.spec.ts @@ -29,11 +29,7 @@ class WorkspaceResolverMock { permissions: { create: { type: Permission.Owner, - user: { - connect: { - id: user.id, - }, - }, + userId: user.id, accepted: true, }, }, @@ -163,7 +159,7 @@ test('should be able to set workspace feature', async t => { const f1 = await feature.getWorkspaceFeatures(w1.id); t.is(f1.length, 0, 'should be empty'); - await feature.addWorkspaceFeature(w1.id, FeatureType.Copilot, 1, 'test'); + await feature.addWorkspaceFeature(w1.id, FeatureType.Copilot, 'test'); const f2 = await feature.getWorkspaceFeatures(w1.id); t.is(f2.length, 1, 'should have 1 feature'); @@ -178,7 +174,7 @@ test('should be able to check workspace feature', async t => { const f1 = await management.hasWorkspaceFeature(w1.id, FeatureType.Copilot); t.false(f1, 'should not have copilot'); - await management.addWorkspaceFeatures(w1.id, FeatureType.Copilot, 1, 'test'); + await management.addWorkspaceFeatures(w1.id, FeatureType.Copilot, 'test'); const f2 = await management.hasWorkspaceFeature(w1.id, FeatureType.Copilot); t.true(f2, 'should have copilot'); @@ -195,7 +191,7 @@ test('should be able revert workspace feature', async t => { const f1 = await management.hasWorkspaceFeature(w1.id, FeatureType.Copilot); t.false(f1, 'should not have feature'); - await management.addWorkspaceFeatures(w1.id, FeatureType.Copilot, 1, 'test'); + await management.addWorkspaceFeatures(w1.id, FeatureType.Copilot, 'test'); const f2 = await management.hasWorkspaceFeature(w1.id, FeatureType.Copilot); t.true(f2, 'should have feature'); diff --git a/tests/kit/utils/cloud.ts b/tests/kit/utils/cloud.ts index 683e941986..c61b9cd565 100644 --- a/tests/kit/utils/cloud.ts +++ b/tests/kit/utils/cloud.ts @@ -97,6 +97,14 @@ export async function createRandomUser(): Promise<{ password: '123456', }; const result = await runPrisma(async client => { + const featureId = await client.features + .findFirst({ + where: { feature: 'free_plan_v1' }, + select: { id: true }, + orderBy: { version: 'desc' }, + }) + .then(f => f!.id); + await client.user.create({ data: { ...user, @@ -106,14 +114,7 @@ export async function createRandomUser(): Promise<{ create: { reason: 'created by test case', activated: true, - feature: { - connect: { - feature_version: { - feature: 'free_plan_v1', - version: 1, - }, - }, - }, + featureId, }, }, },