Instructions to use BiliSakura/JiT-diffusers with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use BiliSakura/JiT-diffusers with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("BiliSakura/JiT-diffusers", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
Delete JiT-L-32/scheduler/scheduling_jit.py
Browse files
JiT-L-32/scheduler/scheduling_jit.py
DELETED
|
@@ -1,161 +0,0 @@
|
|
| 1 |
-
# Copyright 2026 The HuggingFace Team. All rights reserved.
|
| 2 |
-
#
|
| 3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
-
# you may not use this file except in compliance with the License.
|
| 5 |
-
# You may obtain a copy of the License at
|
| 6 |
-
#
|
| 7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
-
#
|
| 9 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
-
# See the License for the specific language governing permissions and
|
| 13 |
-
# limitations under the License.
|
| 14 |
-
|
| 15 |
-
from dataclasses import dataclass
|
| 16 |
-
from typing import List, Optional, Tuple, Union
|
| 17 |
-
|
| 18 |
-
import torch
|
| 19 |
-
|
| 20 |
-
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 21 |
-
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
| 22 |
-
from diffusers.utils import BaseOutput
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
@dataclass
|
| 26 |
-
class JiTSchedulerOutput(BaseOutput):
|
| 27 |
-
"""
|
| 28 |
-
Output class for the JiT scheduler's `step` function.
|
| 29 |
-
|
| 30 |
-
Args:
|
| 31 |
-
prev_sample (`torch.Tensor`):
|
| 32 |
-
Updated sample after one solver step along the JiT flow-time grid.
|
| 33 |
-
"""
|
| 34 |
-
|
| 35 |
-
prev_sample: torch.Tensor
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
class JiTScheduler(SchedulerMixin, ConfigMixin):
|
| 39 |
-
"""
|
| 40 |
-
Manual flow-matching scheduler for JiT checkpoints.
|
| 41 |
-
|
| 42 |
-
Uses a linear flow-time grid `t in [0, 1]` (increasing), matching the official JiT
|
| 43 |
-
sampler. Velocity is `v = (x_pred - z) / (1 - t)`; integration is explicit Euler or
|
| 44 |
-
Heun along that grid.
|
| 45 |
-
"""
|
| 46 |
-
|
| 47 |
-
order = 2
|
| 48 |
-
|
| 49 |
-
@register_to_config
|
| 50 |
-
def __init__(
|
| 51 |
-
self,
|
| 52 |
-
num_train_timesteps: int = 1000,
|
| 53 |
-
t_eps: float = 5e-2,
|
| 54 |
-
solver: str = "heun",
|
| 55 |
-
):
|
| 56 |
-
if solver not in {"heun", "euler"}:
|
| 57 |
-
raise ValueError("solver must be one of: 'heun', 'euler'.")
|
| 58 |
-
self.timesteps: Optional[torch.Tensor] = None
|
| 59 |
-
self.sigmas: Optional[List[float]] = None
|
| 60 |
-
self.num_inference_steps: Optional[int] = None
|
| 61 |
-
self._step_index: Optional[int] = None
|
| 62 |
-
|
| 63 |
-
@property
|
| 64 |
-
def init_noise_sigma(self) -> float:
|
| 65 |
-
return 1.0
|
| 66 |
-
|
| 67 |
-
def set_timesteps(
|
| 68 |
-
self,
|
| 69 |
-
num_inference_steps: int,
|
| 70 |
-
device: Union[str, torch.device, None] = None,
|
| 71 |
-
solver: Optional[str] = None,
|
| 72 |
-
) -> None:
|
| 73 |
-
if num_inference_steps < 2:
|
| 74 |
-
raise ValueError("num_inference_steps must be >= 2.")
|
| 75 |
-
|
| 76 |
-
self.num_inference_steps = num_inference_steps
|
| 77 |
-
self.timesteps = torch.linspace(
|
| 78 |
-
0.0,
|
| 79 |
-
1.0,
|
| 80 |
-
num_inference_steps + 1,
|
| 81 |
-
device=device,
|
| 82 |
-
dtype=torch.float32,
|
| 83 |
-
)
|
| 84 |
-
sigma_grid = torch.linspace(0.0, 1.0, num_inference_steps, device=device, dtype=torch.float32)
|
| 85 |
-
self.sigmas = (1.0 - sigma_grid).tolist()
|
| 86 |
-
self._step_index = 0
|
| 87 |
-
if solver is not None:
|
| 88 |
-
self.register_to_config(solver=solver)
|
| 89 |
-
|
| 90 |
-
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
|
| 91 |
-
del timestep
|
| 92 |
-
return sample
|
| 93 |
-
|
| 94 |
-
def _resolve_step_index(self, timestep: Union[float, torch.Tensor, None]) -> int:
|
| 95 |
-
if self._step_index is not None:
|
| 96 |
-
return self._step_index
|
| 97 |
-
if self.timesteps is None:
|
| 98 |
-
raise ValueError("Call `set_timesteps` before `step`.")
|
| 99 |
-
if timestep is None:
|
| 100 |
-
return 0
|
| 101 |
-
t_value = float(timestep) if not isinstance(timestep, torch.Tensor) else float(timestep.flatten()[0])
|
| 102 |
-
matches = (self.timesteps - t_value).abs() < 1e-6
|
| 103 |
-
if matches.any():
|
| 104 |
-
return int(matches.nonzero(as_tuple=False)[0].item())
|
| 105 |
-
return 0
|
| 106 |
-
|
| 107 |
-
def step(
|
| 108 |
-
self,
|
| 109 |
-
model_output: torch.Tensor,
|
| 110 |
-
timestep: Union[float, torch.Tensor, None],
|
| 111 |
-
sample: torch.Tensor,
|
| 112 |
-
model_output_next: Optional[torch.Tensor] = None,
|
| 113 |
-
return_dict: bool = True,
|
| 114 |
-
) -> Union[JiTSchedulerOutput, Tuple[torch.Tensor]]:
|
| 115 |
-
"""
|
| 116 |
-
Integrate one step on the linear `t` grid.
|
| 117 |
-
|
| 118 |
-
Args:
|
| 119 |
-
model_output (`torch.Tensor`):
|
| 120 |
-
Velocity `v = (x_pred - z) / (1 - t)` at the current time.
|
| 121 |
-
timestep (`float` or `torch.Tensor`, *optional*):
|
| 122 |
-
Current flow time `t`. When omitted, uses the internal step index.
|
| 123 |
-
sample (`torch.Tensor`):
|
| 124 |
-
Current noisy latent `z`.
|
| 125 |
-
model_output_next (`torch.Tensor`, *optional*):
|
| 126 |
-
Velocity at `t_next` (required for Heun intermediate steps).
|
| 127 |
-
"""
|
| 128 |
-
if self.timesteps is None:
|
| 129 |
-
raise ValueError("Call `set_timesteps` before `step`.")
|
| 130 |
-
|
| 131 |
-
step_index = self._resolve_step_index(timestep)
|
| 132 |
-
if step_index >= len(self.timesteps) - 1:
|
| 133 |
-
raise ValueError("Scheduler has already reached the final timestep.")
|
| 134 |
-
|
| 135 |
-
t = self.timesteps[step_index]
|
| 136 |
-
t_next = self.timesteps[step_index + 1]
|
| 137 |
-
dt = t_next - t
|
| 138 |
-
|
| 139 |
-
if self.config.solver == "heun" and model_output_next is not None:
|
| 140 |
-
prev_sample = sample + dt * 0.5 * (model_output + model_output_next)
|
| 141 |
-
else:
|
| 142 |
-
prev_sample = sample + dt * model_output
|
| 143 |
-
|
| 144 |
-
self._step_index = step_index + 1
|
| 145 |
-
|
| 146 |
-
if not return_dict:
|
| 147 |
-
return (prev_sample,)
|
| 148 |
-
return JiTSchedulerOutput(prev_sample=prev_sample)
|
| 149 |
-
|
| 150 |
-
def velocity_from_prediction(
|
| 151 |
-
self,
|
| 152 |
-
sample: torch.Tensor,
|
| 153 |
-
x_pred: torch.Tensor,
|
| 154 |
-
timestep: Union[float, torch.Tensor],
|
| 155 |
-
) -> torch.Tensor:
|
| 156 |
-
"""Compute JiT velocity `v = (x_pred - z) / (1 - t)` with denominator clamp."""
|
| 157 |
-
t = torch.as_tensor(timestep, device=sample.device, dtype=sample.dtype)
|
| 158 |
-
while t.ndim < sample.ndim:
|
| 159 |
-
t = t.unsqueeze(-1)
|
| 160 |
-
denom = (1.0 - t).clamp_min(self.config.t_eps)
|
| 161 |
-
return (x_pred - sample) / denom
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|