feat: improve copilot plugin (#3459)

This commit is contained in:
Alex Yang
2023-07-29 00:37:01 -07:00
committed by GitHub
parent 52809a2783
commit ce0c1c39e2
7 changed files with 101 additions and 86 deletions

View File

@@ -17,7 +17,7 @@ export const HeaderItem = (): ReactElement => {
return {
direction: 'horizontal',
first: 'editor',
second: '@affine/copilot',
second: '@affine/copilot-plugin',
splitPercentage: 70,
};
} else {

View File

@@ -11,22 +11,29 @@ import {
import { IndexedDBChatMessageHistory } from './langchain/message-history';
import { chatPrompt, followupQuestionPrompt } from './prompts';
import { followupQuestionParser } from './prompts/output-parser';
declare global {
interface WindowEventMap {
'llm-start': CustomEvent;
'llm-new-token': CustomEvent<{ token: string }>;
}
}
type ChatAI = {
// Core chat AI
conversationChain: ConversationChain;
// Followup AI, used to generate followup questions
followupChain: LLMChain<string>;
// Chat history, used to store messages
chatHistory: IndexedDBChatMessageHistory;
};
export type ChatAIConfig = {
events: {
llmStart: () => void;
llmNewToken: (token: string) => void;
};
};
export async function createChatAI(
room: string,
openAIApiKey: string
): Promise<{
conversationChain: ConversationChain;
followupChain: LLMChain<string>;
chatHistory: IndexedDBChatMessageHistory;
}> {
openAIApiKey: string,
config: ChatAIConfig
): Promise<ChatAI> {
if (!openAIApiKey) {
console.warn('OpenAI API key not set, chat will not work');
}
@@ -44,25 +51,11 @@ export async function createChatAI(
openAIApiKey: openAIApiKey,
callbacks: [
{
async handleLLMStart(llm, prompts, runId, parentRunId, extraParams) {
console.log(
'handleLLMStart',
llm,
prompts,
runId,
parentRunId,
extraParams
);
window.dispatchEvent(new CustomEvent('llm-start'));
async handleLLMStart() {
config.events.llmStart();
},
async handleLLMNewToken(token, runId, parentRunId) {
console.log('handleLLMNewToken', token, runId, parentRunId);
window.dispatchEvent(
new CustomEvent('llm-new-token', { detail: { token } })
);
},
async handleLLMEnd(output, runId, parentRunId) {
console.log('handleLLMEnd', output, runId, parentRunId);
async handleLLMNewToken(token) {
config.events.llmNewToken(token);
},
},
],
@@ -77,6 +70,9 @@ export async function createChatAI(
const followupPromptTemplate = new PromptTemplate({
template: followupQuestionPrompt,
inputVariables: ['human_conversation', 'ai_conversation'],
partialVariables: {
format_instructions: followupQuestionParser.getFormatInstructions(),
},
});
const followupChain = new LLMChain({
@@ -101,5 +97,5 @@ export async function createChatAI(
conversationChain,
followupChain,
chatHistory,
} as const;
};
}

View File

@@ -1,51 +1,56 @@
import type { IndexedDBChatMessageHistory } from '@affine/copilot/core/langchain/message-history';
import { atom, useAtomValue } from 'jotai';
import { atomWithDefault, atomWithStorage } from 'jotai/utils';
import type { WritableAtom } from 'jotai/vanilla';
import type { PrimitiveAtom } from 'jotai/vanilla';
import type { LLMChain } from 'langchain/chains';
import { type ConversationChain } from 'langchain/chains';
import { type BufferMemory } from 'langchain/memory';
import type { BaseMessage } from 'langchain/schema';
import { AIMessage } from 'langchain/schema';
import { HumanMessage } from 'langchain/schema';
import { z } from 'zod';
import type { ChatAIConfig } from '../chat';
import { createChatAI } from '../chat';
const followupResponseSchema = z.array(z.string());
import type { IndexedDBChatMessageHistory } from '../langchain/message-history';
import { followupQuestionParser } from '../prompts/output-parser';
export const openAIApiKeyAtom = atomWithStorage<string | null>(
'com.affine.copilot.openai.token',
null
);
export const chatAtom = atom(async get => {
const openAIApiKey = get(openAIApiKeyAtom);
if (!openAIApiKey) {
throw new Error('OpenAI API key not set, chat will not work');
}
return createChatAI('default-copilot', openAIApiKey);
});
const conversationBaseWeakMap = new WeakMap<
ConversationChain,
PrimitiveAtom<BaseMessage[]>
>();
const conversationWeakMap = new WeakMap<
ConversationChain,
WritableAtom<BaseMessage[], [string], Promise<void>>
>();
const getConversationAtom = (chat: ConversationChain) => {
if (conversationWeakMap.has(chat)) {
return conversationWeakMap.get(chat) as WritableAtom<
BaseMessage[],
[string],
Promise<void>
>;
export const chatAtom = atom(async get => {
const openAIApiKey = get(openAIApiKeyAtom);
if (!openAIApiKey) {
throw new Error('OpenAI API key not set, chat will not work');
}
const conversationBaseAtom = atom<BaseMessage[]>([]);
conversationBaseAtom.onMount = setAtom => {
if (!chat) {
throw new Error();
}
const memory = chat.memory as BufferMemory;
const events: ChatAIConfig['events'] = {
llmStart: () => {
throw new Error('llmStart not set');
},
llmNewToken: () => {
throw new Error('llmNewToken not set');
},
};
const chatAI = await createChatAI('default-copilot', openAIApiKey, {
events,
});
getOrCreateConversationAtom(chatAI.conversationChain);
const baseAtom = conversationBaseWeakMap.get(chatAI.conversationChain);
if (!baseAtom) {
throw new TypeError();
}
baseAtom.onMount = setAtom => {
const memory = chatAI.conversationChain.memory as BufferMemory;
memory.chatHistory
.getMessages()
.then(messages => {
@@ -54,23 +59,27 @@ const getConversationAtom = (chat: ConversationChain) => {
.catch(err => {
console.error(err);
});
const llmStart = (): void => {
events.llmStart = () => {
setAtom(conversations => [...conversations, new AIMessage('')]);
};
const llmNewToken = (event: CustomEvent<{ token: string }>): void => {
events.llmNewToken = token => {
setAtom(conversations => {
const last = conversations[conversations.length - 1] as AIMessage;
last.content += event.detail.token;
last.content += token;
return [...conversations];
});
};
window.addEventListener('llm-start', llmStart);
window.addEventListener('llm-new-token', llmNewToken);
return () => {
window.removeEventListener('llm-start', llmStart);
window.removeEventListener('llm-new-token', llmNewToken);
};
};
return chatAI;
});
const getOrCreateConversationAtom = (chat: ConversationChain) => {
if (conversationWeakMap.has(chat)) {
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
return conversationWeakMap.get(chat)!;
}
const conversationBaseAtom = atom<BaseMessage[]>([]);
conversationBaseWeakMap.set(chat, conversationBaseAtom);
const conversationAtom = atom<BaseMessage[], [string], Promise<void>>(
get => get(conversationBaseAtom),
@@ -105,7 +114,9 @@ const getConversationAtom = (chat: ConversationChain) => {
const followingUpWeakMap = new WeakMap<
LLMChain<string>,
{
questionsAtom: ReturnType<typeof atomWithDefault<Promise<string[]>>>;
questionsAtom: ReturnType<
typeof atomWithDefault<Promise<string[]> | string[]>
>;
generateChatAtom: WritableAtom<null, [], void>;
}
>();
@@ -115,12 +126,10 @@ const getFollowingUpAtoms = (
chatHistory: IndexedDBChatMessageHistory
) => {
if (followingUpWeakMap.has(followupLLMChain)) {
return followingUpWeakMap.get(followupLLMChain) as {
questionsAtom: ReturnType<typeof atomWithDefault<Promise<string[]>>>;
generateChatAtom: WritableAtom<null, [], void>;
};
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
return followingUpWeakMap.get(followupLLMChain)!;
}
const baseAtom = atomWithDefault<Promise<string[]>>(async () => {
const baseAtom = atomWithDefault<Promise<string[]> | string[]>(async () => {
return chatHistory?.getFollowingUp() ?? [];
});
const setAtom = atom<null, [], void>(null, async (get, set) => {
@@ -137,10 +146,9 @@ const getFollowingUpAtoms = (
ai_conversation: aiMessage,
human_conversation: humanMessage,
});
const followingUp = JSON.parse(response.text);
followupResponseSchema.parse(followingUp);
set(baseAtom, followingUp);
chatHistory.saveFollowingUp(followingUp).catch(() => {
const followingUp = await followupQuestionParser.parse(response.text);
set(baseAtom, followingUp.followupQuestions);
chatHistory.saveFollowingUp(followingUp.followupQuestions).catch(() => {
console.error('failed to save followup');
});
});
@@ -155,11 +163,11 @@ const getFollowingUpAtoms = (
};
export function useChatAtoms(): {
conversationAtom: ReturnType<typeof getConversationAtom>;
conversationAtom: ReturnType<typeof getOrCreateConversationAtom>;
followingUpAtoms: ReturnType<typeof getFollowingUpAtoms>;
} {
const chat = useAtomValue(chatAtom);
const conversationAtom = getConversationAtom(chat.conversationChain);
const conversationAtom = getOrCreateConversationAtom(chat.conversationChain);
const followingUpAtoms = getFollowingUpAtoms(
chat.followupChain,
chat.chatHistory

View File

@@ -18,12 +18,10 @@ You can only give one reply for each conversation turn.
`;
export const followupQuestionPrompt = `Rules you must follow:
- You only respond in JSON format
- Read the following conversation between AI and Human and generate at most 3 follow-up messages or questions the Human can ask
- Your response MUST be a valid JSON array of strings like this: ["some question", "another question"]
- Each message in your response should be concise, no more than 15 words
- You MUST reply in the same written language as the conversation
- Don't output anything other text
Read the following conversation between AI and Human and generate at most 3 follow-up messages or questions the Human can ask
Each message in your response should be concise, no more than 15 words
You MUST reply in the same written language as the conversation
{format_instructions}
The conversation is inside triple quotes:
\`\`\`
Human: {human_conversation}

View File

@@ -0,0 +1,8 @@
import { StructuredOutputParser } from 'langchain/output_parsers';
import { z } from 'zod';
export const followupQuestionParser = StructuredOutputParser.fromZodSchema(
z.object({
followupQuestions: z.array(z.string()),
})
);