BiliSakura commited on
Commit
af7c4b8
·
verified ·
1 Parent(s): 1071e0d

Delete JiT-L-32/scheduler/scheduling_jit.py

Browse files
Files changed (1) hide show
  1. JiT-L-32/scheduler/scheduling_jit.py +0 -161
JiT-L-32/scheduler/scheduling_jit.py DELETED
@@ -1,161 +0,0 @@
1
- # Copyright 2026 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from dataclasses import dataclass
16
- from typing import List, Optional, Tuple, Union
17
-
18
- import torch
19
-
20
- from diffusers.configuration_utils import ConfigMixin, register_to_config
21
- from diffusers.schedulers.scheduling_utils import SchedulerMixin
22
- from diffusers.utils import BaseOutput
23
-
24
-
25
- @dataclass
26
- class JiTSchedulerOutput(BaseOutput):
27
- """
28
- Output class for the JiT scheduler's `step` function.
29
-
30
- Args:
31
- prev_sample (`torch.Tensor`):
32
- Updated sample after one solver step along the JiT flow-time grid.
33
- """
34
-
35
- prev_sample: torch.Tensor
36
-
37
-
38
- class JiTScheduler(SchedulerMixin, ConfigMixin):
39
- """
40
- Manual flow-matching scheduler for JiT checkpoints.
41
-
42
- Uses a linear flow-time grid `t in [0, 1]` (increasing), matching the official JiT
43
- sampler. Velocity is `v = (x_pred - z) / (1 - t)`; integration is explicit Euler or
44
- Heun along that grid.
45
- """
46
-
47
- order = 2
48
-
49
- @register_to_config
50
- def __init__(
51
- self,
52
- num_train_timesteps: int = 1000,
53
- t_eps: float = 5e-2,
54
- solver: str = "heun",
55
- ):
56
- if solver not in {"heun", "euler"}:
57
- raise ValueError("solver must be one of: 'heun', 'euler'.")
58
- self.timesteps: Optional[torch.Tensor] = None
59
- self.sigmas: Optional[List[float]] = None
60
- self.num_inference_steps: Optional[int] = None
61
- self._step_index: Optional[int] = None
62
-
63
- @property
64
- def init_noise_sigma(self) -> float:
65
- return 1.0
66
-
67
- def set_timesteps(
68
- self,
69
- num_inference_steps: int,
70
- device: Union[str, torch.device, None] = None,
71
- solver: Optional[str] = None,
72
- ) -> None:
73
- if num_inference_steps < 2:
74
- raise ValueError("num_inference_steps must be >= 2.")
75
-
76
- self.num_inference_steps = num_inference_steps
77
- self.timesteps = torch.linspace(
78
- 0.0,
79
- 1.0,
80
- num_inference_steps + 1,
81
- device=device,
82
- dtype=torch.float32,
83
- )
84
- sigma_grid = torch.linspace(0.0, 1.0, num_inference_steps, device=device, dtype=torch.float32)
85
- self.sigmas = (1.0 - sigma_grid).tolist()
86
- self._step_index = 0
87
- if solver is not None:
88
- self.register_to_config(solver=solver)
89
-
90
- def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
91
- del timestep
92
- return sample
93
-
94
- def _resolve_step_index(self, timestep: Union[float, torch.Tensor, None]) -> int:
95
- if self._step_index is not None:
96
- return self._step_index
97
- if self.timesteps is None:
98
- raise ValueError("Call `set_timesteps` before `step`.")
99
- if timestep is None:
100
- return 0
101
- t_value = float(timestep) if not isinstance(timestep, torch.Tensor) else float(timestep.flatten()[0])
102
- matches = (self.timesteps - t_value).abs() < 1e-6
103
- if matches.any():
104
- return int(matches.nonzero(as_tuple=False)[0].item())
105
- return 0
106
-
107
- def step(
108
- self,
109
- model_output: torch.Tensor,
110
- timestep: Union[float, torch.Tensor, None],
111
- sample: torch.Tensor,
112
- model_output_next: Optional[torch.Tensor] = None,
113
- return_dict: bool = True,
114
- ) -> Union[JiTSchedulerOutput, Tuple[torch.Tensor]]:
115
- """
116
- Integrate one step on the linear `t` grid.
117
-
118
- Args:
119
- model_output (`torch.Tensor`):
120
- Velocity `v = (x_pred - z) / (1 - t)` at the current time.
121
- timestep (`float` or `torch.Tensor`, *optional*):
122
- Current flow time `t`. When omitted, uses the internal step index.
123
- sample (`torch.Tensor`):
124
- Current noisy latent `z`.
125
- model_output_next (`torch.Tensor`, *optional*):
126
- Velocity at `t_next` (required for Heun intermediate steps).
127
- """
128
- if self.timesteps is None:
129
- raise ValueError("Call `set_timesteps` before `step`.")
130
-
131
- step_index = self._resolve_step_index(timestep)
132
- if step_index >= len(self.timesteps) - 1:
133
- raise ValueError("Scheduler has already reached the final timestep.")
134
-
135
- t = self.timesteps[step_index]
136
- t_next = self.timesteps[step_index + 1]
137
- dt = t_next - t
138
-
139
- if self.config.solver == "heun" and model_output_next is not None:
140
- prev_sample = sample + dt * 0.5 * (model_output + model_output_next)
141
- else:
142
- prev_sample = sample + dt * model_output
143
-
144
- self._step_index = step_index + 1
145
-
146
- if not return_dict:
147
- return (prev_sample,)
148
- return JiTSchedulerOutput(prev_sample=prev_sample)
149
-
150
- def velocity_from_prediction(
151
- self,
152
- sample: torch.Tensor,
153
- x_pred: torch.Tensor,
154
- timestep: Union[float, torch.Tensor],
155
- ) -> torch.Tensor:
156
- """Compute JiT velocity `v = (x_pred - z) / (1 - t)` with denominator clamp."""
157
- t = torch.as_tensor(timestep, device=sample.device, dtype=sample.dtype)
158
- while t.ndim < sample.ndim:
159
- t = t.unsqueeze(-1)
160
- denom = (1.0 - t).clamp_min(self.config.t_eps)
161
- return (x_pred - sample) / denom