Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
huggingface
GitHub Repository: huggingface/notebooks
Path: blob/main/diffusers_doc/en/tensorflow/schedulers.ipynb
5906 views
Kernel: Unknown Kernel

Schedulers

A scheduler is an algorithm that provides instructions to the denoising process such as how much noise to remove at a certain step. It takes the model prediction from step t and applies an update for how to compute the next sample at step t-1. Different schedulers produce different results; some are faster while others are more accurate.

Diffusers supports many schedulers and allows you to modify their timestep schedules, timestep spacing, and more, to generate high-quality images in fewer steps.

This guide will show you how to load and customize schedulers.

Loading schedulers

Schedulers don't have any parameters and are defined in a configuration file. Access the .scheduler attribute of a pipeline to view the configuration.

import torch from diffusers import DiffusionPipeline pipeline = DiffusionPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, device_map="cuda" ) pipeline.scheduler

Load a different scheduler with from_pretrained() and specify the subfolder argument to load the configuration file into the correct subfolder of the pipeline repository. Pass the new scheduler to the existing pipeline.

from diffusers import DPMSolverMultistepScheduler dpm = DPMSolverMultistepScheduler.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", subfolder="scheduler" ) pipeline = DiffusionPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", scheduler=dpm, torch_dtype=torch.float16, device_map="cuda" ) pipeline.scheduler

Timestep schedules

Timestep or noise schedule decides how noise is distributed over the denoising process. The schedule can be linear or more concentrated toward the beginning or end. It is a precomputed sequence of noise levels generated from the scheduler's default configuration, but it can be customized to use other schedules.

[!TIP] The timesteps argument is only supported for a select list of schedulers and pipelines. Feel free to open a feature request if you want to extend these parameters to a scheduler and pipeline that does not currently support it!

The example below uses the Align Your Steps (AYS) schedule which can generate a high-quality image in 10 steps, significantly speeding up generation and reducing computation time.

Import the schedule and pass it to the timesteps argument in the pipeline.

import torch from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler from diffusers.schedulers import AysSchedules sampling_schedule = AysSchedules["StableDiffusionXLTimesteps"] print(sampling_schedule) "[999, 845, 730, 587, 443, 310, 193, 116, 53, 13]" pipeline = DiffusionPipeline.from_pretrained( "SG161222/RealVisXL_V4.0", torch_dtype=torch.float16, device_map="cuda" ) pipeline.scheduler = DPMSolverMultistepScheduler.from_config( pipeline.scheduler.config, algorithm_type="sde-dpmsolver++" ) prompt = "A cinematic shot of a cute little rabbit wearing a jacket and doing a thumbs up" image = pipeline( prompt=prompt, negative_prompt="", timesteps=sampling_schedule, ).images[0]
AYS timestep schedule 10 steps
Linearly-spaced timestep schedule 10 steps
Linearly-spaced timestep schedule 25 steps

Rescaling schedules

Denoising should begin with pure noise and the signal-to-noise (SNR) ration should be zero. However, some models don't actually start from pure noise which makes it difficult to generate images at brightness extremes.

[!TIP] Train your own model with v_prediction by adding the --prediction_type="v_prediction" flag to your training script. You can also search for existing models trained with v_prediction.

To fix this, a model must be trained with v_prediction. If a model is trained with v_prediction, then enable the following arguments in the scheduler.

  • Set rescale_betas_zero_snr=True to rescale the noise schedule to the very last timestep with exactly zero SNR

  • Set timestep_spacing="trailing" to force sampling from the last timestep with pure noise

from diffusers import DiffusionPipeline, DDIMScheduler pipeline = DiffusionPipeline.from_pretrained("ptx0/pseudo-journey-v2", device_map="cuda") pipeline.scheduler = DDIMScheduler.from_config( pipeline.scheduler.config, rescale_betas_zero_snr=True, timestep_spacing="trailing" )

Set guidance_rescale in the pipeline to avoid overexposed images. A lower value increases brightness, but some details may appear washed out.

prompt = """ cinematic photo of a snowy mountain at night with the northern lights aurora borealis overhead, 35mm photograph, film, professional, 4k, highly detailed """ image = pipeline(prompt, guidance_rescale=0.7).images[0]
default Stable Diffusion v2-1 image
image with zero SNR and trailing timestep spacing enabled

Timestep spacing

Timestep spacing refers to the specific steps t to sample from from the schedule. Diffusers provides three spacing types as shown below.

spacing strategyspacing calculationexample timesteps
leadingevenly spaced steps[900, 800, 700, ..., 100, 0]
linspaceinclude first and last steps and evenly divide remaining intermediate steps[1000, 888.89, 777.78, ..., 111.11, 0]
trailinginclude last step and evenly divide remaining intermediate steps beginning from the end[999, 899, 799, 699, 599, 499, 399, 299, 199, 99]

Pass the spacing strategy to the timestep_spacing argument in the scheduler.

[!TIP] The trailing strategy typically produces higher quality images with more details with fewer steps, but the difference in quality is not as obvious for more standard step values.

import torch from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler pipeline = DiffusionPipeline.from_pretrained( "SG161222/RealVisXL_V4.0", torch_dtype=torch.float16, device_map="cuda" ) pipeline.scheduler = DPMSolverMultistepScheduler.from_config( pipeline.scheduler.config, timestep_spacing="trailing" ) prompt = "A cinematic shot of a cute little black cat sitting on a pumpkin at night" image = pipeline( prompt=prompt, negative_prompt="", num_inference_steps=5, ).images[0] image
trailing spacing after 5 steps
leading spacing after 5 steps

Sigmas

Sigmas is a measure of how noisy a sample is at a certain step as defined by the schedule. When using custom sigmas, the timesteps are calculated from these values instead of the default scheduler configuration.

[!TIP] The sigmas argument is only supported for a select list of schedulers and pipelines. Feel free to open a feature request if you want to extend these parameters to a scheduler and pipeline that does not currently support it!

Pass the custom sigmas to the sigmas argument in the pipeline. The example below uses the sigmas from the 10-step AYS schedule.

import torch from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler pipeline = DiffusionPipeline.from_pretrained( "SG161222/RealVisXL_V4.0", torch_dtype=torch.float16, device_map="cuda" ) pipeline.scheduler = DPMSolverMultistepScheduler.from_config( pipeline.scheduler.config, algorithm_type="sde-dpmsolver++" ) sigmas = [14.615, 6.315, 3.771, 2.181, 1.342, 0.862, 0.555, 0.380, 0.234, 0.113, 0.0] prompt = "A cinematic shot of a cute little rabbit wearing a jacket and doing a thumbs up" image = pipeline( prompt=prompt, negative_prompt="", sigmas=sigmas, ).images[0]

Karras sigmas

Karras sigmas resamples the noise schedule for more efficient sampling by clustering sigmas more densely in the middle of the sequence where structure reconstruction is critical, while using fewer sigmas at the beginning and end where noise changes have less impact. This can increase the level of details in a generated image.

Set use_karras_sigmas=True in the scheduler to enable it.

import torch from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler pipeline = DiffusionPipeline.from_pretrained( "SG161222/RealVisXL_V4.0", torch_dtype=torch.float16, device_map="cuda" ) pipeline.scheduler = DPMSolverMultistepScheduler.from_config( pipeline.scheduler.config, algorithm_type="sde-dpmsolver++", use_karras_sigmas=True, ) prompt = "A cinematic shot of a cute little rabbit wearing a jacket and doing a thumbs up" image = pipeline( prompt=prompt, negative_prompt="", sigmas=sigmas, ).images[0]
Karras sigmas enabled
Karras sigmas disabled

Refer to the scheduler API overview for a list of schedulers that support Karras sigmas. It should only be used for models trained with Karras sigmas.

Choosing a scheduler

It's important to try different schedulers to find the best one for your use case. Here are a few recommendations to help you get started.

Resources