fix(core): unable to explain image when network search is active (#10228)

Fix issue [PD-2316](https://linear.app/affine-design/issue/PD-2316).
This commit is contained in:
akumatus
2025-02-18 02:20:20 +00:00
parent eed00e0b26
commit 015452e8fb
3 changed files with 116 additions and 66 deletions

View File

@@ -288,7 +288,7 @@ export class ChatPanelInput extends SignalWatcher(WithDisposable(LitElement)) {
); );
} }
private get _promptName() { private _getPromptName() {
if (this._isNetworkDisabled) { if (this._isNetworkDisabled) {
return PROMPT_NAME_AFFINE_AI; return PROMPT_NAME_AFFINE_AI;
} }
@@ -297,12 +297,12 @@ export class ChatPanelInput extends SignalWatcher(WithDisposable(LitElement)) {
: PROMPT_NAME_AFFINE_AI; : PROMPT_NAME_AFFINE_AI;
} }
private async _updatePromptName() { private async _updatePromptName(promptName: string) {
if (this._lastPromptName !== this._promptName) { if (this._lastPromptName !== promptName) {
this._lastPromptName = this._promptName;
const sessionId = await this.getSessionId(); const sessionId = await this.getSessionId();
if (sessionId) { if (sessionId && AIProvider.session) {
await AIProvider.session?.updateSession(sessionId, this._promptName); await AIProvider.session.updateSession(sessionId, promptName);
this._lastPromptName = promptName;
} }
} }
} }
@@ -457,7 +457,7 @@ export class ChatPanelInput extends SignalWatcher(WithDisposable(LitElement)) {
}} }}
@keydown=${async (evt: KeyboardEvent) => { @keydown=${async (evt: KeyboardEvent) => {
if (evt.key === 'Enter' && !evt.shiftKey && !evt.isComposing) { if (evt.key === 'Enter' && !evt.shiftKey && !evt.isComposing) {
this._onTextareaSend(evt); await this._onTextareaSend(evt);
} }
}} }}
@focus=${() => { @focus=${() => {
@@ -538,7 +538,7 @@ export class ChatPanelInput extends SignalWatcher(WithDisposable(LitElement)) {
</div>`; </div>`;
} }
private readonly _onTextareaSend = (e: MouseEvent | KeyboardEvent) => { private readonly _onTextareaSend = async (e: MouseEvent | KeyboardEvent) => {
e.preventDefault(); e.preventDefault();
e.stopPropagation(); e.stopPropagation();
@@ -549,17 +549,17 @@ export class ChatPanelInput extends SignalWatcher(WithDisposable(LitElement)) {
this.isInputEmpty = true; this.isInputEmpty = true;
this.textarea.style.height = 'unset'; this.textarea.style.height = 'unset';
this.send(value).catch(console.error); await this.send(value);
}; };
send = async (text: string) => { send = async (text: string) => {
const { status, markdown, chips } = this.chatContextValue; const { status, markdown, chips, images } = this.chatContextValue;
if (status === 'loading' || status === 'transmitting') return; if (status === 'loading' || status === 'transmitting') return;
if (!text) return; if (!text) return;
try { try {
const { images } = this.chatContextValue;
const { doc } = this.host; const { doc } = this.host;
const promptName = this._getPromptName();
this.updateContext({ this.updateContext({
images: [], images: [],
@@ -593,7 +593,9 @@ export class ChatPanelInput extends SignalWatcher(WithDisposable(LitElement)) {
], ],
}); });
await this._updatePromptName(); // must update prompt name after local chat message is updated
// otherwise, the unauthorized error can not be rendered properly
await this._updatePromptName(promptName);
const abortController = new AbortController(); const abortController = new AbortController();
const sessionId = await this.getSessionId(); const sessionId = await this.getSessionId();

View File

@@ -241,7 +241,7 @@ export class ChatBlockInput extends SignalWatcher(LitElement) {
${status === 'transmitting' ${status === 'transmitting'
? html`<div @click=${this._handleAbort}>${ChatAbortIcon}</div>` ? html`<div @click=${this._handleAbort}>${ChatAbortIcon}</div>`
: html`<div : html`<div
@click="${this._send}" @click=${this._onTextareaSend}
class="chat-panel-send" class="chat-panel-send"
aria-disabled=${this._isInputEmpty} aria-disabled=${this._isInputEmpty}
> >
@@ -306,7 +306,7 @@ export class ChatBlockInput extends SignalWatcher(LitElement) {
return !!this.chatContext.images.length; return !!this.chatContext.images.length;
} }
private get _promptName() { private _getPromptName() {
if (this._isNetworkDisabled) { if (this._isNetworkDisabled) {
return PROMPT_NAME_AFFINE_AI; return PROMPT_NAME_AFFINE_AI;
} }
@@ -315,15 +315,12 @@ export class ChatBlockInput extends SignalWatcher(LitElement) {
: PROMPT_NAME_AFFINE_AI; : PROMPT_NAME_AFFINE_AI;
} }
private async _updatePromptName() { private async _updatePromptName(promptName: string) {
if (this._lastPromptName !== this._promptName) { if (this._lastPromptName !== promptName) {
this._lastPromptName = this._promptName;
const { currentSessionId } = this.chatContext; const { currentSessionId } = this.chatContext;
if (currentSessionId) { if (currentSessionId && AIProvider.session) {
await AIProvider.session?.updateSession( await AIProvider.session.updateSession(currentSessionId, promptName);
currentSessionId, this._lastPromptName = promptName;
this._promptName
);
} }
} }
} }
@@ -346,7 +343,7 @@ export class ChatBlockInput extends SignalWatcher(LitElement) {
private readonly _handleKeyDown = async (evt: KeyboardEvent) => { private readonly _handleKeyDown = async (evt: KeyboardEvent) => {
if (evt.key === 'Enter' && !evt.shiftKey && !evt.isComposing) { if (evt.key === 'Enter' && !evt.shiftKey && !evt.isComposing) {
evt.preventDefault(); evt.preventDefault();
await this._send(); await this._onTextareaSend(evt);
} }
}; };
@@ -452,56 +449,70 @@ export class ChatBlockInput extends SignalWatcher(LitElement) {
`; `;
} }
private readonly _send = async () => { private readonly _onTextareaSend = async (e: MouseEvent | KeyboardEvent) => {
const { images, status } = this.chatContext; e.preventDefault();
if (status === 'loading' || status === 'transmitting') return; e.stopPropagation();
const text = this.textarea.value; const value = this.textarea.value.trim();
if (!text && !images.length) { if (value.length === 0) return;
return;
}
const { doc } = this.host;
this.textarea.value = ''; this.textarea.value = '';
this._isInputEmpty = true; this._isInputEmpty = true;
this.textarea.style.height = 'unset'; this.textarea.style.height = 'unset';
this.updateContext({
images: [],
status: 'loading',
error: null,
});
const attachments = await Promise.all( await this._send(value);
images?.map(image => readBlobAsURL(image)) };
);
const userInfo = await AIProvider.userInfo; private readonly _send = async (text: string) => {
this.updateContext({ const { images, status, currentChatBlockId, currentSessionId } =
messages: [ this.chatContext;
...this.chatContext.messages,
{
id: '',
content: text,
role: 'user',
createdAt: new Date().toISOString(),
attachments,
userId: userInfo?.id,
userName: userInfo?.name,
avatarUrl: userInfo?.avatarUrl ?? undefined,
},
{
id: '',
content: '',
role: 'assistant',
createdAt: new Date().toISOString(),
},
],
});
const { currentChatBlockId, currentSessionId } = this.chatContext;
let content = '';
const chatBlockExists = !!currentChatBlockId; const chatBlockExists = !!currentChatBlockId;
let content = '';
if (status === 'loading' || status === 'transmitting') return;
if (!text) return;
try { try {
const { doc } = this.host;
const promptName = this._getPromptName();
this.updateContext({
images: [],
status: 'loading',
error: null,
});
const attachments = await Promise.all(
images?.map(image => readBlobAsURL(image))
);
const userInfo = await AIProvider.userInfo;
this.updateContext({
messages: [
...this.chatContext.messages,
{
id: '',
content: text,
role: 'user',
createdAt: new Date().toISOString(),
attachments,
userId: userInfo?.id,
userName: userInfo?.name,
avatarUrl: userInfo?.avatarUrl ?? undefined,
},
{
id: '',
content: '',
role: 'assistant',
createdAt: new Date().toISOString(),
},
],
});
// must update prompt name after local chat message is updated
// otherwise, the unauthorized error can not be rendered properly
await this._updatePromptName(promptName);
// If has not forked a chat session, fork a new one // If has not forked a chat session, fork a new one
let chatSessionId = currentSessionId; let chatSessionId = currentSessionId;
if (!chatSessionId) { if (!chatSessionId) {
@@ -518,8 +529,6 @@ export class ChatBlockInput extends SignalWatcher(LitElement) {
chatSessionId = forkSessionId; chatSessionId = forkSessionId;
} }
await this._updatePromptName();
const abortController = new AbortController(); const abortController = new AbortController();
const stream = AIProvider.actions.chat?.({ const stream = AIProvider.actions.chat?.({
input: text, input: text,

View File

@@ -517,6 +517,45 @@ test.describe('chat panel', () => {
); );
}); });
test('can identify shape color, even if network search is active', async ({
page,
}) => {
await page.reload();
await clickSideBarAllPageButton(page);
await page.waitForTimeout(200);
await createLocalWorkspace({ name: 'test' }, page);
await clickNewPageButton(page);
await openChat(page);
await page.getByTestId('chat-network-search').click();
await switchToEdgelessMode(page);
const shapeButton = await page.waitForSelector(
'edgeless-shape-tool-button'
);
await shapeButton.click();
await page.mouse.click(400, 400);
const askAIButton = await page.waitForSelector('.copilot-icon-button');
await askAIButton.click();
await page.waitForTimeout(1000);
await page.keyboard.type('What color is this shape?');
await page.keyboard.press('Enter');
const history = await collectChat(page);
expect(history[0]).toEqual({
name: 'You',
content: 'What color is this shape?',
});
expect(history[1].name).toBe('AFFiNE AI');
expect(history[1].content).toContain('yellow');
expect(await page.locator('chat-panel affine-footnote-node').count()).toBe(
0
);
});
test('can trigger inline ai input and action panel by clicking Start with AI button', async ({ test('can trigger inline ai input and action panel by clicking Start with AI button', async ({
page, page,
}) => { }) => {