Varad1707 commited on
Commit
95bb22a
·
verified ·
1 Parent(s): 25eb67d

Upload wan_i2v.py

Browse files
Files changed (1) hide show
  1. wan_i2v.py +178 -0
wan_i2v.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.distributed as dist
3
+
4
+ from diffusers import AutoencoderKLWan, WanImageToVideoPipeline,WanTransformer3DModel
5
+ from diffusers.utils import export_to_video, load_image
6
+ from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
7
+
8
+ from transformers import CLIPVisionModel, UMT5EncoderModel
9
+
10
+ from para_attn.context_parallel import init_context_parallel_mesh
11
+ from para_attn.context_parallel.diffusers_adapters import parallelize_pipe
12
+
13
+ import gc
14
+ import time
15
+ import logging
16
+ from PIL import Image
17
+
18
+ logging.basicConfig(level=logging.INFO)
19
+ logger = logging.getLogger(__name__)
20
+
21
+ class WanI2V(WanImageToVideoPipeline):
22
+ """
23
+ Wan Image to Video Pipeline
24
+ Supports 14B and 720P models (Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-14B-720P-Diffusers)
25
+ """
26
+
27
+ def __init__(self,model_id,apply_cache=True):
28
+ logger.info(f"Initializing WanI2V pipeline with model {model_id}")
29
+ dist.init_process_group()
30
+ torch.cuda.set_device(dist.get_rank())
31
+ start_load_time = time.time()
32
+ self.device = "cuda"
33
+ self.pipe = None
34
+ self.model_id = model_id
35
+ self.dtype = torch.bfloat16
36
+ self.apply_cache = apply_cache
37
+ if self.model_id == "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers":
38
+ self.flow_shift = 3.0
39
+ elif self.model_id == "Wan-AI/Wan2.1-I2V-14B-720P-Diffusers":
40
+ self.flow_shift = 5.0
41
+
42
+ try:
43
+ self.load_model()
44
+ self.optimize_pipe()
45
+ logger.info(f"Pipeline initialized {model_id} in {time.time() - start_load_time} seconds")
46
+ self.warmup()
47
+ except Exception as e:
48
+ logger.error(f"Error initializing {model_id}: {e}")
49
+ raise e
50
+
51
+ def load_model(self):
52
+ self.text_encoder = UMT5EncoderModel.from_pretrained(self.model_id, subfolder="text_encoder", torch_dtype=self.dtype)
53
+ self.vae = AutoencoderKLWan.from_pretrained(self.model_id, subfolder="vae", torch_dtype=torch.float32)
54
+ self.transformer = WanTransformer3DModel.from_pretrained(self.model_id, subfolder="transformer", torch_dtype=self.dtype)
55
+ self.image_encoder = CLIPVisionModel.from_pretrained(self.model_id, subfolder="image_encoder", torch_dtype=torch.float32)
56
+
57
+ self.pipe = WanImageToVideoPipeline.from_pretrained(
58
+ self.model_id,
59
+ vae=self.vae,
60
+ transformer=self.transformer,
61
+ text_encoder=self.text_encoder,
62
+ image_encoder=self.image_encoder,
63
+ torch_dtype=self.dtype
64
+ )
65
+ self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config, flow_shift=self.flow_shift)
66
+ self.pipe.to(self.device)
67
+
68
+ def optimize_pipe(self):
69
+ self.pipe.transformer.enable_gradient_checkpointing() # Enable gradient checkpointing
70
+ self.pipe.enable_attention_slicing(slice_size="auto")
71
+ parallelize_pipe(
72
+ self.pipe,
73
+ mesh=init_context_parallel_mesh(
74
+ self.pipe.device.type,
75
+ ),
76
+ )
77
+ if self.apply_cache:
78
+ from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe
79
+ apply_cache_on_pipe(self.pipe , residual_diff_threshold=0.1)
80
+
81
+ def clear_memory(self):
82
+ torch.cuda.empty_cache()
83
+ gc.collect()
84
+
85
+ def generate_video(self,prompt : str,negative_prompt : str,image : Image.Image,num_frames : int = 81,guidance_scale : float = 5.0,num_inference_steps : int = 30,height : int = 576,width : int = 1024):
86
+ with torch.inference_mode():
87
+ output = self.pipe(
88
+ image=image,
89
+ prompt=prompt,
90
+ negative_prompt=negative_prompt,
91
+ height=height,
92
+ width=width,
93
+ num_frames=num_frames,
94
+ guidance_scale=guidance_scale,
95
+ num_inference_steps=num_inference_steps,
96
+ output_type="pil" if dist.get_rank() == 0 else "pt",
97
+ ).frames[0]
98
+
99
+ if dist.get_rank() == 0:
100
+ self.clear_memory()
101
+ if isinstance(output[0], torch.Tensor):
102
+ output = [frame.cpu() if frame.device.type == 'cuda' else frame for frame in output]
103
+ return output
104
+
105
+ def warmup(self):
106
+ logger.info("Running Warm Up!")
107
+ prompt = "A car driving on a road"
108
+ negative_prompt = "blurry, low quality, dark"
109
+ image = load_image("https://storage.googleapis.com/falserverless/gallery/car_720p.png")
110
+ start_time = time.time()
111
+ with torch.inference_mode():
112
+ self.pipe(
113
+ prompt=prompt,
114
+ negative_prompt=negative_prompt,
115
+ image=image,
116
+ num_inference_steps=30,
117
+ height=576,
118
+ width=1024,
119
+ num_frames=81,
120
+ guidance_scale=5.0
121
+ )
122
+ self.get_matrix(start_time,time.time(),576,1024)
123
+ logger.info("Warm Up Completed!")
124
+ self.clear_memory()
125
+
126
+ def shutdown(self):
127
+ dist.destroy_process_group()
128
+
129
+ def get_matrix(self,start_time : int,end_time : int,height : int,width : int):
130
+ logger.info("-"*20)
131
+ logger.info(f"inference time : {end_time - start_time}")
132
+ logger.info(f"height : {height}")
133
+ logger.info(f"width : {width}")
134
+ logger.info("-"*20)
135
+
136
+
137
+
138
+ if __name__ == "__main__":
139
+ inputs = {
140
+ "1" : {
141
+ "prompt" : "Cars racing in slow motion",
142
+ "negative_prompt" : "bright colors, overexposed, static, blurred details, subtitles, style, artwork, painting, picture, still, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, malformed limbs, fused fingers, still picture, cluttered background, three legs, many people in the background, walking backwards",
143
+ "image" : load_image("https://storage.googleapis.com/falserverless/gallery/car_720p.png")
144
+ },
145
+ "2" : {
146
+ "prompt" : "A cat in a car",
147
+ "negative_prompt" : "bright colors, overexposed, static, blurred details, subtitles, style, artwork, painting, picture, still, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, malformed limbs, fused fingers, still picture, cluttered background, three legs, many people in the background, walking backwards",
148
+ "image" : load_image("https://fancypawscatclinic.com/uploads/SiteAssets/426/images/services/payment-options-cat-720px.jpg")
149
+ }
150
+ }
151
+
152
+ RESOLUTION_CONFIG = {
153
+ "Horizontal" : {
154
+ "height" : 576,
155
+ "width" : 1024
156
+ },
157
+ "Vertical" : {
158
+ "height" : 1024,
159
+ "width" : 576
160
+ },
161
+ "Square" : {
162
+ "height" : 768,
163
+ "width" : 768
164
+ }
165
+ }
166
+ WanI2V = WanI2V("Wan-AI/Wan2.1-I2V-14B-480P-Diffusers",apply_cache=True)
167
+ resolution = "Horizontal"
168
+
169
+ for i in range(0,len(inputs)):
170
+ for j in range(0,len(inputs[str(i+1)])):
171
+ prompt = inputs[str(i+1)]["prompt"]
172
+ negative_prompt = inputs[str(i+1)]["negative_prompt"]
173
+ image = inputs[str(i+1)]["image"]
174
+ start_time = time.time()
175
+ WanI2V.generate_video(prompt,negative_prompt,image)
176
+ end_time = time.time()
177
+ WanI2V.get_matrix(start_time,end_time,RESOLUTION_CONFIG[resolution]["height"],RESOLUTION_CONFIG[resolution]["width"])
178
+ WanI2V.shutdown()