Skip to content

Commit

Permalink
feat: adapt sampler with type
Browse files Browse the repository at this point in the history
  • Loading branch information
shigma committed Oct 22, 2022
1 parent 0074d4e commit 1c7ff26
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 80 deletions.
156 changes: 100 additions & 56 deletions src/index.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { Context, Dict, Logger, Quester, Schema, segment, Session, Time, trimSlash } from 'koishi'
import { StableDiffusionWebUI } from './types'
import { download, getImageSize, login, NetworkError, project, resizeInput, samplersMapN2S } from './utils'
import { download, getImageSize, login, NetworkError, project, resizeInput } from './utils'
import {} from '@koishijs/plugin-help'

export const reactive = true
Expand Down Expand Up @@ -34,11 +34,30 @@ const badAnatomy = [

type Model = keyof typeof modelMap
type Orient = keyof typeof orientMap
type Sampler = typeof samplers[number]

const models = Object.keys(modelMap) as Model[]
const orients = Object.keys(orientMap) as Orient[]
const samplers = ['k_euler_ancestral', 'k_euler', 'k_lms', 'plms', 'ddim'] as const
const naiSamplers = ['k_euler_ancestral', 'k_euler', 'k_lms', 'plms', 'ddim']
const sdSamplers = {
'k_euler_a': 'Euler a',
'k_euler': 'Euler',
'k_lms': 'LMS',
'k_heun': 'Heun',
'k_dpm_2': 'DPM2',
'k_dpm_2_a': 'DPM2 a',
'k_dpm_fast': 'DPM fast',
'k_dpm_ad': 'DPM adaptive',
'k_lms_ka': 'LMS Karras',
'k_dpm_2_ka': 'DPM2 Karras',
'k_dpm_2_a_ka': 'DPM2 a Karras',
'ddim': 'DDIM',
'plms': 'PLMS',
}

function toNAISampler(sampler: string): string {
if (naiSamplers.includes(sampler)) return sampler
return 'k_euler_ancestral'
}

export interface Config {
type: 'token' | 'login' | 'naifu' | 'sd-webui'
Expand All @@ -47,7 +66,7 @@ export interface Config {
password: string
model?: Model
orient?: Orient
sampler?: Sampler
sampler?: string
anatomy?: boolean
output?: 'minimal' | 'default' | 'verbose'
allowAnlas?: boolean | number
Expand All @@ -71,61 +90,79 @@ export const Config = Schema.intersect([
Schema.const('sd-webui' as const).description('sd-webui'),
] as const).description('登录方式'),
}).description('登录设置'),

Schema.union([
Schema.object({
type: Schema.const('token' as const),
token: Schema.string().description('授权令牌。').role('secret').required(),
endpoint: Schema.string().description('API 服务器地址。').default('https://api.novelai.net'),
headers: Schema.dict(String).description('要附加的额外请求头。').default({
'referer': 'https://novelai.net/',
'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/106.0.0.0 Safari/537.36',
}),
}),
Schema.object({
type: Schema.const('login' as const),
email: Schema.string().description('用户名。').required(),
password: Schema.string().description('密码。').role('secret').required(),
endpoint: Schema.string().description('API 服务器地址。').default('https://api.novelai.net'),
headers: Schema.dict(String).description('要附加的额外请求头。').default({
'referer': 'https://novelai.net/',
'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/106.0.0.0 Safari/537.36',
Schema.intersect([
Schema.union([
Schema.object({
type: Schema.const('token'),
token: Schema.string().description('授权令牌。').role('secret').required(),
}),
Schema.object({
type: Schema.const('login'),
email: Schema.string().description('用户名。').required(),
password: Schema.string().description('密码。').role('secret').required(),
}),
]),
Schema.object({
endpoint: Schema.string().description('API 服务器地址。').default('https://api.novelai.net'),
headers: Schema.dict(String).description('要附加的额外请求头。').default({
'referer': 'https://novelai.net/',
'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/106.0.0.0 Safari/537.36',
}),
allowAnlas: Schema.union([
Schema.const(true).description('允许'),
Schema.const(false).description('禁止'),
Schema.natural().description('权限等级').default(1),
]).default(true).description('是否允许使用点数。禁用后部分功能 (图片增强和手动设置某些参数) 将无法使用。'),
}),
}),
]),
Schema.object({
type: Schema.const('naifu' as const),
type: Schema.const('naifu'),
token: Schema.string().description('授权令牌。').role('secret'),
endpoint: Schema.string().description('API 服务器地址。').required(),
headers: Schema.dict(String).description('要附加的额外请求头。'),
}),
Schema.object({
type: Schema.const('sd-webui' as const),
type: Schema.const('sd-webui'),
endpoint: Schema.string().description('API 服务器地址。').required(),
headers: Schema.dict(String).description('要附加的额外请求头。'),
}),
]),

Schema.union([
Schema.object({
type: Schema.const('sd-webui'),
sampler: Schema.union(Object.entries(sdSamplers).map(([key, value]) => {
return Schema.const(key).description(value)
})).description('默认的采样器。').default('k_euler_a'),
}).description('功能设置'),
Schema.object({
type: Schema.const('naifu'),
sampler: Schema.union(naiSamplers).description('默认的采样器。').default('k_euler_ancestral'),
}).description('功能设置'),
Schema.object({
model: Schema.union(models).description('默认的生成模型。').default('nai'),
sampler: Schema.union(naiSamplers).description('默认的采样器。').default('k_euler_ancestral'),
}).description('功能设置'),
] as const),

Schema.object({
model: Schema.union(models).description('默认的生成模型。').default('nai'),
orient: Schema.union(orients).description('默认的图片方向。').default('portrait'),
sampler: Schema.union(samplers).description('默认的采样器。').default('k_euler_ancestral'),
anatomy: Schema.boolean().default(true).description('是否过滤不合理构图。'),
output: Schema.union([
Schema.const('minimal' as const).description('只发送图片'),
Schema.const('default' as const).description('发送图片和关键信息'),
Schema.const('verbose' as const).description('发送全部信息'),
Schema.const('minimal').description('只发送图片'),
Schema.const('default').description('发送图片和关键信息'),
Schema.const('verbose').description('发送全部信息'),
]).description('输出方式。').default('default'),
allowAnlas: Schema.union([
Schema.const(true).description('允许'),
Schema.const(false).description('禁止'),
Schema.natural().description('权限等级').default(1),
]).default(true).description('是否允许使用点数。禁用后部分功能 (图片增强和手动设置某些参数) 将无法使用。'),
basePrompt: Schema.string().role('textarea').description('默认附加的标签。').default('masterpiece, best quality'),
negativePrompt: Schema.string().role('textarea').description('默认附加的反向标签。').default([lowQuality, badAnatomy].join(', ')),
forbidden: Schema.string().role('textarea').description('违禁词列表。含有违禁词的请求将被拒绝。').default(''),
maxRetryCount: Schema.natural().description('连接失败时最大的重试次数。').default(3),
requestTimeout: Schema.number().role('time').description('当请求超过这个时间时会中止并提示超时。').default(Time.minute),
recallTimeout: Schema.number().role('time').description('图片发送后自动撤回的时间 (设置为 0 以禁用此功能)。').default(0),
maxConcurrency: Schema.number().description('单个频道下的最大并发数量 (设置为 0 以禁用此功能)。').default(0),
}).description('功能设置'),
] as const) as Schema<Config>
}),
]) as Schema<Config>

function handleError(session: Session, err: Error) {
if (Quester.isAxiosError(err)) {
Expand Down Expand Up @@ -196,7 +233,7 @@ export function apply(ctx: Context, config: Config) {
.option('enhance', '-e', { hidden })
.option('model', '-m <model>', { type: models })
.option('orient', '-o <orient>', { type: orients })
.option('sampler', '-s <sampler>', { type: samplers })
.option('sampler', '-s <sampler>')
.option('seed', '-x <seed:number>')
.option('steps', '-t <step:number>', { hidden })
.option('scale', '-c <scale:number>')
Expand Down Expand Up @@ -285,7 +322,6 @@ export function apply(ctx: Context, config: Config) {
const parameters: Dict = {
seed,
n_samples: 1,
sampler: options.sampler,
uc: undesired.join(', '),
ucPreset: 0,
}
Expand Down Expand Up @@ -358,6 +394,29 @@ export function apply(ctx: Context, config: Config) {
globalTasks.delete(id)
}

function getPostData() {
if (config.type !== 'sd-webui') {
parameters.sampler = toNAISampler(options.sampler)
return config.type === 'naifu'
? { ...parameters, prompt: input }
: { model, input, parameters }
}

return {
prompt: input,
sampler_index: sdSamplers[options.sampler],
...project(parameters, {
n_samples: 'n_samples',
seed: 'seed',
negative_prompt: 'uc',
cfg_scale: 'scale',
steps: 'steps',
width: 'width',
height: 'height',
}),
}
}

const path = config.type === 'sd-webui' ? '/sdapi/v1/txt2img' : config.type === 'naifu' ? '/generate-stream' : '/ai/generate-image'
const request = () => ctx.http.axios(trimSlash(config.endpoint) + path, {
method: 'POST',
Expand All @@ -366,23 +425,7 @@ export function apply(ctx: Context, config: Config) {
...config.headers,
authorization: 'Bearer ' + token,
},
data: config.type === 'sd-webui'
? {
prompt: input,
sampler_index: samplersMapN2S(parameters.sampler),
...project(parameters, {
n_samples: 'n_samples',
seed: 'seed',
negative_prompt: 'uc',
cfg_scale: 'scale',
steps: 'steps',
width: 'width',
height: 'height',
}),
}
: config.type === 'naifu'
? { ...parameters, prompt: input }
: { model, input, parameters },
data: getPostData(),
}).then((res) => {
if (config.type === 'sd-webui') {
return (res.data as StableDiffusionWebUI.Response).images[0]
Expand Down Expand Up @@ -452,5 +495,6 @@ export function apply(ctx: Context, config: Config) {
cmd._options.model.fallback = config.model
cmd._options.orient.fallback = config.orient
cmd._options.sampler.fallback = config.sampler
cmd._options.sampler.type = config.type === 'sd-webui' ? Object.keys(sdSamplers) : naiSamplers
}, { immediate: true })
}
6 changes: 3 additions & 3 deletions src/locales/en.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ commands:
options:
enhance: Image Enhance Mode
model: Set Model for Generation (safe, nai, furry)
orient: Set Image Orientation (portrait, landscape, square)
sampler: Set Sampler (k_euler_ancestral, k_euler, k_lms, plms, ddim)
model: Set Model for Generation
orient: Set Image Orientation
sampler: Set Sampler
anatomy.true: Filter Anatomically Incorrect Images
anatomy.false: Allow Anatomically Incorrect Images
seed: Set Random Seed
Expand Down
6 changes: 3 additions & 3 deletions src/locales/fr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ commands:
options:
enhance: Mode d'amélioration de l'image
model: Définir le modèle pour génération (safe, nai, furry)
orient: Définir l'orientation de l'image (portrait, landscape, square)
sampler: Définir l'échantillonneur (k_euler_ancestral, k_euler, k_lms, plms, ddim)
model: Définir le modèle pour génération
orient: Définir l'orientation de l'image
sampler: Définir l'échantillonneur
anatomy.true: Filtrer les images anatomiquement incorrectes
anatomy.false: Autoriser les images anatomiquement incorrectes
seed: Définir une graine aléatoire
Expand Down
6 changes: 3 additions & 3 deletions src/locales/zh-tw.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ commands:
options:
enhance: 圖片增強模式
model: 設定生成模型 (safe, nai, furry)
orient: 設定圖片方向 (portrait, landscape, square)
sampler: 設定取樣器 (k_euler_ancestral, k_euler, k_lms, plms, ddim)
model: 設定生成模型
orient: 設定圖片方向
sampler: 設定取樣器
anatomy.true: 過濾不合理構圖
anatomy.false: 允許不合理構圖
seed: 設置隨機種子
Expand Down
6 changes: 3 additions & 3 deletions src/locales/zh.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ commands:
options:
enhance: 图片增强模式
model: 设定生成模型 (safe, nai, furry)
orient: 设定图片方向 (portrait, landscape, square)
sampler: 设置采样器 (k_euler_ancestral, k_euler, k_lms, plms, ddim)
model: 设定生成模型
orient: 设定图片方向
sampler: 设置采样器
anatomy.true: 过滤不合理构图
anatomy.false: 允许不合理构图
seed: 设置随机种子
Expand Down
12 changes: 0 additions & 12 deletions src/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -184,15 +184,3 @@ export function resizeInput(size: Size): Size {
return { width, height }
}
}

export function samplersMapN2S(sampler: string): string {
switch (sampler) {
case 'k_euler_ancestral': return 'Euler a'
case 'k_euler': return 'Euler'
case 'k_lms': return 'LMS'
case 'plms': return 'PLMS'
case 'ddim': return 'DDIM'
}

return 'Euler a'
}

0 comments on commit 1c7ff26

Please sign in to comment.