| | import argparse |
| | import datetime |
| | import inspect |
| | import os |
| | from omegaconf import OmegaConf |
| |
|
| | import torch |
| |
|
| | import diffusers |
| | from diffusers import AutoencoderKL, DDIMScheduler |
| |
|
| | from tqdm.auto import tqdm |
| | from transformers import CLIPTextModel, CLIPTokenizer |
| |
|
| | from animatediff.models.unet import UNet3DConditionModel |
| | from animatediff.pipelines.pipeline_animation import AnimationPipeline |
| | from animatediff.utils.util import save_videos_grid |
| | from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint |
| | from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora |
| | from diffusers.utils.import_utils import is_xformers_available |
| |
|
| | from einops import rearrange, repeat |
| |
|
| | import csv, pdb, glob |
| | from safetensors import safe_open |
| | import math |
| | from pathlib import Path |
| |
|
| |
|
| | def main(args): |
| | *_, func_args = inspect.getargvalues(inspect.currentframe()) |
| | func_args = dict(func_args) |
| | |
| | time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
| | savedir = f"samples/{Path(args.config).stem}-{time_str}" |
| | os.makedirs(savedir) |
| | inference_config = OmegaConf.load(args.inference_config) |
| |
|
| | config = OmegaConf.load(args.config) |
| | samples = [] |
| | |
| | sample_idx = 0 |
| | for model_idx, (config_key, model_config) in enumerate(list(config.items())): |
| | |
| | motion_modules = model_config.motion_module |
| | motion_modules = [motion_modules] if isinstance(motion_modules, str) else list(motion_modules) |
| | for motion_module in motion_modules: |
| | |
| | |
| | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer") |
| | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder") |
| | vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae") |
| | unet = UNet3DConditionModel.from_pretrained_2d(args.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs)) |
| |
|
| | if is_xformers_available(): unet.enable_xformers_memory_efficient_attention() |
| | else: assert False |
| |
|
| | pipeline = AnimationPipeline( |
| | vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, |
| | scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)), |
| | ).to("cuda") |
| |
|
| | |
| | |
| | motion_module_state_dict = torch.load(motion_module, map_location="cpu") |
| | if "global_step" in motion_module_state_dict: func_args.update({"global_step": motion_module_state_dict["global_step"]}) |
| | missing, unexpected = pipeline.unet.load_state_dict(motion_module_state_dict, strict=False) |
| | assert len(unexpected) == 0 |
| | |
| | |
| | if model_config.path != "": |
| | if model_config.path.endswith(".ckpt"): |
| | state_dict = torch.load(model_config.path) |
| | pipeline.unet.load_state_dict(state_dict) |
| | |
| | elif model_config.path.endswith(".safetensors"): |
| | state_dict = {} |
| | with safe_open(model_config.path, framework="pt", device="cpu") as f: |
| | for key in f.keys(): |
| | state_dict[key] = f.get_tensor(key) |
| | |
| | is_lora = all("lora" in k for k in state_dict.keys()) |
| | if not is_lora: |
| | base_state_dict = state_dict |
| | else: |
| | base_state_dict = {} |
| | with safe_open(model_config.base, framework="pt", device="cpu") as f: |
| | for key in f.keys(): |
| | base_state_dict[key] = f.get_tensor(key) |
| | |
| | |
| | converted_vae_checkpoint = convert_ldm_vae_checkpoint(base_state_dict, pipeline.vae.config) |
| | pipeline.vae.load_state_dict(converted_vae_checkpoint) |
| | |
| | converted_unet_checkpoint = convert_ldm_unet_checkpoint(base_state_dict, pipeline.unet.config) |
| | pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False) |
| | |
| | pipeline.text_encoder = convert_ldm_clip_checkpoint(base_state_dict) |
| | |
| | |
| | |
| | if is_lora: |
| | pipeline = convert_lora(pipeline, state_dict, alpha=model_config.lora_alpha) |
| |
|
| | pipeline.to("cuda") |
| | |
| |
|
| | prompts = model_config.prompt |
| | n_prompts = list(model_config.n_prompt) * len(prompts) if len(model_config.n_prompt) == 1 else model_config.n_prompt |
| | |
| | random_seeds = model_config.get("seed", [-1]) |
| | random_seeds = [random_seeds] if isinstance(random_seeds, int) else list(random_seeds) |
| | random_seeds = random_seeds * len(prompts) if len(random_seeds) == 1 else random_seeds |
| | |
| | config[config_key].random_seed = [] |
| | for prompt_idx, (prompt, n_prompt, random_seed) in enumerate(zip(prompts, n_prompts, random_seeds)): |
| | |
| | |
| | if random_seed != -1: torch.manual_seed(random_seed) |
| | else: torch.seed() |
| | config[config_key].random_seed.append(torch.initial_seed()) |
| | |
| | print(f"current seed: {torch.initial_seed()}") |
| | print(f"sampling {prompt} ...") |
| | sample = pipeline( |
| | prompt, |
| | negative_prompt = n_prompt, |
| | num_inference_steps = model_config.steps, |
| | guidance_scale = model_config.guidance_scale, |
| | width = args.W, |
| | height = args.H, |
| | video_length = args.L, |
| | ).videos |
| | samples.append(sample) |
| |
|
| | prompt = "-".join((prompt.replace("/", "").split(" ")[:10])) |
| | save_videos_grid(sample, f"{savedir}/sample/{sample_idx}-{prompt}.gif") |
| | print(f"save to {savedir}/sample/{prompt}.gif") |
| | |
| | sample_idx += 1 |
| |
|
| | samples = torch.concat(samples) |
| | save_videos_grid(samples, f"{savedir}/sample.gif", n_rows=4) |
| |
|
| | OmegaConf.save(config, f"{savedir}/config.yaml") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--pretrained_model_path", type=str, default="models/StableDiffusion/stable-diffusion-v1-5",) |
| | parser.add_argument("--inference_config", type=str, default="configs/inference/inference.yaml") |
| | parser.add_argument("--config", type=str, required=True) |
| | |
| | parser.add_argument("--L", type=int, default=16 ) |
| | parser.add_argument("--W", type=int, default=512) |
| | parser.add_argument("--H", type=int, default=512) |
| |
|
| | args = parser.parse_args() |
| | main(args) |
| |
|