mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-13 21:05:19 +00:00
feat(server): scenario mapping (#13404)
fix AI-404 <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Introduced scenario-based configuration for copilot, allowing default model assignments for various AI use cases. * Added a new image generation model to the available options. * **Improvements** * Refined copilot provider settings by removing deprecated fallback options and standardizing base URL configuration. * Enhanced prompt management to support scenario-driven updates and improved configuration handling. * Updated admin and settings interfaces to support new scenario configurations. * **Bug Fixes** * Removed deprecated or unused prompts and related references across platforms for consistency. * **Other** * Improved test coverage and updated test assets to reflect prompt and scenario changes. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
@@ -3,6 +3,7 @@ import {
|
||||
StorageJSONSchema,
|
||||
StorageProviderConfig,
|
||||
} from '../../base';
|
||||
import { CopilotPromptScenario } from './prompt/prompts';
|
||||
import {
|
||||
AnthropicOfficialConfig,
|
||||
AnthropicVertexConfig,
|
||||
@@ -24,6 +25,7 @@ declare global {
|
||||
key: string;
|
||||
}>;
|
||||
storage: ConfigItem<StorageProviderConfig>;
|
||||
scenarios: ConfigItem<CopilotPromptScenario>;
|
||||
providers: {
|
||||
openai: ConfigItem<OpenAIConfig>;
|
||||
fal: ConfigItem<FalConfig>;
|
||||
@@ -43,17 +45,29 @@ defineModuleConfig('copilot', {
|
||||
desc: 'Whether to enable the copilot plugin.',
|
||||
default: false,
|
||||
},
|
||||
scenarios: {
|
||||
desc: 'The models used in the scene for the copilot, will use this config if enabled.',
|
||||
default: {
|
||||
enabled: false,
|
||||
scenarios: {
|
||||
audio: 'gemini-2.5-flash',
|
||||
chat: 'claude-sonnet-4@20250514',
|
||||
embedding: 'gemini-embedding-001',
|
||||
image: 'gpt-image-1',
|
||||
rerank: 'gpt-4.1',
|
||||
brainstorm: 'gpt-4o-2024-08-06',
|
||||
coding: 'claude-sonnet-4@20250514',
|
||||
quick_decision: 'gpt-4.1-mini',
|
||||
quick_written: 'gemini-2.5-flash',
|
||||
summary_inspection: 'gemini-2.5-flash',
|
||||
},
|
||||
},
|
||||
},
|
||||
'providers.openai': {
|
||||
desc: 'The config for the openai provider.',
|
||||
default: {
|
||||
apiKey: '',
|
||||
baseUrl: '',
|
||||
fallback: {
|
||||
text: '',
|
||||
structured: '',
|
||||
image: '',
|
||||
embedding: '',
|
||||
},
|
||||
baseURL: 'https://api.openai.com/v1',
|
||||
},
|
||||
link: 'https://github.com/openai/openai-node',
|
||||
},
|
||||
@@ -67,54 +81,30 @@ defineModuleConfig('copilot', {
|
||||
desc: 'The config for the gemini provider.',
|
||||
default: {
|
||||
apiKey: '',
|
||||
baseUrl: '',
|
||||
fallback: {
|
||||
text: '',
|
||||
structured: '',
|
||||
image: '',
|
||||
embedding: '',
|
||||
},
|
||||
baseURL: 'https://generativelanguage.googleapis.com/v1beta',
|
||||
},
|
||||
},
|
||||
'providers.geminiVertex': {
|
||||
desc: 'The config for the gemini provider in Google Vertex AI.',
|
||||
default: {
|
||||
baseURL: '',
|
||||
fallback: {
|
||||
text: '',
|
||||
structured: '',
|
||||
image: '',
|
||||
embedding: '',
|
||||
},
|
||||
},
|
||||
default: {},
|
||||
schema: VertexSchema,
|
||||
},
|
||||
'providers.perplexity': {
|
||||
desc: 'The config for the perplexity provider.',
|
||||
default: {
|
||||
apiKey: '',
|
||||
fallback: {
|
||||
text: '',
|
||||
},
|
||||
},
|
||||
},
|
||||
'providers.anthropic': {
|
||||
desc: 'The config for the anthropic provider.',
|
||||
default: {
|
||||
apiKey: '',
|
||||
fallback: {
|
||||
text: '',
|
||||
},
|
||||
baseURL: 'https://api.anthropic.com/v1',
|
||||
},
|
||||
},
|
||||
'providers.anthropicVertex': {
|
||||
desc: 'The config for the anthropic provider in Google Vertex AI.',
|
||||
default: {
|
||||
baseURL: '',
|
||||
fallback: {
|
||||
text: '',
|
||||
},
|
||||
},
|
||||
default: {},
|
||||
schema: VertexSchema,
|
||||
},
|
||||
'providers.morph': {
|
||||
|
||||
@@ -2,6 +2,7 @@ import { Logger } from '@nestjs/common';
|
||||
import type { ModuleRef } from '@nestjs/core';
|
||||
|
||||
import {
|
||||
Config,
|
||||
CopilotPromptNotFound,
|
||||
CopilotProviderNotSupported,
|
||||
} from '../../../base';
|
||||
@@ -28,6 +29,7 @@ class ProductionEmbeddingClient extends EmbeddingClient {
|
||||
private readonly logger = new Logger(ProductionEmbeddingClient.name);
|
||||
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
private readonly providerFactory: CopilotProviderFactory,
|
||||
private readonly prompt: PromptService
|
||||
) {
|
||||
@@ -36,7 +38,9 @@ class ProductionEmbeddingClient extends EmbeddingClient {
|
||||
|
||||
override async configured(): Promise<boolean> {
|
||||
const embedding = await this.providerFactory.getProvider({
|
||||
modelId: EMBEDDING_MODEL,
|
||||
modelId: this.config.copilot?.scenarios?.enabled
|
||||
? this.config.copilot.scenarios.scenarios?.embedding || EMBEDDING_MODEL
|
||||
: EMBEDDING_MODEL,
|
||||
outputType: ModelOutputType.Embedding,
|
||||
});
|
||||
const result = Boolean(embedding);
|
||||
@@ -209,12 +213,13 @@ export async function getEmbeddingClient(
|
||||
if (EMBEDDING_CLIENT) {
|
||||
return EMBEDDING_CLIENT;
|
||||
}
|
||||
const config = moduleRef.get(Config, { strict: false });
|
||||
const providerFactory = moduleRef.get(CopilotProviderFactory, {
|
||||
strict: false,
|
||||
});
|
||||
const prompt = moduleRef.get(PromptService, { strict: false });
|
||||
|
||||
const client = new ProductionEmbeddingClient(providerFactory, prompt);
|
||||
const client = new ProductionEmbeddingClient(config, providerFactory, prompt);
|
||||
if (await client.configured()) {
|
||||
EMBEDDING_CLIENT = client;
|
||||
}
|
||||
|
||||
@@ -19,6 +19,83 @@ type Prompt = Omit<
|
||||
config?: PromptConfig;
|
||||
};
|
||||
|
||||
export const Scenario: Record<string, string[]> = {
|
||||
audio: ['Transcript audio'],
|
||||
brainstorm: [
|
||||
'Brainstorm mindmap',
|
||||
'Create a presentation',
|
||||
'Expand mind map',
|
||||
'workflow:brainstorm:step2',
|
||||
'workflow:presentation:step2',
|
||||
'workflow:presentation:step4',
|
||||
],
|
||||
chat: ['Chat With AFFiNE AI'],
|
||||
coding: [
|
||||
'Apply Updates',
|
||||
'Code Artifact',
|
||||
'Make it real',
|
||||
'Make it real with text',
|
||||
'Section Edit',
|
||||
],
|
||||
// no prompt needed, just a placeholder
|
||||
embedding: [],
|
||||
image: [
|
||||
'Convert to Anime style',
|
||||
'Convert to Clay style',
|
||||
'Convert to Pixel style',
|
||||
'Convert to Sketch style',
|
||||
'Convert to sticker',
|
||||
'Generate image',
|
||||
'Remove background',
|
||||
'Upscale image',
|
||||
],
|
||||
quick_decision: [
|
||||
'Create headings',
|
||||
'Generate a caption',
|
||||
'Translate to',
|
||||
'workflow:brainstorm:step1',
|
||||
'workflow:presentation:step1',
|
||||
'workflow:image-anime:step2',
|
||||
'workflow:image-clay:step2',
|
||||
'workflow:image-pixel:step2',
|
||||
'workflow:image-sketch:step2',
|
||||
],
|
||||
quick_written: [
|
||||
'Brainstorm ideas about this',
|
||||
'Continue writing',
|
||||
'Explain this code',
|
||||
'Fix spelling for it',
|
||||
'Improve writing for it',
|
||||
'Make it longer',
|
||||
'Make it shorter',
|
||||
'Write a blog post about this',
|
||||
'Write a poem about this',
|
||||
'Write an article about this',
|
||||
'Write outline',
|
||||
],
|
||||
rerank: ['Rerank results'],
|
||||
summary_inspection: [
|
||||
'Change tone to',
|
||||
'Check code error',
|
||||
'Conversation Summary',
|
||||
'Explain this',
|
||||
'Explain this image',
|
||||
'Find action for summary',
|
||||
'Find action items from it',
|
||||
'Improve grammar for it',
|
||||
'Summarize the meeting',
|
||||
'Summary',
|
||||
'Summary as title',
|
||||
'Summary the webpage',
|
||||
'Write a twitter about this',
|
||||
],
|
||||
};
|
||||
|
||||
export type CopilotPromptScenario = {
|
||||
enabled?: boolean;
|
||||
scenarios?: Partial<Record<keyof typeof Scenario, string>>;
|
||||
};
|
||||
|
||||
const workflows: Prompt[] = [
|
||||
{
|
||||
name: 'workflow:presentation',
|
||||
@@ -1612,31 +1689,6 @@ const imageActions: Prompt[] = [
|
||||
model: 'workflowutils/teed',
|
||||
messages: [{ role: 'user', content: '{{content}}' }],
|
||||
},
|
||||
{
|
||||
name: 'debug:action:dalle3',
|
||||
action: 'image',
|
||||
model: 'dall-e-3',
|
||||
messages: [
|
||||
{
|
||||
role: 'user',
|
||||
content: '{{content}}',
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
name: 'debug:action:gpt-image-1',
|
||||
action: 'image',
|
||||
model: 'gpt-image-1',
|
||||
messages: [
|
||||
{
|
||||
role: 'user',
|
||||
content: '{{content}}',
|
||||
},
|
||||
],
|
||||
config: {
|
||||
requireContent: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: 'debug:action:fal-sd15',
|
||||
action: 'image',
|
||||
@@ -1814,6 +1866,65 @@ Now apply the \`updates\` to the \`content\`, following the intent in \`op\`, an
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
name: 'Code Artifact',
|
||||
model: 'claude-sonnet-4@20250514',
|
||||
messages: [
|
||||
{
|
||||
role: 'system',
|
||||
content: `
|
||||
When sent new notes, respond ONLY with the contents of the html file.
|
||||
DO NOT INCLUDE ANY OTHER TEXT, EXPLANATIONS, APOLOGIES, OR INTRODUCTORY/CLOSING PHRASES.
|
||||
IF USER DOES NOT SPECIFY A STYLE, FOLLOW THE DEFAULT STYLE.
|
||||
<generate_guide>
|
||||
- The results should be a single HTML file.
|
||||
- Use tailwindcss to style the website
|
||||
- Put any additional CSS styles in a style tag and any JavaScript in a script tag.
|
||||
- Use unpkg or skypack to import any required dependencies.
|
||||
- Use Google fonts to pull in any open source fonts you require.
|
||||
- Use lucide icons for any icons.
|
||||
- If you have any images, load them from Unsplash or use solid colored rectangles.
|
||||
</generate_guide>
|
||||
|
||||
<DO_NOT_USE_COLORS>
|
||||
- DO NOT USE ANY COLORS
|
||||
</DO_NOT_USE_COLORS>
|
||||
<DO_NOT_USE_GRADIENTS>
|
||||
- DO NOT USE ANY GRADIENTS
|
||||
</DO_NOT_USE_GRADIENTS>
|
||||
|
||||
<COLOR_THEME>
|
||||
- --affine-blue-300: #93e2fd
|
||||
- --affine-blue-400: #60cffa
|
||||
- --affine-blue-500: #3ab5f7
|
||||
- --affine-blue-600: #1e96eb
|
||||
- --affine-blue-700: #1e67af
|
||||
- --affine-text-primary-color: #121212
|
||||
- --affine-text-secondary-color: #8e8d91
|
||||
- --affine-text-disable-color: #a9a9ad
|
||||
- --affine-background-overlay-panel-color: #fbfbfc
|
||||
- --affine-background-secondary-color: #f4f4f5
|
||||
- --affine-background-primary-color: #fff
|
||||
</COLOR_THEME>
|
||||
<default_style_guide>
|
||||
- MUST USE White and Blue(#1e96eb) as the primary color
|
||||
- KEEP THE DEFAULT STYLE SIMPLE AND CLEAN
|
||||
- DO NOT USE ANY COMPLEX STYLES
|
||||
- DO NOT USE ANY GRADIENTS
|
||||
- USE LESS SHADOWS
|
||||
- USE RADIUS 4px or 8px for rounded corners
|
||||
- USE 12px or 16px for padding
|
||||
- Use the tailwind color gray, zinc, slate, neutral much more.
|
||||
- Use 0.5px border should be better
|
||||
</default_style_guide>
|
||||
`,
|
||||
},
|
||||
{
|
||||
role: 'user',
|
||||
content: '{{content}}',
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
const CHAT_PROMPT: Omit<Prompt, 'name'> = {
|
||||
@@ -1973,84 +2084,6 @@ const chat: Prompt[] = [
|
||||
name: 'Chat With AFFiNE AI',
|
||||
...CHAT_PROMPT,
|
||||
},
|
||||
{
|
||||
name: 'Search With AFFiNE AI',
|
||||
...CHAT_PROMPT,
|
||||
},
|
||||
// use for believer plan
|
||||
{
|
||||
name: 'Chat With AFFiNE AI - Believer',
|
||||
model: 'gpt-o1',
|
||||
messages: [
|
||||
{
|
||||
role: 'system',
|
||||
content:
|
||||
"You are AFFiNE AI, a professional and humorous copilot within AFFiNE. You are powered by latest GPT model from OpenAI and AFFiNE. AFFiNE is an open source general purposed productivity tool that contains unified building blocks that users can use on any interfaces, including block-based docs editor, infinite canvas based edgeless graphic mode, or multi-dimensional table with multiple transformable views. Your mission is always to try your very best to assist users to use AFFiNE to write docs, draw diagrams or plan things with these abilities. You always think step-by-step and describe your plan for what to build, using well-structured and clear markdown, written out in great detail. Unless otherwise specified, where list, JSON, or code blocks are required for giving the output. Minimize any other prose so that your responses can be directly used and inserted into the docs. You are able to access to API of AFFiNE to finish your job. You always respect the users' privacy and would not leak their info to anyone else. AFFiNE is made by Toeverything .Pte .Ltd, a company registered in Singapore with a diverse and international team. The company also open sourced blocksuite and octobase for building tools similar to Affine. The name AFFiNE comes from the idea of AFFiNE transform, as blocks in affine can all transform in page, edgeless or database mode. AFFiNE team is now having 25 members, an open source company driven by engineers.",
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
const artifactActions: Prompt[] = [
|
||||
{
|
||||
name: 'Code Artifact',
|
||||
model: 'claude-sonnet-4@20250514',
|
||||
messages: [
|
||||
{
|
||||
role: 'system',
|
||||
content: `
|
||||
When sent new notes, respond ONLY with the contents of the html file.
|
||||
DO NOT INCLUDE ANY OTHER TEXT, EXPLANATIONS, APOLOGIES, OR INTRODUCTORY/CLOSING PHRASES.
|
||||
IF USER DOES NOT SPECIFY A STYLE, FOLLOW THE DEFAULT STYLE.
|
||||
<generate_guide>
|
||||
- The results should be a single HTML file.
|
||||
- Use tailwindcss to style the website
|
||||
- Put any additional CSS styles in a style tag and any JavaScript in a script tag.
|
||||
- Use unpkg or skypack to import any required dependencies.
|
||||
- Use Google fonts to pull in any open source fonts you require.
|
||||
- Use lucide icons for any icons.
|
||||
- If you have any images, load them from Unsplash or use solid colored rectangles.
|
||||
</generate_guide>
|
||||
|
||||
<DO_NOT_USE_COLORS>
|
||||
- DO NOT USE ANY COLORS
|
||||
</DO_NOT_USE_COLORS>
|
||||
<DO_NOT_USE_GRADIENTS>
|
||||
- DO NOT USE ANY GRADIENTS
|
||||
</DO_NOT_USE_GRADIENTS>
|
||||
|
||||
<COLOR_THEME>
|
||||
- --affine-blue-300: #93e2fd
|
||||
- --affine-blue-400: #60cffa
|
||||
- --affine-blue-500: #3ab5f7
|
||||
- --affine-blue-600: #1e96eb
|
||||
- --affine-blue-700: #1e67af
|
||||
- --affine-text-primary-color: #121212
|
||||
- --affine-text-secondary-color: #8e8d91
|
||||
- --affine-text-disable-color: #a9a9ad
|
||||
- --affine-background-overlay-panel-color: #fbfbfc
|
||||
- --affine-background-secondary-color: #f4f4f5
|
||||
- --affine-background-primary-color: #fff
|
||||
</COLOR_THEME>
|
||||
<default_style_guide>
|
||||
- MUST USE White and Blue(#1e96eb) as the primary color
|
||||
- KEEP THE DEFAULT STYLE SIMPLE AND CLEAN
|
||||
- DO NOT USE ANY COMPLEX STYLES
|
||||
- DO NOT USE ANY GRADIENTS
|
||||
- USE LESS SHADOWS
|
||||
- USE RADIUS 4px or 8px for rounded corners
|
||||
- USE 12px or 16px for padding
|
||||
- Use the tailwind color gray, zinc, slate, neutral much more.
|
||||
- Use 0.5px border should be better
|
||||
</default_style_guide>
|
||||
`,
|
||||
},
|
||||
{
|
||||
role: 'user',
|
||||
content: '{{content}}',
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
export const prompts: Prompt[] = [
|
||||
@@ -2059,7 +2092,6 @@ export const prompts: Prompt[] = [
|
||||
...modelActions,
|
||||
...chat,
|
||||
...workflows,
|
||||
...artifactActions,
|
||||
];
|
||||
|
||||
export async function refreshPrompts(db: PrismaClient) {
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import { Injectable, OnApplicationBootstrap } from '@nestjs/common';
|
||||
import { PrismaClient } from '@prisma/client';
|
||||
import { Injectable, Logger, OnApplicationBootstrap } from '@nestjs/common';
|
||||
import { Transactional } from '@nestjs-cls/transactional';
|
||||
import { Prisma, PrismaClient } from '@prisma/client';
|
||||
|
||||
import { Config, OnEvent } from '../../../base';
|
||||
import {
|
||||
PromptConfig,
|
||||
PromptConfigSchema,
|
||||
@@ -8,19 +10,65 @@ import {
|
||||
PromptMessageSchema,
|
||||
} from '../providers';
|
||||
import { ChatPrompt } from './chat-prompt';
|
||||
import { refreshPrompts } from './prompts';
|
||||
import {
|
||||
CopilotPromptScenario,
|
||||
prompts,
|
||||
refreshPrompts,
|
||||
Scenario,
|
||||
} from './prompts';
|
||||
|
||||
@Injectable()
|
||||
export class PromptService implements OnApplicationBootstrap {
|
||||
private readonly logger = new Logger(PromptService.name);
|
||||
private readonly cache = new Map<string, ChatPrompt>();
|
||||
|
||||
constructor(private readonly db: PrismaClient) {}
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
private readonly db: PrismaClient
|
||||
) {}
|
||||
|
||||
async onApplicationBootstrap() {
|
||||
this.cache.clear();
|
||||
await refreshPrompts(this.db);
|
||||
}
|
||||
|
||||
@OnEvent('config.init')
|
||||
async onConfigInit() {
|
||||
await this.setup(this.config.copilot?.scenarios);
|
||||
}
|
||||
|
||||
@OnEvent('config.changed')
|
||||
async onConfigChanged(event: Events['config.changed']) {
|
||||
if ('copilot' in event.updates) {
|
||||
await this.setup(event.updates.copilot?.scenarios);
|
||||
}
|
||||
}
|
||||
|
||||
protected async setup(scenarios?: CopilotPromptScenario) {
|
||||
if (!!scenarios && scenarios.enabled && scenarios.scenarios) {
|
||||
this.logger.log('Updating prompts based on scenarios...');
|
||||
for (const [scenario, model] of Object.entries(scenarios.scenarios)) {
|
||||
const promptNames = Scenario[scenario];
|
||||
for (const name of promptNames) {
|
||||
const prompt = prompts.find(p => p.name === name);
|
||||
if (prompt && model) {
|
||||
await this.update(
|
||||
prompt.name,
|
||||
{ model, modified: true },
|
||||
{ model: { not: model } }
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
this.logger.log('No scenarios enabled, using default prompts.');
|
||||
const prompts = Object.values(Scenario).flat();
|
||||
for (const prompt of prompts) {
|
||||
await this.update(prompt, { modified: false });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* list prompt names
|
||||
* @returns prompt names
|
||||
@@ -121,33 +169,46 @@ export class PromptService implements OnApplicationBootstrap {
|
||||
.then(ret => ret.id);
|
||||
}
|
||||
|
||||
@Transactional()
|
||||
async update(
|
||||
name: string,
|
||||
messages: PromptMessage[],
|
||||
modifyByApi: boolean = false,
|
||||
config?: PromptConfig
|
||||
data: {
|
||||
messages?: PromptMessage[];
|
||||
model?: string;
|
||||
modified?: boolean;
|
||||
config?: PromptConfig;
|
||||
},
|
||||
where?: Prisma.AiPromptWhereInput
|
||||
) {
|
||||
const { id } = await this.db.aiPrompt.update({
|
||||
where: { name },
|
||||
data: {
|
||||
config: config || undefined,
|
||||
updatedAt: new Date(),
|
||||
modified: modifyByApi,
|
||||
messages: {
|
||||
// cleanup old messages
|
||||
deleteMany: {},
|
||||
create: messages.map((m, idx) => ({
|
||||
idx,
|
||||
...m,
|
||||
attachments: m.attachments || undefined,
|
||||
params: m.params || undefined,
|
||||
})),
|
||||
const { config, messages, model, modified } = data;
|
||||
const existing = await this.db.aiPrompt
|
||||
.count({ where: { ...where, name } })
|
||||
.then(count => count > 0);
|
||||
if (existing) {
|
||||
await this.db.aiPrompt.update({
|
||||
where: { name },
|
||||
data: {
|
||||
config: config || undefined,
|
||||
updatedAt: new Date(),
|
||||
modified,
|
||||
model,
|
||||
messages: messages
|
||||
? {
|
||||
// cleanup old messages
|
||||
deleteMany: {},
|
||||
create: messages.map((m, idx) => ({
|
||||
idx,
|
||||
...m,
|
||||
attachments: m.attachments || undefined,
|
||||
params: m.params || undefined,
|
||||
})),
|
||||
}
|
||||
: undefined,
|
||||
},
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
this.cache.delete(name);
|
||||
return id;
|
||||
this.cache.delete(name);
|
||||
}
|
||||
}
|
||||
|
||||
async delete(name: string) {
|
||||
|
||||
@@ -2,26 +2,20 @@ import {
|
||||
type AnthropicProvider as AnthropicSDKProvider,
|
||||
createAnthropic,
|
||||
} from '@ai-sdk/anthropic';
|
||||
import z from 'zod';
|
||||
|
||||
import {
|
||||
CopilotChatOptions,
|
||||
CopilotProviderType,
|
||||
ModelConditions,
|
||||
ModelInputType,
|
||||
ModelOutputType,
|
||||
PromptMessage,
|
||||
StreamObject,
|
||||
} from '../types';
|
||||
import { CopilotProviderType, ModelInputType, ModelOutputType } from '../types';
|
||||
import { AnthropicProvider } from './anthropic';
|
||||
|
||||
export type AnthropicOfficialConfig = {
|
||||
apiKey: string;
|
||||
baseUrl?: string;
|
||||
fallback?: {
|
||||
text?: string;
|
||||
};
|
||||
baseURL?: string;
|
||||
};
|
||||
|
||||
const ModelListSchema = z.object({
|
||||
data: z.array(z.object({ id: z.string() })),
|
||||
});
|
||||
|
||||
export class AnthropicOfficialProvider extends AnthropicProvider<AnthropicOfficialConfig> {
|
||||
override readonly type = CopilotProviderType.Anthropic;
|
||||
|
||||
@@ -75,34 +69,27 @@ export class AnthropicOfficialProvider extends AnthropicProvider<AnthropicOffici
|
||||
super.setup();
|
||||
this.instance = createAnthropic({
|
||||
apiKey: this.config.apiKey,
|
||||
baseURL: this.config.baseUrl,
|
||||
baseURL: this.config.baseURL,
|
||||
});
|
||||
}
|
||||
|
||||
override async text(
|
||||
cond: ModelConditions,
|
||||
messages: PromptMessage[],
|
||||
options: CopilotChatOptions = {}
|
||||
): Promise<string> {
|
||||
const fullCond = { ...cond, fallbackModel: this.config.fallback?.text };
|
||||
return super.text(fullCond, messages, options);
|
||||
}
|
||||
|
||||
override async *streamText(
|
||||
cond: ModelConditions,
|
||||
messages: PromptMessage[],
|
||||
options: CopilotChatOptions = {}
|
||||
): AsyncIterable<string> {
|
||||
const fullCond = { ...cond, fallbackModel: this.config.fallback?.text };
|
||||
yield* super.streamText(fullCond, messages, options);
|
||||
}
|
||||
|
||||
override async *streamObject(
|
||||
cond: ModelConditions,
|
||||
messages: PromptMessage[],
|
||||
options: CopilotChatOptions = {}
|
||||
): AsyncIterable<StreamObject> {
|
||||
const fullCond = { ...cond, fallbackModel: this.config.fallback?.text };
|
||||
yield* super.streamObject(fullCond, messages, options);
|
||||
override async refreshOnlineModels() {
|
||||
try {
|
||||
const baseUrl = this.config.baseURL || 'https://api.anthropic.com/v1';
|
||||
if (baseUrl && !this.onlineModelList.length) {
|
||||
const { data } = await fetch(`${baseUrl}/models`, {
|
||||
headers: {
|
||||
'x-api-key': this.config.apiKey,
|
||||
'anthropic-version': '2023-06-01',
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
})
|
||||
.then(r => r.json())
|
||||
.then(r => ModelListSchema.parse(r));
|
||||
this.onlineModelList = data.map(model => model.id);
|
||||
}
|
||||
} catch (e) {
|
||||
this.logger.error('Failed to fetch available models', e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,23 +4,11 @@ import {
|
||||
type GoogleVertexAnthropicProviderSettings,
|
||||
} from '@ai-sdk/google-vertex/anthropic';
|
||||
|
||||
import {
|
||||
CopilotChatOptions,
|
||||
CopilotProviderType,
|
||||
ModelConditions,
|
||||
ModelInputType,
|
||||
ModelOutputType,
|
||||
PromptMessage,
|
||||
StreamObject,
|
||||
} from '../types';
|
||||
import { CopilotProviderType, ModelInputType, ModelOutputType } from '../types';
|
||||
import { getGoogleAuth, VertexModelListSchema } from '../utils';
|
||||
import { AnthropicProvider } from './anthropic';
|
||||
|
||||
export type AnthropicVertexConfig = GoogleVertexAnthropicProviderSettings & {
|
||||
fallback?: {
|
||||
text?: string;
|
||||
};
|
||||
};
|
||||
export type AnthropicVertexConfig = GoogleVertexAnthropicProviderSettings;
|
||||
|
||||
export class AnthropicVertexProvider extends AnthropicProvider<AnthropicVertexConfig> {
|
||||
override readonly type = CopilotProviderType.AnthropicVertex;
|
||||
@@ -76,33 +64,6 @@ export class AnthropicVertexProvider extends AnthropicProvider<AnthropicVertexCo
|
||||
this.instance = createVertexAnthropic(this.config);
|
||||
}
|
||||
|
||||
override async text(
|
||||
cond: ModelConditions,
|
||||
messages: PromptMessage[],
|
||||
options: CopilotChatOptions = {}
|
||||
): Promise<string> {
|
||||
const fullCond = { ...cond, fallbackModel: this.config.fallback?.text };
|
||||
return super.text(fullCond, messages, options);
|
||||
}
|
||||
|
||||
override async *streamText(
|
||||
cond: ModelConditions,
|
||||
messages: PromptMessage[],
|
||||
options: CopilotChatOptions = {}
|
||||
): AsyncIterable<string> {
|
||||
const fullCond = { ...cond, fallbackModel: this.config.fallback?.text };
|
||||
yield* super.streamText(fullCond, messages, options);
|
||||
}
|
||||
|
||||
override async *streamObject(
|
||||
cond: ModelConditions,
|
||||
messages: PromptMessage[],
|
||||
options: CopilotChatOptions = {}
|
||||
): AsyncIterable<StreamObject> {
|
||||
const fullCond = { ...cond, fallbackModel: this.config.fallback?.text };
|
||||
yield* super.streamObject(fullCond, messages, options);
|
||||
}
|
||||
|
||||
override async refreshOnlineModels() {
|
||||
try {
|
||||
const { baseUrl, headers } = await getGoogleAuth(
|
||||
|
||||
@@ -74,6 +74,16 @@ export class FalProvider extends CopilotProvider<FalConfig> {
|
||||
override type = CopilotProviderType.FAL;
|
||||
|
||||
override readonly models = [
|
||||
{
|
||||
id: 'lcm',
|
||||
capabilities: [
|
||||
{
|
||||
input: [ModelInputType.Text],
|
||||
output: [ModelOutputType.Image],
|
||||
defaultForOutputType: true,
|
||||
},
|
||||
],
|
||||
},
|
||||
// image to image models
|
||||
{
|
||||
id: 'lcm-sd15-i2i',
|
||||
|
||||
@@ -4,27 +4,12 @@ import {
|
||||
} from '@ai-sdk/google';
|
||||
import z from 'zod';
|
||||
|
||||
import {
|
||||
CopilotChatOptions,
|
||||
CopilotEmbeddingOptions,
|
||||
CopilotProviderType,
|
||||
ModelConditions,
|
||||
ModelInputType,
|
||||
ModelOutputType,
|
||||
PromptMessage,
|
||||
StreamObject,
|
||||
} from '../types';
|
||||
import { CopilotProviderType, ModelInputType, ModelOutputType } from '../types';
|
||||
import { GeminiProvider } from './gemini';
|
||||
|
||||
export type GeminiGenerativeConfig = {
|
||||
apiKey: string;
|
||||
baseUrl?: string;
|
||||
fallback?: {
|
||||
text?: string;
|
||||
structured?: string;
|
||||
image?: string;
|
||||
embedding?: string;
|
||||
};
|
||||
baseURL?: string;
|
||||
};
|
||||
|
||||
const ModelListSchema = z.object({
|
||||
@@ -113,65 +98,14 @@ export class GeminiGenerativeProvider extends GeminiProvider<GeminiGenerativeCon
|
||||
super.setup();
|
||||
this.instance = createGoogleGenerativeAI({
|
||||
apiKey: this.config.apiKey,
|
||||
baseURL: this.config.baseUrl,
|
||||
baseURL: this.config.baseURL,
|
||||
});
|
||||
}
|
||||
|
||||
override async text(
|
||||
cond: ModelConditions,
|
||||
messages: PromptMessage[],
|
||||
options: CopilotChatOptions = {}
|
||||
): Promise<string> {
|
||||
const fullCond = { ...cond, fallbackModel: this.config.fallback?.text };
|
||||
return super.text(fullCond, messages, options);
|
||||
}
|
||||
|
||||
override async structure(
|
||||
cond: ModelConditions,
|
||||
messages: PromptMessage[],
|
||||
options?: CopilotChatOptions
|
||||
): Promise<string> {
|
||||
const fullCond = {
|
||||
...cond,
|
||||
fallbackModel: this.config.fallback?.structured,
|
||||
};
|
||||
return super.structure(fullCond, messages, options);
|
||||
}
|
||||
|
||||
override async *streamText(
|
||||
cond: ModelConditions,
|
||||
messages: PromptMessage[],
|
||||
options: CopilotChatOptions = {}
|
||||
): AsyncIterable<string> {
|
||||
const fullCond = { ...cond, fallbackModel: this.config.fallback?.text };
|
||||
yield* super.streamText(fullCond, messages, options);
|
||||
}
|
||||
|
||||
override async *streamObject(
|
||||
cond: ModelConditions,
|
||||
messages: PromptMessage[],
|
||||
options: CopilotChatOptions = {}
|
||||
): AsyncIterable<StreamObject> {
|
||||
const fullCond = { ...cond, fallbackModel: this.config.fallback?.text };
|
||||
yield* super.streamObject(fullCond, messages, options);
|
||||
}
|
||||
|
||||
override async embedding(
|
||||
cond: ModelConditions,
|
||||
messages: string | string[],
|
||||
options?: CopilotEmbeddingOptions
|
||||
): Promise<number[][]> {
|
||||
const fullCond = {
|
||||
...cond,
|
||||
fallbackModel: this.config.fallback?.embedding,
|
||||
};
|
||||
return super.embedding(fullCond, messages, options);
|
||||
}
|
||||
|
||||
override async refreshOnlineModels() {
|
||||
try {
|
||||
const baseUrl =
|
||||
this.config.baseUrl ||
|
||||
this.config.baseURL ||
|
||||
'https://generativelanguage.googleapis.com/v1beta';
|
||||
if (baseUrl && !this.onlineModelList.length) {
|
||||
const { models } = await fetch(
|
||||
|
||||
@@ -4,27 +4,11 @@ import {
|
||||
type GoogleVertexProviderSettings,
|
||||
} from '@ai-sdk/google-vertex';
|
||||
|
||||
import {
|
||||
CopilotChatOptions,
|
||||
CopilotEmbeddingOptions,
|
||||
CopilotProviderType,
|
||||
ModelConditions,
|
||||
ModelInputType,
|
||||
ModelOutputType,
|
||||
PromptMessage,
|
||||
StreamObject,
|
||||
} from '../types';
|
||||
import { CopilotProviderType, ModelInputType, ModelOutputType } from '../types';
|
||||
import { getGoogleAuth, VertexModelListSchema } from '../utils';
|
||||
import { GeminiProvider } from './gemini';
|
||||
|
||||
export type GeminiVertexConfig = GoogleVertexProviderSettings & {
|
||||
fallback?: {
|
||||
text?: string;
|
||||
structured?: string;
|
||||
image?: string;
|
||||
embedding?: string;
|
||||
};
|
||||
};
|
||||
export type GeminiVertexConfig = GoogleVertexProviderSettings;
|
||||
|
||||
export class GeminiVertexProvider extends GeminiProvider<GeminiVertexConfig> {
|
||||
override readonly type = CopilotProviderType.GeminiVertex;
|
||||
@@ -90,57 +74,6 @@ export class GeminiVertexProvider extends GeminiProvider<GeminiVertexConfig> {
|
||||
this.instance = createVertex(this.config);
|
||||
}
|
||||
|
||||
override async text(
|
||||
cond: ModelConditions,
|
||||
messages: PromptMessage[],
|
||||
options: CopilotChatOptions = {}
|
||||
): Promise<string> {
|
||||
const fullCond = { ...cond, fallbackModel: this.config.fallback?.text };
|
||||
return super.text(fullCond, messages, options);
|
||||
}
|
||||
|
||||
override async structure(
|
||||
cond: ModelConditions,
|
||||
messages: PromptMessage[],
|
||||
options?: CopilotChatOptions
|
||||
): Promise<string> {
|
||||
const fullCond = {
|
||||
...cond,
|
||||
fallbackModel: this.config.fallback?.structured,
|
||||
};
|
||||
return super.structure(fullCond, messages, options);
|
||||
}
|
||||
|
||||
override async *streamText(
|
||||
cond: ModelConditions,
|
||||
messages: PromptMessage[],
|
||||
options: CopilotChatOptions = {}
|
||||
): AsyncIterable<string> {
|
||||
const fullCond = { ...cond, fallbackModel: this.config.fallback?.text };
|
||||
yield* super.streamText(fullCond, messages, options);
|
||||
}
|
||||
|
||||
override async *streamObject(
|
||||
cond: ModelConditions,
|
||||
messages: PromptMessage[],
|
||||
options: CopilotChatOptions = {}
|
||||
): AsyncIterable<StreamObject> {
|
||||
const fullCond = { ...cond, fallbackModel: this.config.fallback?.text };
|
||||
yield* super.streamObject(fullCond, messages, options);
|
||||
}
|
||||
|
||||
override async embedding(
|
||||
cond: ModelConditions,
|
||||
messages: string | string[],
|
||||
options?: CopilotEmbeddingOptions
|
||||
): Promise<number[][]> {
|
||||
const fullCond = {
|
||||
...cond,
|
||||
fallbackModel: this.config.fallback?.embedding,
|
||||
};
|
||||
return super.embedding(fullCond, messages, options);
|
||||
}
|
||||
|
||||
override async refreshOnlineModels() {
|
||||
try {
|
||||
const { baseUrl, headers } = await getGoogleAuth(this.config, 'google');
|
||||
|
||||
@@ -45,13 +45,7 @@ export const DEFAULT_DIMENSIONS = 256;
|
||||
|
||||
export type OpenAIConfig = {
|
||||
apiKey: string;
|
||||
baseUrl?: string;
|
||||
fallback?: {
|
||||
text?: string;
|
||||
structured?: string;
|
||||
image?: string;
|
||||
embedding?: string;
|
||||
};
|
||||
baseURL?: string;
|
||||
};
|
||||
|
||||
const ModelListSchema = z.object({
|
||||
@@ -249,7 +243,7 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
super.setup();
|
||||
this.#instance = createOpenAI({
|
||||
apiKey: this.config.apiKey,
|
||||
baseURL: this.config.baseUrl,
|
||||
baseURL: this.config.baseURL,
|
||||
});
|
||||
}
|
||||
|
||||
@@ -283,7 +277,7 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
|
||||
override async refreshOnlineModels() {
|
||||
try {
|
||||
const baseUrl = this.config.baseUrl || 'https://api.openai.com/v1';
|
||||
const baseUrl = this.config.baseURL || 'https://api.openai.com/v1';
|
||||
if (baseUrl && !this.onlineModelList.length) {
|
||||
const { data } = await fetch(`${baseUrl}/models`, {
|
||||
headers: {
|
||||
@@ -320,7 +314,6 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
const fullCond = {
|
||||
...cond,
|
||||
outputType: ModelOutputType.Text,
|
||||
fallbackModel: this.config.fallback?.text,
|
||||
};
|
||||
await this.checkParams({ messages, cond: fullCond, options });
|
||||
const model = this.selectModel(fullCond);
|
||||
@@ -361,7 +354,6 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
const fullCond = {
|
||||
...cond,
|
||||
outputType: ModelOutputType.Text,
|
||||
fallbackModel: this.config.fallback?.text,
|
||||
};
|
||||
await this.checkParams({ messages, cond: fullCond, options });
|
||||
const model = this.selectModel(fullCond);
|
||||
@@ -407,11 +399,7 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
messages: PromptMessage[],
|
||||
options: CopilotChatOptions = {}
|
||||
): AsyncIterable<StreamObject> {
|
||||
const fullCond = {
|
||||
...cond,
|
||||
outputType: ModelOutputType.Object,
|
||||
fallbackModel: this.config.fallback?.text,
|
||||
};
|
||||
const fullCond = { ...cond, outputType: ModelOutputType.Object };
|
||||
await this.checkParams({ cond: fullCond, messages, options });
|
||||
const model = this.selectModel(fullCond);
|
||||
|
||||
@@ -444,11 +432,7 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
messages: PromptMessage[],
|
||||
options: CopilotStructuredOptions = {}
|
||||
): Promise<string> {
|
||||
const fullCond = {
|
||||
...cond,
|
||||
outputType: ModelOutputType.Structured,
|
||||
fallbackModel: this.config.fallback?.structured,
|
||||
};
|
||||
const fullCond = { ...cond, outputType: ModelOutputType.Structured };
|
||||
await this.checkParams({ messages, cond: fullCond, options });
|
||||
const model = this.selectModel(fullCond);
|
||||
|
||||
@@ -488,11 +472,7 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
chunkMessages: PromptMessage[][],
|
||||
options: CopilotChatOptions = {}
|
||||
): Promise<number[]> {
|
||||
const fullCond = {
|
||||
...cond,
|
||||
outputType: ModelOutputType.Text,
|
||||
fallbackModel: this.config.fallback?.text,
|
||||
};
|
||||
const fullCond = { ...cond, outputType: ModelOutputType.Text };
|
||||
await this.checkParams({ messages: [], cond: fullCond, options });
|
||||
const model = this.selectModel(fullCond);
|
||||
// get the log probability of "yes"/"no"
|
||||
@@ -605,7 +585,7 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
);
|
||||
}
|
||||
|
||||
const url = `${this.config.baseUrl || 'https://api.openai.com'}/v1/images/edits`;
|
||||
const url = `${this.config.baseURL || 'https://api.openai.com/v1'}/images/edits`;
|
||||
const res = await fetch(url, {
|
||||
method: 'POST',
|
||||
headers: { Authorization: `Bearer ${this.config.apiKey}` },
|
||||
@@ -637,11 +617,7 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
messages: PromptMessage[],
|
||||
options: CopilotImageOptions = {}
|
||||
) {
|
||||
const fullCond = {
|
||||
...cond,
|
||||
outputType: ModelOutputType.Image,
|
||||
fallbackModel: this.config.fallback?.image,
|
||||
};
|
||||
const fullCond = { ...cond, outputType: ModelOutputType.Image };
|
||||
await this.checkParams({ messages, cond: fullCond, options });
|
||||
const model = this.selectModel(fullCond);
|
||||
|
||||
@@ -691,11 +667,7 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
options: CopilotEmbeddingOptions = { dimensions: DEFAULT_DIMENSIONS }
|
||||
): Promise<number[][]> {
|
||||
messages = Array.isArray(messages) ? messages : [messages];
|
||||
const fullCond = {
|
||||
...cond,
|
||||
outputType: ModelOutputType.Embedding,
|
||||
fallbackModel: this.config.fallback?.embedding,
|
||||
};
|
||||
const fullCond = { ...cond, outputType: ModelOutputType.Embedding };
|
||||
await this.checkParams({ embeddings: messages, cond: fullCond, options });
|
||||
const model = this.selectModel(fullCond);
|
||||
|
||||
|
||||
@@ -20,9 +20,6 @@ import { chatToGPTMessage, CitationParser } from './utils';
|
||||
export type PerplexityConfig = {
|
||||
apiKey: string;
|
||||
endpoint?: string;
|
||||
fallback?: {
|
||||
text?: string;
|
||||
};
|
||||
};
|
||||
|
||||
const PerplexityErrorSchema = z.union([
|
||||
@@ -112,11 +109,7 @@ export class PerplexityProvider extends CopilotProvider<PerplexityConfig> {
|
||||
messages: PromptMessage[],
|
||||
options: CopilotChatOptions = {}
|
||||
): Promise<string> {
|
||||
const fullCond = {
|
||||
...cond,
|
||||
outputType: ModelOutputType.Text,
|
||||
fallbackModel: this.config.fallback?.text,
|
||||
};
|
||||
const fullCond = { ...cond, outputType: ModelOutputType.Text };
|
||||
await this.checkParams({ cond: fullCond, messages, options });
|
||||
const model = this.selectModel(fullCond);
|
||||
|
||||
@@ -156,11 +149,7 @@ export class PerplexityProvider extends CopilotProvider<PerplexityConfig> {
|
||||
messages: PromptMessage[],
|
||||
options: CopilotChatOptions = {}
|
||||
): AsyncIterable<string> {
|
||||
const fullCond = {
|
||||
...cond,
|
||||
outputType: ModelOutputType.Text,
|
||||
fallbackModel: this.config.fallback?.text,
|
||||
};
|
||||
const fullCond = { ...cond, outputType: ModelOutputType.Text };
|
||||
await this.checkParams({ cond: fullCond, messages, options });
|
||||
const model = this.selectModel(fullCond);
|
||||
|
||||
|
||||
@@ -104,22 +104,12 @@ export abstract class CopilotProvider<C = any> {
|
||||
|
||||
if (modelId) {
|
||||
const hasOnlineModel = this.onlineModelList.includes(modelId);
|
||||
const hasFallbackModel = cond.fallbackModel
|
||||
? this.onlineModelList.includes(cond.fallbackModel)
|
||||
: undefined;
|
||||
|
||||
const model = this.models.find(
|
||||
m => m.id === modelId && m.capabilities.some(matcher)
|
||||
);
|
||||
|
||||
if (model) {
|
||||
// return fallback model if current model is not alive
|
||||
if (!hasOnlineModel && hasFallbackModel) {
|
||||
// oxlint-disable-next-line typescript-eslint(no-non-null-assertion)
|
||||
return { id: cond.fallbackModel!, capabilities: [] };
|
||||
}
|
||||
return model;
|
||||
}
|
||||
if (model) return model;
|
||||
// allow online model without capabilities check
|
||||
if (hasOnlineModel) return { id: modelId, capabilities: [] };
|
||||
return undefined;
|
||||
|
||||
@@ -248,5 +248,4 @@ export type ModelConditions = {
|
||||
|
||||
export type ModelFullConditions = ModelConditions & {
|
||||
outputType?: ModelOutputType;
|
||||
fallbackModel?: string;
|
||||
};
|
||||
|
||||
@@ -907,7 +907,7 @@ export class PromptsManagementResolver {
|
||||
@Args('messages', { type: () => [CopilotPromptMessageType] })
|
||||
messages: CopilotPromptMessageType[]
|
||||
) {
|
||||
await this.promptService.update(name, messages, true);
|
||||
await this.promptService.update(name, { messages, modified: true });
|
||||
return this.promptService.get(name);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user