forked from koishijs/novelai-bot
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
262 additions
and
214 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,246 @@ | ||
import { Dict, Schema, Time } from 'koishi' | ||
import { Size } from './utils' | ||
|
||
export const modelMap = { | ||
safe: 'safe-diffusion', | ||
nai: 'nai-diffusion', | ||
furry: 'nai-diffusion-furry', | ||
} as const | ||
|
||
export const orientMap = { | ||
landscape: { height: 512, width: 768 }, | ||
portrait: { height: 768, width: 512 }, | ||
square: { height: 640, width: 640 }, | ||
} as const | ||
|
||
const lowQuality = [ | ||
'nsfw, text, cropped, jpeg artifacts, signature, watermark, username, blurry', | ||
'lowres, polar lores, worst quality, low quality, normal quality', | ||
].join(', ') | ||
|
||
const badAnatomy = [ | ||
'bad anatomy, error, long neck, cross-eyed, mutation, deformed', | ||
'bad hands, bad feet, malformed limbs, fused fingers, mutated hands', | ||
'missing fingers, fewer digits, to many fingers, extra fingers, extra digit, extra limbs, extra arms, extra legs', | ||
'poorly drawn hands, poorly drawn face, poorly drawn limbs', | ||
].join(', ') | ||
|
||
type Model = keyof typeof modelMap | ||
type Orient = keyof typeof orientMap | ||
|
||
export const models = Object.keys(modelMap) as Model[] | ||
export const orients = Object.keys(orientMap) as Orient[] | ||
|
||
export namespace sampler { | ||
export const nai = { | ||
'k_euler_a': 'Euler ancestral', | ||
'k_euler': 'Euler', | ||
'k_lms': 'LMS', | ||
'ddim': 'DDIM', | ||
'plms': 'PLMS', | ||
} | ||
|
||
export const sd = { | ||
'k_euler_a': 'Euler ancestral', | ||
'k_euler': 'Euler', | ||
'k_lms': 'LMS', | ||
'k_heun': 'Heun', | ||
'k_dpm_2': 'DPM2', | ||
'k_dpm_2_a': 'DPM2 ancestral', | ||
'k_dpm_fast': 'DPM Fast', | ||
'k_dpm_ad': 'DPM adaptive', | ||
'k_lms_ka': 'LMS karras', | ||
'k_dpm_2_ka': 'DPM2 karras', | ||
'k_dpm_2_a_ka': 'DPM2 ancestral karras', | ||
'ddim': 'DDIM', | ||
'plms': 'PLMS', | ||
} | ||
|
||
export function createSchema(map: Dict<string>) { | ||
return Schema.union(Object.entries(map).map(([key, value]) => { | ||
return Schema.const(key).description(value) | ||
})).description('默认的采样器。').default('k_euler_a') | ||
} | ||
|
||
export function sd2nai(sampler: string): string { | ||
if (sampler === 'k_euler_a') return 'k_euler_ancestral' | ||
if (sampler in nai) return sampler | ||
return 'k_euler_ancestral' | ||
} | ||
} | ||
|
||
export interface Options { | ||
enhance: boolean | ||
model: string | ||
viewport: Size | ||
sampler: string | ||
seed: string | ||
steps: number | ||
scale: number | ||
noise: number | ||
strength: number | ||
} | ||
|
||
export interface Config { | ||
type: 'token' | 'login' | 'naifu' | 'sd-webui' | ||
token: string | ||
email: string | ||
password: string | ||
model?: Model | ||
orient?: Orient | ||
sampler?: string | ||
anatomy?: boolean | ||
output?: 'minimal' | 'default' | 'verbose' | ||
allowAnlas?: boolean | number | ||
basePrompt?: string | ||
negativePrompt?: string | ||
forbidden?: string | ||
endpoint?: string | ||
headers?: Dict<string> | ||
maxRetryCount?: number | ||
requestTimeout?: number | ||
recallTimeout?: number | ||
maxConcurrency?: number | ||
} | ||
|
||
export const Config = Schema.intersect([ | ||
Schema.object({ | ||
type: Schema.union([ | ||
Schema.const('token' as const).description('授权令牌'), | ||
...process.env.KOISHI_ENV === 'browser' ? [] : [Schema.const('login' as const).description('账号密码')], | ||
Schema.const('naifu' as const).description('naifu'), | ||
Schema.const('sd-webui' as const).description('sd-webui'), | ||
] as const).description('登录方式'), | ||
}).description('登录设置'), | ||
|
||
Schema.union([ | ||
Schema.intersect([ | ||
Schema.union([ | ||
Schema.object({ | ||
type: Schema.const('token'), | ||
token: Schema.string().description('授权令牌。').role('secret').required(), | ||
}), | ||
Schema.object({ | ||
type: Schema.const('login'), | ||
email: Schema.string().description('用户名。').required(), | ||
password: Schema.string().description('密码。').role('secret').required(), | ||
}), | ||
]), | ||
Schema.object({ | ||
endpoint: Schema.string().description('API 服务器地址。').default('https://api.novelai.net'), | ||
headers: Schema.dict(String).description('要附加的额外请求头。').default({ | ||
'referer': 'https://novelai.net/', | ||
'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/106.0.0.0 Safari/537.36', | ||
}), | ||
allowAnlas: Schema.union([ | ||
Schema.const(true).description('允许'), | ||
Schema.const(false).description('禁止'), | ||
Schema.natural().description('权限等级').default(1), | ||
]).default(true).description('是否允许使用点数。禁用后部分功能 (图片增强和手动设置某些参数) 将无法使用。'), | ||
}), | ||
]), | ||
Schema.object({ | ||
type: Schema.const('naifu'), | ||
token: Schema.string().description('授权令牌。').role('secret'), | ||
endpoint: Schema.string().description('API 服务器地址。').required(), | ||
headers: Schema.dict(String).description('要附加的额外请求头。'), | ||
}), | ||
Schema.object({ | ||
type: Schema.const('sd-webui'), | ||
endpoint: Schema.string().description('API 服务器地址。').required(), | ||
headers: Schema.dict(String).description('要附加的额外请求头。'), | ||
}), | ||
]), | ||
|
||
Schema.union([ | ||
Schema.object({ | ||
type: Schema.const('sd-webui'), | ||
sampler: sampler.createSchema(sampler.sd), | ||
}).description('功能设置'), | ||
Schema.object({ | ||
type: Schema.const('naifu'), | ||
sampler: sampler.createSchema(sampler.nai), | ||
}).description('功能设置'), | ||
Schema.object({ | ||
model: Schema.union(models).description('默认的生成模型。').default('nai'), | ||
sampler: sampler.createSchema(sampler.nai), | ||
}).description('功能设置'), | ||
] as const), | ||
|
||
Schema.object({ | ||
orient: Schema.union(orients).description('默认的图片方向。').default('portrait'), | ||
output: Schema.union([ | ||
Schema.const('minimal').description('只发送图片'), | ||
Schema.const('default').description('发送图片和关键信息'), | ||
Schema.const('verbose').description('发送全部信息'), | ||
]).description('输出方式。').default('default'), | ||
basePrompt: Schema.string().role('textarea').description('默认附加的标签。').default('masterpiece, best quality'), | ||
negativePrompt: Schema.string().role('textarea').description('默认附加的反向标签。').default([lowQuality, badAnatomy].join(', ')), | ||
forbidden: Schema.string().role('textarea').description('违禁词列表。含有违禁词的请求将被拒绝。').default(''), | ||
maxRetryCount: Schema.natural().description('连接失败时最大的重试次数。').default(3), | ||
requestTimeout: Schema.number().role('time').description('当请求超过这个时间时会中止并提示超时。').default(Time.minute), | ||
recallTimeout: Schema.number().role('time').description('图片发送后自动撤回的时间 (设置为 0 以禁用此功能)。').default(0), | ||
maxConcurrency: Schema.number().description('单个频道下的最大并发数量 (设置为 0 以禁用此功能)。').default(0), | ||
}), | ||
]) as Schema<Config> | ||
|
||
interface Forbidden { | ||
pattern: string | ||
strict: boolean | ||
} | ||
|
||
export function parseForbidden(input: string) { | ||
return input.trim() | ||
.toLowerCase() | ||
.replace(/,/g, ',') | ||
.split(/(?:,\s*|\s*\n\s*)/g) | ||
.filter(Boolean) | ||
.map<Forbidden>((pattern: string) => { | ||
const strict = pattern.endsWith('!') | ||
if (strict) pattern = pattern.slice(0, -1) | ||
pattern = pattern.replace(/[^a-z0-9]+/g, ' ').trim() | ||
return { pattern, strict } | ||
}) | ||
} | ||
|
||
export function parseInput(input: string, config: Config, forbidden: Forbidden[]): string[] { | ||
input = input.toLowerCase() | ||
.replace(/[,,]/g, ', ') | ||
.replace(/[((]/g, '{') | ||
.replace(/[))]/g, '}') | ||
.replace(/\s+/g, ' ') | ||
|
||
if (/[^\s\w"'“”‘’.,:|()\[\]{}-]/.test(input)) { | ||
return ['.invalid-input'] | ||
} | ||
|
||
// extract negative prompts | ||
const undesired = [config.negativePrompt] | ||
const capture = input.match(/(,\s*|\s+)(-u\s+|negative prompts?:)\s*([\s\S]+)/m) | ||
if (capture?.[3]) { | ||
input = input.slice(0, capture.index).trim() | ||
undesired.push(capture[3]) | ||
} | ||
|
||
// remove forbidden words | ||
const words = input.split(/, /g).filter((word) => { | ||
word = word.replace(/[^a-z0-9]+/g, ' ').trim() | ||
if (!word) return false | ||
for (const { pattern, strict } of forbidden) { | ||
if (strict && word.split(/\W+/g).includes(pattern)) { | ||
return false | ||
} else if (!strict && word.includes(pattern)) { | ||
return false | ||
} | ||
} | ||
return true | ||
}) | ||
|
||
// append base prompt when input does not include it | ||
for (let tag of config.basePrompt.split(/,\s*/g)) { | ||
tag = tag.trim().toLowerCase() | ||
if (tag && !words.includes(tag)) words.push(tag) | ||
} | ||
input = words.join(', ') | ||
return [null, input, undesired.join(', ')] | ||
} |
Oops, something went wrong.