Upload wan_i2v.py
Browse files- wan_i2v.py +178 -0
wan_i2v.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.distributed as dist
|
| 3 |
+
|
| 4 |
+
from diffusers import AutoencoderKLWan, WanImageToVideoPipeline,WanTransformer3DModel
|
| 5 |
+
from diffusers.utils import export_to_video, load_image
|
| 6 |
+
from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
|
| 7 |
+
|
| 8 |
+
from transformers import CLIPVisionModel, UMT5EncoderModel
|
| 9 |
+
|
| 10 |
+
from para_attn.context_parallel import init_context_parallel_mesh
|
| 11 |
+
from para_attn.context_parallel.diffusers_adapters import parallelize_pipe
|
| 12 |
+
|
| 13 |
+
import gc
|
| 14 |
+
import time
|
| 15 |
+
import logging
|
| 16 |
+
from PIL import Image
|
| 17 |
+
|
| 18 |
+
logging.basicConfig(level=logging.INFO)
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
class WanI2V(WanImageToVideoPipeline):
|
| 22 |
+
"""
|
| 23 |
+
Wan Image to Video Pipeline
|
| 24 |
+
Supports 14B and 720P models (Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-14B-720P-Diffusers)
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(self,model_id,apply_cache=True):
|
| 28 |
+
logger.info(f"Initializing WanI2V pipeline with model {model_id}")
|
| 29 |
+
dist.init_process_group()
|
| 30 |
+
torch.cuda.set_device(dist.get_rank())
|
| 31 |
+
start_load_time = time.time()
|
| 32 |
+
self.device = "cuda"
|
| 33 |
+
self.pipe = None
|
| 34 |
+
self.model_id = model_id
|
| 35 |
+
self.dtype = torch.bfloat16
|
| 36 |
+
self.apply_cache = apply_cache
|
| 37 |
+
if self.model_id == "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers":
|
| 38 |
+
self.flow_shift = 3.0
|
| 39 |
+
elif self.model_id == "Wan-AI/Wan2.1-I2V-14B-720P-Diffusers":
|
| 40 |
+
self.flow_shift = 5.0
|
| 41 |
+
|
| 42 |
+
try:
|
| 43 |
+
self.load_model()
|
| 44 |
+
self.optimize_pipe()
|
| 45 |
+
logger.info(f"Pipeline initialized {model_id} in {time.time() - start_load_time} seconds")
|
| 46 |
+
self.warmup()
|
| 47 |
+
except Exception as e:
|
| 48 |
+
logger.error(f"Error initializing {model_id}: {e}")
|
| 49 |
+
raise e
|
| 50 |
+
|
| 51 |
+
def load_model(self):
|
| 52 |
+
self.text_encoder = UMT5EncoderModel.from_pretrained(self.model_id, subfolder="text_encoder", torch_dtype=self.dtype)
|
| 53 |
+
self.vae = AutoencoderKLWan.from_pretrained(self.model_id, subfolder="vae", torch_dtype=torch.float32)
|
| 54 |
+
self.transformer = WanTransformer3DModel.from_pretrained(self.model_id, subfolder="transformer", torch_dtype=self.dtype)
|
| 55 |
+
self.image_encoder = CLIPVisionModel.from_pretrained(self.model_id, subfolder="image_encoder", torch_dtype=torch.float32)
|
| 56 |
+
|
| 57 |
+
self.pipe = WanImageToVideoPipeline.from_pretrained(
|
| 58 |
+
self.model_id,
|
| 59 |
+
vae=self.vae,
|
| 60 |
+
transformer=self.transformer,
|
| 61 |
+
text_encoder=self.text_encoder,
|
| 62 |
+
image_encoder=self.image_encoder,
|
| 63 |
+
torch_dtype=self.dtype
|
| 64 |
+
)
|
| 65 |
+
self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config, flow_shift=self.flow_shift)
|
| 66 |
+
self.pipe.to(self.device)
|
| 67 |
+
|
| 68 |
+
def optimize_pipe(self):
|
| 69 |
+
self.pipe.transformer.enable_gradient_checkpointing() # Enable gradient checkpointing
|
| 70 |
+
self.pipe.enable_attention_slicing(slice_size="auto")
|
| 71 |
+
parallelize_pipe(
|
| 72 |
+
self.pipe,
|
| 73 |
+
mesh=init_context_parallel_mesh(
|
| 74 |
+
self.pipe.device.type,
|
| 75 |
+
),
|
| 76 |
+
)
|
| 77 |
+
if self.apply_cache:
|
| 78 |
+
from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe
|
| 79 |
+
apply_cache_on_pipe(self.pipe , residual_diff_threshold=0.1)
|
| 80 |
+
|
| 81 |
+
def clear_memory(self):
|
| 82 |
+
torch.cuda.empty_cache()
|
| 83 |
+
gc.collect()
|
| 84 |
+
|
| 85 |
+
def generate_video(self,prompt : str,negative_prompt : str,image : Image.Image,num_frames : int = 81,guidance_scale : float = 5.0,num_inference_steps : int = 30,height : int = 576,width : int = 1024):
|
| 86 |
+
with torch.inference_mode():
|
| 87 |
+
output = self.pipe(
|
| 88 |
+
image=image,
|
| 89 |
+
prompt=prompt,
|
| 90 |
+
negative_prompt=negative_prompt,
|
| 91 |
+
height=height,
|
| 92 |
+
width=width,
|
| 93 |
+
num_frames=num_frames,
|
| 94 |
+
guidance_scale=guidance_scale,
|
| 95 |
+
num_inference_steps=num_inference_steps,
|
| 96 |
+
output_type="pil" if dist.get_rank() == 0 else "pt",
|
| 97 |
+
).frames[0]
|
| 98 |
+
|
| 99 |
+
if dist.get_rank() == 0:
|
| 100 |
+
self.clear_memory()
|
| 101 |
+
if isinstance(output[0], torch.Tensor):
|
| 102 |
+
output = [frame.cpu() if frame.device.type == 'cuda' else frame for frame in output]
|
| 103 |
+
return output
|
| 104 |
+
|
| 105 |
+
def warmup(self):
|
| 106 |
+
logger.info("Running Warm Up!")
|
| 107 |
+
prompt = "A car driving on a road"
|
| 108 |
+
negative_prompt = "blurry, low quality, dark"
|
| 109 |
+
image = load_image("https://storage.googleapis.com/falserverless/gallery/car_720p.png")
|
| 110 |
+
start_time = time.time()
|
| 111 |
+
with torch.inference_mode():
|
| 112 |
+
self.pipe(
|
| 113 |
+
prompt=prompt,
|
| 114 |
+
negative_prompt=negative_prompt,
|
| 115 |
+
image=image,
|
| 116 |
+
num_inference_steps=30,
|
| 117 |
+
height=576,
|
| 118 |
+
width=1024,
|
| 119 |
+
num_frames=81,
|
| 120 |
+
guidance_scale=5.0
|
| 121 |
+
)
|
| 122 |
+
self.get_matrix(start_time,time.time(),576,1024)
|
| 123 |
+
logger.info("Warm Up Completed!")
|
| 124 |
+
self.clear_memory()
|
| 125 |
+
|
| 126 |
+
def shutdown(self):
|
| 127 |
+
dist.destroy_process_group()
|
| 128 |
+
|
| 129 |
+
def get_matrix(self,start_time : int,end_time : int,height : int,width : int):
|
| 130 |
+
logger.info("-"*20)
|
| 131 |
+
logger.info(f"inference time : {end_time - start_time}")
|
| 132 |
+
logger.info(f"height : {height}")
|
| 133 |
+
logger.info(f"width : {width}")
|
| 134 |
+
logger.info("-"*20)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
if __name__ == "__main__":
|
| 139 |
+
inputs = {
|
| 140 |
+
"1" : {
|
| 141 |
+
"prompt" : "Cars racing in slow motion",
|
| 142 |
+
"negative_prompt" : "bright colors, overexposed, static, blurred details, subtitles, style, artwork, painting, picture, still, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, malformed limbs, fused fingers, still picture, cluttered background, three legs, many people in the background, walking backwards",
|
| 143 |
+
"image" : load_image("https://storage.googleapis.com/falserverless/gallery/car_720p.png")
|
| 144 |
+
},
|
| 145 |
+
"2" : {
|
| 146 |
+
"prompt" : "A cat in a car",
|
| 147 |
+
"negative_prompt" : "bright colors, overexposed, static, blurred details, subtitles, style, artwork, painting, picture, still, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, malformed limbs, fused fingers, still picture, cluttered background, three legs, many people in the background, walking backwards",
|
| 148 |
+
"image" : load_image("https://fancypawscatclinic.com/uploads/SiteAssets/426/images/services/payment-options-cat-720px.jpg")
|
| 149 |
+
}
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
RESOLUTION_CONFIG = {
|
| 153 |
+
"Horizontal" : {
|
| 154 |
+
"height" : 576,
|
| 155 |
+
"width" : 1024
|
| 156 |
+
},
|
| 157 |
+
"Vertical" : {
|
| 158 |
+
"height" : 1024,
|
| 159 |
+
"width" : 576
|
| 160 |
+
},
|
| 161 |
+
"Square" : {
|
| 162 |
+
"height" : 768,
|
| 163 |
+
"width" : 768
|
| 164 |
+
}
|
| 165 |
+
}
|
| 166 |
+
WanI2V = WanI2V("Wan-AI/Wan2.1-I2V-14B-480P-Diffusers",apply_cache=True)
|
| 167 |
+
resolution = "Horizontal"
|
| 168 |
+
|
| 169 |
+
for i in range(0,len(inputs)):
|
| 170 |
+
for j in range(0,len(inputs[str(i+1)])):
|
| 171 |
+
prompt = inputs[str(i+1)]["prompt"]
|
| 172 |
+
negative_prompt = inputs[str(i+1)]["negative_prompt"]
|
| 173 |
+
image = inputs[str(i+1)]["image"]
|
| 174 |
+
start_time = time.time()
|
| 175 |
+
WanI2V.generate_video(prompt,negative_prompt,image)
|
| 176 |
+
end_time = time.time()
|
| 177 |
+
WanI2V.get_matrix(start_time,end_time,RESOLUTION_CONFIG[resolution]["height"],RESOLUTION_CONFIG[resolution]["width"])
|
| 178 |
+
WanI2V.shutdown()
|