feat: allow sort and filter forked session (#7519)

This commit is contained in:
DarkSky
2024-07-18 11:08:47 +08:00
committed by GitHub
parent ccac7a883c
commit dcb9d75db7
10 changed files with 166 additions and 26 deletions

View File

@@ -108,17 +108,33 @@ class CreateChatMessageInput implements Omit<SubmittedMessage, 'content'> {
params!: Record<string, string> | undefined;
}
enum ChatHistoryOrder {
asc = 'asc',
desc = 'desc',
}
registerEnumType(ChatHistoryOrder, { name: 'ChatHistoryOrder' });
@InputType()
class QueryChatHistoriesInput implements Partial<ListHistoriesOptions> {
@Field(() => Boolean, { nullable: true })
action: boolean | undefined;
@Field(() => Boolean, { nullable: true })
fork: boolean | undefined;
@Field(() => Number, { nullable: true })
limit: number | undefined;
@Field(() => Number, { nullable: true })
skip: number | undefined;
@Field(() => ChatHistoryOrder, { nullable: true })
messageOrder: 'asc' | 'desc' | undefined;
@Field(() => ChatHistoryOrder, { nullable: true })
sessionOrder: 'asc' | 'desc' | undefined;
@Field(() => String, { nullable: true })
sessionId: string | undefined;
}

View File

@@ -382,6 +382,21 @@ export class ChatSessionService {
options?: ListHistoriesOptions,
withPrompt = false
): Promise<ChatHistory[]> {
const extraCondition = [];
if (!options?.action && options?.fork) {
// only query forked session if fork == true and action == false
extraCondition.push({
userId: { not: userId },
workspaceId: workspaceId,
docId: workspaceId === docId ? undefined : docId,
id: options?.sessionId ? { equals: options.sessionId } : undefined,
// should only find forked session
parentSessionId: { not: null },
deletedAt: null,
});
}
return await this.db.aiSession
.findMany({
where: {
@@ -395,21 +410,7 @@ export class ChatSessionService {
: undefined,
deletedAt: null,
},
...(options?.action
? []
: [
{
userId: { not: userId },
workspaceId: workspaceId,
docId: workspaceId === docId ? undefined : docId,
id: options?.sessionId
? { equals: options.sessionId }
: undefined,
// should only find forked session
parentSessionId: { not: null },
deletedAt: null,
},
]),
...extraCondition,
],
},
select: {
@@ -428,13 +429,17 @@ export class ChatSessionService {
createdAt: true,
},
orderBy: {
createdAt: 'asc',
// message order is asc by default
createdAt: options?.messageOrder === 'desc' ? 'desc' : 'asc',
},
},
},
take: options?.limit,
skip: options?.skip,
orderBy: { createdAt: 'desc' },
orderBy: {
// session order is desc by default
createdAt: options?.sessionOrder === 'asc' ? 'asc' : 'desc',
},
})
.then(sessions =>
Promise.all(

View File

@@ -131,8 +131,11 @@ export interface ChatSessionState
export type ListHistoriesOptions = {
action: boolean | undefined;
fork: boolean | undefined;
limit: number | undefined;
skip: number | undefined;
sessionOrder: 'asc' | 'desc' | undefined;
messageOrder: 'asc' | 'desc' | undefined;
sessionId: string | undefined;
};

View File

@@ -7,6 +7,11 @@ type BlobNotFoundDataType {
workspaceId: String!
}
enum ChatHistoryOrder {
asc
desc
}
type ChatMessage {
attachments: [String!]
content: String!
@@ -554,8 +559,11 @@ type Query {
input QueryChatHistoriesInput {
action: Boolean
fork: Boolean
limit: Int
messageOrder: ChatHistoryOrder
sessionId: String
sessionOrder: ChatHistoryOrder
skip: Int
}

View File

@@ -564,15 +564,29 @@ test('should be able to list history', async t => {
promptName
);
const messageId = await createCopilotMessage(app, token, sessionId);
const messageId = await createCopilotMessage(app, token, sessionId, 'hello');
await chatWithText(app, token, sessionId, messageId);
const histories = await getHistories(app, token, { workspaceId });
t.deepEqual(
histories.map(h => h.messages.map(m => m.content)),
[['generate text to text']],
'should be able to list history'
);
{
const histories = await getHistories(app, token, { workspaceId });
t.deepEqual(
histories.map(h => h.messages.map(m => m.content)),
[['hello', 'generate text to text']],
'should be able to list history'
);
}
{
const histories = await getHistories(app, token, {
workspaceId,
options: { messageOrder: 'desc' },
});
t.deepEqual(
histories.map(h => h.messages.map(m => m.content)),
[['generate text to text', 'hello']],
'should be able to list history'
);
}
});
test('should reject request that user have not permission', async t => {

View File

@@ -27,7 +27,7 @@ import {
WorkflowParams,
} from '../../src/plugins/copilot/workflow/types';
import { gql } from './common';
import { handleGraphQLError } from './utils';
import { handleGraphQLError, sleep } from './utils';
// @ts-expect-error no error
export class MockCopilotTestProvider
@@ -84,6 +84,8 @@ export class MockCopilotTestProvider
options: CopilotChatOptions = {}
): Promise<string> {
this.checkParams({ messages, model, options });
// make some time gap for history test case
await sleep(100);
return 'generate text to text';
}
@@ -94,6 +96,8 @@ export class MockCopilotTestProvider
): AsyncIterable<string> {
this.checkParams({ messages, model, options });
// make some time gap for history test case
await sleep(100);
const result = 'generate text to text stream';
for await (const message of result) {
yield message;
@@ -113,6 +117,8 @@ export class MockCopilotTestProvider
messages = Array.isArray(messages) ? messages : [messages];
this.checkParams({ embeddings: messages, model, options });
// make some time gap for history test case
await sleep(100);
return [Array.from(randomBytes(options.dimensions)).map(v => v % 128)];
}
@@ -130,6 +136,8 @@ export class MockCopilotTestProvider
throw new Error('Prompt is required');
}
// make some time gap for history test case
await sleep(100);
// just let test case can easily verify the final prompt
return [`https://example.com/${model}.jpg`, prompt];
}
@@ -338,10 +346,13 @@ export async function getHistories(
workspaceId: string;
docId?: string;
options?: {
sessionId?: string;
action?: boolean;
fork?: boolean;
limit?: number;
skip?: number;
sessionOrder?: 'asc' | 'desc';
messageOrder?: 'asc' | 'desc';
sessionId?: string;
};
}
): Promise<History[]> {

View File

@@ -167,3 +167,7 @@ export function gql(app: INestApplication, query?: string) {
return req;
}
export async function sleep(ms: number) {
return new Promise(resolve => setTimeout(resolve, ms));
}