From 1694cb7a87c4e738a03557330521d209ba658821 Mon Sep 17 00:00:00 2001 From: Shigma Date: Sun, 23 Oct 2022 03:46:20 +0800 Subject: [PATCH] refa: implement config.ts --- src/config.ts | 246 ++++++++++++++++++++++++++++++++++++++++++++++++++ src/index.ts | 230 ++++------------------------------------------ 2 files changed, 262 insertions(+), 214 deletions(-) create mode 100644 src/config.ts diff --git a/src/config.ts b/src/config.ts new file mode 100644 index 0000000..2ba9977 --- /dev/null +++ b/src/config.ts @@ -0,0 +1,246 @@ +import { Dict, Schema, Time } from 'koishi' +import { Size } from './utils' + +export const modelMap = { + safe: 'safe-diffusion', + nai: 'nai-diffusion', + furry: 'nai-diffusion-furry', +} as const + +export const orientMap = { + landscape: { height: 512, width: 768 }, + portrait: { height: 768, width: 512 }, + square: { height: 640, width: 640 }, +} as const + +const lowQuality = [ + 'nsfw, text, cropped, jpeg artifacts, signature, watermark, username, blurry', + 'lowres, polar lores, worst quality, low quality, normal quality', +].join(', ') + +const badAnatomy = [ + 'bad anatomy, error, long neck, cross-eyed, mutation, deformed', + 'bad hands, bad feet, malformed limbs, fused fingers, mutated hands', + 'missing fingers, fewer digits, to many fingers, extra fingers, extra digit, extra limbs, extra arms, extra legs', + 'poorly drawn hands, poorly drawn face, poorly drawn limbs', +].join(', ') + +type Model = keyof typeof modelMap +type Orient = keyof typeof orientMap + +export const models = Object.keys(modelMap) as Model[] +export const orients = Object.keys(orientMap) as Orient[] + +export namespace sampler { + export const nai = { + 'k_euler_a': 'Euler ancestral', + 'k_euler': 'Euler', + 'k_lms': 'LMS', + 'ddim': 'DDIM', + 'plms': 'PLMS', + } + + export const sd = { + 'k_euler_a': 'Euler ancestral', + 'k_euler': 'Euler', + 'k_lms': 'LMS', + 'k_heun': 'Heun', + 'k_dpm_2': 'DPM2', + 'k_dpm_2_a': 'DPM2 ancestral', + 'k_dpm_fast': 'DPM Fast', + 'k_dpm_ad': 'DPM adaptive', + 'k_lms_ka': 'LMS karras', + 'k_dpm_2_ka': 'DPM2 karras', + 'k_dpm_2_a_ka': 'DPM2 ancestral karras', + 'ddim': 'DDIM', + 'plms': 'PLMS', + } + + export function createSchema(map: Dict) { + return Schema.union(Object.entries(map).map(([key, value]) => { + return Schema.const(key).description(value) + })).description('默认的采样器。').default('k_euler_a') + } + + export function sd2nai(sampler: string): string { + if (sampler === 'k_euler_a') return 'k_euler_ancestral' + if (sampler in nai) return sampler + return 'k_euler_ancestral' + } +} + +export interface Options { + enhance: boolean + model: string + viewport: Size + sampler: string + seed: string + steps: number + scale: number + noise: number + strength: number +} + +export interface Config { + type: 'token' | 'login' | 'naifu' | 'sd-webui' + token: string + email: string + password: string + model?: Model + orient?: Orient + sampler?: string + anatomy?: boolean + output?: 'minimal' | 'default' | 'verbose' + allowAnlas?: boolean | number + basePrompt?: string + negativePrompt?: string + forbidden?: string + endpoint?: string + headers?: Dict + maxRetryCount?: number + requestTimeout?: number + recallTimeout?: number + maxConcurrency?: number +} + +export const Config = Schema.intersect([ + Schema.object({ + type: Schema.union([ + Schema.const('token' as const).description('授权令牌'), + ...process.env.KOISHI_ENV === 'browser' ? [] : [Schema.const('login' as const).description('账号密码')], + Schema.const('naifu' as const).description('naifu'), + Schema.const('sd-webui' as const).description('sd-webui'), + ] as const).description('登录方式'), + }).description('登录设置'), + + Schema.union([ + Schema.intersect([ + Schema.union([ + Schema.object({ + type: Schema.const('token'), + token: Schema.string().description('授权令牌。').role('secret').required(), + }), + Schema.object({ + type: Schema.const('login'), + email: Schema.string().description('用户名。').required(), + password: Schema.string().description('密码。').role('secret').required(), + }), + ]), + Schema.object({ + endpoint: Schema.string().description('API 服务器地址。').default('https://api.novelai.net'), + headers: Schema.dict(String).description('要附加的额外请求头。').default({ + 'referer': 'https://novelai.net/', + 'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/106.0.0.0 Safari/537.36', + }), + allowAnlas: Schema.union([ + Schema.const(true).description('允许'), + Schema.const(false).description('禁止'), + Schema.natural().description('权限等级').default(1), + ]).default(true).description('是否允许使用点数。禁用后部分功能 (图片增强和手动设置某些参数) 将无法使用。'), + }), + ]), + Schema.object({ + type: Schema.const('naifu'), + token: Schema.string().description('授权令牌。').role('secret'), + endpoint: Schema.string().description('API 服务器地址。').required(), + headers: Schema.dict(String).description('要附加的额外请求头。'), + }), + Schema.object({ + type: Schema.const('sd-webui'), + endpoint: Schema.string().description('API 服务器地址。').required(), + headers: Schema.dict(String).description('要附加的额外请求头。'), + }), + ]), + + Schema.union([ + Schema.object({ + type: Schema.const('sd-webui'), + sampler: sampler.createSchema(sampler.sd), + }).description('功能设置'), + Schema.object({ + type: Schema.const('naifu'), + sampler: sampler.createSchema(sampler.nai), + }).description('功能设置'), + Schema.object({ + model: Schema.union(models).description('默认的生成模型。').default('nai'), + sampler: sampler.createSchema(sampler.nai), + }).description('功能设置'), + ] as const), + + Schema.object({ + orient: Schema.union(orients).description('默认的图片方向。').default('portrait'), + output: Schema.union([ + Schema.const('minimal').description('只发送图片'), + Schema.const('default').description('发送图片和关键信息'), + Schema.const('verbose').description('发送全部信息'), + ]).description('输出方式。').default('default'), + basePrompt: Schema.string().role('textarea').description('默认附加的标签。').default('masterpiece, best quality'), + negativePrompt: Schema.string().role('textarea').description('默认附加的反向标签。').default([lowQuality, badAnatomy].join(', ')), + forbidden: Schema.string().role('textarea').description('违禁词列表。含有违禁词的请求将被拒绝。').default(''), + maxRetryCount: Schema.natural().description('连接失败时最大的重试次数。').default(3), + requestTimeout: Schema.number().role('time').description('当请求超过这个时间时会中止并提示超时。').default(Time.minute), + recallTimeout: Schema.number().role('time').description('图片发送后自动撤回的时间 (设置为 0 以禁用此功能)。').default(0), + maxConcurrency: Schema.number().description('单个频道下的最大并发数量 (设置为 0 以禁用此功能)。').default(0), + }), +]) as Schema + +interface Forbidden { + pattern: string + strict: boolean +} + +export function parseForbidden(input: string) { + return input.trim() + .toLowerCase() + .replace(/,/g, ',') + .split(/(?:,\s*|\s*\n\s*)/g) + .filter(Boolean) + .map((pattern: string) => { + const strict = pattern.endsWith('!') + if (strict) pattern = pattern.slice(0, -1) + pattern = pattern.replace(/[^a-z0-9]+/g, ' ').trim() + return { pattern, strict } + }) +} + +export function parseInput(input: string, config: Config, forbidden: Forbidden[]): string[] { + input = input.toLowerCase() + .replace(/[,,]/g, ', ') + .replace(/[((]/g, '{') + .replace(/[))]/g, '}') + .replace(/\s+/g, ' ') + + if (/[^\s\w"'“”‘’.,:|()\[\]{}-]/.test(input)) { + return ['.invalid-input'] + } + + // extract negative prompts + const undesired = [config.negativePrompt] + const capture = input.match(/(,\s*|\s+)(-u\s+|negative prompts?:)\s*([\s\S]+)/m) + if (capture?.[3]) { + input = input.slice(0, capture.index).trim() + undesired.push(capture[3]) + } + + // remove forbidden words + const words = input.split(/, /g).filter((word) => { + word = word.replace(/[^a-z0-9]+/g, ' ').trim() + if (!word) return false + for (const { pattern, strict } of forbidden) { + if (strict && word.split(/\W+/g).includes(pattern)) { + return false + } else if (!strict && word.includes(pattern)) { + return false + } + } + return true + }) + + // append base prompt when input does not include it + for (let tag of config.basePrompt.split(/,\s*/g)) { + tag = tag.trim().toLowerCase() + if (tag && !words.includes(tag)) words.push(tag) + } + input = words.join(', ') + return [null, input, undesired.join(', ')] +} diff --git a/src/index.ts b/src/index.ts index 402b30a..7dd142f 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,169 +1,16 @@ -import { Context, Dict, Logger, Quester, Schema, segment, Session, Time, trimSlash } from 'koishi' +import { Context, Dict, Logger, omit, Quester, segment, Session, trimSlash } from 'koishi' +import { Config, modelMap, models, orientMap, parseForbidden, parseInput, sampler } from './config' import { StableDiffusionWebUI } from './types' import { download, getImageSize, login, NetworkError, project, resizeInput, Size } from './utils' import {} from '@koishijs/plugin-help' +export * from './config' + export const reactive = true export const name = 'novelai' const logger = new Logger('novelai') -const modelMap = { - safe: 'safe-diffusion', - nai: 'nai-diffusion', - furry: 'nai-diffusion-furry', -} as const - -const orientMap = { - landscape: { height: 512, width: 768 }, - portrait: { height: 768, width: 512 }, - square: { height: 640, width: 640 }, -} as const - -const lowQuality = [ - 'nsfw, text, cropped, jpeg artifacts, signature, watermark, username, blurry', - 'lowres, polar lores, worst quality, low quality, normal quality', -].join(', ') - -const badAnatomy = [ - 'bad anatomy, error, long neck, cross-eyed, mutation, deformed', - 'bad hands, bad feet, malformed limbs, fused fingers, mutated hands', - 'missing fingers, fewer digits, to many fingers, extra fingers, extra digit, extra limbs, extra arms, extra legs', - 'poorly drawn hands, poorly drawn face, poorly drawn limbs', -].join(', ') - -type Model = keyof typeof modelMap -type Orient = keyof typeof orientMap - -const models = Object.keys(modelMap) as Model[] -const orients = Object.keys(orientMap) as Orient[] -const naiSamplers = ['k_euler_ancestral', 'k_euler', 'k_lms', 'plms', 'ddim'] -const sdSamplers = { - 'k_euler_a': 'Euler a', - 'k_euler': 'Euler', - 'k_lms': 'LMS', - 'k_heun': 'Heun', - 'k_dpm_2': 'DPM2', - 'k_dpm_2_a': 'DPM2 a', - 'k_dpm_fast': 'DPM fast', - 'k_dpm_ad': 'DPM adaptive', - 'k_lms_ka': 'LMS Karras', - 'k_dpm_2_ka': 'DPM2 Karras', - 'k_dpm_2_a_ka': 'DPM2 a Karras', - 'ddim': 'DDIM', - 'plms': 'PLMS', -} - -function toNAISampler(sampler: string): string { - if (naiSamplers.includes(sampler)) return sampler - return 'k_euler_ancestral' -} - -export interface Config { - type: 'token' | 'login' | 'naifu' | 'sd-webui' - token: string - email: string - password: string - model?: Model - orient?: Orient - sampler?: string - anatomy?: boolean - output?: 'minimal' | 'default' | 'verbose' - allowAnlas?: boolean | number - basePrompt?: string - negativePrompt?: string - forbidden?: string - endpoint?: string - headers?: Dict - maxRetryCount?: number - requestTimeout?: number - recallTimeout?: number - maxConcurrency?: number -} - -export const Config = Schema.intersect([ - Schema.object({ - type: Schema.union([ - Schema.const('token' as const).description('授权令牌'), - ...process.env.KOISHI_ENV === 'browser' ? [] : [Schema.const('login' as const).description('账号密码')], - Schema.const('naifu' as const).description('naifu'), - Schema.const('sd-webui' as const).description('sd-webui'), - ] as const).description('登录方式'), - }).description('登录设置'), - - Schema.union([ - Schema.intersect([ - Schema.union([ - Schema.object({ - type: Schema.const('token'), - token: Schema.string().description('授权令牌。').role('secret').required(), - }), - Schema.object({ - type: Schema.const('login'), - email: Schema.string().description('用户名。').required(), - password: Schema.string().description('密码。').role('secret').required(), - }), - ]), - Schema.object({ - endpoint: Schema.string().description('API 服务器地址。').default('https://api.novelai.net'), - headers: Schema.dict(String).description('要附加的额外请求头。').default({ - 'referer': 'https://novelai.net/', - 'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/106.0.0.0 Safari/537.36', - }), - allowAnlas: Schema.union([ - Schema.const(true).description('允许'), - Schema.const(false).description('禁止'), - Schema.natural().description('权限等级').default(1), - ]).default(true).description('是否允许使用点数。禁用后部分功能 (图片增强和手动设置某些参数) 将无法使用。'), - }), - ]), - Schema.object({ - type: Schema.const('naifu'), - token: Schema.string().description('授权令牌。').role('secret'), - endpoint: Schema.string().description('API 服务器地址。').required(), - headers: Schema.dict(String).description('要附加的额外请求头。'), - }), - Schema.object({ - type: Schema.const('sd-webui'), - endpoint: Schema.string().description('API 服务器地址。').required(), - headers: Schema.dict(String).description('要附加的额外请求头。'), - }), - ]), - - Schema.union([ - Schema.object({ - type: Schema.const('sd-webui'), - sampler: Schema.union(Object.entries(sdSamplers).map(([key, value]) => { - return Schema.const(key).description(value) - })).description('默认的采样器。').default('k_euler_a'), - }).description('功能设置'), - Schema.object({ - type: Schema.const('naifu'), - sampler: Schema.union(naiSamplers).description('默认的采样器。').default('k_euler_ancestral'), - }).description('功能设置'), - Schema.object({ - model: Schema.union(models).description('默认的生成模型。').default('nai'), - sampler: Schema.union(naiSamplers).description('默认的采样器。').default('k_euler_ancestral'), - }).description('功能设置'), - ] as const), - - Schema.object({ - orient: Schema.union(orients).description('默认的图片方向。').default('portrait'), - output: Schema.union([ - Schema.const('minimal').description('只发送图片'), - Schema.const('default').description('发送图片和关键信息'), - Schema.const('verbose').description('发送全部信息'), - ]).description('输出方式。').default('default'), - basePrompt: Schema.string().role('textarea').description('默认附加的标签。').default('masterpiece, best quality'), - negativePrompt: Schema.string().role('textarea').description('默认附加的反向标签。').default([lowQuality, badAnatomy].join(', ')), - forbidden: Schema.string().role('textarea').description('违禁词列表。含有违禁词的请求将被拒绝。').default(''), - maxRetryCount: Schema.natural().description('连接失败时最大的重试次数。').default(3), - requestTimeout: Schema.number().role('time').description('当请求超过这个时间时会中止并提示超时。').default(Time.minute), - recallTimeout: Schema.number().role('time').description('图片发送后自动撤回的时间 (设置为 0 以禁用此功能)。').default(0), - maxConcurrency: Schema.number().description('单个频道下的最大并发数量 (设置为 0 以禁用此功能)。').default(0), - }), -]) as Schema - function handleError(session: Session, err: Error) { if (Quester.isAxiosError(err)) { if (err.response?.status === 402) { @@ -196,17 +43,7 @@ export function apply(ctx: Context, config: Config) { const globalTasks = new Set() ctx.accept(['forbidden'], (config) => { - forbidden = config.forbidden.trim() - .toLowerCase() - .replace(/,/g, ',') - .split(/(?:,\s*|\s*\n\s*)/g) - .filter(Boolean) - .map((pattern: string) => { - const strict = pattern.endsWith('!') - if (strict) pattern = pattern.slice(0, -1) - pattern = pattern.replace(/[^a-z0-9]+/g, ' ').trim() - return { pattern, strict } - }) + forbidden = parseForbidden(config.forbidden) }, { immediate: true }) let tokenTask: Promise = null @@ -275,44 +112,8 @@ export function apply(ctx: Context, config: Config) { delete options.steps } - input = input.toLowerCase() - .replace(/[,,]/g, ', ') - .replace(/[((]/g, '{') - .replace(/[))]/g, '}') - .replace(/\s+/g, ' ') - - if (/[^\s\w"'“”‘’.,:|()\[\]{}-]/.test(input)) { - return session.text('.invalid-input') - } - - // extract negative prompts - const undesired = [config.negativePrompt] - const capture = input.match(/(,\s*|\s+)(-u\s+|negative prompts?:)\s*([\s\S]+)/m) - if (capture?.[3]) { - input = input.slice(0, capture.index).trim() - undesired.push(capture[3]) - } - - // remove forbidden words - const words = input.split(/, /g).filter((word) => { - word = word.replace(/[^a-z0-9]+/g, ' ').trim() - if (!word) return false - for (const { pattern, strict } of forbidden) { - if (strict && word.split(/\W+/g).includes(pattern)) { - return false - } else if (!strict && word.includes(pattern)) { - return false - } - } - return true - }) - - // append base prompt when input does not include it - for (let tag of config.basePrompt.split(/,\s*/g)) { - tag = tag.trim().toLowerCase() - if (tag && !words.includes(tag)) words.push(tag) - } - input = words.join(', ') + const [errPath, prompt, uc] = parseInput(input, config, forbidden) + if (errPath) return session.text(errPath) let token: string try { @@ -330,8 +131,9 @@ export function apply(ctx: Context, config: Config) { const parameters: Dict = { seed, + prompt, n_samples: 1, - uc: undesired.join(', '), + uc, ucPreset: 0, } @@ -405,16 +207,16 @@ export function apply(ctx: Context, config: Config) { function getPostData() { if (config.type !== 'sd-webui') { - parameters.sampler = toNAISampler(options.sampler) + parameters.sampler = sampler.sd2nai(options.sampler) return config.type === 'naifu' - ? { ...parameters, prompt: input } - : { model, input, parameters } + ? parameters + : { model, input, parameters: omit(parameters, ['prompt']) } } return { - prompt: input, - sampler_index: sdSamplers[options.sampler], + sampler_index: sampler.sd[options.sampler], ...project(parameters, { + prompt: 'prompt', n_samples: 'n_samples', seed: 'seed', negative_prompt: 'uc', @@ -483,7 +285,7 @@ export function apply(ctx: Context, config: Config) { result.children.push(segment('message', attrs, params.join('\n'))) result.children.push(segment('message', attrs, `prompt = ${input}`)) if (config.output === 'verbose') { - result.children.push(segment('message', attrs, `undesired = ${undesired.join(', ')}`)) + result.children.push(segment('message', attrs, `undesired = ${uc}`)) } result.children.push(segment('message', attrs, segment.image('base64://' + base64))) return result @@ -504,6 +306,6 @@ export function apply(ctx: Context, config: Config) { cmd._options.model.fallback = config.model cmd._options.viewport.fallback = config.orient cmd._options.sampler.fallback = config.sampler - cmd._options.sampler.type = config.type === 'sd-webui' ? Object.keys(sdSamplers) : naiSamplers + cmd._options.sampler.type = Object.keys(config.type === 'sd-webui' ? sampler.sd : sampler.nai) }, { immediate: true }) }