diff --git a/comfy/sd.py b/comfy/sd.py index 178f52e8b10..470bd717bb9 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -299,15 +299,24 @@ def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): return output def decode_tiled_1d(self, samples, tile_x=128, overlap=64): - output = torch.empty((samples.shape[0], self.output_channels) + tuple(map(lambda a: a * self.upscale_ratio, samples.shape[2:])), device=self.output_device) + output = torch.zeros((samples.shape[0], self.output_channels) + tuple(map(lambda a: a * self.upscale_ratio, samples.shape[2:])), device=self.output_device) + output_mult = torch.zeros((samples.shape[0], self.output_channels) + tuple(map(lambda a: a * self.upscale_ratio, samples.shape[2:])), device=self.output_device) for j in range(samples.shape[0]): for i in range(0, samples.shape[-1], tile_x - overlap): f = i t = i + tile_x - output[j:j+1,:,f * self.upscale_ratio:t * self.upscale_ratio] = self.first_stage_model.decode(samples[j:j+1,:,f:t].to(self.vae_dtype).to(self.device)).float() - - return output + o = output[j:j+1,:,f * self.upscale_ratio:t * self.upscale_ratio] + m = torch.ones_like(o) + l = m.shape[-1] + for x in range(overlap): + c = ((x + 1) / overlap) + m[:,:,x:x+1] *= c + m[:,:,l-x-1:l-x] *= c + o += self.first_stage_model.decode(samples[j:j+1,:,f:t].to(self.vae_dtype).to(self.device)).float().to(self.output_device) * m + output_mult[j:j+1,:,f * self.upscale_ratio:t * self.upscale_ratio] += m + + return output / output_mult def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): steps = pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)