BiliSakura commited on
Commit
1b57e93
·
verified ·
1 Parent(s): e63da7d

Delete SiT-XL-2-256/scheduler/scheduling_flow_match_sit.py

Browse files
SiT-XL-2-256/scheduler/scheduling_flow_match_sit.py DELETED
@@ -1,98 +0,0 @@
1
- from dataclasses import dataclass
2
- from typing import Optional, Tuple, Union
3
-
4
- import torch
5
-
6
- from diffusers.configuration_utils import ConfigMixin, register_to_config
7
- from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
8
- from diffusers.utils import BaseOutput
9
-
10
-
11
- @dataclass
12
- class SiTFlowMatchSchedulerOutput(BaseOutput):
13
- prev_sample: torch.Tensor
14
-
15
-
16
- class SiTFlowMatchScheduler(SchedulerMixin, ConfigMixin):
17
- _compatibles = [e.name for e in KarrasDiffusionSchedulers]
18
- order = 1
19
-
20
- @register_to_config
21
- def __init__(
22
- self,
23
- mode: str = "ode",
24
- num_train_timesteps: int = 1000,
25
- shift: float = 1.0,
26
- diffusion_form: str = "sigma",
27
- diffusion_norm: float = 1.0,
28
- ):
29
- self.timesteps = None
30
- self.sigmas = None
31
- self._step_index = None
32
-
33
- def set_timesteps(self, num_inference_steps: int, device: Optional[Union[str, torch.device]] = None):
34
- # Flow matching integrates from noise (t=0) to data (t=1).
35
- ts = torch.linspace(0.0, 1.0, num_inference_steps + 1, device=device, dtype=torch.float32)
36
- self.timesteps = ts[:-1]
37
- self.sigmas = 1.0 - self.timesteps
38
- self._step_index = 0
39
- return self.timesteps
40
-
41
- def scale_model_input(self, sample: torch.Tensor, timestep: Optional[torch.Tensor] = None) -> torch.Tensor:
42
- return sample
43
-
44
- def _diffusion(self, t: torch.Tensor) -> torch.Tensor:
45
- form = self.config.diffusion_form
46
- norm = self.config.diffusion_norm
47
- if form == "constant":
48
- return torch.full_like(t, norm)
49
- if form == "sigma":
50
- return norm * (1.0 - t)
51
- if form == "linear":
52
- return norm * (1.0 - t)
53
- if form == "decreasing":
54
- return 0.25 * (norm * torch.cos(torch.pi * t) + 1) ** 2
55
- if form == "increasing-decreasing":
56
- return norm * torch.sin(torch.pi * t) ** 2
57
- # "SBDM" approximated with sigma-based schedule for compatibility.
58
- return norm * (1.0 - t)
59
-
60
- def step(
61
- self,
62
- model_output: torch.Tensor,
63
- timestep: Union[float, torch.Tensor],
64
- sample: torch.Tensor,
65
- generator: Optional[torch.Generator] = None,
66
- return_dict: bool = True,
67
- ) -> Union[SiTFlowMatchSchedulerOutput, Tuple[torch.Tensor]]:
68
- if self.timesteps is None:
69
- raise ValueError("Call `set_timesteps` before `step`.")
70
- if self._step_index is None:
71
- self._step_index = 0
72
-
73
- step_index = min(self._step_index, len(self.timesteps) - 1)
74
- t = self.timesteps[step_index].to(sample.device)
75
- next_t = 1.0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1].to(sample.device)
76
- dt = next_t - t
77
-
78
- prev_sample = sample + model_output * dt
79
- if self.config.mode.lower() == "sde":
80
- diffusion = self._diffusion(torch.full((sample.shape[0],), t, device=sample.device, dtype=sample.dtype))
81
- while diffusion.dim() < sample.dim():
82
- diffusion = diffusion.unsqueeze(-1)
83
- noise = torch.randn(sample.shape, generator=generator, device=sample.device, dtype=sample.dtype)
84
- prev_sample = prev_sample + torch.sqrt(torch.clamp(2.0 * diffusion * torch.abs(dt), min=0.0)) * noise
85
-
86
- self._step_index += 1
87
- if not return_dict:
88
- return (prev_sample,)
89
- return SiTFlowMatchSchedulerOutput(prev_sample=prev_sample)
90
-
91
- def add_noise(
92
- self,
93
- original_samples: torch.Tensor,
94
- noise: torch.Tensor,
95
- timesteps: torch.Tensor,
96
- ) -> torch.Tensor:
97
- sigma = (1.0 - timesteps).view(-1, *([1] * (original_samples.ndim - 1)))
98
- return (1 - sigma) * original_samples + sigma * noise