feat(core): support ai network search (#9357)

### What Changed?
- Add `PerplexityProvider` in backend.
- Update session prompt name if user toggle network search mode in chat panel.
- Add experimental flag for AI network search feature.
- Add unit tests and e2e tests.

Search results are streamed and appear word for word:

<div class='graphite__hidden'>
          <div>🎥 Video uploaded on Graphite:</div>
            <a href="https://app.graphite.dev/media/video/sJGviKxfE3Ap685cl5bj/56f6ec7b-4b21-405f-9612-43e083f6fb84.mov">
              <img src="https://app.graphite.dev/api/v1/graphite/video/thumbnail/sJGviKxfE3Ap685cl5bj/56f6ec7b-4b21-405f-9612-43e083f6fb84.mov">
            </a>
          </div>
<video src="https://graphite-user-uploaded-assets-prod.s3.amazonaws.com/sJGviKxfE3Ap685cl5bj/56f6ec7b-4b21-405f-9612-43e083f6fb84.mov">录屏2024-12-27 18.58.40.mov</video>

Click the little globe icon to manually turn on/off Internet search:
<div class='graphite__hidden'>
          <div>🎥 Video uploaded on Graphite:</div>
            <a href="https://app.graphite.dev/media/video/sJGviKxfE3Ap685cl5bj/778f1406-bf29-498e-a90d-7dad813392d1.mov">
              <img src="https://app.graphite.dev/api/v1/graphite/video/thumbnail/sJGviKxfE3Ap685cl5bj/778f1406-bf29-498e-a90d-7dad813392d1.mov">
            </a>
          </div>
<video src="https://graphite-user-uploaded-assets-prod.s3.amazonaws.com/sJGviKxfE3Ap685cl5bj/778f1406-bf29-498e-a90d-7dad813392d1.mov">录屏2024-12-27 19.01.16.mov</video>

When there is an image, it will automatically switch to the openai model:

<div class='graphite__hidden'>
          <div>🎥 Video uploaded on Graphite:</div>
            <a href="https://app.graphite.dev/media/video/sJGviKxfE3Ap685cl5bj/56431d8e-75e1-4d84-ab4a-b6636042cc6a.mov">
              <img src="https://app.graphite.dev/api/v1/graphite/video/thumbnail/sJGviKxfE3Ap685cl5bj/56431d8e-75e1-4d84-ab4a-b6636042cc6a.mov">
            </a>
          </div>
<video src="https://graphite-user-uploaded-assets-prod.s3.amazonaws.com/sJGviKxfE3Ap685cl5bj/56431d8e-75e1-4d84-ab4a-b6636042cc6a.mov">录屏2024-12-27 19.02.13.mov</video>
This commit is contained in:
akumatus
2025-01-09 04:00:58 +00:00
parent 4f10457815
commit 58ce86533e
49 changed files with 1274 additions and 169 deletions

View File

@@ -28,6 +28,7 @@ AFFiNE.ENV_MAP = {
CAPTCHA_TURNSTILE_SECRET: ['plugins.captcha.turnstile.secret', 'string'],
COPILOT_OPENAI_API_KEY: 'plugins.copilot.openai.apiKey',
COPILOT_FAL_API_KEY: 'plugins.copilot.fal.apiKey',
COPILOT_PERPLEXITY_API_KEY: 'plugins.copilot.perplexity.apiKey',
COPILOT_UNSPLASH_API_KEY: 'plugins.copilot.unsplashKey',
REDIS_SERVER_HOST: 'redis.host',
REDIS_SERVER_PORT: ['redis.port', 'int'],

View File

@@ -3,10 +3,12 @@ import type { ClientOptions as OpenAIClientOptions } from 'openai';
import { defineStartupConfig, ModuleConfig } from '../../base/config';
import { StorageConfig } from '../../base/storage/config';
import type { FalConfig } from './providers/fal';
import { PerplexityConfig } from './providers/perplexity';
export interface CopilotStartupConfigurations {
openai?: OpenAIClientOptions;
fal?: FalConfig;
perplexity?: PerplexityConfig;
test?: never;
unsplashKey?: string;
storage: StorageConfig;

View File

@@ -13,6 +13,7 @@ import {
CopilotProviderService,
FalProvider,
OpenAIProvider,
PerplexityProvider,
registerCopilotProvider,
} from './providers';
import {
@@ -26,6 +27,7 @@ import { CopilotWorkflowExecutors, CopilotWorkflowService } from './workflow';
registerCopilotProvider(FalProvider);
registerCopilotProvider(OpenAIProvider);
registerCopilotProvider(PerplexityProvider);
@Plugin({
name: 'copilot',

View File

@@ -952,6 +952,11 @@ const chat: Prompt[] = [
},
],
},
{
name: 'Search With AFFiNE AI',
model: 'llama-3.1-sonar-small-128k-online',
messages: [],
},
// use for believer plan
{
name: 'Chat With AFFiNE AI - Believer',

View File

@@ -124,9 +124,7 @@ export class CopilotProviderService {
if (!this.cachedProviders.has(provider)) {
this.cachedProviders.set(provider, this.create(provider));
}
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
return this.cachedProviders.get(provider)!;
return this.cachedProviders.get(provider) as CopilotProvider;
}
async getProviderByCapability<C extends CopilotCapability>(
@@ -196,3 +194,4 @@ export class CopilotProviderService {
export { FalProvider } from './fal';
export { OpenAIProvider } from './openai';
export { PerplexityProvider } from './perplexity';

View File

@@ -0,0 +1,325 @@
import assert from 'node:assert';
import { EventSourceParserStream } from 'eventsource-parser/stream';
import { z } from 'zod';
import {
CopilotPromptInvalid,
CopilotProviderSideError,
metrics,
} from '../../../base';
import {
CopilotCapability,
CopilotChatOptions,
CopilotProviderType,
CopilotTextToTextProvider,
PromptMessage,
} from '../types';
export type PerplexityConfig = {
apiKey: string;
endpoint?: string;
};
const PerplexityErrorSchema = z.object({
detail: z.array(
z.object({
loc: z.array(z.string()),
msg: z.string(),
type: z.string(),
})
),
});
const PerplexityDataSchema = z.object({
citations: z.array(z.string()),
choices: z.array(
z.object({
message: z.object({
content: z.string(),
role: z.literal('assistant'),
}),
delta: z.object({
content: z.string(),
role: z.literal('assistant'),
}),
finish_reason: z.union([z.literal('stop'), z.literal(null)]),
})
),
});
const PerplexitySchema = z.union([PerplexityDataSchema, PerplexityErrorSchema]);
export class CitationParser {
private readonly SQUARE_BRACKET_OPEN = '[';
private readonly SQUARE_BRACKET_CLOSE = ']';
private readonly PARENTHESES_OPEN = '(';
private startToken: string[] = [];
private endToken: string[] = [];
private numberToken: string[] = [];
public parse(content: string, citations: string[]) {
let result = '';
const contentArray = content.split('');
for (const [index, char] of contentArray.entries()) {
if (char === this.SQUARE_BRACKET_OPEN) {
if (this.numberToken.length === 0) {
this.startToken.push(char);
} else {
result += this.flush() + char;
}
continue;
}
if (char === this.SQUARE_BRACKET_CLOSE) {
this.endToken.push(char);
if (this.startToken.length === this.endToken.length) {
const cIndex = Number(this.numberToken.join('').trim());
if (
cIndex > 0 &&
cIndex <= citations.length &&
contentArray[index + 1] !== this.PARENTHESES_OPEN
) {
const content = `[[${cIndex}](${citations[cIndex - 1]})]`;
result += content;
this.resetToken();
} else {
result += this.flush();
}
} else if (this.startToken.length < this.endToken.length) {
result += this.flush();
}
continue;
}
if (this.isNumeric(char)) {
if (this.startToken.length > 0) {
this.numberToken.push(char);
} else {
result += this.flush() + char;
}
continue;
}
if (this.startToken.length > 0) {
result += this.flush() + char;
} else {
result += char;
}
}
return result;
}
public flush() {
const content = this.getFullContent();
this.resetToken();
return content;
}
private getFullContent() {
return this.startToken.concat(this.numberToken, this.endToken).join('');
}
private resetToken() {
this.startToken = [];
this.endToken = [];
this.numberToken = [];
}
private isNumeric(str: string) {
return !isNaN(Number(str)) && str.trim() !== '';
}
}
export class PerplexityProvider implements CopilotTextToTextProvider {
static readonly type = CopilotProviderType.Perplexity;
static readonly capabilities = [CopilotCapability.TextToText];
static assetsConfig(config: PerplexityConfig) {
return !!config.apiKey;
}
constructor(private readonly config: PerplexityConfig) {
assert(PerplexityProvider.assetsConfig(config));
}
readonly availableModels = [
'llama-3.1-sonar-small-128k-online',
'llama-3.1-sonar-large-128k-online',
'llama-3.1-sonar-huge-128k-online',
];
get type(): CopilotProviderType {
return PerplexityProvider.type;
}
getCapabilities(): CopilotCapability[] {
return PerplexityProvider.capabilities;
}
async isModelAvailable(model: string): Promise<boolean> {
return this.availableModels.includes(model);
}
async generateText(
messages: PromptMessage[],
model: string = 'llama-3.1-sonar-small-128k-online',
options: CopilotChatOptions = {}
): Promise<string> {
await this.checkParams({ messages, model, options });
try {
metrics.ai.counter('chat_text_calls').add(1, { model });
const sMessages = messages
.map(({ content, role }) => ({ content, role }))
.filter(({ content }) => typeof content === 'string');
const params = {
method: 'POST',
headers: {
Authorization: `Bearer ${this.config.apiKey}`,
'Content-Type': 'application/json',
},
body: JSON.stringify({
model,
messages: sMessages,
max_tokens: options.maxTokens || 4096,
}),
};
const response = await fetch(
this.config.endpoint || 'https://api.perplexity.ai/chat/completions',
params
);
const data = PerplexitySchema.parse(await response.json());
if ('detail' in data) {
throw new CopilotProviderSideError({
provider: this.type,
kind: 'unexpected_response',
message: data.detail[0].msg || 'Unexpected perplexity response',
});
} else {
const parser = new CitationParser();
const { content } = data.choices[0].message;
const { citations } = data;
let result = parser.parse(content, citations);
result += parser.flush();
return result;
}
} catch (e: any) {
metrics.ai.counter('chat_text_errors').add(1, { model });
throw this.handleError(e);
}
}
async *generateTextStream(
messages: PromptMessage[],
model: string = 'llama-3.1-sonar-small-128k-online',
options: CopilotChatOptions = {}
): AsyncIterable<string> {
await this.checkParams({ messages, model, options });
try {
metrics.ai.counter('chat_text_stream_calls').add(1, { model });
const sMessages = messages
.map(({ content, role }) => ({ content, role }))
.filter(({ content }) => typeof content === 'string');
const params = {
method: 'POST',
headers: {
Authorization: `Bearer ${this.config.apiKey}`,
'Content-Type': 'application/json',
},
body: JSON.stringify({
model,
messages: sMessages,
max_tokens: options.maxTokens || 4096,
stream: true,
}),
};
const response = await fetch(
this.config.endpoint || 'https://api.perplexity.ai/chat/completions',
params
);
if (response.body) {
const parser = new CitationParser();
const provider = this.type;
const eventStream = response.body
.pipeThrough(new TextDecoderStream())
.pipeThrough(new EventSourceParserStream())
.pipeThrough(
new TransformStream({
transform(chunk, controller) {
if (options.signal?.aborted) {
controller.enqueue(null);
return;
}
const json = JSON.parse(chunk.data);
if (json) {
const data = PerplexitySchema.parse(json);
if ('detail' in data) {
throw new CopilotProviderSideError({
provider,
kind: 'unexpected_response',
message:
data.detail[0].msg || 'Unexpected perplexity response',
});
}
const { content } = data.choices[0].delta;
const { citations } = data;
const result = parser.parse(content, citations);
controller.enqueue(result);
}
},
flush(controller) {
controller.enqueue(parser.flush());
controller.enqueue(null);
},
})
);
const reader = eventStream.getReader();
while (true) {
const { done, value } = await reader.read();
if (done) break;
yield value;
}
} else {
const result = await this.generateText(messages, model, options);
yield result;
}
} catch (e) {
metrics.ai.counter('chat_text_stream_errors').add(1, { model });
throw e;
}
}
protected async checkParams({
model,
}: {
messages?: PromptMessage[];
embeddings?: string[];
model: string;
options: CopilotChatOptions;
}) {
if (!(await this.isModelAvailable(model))) {
throw new CopilotPromptInvalid(`Invalid model: ${model}`);
}
}
private handleError(e: any) {
if (e instanceof CopilotProviderSideError) {
return e;
}
return new CopilotProviderSideError({
provider: this.type,
kind: 'unexpected_response',
message: e?.message || 'Unexpected perplexity response',
});
}
}

View File

@@ -22,6 +22,7 @@ import GraphQLUpload from 'graphql-upload/GraphQLUpload.mjs';
import {
CallMetric,
CopilotFailedToCreateMessage,
CopilotSessionNotFound,
FileUpload,
RequestMutex,
Throttle,
@@ -62,6 +63,17 @@ class CreateChatSessionInput {
promptName!: string;
}
@InputType()
class UpdateChatSessionInput {
@Field(() => String)
sessionId!: string;
@Field(() => String, {
description: 'The prompt name to use for the session',
})
promptName!: string;
}
@InputType()
class ForkChatSessionInput {
@Field(() => String)
@@ -372,6 +384,38 @@ export class CopilotResolver {
});
}
@Mutation(() => String, {
description: 'Update a chat session',
})
@CallMetric('ai', 'chat_session_update')
async updateCopilotSession(
@CurrentUser() user: CurrentUser,
@Args({ name: 'options', type: () => UpdateChatSessionInput })
options: UpdateChatSessionInput
) {
const session = await this.chatSession.get(options.sessionId);
if (!session) {
throw new CopilotSessionNotFound();
}
const { workspaceId, docId } = session.config;
await this.permissions.checkCloudPagePermission(
workspaceId,
docId,
user.id
);
const lockFlag = `${COPILOT_LOCKER}:session:${user.id}:${workspaceId}`;
await using lock = await this.mutex.acquire(lockFlag);
if (!lock) {
return new TooManyRequest('Server is busy');
}
await this.chatSession.checkQuota(user.id);
return await this.chatSession.updateSessionPrompt({
...options,
userId: user.id,
});
}
@Mutation(() => String, {
description: 'Create a chat session',
})

View File

@@ -10,6 +10,7 @@ import {
CopilotQuotaExceeded,
CopilotSessionDeleted,
CopilotSessionNotFound,
PrismaTransaction,
} from '../../base';
import { FeatureManagementService } from '../../core/features';
import { QuotaService } from '../../core/quota';
@@ -22,6 +23,7 @@ import {
ChatMessageSchema,
ChatSessionForkOptions,
ChatSessionOptions,
ChatSessionPromptUpdateOptions,
ChatSessionState,
getTokenEncoder,
ListHistoriesOptions,
@@ -198,6 +200,22 @@ export class ChatSessionService {
private readonly prompt: PromptService
) {}
private async haveSession(
sessionId: string,
userId: string,
tx?: PrismaTransaction
) {
const executor = tx ?? this.db;
return await executor.aiSession
.count({
where: {
id: sessionId,
userId,
},
})
.then(c => c > 0);
}
private async setSession(state: ChatSessionState): Promise<string> {
return await this.db.$transaction(async tx => {
let sessionId = state.sessionId;
@@ -226,15 +244,7 @@ export class ChatSessionService {
if (id) sessionId = id;
}
const haveSession = await tx.aiSession
.count({
where: {
id: sessionId,
userId: state.userId,
},
})
.then(c => c > 0);
const haveSession = await this.haveSession(sessionId, state.userId, tx);
if (haveSession) {
// message will only exists when setSession call by session.save
if (state.messages.length) {
@@ -570,6 +580,27 @@ export class ChatSessionService {
});
}
async updateSessionPrompt(
options: ChatSessionPromptUpdateOptions
): Promise<string> {
const prompt = await this.prompt.get(options.promptName);
if (!prompt) {
this.logger.error(`Prompt not found: ${options.promptName}`);
throw new CopilotPromptNotFound({ name: options.promptName });
}
return await this.db.$transaction(async tx => {
let sessionId = options.sessionId;
const haveSession = await this.haveSession(sessionId, options.userId, tx);
if (haveSession) {
await tx.aiSession.update({
where: { id: sessionId },
data: { promptName: prompt.name },
});
}
return sessionId;
});
}
async fork(options: ChatSessionForkOptions): Promise<string> {
const state = await this.getSession(options.sessionId);
if (!state) {

View File

@@ -123,6 +123,11 @@ export interface ChatSessionOptions {
promptName: string;
}
export interface ChatSessionPromptUpdateOptions
extends Pick<ChatSessionState, 'sessionId' | 'userId'> {
promptName: string;
}
export interface ChatSessionForkOptions
extends Omit<ChatSessionOptions, 'promptName'> {
sessionId: string;
@@ -154,6 +159,7 @@ export type ListHistoriesOptions = {
export enum CopilotProviderType {
FAL = 'fal',
OpenAI = 'openai',
Perplexity = 'perplexity',
// only for test
Test = 'test',
}

View File

@@ -551,6 +551,9 @@ type Mutation {
"""Update a copilot prompt"""
updateCopilotPrompt(messages: [CopilotPromptMessageInput!]!, name: String!): CopilotPromptType!
"""Update a chat session"""
updateCopilotSession(options: UpdateChatSessionInput!): String!
updateProfile(input: UpdateUserInput!): UserType!
"""update server runtime configurable setting"""
@@ -865,6 +868,12 @@ type UnsupportedSubscriptionPlanDataType {
plan: String!
}
input UpdateChatSessionInput {
"""The prompt name to use for the session"""
promptName: String!
sessionId: String!
}
input UpdateUserInput {
"""User name"""
name: String