mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-19 07:17:00 +08:00
feat(core): support ai network search (#9357)
### What Changed?
- Add `PerplexityProvider` in backend.
- Update session prompt name if user toggle network search mode in chat panel.
- Add experimental flag for AI network search feature.
- Add unit tests and e2e tests.
Search results are streamed and appear word for word:
<div class='graphite__hidden'>
<div>🎥 Video uploaded on Graphite:</div>
<a href="https://app.graphite.dev/media/video/sJGviKxfE3Ap685cl5bj/56f6ec7b-4b21-405f-9612-43e083f6fb84.mov">
<img src="https://app.graphite.dev/api/v1/graphite/video/thumbnail/sJGviKxfE3Ap685cl5bj/56f6ec7b-4b21-405f-9612-43e083f6fb84.mov">
</a>
</div>
<video src="https://graphite-user-uploaded-assets-prod.s3.amazonaws.com/sJGviKxfE3Ap685cl5bj/56f6ec7b-4b21-405f-9612-43e083f6fb84.mov">录屏2024-12-27 18.58.40.mov</video>
Click the little globe icon to manually turn on/off Internet search:
<div class='graphite__hidden'>
<div>🎥 Video uploaded on Graphite:</div>
<a href="https://app.graphite.dev/media/video/sJGviKxfE3Ap685cl5bj/778f1406-bf29-498e-a90d-7dad813392d1.mov">
<img src="https://app.graphite.dev/api/v1/graphite/video/thumbnail/sJGviKxfE3Ap685cl5bj/778f1406-bf29-498e-a90d-7dad813392d1.mov">
</a>
</div>
<video src="https://graphite-user-uploaded-assets-prod.s3.amazonaws.com/sJGviKxfE3Ap685cl5bj/778f1406-bf29-498e-a90d-7dad813392d1.mov">录屏2024-12-27 19.01.16.mov</video>
When there is an image, it will automatically switch to the openai model:
<div class='graphite__hidden'>
<div>🎥 Video uploaded on Graphite:</div>
<a href="https://app.graphite.dev/media/video/sJGviKxfE3Ap685cl5bj/56431d8e-75e1-4d84-ab4a-b6636042cc6a.mov">
<img src="https://app.graphite.dev/api/v1/graphite/video/thumbnail/sJGviKxfE3Ap685cl5bj/56431d8e-75e1-4d84-ab4a-b6636042cc6a.mov">
</a>
</div>
<video src="https://graphite-user-uploaded-assets-prod.s3.amazonaws.com/sJGviKxfE3Ap685cl5bj/56431d8e-75e1-4d84-ab4a-b6636042cc6a.mov">录屏2024-12-27 19.02.13.mov</video>
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
# REDIS_SERVER_HOST=localhost
|
||||
# COPILOT_FAL_API_KEY=YOUR_KEY
|
||||
# COPILOT_OPENAI_API_KEY=YOUR_KEY
|
||||
# COPILOT_PERPLEXITY_API_KEY=YOUR_KEY
|
||||
|
||||
# MAILER_HOST=127.0.0.1
|
||||
# MAILER_PORT=1025
|
||||
|
||||
@@ -59,6 +59,7 @@
|
||||
"@socket.io/redis-adapter": "^8.3.0",
|
||||
"cookie-parser": "^1.4.7",
|
||||
"dotenv": "^16.4.7",
|
||||
"eventsource-parser": "^3.0.0",
|
||||
"express": "^4.21.2",
|
||||
"fast-xml-parser": "^4.5.0",
|
||||
"get-stream": "^9.0.1",
|
||||
|
||||
@@ -28,6 +28,7 @@ AFFiNE.ENV_MAP = {
|
||||
CAPTCHA_TURNSTILE_SECRET: ['plugins.captcha.turnstile.secret', 'string'],
|
||||
COPILOT_OPENAI_API_KEY: 'plugins.copilot.openai.apiKey',
|
||||
COPILOT_FAL_API_KEY: 'plugins.copilot.fal.apiKey',
|
||||
COPILOT_PERPLEXITY_API_KEY: 'plugins.copilot.perplexity.apiKey',
|
||||
COPILOT_UNSPLASH_API_KEY: 'plugins.copilot.unsplashKey',
|
||||
REDIS_SERVER_HOST: 'redis.host',
|
||||
REDIS_SERVER_PORT: ['redis.port', 'int'],
|
||||
|
||||
@@ -3,10 +3,12 @@ import type { ClientOptions as OpenAIClientOptions } from 'openai';
|
||||
import { defineStartupConfig, ModuleConfig } from '../../base/config';
|
||||
import { StorageConfig } from '../../base/storage/config';
|
||||
import type { FalConfig } from './providers/fal';
|
||||
import { PerplexityConfig } from './providers/perplexity';
|
||||
|
||||
export interface CopilotStartupConfigurations {
|
||||
openai?: OpenAIClientOptions;
|
||||
fal?: FalConfig;
|
||||
perplexity?: PerplexityConfig;
|
||||
test?: never;
|
||||
unsplashKey?: string;
|
||||
storage: StorageConfig;
|
||||
|
||||
@@ -13,6 +13,7 @@ import {
|
||||
CopilotProviderService,
|
||||
FalProvider,
|
||||
OpenAIProvider,
|
||||
PerplexityProvider,
|
||||
registerCopilotProvider,
|
||||
} from './providers';
|
||||
import {
|
||||
@@ -26,6 +27,7 @@ import { CopilotWorkflowExecutors, CopilotWorkflowService } from './workflow';
|
||||
|
||||
registerCopilotProvider(FalProvider);
|
||||
registerCopilotProvider(OpenAIProvider);
|
||||
registerCopilotProvider(PerplexityProvider);
|
||||
|
||||
@Plugin({
|
||||
name: 'copilot',
|
||||
|
||||
@@ -952,6 +952,11 @@ const chat: Prompt[] = [
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
name: 'Search With AFFiNE AI',
|
||||
model: 'llama-3.1-sonar-small-128k-online',
|
||||
messages: [],
|
||||
},
|
||||
// use for believer plan
|
||||
{
|
||||
name: 'Chat With AFFiNE AI - Believer',
|
||||
|
||||
@@ -124,9 +124,7 @@ export class CopilotProviderService {
|
||||
if (!this.cachedProviders.has(provider)) {
|
||||
this.cachedProviders.set(provider, this.create(provider));
|
||||
}
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
||||
return this.cachedProviders.get(provider)!;
|
||||
return this.cachedProviders.get(provider) as CopilotProvider;
|
||||
}
|
||||
|
||||
async getProviderByCapability<C extends CopilotCapability>(
|
||||
@@ -196,3 +194,4 @@ export class CopilotProviderService {
|
||||
|
||||
export { FalProvider } from './fal';
|
||||
export { OpenAIProvider } from './openai';
|
||||
export { PerplexityProvider } from './perplexity';
|
||||
|
||||
@@ -0,0 +1,325 @@
|
||||
import assert from 'node:assert';
|
||||
|
||||
import { EventSourceParserStream } from 'eventsource-parser/stream';
|
||||
import { z } from 'zod';
|
||||
|
||||
import {
|
||||
CopilotPromptInvalid,
|
||||
CopilotProviderSideError,
|
||||
metrics,
|
||||
} from '../../../base';
|
||||
import {
|
||||
CopilotCapability,
|
||||
CopilotChatOptions,
|
||||
CopilotProviderType,
|
||||
CopilotTextToTextProvider,
|
||||
PromptMessage,
|
||||
} from '../types';
|
||||
|
||||
export type PerplexityConfig = {
|
||||
apiKey: string;
|
||||
endpoint?: string;
|
||||
};
|
||||
|
||||
const PerplexityErrorSchema = z.object({
|
||||
detail: z.array(
|
||||
z.object({
|
||||
loc: z.array(z.string()),
|
||||
msg: z.string(),
|
||||
type: z.string(),
|
||||
})
|
||||
),
|
||||
});
|
||||
|
||||
const PerplexityDataSchema = z.object({
|
||||
citations: z.array(z.string()),
|
||||
choices: z.array(
|
||||
z.object({
|
||||
message: z.object({
|
||||
content: z.string(),
|
||||
role: z.literal('assistant'),
|
||||
}),
|
||||
delta: z.object({
|
||||
content: z.string(),
|
||||
role: z.literal('assistant'),
|
||||
}),
|
||||
finish_reason: z.union([z.literal('stop'), z.literal(null)]),
|
||||
})
|
||||
),
|
||||
});
|
||||
|
||||
const PerplexitySchema = z.union([PerplexityDataSchema, PerplexityErrorSchema]);
|
||||
|
||||
export class CitationParser {
|
||||
private readonly SQUARE_BRACKET_OPEN = '[';
|
||||
|
||||
private readonly SQUARE_BRACKET_CLOSE = ']';
|
||||
|
||||
private readonly PARENTHESES_OPEN = '(';
|
||||
|
||||
private startToken: string[] = [];
|
||||
|
||||
private endToken: string[] = [];
|
||||
|
||||
private numberToken: string[] = [];
|
||||
|
||||
public parse(content: string, citations: string[]) {
|
||||
let result = '';
|
||||
const contentArray = content.split('');
|
||||
for (const [index, char] of contentArray.entries()) {
|
||||
if (char === this.SQUARE_BRACKET_OPEN) {
|
||||
if (this.numberToken.length === 0) {
|
||||
this.startToken.push(char);
|
||||
} else {
|
||||
result += this.flush() + char;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (char === this.SQUARE_BRACKET_CLOSE) {
|
||||
this.endToken.push(char);
|
||||
if (this.startToken.length === this.endToken.length) {
|
||||
const cIndex = Number(this.numberToken.join('').trim());
|
||||
if (
|
||||
cIndex > 0 &&
|
||||
cIndex <= citations.length &&
|
||||
contentArray[index + 1] !== this.PARENTHESES_OPEN
|
||||
) {
|
||||
const content = `[[${cIndex}](${citations[cIndex - 1]})]`;
|
||||
result += content;
|
||||
this.resetToken();
|
||||
} else {
|
||||
result += this.flush();
|
||||
}
|
||||
} else if (this.startToken.length < this.endToken.length) {
|
||||
result += this.flush();
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (this.isNumeric(char)) {
|
||||
if (this.startToken.length > 0) {
|
||||
this.numberToken.push(char);
|
||||
} else {
|
||||
result += this.flush() + char;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (this.startToken.length > 0) {
|
||||
result += this.flush() + char;
|
||||
} else {
|
||||
result += char;
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
public flush() {
|
||||
const content = this.getFullContent();
|
||||
this.resetToken();
|
||||
return content;
|
||||
}
|
||||
|
||||
private getFullContent() {
|
||||
return this.startToken.concat(this.numberToken, this.endToken).join('');
|
||||
}
|
||||
|
||||
private resetToken() {
|
||||
this.startToken = [];
|
||||
this.endToken = [];
|
||||
this.numberToken = [];
|
||||
}
|
||||
|
||||
private isNumeric(str: string) {
|
||||
return !isNaN(Number(str)) && str.trim() !== '';
|
||||
}
|
||||
}
|
||||
|
||||
export class PerplexityProvider implements CopilotTextToTextProvider {
|
||||
static readonly type = CopilotProviderType.Perplexity;
|
||||
|
||||
static readonly capabilities = [CopilotCapability.TextToText];
|
||||
|
||||
static assetsConfig(config: PerplexityConfig) {
|
||||
return !!config.apiKey;
|
||||
}
|
||||
|
||||
constructor(private readonly config: PerplexityConfig) {
|
||||
assert(PerplexityProvider.assetsConfig(config));
|
||||
}
|
||||
|
||||
readonly availableModels = [
|
||||
'llama-3.1-sonar-small-128k-online',
|
||||
'llama-3.1-sonar-large-128k-online',
|
||||
'llama-3.1-sonar-huge-128k-online',
|
||||
];
|
||||
|
||||
get type(): CopilotProviderType {
|
||||
return PerplexityProvider.type;
|
||||
}
|
||||
|
||||
getCapabilities(): CopilotCapability[] {
|
||||
return PerplexityProvider.capabilities;
|
||||
}
|
||||
|
||||
async isModelAvailable(model: string): Promise<boolean> {
|
||||
return this.availableModels.includes(model);
|
||||
}
|
||||
|
||||
async generateText(
|
||||
messages: PromptMessage[],
|
||||
model: string = 'llama-3.1-sonar-small-128k-online',
|
||||
options: CopilotChatOptions = {}
|
||||
): Promise<string> {
|
||||
await this.checkParams({ messages, model, options });
|
||||
try {
|
||||
metrics.ai.counter('chat_text_calls').add(1, { model });
|
||||
const sMessages = messages
|
||||
.map(({ content, role }) => ({ content, role }))
|
||||
.filter(({ content }) => typeof content === 'string');
|
||||
|
||||
const params = {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
Authorization: `Bearer ${this.config.apiKey}`,
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
model,
|
||||
messages: sMessages,
|
||||
max_tokens: options.maxTokens || 4096,
|
||||
}),
|
||||
};
|
||||
const response = await fetch(
|
||||
this.config.endpoint || 'https://api.perplexity.ai/chat/completions',
|
||||
params
|
||||
);
|
||||
const data = PerplexitySchema.parse(await response.json());
|
||||
if ('detail' in data) {
|
||||
throw new CopilotProviderSideError({
|
||||
provider: this.type,
|
||||
kind: 'unexpected_response',
|
||||
message: data.detail[0].msg || 'Unexpected perplexity response',
|
||||
});
|
||||
} else {
|
||||
const parser = new CitationParser();
|
||||
const { content } = data.choices[0].message;
|
||||
const { citations } = data;
|
||||
let result = parser.parse(content, citations);
|
||||
result += parser.flush();
|
||||
return result;
|
||||
}
|
||||
} catch (e: any) {
|
||||
metrics.ai.counter('chat_text_errors').add(1, { model });
|
||||
throw this.handleError(e);
|
||||
}
|
||||
}
|
||||
|
||||
async *generateTextStream(
|
||||
messages: PromptMessage[],
|
||||
model: string = 'llama-3.1-sonar-small-128k-online',
|
||||
options: CopilotChatOptions = {}
|
||||
): AsyncIterable<string> {
|
||||
await this.checkParams({ messages, model, options });
|
||||
try {
|
||||
metrics.ai.counter('chat_text_stream_calls').add(1, { model });
|
||||
const sMessages = messages
|
||||
.map(({ content, role }) => ({ content, role }))
|
||||
.filter(({ content }) => typeof content === 'string');
|
||||
|
||||
const params = {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
Authorization: `Bearer ${this.config.apiKey}`,
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
model,
|
||||
messages: sMessages,
|
||||
max_tokens: options.maxTokens || 4096,
|
||||
stream: true,
|
||||
}),
|
||||
};
|
||||
const response = await fetch(
|
||||
this.config.endpoint || 'https://api.perplexity.ai/chat/completions',
|
||||
params
|
||||
);
|
||||
if (response.body) {
|
||||
const parser = new CitationParser();
|
||||
const provider = this.type;
|
||||
const eventStream = response.body
|
||||
.pipeThrough(new TextDecoderStream())
|
||||
.pipeThrough(new EventSourceParserStream())
|
||||
.pipeThrough(
|
||||
new TransformStream({
|
||||
transform(chunk, controller) {
|
||||
if (options.signal?.aborted) {
|
||||
controller.enqueue(null);
|
||||
return;
|
||||
}
|
||||
const json = JSON.parse(chunk.data);
|
||||
if (json) {
|
||||
const data = PerplexitySchema.parse(json);
|
||||
if ('detail' in data) {
|
||||
throw new CopilotProviderSideError({
|
||||
provider,
|
||||
kind: 'unexpected_response',
|
||||
message:
|
||||
data.detail[0].msg || 'Unexpected perplexity response',
|
||||
});
|
||||
}
|
||||
const { content } = data.choices[0].delta;
|
||||
const { citations } = data;
|
||||
const result = parser.parse(content, citations);
|
||||
controller.enqueue(result);
|
||||
}
|
||||
},
|
||||
flush(controller) {
|
||||
controller.enqueue(parser.flush());
|
||||
controller.enqueue(null);
|
||||
},
|
||||
})
|
||||
);
|
||||
|
||||
const reader = eventStream.getReader();
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) break;
|
||||
yield value;
|
||||
}
|
||||
} else {
|
||||
const result = await this.generateText(messages, model, options);
|
||||
yield result;
|
||||
}
|
||||
} catch (e) {
|
||||
metrics.ai.counter('chat_text_stream_errors').add(1, { model });
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
protected async checkParams({
|
||||
model,
|
||||
}: {
|
||||
messages?: PromptMessage[];
|
||||
embeddings?: string[];
|
||||
model: string;
|
||||
options: CopilotChatOptions;
|
||||
}) {
|
||||
if (!(await this.isModelAvailable(model))) {
|
||||
throw new CopilotPromptInvalid(`Invalid model: ${model}`);
|
||||
}
|
||||
}
|
||||
|
||||
private handleError(e: any) {
|
||||
if (e instanceof CopilotProviderSideError) {
|
||||
return e;
|
||||
}
|
||||
return new CopilotProviderSideError({
|
||||
provider: this.type,
|
||||
kind: 'unexpected_response',
|
||||
message: e?.message || 'Unexpected perplexity response',
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -22,6 +22,7 @@ import GraphQLUpload from 'graphql-upload/GraphQLUpload.mjs';
|
||||
import {
|
||||
CallMetric,
|
||||
CopilotFailedToCreateMessage,
|
||||
CopilotSessionNotFound,
|
||||
FileUpload,
|
||||
RequestMutex,
|
||||
Throttle,
|
||||
@@ -62,6 +63,17 @@ class CreateChatSessionInput {
|
||||
promptName!: string;
|
||||
}
|
||||
|
||||
@InputType()
|
||||
class UpdateChatSessionInput {
|
||||
@Field(() => String)
|
||||
sessionId!: string;
|
||||
|
||||
@Field(() => String, {
|
||||
description: 'The prompt name to use for the session',
|
||||
})
|
||||
promptName!: string;
|
||||
}
|
||||
|
||||
@InputType()
|
||||
class ForkChatSessionInput {
|
||||
@Field(() => String)
|
||||
@@ -372,6 +384,38 @@ export class CopilotResolver {
|
||||
});
|
||||
}
|
||||
|
||||
@Mutation(() => String, {
|
||||
description: 'Update a chat session',
|
||||
})
|
||||
@CallMetric('ai', 'chat_session_update')
|
||||
async updateCopilotSession(
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@Args({ name: 'options', type: () => UpdateChatSessionInput })
|
||||
options: UpdateChatSessionInput
|
||||
) {
|
||||
const session = await this.chatSession.get(options.sessionId);
|
||||
if (!session) {
|
||||
throw new CopilotSessionNotFound();
|
||||
}
|
||||
const { workspaceId, docId } = session.config;
|
||||
await this.permissions.checkCloudPagePermission(
|
||||
workspaceId,
|
||||
docId,
|
||||
user.id
|
||||
);
|
||||
const lockFlag = `${COPILOT_LOCKER}:session:${user.id}:${workspaceId}`;
|
||||
await using lock = await this.mutex.acquire(lockFlag);
|
||||
if (!lock) {
|
||||
return new TooManyRequest('Server is busy');
|
||||
}
|
||||
|
||||
await this.chatSession.checkQuota(user.id);
|
||||
return await this.chatSession.updateSessionPrompt({
|
||||
...options,
|
||||
userId: user.id,
|
||||
});
|
||||
}
|
||||
|
||||
@Mutation(() => String, {
|
||||
description: 'Create a chat session',
|
||||
})
|
||||
|
||||
@@ -10,6 +10,7 @@ import {
|
||||
CopilotQuotaExceeded,
|
||||
CopilotSessionDeleted,
|
||||
CopilotSessionNotFound,
|
||||
PrismaTransaction,
|
||||
} from '../../base';
|
||||
import { FeatureManagementService } from '../../core/features';
|
||||
import { QuotaService } from '../../core/quota';
|
||||
@@ -22,6 +23,7 @@ import {
|
||||
ChatMessageSchema,
|
||||
ChatSessionForkOptions,
|
||||
ChatSessionOptions,
|
||||
ChatSessionPromptUpdateOptions,
|
||||
ChatSessionState,
|
||||
getTokenEncoder,
|
||||
ListHistoriesOptions,
|
||||
@@ -198,6 +200,22 @@ export class ChatSessionService {
|
||||
private readonly prompt: PromptService
|
||||
) {}
|
||||
|
||||
private async haveSession(
|
||||
sessionId: string,
|
||||
userId: string,
|
||||
tx?: PrismaTransaction
|
||||
) {
|
||||
const executor = tx ?? this.db;
|
||||
return await executor.aiSession
|
||||
.count({
|
||||
where: {
|
||||
id: sessionId,
|
||||
userId,
|
||||
},
|
||||
})
|
||||
.then(c => c > 0);
|
||||
}
|
||||
|
||||
private async setSession(state: ChatSessionState): Promise<string> {
|
||||
return await this.db.$transaction(async tx => {
|
||||
let sessionId = state.sessionId;
|
||||
@@ -226,15 +244,7 @@ export class ChatSessionService {
|
||||
if (id) sessionId = id;
|
||||
}
|
||||
|
||||
const haveSession = await tx.aiSession
|
||||
.count({
|
||||
where: {
|
||||
id: sessionId,
|
||||
userId: state.userId,
|
||||
},
|
||||
})
|
||||
.then(c => c > 0);
|
||||
|
||||
const haveSession = await this.haveSession(sessionId, state.userId, tx);
|
||||
if (haveSession) {
|
||||
// message will only exists when setSession call by session.save
|
||||
if (state.messages.length) {
|
||||
@@ -570,6 +580,27 @@ export class ChatSessionService {
|
||||
});
|
||||
}
|
||||
|
||||
async updateSessionPrompt(
|
||||
options: ChatSessionPromptUpdateOptions
|
||||
): Promise<string> {
|
||||
const prompt = await this.prompt.get(options.promptName);
|
||||
if (!prompt) {
|
||||
this.logger.error(`Prompt not found: ${options.promptName}`);
|
||||
throw new CopilotPromptNotFound({ name: options.promptName });
|
||||
}
|
||||
return await this.db.$transaction(async tx => {
|
||||
let sessionId = options.sessionId;
|
||||
const haveSession = await this.haveSession(sessionId, options.userId, tx);
|
||||
if (haveSession) {
|
||||
await tx.aiSession.update({
|
||||
where: { id: sessionId },
|
||||
data: { promptName: prompt.name },
|
||||
});
|
||||
}
|
||||
return sessionId;
|
||||
});
|
||||
}
|
||||
|
||||
async fork(options: ChatSessionForkOptions): Promise<string> {
|
||||
const state = await this.getSession(options.sessionId);
|
||||
if (!state) {
|
||||
|
||||
@@ -123,6 +123,11 @@ export interface ChatSessionOptions {
|
||||
promptName: string;
|
||||
}
|
||||
|
||||
export interface ChatSessionPromptUpdateOptions
|
||||
extends Pick<ChatSessionState, 'sessionId' | 'userId'> {
|
||||
promptName: string;
|
||||
}
|
||||
|
||||
export interface ChatSessionForkOptions
|
||||
extends Omit<ChatSessionOptions, 'promptName'> {
|
||||
sessionId: string;
|
||||
@@ -154,6 +159,7 @@ export type ListHistoriesOptions = {
|
||||
export enum CopilotProviderType {
|
||||
FAL = 'fal',
|
||||
OpenAI = 'openai',
|
||||
Perplexity = 'perplexity',
|
||||
// only for test
|
||||
Test = 'test',
|
||||
}
|
||||
|
||||
@@ -551,6 +551,9 @@ type Mutation {
|
||||
|
||||
"""Update a copilot prompt"""
|
||||
updateCopilotPrompt(messages: [CopilotPromptMessageInput!]!, name: String!): CopilotPromptType!
|
||||
|
||||
"""Update a chat session"""
|
||||
updateCopilotSession(options: UpdateChatSessionInput!): String!
|
||||
updateProfile(input: UpdateUserInput!): UserType!
|
||||
|
||||
"""update server runtime configurable setting"""
|
||||
@@ -865,6 +868,12 @@ type UnsupportedSubscriptionPlanDataType {
|
||||
plan: String!
|
||||
}
|
||||
|
||||
input UpdateChatSessionInput {
|
||||
"""The prompt name to use for the session"""
|
||||
promptName: String!
|
||||
sessionId: String!
|
||||
}
|
||||
|
||||
input UpdateUserInput {
|
||||
"""User name"""
|
||||
name: String
|
||||
|
||||
@@ -13,6 +13,7 @@ import {
|
||||
CopilotProviderService,
|
||||
FalProvider,
|
||||
OpenAIProvider,
|
||||
PerplexityProvider,
|
||||
registerCopilotProvider,
|
||||
unregisterCopilotProvider,
|
||||
} from '../src/plugins/copilot/providers';
|
||||
@@ -47,8 +48,10 @@ const test = ava as TestFn<Tester>;
|
||||
const isCopilotConfigured =
|
||||
!!process.env.COPILOT_OPENAI_API_KEY &&
|
||||
!!process.env.COPILOT_FAL_API_KEY &&
|
||||
!!process.env.COPILOT_PERPLEXITY_API_KEY &&
|
||||
process.env.COPILOT_OPENAI_API_KEY !== '1' &&
|
||||
process.env.COPILOT_FAL_API_KEY !== '1';
|
||||
process.env.COPILOT_FAL_API_KEY !== '1' &&
|
||||
process.env.COPILOT_PERPLEXITY_API_KEY !== '1';
|
||||
const runIfCopilotConfigured = test.macro(
|
||||
async (
|
||||
t,
|
||||
@@ -75,6 +78,9 @@ test.serial.before(async t => {
|
||||
fal: {
|
||||
apiKey: process.env.COPILOT_FAL_API_KEY,
|
||||
},
|
||||
perplexity: {
|
||||
apiKey: process.env.COPILOT_PERPLEXITY_API_KEY,
|
||||
},
|
||||
},
|
||||
},
|
||||
}),
|
||||
@@ -111,6 +117,7 @@ test.serial.before(async t => {
|
||||
|
||||
registerCopilotProvider(OpenAIProvider);
|
||||
registerCopilotProvider(FalProvider);
|
||||
registerCopilotProvider(PerplexityProvider);
|
||||
|
||||
for (const name of await prompt.listNames()) {
|
||||
await prompt.delete(name);
|
||||
@@ -124,6 +131,7 @@ test.serial.before(async t => {
|
||||
test.after(async _ => {
|
||||
unregisterCopilotProvider(OpenAIProvider.type);
|
||||
unregisterCopilotProvider(FalProvider.type);
|
||||
unregisterCopilotProvider(PerplexityProvider.type);
|
||||
});
|
||||
|
||||
test.after(async t => {
|
||||
@@ -152,7 +160,6 @@ const checkMDList = (text: string) => {
|
||||
return false;
|
||||
}
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-non-null-asserted-optional-chain
|
||||
const currentIndent = line.match(/^( *)/)?.[0].length!;
|
||||
if (Number.isNaN(currentIndent) || currentIndent % 2 !== 0) {
|
||||
return false;
|
||||
@@ -282,6 +289,8 @@ const actions = [
|
||||
'Make it longer',
|
||||
'Make it shorter',
|
||||
'Continue writing',
|
||||
'Chat With AFFiNE AI',
|
||||
'Search With AFFiNE AI',
|
||||
],
|
||||
messages: [{ role: 'user' as const, content: TestAssets.SSOT }],
|
||||
verifier: (t: ExecutionContext<Tester>, result: string) => {
|
||||
|
||||
@@ -16,6 +16,7 @@ import {
|
||||
CopilotProviderService,
|
||||
FalProvider,
|
||||
OpenAIProvider,
|
||||
PerplexityProvider,
|
||||
registerCopilotProvider,
|
||||
unregisterCopilotProvider,
|
||||
} from '../src/plugins/copilot/providers';
|
||||
@@ -41,6 +42,7 @@ import {
|
||||
sse2array,
|
||||
textToEventStream,
|
||||
unsplashSearch,
|
||||
updateCopilotSession,
|
||||
} from './utils/copilot';
|
||||
|
||||
const test = ava as TestFn<{
|
||||
@@ -63,6 +65,9 @@ test.beforeEach(async t => {
|
||||
fal: {
|
||||
apiKey: '1',
|
||||
},
|
||||
perplexity: {
|
||||
apiKey: '1',
|
||||
},
|
||||
unsplashKey: process.env.UNSPLASH_ACCESS_KEY || '1',
|
||||
},
|
||||
},
|
||||
@@ -91,6 +96,7 @@ test.beforeEach(async t => {
|
||||
|
||||
unregisterCopilotProvider(OpenAIProvider.type);
|
||||
unregisterCopilotProvider(FalProvider.type);
|
||||
unregisterCopilotProvider(PerplexityProvider.type);
|
||||
registerCopilotProvider(MockCopilotTestProvider);
|
||||
|
||||
await prompt.set(promptName, 'test', [
|
||||
@@ -156,6 +162,85 @@ test('should create session correctly', async t => {
|
||||
}
|
||||
});
|
||||
|
||||
test('should update session correctly', async t => {
|
||||
const { app } = t.context;
|
||||
|
||||
const assertUpdateSession = async (
|
||||
sessionId: string,
|
||||
error: string,
|
||||
asserter = async (x: any) => {
|
||||
t.truthy(await x, error);
|
||||
}
|
||||
) => {
|
||||
await asserter(updateCopilotSession(app, token, sessionId, promptName));
|
||||
};
|
||||
|
||||
{
|
||||
const { id: workspaceId } = await createWorkspace(app, token);
|
||||
const docId = randomUUID();
|
||||
const sessionId = await createCopilotSession(
|
||||
app,
|
||||
token,
|
||||
workspaceId,
|
||||
docId,
|
||||
promptName
|
||||
);
|
||||
await assertUpdateSession(
|
||||
sessionId,
|
||||
'should be able to update session with cloud workspace that user can access'
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
const sessionId = await createCopilotSession(
|
||||
app,
|
||||
token,
|
||||
randomUUID(),
|
||||
randomUUID(),
|
||||
promptName
|
||||
);
|
||||
await assertUpdateSession(
|
||||
sessionId,
|
||||
'should be able to update session with local workspace'
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
const aToken = (await signUp(app, 'test', 'test@affine.pro', '123456'))
|
||||
.token.token;
|
||||
const { id: workspaceId } = await createWorkspace(app, aToken);
|
||||
const inviteId = await inviteUser(
|
||||
app,
|
||||
aToken,
|
||||
workspaceId,
|
||||
'darksky@affine.pro'
|
||||
);
|
||||
await acceptInviteById(app, workspaceId, inviteId, false);
|
||||
const sessionId = await createCopilotSession(
|
||||
app,
|
||||
token,
|
||||
workspaceId,
|
||||
randomUUID(),
|
||||
promptName
|
||||
);
|
||||
await assertUpdateSession(
|
||||
sessionId,
|
||||
'should able to update session after user have permission'
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
const sessionId = '123456';
|
||||
await assertUpdateSession(sessionId, '', async x => {
|
||||
await t.throwsAsync(
|
||||
x,
|
||||
{ instanceOf: Error },
|
||||
'should not able to update invalid session id'
|
||||
);
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
test('should fork session correctly', async t => {
|
||||
const { app } = t.context;
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ import {
|
||||
registerCopilotProvider,
|
||||
unregisterCopilotProvider,
|
||||
} from '../src/plugins/copilot/providers';
|
||||
import { CitationParser } from '../src/plugins/copilot/providers/perplexity';
|
||||
import { ChatSessionService } from '../src/plugins/copilot/session';
|
||||
import {
|
||||
CopilotCapability,
|
||||
@@ -68,7 +69,10 @@ test.beforeEach(async t => {
|
||||
apiKey: process.env.COPILOT_OPENAI_API_KEY ?? '1',
|
||||
},
|
||||
fal: {
|
||||
apiKey: '1',
|
||||
apiKey: process.env.COPILOT_FAL_API_KEY ?? '1',
|
||||
},
|
||||
perplexity: {
|
||||
apiKey: process.env.COPILOT_PERPLEXITY_API_KEY ?? '1',
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -274,6 +278,41 @@ test('should be able to manage chat session', async t => {
|
||||
}
|
||||
});
|
||||
|
||||
test('should be able to update chat session prompt', async t => {
|
||||
const { prompt, session } = t.context;
|
||||
|
||||
// Set up a prompt to be used in the session
|
||||
await prompt.set('prompt', 'model', [
|
||||
{ role: 'system', content: 'hello {{word}}' },
|
||||
]);
|
||||
|
||||
// Create a session
|
||||
const sessionId = await session.create({
|
||||
promptName: 'prompt',
|
||||
docId: 'test',
|
||||
workspaceId: 'test',
|
||||
userId,
|
||||
});
|
||||
t.truthy(sessionId, 'should create session');
|
||||
|
||||
// Update the session
|
||||
const updatedSessionId = await session.updateSessionPrompt({
|
||||
sessionId,
|
||||
promptName: 'Search With AFFiNE AI',
|
||||
userId,
|
||||
});
|
||||
t.is(updatedSessionId, sessionId, 'should update session with same id');
|
||||
|
||||
// Verify the session was updated
|
||||
const updatedSession = await session.get(sessionId);
|
||||
t.truthy(updatedSession, 'should retrieve updated session');
|
||||
t.is(
|
||||
updatedSession?.config.promptName,
|
||||
'Search With AFFiNE AI',
|
||||
'should have updated prompt name'
|
||||
);
|
||||
});
|
||||
|
||||
test('should be able to fork chat session', async t => {
|
||||
const { auth, prompt, session } = t.context;
|
||||
|
||||
@@ -1050,3 +1089,88 @@ test('should be able to run image executor', async t => {
|
||||
unregisterCopilotProvider(MockCopilotTestProvider.type);
|
||||
registerCopilotProvider(OpenAIProvider);
|
||||
});
|
||||
|
||||
test('CitationParser should replace citation placeholders with URLs', t => {
|
||||
const content =
|
||||
'This is [a] test sentence with [citations [1]] and [[2]] and [3].';
|
||||
const citations = ['https://example1.com', 'https://example2.com'];
|
||||
|
||||
const parser = new CitationParser();
|
||||
const result = parser.parse(content, citations);
|
||||
|
||||
const expected =
|
||||
'This is [a] test sentence with [citations [[1](https://example1.com)]] and [[2](https://example2.com)] and [3].';
|
||||
t.is(result, expected);
|
||||
});
|
||||
|
||||
test('CitationParser should replace chunks of citation placeholders with URLs', t => {
|
||||
const contents = [
|
||||
'[[]]',
|
||||
'This is [',
|
||||
'a] test sentence ',
|
||||
'with citations [1',
|
||||
'] and [',
|
||||
'[2]] and [[',
|
||||
'3]] and [[4',
|
||||
']] and [[5]',
|
||||
'] and [[6]]',
|
||||
' and [7',
|
||||
];
|
||||
const citations = [
|
||||
'https://example1.com',
|
||||
'https://example2.com',
|
||||
'https://example3.com',
|
||||
'https://example4.com',
|
||||
'https://example5.com',
|
||||
'https://example6.com',
|
||||
'https://example7.com',
|
||||
];
|
||||
|
||||
const parser = new CitationParser();
|
||||
let result = contents.reduce((acc, current) => {
|
||||
return acc + parser.parse(current, citations);
|
||||
}, '');
|
||||
result += parser.flush();
|
||||
|
||||
const expected =
|
||||
'[[]]This is [a] test sentence with citations [[1](https://example1.com)] and [[2](https://example2.com)] and [[3](https://example3.com)] and [[4](https://example4.com)] and [[5](https://example5.com)] and [[6](https://example6.com)] and [7';
|
||||
t.is(result, expected);
|
||||
});
|
||||
|
||||
test('CitationParser should not replace citation already with URLs', t => {
|
||||
const content =
|
||||
'This is [a] test sentence with citations [1](https://example1.com) and [[2]](https://example2.com) and [[3](https://example3.com)].';
|
||||
const citations = [
|
||||
'https://example4.com',
|
||||
'https://example5.com',
|
||||
'https://example6.com',
|
||||
];
|
||||
|
||||
const parser = new CitationParser();
|
||||
const result = parser.parse(content, citations);
|
||||
|
||||
const expected = content;
|
||||
t.is(result, expected);
|
||||
});
|
||||
|
||||
test('CitationParser should not replace chunks of citation already with URLs', t => {
|
||||
const contents = [
|
||||
'This is [a] test sentence with citations [1',
|
||||
'](https://example1.com) and [[2]',
|
||||
'](https://example2.com) and [[3](https://example3.com)].',
|
||||
];
|
||||
const citations = [
|
||||
'https://example4.com',
|
||||
'https://example5.com',
|
||||
'https://example6.com',
|
||||
];
|
||||
|
||||
const parser = new CitationParser();
|
||||
let result = contents.reduce((acc, current) => {
|
||||
return acc + parser.parse(current, citations);
|
||||
}, '');
|
||||
result += parser.flush();
|
||||
|
||||
const expected = contents.join('');
|
||||
t.is(result, expected);
|
||||
});
|
||||
|
||||
@@ -184,6 +184,31 @@ export async function createCopilotSession(
|
||||
return res.body.data.createCopilotSession;
|
||||
}
|
||||
|
||||
export async function updateCopilotSession(
|
||||
app: INestApplication,
|
||||
userToken: string,
|
||||
sessionId: string,
|
||||
promptName: string
|
||||
): Promise<string> {
|
||||
const res = await request(app.getHttpServer())
|
||||
.post(gql)
|
||||
.auth(userToken, { type: 'bearer' })
|
||||
.set({ 'x-request-id': 'test', 'x-operation-name': 'test' })
|
||||
.send({
|
||||
query: `
|
||||
mutation updateCopilotSession($options: UpdateChatSessionInput!) {
|
||||
updateCopilotSession(options: $options)
|
||||
}
|
||||
`,
|
||||
variables: { options: { sessionId, promptName } },
|
||||
})
|
||||
.expect(200);
|
||||
|
||||
handleGraphQLError(res);
|
||||
|
||||
return res.body.data.updateCopilotSession;
|
||||
}
|
||||
|
||||
export async function forkCopilotSession(
|
||||
app: INestApplication,
|
||||
userToken: string,
|
||||
|
||||
Reference in New Issue
Block a user