BiliSakura commited on
Commit
de7c0d6
·
verified ·
1 Parent(s): 124234e

Upload folder using huggingface_hub

Browse files
README.md CHANGED
@@ -32,10 +32,13 @@ Each subfolder is a self-contained Diffusers model repo with:
32
 
33
  - `pipeline.py`
34
  - `transformer/transformer_sit.py`
35
- - `scheduler/scheduling_flow_match_sit.py`
36
  - `transformer/diffusion_pytorch_model.safetensors`
37
  - `vae/diffusion_pytorch_model.safetensors`
38
 
 
 
 
39
  ## Demo
40
 
41
  ![SiT-XL-2-512 demo](SiT-XL-2-512/demo.png)
@@ -47,7 +50,7 @@ Class-conditional sample (ImageNet class **207**, golden retriever), `SiT-XL/2`
47
  Use paths relative to this root README:
48
 
49
  | Model | Resolution | Local path |
50
- |---|---:|---|
51
  | SiT-S/2 | 256x256 | `./SiT-S-2-256` |
52
  | SiT-B/2 | 256x256 | `./SiT-B-2-256` |
53
  | SiT-L/2 | 256x256 | `./SiT-L-2-256` |
@@ -73,8 +76,10 @@ pipe = DiffusionPipeline.from_pretrained(
73
  generator = torch.Generator(device=device).manual_seed(0)
74
 
75
  # ImageNet class example: 207 = golden retriever
 
 
76
  result = pipe(
77
- class_labels=207,
78
  height=512,
79
  width=512,
80
  num_inference_steps=250, # official SiT comparisons commonly use 250 steps
 
32
 
33
  - `pipeline.py`
34
  - `transformer/transformer_sit.py`
35
+ - `scheduler/scheduler_config.json` (`FlowMatchEulerDiscreteScheduler`)
36
  - `transformer/diffusion_pytorch_model.safetensors`
37
  - `vae/diffusion_pytorch_model.safetensors`
38
 
39
+ Each variant embeds English `id2label` directly in `model_index.json` (DiT-style), so class labels can be passed as
40
+ ImageNet ids or English synonym strings.
41
+
42
  ## Demo
43
 
44
  ![SiT-XL-2-512 demo](SiT-XL-2-512/demo.png)
 
50
  Use paths relative to this root README:
51
 
52
  | Model | Resolution | Local path |
53
+ | --- | ---: | --- |
54
  | SiT-S/2 | 256x256 | `./SiT-S-2-256` |
55
  | SiT-B/2 | 256x256 | `./SiT-B-2-256` |
56
  | SiT-L/2 | 256x256 | `./SiT-L-2-256` |
 
76
  generator = torch.Generator(device=device).manual_seed(0)
77
 
78
  # ImageNet class example: 207 = golden retriever
79
+ print(pipe.id2label[207])
80
+ print(pipe.get_label_ids("golden retriever")) # [207]
81
  result = pipe(
82
+ class_labels="golden retriever",
83
  height=512,
84
  width=512,
85
  num_inference_steps=250, # official SiT comparisons commonly use 250 steps
SiT-B-2-256/__pycache__/pipeline.cpython-312.pyc ADDED
Binary file (7.51 kB). View file
 
SiT-B-2-256/model_index.json CHANGED
@@ -1,19 +1,1021 @@
1
- {
2
- "_class_name": [
3
- "pipeline",
4
- "SiTPipeline"
5
- ],
6
- "_diffusers_version": "0.36.0",
7
- "scheduler": [
8
- "scheduling_flow_match_sit",
9
- "SiTFlowMatchScheduler"
10
- ],
11
- "transformer": [
12
- "transformer_sit",
13
- "SiTTransformer2DModel"
14
- ],
15
- "vae": [
16
- "diffusers",
17
- "AutoencoderKL"
18
- ]
19
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": [
3
+ "pipeline",
4
+ "SiTPipeline"
5
+ ],
6
+ "_diffusers_version": "0.36.0",
7
+ "scheduler": [
8
+ "diffusers",
9
+ "FlowMatchEulerDiscreteScheduler"
10
+ ],
11
+ "transformer": [
12
+ "transformer_sit",
13
+ "SiTTransformer2DModel"
14
+ ],
15
+ "vae": [
16
+ "diffusers",
17
+ "AutoencoderKL"
18
+ ],
19
+ "id2label": {
20
+ "0": "tench, Tinca tinca",
21
+ "1": "goldfish, Carassius auratus",
22
+ "2": "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias",
23
+ "3": "tiger shark, Galeocerdo cuvieri",
24
+ "4": "hammerhead, hammerhead shark",
25
+ "5": "electric ray, crampfish, numbfish, torpedo",
26
+ "6": "stingray",
27
+ "7": "cock",
28
+ "8": "hen",
29
+ "9": "ostrich, Struthio camelus",
30
+ "10": "brambling, Fringilla montifringilla",
31
+ "11": "goldfinch, Carduelis carduelis",
32
+ "12": "house finch, linnet, Carpodacus mexicanus",
33
+ "13": "junco, snowbird",
34
+ "14": "indigo bunting, indigo finch, indigo bird, Passerina cyanea",
35
+ "15": "robin, American robin, Turdus migratorius",
36
+ "16": "bulbul",
37
+ "17": "jay",
38
+ "18": "magpie",
39
+ "19": "chickadee",
40
+ "20": "water ouzel, dipper",
41
+ "21": "kite",
42
+ "22": "bald eagle, American eagle, Haliaeetus leucocephalus",
43
+ "23": "vulture",
44
+ "24": "great grey owl, great gray owl, Strix nebulosa",
45
+ "25": "European fire salamander, Salamandra salamandra",
46
+ "26": "common newt, Triturus vulgaris",
47
+ "27": "eft",
48
+ "28": "spotted salamander, Ambystoma maculatum",
49
+ "29": "axolotl, mud puppy, Ambystoma mexicanum",
50
+ "30": "bullfrog, Rana catesbeiana",
51
+ "31": "tree frog, tree-frog",
52
+ "32": "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui",
53
+ "33": "loggerhead, loggerhead turtle, Caretta caretta",
54
+ "34": "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea",
55
+ "35": "mud turtle",
56
+ "36": "terrapin",
57
+ "37": "box turtle, box tortoise",
58
+ "38": "banded gecko",
59
+ "39": "common iguana, iguana, Iguana iguana",
60
+ "40": "American chameleon, anole, Anolis carolinensis",
61
+ "41": "whiptail, whiptail lizard",
62
+ "42": "agama",
63
+ "43": "frilled lizard, Chlamydosaurus kingi",
64
+ "44": "alligator lizard",
65
+ "45": "Gila monster, Heloderma suspectum",
66
+ "46": "green lizard, Lacerta viridis",
67
+ "47": "African chameleon, Chamaeleo chamaeleon",
68
+ "48": "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis",
69
+ "49": "African crocodile, Nile crocodile, Crocodylus niloticus",
70
+ "50": "American alligator, Alligator mississipiensis",
71
+ "51": "triceratops",
72
+ "52": "thunder snake, worm snake, Carphophis amoenus",
73
+ "53": "ringneck snake, ring-necked snake, ring snake",
74
+ "54": "hognose snake, puff adder, sand viper",
75
+ "55": "green snake, grass snake",
76
+ "56": "king snake, kingsnake",
77
+ "57": "garter snake, grass snake",
78
+ "58": "water snake",
79
+ "59": "vine snake",
80
+ "60": "night snake, Hypsiglena torquata",
81
+ "61": "boa constrictor, Constrictor constrictor",
82
+ "62": "rock python, rock snake, Python sebae",
83
+ "63": "Indian cobra, Naja naja",
84
+ "64": "green mamba",
85
+ "65": "sea snake",
86
+ "66": "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus",
87
+ "67": "diamondback, diamondback rattlesnake, Crotalus adamanteus",
88
+ "68": "sidewinder, horned rattlesnake, Crotalus cerastes",
89
+ "69": "trilobite",
90
+ "70": "harvestman, daddy longlegs, Phalangium opilio",
91
+ "71": "scorpion",
92
+ "72": "black and gold garden spider, Argiope aurantia",
93
+ "73": "barn spider, Araneus cavaticus",
94
+ "74": "garden spider, Aranea diademata",
95
+ "75": "black widow, Latrodectus mactans",
96
+ "76": "tarantula",
97
+ "77": "wolf spider, hunting spider",
98
+ "78": "tick",
99
+ "79": "centipede",
100
+ "80": "black grouse",
101
+ "81": "ptarmigan",
102
+ "82": "ruffed grouse, partridge, Bonasa umbellus",
103
+ "83": "prairie chicken, prairie grouse, prairie fowl",
104
+ "84": "peacock",
105
+ "85": "quail",
106
+ "86": "partridge",
107
+ "87": "African grey, African gray, Psittacus erithacus",
108
+ "88": "macaw",
109
+ "89": "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita",
110
+ "90": "lorikeet",
111
+ "91": "coucal",
112
+ "92": "bee eater",
113
+ "93": "hornbill",
114
+ "94": "hummingbird",
115
+ "95": "jacamar",
116
+ "96": "toucan",
117
+ "97": "drake",
118
+ "98": "red-breasted merganser, Mergus serrator",
119
+ "99": "goose",
120
+ "100": "black swan, Cygnus atratus",
121
+ "101": "tusker",
122
+ "102": "echidna, spiny anteater, anteater",
123
+ "103": "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus",
124
+ "104": "wallaby, brush kangaroo",
125
+ "105": "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus",
126
+ "106": "wombat",
127
+ "107": "jellyfish",
128
+ "108": "sea anemone, anemone",
129
+ "109": "brain coral",
130
+ "110": "flatworm, platyhelminth",
131
+ "111": "nematode, nematode worm, roundworm",
132
+ "112": "conch",
133
+ "113": "snail",
134
+ "114": "slug",
135
+ "115": "sea slug, nudibranch",
136
+ "116": "chiton, coat-of-mail shell, sea cradle, polyplacophore",
137
+ "117": "chambered nautilus, pearly nautilus, nautilus",
138
+ "118": "Dungeness crab, Cancer magister",
139
+ "119": "rock crab, Cancer irroratus",
140
+ "120": "fiddler crab",
141
+ "121": "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica",
142
+ "122": "American lobster, Northern lobster, Maine lobster, Homarus americanus",
143
+ "123": "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish",
144
+ "124": "crayfish, crawfish, crawdad, crawdaddy",
145
+ "125": "hermit crab",
146
+ "126": "isopod",
147
+ "127": "white stork, Ciconia ciconia",
148
+ "128": "black stork, Ciconia nigra",
149
+ "129": "spoonbill",
150
+ "130": "flamingo",
151
+ "131": "little blue heron, Egretta caerulea",
152
+ "132": "American egret, great white heron, Egretta albus",
153
+ "133": "bittern",
154
+ "134": "crane",
155
+ "135": "limpkin, Aramus pictus",
156
+ "136": "European gallinule, Porphyrio porphyrio",
157
+ "137": "American coot, marsh hen, mud hen, water hen, Fulica americana",
158
+ "138": "bustard",
159
+ "139": "ruddy turnstone, Arenaria interpres",
160
+ "140": "red-backed sandpiper, dunlin, Erolia alpina",
161
+ "141": "redshank, Tringa totanus",
162
+ "142": "dowitcher",
163
+ "143": "oystercatcher, oyster catcher",
164
+ "144": "pelican",
165
+ "145": "king penguin, Aptenodytes patagonica",
166
+ "146": "albatross, mollymawk",
167
+ "147": "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus",
168
+ "148": "killer whale, killer, orca, grampus, sea wolf, Orcinus orca",
169
+ "149": "dugong, Dugong dugon",
170
+ "150": "sea lion",
171
+ "151": "Chihuahua",
172
+ "152": "Japanese spaniel",
173
+ "153": "Maltese dog, Maltese terrier, Maltese",
174
+ "154": "Pekinese, Pekingese, Peke",
175
+ "155": "Shih-Tzu",
176
+ "156": "Blenheim spaniel",
177
+ "157": "papillon",
178
+ "158": "toy terrier",
179
+ "159": "Rhodesian ridgeback",
180
+ "160": "Afghan hound, Afghan",
181
+ "161": "basset, basset hound",
182
+ "162": "beagle",
183
+ "163": "bloodhound, sleuthhound",
184
+ "164": "bluetick",
185
+ "165": "black-and-tan coonhound",
186
+ "166": "Walker hound, Walker foxhound",
187
+ "167": "English foxhound",
188
+ "168": "redbone",
189
+ "169": "borzoi, Russian wolfhound",
190
+ "170": "Irish wolfhound",
191
+ "171": "Italian greyhound",
192
+ "172": "whippet",
193
+ "173": "Ibizan hound, Ibizan Podenco",
194
+ "174": "Norwegian elkhound, elkhound",
195
+ "175": "otterhound, otter hound",
196
+ "176": "Saluki, gazelle hound",
197
+ "177": "Scottish deerhound, deerhound",
198
+ "178": "Weimaraner",
199
+ "179": "Staffordshire bullterrier, Staffordshire bull terrier",
200
+ "180": "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier",
201
+ "181": "Bedlington terrier",
202
+ "182": "Border terrier",
203
+ "183": "Kerry blue terrier",
204
+ "184": "Irish terrier",
205
+ "185": "Norfolk terrier",
206
+ "186": "Norwich terrier",
207
+ "187": "Yorkshire terrier",
208
+ "188": "wire-haired fox terrier",
209
+ "189": "Lakeland terrier",
210
+ "190": "Sealyham terrier, Sealyham",
211
+ "191": "Airedale, Airedale terrier",
212
+ "192": "cairn, cairn terrier",
213
+ "193": "Australian terrier",
214
+ "194": "Dandie Dinmont, Dandie Dinmont terrier",
215
+ "195": "Boston bull, Boston terrier",
216
+ "196": "miniature schnauzer",
217
+ "197": "giant schnauzer",
218
+ "198": "standard schnauzer",
219
+ "199": "Scotch terrier, Scottish terrier, Scottie",
220
+ "200": "Tibetan terrier, chrysanthemum dog",
221
+ "201": "silky terrier, Sydney silky",
222
+ "202": "soft-coated wheaten terrier",
223
+ "203": "West Highland white terrier",
224
+ "204": "Lhasa, Lhasa apso",
225
+ "205": "flat-coated retriever",
226
+ "206": "curly-coated retriever",
227
+ "207": "golden retriever",
228
+ "208": "Labrador retriever",
229
+ "209": "Chesapeake Bay retriever",
230
+ "210": "German short-haired pointer",
231
+ "211": "vizsla, Hungarian pointer",
232
+ "212": "English setter",
233
+ "213": "Irish setter, red setter",
234
+ "214": "Gordon setter",
235
+ "215": "Brittany spaniel",
236
+ "216": "clumber, clumber spaniel",
237
+ "217": "English springer, English springer spaniel",
238
+ "218": "Welsh springer spaniel",
239
+ "219": "cocker spaniel, English cocker spaniel, cocker",
240
+ "220": "Sussex spaniel",
241
+ "221": "Irish water spaniel",
242
+ "222": "kuvasz",
243
+ "223": "schipperke",
244
+ "224": "groenendael",
245
+ "225": "malinois",
246
+ "226": "briard",
247
+ "227": "kelpie",
248
+ "228": "komondor",
249
+ "229": "Old English sheepdog, bobtail",
250
+ "230": "Shetland sheepdog, Shetland sheep dog, Shetland",
251
+ "231": "collie",
252
+ "232": "Border collie",
253
+ "233": "Bouvier des Flandres, Bouviers des Flandres",
254
+ "234": "Rottweiler",
255
+ "235": "German shepherd, German shepherd dog, German police dog, alsatian",
256
+ "236": "Doberman, Doberman pinscher",
257
+ "237": "miniature pinscher",
258
+ "238": "Greater Swiss Mountain dog",
259
+ "239": "Bernese mountain dog",
260
+ "240": "Appenzeller",
261
+ "241": "EntleBucher",
262
+ "242": "boxer",
263
+ "243": "bull mastiff",
264
+ "244": "Tibetan mastiff",
265
+ "245": "French bulldog",
266
+ "246": "Great Dane",
267
+ "247": "Saint Bernard, St Bernard",
268
+ "248": "Eskimo dog, husky",
269
+ "249": "malamute, malemute, Alaskan malamute",
270
+ "250": "Siberian husky",
271
+ "251": "dalmatian, coach dog, carriage dog",
272
+ "252": "affenpinscher, monkey pinscher, monkey dog",
273
+ "253": "basenji",
274
+ "254": "pug, pug-dog",
275
+ "255": "Leonberg",
276
+ "256": "Newfoundland, Newfoundland dog",
277
+ "257": "Great Pyrenees",
278
+ "258": "Samoyed, Samoyede",
279
+ "259": "Pomeranian",
280
+ "260": "chow, chow chow",
281
+ "261": "keeshond",
282
+ "262": "Brabancon griffon",
283
+ "263": "Pembroke, Pembroke Welsh corgi",
284
+ "264": "Cardigan, Cardigan Welsh corgi",
285
+ "265": "toy poodle",
286
+ "266": "miniature poodle",
287
+ "267": "standard poodle",
288
+ "268": "Mexican hairless",
289
+ "269": "timber wolf, grey wolf, gray wolf, Canis lupus",
290
+ "270": "white wolf, Arctic wolf, Canis lupus tundrarum",
291
+ "271": "red wolf, maned wolf, Canis rufus, Canis niger",
292
+ "272": "coyote, prairie wolf, brush wolf, Canis latrans",
293
+ "273": "dingo, warrigal, warragal, Canis dingo",
294
+ "274": "dhole, Cuon alpinus",
295
+ "275": "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus",
296
+ "276": "hyena, hyaena",
297
+ "277": "red fox, Vulpes vulpes",
298
+ "278": "kit fox, Vulpes macrotis",
299
+ "279": "Arctic fox, white fox, Alopex lagopus",
300
+ "280": "grey fox, gray fox, Urocyon cinereoargenteus",
301
+ "281": "tabby, tabby cat",
302
+ "282": "tiger cat",
303
+ "283": "Persian cat",
304
+ "284": "Siamese cat, Siamese",
305
+ "285": "Egyptian cat",
306
+ "286": "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor",
307
+ "287": "lynx, catamount",
308
+ "288": "leopard, Panthera pardus",
309
+ "289": "snow leopard, ounce, Panthera uncia",
310
+ "290": "jaguar, panther, Panthera onca, Felis onca",
311
+ "291": "lion, king of beasts, Panthera leo",
312
+ "292": "tiger, Panthera tigris",
313
+ "293": "cheetah, chetah, Acinonyx jubatus",
314
+ "294": "brown bear, bruin, Ursus arctos",
315
+ "295": "American black bear, black bear, Ursus americanus, Euarctos americanus",
316
+ "296": "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus",
317
+ "297": "sloth bear, Melursus ursinus, Ursus ursinus",
318
+ "298": "mongoose",
319
+ "299": "meerkat, mierkat",
320
+ "300": "tiger beetle",
321
+ "301": "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle",
322
+ "302": "ground beetle, carabid beetle",
323
+ "303": "long-horned beetle, longicorn, longicorn beetle",
324
+ "304": "leaf beetle, chrysomelid",
325
+ "305": "dung beetle",
326
+ "306": "rhinoceros beetle",
327
+ "307": "weevil",
328
+ "308": "fly",
329
+ "309": "bee",
330
+ "310": "ant, emmet, pismire",
331
+ "311": "grasshopper, hopper",
332
+ "312": "cricket",
333
+ "313": "walking stick, walkingstick, stick insect",
334
+ "314": "cockroach, roach",
335
+ "315": "mantis, mantid",
336
+ "316": "cicada, cicala",
337
+ "317": "leafhopper",
338
+ "318": "lacewing, lacewing fly",
339
+ "319": "dragonfly, darning needle, devils darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk",
340
+ "320": "damselfly",
341
+ "321": "admiral",
342
+ "322": "ringlet, ringlet butterfly",
343
+ "323": "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus",
344
+ "324": "cabbage butterfly",
345
+ "325": "sulphur butterfly, sulfur butterfly",
346
+ "326": "lycaenid, lycaenid butterfly",
347
+ "327": "starfish, sea star",
348
+ "328": "sea urchin",
349
+ "329": "sea cucumber, holothurian",
350
+ "330": "wood rabbit, cottontail, cottontail rabbit",
351
+ "331": "hare",
352
+ "332": "Angora, Angora rabbit",
353
+ "333": "hamster",
354
+ "334": "porcupine, hedgehog",
355
+ "335": "fox squirrel, eastern fox squirrel, Sciurus niger",
356
+ "336": "marmot",
357
+ "337": "beaver",
358
+ "338": "guinea pig, Cavia cobaya",
359
+ "339": "sorrel",
360
+ "340": "zebra",
361
+ "341": "hog, pig, grunter, squealer, Sus scrofa",
362
+ "342": "wild boar, boar, Sus scrofa",
363
+ "343": "warthog",
364
+ "344": "hippopotamus, hippo, river horse, Hippopotamus amphibius",
365
+ "345": "ox",
366
+ "346": "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis",
367
+ "347": "bison",
368
+ "348": "ram, tup",
369
+ "349": "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis",
370
+ "350": "ibex, Capra ibex",
371
+ "351": "hartebeest",
372
+ "352": "impala, Aepyceros melampus",
373
+ "353": "gazelle",
374
+ "354": "Arabian camel, dromedary, Camelus dromedarius",
375
+ "355": "llama",
376
+ "356": "weasel",
377
+ "357": "mink",
378
+ "358": "polecat, fitch, foulmart, foumart, Mustela putorius",
379
+ "359": "black-footed ferret, ferret, Mustela nigripes",
380
+ "360": "otter",
381
+ "361": "skunk, polecat, wood pussy",
382
+ "362": "badger",
383
+ "363": "armadillo",
384
+ "364": "three-toed sloth, ai, Bradypus tridactylus",
385
+ "365": "orangutan, orang, orangutang, Pongo pygmaeus",
386
+ "366": "gorilla, Gorilla gorilla",
387
+ "367": "chimpanzee, chimp, Pan troglodytes",
388
+ "368": "gibbon, Hylobates lar",
389
+ "369": "siamang, Hylobates syndactylus, Symphalangus syndactylus",
390
+ "370": "guenon, guenon monkey",
391
+ "371": "patas, hussar monkey, Erythrocebus patas",
392
+ "372": "baboon",
393
+ "373": "macaque",
394
+ "374": "langur",
395
+ "375": "colobus, colobus monkey",
396
+ "376": "proboscis monkey, Nasalis larvatus",
397
+ "377": "marmoset",
398
+ "378": "capuchin, ringtail, Cebus capucinus",
399
+ "379": "howler monkey, howler",
400
+ "380": "titi, titi monkey",
401
+ "381": "spider monkey, Ateles geoffroyi",
402
+ "382": "squirrel monkey, Saimiri sciureus",
403
+ "383": "Madagascar cat, ring-tailed lemur, Lemur catta",
404
+ "384": "indri, indris, Indri indri, Indri brevicaudatus",
405
+ "385": "Indian elephant, Elephas maximus",
406
+ "386": "African elephant, Loxodonta africana",
407
+ "387": "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens",
408
+ "388": "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca",
409
+ "389": "barracouta, snoek",
410
+ "390": "eel",
411
+ "391": "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch",
412
+ "392": "rock beauty, Holocanthus tricolor",
413
+ "393": "anemone fish",
414
+ "394": "sturgeon",
415
+ "395": "gar, garfish, garpike, billfish, Lepisosteus osseus",
416
+ "396": "lionfish",
417
+ "397": "puffer, pufferfish, blowfish, globefish",
418
+ "398": "abacus",
419
+ "399": "abaya",
420
+ "400": "academic gown, academic robe, judge robe",
421
+ "401": "accordion, piano accordion, squeeze box",
422
+ "402": "acoustic guitar",
423
+ "403": "aircraft carrier, carrier, flattop, attack aircraft carrier",
424
+ "404": "airliner",
425
+ "405": "airship, dirigible",
426
+ "406": "altar",
427
+ "407": "ambulance",
428
+ "408": "amphibian, amphibious vehicle",
429
+ "409": "analog clock",
430
+ "410": "apiary, bee house",
431
+ "411": "apron",
432
+ "412": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin",
433
+ "413": "assault rifle, assault gun",
434
+ "414": "backpack, back pack, knapsack, packsack, rucksack, haversack",
435
+ "415": "bakery, bakeshop, bakehouse",
436
+ "416": "balance beam, beam",
437
+ "417": "balloon",
438
+ "418": "ballpoint, ballpoint pen, ballpen, Biro",
439
+ "419": "Band Aid",
440
+ "420": "banjo",
441
+ "421": "bannister, banister, balustrade, balusters, handrail",
442
+ "422": "barbell",
443
+ "423": "barber chair",
444
+ "424": "barbershop",
445
+ "425": "barn",
446
+ "426": "barometer",
447
+ "427": "barrel, cask",
448
+ "428": "barrow, garden cart, lawn cart, wheelbarrow",
449
+ "429": "baseball",
450
+ "430": "basketball",
451
+ "431": "bassinet",
452
+ "432": "bassoon",
453
+ "433": "bathing cap, swimming cap",
454
+ "434": "bath towel",
455
+ "435": "bathtub, bathing tub, bath, tub",
456
+ "436": "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon",
457
+ "437": "beacon, lighthouse, beacon light, pharos",
458
+ "438": "beaker",
459
+ "439": "bearskin, busby, shako",
460
+ "440": "beer bottle",
461
+ "441": "beer glass",
462
+ "442": "bell cote, bell cot",
463
+ "443": "bib",
464
+ "444": "bicycle-built-for-two, tandem bicycle, tandem",
465
+ "445": "bikini, two-piece",
466
+ "446": "binder, ring-binder",
467
+ "447": "binoculars, field glasses, opera glasses",
468
+ "448": "birdhouse",
469
+ "449": "boathouse",
470
+ "450": "bobsled, bobsleigh, bob",
471
+ "451": "bolo tie, bolo, bola tie, bola",
472
+ "452": "bonnet, poke bonnet",
473
+ "453": "bookcase",
474
+ "454": "bookshop, bookstore, bookstall",
475
+ "455": "bottlecap",
476
+ "456": "bow",
477
+ "457": "bow tie, bow-tie, bowtie",
478
+ "458": "brass, memorial tablet, plaque",
479
+ "459": "brassiere, bra, bandeau",
480
+ "460": "breakwater, groin, groyne, mole, bulwark, seawall, jetty",
481
+ "461": "breastplate, aegis, egis",
482
+ "462": "broom",
483
+ "463": "bucket, pail",
484
+ "464": "buckle",
485
+ "465": "bulletproof vest",
486
+ "466": "bullet train, bullet",
487
+ "467": "butcher shop, meat market",
488
+ "468": "cab, hack, taxi, taxicab",
489
+ "469": "caldron, cauldron",
490
+ "470": "candle, taper, wax light",
491
+ "471": "cannon",
492
+ "472": "canoe",
493
+ "473": "can opener, tin opener",
494
+ "474": "cardigan",
495
+ "475": "car mirror",
496
+ "476": "carousel, carrousel, merry-go-round, roundabout, whirligig",
497
+ "477": "carpenters kit, tool kit",
498
+ "478": "carton",
499
+ "479": "car wheel",
500
+ "480": "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM",
501
+ "481": "cassette",
502
+ "482": "cassette player",
503
+ "483": "castle",
504
+ "484": "catamaran",
505
+ "485": "CD player",
506
+ "486": "cello, violoncello",
507
+ "487": "cellular telephone, cellular phone, cellphone, cell, mobile phone",
508
+ "488": "chain",
509
+ "489": "chainlink fence",
510
+ "490": "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour",
511
+ "491": "chain saw, chainsaw",
512
+ "492": "chest",
513
+ "493": "chiffonier, commode",
514
+ "494": "chime, bell, gong",
515
+ "495": "china cabinet, china closet",
516
+ "496": "Christmas stocking",
517
+ "497": "church, church building",
518
+ "498": "cinema, movie theater, movie theatre, movie house, picture palace",
519
+ "499": "cleaver, meat cleaver, chopper",
520
+ "500": "cliff dwelling",
521
+ "501": "cloak",
522
+ "502": "clog, geta, patten, sabot",
523
+ "503": "cocktail shaker",
524
+ "504": "coffee mug",
525
+ "505": "coffeepot",
526
+ "506": "coil, spiral, volute, whorl, helix",
527
+ "507": "combination lock",
528
+ "508": "computer keyboard, keypad",
529
+ "509": "confectionery, confectionary, candy store",
530
+ "510": "container ship, containership, container vessel",
531
+ "511": "convertible",
532
+ "512": "corkscrew, bottle screw",
533
+ "513": "cornet, horn, trumpet, trump",
534
+ "514": "cowboy boot",
535
+ "515": "cowboy hat, ten-gallon hat",
536
+ "516": "cradle",
537
+ "517": "crane",
538
+ "518": "crash helmet",
539
+ "519": "crate",
540
+ "520": "crib, cot",
541
+ "521": "Crock Pot",
542
+ "522": "croquet ball",
543
+ "523": "crutch",
544
+ "524": "cuirass",
545
+ "525": "dam, dike, dyke",
546
+ "526": "desk",
547
+ "527": "desktop computer",
548
+ "528": "dial telephone, dial phone",
549
+ "529": "diaper, nappy, napkin",
550
+ "530": "digital clock",
551
+ "531": "digital watch",
552
+ "532": "dining table, board",
553
+ "533": "dishrag, dishcloth",
554
+ "534": "dishwasher, dish washer, dishwashing machine",
555
+ "535": "disk brake, disc brake",
556
+ "536": "dock, dockage, docking facility",
557
+ "537": "dogsled, dog sled, dog sleigh",
558
+ "538": "dome",
559
+ "539": "doormat, welcome mat",
560
+ "540": "drilling platform, offshore rig",
561
+ "541": "drum, membranophone, tympan",
562
+ "542": "drumstick",
563
+ "543": "dumbbell",
564
+ "544": "Dutch oven",
565
+ "545": "electric fan, blower",
566
+ "546": "electric guitar",
567
+ "547": "electric locomotive",
568
+ "548": "entertainment center",
569
+ "549": "envelope",
570
+ "550": "espresso maker",
571
+ "551": "face powder",
572
+ "552": "feather boa, boa",
573
+ "553": "file, file cabinet, filing cabinet",
574
+ "554": "fireboat",
575
+ "555": "fire engine, fire truck",
576
+ "556": "fire screen, fireguard",
577
+ "557": "flagpole, flagstaff",
578
+ "558": "flute, transverse flute",
579
+ "559": "folding chair",
580
+ "560": "football helmet",
581
+ "561": "forklift",
582
+ "562": "fountain",
583
+ "563": "fountain pen",
584
+ "564": "four-poster",
585
+ "565": "freight car",
586
+ "566": "French horn, horn",
587
+ "567": "frying pan, frypan, skillet",
588
+ "568": "fur coat",
589
+ "569": "garbage truck, dustcart",
590
+ "570": "gasmask, respirator, gas helmet",
591
+ "571": "gas pump, gasoline pump, petrol pump, island dispenser",
592
+ "572": "goblet",
593
+ "573": "go-kart",
594
+ "574": "golf ball",
595
+ "575": "golfcart, golf cart",
596
+ "576": "gondola",
597
+ "577": "gong, tam-tam",
598
+ "578": "gown",
599
+ "579": "grand piano, grand",
600
+ "580": "greenhouse, nursery, glasshouse",
601
+ "581": "grille, radiator grille",
602
+ "582": "grocery store, grocery, food market, market",
603
+ "583": "guillotine",
604
+ "584": "hair slide",
605
+ "585": "hair spray",
606
+ "586": "half track",
607
+ "587": "hammer",
608
+ "588": "hamper",
609
+ "589": "hand blower, blow dryer, blow drier, hair dryer, hair drier",
610
+ "590": "hand-held computer, hand-held microcomputer",
611
+ "591": "handkerchief, hankie, hanky, hankey",
612
+ "592": "hard disc, hard disk, fixed disk",
613
+ "593": "harmonica, mouth organ, harp, mouth harp",
614
+ "594": "harp",
615
+ "595": "harvester, reaper",
616
+ "596": "hatchet",
617
+ "597": "holster",
618
+ "598": "home theater, home theatre",
619
+ "599": "honeycomb",
620
+ "600": "hook, claw",
621
+ "601": "hoopskirt, crinoline",
622
+ "602": "horizontal bar, high bar",
623
+ "603": "horse cart, horse-cart",
624
+ "604": "hourglass",
625
+ "605": "iPod",
626
+ "606": "iron, smoothing iron",
627
+ "607": "jack-o-lantern",
628
+ "608": "jean, blue jean, denim",
629
+ "609": "jeep, landrover",
630
+ "610": "jersey, T-shirt, tee shirt",
631
+ "611": "jigsaw puzzle",
632
+ "612": "jinrikisha, ricksha, rickshaw",
633
+ "613": "joystick",
634
+ "614": "kimono",
635
+ "615": "knee pad",
636
+ "616": "knot",
637
+ "617": "lab coat, laboratory coat",
638
+ "618": "ladle",
639
+ "619": "lampshade, lamp shade",
640
+ "620": "laptop, laptop computer",
641
+ "621": "lawn mower, mower",
642
+ "622": "lens cap, lens cover",
643
+ "623": "letter opener, paper knife, paperknife",
644
+ "624": "library",
645
+ "625": "lifeboat",
646
+ "626": "lighter, light, igniter, ignitor",
647
+ "627": "limousine, limo",
648
+ "628": "liner, ocean liner",
649
+ "629": "lipstick, lip rouge",
650
+ "630": "Loafer",
651
+ "631": "lotion",
652
+ "632": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system",
653
+ "633": "loupe, jewelers loupe",
654
+ "634": "lumbermill, sawmill",
655
+ "635": "magnetic compass",
656
+ "636": "mailbag, postbag",
657
+ "637": "mailbox, letter box",
658
+ "638": "maillot",
659
+ "639": "maillot, tank suit",
660
+ "640": "manhole cover",
661
+ "641": "maraca",
662
+ "642": "marimba, xylophone",
663
+ "643": "mask",
664
+ "644": "matchstick",
665
+ "645": "maypole",
666
+ "646": "maze, labyrinth",
667
+ "647": "measuring cup",
668
+ "648": "medicine chest, medicine cabinet",
669
+ "649": "megalith, megalithic structure",
670
+ "650": "microphone, mike",
671
+ "651": "microwave, microwave oven",
672
+ "652": "military uniform",
673
+ "653": "milk can",
674
+ "654": "minibus",
675
+ "655": "miniskirt, mini",
676
+ "656": "minivan",
677
+ "657": "missile",
678
+ "658": "mitten",
679
+ "659": "mixing bowl",
680
+ "660": "mobile home, manufactured home",
681
+ "661": "Model T",
682
+ "662": "modem",
683
+ "663": "monastery",
684
+ "664": "monitor",
685
+ "665": "moped",
686
+ "666": "mortar",
687
+ "667": "mortarboard",
688
+ "668": "mosque",
689
+ "669": "mosquito net",
690
+ "670": "motor scooter, scooter",
691
+ "671": "mountain bike, all-terrain bike, off-roader",
692
+ "672": "mountain tent",
693
+ "673": "mouse, computer mouse",
694
+ "674": "mousetrap",
695
+ "675": "moving van",
696
+ "676": "muzzle",
697
+ "677": "nail",
698
+ "678": "neck brace",
699
+ "679": "necklace",
700
+ "680": "nipple",
701
+ "681": "notebook, notebook computer",
702
+ "682": "obelisk",
703
+ "683": "oboe, hautboy, hautbois",
704
+ "684": "ocarina, sweet potato",
705
+ "685": "odometer, hodometer, mileometer, milometer",
706
+ "686": "oil filter",
707
+ "687": "organ, pipe organ",
708
+ "688": "oscilloscope, scope, cathode-ray oscilloscope, CRO",
709
+ "689": "overskirt",
710
+ "690": "oxcart",
711
+ "691": "oxygen mask",
712
+ "692": "packet",
713
+ "693": "paddle, boat paddle",
714
+ "694": "paddlewheel, paddle wheel",
715
+ "695": "padlock",
716
+ "696": "paintbrush",
717
+ "697": "pajama, pyjama, pjs, jammies",
718
+ "698": "palace",
719
+ "699": "panpipe, pandean pipe, syrinx",
720
+ "700": "paper towel",
721
+ "701": "parachute, chute",
722
+ "702": "parallel bars, bars",
723
+ "703": "park bench",
724
+ "704": "parking meter",
725
+ "705": "passenger car, coach, carriage",
726
+ "706": "patio, terrace",
727
+ "707": "pay-phone, pay-station",
728
+ "708": "pedestal, plinth, footstall",
729
+ "709": "pencil box, pencil case",
730
+ "710": "pencil sharpener",
731
+ "711": "perfume, essence",
732
+ "712": "Petri dish",
733
+ "713": "photocopier",
734
+ "714": "pick, plectrum, plectron",
735
+ "715": "pickelhaube",
736
+ "716": "picket fence, paling",
737
+ "717": "pickup, pickup truck",
738
+ "718": "pier",
739
+ "719": "piggy bank, penny bank",
740
+ "720": "pill bottle",
741
+ "721": "pillow",
742
+ "722": "ping-pong ball",
743
+ "723": "pinwheel",
744
+ "724": "pirate, pirate ship",
745
+ "725": "pitcher, ewer",
746
+ "726": "plane, carpenters plane, woodworking plane",
747
+ "727": "planetarium",
748
+ "728": "plastic bag",
749
+ "729": "plate rack",
750
+ "730": "plow, plough",
751
+ "731": "plunger, plumbers helper",
752
+ "732": "Polaroid camera, Polaroid Land camera",
753
+ "733": "pole",
754
+ "734": "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria",
755
+ "735": "poncho",
756
+ "736": "pool table, billiard table, snooker table",
757
+ "737": "pop bottle, soda bottle",
758
+ "738": "pot, flowerpot",
759
+ "739": "potters wheel",
760
+ "740": "power drill",
761
+ "741": "prayer rug, prayer mat",
762
+ "742": "printer",
763
+ "743": "prison, prison house",
764
+ "744": "projectile, missile",
765
+ "745": "projector",
766
+ "746": "puck, hockey puck",
767
+ "747": "punching bag, punch bag, punching ball, punchball",
768
+ "748": "purse",
769
+ "749": "quill, quill pen",
770
+ "750": "quilt, comforter, comfort, puff",
771
+ "751": "racer, race car, racing car",
772
+ "752": "racket, racquet",
773
+ "753": "radiator",
774
+ "754": "radio, wireless",
775
+ "755": "radio telescope, radio reflector",
776
+ "756": "rain barrel",
777
+ "757": "recreational vehicle, RV, R.V.",
778
+ "758": "reel",
779
+ "759": "reflex camera",
780
+ "760": "refrigerator, icebox",
781
+ "761": "remote control, remote",
782
+ "762": "restaurant, eating house, eating place, eatery",
783
+ "763": "revolver, six-gun, six-shooter",
784
+ "764": "rifle",
785
+ "765": "rocking chair, rocker",
786
+ "766": "rotisserie",
787
+ "767": "rubber eraser, rubber, pencil eraser",
788
+ "768": "rugby ball",
789
+ "769": "rule, ruler",
790
+ "770": "running shoe",
791
+ "771": "safe",
792
+ "772": "safety pin",
793
+ "773": "saltshaker, salt shaker",
794
+ "774": "sandal",
795
+ "775": "sarong",
796
+ "776": "sax, saxophone",
797
+ "777": "scabbard",
798
+ "778": "scale, weighing machine",
799
+ "779": "school bus",
800
+ "780": "schooner",
801
+ "781": "scoreboard",
802
+ "782": "screen, CRT screen",
803
+ "783": "screw",
804
+ "784": "screwdriver",
805
+ "785": "seat belt, seatbelt",
806
+ "786": "sewing machine",
807
+ "787": "shield, buckler",
808
+ "788": "shoe shop, shoe-shop, shoe store",
809
+ "789": "shoji",
810
+ "790": "shopping basket",
811
+ "791": "shopping cart",
812
+ "792": "shovel",
813
+ "793": "shower cap",
814
+ "794": "shower curtain",
815
+ "795": "ski",
816
+ "796": "ski mask",
817
+ "797": "sleeping bag",
818
+ "798": "slide rule, slipstick",
819
+ "799": "sliding door",
820
+ "800": "slot, one-armed bandit",
821
+ "801": "snorkel",
822
+ "802": "snowmobile",
823
+ "803": "snowplow, snowplough",
824
+ "804": "soap dispenser",
825
+ "805": "soccer ball",
826
+ "806": "sock",
827
+ "807": "solar dish, solar collector, solar furnace",
828
+ "808": "sombrero",
829
+ "809": "soup bowl",
830
+ "810": "space bar",
831
+ "811": "space heater",
832
+ "812": "space shuttle",
833
+ "813": "spatula",
834
+ "814": "speedboat",
835
+ "815": "spider web, spiders web",
836
+ "816": "spindle",
837
+ "817": "sports car, sport car",
838
+ "818": "spotlight, spot",
839
+ "819": "stage",
840
+ "820": "steam locomotive",
841
+ "821": "steel arch bridge",
842
+ "822": "steel drum",
843
+ "823": "stethoscope",
844
+ "824": "stole",
845
+ "825": "stone wall",
846
+ "826": "stopwatch, stop watch",
847
+ "827": "stove",
848
+ "828": "strainer",
849
+ "829": "streetcar, tram, tramcar, trolley, trolley car",
850
+ "830": "stretcher",
851
+ "831": "studio couch, day bed",
852
+ "832": "stupa, tope",
853
+ "833": "submarine, pigboat, sub, U-boat",
854
+ "834": "suit, suit of clothes",
855
+ "835": "sundial",
856
+ "836": "sunglass",
857
+ "837": "sunglasses, dark glasses, shades",
858
+ "838": "sunscreen, sunblock, sun blocker",
859
+ "839": "suspension bridge",
860
+ "840": "swab, swob, mop",
861
+ "841": "sweatshirt",
862
+ "842": "swimming trunks, bathing trunks",
863
+ "843": "swing",
864
+ "844": "switch, electric switch, electrical switch",
865
+ "845": "syringe",
866
+ "846": "table lamp",
867
+ "847": "tank, army tank, armored combat vehicle, armoured combat vehicle",
868
+ "848": "tape player",
869
+ "849": "teapot",
870
+ "850": "teddy, teddy bear",
871
+ "851": "television, television system",
872
+ "852": "tennis ball",
873
+ "853": "thatch, thatched roof",
874
+ "854": "theater curtain, theatre curtain",
875
+ "855": "thimble",
876
+ "856": "thresher, thrasher, threshing machine",
877
+ "857": "throne",
878
+ "858": "tile roof",
879
+ "859": "toaster",
880
+ "860": "tobacco shop, tobacconist shop, tobacconist",
881
+ "861": "toilet seat",
882
+ "862": "torch",
883
+ "863": "totem pole",
884
+ "864": "tow truck, tow car, wrecker",
885
+ "865": "toyshop",
886
+ "866": "tractor",
887
+ "867": "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi",
888
+ "868": "tray",
889
+ "869": "trench coat",
890
+ "870": "tricycle, trike, velocipede",
891
+ "871": "trimaran",
892
+ "872": "tripod",
893
+ "873": "triumphal arch",
894
+ "874": "trolleybus, trolley coach, trackless trolley",
895
+ "875": "trombone",
896
+ "876": "tub, vat",
897
+ "877": "turnstile",
898
+ "878": "typewriter keyboard",
899
+ "879": "umbrella",
900
+ "880": "unicycle, monocycle",
901
+ "881": "upright, upright piano",
902
+ "882": "vacuum, vacuum cleaner",
903
+ "883": "vase",
904
+ "884": "vault",
905
+ "885": "velvet",
906
+ "886": "vending machine",
907
+ "887": "vestment",
908
+ "888": "viaduct",
909
+ "889": "violin, fiddle",
910
+ "890": "volleyball",
911
+ "891": "waffle iron",
912
+ "892": "wall clock",
913
+ "893": "wallet, billfold, notecase, pocketbook",
914
+ "894": "wardrobe, closet, press",
915
+ "895": "warplane, military plane",
916
+ "896": "washbasin, handbasin, washbowl, lavabo, wash-hand basin",
917
+ "897": "washer, automatic washer, washing machine",
918
+ "898": "water bottle",
919
+ "899": "water jug",
920
+ "900": "water tower",
921
+ "901": "whiskey jug",
922
+ "902": "whistle",
923
+ "903": "wig",
924
+ "904": "window screen",
925
+ "905": "window shade",
926
+ "906": "Windsor tie",
927
+ "907": "wine bottle",
928
+ "908": "wing",
929
+ "909": "wok",
930
+ "910": "wooden spoon",
931
+ "911": "wool, woolen, woollen",
932
+ "912": "worm fence, snake fence, snake-rail fence, Virginia fence",
933
+ "913": "wreck",
934
+ "914": "yawl",
935
+ "915": "yurt",
936
+ "916": "web site, website, internet site, site",
937
+ "917": "comic book",
938
+ "918": "crossword puzzle, crossword",
939
+ "919": "street sign",
940
+ "920": "traffic light, traffic signal, stoplight",
941
+ "921": "book jacket, dust cover, dust jacket, dust wrapper",
942
+ "922": "menu",
943
+ "923": "plate",
944
+ "924": "guacamole",
945
+ "925": "consomme",
946
+ "926": "hot pot, hotpot",
947
+ "927": "trifle",
948
+ "928": "ice cream, icecream",
949
+ "929": "ice lolly, lolly, lollipop, popsicle",
950
+ "930": "French loaf",
951
+ "931": "bagel, beigel",
952
+ "932": "pretzel",
953
+ "933": "cheeseburger",
954
+ "934": "hotdog, hot dog, red hot",
955
+ "935": "mashed potato",
956
+ "936": "head cabbage",
957
+ "937": "broccoli",
958
+ "938": "cauliflower",
959
+ "939": "zucchini, courgette",
960
+ "940": "spaghetti squash",
961
+ "941": "acorn squash",
962
+ "942": "butternut squash",
963
+ "943": "cucumber, cuke",
964
+ "944": "artichoke, globe artichoke",
965
+ "945": "bell pepper",
966
+ "946": "cardoon",
967
+ "947": "mushroom",
968
+ "948": "Granny Smith",
969
+ "949": "strawberry",
970
+ "950": "orange",
971
+ "951": "lemon",
972
+ "952": "fig",
973
+ "953": "pineapple, ananas",
974
+ "954": "banana",
975
+ "955": "jackfruit, jak, jack",
976
+ "956": "custard apple",
977
+ "957": "pomegranate",
978
+ "958": "hay",
979
+ "959": "carbonara",
980
+ "960": "chocolate sauce, chocolate syrup",
981
+ "961": "dough",
982
+ "962": "meat loaf, meatloaf",
983
+ "963": "pizza, pizza pie",
984
+ "964": "potpie",
985
+ "965": "burrito",
986
+ "966": "red wine",
987
+ "967": "espresso",
988
+ "968": "cup",
989
+ "969": "eggnog",
990
+ "970": "alp",
991
+ "971": "bubble",
992
+ "972": "cliff, drop, drop-off",
993
+ "973": "coral reef",
994
+ "974": "geyser",
995
+ "975": "lakeside, lakeshore",
996
+ "976": "promontory, headland, head, foreland",
997
+ "977": "sandbar, sand bar",
998
+ "978": "seashore, coast, seacoast, sea-coast",
999
+ "979": "valley, vale",
1000
+ "980": "volcano",
1001
+ "981": "ballplayer, baseball player",
1002
+ "982": "groom, bridegroom",
1003
+ "983": "scuba diver",
1004
+ "984": "rapeseed",
1005
+ "985": "daisy",
1006
+ "986": "yellow ladys slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum",
1007
+ "987": "corn",
1008
+ "988": "acorn",
1009
+ "989": "hip, rose hip, rosehip",
1010
+ "990": "buckeye, horse chestnut, conker",
1011
+ "991": "coral fungus",
1012
+ "992": "agaric",
1013
+ "993": "gyromitra",
1014
+ "994": "stinkhorn, carrion fungus",
1015
+ "995": "earthstar",
1016
+ "996": "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa",
1017
+ "997": "bolete",
1018
+ "998": "ear, spike, capitulum",
1019
+ "999": "toilet tissue, toilet paper, bathroom tissue"
1020
+ }
1021
+ }
SiT-B-2-256/pipeline.py CHANGED
@@ -1,82 +1,349 @@
1
- from typing import List, Optional, Union
2
-
3
- import torch
4
-
5
- from diffusers.image_processor import VaeImageProcessor
6
- from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
7
- from diffusers.utils.torch_utils import randn_tensor
8
-
9
-
10
- class SiTPipeline(DiffusionPipeline):
11
- model_cpu_offload_seq = "transformer->vae"
12
-
13
- def __init__(self, transformer, scheduler, vae):
14
- super().__init__()
15
- self.register_modules(transformer=transformer, scheduler=scheduler, vae=vae)
16
- self.vae_scale_factor = 8
17
- self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
18
-
19
- @torch.no_grad()
20
- def __call__(
21
- self,
22
- class_labels: Union[int, List[int]] = 207,
23
- height: int = 256,
24
- width: int = 256,
25
- num_inference_steps: int = 250,
26
- guidance_scale: float = 4.0,
27
- generator: Optional[torch.Generator] = None,
28
- output_type: str = "pil",
29
- return_dict: bool = True,
30
- ):
31
- device = self._execution_device
32
- if isinstance(class_labels, int):
33
- class_labels = [class_labels]
34
- batch_size = len(class_labels)
35
-
36
- latent_h = height // self.vae_scale_factor
37
- latent_w = width // self.vae_scale_factor
38
- latents = randn_tensor(
39
- (batch_size, self.transformer.config.in_channels, latent_h, latent_w),
40
- generator=generator,
41
- device=device,
42
- dtype=self.transformer.dtype,
43
- )
44
-
45
- labels = torch.tensor(class_labels, device=device, dtype=torch.long)
46
- do_cfg = guidance_scale is not None and guidance_scale > 1.0
47
- if do_cfg:
48
- null_label = torch.full((batch_size,), self.transformer.config.num_classes, device=device, dtype=torch.long)
49
- labels = torch.cat([labels, null_label], dim=0)
50
-
51
- self.scheduler.set_timesteps(num_inference_steps, device=device)
52
- timesteps = self.scheduler.timesteps
53
-
54
- for t in self.progress_bar(timesteps):
55
- t_batch = torch.full((batch_size,), t, device=device, dtype=latents.dtype)
56
- model_input = latents
57
- if do_cfg:
58
- model_input = torch.cat([latents, latents], dim=0)
59
- t_batch = torch.cat([t_batch, t_batch], dim=0)
60
-
61
- model_pred = self.transformer(
62
- hidden_states=model_input,
63
- timestep=t_batch,
64
- class_labels=labels,
65
- ).sample
66
-
67
- if do_cfg:
68
- cond, uncond = model_pred.chunk(2, dim=0)
69
- model_pred = uncond + guidance_scale * (cond - uncond)
70
-
71
- latents = self.scheduler.step(model_pred, t, latents, generator=generator).prev_sample
72
-
73
- image = self.vae.decode(latents / 0.18215).sample
74
- # Keep PyTorch outputs in raw VAE range [-1, 1] to match original SiT scripts.
75
- if output_type == "pt":
76
- image = image
77
- else:
78
- image = self.image_processor.postprocess(image, output_type=output_type)
79
-
80
- if not return_dict:
81
- return (image,)
82
- return ImagePipelineOutput(images=image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Hub custom pipeline: SiTPipeline.
2
+ Load with native Hugging Face diffusers and trust_remote_code=True.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ # Copyright 2026 The HuggingFace Team. All rights reserved.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+
21
+ import json
22
+ from pathlib import Path
23
+ from typing import Dict, List, Optional, Tuple, Union
24
+
25
+ import torch
26
+
27
+ from diffusers.image_processor import VaeImageProcessor
28
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
29
+ from diffusers.utils.torch_utils import randn_tensor
30
+
31
+ EXAMPLE_DOC_STRING = """
32
+ Examples:
33
+ ```py
34
+ >>> from pathlib import Path
35
+ >>> from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler
36
+ >>> import torch
37
+
38
+ >>> model_dir = Path("./SiT-XL-2-256").resolve()
39
+ >>> pipe = DiffusionPipeline.from_pretrained(
40
+ ... str(model_dir),
41
+ ... local_files_only=True,
42
+ ... custom_pipeline=str(model_dir / "pipeline.py"),
43
+ ... trust_remote_code=True,
44
+ ... torch_dtype=torch.bfloat16,
45
+ ... )
46
+ >>> pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(pipe.scheduler.config)
47
+ >>> pipe.to("cuda")
48
+
49
+ >>> print(pipe.id2label[207])
50
+ >>> print(pipe.get_label_ids("golden retriever"))
51
+
52
+ >>> generator = torch.Generator(device="cuda").manual_seed(42)
53
+ >>> image = pipe(
54
+ ... class_labels="golden retriever",
55
+ ... height=256,
56
+ ... width=256,
57
+ ... num_inference_steps=250,
58
+ ... guidance_scale=4.0,
59
+ ... generator=generator,
60
+ ... ).images[0]
61
+ ```
62
+ """
63
+
64
+ class SiTPipeline(DiffusionPipeline):
65
+ r"""
66
+ Pipeline for class-conditional image generation with Scalable Interpolant Transformers (SiT).
67
+
68
+ Parameters:
69
+ transformer ([`SiTTransformer2DModel`]):
70
+ Class-conditional SiT transformer that predicts flow-matching velocity in latent space.
71
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
72
+ Flow-matching Euler scheduler. Other [`KarrasDiffusionSchedulers`] can be swapped at inference time.
73
+ vae ([`AutoencoderKL`]):
74
+ Variational autoencoder used to decode transformer latents to pixels.
75
+ id2label (`dict[int, str]`, *optional*):
76
+ ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
77
+ """
78
+
79
+ model_cpu_offload_seq = "transformer->vae"
80
+
81
+ def __init__(
82
+ self,
83
+ transformer,
84
+ scheduler,
85
+ vae,
86
+ id2label: Optional[Dict[Union[int, str], str]] = None,
87
+ ):
88
+ super().__init__()
89
+ if scheduler is None:
90
+ scheduler = FlowMatchEulerDiscreteScheduler(
91
+ num_train_timesteps=1000,
92
+ shift=1.0,
93
+ stochastic_sampling=False,
94
+ )
95
+ self.register_modules(transformer=transformer, scheduler=scheduler, vae=vae)
96
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
97
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
98
+ self._id2label = self._normalize_id2label(id2label)
99
+ self.labels = self._build_label2id(self._id2label)
100
+ self._labels_loaded_from_model_index = bool(self._id2label)
101
+
102
+ def _ensure_labels_loaded(self) -> None:
103
+ if self._labels_loaded_from_model_index:
104
+ return
105
+ loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None))
106
+ if loaded:
107
+ self._id2label = loaded
108
+ self.labels = self._build_label2id(self._id2label)
109
+ self._labels_loaded_from_model_index = True
110
+
111
+ @staticmethod
112
+ def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
113
+ if not id2label:
114
+ return {}
115
+ return {int(key): value for key, value in id2label.items()}
116
+
117
+ @staticmethod
118
+ def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]:
119
+ if not variant_path:
120
+ return {}
121
+ variant_dir = Path(variant_path).resolve()
122
+ model_index_path = variant_dir / "model_index.json"
123
+ if not model_index_path.exists():
124
+ return {}
125
+ raw = json.loads(model_index_path.read_text(encoding="utf-8"))
126
+ id2label = raw.get("id2label")
127
+ if not isinstance(id2label, dict):
128
+ return {}
129
+ return {int(key): value for key, value in id2label.items()}
130
+
131
+ @staticmethod
132
+ def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
133
+ label2id: Dict[str, int] = {}
134
+ for class_id, value in id2label.items():
135
+ for synonym in value.split(","):
136
+ synonym = synonym.strip()
137
+ if synonym:
138
+ label2id[synonym] = int(class_id)
139
+ return dict(sorted(label2id.items()))
140
+
141
+ @property
142
+ def id2label(self) -> Dict[int, str]:
143
+ r"""ImageNet class id to English label string (comma-separated synonyms)."""
144
+ self._ensure_labels_loaded()
145
+ return self._id2label
146
+
147
+ def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
148
+ r"""
149
+ Map ImageNet label strings to class ids.
150
+
151
+ Args:
152
+ label (`str` or `list[str]`):
153
+ One or more English label strings. Each string must match a synonym in `id2label`.
154
+ """
155
+ self._ensure_labels_loaded()
156
+ label2id = self.labels
157
+ if not label2id:
158
+ raise ValueError("No English labels loaded. Ensure `id2label` exists in model_index.json.")
159
+
160
+ if isinstance(label, str):
161
+ label = [label]
162
+
163
+ missing = [item for item in label if item not in label2id]
164
+ if missing:
165
+ preview = ", ".join(list(label2id.keys())[:8])
166
+ raise ValueError(f"Unknown English label(s): {missing}. Example valid labels: {preview}, ...")
167
+ return [label2id[item] for item in label]
168
+
169
+ def _normalize_class_labels(
170
+ self,
171
+ class_labels: Union[int, str, List[Union[int, str]], torch.LongTensor],
172
+ ) -> torch.LongTensor:
173
+ if torch.is_tensor(class_labels):
174
+ return class_labels.to(device=self._execution_device, dtype=torch.long).reshape(-1)
175
+
176
+ if isinstance(class_labels, int):
177
+ class_label_ids = [class_labels]
178
+ elif isinstance(class_labels, str):
179
+ class_label_ids = self.get_label_ids(class_labels)
180
+ elif class_labels and isinstance(class_labels[0], str):
181
+ class_label_ids = self.get_label_ids(class_labels)
182
+ else:
183
+ class_label_ids = list(class_labels)
184
+
185
+ return torch.tensor(class_label_ids, device=self._execution_device, dtype=torch.long).reshape(-1)
186
+
187
+ def _default_image_size(self) -> int:
188
+ return int(self.transformer.config.input_size) * self.vae_scale_factor
189
+
190
+ def check_inputs(
191
+ self,
192
+ height: int,
193
+ width: int,
194
+ num_inference_steps: int,
195
+ output_type: str,
196
+ ) -> None:
197
+ if num_inference_steps < 1:
198
+ raise ValueError("num_inference_steps must be >= 1.")
199
+ if output_type not in {"pil", "np", "pt", "latent"}:
200
+ raise ValueError("output_type must be one of: 'pil', 'np', 'pt', 'latent'.")
201
+
202
+ if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0:
203
+ raise ValueError(
204
+ f"height and width must be divisible by the VAE downsample factor {self.vae_scale_factor}."
205
+ )
206
+
207
+ latent_height = height // self.vae_scale_factor
208
+ latent_width = width // self.vae_scale_factor
209
+ expected_size = int(self.transformer.config.input_size)
210
+ if latent_height != expected_size or latent_width != expected_size:
211
+ raise ValueError(
212
+ f"Requested latent size {(latent_height, latent_width)} does not match the pretrained "
213
+ f"transformer input_size={expected_size}. Use height=width={self._default_image_size()}."
214
+ )
215
+
216
+ def prepare_latents(
217
+ self,
218
+ batch_size: int,
219
+ height: int,
220
+ width: int,
221
+ dtype: torch.dtype,
222
+ device: torch.device,
223
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]],
224
+ ) -> torch.Tensor:
225
+ latent_height = height // self.vae_scale_factor
226
+ latent_width = width // self.vae_scale_factor
227
+ return randn_tensor(
228
+ (batch_size, self.transformer.config.in_channels, latent_height, latent_width),
229
+ generator=generator,
230
+ device=device,
231
+ dtype=dtype,
232
+ )
233
+
234
+ @staticmethod
235
+ def _apply_classifier_free_guidance(model_output: torch.Tensor, guidance_scale: float) -> torch.Tensor:
236
+ if guidance_scale <= 1.0:
237
+ return model_output
238
+ model_output_cond, model_output_uncond = model_output.chunk(2)
239
+ return model_output_uncond + guidance_scale * (model_output_cond - model_output_uncond)
240
+
241
+ def decode_latents(self, latents: torch.Tensor, output_type: str = "pil"):
242
+ if output_type == "latent":
243
+ return latents
244
+
245
+ scaling_factor = getattr(self.vae.config, "scaling_factor", 0.18215)
246
+ image = self.vae.decode(latents / scaling_factor).sample
247
+ if output_type == "pt":
248
+ return image
249
+ return self.image_processor.postprocess(image, output_type=output_type)
250
+
251
+ @torch.inference_mode()
252
+ def __call__(
253
+ self,
254
+ class_labels: Union[int, str, List[Union[int, str]], torch.LongTensor],
255
+ height: Optional[int] = None,
256
+ width: Optional[int] = None,
257
+ num_inference_steps: int = 250,
258
+ guidance_scale: float = 4.0,
259
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
260
+ output_type: str = "pil",
261
+ return_dict: bool = True,
262
+ ) -> Union[ImagePipelineOutput, Tuple]:
263
+ r"""
264
+ Generate class-conditional images with SiT.
265
+
266
+ Args:
267
+ class_labels (`int`, `str`, `list[int]`, `list[str]`, or `torch.LongTensor`):
268
+ ImageNet class indices or human-readable English label strings.
269
+ height (`int`, *optional*):
270
+ Output image height in pixels. Defaults to the pretrained native resolution.
271
+ width (`int`, *optional*):
272
+ Output image width in pixels. Defaults to the pretrained native resolution.
273
+ num_inference_steps (`int`, defaults to `250`):
274
+ Number of denoising steps.
275
+ guidance_scale (`float`, defaults to `4.0`):
276
+ Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`.
277
+ generator (`torch.Generator`, *optional*):
278
+ RNG for reproducibility.
279
+ output_type (`str`, defaults to `"pil"`):
280
+ `"pil"`, `"np"`, `"pt"`, or `"latent"`.
281
+ return_dict (`bool`, defaults to `True`):
282
+ Return [`ImagePipelineOutput`] if True.
283
+ """
284
+ default_size = self._default_image_size()
285
+ height = int(height or default_size)
286
+ width = int(width or default_size)
287
+ self.check_inputs(height, width, num_inference_steps, output_type)
288
+
289
+ device = self._execution_device
290
+ model_dtype = next(self.transformer.parameters()).dtype
291
+ class_labels_tensor = self._normalize_class_labels(class_labels)
292
+ batch_size = class_labels_tensor.numel()
293
+ do_cfg = guidance_scale > 1.0
294
+
295
+ latents = self.prepare_latents(
296
+ batch_size=batch_size,
297
+ height=height,
298
+ width=width,
299
+ dtype=model_dtype,
300
+ device=device,
301
+ generator=generator,
302
+ )
303
+
304
+ labels = class_labels_tensor
305
+ if do_cfg:
306
+ null_labels = torch.full_like(class_labels_tensor, self.transformer.config.num_classes)
307
+ labels = torch.cat([class_labels_tensor, null_labels], dim=0)
308
+
309
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
310
+ num_train_timesteps = self.scheduler.config.num_train_timesteps
311
+
312
+ if getattr(self.scheduler.config, "stochastic_sampling", False):
313
+ raise ValueError(
314
+ "SiT expects deterministic FlowMatchEulerDiscreteScheduler stepping "
315
+ "(scheduler.config.stochastic_sampling=False)."
316
+ )
317
+
318
+ for t in self.progress_bar(self.scheduler.timesteps):
319
+ flow_time = 1.0 - float(t) / num_train_timesteps
320
+ if do_cfg:
321
+ model_input = torch.cat([latents, latents], dim=0)
322
+ else:
323
+ model_input = latents
324
+
325
+ timestep_batch = torch.full((model_input.shape[0],), flow_time, device=device, dtype=model_dtype)
326
+ model_output = self.transformer(
327
+ hidden_states=model_input,
328
+ timestep=timestep_batch,
329
+ class_labels=labels,
330
+ return_dict=True,
331
+ ).sample
332
+ model_output = self._apply_classifier_free_guidance(model_output, guidance_scale=guidance_scale)
333
+ # SiT predicts dx/d(flow_time) with flow_time increasing from noise (0) to data (1).
334
+ # FlowMatchEulerDiscreteScheduler integrates over sigma decreasing from 1 to 0, so flip sign.
335
+ model_output = -model_output
336
+ latents = self.scheduler.step(
337
+ model_output=model_output,
338
+ timestep=t,
339
+ sample=latents,
340
+ generator=generator,
341
+ return_dict=True,
342
+ ).prev_sample
343
+
344
+ image = self.decode_latents(latents, output_type=output_type)
345
+
346
+ self.maybe_free_model_hooks()
347
+ if not return_dict:
348
+ return (image,)
349
+ return ImagePipelineOutput(images=image)
SiT-B-2-256/scheduler/scheduler_config.json CHANGED
@@ -1,9 +1,7 @@
1
- {
2
- "_class_name": "SiTFlowMatchScheduler",
3
- "_diffusers_version": "0.36.0",
4
- "diffusion_form": "sigma",
5
- "diffusion_norm": 1.0,
6
- "mode": "ode",
7
- "num_train_timesteps": 1000,
8
- "shift": 1.0
9
- }
 
1
+ {
2
+ "_class_name": "FlowMatchEulerDiscreteScheduler",
3
+ "_diffusers_version": "0.36.0",
4
+ "num_train_timesteps": 1000,
5
+ "shift": 1.0,
6
+ "stochastic_sampling": false
7
+ }
 
 
SiT-B-2-256/transformer/transformer_sit.py CHANGED
@@ -1,224 +1,240 @@
1
- import math
2
- from dataclasses import dataclass
3
- from typing import Optional
4
-
5
- import numpy as np
6
- import torch
7
- import torch.nn as nn
8
- from timm.models.vision_transformer import Attention, Mlp, PatchEmbed
9
-
10
- from diffusers.configuration_utils import ConfigMixin, register_to_config
11
- from diffusers.models.modeling_utils import ModelMixin
12
- from diffusers.utils import BaseOutput
13
-
14
-
15
- def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
16
- return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
17
-
18
-
19
- @dataclass
20
- class SiTTransformer2DModelOutput(BaseOutput):
21
- sample: torch.Tensor
22
-
23
-
24
- class TimestepEmbedder(nn.Module):
25
- def __init__(self, hidden_size: int, frequency_embedding_size: int = 256):
26
- super().__init__()
27
- self.mlp = nn.Sequential(
28
- nn.Linear(frequency_embedding_size, hidden_size, bias=True),
29
- nn.SiLU(),
30
- nn.Linear(hidden_size, hidden_size, bias=True),
31
- )
32
- self.frequency_embedding_size = frequency_embedding_size
33
-
34
- @staticmethod
35
- def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 10000) -> torch.Tensor:
36
- half = dim // 2
37
- freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
38
- device=t.device
39
- )
40
- args = t[:, None].float() * freqs[None]
41
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
42
- if dim % 2:
43
- embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
44
- return embedding
45
-
46
- def forward(self, t: torch.Tensor) -> torch.Tensor:
47
- return self.mlp(self.timestep_embedding(t, self.frequency_embedding_size))
48
-
49
-
50
- class LabelEmbedder(nn.Module):
51
- def __init__(self, num_classes: int, hidden_size: int, dropout_prob: float):
52
- super().__init__()
53
- use_cfg_embedding = dropout_prob > 0
54
- self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
55
- self.num_classes = num_classes
56
- self.dropout_prob = dropout_prob
57
-
58
- def token_drop(self, labels: torch.Tensor, force_drop_ids: Optional[torch.Tensor] = None) -> torch.Tensor:
59
- if force_drop_ids is None:
60
- drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
61
- else:
62
- drop_ids = force_drop_ids == 1
63
- labels = torch.where(drop_ids, self.num_classes, labels)
64
- return labels
65
-
66
- def forward(
67
- self,
68
- labels: torch.Tensor,
69
- train: bool,
70
- force_drop_ids: Optional[torch.Tensor] = None,
71
- ) -> torch.Tensor:
72
- use_dropout = self.dropout_prob > 0
73
- if (train and use_dropout) or (force_drop_ids is not None):
74
- labels = self.token_drop(labels, force_drop_ids)
75
- return self.embedding_table(labels)
76
-
77
-
78
- class SiTBlock(nn.Module):
79
- def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float = 4.0, **block_kwargs):
80
- super().__init__()
81
- self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
82
- self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
83
- self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
84
- mlp_hidden_dim = int(hidden_size * mlp_ratio)
85
- approx_gelu = lambda: nn.GELU(approximate="tanh")
86
- self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
87
- self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
88
-
89
- def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
90
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
91
- x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
92
- x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
93
- return x
94
-
95
-
96
- class FinalLayer(nn.Module):
97
- def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
98
- super().__init__()
99
- self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
100
- self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
101
- self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
102
-
103
- def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
104
- shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
105
- x = modulate(self.norm_final(x), shift, scale)
106
- return self.linear(x)
107
-
108
-
109
- class SiTTransformer2DModel(ModelMixin, ConfigMixin):
110
- @register_to_config
111
- def __init__(
112
- self,
113
- input_size: int = 32,
114
- patch_size: int = 2,
115
- in_channels: int = 4,
116
- hidden_size: int = 1152,
117
- depth: int = 28,
118
- num_heads: int = 16,
119
- mlp_ratio: float = 4.0,
120
- class_dropout_prob: float = 0.1,
121
- num_classes: int = 1000,
122
- learn_sigma: bool = True,
123
- ):
124
- super().__init__()
125
- self.learn_sigma = learn_sigma
126
- self.in_channels = in_channels
127
- self.out_channels = in_channels * 2 if learn_sigma else in_channels
128
- self.patch_size = patch_size
129
- self.num_classes = num_classes
130
-
131
- self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
132
- self.t_embedder = TimestepEmbedder(hidden_size)
133
- self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
134
- num_patches = self.x_embedder.num_patches
135
- self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
136
-
137
- self.blocks = nn.ModuleList([SiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)])
138
- self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
139
- self.initialize_weights()
140
-
141
- def initialize_weights(self) -> None:
142
- def _basic_init(module: nn.Module):
143
- if isinstance(module, nn.Linear):
144
- torch.nn.init.xavier_uniform_(module.weight)
145
- if module.bias is not None:
146
- nn.init.constant_(module.bias, 0)
147
-
148
- self.apply(_basic_init)
149
- pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5))
150
- self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
151
-
152
- w = self.x_embedder.proj.weight.data
153
- nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
154
- nn.init.constant_(self.x_embedder.proj.bias, 0)
155
- nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
156
- nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
157
- nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
158
- for block in self.blocks:
159
- nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
160
- nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
161
- nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
162
- nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
163
- nn.init.constant_(self.final_layer.linear.weight, 0)
164
- nn.init.constant_(self.final_layer.linear.bias, 0)
165
-
166
- def unpatchify(self, x: torch.Tensor) -> torch.Tensor:
167
- c = self.out_channels
168
- p = self.x_embedder.patch_size[0]
169
- h = w = int(x.shape[1] ** 0.5)
170
- x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
171
- x = torch.einsum("nhwpqc->nchpwq", x)
172
- return x.reshape(shape=(x.shape[0], c, h * p, h * p))
173
-
174
- def forward(
175
- self,
176
- hidden_states: torch.Tensor,
177
- timestep: torch.Tensor,
178
- class_labels: torch.Tensor,
179
- force_drop_ids: Optional[torch.Tensor] = None,
180
- return_dict: bool = True,
181
- ) -> SiTTransformer2DModelOutput:
182
- x = self.x_embedder(hidden_states) + self.pos_embed
183
- t = self.t_embedder(timestep)
184
- y = self.y_embedder(class_labels, self.training, force_drop_ids=force_drop_ids)
185
- c = t + y
186
- for block in self.blocks:
187
- x = block(x, c)
188
- x = self.final_layer(x, c)
189
- x = self.unpatchify(x)
190
- if self.learn_sigma:
191
- x, _ = x.chunk(2, dim=1)
192
- if not return_dict:
193
- return (x,)
194
- return SiTTransformer2DModelOutput(sample=x)
195
-
196
-
197
- def get_2d_sincos_pos_embed(embed_dim: int, grid_size: int, cls_token: bool = False, extra_tokens: int = 0):
198
- grid_h = np.arange(grid_size, dtype=np.float32)
199
- grid_w = np.arange(grid_size, dtype=np.float32)
200
- grid = np.meshgrid(grid_w, grid_h)
201
- grid = np.stack(grid, axis=0).reshape([2, 1, grid_size, grid_size])
202
- pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
203
- if cls_token and extra_tokens > 0:
204
- pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
205
- return pos_embed
206
-
207
-
208
- def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: np.ndarray):
209
- assert embed_dim % 2 == 0
210
- emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
211
- emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
212
- return np.concatenate([emb_h, emb_w], axis=1)
213
-
214
-
215
- def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: np.ndarray):
216
- assert embed_dim % 2 == 0
217
- omega = np.arange(embed_dim // 2, dtype=np.float64)
218
- omega /= embed_dim / 2.0
219
- omega = 1.0 / 10000**omega
220
- pos = pos.reshape(-1)
221
- out = np.einsum("m,d->md", pos, omega)
222
- emb_sin = np.sin(out)
223
- emb_cos = np.cos(out)
224
- return np.concatenate([emb_sin, emb_cos], axis=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from dataclasses import dataclass
17
+ from typing import Optional
18
+
19
+ import numpy as np
20
+ import torch
21
+ import torch.nn as nn
22
+ from timm.models.vision_transformer import Attention, Mlp, PatchEmbed
23
+
24
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
25
+ from diffusers.models.modeling_utils import ModelMixin
26
+ from diffusers.utils import BaseOutput
27
+
28
+
29
+ def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
30
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
31
+
32
+
33
+ @dataclass
34
+ class SiTTransformer2DModelOutput(BaseOutput):
35
+ sample: torch.Tensor
36
+
37
+
38
+ class TimestepEmbedder(nn.Module):
39
+ def __init__(self, hidden_size: int, frequency_embedding_size: int = 256):
40
+ super().__init__()
41
+ self.mlp = nn.Sequential(
42
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
43
+ nn.SiLU(),
44
+ nn.Linear(hidden_size, hidden_size, bias=True),
45
+ )
46
+ self.frequency_embedding_size = frequency_embedding_size
47
+
48
+ @staticmethod
49
+ def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 10000) -> torch.Tensor:
50
+ half = dim // 2
51
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
52
+ device=t.device
53
+ )
54
+ args = t[:, None].float() * freqs[None]
55
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
56
+ if dim % 2:
57
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
58
+ return embedding
59
+
60
+ def forward(self, t: torch.Tensor) -> torch.Tensor:
61
+ emb = self.timestep_embedding(t.float(), self.frequency_embedding_size)
62
+ weight_dtype = self.mlp[0].weight.dtype
63
+ return self.mlp(emb.to(dtype=weight_dtype))
64
+
65
+
66
+ class LabelEmbedder(nn.Module):
67
+ def __init__(self, num_classes: int, hidden_size: int, dropout_prob: float):
68
+ super().__init__()
69
+ use_cfg_embedding = dropout_prob > 0
70
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
71
+ self.num_classes = num_classes
72
+ self.dropout_prob = dropout_prob
73
+
74
+ def token_drop(self, labels: torch.Tensor, force_drop_ids: Optional[torch.Tensor] = None) -> torch.Tensor:
75
+ if force_drop_ids is None:
76
+ drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
77
+ else:
78
+ drop_ids = force_drop_ids == 1
79
+ labels = torch.where(drop_ids, self.num_classes, labels)
80
+ return labels
81
+
82
+ def forward(
83
+ self,
84
+ labels: torch.Tensor,
85
+ train: bool,
86
+ force_drop_ids: Optional[torch.Tensor] = None,
87
+ ) -> torch.Tensor:
88
+ use_dropout = self.dropout_prob > 0
89
+ if (train and use_dropout) or (force_drop_ids is not None):
90
+ labels = self.token_drop(labels, force_drop_ids)
91
+ return self.embedding_table(labels)
92
+
93
+
94
+ class SiTBlock(nn.Module):
95
+ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float = 4.0, **block_kwargs):
96
+ super().__init__()
97
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
98
+ self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
99
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
100
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
101
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
102
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
103
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
104
+
105
+ def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
106
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
107
+ x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
108
+ x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
109
+ return x
110
+
111
+
112
+ class FinalLayer(nn.Module):
113
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
114
+ super().__init__()
115
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
116
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
117
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
118
+
119
+ def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
120
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
121
+ x = modulate(self.norm_final(x), shift, scale)
122
+ return self.linear(x)
123
+
124
+
125
+ class SiTTransformer2DModel(ModelMixin, ConfigMixin):
126
+ @register_to_config
127
+ def __init__(
128
+ self,
129
+ input_size: int = 32,
130
+ patch_size: int = 2,
131
+ in_channels: int = 4,
132
+ hidden_size: int = 1152,
133
+ depth: int = 28,
134
+ num_heads: int = 16,
135
+ mlp_ratio: float = 4.0,
136
+ class_dropout_prob: float = 0.1,
137
+ num_classes: int = 1000,
138
+ learn_sigma: bool = True,
139
+ ):
140
+ super().__init__()
141
+ self.learn_sigma = learn_sigma
142
+ self.in_channels = in_channels
143
+ self.out_channels = in_channels * 2 if learn_sigma else in_channels
144
+ self.patch_size = patch_size
145
+ self.num_classes = num_classes
146
+
147
+ self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
148
+ self.t_embedder = TimestepEmbedder(hidden_size)
149
+ self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
150
+ num_patches = self.x_embedder.num_patches
151
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
152
+
153
+ self.blocks = nn.ModuleList([SiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)])
154
+ self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
155
+ self.initialize_weights()
156
+
157
+ def initialize_weights(self) -> None:
158
+ def _basic_init(module: nn.Module):
159
+ if isinstance(module, nn.Linear):
160
+ torch.nn.init.xavier_uniform_(module.weight)
161
+ if module.bias is not None:
162
+ nn.init.constant_(module.bias, 0)
163
+
164
+ self.apply(_basic_init)
165
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5))
166
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
167
+
168
+ w = self.x_embedder.proj.weight.data
169
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
170
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
171
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
172
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
173
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
174
+ for block in self.blocks:
175
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
176
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
177
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
178
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
179
+ nn.init.constant_(self.final_layer.linear.weight, 0)
180
+ nn.init.constant_(self.final_layer.linear.bias, 0)
181
+
182
+ def unpatchify(self, x: torch.Tensor) -> torch.Tensor:
183
+ c = self.out_channels
184
+ p = self.x_embedder.patch_size[0]
185
+ h = w = int(x.shape[1] ** 0.5)
186
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
187
+ x = torch.einsum("nhwpqc->nchpwq", x)
188
+ return x.reshape(shape=(x.shape[0], c, h * p, h * p))
189
+
190
+ def forward(
191
+ self,
192
+ hidden_states: torch.Tensor,
193
+ timestep: torch.Tensor,
194
+ class_labels: torch.Tensor,
195
+ force_drop_ids: Optional[torch.Tensor] = None,
196
+ return_dict: bool = True,
197
+ ) -> SiTTransformer2DModelOutput:
198
+ x = self.x_embedder(hidden_states) + self.pos_embed
199
+ t = self.t_embedder(timestep)
200
+ y = self.y_embedder(class_labels, self.training, force_drop_ids=force_drop_ids)
201
+ c = t + y
202
+ for block in self.blocks:
203
+ x = block(x, c)
204
+ x = self.final_layer(x, c)
205
+ x = self.unpatchify(x)
206
+ if self.learn_sigma:
207
+ x, _ = x.chunk(2, dim=1)
208
+ if not return_dict:
209
+ return (x,)
210
+ return SiTTransformer2DModelOutput(sample=x)
211
+
212
+
213
+ def get_2d_sincos_pos_embed(embed_dim: int, grid_size: int, cls_token: bool = False, extra_tokens: int = 0):
214
+ grid_h = np.arange(grid_size, dtype=np.float32)
215
+ grid_w = np.arange(grid_size, dtype=np.float32)
216
+ grid = np.meshgrid(grid_w, grid_h)
217
+ grid = np.stack(grid, axis=0).reshape([2, 1, grid_size, grid_size])
218
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
219
+ if cls_token and extra_tokens > 0:
220
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
221
+ return pos_embed
222
+
223
+
224
+ def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: np.ndarray):
225
+ assert embed_dim % 2 == 0
226
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
227
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
228
+ return np.concatenate([emb_h, emb_w], axis=1)
229
+
230
+
231
+ def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: np.ndarray):
232
+ assert embed_dim % 2 == 0
233
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
234
+ omega /= embed_dim / 2.0
235
+ omega = 1.0 / 10000**omega
236
+ pos = pos.reshape(-1)
237
+ out = np.einsum("m,d->md", pos, omega)
238
+ emb_sin = np.sin(out)
239
+ emb_cos = np.cos(out)
240
+ return np.concatenate([emb_sin, emb_cos], axis=1)
SiT-L-2-256/__pycache__/pipeline.cpython-312.pyc ADDED
Binary file (7.51 kB). View file
 
SiT-L-2-256/model_index.json CHANGED
@@ -1,19 +1,1021 @@
1
- {
2
- "_class_name": [
3
- "pipeline",
4
- "SiTPipeline"
5
- ],
6
- "_diffusers_version": "0.36.0",
7
- "scheduler": [
8
- "scheduling_flow_match_sit",
9
- "SiTFlowMatchScheduler"
10
- ],
11
- "transformer": [
12
- "transformer_sit",
13
- "SiTTransformer2DModel"
14
- ],
15
- "vae": [
16
- "diffusers",
17
- "AutoencoderKL"
18
- ]
19
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": [
3
+ "pipeline",
4
+ "SiTPipeline"
5
+ ],
6
+ "_diffusers_version": "0.36.0",
7
+ "scheduler": [
8
+ "diffusers",
9
+ "FlowMatchEulerDiscreteScheduler"
10
+ ],
11
+ "transformer": [
12
+ "transformer_sit",
13
+ "SiTTransformer2DModel"
14
+ ],
15
+ "vae": [
16
+ "diffusers",
17
+ "AutoencoderKL"
18
+ ],
19
+ "id2label": {
20
+ "0": "tench, Tinca tinca",
21
+ "1": "goldfish, Carassius auratus",
22
+ "2": "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias",
23
+ "3": "tiger shark, Galeocerdo cuvieri",
24
+ "4": "hammerhead, hammerhead shark",
25
+ "5": "electric ray, crampfish, numbfish, torpedo",
26
+ "6": "stingray",
27
+ "7": "cock",
28
+ "8": "hen",
29
+ "9": "ostrich, Struthio camelus",
30
+ "10": "brambling, Fringilla montifringilla",
31
+ "11": "goldfinch, Carduelis carduelis",
32
+ "12": "house finch, linnet, Carpodacus mexicanus",
33
+ "13": "junco, snowbird",
34
+ "14": "indigo bunting, indigo finch, indigo bird, Passerina cyanea",
35
+ "15": "robin, American robin, Turdus migratorius",
36
+ "16": "bulbul",
37
+ "17": "jay",
38
+ "18": "magpie",
39
+ "19": "chickadee",
40
+ "20": "water ouzel, dipper",
41
+ "21": "kite",
42
+ "22": "bald eagle, American eagle, Haliaeetus leucocephalus",
43
+ "23": "vulture",
44
+ "24": "great grey owl, great gray owl, Strix nebulosa",
45
+ "25": "European fire salamander, Salamandra salamandra",
46
+ "26": "common newt, Triturus vulgaris",
47
+ "27": "eft",
48
+ "28": "spotted salamander, Ambystoma maculatum",
49
+ "29": "axolotl, mud puppy, Ambystoma mexicanum",
50
+ "30": "bullfrog, Rana catesbeiana",
51
+ "31": "tree frog, tree-frog",
52
+ "32": "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui",
53
+ "33": "loggerhead, loggerhead turtle, Caretta caretta",
54
+ "34": "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea",
55
+ "35": "mud turtle",
56
+ "36": "terrapin",
57
+ "37": "box turtle, box tortoise",
58
+ "38": "banded gecko",
59
+ "39": "common iguana, iguana, Iguana iguana",
60
+ "40": "American chameleon, anole, Anolis carolinensis",
61
+ "41": "whiptail, whiptail lizard",
62
+ "42": "agama",
63
+ "43": "frilled lizard, Chlamydosaurus kingi",
64
+ "44": "alligator lizard",
65
+ "45": "Gila monster, Heloderma suspectum",
66
+ "46": "green lizard, Lacerta viridis",
67
+ "47": "African chameleon, Chamaeleo chamaeleon",
68
+ "48": "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis",
69
+ "49": "African crocodile, Nile crocodile, Crocodylus niloticus",
70
+ "50": "American alligator, Alligator mississipiensis",
71
+ "51": "triceratops",
72
+ "52": "thunder snake, worm snake, Carphophis amoenus",
73
+ "53": "ringneck snake, ring-necked snake, ring snake",
74
+ "54": "hognose snake, puff adder, sand viper",
75
+ "55": "green snake, grass snake",
76
+ "56": "king snake, kingsnake",
77
+ "57": "garter snake, grass snake",
78
+ "58": "water snake",
79
+ "59": "vine snake",
80
+ "60": "night snake, Hypsiglena torquata",
81
+ "61": "boa constrictor, Constrictor constrictor",
82
+ "62": "rock python, rock snake, Python sebae",
83
+ "63": "Indian cobra, Naja naja",
84
+ "64": "green mamba",
85
+ "65": "sea snake",
86
+ "66": "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus",
87
+ "67": "diamondback, diamondback rattlesnake, Crotalus adamanteus",
88
+ "68": "sidewinder, horned rattlesnake, Crotalus cerastes",
89
+ "69": "trilobite",
90
+ "70": "harvestman, daddy longlegs, Phalangium opilio",
91
+ "71": "scorpion",
92
+ "72": "black and gold garden spider, Argiope aurantia",
93
+ "73": "barn spider, Araneus cavaticus",
94
+ "74": "garden spider, Aranea diademata",
95
+ "75": "black widow, Latrodectus mactans",
96
+ "76": "tarantula",
97
+ "77": "wolf spider, hunting spider",
98
+ "78": "tick",
99
+ "79": "centipede",
100
+ "80": "black grouse",
101
+ "81": "ptarmigan",
102
+ "82": "ruffed grouse, partridge, Bonasa umbellus",
103
+ "83": "prairie chicken, prairie grouse, prairie fowl",
104
+ "84": "peacock",
105
+ "85": "quail",
106
+ "86": "partridge",
107
+ "87": "African grey, African gray, Psittacus erithacus",
108
+ "88": "macaw",
109
+ "89": "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita",
110
+ "90": "lorikeet",
111
+ "91": "coucal",
112
+ "92": "bee eater",
113
+ "93": "hornbill",
114
+ "94": "hummingbird",
115
+ "95": "jacamar",
116
+ "96": "toucan",
117
+ "97": "drake",
118
+ "98": "red-breasted merganser, Mergus serrator",
119
+ "99": "goose",
120
+ "100": "black swan, Cygnus atratus",
121
+ "101": "tusker",
122
+ "102": "echidna, spiny anteater, anteater",
123
+ "103": "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus",
124
+ "104": "wallaby, brush kangaroo",
125
+ "105": "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus",
126
+ "106": "wombat",
127
+ "107": "jellyfish",
128
+ "108": "sea anemone, anemone",
129
+ "109": "brain coral",
130
+ "110": "flatworm, platyhelminth",
131
+ "111": "nematode, nematode worm, roundworm",
132
+ "112": "conch",
133
+ "113": "snail",
134
+ "114": "slug",
135
+ "115": "sea slug, nudibranch",
136
+ "116": "chiton, coat-of-mail shell, sea cradle, polyplacophore",
137
+ "117": "chambered nautilus, pearly nautilus, nautilus",
138
+ "118": "Dungeness crab, Cancer magister",
139
+ "119": "rock crab, Cancer irroratus",
140
+ "120": "fiddler crab",
141
+ "121": "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica",
142
+ "122": "American lobster, Northern lobster, Maine lobster, Homarus americanus",
143
+ "123": "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish",
144
+ "124": "crayfish, crawfish, crawdad, crawdaddy",
145
+ "125": "hermit crab",
146
+ "126": "isopod",
147
+ "127": "white stork, Ciconia ciconia",
148
+ "128": "black stork, Ciconia nigra",
149
+ "129": "spoonbill",
150
+ "130": "flamingo",
151
+ "131": "little blue heron, Egretta caerulea",
152
+ "132": "American egret, great white heron, Egretta albus",
153
+ "133": "bittern",
154
+ "134": "crane",
155
+ "135": "limpkin, Aramus pictus",
156
+ "136": "European gallinule, Porphyrio porphyrio",
157
+ "137": "American coot, marsh hen, mud hen, water hen, Fulica americana",
158
+ "138": "bustard",
159
+ "139": "ruddy turnstone, Arenaria interpres",
160
+ "140": "red-backed sandpiper, dunlin, Erolia alpina",
161
+ "141": "redshank, Tringa totanus",
162
+ "142": "dowitcher",
163
+ "143": "oystercatcher, oyster catcher",
164
+ "144": "pelican",
165
+ "145": "king penguin, Aptenodytes patagonica",
166
+ "146": "albatross, mollymawk",
167
+ "147": "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus",
168
+ "148": "killer whale, killer, orca, grampus, sea wolf, Orcinus orca",
169
+ "149": "dugong, Dugong dugon",
170
+ "150": "sea lion",
171
+ "151": "Chihuahua",
172
+ "152": "Japanese spaniel",
173
+ "153": "Maltese dog, Maltese terrier, Maltese",
174
+ "154": "Pekinese, Pekingese, Peke",
175
+ "155": "Shih-Tzu",
176
+ "156": "Blenheim spaniel",
177
+ "157": "papillon",
178
+ "158": "toy terrier",
179
+ "159": "Rhodesian ridgeback",
180
+ "160": "Afghan hound, Afghan",
181
+ "161": "basset, basset hound",
182
+ "162": "beagle",
183
+ "163": "bloodhound, sleuthhound",
184
+ "164": "bluetick",
185
+ "165": "black-and-tan coonhound",
186
+ "166": "Walker hound, Walker foxhound",
187
+ "167": "English foxhound",
188
+ "168": "redbone",
189
+ "169": "borzoi, Russian wolfhound",
190
+ "170": "Irish wolfhound",
191
+ "171": "Italian greyhound",
192
+ "172": "whippet",
193
+ "173": "Ibizan hound, Ibizan Podenco",
194
+ "174": "Norwegian elkhound, elkhound",
195
+ "175": "otterhound, otter hound",
196
+ "176": "Saluki, gazelle hound",
197
+ "177": "Scottish deerhound, deerhound",
198
+ "178": "Weimaraner",
199
+ "179": "Staffordshire bullterrier, Staffordshire bull terrier",
200
+ "180": "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier",
201
+ "181": "Bedlington terrier",
202
+ "182": "Border terrier",
203
+ "183": "Kerry blue terrier",
204
+ "184": "Irish terrier",
205
+ "185": "Norfolk terrier",
206
+ "186": "Norwich terrier",
207
+ "187": "Yorkshire terrier",
208
+ "188": "wire-haired fox terrier",
209
+ "189": "Lakeland terrier",
210
+ "190": "Sealyham terrier, Sealyham",
211
+ "191": "Airedale, Airedale terrier",
212
+ "192": "cairn, cairn terrier",
213
+ "193": "Australian terrier",
214
+ "194": "Dandie Dinmont, Dandie Dinmont terrier",
215
+ "195": "Boston bull, Boston terrier",
216
+ "196": "miniature schnauzer",
217
+ "197": "giant schnauzer",
218
+ "198": "standard schnauzer",
219
+ "199": "Scotch terrier, Scottish terrier, Scottie",
220
+ "200": "Tibetan terrier, chrysanthemum dog",
221
+ "201": "silky terrier, Sydney silky",
222
+ "202": "soft-coated wheaten terrier",
223
+ "203": "West Highland white terrier",
224
+ "204": "Lhasa, Lhasa apso",
225
+ "205": "flat-coated retriever",
226
+ "206": "curly-coated retriever",
227
+ "207": "golden retriever",
228
+ "208": "Labrador retriever",
229
+ "209": "Chesapeake Bay retriever",
230
+ "210": "German short-haired pointer",
231
+ "211": "vizsla, Hungarian pointer",
232
+ "212": "English setter",
233
+ "213": "Irish setter, red setter",
234
+ "214": "Gordon setter",
235
+ "215": "Brittany spaniel",
236
+ "216": "clumber, clumber spaniel",
237
+ "217": "English springer, English springer spaniel",
238
+ "218": "Welsh springer spaniel",
239
+ "219": "cocker spaniel, English cocker spaniel, cocker",
240
+ "220": "Sussex spaniel",
241
+ "221": "Irish water spaniel",
242
+ "222": "kuvasz",
243
+ "223": "schipperke",
244
+ "224": "groenendael",
245
+ "225": "malinois",
246
+ "226": "briard",
247
+ "227": "kelpie",
248
+ "228": "komondor",
249
+ "229": "Old English sheepdog, bobtail",
250
+ "230": "Shetland sheepdog, Shetland sheep dog, Shetland",
251
+ "231": "collie",
252
+ "232": "Border collie",
253
+ "233": "Bouvier des Flandres, Bouviers des Flandres",
254
+ "234": "Rottweiler",
255
+ "235": "German shepherd, German shepherd dog, German police dog, alsatian",
256
+ "236": "Doberman, Doberman pinscher",
257
+ "237": "miniature pinscher",
258
+ "238": "Greater Swiss Mountain dog",
259
+ "239": "Bernese mountain dog",
260
+ "240": "Appenzeller",
261
+ "241": "EntleBucher",
262
+ "242": "boxer",
263
+ "243": "bull mastiff",
264
+ "244": "Tibetan mastiff",
265
+ "245": "French bulldog",
266
+ "246": "Great Dane",
267
+ "247": "Saint Bernard, St Bernard",
268
+ "248": "Eskimo dog, husky",
269
+ "249": "malamute, malemute, Alaskan malamute",
270
+ "250": "Siberian husky",
271
+ "251": "dalmatian, coach dog, carriage dog",
272
+ "252": "affenpinscher, monkey pinscher, monkey dog",
273
+ "253": "basenji",
274
+ "254": "pug, pug-dog",
275
+ "255": "Leonberg",
276
+ "256": "Newfoundland, Newfoundland dog",
277
+ "257": "Great Pyrenees",
278
+ "258": "Samoyed, Samoyede",
279
+ "259": "Pomeranian",
280
+ "260": "chow, chow chow",
281
+ "261": "keeshond",
282
+ "262": "Brabancon griffon",
283
+ "263": "Pembroke, Pembroke Welsh corgi",
284
+ "264": "Cardigan, Cardigan Welsh corgi",
285
+ "265": "toy poodle",
286
+ "266": "miniature poodle",
287
+ "267": "standard poodle",
288
+ "268": "Mexican hairless",
289
+ "269": "timber wolf, grey wolf, gray wolf, Canis lupus",
290
+ "270": "white wolf, Arctic wolf, Canis lupus tundrarum",
291
+ "271": "red wolf, maned wolf, Canis rufus, Canis niger",
292
+ "272": "coyote, prairie wolf, brush wolf, Canis latrans",
293
+ "273": "dingo, warrigal, warragal, Canis dingo",
294
+ "274": "dhole, Cuon alpinus",
295
+ "275": "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus",
296
+ "276": "hyena, hyaena",
297
+ "277": "red fox, Vulpes vulpes",
298
+ "278": "kit fox, Vulpes macrotis",
299
+ "279": "Arctic fox, white fox, Alopex lagopus",
300
+ "280": "grey fox, gray fox, Urocyon cinereoargenteus",
301
+ "281": "tabby, tabby cat",
302
+ "282": "tiger cat",
303
+ "283": "Persian cat",
304
+ "284": "Siamese cat, Siamese",
305
+ "285": "Egyptian cat",
306
+ "286": "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor",
307
+ "287": "lynx, catamount",
308
+ "288": "leopard, Panthera pardus",
309
+ "289": "snow leopard, ounce, Panthera uncia",
310
+ "290": "jaguar, panther, Panthera onca, Felis onca",
311
+ "291": "lion, king of beasts, Panthera leo",
312
+ "292": "tiger, Panthera tigris",
313
+ "293": "cheetah, chetah, Acinonyx jubatus",
314
+ "294": "brown bear, bruin, Ursus arctos",
315
+ "295": "American black bear, black bear, Ursus americanus, Euarctos americanus",
316
+ "296": "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus",
317
+ "297": "sloth bear, Melursus ursinus, Ursus ursinus",
318
+ "298": "mongoose",
319
+ "299": "meerkat, mierkat",
320
+ "300": "tiger beetle",
321
+ "301": "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle",
322
+ "302": "ground beetle, carabid beetle",
323
+ "303": "long-horned beetle, longicorn, longicorn beetle",
324
+ "304": "leaf beetle, chrysomelid",
325
+ "305": "dung beetle",
326
+ "306": "rhinoceros beetle",
327
+ "307": "weevil",
328
+ "308": "fly",
329
+ "309": "bee",
330
+ "310": "ant, emmet, pismire",
331
+ "311": "grasshopper, hopper",
332
+ "312": "cricket",
333
+ "313": "walking stick, walkingstick, stick insect",
334
+ "314": "cockroach, roach",
335
+ "315": "mantis, mantid",
336
+ "316": "cicada, cicala",
337
+ "317": "leafhopper",
338
+ "318": "lacewing, lacewing fly",
339
+ "319": "dragonfly, darning needle, devils darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk",
340
+ "320": "damselfly",
341
+ "321": "admiral",
342
+ "322": "ringlet, ringlet butterfly",
343
+ "323": "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus",
344
+ "324": "cabbage butterfly",
345
+ "325": "sulphur butterfly, sulfur butterfly",
346
+ "326": "lycaenid, lycaenid butterfly",
347
+ "327": "starfish, sea star",
348
+ "328": "sea urchin",
349
+ "329": "sea cucumber, holothurian",
350
+ "330": "wood rabbit, cottontail, cottontail rabbit",
351
+ "331": "hare",
352
+ "332": "Angora, Angora rabbit",
353
+ "333": "hamster",
354
+ "334": "porcupine, hedgehog",
355
+ "335": "fox squirrel, eastern fox squirrel, Sciurus niger",
356
+ "336": "marmot",
357
+ "337": "beaver",
358
+ "338": "guinea pig, Cavia cobaya",
359
+ "339": "sorrel",
360
+ "340": "zebra",
361
+ "341": "hog, pig, grunter, squealer, Sus scrofa",
362
+ "342": "wild boar, boar, Sus scrofa",
363
+ "343": "warthog",
364
+ "344": "hippopotamus, hippo, river horse, Hippopotamus amphibius",
365
+ "345": "ox",
366
+ "346": "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis",
367
+ "347": "bison",
368
+ "348": "ram, tup",
369
+ "349": "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis",
370
+ "350": "ibex, Capra ibex",
371
+ "351": "hartebeest",
372
+ "352": "impala, Aepyceros melampus",
373
+ "353": "gazelle",
374
+ "354": "Arabian camel, dromedary, Camelus dromedarius",
375
+ "355": "llama",
376
+ "356": "weasel",
377
+ "357": "mink",
378
+ "358": "polecat, fitch, foulmart, foumart, Mustela putorius",
379
+ "359": "black-footed ferret, ferret, Mustela nigripes",
380
+ "360": "otter",
381
+ "361": "skunk, polecat, wood pussy",
382
+ "362": "badger",
383
+ "363": "armadillo",
384
+ "364": "three-toed sloth, ai, Bradypus tridactylus",
385
+ "365": "orangutan, orang, orangutang, Pongo pygmaeus",
386
+ "366": "gorilla, Gorilla gorilla",
387
+ "367": "chimpanzee, chimp, Pan troglodytes",
388
+ "368": "gibbon, Hylobates lar",
389
+ "369": "siamang, Hylobates syndactylus, Symphalangus syndactylus",
390
+ "370": "guenon, guenon monkey",
391
+ "371": "patas, hussar monkey, Erythrocebus patas",
392
+ "372": "baboon",
393
+ "373": "macaque",
394
+ "374": "langur",
395
+ "375": "colobus, colobus monkey",
396
+ "376": "proboscis monkey, Nasalis larvatus",
397
+ "377": "marmoset",
398
+ "378": "capuchin, ringtail, Cebus capucinus",
399
+ "379": "howler monkey, howler",
400
+ "380": "titi, titi monkey",
401
+ "381": "spider monkey, Ateles geoffroyi",
402
+ "382": "squirrel monkey, Saimiri sciureus",
403
+ "383": "Madagascar cat, ring-tailed lemur, Lemur catta",
404
+ "384": "indri, indris, Indri indri, Indri brevicaudatus",
405
+ "385": "Indian elephant, Elephas maximus",
406
+ "386": "African elephant, Loxodonta africana",
407
+ "387": "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens",
408
+ "388": "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca",
409
+ "389": "barracouta, snoek",
410
+ "390": "eel",
411
+ "391": "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch",
412
+ "392": "rock beauty, Holocanthus tricolor",
413
+ "393": "anemone fish",
414
+ "394": "sturgeon",
415
+ "395": "gar, garfish, garpike, billfish, Lepisosteus osseus",
416
+ "396": "lionfish",
417
+ "397": "puffer, pufferfish, blowfish, globefish",
418
+ "398": "abacus",
419
+ "399": "abaya",
420
+ "400": "academic gown, academic robe, judge robe",
421
+ "401": "accordion, piano accordion, squeeze box",
422
+ "402": "acoustic guitar",
423
+ "403": "aircraft carrier, carrier, flattop, attack aircraft carrier",
424
+ "404": "airliner",
425
+ "405": "airship, dirigible",
426
+ "406": "altar",
427
+ "407": "ambulance",
428
+ "408": "amphibian, amphibious vehicle",
429
+ "409": "analog clock",
430
+ "410": "apiary, bee house",
431
+ "411": "apron",
432
+ "412": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin",
433
+ "413": "assault rifle, assault gun",
434
+ "414": "backpack, back pack, knapsack, packsack, rucksack, haversack",
435
+ "415": "bakery, bakeshop, bakehouse",
436
+ "416": "balance beam, beam",
437
+ "417": "balloon",
438
+ "418": "ballpoint, ballpoint pen, ballpen, Biro",
439
+ "419": "Band Aid",
440
+ "420": "banjo",
441
+ "421": "bannister, banister, balustrade, balusters, handrail",
442
+ "422": "barbell",
443
+ "423": "barber chair",
444
+ "424": "barbershop",
445
+ "425": "barn",
446
+ "426": "barometer",
447
+ "427": "barrel, cask",
448
+ "428": "barrow, garden cart, lawn cart, wheelbarrow",
449
+ "429": "baseball",
450
+ "430": "basketball",
451
+ "431": "bassinet",
452
+ "432": "bassoon",
453
+ "433": "bathing cap, swimming cap",
454
+ "434": "bath towel",
455
+ "435": "bathtub, bathing tub, bath, tub",
456
+ "436": "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon",
457
+ "437": "beacon, lighthouse, beacon light, pharos",
458
+ "438": "beaker",
459
+ "439": "bearskin, busby, shako",
460
+ "440": "beer bottle",
461
+ "441": "beer glass",
462
+ "442": "bell cote, bell cot",
463
+ "443": "bib",
464
+ "444": "bicycle-built-for-two, tandem bicycle, tandem",
465
+ "445": "bikini, two-piece",
466
+ "446": "binder, ring-binder",
467
+ "447": "binoculars, field glasses, opera glasses",
468
+ "448": "birdhouse",
469
+ "449": "boathouse",
470
+ "450": "bobsled, bobsleigh, bob",
471
+ "451": "bolo tie, bolo, bola tie, bola",
472
+ "452": "bonnet, poke bonnet",
473
+ "453": "bookcase",
474
+ "454": "bookshop, bookstore, bookstall",
475
+ "455": "bottlecap",
476
+ "456": "bow",
477
+ "457": "bow tie, bow-tie, bowtie",
478
+ "458": "brass, memorial tablet, plaque",
479
+ "459": "brassiere, bra, bandeau",
480
+ "460": "breakwater, groin, groyne, mole, bulwark, seawall, jetty",
481
+ "461": "breastplate, aegis, egis",
482
+ "462": "broom",
483
+ "463": "bucket, pail",
484
+ "464": "buckle",
485
+ "465": "bulletproof vest",
486
+ "466": "bullet train, bullet",
487
+ "467": "butcher shop, meat market",
488
+ "468": "cab, hack, taxi, taxicab",
489
+ "469": "caldron, cauldron",
490
+ "470": "candle, taper, wax light",
491
+ "471": "cannon",
492
+ "472": "canoe",
493
+ "473": "can opener, tin opener",
494
+ "474": "cardigan",
495
+ "475": "car mirror",
496
+ "476": "carousel, carrousel, merry-go-round, roundabout, whirligig",
497
+ "477": "carpenters kit, tool kit",
498
+ "478": "carton",
499
+ "479": "car wheel",
500
+ "480": "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM",
501
+ "481": "cassette",
502
+ "482": "cassette player",
503
+ "483": "castle",
504
+ "484": "catamaran",
505
+ "485": "CD player",
506
+ "486": "cello, violoncello",
507
+ "487": "cellular telephone, cellular phone, cellphone, cell, mobile phone",
508
+ "488": "chain",
509
+ "489": "chainlink fence",
510
+ "490": "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour",
511
+ "491": "chain saw, chainsaw",
512
+ "492": "chest",
513
+ "493": "chiffonier, commode",
514
+ "494": "chime, bell, gong",
515
+ "495": "china cabinet, china closet",
516
+ "496": "Christmas stocking",
517
+ "497": "church, church building",
518
+ "498": "cinema, movie theater, movie theatre, movie house, picture palace",
519
+ "499": "cleaver, meat cleaver, chopper",
520
+ "500": "cliff dwelling",
521
+ "501": "cloak",
522
+ "502": "clog, geta, patten, sabot",
523
+ "503": "cocktail shaker",
524
+ "504": "coffee mug",
525
+ "505": "coffeepot",
526
+ "506": "coil, spiral, volute, whorl, helix",
527
+ "507": "combination lock",
528
+ "508": "computer keyboard, keypad",
529
+ "509": "confectionery, confectionary, candy store",
530
+ "510": "container ship, containership, container vessel",
531
+ "511": "convertible",
532
+ "512": "corkscrew, bottle screw",
533
+ "513": "cornet, horn, trumpet, trump",
534
+ "514": "cowboy boot",
535
+ "515": "cowboy hat, ten-gallon hat",
536
+ "516": "cradle",
537
+ "517": "crane",
538
+ "518": "crash helmet",
539
+ "519": "crate",
540
+ "520": "crib, cot",
541
+ "521": "Crock Pot",
542
+ "522": "croquet ball",
543
+ "523": "crutch",
544
+ "524": "cuirass",
545
+ "525": "dam, dike, dyke",
546
+ "526": "desk",
547
+ "527": "desktop computer",
548
+ "528": "dial telephone, dial phone",
549
+ "529": "diaper, nappy, napkin",
550
+ "530": "digital clock",
551
+ "531": "digital watch",
552
+ "532": "dining table, board",
553
+ "533": "dishrag, dishcloth",
554
+ "534": "dishwasher, dish washer, dishwashing machine",
555
+ "535": "disk brake, disc brake",
556
+ "536": "dock, dockage, docking facility",
557
+ "537": "dogsled, dog sled, dog sleigh",
558
+ "538": "dome",
559
+ "539": "doormat, welcome mat",
560
+ "540": "drilling platform, offshore rig",
561
+ "541": "drum, membranophone, tympan",
562
+ "542": "drumstick",
563
+ "543": "dumbbell",
564
+ "544": "Dutch oven",
565
+ "545": "electric fan, blower",
566
+ "546": "electric guitar",
567
+ "547": "electric locomotive",
568
+ "548": "entertainment center",
569
+ "549": "envelope",
570
+ "550": "espresso maker",
571
+ "551": "face powder",
572
+ "552": "feather boa, boa",
573
+ "553": "file, file cabinet, filing cabinet",
574
+ "554": "fireboat",
575
+ "555": "fire engine, fire truck",
576
+ "556": "fire screen, fireguard",
577
+ "557": "flagpole, flagstaff",
578
+ "558": "flute, transverse flute",
579
+ "559": "folding chair",
580
+ "560": "football helmet",
581
+ "561": "forklift",
582
+ "562": "fountain",
583
+ "563": "fountain pen",
584
+ "564": "four-poster",
585
+ "565": "freight car",
586
+ "566": "French horn, horn",
587
+ "567": "frying pan, frypan, skillet",
588
+ "568": "fur coat",
589
+ "569": "garbage truck, dustcart",
590
+ "570": "gasmask, respirator, gas helmet",
591
+ "571": "gas pump, gasoline pump, petrol pump, island dispenser",
592
+ "572": "goblet",
593
+ "573": "go-kart",
594
+ "574": "golf ball",
595
+ "575": "golfcart, golf cart",
596
+ "576": "gondola",
597
+ "577": "gong, tam-tam",
598
+ "578": "gown",
599
+ "579": "grand piano, grand",
600
+ "580": "greenhouse, nursery, glasshouse",
601
+ "581": "grille, radiator grille",
602
+ "582": "grocery store, grocery, food market, market",
603
+ "583": "guillotine",
604
+ "584": "hair slide",
605
+ "585": "hair spray",
606
+ "586": "half track",
607
+ "587": "hammer",
608
+ "588": "hamper",
609
+ "589": "hand blower, blow dryer, blow drier, hair dryer, hair drier",
610
+ "590": "hand-held computer, hand-held microcomputer",
611
+ "591": "handkerchief, hankie, hanky, hankey",
612
+ "592": "hard disc, hard disk, fixed disk",
613
+ "593": "harmonica, mouth organ, harp, mouth harp",
614
+ "594": "harp",
615
+ "595": "harvester, reaper",
616
+ "596": "hatchet",
617
+ "597": "holster",
618
+ "598": "home theater, home theatre",
619
+ "599": "honeycomb",
620
+ "600": "hook, claw",
621
+ "601": "hoopskirt, crinoline",
622
+ "602": "horizontal bar, high bar",
623
+ "603": "horse cart, horse-cart",
624
+ "604": "hourglass",
625
+ "605": "iPod",
626
+ "606": "iron, smoothing iron",
627
+ "607": "jack-o-lantern",
628
+ "608": "jean, blue jean, denim",
629
+ "609": "jeep, landrover",
630
+ "610": "jersey, T-shirt, tee shirt",
631
+ "611": "jigsaw puzzle",
632
+ "612": "jinrikisha, ricksha, rickshaw",
633
+ "613": "joystick",
634
+ "614": "kimono",
635
+ "615": "knee pad",
636
+ "616": "knot",
637
+ "617": "lab coat, laboratory coat",
638
+ "618": "ladle",
639
+ "619": "lampshade, lamp shade",
640
+ "620": "laptop, laptop computer",
641
+ "621": "lawn mower, mower",
642
+ "622": "lens cap, lens cover",
643
+ "623": "letter opener, paper knife, paperknife",
644
+ "624": "library",
645
+ "625": "lifeboat",
646
+ "626": "lighter, light, igniter, ignitor",
647
+ "627": "limousine, limo",
648
+ "628": "liner, ocean liner",
649
+ "629": "lipstick, lip rouge",
650
+ "630": "Loafer",
651
+ "631": "lotion",
652
+ "632": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system",
653
+ "633": "loupe, jewelers loupe",
654
+ "634": "lumbermill, sawmill",
655
+ "635": "magnetic compass",
656
+ "636": "mailbag, postbag",
657
+ "637": "mailbox, letter box",
658
+ "638": "maillot",
659
+ "639": "maillot, tank suit",
660
+ "640": "manhole cover",
661
+ "641": "maraca",
662
+ "642": "marimba, xylophone",
663
+ "643": "mask",
664
+ "644": "matchstick",
665
+ "645": "maypole",
666
+ "646": "maze, labyrinth",
667
+ "647": "measuring cup",
668
+ "648": "medicine chest, medicine cabinet",
669
+ "649": "megalith, megalithic structure",
670
+ "650": "microphone, mike",
671
+ "651": "microwave, microwave oven",
672
+ "652": "military uniform",
673
+ "653": "milk can",
674
+ "654": "minibus",
675
+ "655": "miniskirt, mini",
676
+ "656": "minivan",
677
+ "657": "missile",
678
+ "658": "mitten",
679
+ "659": "mixing bowl",
680
+ "660": "mobile home, manufactured home",
681
+ "661": "Model T",
682
+ "662": "modem",
683
+ "663": "monastery",
684
+ "664": "monitor",
685
+ "665": "moped",
686
+ "666": "mortar",
687
+ "667": "mortarboard",
688
+ "668": "mosque",
689
+ "669": "mosquito net",
690
+ "670": "motor scooter, scooter",
691
+ "671": "mountain bike, all-terrain bike, off-roader",
692
+ "672": "mountain tent",
693
+ "673": "mouse, computer mouse",
694
+ "674": "mousetrap",
695
+ "675": "moving van",
696
+ "676": "muzzle",
697
+ "677": "nail",
698
+ "678": "neck brace",
699
+ "679": "necklace",
700
+ "680": "nipple",
701
+ "681": "notebook, notebook computer",
702
+ "682": "obelisk",
703
+ "683": "oboe, hautboy, hautbois",
704
+ "684": "ocarina, sweet potato",
705
+ "685": "odometer, hodometer, mileometer, milometer",
706
+ "686": "oil filter",
707
+ "687": "organ, pipe organ",
708
+ "688": "oscilloscope, scope, cathode-ray oscilloscope, CRO",
709
+ "689": "overskirt",
710
+ "690": "oxcart",
711
+ "691": "oxygen mask",
712
+ "692": "packet",
713
+ "693": "paddle, boat paddle",
714
+ "694": "paddlewheel, paddle wheel",
715
+ "695": "padlock",
716
+ "696": "paintbrush",
717
+ "697": "pajama, pyjama, pjs, jammies",
718
+ "698": "palace",
719
+ "699": "panpipe, pandean pipe, syrinx",
720
+ "700": "paper towel",
721
+ "701": "parachute, chute",
722
+ "702": "parallel bars, bars",
723
+ "703": "park bench",
724
+ "704": "parking meter",
725
+ "705": "passenger car, coach, carriage",
726
+ "706": "patio, terrace",
727
+ "707": "pay-phone, pay-station",
728
+ "708": "pedestal, plinth, footstall",
729
+ "709": "pencil box, pencil case",
730
+ "710": "pencil sharpener",
731
+ "711": "perfume, essence",
732
+ "712": "Petri dish",
733
+ "713": "photocopier",
734
+ "714": "pick, plectrum, plectron",
735
+ "715": "pickelhaube",
736
+ "716": "picket fence, paling",
737
+ "717": "pickup, pickup truck",
738
+ "718": "pier",
739
+ "719": "piggy bank, penny bank",
740
+ "720": "pill bottle",
741
+ "721": "pillow",
742
+ "722": "ping-pong ball",
743
+ "723": "pinwheel",
744
+ "724": "pirate, pirate ship",
745
+ "725": "pitcher, ewer",
746
+ "726": "plane, carpenters plane, woodworking plane",
747
+ "727": "planetarium",
748
+ "728": "plastic bag",
749
+ "729": "plate rack",
750
+ "730": "plow, plough",
751
+ "731": "plunger, plumbers helper",
752
+ "732": "Polaroid camera, Polaroid Land camera",
753
+ "733": "pole",
754
+ "734": "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria",
755
+ "735": "poncho",
756
+ "736": "pool table, billiard table, snooker table",
757
+ "737": "pop bottle, soda bottle",
758
+ "738": "pot, flowerpot",
759
+ "739": "potters wheel",
760
+ "740": "power drill",
761
+ "741": "prayer rug, prayer mat",
762
+ "742": "printer",
763
+ "743": "prison, prison house",
764
+ "744": "projectile, missile",
765
+ "745": "projector",
766
+ "746": "puck, hockey puck",
767
+ "747": "punching bag, punch bag, punching ball, punchball",
768
+ "748": "purse",
769
+ "749": "quill, quill pen",
770
+ "750": "quilt, comforter, comfort, puff",
771
+ "751": "racer, race car, racing car",
772
+ "752": "racket, racquet",
773
+ "753": "radiator",
774
+ "754": "radio, wireless",
775
+ "755": "radio telescope, radio reflector",
776
+ "756": "rain barrel",
777
+ "757": "recreational vehicle, RV, R.V.",
778
+ "758": "reel",
779
+ "759": "reflex camera",
780
+ "760": "refrigerator, icebox",
781
+ "761": "remote control, remote",
782
+ "762": "restaurant, eating house, eating place, eatery",
783
+ "763": "revolver, six-gun, six-shooter",
784
+ "764": "rifle",
785
+ "765": "rocking chair, rocker",
786
+ "766": "rotisserie",
787
+ "767": "rubber eraser, rubber, pencil eraser",
788
+ "768": "rugby ball",
789
+ "769": "rule, ruler",
790
+ "770": "running shoe",
791
+ "771": "safe",
792
+ "772": "safety pin",
793
+ "773": "saltshaker, salt shaker",
794
+ "774": "sandal",
795
+ "775": "sarong",
796
+ "776": "sax, saxophone",
797
+ "777": "scabbard",
798
+ "778": "scale, weighing machine",
799
+ "779": "school bus",
800
+ "780": "schooner",
801
+ "781": "scoreboard",
802
+ "782": "screen, CRT screen",
803
+ "783": "screw",
804
+ "784": "screwdriver",
805
+ "785": "seat belt, seatbelt",
806
+ "786": "sewing machine",
807
+ "787": "shield, buckler",
808
+ "788": "shoe shop, shoe-shop, shoe store",
809
+ "789": "shoji",
810
+ "790": "shopping basket",
811
+ "791": "shopping cart",
812
+ "792": "shovel",
813
+ "793": "shower cap",
814
+ "794": "shower curtain",
815
+ "795": "ski",
816
+ "796": "ski mask",
817
+ "797": "sleeping bag",
818
+ "798": "slide rule, slipstick",
819
+ "799": "sliding door",
820
+ "800": "slot, one-armed bandit",
821
+ "801": "snorkel",
822
+ "802": "snowmobile",
823
+ "803": "snowplow, snowplough",
824
+ "804": "soap dispenser",
825
+ "805": "soccer ball",
826
+ "806": "sock",
827
+ "807": "solar dish, solar collector, solar furnace",
828
+ "808": "sombrero",
829
+ "809": "soup bowl",
830
+ "810": "space bar",
831
+ "811": "space heater",
832
+ "812": "space shuttle",
833
+ "813": "spatula",
834
+ "814": "speedboat",
835
+ "815": "spider web, spiders web",
836
+ "816": "spindle",
837
+ "817": "sports car, sport car",
838
+ "818": "spotlight, spot",
839
+ "819": "stage",
840
+ "820": "steam locomotive",
841
+ "821": "steel arch bridge",
842
+ "822": "steel drum",
843
+ "823": "stethoscope",
844
+ "824": "stole",
845
+ "825": "stone wall",
846
+ "826": "stopwatch, stop watch",
847
+ "827": "stove",
848
+ "828": "strainer",
849
+ "829": "streetcar, tram, tramcar, trolley, trolley car",
850
+ "830": "stretcher",
851
+ "831": "studio couch, day bed",
852
+ "832": "stupa, tope",
853
+ "833": "submarine, pigboat, sub, U-boat",
854
+ "834": "suit, suit of clothes",
855
+ "835": "sundial",
856
+ "836": "sunglass",
857
+ "837": "sunglasses, dark glasses, shades",
858
+ "838": "sunscreen, sunblock, sun blocker",
859
+ "839": "suspension bridge",
860
+ "840": "swab, swob, mop",
861
+ "841": "sweatshirt",
862
+ "842": "swimming trunks, bathing trunks",
863
+ "843": "swing",
864
+ "844": "switch, electric switch, electrical switch",
865
+ "845": "syringe",
866
+ "846": "table lamp",
867
+ "847": "tank, army tank, armored combat vehicle, armoured combat vehicle",
868
+ "848": "tape player",
869
+ "849": "teapot",
870
+ "850": "teddy, teddy bear",
871
+ "851": "television, television system",
872
+ "852": "tennis ball",
873
+ "853": "thatch, thatched roof",
874
+ "854": "theater curtain, theatre curtain",
875
+ "855": "thimble",
876
+ "856": "thresher, thrasher, threshing machine",
877
+ "857": "throne",
878
+ "858": "tile roof",
879
+ "859": "toaster",
880
+ "860": "tobacco shop, tobacconist shop, tobacconist",
881
+ "861": "toilet seat",
882
+ "862": "torch",
883
+ "863": "totem pole",
884
+ "864": "tow truck, tow car, wrecker",
885
+ "865": "toyshop",
886
+ "866": "tractor",
887
+ "867": "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi",
888
+ "868": "tray",
889
+ "869": "trench coat",
890
+ "870": "tricycle, trike, velocipede",
891
+ "871": "trimaran",
892
+ "872": "tripod",
893
+ "873": "triumphal arch",
894
+ "874": "trolleybus, trolley coach, trackless trolley",
895
+ "875": "trombone",
896
+ "876": "tub, vat",
897
+ "877": "turnstile",
898
+ "878": "typewriter keyboard",
899
+ "879": "umbrella",
900
+ "880": "unicycle, monocycle",
901
+ "881": "upright, upright piano",
902
+ "882": "vacuum, vacuum cleaner",
903
+ "883": "vase",
904
+ "884": "vault",
905
+ "885": "velvet",
906
+ "886": "vending machine",
907
+ "887": "vestment",
908
+ "888": "viaduct",
909
+ "889": "violin, fiddle",
910
+ "890": "volleyball",
911
+ "891": "waffle iron",
912
+ "892": "wall clock",
913
+ "893": "wallet, billfold, notecase, pocketbook",
914
+ "894": "wardrobe, closet, press",
915
+ "895": "warplane, military plane",
916
+ "896": "washbasin, handbasin, washbowl, lavabo, wash-hand basin",
917
+ "897": "washer, automatic washer, washing machine",
918
+ "898": "water bottle",
919
+ "899": "water jug",
920
+ "900": "water tower",
921
+ "901": "whiskey jug",
922
+ "902": "whistle",
923
+ "903": "wig",
924
+ "904": "window screen",
925
+ "905": "window shade",
926
+ "906": "Windsor tie",
927
+ "907": "wine bottle",
928
+ "908": "wing",
929
+ "909": "wok",
930
+ "910": "wooden spoon",
931
+ "911": "wool, woolen, woollen",
932
+ "912": "worm fence, snake fence, snake-rail fence, Virginia fence",
933
+ "913": "wreck",
934
+ "914": "yawl",
935
+ "915": "yurt",
936
+ "916": "web site, website, internet site, site",
937
+ "917": "comic book",
938
+ "918": "crossword puzzle, crossword",
939
+ "919": "street sign",
940
+ "920": "traffic light, traffic signal, stoplight",
941
+ "921": "book jacket, dust cover, dust jacket, dust wrapper",
942
+ "922": "menu",
943
+ "923": "plate",
944
+ "924": "guacamole",
945
+ "925": "consomme",
946
+ "926": "hot pot, hotpot",
947
+ "927": "trifle",
948
+ "928": "ice cream, icecream",
949
+ "929": "ice lolly, lolly, lollipop, popsicle",
950
+ "930": "French loaf",
951
+ "931": "bagel, beigel",
952
+ "932": "pretzel",
953
+ "933": "cheeseburger",
954
+ "934": "hotdog, hot dog, red hot",
955
+ "935": "mashed potato",
956
+ "936": "head cabbage",
957
+ "937": "broccoli",
958
+ "938": "cauliflower",
959
+ "939": "zucchini, courgette",
960
+ "940": "spaghetti squash",
961
+ "941": "acorn squash",
962
+ "942": "butternut squash",
963
+ "943": "cucumber, cuke",
964
+ "944": "artichoke, globe artichoke",
965
+ "945": "bell pepper",
966
+ "946": "cardoon",
967
+ "947": "mushroom",
968
+ "948": "Granny Smith",
969
+ "949": "strawberry",
970
+ "950": "orange",
971
+ "951": "lemon",
972
+ "952": "fig",
973
+ "953": "pineapple, ananas",
974
+ "954": "banana",
975
+ "955": "jackfruit, jak, jack",
976
+ "956": "custard apple",
977
+ "957": "pomegranate",
978
+ "958": "hay",
979
+ "959": "carbonara",
980
+ "960": "chocolate sauce, chocolate syrup",
981
+ "961": "dough",
982
+ "962": "meat loaf, meatloaf",
983
+ "963": "pizza, pizza pie",
984
+ "964": "potpie",
985
+ "965": "burrito",
986
+ "966": "red wine",
987
+ "967": "espresso",
988
+ "968": "cup",
989
+ "969": "eggnog",
990
+ "970": "alp",
991
+ "971": "bubble",
992
+ "972": "cliff, drop, drop-off",
993
+ "973": "coral reef",
994
+ "974": "geyser",
995
+ "975": "lakeside, lakeshore",
996
+ "976": "promontory, headland, head, foreland",
997
+ "977": "sandbar, sand bar",
998
+ "978": "seashore, coast, seacoast, sea-coast",
999
+ "979": "valley, vale",
1000
+ "980": "volcano",
1001
+ "981": "ballplayer, baseball player",
1002
+ "982": "groom, bridegroom",
1003
+ "983": "scuba diver",
1004
+ "984": "rapeseed",
1005
+ "985": "daisy",
1006
+ "986": "yellow ladys slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum",
1007
+ "987": "corn",
1008
+ "988": "acorn",
1009
+ "989": "hip, rose hip, rosehip",
1010
+ "990": "buckeye, horse chestnut, conker",
1011
+ "991": "coral fungus",
1012
+ "992": "agaric",
1013
+ "993": "gyromitra",
1014
+ "994": "stinkhorn, carrion fungus",
1015
+ "995": "earthstar",
1016
+ "996": "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa",
1017
+ "997": "bolete",
1018
+ "998": "ear, spike, capitulum",
1019
+ "999": "toilet tissue, toilet paper, bathroom tissue"
1020
+ }
1021
+ }
SiT-L-2-256/pipeline.py CHANGED
@@ -1,82 +1,349 @@
1
- from typing import List, Optional, Union
2
-
3
- import torch
4
-
5
- from diffusers.image_processor import VaeImageProcessor
6
- from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
7
- from diffusers.utils.torch_utils import randn_tensor
8
-
9
-
10
- class SiTPipeline(DiffusionPipeline):
11
- model_cpu_offload_seq = "transformer->vae"
12
-
13
- def __init__(self, transformer, scheduler, vae):
14
- super().__init__()
15
- self.register_modules(transformer=transformer, scheduler=scheduler, vae=vae)
16
- self.vae_scale_factor = 8
17
- self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
18
-
19
- @torch.no_grad()
20
- def __call__(
21
- self,
22
- class_labels: Union[int, List[int]] = 207,
23
- height: int = 256,
24
- width: int = 256,
25
- num_inference_steps: int = 250,
26
- guidance_scale: float = 4.0,
27
- generator: Optional[torch.Generator] = None,
28
- output_type: str = "pil",
29
- return_dict: bool = True,
30
- ):
31
- device = self._execution_device
32
- if isinstance(class_labels, int):
33
- class_labels = [class_labels]
34
- batch_size = len(class_labels)
35
-
36
- latent_h = height // self.vae_scale_factor
37
- latent_w = width // self.vae_scale_factor
38
- latents = randn_tensor(
39
- (batch_size, self.transformer.config.in_channels, latent_h, latent_w),
40
- generator=generator,
41
- device=device,
42
- dtype=self.transformer.dtype,
43
- )
44
-
45
- labels = torch.tensor(class_labels, device=device, dtype=torch.long)
46
- do_cfg = guidance_scale is not None and guidance_scale > 1.0
47
- if do_cfg:
48
- null_label = torch.full((batch_size,), self.transformer.config.num_classes, device=device, dtype=torch.long)
49
- labels = torch.cat([labels, null_label], dim=0)
50
-
51
- self.scheduler.set_timesteps(num_inference_steps, device=device)
52
- timesteps = self.scheduler.timesteps
53
-
54
- for t in self.progress_bar(timesteps):
55
- t_batch = torch.full((batch_size,), t, device=device, dtype=latents.dtype)
56
- model_input = latents
57
- if do_cfg:
58
- model_input = torch.cat([latents, latents], dim=0)
59
- t_batch = torch.cat([t_batch, t_batch], dim=0)
60
-
61
- model_pred = self.transformer(
62
- hidden_states=model_input,
63
- timestep=t_batch,
64
- class_labels=labels,
65
- ).sample
66
-
67
- if do_cfg:
68
- cond, uncond = model_pred.chunk(2, dim=0)
69
- model_pred = uncond + guidance_scale * (cond - uncond)
70
-
71
- latents = self.scheduler.step(model_pred, t, latents, generator=generator).prev_sample
72
-
73
- image = self.vae.decode(latents / 0.18215).sample
74
- # Keep PyTorch outputs in raw VAE range [-1, 1] to match original SiT scripts.
75
- if output_type == "pt":
76
- image = image
77
- else:
78
- image = self.image_processor.postprocess(image, output_type=output_type)
79
-
80
- if not return_dict:
81
- return (image,)
82
- return ImagePipelineOutput(images=image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Hub custom pipeline: SiTPipeline.
2
+ Load with native Hugging Face diffusers and trust_remote_code=True.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ # Copyright 2026 The HuggingFace Team. All rights reserved.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+
21
+ import json
22
+ from pathlib import Path
23
+ from typing import Dict, List, Optional, Tuple, Union
24
+
25
+ import torch
26
+
27
+ from diffusers.image_processor import VaeImageProcessor
28
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
29
+ from diffusers.utils.torch_utils import randn_tensor
30
+
31
+ EXAMPLE_DOC_STRING = """
32
+ Examples:
33
+ ```py
34
+ >>> from pathlib import Path
35
+ >>> from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler
36
+ >>> import torch
37
+
38
+ >>> model_dir = Path("./SiT-XL-2-256").resolve()
39
+ >>> pipe = DiffusionPipeline.from_pretrained(
40
+ ... str(model_dir),
41
+ ... local_files_only=True,
42
+ ... custom_pipeline=str(model_dir / "pipeline.py"),
43
+ ... trust_remote_code=True,
44
+ ... torch_dtype=torch.bfloat16,
45
+ ... )
46
+ >>> pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(pipe.scheduler.config)
47
+ >>> pipe.to("cuda")
48
+
49
+ >>> print(pipe.id2label[207])
50
+ >>> print(pipe.get_label_ids("golden retriever"))
51
+
52
+ >>> generator = torch.Generator(device="cuda").manual_seed(42)
53
+ >>> image = pipe(
54
+ ... class_labels="golden retriever",
55
+ ... height=256,
56
+ ... width=256,
57
+ ... num_inference_steps=250,
58
+ ... guidance_scale=4.0,
59
+ ... generator=generator,
60
+ ... ).images[0]
61
+ ```
62
+ """
63
+
64
+ class SiTPipeline(DiffusionPipeline):
65
+ r"""
66
+ Pipeline for class-conditional image generation with Scalable Interpolant Transformers (SiT).
67
+
68
+ Parameters:
69
+ transformer ([`SiTTransformer2DModel`]):
70
+ Class-conditional SiT transformer that predicts flow-matching velocity in latent space.
71
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
72
+ Flow-matching Euler scheduler. Other [`KarrasDiffusionSchedulers`] can be swapped at inference time.
73
+ vae ([`AutoencoderKL`]):
74
+ Variational autoencoder used to decode transformer latents to pixels.
75
+ id2label (`dict[int, str]`, *optional*):
76
+ ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
77
+ """
78
+
79
+ model_cpu_offload_seq = "transformer->vae"
80
+
81
+ def __init__(
82
+ self,
83
+ transformer,
84
+ scheduler,
85
+ vae,
86
+ id2label: Optional[Dict[Union[int, str], str]] = None,
87
+ ):
88
+ super().__init__()
89
+ if scheduler is None:
90
+ scheduler = FlowMatchEulerDiscreteScheduler(
91
+ num_train_timesteps=1000,
92
+ shift=1.0,
93
+ stochastic_sampling=False,
94
+ )
95
+ self.register_modules(transformer=transformer, scheduler=scheduler, vae=vae)
96
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
97
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
98
+ self._id2label = self._normalize_id2label(id2label)
99
+ self.labels = self._build_label2id(self._id2label)
100
+ self._labels_loaded_from_model_index = bool(self._id2label)
101
+
102
+ def _ensure_labels_loaded(self) -> None:
103
+ if self._labels_loaded_from_model_index:
104
+ return
105
+ loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None))
106
+ if loaded:
107
+ self._id2label = loaded
108
+ self.labels = self._build_label2id(self._id2label)
109
+ self._labels_loaded_from_model_index = True
110
+
111
+ @staticmethod
112
+ def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
113
+ if not id2label:
114
+ return {}
115
+ return {int(key): value for key, value in id2label.items()}
116
+
117
+ @staticmethod
118
+ def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]:
119
+ if not variant_path:
120
+ return {}
121
+ variant_dir = Path(variant_path).resolve()
122
+ model_index_path = variant_dir / "model_index.json"
123
+ if not model_index_path.exists():
124
+ return {}
125
+ raw = json.loads(model_index_path.read_text(encoding="utf-8"))
126
+ id2label = raw.get("id2label")
127
+ if not isinstance(id2label, dict):
128
+ return {}
129
+ return {int(key): value for key, value in id2label.items()}
130
+
131
+ @staticmethod
132
+ def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
133
+ label2id: Dict[str, int] = {}
134
+ for class_id, value in id2label.items():
135
+ for synonym in value.split(","):
136
+ synonym = synonym.strip()
137
+ if synonym:
138
+ label2id[synonym] = int(class_id)
139
+ return dict(sorted(label2id.items()))
140
+
141
+ @property
142
+ def id2label(self) -> Dict[int, str]:
143
+ r"""ImageNet class id to English label string (comma-separated synonyms)."""
144
+ self._ensure_labels_loaded()
145
+ return self._id2label
146
+
147
+ def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
148
+ r"""
149
+ Map ImageNet label strings to class ids.
150
+
151
+ Args:
152
+ label (`str` or `list[str]`):
153
+ One or more English label strings. Each string must match a synonym in `id2label`.
154
+ """
155
+ self._ensure_labels_loaded()
156
+ label2id = self.labels
157
+ if not label2id:
158
+ raise ValueError("No English labels loaded. Ensure `id2label` exists in model_index.json.")
159
+
160
+ if isinstance(label, str):
161
+ label = [label]
162
+
163
+ missing = [item for item in label if item not in label2id]
164
+ if missing:
165
+ preview = ", ".join(list(label2id.keys())[:8])
166
+ raise ValueError(f"Unknown English label(s): {missing}. Example valid labels: {preview}, ...")
167
+ return [label2id[item] for item in label]
168
+
169
+ def _normalize_class_labels(
170
+ self,
171
+ class_labels: Union[int, str, List[Union[int, str]], torch.LongTensor],
172
+ ) -> torch.LongTensor:
173
+ if torch.is_tensor(class_labels):
174
+ return class_labels.to(device=self._execution_device, dtype=torch.long).reshape(-1)
175
+
176
+ if isinstance(class_labels, int):
177
+ class_label_ids = [class_labels]
178
+ elif isinstance(class_labels, str):
179
+ class_label_ids = self.get_label_ids(class_labels)
180
+ elif class_labels and isinstance(class_labels[0], str):
181
+ class_label_ids = self.get_label_ids(class_labels)
182
+ else:
183
+ class_label_ids = list(class_labels)
184
+
185
+ return torch.tensor(class_label_ids, device=self._execution_device, dtype=torch.long).reshape(-1)
186
+
187
+ def _default_image_size(self) -> int:
188
+ return int(self.transformer.config.input_size) * self.vae_scale_factor
189
+
190
+ def check_inputs(
191
+ self,
192
+ height: int,
193
+ width: int,
194
+ num_inference_steps: int,
195
+ output_type: str,
196
+ ) -> None:
197
+ if num_inference_steps < 1:
198
+ raise ValueError("num_inference_steps must be >= 1.")
199
+ if output_type not in {"pil", "np", "pt", "latent"}:
200
+ raise ValueError("output_type must be one of: 'pil', 'np', 'pt', 'latent'.")
201
+
202
+ if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0:
203
+ raise ValueError(
204
+ f"height and width must be divisible by the VAE downsample factor {self.vae_scale_factor}."
205
+ )
206
+
207
+ latent_height = height // self.vae_scale_factor
208
+ latent_width = width // self.vae_scale_factor
209
+ expected_size = int(self.transformer.config.input_size)
210
+ if latent_height != expected_size or latent_width != expected_size:
211
+ raise ValueError(
212
+ f"Requested latent size {(latent_height, latent_width)} does not match the pretrained "
213
+ f"transformer input_size={expected_size}. Use height=width={self._default_image_size()}."
214
+ )
215
+
216
+ def prepare_latents(
217
+ self,
218
+ batch_size: int,
219
+ height: int,
220
+ width: int,
221
+ dtype: torch.dtype,
222
+ device: torch.device,
223
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]],
224
+ ) -> torch.Tensor:
225
+ latent_height = height // self.vae_scale_factor
226
+ latent_width = width // self.vae_scale_factor
227
+ return randn_tensor(
228
+ (batch_size, self.transformer.config.in_channels, latent_height, latent_width),
229
+ generator=generator,
230
+ device=device,
231
+ dtype=dtype,
232
+ )
233
+
234
+ @staticmethod
235
+ def _apply_classifier_free_guidance(model_output: torch.Tensor, guidance_scale: float) -> torch.Tensor:
236
+ if guidance_scale <= 1.0:
237
+ return model_output
238
+ model_output_cond, model_output_uncond = model_output.chunk(2)
239
+ return model_output_uncond + guidance_scale * (model_output_cond - model_output_uncond)
240
+
241
+ def decode_latents(self, latents: torch.Tensor, output_type: str = "pil"):
242
+ if output_type == "latent":
243
+ return latents
244
+
245
+ scaling_factor = getattr(self.vae.config, "scaling_factor", 0.18215)
246
+ image = self.vae.decode(latents / scaling_factor).sample
247
+ if output_type == "pt":
248
+ return image
249
+ return self.image_processor.postprocess(image, output_type=output_type)
250
+
251
+ @torch.inference_mode()
252
+ def __call__(
253
+ self,
254
+ class_labels: Union[int, str, List[Union[int, str]], torch.LongTensor],
255
+ height: Optional[int] = None,
256
+ width: Optional[int] = None,
257
+ num_inference_steps: int = 250,
258
+ guidance_scale: float = 4.0,
259
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
260
+ output_type: str = "pil",
261
+ return_dict: bool = True,
262
+ ) -> Union[ImagePipelineOutput, Tuple]:
263
+ r"""
264
+ Generate class-conditional images with SiT.
265
+
266
+ Args:
267
+ class_labels (`int`, `str`, `list[int]`, `list[str]`, or `torch.LongTensor`):
268
+ ImageNet class indices or human-readable English label strings.
269
+ height (`int`, *optional*):
270
+ Output image height in pixels. Defaults to the pretrained native resolution.
271
+ width (`int`, *optional*):
272
+ Output image width in pixels. Defaults to the pretrained native resolution.
273
+ num_inference_steps (`int`, defaults to `250`):
274
+ Number of denoising steps.
275
+ guidance_scale (`float`, defaults to `4.0`):
276
+ Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`.
277
+ generator (`torch.Generator`, *optional*):
278
+ RNG for reproducibility.
279
+ output_type (`str`, defaults to `"pil"`):
280
+ `"pil"`, `"np"`, `"pt"`, or `"latent"`.
281
+ return_dict (`bool`, defaults to `True`):
282
+ Return [`ImagePipelineOutput`] if True.
283
+ """
284
+ default_size = self._default_image_size()
285
+ height = int(height or default_size)
286
+ width = int(width or default_size)
287
+ self.check_inputs(height, width, num_inference_steps, output_type)
288
+
289
+ device = self._execution_device
290
+ model_dtype = next(self.transformer.parameters()).dtype
291
+ class_labels_tensor = self._normalize_class_labels(class_labels)
292
+ batch_size = class_labels_tensor.numel()
293
+ do_cfg = guidance_scale > 1.0
294
+
295
+ latents = self.prepare_latents(
296
+ batch_size=batch_size,
297
+ height=height,
298
+ width=width,
299
+ dtype=model_dtype,
300
+ device=device,
301
+ generator=generator,
302
+ )
303
+
304
+ labels = class_labels_tensor
305
+ if do_cfg:
306
+ null_labels = torch.full_like(class_labels_tensor, self.transformer.config.num_classes)
307
+ labels = torch.cat([class_labels_tensor, null_labels], dim=0)
308
+
309
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
310
+ num_train_timesteps = self.scheduler.config.num_train_timesteps
311
+
312
+ if getattr(self.scheduler.config, "stochastic_sampling", False):
313
+ raise ValueError(
314
+ "SiT expects deterministic FlowMatchEulerDiscreteScheduler stepping "
315
+ "(scheduler.config.stochastic_sampling=False)."
316
+ )
317
+
318
+ for t in self.progress_bar(self.scheduler.timesteps):
319
+ flow_time = 1.0 - float(t) / num_train_timesteps
320
+ if do_cfg:
321
+ model_input = torch.cat([latents, latents], dim=0)
322
+ else:
323
+ model_input = latents
324
+
325
+ timestep_batch = torch.full((model_input.shape[0],), flow_time, device=device, dtype=model_dtype)
326
+ model_output = self.transformer(
327
+ hidden_states=model_input,
328
+ timestep=timestep_batch,
329
+ class_labels=labels,
330
+ return_dict=True,
331
+ ).sample
332
+ model_output = self._apply_classifier_free_guidance(model_output, guidance_scale=guidance_scale)
333
+ # SiT predicts dx/d(flow_time) with flow_time increasing from noise (0) to data (1).
334
+ # FlowMatchEulerDiscreteScheduler integrates over sigma decreasing from 1 to 0, so flip sign.
335
+ model_output = -model_output
336
+ latents = self.scheduler.step(
337
+ model_output=model_output,
338
+ timestep=t,
339
+ sample=latents,
340
+ generator=generator,
341
+ return_dict=True,
342
+ ).prev_sample
343
+
344
+ image = self.decode_latents(latents, output_type=output_type)
345
+
346
+ self.maybe_free_model_hooks()
347
+ if not return_dict:
348
+ return (image,)
349
+ return ImagePipelineOutput(images=image)
SiT-L-2-256/scheduler/scheduler_config.json CHANGED
@@ -1,9 +1,7 @@
1
- {
2
- "_class_name": "SiTFlowMatchScheduler",
3
- "_diffusers_version": "0.36.0",
4
- "diffusion_form": "sigma",
5
- "diffusion_norm": 1.0,
6
- "mode": "ode",
7
- "num_train_timesteps": 1000,
8
- "shift": 1.0
9
- }
 
1
+ {
2
+ "_class_name": "FlowMatchEulerDiscreteScheduler",
3
+ "_diffusers_version": "0.36.0",
4
+ "num_train_timesteps": 1000,
5
+ "shift": 1.0,
6
+ "stochastic_sampling": false
7
+ }
 
 
SiT-L-2-256/transformer/transformer_sit.py CHANGED
@@ -1,224 +1,240 @@
1
- import math
2
- from dataclasses import dataclass
3
- from typing import Optional
4
-
5
- import numpy as np
6
- import torch
7
- import torch.nn as nn
8
- from timm.models.vision_transformer import Attention, Mlp, PatchEmbed
9
-
10
- from diffusers.configuration_utils import ConfigMixin, register_to_config
11
- from diffusers.models.modeling_utils import ModelMixin
12
- from diffusers.utils import BaseOutput
13
-
14
-
15
- def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
16
- return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
17
-
18
-
19
- @dataclass
20
- class SiTTransformer2DModelOutput(BaseOutput):
21
- sample: torch.Tensor
22
-
23
-
24
- class TimestepEmbedder(nn.Module):
25
- def __init__(self, hidden_size: int, frequency_embedding_size: int = 256):
26
- super().__init__()
27
- self.mlp = nn.Sequential(
28
- nn.Linear(frequency_embedding_size, hidden_size, bias=True),
29
- nn.SiLU(),
30
- nn.Linear(hidden_size, hidden_size, bias=True),
31
- )
32
- self.frequency_embedding_size = frequency_embedding_size
33
-
34
- @staticmethod
35
- def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 10000) -> torch.Tensor:
36
- half = dim // 2
37
- freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
38
- device=t.device
39
- )
40
- args = t[:, None].float() * freqs[None]
41
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
42
- if dim % 2:
43
- embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
44
- return embedding
45
-
46
- def forward(self, t: torch.Tensor) -> torch.Tensor:
47
- return self.mlp(self.timestep_embedding(t, self.frequency_embedding_size))
48
-
49
-
50
- class LabelEmbedder(nn.Module):
51
- def __init__(self, num_classes: int, hidden_size: int, dropout_prob: float):
52
- super().__init__()
53
- use_cfg_embedding = dropout_prob > 0
54
- self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
55
- self.num_classes = num_classes
56
- self.dropout_prob = dropout_prob
57
-
58
- def token_drop(self, labels: torch.Tensor, force_drop_ids: Optional[torch.Tensor] = None) -> torch.Tensor:
59
- if force_drop_ids is None:
60
- drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
61
- else:
62
- drop_ids = force_drop_ids == 1
63
- labels = torch.where(drop_ids, self.num_classes, labels)
64
- return labels
65
-
66
- def forward(
67
- self,
68
- labels: torch.Tensor,
69
- train: bool,
70
- force_drop_ids: Optional[torch.Tensor] = None,
71
- ) -> torch.Tensor:
72
- use_dropout = self.dropout_prob > 0
73
- if (train and use_dropout) or (force_drop_ids is not None):
74
- labels = self.token_drop(labels, force_drop_ids)
75
- return self.embedding_table(labels)
76
-
77
-
78
- class SiTBlock(nn.Module):
79
- def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float = 4.0, **block_kwargs):
80
- super().__init__()
81
- self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
82
- self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
83
- self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
84
- mlp_hidden_dim = int(hidden_size * mlp_ratio)
85
- approx_gelu = lambda: nn.GELU(approximate="tanh")
86
- self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
87
- self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
88
-
89
- def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
90
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
91
- x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
92
- x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
93
- return x
94
-
95
-
96
- class FinalLayer(nn.Module):
97
- def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
98
- super().__init__()
99
- self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
100
- self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
101
- self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
102
-
103
- def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
104
- shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
105
- x = modulate(self.norm_final(x), shift, scale)
106
- return self.linear(x)
107
-
108
-
109
- class SiTTransformer2DModel(ModelMixin, ConfigMixin):
110
- @register_to_config
111
- def __init__(
112
- self,
113
- input_size: int = 32,
114
- patch_size: int = 2,
115
- in_channels: int = 4,
116
- hidden_size: int = 1152,
117
- depth: int = 28,
118
- num_heads: int = 16,
119
- mlp_ratio: float = 4.0,
120
- class_dropout_prob: float = 0.1,
121
- num_classes: int = 1000,
122
- learn_sigma: bool = True,
123
- ):
124
- super().__init__()
125
- self.learn_sigma = learn_sigma
126
- self.in_channels = in_channels
127
- self.out_channels = in_channels * 2 if learn_sigma else in_channels
128
- self.patch_size = patch_size
129
- self.num_classes = num_classes
130
-
131
- self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
132
- self.t_embedder = TimestepEmbedder(hidden_size)
133
- self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
134
- num_patches = self.x_embedder.num_patches
135
- self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
136
-
137
- self.blocks = nn.ModuleList([SiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)])
138
- self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
139
- self.initialize_weights()
140
-
141
- def initialize_weights(self) -> None:
142
- def _basic_init(module: nn.Module):
143
- if isinstance(module, nn.Linear):
144
- torch.nn.init.xavier_uniform_(module.weight)
145
- if module.bias is not None:
146
- nn.init.constant_(module.bias, 0)
147
-
148
- self.apply(_basic_init)
149
- pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5))
150
- self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
151
-
152
- w = self.x_embedder.proj.weight.data
153
- nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
154
- nn.init.constant_(self.x_embedder.proj.bias, 0)
155
- nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
156
- nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
157
- nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
158
- for block in self.blocks:
159
- nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
160
- nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
161
- nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
162
- nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
163
- nn.init.constant_(self.final_layer.linear.weight, 0)
164
- nn.init.constant_(self.final_layer.linear.bias, 0)
165
-
166
- def unpatchify(self, x: torch.Tensor) -> torch.Tensor:
167
- c = self.out_channels
168
- p = self.x_embedder.patch_size[0]
169
- h = w = int(x.shape[1] ** 0.5)
170
- x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
171
- x = torch.einsum("nhwpqc->nchpwq", x)
172
- return x.reshape(shape=(x.shape[0], c, h * p, h * p))
173
-
174
- def forward(
175
- self,
176
- hidden_states: torch.Tensor,
177
- timestep: torch.Tensor,
178
- class_labels: torch.Tensor,
179
- force_drop_ids: Optional[torch.Tensor] = None,
180
- return_dict: bool = True,
181
- ) -> SiTTransformer2DModelOutput:
182
- x = self.x_embedder(hidden_states) + self.pos_embed
183
- t = self.t_embedder(timestep)
184
- y = self.y_embedder(class_labels, self.training, force_drop_ids=force_drop_ids)
185
- c = t + y
186
- for block in self.blocks:
187
- x = block(x, c)
188
- x = self.final_layer(x, c)
189
- x = self.unpatchify(x)
190
- if self.learn_sigma:
191
- x, _ = x.chunk(2, dim=1)
192
- if not return_dict:
193
- return (x,)
194
- return SiTTransformer2DModelOutput(sample=x)
195
-
196
-
197
- def get_2d_sincos_pos_embed(embed_dim: int, grid_size: int, cls_token: bool = False, extra_tokens: int = 0):
198
- grid_h = np.arange(grid_size, dtype=np.float32)
199
- grid_w = np.arange(grid_size, dtype=np.float32)
200
- grid = np.meshgrid(grid_w, grid_h)
201
- grid = np.stack(grid, axis=0).reshape([2, 1, grid_size, grid_size])
202
- pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
203
- if cls_token and extra_tokens > 0:
204
- pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
205
- return pos_embed
206
-
207
-
208
- def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: np.ndarray):
209
- assert embed_dim % 2 == 0
210
- emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
211
- emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
212
- return np.concatenate([emb_h, emb_w], axis=1)
213
-
214
-
215
- def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: np.ndarray):
216
- assert embed_dim % 2 == 0
217
- omega = np.arange(embed_dim // 2, dtype=np.float64)
218
- omega /= embed_dim / 2.0
219
- omega = 1.0 / 10000**omega
220
- pos = pos.reshape(-1)
221
- out = np.einsum("m,d->md", pos, omega)
222
- emb_sin = np.sin(out)
223
- emb_cos = np.cos(out)
224
- return np.concatenate([emb_sin, emb_cos], axis=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from dataclasses import dataclass
17
+ from typing import Optional
18
+
19
+ import numpy as np
20
+ import torch
21
+ import torch.nn as nn
22
+ from timm.models.vision_transformer import Attention, Mlp, PatchEmbed
23
+
24
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
25
+ from diffusers.models.modeling_utils import ModelMixin
26
+ from diffusers.utils import BaseOutput
27
+
28
+
29
+ def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
30
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
31
+
32
+
33
+ @dataclass
34
+ class SiTTransformer2DModelOutput(BaseOutput):
35
+ sample: torch.Tensor
36
+
37
+
38
+ class TimestepEmbedder(nn.Module):
39
+ def __init__(self, hidden_size: int, frequency_embedding_size: int = 256):
40
+ super().__init__()
41
+ self.mlp = nn.Sequential(
42
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
43
+ nn.SiLU(),
44
+ nn.Linear(hidden_size, hidden_size, bias=True),
45
+ )
46
+ self.frequency_embedding_size = frequency_embedding_size
47
+
48
+ @staticmethod
49
+ def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 10000) -> torch.Tensor:
50
+ half = dim // 2
51
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
52
+ device=t.device
53
+ )
54
+ args = t[:, None].float() * freqs[None]
55
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
56
+ if dim % 2:
57
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
58
+ return embedding
59
+
60
+ def forward(self, t: torch.Tensor) -> torch.Tensor:
61
+ emb = self.timestep_embedding(t.float(), self.frequency_embedding_size)
62
+ weight_dtype = self.mlp[0].weight.dtype
63
+ return self.mlp(emb.to(dtype=weight_dtype))
64
+
65
+
66
+ class LabelEmbedder(nn.Module):
67
+ def __init__(self, num_classes: int, hidden_size: int, dropout_prob: float):
68
+ super().__init__()
69
+ use_cfg_embedding = dropout_prob > 0
70
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
71
+ self.num_classes = num_classes
72
+ self.dropout_prob = dropout_prob
73
+
74
+ def token_drop(self, labels: torch.Tensor, force_drop_ids: Optional[torch.Tensor] = None) -> torch.Tensor:
75
+ if force_drop_ids is None:
76
+ drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
77
+ else:
78
+ drop_ids = force_drop_ids == 1
79
+ labels = torch.where(drop_ids, self.num_classes, labels)
80
+ return labels
81
+
82
+ def forward(
83
+ self,
84
+ labels: torch.Tensor,
85
+ train: bool,
86
+ force_drop_ids: Optional[torch.Tensor] = None,
87
+ ) -> torch.Tensor:
88
+ use_dropout = self.dropout_prob > 0
89
+ if (train and use_dropout) or (force_drop_ids is not None):
90
+ labels = self.token_drop(labels, force_drop_ids)
91
+ return self.embedding_table(labels)
92
+
93
+
94
+ class SiTBlock(nn.Module):
95
+ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float = 4.0, **block_kwargs):
96
+ super().__init__()
97
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
98
+ self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
99
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
100
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
101
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
102
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
103
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
104
+
105
+ def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
106
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
107
+ x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
108
+ x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
109
+ return x
110
+
111
+
112
+ class FinalLayer(nn.Module):
113
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
114
+ super().__init__()
115
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
116
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
117
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
118
+
119
+ def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
120
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
121
+ x = modulate(self.norm_final(x), shift, scale)
122
+ return self.linear(x)
123
+
124
+
125
+ class SiTTransformer2DModel(ModelMixin, ConfigMixin):
126
+ @register_to_config
127
+ def __init__(
128
+ self,
129
+ input_size: int = 32,
130
+ patch_size: int = 2,
131
+ in_channels: int = 4,
132
+ hidden_size: int = 1152,
133
+ depth: int = 28,
134
+ num_heads: int = 16,
135
+ mlp_ratio: float = 4.0,
136
+ class_dropout_prob: float = 0.1,
137
+ num_classes: int = 1000,
138
+ learn_sigma: bool = True,
139
+ ):
140
+ super().__init__()
141
+ self.learn_sigma = learn_sigma
142
+ self.in_channels = in_channels
143
+ self.out_channels = in_channels * 2 if learn_sigma else in_channels
144
+ self.patch_size = patch_size
145
+ self.num_classes = num_classes
146
+
147
+ self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
148
+ self.t_embedder = TimestepEmbedder(hidden_size)
149
+ self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
150
+ num_patches = self.x_embedder.num_patches
151
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
152
+
153
+ self.blocks = nn.ModuleList([SiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)])
154
+ self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
155
+ self.initialize_weights()
156
+
157
+ def initialize_weights(self) -> None:
158
+ def _basic_init(module: nn.Module):
159
+ if isinstance(module, nn.Linear):
160
+ torch.nn.init.xavier_uniform_(module.weight)
161
+ if module.bias is not None:
162
+ nn.init.constant_(module.bias, 0)
163
+
164
+ self.apply(_basic_init)
165
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5))
166
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
167
+
168
+ w = self.x_embedder.proj.weight.data
169
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
170
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
171
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
172
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
173
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
174
+ for block in self.blocks:
175
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
176
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
177
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
178
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
179
+ nn.init.constant_(self.final_layer.linear.weight, 0)
180
+ nn.init.constant_(self.final_layer.linear.bias, 0)
181
+
182
+ def unpatchify(self, x: torch.Tensor) -> torch.Tensor:
183
+ c = self.out_channels
184
+ p = self.x_embedder.patch_size[0]
185
+ h = w = int(x.shape[1] ** 0.5)
186
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
187
+ x = torch.einsum("nhwpqc->nchpwq", x)
188
+ return x.reshape(shape=(x.shape[0], c, h * p, h * p))
189
+
190
+ def forward(
191
+ self,
192
+ hidden_states: torch.Tensor,
193
+ timestep: torch.Tensor,
194
+ class_labels: torch.Tensor,
195
+ force_drop_ids: Optional[torch.Tensor] = None,
196
+ return_dict: bool = True,
197
+ ) -> SiTTransformer2DModelOutput:
198
+ x = self.x_embedder(hidden_states) + self.pos_embed
199
+ t = self.t_embedder(timestep)
200
+ y = self.y_embedder(class_labels, self.training, force_drop_ids=force_drop_ids)
201
+ c = t + y
202
+ for block in self.blocks:
203
+ x = block(x, c)
204
+ x = self.final_layer(x, c)
205
+ x = self.unpatchify(x)
206
+ if self.learn_sigma:
207
+ x, _ = x.chunk(2, dim=1)
208
+ if not return_dict:
209
+ return (x,)
210
+ return SiTTransformer2DModelOutput(sample=x)
211
+
212
+
213
+ def get_2d_sincos_pos_embed(embed_dim: int, grid_size: int, cls_token: bool = False, extra_tokens: int = 0):
214
+ grid_h = np.arange(grid_size, dtype=np.float32)
215
+ grid_w = np.arange(grid_size, dtype=np.float32)
216
+ grid = np.meshgrid(grid_w, grid_h)
217
+ grid = np.stack(grid, axis=0).reshape([2, 1, grid_size, grid_size])
218
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
219
+ if cls_token and extra_tokens > 0:
220
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
221
+ return pos_embed
222
+
223
+
224
+ def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: np.ndarray):
225
+ assert embed_dim % 2 == 0
226
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
227
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
228
+ return np.concatenate([emb_h, emb_w], axis=1)
229
+
230
+
231
+ def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: np.ndarray):
232
+ assert embed_dim % 2 == 0
233
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
234
+ omega /= embed_dim / 2.0
235
+ omega = 1.0 / 10000**omega
236
+ pos = pos.reshape(-1)
237
+ out = np.einsum("m,d->md", pos, omega)
238
+ emb_sin = np.sin(out)
239
+ emb_cos = np.cos(out)
240
+ return np.concatenate([emb_sin, emb_cos], axis=1)
SiT-S-2-256/__pycache__/pipeline.cpython-312.pyc ADDED
Binary file (7.51 kB). View file
 
SiT-S-2-256/model_index.json CHANGED
@@ -1,19 +1,1021 @@
1
- {
2
- "_class_name": [
3
- "pipeline",
4
- "SiTPipeline"
5
- ],
6
- "_diffusers_version": "0.36.0",
7
- "scheduler": [
8
- "scheduling_flow_match_sit",
9
- "SiTFlowMatchScheduler"
10
- ],
11
- "transformer": [
12
- "transformer_sit",
13
- "SiTTransformer2DModel"
14
- ],
15
- "vae": [
16
- "diffusers",
17
- "AutoencoderKL"
18
- ]
19
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": [
3
+ "pipeline",
4
+ "SiTPipeline"
5
+ ],
6
+ "_diffusers_version": "0.36.0",
7
+ "scheduler": [
8
+ "diffusers",
9
+ "FlowMatchEulerDiscreteScheduler"
10
+ ],
11
+ "transformer": [
12
+ "transformer_sit",
13
+ "SiTTransformer2DModel"
14
+ ],
15
+ "vae": [
16
+ "diffusers",
17
+ "AutoencoderKL"
18
+ ],
19
+ "id2label": {
20
+ "0": "tench, Tinca tinca",
21
+ "1": "goldfish, Carassius auratus",
22
+ "2": "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias",
23
+ "3": "tiger shark, Galeocerdo cuvieri",
24
+ "4": "hammerhead, hammerhead shark",
25
+ "5": "electric ray, crampfish, numbfish, torpedo",
26
+ "6": "stingray",
27
+ "7": "cock",
28
+ "8": "hen",
29
+ "9": "ostrich, Struthio camelus",
30
+ "10": "brambling, Fringilla montifringilla",
31
+ "11": "goldfinch, Carduelis carduelis",
32
+ "12": "house finch, linnet, Carpodacus mexicanus",
33
+ "13": "junco, snowbird",
34
+ "14": "indigo bunting, indigo finch, indigo bird, Passerina cyanea",
35
+ "15": "robin, American robin, Turdus migratorius",
36
+ "16": "bulbul",
37
+ "17": "jay",
38
+ "18": "magpie",
39
+ "19": "chickadee",
40
+ "20": "water ouzel, dipper",
41
+ "21": "kite",
42
+ "22": "bald eagle, American eagle, Haliaeetus leucocephalus",
43
+ "23": "vulture",
44
+ "24": "great grey owl, great gray owl, Strix nebulosa",
45
+ "25": "European fire salamander, Salamandra salamandra",
46
+ "26": "common newt, Triturus vulgaris",
47
+ "27": "eft",
48
+ "28": "spotted salamander, Ambystoma maculatum",
49
+ "29": "axolotl, mud puppy, Ambystoma mexicanum",
50
+ "30": "bullfrog, Rana catesbeiana",
51
+ "31": "tree frog, tree-frog",
52
+ "32": "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui",
53
+ "33": "loggerhead, loggerhead turtle, Caretta caretta",
54
+ "34": "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea",
55
+ "35": "mud turtle",
56
+ "36": "terrapin",
57
+ "37": "box turtle, box tortoise",
58
+ "38": "banded gecko",
59
+ "39": "common iguana, iguana, Iguana iguana",
60
+ "40": "American chameleon, anole, Anolis carolinensis",
61
+ "41": "whiptail, whiptail lizard",
62
+ "42": "agama",
63
+ "43": "frilled lizard, Chlamydosaurus kingi",
64
+ "44": "alligator lizard",
65
+ "45": "Gila monster, Heloderma suspectum",
66
+ "46": "green lizard, Lacerta viridis",
67
+ "47": "African chameleon, Chamaeleo chamaeleon",
68
+ "48": "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis",
69
+ "49": "African crocodile, Nile crocodile, Crocodylus niloticus",
70
+ "50": "American alligator, Alligator mississipiensis",
71
+ "51": "triceratops",
72
+ "52": "thunder snake, worm snake, Carphophis amoenus",
73
+ "53": "ringneck snake, ring-necked snake, ring snake",
74
+ "54": "hognose snake, puff adder, sand viper",
75
+ "55": "green snake, grass snake",
76
+ "56": "king snake, kingsnake",
77
+ "57": "garter snake, grass snake",
78
+ "58": "water snake",
79
+ "59": "vine snake",
80
+ "60": "night snake, Hypsiglena torquata",
81
+ "61": "boa constrictor, Constrictor constrictor",
82
+ "62": "rock python, rock snake, Python sebae",
83
+ "63": "Indian cobra, Naja naja",
84
+ "64": "green mamba",
85
+ "65": "sea snake",
86
+ "66": "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus",
87
+ "67": "diamondback, diamondback rattlesnake, Crotalus adamanteus",
88
+ "68": "sidewinder, horned rattlesnake, Crotalus cerastes",
89
+ "69": "trilobite",
90
+ "70": "harvestman, daddy longlegs, Phalangium opilio",
91
+ "71": "scorpion",
92
+ "72": "black and gold garden spider, Argiope aurantia",
93
+ "73": "barn spider, Araneus cavaticus",
94
+ "74": "garden spider, Aranea diademata",
95
+ "75": "black widow, Latrodectus mactans",
96
+ "76": "tarantula",
97
+ "77": "wolf spider, hunting spider",
98
+ "78": "tick",
99
+ "79": "centipede",
100
+ "80": "black grouse",
101
+ "81": "ptarmigan",
102
+ "82": "ruffed grouse, partridge, Bonasa umbellus",
103
+ "83": "prairie chicken, prairie grouse, prairie fowl",
104
+ "84": "peacock",
105
+ "85": "quail",
106
+ "86": "partridge",
107
+ "87": "African grey, African gray, Psittacus erithacus",
108
+ "88": "macaw",
109
+ "89": "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita",
110
+ "90": "lorikeet",
111
+ "91": "coucal",
112
+ "92": "bee eater",
113
+ "93": "hornbill",
114
+ "94": "hummingbird",
115
+ "95": "jacamar",
116
+ "96": "toucan",
117
+ "97": "drake",
118
+ "98": "red-breasted merganser, Mergus serrator",
119
+ "99": "goose",
120
+ "100": "black swan, Cygnus atratus",
121
+ "101": "tusker",
122
+ "102": "echidna, spiny anteater, anteater",
123
+ "103": "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus",
124
+ "104": "wallaby, brush kangaroo",
125
+ "105": "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus",
126
+ "106": "wombat",
127
+ "107": "jellyfish",
128
+ "108": "sea anemone, anemone",
129
+ "109": "brain coral",
130
+ "110": "flatworm, platyhelminth",
131
+ "111": "nematode, nematode worm, roundworm",
132
+ "112": "conch",
133
+ "113": "snail",
134
+ "114": "slug",
135
+ "115": "sea slug, nudibranch",
136
+ "116": "chiton, coat-of-mail shell, sea cradle, polyplacophore",
137
+ "117": "chambered nautilus, pearly nautilus, nautilus",
138
+ "118": "Dungeness crab, Cancer magister",
139
+ "119": "rock crab, Cancer irroratus",
140
+ "120": "fiddler crab",
141
+ "121": "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica",
142
+ "122": "American lobster, Northern lobster, Maine lobster, Homarus americanus",
143
+ "123": "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish",
144
+ "124": "crayfish, crawfish, crawdad, crawdaddy",
145
+ "125": "hermit crab",
146
+ "126": "isopod",
147
+ "127": "white stork, Ciconia ciconia",
148
+ "128": "black stork, Ciconia nigra",
149
+ "129": "spoonbill",
150
+ "130": "flamingo",
151
+ "131": "little blue heron, Egretta caerulea",
152
+ "132": "American egret, great white heron, Egretta albus",
153
+ "133": "bittern",
154
+ "134": "crane",
155
+ "135": "limpkin, Aramus pictus",
156
+ "136": "European gallinule, Porphyrio porphyrio",
157
+ "137": "American coot, marsh hen, mud hen, water hen, Fulica americana",
158
+ "138": "bustard",
159
+ "139": "ruddy turnstone, Arenaria interpres",
160
+ "140": "red-backed sandpiper, dunlin, Erolia alpina",
161
+ "141": "redshank, Tringa totanus",
162
+ "142": "dowitcher",
163
+ "143": "oystercatcher, oyster catcher",
164
+ "144": "pelican",
165
+ "145": "king penguin, Aptenodytes patagonica",
166
+ "146": "albatross, mollymawk",
167
+ "147": "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus",
168
+ "148": "killer whale, killer, orca, grampus, sea wolf, Orcinus orca",
169
+ "149": "dugong, Dugong dugon",
170
+ "150": "sea lion",
171
+ "151": "Chihuahua",
172
+ "152": "Japanese spaniel",
173
+ "153": "Maltese dog, Maltese terrier, Maltese",
174
+ "154": "Pekinese, Pekingese, Peke",
175
+ "155": "Shih-Tzu",
176
+ "156": "Blenheim spaniel",
177
+ "157": "papillon",
178
+ "158": "toy terrier",
179
+ "159": "Rhodesian ridgeback",
180
+ "160": "Afghan hound, Afghan",
181
+ "161": "basset, basset hound",
182
+ "162": "beagle",
183
+ "163": "bloodhound, sleuthhound",
184
+ "164": "bluetick",
185
+ "165": "black-and-tan coonhound",
186
+ "166": "Walker hound, Walker foxhound",
187
+ "167": "English foxhound",
188
+ "168": "redbone",
189
+ "169": "borzoi, Russian wolfhound",
190
+ "170": "Irish wolfhound",
191
+ "171": "Italian greyhound",
192
+ "172": "whippet",
193
+ "173": "Ibizan hound, Ibizan Podenco",
194
+ "174": "Norwegian elkhound, elkhound",
195
+ "175": "otterhound, otter hound",
196
+ "176": "Saluki, gazelle hound",
197
+ "177": "Scottish deerhound, deerhound",
198
+ "178": "Weimaraner",
199
+ "179": "Staffordshire bullterrier, Staffordshire bull terrier",
200
+ "180": "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier",
201
+ "181": "Bedlington terrier",
202
+ "182": "Border terrier",
203
+ "183": "Kerry blue terrier",
204
+ "184": "Irish terrier",
205
+ "185": "Norfolk terrier",
206
+ "186": "Norwich terrier",
207
+ "187": "Yorkshire terrier",
208
+ "188": "wire-haired fox terrier",
209
+ "189": "Lakeland terrier",
210
+ "190": "Sealyham terrier, Sealyham",
211
+ "191": "Airedale, Airedale terrier",
212
+ "192": "cairn, cairn terrier",
213
+ "193": "Australian terrier",
214
+ "194": "Dandie Dinmont, Dandie Dinmont terrier",
215
+ "195": "Boston bull, Boston terrier",
216
+ "196": "miniature schnauzer",
217
+ "197": "giant schnauzer",
218
+ "198": "standard schnauzer",
219
+ "199": "Scotch terrier, Scottish terrier, Scottie",
220
+ "200": "Tibetan terrier, chrysanthemum dog",
221
+ "201": "silky terrier, Sydney silky",
222
+ "202": "soft-coated wheaten terrier",
223
+ "203": "West Highland white terrier",
224
+ "204": "Lhasa, Lhasa apso",
225
+ "205": "flat-coated retriever",
226
+ "206": "curly-coated retriever",
227
+ "207": "golden retriever",
228
+ "208": "Labrador retriever",
229
+ "209": "Chesapeake Bay retriever",
230
+ "210": "German short-haired pointer",
231
+ "211": "vizsla, Hungarian pointer",
232
+ "212": "English setter",
233
+ "213": "Irish setter, red setter",
234
+ "214": "Gordon setter",
235
+ "215": "Brittany spaniel",
236
+ "216": "clumber, clumber spaniel",
237
+ "217": "English springer, English springer spaniel",
238
+ "218": "Welsh springer spaniel",
239
+ "219": "cocker spaniel, English cocker spaniel, cocker",
240
+ "220": "Sussex spaniel",
241
+ "221": "Irish water spaniel",
242
+ "222": "kuvasz",
243
+ "223": "schipperke",
244
+ "224": "groenendael",
245
+ "225": "malinois",
246
+ "226": "briard",
247
+ "227": "kelpie",
248
+ "228": "komondor",
249
+ "229": "Old English sheepdog, bobtail",
250
+ "230": "Shetland sheepdog, Shetland sheep dog, Shetland",
251
+ "231": "collie",
252
+ "232": "Border collie",
253
+ "233": "Bouvier des Flandres, Bouviers des Flandres",
254
+ "234": "Rottweiler",
255
+ "235": "German shepherd, German shepherd dog, German police dog, alsatian",
256
+ "236": "Doberman, Doberman pinscher",
257
+ "237": "miniature pinscher",
258
+ "238": "Greater Swiss Mountain dog",
259
+ "239": "Bernese mountain dog",
260
+ "240": "Appenzeller",
261
+ "241": "EntleBucher",
262
+ "242": "boxer",
263
+ "243": "bull mastiff",
264
+ "244": "Tibetan mastiff",
265
+ "245": "French bulldog",
266
+ "246": "Great Dane",
267
+ "247": "Saint Bernard, St Bernard",
268
+ "248": "Eskimo dog, husky",
269
+ "249": "malamute, malemute, Alaskan malamute",
270
+ "250": "Siberian husky",
271
+ "251": "dalmatian, coach dog, carriage dog",
272
+ "252": "affenpinscher, monkey pinscher, monkey dog",
273
+ "253": "basenji",
274
+ "254": "pug, pug-dog",
275
+ "255": "Leonberg",
276
+ "256": "Newfoundland, Newfoundland dog",
277
+ "257": "Great Pyrenees",
278
+ "258": "Samoyed, Samoyede",
279
+ "259": "Pomeranian",
280
+ "260": "chow, chow chow",
281
+ "261": "keeshond",
282
+ "262": "Brabancon griffon",
283
+ "263": "Pembroke, Pembroke Welsh corgi",
284
+ "264": "Cardigan, Cardigan Welsh corgi",
285
+ "265": "toy poodle",
286
+ "266": "miniature poodle",
287
+ "267": "standard poodle",
288
+ "268": "Mexican hairless",
289
+ "269": "timber wolf, grey wolf, gray wolf, Canis lupus",
290
+ "270": "white wolf, Arctic wolf, Canis lupus tundrarum",
291
+ "271": "red wolf, maned wolf, Canis rufus, Canis niger",
292
+ "272": "coyote, prairie wolf, brush wolf, Canis latrans",
293
+ "273": "dingo, warrigal, warragal, Canis dingo",
294
+ "274": "dhole, Cuon alpinus",
295
+ "275": "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus",
296
+ "276": "hyena, hyaena",
297
+ "277": "red fox, Vulpes vulpes",
298
+ "278": "kit fox, Vulpes macrotis",
299
+ "279": "Arctic fox, white fox, Alopex lagopus",
300
+ "280": "grey fox, gray fox, Urocyon cinereoargenteus",
301
+ "281": "tabby, tabby cat",
302
+ "282": "tiger cat",
303
+ "283": "Persian cat",
304
+ "284": "Siamese cat, Siamese",
305
+ "285": "Egyptian cat",
306
+ "286": "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor",
307
+ "287": "lynx, catamount",
308
+ "288": "leopard, Panthera pardus",
309
+ "289": "snow leopard, ounce, Panthera uncia",
310
+ "290": "jaguar, panther, Panthera onca, Felis onca",
311
+ "291": "lion, king of beasts, Panthera leo",
312
+ "292": "tiger, Panthera tigris",
313
+ "293": "cheetah, chetah, Acinonyx jubatus",
314
+ "294": "brown bear, bruin, Ursus arctos",
315
+ "295": "American black bear, black bear, Ursus americanus, Euarctos americanus",
316
+ "296": "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus",
317
+ "297": "sloth bear, Melursus ursinus, Ursus ursinus",
318
+ "298": "mongoose",
319
+ "299": "meerkat, mierkat",
320
+ "300": "tiger beetle",
321
+ "301": "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle",
322
+ "302": "ground beetle, carabid beetle",
323
+ "303": "long-horned beetle, longicorn, longicorn beetle",
324
+ "304": "leaf beetle, chrysomelid",
325
+ "305": "dung beetle",
326
+ "306": "rhinoceros beetle",
327
+ "307": "weevil",
328
+ "308": "fly",
329
+ "309": "bee",
330
+ "310": "ant, emmet, pismire",
331
+ "311": "grasshopper, hopper",
332
+ "312": "cricket",
333
+ "313": "walking stick, walkingstick, stick insect",
334
+ "314": "cockroach, roach",
335
+ "315": "mantis, mantid",
336
+ "316": "cicada, cicala",
337
+ "317": "leafhopper",
338
+ "318": "lacewing, lacewing fly",
339
+ "319": "dragonfly, darning needle, devils darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk",
340
+ "320": "damselfly",
341
+ "321": "admiral",
342
+ "322": "ringlet, ringlet butterfly",
343
+ "323": "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus",
344
+ "324": "cabbage butterfly",
345
+ "325": "sulphur butterfly, sulfur butterfly",
346
+ "326": "lycaenid, lycaenid butterfly",
347
+ "327": "starfish, sea star",
348
+ "328": "sea urchin",
349
+ "329": "sea cucumber, holothurian",
350
+ "330": "wood rabbit, cottontail, cottontail rabbit",
351
+ "331": "hare",
352
+ "332": "Angora, Angora rabbit",
353
+ "333": "hamster",
354
+ "334": "porcupine, hedgehog",
355
+ "335": "fox squirrel, eastern fox squirrel, Sciurus niger",
356
+ "336": "marmot",
357
+ "337": "beaver",
358
+ "338": "guinea pig, Cavia cobaya",
359
+ "339": "sorrel",
360
+ "340": "zebra",
361
+ "341": "hog, pig, grunter, squealer, Sus scrofa",
362
+ "342": "wild boar, boar, Sus scrofa",
363
+ "343": "warthog",
364
+ "344": "hippopotamus, hippo, river horse, Hippopotamus amphibius",
365
+ "345": "ox",
366
+ "346": "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis",
367
+ "347": "bison",
368
+ "348": "ram, tup",
369
+ "349": "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis",
370
+ "350": "ibex, Capra ibex",
371
+ "351": "hartebeest",
372
+ "352": "impala, Aepyceros melampus",
373
+ "353": "gazelle",
374
+ "354": "Arabian camel, dromedary, Camelus dromedarius",
375
+ "355": "llama",
376
+ "356": "weasel",
377
+ "357": "mink",
378
+ "358": "polecat, fitch, foulmart, foumart, Mustela putorius",
379
+ "359": "black-footed ferret, ferret, Mustela nigripes",
380
+ "360": "otter",
381
+ "361": "skunk, polecat, wood pussy",
382
+ "362": "badger",
383
+ "363": "armadillo",
384
+ "364": "three-toed sloth, ai, Bradypus tridactylus",
385
+ "365": "orangutan, orang, orangutang, Pongo pygmaeus",
386
+ "366": "gorilla, Gorilla gorilla",
387
+ "367": "chimpanzee, chimp, Pan troglodytes",
388
+ "368": "gibbon, Hylobates lar",
389
+ "369": "siamang, Hylobates syndactylus, Symphalangus syndactylus",
390
+ "370": "guenon, guenon monkey",
391
+ "371": "patas, hussar monkey, Erythrocebus patas",
392
+ "372": "baboon",
393
+ "373": "macaque",
394
+ "374": "langur",
395
+ "375": "colobus, colobus monkey",
396
+ "376": "proboscis monkey, Nasalis larvatus",
397
+ "377": "marmoset",
398
+ "378": "capuchin, ringtail, Cebus capucinus",
399
+ "379": "howler monkey, howler",
400
+ "380": "titi, titi monkey",
401
+ "381": "spider monkey, Ateles geoffroyi",
402
+ "382": "squirrel monkey, Saimiri sciureus",
403
+ "383": "Madagascar cat, ring-tailed lemur, Lemur catta",
404
+ "384": "indri, indris, Indri indri, Indri brevicaudatus",
405
+ "385": "Indian elephant, Elephas maximus",
406
+ "386": "African elephant, Loxodonta africana",
407
+ "387": "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens",
408
+ "388": "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca",
409
+ "389": "barracouta, snoek",
410
+ "390": "eel",
411
+ "391": "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch",
412
+ "392": "rock beauty, Holocanthus tricolor",
413
+ "393": "anemone fish",
414
+ "394": "sturgeon",
415
+ "395": "gar, garfish, garpike, billfish, Lepisosteus osseus",
416
+ "396": "lionfish",
417
+ "397": "puffer, pufferfish, blowfish, globefish",
418
+ "398": "abacus",
419
+ "399": "abaya",
420
+ "400": "academic gown, academic robe, judge robe",
421
+ "401": "accordion, piano accordion, squeeze box",
422
+ "402": "acoustic guitar",
423
+ "403": "aircraft carrier, carrier, flattop, attack aircraft carrier",
424
+ "404": "airliner",
425
+ "405": "airship, dirigible",
426
+ "406": "altar",
427
+ "407": "ambulance",
428
+ "408": "amphibian, amphibious vehicle",
429
+ "409": "analog clock",
430
+ "410": "apiary, bee house",
431
+ "411": "apron",
432
+ "412": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin",
433
+ "413": "assault rifle, assault gun",
434
+ "414": "backpack, back pack, knapsack, packsack, rucksack, haversack",
435
+ "415": "bakery, bakeshop, bakehouse",
436
+ "416": "balance beam, beam",
437
+ "417": "balloon",
438
+ "418": "ballpoint, ballpoint pen, ballpen, Biro",
439
+ "419": "Band Aid",
440
+ "420": "banjo",
441
+ "421": "bannister, banister, balustrade, balusters, handrail",
442
+ "422": "barbell",
443
+ "423": "barber chair",
444
+ "424": "barbershop",
445
+ "425": "barn",
446
+ "426": "barometer",
447
+ "427": "barrel, cask",
448
+ "428": "barrow, garden cart, lawn cart, wheelbarrow",
449
+ "429": "baseball",
450
+ "430": "basketball",
451
+ "431": "bassinet",
452
+ "432": "bassoon",
453
+ "433": "bathing cap, swimming cap",
454
+ "434": "bath towel",
455
+ "435": "bathtub, bathing tub, bath, tub",
456
+ "436": "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon",
457
+ "437": "beacon, lighthouse, beacon light, pharos",
458
+ "438": "beaker",
459
+ "439": "bearskin, busby, shako",
460
+ "440": "beer bottle",
461
+ "441": "beer glass",
462
+ "442": "bell cote, bell cot",
463
+ "443": "bib",
464
+ "444": "bicycle-built-for-two, tandem bicycle, tandem",
465
+ "445": "bikini, two-piece",
466
+ "446": "binder, ring-binder",
467
+ "447": "binoculars, field glasses, opera glasses",
468
+ "448": "birdhouse",
469
+ "449": "boathouse",
470
+ "450": "bobsled, bobsleigh, bob",
471
+ "451": "bolo tie, bolo, bola tie, bola",
472
+ "452": "bonnet, poke bonnet",
473
+ "453": "bookcase",
474
+ "454": "bookshop, bookstore, bookstall",
475
+ "455": "bottlecap",
476
+ "456": "bow",
477
+ "457": "bow tie, bow-tie, bowtie",
478
+ "458": "brass, memorial tablet, plaque",
479
+ "459": "brassiere, bra, bandeau",
480
+ "460": "breakwater, groin, groyne, mole, bulwark, seawall, jetty",
481
+ "461": "breastplate, aegis, egis",
482
+ "462": "broom",
483
+ "463": "bucket, pail",
484
+ "464": "buckle",
485
+ "465": "bulletproof vest",
486
+ "466": "bullet train, bullet",
487
+ "467": "butcher shop, meat market",
488
+ "468": "cab, hack, taxi, taxicab",
489
+ "469": "caldron, cauldron",
490
+ "470": "candle, taper, wax light",
491
+ "471": "cannon",
492
+ "472": "canoe",
493
+ "473": "can opener, tin opener",
494
+ "474": "cardigan",
495
+ "475": "car mirror",
496
+ "476": "carousel, carrousel, merry-go-round, roundabout, whirligig",
497
+ "477": "carpenters kit, tool kit",
498
+ "478": "carton",
499
+ "479": "car wheel",
500
+ "480": "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM",
501
+ "481": "cassette",
502
+ "482": "cassette player",
503
+ "483": "castle",
504
+ "484": "catamaran",
505
+ "485": "CD player",
506
+ "486": "cello, violoncello",
507
+ "487": "cellular telephone, cellular phone, cellphone, cell, mobile phone",
508
+ "488": "chain",
509
+ "489": "chainlink fence",
510
+ "490": "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour",
511
+ "491": "chain saw, chainsaw",
512
+ "492": "chest",
513
+ "493": "chiffonier, commode",
514
+ "494": "chime, bell, gong",
515
+ "495": "china cabinet, china closet",
516
+ "496": "Christmas stocking",
517
+ "497": "church, church building",
518
+ "498": "cinema, movie theater, movie theatre, movie house, picture palace",
519
+ "499": "cleaver, meat cleaver, chopper",
520
+ "500": "cliff dwelling",
521
+ "501": "cloak",
522
+ "502": "clog, geta, patten, sabot",
523
+ "503": "cocktail shaker",
524
+ "504": "coffee mug",
525
+ "505": "coffeepot",
526
+ "506": "coil, spiral, volute, whorl, helix",
527
+ "507": "combination lock",
528
+ "508": "computer keyboard, keypad",
529
+ "509": "confectionery, confectionary, candy store",
530
+ "510": "container ship, containership, container vessel",
531
+ "511": "convertible",
532
+ "512": "corkscrew, bottle screw",
533
+ "513": "cornet, horn, trumpet, trump",
534
+ "514": "cowboy boot",
535
+ "515": "cowboy hat, ten-gallon hat",
536
+ "516": "cradle",
537
+ "517": "crane",
538
+ "518": "crash helmet",
539
+ "519": "crate",
540
+ "520": "crib, cot",
541
+ "521": "Crock Pot",
542
+ "522": "croquet ball",
543
+ "523": "crutch",
544
+ "524": "cuirass",
545
+ "525": "dam, dike, dyke",
546
+ "526": "desk",
547
+ "527": "desktop computer",
548
+ "528": "dial telephone, dial phone",
549
+ "529": "diaper, nappy, napkin",
550
+ "530": "digital clock",
551
+ "531": "digital watch",
552
+ "532": "dining table, board",
553
+ "533": "dishrag, dishcloth",
554
+ "534": "dishwasher, dish washer, dishwashing machine",
555
+ "535": "disk brake, disc brake",
556
+ "536": "dock, dockage, docking facility",
557
+ "537": "dogsled, dog sled, dog sleigh",
558
+ "538": "dome",
559
+ "539": "doormat, welcome mat",
560
+ "540": "drilling platform, offshore rig",
561
+ "541": "drum, membranophone, tympan",
562
+ "542": "drumstick",
563
+ "543": "dumbbell",
564
+ "544": "Dutch oven",
565
+ "545": "electric fan, blower",
566
+ "546": "electric guitar",
567
+ "547": "electric locomotive",
568
+ "548": "entertainment center",
569
+ "549": "envelope",
570
+ "550": "espresso maker",
571
+ "551": "face powder",
572
+ "552": "feather boa, boa",
573
+ "553": "file, file cabinet, filing cabinet",
574
+ "554": "fireboat",
575
+ "555": "fire engine, fire truck",
576
+ "556": "fire screen, fireguard",
577
+ "557": "flagpole, flagstaff",
578
+ "558": "flute, transverse flute",
579
+ "559": "folding chair",
580
+ "560": "football helmet",
581
+ "561": "forklift",
582
+ "562": "fountain",
583
+ "563": "fountain pen",
584
+ "564": "four-poster",
585
+ "565": "freight car",
586
+ "566": "French horn, horn",
587
+ "567": "frying pan, frypan, skillet",
588
+ "568": "fur coat",
589
+ "569": "garbage truck, dustcart",
590
+ "570": "gasmask, respirator, gas helmet",
591
+ "571": "gas pump, gasoline pump, petrol pump, island dispenser",
592
+ "572": "goblet",
593
+ "573": "go-kart",
594
+ "574": "golf ball",
595
+ "575": "golfcart, golf cart",
596
+ "576": "gondola",
597
+ "577": "gong, tam-tam",
598
+ "578": "gown",
599
+ "579": "grand piano, grand",
600
+ "580": "greenhouse, nursery, glasshouse",
601
+ "581": "grille, radiator grille",
602
+ "582": "grocery store, grocery, food market, market",
603
+ "583": "guillotine",
604
+ "584": "hair slide",
605
+ "585": "hair spray",
606
+ "586": "half track",
607
+ "587": "hammer",
608
+ "588": "hamper",
609
+ "589": "hand blower, blow dryer, blow drier, hair dryer, hair drier",
610
+ "590": "hand-held computer, hand-held microcomputer",
611
+ "591": "handkerchief, hankie, hanky, hankey",
612
+ "592": "hard disc, hard disk, fixed disk",
613
+ "593": "harmonica, mouth organ, harp, mouth harp",
614
+ "594": "harp",
615
+ "595": "harvester, reaper",
616
+ "596": "hatchet",
617
+ "597": "holster",
618
+ "598": "home theater, home theatre",
619
+ "599": "honeycomb",
620
+ "600": "hook, claw",
621
+ "601": "hoopskirt, crinoline",
622
+ "602": "horizontal bar, high bar",
623
+ "603": "horse cart, horse-cart",
624
+ "604": "hourglass",
625
+ "605": "iPod",
626
+ "606": "iron, smoothing iron",
627
+ "607": "jack-o-lantern",
628
+ "608": "jean, blue jean, denim",
629
+ "609": "jeep, landrover",
630
+ "610": "jersey, T-shirt, tee shirt",
631
+ "611": "jigsaw puzzle",
632
+ "612": "jinrikisha, ricksha, rickshaw",
633
+ "613": "joystick",
634
+ "614": "kimono",
635
+ "615": "knee pad",
636
+ "616": "knot",
637
+ "617": "lab coat, laboratory coat",
638
+ "618": "ladle",
639
+ "619": "lampshade, lamp shade",
640
+ "620": "laptop, laptop computer",
641
+ "621": "lawn mower, mower",
642
+ "622": "lens cap, lens cover",
643
+ "623": "letter opener, paper knife, paperknife",
644
+ "624": "library",
645
+ "625": "lifeboat",
646
+ "626": "lighter, light, igniter, ignitor",
647
+ "627": "limousine, limo",
648
+ "628": "liner, ocean liner",
649
+ "629": "lipstick, lip rouge",
650
+ "630": "Loafer",
651
+ "631": "lotion",
652
+ "632": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system",
653
+ "633": "loupe, jewelers loupe",
654
+ "634": "lumbermill, sawmill",
655
+ "635": "magnetic compass",
656
+ "636": "mailbag, postbag",
657
+ "637": "mailbox, letter box",
658
+ "638": "maillot",
659
+ "639": "maillot, tank suit",
660
+ "640": "manhole cover",
661
+ "641": "maraca",
662
+ "642": "marimba, xylophone",
663
+ "643": "mask",
664
+ "644": "matchstick",
665
+ "645": "maypole",
666
+ "646": "maze, labyrinth",
667
+ "647": "measuring cup",
668
+ "648": "medicine chest, medicine cabinet",
669
+ "649": "megalith, megalithic structure",
670
+ "650": "microphone, mike",
671
+ "651": "microwave, microwave oven",
672
+ "652": "military uniform",
673
+ "653": "milk can",
674
+ "654": "minibus",
675
+ "655": "miniskirt, mini",
676
+ "656": "minivan",
677
+ "657": "missile",
678
+ "658": "mitten",
679
+ "659": "mixing bowl",
680
+ "660": "mobile home, manufactured home",
681
+ "661": "Model T",
682
+ "662": "modem",
683
+ "663": "monastery",
684
+ "664": "monitor",
685
+ "665": "moped",
686
+ "666": "mortar",
687
+ "667": "mortarboard",
688
+ "668": "mosque",
689
+ "669": "mosquito net",
690
+ "670": "motor scooter, scooter",
691
+ "671": "mountain bike, all-terrain bike, off-roader",
692
+ "672": "mountain tent",
693
+ "673": "mouse, computer mouse",
694
+ "674": "mousetrap",
695
+ "675": "moving van",
696
+ "676": "muzzle",
697
+ "677": "nail",
698
+ "678": "neck brace",
699
+ "679": "necklace",
700
+ "680": "nipple",
701
+ "681": "notebook, notebook computer",
702
+ "682": "obelisk",
703
+ "683": "oboe, hautboy, hautbois",
704
+ "684": "ocarina, sweet potato",
705
+ "685": "odometer, hodometer, mileometer, milometer",
706
+ "686": "oil filter",
707
+ "687": "organ, pipe organ",
708
+ "688": "oscilloscope, scope, cathode-ray oscilloscope, CRO",
709
+ "689": "overskirt",
710
+ "690": "oxcart",
711
+ "691": "oxygen mask",
712
+ "692": "packet",
713
+ "693": "paddle, boat paddle",
714
+ "694": "paddlewheel, paddle wheel",
715
+ "695": "padlock",
716
+ "696": "paintbrush",
717
+ "697": "pajama, pyjama, pjs, jammies",
718
+ "698": "palace",
719
+ "699": "panpipe, pandean pipe, syrinx",
720
+ "700": "paper towel",
721
+ "701": "parachute, chute",
722
+ "702": "parallel bars, bars",
723
+ "703": "park bench",
724
+ "704": "parking meter",
725
+ "705": "passenger car, coach, carriage",
726
+ "706": "patio, terrace",
727
+ "707": "pay-phone, pay-station",
728
+ "708": "pedestal, plinth, footstall",
729
+ "709": "pencil box, pencil case",
730
+ "710": "pencil sharpener",
731
+ "711": "perfume, essence",
732
+ "712": "Petri dish",
733
+ "713": "photocopier",
734
+ "714": "pick, plectrum, plectron",
735
+ "715": "pickelhaube",
736
+ "716": "picket fence, paling",
737
+ "717": "pickup, pickup truck",
738
+ "718": "pier",
739
+ "719": "piggy bank, penny bank",
740
+ "720": "pill bottle",
741
+ "721": "pillow",
742
+ "722": "ping-pong ball",
743
+ "723": "pinwheel",
744
+ "724": "pirate, pirate ship",
745
+ "725": "pitcher, ewer",
746
+ "726": "plane, carpenters plane, woodworking plane",
747
+ "727": "planetarium",
748
+ "728": "plastic bag",
749
+ "729": "plate rack",
750
+ "730": "plow, plough",
751
+ "731": "plunger, plumbers helper",
752
+ "732": "Polaroid camera, Polaroid Land camera",
753
+ "733": "pole",
754
+ "734": "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria",
755
+ "735": "poncho",
756
+ "736": "pool table, billiard table, snooker table",
757
+ "737": "pop bottle, soda bottle",
758
+ "738": "pot, flowerpot",
759
+ "739": "potters wheel",
760
+ "740": "power drill",
761
+ "741": "prayer rug, prayer mat",
762
+ "742": "printer",
763
+ "743": "prison, prison house",
764
+ "744": "projectile, missile",
765
+ "745": "projector",
766
+ "746": "puck, hockey puck",
767
+ "747": "punching bag, punch bag, punching ball, punchball",
768
+ "748": "purse",
769
+ "749": "quill, quill pen",
770
+ "750": "quilt, comforter, comfort, puff",
771
+ "751": "racer, race car, racing car",
772
+ "752": "racket, racquet",
773
+ "753": "radiator",
774
+ "754": "radio, wireless",
775
+ "755": "radio telescope, radio reflector",
776
+ "756": "rain barrel",
777
+ "757": "recreational vehicle, RV, R.V.",
778
+ "758": "reel",
779
+ "759": "reflex camera",
780
+ "760": "refrigerator, icebox",
781
+ "761": "remote control, remote",
782
+ "762": "restaurant, eating house, eating place, eatery",
783
+ "763": "revolver, six-gun, six-shooter",
784
+ "764": "rifle",
785
+ "765": "rocking chair, rocker",
786
+ "766": "rotisserie",
787
+ "767": "rubber eraser, rubber, pencil eraser",
788
+ "768": "rugby ball",
789
+ "769": "rule, ruler",
790
+ "770": "running shoe",
791
+ "771": "safe",
792
+ "772": "safety pin",
793
+ "773": "saltshaker, salt shaker",
794
+ "774": "sandal",
795
+ "775": "sarong",
796
+ "776": "sax, saxophone",
797
+ "777": "scabbard",
798
+ "778": "scale, weighing machine",
799
+ "779": "school bus",
800
+ "780": "schooner",
801
+ "781": "scoreboard",
802
+ "782": "screen, CRT screen",
803
+ "783": "screw",
804
+ "784": "screwdriver",
805
+ "785": "seat belt, seatbelt",
806
+ "786": "sewing machine",
807
+ "787": "shield, buckler",
808
+ "788": "shoe shop, shoe-shop, shoe store",
809
+ "789": "shoji",
810
+ "790": "shopping basket",
811
+ "791": "shopping cart",
812
+ "792": "shovel",
813
+ "793": "shower cap",
814
+ "794": "shower curtain",
815
+ "795": "ski",
816
+ "796": "ski mask",
817
+ "797": "sleeping bag",
818
+ "798": "slide rule, slipstick",
819
+ "799": "sliding door",
820
+ "800": "slot, one-armed bandit",
821
+ "801": "snorkel",
822
+ "802": "snowmobile",
823
+ "803": "snowplow, snowplough",
824
+ "804": "soap dispenser",
825
+ "805": "soccer ball",
826
+ "806": "sock",
827
+ "807": "solar dish, solar collector, solar furnace",
828
+ "808": "sombrero",
829
+ "809": "soup bowl",
830
+ "810": "space bar",
831
+ "811": "space heater",
832
+ "812": "space shuttle",
833
+ "813": "spatula",
834
+ "814": "speedboat",
835
+ "815": "spider web, spiders web",
836
+ "816": "spindle",
837
+ "817": "sports car, sport car",
838
+ "818": "spotlight, spot",
839
+ "819": "stage",
840
+ "820": "steam locomotive",
841
+ "821": "steel arch bridge",
842
+ "822": "steel drum",
843
+ "823": "stethoscope",
844
+ "824": "stole",
845
+ "825": "stone wall",
846
+ "826": "stopwatch, stop watch",
847
+ "827": "stove",
848
+ "828": "strainer",
849
+ "829": "streetcar, tram, tramcar, trolley, trolley car",
850
+ "830": "stretcher",
851
+ "831": "studio couch, day bed",
852
+ "832": "stupa, tope",
853
+ "833": "submarine, pigboat, sub, U-boat",
854
+ "834": "suit, suit of clothes",
855
+ "835": "sundial",
856
+ "836": "sunglass",
857
+ "837": "sunglasses, dark glasses, shades",
858
+ "838": "sunscreen, sunblock, sun blocker",
859
+ "839": "suspension bridge",
860
+ "840": "swab, swob, mop",
861
+ "841": "sweatshirt",
862
+ "842": "swimming trunks, bathing trunks",
863
+ "843": "swing",
864
+ "844": "switch, electric switch, electrical switch",
865
+ "845": "syringe",
866
+ "846": "table lamp",
867
+ "847": "tank, army tank, armored combat vehicle, armoured combat vehicle",
868
+ "848": "tape player",
869
+ "849": "teapot",
870
+ "850": "teddy, teddy bear",
871
+ "851": "television, television system",
872
+ "852": "tennis ball",
873
+ "853": "thatch, thatched roof",
874
+ "854": "theater curtain, theatre curtain",
875
+ "855": "thimble",
876
+ "856": "thresher, thrasher, threshing machine",
877
+ "857": "throne",
878
+ "858": "tile roof",
879
+ "859": "toaster",
880
+ "860": "tobacco shop, tobacconist shop, tobacconist",
881
+ "861": "toilet seat",
882
+ "862": "torch",
883
+ "863": "totem pole",
884
+ "864": "tow truck, tow car, wrecker",
885
+ "865": "toyshop",
886
+ "866": "tractor",
887
+ "867": "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi",
888
+ "868": "tray",
889
+ "869": "trench coat",
890
+ "870": "tricycle, trike, velocipede",
891
+ "871": "trimaran",
892
+ "872": "tripod",
893
+ "873": "triumphal arch",
894
+ "874": "trolleybus, trolley coach, trackless trolley",
895
+ "875": "trombone",
896
+ "876": "tub, vat",
897
+ "877": "turnstile",
898
+ "878": "typewriter keyboard",
899
+ "879": "umbrella",
900
+ "880": "unicycle, monocycle",
901
+ "881": "upright, upright piano",
902
+ "882": "vacuum, vacuum cleaner",
903
+ "883": "vase",
904
+ "884": "vault",
905
+ "885": "velvet",
906
+ "886": "vending machine",
907
+ "887": "vestment",
908
+ "888": "viaduct",
909
+ "889": "violin, fiddle",
910
+ "890": "volleyball",
911
+ "891": "waffle iron",
912
+ "892": "wall clock",
913
+ "893": "wallet, billfold, notecase, pocketbook",
914
+ "894": "wardrobe, closet, press",
915
+ "895": "warplane, military plane",
916
+ "896": "washbasin, handbasin, washbowl, lavabo, wash-hand basin",
917
+ "897": "washer, automatic washer, washing machine",
918
+ "898": "water bottle",
919
+ "899": "water jug",
920
+ "900": "water tower",
921
+ "901": "whiskey jug",
922
+ "902": "whistle",
923
+ "903": "wig",
924
+ "904": "window screen",
925
+ "905": "window shade",
926
+ "906": "Windsor tie",
927
+ "907": "wine bottle",
928
+ "908": "wing",
929
+ "909": "wok",
930
+ "910": "wooden spoon",
931
+ "911": "wool, woolen, woollen",
932
+ "912": "worm fence, snake fence, snake-rail fence, Virginia fence",
933
+ "913": "wreck",
934
+ "914": "yawl",
935
+ "915": "yurt",
936
+ "916": "web site, website, internet site, site",
937
+ "917": "comic book",
938
+ "918": "crossword puzzle, crossword",
939
+ "919": "street sign",
940
+ "920": "traffic light, traffic signal, stoplight",
941
+ "921": "book jacket, dust cover, dust jacket, dust wrapper",
942
+ "922": "menu",
943
+ "923": "plate",
944
+ "924": "guacamole",
945
+ "925": "consomme",
946
+ "926": "hot pot, hotpot",
947
+ "927": "trifle",
948
+ "928": "ice cream, icecream",
949
+ "929": "ice lolly, lolly, lollipop, popsicle",
950
+ "930": "French loaf",
951
+ "931": "bagel, beigel",
952
+ "932": "pretzel",
953
+ "933": "cheeseburger",
954
+ "934": "hotdog, hot dog, red hot",
955
+ "935": "mashed potato",
956
+ "936": "head cabbage",
957
+ "937": "broccoli",
958
+ "938": "cauliflower",
959
+ "939": "zucchini, courgette",
960
+ "940": "spaghetti squash",
961
+ "941": "acorn squash",
962
+ "942": "butternut squash",
963
+ "943": "cucumber, cuke",
964
+ "944": "artichoke, globe artichoke",
965
+ "945": "bell pepper",
966
+ "946": "cardoon",
967
+ "947": "mushroom",
968
+ "948": "Granny Smith",
969
+ "949": "strawberry",
970
+ "950": "orange",
971
+ "951": "lemon",
972
+ "952": "fig",
973
+ "953": "pineapple, ananas",
974
+ "954": "banana",
975
+ "955": "jackfruit, jak, jack",
976
+ "956": "custard apple",
977
+ "957": "pomegranate",
978
+ "958": "hay",
979
+ "959": "carbonara",
980
+ "960": "chocolate sauce, chocolate syrup",
981
+ "961": "dough",
982
+ "962": "meat loaf, meatloaf",
983
+ "963": "pizza, pizza pie",
984
+ "964": "potpie",
985
+ "965": "burrito",
986
+ "966": "red wine",
987
+ "967": "espresso",
988
+ "968": "cup",
989
+ "969": "eggnog",
990
+ "970": "alp",
991
+ "971": "bubble",
992
+ "972": "cliff, drop, drop-off",
993
+ "973": "coral reef",
994
+ "974": "geyser",
995
+ "975": "lakeside, lakeshore",
996
+ "976": "promontory, headland, head, foreland",
997
+ "977": "sandbar, sand bar",
998
+ "978": "seashore, coast, seacoast, sea-coast",
999
+ "979": "valley, vale",
1000
+ "980": "volcano",
1001
+ "981": "ballplayer, baseball player",
1002
+ "982": "groom, bridegroom",
1003
+ "983": "scuba diver",
1004
+ "984": "rapeseed",
1005
+ "985": "daisy",
1006
+ "986": "yellow ladys slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum",
1007
+ "987": "corn",
1008
+ "988": "acorn",
1009
+ "989": "hip, rose hip, rosehip",
1010
+ "990": "buckeye, horse chestnut, conker",
1011
+ "991": "coral fungus",
1012
+ "992": "agaric",
1013
+ "993": "gyromitra",
1014
+ "994": "stinkhorn, carrion fungus",
1015
+ "995": "earthstar",
1016
+ "996": "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa",
1017
+ "997": "bolete",
1018
+ "998": "ear, spike, capitulum",
1019
+ "999": "toilet tissue, toilet paper, bathroom tissue"
1020
+ }
1021
+ }
SiT-S-2-256/pipeline.py CHANGED
@@ -1,82 +1,349 @@
1
- from typing import List, Optional, Union
2
-
3
- import torch
4
-
5
- from diffusers.image_processor import VaeImageProcessor
6
- from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
7
- from diffusers.utils.torch_utils import randn_tensor
8
-
9
-
10
- class SiTPipeline(DiffusionPipeline):
11
- model_cpu_offload_seq = "transformer->vae"
12
-
13
- def __init__(self, transformer, scheduler, vae):
14
- super().__init__()
15
- self.register_modules(transformer=transformer, scheduler=scheduler, vae=vae)
16
- self.vae_scale_factor = 8
17
- self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
18
-
19
- @torch.no_grad()
20
- def __call__(
21
- self,
22
- class_labels: Union[int, List[int]] = 207,
23
- height: int = 256,
24
- width: int = 256,
25
- num_inference_steps: int = 250,
26
- guidance_scale: float = 4.0,
27
- generator: Optional[torch.Generator] = None,
28
- output_type: str = "pil",
29
- return_dict: bool = True,
30
- ):
31
- device = self._execution_device
32
- if isinstance(class_labels, int):
33
- class_labels = [class_labels]
34
- batch_size = len(class_labels)
35
-
36
- latent_h = height // self.vae_scale_factor
37
- latent_w = width // self.vae_scale_factor
38
- latents = randn_tensor(
39
- (batch_size, self.transformer.config.in_channels, latent_h, latent_w),
40
- generator=generator,
41
- device=device,
42
- dtype=self.transformer.dtype,
43
- )
44
-
45
- labels = torch.tensor(class_labels, device=device, dtype=torch.long)
46
- do_cfg = guidance_scale is not None and guidance_scale > 1.0
47
- if do_cfg:
48
- null_label = torch.full((batch_size,), self.transformer.config.num_classes, device=device, dtype=torch.long)
49
- labels = torch.cat([labels, null_label], dim=0)
50
-
51
- self.scheduler.set_timesteps(num_inference_steps, device=device)
52
- timesteps = self.scheduler.timesteps
53
-
54
- for t in self.progress_bar(timesteps):
55
- t_batch = torch.full((batch_size,), t, device=device, dtype=latents.dtype)
56
- model_input = latents
57
- if do_cfg:
58
- model_input = torch.cat([latents, latents], dim=0)
59
- t_batch = torch.cat([t_batch, t_batch], dim=0)
60
-
61
- model_pred = self.transformer(
62
- hidden_states=model_input,
63
- timestep=t_batch,
64
- class_labels=labels,
65
- ).sample
66
-
67
- if do_cfg:
68
- cond, uncond = model_pred.chunk(2, dim=0)
69
- model_pred = uncond + guidance_scale * (cond - uncond)
70
-
71
- latents = self.scheduler.step(model_pred, t, latents, generator=generator).prev_sample
72
-
73
- image = self.vae.decode(latents / 0.18215).sample
74
- # Keep PyTorch outputs in raw VAE range [-1, 1] to match original SiT scripts.
75
- if output_type == "pt":
76
- image = image
77
- else:
78
- image = self.image_processor.postprocess(image, output_type=output_type)
79
-
80
- if not return_dict:
81
- return (image,)
82
- return ImagePipelineOutput(images=image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Hub custom pipeline: SiTPipeline.
2
+ Load with native Hugging Face diffusers and trust_remote_code=True.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ # Copyright 2026 The HuggingFace Team. All rights reserved.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+
21
+ import json
22
+ from pathlib import Path
23
+ from typing import Dict, List, Optional, Tuple, Union
24
+
25
+ import torch
26
+
27
+ from diffusers.image_processor import VaeImageProcessor
28
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
29
+ from diffusers.utils.torch_utils import randn_tensor
30
+
31
+ EXAMPLE_DOC_STRING = """
32
+ Examples:
33
+ ```py
34
+ >>> from pathlib import Path
35
+ >>> from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler
36
+ >>> import torch
37
+
38
+ >>> model_dir = Path("./SiT-XL-2-256").resolve()
39
+ >>> pipe = DiffusionPipeline.from_pretrained(
40
+ ... str(model_dir),
41
+ ... local_files_only=True,
42
+ ... custom_pipeline=str(model_dir / "pipeline.py"),
43
+ ... trust_remote_code=True,
44
+ ... torch_dtype=torch.bfloat16,
45
+ ... )
46
+ >>> pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(pipe.scheduler.config)
47
+ >>> pipe.to("cuda")
48
+
49
+ >>> print(pipe.id2label[207])
50
+ >>> print(pipe.get_label_ids("golden retriever"))
51
+
52
+ >>> generator = torch.Generator(device="cuda").manual_seed(42)
53
+ >>> image = pipe(
54
+ ... class_labels="golden retriever",
55
+ ... height=256,
56
+ ... width=256,
57
+ ... num_inference_steps=250,
58
+ ... guidance_scale=4.0,
59
+ ... generator=generator,
60
+ ... ).images[0]
61
+ ```
62
+ """
63
+
64
+ class SiTPipeline(DiffusionPipeline):
65
+ r"""
66
+ Pipeline for class-conditional image generation with Scalable Interpolant Transformers (SiT).
67
+
68
+ Parameters:
69
+ transformer ([`SiTTransformer2DModel`]):
70
+ Class-conditional SiT transformer that predicts flow-matching velocity in latent space.
71
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
72
+ Flow-matching Euler scheduler. Other [`KarrasDiffusionSchedulers`] can be swapped at inference time.
73
+ vae ([`AutoencoderKL`]):
74
+ Variational autoencoder used to decode transformer latents to pixels.
75
+ id2label (`dict[int, str]`, *optional*):
76
+ ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
77
+ """
78
+
79
+ model_cpu_offload_seq = "transformer->vae"
80
+
81
+ def __init__(
82
+ self,
83
+ transformer,
84
+ scheduler,
85
+ vae,
86
+ id2label: Optional[Dict[Union[int, str], str]] = None,
87
+ ):
88
+ super().__init__()
89
+ if scheduler is None:
90
+ scheduler = FlowMatchEulerDiscreteScheduler(
91
+ num_train_timesteps=1000,
92
+ shift=1.0,
93
+ stochastic_sampling=False,
94
+ )
95
+ self.register_modules(transformer=transformer, scheduler=scheduler, vae=vae)
96
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
97
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
98
+ self._id2label = self._normalize_id2label(id2label)
99
+ self.labels = self._build_label2id(self._id2label)
100
+ self._labels_loaded_from_model_index = bool(self._id2label)
101
+
102
+ def _ensure_labels_loaded(self) -> None:
103
+ if self._labels_loaded_from_model_index:
104
+ return
105
+ loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None))
106
+ if loaded:
107
+ self._id2label = loaded
108
+ self.labels = self._build_label2id(self._id2label)
109
+ self._labels_loaded_from_model_index = True
110
+
111
+ @staticmethod
112
+ def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
113
+ if not id2label:
114
+ return {}
115
+ return {int(key): value for key, value in id2label.items()}
116
+
117
+ @staticmethod
118
+ def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]:
119
+ if not variant_path:
120
+ return {}
121
+ variant_dir = Path(variant_path).resolve()
122
+ model_index_path = variant_dir / "model_index.json"
123
+ if not model_index_path.exists():
124
+ return {}
125
+ raw = json.loads(model_index_path.read_text(encoding="utf-8"))
126
+ id2label = raw.get("id2label")
127
+ if not isinstance(id2label, dict):
128
+ return {}
129
+ return {int(key): value for key, value in id2label.items()}
130
+
131
+ @staticmethod
132
+ def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
133
+ label2id: Dict[str, int] = {}
134
+ for class_id, value in id2label.items():
135
+ for synonym in value.split(","):
136
+ synonym = synonym.strip()
137
+ if synonym:
138
+ label2id[synonym] = int(class_id)
139
+ return dict(sorted(label2id.items()))
140
+
141
+ @property
142
+ def id2label(self) -> Dict[int, str]:
143
+ r"""ImageNet class id to English label string (comma-separated synonyms)."""
144
+ self._ensure_labels_loaded()
145
+ return self._id2label
146
+
147
+ def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
148
+ r"""
149
+ Map ImageNet label strings to class ids.
150
+
151
+ Args:
152
+ label (`str` or `list[str]`):
153
+ One or more English label strings. Each string must match a synonym in `id2label`.
154
+ """
155
+ self._ensure_labels_loaded()
156
+ label2id = self.labels
157
+ if not label2id:
158
+ raise ValueError("No English labels loaded. Ensure `id2label` exists in model_index.json.")
159
+
160
+ if isinstance(label, str):
161
+ label = [label]
162
+
163
+ missing = [item for item in label if item not in label2id]
164
+ if missing:
165
+ preview = ", ".join(list(label2id.keys())[:8])
166
+ raise ValueError(f"Unknown English label(s): {missing}. Example valid labels: {preview}, ...")
167
+ return [label2id[item] for item in label]
168
+
169
+ def _normalize_class_labels(
170
+ self,
171
+ class_labels: Union[int, str, List[Union[int, str]], torch.LongTensor],
172
+ ) -> torch.LongTensor:
173
+ if torch.is_tensor(class_labels):
174
+ return class_labels.to(device=self._execution_device, dtype=torch.long).reshape(-1)
175
+
176
+ if isinstance(class_labels, int):
177
+ class_label_ids = [class_labels]
178
+ elif isinstance(class_labels, str):
179
+ class_label_ids = self.get_label_ids(class_labels)
180
+ elif class_labels and isinstance(class_labels[0], str):
181
+ class_label_ids = self.get_label_ids(class_labels)
182
+ else:
183
+ class_label_ids = list(class_labels)
184
+
185
+ return torch.tensor(class_label_ids, device=self._execution_device, dtype=torch.long).reshape(-1)
186
+
187
+ def _default_image_size(self) -> int:
188
+ return int(self.transformer.config.input_size) * self.vae_scale_factor
189
+
190
+ def check_inputs(
191
+ self,
192
+ height: int,
193
+ width: int,
194
+ num_inference_steps: int,
195
+ output_type: str,
196
+ ) -> None:
197
+ if num_inference_steps < 1:
198
+ raise ValueError("num_inference_steps must be >= 1.")
199
+ if output_type not in {"pil", "np", "pt", "latent"}:
200
+ raise ValueError("output_type must be one of: 'pil', 'np', 'pt', 'latent'.")
201
+
202
+ if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0:
203
+ raise ValueError(
204
+ f"height and width must be divisible by the VAE downsample factor {self.vae_scale_factor}."
205
+ )
206
+
207
+ latent_height = height // self.vae_scale_factor
208
+ latent_width = width // self.vae_scale_factor
209
+ expected_size = int(self.transformer.config.input_size)
210
+ if latent_height != expected_size or latent_width != expected_size:
211
+ raise ValueError(
212
+ f"Requested latent size {(latent_height, latent_width)} does not match the pretrained "
213
+ f"transformer input_size={expected_size}. Use height=width={self._default_image_size()}."
214
+ )
215
+
216
+ def prepare_latents(
217
+ self,
218
+ batch_size: int,
219
+ height: int,
220
+ width: int,
221
+ dtype: torch.dtype,
222
+ device: torch.device,
223
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]],
224
+ ) -> torch.Tensor:
225
+ latent_height = height // self.vae_scale_factor
226
+ latent_width = width // self.vae_scale_factor
227
+ return randn_tensor(
228
+ (batch_size, self.transformer.config.in_channels, latent_height, latent_width),
229
+ generator=generator,
230
+ device=device,
231
+ dtype=dtype,
232
+ )
233
+
234
+ @staticmethod
235
+ def _apply_classifier_free_guidance(model_output: torch.Tensor, guidance_scale: float) -> torch.Tensor:
236
+ if guidance_scale <= 1.0:
237
+ return model_output
238
+ model_output_cond, model_output_uncond = model_output.chunk(2)
239
+ return model_output_uncond + guidance_scale * (model_output_cond - model_output_uncond)
240
+
241
+ def decode_latents(self, latents: torch.Tensor, output_type: str = "pil"):
242
+ if output_type == "latent":
243
+ return latents
244
+
245
+ scaling_factor = getattr(self.vae.config, "scaling_factor", 0.18215)
246
+ image = self.vae.decode(latents / scaling_factor).sample
247
+ if output_type == "pt":
248
+ return image
249
+ return self.image_processor.postprocess(image, output_type=output_type)
250
+
251
+ @torch.inference_mode()
252
+ def __call__(
253
+ self,
254
+ class_labels: Union[int, str, List[Union[int, str]], torch.LongTensor],
255
+ height: Optional[int] = None,
256
+ width: Optional[int] = None,
257
+ num_inference_steps: int = 250,
258
+ guidance_scale: float = 4.0,
259
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
260
+ output_type: str = "pil",
261
+ return_dict: bool = True,
262
+ ) -> Union[ImagePipelineOutput, Tuple]:
263
+ r"""
264
+ Generate class-conditional images with SiT.
265
+
266
+ Args:
267
+ class_labels (`int`, `str`, `list[int]`, `list[str]`, or `torch.LongTensor`):
268
+ ImageNet class indices or human-readable English label strings.
269
+ height (`int`, *optional*):
270
+ Output image height in pixels. Defaults to the pretrained native resolution.
271
+ width (`int`, *optional*):
272
+ Output image width in pixels. Defaults to the pretrained native resolution.
273
+ num_inference_steps (`int`, defaults to `250`):
274
+ Number of denoising steps.
275
+ guidance_scale (`float`, defaults to `4.0`):
276
+ Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`.
277
+ generator (`torch.Generator`, *optional*):
278
+ RNG for reproducibility.
279
+ output_type (`str`, defaults to `"pil"`):
280
+ `"pil"`, `"np"`, `"pt"`, or `"latent"`.
281
+ return_dict (`bool`, defaults to `True`):
282
+ Return [`ImagePipelineOutput`] if True.
283
+ """
284
+ default_size = self._default_image_size()
285
+ height = int(height or default_size)
286
+ width = int(width or default_size)
287
+ self.check_inputs(height, width, num_inference_steps, output_type)
288
+
289
+ device = self._execution_device
290
+ model_dtype = next(self.transformer.parameters()).dtype
291
+ class_labels_tensor = self._normalize_class_labels(class_labels)
292
+ batch_size = class_labels_tensor.numel()
293
+ do_cfg = guidance_scale > 1.0
294
+
295
+ latents = self.prepare_latents(
296
+ batch_size=batch_size,
297
+ height=height,
298
+ width=width,
299
+ dtype=model_dtype,
300
+ device=device,
301
+ generator=generator,
302
+ )
303
+
304
+ labels = class_labels_tensor
305
+ if do_cfg:
306
+ null_labels = torch.full_like(class_labels_tensor, self.transformer.config.num_classes)
307
+ labels = torch.cat([class_labels_tensor, null_labels], dim=0)
308
+
309
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
310
+ num_train_timesteps = self.scheduler.config.num_train_timesteps
311
+
312
+ if getattr(self.scheduler.config, "stochastic_sampling", False):
313
+ raise ValueError(
314
+ "SiT expects deterministic FlowMatchEulerDiscreteScheduler stepping "
315
+ "(scheduler.config.stochastic_sampling=False)."
316
+ )
317
+
318
+ for t in self.progress_bar(self.scheduler.timesteps):
319
+ flow_time = 1.0 - float(t) / num_train_timesteps
320
+ if do_cfg:
321
+ model_input = torch.cat([latents, latents], dim=0)
322
+ else:
323
+ model_input = latents
324
+
325
+ timestep_batch = torch.full((model_input.shape[0],), flow_time, device=device, dtype=model_dtype)
326
+ model_output = self.transformer(
327
+ hidden_states=model_input,
328
+ timestep=timestep_batch,
329
+ class_labels=labels,
330
+ return_dict=True,
331
+ ).sample
332
+ model_output = self._apply_classifier_free_guidance(model_output, guidance_scale=guidance_scale)
333
+ # SiT predicts dx/d(flow_time) with flow_time increasing from noise (0) to data (1).
334
+ # FlowMatchEulerDiscreteScheduler integrates over sigma decreasing from 1 to 0, so flip sign.
335
+ model_output = -model_output
336
+ latents = self.scheduler.step(
337
+ model_output=model_output,
338
+ timestep=t,
339
+ sample=latents,
340
+ generator=generator,
341
+ return_dict=True,
342
+ ).prev_sample
343
+
344
+ image = self.decode_latents(latents, output_type=output_type)
345
+
346
+ self.maybe_free_model_hooks()
347
+ if not return_dict:
348
+ return (image,)
349
+ return ImagePipelineOutput(images=image)
SiT-S-2-256/scheduler/scheduler_config.json CHANGED
@@ -1,9 +1,7 @@
1
- {
2
- "_class_name": "SiTFlowMatchScheduler",
3
- "_diffusers_version": "0.36.0",
4
- "diffusion_form": "sigma",
5
- "diffusion_norm": 1.0,
6
- "mode": "ode",
7
- "num_train_timesteps": 1000,
8
- "shift": 1.0
9
- }
 
1
+ {
2
+ "_class_name": "FlowMatchEulerDiscreteScheduler",
3
+ "_diffusers_version": "0.36.0",
4
+ "num_train_timesteps": 1000,
5
+ "shift": 1.0,
6
+ "stochastic_sampling": false
7
+ }
 
 
SiT-S-2-256/transformer/transformer_sit.py CHANGED
@@ -1,224 +1,240 @@
1
- import math
2
- from dataclasses import dataclass
3
- from typing import Optional
4
-
5
- import numpy as np
6
- import torch
7
- import torch.nn as nn
8
- from timm.models.vision_transformer import Attention, Mlp, PatchEmbed
9
-
10
- from diffusers.configuration_utils import ConfigMixin, register_to_config
11
- from diffusers.models.modeling_utils import ModelMixin
12
- from diffusers.utils import BaseOutput
13
-
14
-
15
- def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
16
- return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
17
-
18
-
19
- @dataclass
20
- class SiTTransformer2DModelOutput(BaseOutput):
21
- sample: torch.Tensor
22
-
23
-
24
- class TimestepEmbedder(nn.Module):
25
- def __init__(self, hidden_size: int, frequency_embedding_size: int = 256):
26
- super().__init__()
27
- self.mlp = nn.Sequential(
28
- nn.Linear(frequency_embedding_size, hidden_size, bias=True),
29
- nn.SiLU(),
30
- nn.Linear(hidden_size, hidden_size, bias=True),
31
- )
32
- self.frequency_embedding_size = frequency_embedding_size
33
-
34
- @staticmethod
35
- def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 10000) -> torch.Tensor:
36
- half = dim // 2
37
- freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
38
- device=t.device
39
- )
40
- args = t[:, None].float() * freqs[None]
41
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
42
- if dim % 2:
43
- embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
44
- return embedding
45
-
46
- def forward(self, t: torch.Tensor) -> torch.Tensor:
47
- return self.mlp(self.timestep_embedding(t, self.frequency_embedding_size))
48
-
49
-
50
- class LabelEmbedder(nn.Module):
51
- def __init__(self, num_classes: int, hidden_size: int, dropout_prob: float):
52
- super().__init__()
53
- use_cfg_embedding = dropout_prob > 0
54
- self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
55
- self.num_classes = num_classes
56
- self.dropout_prob = dropout_prob
57
-
58
- def token_drop(self, labels: torch.Tensor, force_drop_ids: Optional[torch.Tensor] = None) -> torch.Tensor:
59
- if force_drop_ids is None:
60
- drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
61
- else:
62
- drop_ids = force_drop_ids == 1
63
- labels = torch.where(drop_ids, self.num_classes, labels)
64
- return labels
65
-
66
- def forward(
67
- self,
68
- labels: torch.Tensor,
69
- train: bool,
70
- force_drop_ids: Optional[torch.Tensor] = None,
71
- ) -> torch.Tensor:
72
- use_dropout = self.dropout_prob > 0
73
- if (train and use_dropout) or (force_drop_ids is not None):
74
- labels = self.token_drop(labels, force_drop_ids)
75
- return self.embedding_table(labels)
76
-
77
-
78
- class SiTBlock(nn.Module):
79
- def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float = 4.0, **block_kwargs):
80
- super().__init__()
81
- self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
82
- self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
83
- self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
84
- mlp_hidden_dim = int(hidden_size * mlp_ratio)
85
- approx_gelu = lambda: nn.GELU(approximate="tanh")
86
- self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
87
- self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
88
-
89
- def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
90
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
91
- x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
92
- x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
93
- return x
94
-
95
-
96
- class FinalLayer(nn.Module):
97
- def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
98
- super().__init__()
99
- self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
100
- self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
101
- self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
102
-
103
- def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
104
- shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
105
- x = modulate(self.norm_final(x), shift, scale)
106
- return self.linear(x)
107
-
108
-
109
- class SiTTransformer2DModel(ModelMixin, ConfigMixin):
110
- @register_to_config
111
- def __init__(
112
- self,
113
- input_size: int = 32,
114
- patch_size: int = 2,
115
- in_channels: int = 4,
116
- hidden_size: int = 1152,
117
- depth: int = 28,
118
- num_heads: int = 16,
119
- mlp_ratio: float = 4.0,
120
- class_dropout_prob: float = 0.1,
121
- num_classes: int = 1000,
122
- learn_sigma: bool = True,
123
- ):
124
- super().__init__()
125
- self.learn_sigma = learn_sigma
126
- self.in_channels = in_channels
127
- self.out_channels = in_channels * 2 if learn_sigma else in_channels
128
- self.patch_size = patch_size
129
- self.num_classes = num_classes
130
-
131
- self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
132
- self.t_embedder = TimestepEmbedder(hidden_size)
133
- self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
134
- num_patches = self.x_embedder.num_patches
135
- self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
136
-
137
- self.blocks = nn.ModuleList([SiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)])
138
- self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
139
- self.initialize_weights()
140
-
141
- def initialize_weights(self) -> None:
142
- def _basic_init(module: nn.Module):
143
- if isinstance(module, nn.Linear):
144
- torch.nn.init.xavier_uniform_(module.weight)
145
- if module.bias is not None:
146
- nn.init.constant_(module.bias, 0)
147
-
148
- self.apply(_basic_init)
149
- pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5))
150
- self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
151
-
152
- w = self.x_embedder.proj.weight.data
153
- nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
154
- nn.init.constant_(self.x_embedder.proj.bias, 0)
155
- nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
156
- nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
157
- nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
158
- for block in self.blocks:
159
- nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
160
- nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
161
- nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
162
- nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
163
- nn.init.constant_(self.final_layer.linear.weight, 0)
164
- nn.init.constant_(self.final_layer.linear.bias, 0)
165
-
166
- def unpatchify(self, x: torch.Tensor) -> torch.Tensor:
167
- c = self.out_channels
168
- p = self.x_embedder.patch_size[0]
169
- h = w = int(x.shape[1] ** 0.5)
170
- x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
171
- x = torch.einsum("nhwpqc->nchpwq", x)
172
- return x.reshape(shape=(x.shape[0], c, h * p, h * p))
173
-
174
- def forward(
175
- self,
176
- hidden_states: torch.Tensor,
177
- timestep: torch.Tensor,
178
- class_labels: torch.Tensor,
179
- force_drop_ids: Optional[torch.Tensor] = None,
180
- return_dict: bool = True,
181
- ) -> SiTTransformer2DModelOutput:
182
- x = self.x_embedder(hidden_states) + self.pos_embed
183
- t = self.t_embedder(timestep)
184
- y = self.y_embedder(class_labels, self.training, force_drop_ids=force_drop_ids)
185
- c = t + y
186
- for block in self.blocks:
187
- x = block(x, c)
188
- x = self.final_layer(x, c)
189
- x = self.unpatchify(x)
190
- if self.learn_sigma:
191
- x, _ = x.chunk(2, dim=1)
192
- if not return_dict:
193
- return (x,)
194
- return SiTTransformer2DModelOutput(sample=x)
195
-
196
-
197
- def get_2d_sincos_pos_embed(embed_dim: int, grid_size: int, cls_token: bool = False, extra_tokens: int = 0):
198
- grid_h = np.arange(grid_size, dtype=np.float32)
199
- grid_w = np.arange(grid_size, dtype=np.float32)
200
- grid = np.meshgrid(grid_w, grid_h)
201
- grid = np.stack(grid, axis=0).reshape([2, 1, grid_size, grid_size])
202
- pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
203
- if cls_token and extra_tokens > 0:
204
- pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
205
- return pos_embed
206
-
207
-
208
- def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: np.ndarray):
209
- assert embed_dim % 2 == 0
210
- emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
211
- emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
212
- return np.concatenate([emb_h, emb_w], axis=1)
213
-
214
-
215
- def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: np.ndarray):
216
- assert embed_dim % 2 == 0
217
- omega = np.arange(embed_dim // 2, dtype=np.float64)
218
- omega /= embed_dim / 2.0
219
- omega = 1.0 / 10000**omega
220
- pos = pos.reshape(-1)
221
- out = np.einsum("m,d->md", pos, omega)
222
- emb_sin = np.sin(out)
223
- emb_cos = np.cos(out)
224
- return np.concatenate([emb_sin, emb_cos], axis=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from dataclasses import dataclass
17
+ from typing import Optional
18
+
19
+ import numpy as np
20
+ import torch
21
+ import torch.nn as nn
22
+ from timm.models.vision_transformer import Attention, Mlp, PatchEmbed
23
+
24
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
25
+ from diffusers.models.modeling_utils import ModelMixin
26
+ from diffusers.utils import BaseOutput
27
+
28
+
29
+ def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
30
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
31
+
32
+
33
+ @dataclass
34
+ class SiTTransformer2DModelOutput(BaseOutput):
35
+ sample: torch.Tensor
36
+
37
+
38
+ class TimestepEmbedder(nn.Module):
39
+ def __init__(self, hidden_size: int, frequency_embedding_size: int = 256):
40
+ super().__init__()
41
+ self.mlp = nn.Sequential(
42
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
43
+ nn.SiLU(),
44
+ nn.Linear(hidden_size, hidden_size, bias=True),
45
+ )
46
+ self.frequency_embedding_size = frequency_embedding_size
47
+
48
+ @staticmethod
49
+ def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 10000) -> torch.Tensor:
50
+ half = dim // 2
51
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
52
+ device=t.device
53
+ )
54
+ args = t[:, None].float() * freqs[None]
55
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
56
+ if dim % 2:
57
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
58
+ return embedding
59
+
60
+ def forward(self, t: torch.Tensor) -> torch.Tensor:
61
+ emb = self.timestep_embedding(t.float(), self.frequency_embedding_size)
62
+ weight_dtype = self.mlp[0].weight.dtype
63
+ return self.mlp(emb.to(dtype=weight_dtype))
64
+
65
+
66
+ class LabelEmbedder(nn.Module):
67
+ def __init__(self, num_classes: int, hidden_size: int, dropout_prob: float):
68
+ super().__init__()
69
+ use_cfg_embedding = dropout_prob > 0
70
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
71
+ self.num_classes = num_classes
72
+ self.dropout_prob = dropout_prob
73
+
74
+ def token_drop(self, labels: torch.Tensor, force_drop_ids: Optional[torch.Tensor] = None) -> torch.Tensor:
75
+ if force_drop_ids is None:
76
+ drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
77
+ else:
78
+ drop_ids = force_drop_ids == 1
79
+ labels = torch.where(drop_ids, self.num_classes, labels)
80
+ return labels
81
+
82
+ def forward(
83
+ self,
84
+ labels: torch.Tensor,
85
+ train: bool,
86
+ force_drop_ids: Optional[torch.Tensor] = None,
87
+ ) -> torch.Tensor:
88
+ use_dropout = self.dropout_prob > 0
89
+ if (train and use_dropout) or (force_drop_ids is not None):
90
+ labels = self.token_drop(labels, force_drop_ids)
91
+ return self.embedding_table(labels)
92
+
93
+
94
+ class SiTBlock(nn.Module):
95
+ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float = 4.0, **block_kwargs):
96
+ super().__init__()
97
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
98
+ self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
99
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
100
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
101
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
102
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
103
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
104
+
105
+ def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
106
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
107
+ x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
108
+ x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
109
+ return x
110
+
111
+
112
+ class FinalLayer(nn.Module):
113
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
114
+ super().__init__()
115
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
116
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
117
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
118
+
119
+ def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
120
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
121
+ x = modulate(self.norm_final(x), shift, scale)
122
+ return self.linear(x)
123
+
124
+
125
+ class SiTTransformer2DModel(ModelMixin, ConfigMixin):
126
+ @register_to_config
127
+ def __init__(
128
+ self,
129
+ input_size: int = 32,
130
+ patch_size: int = 2,
131
+ in_channels: int = 4,
132
+ hidden_size: int = 1152,
133
+ depth: int = 28,
134
+ num_heads: int = 16,
135
+ mlp_ratio: float = 4.0,
136
+ class_dropout_prob: float = 0.1,
137
+ num_classes: int = 1000,
138
+ learn_sigma: bool = True,
139
+ ):
140
+ super().__init__()
141
+ self.learn_sigma = learn_sigma
142
+ self.in_channels = in_channels
143
+ self.out_channels = in_channels * 2 if learn_sigma else in_channels
144
+ self.patch_size = patch_size
145
+ self.num_classes = num_classes
146
+
147
+ self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
148
+ self.t_embedder = TimestepEmbedder(hidden_size)
149
+ self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
150
+ num_patches = self.x_embedder.num_patches
151
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
152
+
153
+ self.blocks = nn.ModuleList([SiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)])
154
+ self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
155
+ self.initialize_weights()
156
+
157
+ def initialize_weights(self) -> None:
158
+ def _basic_init(module: nn.Module):
159
+ if isinstance(module, nn.Linear):
160
+ torch.nn.init.xavier_uniform_(module.weight)
161
+ if module.bias is not None:
162
+ nn.init.constant_(module.bias, 0)
163
+
164
+ self.apply(_basic_init)
165
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5))
166
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
167
+
168
+ w = self.x_embedder.proj.weight.data
169
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
170
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
171
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
172
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
173
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
174
+ for block in self.blocks:
175
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
176
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
177
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
178
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
179
+ nn.init.constant_(self.final_layer.linear.weight, 0)
180
+ nn.init.constant_(self.final_layer.linear.bias, 0)
181
+
182
+ def unpatchify(self, x: torch.Tensor) -> torch.Tensor:
183
+ c = self.out_channels
184
+ p = self.x_embedder.patch_size[0]
185
+ h = w = int(x.shape[1] ** 0.5)
186
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
187
+ x = torch.einsum("nhwpqc->nchpwq", x)
188
+ return x.reshape(shape=(x.shape[0], c, h * p, h * p))
189
+
190
+ def forward(
191
+ self,
192
+ hidden_states: torch.Tensor,
193
+ timestep: torch.Tensor,
194
+ class_labels: torch.Tensor,
195
+ force_drop_ids: Optional[torch.Tensor] = None,
196
+ return_dict: bool = True,
197
+ ) -> SiTTransformer2DModelOutput:
198
+ x = self.x_embedder(hidden_states) + self.pos_embed
199
+ t = self.t_embedder(timestep)
200
+ y = self.y_embedder(class_labels, self.training, force_drop_ids=force_drop_ids)
201
+ c = t + y
202
+ for block in self.blocks:
203
+ x = block(x, c)
204
+ x = self.final_layer(x, c)
205
+ x = self.unpatchify(x)
206
+ if self.learn_sigma:
207
+ x, _ = x.chunk(2, dim=1)
208
+ if not return_dict:
209
+ return (x,)
210
+ return SiTTransformer2DModelOutput(sample=x)
211
+
212
+
213
+ def get_2d_sincos_pos_embed(embed_dim: int, grid_size: int, cls_token: bool = False, extra_tokens: int = 0):
214
+ grid_h = np.arange(grid_size, dtype=np.float32)
215
+ grid_w = np.arange(grid_size, dtype=np.float32)
216
+ grid = np.meshgrid(grid_w, grid_h)
217
+ grid = np.stack(grid, axis=0).reshape([2, 1, grid_size, grid_size])
218
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
219
+ if cls_token and extra_tokens > 0:
220
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
221
+ return pos_embed
222
+
223
+
224
+ def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: np.ndarray):
225
+ assert embed_dim % 2 == 0
226
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
227
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
228
+ return np.concatenate([emb_h, emb_w], axis=1)
229
+
230
+
231
+ def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: np.ndarray):
232
+ assert embed_dim % 2 == 0
233
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
234
+ omega /= embed_dim / 2.0
235
+ omega = 1.0 / 10000**omega
236
+ pos = pos.reshape(-1)
237
+ out = np.einsum("m,d->md", pos, omega)
238
+ emb_sin = np.sin(out)
239
+ emb_cos = np.cos(out)
240
+ return np.concatenate([emb_sin, emb_cos], axis=1)
SiT-XL-2-256/__pycache__/pipeline.cpython-312.pyc ADDED
Binary file (7.51 kB). View file
 
SiT-XL-2-256/model_index.json CHANGED
@@ -1,19 +1,1021 @@
1
- {
2
- "_class_name": [
3
- "pipeline",
4
- "SiTPipeline"
5
- ],
6
- "_diffusers_version": "0.36.0",
7
- "scheduler": [
8
- "scheduling_flow_match_sit",
9
- "SiTFlowMatchScheduler"
10
- ],
11
- "transformer": [
12
- "transformer_sit",
13
- "SiTTransformer2DModel"
14
- ],
15
- "vae": [
16
- "diffusers",
17
- "AutoencoderKL"
18
- ]
19
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": [
3
+ "pipeline",
4
+ "SiTPipeline"
5
+ ],
6
+ "_diffusers_version": "0.36.0",
7
+ "scheduler": [
8
+ "diffusers",
9
+ "FlowMatchEulerDiscreteScheduler"
10
+ ],
11
+ "transformer": [
12
+ "transformer_sit",
13
+ "SiTTransformer2DModel"
14
+ ],
15
+ "vae": [
16
+ "diffusers",
17
+ "AutoencoderKL"
18
+ ],
19
+ "id2label": {
20
+ "0": "tench, Tinca tinca",
21
+ "1": "goldfish, Carassius auratus",
22
+ "2": "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias",
23
+ "3": "tiger shark, Galeocerdo cuvieri",
24
+ "4": "hammerhead, hammerhead shark",
25
+ "5": "electric ray, crampfish, numbfish, torpedo",
26
+ "6": "stingray",
27
+ "7": "cock",
28
+ "8": "hen",
29
+ "9": "ostrich, Struthio camelus",
30
+ "10": "brambling, Fringilla montifringilla",
31
+ "11": "goldfinch, Carduelis carduelis",
32
+ "12": "house finch, linnet, Carpodacus mexicanus",
33
+ "13": "junco, snowbird",
34
+ "14": "indigo bunting, indigo finch, indigo bird, Passerina cyanea",
35
+ "15": "robin, American robin, Turdus migratorius",
36
+ "16": "bulbul",
37
+ "17": "jay",
38
+ "18": "magpie",
39
+ "19": "chickadee",
40
+ "20": "water ouzel, dipper",
41
+ "21": "kite",
42
+ "22": "bald eagle, American eagle, Haliaeetus leucocephalus",
43
+ "23": "vulture",
44
+ "24": "great grey owl, great gray owl, Strix nebulosa",
45
+ "25": "European fire salamander, Salamandra salamandra",
46
+ "26": "common newt, Triturus vulgaris",
47
+ "27": "eft",
48
+ "28": "spotted salamander, Ambystoma maculatum",
49
+ "29": "axolotl, mud puppy, Ambystoma mexicanum",
50
+ "30": "bullfrog, Rana catesbeiana",
51
+ "31": "tree frog, tree-frog",
52
+ "32": "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui",
53
+ "33": "loggerhead, loggerhead turtle, Caretta caretta",
54
+ "34": "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea",
55
+ "35": "mud turtle",
56
+ "36": "terrapin",
57
+ "37": "box turtle, box tortoise",
58
+ "38": "banded gecko",
59
+ "39": "common iguana, iguana, Iguana iguana",
60
+ "40": "American chameleon, anole, Anolis carolinensis",
61
+ "41": "whiptail, whiptail lizard",
62
+ "42": "agama",
63
+ "43": "frilled lizard, Chlamydosaurus kingi",
64
+ "44": "alligator lizard",
65
+ "45": "Gila monster, Heloderma suspectum",
66
+ "46": "green lizard, Lacerta viridis",
67
+ "47": "African chameleon, Chamaeleo chamaeleon",
68
+ "48": "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis",
69
+ "49": "African crocodile, Nile crocodile, Crocodylus niloticus",
70
+ "50": "American alligator, Alligator mississipiensis",
71
+ "51": "triceratops",
72
+ "52": "thunder snake, worm snake, Carphophis amoenus",
73
+ "53": "ringneck snake, ring-necked snake, ring snake",
74
+ "54": "hognose snake, puff adder, sand viper",
75
+ "55": "green snake, grass snake",
76
+ "56": "king snake, kingsnake",
77
+ "57": "garter snake, grass snake",
78
+ "58": "water snake",
79
+ "59": "vine snake",
80
+ "60": "night snake, Hypsiglena torquata",
81
+ "61": "boa constrictor, Constrictor constrictor",
82
+ "62": "rock python, rock snake, Python sebae",
83
+ "63": "Indian cobra, Naja naja",
84
+ "64": "green mamba",
85
+ "65": "sea snake",
86
+ "66": "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus",
87
+ "67": "diamondback, diamondback rattlesnake, Crotalus adamanteus",
88
+ "68": "sidewinder, horned rattlesnake, Crotalus cerastes",
89
+ "69": "trilobite",
90
+ "70": "harvestman, daddy longlegs, Phalangium opilio",
91
+ "71": "scorpion",
92
+ "72": "black and gold garden spider, Argiope aurantia",
93
+ "73": "barn spider, Araneus cavaticus",
94
+ "74": "garden spider, Aranea diademata",
95
+ "75": "black widow, Latrodectus mactans",
96
+ "76": "tarantula",
97
+ "77": "wolf spider, hunting spider",
98
+ "78": "tick",
99
+ "79": "centipede",
100
+ "80": "black grouse",
101
+ "81": "ptarmigan",
102
+ "82": "ruffed grouse, partridge, Bonasa umbellus",
103
+ "83": "prairie chicken, prairie grouse, prairie fowl",
104
+ "84": "peacock",
105
+ "85": "quail",
106
+ "86": "partridge",
107
+ "87": "African grey, African gray, Psittacus erithacus",
108
+ "88": "macaw",
109
+ "89": "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita",
110
+ "90": "lorikeet",
111
+ "91": "coucal",
112
+ "92": "bee eater",
113
+ "93": "hornbill",
114
+ "94": "hummingbird",
115
+ "95": "jacamar",
116
+ "96": "toucan",
117
+ "97": "drake",
118
+ "98": "red-breasted merganser, Mergus serrator",
119
+ "99": "goose",
120
+ "100": "black swan, Cygnus atratus",
121
+ "101": "tusker",
122
+ "102": "echidna, spiny anteater, anteater",
123
+ "103": "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus",
124
+ "104": "wallaby, brush kangaroo",
125
+ "105": "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus",
126
+ "106": "wombat",
127
+ "107": "jellyfish",
128
+ "108": "sea anemone, anemone",
129
+ "109": "brain coral",
130
+ "110": "flatworm, platyhelminth",
131
+ "111": "nematode, nematode worm, roundworm",
132
+ "112": "conch",
133
+ "113": "snail",
134
+ "114": "slug",
135
+ "115": "sea slug, nudibranch",
136
+ "116": "chiton, coat-of-mail shell, sea cradle, polyplacophore",
137
+ "117": "chambered nautilus, pearly nautilus, nautilus",
138
+ "118": "Dungeness crab, Cancer magister",
139
+ "119": "rock crab, Cancer irroratus",
140
+ "120": "fiddler crab",
141
+ "121": "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica",
142
+ "122": "American lobster, Northern lobster, Maine lobster, Homarus americanus",
143
+ "123": "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish",
144
+ "124": "crayfish, crawfish, crawdad, crawdaddy",
145
+ "125": "hermit crab",
146
+ "126": "isopod",
147
+ "127": "white stork, Ciconia ciconia",
148
+ "128": "black stork, Ciconia nigra",
149
+ "129": "spoonbill",
150
+ "130": "flamingo",
151
+ "131": "little blue heron, Egretta caerulea",
152
+ "132": "American egret, great white heron, Egretta albus",
153
+ "133": "bittern",
154
+ "134": "crane",
155
+ "135": "limpkin, Aramus pictus",
156
+ "136": "European gallinule, Porphyrio porphyrio",
157
+ "137": "American coot, marsh hen, mud hen, water hen, Fulica americana",
158
+ "138": "bustard",
159
+ "139": "ruddy turnstone, Arenaria interpres",
160
+ "140": "red-backed sandpiper, dunlin, Erolia alpina",
161
+ "141": "redshank, Tringa totanus",
162
+ "142": "dowitcher",
163
+ "143": "oystercatcher, oyster catcher",
164
+ "144": "pelican",
165
+ "145": "king penguin, Aptenodytes patagonica",
166
+ "146": "albatross, mollymawk",
167
+ "147": "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus",
168
+ "148": "killer whale, killer, orca, grampus, sea wolf, Orcinus orca",
169
+ "149": "dugong, Dugong dugon",
170
+ "150": "sea lion",
171
+ "151": "Chihuahua",
172
+ "152": "Japanese spaniel",
173
+ "153": "Maltese dog, Maltese terrier, Maltese",
174
+ "154": "Pekinese, Pekingese, Peke",
175
+ "155": "Shih-Tzu",
176
+ "156": "Blenheim spaniel",
177
+ "157": "papillon",
178
+ "158": "toy terrier",
179
+ "159": "Rhodesian ridgeback",
180
+ "160": "Afghan hound, Afghan",
181
+ "161": "basset, basset hound",
182
+ "162": "beagle",
183
+ "163": "bloodhound, sleuthhound",
184
+ "164": "bluetick",
185
+ "165": "black-and-tan coonhound",
186
+ "166": "Walker hound, Walker foxhound",
187
+ "167": "English foxhound",
188
+ "168": "redbone",
189
+ "169": "borzoi, Russian wolfhound",
190
+ "170": "Irish wolfhound",
191
+ "171": "Italian greyhound",
192
+ "172": "whippet",
193
+ "173": "Ibizan hound, Ibizan Podenco",
194
+ "174": "Norwegian elkhound, elkhound",
195
+ "175": "otterhound, otter hound",
196
+ "176": "Saluki, gazelle hound",
197
+ "177": "Scottish deerhound, deerhound",
198
+ "178": "Weimaraner",
199
+ "179": "Staffordshire bullterrier, Staffordshire bull terrier",
200
+ "180": "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier",
201
+ "181": "Bedlington terrier",
202
+ "182": "Border terrier",
203
+ "183": "Kerry blue terrier",
204
+ "184": "Irish terrier",
205
+ "185": "Norfolk terrier",
206
+ "186": "Norwich terrier",
207
+ "187": "Yorkshire terrier",
208
+ "188": "wire-haired fox terrier",
209
+ "189": "Lakeland terrier",
210
+ "190": "Sealyham terrier, Sealyham",
211
+ "191": "Airedale, Airedale terrier",
212
+ "192": "cairn, cairn terrier",
213
+ "193": "Australian terrier",
214
+ "194": "Dandie Dinmont, Dandie Dinmont terrier",
215
+ "195": "Boston bull, Boston terrier",
216
+ "196": "miniature schnauzer",
217
+ "197": "giant schnauzer",
218
+ "198": "standard schnauzer",
219
+ "199": "Scotch terrier, Scottish terrier, Scottie",
220
+ "200": "Tibetan terrier, chrysanthemum dog",
221
+ "201": "silky terrier, Sydney silky",
222
+ "202": "soft-coated wheaten terrier",
223
+ "203": "West Highland white terrier",
224
+ "204": "Lhasa, Lhasa apso",
225
+ "205": "flat-coated retriever",
226
+ "206": "curly-coated retriever",
227
+ "207": "golden retriever",
228
+ "208": "Labrador retriever",
229
+ "209": "Chesapeake Bay retriever",
230
+ "210": "German short-haired pointer",
231
+ "211": "vizsla, Hungarian pointer",
232
+ "212": "English setter",
233
+ "213": "Irish setter, red setter",
234
+ "214": "Gordon setter",
235
+ "215": "Brittany spaniel",
236
+ "216": "clumber, clumber spaniel",
237
+ "217": "English springer, English springer spaniel",
238
+ "218": "Welsh springer spaniel",
239
+ "219": "cocker spaniel, English cocker spaniel, cocker",
240
+ "220": "Sussex spaniel",
241
+ "221": "Irish water spaniel",
242
+ "222": "kuvasz",
243
+ "223": "schipperke",
244
+ "224": "groenendael",
245
+ "225": "malinois",
246
+ "226": "briard",
247
+ "227": "kelpie",
248
+ "228": "komondor",
249
+ "229": "Old English sheepdog, bobtail",
250
+ "230": "Shetland sheepdog, Shetland sheep dog, Shetland",
251
+ "231": "collie",
252
+ "232": "Border collie",
253
+ "233": "Bouvier des Flandres, Bouviers des Flandres",
254
+ "234": "Rottweiler",
255
+ "235": "German shepherd, German shepherd dog, German police dog, alsatian",
256
+ "236": "Doberman, Doberman pinscher",
257
+ "237": "miniature pinscher",
258
+ "238": "Greater Swiss Mountain dog",
259
+ "239": "Bernese mountain dog",
260
+ "240": "Appenzeller",
261
+ "241": "EntleBucher",
262
+ "242": "boxer",
263
+ "243": "bull mastiff",
264
+ "244": "Tibetan mastiff",
265
+ "245": "French bulldog",
266
+ "246": "Great Dane",
267
+ "247": "Saint Bernard, St Bernard",
268
+ "248": "Eskimo dog, husky",
269
+ "249": "malamute, malemute, Alaskan malamute",
270
+ "250": "Siberian husky",
271
+ "251": "dalmatian, coach dog, carriage dog",
272
+ "252": "affenpinscher, monkey pinscher, monkey dog",
273
+ "253": "basenji",
274
+ "254": "pug, pug-dog",
275
+ "255": "Leonberg",
276
+ "256": "Newfoundland, Newfoundland dog",
277
+ "257": "Great Pyrenees",
278
+ "258": "Samoyed, Samoyede",
279
+ "259": "Pomeranian",
280
+ "260": "chow, chow chow",
281
+ "261": "keeshond",
282
+ "262": "Brabancon griffon",
283
+ "263": "Pembroke, Pembroke Welsh corgi",
284
+ "264": "Cardigan, Cardigan Welsh corgi",
285
+ "265": "toy poodle",
286
+ "266": "miniature poodle",
287
+ "267": "standard poodle",
288
+ "268": "Mexican hairless",
289
+ "269": "timber wolf, grey wolf, gray wolf, Canis lupus",
290
+ "270": "white wolf, Arctic wolf, Canis lupus tundrarum",
291
+ "271": "red wolf, maned wolf, Canis rufus, Canis niger",
292
+ "272": "coyote, prairie wolf, brush wolf, Canis latrans",
293
+ "273": "dingo, warrigal, warragal, Canis dingo",
294
+ "274": "dhole, Cuon alpinus",
295
+ "275": "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus",
296
+ "276": "hyena, hyaena",
297
+ "277": "red fox, Vulpes vulpes",
298
+ "278": "kit fox, Vulpes macrotis",
299
+ "279": "Arctic fox, white fox, Alopex lagopus",
300
+ "280": "grey fox, gray fox, Urocyon cinereoargenteus",
301
+ "281": "tabby, tabby cat",
302
+ "282": "tiger cat",
303
+ "283": "Persian cat",
304
+ "284": "Siamese cat, Siamese",
305
+ "285": "Egyptian cat",
306
+ "286": "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor",
307
+ "287": "lynx, catamount",
308
+ "288": "leopard, Panthera pardus",
309
+ "289": "snow leopard, ounce, Panthera uncia",
310
+ "290": "jaguar, panther, Panthera onca, Felis onca",
311
+ "291": "lion, king of beasts, Panthera leo",
312
+ "292": "tiger, Panthera tigris",
313
+ "293": "cheetah, chetah, Acinonyx jubatus",
314
+ "294": "brown bear, bruin, Ursus arctos",
315
+ "295": "American black bear, black bear, Ursus americanus, Euarctos americanus",
316
+ "296": "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus",
317
+ "297": "sloth bear, Melursus ursinus, Ursus ursinus",
318
+ "298": "mongoose",
319
+ "299": "meerkat, mierkat",
320
+ "300": "tiger beetle",
321
+ "301": "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle",
322
+ "302": "ground beetle, carabid beetle",
323
+ "303": "long-horned beetle, longicorn, longicorn beetle",
324
+ "304": "leaf beetle, chrysomelid",
325
+ "305": "dung beetle",
326
+ "306": "rhinoceros beetle",
327
+ "307": "weevil",
328
+ "308": "fly",
329
+ "309": "bee",
330
+ "310": "ant, emmet, pismire",
331
+ "311": "grasshopper, hopper",
332
+ "312": "cricket",
333
+ "313": "walking stick, walkingstick, stick insect",
334
+ "314": "cockroach, roach",
335
+ "315": "mantis, mantid",
336
+ "316": "cicada, cicala",
337
+ "317": "leafhopper",
338
+ "318": "lacewing, lacewing fly",
339
+ "319": "dragonfly, darning needle, devils darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk",
340
+ "320": "damselfly",
341
+ "321": "admiral",
342
+ "322": "ringlet, ringlet butterfly",
343
+ "323": "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus",
344
+ "324": "cabbage butterfly",
345
+ "325": "sulphur butterfly, sulfur butterfly",
346
+ "326": "lycaenid, lycaenid butterfly",
347
+ "327": "starfish, sea star",
348
+ "328": "sea urchin",
349
+ "329": "sea cucumber, holothurian",
350
+ "330": "wood rabbit, cottontail, cottontail rabbit",
351
+ "331": "hare",
352
+ "332": "Angora, Angora rabbit",
353
+ "333": "hamster",
354
+ "334": "porcupine, hedgehog",
355
+ "335": "fox squirrel, eastern fox squirrel, Sciurus niger",
356
+ "336": "marmot",
357
+ "337": "beaver",
358
+ "338": "guinea pig, Cavia cobaya",
359
+ "339": "sorrel",
360
+ "340": "zebra",
361
+ "341": "hog, pig, grunter, squealer, Sus scrofa",
362
+ "342": "wild boar, boar, Sus scrofa",
363
+ "343": "warthog",
364
+ "344": "hippopotamus, hippo, river horse, Hippopotamus amphibius",
365
+ "345": "ox",
366
+ "346": "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis",
367
+ "347": "bison",
368
+ "348": "ram, tup",
369
+ "349": "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis",
370
+ "350": "ibex, Capra ibex",
371
+ "351": "hartebeest",
372
+ "352": "impala, Aepyceros melampus",
373
+ "353": "gazelle",
374
+ "354": "Arabian camel, dromedary, Camelus dromedarius",
375
+ "355": "llama",
376
+ "356": "weasel",
377
+ "357": "mink",
378
+ "358": "polecat, fitch, foulmart, foumart, Mustela putorius",
379
+ "359": "black-footed ferret, ferret, Mustela nigripes",
380
+ "360": "otter",
381
+ "361": "skunk, polecat, wood pussy",
382
+ "362": "badger",
383
+ "363": "armadillo",
384
+ "364": "three-toed sloth, ai, Bradypus tridactylus",
385
+ "365": "orangutan, orang, orangutang, Pongo pygmaeus",
386
+ "366": "gorilla, Gorilla gorilla",
387
+ "367": "chimpanzee, chimp, Pan troglodytes",
388
+ "368": "gibbon, Hylobates lar",
389
+ "369": "siamang, Hylobates syndactylus, Symphalangus syndactylus",
390
+ "370": "guenon, guenon monkey",
391
+ "371": "patas, hussar monkey, Erythrocebus patas",
392
+ "372": "baboon",
393
+ "373": "macaque",
394
+ "374": "langur",
395
+ "375": "colobus, colobus monkey",
396
+ "376": "proboscis monkey, Nasalis larvatus",
397
+ "377": "marmoset",
398
+ "378": "capuchin, ringtail, Cebus capucinus",
399
+ "379": "howler monkey, howler",
400
+ "380": "titi, titi monkey",
401
+ "381": "spider monkey, Ateles geoffroyi",
402
+ "382": "squirrel monkey, Saimiri sciureus",
403
+ "383": "Madagascar cat, ring-tailed lemur, Lemur catta",
404
+ "384": "indri, indris, Indri indri, Indri brevicaudatus",
405
+ "385": "Indian elephant, Elephas maximus",
406
+ "386": "African elephant, Loxodonta africana",
407
+ "387": "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens",
408
+ "388": "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca",
409
+ "389": "barracouta, snoek",
410
+ "390": "eel",
411
+ "391": "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch",
412
+ "392": "rock beauty, Holocanthus tricolor",
413
+ "393": "anemone fish",
414
+ "394": "sturgeon",
415
+ "395": "gar, garfish, garpike, billfish, Lepisosteus osseus",
416
+ "396": "lionfish",
417
+ "397": "puffer, pufferfish, blowfish, globefish",
418
+ "398": "abacus",
419
+ "399": "abaya",
420
+ "400": "academic gown, academic robe, judge robe",
421
+ "401": "accordion, piano accordion, squeeze box",
422
+ "402": "acoustic guitar",
423
+ "403": "aircraft carrier, carrier, flattop, attack aircraft carrier",
424
+ "404": "airliner",
425
+ "405": "airship, dirigible",
426
+ "406": "altar",
427
+ "407": "ambulance",
428
+ "408": "amphibian, amphibious vehicle",
429
+ "409": "analog clock",
430
+ "410": "apiary, bee house",
431
+ "411": "apron",
432
+ "412": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin",
433
+ "413": "assault rifle, assault gun",
434
+ "414": "backpack, back pack, knapsack, packsack, rucksack, haversack",
435
+ "415": "bakery, bakeshop, bakehouse",
436
+ "416": "balance beam, beam",
437
+ "417": "balloon",
438
+ "418": "ballpoint, ballpoint pen, ballpen, Biro",
439
+ "419": "Band Aid",
440
+ "420": "banjo",
441
+ "421": "bannister, banister, balustrade, balusters, handrail",
442
+ "422": "barbell",
443
+ "423": "barber chair",
444
+ "424": "barbershop",
445
+ "425": "barn",
446
+ "426": "barometer",
447
+ "427": "barrel, cask",
448
+ "428": "barrow, garden cart, lawn cart, wheelbarrow",
449
+ "429": "baseball",
450
+ "430": "basketball",
451
+ "431": "bassinet",
452
+ "432": "bassoon",
453
+ "433": "bathing cap, swimming cap",
454
+ "434": "bath towel",
455
+ "435": "bathtub, bathing tub, bath, tub",
456
+ "436": "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon",
457
+ "437": "beacon, lighthouse, beacon light, pharos",
458
+ "438": "beaker",
459
+ "439": "bearskin, busby, shako",
460
+ "440": "beer bottle",
461
+ "441": "beer glass",
462
+ "442": "bell cote, bell cot",
463
+ "443": "bib",
464
+ "444": "bicycle-built-for-two, tandem bicycle, tandem",
465
+ "445": "bikini, two-piece",
466
+ "446": "binder, ring-binder",
467
+ "447": "binoculars, field glasses, opera glasses",
468
+ "448": "birdhouse",
469
+ "449": "boathouse",
470
+ "450": "bobsled, bobsleigh, bob",
471
+ "451": "bolo tie, bolo, bola tie, bola",
472
+ "452": "bonnet, poke bonnet",
473
+ "453": "bookcase",
474
+ "454": "bookshop, bookstore, bookstall",
475
+ "455": "bottlecap",
476
+ "456": "bow",
477
+ "457": "bow tie, bow-tie, bowtie",
478
+ "458": "brass, memorial tablet, plaque",
479
+ "459": "brassiere, bra, bandeau",
480
+ "460": "breakwater, groin, groyne, mole, bulwark, seawall, jetty",
481
+ "461": "breastplate, aegis, egis",
482
+ "462": "broom",
483
+ "463": "bucket, pail",
484
+ "464": "buckle",
485
+ "465": "bulletproof vest",
486
+ "466": "bullet train, bullet",
487
+ "467": "butcher shop, meat market",
488
+ "468": "cab, hack, taxi, taxicab",
489
+ "469": "caldron, cauldron",
490
+ "470": "candle, taper, wax light",
491
+ "471": "cannon",
492
+ "472": "canoe",
493
+ "473": "can opener, tin opener",
494
+ "474": "cardigan",
495
+ "475": "car mirror",
496
+ "476": "carousel, carrousel, merry-go-round, roundabout, whirligig",
497
+ "477": "carpenters kit, tool kit",
498
+ "478": "carton",
499
+ "479": "car wheel",
500
+ "480": "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM",
501
+ "481": "cassette",
502
+ "482": "cassette player",
503
+ "483": "castle",
504
+ "484": "catamaran",
505
+ "485": "CD player",
506
+ "486": "cello, violoncello",
507
+ "487": "cellular telephone, cellular phone, cellphone, cell, mobile phone",
508
+ "488": "chain",
509
+ "489": "chainlink fence",
510
+ "490": "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour",
511
+ "491": "chain saw, chainsaw",
512
+ "492": "chest",
513
+ "493": "chiffonier, commode",
514
+ "494": "chime, bell, gong",
515
+ "495": "china cabinet, china closet",
516
+ "496": "Christmas stocking",
517
+ "497": "church, church building",
518
+ "498": "cinema, movie theater, movie theatre, movie house, picture palace",
519
+ "499": "cleaver, meat cleaver, chopper",
520
+ "500": "cliff dwelling",
521
+ "501": "cloak",
522
+ "502": "clog, geta, patten, sabot",
523
+ "503": "cocktail shaker",
524
+ "504": "coffee mug",
525
+ "505": "coffeepot",
526
+ "506": "coil, spiral, volute, whorl, helix",
527
+ "507": "combination lock",
528
+ "508": "computer keyboard, keypad",
529
+ "509": "confectionery, confectionary, candy store",
530
+ "510": "container ship, containership, container vessel",
531
+ "511": "convertible",
532
+ "512": "corkscrew, bottle screw",
533
+ "513": "cornet, horn, trumpet, trump",
534
+ "514": "cowboy boot",
535
+ "515": "cowboy hat, ten-gallon hat",
536
+ "516": "cradle",
537
+ "517": "crane",
538
+ "518": "crash helmet",
539
+ "519": "crate",
540
+ "520": "crib, cot",
541
+ "521": "Crock Pot",
542
+ "522": "croquet ball",
543
+ "523": "crutch",
544
+ "524": "cuirass",
545
+ "525": "dam, dike, dyke",
546
+ "526": "desk",
547
+ "527": "desktop computer",
548
+ "528": "dial telephone, dial phone",
549
+ "529": "diaper, nappy, napkin",
550
+ "530": "digital clock",
551
+ "531": "digital watch",
552
+ "532": "dining table, board",
553
+ "533": "dishrag, dishcloth",
554
+ "534": "dishwasher, dish washer, dishwashing machine",
555
+ "535": "disk brake, disc brake",
556
+ "536": "dock, dockage, docking facility",
557
+ "537": "dogsled, dog sled, dog sleigh",
558
+ "538": "dome",
559
+ "539": "doormat, welcome mat",
560
+ "540": "drilling platform, offshore rig",
561
+ "541": "drum, membranophone, tympan",
562
+ "542": "drumstick",
563
+ "543": "dumbbell",
564
+ "544": "Dutch oven",
565
+ "545": "electric fan, blower",
566
+ "546": "electric guitar",
567
+ "547": "electric locomotive",
568
+ "548": "entertainment center",
569
+ "549": "envelope",
570
+ "550": "espresso maker",
571
+ "551": "face powder",
572
+ "552": "feather boa, boa",
573
+ "553": "file, file cabinet, filing cabinet",
574
+ "554": "fireboat",
575
+ "555": "fire engine, fire truck",
576
+ "556": "fire screen, fireguard",
577
+ "557": "flagpole, flagstaff",
578
+ "558": "flute, transverse flute",
579
+ "559": "folding chair",
580
+ "560": "football helmet",
581
+ "561": "forklift",
582
+ "562": "fountain",
583
+ "563": "fountain pen",
584
+ "564": "four-poster",
585
+ "565": "freight car",
586
+ "566": "French horn, horn",
587
+ "567": "frying pan, frypan, skillet",
588
+ "568": "fur coat",
589
+ "569": "garbage truck, dustcart",
590
+ "570": "gasmask, respirator, gas helmet",
591
+ "571": "gas pump, gasoline pump, petrol pump, island dispenser",
592
+ "572": "goblet",
593
+ "573": "go-kart",
594
+ "574": "golf ball",
595
+ "575": "golfcart, golf cart",
596
+ "576": "gondola",
597
+ "577": "gong, tam-tam",
598
+ "578": "gown",
599
+ "579": "grand piano, grand",
600
+ "580": "greenhouse, nursery, glasshouse",
601
+ "581": "grille, radiator grille",
602
+ "582": "grocery store, grocery, food market, market",
603
+ "583": "guillotine",
604
+ "584": "hair slide",
605
+ "585": "hair spray",
606
+ "586": "half track",
607
+ "587": "hammer",
608
+ "588": "hamper",
609
+ "589": "hand blower, blow dryer, blow drier, hair dryer, hair drier",
610
+ "590": "hand-held computer, hand-held microcomputer",
611
+ "591": "handkerchief, hankie, hanky, hankey",
612
+ "592": "hard disc, hard disk, fixed disk",
613
+ "593": "harmonica, mouth organ, harp, mouth harp",
614
+ "594": "harp",
615
+ "595": "harvester, reaper",
616
+ "596": "hatchet",
617
+ "597": "holster",
618
+ "598": "home theater, home theatre",
619
+ "599": "honeycomb",
620
+ "600": "hook, claw",
621
+ "601": "hoopskirt, crinoline",
622
+ "602": "horizontal bar, high bar",
623
+ "603": "horse cart, horse-cart",
624
+ "604": "hourglass",
625
+ "605": "iPod",
626
+ "606": "iron, smoothing iron",
627
+ "607": "jack-o-lantern",
628
+ "608": "jean, blue jean, denim",
629
+ "609": "jeep, landrover",
630
+ "610": "jersey, T-shirt, tee shirt",
631
+ "611": "jigsaw puzzle",
632
+ "612": "jinrikisha, ricksha, rickshaw",
633
+ "613": "joystick",
634
+ "614": "kimono",
635
+ "615": "knee pad",
636
+ "616": "knot",
637
+ "617": "lab coat, laboratory coat",
638
+ "618": "ladle",
639
+ "619": "lampshade, lamp shade",
640
+ "620": "laptop, laptop computer",
641
+ "621": "lawn mower, mower",
642
+ "622": "lens cap, lens cover",
643
+ "623": "letter opener, paper knife, paperknife",
644
+ "624": "library",
645
+ "625": "lifeboat",
646
+ "626": "lighter, light, igniter, ignitor",
647
+ "627": "limousine, limo",
648
+ "628": "liner, ocean liner",
649
+ "629": "lipstick, lip rouge",
650
+ "630": "Loafer",
651
+ "631": "lotion",
652
+ "632": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system",
653
+ "633": "loupe, jewelers loupe",
654
+ "634": "lumbermill, sawmill",
655
+ "635": "magnetic compass",
656
+ "636": "mailbag, postbag",
657
+ "637": "mailbox, letter box",
658
+ "638": "maillot",
659
+ "639": "maillot, tank suit",
660
+ "640": "manhole cover",
661
+ "641": "maraca",
662
+ "642": "marimba, xylophone",
663
+ "643": "mask",
664
+ "644": "matchstick",
665
+ "645": "maypole",
666
+ "646": "maze, labyrinth",
667
+ "647": "measuring cup",
668
+ "648": "medicine chest, medicine cabinet",
669
+ "649": "megalith, megalithic structure",
670
+ "650": "microphone, mike",
671
+ "651": "microwave, microwave oven",
672
+ "652": "military uniform",
673
+ "653": "milk can",
674
+ "654": "minibus",
675
+ "655": "miniskirt, mini",
676
+ "656": "minivan",
677
+ "657": "missile",
678
+ "658": "mitten",
679
+ "659": "mixing bowl",
680
+ "660": "mobile home, manufactured home",
681
+ "661": "Model T",
682
+ "662": "modem",
683
+ "663": "monastery",
684
+ "664": "monitor",
685
+ "665": "moped",
686
+ "666": "mortar",
687
+ "667": "mortarboard",
688
+ "668": "mosque",
689
+ "669": "mosquito net",
690
+ "670": "motor scooter, scooter",
691
+ "671": "mountain bike, all-terrain bike, off-roader",
692
+ "672": "mountain tent",
693
+ "673": "mouse, computer mouse",
694
+ "674": "mousetrap",
695
+ "675": "moving van",
696
+ "676": "muzzle",
697
+ "677": "nail",
698
+ "678": "neck brace",
699
+ "679": "necklace",
700
+ "680": "nipple",
701
+ "681": "notebook, notebook computer",
702
+ "682": "obelisk",
703
+ "683": "oboe, hautboy, hautbois",
704
+ "684": "ocarina, sweet potato",
705
+ "685": "odometer, hodometer, mileometer, milometer",
706
+ "686": "oil filter",
707
+ "687": "organ, pipe organ",
708
+ "688": "oscilloscope, scope, cathode-ray oscilloscope, CRO",
709
+ "689": "overskirt",
710
+ "690": "oxcart",
711
+ "691": "oxygen mask",
712
+ "692": "packet",
713
+ "693": "paddle, boat paddle",
714
+ "694": "paddlewheel, paddle wheel",
715
+ "695": "padlock",
716
+ "696": "paintbrush",
717
+ "697": "pajama, pyjama, pjs, jammies",
718
+ "698": "palace",
719
+ "699": "panpipe, pandean pipe, syrinx",
720
+ "700": "paper towel",
721
+ "701": "parachute, chute",
722
+ "702": "parallel bars, bars",
723
+ "703": "park bench",
724
+ "704": "parking meter",
725
+ "705": "passenger car, coach, carriage",
726
+ "706": "patio, terrace",
727
+ "707": "pay-phone, pay-station",
728
+ "708": "pedestal, plinth, footstall",
729
+ "709": "pencil box, pencil case",
730
+ "710": "pencil sharpener",
731
+ "711": "perfume, essence",
732
+ "712": "Petri dish",
733
+ "713": "photocopier",
734
+ "714": "pick, plectrum, plectron",
735
+ "715": "pickelhaube",
736
+ "716": "picket fence, paling",
737
+ "717": "pickup, pickup truck",
738
+ "718": "pier",
739
+ "719": "piggy bank, penny bank",
740
+ "720": "pill bottle",
741
+ "721": "pillow",
742
+ "722": "ping-pong ball",
743
+ "723": "pinwheel",
744
+ "724": "pirate, pirate ship",
745
+ "725": "pitcher, ewer",
746
+ "726": "plane, carpenters plane, woodworking plane",
747
+ "727": "planetarium",
748
+ "728": "plastic bag",
749
+ "729": "plate rack",
750
+ "730": "plow, plough",
751
+ "731": "plunger, plumbers helper",
752
+ "732": "Polaroid camera, Polaroid Land camera",
753
+ "733": "pole",
754
+ "734": "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria",
755
+ "735": "poncho",
756
+ "736": "pool table, billiard table, snooker table",
757
+ "737": "pop bottle, soda bottle",
758
+ "738": "pot, flowerpot",
759
+ "739": "potters wheel",
760
+ "740": "power drill",
761
+ "741": "prayer rug, prayer mat",
762
+ "742": "printer",
763
+ "743": "prison, prison house",
764
+ "744": "projectile, missile",
765
+ "745": "projector",
766
+ "746": "puck, hockey puck",
767
+ "747": "punching bag, punch bag, punching ball, punchball",
768
+ "748": "purse",
769
+ "749": "quill, quill pen",
770
+ "750": "quilt, comforter, comfort, puff",
771
+ "751": "racer, race car, racing car",
772
+ "752": "racket, racquet",
773
+ "753": "radiator",
774
+ "754": "radio, wireless",
775
+ "755": "radio telescope, radio reflector",
776
+ "756": "rain barrel",
777
+ "757": "recreational vehicle, RV, R.V.",
778
+ "758": "reel",
779
+ "759": "reflex camera",
780
+ "760": "refrigerator, icebox",
781
+ "761": "remote control, remote",
782
+ "762": "restaurant, eating house, eating place, eatery",
783
+ "763": "revolver, six-gun, six-shooter",
784
+ "764": "rifle",
785
+ "765": "rocking chair, rocker",
786
+ "766": "rotisserie",
787
+ "767": "rubber eraser, rubber, pencil eraser",
788
+ "768": "rugby ball",
789
+ "769": "rule, ruler",
790
+ "770": "running shoe",
791
+ "771": "safe",
792
+ "772": "safety pin",
793
+ "773": "saltshaker, salt shaker",
794
+ "774": "sandal",
795
+ "775": "sarong",
796
+ "776": "sax, saxophone",
797
+ "777": "scabbard",
798
+ "778": "scale, weighing machine",
799
+ "779": "school bus",
800
+ "780": "schooner",
801
+ "781": "scoreboard",
802
+ "782": "screen, CRT screen",
803
+ "783": "screw",
804
+ "784": "screwdriver",
805
+ "785": "seat belt, seatbelt",
806
+ "786": "sewing machine",
807
+ "787": "shield, buckler",
808
+ "788": "shoe shop, shoe-shop, shoe store",
809
+ "789": "shoji",
810
+ "790": "shopping basket",
811
+ "791": "shopping cart",
812
+ "792": "shovel",
813
+ "793": "shower cap",
814
+ "794": "shower curtain",
815
+ "795": "ski",
816
+ "796": "ski mask",
817
+ "797": "sleeping bag",
818
+ "798": "slide rule, slipstick",
819
+ "799": "sliding door",
820
+ "800": "slot, one-armed bandit",
821
+ "801": "snorkel",
822
+ "802": "snowmobile",
823
+ "803": "snowplow, snowplough",
824
+ "804": "soap dispenser",
825
+ "805": "soccer ball",
826
+ "806": "sock",
827
+ "807": "solar dish, solar collector, solar furnace",
828
+ "808": "sombrero",
829
+ "809": "soup bowl",
830
+ "810": "space bar",
831
+ "811": "space heater",
832
+ "812": "space shuttle",
833
+ "813": "spatula",
834
+ "814": "speedboat",
835
+ "815": "spider web, spiders web",
836
+ "816": "spindle",
837
+ "817": "sports car, sport car",
838
+ "818": "spotlight, spot",
839
+ "819": "stage",
840
+ "820": "steam locomotive",
841
+ "821": "steel arch bridge",
842
+ "822": "steel drum",
843
+ "823": "stethoscope",
844
+ "824": "stole",
845
+ "825": "stone wall",
846
+ "826": "stopwatch, stop watch",
847
+ "827": "stove",
848
+ "828": "strainer",
849
+ "829": "streetcar, tram, tramcar, trolley, trolley car",
850
+ "830": "stretcher",
851
+ "831": "studio couch, day bed",
852
+ "832": "stupa, tope",
853
+ "833": "submarine, pigboat, sub, U-boat",
854
+ "834": "suit, suit of clothes",
855
+ "835": "sundial",
856
+ "836": "sunglass",
857
+ "837": "sunglasses, dark glasses, shades",
858
+ "838": "sunscreen, sunblock, sun blocker",
859
+ "839": "suspension bridge",
860
+ "840": "swab, swob, mop",
861
+ "841": "sweatshirt",
862
+ "842": "swimming trunks, bathing trunks",
863
+ "843": "swing",
864
+ "844": "switch, electric switch, electrical switch",
865
+ "845": "syringe",
866
+ "846": "table lamp",
867
+ "847": "tank, army tank, armored combat vehicle, armoured combat vehicle",
868
+ "848": "tape player",
869
+ "849": "teapot",
870
+ "850": "teddy, teddy bear",
871
+ "851": "television, television system",
872
+ "852": "tennis ball",
873
+ "853": "thatch, thatched roof",
874
+ "854": "theater curtain, theatre curtain",
875
+ "855": "thimble",
876
+ "856": "thresher, thrasher, threshing machine",
877
+ "857": "throne",
878
+ "858": "tile roof",
879
+ "859": "toaster",
880
+ "860": "tobacco shop, tobacconist shop, tobacconist",
881
+ "861": "toilet seat",
882
+ "862": "torch",
883
+ "863": "totem pole",
884
+ "864": "tow truck, tow car, wrecker",
885
+ "865": "toyshop",
886
+ "866": "tractor",
887
+ "867": "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi",
888
+ "868": "tray",
889
+ "869": "trench coat",
890
+ "870": "tricycle, trike, velocipede",
891
+ "871": "trimaran",
892
+ "872": "tripod",
893
+ "873": "triumphal arch",
894
+ "874": "trolleybus, trolley coach, trackless trolley",
895
+ "875": "trombone",
896
+ "876": "tub, vat",
897
+ "877": "turnstile",
898
+ "878": "typewriter keyboard",
899
+ "879": "umbrella",
900
+ "880": "unicycle, monocycle",
901
+ "881": "upright, upright piano",
902
+ "882": "vacuum, vacuum cleaner",
903
+ "883": "vase",
904
+ "884": "vault",
905
+ "885": "velvet",
906
+ "886": "vending machine",
907
+ "887": "vestment",
908
+ "888": "viaduct",
909
+ "889": "violin, fiddle",
910
+ "890": "volleyball",
911
+ "891": "waffle iron",
912
+ "892": "wall clock",
913
+ "893": "wallet, billfold, notecase, pocketbook",
914
+ "894": "wardrobe, closet, press",
915
+ "895": "warplane, military plane",
916
+ "896": "washbasin, handbasin, washbowl, lavabo, wash-hand basin",
917
+ "897": "washer, automatic washer, washing machine",
918
+ "898": "water bottle",
919
+ "899": "water jug",
920
+ "900": "water tower",
921
+ "901": "whiskey jug",
922
+ "902": "whistle",
923
+ "903": "wig",
924
+ "904": "window screen",
925
+ "905": "window shade",
926
+ "906": "Windsor tie",
927
+ "907": "wine bottle",
928
+ "908": "wing",
929
+ "909": "wok",
930
+ "910": "wooden spoon",
931
+ "911": "wool, woolen, woollen",
932
+ "912": "worm fence, snake fence, snake-rail fence, Virginia fence",
933
+ "913": "wreck",
934
+ "914": "yawl",
935
+ "915": "yurt",
936
+ "916": "web site, website, internet site, site",
937
+ "917": "comic book",
938
+ "918": "crossword puzzle, crossword",
939
+ "919": "street sign",
940
+ "920": "traffic light, traffic signal, stoplight",
941
+ "921": "book jacket, dust cover, dust jacket, dust wrapper",
942
+ "922": "menu",
943
+ "923": "plate",
944
+ "924": "guacamole",
945
+ "925": "consomme",
946
+ "926": "hot pot, hotpot",
947
+ "927": "trifle",
948
+ "928": "ice cream, icecream",
949
+ "929": "ice lolly, lolly, lollipop, popsicle",
950
+ "930": "French loaf",
951
+ "931": "bagel, beigel",
952
+ "932": "pretzel",
953
+ "933": "cheeseburger",
954
+ "934": "hotdog, hot dog, red hot",
955
+ "935": "mashed potato",
956
+ "936": "head cabbage",
957
+ "937": "broccoli",
958
+ "938": "cauliflower",
959
+ "939": "zucchini, courgette",
960
+ "940": "spaghetti squash",
961
+ "941": "acorn squash",
962
+ "942": "butternut squash",
963
+ "943": "cucumber, cuke",
964
+ "944": "artichoke, globe artichoke",
965
+ "945": "bell pepper",
966
+ "946": "cardoon",
967
+ "947": "mushroom",
968
+ "948": "Granny Smith",
969
+ "949": "strawberry",
970
+ "950": "orange",
971
+ "951": "lemon",
972
+ "952": "fig",
973
+ "953": "pineapple, ananas",
974
+ "954": "banana",
975
+ "955": "jackfruit, jak, jack",
976
+ "956": "custard apple",
977
+ "957": "pomegranate",
978
+ "958": "hay",
979
+ "959": "carbonara",
980
+ "960": "chocolate sauce, chocolate syrup",
981
+ "961": "dough",
982
+ "962": "meat loaf, meatloaf",
983
+ "963": "pizza, pizza pie",
984
+ "964": "potpie",
985
+ "965": "burrito",
986
+ "966": "red wine",
987
+ "967": "espresso",
988
+ "968": "cup",
989
+ "969": "eggnog",
990
+ "970": "alp",
991
+ "971": "bubble",
992
+ "972": "cliff, drop, drop-off",
993
+ "973": "coral reef",
994
+ "974": "geyser",
995
+ "975": "lakeside, lakeshore",
996
+ "976": "promontory, headland, head, foreland",
997
+ "977": "sandbar, sand bar",
998
+ "978": "seashore, coast, seacoast, sea-coast",
999
+ "979": "valley, vale",
1000
+ "980": "volcano",
1001
+ "981": "ballplayer, baseball player",
1002
+ "982": "groom, bridegroom",
1003
+ "983": "scuba diver",
1004
+ "984": "rapeseed",
1005
+ "985": "daisy",
1006
+ "986": "yellow ladys slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum",
1007
+ "987": "corn",
1008
+ "988": "acorn",
1009
+ "989": "hip, rose hip, rosehip",
1010
+ "990": "buckeye, horse chestnut, conker",
1011
+ "991": "coral fungus",
1012
+ "992": "agaric",
1013
+ "993": "gyromitra",
1014
+ "994": "stinkhorn, carrion fungus",
1015
+ "995": "earthstar",
1016
+ "996": "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa",
1017
+ "997": "bolete",
1018
+ "998": "ear, spike, capitulum",
1019
+ "999": "toilet tissue, toilet paper, bathroom tissue"
1020
+ }
1021
+ }
SiT-XL-2-256/pipeline.py CHANGED
@@ -1,82 +1,349 @@
1
- from typing import List, Optional, Union
2
-
3
- import torch
4
-
5
- from diffusers.image_processor import VaeImageProcessor
6
- from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
7
- from diffusers.utils.torch_utils import randn_tensor
8
-
9
-
10
- class SiTPipeline(DiffusionPipeline):
11
- model_cpu_offload_seq = "transformer->vae"
12
-
13
- def __init__(self, transformer, scheduler, vae):
14
- super().__init__()
15
- self.register_modules(transformer=transformer, scheduler=scheduler, vae=vae)
16
- self.vae_scale_factor = 8
17
- self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
18
-
19
- @torch.no_grad()
20
- def __call__(
21
- self,
22
- class_labels: Union[int, List[int]] = 207,
23
- height: int = 256,
24
- width: int = 256,
25
- num_inference_steps: int = 250,
26
- guidance_scale: float = 4.0,
27
- generator: Optional[torch.Generator] = None,
28
- output_type: str = "pil",
29
- return_dict: bool = True,
30
- ):
31
- device = self._execution_device
32
- if isinstance(class_labels, int):
33
- class_labels = [class_labels]
34
- batch_size = len(class_labels)
35
-
36
- latent_h = height // self.vae_scale_factor
37
- latent_w = width // self.vae_scale_factor
38
- latents = randn_tensor(
39
- (batch_size, self.transformer.config.in_channels, latent_h, latent_w),
40
- generator=generator,
41
- device=device,
42
- dtype=self.transformer.dtype,
43
- )
44
-
45
- labels = torch.tensor(class_labels, device=device, dtype=torch.long)
46
- do_cfg = guidance_scale is not None and guidance_scale > 1.0
47
- if do_cfg:
48
- null_label = torch.full((batch_size,), self.transformer.config.num_classes, device=device, dtype=torch.long)
49
- labels = torch.cat([labels, null_label], dim=0)
50
-
51
- self.scheduler.set_timesteps(num_inference_steps, device=device)
52
- timesteps = self.scheduler.timesteps
53
-
54
- for t in self.progress_bar(timesteps):
55
- t_batch = torch.full((batch_size,), t, device=device, dtype=latents.dtype)
56
- model_input = latents
57
- if do_cfg:
58
- model_input = torch.cat([latents, latents], dim=0)
59
- t_batch = torch.cat([t_batch, t_batch], dim=0)
60
-
61
- model_pred = self.transformer(
62
- hidden_states=model_input,
63
- timestep=t_batch,
64
- class_labels=labels,
65
- ).sample
66
-
67
- if do_cfg:
68
- cond, uncond = model_pred.chunk(2, dim=0)
69
- model_pred = uncond + guidance_scale * (cond - uncond)
70
-
71
- latents = self.scheduler.step(model_pred, t, latents, generator=generator).prev_sample
72
-
73
- image = self.vae.decode(latents / 0.18215).sample
74
- # Keep PyTorch outputs in raw VAE range [-1, 1] to match original SiT scripts.
75
- if output_type == "pt":
76
- image = image
77
- else:
78
- image = self.image_processor.postprocess(image, output_type=output_type)
79
-
80
- if not return_dict:
81
- return (image,)
82
- return ImagePipelineOutput(images=image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Hub custom pipeline: SiTPipeline.
2
+ Load with native Hugging Face diffusers and trust_remote_code=True.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ # Copyright 2026 The HuggingFace Team. All rights reserved.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+
21
+ import json
22
+ from pathlib import Path
23
+ from typing import Dict, List, Optional, Tuple, Union
24
+
25
+ import torch
26
+
27
+ from diffusers.image_processor import VaeImageProcessor
28
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
29
+ from diffusers.utils.torch_utils import randn_tensor
30
+
31
+ EXAMPLE_DOC_STRING = """
32
+ Examples:
33
+ ```py
34
+ >>> from pathlib import Path
35
+ >>> from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler
36
+ >>> import torch
37
+
38
+ >>> model_dir = Path("./SiT-XL-2-256").resolve()
39
+ >>> pipe = DiffusionPipeline.from_pretrained(
40
+ ... str(model_dir),
41
+ ... local_files_only=True,
42
+ ... custom_pipeline=str(model_dir / "pipeline.py"),
43
+ ... trust_remote_code=True,
44
+ ... torch_dtype=torch.bfloat16,
45
+ ... )
46
+ >>> pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(pipe.scheduler.config)
47
+ >>> pipe.to("cuda")
48
+
49
+ >>> print(pipe.id2label[207])
50
+ >>> print(pipe.get_label_ids("golden retriever"))
51
+
52
+ >>> generator = torch.Generator(device="cuda").manual_seed(42)
53
+ >>> image = pipe(
54
+ ... class_labels="golden retriever",
55
+ ... height=256,
56
+ ... width=256,
57
+ ... num_inference_steps=250,
58
+ ... guidance_scale=4.0,
59
+ ... generator=generator,
60
+ ... ).images[0]
61
+ ```
62
+ """
63
+
64
+ class SiTPipeline(DiffusionPipeline):
65
+ r"""
66
+ Pipeline for class-conditional image generation with Scalable Interpolant Transformers (SiT).
67
+
68
+ Parameters:
69
+ transformer ([`SiTTransformer2DModel`]):
70
+ Class-conditional SiT transformer that predicts flow-matching velocity in latent space.
71
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
72
+ Flow-matching Euler scheduler. Other [`KarrasDiffusionSchedulers`] can be swapped at inference time.
73
+ vae ([`AutoencoderKL`]):
74
+ Variational autoencoder used to decode transformer latents to pixels.
75
+ id2label (`dict[int, str]`, *optional*):
76
+ ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
77
+ """
78
+
79
+ model_cpu_offload_seq = "transformer->vae"
80
+
81
+ def __init__(
82
+ self,
83
+ transformer,
84
+ scheduler,
85
+ vae,
86
+ id2label: Optional[Dict[Union[int, str], str]] = None,
87
+ ):
88
+ super().__init__()
89
+ if scheduler is None:
90
+ scheduler = FlowMatchEulerDiscreteScheduler(
91
+ num_train_timesteps=1000,
92
+ shift=1.0,
93
+ stochastic_sampling=False,
94
+ )
95
+ self.register_modules(transformer=transformer, scheduler=scheduler, vae=vae)
96
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
97
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
98
+ self._id2label = self._normalize_id2label(id2label)
99
+ self.labels = self._build_label2id(self._id2label)
100
+ self._labels_loaded_from_model_index = bool(self._id2label)
101
+
102
+ def _ensure_labels_loaded(self) -> None:
103
+ if self._labels_loaded_from_model_index:
104
+ return
105
+ loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None))
106
+ if loaded:
107
+ self._id2label = loaded
108
+ self.labels = self._build_label2id(self._id2label)
109
+ self._labels_loaded_from_model_index = True
110
+
111
+ @staticmethod
112
+ def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
113
+ if not id2label:
114
+ return {}
115
+ return {int(key): value for key, value in id2label.items()}
116
+
117
+ @staticmethod
118
+ def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]:
119
+ if not variant_path:
120
+ return {}
121
+ variant_dir = Path(variant_path).resolve()
122
+ model_index_path = variant_dir / "model_index.json"
123
+ if not model_index_path.exists():
124
+ return {}
125
+ raw = json.loads(model_index_path.read_text(encoding="utf-8"))
126
+ id2label = raw.get("id2label")
127
+ if not isinstance(id2label, dict):
128
+ return {}
129
+ return {int(key): value for key, value in id2label.items()}
130
+
131
+ @staticmethod
132
+ def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
133
+ label2id: Dict[str, int] = {}
134
+ for class_id, value in id2label.items():
135
+ for synonym in value.split(","):
136
+ synonym = synonym.strip()
137
+ if synonym:
138
+ label2id[synonym] = int(class_id)
139
+ return dict(sorted(label2id.items()))
140
+
141
+ @property
142
+ def id2label(self) -> Dict[int, str]:
143
+ r"""ImageNet class id to English label string (comma-separated synonyms)."""
144
+ self._ensure_labels_loaded()
145
+ return self._id2label
146
+
147
+ def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
148
+ r"""
149
+ Map ImageNet label strings to class ids.
150
+
151
+ Args:
152
+ label (`str` or `list[str]`):
153
+ One or more English label strings. Each string must match a synonym in `id2label`.
154
+ """
155
+ self._ensure_labels_loaded()
156
+ label2id = self.labels
157
+ if not label2id:
158
+ raise ValueError("No English labels loaded. Ensure `id2label` exists in model_index.json.")
159
+
160
+ if isinstance(label, str):
161
+ label = [label]
162
+
163
+ missing = [item for item in label if item not in label2id]
164
+ if missing:
165
+ preview = ", ".join(list(label2id.keys())[:8])
166
+ raise ValueError(f"Unknown English label(s): {missing}. Example valid labels: {preview}, ...")
167
+ return [label2id[item] for item in label]
168
+
169
+ def _normalize_class_labels(
170
+ self,
171
+ class_labels: Union[int, str, List[Union[int, str]], torch.LongTensor],
172
+ ) -> torch.LongTensor:
173
+ if torch.is_tensor(class_labels):
174
+ return class_labels.to(device=self._execution_device, dtype=torch.long).reshape(-1)
175
+
176
+ if isinstance(class_labels, int):
177
+ class_label_ids = [class_labels]
178
+ elif isinstance(class_labels, str):
179
+ class_label_ids = self.get_label_ids(class_labels)
180
+ elif class_labels and isinstance(class_labels[0], str):
181
+ class_label_ids = self.get_label_ids(class_labels)
182
+ else:
183
+ class_label_ids = list(class_labels)
184
+
185
+ return torch.tensor(class_label_ids, device=self._execution_device, dtype=torch.long).reshape(-1)
186
+
187
+ def _default_image_size(self) -> int:
188
+ return int(self.transformer.config.input_size) * self.vae_scale_factor
189
+
190
+ def check_inputs(
191
+ self,
192
+ height: int,
193
+ width: int,
194
+ num_inference_steps: int,
195
+ output_type: str,
196
+ ) -> None:
197
+ if num_inference_steps < 1:
198
+ raise ValueError("num_inference_steps must be >= 1.")
199
+ if output_type not in {"pil", "np", "pt", "latent"}:
200
+ raise ValueError("output_type must be one of: 'pil', 'np', 'pt', 'latent'.")
201
+
202
+ if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0:
203
+ raise ValueError(
204
+ f"height and width must be divisible by the VAE downsample factor {self.vae_scale_factor}."
205
+ )
206
+
207
+ latent_height = height // self.vae_scale_factor
208
+ latent_width = width // self.vae_scale_factor
209
+ expected_size = int(self.transformer.config.input_size)
210
+ if latent_height != expected_size or latent_width != expected_size:
211
+ raise ValueError(
212
+ f"Requested latent size {(latent_height, latent_width)} does not match the pretrained "
213
+ f"transformer input_size={expected_size}. Use height=width={self._default_image_size()}."
214
+ )
215
+
216
+ def prepare_latents(
217
+ self,
218
+ batch_size: int,
219
+ height: int,
220
+ width: int,
221
+ dtype: torch.dtype,
222
+ device: torch.device,
223
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]],
224
+ ) -> torch.Tensor:
225
+ latent_height = height // self.vae_scale_factor
226
+ latent_width = width // self.vae_scale_factor
227
+ return randn_tensor(
228
+ (batch_size, self.transformer.config.in_channels, latent_height, latent_width),
229
+ generator=generator,
230
+ device=device,
231
+ dtype=dtype,
232
+ )
233
+
234
+ @staticmethod
235
+ def _apply_classifier_free_guidance(model_output: torch.Tensor, guidance_scale: float) -> torch.Tensor:
236
+ if guidance_scale <= 1.0:
237
+ return model_output
238
+ model_output_cond, model_output_uncond = model_output.chunk(2)
239
+ return model_output_uncond + guidance_scale * (model_output_cond - model_output_uncond)
240
+
241
+ def decode_latents(self, latents: torch.Tensor, output_type: str = "pil"):
242
+ if output_type == "latent":
243
+ return latents
244
+
245
+ scaling_factor = getattr(self.vae.config, "scaling_factor", 0.18215)
246
+ image = self.vae.decode(latents / scaling_factor).sample
247
+ if output_type == "pt":
248
+ return image
249
+ return self.image_processor.postprocess(image, output_type=output_type)
250
+
251
+ @torch.inference_mode()
252
+ def __call__(
253
+ self,
254
+ class_labels: Union[int, str, List[Union[int, str]], torch.LongTensor],
255
+ height: Optional[int] = None,
256
+ width: Optional[int] = None,
257
+ num_inference_steps: int = 250,
258
+ guidance_scale: float = 4.0,
259
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
260
+ output_type: str = "pil",
261
+ return_dict: bool = True,
262
+ ) -> Union[ImagePipelineOutput, Tuple]:
263
+ r"""
264
+ Generate class-conditional images with SiT.
265
+
266
+ Args:
267
+ class_labels (`int`, `str`, `list[int]`, `list[str]`, or `torch.LongTensor`):
268
+ ImageNet class indices or human-readable English label strings.
269
+ height (`int`, *optional*):
270
+ Output image height in pixels. Defaults to the pretrained native resolution.
271
+ width (`int`, *optional*):
272
+ Output image width in pixels. Defaults to the pretrained native resolution.
273
+ num_inference_steps (`int`, defaults to `250`):
274
+ Number of denoising steps.
275
+ guidance_scale (`float`, defaults to `4.0`):
276
+ Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`.
277
+ generator (`torch.Generator`, *optional*):
278
+ RNG for reproducibility.
279
+ output_type (`str`, defaults to `"pil"`):
280
+ `"pil"`, `"np"`, `"pt"`, or `"latent"`.
281
+ return_dict (`bool`, defaults to `True`):
282
+ Return [`ImagePipelineOutput`] if True.
283
+ """
284
+ default_size = self._default_image_size()
285
+ height = int(height or default_size)
286
+ width = int(width or default_size)
287
+ self.check_inputs(height, width, num_inference_steps, output_type)
288
+
289
+ device = self._execution_device
290
+ model_dtype = next(self.transformer.parameters()).dtype
291
+ class_labels_tensor = self._normalize_class_labels(class_labels)
292
+ batch_size = class_labels_tensor.numel()
293
+ do_cfg = guidance_scale > 1.0
294
+
295
+ latents = self.prepare_latents(
296
+ batch_size=batch_size,
297
+ height=height,
298
+ width=width,
299
+ dtype=model_dtype,
300
+ device=device,
301
+ generator=generator,
302
+ )
303
+
304
+ labels = class_labels_tensor
305
+ if do_cfg:
306
+ null_labels = torch.full_like(class_labels_tensor, self.transformer.config.num_classes)
307
+ labels = torch.cat([class_labels_tensor, null_labels], dim=0)
308
+
309
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
310
+ num_train_timesteps = self.scheduler.config.num_train_timesteps
311
+
312
+ if getattr(self.scheduler.config, "stochastic_sampling", False):
313
+ raise ValueError(
314
+ "SiT expects deterministic FlowMatchEulerDiscreteScheduler stepping "
315
+ "(scheduler.config.stochastic_sampling=False)."
316
+ )
317
+
318
+ for t in self.progress_bar(self.scheduler.timesteps):
319
+ flow_time = 1.0 - float(t) / num_train_timesteps
320
+ if do_cfg:
321
+ model_input = torch.cat([latents, latents], dim=0)
322
+ else:
323
+ model_input = latents
324
+
325
+ timestep_batch = torch.full((model_input.shape[0],), flow_time, device=device, dtype=model_dtype)
326
+ model_output = self.transformer(
327
+ hidden_states=model_input,
328
+ timestep=timestep_batch,
329
+ class_labels=labels,
330
+ return_dict=True,
331
+ ).sample
332
+ model_output = self._apply_classifier_free_guidance(model_output, guidance_scale=guidance_scale)
333
+ # SiT predicts dx/d(flow_time) with flow_time increasing from noise (0) to data (1).
334
+ # FlowMatchEulerDiscreteScheduler integrates over sigma decreasing from 1 to 0, so flip sign.
335
+ model_output = -model_output
336
+ latents = self.scheduler.step(
337
+ model_output=model_output,
338
+ timestep=t,
339
+ sample=latents,
340
+ generator=generator,
341
+ return_dict=True,
342
+ ).prev_sample
343
+
344
+ image = self.decode_latents(latents, output_type=output_type)
345
+
346
+ self.maybe_free_model_hooks()
347
+ if not return_dict:
348
+ return (image,)
349
+ return ImagePipelineOutput(images=image)
SiT-XL-2-256/scheduler/scheduler_config.json CHANGED
@@ -1,9 +1,7 @@
1
- {
2
- "_class_name": "SiTFlowMatchScheduler",
3
- "_diffusers_version": "0.36.0",
4
- "diffusion_form": "sigma",
5
- "diffusion_norm": 1.0,
6
- "mode": "ode",
7
- "num_train_timesteps": 1000,
8
- "shift": 1.0
9
- }
 
1
+ {
2
+ "_class_name": "FlowMatchEulerDiscreteScheduler",
3
+ "_diffusers_version": "0.36.0",
4
+ "num_train_timesteps": 1000,
5
+ "shift": 1.0,
6
+ "stochastic_sampling": false
7
+ }
 
 
SiT-XL-2-256/transformer/transformer_sit.py CHANGED
@@ -1,224 +1,240 @@
1
- import math
2
- from dataclasses import dataclass
3
- from typing import Optional
4
-
5
- import numpy as np
6
- import torch
7
- import torch.nn as nn
8
- from timm.models.vision_transformer import Attention, Mlp, PatchEmbed
9
-
10
- from diffusers.configuration_utils import ConfigMixin, register_to_config
11
- from diffusers.models.modeling_utils import ModelMixin
12
- from diffusers.utils import BaseOutput
13
-
14
-
15
- def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
16
- return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
17
-
18
-
19
- @dataclass
20
- class SiTTransformer2DModelOutput(BaseOutput):
21
- sample: torch.Tensor
22
-
23
-
24
- class TimestepEmbedder(nn.Module):
25
- def __init__(self, hidden_size: int, frequency_embedding_size: int = 256):
26
- super().__init__()
27
- self.mlp = nn.Sequential(
28
- nn.Linear(frequency_embedding_size, hidden_size, bias=True),
29
- nn.SiLU(),
30
- nn.Linear(hidden_size, hidden_size, bias=True),
31
- )
32
- self.frequency_embedding_size = frequency_embedding_size
33
-
34
- @staticmethod
35
- def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 10000) -> torch.Tensor:
36
- half = dim // 2
37
- freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
38
- device=t.device
39
- )
40
- args = t[:, None].float() * freqs[None]
41
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
42
- if dim % 2:
43
- embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
44
- return embedding
45
-
46
- def forward(self, t: torch.Tensor) -> torch.Tensor:
47
- return self.mlp(self.timestep_embedding(t, self.frequency_embedding_size))
48
-
49
-
50
- class LabelEmbedder(nn.Module):
51
- def __init__(self, num_classes: int, hidden_size: int, dropout_prob: float):
52
- super().__init__()
53
- use_cfg_embedding = dropout_prob > 0
54
- self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
55
- self.num_classes = num_classes
56
- self.dropout_prob = dropout_prob
57
-
58
- def token_drop(self, labels: torch.Tensor, force_drop_ids: Optional[torch.Tensor] = None) -> torch.Tensor:
59
- if force_drop_ids is None:
60
- drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
61
- else:
62
- drop_ids = force_drop_ids == 1
63
- labels = torch.where(drop_ids, self.num_classes, labels)
64
- return labels
65
-
66
- def forward(
67
- self,
68
- labels: torch.Tensor,
69
- train: bool,
70
- force_drop_ids: Optional[torch.Tensor] = None,
71
- ) -> torch.Tensor:
72
- use_dropout = self.dropout_prob > 0
73
- if (train and use_dropout) or (force_drop_ids is not None):
74
- labels = self.token_drop(labels, force_drop_ids)
75
- return self.embedding_table(labels)
76
-
77
-
78
- class SiTBlock(nn.Module):
79
- def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float = 4.0, **block_kwargs):
80
- super().__init__()
81
- self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
82
- self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
83
- self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
84
- mlp_hidden_dim = int(hidden_size * mlp_ratio)
85
- approx_gelu = lambda: nn.GELU(approximate="tanh")
86
- self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
87
- self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
88
-
89
- def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
90
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
91
- x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
92
- x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
93
- return x
94
-
95
-
96
- class FinalLayer(nn.Module):
97
- def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
98
- super().__init__()
99
- self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
100
- self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
101
- self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
102
-
103
- def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
104
- shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
105
- x = modulate(self.norm_final(x), shift, scale)
106
- return self.linear(x)
107
-
108
-
109
- class SiTTransformer2DModel(ModelMixin, ConfigMixin):
110
- @register_to_config
111
- def __init__(
112
- self,
113
- input_size: int = 32,
114
- patch_size: int = 2,
115
- in_channels: int = 4,
116
- hidden_size: int = 1152,
117
- depth: int = 28,
118
- num_heads: int = 16,
119
- mlp_ratio: float = 4.0,
120
- class_dropout_prob: float = 0.1,
121
- num_classes: int = 1000,
122
- learn_sigma: bool = True,
123
- ):
124
- super().__init__()
125
- self.learn_sigma = learn_sigma
126
- self.in_channels = in_channels
127
- self.out_channels = in_channels * 2 if learn_sigma else in_channels
128
- self.patch_size = patch_size
129
- self.num_classes = num_classes
130
-
131
- self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
132
- self.t_embedder = TimestepEmbedder(hidden_size)
133
- self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
134
- num_patches = self.x_embedder.num_patches
135
- self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
136
-
137
- self.blocks = nn.ModuleList([SiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)])
138
- self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
139
- self.initialize_weights()
140
-
141
- def initialize_weights(self) -> None:
142
- def _basic_init(module: nn.Module):
143
- if isinstance(module, nn.Linear):
144
- torch.nn.init.xavier_uniform_(module.weight)
145
- if module.bias is not None:
146
- nn.init.constant_(module.bias, 0)
147
-
148
- self.apply(_basic_init)
149
- pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5))
150
- self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
151
-
152
- w = self.x_embedder.proj.weight.data
153
- nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
154
- nn.init.constant_(self.x_embedder.proj.bias, 0)
155
- nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
156
- nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
157
- nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
158
- for block in self.blocks:
159
- nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
160
- nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
161
- nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
162
- nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
163
- nn.init.constant_(self.final_layer.linear.weight, 0)
164
- nn.init.constant_(self.final_layer.linear.bias, 0)
165
-
166
- def unpatchify(self, x: torch.Tensor) -> torch.Tensor:
167
- c = self.out_channels
168
- p = self.x_embedder.patch_size[0]
169
- h = w = int(x.shape[1] ** 0.5)
170
- x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
171
- x = torch.einsum("nhwpqc->nchpwq", x)
172
- return x.reshape(shape=(x.shape[0], c, h * p, h * p))
173
-
174
- def forward(
175
- self,
176
- hidden_states: torch.Tensor,
177
- timestep: torch.Tensor,
178
- class_labels: torch.Tensor,
179
- force_drop_ids: Optional[torch.Tensor] = None,
180
- return_dict: bool = True,
181
- ) -> SiTTransformer2DModelOutput:
182
- x = self.x_embedder(hidden_states) + self.pos_embed
183
- t = self.t_embedder(timestep)
184
- y = self.y_embedder(class_labels, self.training, force_drop_ids=force_drop_ids)
185
- c = t + y
186
- for block in self.blocks:
187
- x = block(x, c)
188
- x = self.final_layer(x, c)
189
- x = self.unpatchify(x)
190
- if self.learn_sigma:
191
- x, _ = x.chunk(2, dim=1)
192
- if not return_dict:
193
- return (x,)
194
- return SiTTransformer2DModelOutput(sample=x)
195
-
196
-
197
- def get_2d_sincos_pos_embed(embed_dim: int, grid_size: int, cls_token: bool = False, extra_tokens: int = 0):
198
- grid_h = np.arange(grid_size, dtype=np.float32)
199
- grid_w = np.arange(grid_size, dtype=np.float32)
200
- grid = np.meshgrid(grid_w, grid_h)
201
- grid = np.stack(grid, axis=0).reshape([2, 1, grid_size, grid_size])
202
- pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
203
- if cls_token and extra_tokens > 0:
204
- pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
205
- return pos_embed
206
-
207
-
208
- def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: np.ndarray):
209
- assert embed_dim % 2 == 0
210
- emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
211
- emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
212
- return np.concatenate([emb_h, emb_w], axis=1)
213
-
214
-
215
- def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: np.ndarray):
216
- assert embed_dim % 2 == 0
217
- omega = np.arange(embed_dim // 2, dtype=np.float64)
218
- omega /= embed_dim / 2.0
219
- omega = 1.0 / 10000**omega
220
- pos = pos.reshape(-1)
221
- out = np.einsum("m,d->md", pos, omega)
222
- emb_sin = np.sin(out)
223
- emb_cos = np.cos(out)
224
- return np.concatenate([emb_sin, emb_cos], axis=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from dataclasses import dataclass
17
+ from typing import Optional
18
+
19
+ import numpy as np
20
+ import torch
21
+ import torch.nn as nn
22
+ from timm.models.vision_transformer import Attention, Mlp, PatchEmbed
23
+
24
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
25
+ from diffusers.models.modeling_utils import ModelMixin
26
+ from diffusers.utils import BaseOutput
27
+
28
+
29
+ def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
30
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
31
+
32
+
33
+ @dataclass
34
+ class SiTTransformer2DModelOutput(BaseOutput):
35
+ sample: torch.Tensor
36
+
37
+
38
+ class TimestepEmbedder(nn.Module):
39
+ def __init__(self, hidden_size: int, frequency_embedding_size: int = 256):
40
+ super().__init__()
41
+ self.mlp = nn.Sequential(
42
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
43
+ nn.SiLU(),
44
+ nn.Linear(hidden_size, hidden_size, bias=True),
45
+ )
46
+ self.frequency_embedding_size = frequency_embedding_size
47
+
48
+ @staticmethod
49
+ def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 10000) -> torch.Tensor:
50
+ half = dim // 2
51
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
52
+ device=t.device
53
+ )
54
+ args = t[:, None].float() * freqs[None]
55
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
56
+ if dim % 2:
57
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
58
+ return embedding
59
+
60
+ def forward(self, t: torch.Tensor) -> torch.Tensor:
61
+ emb = self.timestep_embedding(t.float(), self.frequency_embedding_size)
62
+ weight_dtype = self.mlp[0].weight.dtype
63
+ return self.mlp(emb.to(dtype=weight_dtype))
64
+
65
+
66
+ class LabelEmbedder(nn.Module):
67
+ def __init__(self, num_classes: int, hidden_size: int, dropout_prob: float):
68
+ super().__init__()
69
+ use_cfg_embedding = dropout_prob > 0
70
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
71
+ self.num_classes = num_classes
72
+ self.dropout_prob = dropout_prob
73
+
74
+ def token_drop(self, labels: torch.Tensor, force_drop_ids: Optional[torch.Tensor] = None) -> torch.Tensor:
75
+ if force_drop_ids is None:
76
+ drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
77
+ else:
78
+ drop_ids = force_drop_ids == 1
79
+ labels = torch.where(drop_ids, self.num_classes, labels)
80
+ return labels
81
+
82
+ def forward(
83
+ self,
84
+ labels: torch.Tensor,
85
+ train: bool,
86
+ force_drop_ids: Optional[torch.Tensor] = None,
87
+ ) -> torch.Tensor:
88
+ use_dropout = self.dropout_prob > 0
89
+ if (train and use_dropout) or (force_drop_ids is not None):
90
+ labels = self.token_drop(labels, force_drop_ids)
91
+ return self.embedding_table(labels)
92
+
93
+
94
+ class SiTBlock(nn.Module):
95
+ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float = 4.0, **block_kwargs):
96
+ super().__init__()
97
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
98
+ self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
99
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
100
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
101
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
102
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
103
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
104
+
105
+ def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
106
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
107
+ x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
108
+ x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
109
+ return x
110
+
111
+
112
+ class FinalLayer(nn.Module):
113
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
114
+ super().__init__()
115
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
116
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
117
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
118
+
119
+ def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
120
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
121
+ x = modulate(self.norm_final(x), shift, scale)
122
+ return self.linear(x)
123
+
124
+
125
+ class SiTTransformer2DModel(ModelMixin, ConfigMixin):
126
+ @register_to_config
127
+ def __init__(
128
+ self,
129
+ input_size: int = 32,
130
+ patch_size: int = 2,
131
+ in_channels: int = 4,
132
+ hidden_size: int = 1152,
133
+ depth: int = 28,
134
+ num_heads: int = 16,
135
+ mlp_ratio: float = 4.0,
136
+ class_dropout_prob: float = 0.1,
137
+ num_classes: int = 1000,
138
+ learn_sigma: bool = True,
139
+ ):
140
+ super().__init__()
141
+ self.learn_sigma = learn_sigma
142
+ self.in_channels = in_channels
143
+ self.out_channels = in_channels * 2 if learn_sigma else in_channels
144
+ self.patch_size = patch_size
145
+ self.num_classes = num_classes
146
+
147
+ self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
148
+ self.t_embedder = TimestepEmbedder(hidden_size)
149
+ self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
150
+ num_patches = self.x_embedder.num_patches
151
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
152
+
153
+ self.blocks = nn.ModuleList([SiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)])
154
+ self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
155
+ self.initialize_weights()
156
+
157
+ def initialize_weights(self) -> None:
158
+ def _basic_init(module: nn.Module):
159
+ if isinstance(module, nn.Linear):
160
+ torch.nn.init.xavier_uniform_(module.weight)
161
+ if module.bias is not None:
162
+ nn.init.constant_(module.bias, 0)
163
+
164
+ self.apply(_basic_init)
165
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5))
166
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
167
+
168
+ w = self.x_embedder.proj.weight.data
169
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
170
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
171
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
172
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
173
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
174
+ for block in self.blocks:
175
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
176
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
177
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
178
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
179
+ nn.init.constant_(self.final_layer.linear.weight, 0)
180
+ nn.init.constant_(self.final_layer.linear.bias, 0)
181
+
182
+ def unpatchify(self, x: torch.Tensor) -> torch.Tensor:
183
+ c = self.out_channels
184
+ p = self.x_embedder.patch_size[0]
185
+ h = w = int(x.shape[1] ** 0.5)
186
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
187
+ x = torch.einsum("nhwpqc->nchpwq", x)
188
+ return x.reshape(shape=(x.shape[0], c, h * p, h * p))
189
+
190
+ def forward(
191
+ self,
192
+ hidden_states: torch.Tensor,
193
+ timestep: torch.Tensor,
194
+ class_labels: torch.Tensor,
195
+ force_drop_ids: Optional[torch.Tensor] = None,
196
+ return_dict: bool = True,
197
+ ) -> SiTTransformer2DModelOutput:
198
+ x = self.x_embedder(hidden_states) + self.pos_embed
199
+ t = self.t_embedder(timestep)
200
+ y = self.y_embedder(class_labels, self.training, force_drop_ids=force_drop_ids)
201
+ c = t + y
202
+ for block in self.blocks:
203
+ x = block(x, c)
204
+ x = self.final_layer(x, c)
205
+ x = self.unpatchify(x)
206
+ if self.learn_sigma:
207
+ x, _ = x.chunk(2, dim=1)
208
+ if not return_dict:
209
+ return (x,)
210
+ return SiTTransformer2DModelOutput(sample=x)
211
+
212
+
213
+ def get_2d_sincos_pos_embed(embed_dim: int, grid_size: int, cls_token: bool = False, extra_tokens: int = 0):
214
+ grid_h = np.arange(grid_size, dtype=np.float32)
215
+ grid_w = np.arange(grid_size, dtype=np.float32)
216
+ grid = np.meshgrid(grid_w, grid_h)
217
+ grid = np.stack(grid, axis=0).reshape([2, 1, grid_size, grid_size])
218
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
219
+ if cls_token and extra_tokens > 0:
220
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
221
+ return pos_embed
222
+
223
+
224
+ def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: np.ndarray):
225
+ assert embed_dim % 2 == 0
226
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
227
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
228
+ return np.concatenate([emb_h, emb_w], axis=1)
229
+
230
+
231
+ def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: np.ndarray):
232
+ assert embed_dim % 2 == 0
233
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
234
+ omega /= embed_dim / 2.0
235
+ omega = 1.0 / 10000**omega
236
+ pos = pos.reshape(-1)
237
+ out = np.einsum("m,d->md", pos, omega)
238
+ emb_sin = np.sin(out)
239
+ emb_cos = np.cos(out)
240
+ return np.concatenate([emb_sin, emb_cos], axis=1)
SiT-XL-2-512/__pycache__/pipeline.cpython-312.pyc ADDED
Binary file (7.51 kB). View file
 
SiT-XL-2-512/demo.png CHANGED

Git LFS Details

  • SHA256: 950aa2c2b87fbe579a7ba22fb1add8fb372d7f7f05829c4ec98b42288e58bf13
  • Pointer size: 131 Bytes
  • Size of remote file: 381 kB

Git LFS Details

  • SHA256: a271541f1bf1b6a7cd0937847ae5b3d26e49be2f289bf06aca8689a97a2ee21c
  • Pointer size: 131 Bytes
  • Size of remote file: 425 kB
SiT-XL-2-512/model_index.json CHANGED
@@ -1,19 +1,1021 @@
1
- {
2
- "_class_name": [
3
- "pipeline",
4
- "SiTPipeline"
5
- ],
6
- "_diffusers_version": "0.36.0",
7
- "scheduler": [
8
- "scheduling_flow_match_sit",
9
- "SiTFlowMatchScheduler"
10
- ],
11
- "transformer": [
12
- "transformer_sit",
13
- "SiTTransformer2DModel"
14
- ],
15
- "vae": [
16
- "diffusers",
17
- "AutoencoderKL"
18
- ]
19
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": [
3
+ "pipeline",
4
+ "SiTPipeline"
5
+ ],
6
+ "_diffusers_version": "0.36.0",
7
+ "scheduler": [
8
+ "diffusers",
9
+ "FlowMatchEulerDiscreteScheduler"
10
+ ],
11
+ "transformer": [
12
+ "transformer_sit",
13
+ "SiTTransformer2DModel"
14
+ ],
15
+ "vae": [
16
+ "diffusers",
17
+ "AutoencoderKL"
18
+ ],
19
+ "id2label": {
20
+ "0": "tench, Tinca tinca",
21
+ "1": "goldfish, Carassius auratus",
22
+ "2": "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias",
23
+ "3": "tiger shark, Galeocerdo cuvieri",
24
+ "4": "hammerhead, hammerhead shark",
25
+ "5": "electric ray, crampfish, numbfish, torpedo",
26
+ "6": "stingray",
27
+ "7": "cock",
28
+ "8": "hen",
29
+ "9": "ostrich, Struthio camelus",
30
+ "10": "brambling, Fringilla montifringilla",
31
+ "11": "goldfinch, Carduelis carduelis",
32
+ "12": "house finch, linnet, Carpodacus mexicanus",
33
+ "13": "junco, snowbird",
34
+ "14": "indigo bunting, indigo finch, indigo bird, Passerina cyanea",
35
+ "15": "robin, American robin, Turdus migratorius",
36
+ "16": "bulbul",
37
+ "17": "jay",
38
+ "18": "magpie",
39
+ "19": "chickadee",
40
+ "20": "water ouzel, dipper",
41
+ "21": "kite",
42
+ "22": "bald eagle, American eagle, Haliaeetus leucocephalus",
43
+ "23": "vulture",
44
+ "24": "great grey owl, great gray owl, Strix nebulosa",
45
+ "25": "European fire salamander, Salamandra salamandra",
46
+ "26": "common newt, Triturus vulgaris",
47
+ "27": "eft",
48
+ "28": "spotted salamander, Ambystoma maculatum",
49
+ "29": "axolotl, mud puppy, Ambystoma mexicanum",
50
+ "30": "bullfrog, Rana catesbeiana",
51
+ "31": "tree frog, tree-frog",
52
+ "32": "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui",
53
+ "33": "loggerhead, loggerhead turtle, Caretta caretta",
54
+ "34": "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea",
55
+ "35": "mud turtle",
56
+ "36": "terrapin",
57
+ "37": "box turtle, box tortoise",
58
+ "38": "banded gecko",
59
+ "39": "common iguana, iguana, Iguana iguana",
60
+ "40": "American chameleon, anole, Anolis carolinensis",
61
+ "41": "whiptail, whiptail lizard",
62
+ "42": "agama",
63
+ "43": "frilled lizard, Chlamydosaurus kingi",
64
+ "44": "alligator lizard",
65
+ "45": "Gila monster, Heloderma suspectum",
66
+ "46": "green lizard, Lacerta viridis",
67
+ "47": "African chameleon, Chamaeleo chamaeleon",
68
+ "48": "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis",
69
+ "49": "African crocodile, Nile crocodile, Crocodylus niloticus",
70
+ "50": "American alligator, Alligator mississipiensis",
71
+ "51": "triceratops",
72
+ "52": "thunder snake, worm snake, Carphophis amoenus",
73
+ "53": "ringneck snake, ring-necked snake, ring snake",
74
+ "54": "hognose snake, puff adder, sand viper",
75
+ "55": "green snake, grass snake",
76
+ "56": "king snake, kingsnake",
77
+ "57": "garter snake, grass snake",
78
+ "58": "water snake",
79
+ "59": "vine snake",
80
+ "60": "night snake, Hypsiglena torquata",
81
+ "61": "boa constrictor, Constrictor constrictor",
82
+ "62": "rock python, rock snake, Python sebae",
83
+ "63": "Indian cobra, Naja naja",
84
+ "64": "green mamba",
85
+ "65": "sea snake",
86
+ "66": "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus",
87
+ "67": "diamondback, diamondback rattlesnake, Crotalus adamanteus",
88
+ "68": "sidewinder, horned rattlesnake, Crotalus cerastes",
89
+ "69": "trilobite",
90
+ "70": "harvestman, daddy longlegs, Phalangium opilio",
91
+ "71": "scorpion",
92
+ "72": "black and gold garden spider, Argiope aurantia",
93
+ "73": "barn spider, Araneus cavaticus",
94
+ "74": "garden spider, Aranea diademata",
95
+ "75": "black widow, Latrodectus mactans",
96
+ "76": "tarantula",
97
+ "77": "wolf spider, hunting spider",
98
+ "78": "tick",
99
+ "79": "centipede",
100
+ "80": "black grouse",
101
+ "81": "ptarmigan",
102
+ "82": "ruffed grouse, partridge, Bonasa umbellus",
103
+ "83": "prairie chicken, prairie grouse, prairie fowl",
104
+ "84": "peacock",
105
+ "85": "quail",
106
+ "86": "partridge",
107
+ "87": "African grey, African gray, Psittacus erithacus",
108
+ "88": "macaw",
109
+ "89": "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita",
110
+ "90": "lorikeet",
111
+ "91": "coucal",
112
+ "92": "bee eater",
113
+ "93": "hornbill",
114
+ "94": "hummingbird",
115
+ "95": "jacamar",
116
+ "96": "toucan",
117
+ "97": "drake",
118
+ "98": "red-breasted merganser, Mergus serrator",
119
+ "99": "goose",
120
+ "100": "black swan, Cygnus atratus",
121
+ "101": "tusker",
122
+ "102": "echidna, spiny anteater, anteater",
123
+ "103": "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus",
124
+ "104": "wallaby, brush kangaroo",
125
+ "105": "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus",
126
+ "106": "wombat",
127
+ "107": "jellyfish",
128
+ "108": "sea anemone, anemone",
129
+ "109": "brain coral",
130
+ "110": "flatworm, platyhelminth",
131
+ "111": "nematode, nematode worm, roundworm",
132
+ "112": "conch",
133
+ "113": "snail",
134
+ "114": "slug",
135
+ "115": "sea slug, nudibranch",
136
+ "116": "chiton, coat-of-mail shell, sea cradle, polyplacophore",
137
+ "117": "chambered nautilus, pearly nautilus, nautilus",
138
+ "118": "Dungeness crab, Cancer magister",
139
+ "119": "rock crab, Cancer irroratus",
140
+ "120": "fiddler crab",
141
+ "121": "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica",
142
+ "122": "American lobster, Northern lobster, Maine lobster, Homarus americanus",
143
+ "123": "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish",
144
+ "124": "crayfish, crawfish, crawdad, crawdaddy",
145
+ "125": "hermit crab",
146
+ "126": "isopod",
147
+ "127": "white stork, Ciconia ciconia",
148
+ "128": "black stork, Ciconia nigra",
149
+ "129": "spoonbill",
150
+ "130": "flamingo",
151
+ "131": "little blue heron, Egretta caerulea",
152
+ "132": "American egret, great white heron, Egretta albus",
153
+ "133": "bittern",
154
+ "134": "crane",
155
+ "135": "limpkin, Aramus pictus",
156
+ "136": "European gallinule, Porphyrio porphyrio",
157
+ "137": "American coot, marsh hen, mud hen, water hen, Fulica americana",
158
+ "138": "bustard",
159
+ "139": "ruddy turnstone, Arenaria interpres",
160
+ "140": "red-backed sandpiper, dunlin, Erolia alpina",
161
+ "141": "redshank, Tringa totanus",
162
+ "142": "dowitcher",
163
+ "143": "oystercatcher, oyster catcher",
164
+ "144": "pelican",
165
+ "145": "king penguin, Aptenodytes patagonica",
166
+ "146": "albatross, mollymawk",
167
+ "147": "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus",
168
+ "148": "killer whale, killer, orca, grampus, sea wolf, Orcinus orca",
169
+ "149": "dugong, Dugong dugon",
170
+ "150": "sea lion",
171
+ "151": "Chihuahua",
172
+ "152": "Japanese spaniel",
173
+ "153": "Maltese dog, Maltese terrier, Maltese",
174
+ "154": "Pekinese, Pekingese, Peke",
175
+ "155": "Shih-Tzu",
176
+ "156": "Blenheim spaniel",
177
+ "157": "papillon",
178
+ "158": "toy terrier",
179
+ "159": "Rhodesian ridgeback",
180
+ "160": "Afghan hound, Afghan",
181
+ "161": "basset, basset hound",
182
+ "162": "beagle",
183
+ "163": "bloodhound, sleuthhound",
184
+ "164": "bluetick",
185
+ "165": "black-and-tan coonhound",
186
+ "166": "Walker hound, Walker foxhound",
187
+ "167": "English foxhound",
188
+ "168": "redbone",
189
+ "169": "borzoi, Russian wolfhound",
190
+ "170": "Irish wolfhound",
191
+ "171": "Italian greyhound",
192
+ "172": "whippet",
193
+ "173": "Ibizan hound, Ibizan Podenco",
194
+ "174": "Norwegian elkhound, elkhound",
195
+ "175": "otterhound, otter hound",
196
+ "176": "Saluki, gazelle hound",
197
+ "177": "Scottish deerhound, deerhound",
198
+ "178": "Weimaraner",
199
+ "179": "Staffordshire bullterrier, Staffordshire bull terrier",
200
+ "180": "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier",
201
+ "181": "Bedlington terrier",
202
+ "182": "Border terrier",
203
+ "183": "Kerry blue terrier",
204
+ "184": "Irish terrier",
205
+ "185": "Norfolk terrier",
206
+ "186": "Norwich terrier",
207
+ "187": "Yorkshire terrier",
208
+ "188": "wire-haired fox terrier",
209
+ "189": "Lakeland terrier",
210
+ "190": "Sealyham terrier, Sealyham",
211
+ "191": "Airedale, Airedale terrier",
212
+ "192": "cairn, cairn terrier",
213
+ "193": "Australian terrier",
214
+ "194": "Dandie Dinmont, Dandie Dinmont terrier",
215
+ "195": "Boston bull, Boston terrier",
216
+ "196": "miniature schnauzer",
217
+ "197": "giant schnauzer",
218
+ "198": "standard schnauzer",
219
+ "199": "Scotch terrier, Scottish terrier, Scottie",
220
+ "200": "Tibetan terrier, chrysanthemum dog",
221
+ "201": "silky terrier, Sydney silky",
222
+ "202": "soft-coated wheaten terrier",
223
+ "203": "West Highland white terrier",
224
+ "204": "Lhasa, Lhasa apso",
225
+ "205": "flat-coated retriever",
226
+ "206": "curly-coated retriever",
227
+ "207": "golden retriever",
228
+ "208": "Labrador retriever",
229
+ "209": "Chesapeake Bay retriever",
230
+ "210": "German short-haired pointer",
231
+ "211": "vizsla, Hungarian pointer",
232
+ "212": "English setter",
233
+ "213": "Irish setter, red setter",
234
+ "214": "Gordon setter",
235
+ "215": "Brittany spaniel",
236
+ "216": "clumber, clumber spaniel",
237
+ "217": "English springer, English springer spaniel",
238
+ "218": "Welsh springer spaniel",
239
+ "219": "cocker spaniel, English cocker spaniel, cocker",
240
+ "220": "Sussex spaniel",
241
+ "221": "Irish water spaniel",
242
+ "222": "kuvasz",
243
+ "223": "schipperke",
244
+ "224": "groenendael",
245
+ "225": "malinois",
246
+ "226": "briard",
247
+ "227": "kelpie",
248
+ "228": "komondor",
249
+ "229": "Old English sheepdog, bobtail",
250
+ "230": "Shetland sheepdog, Shetland sheep dog, Shetland",
251
+ "231": "collie",
252
+ "232": "Border collie",
253
+ "233": "Bouvier des Flandres, Bouviers des Flandres",
254
+ "234": "Rottweiler",
255
+ "235": "German shepherd, German shepherd dog, German police dog, alsatian",
256
+ "236": "Doberman, Doberman pinscher",
257
+ "237": "miniature pinscher",
258
+ "238": "Greater Swiss Mountain dog",
259
+ "239": "Bernese mountain dog",
260
+ "240": "Appenzeller",
261
+ "241": "EntleBucher",
262
+ "242": "boxer",
263
+ "243": "bull mastiff",
264
+ "244": "Tibetan mastiff",
265
+ "245": "French bulldog",
266
+ "246": "Great Dane",
267
+ "247": "Saint Bernard, St Bernard",
268
+ "248": "Eskimo dog, husky",
269
+ "249": "malamute, malemute, Alaskan malamute",
270
+ "250": "Siberian husky",
271
+ "251": "dalmatian, coach dog, carriage dog",
272
+ "252": "affenpinscher, monkey pinscher, monkey dog",
273
+ "253": "basenji",
274
+ "254": "pug, pug-dog",
275
+ "255": "Leonberg",
276
+ "256": "Newfoundland, Newfoundland dog",
277
+ "257": "Great Pyrenees",
278
+ "258": "Samoyed, Samoyede",
279
+ "259": "Pomeranian",
280
+ "260": "chow, chow chow",
281
+ "261": "keeshond",
282
+ "262": "Brabancon griffon",
283
+ "263": "Pembroke, Pembroke Welsh corgi",
284
+ "264": "Cardigan, Cardigan Welsh corgi",
285
+ "265": "toy poodle",
286
+ "266": "miniature poodle",
287
+ "267": "standard poodle",
288
+ "268": "Mexican hairless",
289
+ "269": "timber wolf, grey wolf, gray wolf, Canis lupus",
290
+ "270": "white wolf, Arctic wolf, Canis lupus tundrarum",
291
+ "271": "red wolf, maned wolf, Canis rufus, Canis niger",
292
+ "272": "coyote, prairie wolf, brush wolf, Canis latrans",
293
+ "273": "dingo, warrigal, warragal, Canis dingo",
294
+ "274": "dhole, Cuon alpinus",
295
+ "275": "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus",
296
+ "276": "hyena, hyaena",
297
+ "277": "red fox, Vulpes vulpes",
298
+ "278": "kit fox, Vulpes macrotis",
299
+ "279": "Arctic fox, white fox, Alopex lagopus",
300
+ "280": "grey fox, gray fox, Urocyon cinereoargenteus",
301
+ "281": "tabby, tabby cat",
302
+ "282": "tiger cat",
303
+ "283": "Persian cat",
304
+ "284": "Siamese cat, Siamese",
305
+ "285": "Egyptian cat",
306
+ "286": "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor",
307
+ "287": "lynx, catamount",
308
+ "288": "leopard, Panthera pardus",
309
+ "289": "snow leopard, ounce, Panthera uncia",
310
+ "290": "jaguar, panther, Panthera onca, Felis onca",
311
+ "291": "lion, king of beasts, Panthera leo",
312
+ "292": "tiger, Panthera tigris",
313
+ "293": "cheetah, chetah, Acinonyx jubatus",
314
+ "294": "brown bear, bruin, Ursus arctos",
315
+ "295": "American black bear, black bear, Ursus americanus, Euarctos americanus",
316
+ "296": "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus",
317
+ "297": "sloth bear, Melursus ursinus, Ursus ursinus",
318
+ "298": "mongoose",
319
+ "299": "meerkat, mierkat",
320
+ "300": "tiger beetle",
321
+ "301": "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle",
322
+ "302": "ground beetle, carabid beetle",
323
+ "303": "long-horned beetle, longicorn, longicorn beetle",
324
+ "304": "leaf beetle, chrysomelid",
325
+ "305": "dung beetle",
326
+ "306": "rhinoceros beetle",
327
+ "307": "weevil",
328
+ "308": "fly",
329
+ "309": "bee",
330
+ "310": "ant, emmet, pismire",
331
+ "311": "grasshopper, hopper",
332
+ "312": "cricket",
333
+ "313": "walking stick, walkingstick, stick insect",
334
+ "314": "cockroach, roach",
335
+ "315": "mantis, mantid",
336
+ "316": "cicada, cicala",
337
+ "317": "leafhopper",
338
+ "318": "lacewing, lacewing fly",
339
+ "319": "dragonfly, darning needle, devils darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk",
340
+ "320": "damselfly",
341
+ "321": "admiral",
342
+ "322": "ringlet, ringlet butterfly",
343
+ "323": "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus",
344
+ "324": "cabbage butterfly",
345
+ "325": "sulphur butterfly, sulfur butterfly",
346
+ "326": "lycaenid, lycaenid butterfly",
347
+ "327": "starfish, sea star",
348
+ "328": "sea urchin",
349
+ "329": "sea cucumber, holothurian",
350
+ "330": "wood rabbit, cottontail, cottontail rabbit",
351
+ "331": "hare",
352
+ "332": "Angora, Angora rabbit",
353
+ "333": "hamster",
354
+ "334": "porcupine, hedgehog",
355
+ "335": "fox squirrel, eastern fox squirrel, Sciurus niger",
356
+ "336": "marmot",
357
+ "337": "beaver",
358
+ "338": "guinea pig, Cavia cobaya",
359
+ "339": "sorrel",
360
+ "340": "zebra",
361
+ "341": "hog, pig, grunter, squealer, Sus scrofa",
362
+ "342": "wild boar, boar, Sus scrofa",
363
+ "343": "warthog",
364
+ "344": "hippopotamus, hippo, river horse, Hippopotamus amphibius",
365
+ "345": "ox",
366
+ "346": "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis",
367
+ "347": "bison",
368
+ "348": "ram, tup",
369
+ "349": "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis",
370
+ "350": "ibex, Capra ibex",
371
+ "351": "hartebeest",
372
+ "352": "impala, Aepyceros melampus",
373
+ "353": "gazelle",
374
+ "354": "Arabian camel, dromedary, Camelus dromedarius",
375
+ "355": "llama",
376
+ "356": "weasel",
377
+ "357": "mink",
378
+ "358": "polecat, fitch, foulmart, foumart, Mustela putorius",
379
+ "359": "black-footed ferret, ferret, Mustela nigripes",
380
+ "360": "otter",
381
+ "361": "skunk, polecat, wood pussy",
382
+ "362": "badger",
383
+ "363": "armadillo",
384
+ "364": "three-toed sloth, ai, Bradypus tridactylus",
385
+ "365": "orangutan, orang, orangutang, Pongo pygmaeus",
386
+ "366": "gorilla, Gorilla gorilla",
387
+ "367": "chimpanzee, chimp, Pan troglodytes",
388
+ "368": "gibbon, Hylobates lar",
389
+ "369": "siamang, Hylobates syndactylus, Symphalangus syndactylus",
390
+ "370": "guenon, guenon monkey",
391
+ "371": "patas, hussar monkey, Erythrocebus patas",
392
+ "372": "baboon",
393
+ "373": "macaque",
394
+ "374": "langur",
395
+ "375": "colobus, colobus monkey",
396
+ "376": "proboscis monkey, Nasalis larvatus",
397
+ "377": "marmoset",
398
+ "378": "capuchin, ringtail, Cebus capucinus",
399
+ "379": "howler monkey, howler",
400
+ "380": "titi, titi monkey",
401
+ "381": "spider monkey, Ateles geoffroyi",
402
+ "382": "squirrel monkey, Saimiri sciureus",
403
+ "383": "Madagascar cat, ring-tailed lemur, Lemur catta",
404
+ "384": "indri, indris, Indri indri, Indri brevicaudatus",
405
+ "385": "Indian elephant, Elephas maximus",
406
+ "386": "African elephant, Loxodonta africana",
407
+ "387": "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens",
408
+ "388": "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca",
409
+ "389": "barracouta, snoek",
410
+ "390": "eel",
411
+ "391": "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch",
412
+ "392": "rock beauty, Holocanthus tricolor",
413
+ "393": "anemone fish",
414
+ "394": "sturgeon",
415
+ "395": "gar, garfish, garpike, billfish, Lepisosteus osseus",
416
+ "396": "lionfish",
417
+ "397": "puffer, pufferfish, blowfish, globefish",
418
+ "398": "abacus",
419
+ "399": "abaya",
420
+ "400": "academic gown, academic robe, judge robe",
421
+ "401": "accordion, piano accordion, squeeze box",
422
+ "402": "acoustic guitar",
423
+ "403": "aircraft carrier, carrier, flattop, attack aircraft carrier",
424
+ "404": "airliner",
425
+ "405": "airship, dirigible",
426
+ "406": "altar",
427
+ "407": "ambulance",
428
+ "408": "amphibian, amphibious vehicle",
429
+ "409": "analog clock",
430
+ "410": "apiary, bee house",
431
+ "411": "apron",
432
+ "412": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin",
433
+ "413": "assault rifle, assault gun",
434
+ "414": "backpack, back pack, knapsack, packsack, rucksack, haversack",
435
+ "415": "bakery, bakeshop, bakehouse",
436
+ "416": "balance beam, beam",
437
+ "417": "balloon",
438
+ "418": "ballpoint, ballpoint pen, ballpen, Biro",
439
+ "419": "Band Aid",
440
+ "420": "banjo",
441
+ "421": "bannister, banister, balustrade, balusters, handrail",
442
+ "422": "barbell",
443
+ "423": "barber chair",
444
+ "424": "barbershop",
445
+ "425": "barn",
446
+ "426": "barometer",
447
+ "427": "barrel, cask",
448
+ "428": "barrow, garden cart, lawn cart, wheelbarrow",
449
+ "429": "baseball",
450
+ "430": "basketball",
451
+ "431": "bassinet",
452
+ "432": "bassoon",
453
+ "433": "bathing cap, swimming cap",
454
+ "434": "bath towel",
455
+ "435": "bathtub, bathing tub, bath, tub",
456
+ "436": "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon",
457
+ "437": "beacon, lighthouse, beacon light, pharos",
458
+ "438": "beaker",
459
+ "439": "bearskin, busby, shako",
460
+ "440": "beer bottle",
461
+ "441": "beer glass",
462
+ "442": "bell cote, bell cot",
463
+ "443": "bib",
464
+ "444": "bicycle-built-for-two, tandem bicycle, tandem",
465
+ "445": "bikini, two-piece",
466
+ "446": "binder, ring-binder",
467
+ "447": "binoculars, field glasses, opera glasses",
468
+ "448": "birdhouse",
469
+ "449": "boathouse",
470
+ "450": "bobsled, bobsleigh, bob",
471
+ "451": "bolo tie, bolo, bola tie, bola",
472
+ "452": "bonnet, poke bonnet",
473
+ "453": "bookcase",
474
+ "454": "bookshop, bookstore, bookstall",
475
+ "455": "bottlecap",
476
+ "456": "bow",
477
+ "457": "bow tie, bow-tie, bowtie",
478
+ "458": "brass, memorial tablet, plaque",
479
+ "459": "brassiere, bra, bandeau",
480
+ "460": "breakwater, groin, groyne, mole, bulwark, seawall, jetty",
481
+ "461": "breastplate, aegis, egis",
482
+ "462": "broom",
483
+ "463": "bucket, pail",
484
+ "464": "buckle",
485
+ "465": "bulletproof vest",
486
+ "466": "bullet train, bullet",
487
+ "467": "butcher shop, meat market",
488
+ "468": "cab, hack, taxi, taxicab",
489
+ "469": "caldron, cauldron",
490
+ "470": "candle, taper, wax light",
491
+ "471": "cannon",
492
+ "472": "canoe",
493
+ "473": "can opener, tin opener",
494
+ "474": "cardigan",
495
+ "475": "car mirror",
496
+ "476": "carousel, carrousel, merry-go-round, roundabout, whirligig",
497
+ "477": "carpenters kit, tool kit",
498
+ "478": "carton",
499
+ "479": "car wheel",
500
+ "480": "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM",
501
+ "481": "cassette",
502
+ "482": "cassette player",
503
+ "483": "castle",
504
+ "484": "catamaran",
505
+ "485": "CD player",
506
+ "486": "cello, violoncello",
507
+ "487": "cellular telephone, cellular phone, cellphone, cell, mobile phone",
508
+ "488": "chain",
509
+ "489": "chainlink fence",
510
+ "490": "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour",
511
+ "491": "chain saw, chainsaw",
512
+ "492": "chest",
513
+ "493": "chiffonier, commode",
514
+ "494": "chime, bell, gong",
515
+ "495": "china cabinet, china closet",
516
+ "496": "Christmas stocking",
517
+ "497": "church, church building",
518
+ "498": "cinema, movie theater, movie theatre, movie house, picture palace",
519
+ "499": "cleaver, meat cleaver, chopper",
520
+ "500": "cliff dwelling",
521
+ "501": "cloak",
522
+ "502": "clog, geta, patten, sabot",
523
+ "503": "cocktail shaker",
524
+ "504": "coffee mug",
525
+ "505": "coffeepot",
526
+ "506": "coil, spiral, volute, whorl, helix",
527
+ "507": "combination lock",
528
+ "508": "computer keyboard, keypad",
529
+ "509": "confectionery, confectionary, candy store",
530
+ "510": "container ship, containership, container vessel",
531
+ "511": "convertible",
532
+ "512": "corkscrew, bottle screw",
533
+ "513": "cornet, horn, trumpet, trump",
534
+ "514": "cowboy boot",
535
+ "515": "cowboy hat, ten-gallon hat",
536
+ "516": "cradle",
537
+ "517": "crane",
538
+ "518": "crash helmet",
539
+ "519": "crate",
540
+ "520": "crib, cot",
541
+ "521": "Crock Pot",
542
+ "522": "croquet ball",
543
+ "523": "crutch",
544
+ "524": "cuirass",
545
+ "525": "dam, dike, dyke",
546
+ "526": "desk",
547
+ "527": "desktop computer",
548
+ "528": "dial telephone, dial phone",
549
+ "529": "diaper, nappy, napkin",
550
+ "530": "digital clock",
551
+ "531": "digital watch",
552
+ "532": "dining table, board",
553
+ "533": "dishrag, dishcloth",
554
+ "534": "dishwasher, dish washer, dishwashing machine",
555
+ "535": "disk brake, disc brake",
556
+ "536": "dock, dockage, docking facility",
557
+ "537": "dogsled, dog sled, dog sleigh",
558
+ "538": "dome",
559
+ "539": "doormat, welcome mat",
560
+ "540": "drilling platform, offshore rig",
561
+ "541": "drum, membranophone, tympan",
562
+ "542": "drumstick",
563
+ "543": "dumbbell",
564
+ "544": "Dutch oven",
565
+ "545": "electric fan, blower",
566
+ "546": "electric guitar",
567
+ "547": "electric locomotive",
568
+ "548": "entertainment center",
569
+ "549": "envelope",
570
+ "550": "espresso maker",
571
+ "551": "face powder",
572
+ "552": "feather boa, boa",
573
+ "553": "file, file cabinet, filing cabinet",
574
+ "554": "fireboat",
575
+ "555": "fire engine, fire truck",
576
+ "556": "fire screen, fireguard",
577
+ "557": "flagpole, flagstaff",
578
+ "558": "flute, transverse flute",
579
+ "559": "folding chair",
580
+ "560": "football helmet",
581
+ "561": "forklift",
582
+ "562": "fountain",
583
+ "563": "fountain pen",
584
+ "564": "four-poster",
585
+ "565": "freight car",
586
+ "566": "French horn, horn",
587
+ "567": "frying pan, frypan, skillet",
588
+ "568": "fur coat",
589
+ "569": "garbage truck, dustcart",
590
+ "570": "gasmask, respirator, gas helmet",
591
+ "571": "gas pump, gasoline pump, petrol pump, island dispenser",
592
+ "572": "goblet",
593
+ "573": "go-kart",
594
+ "574": "golf ball",
595
+ "575": "golfcart, golf cart",
596
+ "576": "gondola",
597
+ "577": "gong, tam-tam",
598
+ "578": "gown",
599
+ "579": "grand piano, grand",
600
+ "580": "greenhouse, nursery, glasshouse",
601
+ "581": "grille, radiator grille",
602
+ "582": "grocery store, grocery, food market, market",
603
+ "583": "guillotine",
604
+ "584": "hair slide",
605
+ "585": "hair spray",
606
+ "586": "half track",
607
+ "587": "hammer",
608
+ "588": "hamper",
609
+ "589": "hand blower, blow dryer, blow drier, hair dryer, hair drier",
610
+ "590": "hand-held computer, hand-held microcomputer",
611
+ "591": "handkerchief, hankie, hanky, hankey",
612
+ "592": "hard disc, hard disk, fixed disk",
613
+ "593": "harmonica, mouth organ, harp, mouth harp",
614
+ "594": "harp",
615
+ "595": "harvester, reaper",
616
+ "596": "hatchet",
617
+ "597": "holster",
618
+ "598": "home theater, home theatre",
619
+ "599": "honeycomb",
620
+ "600": "hook, claw",
621
+ "601": "hoopskirt, crinoline",
622
+ "602": "horizontal bar, high bar",
623
+ "603": "horse cart, horse-cart",
624
+ "604": "hourglass",
625
+ "605": "iPod",
626
+ "606": "iron, smoothing iron",
627
+ "607": "jack-o-lantern",
628
+ "608": "jean, blue jean, denim",
629
+ "609": "jeep, landrover",
630
+ "610": "jersey, T-shirt, tee shirt",
631
+ "611": "jigsaw puzzle",
632
+ "612": "jinrikisha, ricksha, rickshaw",
633
+ "613": "joystick",
634
+ "614": "kimono",
635
+ "615": "knee pad",
636
+ "616": "knot",
637
+ "617": "lab coat, laboratory coat",
638
+ "618": "ladle",
639
+ "619": "lampshade, lamp shade",
640
+ "620": "laptop, laptop computer",
641
+ "621": "lawn mower, mower",
642
+ "622": "lens cap, lens cover",
643
+ "623": "letter opener, paper knife, paperknife",
644
+ "624": "library",
645
+ "625": "lifeboat",
646
+ "626": "lighter, light, igniter, ignitor",
647
+ "627": "limousine, limo",
648
+ "628": "liner, ocean liner",
649
+ "629": "lipstick, lip rouge",
650
+ "630": "Loafer",
651
+ "631": "lotion",
652
+ "632": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system",
653
+ "633": "loupe, jewelers loupe",
654
+ "634": "lumbermill, sawmill",
655
+ "635": "magnetic compass",
656
+ "636": "mailbag, postbag",
657
+ "637": "mailbox, letter box",
658
+ "638": "maillot",
659
+ "639": "maillot, tank suit",
660
+ "640": "manhole cover",
661
+ "641": "maraca",
662
+ "642": "marimba, xylophone",
663
+ "643": "mask",
664
+ "644": "matchstick",
665
+ "645": "maypole",
666
+ "646": "maze, labyrinth",
667
+ "647": "measuring cup",
668
+ "648": "medicine chest, medicine cabinet",
669
+ "649": "megalith, megalithic structure",
670
+ "650": "microphone, mike",
671
+ "651": "microwave, microwave oven",
672
+ "652": "military uniform",
673
+ "653": "milk can",
674
+ "654": "minibus",
675
+ "655": "miniskirt, mini",
676
+ "656": "minivan",
677
+ "657": "missile",
678
+ "658": "mitten",
679
+ "659": "mixing bowl",
680
+ "660": "mobile home, manufactured home",
681
+ "661": "Model T",
682
+ "662": "modem",
683
+ "663": "monastery",
684
+ "664": "monitor",
685
+ "665": "moped",
686
+ "666": "mortar",
687
+ "667": "mortarboard",
688
+ "668": "mosque",
689
+ "669": "mosquito net",
690
+ "670": "motor scooter, scooter",
691
+ "671": "mountain bike, all-terrain bike, off-roader",
692
+ "672": "mountain tent",
693
+ "673": "mouse, computer mouse",
694
+ "674": "mousetrap",
695
+ "675": "moving van",
696
+ "676": "muzzle",
697
+ "677": "nail",
698
+ "678": "neck brace",
699
+ "679": "necklace",
700
+ "680": "nipple",
701
+ "681": "notebook, notebook computer",
702
+ "682": "obelisk",
703
+ "683": "oboe, hautboy, hautbois",
704
+ "684": "ocarina, sweet potato",
705
+ "685": "odometer, hodometer, mileometer, milometer",
706
+ "686": "oil filter",
707
+ "687": "organ, pipe organ",
708
+ "688": "oscilloscope, scope, cathode-ray oscilloscope, CRO",
709
+ "689": "overskirt",
710
+ "690": "oxcart",
711
+ "691": "oxygen mask",
712
+ "692": "packet",
713
+ "693": "paddle, boat paddle",
714
+ "694": "paddlewheel, paddle wheel",
715
+ "695": "padlock",
716
+ "696": "paintbrush",
717
+ "697": "pajama, pyjama, pjs, jammies",
718
+ "698": "palace",
719
+ "699": "panpipe, pandean pipe, syrinx",
720
+ "700": "paper towel",
721
+ "701": "parachute, chute",
722
+ "702": "parallel bars, bars",
723
+ "703": "park bench",
724
+ "704": "parking meter",
725
+ "705": "passenger car, coach, carriage",
726
+ "706": "patio, terrace",
727
+ "707": "pay-phone, pay-station",
728
+ "708": "pedestal, plinth, footstall",
729
+ "709": "pencil box, pencil case",
730
+ "710": "pencil sharpener",
731
+ "711": "perfume, essence",
732
+ "712": "Petri dish",
733
+ "713": "photocopier",
734
+ "714": "pick, plectrum, plectron",
735
+ "715": "pickelhaube",
736
+ "716": "picket fence, paling",
737
+ "717": "pickup, pickup truck",
738
+ "718": "pier",
739
+ "719": "piggy bank, penny bank",
740
+ "720": "pill bottle",
741
+ "721": "pillow",
742
+ "722": "ping-pong ball",
743
+ "723": "pinwheel",
744
+ "724": "pirate, pirate ship",
745
+ "725": "pitcher, ewer",
746
+ "726": "plane, carpenters plane, woodworking plane",
747
+ "727": "planetarium",
748
+ "728": "plastic bag",
749
+ "729": "plate rack",
750
+ "730": "plow, plough",
751
+ "731": "plunger, plumbers helper",
752
+ "732": "Polaroid camera, Polaroid Land camera",
753
+ "733": "pole",
754
+ "734": "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria",
755
+ "735": "poncho",
756
+ "736": "pool table, billiard table, snooker table",
757
+ "737": "pop bottle, soda bottle",
758
+ "738": "pot, flowerpot",
759
+ "739": "potters wheel",
760
+ "740": "power drill",
761
+ "741": "prayer rug, prayer mat",
762
+ "742": "printer",
763
+ "743": "prison, prison house",
764
+ "744": "projectile, missile",
765
+ "745": "projector",
766
+ "746": "puck, hockey puck",
767
+ "747": "punching bag, punch bag, punching ball, punchball",
768
+ "748": "purse",
769
+ "749": "quill, quill pen",
770
+ "750": "quilt, comforter, comfort, puff",
771
+ "751": "racer, race car, racing car",
772
+ "752": "racket, racquet",
773
+ "753": "radiator",
774
+ "754": "radio, wireless",
775
+ "755": "radio telescope, radio reflector",
776
+ "756": "rain barrel",
777
+ "757": "recreational vehicle, RV, R.V.",
778
+ "758": "reel",
779
+ "759": "reflex camera",
780
+ "760": "refrigerator, icebox",
781
+ "761": "remote control, remote",
782
+ "762": "restaurant, eating house, eating place, eatery",
783
+ "763": "revolver, six-gun, six-shooter",
784
+ "764": "rifle",
785
+ "765": "rocking chair, rocker",
786
+ "766": "rotisserie",
787
+ "767": "rubber eraser, rubber, pencil eraser",
788
+ "768": "rugby ball",
789
+ "769": "rule, ruler",
790
+ "770": "running shoe",
791
+ "771": "safe",
792
+ "772": "safety pin",
793
+ "773": "saltshaker, salt shaker",
794
+ "774": "sandal",
795
+ "775": "sarong",
796
+ "776": "sax, saxophone",
797
+ "777": "scabbard",
798
+ "778": "scale, weighing machine",
799
+ "779": "school bus",
800
+ "780": "schooner",
801
+ "781": "scoreboard",
802
+ "782": "screen, CRT screen",
803
+ "783": "screw",
804
+ "784": "screwdriver",
805
+ "785": "seat belt, seatbelt",
806
+ "786": "sewing machine",
807
+ "787": "shield, buckler",
808
+ "788": "shoe shop, shoe-shop, shoe store",
809
+ "789": "shoji",
810
+ "790": "shopping basket",
811
+ "791": "shopping cart",
812
+ "792": "shovel",
813
+ "793": "shower cap",
814
+ "794": "shower curtain",
815
+ "795": "ski",
816
+ "796": "ski mask",
817
+ "797": "sleeping bag",
818
+ "798": "slide rule, slipstick",
819
+ "799": "sliding door",
820
+ "800": "slot, one-armed bandit",
821
+ "801": "snorkel",
822
+ "802": "snowmobile",
823
+ "803": "snowplow, snowplough",
824
+ "804": "soap dispenser",
825
+ "805": "soccer ball",
826
+ "806": "sock",
827
+ "807": "solar dish, solar collector, solar furnace",
828
+ "808": "sombrero",
829
+ "809": "soup bowl",
830
+ "810": "space bar",
831
+ "811": "space heater",
832
+ "812": "space shuttle",
833
+ "813": "spatula",
834
+ "814": "speedboat",
835
+ "815": "spider web, spiders web",
836
+ "816": "spindle",
837
+ "817": "sports car, sport car",
838
+ "818": "spotlight, spot",
839
+ "819": "stage",
840
+ "820": "steam locomotive",
841
+ "821": "steel arch bridge",
842
+ "822": "steel drum",
843
+ "823": "stethoscope",
844
+ "824": "stole",
845
+ "825": "stone wall",
846
+ "826": "stopwatch, stop watch",
847
+ "827": "stove",
848
+ "828": "strainer",
849
+ "829": "streetcar, tram, tramcar, trolley, trolley car",
850
+ "830": "stretcher",
851
+ "831": "studio couch, day bed",
852
+ "832": "stupa, tope",
853
+ "833": "submarine, pigboat, sub, U-boat",
854
+ "834": "suit, suit of clothes",
855
+ "835": "sundial",
856
+ "836": "sunglass",
857
+ "837": "sunglasses, dark glasses, shades",
858
+ "838": "sunscreen, sunblock, sun blocker",
859
+ "839": "suspension bridge",
860
+ "840": "swab, swob, mop",
861
+ "841": "sweatshirt",
862
+ "842": "swimming trunks, bathing trunks",
863
+ "843": "swing",
864
+ "844": "switch, electric switch, electrical switch",
865
+ "845": "syringe",
866
+ "846": "table lamp",
867
+ "847": "tank, army tank, armored combat vehicle, armoured combat vehicle",
868
+ "848": "tape player",
869
+ "849": "teapot",
870
+ "850": "teddy, teddy bear",
871
+ "851": "television, television system",
872
+ "852": "tennis ball",
873
+ "853": "thatch, thatched roof",
874
+ "854": "theater curtain, theatre curtain",
875
+ "855": "thimble",
876
+ "856": "thresher, thrasher, threshing machine",
877
+ "857": "throne",
878
+ "858": "tile roof",
879
+ "859": "toaster",
880
+ "860": "tobacco shop, tobacconist shop, tobacconist",
881
+ "861": "toilet seat",
882
+ "862": "torch",
883
+ "863": "totem pole",
884
+ "864": "tow truck, tow car, wrecker",
885
+ "865": "toyshop",
886
+ "866": "tractor",
887
+ "867": "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi",
888
+ "868": "tray",
889
+ "869": "trench coat",
890
+ "870": "tricycle, trike, velocipede",
891
+ "871": "trimaran",
892
+ "872": "tripod",
893
+ "873": "triumphal arch",
894
+ "874": "trolleybus, trolley coach, trackless trolley",
895
+ "875": "trombone",
896
+ "876": "tub, vat",
897
+ "877": "turnstile",
898
+ "878": "typewriter keyboard",
899
+ "879": "umbrella",
900
+ "880": "unicycle, monocycle",
901
+ "881": "upright, upright piano",
902
+ "882": "vacuum, vacuum cleaner",
903
+ "883": "vase",
904
+ "884": "vault",
905
+ "885": "velvet",
906
+ "886": "vending machine",
907
+ "887": "vestment",
908
+ "888": "viaduct",
909
+ "889": "violin, fiddle",
910
+ "890": "volleyball",
911
+ "891": "waffle iron",
912
+ "892": "wall clock",
913
+ "893": "wallet, billfold, notecase, pocketbook",
914
+ "894": "wardrobe, closet, press",
915
+ "895": "warplane, military plane",
916
+ "896": "washbasin, handbasin, washbowl, lavabo, wash-hand basin",
917
+ "897": "washer, automatic washer, washing machine",
918
+ "898": "water bottle",
919
+ "899": "water jug",
920
+ "900": "water tower",
921
+ "901": "whiskey jug",
922
+ "902": "whistle",
923
+ "903": "wig",
924
+ "904": "window screen",
925
+ "905": "window shade",
926
+ "906": "Windsor tie",
927
+ "907": "wine bottle",
928
+ "908": "wing",
929
+ "909": "wok",
930
+ "910": "wooden spoon",
931
+ "911": "wool, woolen, woollen",
932
+ "912": "worm fence, snake fence, snake-rail fence, Virginia fence",
933
+ "913": "wreck",
934
+ "914": "yawl",
935
+ "915": "yurt",
936
+ "916": "web site, website, internet site, site",
937
+ "917": "comic book",
938
+ "918": "crossword puzzle, crossword",
939
+ "919": "street sign",
940
+ "920": "traffic light, traffic signal, stoplight",
941
+ "921": "book jacket, dust cover, dust jacket, dust wrapper",
942
+ "922": "menu",
943
+ "923": "plate",
944
+ "924": "guacamole",
945
+ "925": "consomme",
946
+ "926": "hot pot, hotpot",
947
+ "927": "trifle",
948
+ "928": "ice cream, icecream",
949
+ "929": "ice lolly, lolly, lollipop, popsicle",
950
+ "930": "French loaf",
951
+ "931": "bagel, beigel",
952
+ "932": "pretzel",
953
+ "933": "cheeseburger",
954
+ "934": "hotdog, hot dog, red hot",
955
+ "935": "mashed potato",
956
+ "936": "head cabbage",
957
+ "937": "broccoli",
958
+ "938": "cauliflower",
959
+ "939": "zucchini, courgette",
960
+ "940": "spaghetti squash",
961
+ "941": "acorn squash",
962
+ "942": "butternut squash",
963
+ "943": "cucumber, cuke",
964
+ "944": "artichoke, globe artichoke",
965
+ "945": "bell pepper",
966
+ "946": "cardoon",
967
+ "947": "mushroom",
968
+ "948": "Granny Smith",
969
+ "949": "strawberry",
970
+ "950": "orange",
971
+ "951": "lemon",
972
+ "952": "fig",
973
+ "953": "pineapple, ananas",
974
+ "954": "banana",
975
+ "955": "jackfruit, jak, jack",
976
+ "956": "custard apple",
977
+ "957": "pomegranate",
978
+ "958": "hay",
979
+ "959": "carbonara",
980
+ "960": "chocolate sauce, chocolate syrup",
981
+ "961": "dough",
982
+ "962": "meat loaf, meatloaf",
983
+ "963": "pizza, pizza pie",
984
+ "964": "potpie",
985
+ "965": "burrito",
986
+ "966": "red wine",
987
+ "967": "espresso",
988
+ "968": "cup",
989
+ "969": "eggnog",
990
+ "970": "alp",
991
+ "971": "bubble",
992
+ "972": "cliff, drop, drop-off",
993
+ "973": "coral reef",
994
+ "974": "geyser",
995
+ "975": "lakeside, lakeshore",
996
+ "976": "promontory, headland, head, foreland",
997
+ "977": "sandbar, sand bar",
998
+ "978": "seashore, coast, seacoast, sea-coast",
999
+ "979": "valley, vale",
1000
+ "980": "volcano",
1001
+ "981": "ballplayer, baseball player",
1002
+ "982": "groom, bridegroom",
1003
+ "983": "scuba diver",
1004
+ "984": "rapeseed",
1005
+ "985": "daisy",
1006
+ "986": "yellow ladys slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum",
1007
+ "987": "corn",
1008
+ "988": "acorn",
1009
+ "989": "hip, rose hip, rosehip",
1010
+ "990": "buckeye, horse chestnut, conker",
1011
+ "991": "coral fungus",
1012
+ "992": "agaric",
1013
+ "993": "gyromitra",
1014
+ "994": "stinkhorn, carrion fungus",
1015
+ "995": "earthstar",
1016
+ "996": "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa",
1017
+ "997": "bolete",
1018
+ "998": "ear, spike, capitulum",
1019
+ "999": "toilet tissue, toilet paper, bathroom tissue"
1020
+ }
1021
+ }
SiT-XL-2-512/pipeline.py CHANGED
@@ -1,82 +1,349 @@
1
- from typing import List, Optional, Union
2
-
3
- import torch
4
-
5
- from diffusers.image_processor import VaeImageProcessor
6
- from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
7
- from diffusers.utils.torch_utils import randn_tensor
8
-
9
-
10
- class SiTPipeline(DiffusionPipeline):
11
- model_cpu_offload_seq = "transformer->vae"
12
-
13
- def __init__(self, transformer, scheduler, vae):
14
- super().__init__()
15
- self.register_modules(transformer=transformer, scheduler=scheduler, vae=vae)
16
- self.vae_scale_factor = 8
17
- self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
18
-
19
- @torch.no_grad()
20
- def __call__(
21
- self,
22
- class_labels: Union[int, List[int]] = 207,
23
- height: int = 256,
24
- width: int = 256,
25
- num_inference_steps: int = 250,
26
- guidance_scale: float = 4.0,
27
- generator: Optional[torch.Generator] = None,
28
- output_type: str = "pil",
29
- return_dict: bool = True,
30
- ):
31
- device = self._execution_device
32
- if isinstance(class_labels, int):
33
- class_labels = [class_labels]
34
- batch_size = len(class_labels)
35
-
36
- latent_h = height // self.vae_scale_factor
37
- latent_w = width // self.vae_scale_factor
38
- latents = randn_tensor(
39
- (batch_size, self.transformer.config.in_channels, latent_h, latent_w),
40
- generator=generator,
41
- device=device,
42
- dtype=self.transformer.dtype,
43
- )
44
-
45
- labels = torch.tensor(class_labels, device=device, dtype=torch.long)
46
- do_cfg = guidance_scale is not None and guidance_scale > 1.0
47
- if do_cfg:
48
- null_label = torch.full((batch_size,), self.transformer.config.num_classes, device=device, dtype=torch.long)
49
- labels = torch.cat([labels, null_label], dim=0)
50
-
51
- self.scheduler.set_timesteps(num_inference_steps, device=device)
52
- timesteps = self.scheduler.timesteps
53
-
54
- for t in self.progress_bar(timesteps):
55
- t_batch = torch.full((batch_size,), t, device=device, dtype=latents.dtype)
56
- model_input = latents
57
- if do_cfg:
58
- model_input = torch.cat([latents, latents], dim=0)
59
- t_batch = torch.cat([t_batch, t_batch], dim=0)
60
-
61
- model_pred = self.transformer(
62
- hidden_states=model_input,
63
- timestep=t_batch,
64
- class_labels=labels,
65
- ).sample
66
-
67
- if do_cfg:
68
- cond, uncond = model_pred.chunk(2, dim=0)
69
- model_pred = uncond + guidance_scale * (cond - uncond)
70
-
71
- latents = self.scheduler.step(model_pred, t, latents, generator=generator).prev_sample
72
-
73
- image = self.vae.decode(latents / 0.18215).sample
74
- # Keep PyTorch outputs in raw VAE range [-1, 1] to match original SiT scripts.
75
- if output_type == "pt":
76
- image = image
77
- else:
78
- image = self.image_processor.postprocess(image, output_type=output_type)
79
-
80
- if not return_dict:
81
- return (image,)
82
- return ImagePipelineOutput(images=image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Hub custom pipeline: SiTPipeline.
2
+ Load with native Hugging Face diffusers and trust_remote_code=True.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ # Copyright 2026 The HuggingFace Team. All rights reserved.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+
21
+ import json
22
+ from pathlib import Path
23
+ from typing import Dict, List, Optional, Tuple, Union
24
+
25
+ import torch
26
+
27
+ from diffusers.image_processor import VaeImageProcessor
28
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
29
+ from diffusers.utils.torch_utils import randn_tensor
30
+
31
+ EXAMPLE_DOC_STRING = """
32
+ Examples:
33
+ ```py
34
+ >>> from pathlib import Path
35
+ >>> from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler
36
+ >>> import torch
37
+
38
+ >>> model_dir = Path("./SiT-XL-2-256").resolve()
39
+ >>> pipe = DiffusionPipeline.from_pretrained(
40
+ ... str(model_dir),
41
+ ... local_files_only=True,
42
+ ... custom_pipeline=str(model_dir / "pipeline.py"),
43
+ ... trust_remote_code=True,
44
+ ... torch_dtype=torch.bfloat16,
45
+ ... )
46
+ >>> pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(pipe.scheduler.config)
47
+ >>> pipe.to("cuda")
48
+
49
+ >>> print(pipe.id2label[207])
50
+ >>> print(pipe.get_label_ids("golden retriever"))
51
+
52
+ >>> generator = torch.Generator(device="cuda").manual_seed(42)
53
+ >>> image = pipe(
54
+ ... class_labels="golden retriever",
55
+ ... height=256,
56
+ ... width=256,
57
+ ... num_inference_steps=250,
58
+ ... guidance_scale=4.0,
59
+ ... generator=generator,
60
+ ... ).images[0]
61
+ ```
62
+ """
63
+
64
+ class SiTPipeline(DiffusionPipeline):
65
+ r"""
66
+ Pipeline for class-conditional image generation with Scalable Interpolant Transformers (SiT).
67
+
68
+ Parameters:
69
+ transformer ([`SiTTransformer2DModel`]):
70
+ Class-conditional SiT transformer that predicts flow-matching velocity in latent space.
71
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
72
+ Flow-matching Euler scheduler. Other [`KarrasDiffusionSchedulers`] can be swapped at inference time.
73
+ vae ([`AutoencoderKL`]):
74
+ Variational autoencoder used to decode transformer latents to pixels.
75
+ id2label (`dict[int, str]`, *optional*):
76
+ ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
77
+ """
78
+
79
+ model_cpu_offload_seq = "transformer->vae"
80
+
81
+ def __init__(
82
+ self,
83
+ transformer,
84
+ scheduler,
85
+ vae,
86
+ id2label: Optional[Dict[Union[int, str], str]] = None,
87
+ ):
88
+ super().__init__()
89
+ if scheduler is None:
90
+ scheduler = FlowMatchEulerDiscreteScheduler(
91
+ num_train_timesteps=1000,
92
+ shift=1.0,
93
+ stochastic_sampling=False,
94
+ )
95
+ self.register_modules(transformer=transformer, scheduler=scheduler, vae=vae)
96
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
97
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
98
+ self._id2label = self._normalize_id2label(id2label)
99
+ self.labels = self._build_label2id(self._id2label)
100
+ self._labels_loaded_from_model_index = bool(self._id2label)
101
+
102
+ def _ensure_labels_loaded(self) -> None:
103
+ if self._labels_loaded_from_model_index:
104
+ return
105
+ loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None))
106
+ if loaded:
107
+ self._id2label = loaded
108
+ self.labels = self._build_label2id(self._id2label)
109
+ self._labels_loaded_from_model_index = True
110
+
111
+ @staticmethod
112
+ def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
113
+ if not id2label:
114
+ return {}
115
+ return {int(key): value for key, value in id2label.items()}
116
+
117
+ @staticmethod
118
+ def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]:
119
+ if not variant_path:
120
+ return {}
121
+ variant_dir = Path(variant_path).resolve()
122
+ model_index_path = variant_dir / "model_index.json"
123
+ if not model_index_path.exists():
124
+ return {}
125
+ raw = json.loads(model_index_path.read_text(encoding="utf-8"))
126
+ id2label = raw.get("id2label")
127
+ if not isinstance(id2label, dict):
128
+ return {}
129
+ return {int(key): value for key, value in id2label.items()}
130
+
131
+ @staticmethod
132
+ def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
133
+ label2id: Dict[str, int] = {}
134
+ for class_id, value in id2label.items():
135
+ for synonym in value.split(","):
136
+ synonym = synonym.strip()
137
+ if synonym:
138
+ label2id[synonym] = int(class_id)
139
+ return dict(sorted(label2id.items()))
140
+
141
+ @property
142
+ def id2label(self) -> Dict[int, str]:
143
+ r"""ImageNet class id to English label string (comma-separated synonyms)."""
144
+ self._ensure_labels_loaded()
145
+ return self._id2label
146
+
147
+ def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
148
+ r"""
149
+ Map ImageNet label strings to class ids.
150
+
151
+ Args:
152
+ label (`str` or `list[str]`):
153
+ One or more English label strings. Each string must match a synonym in `id2label`.
154
+ """
155
+ self._ensure_labels_loaded()
156
+ label2id = self.labels
157
+ if not label2id:
158
+ raise ValueError("No English labels loaded. Ensure `id2label` exists in model_index.json.")
159
+
160
+ if isinstance(label, str):
161
+ label = [label]
162
+
163
+ missing = [item for item in label if item not in label2id]
164
+ if missing:
165
+ preview = ", ".join(list(label2id.keys())[:8])
166
+ raise ValueError(f"Unknown English label(s): {missing}. Example valid labels: {preview}, ...")
167
+ return [label2id[item] for item in label]
168
+
169
+ def _normalize_class_labels(
170
+ self,
171
+ class_labels: Union[int, str, List[Union[int, str]], torch.LongTensor],
172
+ ) -> torch.LongTensor:
173
+ if torch.is_tensor(class_labels):
174
+ return class_labels.to(device=self._execution_device, dtype=torch.long).reshape(-1)
175
+
176
+ if isinstance(class_labels, int):
177
+ class_label_ids = [class_labels]
178
+ elif isinstance(class_labels, str):
179
+ class_label_ids = self.get_label_ids(class_labels)
180
+ elif class_labels and isinstance(class_labels[0], str):
181
+ class_label_ids = self.get_label_ids(class_labels)
182
+ else:
183
+ class_label_ids = list(class_labels)
184
+
185
+ return torch.tensor(class_label_ids, device=self._execution_device, dtype=torch.long).reshape(-1)
186
+
187
+ def _default_image_size(self) -> int:
188
+ return int(self.transformer.config.input_size) * self.vae_scale_factor
189
+
190
+ def check_inputs(
191
+ self,
192
+ height: int,
193
+ width: int,
194
+ num_inference_steps: int,
195
+ output_type: str,
196
+ ) -> None:
197
+ if num_inference_steps < 1:
198
+ raise ValueError("num_inference_steps must be >= 1.")
199
+ if output_type not in {"pil", "np", "pt", "latent"}:
200
+ raise ValueError("output_type must be one of: 'pil', 'np', 'pt', 'latent'.")
201
+
202
+ if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0:
203
+ raise ValueError(
204
+ f"height and width must be divisible by the VAE downsample factor {self.vae_scale_factor}."
205
+ )
206
+
207
+ latent_height = height // self.vae_scale_factor
208
+ latent_width = width // self.vae_scale_factor
209
+ expected_size = int(self.transformer.config.input_size)
210
+ if latent_height != expected_size or latent_width != expected_size:
211
+ raise ValueError(
212
+ f"Requested latent size {(latent_height, latent_width)} does not match the pretrained "
213
+ f"transformer input_size={expected_size}. Use height=width={self._default_image_size()}."
214
+ )
215
+
216
+ def prepare_latents(
217
+ self,
218
+ batch_size: int,
219
+ height: int,
220
+ width: int,
221
+ dtype: torch.dtype,
222
+ device: torch.device,
223
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]],
224
+ ) -> torch.Tensor:
225
+ latent_height = height // self.vae_scale_factor
226
+ latent_width = width // self.vae_scale_factor
227
+ return randn_tensor(
228
+ (batch_size, self.transformer.config.in_channels, latent_height, latent_width),
229
+ generator=generator,
230
+ device=device,
231
+ dtype=dtype,
232
+ )
233
+
234
+ @staticmethod
235
+ def _apply_classifier_free_guidance(model_output: torch.Tensor, guidance_scale: float) -> torch.Tensor:
236
+ if guidance_scale <= 1.0:
237
+ return model_output
238
+ model_output_cond, model_output_uncond = model_output.chunk(2)
239
+ return model_output_uncond + guidance_scale * (model_output_cond - model_output_uncond)
240
+
241
+ def decode_latents(self, latents: torch.Tensor, output_type: str = "pil"):
242
+ if output_type == "latent":
243
+ return latents
244
+
245
+ scaling_factor = getattr(self.vae.config, "scaling_factor", 0.18215)
246
+ image = self.vae.decode(latents / scaling_factor).sample
247
+ if output_type == "pt":
248
+ return image
249
+ return self.image_processor.postprocess(image, output_type=output_type)
250
+
251
+ @torch.inference_mode()
252
+ def __call__(
253
+ self,
254
+ class_labels: Union[int, str, List[Union[int, str]], torch.LongTensor],
255
+ height: Optional[int] = None,
256
+ width: Optional[int] = None,
257
+ num_inference_steps: int = 250,
258
+ guidance_scale: float = 4.0,
259
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
260
+ output_type: str = "pil",
261
+ return_dict: bool = True,
262
+ ) -> Union[ImagePipelineOutput, Tuple]:
263
+ r"""
264
+ Generate class-conditional images with SiT.
265
+
266
+ Args:
267
+ class_labels (`int`, `str`, `list[int]`, `list[str]`, or `torch.LongTensor`):
268
+ ImageNet class indices or human-readable English label strings.
269
+ height (`int`, *optional*):
270
+ Output image height in pixels. Defaults to the pretrained native resolution.
271
+ width (`int`, *optional*):
272
+ Output image width in pixels. Defaults to the pretrained native resolution.
273
+ num_inference_steps (`int`, defaults to `250`):
274
+ Number of denoising steps.
275
+ guidance_scale (`float`, defaults to `4.0`):
276
+ Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`.
277
+ generator (`torch.Generator`, *optional*):
278
+ RNG for reproducibility.
279
+ output_type (`str`, defaults to `"pil"`):
280
+ `"pil"`, `"np"`, `"pt"`, or `"latent"`.
281
+ return_dict (`bool`, defaults to `True`):
282
+ Return [`ImagePipelineOutput`] if True.
283
+ """
284
+ default_size = self._default_image_size()
285
+ height = int(height or default_size)
286
+ width = int(width or default_size)
287
+ self.check_inputs(height, width, num_inference_steps, output_type)
288
+
289
+ device = self._execution_device
290
+ model_dtype = next(self.transformer.parameters()).dtype
291
+ class_labels_tensor = self._normalize_class_labels(class_labels)
292
+ batch_size = class_labels_tensor.numel()
293
+ do_cfg = guidance_scale > 1.0
294
+
295
+ latents = self.prepare_latents(
296
+ batch_size=batch_size,
297
+ height=height,
298
+ width=width,
299
+ dtype=model_dtype,
300
+ device=device,
301
+ generator=generator,
302
+ )
303
+
304
+ labels = class_labels_tensor
305
+ if do_cfg:
306
+ null_labels = torch.full_like(class_labels_tensor, self.transformer.config.num_classes)
307
+ labels = torch.cat([class_labels_tensor, null_labels], dim=0)
308
+
309
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
310
+ num_train_timesteps = self.scheduler.config.num_train_timesteps
311
+
312
+ if getattr(self.scheduler.config, "stochastic_sampling", False):
313
+ raise ValueError(
314
+ "SiT expects deterministic FlowMatchEulerDiscreteScheduler stepping "
315
+ "(scheduler.config.stochastic_sampling=False)."
316
+ )
317
+
318
+ for t in self.progress_bar(self.scheduler.timesteps):
319
+ flow_time = 1.0 - float(t) / num_train_timesteps
320
+ if do_cfg:
321
+ model_input = torch.cat([latents, latents], dim=0)
322
+ else:
323
+ model_input = latents
324
+
325
+ timestep_batch = torch.full((model_input.shape[0],), flow_time, device=device, dtype=model_dtype)
326
+ model_output = self.transformer(
327
+ hidden_states=model_input,
328
+ timestep=timestep_batch,
329
+ class_labels=labels,
330
+ return_dict=True,
331
+ ).sample
332
+ model_output = self._apply_classifier_free_guidance(model_output, guidance_scale=guidance_scale)
333
+ # SiT predicts dx/d(flow_time) with flow_time increasing from noise (0) to data (1).
334
+ # FlowMatchEulerDiscreteScheduler integrates over sigma decreasing from 1 to 0, so flip sign.
335
+ model_output = -model_output
336
+ latents = self.scheduler.step(
337
+ model_output=model_output,
338
+ timestep=t,
339
+ sample=latents,
340
+ generator=generator,
341
+ return_dict=True,
342
+ ).prev_sample
343
+
344
+ image = self.decode_latents(latents, output_type=output_type)
345
+
346
+ self.maybe_free_model_hooks()
347
+ if not return_dict:
348
+ return (image,)
349
+ return ImagePipelineOutput(images=image)
SiT-XL-2-512/scheduler/scheduler_config.json CHANGED
@@ -1,9 +1,7 @@
1
- {
2
- "_class_name": "SiTFlowMatchScheduler",
3
- "_diffusers_version": "0.36.0",
4
- "diffusion_form": "sigma",
5
- "diffusion_norm": 1.0,
6
- "mode": "ode",
7
- "num_train_timesteps": 1000,
8
- "shift": 1.0
9
- }
 
1
+ {
2
+ "_class_name": "FlowMatchEulerDiscreteScheduler",
3
+ "_diffusers_version": "0.36.0",
4
+ "num_train_timesteps": 1000,
5
+ "shift": 1.0,
6
+ "stochastic_sampling": false
7
+ }
 
 
SiT-XL-2-512/transformer/transformer_sit.py CHANGED
@@ -1,224 +1,240 @@
1
- import math
2
- from dataclasses import dataclass
3
- from typing import Optional
4
-
5
- import numpy as np
6
- import torch
7
- import torch.nn as nn
8
- from timm.models.vision_transformer import Attention, Mlp, PatchEmbed
9
-
10
- from diffusers.configuration_utils import ConfigMixin, register_to_config
11
- from diffusers.models.modeling_utils import ModelMixin
12
- from diffusers.utils import BaseOutput
13
-
14
-
15
- def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
16
- return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
17
-
18
-
19
- @dataclass
20
- class SiTTransformer2DModelOutput(BaseOutput):
21
- sample: torch.Tensor
22
-
23
-
24
- class TimestepEmbedder(nn.Module):
25
- def __init__(self, hidden_size: int, frequency_embedding_size: int = 256):
26
- super().__init__()
27
- self.mlp = nn.Sequential(
28
- nn.Linear(frequency_embedding_size, hidden_size, bias=True),
29
- nn.SiLU(),
30
- nn.Linear(hidden_size, hidden_size, bias=True),
31
- )
32
- self.frequency_embedding_size = frequency_embedding_size
33
-
34
- @staticmethod
35
- def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 10000) -> torch.Tensor:
36
- half = dim // 2
37
- freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
38
- device=t.device
39
- )
40
- args = t[:, None].float() * freqs[None]
41
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
42
- if dim % 2:
43
- embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
44
- return embedding
45
-
46
- def forward(self, t: torch.Tensor) -> torch.Tensor:
47
- return self.mlp(self.timestep_embedding(t, self.frequency_embedding_size))
48
-
49
-
50
- class LabelEmbedder(nn.Module):
51
- def __init__(self, num_classes: int, hidden_size: int, dropout_prob: float):
52
- super().__init__()
53
- use_cfg_embedding = dropout_prob > 0
54
- self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
55
- self.num_classes = num_classes
56
- self.dropout_prob = dropout_prob
57
-
58
- def token_drop(self, labels: torch.Tensor, force_drop_ids: Optional[torch.Tensor] = None) -> torch.Tensor:
59
- if force_drop_ids is None:
60
- drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
61
- else:
62
- drop_ids = force_drop_ids == 1
63
- labels = torch.where(drop_ids, self.num_classes, labels)
64
- return labels
65
-
66
- def forward(
67
- self,
68
- labels: torch.Tensor,
69
- train: bool,
70
- force_drop_ids: Optional[torch.Tensor] = None,
71
- ) -> torch.Tensor:
72
- use_dropout = self.dropout_prob > 0
73
- if (train and use_dropout) or (force_drop_ids is not None):
74
- labels = self.token_drop(labels, force_drop_ids)
75
- return self.embedding_table(labels)
76
-
77
-
78
- class SiTBlock(nn.Module):
79
- def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float = 4.0, **block_kwargs):
80
- super().__init__()
81
- self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
82
- self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
83
- self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
84
- mlp_hidden_dim = int(hidden_size * mlp_ratio)
85
- approx_gelu = lambda: nn.GELU(approximate="tanh")
86
- self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
87
- self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
88
-
89
- def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
90
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
91
- x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
92
- x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
93
- return x
94
-
95
-
96
- class FinalLayer(nn.Module):
97
- def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
98
- super().__init__()
99
- self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
100
- self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
101
- self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
102
-
103
- def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
104
- shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
105
- x = modulate(self.norm_final(x), shift, scale)
106
- return self.linear(x)
107
-
108
-
109
- class SiTTransformer2DModel(ModelMixin, ConfigMixin):
110
- @register_to_config
111
- def __init__(
112
- self,
113
- input_size: int = 32,
114
- patch_size: int = 2,
115
- in_channels: int = 4,
116
- hidden_size: int = 1152,
117
- depth: int = 28,
118
- num_heads: int = 16,
119
- mlp_ratio: float = 4.0,
120
- class_dropout_prob: float = 0.1,
121
- num_classes: int = 1000,
122
- learn_sigma: bool = True,
123
- ):
124
- super().__init__()
125
- self.learn_sigma = learn_sigma
126
- self.in_channels = in_channels
127
- self.out_channels = in_channels * 2 if learn_sigma else in_channels
128
- self.patch_size = patch_size
129
- self.num_classes = num_classes
130
-
131
- self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
132
- self.t_embedder = TimestepEmbedder(hidden_size)
133
- self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
134
- num_patches = self.x_embedder.num_patches
135
- self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
136
-
137
- self.blocks = nn.ModuleList([SiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)])
138
- self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
139
- self.initialize_weights()
140
-
141
- def initialize_weights(self) -> None:
142
- def _basic_init(module: nn.Module):
143
- if isinstance(module, nn.Linear):
144
- torch.nn.init.xavier_uniform_(module.weight)
145
- if module.bias is not None:
146
- nn.init.constant_(module.bias, 0)
147
-
148
- self.apply(_basic_init)
149
- pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5))
150
- self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
151
-
152
- w = self.x_embedder.proj.weight.data
153
- nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
154
- nn.init.constant_(self.x_embedder.proj.bias, 0)
155
- nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
156
- nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
157
- nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
158
- for block in self.blocks:
159
- nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
160
- nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
161
- nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
162
- nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
163
- nn.init.constant_(self.final_layer.linear.weight, 0)
164
- nn.init.constant_(self.final_layer.linear.bias, 0)
165
-
166
- def unpatchify(self, x: torch.Tensor) -> torch.Tensor:
167
- c = self.out_channels
168
- p = self.x_embedder.patch_size[0]
169
- h = w = int(x.shape[1] ** 0.5)
170
- x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
171
- x = torch.einsum("nhwpqc->nchpwq", x)
172
- return x.reshape(shape=(x.shape[0], c, h * p, h * p))
173
-
174
- def forward(
175
- self,
176
- hidden_states: torch.Tensor,
177
- timestep: torch.Tensor,
178
- class_labels: torch.Tensor,
179
- force_drop_ids: Optional[torch.Tensor] = None,
180
- return_dict: bool = True,
181
- ) -> SiTTransformer2DModelOutput:
182
- x = self.x_embedder(hidden_states) + self.pos_embed
183
- t = self.t_embedder(timestep)
184
- y = self.y_embedder(class_labels, self.training, force_drop_ids=force_drop_ids)
185
- c = t + y
186
- for block in self.blocks:
187
- x = block(x, c)
188
- x = self.final_layer(x, c)
189
- x = self.unpatchify(x)
190
- if self.learn_sigma:
191
- x, _ = x.chunk(2, dim=1)
192
- if not return_dict:
193
- return (x,)
194
- return SiTTransformer2DModelOutput(sample=x)
195
-
196
-
197
- def get_2d_sincos_pos_embed(embed_dim: int, grid_size: int, cls_token: bool = False, extra_tokens: int = 0):
198
- grid_h = np.arange(grid_size, dtype=np.float32)
199
- grid_w = np.arange(grid_size, dtype=np.float32)
200
- grid = np.meshgrid(grid_w, grid_h)
201
- grid = np.stack(grid, axis=0).reshape([2, 1, grid_size, grid_size])
202
- pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
203
- if cls_token and extra_tokens > 0:
204
- pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
205
- return pos_embed
206
-
207
-
208
- def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: np.ndarray):
209
- assert embed_dim % 2 == 0
210
- emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
211
- emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
212
- return np.concatenate([emb_h, emb_w], axis=1)
213
-
214
-
215
- def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: np.ndarray):
216
- assert embed_dim % 2 == 0
217
- omega = np.arange(embed_dim // 2, dtype=np.float64)
218
- omega /= embed_dim / 2.0
219
- omega = 1.0 / 10000**omega
220
- pos = pos.reshape(-1)
221
- out = np.einsum("m,d->md", pos, omega)
222
- emb_sin = np.sin(out)
223
- emb_cos = np.cos(out)
224
- return np.concatenate([emb_sin, emb_cos], axis=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from dataclasses import dataclass
17
+ from typing import Optional
18
+
19
+ import numpy as np
20
+ import torch
21
+ import torch.nn as nn
22
+ from timm.models.vision_transformer import Attention, Mlp, PatchEmbed
23
+
24
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
25
+ from diffusers.models.modeling_utils import ModelMixin
26
+ from diffusers.utils import BaseOutput
27
+
28
+
29
+ def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
30
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
31
+
32
+
33
+ @dataclass
34
+ class SiTTransformer2DModelOutput(BaseOutput):
35
+ sample: torch.Tensor
36
+
37
+
38
+ class TimestepEmbedder(nn.Module):
39
+ def __init__(self, hidden_size: int, frequency_embedding_size: int = 256):
40
+ super().__init__()
41
+ self.mlp = nn.Sequential(
42
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
43
+ nn.SiLU(),
44
+ nn.Linear(hidden_size, hidden_size, bias=True),
45
+ )
46
+ self.frequency_embedding_size = frequency_embedding_size
47
+
48
+ @staticmethod
49
+ def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 10000) -> torch.Tensor:
50
+ half = dim // 2
51
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
52
+ device=t.device
53
+ )
54
+ args = t[:, None].float() * freqs[None]
55
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
56
+ if dim % 2:
57
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
58
+ return embedding
59
+
60
+ def forward(self, t: torch.Tensor) -> torch.Tensor:
61
+ emb = self.timestep_embedding(t.float(), self.frequency_embedding_size)
62
+ weight_dtype = self.mlp[0].weight.dtype
63
+ return self.mlp(emb.to(dtype=weight_dtype))
64
+
65
+
66
+ class LabelEmbedder(nn.Module):
67
+ def __init__(self, num_classes: int, hidden_size: int, dropout_prob: float):
68
+ super().__init__()
69
+ use_cfg_embedding = dropout_prob > 0
70
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
71
+ self.num_classes = num_classes
72
+ self.dropout_prob = dropout_prob
73
+
74
+ def token_drop(self, labels: torch.Tensor, force_drop_ids: Optional[torch.Tensor] = None) -> torch.Tensor:
75
+ if force_drop_ids is None:
76
+ drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
77
+ else:
78
+ drop_ids = force_drop_ids == 1
79
+ labels = torch.where(drop_ids, self.num_classes, labels)
80
+ return labels
81
+
82
+ def forward(
83
+ self,
84
+ labels: torch.Tensor,
85
+ train: bool,
86
+ force_drop_ids: Optional[torch.Tensor] = None,
87
+ ) -> torch.Tensor:
88
+ use_dropout = self.dropout_prob > 0
89
+ if (train and use_dropout) or (force_drop_ids is not None):
90
+ labels = self.token_drop(labels, force_drop_ids)
91
+ return self.embedding_table(labels)
92
+
93
+
94
+ class SiTBlock(nn.Module):
95
+ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float = 4.0, **block_kwargs):
96
+ super().__init__()
97
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
98
+ self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
99
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
100
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
101
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
102
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
103
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
104
+
105
+ def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
106
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
107
+ x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
108
+ x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
109
+ return x
110
+
111
+
112
+ class FinalLayer(nn.Module):
113
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
114
+ super().__init__()
115
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
116
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
117
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
118
+
119
+ def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
120
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
121
+ x = modulate(self.norm_final(x), shift, scale)
122
+ return self.linear(x)
123
+
124
+
125
+ class SiTTransformer2DModel(ModelMixin, ConfigMixin):
126
+ @register_to_config
127
+ def __init__(
128
+ self,
129
+ input_size: int = 32,
130
+ patch_size: int = 2,
131
+ in_channels: int = 4,
132
+ hidden_size: int = 1152,
133
+ depth: int = 28,
134
+ num_heads: int = 16,
135
+ mlp_ratio: float = 4.0,
136
+ class_dropout_prob: float = 0.1,
137
+ num_classes: int = 1000,
138
+ learn_sigma: bool = True,
139
+ ):
140
+ super().__init__()
141
+ self.learn_sigma = learn_sigma
142
+ self.in_channels = in_channels
143
+ self.out_channels = in_channels * 2 if learn_sigma else in_channels
144
+ self.patch_size = patch_size
145
+ self.num_classes = num_classes
146
+
147
+ self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
148
+ self.t_embedder = TimestepEmbedder(hidden_size)
149
+ self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
150
+ num_patches = self.x_embedder.num_patches
151
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
152
+
153
+ self.blocks = nn.ModuleList([SiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)])
154
+ self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
155
+ self.initialize_weights()
156
+
157
+ def initialize_weights(self) -> None:
158
+ def _basic_init(module: nn.Module):
159
+ if isinstance(module, nn.Linear):
160
+ torch.nn.init.xavier_uniform_(module.weight)
161
+ if module.bias is not None:
162
+ nn.init.constant_(module.bias, 0)
163
+
164
+ self.apply(_basic_init)
165
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5))
166
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
167
+
168
+ w = self.x_embedder.proj.weight.data
169
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
170
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
171
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
172
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
173
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
174
+ for block in self.blocks:
175
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
176
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
177
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
178
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
179
+ nn.init.constant_(self.final_layer.linear.weight, 0)
180
+ nn.init.constant_(self.final_layer.linear.bias, 0)
181
+
182
+ def unpatchify(self, x: torch.Tensor) -> torch.Tensor:
183
+ c = self.out_channels
184
+ p = self.x_embedder.patch_size[0]
185
+ h = w = int(x.shape[1] ** 0.5)
186
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
187
+ x = torch.einsum("nhwpqc->nchpwq", x)
188
+ return x.reshape(shape=(x.shape[0], c, h * p, h * p))
189
+
190
+ def forward(
191
+ self,
192
+ hidden_states: torch.Tensor,
193
+ timestep: torch.Tensor,
194
+ class_labels: torch.Tensor,
195
+ force_drop_ids: Optional[torch.Tensor] = None,
196
+ return_dict: bool = True,
197
+ ) -> SiTTransformer2DModelOutput:
198
+ x = self.x_embedder(hidden_states) + self.pos_embed
199
+ t = self.t_embedder(timestep)
200
+ y = self.y_embedder(class_labels, self.training, force_drop_ids=force_drop_ids)
201
+ c = t + y
202
+ for block in self.blocks:
203
+ x = block(x, c)
204
+ x = self.final_layer(x, c)
205
+ x = self.unpatchify(x)
206
+ if self.learn_sigma:
207
+ x, _ = x.chunk(2, dim=1)
208
+ if not return_dict:
209
+ return (x,)
210
+ return SiTTransformer2DModelOutput(sample=x)
211
+
212
+
213
+ def get_2d_sincos_pos_embed(embed_dim: int, grid_size: int, cls_token: bool = False, extra_tokens: int = 0):
214
+ grid_h = np.arange(grid_size, dtype=np.float32)
215
+ grid_w = np.arange(grid_size, dtype=np.float32)
216
+ grid = np.meshgrid(grid_w, grid_h)
217
+ grid = np.stack(grid, axis=0).reshape([2, 1, grid_size, grid_size])
218
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
219
+ if cls_token and extra_tokens > 0:
220
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
221
+ return pos_embed
222
+
223
+
224
+ def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: np.ndarray):
225
+ assert embed_dim % 2 == 0
226
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
227
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
228
+ return np.concatenate([emb_h, emb_w], axis=1)
229
+
230
+
231
+ def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: np.ndarray):
232
+ assert embed_dim % 2 == 0
233
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
234
+ omega /= embed_dim / 2.0
235
+ omega = 1.0 / 10000**omega
236
+ pos = pos.reshape(-1)
237
+ out = np.einsum("m,d->md", pos, omega)
238
+ emb_sin = np.sin(out)
239
+ emb_cos = np.cos(out)
240
+ return np.concatenate([emb_sin, emb_cos], axis=1)