mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-12 12:28:42 +00:00
feat: improve copilot plugin (#3459)
This commit is contained in:
@@ -17,7 +17,7 @@ export const HeaderItem = (): ReactElement => {
|
||||
return {
|
||||
direction: 'horizontal',
|
||||
first: 'editor',
|
||||
second: '@affine/copilot',
|
||||
second: '@affine/copilot-plugin',
|
||||
splitPercentage: 70,
|
||||
};
|
||||
} else {
|
||||
|
||||
@@ -11,22 +11,29 @@ import {
|
||||
|
||||
import { IndexedDBChatMessageHistory } from './langchain/message-history';
|
||||
import { chatPrompt, followupQuestionPrompt } from './prompts';
|
||||
import { followupQuestionParser } from './prompts/output-parser';
|
||||
|
||||
declare global {
|
||||
interface WindowEventMap {
|
||||
'llm-start': CustomEvent;
|
||||
'llm-new-token': CustomEvent<{ token: string }>;
|
||||
}
|
||||
}
|
||||
type ChatAI = {
|
||||
// Core chat AI
|
||||
conversationChain: ConversationChain;
|
||||
// Followup AI, used to generate followup questions
|
||||
followupChain: LLMChain<string>;
|
||||
// Chat history, used to store messages
|
||||
chatHistory: IndexedDBChatMessageHistory;
|
||||
};
|
||||
|
||||
export type ChatAIConfig = {
|
||||
events: {
|
||||
llmStart: () => void;
|
||||
llmNewToken: (token: string) => void;
|
||||
};
|
||||
};
|
||||
|
||||
export async function createChatAI(
|
||||
room: string,
|
||||
openAIApiKey: string
|
||||
): Promise<{
|
||||
conversationChain: ConversationChain;
|
||||
followupChain: LLMChain<string>;
|
||||
chatHistory: IndexedDBChatMessageHistory;
|
||||
}> {
|
||||
openAIApiKey: string,
|
||||
config: ChatAIConfig
|
||||
): Promise<ChatAI> {
|
||||
if (!openAIApiKey) {
|
||||
console.warn('OpenAI API key not set, chat will not work');
|
||||
}
|
||||
@@ -44,25 +51,11 @@ export async function createChatAI(
|
||||
openAIApiKey: openAIApiKey,
|
||||
callbacks: [
|
||||
{
|
||||
async handleLLMStart(llm, prompts, runId, parentRunId, extraParams) {
|
||||
console.log(
|
||||
'handleLLMStart',
|
||||
llm,
|
||||
prompts,
|
||||
runId,
|
||||
parentRunId,
|
||||
extraParams
|
||||
);
|
||||
window.dispatchEvent(new CustomEvent('llm-start'));
|
||||
async handleLLMStart() {
|
||||
config.events.llmStart();
|
||||
},
|
||||
async handleLLMNewToken(token, runId, parentRunId) {
|
||||
console.log('handleLLMNewToken', token, runId, parentRunId);
|
||||
window.dispatchEvent(
|
||||
new CustomEvent('llm-new-token', { detail: { token } })
|
||||
);
|
||||
},
|
||||
async handleLLMEnd(output, runId, parentRunId) {
|
||||
console.log('handleLLMEnd', output, runId, parentRunId);
|
||||
async handleLLMNewToken(token) {
|
||||
config.events.llmNewToken(token);
|
||||
},
|
||||
},
|
||||
],
|
||||
@@ -77,6 +70,9 @@ export async function createChatAI(
|
||||
const followupPromptTemplate = new PromptTemplate({
|
||||
template: followupQuestionPrompt,
|
||||
inputVariables: ['human_conversation', 'ai_conversation'],
|
||||
partialVariables: {
|
||||
format_instructions: followupQuestionParser.getFormatInstructions(),
|
||||
},
|
||||
});
|
||||
|
||||
const followupChain = new LLMChain({
|
||||
@@ -101,5 +97,5 @@ export async function createChatAI(
|
||||
conversationChain,
|
||||
followupChain,
|
||||
chatHistory,
|
||||
} as const;
|
||||
};
|
||||
}
|
||||
|
||||
@@ -1,51 +1,56 @@
|
||||
import type { IndexedDBChatMessageHistory } from '@affine/copilot/core/langchain/message-history';
|
||||
import { atom, useAtomValue } from 'jotai';
|
||||
import { atomWithDefault, atomWithStorage } from 'jotai/utils';
|
||||
import type { WritableAtom } from 'jotai/vanilla';
|
||||
import type { PrimitiveAtom } from 'jotai/vanilla';
|
||||
import type { LLMChain } from 'langchain/chains';
|
||||
import { type ConversationChain } from 'langchain/chains';
|
||||
import { type BufferMemory } from 'langchain/memory';
|
||||
import type { BaseMessage } from 'langchain/schema';
|
||||
import { AIMessage } from 'langchain/schema';
|
||||
import { HumanMessage } from 'langchain/schema';
|
||||
import { z } from 'zod';
|
||||
|
||||
import type { ChatAIConfig } from '../chat';
|
||||
import { createChatAI } from '../chat';
|
||||
|
||||
const followupResponseSchema = z.array(z.string());
|
||||
import type { IndexedDBChatMessageHistory } from '../langchain/message-history';
|
||||
import { followupQuestionParser } from '../prompts/output-parser';
|
||||
|
||||
export const openAIApiKeyAtom = atomWithStorage<string | null>(
|
||||
'com.affine.copilot.openai.token',
|
||||
null
|
||||
);
|
||||
|
||||
export const chatAtom = atom(async get => {
|
||||
const openAIApiKey = get(openAIApiKeyAtom);
|
||||
if (!openAIApiKey) {
|
||||
throw new Error('OpenAI API key not set, chat will not work');
|
||||
}
|
||||
return createChatAI('default-copilot', openAIApiKey);
|
||||
});
|
||||
|
||||
const conversationBaseWeakMap = new WeakMap<
|
||||
ConversationChain,
|
||||
PrimitiveAtom<BaseMessage[]>
|
||||
>();
|
||||
const conversationWeakMap = new WeakMap<
|
||||
ConversationChain,
|
||||
WritableAtom<BaseMessage[], [string], Promise<void>>
|
||||
>();
|
||||
|
||||
const getConversationAtom = (chat: ConversationChain) => {
|
||||
if (conversationWeakMap.has(chat)) {
|
||||
return conversationWeakMap.get(chat) as WritableAtom<
|
||||
BaseMessage[],
|
||||
[string],
|
||||
Promise<void>
|
||||
>;
|
||||
export const chatAtom = atom(async get => {
|
||||
const openAIApiKey = get(openAIApiKeyAtom);
|
||||
if (!openAIApiKey) {
|
||||
throw new Error('OpenAI API key not set, chat will not work');
|
||||
}
|
||||
const conversationBaseAtom = atom<BaseMessage[]>([]);
|
||||
conversationBaseAtom.onMount = setAtom => {
|
||||
if (!chat) {
|
||||
throw new Error();
|
||||
}
|
||||
const memory = chat.memory as BufferMemory;
|
||||
const events: ChatAIConfig['events'] = {
|
||||
llmStart: () => {
|
||||
throw new Error('llmStart not set');
|
||||
},
|
||||
llmNewToken: () => {
|
||||
throw new Error('llmNewToken not set');
|
||||
},
|
||||
};
|
||||
const chatAI = await createChatAI('default-copilot', openAIApiKey, {
|
||||
events,
|
||||
});
|
||||
getOrCreateConversationAtom(chatAI.conversationChain);
|
||||
const baseAtom = conversationBaseWeakMap.get(chatAI.conversationChain);
|
||||
if (!baseAtom) {
|
||||
throw new TypeError();
|
||||
}
|
||||
baseAtom.onMount = setAtom => {
|
||||
const memory = chatAI.conversationChain.memory as BufferMemory;
|
||||
memory.chatHistory
|
||||
.getMessages()
|
||||
.then(messages => {
|
||||
@@ -54,23 +59,27 @@ const getConversationAtom = (chat: ConversationChain) => {
|
||||
.catch(err => {
|
||||
console.error(err);
|
||||
});
|
||||
const llmStart = (): void => {
|
||||
events.llmStart = () => {
|
||||
setAtom(conversations => [...conversations, new AIMessage('')]);
|
||||
};
|
||||
const llmNewToken = (event: CustomEvent<{ token: string }>): void => {
|
||||
events.llmNewToken = token => {
|
||||
setAtom(conversations => {
|
||||
const last = conversations[conversations.length - 1] as AIMessage;
|
||||
last.content += event.detail.token;
|
||||
last.content += token;
|
||||
return [...conversations];
|
||||
});
|
||||
};
|
||||
window.addEventListener('llm-start', llmStart);
|
||||
window.addEventListener('llm-new-token', llmNewToken);
|
||||
return () => {
|
||||
window.removeEventListener('llm-start', llmStart);
|
||||
window.removeEventListener('llm-new-token', llmNewToken);
|
||||
};
|
||||
};
|
||||
return chatAI;
|
||||
});
|
||||
|
||||
const getOrCreateConversationAtom = (chat: ConversationChain) => {
|
||||
if (conversationWeakMap.has(chat)) {
|
||||
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
||||
return conversationWeakMap.get(chat)!;
|
||||
}
|
||||
const conversationBaseAtom = atom<BaseMessage[]>([]);
|
||||
conversationBaseWeakMap.set(chat, conversationBaseAtom);
|
||||
|
||||
const conversationAtom = atom<BaseMessage[], [string], Promise<void>>(
|
||||
get => get(conversationBaseAtom),
|
||||
@@ -105,7 +114,9 @@ const getConversationAtom = (chat: ConversationChain) => {
|
||||
const followingUpWeakMap = new WeakMap<
|
||||
LLMChain<string>,
|
||||
{
|
||||
questionsAtom: ReturnType<typeof atomWithDefault<Promise<string[]>>>;
|
||||
questionsAtom: ReturnType<
|
||||
typeof atomWithDefault<Promise<string[]> | string[]>
|
||||
>;
|
||||
generateChatAtom: WritableAtom<null, [], void>;
|
||||
}
|
||||
>();
|
||||
@@ -115,12 +126,10 @@ const getFollowingUpAtoms = (
|
||||
chatHistory: IndexedDBChatMessageHistory
|
||||
) => {
|
||||
if (followingUpWeakMap.has(followupLLMChain)) {
|
||||
return followingUpWeakMap.get(followupLLMChain) as {
|
||||
questionsAtom: ReturnType<typeof atomWithDefault<Promise<string[]>>>;
|
||||
generateChatAtom: WritableAtom<null, [], void>;
|
||||
};
|
||||
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
||||
return followingUpWeakMap.get(followupLLMChain)!;
|
||||
}
|
||||
const baseAtom = atomWithDefault<Promise<string[]>>(async () => {
|
||||
const baseAtom = atomWithDefault<Promise<string[]> | string[]>(async () => {
|
||||
return chatHistory?.getFollowingUp() ?? [];
|
||||
});
|
||||
const setAtom = atom<null, [], void>(null, async (get, set) => {
|
||||
@@ -137,10 +146,9 @@ const getFollowingUpAtoms = (
|
||||
ai_conversation: aiMessage,
|
||||
human_conversation: humanMessage,
|
||||
});
|
||||
const followingUp = JSON.parse(response.text);
|
||||
followupResponseSchema.parse(followingUp);
|
||||
set(baseAtom, followingUp);
|
||||
chatHistory.saveFollowingUp(followingUp).catch(() => {
|
||||
const followingUp = await followupQuestionParser.parse(response.text);
|
||||
set(baseAtom, followingUp.followupQuestions);
|
||||
chatHistory.saveFollowingUp(followingUp.followupQuestions).catch(() => {
|
||||
console.error('failed to save followup');
|
||||
});
|
||||
});
|
||||
@@ -155,11 +163,11 @@ const getFollowingUpAtoms = (
|
||||
};
|
||||
|
||||
export function useChatAtoms(): {
|
||||
conversationAtom: ReturnType<typeof getConversationAtom>;
|
||||
conversationAtom: ReturnType<typeof getOrCreateConversationAtom>;
|
||||
followingUpAtoms: ReturnType<typeof getFollowingUpAtoms>;
|
||||
} {
|
||||
const chat = useAtomValue(chatAtom);
|
||||
const conversationAtom = getConversationAtom(chat.conversationChain);
|
||||
const conversationAtom = getOrCreateConversationAtom(chat.conversationChain);
|
||||
const followingUpAtoms = getFollowingUpAtoms(
|
||||
chat.followupChain,
|
||||
chat.chatHistory
|
||||
|
||||
@@ -18,12 +18,10 @@ You can only give one reply for each conversation turn.
|
||||
`;
|
||||
|
||||
export const followupQuestionPrompt = `Rules you must follow:
|
||||
- You only respond in JSON format
|
||||
- Read the following conversation between AI and Human and generate at most 3 follow-up messages or questions the Human can ask
|
||||
- Your response MUST be a valid JSON array of strings like this: ["some question", "another question"]
|
||||
- Each message in your response should be concise, no more than 15 words
|
||||
- You MUST reply in the same written language as the conversation
|
||||
- Don't output anything other text
|
||||
Read the following conversation between AI and Human and generate at most 3 follow-up messages or questions the Human can ask
|
||||
Each message in your response should be concise, no more than 15 words
|
||||
You MUST reply in the same written language as the conversation
|
||||
{format_instructions}
|
||||
The conversation is inside triple quotes:
|
||||
\`\`\`
|
||||
Human: {human_conversation}
|
||||
|
||||
8
plugins/copilot/src/core/prompts/output-parser.ts
Normal file
8
plugins/copilot/src/core/prompts/output-parser.ts
Normal file
@@ -0,0 +1,8 @@
|
||||
import { StructuredOutputParser } from 'langchain/output_parsers';
|
||||
import { z } from 'zod';
|
||||
|
||||
export const followupQuestionParser = StructuredOutputParser.fromZodSchema(
|
||||
z.object({
|
||||
followupQuestions: z.array(z.string()),
|
||||
})
|
||||
);
|
||||
Reference in New Issue
Block a user