Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ninjaneural
GitHub Repository: ninjaneural/webui
Path: blob/master/misc/animatediff/cli.py
3275 views
1
import logging
2
from datetime import datetime
3
from pathlib import Path
4
from typing import Annotated, Optional
5
6
import torch
7
import typer
8
from diffusers.utils.logging import set_verbosity_error as set_diffusers_verbosity_error
9
from rich.logging import RichHandler
10
11
from animatediff import __version__, console, get_dir
12
from animatediff.generate import create_pipeline, run_inference
13
from animatediff.pipelines import AnimationPipeline, load_text_embeddings
14
from animatediff.settings import (
15
CKPT_EXTENSIONS,
16
InferenceConfig,
17
ModelConfig,
18
get_infer_config,
19
get_model_config,
20
)
21
from animatediff.utils.model import checkpoint_to_pipeline, get_base_model
22
from animatediff.utils.pipeline import get_context_params, send_to_device
23
from animatediff.utils.util import path_from_cwd, save_frames, save_video
24
25
cli: typer.Typer = typer.Typer(
26
context_settings=dict(help_option_names=["-h", "--help"]),
27
rich_markup_mode="rich",
28
no_args_is_help=True,
29
pretty_exceptions_show_locals=False,
30
)
31
data_dir = get_dir("data")
32
checkpoint_dir = data_dir.joinpath("models/sd")
33
pipeline_dir = data_dir.joinpath("models/huggingface")
34
35
try:
36
import google.colab
37
38
IN_COLAB = True
39
except:
40
IN_COLAB = False
41
42
if IN_COLAB:
43
import sys
44
45
logging.basicConfig(
46
level=logging.INFO,
47
stream=sys.stdout,
48
format="%(message)s",
49
datefmt="%H:%M:%S",
50
force=True,
51
)
52
else:
53
logging.basicConfig(
54
level=logging.INFO,
55
format="%(message)s",
56
handlers=[
57
RichHandler(console=console, rich_tracebacks=True),
58
],
59
datefmt="%H:%M:%S",
60
force=True,
61
)
62
63
logger = logging.getLogger(__name__)
64
65
66
try:
67
from animatediff.rife import app as rife_app
68
69
cli.add_typer(rife_app, name="rife")
70
except ImportError:
71
logger.debug("RIFE not available, skipping...", exc_info=True)
72
rife_app = None
73
74
# mildly cursed globals to allow for reuse of the pipeline if we're being called as a module
75
pipeline: Optional[AnimationPipeline] = None
76
last_model_path: Optional[Path] = None
77
78
79
def version_callback(value: bool):
80
if value:
81
console.print(f"AnimateDiff v{__version__}")
82
raise typer.Exit()
83
84
85
@cli.command()
86
def generate(
87
model_name_or_path: Annotated[
88
Path,
89
typer.Option(
90
...,
91
"--model-path",
92
"-m",
93
path_type=Path,
94
help="Base model to use (path or HF repo ID). You probably don't need to change this.",
95
),
96
] = Path("runwayml/stable-diffusion-v1-5"),
97
config_path: Annotated[
98
Path,
99
typer.Option(
100
"--config-path",
101
"-c",
102
path_type=Path,
103
exists=True,
104
readable=True,
105
dir_okay=False,
106
help="Path to a prompt configuration JSON file",
107
),
108
] = Path("config/prompts/01-ToonYou.json"),
109
width: Annotated[
110
int,
111
typer.Option(
112
"--width",
113
"-W",
114
min=512,
115
max=3840,
116
help="Width of generated frames",
117
rich_help_panel="Generation",
118
),
119
] = 512,
120
height: Annotated[
121
int,
122
typer.Option(
123
"--height",
124
"-H",
125
min=512,
126
max=2160,
127
help="Height of generated frames",
128
rich_help_panel="Generation",
129
),
130
] = 512,
131
length: Annotated[
132
int,
133
typer.Option(
134
"--length",
135
"-L",
136
min=1,
137
max=999,
138
help="Number of frames to generate",
139
rich_help_panel="Generation",
140
),
141
] = 16,
142
context: Annotated[
143
Optional[int],
144
typer.Option(
145
"--context",
146
"-C",
147
min=1,
148
max=24,
149
help="Number of frames to condition on (default: max of <length> or 24)",
150
show_default=False,
151
rich_help_panel="Generation",
152
),
153
] = None,
154
overlap: Annotated[
155
Optional[int],
156
typer.Option(
157
"--overlap",
158
"-O",
159
min=1,
160
max=12,
161
help="Number of frames to overlap in context (default: context//2)",
162
show_default=False,
163
rich_help_panel="Generation",
164
),
165
] = None,
166
stride: Annotated[
167
Optional[int],
168
typer.Option(
169
"--stride",
170
"-S",
171
min=1,
172
max=8,
173
help="Max motion stride as a power of 2 (default: 4)",
174
show_default=False,
175
rich_help_panel="Generation",
176
),
177
] = None,
178
repeats: Annotated[
179
int,
180
typer.Option(
181
"--repeats",
182
"-r",
183
min=1,
184
max=99,
185
help="Number of times to repeat the prompt (default: 1)",
186
show_default=False,
187
rich_help_panel="Generation",
188
),
189
] = 1,
190
device: Annotated[
191
str,
192
typer.Option("--device", "-d", help="Device to run on (cpu, cuda, cuda:id)", rich_help_panel="Advanced"),
193
] = "cuda",
194
use_xformers: Annotated[
195
bool,
196
typer.Option(
197
"--xformers",
198
"-x",
199
is_flag=True,
200
help="Use XFormers instead of SDP Attention",
201
rich_help_panel="Advanced",
202
),
203
] = False,
204
force_half_vae: Annotated[
205
bool,
206
typer.Option(
207
"--half-vae",
208
is_flag=True,
209
help="Force VAE to use fp16 (not recommended)",
210
rich_help_panel="Advanced",
211
),
212
] = False,
213
out_dir: Annotated[
214
Path,
215
typer.Option(
216
"--out-dir",
217
"-o",
218
path_type=Path,
219
file_okay=False,
220
help="Directory for output folders (frames, gifs, etc)",
221
rich_help_panel="Output",
222
),
223
] = Path("output/"),
224
no_frames: Annotated[
225
bool,
226
typer.Option(
227
"--no-frames",
228
"-N",
229
is_flag=True,
230
help="Don't save frames, only the animation",
231
rich_help_panel="Output",
232
),
233
] = False,
234
save_merged: Annotated[
235
bool,
236
typer.Option(
237
"--save-merged",
238
"-m",
239
is_flag=True,
240
help="Save a merged animation of all prompts",
241
rich_help_panel="Output",
242
),
243
] = False,
244
version: Annotated[
245
Optional[bool],
246
typer.Option(
247
"--version",
248
"-v",
249
callback=version_callback,
250
is_eager=True,
251
is_flag=True,
252
help="Show version",
253
),
254
] = None,
255
):
256
"""
257
Do the thing. Make the animation happen. Waow.
258
"""
259
260
# be quiet, diffusers. we care not for your safety checker
261
set_diffusers_verbosity_error()
262
263
config_path = config_path.absolute()
264
logger.info(f"Using generation config: {path_from_cwd(config_path)}")
265
model_config: ModelConfig = get_model_config(config_path)
266
infer_config: InferenceConfig = get_infer_config()
267
268
# set sane defaults for context, overlap, and stride if not supplied
269
context, overlap, stride = get_context_params(length, context, overlap, stride)
270
271
# turn the device string into a torch.device
272
device: torch.device = torch.device(device)
273
274
# Get the base model if we don't have it already
275
logger.info(f"Using base model: {model_name_or_path}")
276
base_model_path: Path = get_base_model(model_name_or_path, local_dir=get_dir("data/models/huggingface"))
277
278
# get a timestamp for the output directory
279
time_str = datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
280
# make the output directory
281
save_dir = out_dir.joinpath(f"{time_str}-{model_config.save_name}")
282
save_dir.mkdir(parents=True, exist_ok=True)
283
logger.info(f"Will save outputs to ./{path_from_cwd(save_dir)}")
284
285
# beware the pipeline
286
global pipeline
287
global last_model_path
288
if pipeline is None or last_model_path != model_config.base.resolve():
289
pipeline = create_pipeline(
290
base_model=base_model_path,
291
model_config=model_config,
292
infer_config=infer_config,
293
use_xformers=use_xformers,
294
)
295
last_model_path = model_config.base.resolve()
296
else:
297
logger.info("Pipeline already loaded, skipping initialization")
298
# reload TIs; create_pipeline does this for us, but they may have changed
299
# since load time if we're being called from another package
300
load_text_embeddings(pipeline)
301
302
if pipeline.device == device:
303
logger.info("Pipeline already on the correct device, skipping device transfer")
304
else:
305
pipeline = send_to_device(pipeline, device, freeze=True, force_half=force_half_vae, compile=model_config.compile)
306
307
# save config to output directory
308
logger.info("Saving prompt config to output directory")
309
save_config_path = save_dir.joinpath("prompt.json")
310
save_config_path.write_text(model_config.json(), encoding="utf-8")
311
312
num_prompts = len(model_config.prompt)
313
num_negatives = len(model_config.n_prompt)
314
num_seeds = len(model_config.seed)
315
gen_total = num_prompts * repeats # total number of generations
316
317
logger.info("Initialization complete!")
318
logger.info(f"Generating {gen_total} animations from {num_prompts} prompts")
319
outputs = []
320
321
gen_num = 0 # global generation index
322
# repeat the prompts if we're doing multiple runs
323
for _ in range(repeats):
324
for prompt in model_config.prompt:
325
# get the index of the prompt, negative, and seed
326
idx = gen_num % num_prompts
327
logger.info(f"Running generation {gen_num + 1} of {gen_total} (prompt {idx + 1})")
328
329
# allow for reusing the same negative prompt(s) and seed(s) for multiple prompts
330
n_prompt = model_config.n_prompt[idx % num_negatives]
331
seed = seed = model_config.seed[idx % num_seeds]
332
333
# duplicated in run_inference, but this lets us use it for frame save dirs
334
# TODO: Move gif Output out of run_inference...
335
if seed == -1:
336
seed = torch.seed()
337
logger.info(f"Generation seed: {seed}")
338
339
output = run_inference(
340
pipeline=pipeline,
341
prompt=prompt,
342
n_prompt=n_prompt,
343
seed=seed,
344
steps=model_config.steps,
345
guidance_scale=model_config.guidance_scale,
346
width=width,
347
height=height,
348
duration=length,
349
idx=gen_num,
350
out_dir=save_dir,
351
context_frames=context,
352
context_overlap=overlap,
353
context_stride=stride,
354
clip_skip=model_config.clip_skip,
355
)
356
outputs.append(output)
357
torch.cuda.empty_cache()
358
if no_frames is not True:
359
save_frames(output, save_dir.joinpath(f"{gen_num:02d}-{seed}"))
360
361
# increment the generation number
362
gen_num += 1
363
364
logger.info("Generation complete!")
365
if save_merged:
366
logger.info("Output merged output video...")
367
merged_output = torch.concat(outputs, dim=0)
368
save_video(merged_output, save_dir.joinpath("final.gif"))
369
370
logger.info("Done, exiting...")
371
cli.info
372
373
return save_dir
374
375
376
@cli.command()
377
def convert(
378
checkpoint: Annotated[
379
Path,
380
typer.Option(
381
"--checkpoint",
382
"-i",
383
path_type=Path,
384
dir_okay=False,
385
exists=True,
386
help="Path to a model checkpoint file",
387
),
388
] = ...,
389
out_dir: Annotated[
390
Optional[Path],
391
typer.Option(
392
"--out-dir",
393
"-o",
394
path_type=Path,
395
file_okay=False,
396
help="Target directory for converted model",
397
),
398
] = None,
399
):
400
"""Convert a StableDiffusion checkpoint into a Diffusers pipeline"""
401
logger.info(f"Converting checkpoint: {checkpoint}")
402
_, pipeline_dir = checkpoint_to_pipeline(checkpoint, target_dir=out_dir)
403
logger.info(f"Converted to HuggingFace pipeline at {pipeline_dir}")
404
405
406
@cli.command()
407
def merge(
408
checkpoint: Annotated[
409
Path,
410
typer.Option(
411
"--checkpoint",
412
"-i",
413
path_type=Path,
414
dir_okay=False,
415
exists=True,
416
help="Path to a model checkpoint file",
417
),
418
] = ...,
419
out_dir: Annotated[
420
Optional[Path],
421
typer.Option(
422
"--out-dir",
423
"-o",
424
path_type=Path,
425
file_okay=False,
426
help="Target directory for converted model",
427
),
428
] = None,
429
):
430
"""Convert a StableDiffusion checkpoint into an AnimationPipeline"""
431
raise NotImplementedError("Sorry, haven't implemented this yet!")
432
433
# if we have a checkpoint, convert it to HF automagically
434
if checkpoint.is_file() and checkpoint.suffix in CKPT_EXTENSIONS:
435
logger.info(f"Loading model from checkpoint: {checkpoint}")
436
# check if we've already converted this model
437
model_dir = pipeline_dir.joinpath(checkpoint.stem)
438
if model_dir.joinpath("model_index.json").exists():
439
# we have, so just use that
440
logger.info("Found converted model in {model_dir}, will not convert")
441
logger.info("Delete the output directory to re-run conversion.")
442
else:
443
# we haven't, so convert it
444
logger.info("Converting checkpoint to HuggingFace pipeline...")
445
pipeline, model_dir = checkpoint_to_pipeline(checkpoint)
446
logger.info("Done!")
447
448