feat: improve copilot (#2758)

This commit is contained in:
Himself65
2023-06-13 10:29:04 +08:00
committed by GitHub
parent 5ba2dff008
commit ace3c37fcc
20 changed files with 413 additions and 79 deletions

View File

@@ -1,14 +1,15 @@
import { Button, Input } from '@affine/component';
import { rootStore } from '@toeverything/plugin-infra/manager';
import type { PluginUIAdapter } from '@toeverything/plugin-infra/type';
import { Provider, useAtom, useAtomValue, useSetAtom } from 'jotai';
import { Provider, useAtomValue, useSetAtom } from 'jotai';
import type { ReactElement } from 'react';
import { Fragment, StrictMode, useState } from 'react';
import { StrictMode, Suspense, useCallback, useState } from 'react';
import { createRoot } from 'react-dom/client';
import { Conversation } from '../core/components/conversation';
import { Divider } from '../core/components/divider';
import { ConversationList } from '../core/components/conversation-list';
import { FollowingUp } from '../core/components/following-up';
import { openAIApiKeyAtom, useChatAtoms } from '../core/hooks';
import { detailContentActionsStyle, detailContentStyle } from './index.css';
if (typeof window === 'undefined') {
import('@blocksuite/blocks').then(({ FormatQuickBar }) => {
@@ -54,25 +55,16 @@ if (typeof window === 'undefined') {
});
}
const DetailContentImpl = () => {
const Actions = () => {
const { conversationAtom, followingUpAtoms } = useChatAtoms();
const call = useSetAtom(conversationAtom);
const questions = useAtomValue(followingUpAtoms.questionsAtom);
const generateFollowingUp = useSetAtom(followingUpAtoms.generateChatAtom);
const [input, setInput] = useState('');
const { conversationAtom } = useChatAtoms();
const [conversations, call] = useAtom(conversationAtom);
return (
<div
style={{
width: '300px',
}}
>
{conversations.map((message, idx) => {
return (
<Fragment key={idx}>
<Conversation text={message.text} />
<Divider />
</Fragment>
);
})}
<div>
<>
<FollowingUp questions={questions} />
<div className={detailContentActionsStyle}>
<Input
value={input}
onChange={text => {
@@ -80,13 +72,28 @@ const DetailContentImpl = () => {
}}
/>
<Button
onClick={() => {
void call(input);
}}
onClick={useCallback(async () => {
await call(input);
await generateFollowingUp();
}, [call, generateFollowingUp, input])}
>
send
</Button>
</div>
</>
);
};
const DetailContentImpl = () => {
const { conversationAtom } = useChatAtoms();
const conversations = useAtomValue(conversationAtom);
return (
<div className={detailContentStyle}>
<ConversationList conversations={conversations} />
<Suspense fallback="generating follow-up question">
<Actions />
</Suspense>
</div>
);
};

View File

@@ -19,7 +19,7 @@ export const HeaderItem: PluginUIAdapter['headerItem'] = ({
direction: 'horizontal',
first: 'editor',
second: 'com.affine.copilot',
splitPercentage: 80,
splitPercentage: 70,
};
} else {
return 'editor';

View File

@@ -0,0 +1,15 @@
import { style } from '@vanilla-extract/css';
export const detailContentStyle = style({
backgroundColor: 'rgba(0, 0, 0, 0.04)',
height: '100%',
display: 'flex',
flexDirection: 'column',
overflow: 'auto',
paddingLeft: '9px',
paddingRight: '9px',
});
export const detailContentActionsStyle = style({});

View File

@@ -1,16 +1,17 @@
import { ConversationChain } from 'langchain/chains';
import { ConversationChain, LLMChain } from 'langchain/chains';
import { ChatOpenAI } from 'langchain/chat_models/openai';
import { BufferMemory } from 'langchain/memory';
import {
ChatPromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder,
PromptTemplate,
SystemMessagePromptTemplate,
} from 'langchain/prompts';
import { type LLMResult } from 'langchain/schema';
import { IndexedDBChatMessageHistory } from './langchain/message-history';
import { chatPrompt } from './prompts';
import { chatPrompt, followupQuestionPrompt } from './prompts';
declare global {
interface WindowEventMap {
@@ -22,13 +23,24 @@ declare global {
export async function createChatAI(
room: string,
openAIApiKey: string
): Promise<ConversationChain> {
): Promise<{
conversationChain: ConversationChain;
followupChain: LLMChain<string>;
chatHistory: IndexedDBChatMessageHistory;
}> {
if (!openAIApiKey) {
console.warn('OpenAI API key not set, chat will not work');
}
const followup = new ChatOpenAI({
streaming: false,
modelName: 'gpt-3.5-turbo',
temperature: 0.5,
openAIApiKey: openAIApiKey,
});
const chat = new ChatOpenAI({
streaming: true,
modelName: 'gpt-4',
modelName: 'gpt-3.5-turbo',
temperature: 0.5,
openAIApiKey: openAIApiKey,
callbacks: [
@@ -77,13 +89,32 @@ export async function createChatAI(
HumanMessagePromptTemplate.fromTemplate('{input}'),
]);
return new ConversationChain({
const followupPromptTemplate = new PromptTemplate({
template: followupQuestionPrompt,
inputVariables: ['human_conversation', 'ai_conversation'],
});
const followupChain = new LLMChain({
llm: followup,
prompt: followupPromptTemplate,
memory: undefined,
});
const chatHistory = new IndexedDBChatMessageHistory(room);
const conversationChain = new ConversationChain({
memory: new BufferMemory({
returnMessages: true,
memoryKey: 'history',
chatHistory: new IndexedDBChatMessageHistory(room),
chatHistory,
}),
prompt: chatPromptTemplate,
llm: chat,
});
return {
conversationChain,
followupChain,
chatHistory,
} as const;
}

View File

@@ -0,0 +1,7 @@
import { style } from '@vanilla-extract/css';
export const conversationListStyle = style({
display: 'flex',
flexDirection: 'column',
gap: '30px',
});

View File

@@ -0,0 +1,22 @@
import type { BaseChatMessage } from 'langchain/schema';
import { Conversation } from '../conversation';
import { conversationListStyle } from './index.css';
export type ConversationListProps = {
conversations: BaseChatMessage[];
};
export const ConversationList = (props: ConversationListProps) => {
return (
<div className={conversationListStyle}>
{props.conversations.map((conversation, idx) => (
<Conversation
type={conversation._getType()}
text={conversation.text}
key={idx}
/>
))}
</div>
);
};

View File

@@ -1,19 +0,0 @@
import { marked } from 'marked';
import { type ReactElement, useMemo } from 'react';
export interface ConversationProps {
text: string;
}
export const Conversation = (props: ConversationProps): ReactElement => {
const html = useMemo(() => marked.parse(props.text), [props.text]);
return (
<div>
<div
dangerouslySetInnerHTML={{
__html: html,
}}
/>
</div>
);
};

View File

@@ -0,0 +1,15 @@
import { style } from '@vanilla-extract/css';
export const conversationStyle = style({
padding: '10px 18px',
});
export const aiMessageStyle = style({
backgroundColor: 'rgba(207, 252, 255, 0.3)',
borderRadius: '18px 18px 18px 2px',
});
export const humanMessageStyle = style({
borderRadius: '18px 18px 2px 18px',
backgroundColor: 'white',
});

View File

@@ -0,0 +1,42 @@
import { clsx } from 'clsx';
import type { MessageType } from 'langchain/schema';
import { marked } from 'marked';
import { gfmHeadingId } from 'marked-gfm-heading-id';
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
// @ts-expect-error
import { mangle } from 'marked-mangle';
import { type ReactElement, useMemo } from 'react';
import {
aiMessageStyle,
conversationStyle,
humanMessageStyle,
} from './index.css';
marked.use(
gfmHeadingId({
prefix: 'affine-',
})
);
marked.use(mangle());
export interface ConversationProps {
type: MessageType;
text: string;
}
export const Conversation = (props: ConversationProps): ReactElement => {
const html = useMemo(() => marked.parse(props.text), [props.text]);
return (
<div
className={clsx(conversationStyle, {
[aiMessageStyle]: props.type === 'ai',
[humanMessageStyle]: props.type === 'human',
})}
dangerouslySetInnerHTML={{
__html: html,
}}
></div>
);
};

View File

@@ -0,0 +1,17 @@
import { style } from '@vanilla-extract/css';
export const followingUpStyle = style({
display: 'flex',
flexDirection: 'row',
flexWrap: 'wrap',
gap: '10px',
alignItems: 'flex-start',
});
export const questionStyle = style({
backgroundColor: 'white',
borderRadius: '8px',
color: '#8E8D91',
padding: '2px 10px',
cursor: 'pointer',
});

View File

@@ -0,0 +1,19 @@
import type { ReactElement } from 'react';
import { followingUpStyle, questionStyle } from './index.css';
export type FollowingUpProps = {
questions: string[];
};
export const FollowingUp = (props: FollowingUpProps): ReactElement => {
return (
<div className={followingUpStyle}>
{props.questions.map((question, index) => (
<div className={questionStyle} key={index}>
{question}
</div>
))}
</div>
);
};

View File

@@ -1,6 +1,9 @@
import type { IndexedDBChatMessageHistory } from '@affine/copilot/core/langchain/message-history';
import { atom, useAtomValue } from 'jotai';
import { atomFamily } from 'jotai/utils';
import { atomWithDefault } from 'jotai/utils';
import { atomWithStorage } from 'jotai/utils';
import type { WritableAtom } from 'jotai/vanilla';
import type { LLMChain } from 'langchain/chains';
import { type ConversationChain } from 'langchain/chains';
import { type BufferMemory } from 'langchain/memory';
import {
@@ -8,9 +11,12 @@ import {
type BaseChatMessage,
HumanChatMessage,
} from 'langchain/schema';
import { z } from 'zod';
import { createChatAI } from '../chat';
const followupResponseSchema = z.array(z.string());
export const openAIApiKeyAtom = atomWithStorage<string | null>(
'com.affine.copilot.openai.token',
null
@@ -19,12 +25,24 @@ export const openAIApiKeyAtom = atomWithStorage<string | null>(
export const chatAtom = atom(async get => {
const openAIApiKey = get(openAIApiKeyAtom);
if (!openAIApiKey) {
return null;
throw new Error('OpenAI API key not set, chat will not work');
}
return createChatAI('default-copilot', openAIApiKey);
});
const conversationAtomFamily = atomFamily((chat: ConversationChain | null) => {
const conversationWeakMap = new WeakMap<
ConversationChain,
WritableAtom<BaseChatMessage[], [string], Promise<void>>
>();
const getConversationAtom = (chat: ConversationChain) => {
if (conversationWeakMap.has(chat)) {
return conversationWeakMap.get(chat) as WritableAtom<
BaseChatMessage[],
[string],
Promise<void>
>;
}
const conversationBaseAtom = atom<BaseChatMessage[]>([]);
conversationBaseAtom.onMount = setAtom => {
if (!chat) {
@@ -52,7 +70,7 @@ const conversationAtomFamily = atomFamily((chat: ConversationChain | null) => {
};
};
return atom<BaseChatMessage[], [string], Promise<void>>(
const conversationAtom = atom<BaseChatMessage[], [string], Promise<void>>(
get => get(conversationBaseAtom),
async (get, set, input) => {
if (!chat) {
@@ -73,14 +91,75 @@ const conversationAtomFamily = atomFamily((chat: ConversationChain | null) => {
});
}
);
});
conversationWeakMap.set(chat, conversationAtom);
return conversationAtom;
};
const followingUpWeakMap = new WeakMap<
LLMChain<string>,
{
questionsAtom: ReturnType<typeof atomWithDefault<Promise<string[]>>>;
generateChatAtom: WritableAtom<null, [], void>;
}
>();
const getFollowingUpAtoms = (
followupLLMChain: LLMChain<string>,
chatHistory: IndexedDBChatMessageHistory
) => {
if (followingUpWeakMap.has(followupLLMChain)) {
return followingUpWeakMap.get(followupLLMChain) as {
questionsAtom: ReturnType<typeof atomWithDefault<Promise<string[]>>>;
generateChatAtom: WritableAtom<null, [], void>;
};
}
const baseAtom = atomWithDefault<Promise<string[]>>(async () => {
return chatHistory?.getFollowingUp() ?? [];
});
const setAtom = atom<null, [], void>(null, async (get, set) => {
if (!followupLLMChain || !chatHistory) {
throw new Error('followupLLMChain not set');
}
const messages = await chatHistory.getMessages();
const aiMessage = messages.findLast(
message => message._getType() === 'ai'
)?.text;
const humanMessage = messages.findLast(
message => message._getType() === 'human'
)?.text;
const response = await followupLLMChain.call({
ai_conversation: aiMessage,
human_conversation: humanMessage,
});
const followingUp = JSON.parse(response.text);
followupResponseSchema.parse(followingUp);
set(baseAtom, followingUp);
chatHistory.saveFollowingUp(followingUp).catch(() => {
console.error('failed to save followup');
});
});
followingUpWeakMap.set(followupLLMChain, {
questionsAtom: baseAtom,
generateChatAtom: setAtom,
});
return {
questionsAtom: baseAtom,
generateChatAtom: setAtom,
};
};
export function useChatAtoms(): {
conversationAtom: ReturnType<typeof conversationAtomFamily>;
conversationAtom: ReturnType<typeof getConversationAtom>;
followingUpAtoms: ReturnType<typeof getFollowingUpAtoms>;
} {
const chat = useAtomValue(chatAtom);
const conversationAtom = conversationAtomFamily(chat);
const conversationAtom = getConversationAtom(chat.conversationChain);
const followingUpAtoms = getFollowingUpAtoms(
chat.followupChain,
chat.chatHistory
);
return {
conversationAtom,
followingUpAtoms,
};
}

View File

@@ -23,25 +23,42 @@ interface ChatMessageDBV1 extends DBSchema {
};
}
interface ChatMessageDBV2 extends ChatMessageDBV1 {
followingUp: {
key: string;
value: {
/**
* ID of the chat
*/
id: string;
question: string[];
};
};
}
export const conversationHistoryDBName = 'affine-copilot-chat';
export class IndexedDBChatMessageHistory extends BaseChatMessageHistory {
public id: string;
private messages: BaseChatMessage[] = [];
private readonly dbPromise: Promise<IDBPDatabase<ChatMessageDBV1>>;
private readonly dbPromise: Promise<IDBPDatabase<ChatMessageDBV2>>;
private readonly initPromise: Promise<void>;
constructor(id: string) {
super();
this.id = id;
this.messages = [];
this.dbPromise = openDB<ChatMessageDBV1>('affine-copilot-chat', 1, {
this.dbPromise = openDB<ChatMessageDBV2>('affine-copilot-chat', 2, {
upgrade(database, oldVersion) {
if (oldVersion === 0) {
database.createObjectStore('chat', {
keyPath: 'id',
});
} else if (oldVersion === 1) {
database.createObjectStore('followingUp', {
keyPath: 'id',
});
}
},
});
@@ -70,6 +87,31 @@ export class IndexedDBChatMessageHistory extends BaseChatMessageHistory {
});
}
public async saveFollowingUp(question: string[]): Promise<void> {
await this.initPromise;
const db = await this.dbPromise;
const t = db
.transaction('followingUp', 'readwrite')
.objectStore('followingUp');
await t.put({
id: this.id,
question,
});
}
public async getFollowingUp(): Promise<string[]> {
await this.initPromise;
const db = await this.dbPromise;
const t = db
.transaction('followingUp', 'readonly')
.objectStore('followingUp');
const chat = await t.get(this.id);
if (chat != null) {
return chat.question;
}
return [];
}
protected async addMessage(message: BaseChatMessage): Promise<void> {
await this.initPromise;
this.messages.push(message);
@@ -104,6 +146,6 @@ export class IndexedDBChatMessageHistory extends BaseChatMessageHistory {
}
async getMessages(): Promise<BaseChatMessage[]> {
return await this.initPromise.then(() => this.messages);
return this.initPromise.then(() => this.messages);
}
}

View File

@@ -14,8 +14,19 @@ Keep your answers short and impersonal.
The user works in an app called AFFiNE, which has a concept for an editor, a page for a single document, workspace for a collection of documents.
The active document is the markdown file the user is looking at.
Use Markdown formatting in your answers.
Wrap your answers into triple backticks.
You can only give one reply for each conversation turn.
You should always generate short suggestions for the next user turns that are relevant to the conversation and not offensive.
You should reply to the users within 150 characters.
`;
export const followupQuestionPrompt = `Rules you must follow:
- You only respond in JSON format
- Read the following conversation between AI and Human and generate at most 3 follow-up messages or questions the Human can ask
- Your response MUST be a valid JSON array of strings like this: ["some question", "another question"]
- Each message in your response should be concise, no more than 15 words
- You MUST reply in the same written language as the conversation
- Don't output anything other text
The conversation is inside triple quotes:
\`\`\`
Human: {human_conversation}
AI: {ai_conversation}
\`\`\`
`;