BiliSakura commited on
Commit
c5fe00d
Β·
verified Β·
1 Parent(s): ab0e5b8

Upload folder using huggingface_hub

Browse files
PixNerd-XL-16-256/README.md CHANGED
@@ -26,16 +26,16 @@ from diffusers import DiffusionPipeline
26
  pipe = DiffusionPipeline.from_pretrained(
27
  "BiliSakura/PixNerd-diffusers/PixNerd-XL-16-256",
28
  trust_remote_code=True,
29
- torch_dtype=torch.float32,
30
  ).to("cuda")
31
 
 
 
32
  images = pipe(
33
- prompt=207,
34
  height=256,
35
  width=256,
36
  num_inference_steps=25,
37
  guidance_scale=4.0,
38
- timeshift=3.0,
39
- order=2,
40
  ).images
41
  ```
 
26
  pipe = DiffusionPipeline.from_pretrained(
27
  "BiliSakura/PixNerd-diffusers/PixNerd-XL-16-256",
28
  trust_remote_code=True,
29
+ torch_dtype=torch.bfloat16,
30
  ).to("cuda")
31
 
32
+ # timeshift=3.0 and order=2 are defaults in scheduler/scheduler_config.json
33
+
34
  images = pipe(
35
+ class_labels="golden retriever",
36
  height=256,
37
  width=256,
38
  num_inference_steps=25,
39
  guidance_scale=4.0,
 
 
40
  ).images
41
  ```
PixNerd-XL-16-256/model_index.json CHANGED
@@ -1,15 +1,1017 @@
1
- {
2
- "_class_name": [
3
- "pipeline",
4
- "PixNerdPipeline"
5
- ],
6
- "_diffusers_version": "0.36.0",
7
- "scheduler": [
8
- "scheduling_pixnerd_flow_match",
9
- "PixNerdFlowMatchScheduler"
10
- ],
11
- "transformer": [
12
- "modeling_pixnerd_transformer_2d",
13
- "PixNerdTransformer2DModel"
14
- ]
15
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": [
3
+ "pipeline",
4
+ "PixNerdPipeline"
5
+ ],
6
+ "_diffusers_version": "0.36.0",
7
+ "scheduler": [
8
+ "scheduling_pixnerd_flow_match",
9
+ "PixNerdFlowMatchScheduler"
10
+ ],
11
+ "transformer": [
12
+ "modeling_pixnerd_transformer_2d",
13
+ "PixNerdTransformer2DModel"
14
+ ],
15
+ "id2label": {
16
+ "0": "tench, Tinca tinca",
17
+ "1": "goldfish, Carassius auratus",
18
+ "2": "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias",
19
+ "3": "tiger shark, Galeocerdo cuvieri",
20
+ "4": "hammerhead, hammerhead shark",
21
+ "5": "electric ray, crampfish, numbfish, torpedo",
22
+ "6": "stingray",
23
+ "7": "cock",
24
+ "8": "hen",
25
+ "9": "ostrich, Struthio camelus",
26
+ "10": "brambling, Fringilla montifringilla",
27
+ "11": "goldfinch, Carduelis carduelis",
28
+ "12": "house finch, linnet, Carpodacus mexicanus",
29
+ "13": "junco, snowbird",
30
+ "14": "indigo bunting, indigo finch, indigo bird, Passerina cyanea",
31
+ "15": "robin, American robin, Turdus migratorius",
32
+ "16": "bulbul",
33
+ "17": "jay",
34
+ "18": "magpie",
35
+ "19": "chickadee",
36
+ "20": "water ouzel, dipper",
37
+ "21": "kite",
38
+ "22": "bald eagle, American eagle, Haliaeetus leucocephalus",
39
+ "23": "vulture",
40
+ "24": "great grey owl, great gray owl, Strix nebulosa",
41
+ "25": "European fire salamander, Salamandra salamandra",
42
+ "26": "common newt, Triturus vulgaris",
43
+ "27": "eft",
44
+ "28": "spotted salamander, Ambystoma maculatum",
45
+ "29": "axolotl, mud puppy, Ambystoma mexicanum",
46
+ "30": "bullfrog, Rana catesbeiana",
47
+ "31": "tree frog, tree-frog",
48
+ "32": "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui",
49
+ "33": "loggerhead, loggerhead turtle, Caretta caretta",
50
+ "34": "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea",
51
+ "35": "mud turtle",
52
+ "36": "terrapin",
53
+ "37": "box turtle, box tortoise",
54
+ "38": "banded gecko",
55
+ "39": "common iguana, iguana, Iguana iguana",
56
+ "40": "American chameleon, anole, Anolis carolinensis",
57
+ "41": "whiptail, whiptail lizard",
58
+ "42": "agama",
59
+ "43": "frilled lizard, Chlamydosaurus kingi",
60
+ "44": "alligator lizard",
61
+ "45": "Gila monster, Heloderma suspectum",
62
+ "46": "green lizard, Lacerta viridis",
63
+ "47": "African chameleon, Chamaeleo chamaeleon",
64
+ "48": "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis",
65
+ "49": "African crocodile, Nile crocodile, Crocodylus niloticus",
66
+ "50": "American alligator, Alligator mississipiensis",
67
+ "51": "triceratops",
68
+ "52": "thunder snake, worm snake, Carphophis amoenus",
69
+ "53": "ringneck snake, ring-necked snake, ring snake",
70
+ "54": "hognose snake, puff adder, sand viper",
71
+ "55": "green snake, grass snake",
72
+ "56": "king snake, kingsnake",
73
+ "57": "garter snake, grass snake",
74
+ "58": "water snake",
75
+ "59": "vine snake",
76
+ "60": "night snake, Hypsiglena torquata",
77
+ "61": "boa constrictor, Constrictor constrictor",
78
+ "62": "rock python, rock snake, Python sebae",
79
+ "63": "Indian cobra, Naja naja",
80
+ "64": "green mamba",
81
+ "65": "sea snake",
82
+ "66": "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus",
83
+ "67": "diamondback, diamondback rattlesnake, Crotalus adamanteus",
84
+ "68": "sidewinder, horned rattlesnake, Crotalus cerastes",
85
+ "69": "trilobite",
86
+ "70": "harvestman, daddy longlegs, Phalangium opilio",
87
+ "71": "scorpion",
88
+ "72": "black and gold garden spider, Argiope aurantia",
89
+ "73": "barn spider, Araneus cavaticus",
90
+ "74": "garden spider, Aranea diademata",
91
+ "75": "black widow, Latrodectus mactans",
92
+ "76": "tarantula",
93
+ "77": "wolf spider, hunting spider",
94
+ "78": "tick",
95
+ "79": "centipede",
96
+ "80": "black grouse",
97
+ "81": "ptarmigan",
98
+ "82": "ruffed grouse, partridge, Bonasa umbellus",
99
+ "83": "prairie chicken, prairie grouse, prairie fowl",
100
+ "84": "peacock",
101
+ "85": "quail",
102
+ "86": "partridge",
103
+ "87": "African grey, African gray, Psittacus erithacus",
104
+ "88": "macaw",
105
+ "89": "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita",
106
+ "90": "lorikeet",
107
+ "91": "coucal",
108
+ "92": "bee eater",
109
+ "93": "hornbill",
110
+ "94": "hummingbird",
111
+ "95": "jacamar",
112
+ "96": "toucan",
113
+ "97": "drake",
114
+ "98": "red-breasted merganser, Mergus serrator",
115
+ "99": "goose",
116
+ "100": "black swan, Cygnus atratus",
117
+ "101": "tusker",
118
+ "102": "echidna, spiny anteater, anteater",
119
+ "103": "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus",
120
+ "104": "wallaby, brush kangaroo",
121
+ "105": "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus",
122
+ "106": "wombat",
123
+ "107": "jellyfish",
124
+ "108": "sea anemone, anemone",
125
+ "109": "brain coral",
126
+ "110": "flatworm, platyhelminth",
127
+ "111": "nematode, nematode worm, roundworm",
128
+ "112": "conch",
129
+ "113": "snail",
130
+ "114": "slug",
131
+ "115": "sea slug, nudibranch",
132
+ "116": "chiton, coat-of-mail shell, sea cradle, polyplacophore",
133
+ "117": "chambered nautilus, pearly nautilus, nautilus",
134
+ "118": "Dungeness crab, Cancer magister",
135
+ "119": "rock crab, Cancer irroratus",
136
+ "120": "fiddler crab",
137
+ "121": "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica",
138
+ "122": "American lobster, Northern lobster, Maine lobster, Homarus americanus",
139
+ "123": "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish",
140
+ "124": "crayfish, crawfish, crawdad, crawdaddy",
141
+ "125": "hermit crab",
142
+ "126": "isopod",
143
+ "127": "white stork, Ciconia ciconia",
144
+ "128": "black stork, Ciconia nigra",
145
+ "129": "spoonbill",
146
+ "130": "flamingo",
147
+ "131": "little blue heron, Egretta caerulea",
148
+ "132": "American egret, great white heron, Egretta albus",
149
+ "133": "bittern",
150
+ "134": "crane",
151
+ "135": "limpkin, Aramus pictus",
152
+ "136": "European gallinule, Porphyrio porphyrio",
153
+ "137": "American coot, marsh hen, mud hen, water hen, Fulica americana",
154
+ "138": "bustard",
155
+ "139": "ruddy turnstone, Arenaria interpres",
156
+ "140": "red-backed sandpiper, dunlin, Erolia alpina",
157
+ "141": "redshank, Tringa totanus",
158
+ "142": "dowitcher",
159
+ "143": "oystercatcher, oyster catcher",
160
+ "144": "pelican",
161
+ "145": "king penguin, Aptenodytes patagonica",
162
+ "146": "albatross, mollymawk",
163
+ "147": "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus",
164
+ "148": "killer whale, killer, orca, grampus, sea wolf, Orcinus orca",
165
+ "149": "dugong, Dugong dugon",
166
+ "150": "sea lion",
167
+ "151": "Chihuahua",
168
+ "152": "Japanese spaniel",
169
+ "153": "Maltese dog, Maltese terrier, Maltese",
170
+ "154": "Pekinese, Pekingese, Peke",
171
+ "155": "Shih-Tzu",
172
+ "156": "Blenheim spaniel",
173
+ "157": "papillon",
174
+ "158": "toy terrier",
175
+ "159": "Rhodesian ridgeback",
176
+ "160": "Afghan hound, Afghan",
177
+ "161": "basset, basset hound",
178
+ "162": "beagle",
179
+ "163": "bloodhound, sleuthhound",
180
+ "164": "bluetick",
181
+ "165": "black-and-tan coonhound",
182
+ "166": "Walker hound, Walker foxhound",
183
+ "167": "English foxhound",
184
+ "168": "redbone",
185
+ "169": "borzoi, Russian wolfhound",
186
+ "170": "Irish wolfhound",
187
+ "171": "Italian greyhound",
188
+ "172": "whippet",
189
+ "173": "Ibizan hound, Ibizan Podenco",
190
+ "174": "Norwegian elkhound, elkhound",
191
+ "175": "otterhound, otter hound",
192
+ "176": "Saluki, gazelle hound",
193
+ "177": "Scottish deerhound, deerhound",
194
+ "178": "Weimaraner",
195
+ "179": "Staffordshire bullterrier, Staffordshire bull terrier",
196
+ "180": "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier",
197
+ "181": "Bedlington terrier",
198
+ "182": "Border terrier",
199
+ "183": "Kerry blue terrier",
200
+ "184": "Irish terrier",
201
+ "185": "Norfolk terrier",
202
+ "186": "Norwich terrier",
203
+ "187": "Yorkshire terrier",
204
+ "188": "wire-haired fox terrier",
205
+ "189": "Lakeland terrier",
206
+ "190": "Sealyham terrier, Sealyham",
207
+ "191": "Airedale, Airedale terrier",
208
+ "192": "cairn, cairn terrier",
209
+ "193": "Australian terrier",
210
+ "194": "Dandie Dinmont, Dandie Dinmont terrier",
211
+ "195": "Boston bull, Boston terrier",
212
+ "196": "miniature schnauzer",
213
+ "197": "giant schnauzer",
214
+ "198": "standard schnauzer",
215
+ "199": "Scotch terrier, Scottish terrier, Scottie",
216
+ "200": "Tibetan terrier, chrysanthemum dog",
217
+ "201": "silky terrier, Sydney silky",
218
+ "202": "soft-coated wheaten terrier",
219
+ "203": "West Highland white terrier",
220
+ "204": "Lhasa, Lhasa apso",
221
+ "205": "flat-coated retriever",
222
+ "206": "curly-coated retriever",
223
+ "207": "golden retriever",
224
+ "208": "Labrador retriever",
225
+ "209": "Chesapeake Bay retriever",
226
+ "210": "German short-haired pointer",
227
+ "211": "vizsla, Hungarian pointer",
228
+ "212": "English setter",
229
+ "213": "Irish setter, red setter",
230
+ "214": "Gordon setter",
231
+ "215": "Brittany spaniel",
232
+ "216": "clumber, clumber spaniel",
233
+ "217": "English springer, English springer spaniel",
234
+ "218": "Welsh springer spaniel",
235
+ "219": "cocker spaniel, English cocker spaniel, cocker",
236
+ "220": "Sussex spaniel",
237
+ "221": "Irish water spaniel",
238
+ "222": "kuvasz",
239
+ "223": "schipperke",
240
+ "224": "groenendael",
241
+ "225": "malinois",
242
+ "226": "briard",
243
+ "227": "kelpie",
244
+ "228": "komondor",
245
+ "229": "Old English sheepdog, bobtail",
246
+ "230": "Shetland sheepdog, Shetland sheep dog, Shetland",
247
+ "231": "collie",
248
+ "232": "Border collie",
249
+ "233": "Bouvier des Flandres, Bouviers des Flandres",
250
+ "234": "Rottweiler",
251
+ "235": "German shepherd, German shepherd dog, German police dog, alsatian",
252
+ "236": "Doberman, Doberman pinscher",
253
+ "237": "miniature pinscher",
254
+ "238": "Greater Swiss Mountain dog",
255
+ "239": "Bernese mountain dog",
256
+ "240": "Appenzeller",
257
+ "241": "EntleBucher",
258
+ "242": "boxer",
259
+ "243": "bull mastiff",
260
+ "244": "Tibetan mastiff",
261
+ "245": "French bulldog",
262
+ "246": "Great Dane",
263
+ "247": "Saint Bernard, St Bernard",
264
+ "248": "Eskimo dog, husky",
265
+ "249": "malamute, malemute, Alaskan malamute",
266
+ "250": "Siberian husky",
267
+ "251": "dalmatian, coach dog, carriage dog",
268
+ "252": "affenpinscher, monkey pinscher, monkey dog",
269
+ "253": "basenji",
270
+ "254": "pug, pug-dog",
271
+ "255": "Leonberg",
272
+ "256": "Newfoundland, Newfoundland dog",
273
+ "257": "Great Pyrenees",
274
+ "258": "Samoyed, Samoyede",
275
+ "259": "Pomeranian",
276
+ "260": "chow, chow chow",
277
+ "261": "keeshond",
278
+ "262": "Brabancon griffon",
279
+ "263": "Pembroke, Pembroke Welsh corgi",
280
+ "264": "Cardigan, Cardigan Welsh corgi",
281
+ "265": "toy poodle",
282
+ "266": "miniature poodle",
283
+ "267": "standard poodle",
284
+ "268": "Mexican hairless",
285
+ "269": "timber wolf, grey wolf, gray wolf, Canis lupus",
286
+ "270": "white wolf, Arctic wolf, Canis lupus tundrarum",
287
+ "271": "red wolf, maned wolf, Canis rufus, Canis niger",
288
+ "272": "coyote, prairie wolf, brush wolf, Canis latrans",
289
+ "273": "dingo, warrigal, warragal, Canis dingo",
290
+ "274": "dhole, Cuon alpinus",
291
+ "275": "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus",
292
+ "276": "hyena, hyaena",
293
+ "277": "red fox, Vulpes vulpes",
294
+ "278": "kit fox, Vulpes macrotis",
295
+ "279": "Arctic fox, white fox, Alopex lagopus",
296
+ "280": "grey fox, gray fox, Urocyon cinereoargenteus",
297
+ "281": "tabby, tabby cat",
298
+ "282": "tiger cat",
299
+ "283": "Persian cat",
300
+ "284": "Siamese cat, Siamese",
301
+ "285": "Egyptian cat",
302
+ "286": "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor",
303
+ "287": "lynx, catamount",
304
+ "288": "leopard, Panthera pardus",
305
+ "289": "snow leopard, ounce, Panthera uncia",
306
+ "290": "jaguar, panther, Panthera onca, Felis onca",
307
+ "291": "lion, king of beasts, Panthera leo",
308
+ "292": "tiger, Panthera tigris",
309
+ "293": "cheetah, chetah, Acinonyx jubatus",
310
+ "294": "brown bear, bruin, Ursus arctos",
311
+ "295": "American black bear, black bear, Ursus americanus, Euarctos americanus",
312
+ "296": "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus",
313
+ "297": "sloth bear, Melursus ursinus, Ursus ursinus",
314
+ "298": "mongoose",
315
+ "299": "meerkat, mierkat",
316
+ "300": "tiger beetle",
317
+ "301": "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle",
318
+ "302": "ground beetle, carabid beetle",
319
+ "303": "long-horned beetle, longicorn, longicorn beetle",
320
+ "304": "leaf beetle, chrysomelid",
321
+ "305": "dung beetle",
322
+ "306": "rhinoceros beetle",
323
+ "307": "weevil",
324
+ "308": "fly",
325
+ "309": "bee",
326
+ "310": "ant, emmet, pismire",
327
+ "311": "grasshopper, hopper",
328
+ "312": "cricket",
329
+ "313": "walking stick, walkingstick, stick insect",
330
+ "314": "cockroach, roach",
331
+ "315": "mantis, mantid",
332
+ "316": "cicada, cicala",
333
+ "317": "leafhopper",
334
+ "318": "lacewing, lacewing fly",
335
+ "319": "dragonfly, darning needle, devils darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk",
336
+ "320": "damselfly",
337
+ "321": "admiral",
338
+ "322": "ringlet, ringlet butterfly",
339
+ "323": "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus",
340
+ "324": "cabbage butterfly",
341
+ "325": "sulphur butterfly, sulfur butterfly",
342
+ "326": "lycaenid, lycaenid butterfly",
343
+ "327": "starfish, sea star",
344
+ "328": "sea urchin",
345
+ "329": "sea cucumber, holothurian",
346
+ "330": "wood rabbit, cottontail, cottontail rabbit",
347
+ "331": "hare",
348
+ "332": "Angora, Angora rabbit",
349
+ "333": "hamster",
350
+ "334": "porcupine, hedgehog",
351
+ "335": "fox squirrel, eastern fox squirrel, Sciurus niger",
352
+ "336": "marmot",
353
+ "337": "beaver",
354
+ "338": "guinea pig, Cavia cobaya",
355
+ "339": "sorrel",
356
+ "340": "zebra",
357
+ "341": "hog, pig, grunter, squealer, Sus scrofa",
358
+ "342": "wild boar, boar, Sus scrofa",
359
+ "343": "warthog",
360
+ "344": "hippopotamus, hippo, river horse, Hippopotamus amphibius",
361
+ "345": "ox",
362
+ "346": "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis",
363
+ "347": "bison",
364
+ "348": "ram, tup",
365
+ "349": "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis",
366
+ "350": "ibex, Capra ibex",
367
+ "351": "hartebeest",
368
+ "352": "impala, Aepyceros melampus",
369
+ "353": "gazelle",
370
+ "354": "Arabian camel, dromedary, Camelus dromedarius",
371
+ "355": "llama",
372
+ "356": "weasel",
373
+ "357": "mink",
374
+ "358": "polecat, fitch, foulmart, foumart, Mustela putorius",
375
+ "359": "black-footed ferret, ferret, Mustela nigripes",
376
+ "360": "otter",
377
+ "361": "skunk, polecat, wood pussy",
378
+ "362": "badger",
379
+ "363": "armadillo",
380
+ "364": "three-toed sloth, ai, Bradypus tridactylus",
381
+ "365": "orangutan, orang, orangutang, Pongo pygmaeus",
382
+ "366": "gorilla, Gorilla gorilla",
383
+ "367": "chimpanzee, chimp, Pan troglodytes",
384
+ "368": "gibbon, Hylobates lar",
385
+ "369": "siamang, Hylobates syndactylus, Symphalangus syndactylus",
386
+ "370": "guenon, guenon monkey",
387
+ "371": "patas, hussar monkey, Erythrocebus patas",
388
+ "372": "baboon",
389
+ "373": "macaque",
390
+ "374": "langur",
391
+ "375": "colobus, colobus monkey",
392
+ "376": "proboscis monkey, Nasalis larvatus",
393
+ "377": "marmoset",
394
+ "378": "capuchin, ringtail, Cebus capucinus",
395
+ "379": "howler monkey, howler",
396
+ "380": "titi, titi monkey",
397
+ "381": "spider monkey, Ateles geoffroyi",
398
+ "382": "squirrel monkey, Saimiri sciureus",
399
+ "383": "Madagascar cat, ring-tailed lemur, Lemur catta",
400
+ "384": "indri, indris, Indri indri, Indri brevicaudatus",
401
+ "385": "Indian elephant, Elephas maximus",
402
+ "386": "African elephant, Loxodonta africana",
403
+ "387": "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens",
404
+ "388": "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca",
405
+ "389": "barracouta, snoek",
406
+ "390": "eel",
407
+ "391": "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch",
408
+ "392": "rock beauty, Holocanthus tricolor",
409
+ "393": "anemone fish",
410
+ "394": "sturgeon",
411
+ "395": "gar, garfish, garpike, billfish, Lepisosteus osseus",
412
+ "396": "lionfish",
413
+ "397": "puffer, pufferfish, blowfish, globefish",
414
+ "398": "abacus",
415
+ "399": "abaya",
416
+ "400": "academic gown, academic robe, judge robe",
417
+ "401": "accordion, piano accordion, squeeze box",
418
+ "402": "acoustic guitar",
419
+ "403": "aircraft carrier, carrier, flattop, attack aircraft carrier",
420
+ "404": "airliner",
421
+ "405": "airship, dirigible",
422
+ "406": "altar",
423
+ "407": "ambulance",
424
+ "408": "amphibian, amphibious vehicle",
425
+ "409": "analog clock",
426
+ "410": "apiary, bee house",
427
+ "411": "apron",
428
+ "412": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin",
429
+ "413": "assault rifle, assault gun",
430
+ "414": "backpack, back pack, knapsack, packsack, rucksack, haversack",
431
+ "415": "bakery, bakeshop, bakehouse",
432
+ "416": "balance beam, beam",
433
+ "417": "balloon",
434
+ "418": "ballpoint, ballpoint pen, ballpen, Biro",
435
+ "419": "Band Aid",
436
+ "420": "banjo",
437
+ "421": "bannister, banister, balustrade, balusters, handrail",
438
+ "422": "barbell",
439
+ "423": "barber chair",
440
+ "424": "barbershop",
441
+ "425": "barn",
442
+ "426": "barometer",
443
+ "427": "barrel, cask",
444
+ "428": "barrow, garden cart, lawn cart, wheelbarrow",
445
+ "429": "baseball",
446
+ "430": "basketball",
447
+ "431": "bassinet",
448
+ "432": "bassoon",
449
+ "433": "bathing cap, swimming cap",
450
+ "434": "bath towel",
451
+ "435": "bathtub, bathing tub, bath, tub",
452
+ "436": "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon",
453
+ "437": "beacon, lighthouse, beacon light, pharos",
454
+ "438": "beaker",
455
+ "439": "bearskin, busby, shako",
456
+ "440": "beer bottle",
457
+ "441": "beer glass",
458
+ "442": "bell cote, bell cot",
459
+ "443": "bib",
460
+ "444": "bicycle-built-for-two, tandem bicycle, tandem",
461
+ "445": "bikini, two-piece",
462
+ "446": "binder, ring-binder",
463
+ "447": "binoculars, field glasses, opera glasses",
464
+ "448": "birdhouse",
465
+ "449": "boathouse",
466
+ "450": "bobsled, bobsleigh, bob",
467
+ "451": "bolo tie, bolo, bola tie, bola",
468
+ "452": "bonnet, poke bonnet",
469
+ "453": "bookcase",
470
+ "454": "bookshop, bookstore, bookstall",
471
+ "455": "bottlecap",
472
+ "456": "bow",
473
+ "457": "bow tie, bow-tie, bowtie",
474
+ "458": "brass, memorial tablet, plaque",
475
+ "459": "brassiere, bra, bandeau",
476
+ "460": "breakwater, groin, groyne, mole, bulwark, seawall, jetty",
477
+ "461": "breastplate, aegis, egis",
478
+ "462": "broom",
479
+ "463": "bucket, pail",
480
+ "464": "buckle",
481
+ "465": "bulletproof vest",
482
+ "466": "bullet train, bullet",
483
+ "467": "butcher shop, meat market",
484
+ "468": "cab, hack, taxi, taxicab",
485
+ "469": "caldron, cauldron",
486
+ "470": "candle, taper, wax light",
487
+ "471": "cannon",
488
+ "472": "canoe",
489
+ "473": "can opener, tin opener",
490
+ "474": "cardigan",
491
+ "475": "car mirror",
492
+ "476": "carousel, carrousel, merry-go-round, roundabout, whirligig",
493
+ "477": "carpenters kit, tool kit",
494
+ "478": "carton",
495
+ "479": "car wheel",
496
+ "480": "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM",
497
+ "481": "cassette",
498
+ "482": "cassette player",
499
+ "483": "castle",
500
+ "484": "catamaran",
501
+ "485": "CD player",
502
+ "486": "cello, violoncello",
503
+ "487": "cellular telephone, cellular phone, cellphone, cell, mobile phone",
504
+ "488": "chain",
505
+ "489": "chainlink fence",
506
+ "490": "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour",
507
+ "491": "chain saw, chainsaw",
508
+ "492": "chest",
509
+ "493": "chiffonier, commode",
510
+ "494": "chime, bell, gong",
511
+ "495": "china cabinet, china closet",
512
+ "496": "Christmas stocking",
513
+ "497": "church, church building",
514
+ "498": "cinema, movie theater, movie theatre, movie house, picture palace",
515
+ "499": "cleaver, meat cleaver, chopper",
516
+ "500": "cliff dwelling",
517
+ "501": "cloak",
518
+ "502": "clog, geta, patten, sabot",
519
+ "503": "cocktail shaker",
520
+ "504": "coffee mug",
521
+ "505": "coffeepot",
522
+ "506": "coil, spiral, volute, whorl, helix",
523
+ "507": "combination lock",
524
+ "508": "computer keyboard, keypad",
525
+ "509": "confectionery, confectionary, candy store",
526
+ "510": "container ship, containership, container vessel",
527
+ "511": "convertible",
528
+ "512": "corkscrew, bottle screw",
529
+ "513": "cornet, horn, trumpet, trump",
530
+ "514": "cowboy boot",
531
+ "515": "cowboy hat, ten-gallon hat",
532
+ "516": "cradle",
533
+ "517": "crane",
534
+ "518": "crash helmet",
535
+ "519": "crate",
536
+ "520": "crib, cot",
537
+ "521": "Crock Pot",
538
+ "522": "croquet ball",
539
+ "523": "crutch",
540
+ "524": "cuirass",
541
+ "525": "dam, dike, dyke",
542
+ "526": "desk",
543
+ "527": "desktop computer",
544
+ "528": "dial telephone, dial phone",
545
+ "529": "diaper, nappy, napkin",
546
+ "530": "digital clock",
547
+ "531": "digital watch",
548
+ "532": "dining table, board",
549
+ "533": "dishrag, dishcloth",
550
+ "534": "dishwasher, dish washer, dishwashing machine",
551
+ "535": "disk brake, disc brake",
552
+ "536": "dock, dockage, docking facility",
553
+ "537": "dogsled, dog sled, dog sleigh",
554
+ "538": "dome",
555
+ "539": "doormat, welcome mat",
556
+ "540": "drilling platform, offshore rig",
557
+ "541": "drum, membranophone, tympan",
558
+ "542": "drumstick",
559
+ "543": "dumbbell",
560
+ "544": "Dutch oven",
561
+ "545": "electric fan, blower",
562
+ "546": "electric guitar",
563
+ "547": "electric locomotive",
564
+ "548": "entertainment center",
565
+ "549": "envelope",
566
+ "550": "espresso maker",
567
+ "551": "face powder",
568
+ "552": "feather boa, boa",
569
+ "553": "file, file cabinet, filing cabinet",
570
+ "554": "fireboat",
571
+ "555": "fire engine, fire truck",
572
+ "556": "fire screen, fireguard",
573
+ "557": "flagpole, flagstaff",
574
+ "558": "flute, transverse flute",
575
+ "559": "folding chair",
576
+ "560": "football helmet",
577
+ "561": "forklift",
578
+ "562": "fountain",
579
+ "563": "fountain pen",
580
+ "564": "four-poster",
581
+ "565": "freight car",
582
+ "566": "French horn, horn",
583
+ "567": "frying pan, frypan, skillet",
584
+ "568": "fur coat",
585
+ "569": "garbage truck, dustcart",
586
+ "570": "gasmask, respirator, gas helmet",
587
+ "571": "gas pump, gasoline pump, petrol pump, island dispenser",
588
+ "572": "goblet",
589
+ "573": "go-kart",
590
+ "574": "golf ball",
591
+ "575": "golfcart, golf cart",
592
+ "576": "gondola",
593
+ "577": "gong, tam-tam",
594
+ "578": "gown",
595
+ "579": "grand piano, grand",
596
+ "580": "greenhouse, nursery, glasshouse",
597
+ "581": "grille, radiator grille",
598
+ "582": "grocery store, grocery, food market, market",
599
+ "583": "guillotine",
600
+ "584": "hair slide",
601
+ "585": "hair spray",
602
+ "586": "half track",
603
+ "587": "hammer",
604
+ "588": "hamper",
605
+ "589": "hand blower, blow dryer, blow drier, hair dryer, hair drier",
606
+ "590": "hand-held computer, hand-held microcomputer",
607
+ "591": "handkerchief, hankie, hanky, hankey",
608
+ "592": "hard disc, hard disk, fixed disk",
609
+ "593": "harmonica, mouth organ, harp, mouth harp",
610
+ "594": "harp",
611
+ "595": "harvester, reaper",
612
+ "596": "hatchet",
613
+ "597": "holster",
614
+ "598": "home theater, home theatre",
615
+ "599": "honeycomb",
616
+ "600": "hook, claw",
617
+ "601": "hoopskirt, crinoline",
618
+ "602": "horizontal bar, high bar",
619
+ "603": "horse cart, horse-cart",
620
+ "604": "hourglass",
621
+ "605": "iPod",
622
+ "606": "iron, smoothing iron",
623
+ "607": "jack-o-lantern",
624
+ "608": "jean, blue jean, denim",
625
+ "609": "jeep, landrover",
626
+ "610": "jersey, T-shirt, tee shirt",
627
+ "611": "jigsaw puzzle",
628
+ "612": "jinrikisha, ricksha, rickshaw",
629
+ "613": "joystick",
630
+ "614": "kimono",
631
+ "615": "knee pad",
632
+ "616": "knot",
633
+ "617": "lab coat, laboratory coat",
634
+ "618": "ladle",
635
+ "619": "lampshade, lamp shade",
636
+ "620": "laptop, laptop computer",
637
+ "621": "lawn mower, mower",
638
+ "622": "lens cap, lens cover",
639
+ "623": "letter opener, paper knife, paperknife",
640
+ "624": "library",
641
+ "625": "lifeboat",
642
+ "626": "lighter, light, igniter, ignitor",
643
+ "627": "limousine, limo",
644
+ "628": "liner, ocean liner",
645
+ "629": "lipstick, lip rouge",
646
+ "630": "Loafer",
647
+ "631": "lotion",
648
+ "632": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system",
649
+ "633": "loupe, jewelers loupe",
650
+ "634": "lumbermill, sawmill",
651
+ "635": "magnetic compass",
652
+ "636": "mailbag, postbag",
653
+ "637": "mailbox, letter box",
654
+ "638": "maillot",
655
+ "639": "maillot, tank suit",
656
+ "640": "manhole cover",
657
+ "641": "maraca",
658
+ "642": "marimba, xylophone",
659
+ "643": "mask",
660
+ "644": "matchstick",
661
+ "645": "maypole",
662
+ "646": "maze, labyrinth",
663
+ "647": "measuring cup",
664
+ "648": "medicine chest, medicine cabinet",
665
+ "649": "megalith, megalithic structure",
666
+ "650": "microphone, mike",
667
+ "651": "microwave, microwave oven",
668
+ "652": "military uniform",
669
+ "653": "milk can",
670
+ "654": "minibus",
671
+ "655": "miniskirt, mini",
672
+ "656": "minivan",
673
+ "657": "missile",
674
+ "658": "mitten",
675
+ "659": "mixing bowl",
676
+ "660": "mobile home, manufactured home",
677
+ "661": "Model T",
678
+ "662": "modem",
679
+ "663": "monastery",
680
+ "664": "monitor",
681
+ "665": "moped",
682
+ "666": "mortar",
683
+ "667": "mortarboard",
684
+ "668": "mosque",
685
+ "669": "mosquito net",
686
+ "670": "motor scooter, scooter",
687
+ "671": "mountain bike, all-terrain bike, off-roader",
688
+ "672": "mountain tent",
689
+ "673": "mouse, computer mouse",
690
+ "674": "mousetrap",
691
+ "675": "moving van",
692
+ "676": "muzzle",
693
+ "677": "nail",
694
+ "678": "neck brace",
695
+ "679": "necklace",
696
+ "680": "nipple",
697
+ "681": "notebook, notebook computer",
698
+ "682": "obelisk",
699
+ "683": "oboe, hautboy, hautbois",
700
+ "684": "ocarina, sweet potato",
701
+ "685": "odometer, hodometer, mileometer, milometer",
702
+ "686": "oil filter",
703
+ "687": "organ, pipe organ",
704
+ "688": "oscilloscope, scope, cathode-ray oscilloscope, CRO",
705
+ "689": "overskirt",
706
+ "690": "oxcart",
707
+ "691": "oxygen mask",
708
+ "692": "packet",
709
+ "693": "paddle, boat paddle",
710
+ "694": "paddlewheel, paddle wheel",
711
+ "695": "padlock",
712
+ "696": "paintbrush",
713
+ "697": "pajama, pyjama, pjs, jammies",
714
+ "698": "palace",
715
+ "699": "panpipe, pandean pipe, syrinx",
716
+ "700": "paper towel",
717
+ "701": "parachute, chute",
718
+ "702": "parallel bars, bars",
719
+ "703": "park bench",
720
+ "704": "parking meter",
721
+ "705": "passenger car, coach, carriage",
722
+ "706": "patio, terrace",
723
+ "707": "pay-phone, pay-station",
724
+ "708": "pedestal, plinth, footstall",
725
+ "709": "pencil box, pencil case",
726
+ "710": "pencil sharpener",
727
+ "711": "perfume, essence",
728
+ "712": "Petri dish",
729
+ "713": "photocopier",
730
+ "714": "pick, plectrum, plectron",
731
+ "715": "pickelhaube",
732
+ "716": "picket fence, paling",
733
+ "717": "pickup, pickup truck",
734
+ "718": "pier",
735
+ "719": "piggy bank, penny bank",
736
+ "720": "pill bottle",
737
+ "721": "pillow",
738
+ "722": "ping-pong ball",
739
+ "723": "pinwheel",
740
+ "724": "pirate, pirate ship",
741
+ "725": "pitcher, ewer",
742
+ "726": "plane, carpenters plane, woodworking plane",
743
+ "727": "planetarium",
744
+ "728": "plastic bag",
745
+ "729": "plate rack",
746
+ "730": "plow, plough",
747
+ "731": "plunger, plumbers helper",
748
+ "732": "Polaroid camera, Polaroid Land camera",
749
+ "733": "pole",
750
+ "734": "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria",
751
+ "735": "poncho",
752
+ "736": "pool table, billiard table, snooker table",
753
+ "737": "pop bottle, soda bottle",
754
+ "738": "pot, flowerpot",
755
+ "739": "potters wheel",
756
+ "740": "power drill",
757
+ "741": "prayer rug, prayer mat",
758
+ "742": "printer",
759
+ "743": "prison, prison house",
760
+ "744": "projectile, missile",
761
+ "745": "projector",
762
+ "746": "puck, hockey puck",
763
+ "747": "punching bag, punch bag, punching ball, punchball",
764
+ "748": "purse",
765
+ "749": "quill, quill pen",
766
+ "750": "quilt, comforter, comfort, puff",
767
+ "751": "racer, race car, racing car",
768
+ "752": "racket, racquet",
769
+ "753": "radiator",
770
+ "754": "radio, wireless",
771
+ "755": "radio telescope, radio reflector",
772
+ "756": "rain barrel",
773
+ "757": "recreational vehicle, RV, R.V.",
774
+ "758": "reel",
775
+ "759": "reflex camera",
776
+ "760": "refrigerator, icebox",
777
+ "761": "remote control, remote",
778
+ "762": "restaurant, eating house, eating place, eatery",
779
+ "763": "revolver, six-gun, six-shooter",
780
+ "764": "rifle",
781
+ "765": "rocking chair, rocker",
782
+ "766": "rotisserie",
783
+ "767": "rubber eraser, rubber, pencil eraser",
784
+ "768": "rugby ball",
785
+ "769": "rule, ruler",
786
+ "770": "running shoe",
787
+ "771": "safe",
788
+ "772": "safety pin",
789
+ "773": "saltshaker, salt shaker",
790
+ "774": "sandal",
791
+ "775": "sarong",
792
+ "776": "sax, saxophone",
793
+ "777": "scabbard",
794
+ "778": "scale, weighing machine",
795
+ "779": "school bus",
796
+ "780": "schooner",
797
+ "781": "scoreboard",
798
+ "782": "screen, CRT screen",
799
+ "783": "screw",
800
+ "784": "screwdriver",
801
+ "785": "seat belt, seatbelt",
802
+ "786": "sewing machine",
803
+ "787": "shield, buckler",
804
+ "788": "shoe shop, shoe-shop, shoe store",
805
+ "789": "shoji",
806
+ "790": "shopping basket",
807
+ "791": "shopping cart",
808
+ "792": "shovel",
809
+ "793": "shower cap",
810
+ "794": "shower curtain",
811
+ "795": "ski",
812
+ "796": "ski mask",
813
+ "797": "sleeping bag",
814
+ "798": "slide rule, slipstick",
815
+ "799": "sliding door",
816
+ "800": "slot, one-armed bandit",
817
+ "801": "snorkel",
818
+ "802": "snowmobile",
819
+ "803": "snowplow, snowplough",
820
+ "804": "soap dispenser",
821
+ "805": "soccer ball",
822
+ "806": "sock",
823
+ "807": "solar dish, solar collector, solar furnace",
824
+ "808": "sombrero",
825
+ "809": "soup bowl",
826
+ "810": "space bar",
827
+ "811": "space heater",
828
+ "812": "space shuttle",
829
+ "813": "spatula",
830
+ "814": "speedboat",
831
+ "815": "spider web, spiders web",
832
+ "816": "spindle",
833
+ "817": "sports car, sport car",
834
+ "818": "spotlight, spot",
835
+ "819": "stage",
836
+ "820": "steam locomotive",
837
+ "821": "steel arch bridge",
838
+ "822": "steel drum",
839
+ "823": "stethoscope",
840
+ "824": "stole",
841
+ "825": "stone wall",
842
+ "826": "stopwatch, stop watch",
843
+ "827": "stove",
844
+ "828": "strainer",
845
+ "829": "streetcar, tram, tramcar, trolley, trolley car",
846
+ "830": "stretcher",
847
+ "831": "studio couch, day bed",
848
+ "832": "stupa, tope",
849
+ "833": "submarine, pigboat, sub, U-boat",
850
+ "834": "suit, suit of clothes",
851
+ "835": "sundial",
852
+ "836": "sunglass",
853
+ "837": "sunglasses, dark glasses, shades",
854
+ "838": "sunscreen, sunblock, sun blocker",
855
+ "839": "suspension bridge",
856
+ "840": "swab, swob, mop",
857
+ "841": "sweatshirt",
858
+ "842": "swimming trunks, bathing trunks",
859
+ "843": "swing",
860
+ "844": "switch, electric switch, electrical switch",
861
+ "845": "syringe",
862
+ "846": "table lamp",
863
+ "847": "tank, army tank, armored combat vehicle, armoured combat vehicle",
864
+ "848": "tape player",
865
+ "849": "teapot",
866
+ "850": "teddy, teddy bear",
867
+ "851": "television, television system",
868
+ "852": "tennis ball",
869
+ "853": "thatch, thatched roof",
870
+ "854": "theater curtain, theatre curtain",
871
+ "855": "thimble",
872
+ "856": "thresher, thrasher, threshing machine",
873
+ "857": "throne",
874
+ "858": "tile roof",
875
+ "859": "toaster",
876
+ "860": "tobacco shop, tobacconist shop, tobacconist",
877
+ "861": "toilet seat",
878
+ "862": "torch",
879
+ "863": "totem pole",
880
+ "864": "tow truck, tow car, wrecker",
881
+ "865": "toyshop",
882
+ "866": "tractor",
883
+ "867": "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi",
884
+ "868": "tray",
885
+ "869": "trench coat",
886
+ "870": "tricycle, trike, velocipede",
887
+ "871": "trimaran",
888
+ "872": "tripod",
889
+ "873": "triumphal arch",
890
+ "874": "trolleybus, trolley coach, trackless trolley",
891
+ "875": "trombone",
892
+ "876": "tub, vat",
893
+ "877": "turnstile",
894
+ "878": "typewriter keyboard",
895
+ "879": "umbrella",
896
+ "880": "unicycle, monocycle",
897
+ "881": "upright, upright piano",
898
+ "882": "vacuum, vacuum cleaner",
899
+ "883": "vase",
900
+ "884": "vault",
901
+ "885": "velvet",
902
+ "886": "vending machine",
903
+ "887": "vestment",
904
+ "888": "viaduct",
905
+ "889": "violin, fiddle",
906
+ "890": "volleyball",
907
+ "891": "waffle iron",
908
+ "892": "wall clock",
909
+ "893": "wallet, billfold, notecase, pocketbook",
910
+ "894": "wardrobe, closet, press",
911
+ "895": "warplane, military plane",
912
+ "896": "washbasin, handbasin, washbowl, lavabo, wash-hand basin",
913
+ "897": "washer, automatic washer, washing machine",
914
+ "898": "water bottle",
915
+ "899": "water jug",
916
+ "900": "water tower",
917
+ "901": "whiskey jug",
918
+ "902": "whistle",
919
+ "903": "wig",
920
+ "904": "window screen",
921
+ "905": "window shade",
922
+ "906": "Windsor tie",
923
+ "907": "wine bottle",
924
+ "908": "wing",
925
+ "909": "wok",
926
+ "910": "wooden spoon",
927
+ "911": "wool, woolen, woollen",
928
+ "912": "worm fence, snake fence, snake-rail fence, Virginia fence",
929
+ "913": "wreck",
930
+ "914": "yawl",
931
+ "915": "yurt",
932
+ "916": "web site, website, internet site, site",
933
+ "917": "comic book",
934
+ "918": "crossword puzzle, crossword",
935
+ "919": "street sign",
936
+ "920": "traffic light, traffic signal, stoplight",
937
+ "921": "book jacket, dust cover, dust jacket, dust wrapper",
938
+ "922": "menu",
939
+ "923": "plate",
940
+ "924": "guacamole",
941
+ "925": "consomme",
942
+ "926": "hot pot, hotpot",
943
+ "927": "trifle",
944
+ "928": "ice cream, icecream",
945
+ "929": "ice lolly, lolly, lollipop, popsicle",
946
+ "930": "French loaf",
947
+ "931": "bagel, beigel",
948
+ "932": "pretzel",
949
+ "933": "cheeseburger",
950
+ "934": "hotdog, hot dog, red hot",
951
+ "935": "mashed potato",
952
+ "936": "head cabbage",
953
+ "937": "broccoli",
954
+ "938": "cauliflower",
955
+ "939": "zucchini, courgette",
956
+ "940": "spaghetti squash",
957
+ "941": "acorn squash",
958
+ "942": "butternut squash",
959
+ "943": "cucumber, cuke",
960
+ "944": "artichoke, globe artichoke",
961
+ "945": "bell pepper",
962
+ "946": "cardoon",
963
+ "947": "mushroom",
964
+ "948": "Granny Smith",
965
+ "949": "strawberry",
966
+ "950": "orange",
967
+ "951": "lemon",
968
+ "952": "fig",
969
+ "953": "pineapple, ananas",
970
+ "954": "banana",
971
+ "955": "jackfruit, jak, jack",
972
+ "956": "custard apple",
973
+ "957": "pomegranate",
974
+ "958": "hay",
975
+ "959": "carbonara",
976
+ "960": "chocolate sauce, chocolate syrup",
977
+ "961": "dough",
978
+ "962": "meat loaf, meatloaf",
979
+ "963": "pizza, pizza pie",
980
+ "964": "potpie",
981
+ "965": "burrito",
982
+ "966": "red wine",
983
+ "967": "espresso",
984
+ "968": "cup",
985
+ "969": "eggnog",
986
+ "970": "alp",
987
+ "971": "bubble",
988
+ "972": "cliff, drop, drop-off",
989
+ "973": "coral reef",
990
+ "974": "geyser",
991
+ "975": "lakeside, lakeshore",
992
+ "976": "promontory, headland, head, foreland",
993
+ "977": "sandbar, sand bar",
994
+ "978": "seashore, coast, seacoast, sea-coast",
995
+ "979": "valley, vale",
996
+ "980": "volcano",
997
+ "981": "ballplayer, baseball player",
998
+ "982": "groom, bridegroom",
999
+ "983": "scuba diver",
1000
+ "984": "rapeseed",
1001
+ "985": "daisy",
1002
+ "986": "yellow ladys slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum",
1003
+ "987": "corn",
1004
+ "988": "acorn",
1005
+ "989": "hip, rose hip, rosehip",
1006
+ "990": "buckeye, horse chestnut, conker",
1007
+ "991": "coral fungus",
1008
+ "992": "agaric",
1009
+ "993": "gyromitra",
1010
+ "994": "stinkhorn, carrion fungus",
1011
+ "995": "earthstar",
1012
+ "996": "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa",
1013
+ "997": "bolete",
1014
+ "998": "ear, spike, capitulum",
1015
+ "999": "toilet tissue, toilet paper, bathroom tissue"
1016
+ }
1017
+ }
PixNerd-XL-16-256/pipeline.py CHANGED
@@ -1,353 +1,433 @@
1
- from __future__ import annotations
2
-
3
- import sys
4
- from dataclasses import dataclass
5
- from pathlib import Path
6
- from typing import List, Literal, Optional, Sequence, Union
7
-
8
- import torch
9
- from diffusers import DiffusionPipeline
10
- from diffusers.image_processor import VaeImageProcessor
11
- from diffusers.utils import BaseOutput
12
- from PIL import Image
13
-
14
- ConditioningInput = Union[str, int, Sequence[Union[str, int]]]
15
- Language = Literal["en", "cn"]
16
-
17
-
18
- @dataclass
19
- class PixNerdPipelineOutput(BaseOutput):
20
- images: Union[List[Image.Image], torch.Tensor, "np.ndarray"]
21
-
22
-
23
- class PixNerdPipeline(DiffusionPipeline):
24
- model_cpu_offload_seq = "conditioner->transformer->vae"
25
- _callback_tensor_inputs = ["latents"]
26
-
27
- def __init__(
28
- self,
29
- transformer,
30
- scheduler,
31
- vae=None,
32
- conditioner=None,
33
- id2label: Optional[dict[int, str]] = None,
34
- id2label_cn: Optional[dict[int, str]] = None,
35
- ):
36
- super().__init__()
37
- if vae is None:
38
- vae = getattr(transformer, "vae", None)
39
- if conditioner is None:
40
- conditioner = getattr(transformer, "conditioner", None)
41
- if vae is None or conditioner is None:
42
- raise ValueError("Pipeline requires `vae` and `conditioner` either explicitly or from `transformer`.")
43
- self.register_modules(
44
- vae=vae,
45
- conditioner=conditioner,
46
- transformer=transformer,
47
- scheduler=scheduler,
48
- )
49
- self.image_processor = VaeImageProcessor(vae_scale_factor=1)
50
-
51
- if id2label is None and id2label_cn is None:
52
- id2label, id2label_cn = self._load_repo_labels()
53
- self._id2label = id2label or {}
54
- self._id2label_cn = id2label_cn or {}
55
- self.labels = self._build_label2id(self._id2label)
56
- self.labels_cn = self._build_label2id(self._id2label_cn)
57
- self._labels_loaded_from_path = bool(self._id2label or self._id2label_cn)
58
-
59
- def _ensure_labels_loaded(self) -> None:
60
- if self._labels_loaded_from_path:
61
- return
62
-
63
- path = getattr(getattr(self, "config", None), "_name_or_path", None) or getattr(self, "_name_or_path", None)
64
- if not path:
65
- return
66
-
67
- id2label, id2label_cn = self._load_labels_for_path(path)
68
- if id2label is None and id2label_cn is None:
69
- self._labels_loaded_from_path = True
70
- return
71
-
72
- self._id2label = id2label or {}
73
- self._id2label_cn = id2label_cn or {}
74
- self.labels = self._build_label2id(self._id2label)
75
- self.labels_cn = self._build_label2id(self._id2label_cn)
76
- self._labels_loaded_from_path = True
77
-
78
- @staticmethod
79
- def _resolve_labels_dir(pretrained_model_name_or_path: Union[str, Path]) -> Optional[Path]:
80
- path = Path(pretrained_model_name_or_path)
81
- if not path.exists():
82
- try:
83
- from huggingface_hub import snapshot_download
84
-
85
- path = Path(snapshot_download(pretrained_model_name_or_path))
86
- except Exception:
87
- return None
88
-
89
- if (path / "model_index.json").exists():
90
- labels_dir = path.parent / "labels"
91
- else:
92
- labels_dir = path / "labels"
93
- return labels_dir if labels_dir.is_dir() else None
94
-
95
- @classmethod
96
- def _load_labels_for_path(
97
- cls,
98
- pretrained_model_name_or_path: Union[str, Path],
99
- ) -> tuple[Optional[dict[int, str]], Optional[dict[int, str]]]:
100
- labels_dir = cls._resolve_labels_dir(pretrained_model_name_or_path)
101
- if labels_dir is None:
102
- return None, None
103
-
104
- labels_path = str(labels_dir)
105
- inserted = False
106
- if labels_path not in sys.path:
107
- sys.path.insert(0, labels_path)
108
- inserted = True
109
- try:
110
- from imagenet_labels import load_id2label
111
-
112
- return (
113
- load_id2label(labels_dir, lang="en"),
114
- load_id2label(labels_dir, lang="cn"),
115
- )
116
- finally:
117
- if inserted and labels_path in sys.path:
118
- sys.path.remove(labels_path)
119
-
120
- @staticmethod
121
- def _load_repo_labels() -> tuple[Optional[dict[int, str]], Optional[dict[int, str]]]:
122
- labels_dir = Path(__file__).resolve().parent.parent / "labels"
123
- if not labels_dir.is_dir():
124
- return None, None
125
-
126
- labels_path = str(labels_dir)
127
- inserted = False
128
- if labels_path not in sys.path:
129
- sys.path.insert(0, labels_path)
130
- inserted = True
131
- try:
132
- from imagenet_labels import load_id2label
133
-
134
- return (
135
- load_id2label(labels_dir, lang="en"),
136
- load_id2label(labels_dir, lang="cn"),
137
- )
138
- finally:
139
- if inserted and labels_path in sys.path:
140
- sys.path.remove(labels_path)
141
-
142
- @classmethod
143
- def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
144
- pipe = super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
145
- id2label, id2label_cn = cls._load_labels_for_path(pretrained_model_name_or_path)
146
- if id2label is not None or id2label_cn is not None:
147
- pipe._id2label = id2label or {}
148
- pipe._id2label_cn = id2label_cn or {}
149
- pipe.labels = cls._build_label2id(pipe._id2label)
150
- pipe.labels_cn = cls._build_label2id(pipe._id2label_cn)
151
- return pipe
152
-
153
- @staticmethod
154
- def _build_label2id(id2label: dict[int, str]) -> dict[str, int]:
155
- label2id: dict[str, int] = {}
156
- for class_id, value in id2label.items():
157
- for synonym in value.split(","):
158
- synonym = synonym.strip()
159
- if synonym:
160
- label2id[synonym] = int(class_id)
161
- return dict(sorted(label2id.items()))
162
-
163
- @property
164
- def id2label(self) -> dict[int, str]:
165
- self._ensure_labels_loaded()
166
- return self._id2label
167
-
168
- @property
169
- def id2label_cn(self) -> dict[int, str]:
170
- self._ensure_labels_loaded()
171
- return self._id2label_cn
172
-
173
- def get_label_ids(
174
- self,
175
- labels: Union[str, List[str]],
176
- *,
177
- lang: Language = "en",
178
- ) -> List[int]:
179
- self._ensure_labels_loaded()
180
- if isinstance(labels, str):
181
- labels = [labels]
182
-
183
- label2id = self.labels if lang == "en" else self.labels_cn
184
- if not label2id:
185
- raise ValueError(
186
- f"No {lang} labels loaded. Ensure `labels/id2label_{lang}.json` exists next to the variant folder."
187
- )
188
-
189
- missing = [label for label in labels if label not in label2id]
190
- if missing:
191
- preview = ", ".join(list(label2id.keys())[:8])
192
- raise ValueError(
193
- f"Unknown label(s) for lang={lang!r}: {missing}. Example valid labels: {preview}, ..."
194
- )
195
- return [label2id[label] for label in labels]
196
-
197
- def _resolve_prompt_item(self, value: Union[str, int]) -> int:
198
- if isinstance(value, int):
199
- return value
200
- if value.isdigit():
201
- return int(value)
202
- if value in self.labels:
203
- return self.labels[value]
204
- if value in self.labels_cn:
205
- return self.labels_cn[value]
206
- raise ValueError(
207
- f"Unknown class label {value!r}. Pass an ImageNet class id or a synonym from "
208
- "`pipe.labels` / `pipe.labels_cn`."
209
- )
210
-
211
- def _resolve_prompts(self, prompts: List[Union[str, int]]) -> List[int]:
212
- self._ensure_labels_loaded()
213
- return [self._resolve_prompt_item(prompt) for prompt in prompts]
214
-
215
- @staticmethod
216
- def _fp_to_uint8(image: torch.Tensor) -> torch.Tensor:
217
- return torch.clip_((image + 1) * 127.5 + 0.5, 0, 255).to(torch.uint8)
218
-
219
- @staticmethod
220
- def _to_list(y: ConditioningInput) -> List[Union[str, int]]:
221
- if isinstance(y, (str, int)):
222
- return [y]
223
- return list(y)
224
-
225
- @staticmethod
226
- def _repeat(values: List[Union[str, int]], repeats: int) -> List[Union[str, int]]:
227
- if repeats == 1:
228
- return values
229
- expanded: List[Union[str, int]] = []
230
- for value in values:
231
- expanded.extend([value] * repeats)
232
- return expanded
233
-
234
- def encode_prompt(
235
- self,
236
- prompt: ConditioningInput,
237
- num_images_per_prompt: int,
238
- ):
239
- prompts = self._repeat(self._to_list(prompt), num_images_per_prompt)
240
- resolved = self._resolve_prompts(prompts)
241
- metadata = {"device": self._execution_device}
242
- with torch.no_grad():
243
- cond, uncond = self.conditioner(resolved, metadata)
244
- return cond, uncond, resolved
245
-
246
- def prepare_latents(
247
- self,
248
- batch_size: int,
249
- num_channels: int,
250
- height: int,
251
- width: int,
252
- generator: Optional[torch.Generator] = None,
253
- latents: Optional[torch.Tensor] = None,
254
- ) -> torch.Tensor:
255
- if latents is not None:
256
- return latents.to(device=self._execution_device, dtype=torch.float32)
257
- return torch.randn(
258
- (batch_size, num_channels, height, width),
259
- generator=generator,
260
- device=self._execution_device,
261
- dtype=torch.float32,
262
- )
263
-
264
- @torch.no_grad()
265
- def __call__(
266
- self,
267
- prompt: ConditioningInput,
268
- negative_prompt: Optional[ConditioningInput] = None,
269
- num_images_per_prompt: int = 1,
270
- height: int = 512,
271
- width: int = 512,
272
- num_inference_steps: int = 25,
273
- guidance_scale: float = 4.0,
274
- generator: Optional[torch.Generator] = None,
275
- seed: Optional[int] = None,
276
- latents: Optional[torch.Tensor] = None,
277
- output_type: str = "pil",
278
- return_dict: bool = True,
279
- timeshift: float = 3.0,
280
- order: int = 2,
281
- ) -> PixNerdPipelineOutput | tuple:
282
- patch_size = int(getattr(self.transformer, "patch_size", 1))
283
- channels = int(getattr(self.transformer, "in_channels", 3))
284
- height = (height // patch_size) * patch_size
285
- width = (width // patch_size) * patch_size
286
-
287
- if hasattr(self.transformer, "decoder_patch_scaling_h"):
288
- self.transformer.decoder_patch_scaling_h = height / 512
289
- self.transformer.decoder_patch_scaling_w = width / 512
290
-
291
- cond, default_uncond, prompts = self.encode_prompt(prompt, num_images_per_prompt)
292
- if negative_prompt is not None:
293
- negative = self._repeat(self._to_list(negative_prompt), num_images_per_prompt)
294
- resolved_negative = self._resolve_prompts(negative)
295
- metadata = {"device": self._execution_device}
296
- with torch.no_grad():
297
- _, uncond = self.conditioner(resolved_negative, metadata)
298
- else:
299
- uncond = default_uncond
300
- batch_size = len(prompts)
301
- if generator is None and seed is not None:
302
- generator = torch.Generator(device=self._execution_device).manual_seed(seed)
303
- latents = self.prepare_latents(
304
- batch_size=batch_size,
305
- num_channels=channels,
306
- height=height,
307
- width=width,
308
- generator=generator,
309
- latents=latents,
310
- )
311
- self.scheduler.set_timesteps(
312
- num_inference_steps=num_inference_steps,
313
- guidance_scale=guidance_scale,
314
- timeshift=timeshift,
315
- order=order,
316
- device=latents.device,
317
- )
318
- for timestep in self.scheduler.timesteps:
319
- cfg_latents = torch.cat([latents, latents], dim=0)
320
- cfg_t = timestep.repeat(cfg_latents.shape[0]).to(latents.device, dtype=latents.dtype)
321
- cfg_condition = torch.cat([uncond, cond], dim=0)
322
- model_output = self.transformer(
323
- sample=cfg_latents,
324
- timestep=cfg_t,
325
- encoder_hidden_states=cfg_condition,
326
- ).sample
327
- model_output = self.scheduler.classifier_free_guidance(model_output)
328
- latents = self.scheduler.step(
329
- model_output=model_output,
330
- timestep=timestep,
331
- sample=latents,
332
- ).prev_sample
333
-
334
- image = self.vae.decode(latents)
335
- images_uint8 = self._fp_to_uint8(image).permute(0, 2, 3, 1).cpu().numpy()
336
- if output_type == "pil":
337
- output = [Image.fromarray(img) for img in images_uint8]
338
- elif output_type == "pt":
339
- output = torch.from_numpy(images_uint8)
340
- elif output_type == "np":
341
- output = images_uint8
342
- else:
343
- raise ValueError(f"Unsupported output_type: {output_type}")
344
-
345
- if not return_dict:
346
- return (output,)
347
- return PixNerdPipelineOutput(images=output)
348
-
349
-
350
- __all__ = [
351
- "PixNerdPipeline",
352
- "PixNerdPipelineOutput",
353
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ import json
18
+ from pathlib import Path
19
+ from typing import Dict, List, Optional, Tuple, Union
20
+
21
+ import torch
22
+
23
+ from diffusers.image_processor import VaeImageProcessor
24
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
25
+ from diffusers.utils.torch_utils import randn_tensor
26
+
27
+ DEFAULT_NATIVE_RESOLUTION = 512
28
+
29
+ EXAMPLE_DOC_STRING = """
30
+ Examples:
31
+ ```py
32
+ >>> from pathlib import Path
33
+ >>> from diffusers import DiffusionPipeline
34
+ >>> import torch
35
+
36
+ >>> model_dir = Path("./PixNerd-XL-16-512").resolve()
37
+ >>> pipe = DiffusionPipeline.from_pretrained(
38
+ ... str(model_dir),
39
+ ... local_files_only=True,
40
+ ... custom_pipeline=str(model_dir / "pipeline.py"),
41
+ ... trust_remote_code=True,
42
+ ... torch_dtype=torch.bfloat16,
43
+ ... )
44
+ >>> pipe.to("cuda")
45
+
46
+ >>> print(pipe.id2label[207])
47
+ >>> print(pipe.get_label_ids("golden retriever"))
48
+
49
+ >>> generator = torch.Generator(device="cuda").manual_seed(42)
50
+ >>> # timeshift=3.0 and order=2 are defaults in scheduler/scheduler_config.json
51
+ >>> image = pipe(
52
+ ... class_labels="golden retriever",
53
+ ... height=512,
54
+ ... width=512,
55
+ ... num_inference_steps=25,
56
+ ... guidance_scale=4.0,
57
+ ... generator=generator,
58
+ ... ).images[0]
59
+ >>> image.save("demo.png")
60
+ ```
61
+ """
62
+
63
+ ConditioningInput = Union[int, str, List[Union[int, str]], torch.LongTensor]
64
+
65
+
66
+ class PixNerdPipeline(DiffusionPipeline):
67
+ r"""
68
+ Pipeline for class-conditional PixNerd pixel-space image generation.
69
+
70
+ Parameters:
71
+ transformer ([`PixNerdTransformer2DModel`]):
72
+ Class-conditional PixNerd denoiser operating in pixel space.
73
+ scheduler ([`PixNerdFlowMatchScheduler`]):
74
+ Flow-matching scheduler with AdamLM multi-step coefficients.
75
+ vae ([`PixNerdPixelVAE`], *optional*):
76
+ Identity pixel autoencoder. May also be attached to `transformer.vae`.
77
+ conditioner ([`PixNerdLabelConditioner`], *optional*):
78
+ ImageNet class-label conditioner. May also be attached to `transformer.conditioner`.
79
+ id2label (`dict[int, str]`, *optional*):
80
+ ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
81
+ """
82
+
83
+ model_cpu_offload_seq = "conditioner->transformer->vae"
84
+ _callback_tensor_inputs = ["latents"]
85
+ _optional_components = ["vae", "conditioner"]
86
+
87
+ def __init__(
88
+ self,
89
+ transformer,
90
+ scheduler,
91
+ vae=None,
92
+ conditioner=None,
93
+ id2label: Optional[Dict[Union[int, str], str]] = None,
94
+ ):
95
+ super().__init__()
96
+ if vae is None:
97
+ vae = getattr(transformer, "vae", None)
98
+ if conditioner is None:
99
+ conditioner = getattr(transformer, "conditioner", None)
100
+ if vae is None or conditioner is None:
101
+ raise ValueError("Pipeline requires `vae` and `conditioner` either explicitly or from `transformer`.")
102
+ self.register_modules(
103
+ vae=vae,
104
+ conditioner=conditioner,
105
+ transformer=transformer,
106
+ scheduler=scheduler,
107
+ )
108
+ self.image_processor = VaeImageProcessor(vae_scale_factor=1, do_normalize=False)
109
+ if id2label is None:
110
+ id2label = self._read_id2label_from_model_index(
111
+ getattr(getattr(self, "config", None), "_name_or_path", None)
112
+ )
113
+ self._id2label = self._normalize_id2label(id2label)
114
+ self.labels = self._build_label2id(self._id2label)
115
+ self._labels_loaded_from_model_index = bool(self._id2label)
116
+
117
+ def _get_device(self) -> torch.device:
118
+ try:
119
+ return self._execution_device
120
+ except AttributeError:
121
+ pass
122
+ for name in ("transformer", "vae", "scheduler"):
123
+ module = getattr(self, name, None)
124
+ if isinstance(module, torch.nn.Module):
125
+ parameter = next(module.parameters(), None)
126
+ if parameter is not None:
127
+ return parameter.device
128
+ return torch.device("cpu")
129
+
130
+ @classmethod
131
+ def from_pretrained(cls, pretrained_model_name_or_path=None, *args, **kwargs):
132
+ id2label_override = kwargs.pop("id2label", None)
133
+ pipe = super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
134
+ id2label = id2label_override or cls._read_id2label_from_model_index(pretrained_model_name_or_path)
135
+ if id2label:
136
+ pipe._id2label = cls._normalize_id2label(id2label)
137
+ pipe.labels = cls._build_label2id(pipe._id2label)
138
+ pipe._labels_loaded_from_model_index = True
139
+ return pipe
140
+
141
+ def _ensure_labels_loaded(self) -> None:
142
+ if self._labels_loaded_from_model_index:
143
+ return
144
+ loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None))
145
+ if loaded:
146
+ self._id2label = loaded
147
+ self.labels = self._build_label2id(self._id2label)
148
+ self._labels_loaded_from_model_index = True
149
+
150
+ @staticmethod
151
+ def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
152
+ if not id2label:
153
+ return {}
154
+ return {int(key): value for key, value in id2label.items()}
155
+
156
+ @staticmethod
157
+ def _read_id2label_from_model_index(variant_path: Optional[Union[str, Path]]) -> Dict[int, str]:
158
+ if not variant_path:
159
+ return {}
160
+ model_index_path = Path(variant_path).resolve() / "model_index.json"
161
+ if not model_index_path.exists():
162
+ return {}
163
+ raw = json.loads(model_index_path.read_text(encoding="utf-8"))
164
+ id2label = raw.get("id2label")
165
+ if not isinstance(id2label, dict):
166
+ return {}
167
+ return {int(key): value for key, value in id2label.items()}
168
+
169
+ @staticmethod
170
+ def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
171
+ label2id: Dict[str, int] = {}
172
+ for class_id, value in id2label.items():
173
+ for synonym in value.split(","):
174
+ synonym = synonym.strip()
175
+ if synonym:
176
+ label2id[synonym] = int(class_id)
177
+ return dict(sorted(label2id.items()))
178
+
179
+ @property
180
+ def id2label(self) -> Dict[int, str]:
181
+ r"""ImageNet class id to English label string (comma-separated synonyms)."""
182
+ self._ensure_labels_loaded()
183
+ return self._id2label
184
+
185
+ def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
186
+ r"""
187
+ Map ImageNet label strings to class ids.
188
+
189
+ Args:
190
+ label (`str` or `list[str]`):
191
+ One or more English label strings. Each string must match a synonym in `id2label`.
192
+ """
193
+ self._ensure_labels_loaded()
194
+ if isinstance(label, str):
195
+ label = [label]
196
+ if not self.labels:
197
+ raise ValueError("No English labels loaded. Ensure `id2label` exists in model_index.json.")
198
+ missing = [item for item in label if item not in self.labels]
199
+ if missing:
200
+ preview = ", ".join(list(self.labels.keys())[:8])
201
+ raise ValueError(f"Unknown English label(s): {missing}. Example valid labels: {preview}, ...")
202
+ return [self.labels[item] for item in label]
203
+
204
+ def _normalize_class_labels(
205
+ self,
206
+ class_labels: ConditioningInput,
207
+ num_images_per_prompt: int = 1,
208
+ ) -> List[int]:
209
+ if torch.is_tensor(class_labels):
210
+ values = class_labels.to(dtype=torch.long).reshape(-1).tolist()
211
+ elif isinstance(class_labels, int):
212
+ values = [class_labels]
213
+ elif isinstance(class_labels, str):
214
+ values = self.get_label_ids(class_labels)
215
+ elif class_labels and isinstance(class_labels[0], str):
216
+ values = self.get_label_ids(list(class_labels))
217
+ else:
218
+ values = [int(entry) for entry in class_labels]
219
+
220
+ if num_images_per_prompt == 1:
221
+ return values
222
+ expanded: List[int] = []
223
+ for value in values:
224
+ expanded.extend([value] * num_images_per_prompt)
225
+ return expanded
226
+
227
+ def _get_patch_size(self) -> int:
228
+ patch_size = getattr(self.transformer, "patch_size", None)
229
+ if patch_size is not None:
230
+ return int(patch_size)
231
+ return int(getattr(self.transformer.config, "patch_size", 16))
232
+
233
+ def _get_in_channels(self) -> int:
234
+ in_channels = getattr(self.transformer, "in_channels", None)
235
+ if in_channels is not None:
236
+ return int(in_channels)
237
+ return int(getattr(self.transformer.config, "in_channels", 3))
238
+
239
+ def check_inputs(
240
+ self,
241
+ height: int,
242
+ width: int,
243
+ num_inference_steps: int,
244
+ output_type: str,
245
+ ) -> None:
246
+ if num_inference_steps < 1:
247
+ raise ValueError("num_inference_steps must be >= 1.")
248
+ if output_type not in {"pil", "np", "pt", "latent"}:
249
+ raise ValueError("output_type must be one of: 'pil', 'np', 'pt', 'latent'.")
250
+ order = int(getattr(self.scheduler.config, "order", getattr(self.scheduler, "order", 2)))
251
+ if order < 1:
252
+ raise ValueError("scheduler.config.order must be >= 1.")
253
+
254
+ patch_size = self._get_patch_size()
255
+ if height % patch_size != 0 or width % patch_size != 0:
256
+ raise ValueError(f"height and width must be divisible by patch_size={patch_size}.")
257
+
258
+ def encode_condition(
259
+ self,
260
+ class_label_ids: List[int],
261
+ negative_class_label_ids: Optional[List[int]] = None,
262
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
263
+ metadata = {"device": self._get_device()}
264
+ with torch.no_grad():
265
+ cond, default_uncond = self.conditioner(class_label_ids, metadata)
266
+ if negative_class_label_ids is not None:
267
+ _, uncond = self.conditioner(negative_class_label_ids, metadata)
268
+ else:
269
+ uncond = default_uncond
270
+ return cond, uncond
271
+
272
+ def prepare_latents(
273
+ self,
274
+ batch_size: int,
275
+ num_channels: int,
276
+ height: int,
277
+ width: int,
278
+ dtype: torch.dtype,
279
+ device: torch.device,
280
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
281
+ latents: Optional[torch.Tensor] = None,
282
+ ) -> torch.Tensor:
283
+ if latents is not None:
284
+ return latents.to(device=device, dtype=dtype)
285
+ return randn_tensor(
286
+ (batch_size, num_channels, height, width),
287
+ generator=generator,
288
+ device=device,
289
+ dtype=dtype,
290
+ )
291
+
292
+ @staticmethod
293
+ def _fp_to_uint8(image: torch.Tensor) -> torch.Tensor:
294
+ return torch.clip_((image + 1) * 127.5 + 0.5, 0, 255).to(torch.uint8)
295
+
296
+ def decode_latents(self, latents: torch.Tensor, output_type: str = "pil"):
297
+ if output_type == "latent":
298
+ return latents
299
+
300
+ image = self.vae.decode(latents)
301
+ if output_type == "pt":
302
+ return image
303
+ images_uint8 = self._fp_to_uint8(image).permute(0, 2, 3, 1).cpu().numpy()
304
+ if output_type == "np":
305
+ return images_uint8
306
+ if output_type == "pil":
307
+ from PIL import Image
308
+
309
+ return [Image.fromarray(img) for img in images_uint8]
310
+ raise ValueError(f"Unsupported output_type: {output_type}")
311
+
312
+ def _apply_decoder_patch_scaling(self, height: int, width: int) -> None:
313
+ denoiser = getattr(self.transformer, "denoiser", self.transformer)
314
+ if hasattr(denoiser, "decoder_patch_scaling_h"):
315
+ denoiser.decoder_patch_scaling_h = height / DEFAULT_NATIVE_RESOLUTION
316
+ denoiser.decoder_patch_scaling_w = width / DEFAULT_NATIVE_RESOLUTION
317
+
318
+ @torch.inference_mode()
319
+ def __call__(
320
+ self,
321
+ class_labels: Optional[ConditioningInput] = None,
322
+ negative_class_labels: Optional[ConditioningInput] = None,
323
+ num_images_per_prompt: int = 1,
324
+ height: Optional[int] = None,
325
+ width: Optional[int] = None,
326
+ num_inference_steps: int = 25,
327
+ guidance_scale: float = 4.0,
328
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
329
+ latents: Optional[torch.Tensor] = None,
330
+ output_type: str = "pil",
331
+ return_dict: bool = True,
332
+ prompt: Optional[ConditioningInput] = None,
333
+ negative_prompt: Optional[ConditioningInput] = None,
334
+ ) -> Union[ImagePipelineOutput, Tuple]:
335
+ r"""
336
+ Generate class-conditional images with PixNerd.
337
+
338
+ Args:
339
+ class_labels (`int`, `str`, `list[int]`, `list[str]`, or `torch.LongTensor`):
340
+ ImageNet class indices or human-readable English label strings.
341
+ negative_class_labels (`int`, `str`, `list[int]`, `list[str]`, or `torch.LongTensor`, *optional*):
342
+ Optional negative class labels for classifier-free guidance.
343
+ num_images_per_prompt (`int`, defaults to `1`):
344
+ Number of images to generate per class label.
345
+ height (`int`, *optional*):
346
+ Output image height in pixels. Defaults to `512`.
347
+ width (`int`, *optional*):
348
+ Output image width in pixels. Defaults to `512`.
349
+ num_inference_steps (`int`, defaults to `25`):
350
+ Number of denoising steps.
351
+ guidance_scale (`float`, defaults to `4.0`):
352
+ Classifier-free guidance scale applied by the scheduler.
353
+ generator (`torch.Generator`, *optional*):
354
+ RNG for reproducibility.
355
+ latents (`torch.Tensor`, *optional*):
356
+ Pre-generated noisy pixel tensor.
357
+ output_type (`str`, defaults to `"pil"`):
358
+ `"pil"`, `"np"`, `"pt"`, or `"latent"`.
359
+ return_dict (`bool`, defaults to `True`):
360
+ Return [`ImagePipelineOutput`] if True.
361
+ prompt (`int`, `str`, `list`, *optional*):
362
+ Deprecated alias for `class_labels`.
363
+ negative_prompt (`int`, `str`, `list`, *optional*):
364
+ Deprecated alias for `negative_class_labels`.
365
+ """
366
+ if class_labels is None:
367
+ class_labels = prompt
368
+ if negative_class_labels is None:
369
+ negative_class_labels = negative_prompt
370
+ if class_labels is None:
371
+ raise ValueError("`class_labels` (or deprecated `prompt`) must be provided.")
372
+
373
+ height = int(height or DEFAULT_NATIVE_RESOLUTION)
374
+ width = int(width or DEFAULT_NATIVE_RESOLUTION)
375
+ self.check_inputs(height, width, num_inference_steps, output_type)
376
+
377
+ patch_size = self._get_patch_size()
378
+ height = (height // patch_size) * patch_size
379
+ width = (width // patch_size) * patch_size
380
+ self._apply_decoder_patch_scaling(height, width)
381
+
382
+ class_label_ids = self._normalize_class_labels(class_labels, num_images_per_prompt)
383
+ negative_label_ids = None
384
+ if negative_class_labels is not None:
385
+ negative_label_ids = self._normalize_class_labels(negative_class_labels, num_images_per_prompt)
386
+
387
+ device = self._get_device()
388
+ model_dtype = next(self.transformer.parameters()).dtype
389
+ batch_size = len(class_label_ids)
390
+
391
+ cond, uncond = self.encode_condition(class_label_ids, negative_label_ids)
392
+ latents = self.prepare_latents(
393
+ batch_size=batch_size,
394
+ num_channels=self._get_in_channels(),
395
+ height=height,
396
+ width=width,
397
+ dtype=model_dtype,
398
+ device=device,
399
+ generator=generator,
400
+ latents=latents,
401
+ )
402
+
403
+ self.scheduler.set_timesteps(
404
+ num_inference_steps=num_inference_steps,
405
+ guidance_scale=guidance_scale,
406
+ device=device,
407
+ )
408
+
409
+ for timestep in self.progress_bar(self.scheduler.timesteps):
410
+ cfg_latents = torch.cat([latents, latents], dim=0)
411
+ cfg_t = timestep.repeat(cfg_latents.shape[0]).to(device=device, dtype=latents.dtype)
412
+ cfg_condition = torch.cat([uncond, cond], dim=0)
413
+ model_output = self.transformer(
414
+ sample=cfg_latents.to(dtype=model_dtype),
415
+ timestep=cfg_t,
416
+ encoder_hidden_states=cfg_condition,
417
+ ).sample
418
+ model_output = self.scheduler.classifier_free_guidance(model_output)
419
+ latents = self.scheduler.step(
420
+ model_output=model_output,
421
+ timestep=timestep,
422
+ sample=latents,
423
+ ).prev_sample
424
+
425
+ image = self.decode_latents(latents, output_type=output_type)
426
+
427
+ self.maybe_free_model_hooks()
428
+ if not return_dict:
429
+ return (image,)
430
+ return ImagePipelineOutput(images=image)
431
+
432
+
433
+ PixNerdPipelineOutput = ImagePipelineOutput
PixNerd-XL-16-256/scheduler/scheduling_pixnerd_flow_match.py CHANGED
@@ -1,231 +1,237 @@
1
- from __future__ import annotations
2
-
3
- from dataclasses import dataclass
4
- from typing import Any, Dict, List, Optional, Tuple, Union
5
-
6
- import torch
7
- from diffusers.configuration_utils import ConfigMixin, register_to_config
8
- from diffusers.schedulers.scheduling_utils import SchedulerMixin
9
- from diffusers.utils import BaseOutput
10
-
11
- @dataclass
12
- class PixNerdSchedulerOutput(BaseOutput):
13
- prev_sample: torch.Tensor
14
-
15
-
16
- class PixNerdFlowMatchScheduler(SchedulerMixin, ConfigMixin):
17
- """
18
- Diffusers-compatible scheduler wrapper for PixNerd's AdamLM flow-matching sampler.
19
- """
20
-
21
- config_name = "scheduler_config.json"
22
- order = 1
23
- init_noise_sigma = 1.0
24
-
25
- @staticmethod
26
- def _lagrange_coeffs(order: int, pre_ts: torch.Tensor, t_start: torch.Tensor, t_end: torch.Tensor) -> List[float]:
27
- ts = [float(v) for v in pre_ts[-order:].tolist()]
28
- a = float(t_start)
29
- b = float(t_end)
30
-
31
- if order == 1:
32
- return [1.0]
33
- if order == 2:
34
- t1, t2 = ts
35
- int1 = 0.5 / (t1 - t2) * ((b - t2) ** 2 - (a - t2) ** 2)
36
- int2 = 0.5 / (t2 - t1) * ((b - t1) ** 2 - (a - t1) ** 2)
37
- total = int1 + int2
38
- return [int1 / total, int2 / total]
39
- if order == 3:
40
- t1, t2, t3 = ts
41
- int1_denom = (t1 - t2) * (t1 - t3)
42
- int1 = ((1 / 3) * b**3 - 0.5 * (t2 + t3) * b**2 + (t2 * t3) * b) - (
43
- (1 / 3) * a**3 - 0.5 * (t2 + t3) * a**2 + (t2 * t3) * a
44
- )
45
- int1 = int1 / int1_denom
46
- int2_denom = (t2 - t1) * (t2 - t3)
47
- int2 = ((1 / 3) * b**3 - 0.5 * (t1 + t3) * b**2 + (t1 * t3) * b) - (
48
- (1 / 3) * a**3 - 0.5 * (t1 + t3) * a**2 + (t1 * t3) * a
49
- )
50
- int2 = int2 / int2_denom
51
- int3_denom = (t3 - t1) * (t3 - t2)
52
- int3 = ((1 / 3) * b**3 - 0.5 * (t1 + t2) * b**2 + (t1 * t2) * b) - (
53
- (1 / 3) * a**3 - 0.5 * (t1 + t2) * a**2 + (t1 * t2) * a
54
- )
55
- int3 = int3 / int3_denom
56
- total = int1 + int2 + int3
57
- return [int1 / total, int2 / total, int3 / total]
58
- if order == 4:
59
- t1, t2, t3, t4 = ts
60
- int1_denom = (t1 - t2) * (t1 - t3) * (t1 - t4)
61
- int1 = ((1 / 4) * b**4 - (1 / 3) * (t2 + t3 + t4) * b**3 + 0.5 * (t3 * t4 + t2 * t3 + t2 * t4) * b**2 - (t2 * t3 * t4) * b) - (
62
- (1 / 4) * a**4 - (1 / 3) * (t2 + t3 + t4) * a**3 + 0.5 * (t3 * t4 + t2 * t3 + t2 * t4) * a**2 - (t2 * t3 * t4) * a
63
- )
64
- int1 = int1 / int1_denom
65
- int2_denom = (t2 - t1) * (t2 - t3) * (t2 - t4)
66
- int2 = ((1 / 4) * b**4 - (1 / 3) * (t1 + t3 + t4) * b**3 + 0.5 * (t3 * t4 + t1 * t3 + t1 * t4) * b**2 - (t1 * t3 * t4) * b) - (
67
- (1 / 4) * a**4 - (1 / 3) * (t1 + t3 + t4) * a**3 + 0.5 * (t3 * t4 + t1 * t3 + t1 * t4) * a**2 - (t1 * t3 * t4) * a
68
- )
69
- int2 = int2 / int2_denom
70
- int3_denom = (t3 - t1) * (t3 - t2) * (t3 - t4)
71
- int3 = ((1 / 4) * b**4 - (1 / 3) * (t1 + t2 + t4) * b**3 + 0.5 * (t4 * t2 + t1 * t2 + t1 * t4) * b**2 - (t1 * t2 * t4) * b) - (
72
- (1 / 4) * a**4 - (1 / 3) * (t1 + t2 + t4) * a**3 + 0.5 * (t4 * t2 + t1 * t2 + t1 * t4) * a**2 - (t1 * t2 * t4) * a
73
- )
74
- int3 = int3 / int3_denom
75
- int4_denom = (t4 - t1) * (t4 - t2) * (t4 - t3)
76
- int4 = ((1 / 4) * b**4 - (1 / 3) * (t1 + t2 + t3) * b**3 + 0.5 * (t3 * t2 + t1 * t2 + t1 * t3) * b**2 - (t1 * t2 * t3) * b) - (
77
- (1 / 4) * a**4 - (1 / 3) * (t1 + t2 + t3) * a**3 + 0.5 * (t3 * t2 + t1 * t2 + t1 * t3) * a**2 - (t1 * t2 * t3) * a
78
- )
79
- int4 = int4 / int4_denom
80
- total = int1 + int2 + int3 + int4
81
- return [int1 / total, int2 / total, int3 / total, int4 / total]
82
- raise ValueError(f"Unsupported solver order: {order}.")
83
-
84
- @register_to_config
85
- def __init__(
86
- self,
87
- num_train_timesteps: int = 1000,
88
- num_inference_steps: int = 25,
89
- guidance_scale: float = 4.0,
90
- timeshift: float = 3.0,
91
- order: int = 2,
92
- guidance_interval_min: float = 0.0,
93
- guidance_interval_max: float = 1.0,
94
- last_step: Optional[float] = None,
95
- ) -> None:
96
- self.num_inference_steps = int(num_inference_steps)
97
- self.guidance_scale = float(guidance_scale)
98
- self.timeshift = float(timeshift)
99
- self.order = int(order)
100
- self.guidance_interval_min = float(guidance_interval_min)
101
- self.guidance_interval_max = float(guidance_interval_max)
102
- self.last_step = last_step
103
- self._reset_state()
104
-
105
- @classmethod
106
- def from_sampler_spec(cls, sampler_spec: Dict[str, Any]) -> "PixNerdFlowMatchScheduler":
107
- init_args = dict(sampler_spec.get("init_args", {}))
108
- return cls(
109
- num_inference_steps=int(init_args.get("num_steps", 25)),
110
- guidance_scale=float(init_args.get("guidance", 4.0)),
111
- timeshift=float(init_args.get("timeshift", 3.0)),
112
- order=int(init_args.get("order", 2)),
113
- guidance_interval_min=float(init_args.get("guidance_interval_min", 0.0)),
114
- guidance_interval_max=float(init_args.get("guidance_interval_max", 1.0)),
115
- last_step=init_args.get("last_step"),
116
- )
117
-
118
- def _reset_state(self) -> None:
119
- self.timesteps: Optional[torch.Tensor] = None
120
- self._timedeltas: Optional[torch.Tensor] = None
121
- self._solver_coeffs = None
122
- self._model_outputs = []
123
- self._step_index = 0
124
-
125
- @staticmethod
126
- def _shift_respace_fn(t: torch.Tensor, shift: float = 3.0) -> torch.Tensor:
127
- return t / (t + (1 - t) * shift)
128
-
129
- def _build_solver_state(
130
- self,
131
- num_inference_steps: int,
132
- timeshift: float,
133
- device: Optional[Union[str, torch.device]] = None,
134
- ) -> Tuple[torch.Tensor, torch.Tensor, List[List[float]]]:
135
- last_step = self.last_step
136
- if last_step is None:
137
- last_step = 1.0 / float(num_inference_steps)
138
-
139
- endpoints = torch.linspace(0.0, 1 - float(last_step), int(num_inference_steps), dtype=torch.float32)
140
- endpoints = torch.cat([endpoints, torch.tensor([1.0], dtype=torch.float32)], dim=0)
141
- timesteps = self._shift_respace_fn(endpoints, timeshift).to(device=device)
142
- timedeltas = (timesteps[1:] - timesteps[:-1]).to(device=device)
143
-
144
- solver_coeffs: List[List[float]] = [[] for _ in range(int(num_inference_steps))]
145
- for i in range(int(num_inference_steps)):
146
- order = min(self.order, i + 1)
147
- pre_ts = timesteps[: i + 1]
148
- coeffs = self._lagrange_coeffs(order, pre_ts, pre_ts[i], timesteps[i + 1])
149
- solver_coeffs[i] = coeffs
150
- return timesteps[:-1], timedeltas, solver_coeffs
151
-
152
- def set_timesteps(
153
- self,
154
- num_inference_steps: Optional[int] = None,
155
- device: Optional[Union[str, torch.device]] = None,
156
- timeshift: Optional[float] = None,
157
- guidance_scale: Optional[float] = None,
158
- order: Optional[int] = None,
159
- **kwargs: Any,
160
- ) -> None:
161
- if num_inference_steps is not None:
162
- self.num_inference_steps = int(num_inference_steps)
163
- if timeshift is not None:
164
- self.timeshift = float(timeshift)
165
- if guidance_scale is not None:
166
- self.guidance_scale = float(guidance_scale)
167
- if order is not None:
168
- self.order = int(order)
169
-
170
- timesteps, timedeltas, solver_coeffs = self._build_solver_state(
171
- self.num_inference_steps,
172
- self.timeshift,
173
- device=device,
174
- )
175
- self.timesteps = timesteps
176
- self._timedeltas = timedeltas
177
- self._solver_coeffs = solver_coeffs
178
- self._model_outputs = []
179
- self._step_index = 0
180
-
181
- def scale_model_input(self, sample: torch.Tensor, timestep: Optional[torch.Tensor] = None) -> torch.Tensor:
182
- return sample
183
-
184
- def classifier_free_guidance(self, model_output: torch.Tensor) -> torch.Tensor:
185
- if model_output.shape[0] % 2 != 0:
186
- raise ValueError("Classifier-free guidance expects concatenated unconditional/conditional batches.")
187
- uncond, cond = model_output.chunk(2, dim=0)
188
- return uncond + self.guidance_scale * (cond - uncond)
189
-
190
- def step(
191
- self,
192
- model_output: torch.Tensor,
193
- timestep: Union[torch.Tensor, float, int],
194
- sample: torch.Tensor,
195
- return_dict: bool = True,
196
- **kwargs: Any,
197
- ) -> Union[PixNerdSchedulerOutput, Tuple[torch.Tensor]]:
198
- if self.timesteps is None or self._timedeltas is None or self._solver_coeffs is None:
199
- raise RuntimeError("`set_timesteps` must be called before `step`.")
200
- if self._step_index >= len(self._solver_coeffs):
201
- raise RuntimeError("Scheduler step index exceeded configured timesteps.")
202
-
203
- coeffs = self._solver_coeffs[self._step_index]
204
- self._model_outputs.append(model_output)
205
- order = len(coeffs)
206
- pred = torch.zeros_like(model_output)
207
- recent = self._model_outputs[-order:]
208
- for coeff, output in zip(coeffs, recent):
209
- pred = pred + coeff * output
210
-
211
- prev_sample = sample + pred * self._timedeltas[self._step_index]
212
- self._step_index += 1
213
-
214
- if not return_dict:
215
- return (prev_sample,)
216
- return PixNerdSchedulerOutput(prev_sample=prev_sample)
217
-
218
- def add_noise(
219
- self,
220
- original_samples: torch.Tensor,
221
- noise: torch.Tensor,
222
- timesteps: torch.Tensor,
223
- ) -> torch.Tensor:
224
- alpha = timesteps.view(-1, 1, 1, 1)
225
- sigma = (1.0 - timesteps).view(-1, 1, 1, 1)
226
- return alpha * original_samples + sigma * noise
227
-
228
- __all__ = [
229
- "PixNerdFlowMatchScheduler",
230
- "PixNerdSchedulerOutput",
231
- ]
 
 
 
 
 
 
 
1
+ # Copyright 2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+
6
+ from __future__ import annotations
7
+
8
+ from dataclasses import dataclass
9
+ from typing import Any, Dict, List, Optional, Tuple, Union
10
+
11
+ import torch
12
+
13
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
14
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
15
+ from diffusers.utils import BaseOutput
16
+
17
+
18
+ @dataclass
19
+ class PixNerdSchedulerOutput(BaseOutput):
20
+ prev_sample: torch.Tensor
21
+
22
+
23
+ class PixNerdFlowMatchScheduler(SchedulerMixin, ConfigMixin):
24
+ """
25
+ Diffusers-compatible scheduler wrapper for PixNerd's AdamLM flow-matching sampler.
26
+ """
27
+
28
+ config_name = "scheduler_config.json"
29
+ order = 1
30
+ init_noise_sigma = 1.0
31
+
32
+ @staticmethod
33
+ def _lagrange_coeffs(order: int, pre_ts: torch.Tensor, t_start: torch.Tensor, t_end: torch.Tensor) -> List[float]:
34
+ ts = [float(v) for v in pre_ts[-order:].tolist()]
35
+ a = float(t_start)
36
+ b = float(t_end)
37
+
38
+ if order == 1:
39
+ return [1.0]
40
+ if order == 2:
41
+ t1, t2 = ts
42
+ int1 = 0.5 / (t1 - t2) * ((b - t2) ** 2 - (a - t2) ** 2)
43
+ int2 = 0.5 / (t2 - t1) * ((b - t1) ** 2 - (a - t1) ** 2)
44
+ total = int1 + int2
45
+ return [int1 / total, int2 / total]
46
+ if order == 3:
47
+ t1, t2, t3 = ts
48
+ int1_denom = (t1 - t2) * (t1 - t3)
49
+ int1 = ((1 / 3) * b**3 - 0.5 * (t2 + t3) * b**2 + (t2 * t3) * b) - (
50
+ (1 / 3) * a**3 - 0.5 * (t2 + t3) * a**2 + (t2 * t3) * a
51
+ )
52
+ int1 = int1 / int1_denom
53
+ int2_denom = (t2 - t1) * (t2 - t3)
54
+ int2 = ((1 / 3) * b**3 - 0.5 * (t1 + t3) * b**2 + (t1 * t3) * b) - (
55
+ (1 / 3) * a**3 - 0.5 * (t1 + t3) * a**2 + (t1 * t3) * a
56
+ )
57
+ int2 = int2 / int2_denom
58
+ int3_denom = (t3 - t1) * (t3 - t2)
59
+ int3 = ((1 / 3) * b**3 - 0.5 * (t1 + t2) * b**2 + (t1 * t2) * b) - (
60
+ (1 / 3) * a**3 - 0.5 * (t1 + t2) * a**2 + (t1 * t2) * a
61
+ )
62
+ int3 = int3 / int3_denom
63
+ total = int1 + int2 + int3
64
+ return [int1 / total, int2 / total, int3 / total]
65
+ if order == 4:
66
+ t1, t2, t3, t4 = ts
67
+ int1_denom = (t1 - t2) * (t1 - t3) * (t1 - t4)
68
+ int1 = ((1 / 4) * b**4 - (1 / 3) * (t2 + t3 + t4) * b**3 + 0.5 * (t3 * t4 + t2 * t3 + t2 * t4) * b**2 - (t2 * t3 * t4) * b) - (
69
+ (1 / 4) * a**4 - (1 / 3) * (t2 + t3 + t4) * a**3 + 0.5 * (t3 * t4 + t2 * t3 + t2 * t4) * a**2 - (t2 * t3 * t4) * a
70
+ )
71
+ int1 = int1 / int1_denom
72
+ int2_denom = (t2 - t1) * (t2 - t3) * (t2 - t4)
73
+ int2 = ((1 / 4) * b**4 - (1 / 3) * (t1 + t3 + t4) * b**3 + 0.5 * (t3 * t4 + t1 * t3 + t1 * t4) * b**2 - (t1 * t3 * t4) * b) - (
74
+ (1 / 4) * a**4 - (1 / 3) * (t1 + t3 + t4) * a**3 + 0.5 * (t3 * t4 + t1 * t3 + t1 * t4) * a**2 - (t1 * t3 * t4) * a
75
+ )
76
+ int2 = int2 / int2_denom
77
+ int3_denom = (t3 - t1) * (t3 - t2) * (t3 - t4)
78
+ int3 = ((1 / 4) * b**4 - (1 / 3) * (t1 + t2 + t4) * b**3 + 0.5 * (t4 * t2 + t1 * t2 + t1 * t4) * b**2 - (t1 * t2 * t4) * b) - (
79
+ (1 / 4) * a**4 - (1 / 3) * (t1 + t2 + t4) * a**3 + 0.5 * (t4 * t2 + t1 * t2 + t1 * t4) * a**2 - (t1 * t2 * t4) * a
80
+ )
81
+ int3 = int3 / int3_denom
82
+ int4_denom = (t4 - t1) * (t4 - t2) * (t4 - t3)
83
+ int4 = ((1 / 4) * b**4 - (1 / 3) * (t1 + t2 + t3) * b**3 + 0.5 * (t3 * t2 + t1 * t2 + t1 * t3) * b**2 - (t1 * t2 * t3) * b) - (
84
+ (1 / 4) * a**4 - (1 / 3) * (t1 + t2 + t3) * a**3 + 0.5 * (t3 * t2 + t1 * t2 + t1 * t3) * a**2 - (t1 * t2 * t3) * a
85
+ )
86
+ int4 = int4 / int4_denom
87
+ total = int1 + int2 + int3 + int4
88
+ return [int1 / total, int2 / total, int3 / total, int4 / total]
89
+ raise ValueError(f"Unsupported solver order: {order}.")
90
+
91
+ @register_to_config
92
+ def __init__(
93
+ self,
94
+ num_train_timesteps: int = 1000,
95
+ num_inference_steps: int = 25,
96
+ guidance_scale: float = 4.0,
97
+ timeshift: float = 3.0,
98
+ order: int = 2,
99
+ guidance_interval_min: float = 0.0,
100
+ guidance_interval_max: float = 1.0,
101
+ last_step: Optional[float] = None,
102
+ ) -> None:
103
+ self.num_inference_steps = int(num_inference_steps)
104
+ self.guidance_scale = float(guidance_scale)
105
+ self.timeshift = float(timeshift)
106
+ self.order = int(order)
107
+ self.guidance_interval_min = float(guidance_interval_min)
108
+ self.guidance_interval_max = float(guidance_interval_max)
109
+ self.last_step = last_step
110
+ self._reset_state()
111
+
112
+ @classmethod
113
+ def from_sampler_spec(cls, sampler_spec: Dict[str, Any]) -> "PixNerdFlowMatchScheduler":
114
+ init_args = dict(sampler_spec.get("init_args", {}))
115
+ return cls(
116
+ num_inference_steps=int(init_args.get("num_steps", 25)),
117
+ guidance_scale=float(init_args.get("guidance", 4.0)),
118
+ timeshift=float(init_args.get("timeshift", 3.0)),
119
+ order=int(init_args.get("order", 2)),
120
+ guidance_interval_min=float(init_args.get("guidance_interval_min", 0.0)),
121
+ guidance_interval_max=float(init_args.get("guidance_interval_max", 1.0)),
122
+ last_step=init_args.get("last_step"),
123
+ )
124
+
125
+ def _reset_state(self) -> None:
126
+ self.timesteps: Optional[torch.Tensor] = None
127
+ self._timedeltas: Optional[torch.Tensor] = None
128
+ self._solver_coeffs = None
129
+ self._model_outputs = []
130
+ self._step_index = 0
131
+
132
+ @staticmethod
133
+ def _shift_respace_fn(t: torch.Tensor, shift: float = 3.0) -> torch.Tensor:
134
+ return t / (t + (1 - t) * shift)
135
+
136
+ def _build_solver_state(
137
+ self,
138
+ num_inference_steps: int,
139
+ timeshift: float,
140
+ device: Optional[Union[str, torch.device]] = None,
141
+ ) -> Tuple[torch.Tensor, torch.Tensor, List[List[float]]]:
142
+ last_step = self.last_step
143
+ if last_step is None:
144
+ last_step = 1.0 / float(num_inference_steps)
145
+
146
+ endpoints = torch.linspace(0.0, 1 - float(last_step), int(num_inference_steps), dtype=torch.float32)
147
+ endpoints = torch.cat([endpoints, torch.tensor([1.0], dtype=torch.float32)], dim=0)
148
+ timesteps = self._shift_respace_fn(endpoints, timeshift).to(device=device)
149
+ timedeltas = (timesteps[1:] - timesteps[:-1]).to(device=device)
150
+
151
+ solver_coeffs: List[List[float]] = [[] for _ in range(int(num_inference_steps))]
152
+ for i in range(int(num_inference_steps)):
153
+ order = min(self.order, i + 1)
154
+ pre_ts = timesteps[: i + 1]
155
+ coeffs = self._lagrange_coeffs(order, pre_ts, pre_ts[i], timesteps[i + 1])
156
+ solver_coeffs[i] = coeffs
157
+ return timesteps[:-1], timedeltas, solver_coeffs
158
+
159
+ def set_timesteps(
160
+ self,
161
+ num_inference_steps: Optional[int] = None,
162
+ device: Optional[Union[str, torch.device]] = None,
163
+ timeshift: Optional[float] = None,
164
+ guidance_scale: Optional[float] = None,
165
+ order: Optional[int] = None,
166
+ **kwargs: Any,
167
+ ) -> None:
168
+ if num_inference_steps is not None:
169
+ self.num_inference_steps = int(num_inference_steps)
170
+ if timeshift is not None:
171
+ self.timeshift = float(timeshift)
172
+ else:
173
+ self.timeshift = float(getattr(self.config, "timeshift", self.timeshift))
174
+ if guidance_scale is not None:
175
+ self.guidance_scale = float(guidance_scale)
176
+ if order is not None:
177
+ self.order = int(order)
178
+ else:
179
+ self.order = int(getattr(self.config, "order", self.order))
180
+
181
+ timesteps, timedeltas, solver_coeffs = self._build_solver_state(
182
+ self.num_inference_steps,
183
+ self.timeshift,
184
+ device=device,
185
+ )
186
+ self.timesteps = timesteps
187
+ self._timedeltas = timedeltas
188
+ self._solver_coeffs = solver_coeffs
189
+ self._model_outputs = []
190
+ self._step_index = 0
191
+
192
+ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[torch.Tensor] = None) -> torch.Tensor:
193
+ return sample
194
+
195
+ def classifier_free_guidance(self, model_output: torch.Tensor) -> torch.Tensor:
196
+ if model_output.shape[0] % 2 != 0:
197
+ raise ValueError("Classifier-free guidance expects concatenated unconditional/conditional batches.")
198
+ uncond, cond = model_output.chunk(2, dim=0)
199
+ return uncond + self.guidance_scale * (cond - uncond)
200
+
201
+ def step(
202
+ self,
203
+ model_output: torch.Tensor,
204
+ timestep: Union[torch.Tensor, float, int],
205
+ sample: torch.Tensor,
206
+ return_dict: bool = True,
207
+ **kwargs: Any,
208
+ ) -> Union[PixNerdSchedulerOutput, Tuple[torch.Tensor]]:
209
+ if self.timesteps is None or self._timedeltas is None or self._solver_coeffs is None:
210
+ raise RuntimeError("`set_timesteps` must be called before `step`.")
211
+ if self._step_index >= len(self._solver_coeffs):
212
+ raise RuntimeError("Scheduler step index exceeded configured timesteps.")
213
+
214
+ coeffs = self._solver_coeffs[self._step_index]
215
+ self._model_outputs.append(model_output)
216
+ order = len(coeffs)
217
+ pred = torch.zeros_like(model_output)
218
+ recent = self._model_outputs[-order:]
219
+ for coeff, output in zip(coeffs, recent):
220
+ pred = pred + coeff * output
221
+
222
+ prev_sample = sample + pred * self._timedeltas[self._step_index]
223
+ self._step_index += 1
224
+
225
+ if not return_dict:
226
+ return (prev_sample,)
227
+ return PixNerdSchedulerOutput(prev_sample=prev_sample)
228
+
229
+ def add_noise(
230
+ self,
231
+ original_samples: torch.Tensor,
232
+ noise: torch.Tensor,
233
+ timesteps: torch.Tensor,
234
+ ) -> torch.Tensor:
235
+ alpha = timesteps.view(-1, 1, 1, 1)
236
+ sigma = (1.0 - timesteps).view(-1, 1, 1, 1)
237
+ return alpha * original_samples + sigma * noise
PixNerd-XL-16-256/transformer/modeling_pixnerd_transformer_2d.py CHANGED
@@ -20,6 +20,15 @@ class BaseAE(torch.nn.Module):
20
  super().__init__()
21
  self.scale = scale
22
  self.shift = shift
 
 
 
 
 
 
 
 
 
23
 
24
  def encode(self, x):
25
  return self._impl_encode(x) #.to(torch.bfloat16)
@@ -68,6 +77,15 @@ def resolve_conditioner_device(metadata: dict, fallback: torch.device | None = N
68
  class BaseConditioner(nn.Module):
69
  def __init__(self):
70
  super(BaseConditioner, self).__init__()
 
 
 
 
 
 
 
 
 
71
 
72
  def _impl_condition(self, y, metadata)->torch.Tensor:
73
  raise NotImplementedError()
@@ -166,6 +184,7 @@ class TimestepEmbedder(nn.Module):
166
 
167
  def forward(self, t):
168
  t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
 
169
  t_emb = self.mlp(t_freq)
170
  return t_emb
171
 
 
20
  super().__init__()
21
  self.scale = scale
22
  self.shift = shift
23
+ self.register_buffer("_diffusers_device_anchor", torch.zeros(0), persistent=False)
24
+
25
+ @property
26
+ def dtype(self) -> torch.dtype:
27
+ return self._diffusers_device_anchor.dtype
28
+
29
+ @property
30
+ def device(self) -> torch.device:
31
+ return self._diffusers_device_anchor.device
32
 
33
  def encode(self, x):
34
  return self._impl_encode(x) #.to(torch.bfloat16)
 
77
  class BaseConditioner(nn.Module):
78
  def __init__(self):
79
  super(BaseConditioner, self).__init__()
80
+ self.register_buffer("_diffusers_device_anchor", torch.zeros(0), persistent=False)
81
+
82
+ @property
83
+ def dtype(self) -> torch.dtype:
84
+ return self._diffusers_device_anchor.dtype
85
+
86
+ @property
87
+ def device(self) -> torch.device:
88
+ return self._diffusers_device_anchor.device
89
 
90
  def _impl_condition(self, y, metadata)->torch.Tensor:
91
  raise NotImplementedError()
 
184
 
185
  def forward(self, t):
186
  t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
187
+ t_freq = t_freq.to(dtype=self.mlp[0].weight.dtype)
188
  t_emb = self.mlp(t_freq)
189
  return t_emb
190
 
PixNerd-XL-16-512/README.md CHANGED
@@ -26,17 +26,17 @@ from diffusers import DiffusionPipeline
26
  pipe = DiffusionPipeline.from_pretrained(
27
  "BiliSakura/PixNerd-diffusers/PixNerd-XL-16-512",
28
  trust_remote_code=True,
29
- torch_dtype=torch.float32,
30
  ).to("cuda")
31
 
 
 
32
  images = pipe(
33
- prompt=207,
34
  height=512,
35
  width=512,
36
  num_inference_steps=25,
37
  guidance_scale=4.0,
38
- timeshift=3.0,
39
- order=2,
40
  ).images
41
  ```
42
 
@@ -44,4 +44,4 @@ images = pipe(
44
 
45
  ![PixNerd-XL-16-512 demo](demo.png)
46
 
47
- Class 207 β€” golden retriever / ι‡‘ζ―›ηŒŽηŠ¬, 512Γ—512, 25 steps.
 
26
  pipe = DiffusionPipeline.from_pretrained(
27
  "BiliSakura/PixNerd-diffusers/PixNerd-XL-16-512",
28
  trust_remote_code=True,
29
+ torch_dtype=torch.bfloat16,
30
  ).to("cuda")
31
 
32
+ # timeshift=3.0 and order=2 are defaults in scheduler/scheduler_config.json
33
+
34
  images = pipe(
35
+ class_labels="golden retriever",
36
  height=512,
37
  width=512,
38
  num_inference_steps=25,
39
  guidance_scale=4.0,
 
 
40
  ).images
41
  ```
42
 
 
44
 
45
  ![PixNerd-XL-16-512 demo](demo.png)
46
 
47
+ Class 207 β€” golden retriever, 512Γ—512, 25 steps.
PixNerd-XL-16-512/model_index.json CHANGED
@@ -1,15 +1,1017 @@
1
- {
2
- "_class_name": [
3
- "pipeline",
4
- "PixNerdPipeline"
5
- ],
6
- "_diffusers_version": "0.36.0",
7
- "scheduler": [
8
- "scheduling_pixnerd_flow_match",
9
- "PixNerdFlowMatchScheduler"
10
- ],
11
- "transformer": [
12
- "modeling_pixnerd_transformer_2d",
13
- "PixNerdTransformer2DModel"
14
- ]
15
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": [
3
+ "pipeline",
4
+ "PixNerdPipeline"
5
+ ],
6
+ "_diffusers_version": "0.36.0",
7
+ "scheduler": [
8
+ "scheduling_pixnerd_flow_match",
9
+ "PixNerdFlowMatchScheduler"
10
+ ],
11
+ "transformer": [
12
+ "modeling_pixnerd_transformer_2d",
13
+ "PixNerdTransformer2DModel"
14
+ ],
15
+ "id2label": {
16
+ "0": "tench, Tinca tinca",
17
+ "1": "goldfish, Carassius auratus",
18
+ "2": "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias",
19
+ "3": "tiger shark, Galeocerdo cuvieri",
20
+ "4": "hammerhead, hammerhead shark",
21
+ "5": "electric ray, crampfish, numbfish, torpedo",
22
+ "6": "stingray",
23
+ "7": "cock",
24
+ "8": "hen",
25
+ "9": "ostrich, Struthio camelus",
26
+ "10": "brambling, Fringilla montifringilla",
27
+ "11": "goldfinch, Carduelis carduelis",
28
+ "12": "house finch, linnet, Carpodacus mexicanus",
29
+ "13": "junco, snowbird",
30
+ "14": "indigo bunting, indigo finch, indigo bird, Passerina cyanea",
31
+ "15": "robin, American robin, Turdus migratorius",
32
+ "16": "bulbul",
33
+ "17": "jay",
34
+ "18": "magpie",
35
+ "19": "chickadee",
36
+ "20": "water ouzel, dipper",
37
+ "21": "kite",
38
+ "22": "bald eagle, American eagle, Haliaeetus leucocephalus",
39
+ "23": "vulture",
40
+ "24": "great grey owl, great gray owl, Strix nebulosa",
41
+ "25": "European fire salamander, Salamandra salamandra",
42
+ "26": "common newt, Triturus vulgaris",
43
+ "27": "eft",
44
+ "28": "spotted salamander, Ambystoma maculatum",
45
+ "29": "axolotl, mud puppy, Ambystoma mexicanum",
46
+ "30": "bullfrog, Rana catesbeiana",
47
+ "31": "tree frog, tree-frog",
48
+ "32": "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui",
49
+ "33": "loggerhead, loggerhead turtle, Caretta caretta",
50
+ "34": "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea",
51
+ "35": "mud turtle",
52
+ "36": "terrapin",
53
+ "37": "box turtle, box tortoise",
54
+ "38": "banded gecko",
55
+ "39": "common iguana, iguana, Iguana iguana",
56
+ "40": "American chameleon, anole, Anolis carolinensis",
57
+ "41": "whiptail, whiptail lizard",
58
+ "42": "agama",
59
+ "43": "frilled lizard, Chlamydosaurus kingi",
60
+ "44": "alligator lizard",
61
+ "45": "Gila monster, Heloderma suspectum",
62
+ "46": "green lizard, Lacerta viridis",
63
+ "47": "African chameleon, Chamaeleo chamaeleon",
64
+ "48": "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis",
65
+ "49": "African crocodile, Nile crocodile, Crocodylus niloticus",
66
+ "50": "American alligator, Alligator mississipiensis",
67
+ "51": "triceratops",
68
+ "52": "thunder snake, worm snake, Carphophis amoenus",
69
+ "53": "ringneck snake, ring-necked snake, ring snake",
70
+ "54": "hognose snake, puff adder, sand viper",
71
+ "55": "green snake, grass snake",
72
+ "56": "king snake, kingsnake",
73
+ "57": "garter snake, grass snake",
74
+ "58": "water snake",
75
+ "59": "vine snake",
76
+ "60": "night snake, Hypsiglena torquata",
77
+ "61": "boa constrictor, Constrictor constrictor",
78
+ "62": "rock python, rock snake, Python sebae",
79
+ "63": "Indian cobra, Naja naja",
80
+ "64": "green mamba",
81
+ "65": "sea snake",
82
+ "66": "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus",
83
+ "67": "diamondback, diamondback rattlesnake, Crotalus adamanteus",
84
+ "68": "sidewinder, horned rattlesnake, Crotalus cerastes",
85
+ "69": "trilobite",
86
+ "70": "harvestman, daddy longlegs, Phalangium opilio",
87
+ "71": "scorpion",
88
+ "72": "black and gold garden spider, Argiope aurantia",
89
+ "73": "barn spider, Araneus cavaticus",
90
+ "74": "garden spider, Aranea diademata",
91
+ "75": "black widow, Latrodectus mactans",
92
+ "76": "tarantula",
93
+ "77": "wolf spider, hunting spider",
94
+ "78": "tick",
95
+ "79": "centipede",
96
+ "80": "black grouse",
97
+ "81": "ptarmigan",
98
+ "82": "ruffed grouse, partridge, Bonasa umbellus",
99
+ "83": "prairie chicken, prairie grouse, prairie fowl",
100
+ "84": "peacock",
101
+ "85": "quail",
102
+ "86": "partridge",
103
+ "87": "African grey, African gray, Psittacus erithacus",
104
+ "88": "macaw",
105
+ "89": "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita",
106
+ "90": "lorikeet",
107
+ "91": "coucal",
108
+ "92": "bee eater",
109
+ "93": "hornbill",
110
+ "94": "hummingbird",
111
+ "95": "jacamar",
112
+ "96": "toucan",
113
+ "97": "drake",
114
+ "98": "red-breasted merganser, Mergus serrator",
115
+ "99": "goose",
116
+ "100": "black swan, Cygnus atratus",
117
+ "101": "tusker",
118
+ "102": "echidna, spiny anteater, anteater",
119
+ "103": "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus",
120
+ "104": "wallaby, brush kangaroo",
121
+ "105": "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus",
122
+ "106": "wombat",
123
+ "107": "jellyfish",
124
+ "108": "sea anemone, anemone",
125
+ "109": "brain coral",
126
+ "110": "flatworm, platyhelminth",
127
+ "111": "nematode, nematode worm, roundworm",
128
+ "112": "conch",
129
+ "113": "snail",
130
+ "114": "slug",
131
+ "115": "sea slug, nudibranch",
132
+ "116": "chiton, coat-of-mail shell, sea cradle, polyplacophore",
133
+ "117": "chambered nautilus, pearly nautilus, nautilus",
134
+ "118": "Dungeness crab, Cancer magister",
135
+ "119": "rock crab, Cancer irroratus",
136
+ "120": "fiddler crab",
137
+ "121": "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica",
138
+ "122": "American lobster, Northern lobster, Maine lobster, Homarus americanus",
139
+ "123": "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish",
140
+ "124": "crayfish, crawfish, crawdad, crawdaddy",
141
+ "125": "hermit crab",
142
+ "126": "isopod",
143
+ "127": "white stork, Ciconia ciconia",
144
+ "128": "black stork, Ciconia nigra",
145
+ "129": "spoonbill",
146
+ "130": "flamingo",
147
+ "131": "little blue heron, Egretta caerulea",
148
+ "132": "American egret, great white heron, Egretta albus",
149
+ "133": "bittern",
150
+ "134": "crane",
151
+ "135": "limpkin, Aramus pictus",
152
+ "136": "European gallinule, Porphyrio porphyrio",
153
+ "137": "American coot, marsh hen, mud hen, water hen, Fulica americana",
154
+ "138": "bustard",
155
+ "139": "ruddy turnstone, Arenaria interpres",
156
+ "140": "red-backed sandpiper, dunlin, Erolia alpina",
157
+ "141": "redshank, Tringa totanus",
158
+ "142": "dowitcher",
159
+ "143": "oystercatcher, oyster catcher",
160
+ "144": "pelican",
161
+ "145": "king penguin, Aptenodytes patagonica",
162
+ "146": "albatross, mollymawk",
163
+ "147": "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus",
164
+ "148": "killer whale, killer, orca, grampus, sea wolf, Orcinus orca",
165
+ "149": "dugong, Dugong dugon",
166
+ "150": "sea lion",
167
+ "151": "Chihuahua",
168
+ "152": "Japanese spaniel",
169
+ "153": "Maltese dog, Maltese terrier, Maltese",
170
+ "154": "Pekinese, Pekingese, Peke",
171
+ "155": "Shih-Tzu",
172
+ "156": "Blenheim spaniel",
173
+ "157": "papillon",
174
+ "158": "toy terrier",
175
+ "159": "Rhodesian ridgeback",
176
+ "160": "Afghan hound, Afghan",
177
+ "161": "basset, basset hound",
178
+ "162": "beagle",
179
+ "163": "bloodhound, sleuthhound",
180
+ "164": "bluetick",
181
+ "165": "black-and-tan coonhound",
182
+ "166": "Walker hound, Walker foxhound",
183
+ "167": "English foxhound",
184
+ "168": "redbone",
185
+ "169": "borzoi, Russian wolfhound",
186
+ "170": "Irish wolfhound",
187
+ "171": "Italian greyhound",
188
+ "172": "whippet",
189
+ "173": "Ibizan hound, Ibizan Podenco",
190
+ "174": "Norwegian elkhound, elkhound",
191
+ "175": "otterhound, otter hound",
192
+ "176": "Saluki, gazelle hound",
193
+ "177": "Scottish deerhound, deerhound",
194
+ "178": "Weimaraner",
195
+ "179": "Staffordshire bullterrier, Staffordshire bull terrier",
196
+ "180": "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier",
197
+ "181": "Bedlington terrier",
198
+ "182": "Border terrier",
199
+ "183": "Kerry blue terrier",
200
+ "184": "Irish terrier",
201
+ "185": "Norfolk terrier",
202
+ "186": "Norwich terrier",
203
+ "187": "Yorkshire terrier",
204
+ "188": "wire-haired fox terrier",
205
+ "189": "Lakeland terrier",
206
+ "190": "Sealyham terrier, Sealyham",
207
+ "191": "Airedale, Airedale terrier",
208
+ "192": "cairn, cairn terrier",
209
+ "193": "Australian terrier",
210
+ "194": "Dandie Dinmont, Dandie Dinmont terrier",
211
+ "195": "Boston bull, Boston terrier",
212
+ "196": "miniature schnauzer",
213
+ "197": "giant schnauzer",
214
+ "198": "standard schnauzer",
215
+ "199": "Scotch terrier, Scottish terrier, Scottie",
216
+ "200": "Tibetan terrier, chrysanthemum dog",
217
+ "201": "silky terrier, Sydney silky",
218
+ "202": "soft-coated wheaten terrier",
219
+ "203": "West Highland white terrier",
220
+ "204": "Lhasa, Lhasa apso",
221
+ "205": "flat-coated retriever",
222
+ "206": "curly-coated retriever",
223
+ "207": "golden retriever",
224
+ "208": "Labrador retriever",
225
+ "209": "Chesapeake Bay retriever",
226
+ "210": "German short-haired pointer",
227
+ "211": "vizsla, Hungarian pointer",
228
+ "212": "English setter",
229
+ "213": "Irish setter, red setter",
230
+ "214": "Gordon setter",
231
+ "215": "Brittany spaniel",
232
+ "216": "clumber, clumber spaniel",
233
+ "217": "English springer, English springer spaniel",
234
+ "218": "Welsh springer spaniel",
235
+ "219": "cocker spaniel, English cocker spaniel, cocker",
236
+ "220": "Sussex spaniel",
237
+ "221": "Irish water spaniel",
238
+ "222": "kuvasz",
239
+ "223": "schipperke",
240
+ "224": "groenendael",
241
+ "225": "malinois",
242
+ "226": "briard",
243
+ "227": "kelpie",
244
+ "228": "komondor",
245
+ "229": "Old English sheepdog, bobtail",
246
+ "230": "Shetland sheepdog, Shetland sheep dog, Shetland",
247
+ "231": "collie",
248
+ "232": "Border collie",
249
+ "233": "Bouvier des Flandres, Bouviers des Flandres",
250
+ "234": "Rottweiler",
251
+ "235": "German shepherd, German shepherd dog, German police dog, alsatian",
252
+ "236": "Doberman, Doberman pinscher",
253
+ "237": "miniature pinscher",
254
+ "238": "Greater Swiss Mountain dog",
255
+ "239": "Bernese mountain dog",
256
+ "240": "Appenzeller",
257
+ "241": "EntleBucher",
258
+ "242": "boxer",
259
+ "243": "bull mastiff",
260
+ "244": "Tibetan mastiff",
261
+ "245": "French bulldog",
262
+ "246": "Great Dane",
263
+ "247": "Saint Bernard, St Bernard",
264
+ "248": "Eskimo dog, husky",
265
+ "249": "malamute, malemute, Alaskan malamute",
266
+ "250": "Siberian husky",
267
+ "251": "dalmatian, coach dog, carriage dog",
268
+ "252": "affenpinscher, monkey pinscher, monkey dog",
269
+ "253": "basenji",
270
+ "254": "pug, pug-dog",
271
+ "255": "Leonberg",
272
+ "256": "Newfoundland, Newfoundland dog",
273
+ "257": "Great Pyrenees",
274
+ "258": "Samoyed, Samoyede",
275
+ "259": "Pomeranian",
276
+ "260": "chow, chow chow",
277
+ "261": "keeshond",
278
+ "262": "Brabancon griffon",
279
+ "263": "Pembroke, Pembroke Welsh corgi",
280
+ "264": "Cardigan, Cardigan Welsh corgi",
281
+ "265": "toy poodle",
282
+ "266": "miniature poodle",
283
+ "267": "standard poodle",
284
+ "268": "Mexican hairless",
285
+ "269": "timber wolf, grey wolf, gray wolf, Canis lupus",
286
+ "270": "white wolf, Arctic wolf, Canis lupus tundrarum",
287
+ "271": "red wolf, maned wolf, Canis rufus, Canis niger",
288
+ "272": "coyote, prairie wolf, brush wolf, Canis latrans",
289
+ "273": "dingo, warrigal, warragal, Canis dingo",
290
+ "274": "dhole, Cuon alpinus",
291
+ "275": "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus",
292
+ "276": "hyena, hyaena",
293
+ "277": "red fox, Vulpes vulpes",
294
+ "278": "kit fox, Vulpes macrotis",
295
+ "279": "Arctic fox, white fox, Alopex lagopus",
296
+ "280": "grey fox, gray fox, Urocyon cinereoargenteus",
297
+ "281": "tabby, tabby cat",
298
+ "282": "tiger cat",
299
+ "283": "Persian cat",
300
+ "284": "Siamese cat, Siamese",
301
+ "285": "Egyptian cat",
302
+ "286": "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor",
303
+ "287": "lynx, catamount",
304
+ "288": "leopard, Panthera pardus",
305
+ "289": "snow leopard, ounce, Panthera uncia",
306
+ "290": "jaguar, panther, Panthera onca, Felis onca",
307
+ "291": "lion, king of beasts, Panthera leo",
308
+ "292": "tiger, Panthera tigris",
309
+ "293": "cheetah, chetah, Acinonyx jubatus",
310
+ "294": "brown bear, bruin, Ursus arctos",
311
+ "295": "American black bear, black bear, Ursus americanus, Euarctos americanus",
312
+ "296": "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus",
313
+ "297": "sloth bear, Melursus ursinus, Ursus ursinus",
314
+ "298": "mongoose",
315
+ "299": "meerkat, mierkat",
316
+ "300": "tiger beetle",
317
+ "301": "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle",
318
+ "302": "ground beetle, carabid beetle",
319
+ "303": "long-horned beetle, longicorn, longicorn beetle",
320
+ "304": "leaf beetle, chrysomelid",
321
+ "305": "dung beetle",
322
+ "306": "rhinoceros beetle",
323
+ "307": "weevil",
324
+ "308": "fly",
325
+ "309": "bee",
326
+ "310": "ant, emmet, pismire",
327
+ "311": "grasshopper, hopper",
328
+ "312": "cricket",
329
+ "313": "walking stick, walkingstick, stick insect",
330
+ "314": "cockroach, roach",
331
+ "315": "mantis, mantid",
332
+ "316": "cicada, cicala",
333
+ "317": "leafhopper",
334
+ "318": "lacewing, lacewing fly",
335
+ "319": "dragonfly, darning needle, devils darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk",
336
+ "320": "damselfly",
337
+ "321": "admiral",
338
+ "322": "ringlet, ringlet butterfly",
339
+ "323": "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus",
340
+ "324": "cabbage butterfly",
341
+ "325": "sulphur butterfly, sulfur butterfly",
342
+ "326": "lycaenid, lycaenid butterfly",
343
+ "327": "starfish, sea star",
344
+ "328": "sea urchin",
345
+ "329": "sea cucumber, holothurian",
346
+ "330": "wood rabbit, cottontail, cottontail rabbit",
347
+ "331": "hare",
348
+ "332": "Angora, Angora rabbit",
349
+ "333": "hamster",
350
+ "334": "porcupine, hedgehog",
351
+ "335": "fox squirrel, eastern fox squirrel, Sciurus niger",
352
+ "336": "marmot",
353
+ "337": "beaver",
354
+ "338": "guinea pig, Cavia cobaya",
355
+ "339": "sorrel",
356
+ "340": "zebra",
357
+ "341": "hog, pig, grunter, squealer, Sus scrofa",
358
+ "342": "wild boar, boar, Sus scrofa",
359
+ "343": "warthog",
360
+ "344": "hippopotamus, hippo, river horse, Hippopotamus amphibius",
361
+ "345": "ox",
362
+ "346": "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis",
363
+ "347": "bison",
364
+ "348": "ram, tup",
365
+ "349": "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis",
366
+ "350": "ibex, Capra ibex",
367
+ "351": "hartebeest",
368
+ "352": "impala, Aepyceros melampus",
369
+ "353": "gazelle",
370
+ "354": "Arabian camel, dromedary, Camelus dromedarius",
371
+ "355": "llama",
372
+ "356": "weasel",
373
+ "357": "mink",
374
+ "358": "polecat, fitch, foulmart, foumart, Mustela putorius",
375
+ "359": "black-footed ferret, ferret, Mustela nigripes",
376
+ "360": "otter",
377
+ "361": "skunk, polecat, wood pussy",
378
+ "362": "badger",
379
+ "363": "armadillo",
380
+ "364": "three-toed sloth, ai, Bradypus tridactylus",
381
+ "365": "orangutan, orang, orangutang, Pongo pygmaeus",
382
+ "366": "gorilla, Gorilla gorilla",
383
+ "367": "chimpanzee, chimp, Pan troglodytes",
384
+ "368": "gibbon, Hylobates lar",
385
+ "369": "siamang, Hylobates syndactylus, Symphalangus syndactylus",
386
+ "370": "guenon, guenon monkey",
387
+ "371": "patas, hussar monkey, Erythrocebus patas",
388
+ "372": "baboon",
389
+ "373": "macaque",
390
+ "374": "langur",
391
+ "375": "colobus, colobus monkey",
392
+ "376": "proboscis monkey, Nasalis larvatus",
393
+ "377": "marmoset",
394
+ "378": "capuchin, ringtail, Cebus capucinus",
395
+ "379": "howler monkey, howler",
396
+ "380": "titi, titi monkey",
397
+ "381": "spider monkey, Ateles geoffroyi",
398
+ "382": "squirrel monkey, Saimiri sciureus",
399
+ "383": "Madagascar cat, ring-tailed lemur, Lemur catta",
400
+ "384": "indri, indris, Indri indri, Indri brevicaudatus",
401
+ "385": "Indian elephant, Elephas maximus",
402
+ "386": "African elephant, Loxodonta africana",
403
+ "387": "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens",
404
+ "388": "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca",
405
+ "389": "barracouta, snoek",
406
+ "390": "eel",
407
+ "391": "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch",
408
+ "392": "rock beauty, Holocanthus tricolor",
409
+ "393": "anemone fish",
410
+ "394": "sturgeon",
411
+ "395": "gar, garfish, garpike, billfish, Lepisosteus osseus",
412
+ "396": "lionfish",
413
+ "397": "puffer, pufferfish, blowfish, globefish",
414
+ "398": "abacus",
415
+ "399": "abaya",
416
+ "400": "academic gown, academic robe, judge robe",
417
+ "401": "accordion, piano accordion, squeeze box",
418
+ "402": "acoustic guitar",
419
+ "403": "aircraft carrier, carrier, flattop, attack aircraft carrier",
420
+ "404": "airliner",
421
+ "405": "airship, dirigible",
422
+ "406": "altar",
423
+ "407": "ambulance",
424
+ "408": "amphibian, amphibious vehicle",
425
+ "409": "analog clock",
426
+ "410": "apiary, bee house",
427
+ "411": "apron",
428
+ "412": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin",
429
+ "413": "assault rifle, assault gun",
430
+ "414": "backpack, back pack, knapsack, packsack, rucksack, haversack",
431
+ "415": "bakery, bakeshop, bakehouse",
432
+ "416": "balance beam, beam",
433
+ "417": "balloon",
434
+ "418": "ballpoint, ballpoint pen, ballpen, Biro",
435
+ "419": "Band Aid",
436
+ "420": "banjo",
437
+ "421": "bannister, banister, balustrade, balusters, handrail",
438
+ "422": "barbell",
439
+ "423": "barber chair",
440
+ "424": "barbershop",
441
+ "425": "barn",
442
+ "426": "barometer",
443
+ "427": "barrel, cask",
444
+ "428": "barrow, garden cart, lawn cart, wheelbarrow",
445
+ "429": "baseball",
446
+ "430": "basketball",
447
+ "431": "bassinet",
448
+ "432": "bassoon",
449
+ "433": "bathing cap, swimming cap",
450
+ "434": "bath towel",
451
+ "435": "bathtub, bathing tub, bath, tub",
452
+ "436": "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon",
453
+ "437": "beacon, lighthouse, beacon light, pharos",
454
+ "438": "beaker",
455
+ "439": "bearskin, busby, shako",
456
+ "440": "beer bottle",
457
+ "441": "beer glass",
458
+ "442": "bell cote, bell cot",
459
+ "443": "bib",
460
+ "444": "bicycle-built-for-two, tandem bicycle, tandem",
461
+ "445": "bikini, two-piece",
462
+ "446": "binder, ring-binder",
463
+ "447": "binoculars, field glasses, opera glasses",
464
+ "448": "birdhouse",
465
+ "449": "boathouse",
466
+ "450": "bobsled, bobsleigh, bob",
467
+ "451": "bolo tie, bolo, bola tie, bola",
468
+ "452": "bonnet, poke bonnet",
469
+ "453": "bookcase",
470
+ "454": "bookshop, bookstore, bookstall",
471
+ "455": "bottlecap",
472
+ "456": "bow",
473
+ "457": "bow tie, bow-tie, bowtie",
474
+ "458": "brass, memorial tablet, plaque",
475
+ "459": "brassiere, bra, bandeau",
476
+ "460": "breakwater, groin, groyne, mole, bulwark, seawall, jetty",
477
+ "461": "breastplate, aegis, egis",
478
+ "462": "broom",
479
+ "463": "bucket, pail",
480
+ "464": "buckle",
481
+ "465": "bulletproof vest",
482
+ "466": "bullet train, bullet",
483
+ "467": "butcher shop, meat market",
484
+ "468": "cab, hack, taxi, taxicab",
485
+ "469": "caldron, cauldron",
486
+ "470": "candle, taper, wax light",
487
+ "471": "cannon",
488
+ "472": "canoe",
489
+ "473": "can opener, tin opener",
490
+ "474": "cardigan",
491
+ "475": "car mirror",
492
+ "476": "carousel, carrousel, merry-go-round, roundabout, whirligig",
493
+ "477": "carpenters kit, tool kit",
494
+ "478": "carton",
495
+ "479": "car wheel",
496
+ "480": "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM",
497
+ "481": "cassette",
498
+ "482": "cassette player",
499
+ "483": "castle",
500
+ "484": "catamaran",
501
+ "485": "CD player",
502
+ "486": "cello, violoncello",
503
+ "487": "cellular telephone, cellular phone, cellphone, cell, mobile phone",
504
+ "488": "chain",
505
+ "489": "chainlink fence",
506
+ "490": "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour",
507
+ "491": "chain saw, chainsaw",
508
+ "492": "chest",
509
+ "493": "chiffonier, commode",
510
+ "494": "chime, bell, gong",
511
+ "495": "china cabinet, china closet",
512
+ "496": "Christmas stocking",
513
+ "497": "church, church building",
514
+ "498": "cinema, movie theater, movie theatre, movie house, picture palace",
515
+ "499": "cleaver, meat cleaver, chopper",
516
+ "500": "cliff dwelling",
517
+ "501": "cloak",
518
+ "502": "clog, geta, patten, sabot",
519
+ "503": "cocktail shaker",
520
+ "504": "coffee mug",
521
+ "505": "coffeepot",
522
+ "506": "coil, spiral, volute, whorl, helix",
523
+ "507": "combination lock",
524
+ "508": "computer keyboard, keypad",
525
+ "509": "confectionery, confectionary, candy store",
526
+ "510": "container ship, containership, container vessel",
527
+ "511": "convertible",
528
+ "512": "corkscrew, bottle screw",
529
+ "513": "cornet, horn, trumpet, trump",
530
+ "514": "cowboy boot",
531
+ "515": "cowboy hat, ten-gallon hat",
532
+ "516": "cradle",
533
+ "517": "crane",
534
+ "518": "crash helmet",
535
+ "519": "crate",
536
+ "520": "crib, cot",
537
+ "521": "Crock Pot",
538
+ "522": "croquet ball",
539
+ "523": "crutch",
540
+ "524": "cuirass",
541
+ "525": "dam, dike, dyke",
542
+ "526": "desk",
543
+ "527": "desktop computer",
544
+ "528": "dial telephone, dial phone",
545
+ "529": "diaper, nappy, napkin",
546
+ "530": "digital clock",
547
+ "531": "digital watch",
548
+ "532": "dining table, board",
549
+ "533": "dishrag, dishcloth",
550
+ "534": "dishwasher, dish washer, dishwashing machine",
551
+ "535": "disk brake, disc brake",
552
+ "536": "dock, dockage, docking facility",
553
+ "537": "dogsled, dog sled, dog sleigh",
554
+ "538": "dome",
555
+ "539": "doormat, welcome mat",
556
+ "540": "drilling platform, offshore rig",
557
+ "541": "drum, membranophone, tympan",
558
+ "542": "drumstick",
559
+ "543": "dumbbell",
560
+ "544": "Dutch oven",
561
+ "545": "electric fan, blower",
562
+ "546": "electric guitar",
563
+ "547": "electric locomotive",
564
+ "548": "entertainment center",
565
+ "549": "envelope",
566
+ "550": "espresso maker",
567
+ "551": "face powder",
568
+ "552": "feather boa, boa",
569
+ "553": "file, file cabinet, filing cabinet",
570
+ "554": "fireboat",
571
+ "555": "fire engine, fire truck",
572
+ "556": "fire screen, fireguard",
573
+ "557": "flagpole, flagstaff",
574
+ "558": "flute, transverse flute",
575
+ "559": "folding chair",
576
+ "560": "football helmet",
577
+ "561": "forklift",
578
+ "562": "fountain",
579
+ "563": "fountain pen",
580
+ "564": "four-poster",
581
+ "565": "freight car",
582
+ "566": "French horn, horn",
583
+ "567": "frying pan, frypan, skillet",
584
+ "568": "fur coat",
585
+ "569": "garbage truck, dustcart",
586
+ "570": "gasmask, respirator, gas helmet",
587
+ "571": "gas pump, gasoline pump, petrol pump, island dispenser",
588
+ "572": "goblet",
589
+ "573": "go-kart",
590
+ "574": "golf ball",
591
+ "575": "golfcart, golf cart",
592
+ "576": "gondola",
593
+ "577": "gong, tam-tam",
594
+ "578": "gown",
595
+ "579": "grand piano, grand",
596
+ "580": "greenhouse, nursery, glasshouse",
597
+ "581": "grille, radiator grille",
598
+ "582": "grocery store, grocery, food market, market",
599
+ "583": "guillotine",
600
+ "584": "hair slide",
601
+ "585": "hair spray",
602
+ "586": "half track",
603
+ "587": "hammer",
604
+ "588": "hamper",
605
+ "589": "hand blower, blow dryer, blow drier, hair dryer, hair drier",
606
+ "590": "hand-held computer, hand-held microcomputer",
607
+ "591": "handkerchief, hankie, hanky, hankey",
608
+ "592": "hard disc, hard disk, fixed disk",
609
+ "593": "harmonica, mouth organ, harp, mouth harp",
610
+ "594": "harp",
611
+ "595": "harvester, reaper",
612
+ "596": "hatchet",
613
+ "597": "holster",
614
+ "598": "home theater, home theatre",
615
+ "599": "honeycomb",
616
+ "600": "hook, claw",
617
+ "601": "hoopskirt, crinoline",
618
+ "602": "horizontal bar, high bar",
619
+ "603": "horse cart, horse-cart",
620
+ "604": "hourglass",
621
+ "605": "iPod",
622
+ "606": "iron, smoothing iron",
623
+ "607": "jack-o-lantern",
624
+ "608": "jean, blue jean, denim",
625
+ "609": "jeep, landrover",
626
+ "610": "jersey, T-shirt, tee shirt",
627
+ "611": "jigsaw puzzle",
628
+ "612": "jinrikisha, ricksha, rickshaw",
629
+ "613": "joystick",
630
+ "614": "kimono",
631
+ "615": "knee pad",
632
+ "616": "knot",
633
+ "617": "lab coat, laboratory coat",
634
+ "618": "ladle",
635
+ "619": "lampshade, lamp shade",
636
+ "620": "laptop, laptop computer",
637
+ "621": "lawn mower, mower",
638
+ "622": "lens cap, lens cover",
639
+ "623": "letter opener, paper knife, paperknife",
640
+ "624": "library",
641
+ "625": "lifeboat",
642
+ "626": "lighter, light, igniter, ignitor",
643
+ "627": "limousine, limo",
644
+ "628": "liner, ocean liner",
645
+ "629": "lipstick, lip rouge",
646
+ "630": "Loafer",
647
+ "631": "lotion",
648
+ "632": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system",
649
+ "633": "loupe, jewelers loupe",
650
+ "634": "lumbermill, sawmill",
651
+ "635": "magnetic compass",
652
+ "636": "mailbag, postbag",
653
+ "637": "mailbox, letter box",
654
+ "638": "maillot",
655
+ "639": "maillot, tank suit",
656
+ "640": "manhole cover",
657
+ "641": "maraca",
658
+ "642": "marimba, xylophone",
659
+ "643": "mask",
660
+ "644": "matchstick",
661
+ "645": "maypole",
662
+ "646": "maze, labyrinth",
663
+ "647": "measuring cup",
664
+ "648": "medicine chest, medicine cabinet",
665
+ "649": "megalith, megalithic structure",
666
+ "650": "microphone, mike",
667
+ "651": "microwave, microwave oven",
668
+ "652": "military uniform",
669
+ "653": "milk can",
670
+ "654": "minibus",
671
+ "655": "miniskirt, mini",
672
+ "656": "minivan",
673
+ "657": "missile",
674
+ "658": "mitten",
675
+ "659": "mixing bowl",
676
+ "660": "mobile home, manufactured home",
677
+ "661": "Model T",
678
+ "662": "modem",
679
+ "663": "monastery",
680
+ "664": "monitor",
681
+ "665": "moped",
682
+ "666": "mortar",
683
+ "667": "mortarboard",
684
+ "668": "mosque",
685
+ "669": "mosquito net",
686
+ "670": "motor scooter, scooter",
687
+ "671": "mountain bike, all-terrain bike, off-roader",
688
+ "672": "mountain tent",
689
+ "673": "mouse, computer mouse",
690
+ "674": "mousetrap",
691
+ "675": "moving van",
692
+ "676": "muzzle",
693
+ "677": "nail",
694
+ "678": "neck brace",
695
+ "679": "necklace",
696
+ "680": "nipple",
697
+ "681": "notebook, notebook computer",
698
+ "682": "obelisk",
699
+ "683": "oboe, hautboy, hautbois",
700
+ "684": "ocarina, sweet potato",
701
+ "685": "odometer, hodometer, mileometer, milometer",
702
+ "686": "oil filter",
703
+ "687": "organ, pipe organ",
704
+ "688": "oscilloscope, scope, cathode-ray oscilloscope, CRO",
705
+ "689": "overskirt",
706
+ "690": "oxcart",
707
+ "691": "oxygen mask",
708
+ "692": "packet",
709
+ "693": "paddle, boat paddle",
710
+ "694": "paddlewheel, paddle wheel",
711
+ "695": "padlock",
712
+ "696": "paintbrush",
713
+ "697": "pajama, pyjama, pjs, jammies",
714
+ "698": "palace",
715
+ "699": "panpipe, pandean pipe, syrinx",
716
+ "700": "paper towel",
717
+ "701": "parachute, chute",
718
+ "702": "parallel bars, bars",
719
+ "703": "park bench",
720
+ "704": "parking meter",
721
+ "705": "passenger car, coach, carriage",
722
+ "706": "patio, terrace",
723
+ "707": "pay-phone, pay-station",
724
+ "708": "pedestal, plinth, footstall",
725
+ "709": "pencil box, pencil case",
726
+ "710": "pencil sharpener",
727
+ "711": "perfume, essence",
728
+ "712": "Petri dish",
729
+ "713": "photocopier",
730
+ "714": "pick, plectrum, plectron",
731
+ "715": "pickelhaube",
732
+ "716": "picket fence, paling",
733
+ "717": "pickup, pickup truck",
734
+ "718": "pier",
735
+ "719": "piggy bank, penny bank",
736
+ "720": "pill bottle",
737
+ "721": "pillow",
738
+ "722": "ping-pong ball",
739
+ "723": "pinwheel",
740
+ "724": "pirate, pirate ship",
741
+ "725": "pitcher, ewer",
742
+ "726": "plane, carpenters plane, woodworking plane",
743
+ "727": "planetarium",
744
+ "728": "plastic bag",
745
+ "729": "plate rack",
746
+ "730": "plow, plough",
747
+ "731": "plunger, plumbers helper",
748
+ "732": "Polaroid camera, Polaroid Land camera",
749
+ "733": "pole",
750
+ "734": "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria",
751
+ "735": "poncho",
752
+ "736": "pool table, billiard table, snooker table",
753
+ "737": "pop bottle, soda bottle",
754
+ "738": "pot, flowerpot",
755
+ "739": "potters wheel",
756
+ "740": "power drill",
757
+ "741": "prayer rug, prayer mat",
758
+ "742": "printer",
759
+ "743": "prison, prison house",
760
+ "744": "projectile, missile",
761
+ "745": "projector",
762
+ "746": "puck, hockey puck",
763
+ "747": "punching bag, punch bag, punching ball, punchball",
764
+ "748": "purse",
765
+ "749": "quill, quill pen",
766
+ "750": "quilt, comforter, comfort, puff",
767
+ "751": "racer, race car, racing car",
768
+ "752": "racket, racquet",
769
+ "753": "radiator",
770
+ "754": "radio, wireless",
771
+ "755": "radio telescope, radio reflector",
772
+ "756": "rain barrel",
773
+ "757": "recreational vehicle, RV, R.V.",
774
+ "758": "reel",
775
+ "759": "reflex camera",
776
+ "760": "refrigerator, icebox",
777
+ "761": "remote control, remote",
778
+ "762": "restaurant, eating house, eating place, eatery",
779
+ "763": "revolver, six-gun, six-shooter",
780
+ "764": "rifle",
781
+ "765": "rocking chair, rocker",
782
+ "766": "rotisserie",
783
+ "767": "rubber eraser, rubber, pencil eraser",
784
+ "768": "rugby ball",
785
+ "769": "rule, ruler",
786
+ "770": "running shoe",
787
+ "771": "safe",
788
+ "772": "safety pin",
789
+ "773": "saltshaker, salt shaker",
790
+ "774": "sandal",
791
+ "775": "sarong",
792
+ "776": "sax, saxophone",
793
+ "777": "scabbard",
794
+ "778": "scale, weighing machine",
795
+ "779": "school bus",
796
+ "780": "schooner",
797
+ "781": "scoreboard",
798
+ "782": "screen, CRT screen",
799
+ "783": "screw",
800
+ "784": "screwdriver",
801
+ "785": "seat belt, seatbelt",
802
+ "786": "sewing machine",
803
+ "787": "shield, buckler",
804
+ "788": "shoe shop, shoe-shop, shoe store",
805
+ "789": "shoji",
806
+ "790": "shopping basket",
807
+ "791": "shopping cart",
808
+ "792": "shovel",
809
+ "793": "shower cap",
810
+ "794": "shower curtain",
811
+ "795": "ski",
812
+ "796": "ski mask",
813
+ "797": "sleeping bag",
814
+ "798": "slide rule, slipstick",
815
+ "799": "sliding door",
816
+ "800": "slot, one-armed bandit",
817
+ "801": "snorkel",
818
+ "802": "snowmobile",
819
+ "803": "snowplow, snowplough",
820
+ "804": "soap dispenser",
821
+ "805": "soccer ball",
822
+ "806": "sock",
823
+ "807": "solar dish, solar collector, solar furnace",
824
+ "808": "sombrero",
825
+ "809": "soup bowl",
826
+ "810": "space bar",
827
+ "811": "space heater",
828
+ "812": "space shuttle",
829
+ "813": "spatula",
830
+ "814": "speedboat",
831
+ "815": "spider web, spiders web",
832
+ "816": "spindle",
833
+ "817": "sports car, sport car",
834
+ "818": "spotlight, spot",
835
+ "819": "stage",
836
+ "820": "steam locomotive",
837
+ "821": "steel arch bridge",
838
+ "822": "steel drum",
839
+ "823": "stethoscope",
840
+ "824": "stole",
841
+ "825": "stone wall",
842
+ "826": "stopwatch, stop watch",
843
+ "827": "stove",
844
+ "828": "strainer",
845
+ "829": "streetcar, tram, tramcar, trolley, trolley car",
846
+ "830": "stretcher",
847
+ "831": "studio couch, day bed",
848
+ "832": "stupa, tope",
849
+ "833": "submarine, pigboat, sub, U-boat",
850
+ "834": "suit, suit of clothes",
851
+ "835": "sundial",
852
+ "836": "sunglass",
853
+ "837": "sunglasses, dark glasses, shades",
854
+ "838": "sunscreen, sunblock, sun blocker",
855
+ "839": "suspension bridge",
856
+ "840": "swab, swob, mop",
857
+ "841": "sweatshirt",
858
+ "842": "swimming trunks, bathing trunks",
859
+ "843": "swing",
860
+ "844": "switch, electric switch, electrical switch",
861
+ "845": "syringe",
862
+ "846": "table lamp",
863
+ "847": "tank, army tank, armored combat vehicle, armoured combat vehicle",
864
+ "848": "tape player",
865
+ "849": "teapot",
866
+ "850": "teddy, teddy bear",
867
+ "851": "television, television system",
868
+ "852": "tennis ball",
869
+ "853": "thatch, thatched roof",
870
+ "854": "theater curtain, theatre curtain",
871
+ "855": "thimble",
872
+ "856": "thresher, thrasher, threshing machine",
873
+ "857": "throne",
874
+ "858": "tile roof",
875
+ "859": "toaster",
876
+ "860": "tobacco shop, tobacconist shop, tobacconist",
877
+ "861": "toilet seat",
878
+ "862": "torch",
879
+ "863": "totem pole",
880
+ "864": "tow truck, tow car, wrecker",
881
+ "865": "toyshop",
882
+ "866": "tractor",
883
+ "867": "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi",
884
+ "868": "tray",
885
+ "869": "trench coat",
886
+ "870": "tricycle, trike, velocipede",
887
+ "871": "trimaran",
888
+ "872": "tripod",
889
+ "873": "triumphal arch",
890
+ "874": "trolleybus, trolley coach, trackless trolley",
891
+ "875": "trombone",
892
+ "876": "tub, vat",
893
+ "877": "turnstile",
894
+ "878": "typewriter keyboard",
895
+ "879": "umbrella",
896
+ "880": "unicycle, monocycle",
897
+ "881": "upright, upright piano",
898
+ "882": "vacuum, vacuum cleaner",
899
+ "883": "vase",
900
+ "884": "vault",
901
+ "885": "velvet",
902
+ "886": "vending machine",
903
+ "887": "vestment",
904
+ "888": "viaduct",
905
+ "889": "violin, fiddle",
906
+ "890": "volleyball",
907
+ "891": "waffle iron",
908
+ "892": "wall clock",
909
+ "893": "wallet, billfold, notecase, pocketbook",
910
+ "894": "wardrobe, closet, press",
911
+ "895": "warplane, military plane",
912
+ "896": "washbasin, handbasin, washbowl, lavabo, wash-hand basin",
913
+ "897": "washer, automatic washer, washing machine",
914
+ "898": "water bottle",
915
+ "899": "water jug",
916
+ "900": "water tower",
917
+ "901": "whiskey jug",
918
+ "902": "whistle",
919
+ "903": "wig",
920
+ "904": "window screen",
921
+ "905": "window shade",
922
+ "906": "Windsor tie",
923
+ "907": "wine bottle",
924
+ "908": "wing",
925
+ "909": "wok",
926
+ "910": "wooden spoon",
927
+ "911": "wool, woolen, woollen",
928
+ "912": "worm fence, snake fence, snake-rail fence, Virginia fence",
929
+ "913": "wreck",
930
+ "914": "yawl",
931
+ "915": "yurt",
932
+ "916": "web site, website, internet site, site",
933
+ "917": "comic book",
934
+ "918": "crossword puzzle, crossword",
935
+ "919": "street sign",
936
+ "920": "traffic light, traffic signal, stoplight",
937
+ "921": "book jacket, dust cover, dust jacket, dust wrapper",
938
+ "922": "menu",
939
+ "923": "plate",
940
+ "924": "guacamole",
941
+ "925": "consomme",
942
+ "926": "hot pot, hotpot",
943
+ "927": "trifle",
944
+ "928": "ice cream, icecream",
945
+ "929": "ice lolly, lolly, lollipop, popsicle",
946
+ "930": "French loaf",
947
+ "931": "bagel, beigel",
948
+ "932": "pretzel",
949
+ "933": "cheeseburger",
950
+ "934": "hotdog, hot dog, red hot",
951
+ "935": "mashed potato",
952
+ "936": "head cabbage",
953
+ "937": "broccoli",
954
+ "938": "cauliflower",
955
+ "939": "zucchini, courgette",
956
+ "940": "spaghetti squash",
957
+ "941": "acorn squash",
958
+ "942": "butternut squash",
959
+ "943": "cucumber, cuke",
960
+ "944": "artichoke, globe artichoke",
961
+ "945": "bell pepper",
962
+ "946": "cardoon",
963
+ "947": "mushroom",
964
+ "948": "Granny Smith",
965
+ "949": "strawberry",
966
+ "950": "orange",
967
+ "951": "lemon",
968
+ "952": "fig",
969
+ "953": "pineapple, ananas",
970
+ "954": "banana",
971
+ "955": "jackfruit, jak, jack",
972
+ "956": "custard apple",
973
+ "957": "pomegranate",
974
+ "958": "hay",
975
+ "959": "carbonara",
976
+ "960": "chocolate sauce, chocolate syrup",
977
+ "961": "dough",
978
+ "962": "meat loaf, meatloaf",
979
+ "963": "pizza, pizza pie",
980
+ "964": "potpie",
981
+ "965": "burrito",
982
+ "966": "red wine",
983
+ "967": "espresso",
984
+ "968": "cup",
985
+ "969": "eggnog",
986
+ "970": "alp",
987
+ "971": "bubble",
988
+ "972": "cliff, drop, drop-off",
989
+ "973": "coral reef",
990
+ "974": "geyser",
991
+ "975": "lakeside, lakeshore",
992
+ "976": "promontory, headland, head, foreland",
993
+ "977": "sandbar, sand bar",
994
+ "978": "seashore, coast, seacoast, sea-coast",
995
+ "979": "valley, vale",
996
+ "980": "volcano",
997
+ "981": "ballplayer, baseball player",
998
+ "982": "groom, bridegroom",
999
+ "983": "scuba diver",
1000
+ "984": "rapeseed",
1001
+ "985": "daisy",
1002
+ "986": "yellow ladys slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum",
1003
+ "987": "corn",
1004
+ "988": "acorn",
1005
+ "989": "hip, rose hip, rosehip",
1006
+ "990": "buckeye, horse chestnut, conker",
1007
+ "991": "coral fungus",
1008
+ "992": "agaric",
1009
+ "993": "gyromitra",
1010
+ "994": "stinkhorn, carrion fungus",
1011
+ "995": "earthstar",
1012
+ "996": "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa",
1013
+ "997": "bolete",
1014
+ "998": "ear, spike, capitulum",
1015
+ "999": "toilet tissue, toilet paper, bathroom tissue"
1016
+ }
1017
+ }
PixNerd-XL-16-512/pipeline.py CHANGED
@@ -1,353 +1,433 @@
1
- from __future__ import annotations
2
-
3
- import sys
4
- from dataclasses import dataclass
5
- from pathlib import Path
6
- from typing import List, Literal, Optional, Sequence, Union
7
-
8
- import torch
9
- from diffusers import DiffusionPipeline
10
- from diffusers.image_processor import VaeImageProcessor
11
- from diffusers.utils import BaseOutput
12
- from PIL import Image
13
-
14
- ConditioningInput = Union[str, int, Sequence[Union[str, int]]]
15
- Language = Literal["en", "cn"]
16
-
17
-
18
- @dataclass
19
- class PixNerdPipelineOutput(BaseOutput):
20
- images: Union[List[Image.Image], torch.Tensor, "np.ndarray"]
21
-
22
-
23
- class PixNerdPipeline(DiffusionPipeline):
24
- model_cpu_offload_seq = "conditioner->transformer->vae"
25
- _callback_tensor_inputs = ["latents"]
26
-
27
- def __init__(
28
- self,
29
- transformer,
30
- scheduler,
31
- vae=None,
32
- conditioner=None,
33
- id2label: Optional[dict[int, str]] = None,
34
- id2label_cn: Optional[dict[int, str]] = None,
35
- ):
36
- super().__init__()
37
- if vae is None:
38
- vae = getattr(transformer, "vae", None)
39
- if conditioner is None:
40
- conditioner = getattr(transformer, "conditioner", None)
41
- if vae is None or conditioner is None:
42
- raise ValueError("Pipeline requires `vae` and `conditioner` either explicitly or from `transformer`.")
43
- self.register_modules(
44
- vae=vae,
45
- conditioner=conditioner,
46
- transformer=transformer,
47
- scheduler=scheduler,
48
- )
49
- self.image_processor = VaeImageProcessor(vae_scale_factor=1)
50
-
51
- if id2label is None and id2label_cn is None:
52
- id2label, id2label_cn = self._load_repo_labels()
53
- self._id2label = id2label or {}
54
- self._id2label_cn = id2label_cn or {}
55
- self.labels = self._build_label2id(self._id2label)
56
- self.labels_cn = self._build_label2id(self._id2label_cn)
57
- self._labels_loaded_from_path = bool(self._id2label or self._id2label_cn)
58
-
59
- def _ensure_labels_loaded(self) -> None:
60
- if self._labels_loaded_from_path:
61
- return
62
-
63
- path = getattr(getattr(self, "config", None), "_name_or_path", None) or getattr(self, "_name_or_path", None)
64
- if not path:
65
- return
66
-
67
- id2label, id2label_cn = self._load_labels_for_path(path)
68
- if id2label is None and id2label_cn is None:
69
- self._labels_loaded_from_path = True
70
- return
71
-
72
- self._id2label = id2label or {}
73
- self._id2label_cn = id2label_cn or {}
74
- self.labels = self._build_label2id(self._id2label)
75
- self.labels_cn = self._build_label2id(self._id2label_cn)
76
- self._labels_loaded_from_path = True
77
-
78
- @staticmethod
79
- def _resolve_labels_dir(pretrained_model_name_or_path: Union[str, Path]) -> Optional[Path]:
80
- path = Path(pretrained_model_name_or_path)
81
- if not path.exists():
82
- try:
83
- from huggingface_hub import snapshot_download
84
-
85
- path = Path(snapshot_download(pretrained_model_name_or_path))
86
- except Exception:
87
- return None
88
-
89
- if (path / "model_index.json").exists():
90
- labels_dir = path.parent / "labels"
91
- else:
92
- labels_dir = path / "labels"
93
- return labels_dir if labels_dir.is_dir() else None
94
-
95
- @classmethod
96
- def _load_labels_for_path(
97
- cls,
98
- pretrained_model_name_or_path: Union[str, Path],
99
- ) -> tuple[Optional[dict[int, str]], Optional[dict[int, str]]]:
100
- labels_dir = cls._resolve_labels_dir(pretrained_model_name_or_path)
101
- if labels_dir is None:
102
- return None, None
103
-
104
- labels_path = str(labels_dir)
105
- inserted = False
106
- if labels_path not in sys.path:
107
- sys.path.insert(0, labels_path)
108
- inserted = True
109
- try:
110
- from imagenet_labels import load_id2label
111
-
112
- return (
113
- load_id2label(labels_dir, lang="en"),
114
- load_id2label(labels_dir, lang="cn"),
115
- )
116
- finally:
117
- if inserted and labels_path in sys.path:
118
- sys.path.remove(labels_path)
119
-
120
- @staticmethod
121
- def _load_repo_labels() -> tuple[Optional[dict[int, str]], Optional[dict[int, str]]]:
122
- labels_dir = Path(__file__).resolve().parent.parent / "labels"
123
- if not labels_dir.is_dir():
124
- return None, None
125
-
126
- labels_path = str(labels_dir)
127
- inserted = False
128
- if labels_path not in sys.path:
129
- sys.path.insert(0, labels_path)
130
- inserted = True
131
- try:
132
- from imagenet_labels import load_id2label
133
-
134
- return (
135
- load_id2label(labels_dir, lang="en"),
136
- load_id2label(labels_dir, lang="cn"),
137
- )
138
- finally:
139
- if inserted and labels_path in sys.path:
140
- sys.path.remove(labels_path)
141
-
142
- @classmethod
143
- def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
144
- pipe = super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
145
- id2label, id2label_cn = cls._load_labels_for_path(pretrained_model_name_or_path)
146
- if id2label is not None or id2label_cn is not None:
147
- pipe._id2label = id2label or {}
148
- pipe._id2label_cn = id2label_cn or {}
149
- pipe.labels = cls._build_label2id(pipe._id2label)
150
- pipe.labels_cn = cls._build_label2id(pipe._id2label_cn)
151
- return pipe
152
-
153
- @staticmethod
154
- def _build_label2id(id2label: dict[int, str]) -> dict[str, int]:
155
- label2id: dict[str, int] = {}
156
- for class_id, value in id2label.items():
157
- for synonym in value.split(","):
158
- synonym = synonym.strip()
159
- if synonym:
160
- label2id[synonym] = int(class_id)
161
- return dict(sorted(label2id.items()))
162
-
163
- @property
164
- def id2label(self) -> dict[int, str]:
165
- self._ensure_labels_loaded()
166
- return self._id2label
167
-
168
- @property
169
- def id2label_cn(self) -> dict[int, str]:
170
- self._ensure_labels_loaded()
171
- return self._id2label_cn
172
-
173
- def get_label_ids(
174
- self,
175
- labels: Union[str, List[str]],
176
- *,
177
- lang: Language = "en",
178
- ) -> List[int]:
179
- self._ensure_labels_loaded()
180
- if isinstance(labels, str):
181
- labels = [labels]
182
-
183
- label2id = self.labels if lang == "en" else self.labels_cn
184
- if not label2id:
185
- raise ValueError(
186
- f"No {lang} labels loaded. Ensure `labels/id2label_{lang}.json` exists next to the variant folder."
187
- )
188
-
189
- missing = [label for label in labels if label not in label2id]
190
- if missing:
191
- preview = ", ".join(list(label2id.keys())[:8])
192
- raise ValueError(
193
- f"Unknown label(s) for lang={lang!r}: {missing}. Example valid labels: {preview}, ..."
194
- )
195
- return [label2id[label] for label in labels]
196
-
197
- def _resolve_prompt_item(self, value: Union[str, int]) -> int:
198
- if isinstance(value, int):
199
- return value
200
- if value.isdigit():
201
- return int(value)
202
- if value in self.labels:
203
- return self.labels[value]
204
- if value in self.labels_cn:
205
- return self.labels_cn[value]
206
- raise ValueError(
207
- f"Unknown class label {value!r}. Pass an ImageNet class id or a synonym from "
208
- "`pipe.labels` / `pipe.labels_cn`."
209
- )
210
-
211
- def _resolve_prompts(self, prompts: List[Union[str, int]]) -> List[int]:
212
- self._ensure_labels_loaded()
213
- return [self._resolve_prompt_item(prompt) for prompt in prompts]
214
-
215
- @staticmethod
216
- def _fp_to_uint8(image: torch.Tensor) -> torch.Tensor:
217
- return torch.clip_((image + 1) * 127.5 + 0.5, 0, 255).to(torch.uint8)
218
-
219
- @staticmethod
220
- def _to_list(y: ConditioningInput) -> List[Union[str, int]]:
221
- if isinstance(y, (str, int)):
222
- return [y]
223
- return list(y)
224
-
225
- @staticmethod
226
- def _repeat(values: List[Union[str, int]], repeats: int) -> List[Union[str, int]]:
227
- if repeats == 1:
228
- return values
229
- expanded: List[Union[str, int]] = []
230
- for value in values:
231
- expanded.extend([value] * repeats)
232
- return expanded
233
-
234
- def encode_prompt(
235
- self,
236
- prompt: ConditioningInput,
237
- num_images_per_prompt: int,
238
- ):
239
- prompts = self._repeat(self._to_list(prompt), num_images_per_prompt)
240
- resolved = self._resolve_prompts(prompts)
241
- metadata = {"device": self._execution_device}
242
- with torch.no_grad():
243
- cond, uncond = self.conditioner(resolved, metadata)
244
- return cond, uncond, resolved
245
-
246
- def prepare_latents(
247
- self,
248
- batch_size: int,
249
- num_channels: int,
250
- height: int,
251
- width: int,
252
- generator: Optional[torch.Generator] = None,
253
- latents: Optional[torch.Tensor] = None,
254
- ) -> torch.Tensor:
255
- if latents is not None:
256
- return latents.to(device=self._execution_device, dtype=torch.float32)
257
- return torch.randn(
258
- (batch_size, num_channels, height, width),
259
- generator=generator,
260
- device=self._execution_device,
261
- dtype=torch.float32,
262
- )
263
-
264
- @torch.no_grad()
265
- def __call__(
266
- self,
267
- prompt: ConditioningInput,
268
- negative_prompt: Optional[ConditioningInput] = None,
269
- num_images_per_prompt: int = 1,
270
- height: int = 512,
271
- width: int = 512,
272
- num_inference_steps: int = 25,
273
- guidance_scale: float = 4.0,
274
- generator: Optional[torch.Generator] = None,
275
- seed: Optional[int] = None,
276
- latents: Optional[torch.Tensor] = None,
277
- output_type: str = "pil",
278
- return_dict: bool = True,
279
- timeshift: float = 3.0,
280
- order: int = 2,
281
- ) -> PixNerdPipelineOutput | tuple:
282
- patch_size = int(getattr(self.transformer, "patch_size", 1))
283
- channels = int(getattr(self.transformer, "in_channels", 3))
284
- height = (height // patch_size) * patch_size
285
- width = (width // patch_size) * patch_size
286
-
287
- if hasattr(self.transformer, "decoder_patch_scaling_h"):
288
- self.transformer.decoder_patch_scaling_h = height / 512
289
- self.transformer.decoder_patch_scaling_w = width / 512
290
-
291
- cond, default_uncond, prompts = self.encode_prompt(prompt, num_images_per_prompt)
292
- if negative_prompt is not None:
293
- negative = self._repeat(self._to_list(negative_prompt), num_images_per_prompt)
294
- resolved_negative = self._resolve_prompts(negative)
295
- metadata = {"device": self._execution_device}
296
- with torch.no_grad():
297
- _, uncond = self.conditioner(resolved_negative, metadata)
298
- else:
299
- uncond = default_uncond
300
- batch_size = len(prompts)
301
- if generator is None and seed is not None:
302
- generator = torch.Generator(device=self._execution_device).manual_seed(seed)
303
- latents = self.prepare_latents(
304
- batch_size=batch_size,
305
- num_channels=channels,
306
- height=height,
307
- width=width,
308
- generator=generator,
309
- latents=latents,
310
- )
311
- self.scheduler.set_timesteps(
312
- num_inference_steps=num_inference_steps,
313
- guidance_scale=guidance_scale,
314
- timeshift=timeshift,
315
- order=order,
316
- device=latents.device,
317
- )
318
- for timestep in self.scheduler.timesteps:
319
- cfg_latents = torch.cat([latents, latents], dim=0)
320
- cfg_t = timestep.repeat(cfg_latents.shape[0]).to(latents.device, dtype=latents.dtype)
321
- cfg_condition = torch.cat([uncond, cond], dim=0)
322
- model_output = self.transformer(
323
- sample=cfg_latents,
324
- timestep=cfg_t,
325
- encoder_hidden_states=cfg_condition,
326
- ).sample
327
- model_output = self.scheduler.classifier_free_guidance(model_output)
328
- latents = self.scheduler.step(
329
- model_output=model_output,
330
- timestep=timestep,
331
- sample=latents,
332
- ).prev_sample
333
-
334
- image = self.vae.decode(latents)
335
- images_uint8 = self._fp_to_uint8(image).permute(0, 2, 3, 1).cpu().numpy()
336
- if output_type == "pil":
337
- output = [Image.fromarray(img) for img in images_uint8]
338
- elif output_type == "pt":
339
- output = torch.from_numpy(images_uint8)
340
- elif output_type == "np":
341
- output = images_uint8
342
- else:
343
- raise ValueError(f"Unsupported output_type: {output_type}")
344
-
345
- if not return_dict:
346
- return (output,)
347
- return PixNerdPipelineOutput(images=output)
348
-
349
-
350
- __all__ = [
351
- "PixNerdPipeline",
352
- "PixNerdPipelineOutput",
353
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ import json
18
+ from pathlib import Path
19
+ from typing import Dict, List, Optional, Tuple, Union
20
+
21
+ import torch
22
+
23
+ from diffusers.image_processor import VaeImageProcessor
24
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
25
+ from diffusers.utils.torch_utils import randn_tensor
26
+
27
+ DEFAULT_NATIVE_RESOLUTION = 512
28
+
29
+ EXAMPLE_DOC_STRING = """
30
+ Examples:
31
+ ```py
32
+ >>> from pathlib import Path
33
+ >>> from diffusers import DiffusionPipeline
34
+ >>> import torch
35
+
36
+ >>> model_dir = Path("./PixNerd-XL-16-512").resolve()
37
+ >>> pipe = DiffusionPipeline.from_pretrained(
38
+ ... str(model_dir),
39
+ ... local_files_only=True,
40
+ ... custom_pipeline=str(model_dir / "pipeline.py"),
41
+ ... trust_remote_code=True,
42
+ ... torch_dtype=torch.bfloat16,
43
+ ... )
44
+ >>> pipe.to("cuda")
45
+
46
+ >>> print(pipe.id2label[207])
47
+ >>> print(pipe.get_label_ids("golden retriever"))
48
+
49
+ >>> generator = torch.Generator(device="cuda").manual_seed(42)
50
+ >>> # timeshift=3.0 and order=2 are defaults in scheduler/scheduler_config.json
51
+ >>> image = pipe(
52
+ ... class_labels="golden retriever",
53
+ ... height=512,
54
+ ... width=512,
55
+ ... num_inference_steps=25,
56
+ ... guidance_scale=4.0,
57
+ ... generator=generator,
58
+ ... ).images[0]
59
+ >>> image.save("demo.png")
60
+ ```
61
+ """
62
+
63
+ ConditioningInput = Union[int, str, List[Union[int, str]], torch.LongTensor]
64
+
65
+
66
+ class PixNerdPipeline(DiffusionPipeline):
67
+ r"""
68
+ Pipeline for class-conditional PixNerd pixel-space image generation.
69
+
70
+ Parameters:
71
+ transformer ([`PixNerdTransformer2DModel`]):
72
+ Class-conditional PixNerd denoiser operating in pixel space.
73
+ scheduler ([`PixNerdFlowMatchScheduler`]):
74
+ Flow-matching scheduler with AdamLM multi-step coefficients.
75
+ vae ([`PixNerdPixelVAE`], *optional*):
76
+ Identity pixel autoencoder. May also be attached to `transformer.vae`.
77
+ conditioner ([`PixNerdLabelConditioner`], *optional*):
78
+ ImageNet class-label conditioner. May also be attached to `transformer.conditioner`.
79
+ id2label (`dict[int, str]`, *optional*):
80
+ ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
81
+ """
82
+
83
+ model_cpu_offload_seq = "conditioner->transformer->vae"
84
+ _callback_tensor_inputs = ["latents"]
85
+ _optional_components = ["vae", "conditioner"]
86
+
87
+ def __init__(
88
+ self,
89
+ transformer,
90
+ scheduler,
91
+ vae=None,
92
+ conditioner=None,
93
+ id2label: Optional[Dict[Union[int, str], str]] = None,
94
+ ):
95
+ super().__init__()
96
+ if vae is None:
97
+ vae = getattr(transformer, "vae", None)
98
+ if conditioner is None:
99
+ conditioner = getattr(transformer, "conditioner", None)
100
+ if vae is None or conditioner is None:
101
+ raise ValueError("Pipeline requires `vae` and `conditioner` either explicitly or from `transformer`.")
102
+ self.register_modules(
103
+ vae=vae,
104
+ conditioner=conditioner,
105
+ transformer=transformer,
106
+ scheduler=scheduler,
107
+ )
108
+ self.image_processor = VaeImageProcessor(vae_scale_factor=1, do_normalize=False)
109
+ if id2label is None:
110
+ id2label = self._read_id2label_from_model_index(
111
+ getattr(getattr(self, "config", None), "_name_or_path", None)
112
+ )
113
+ self._id2label = self._normalize_id2label(id2label)
114
+ self.labels = self._build_label2id(self._id2label)
115
+ self._labels_loaded_from_model_index = bool(self._id2label)
116
+
117
+ def _get_device(self) -> torch.device:
118
+ try:
119
+ return self._execution_device
120
+ except AttributeError:
121
+ pass
122
+ for name in ("transformer", "vae", "scheduler"):
123
+ module = getattr(self, name, None)
124
+ if isinstance(module, torch.nn.Module):
125
+ parameter = next(module.parameters(), None)
126
+ if parameter is not None:
127
+ return parameter.device
128
+ return torch.device("cpu")
129
+
130
+ @classmethod
131
+ def from_pretrained(cls, pretrained_model_name_or_path=None, *args, **kwargs):
132
+ id2label_override = kwargs.pop("id2label", None)
133
+ pipe = super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
134
+ id2label = id2label_override or cls._read_id2label_from_model_index(pretrained_model_name_or_path)
135
+ if id2label:
136
+ pipe._id2label = cls._normalize_id2label(id2label)
137
+ pipe.labels = cls._build_label2id(pipe._id2label)
138
+ pipe._labels_loaded_from_model_index = True
139
+ return pipe
140
+
141
+ def _ensure_labels_loaded(self) -> None:
142
+ if self._labels_loaded_from_model_index:
143
+ return
144
+ loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None))
145
+ if loaded:
146
+ self._id2label = loaded
147
+ self.labels = self._build_label2id(self._id2label)
148
+ self._labels_loaded_from_model_index = True
149
+
150
+ @staticmethod
151
+ def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
152
+ if not id2label:
153
+ return {}
154
+ return {int(key): value for key, value in id2label.items()}
155
+
156
+ @staticmethod
157
+ def _read_id2label_from_model_index(variant_path: Optional[Union[str, Path]]) -> Dict[int, str]:
158
+ if not variant_path:
159
+ return {}
160
+ model_index_path = Path(variant_path).resolve() / "model_index.json"
161
+ if not model_index_path.exists():
162
+ return {}
163
+ raw = json.loads(model_index_path.read_text(encoding="utf-8"))
164
+ id2label = raw.get("id2label")
165
+ if not isinstance(id2label, dict):
166
+ return {}
167
+ return {int(key): value for key, value in id2label.items()}
168
+
169
+ @staticmethod
170
+ def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
171
+ label2id: Dict[str, int] = {}
172
+ for class_id, value in id2label.items():
173
+ for synonym in value.split(","):
174
+ synonym = synonym.strip()
175
+ if synonym:
176
+ label2id[synonym] = int(class_id)
177
+ return dict(sorted(label2id.items()))
178
+
179
+ @property
180
+ def id2label(self) -> Dict[int, str]:
181
+ r"""ImageNet class id to English label string (comma-separated synonyms)."""
182
+ self._ensure_labels_loaded()
183
+ return self._id2label
184
+
185
+ def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
186
+ r"""
187
+ Map ImageNet label strings to class ids.
188
+
189
+ Args:
190
+ label (`str` or `list[str]`):
191
+ One or more English label strings. Each string must match a synonym in `id2label`.
192
+ """
193
+ self._ensure_labels_loaded()
194
+ if isinstance(label, str):
195
+ label = [label]
196
+ if not self.labels:
197
+ raise ValueError("No English labels loaded. Ensure `id2label` exists in model_index.json.")
198
+ missing = [item for item in label if item not in self.labels]
199
+ if missing:
200
+ preview = ", ".join(list(self.labels.keys())[:8])
201
+ raise ValueError(f"Unknown English label(s): {missing}. Example valid labels: {preview}, ...")
202
+ return [self.labels[item] for item in label]
203
+
204
+ def _normalize_class_labels(
205
+ self,
206
+ class_labels: ConditioningInput,
207
+ num_images_per_prompt: int = 1,
208
+ ) -> List[int]:
209
+ if torch.is_tensor(class_labels):
210
+ values = class_labels.to(dtype=torch.long).reshape(-1).tolist()
211
+ elif isinstance(class_labels, int):
212
+ values = [class_labels]
213
+ elif isinstance(class_labels, str):
214
+ values = self.get_label_ids(class_labels)
215
+ elif class_labels and isinstance(class_labels[0], str):
216
+ values = self.get_label_ids(list(class_labels))
217
+ else:
218
+ values = [int(entry) for entry in class_labels]
219
+
220
+ if num_images_per_prompt == 1:
221
+ return values
222
+ expanded: List[int] = []
223
+ for value in values:
224
+ expanded.extend([value] * num_images_per_prompt)
225
+ return expanded
226
+
227
+ def _get_patch_size(self) -> int:
228
+ patch_size = getattr(self.transformer, "patch_size", None)
229
+ if patch_size is not None:
230
+ return int(patch_size)
231
+ return int(getattr(self.transformer.config, "patch_size", 16))
232
+
233
+ def _get_in_channels(self) -> int:
234
+ in_channels = getattr(self.transformer, "in_channels", None)
235
+ if in_channels is not None:
236
+ return int(in_channels)
237
+ return int(getattr(self.transformer.config, "in_channels", 3))
238
+
239
+ def check_inputs(
240
+ self,
241
+ height: int,
242
+ width: int,
243
+ num_inference_steps: int,
244
+ output_type: str,
245
+ ) -> None:
246
+ if num_inference_steps < 1:
247
+ raise ValueError("num_inference_steps must be >= 1.")
248
+ if output_type not in {"pil", "np", "pt", "latent"}:
249
+ raise ValueError("output_type must be one of: 'pil', 'np', 'pt', 'latent'.")
250
+ order = int(getattr(self.scheduler.config, "order", getattr(self.scheduler, "order", 2)))
251
+ if order < 1:
252
+ raise ValueError("scheduler.config.order must be >= 1.")
253
+
254
+ patch_size = self._get_patch_size()
255
+ if height % patch_size != 0 or width % patch_size != 0:
256
+ raise ValueError(f"height and width must be divisible by patch_size={patch_size}.")
257
+
258
+ def encode_condition(
259
+ self,
260
+ class_label_ids: List[int],
261
+ negative_class_label_ids: Optional[List[int]] = None,
262
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
263
+ metadata = {"device": self._get_device()}
264
+ with torch.no_grad():
265
+ cond, default_uncond = self.conditioner(class_label_ids, metadata)
266
+ if negative_class_label_ids is not None:
267
+ _, uncond = self.conditioner(negative_class_label_ids, metadata)
268
+ else:
269
+ uncond = default_uncond
270
+ return cond, uncond
271
+
272
+ def prepare_latents(
273
+ self,
274
+ batch_size: int,
275
+ num_channels: int,
276
+ height: int,
277
+ width: int,
278
+ dtype: torch.dtype,
279
+ device: torch.device,
280
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
281
+ latents: Optional[torch.Tensor] = None,
282
+ ) -> torch.Tensor:
283
+ if latents is not None:
284
+ return latents.to(device=device, dtype=dtype)
285
+ return randn_tensor(
286
+ (batch_size, num_channels, height, width),
287
+ generator=generator,
288
+ device=device,
289
+ dtype=dtype,
290
+ )
291
+
292
+ @staticmethod
293
+ def _fp_to_uint8(image: torch.Tensor) -> torch.Tensor:
294
+ return torch.clip_((image + 1) * 127.5 + 0.5, 0, 255).to(torch.uint8)
295
+
296
+ def decode_latents(self, latents: torch.Tensor, output_type: str = "pil"):
297
+ if output_type == "latent":
298
+ return latents
299
+
300
+ image = self.vae.decode(latents)
301
+ if output_type == "pt":
302
+ return image
303
+ images_uint8 = self._fp_to_uint8(image).permute(0, 2, 3, 1).cpu().numpy()
304
+ if output_type == "np":
305
+ return images_uint8
306
+ if output_type == "pil":
307
+ from PIL import Image
308
+
309
+ return [Image.fromarray(img) for img in images_uint8]
310
+ raise ValueError(f"Unsupported output_type: {output_type}")
311
+
312
+ def _apply_decoder_patch_scaling(self, height: int, width: int) -> None:
313
+ denoiser = getattr(self.transformer, "denoiser", self.transformer)
314
+ if hasattr(denoiser, "decoder_patch_scaling_h"):
315
+ denoiser.decoder_patch_scaling_h = height / DEFAULT_NATIVE_RESOLUTION
316
+ denoiser.decoder_patch_scaling_w = width / DEFAULT_NATIVE_RESOLUTION
317
+
318
+ @torch.inference_mode()
319
+ def __call__(
320
+ self,
321
+ class_labels: Optional[ConditioningInput] = None,
322
+ negative_class_labels: Optional[ConditioningInput] = None,
323
+ num_images_per_prompt: int = 1,
324
+ height: Optional[int] = None,
325
+ width: Optional[int] = None,
326
+ num_inference_steps: int = 25,
327
+ guidance_scale: float = 4.0,
328
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
329
+ latents: Optional[torch.Tensor] = None,
330
+ output_type: str = "pil",
331
+ return_dict: bool = True,
332
+ prompt: Optional[ConditioningInput] = None,
333
+ negative_prompt: Optional[ConditioningInput] = None,
334
+ ) -> Union[ImagePipelineOutput, Tuple]:
335
+ r"""
336
+ Generate class-conditional images with PixNerd.
337
+
338
+ Args:
339
+ class_labels (`int`, `str`, `list[int]`, `list[str]`, or `torch.LongTensor`):
340
+ ImageNet class indices or human-readable English label strings.
341
+ negative_class_labels (`int`, `str`, `list[int]`, `list[str]`, or `torch.LongTensor`, *optional*):
342
+ Optional negative class labels for classifier-free guidance.
343
+ num_images_per_prompt (`int`, defaults to `1`):
344
+ Number of images to generate per class label.
345
+ height (`int`, *optional*):
346
+ Output image height in pixels. Defaults to `512`.
347
+ width (`int`, *optional*):
348
+ Output image width in pixels. Defaults to `512`.
349
+ num_inference_steps (`int`, defaults to `25`):
350
+ Number of denoising steps.
351
+ guidance_scale (`float`, defaults to `4.0`):
352
+ Classifier-free guidance scale applied by the scheduler.
353
+ generator (`torch.Generator`, *optional*):
354
+ RNG for reproducibility.
355
+ latents (`torch.Tensor`, *optional*):
356
+ Pre-generated noisy pixel tensor.
357
+ output_type (`str`, defaults to `"pil"`):
358
+ `"pil"`, `"np"`, `"pt"`, or `"latent"`.
359
+ return_dict (`bool`, defaults to `True`):
360
+ Return [`ImagePipelineOutput`] if True.
361
+ prompt (`int`, `str`, `list`, *optional*):
362
+ Deprecated alias for `class_labels`.
363
+ negative_prompt (`int`, `str`, `list`, *optional*):
364
+ Deprecated alias for `negative_class_labels`.
365
+ """
366
+ if class_labels is None:
367
+ class_labels = prompt
368
+ if negative_class_labels is None:
369
+ negative_class_labels = negative_prompt
370
+ if class_labels is None:
371
+ raise ValueError("`class_labels` (or deprecated `prompt`) must be provided.")
372
+
373
+ height = int(height or DEFAULT_NATIVE_RESOLUTION)
374
+ width = int(width or DEFAULT_NATIVE_RESOLUTION)
375
+ self.check_inputs(height, width, num_inference_steps, output_type)
376
+
377
+ patch_size = self._get_patch_size()
378
+ height = (height // patch_size) * patch_size
379
+ width = (width // patch_size) * patch_size
380
+ self._apply_decoder_patch_scaling(height, width)
381
+
382
+ class_label_ids = self._normalize_class_labels(class_labels, num_images_per_prompt)
383
+ negative_label_ids = None
384
+ if negative_class_labels is not None:
385
+ negative_label_ids = self._normalize_class_labels(negative_class_labels, num_images_per_prompt)
386
+
387
+ device = self._get_device()
388
+ model_dtype = next(self.transformer.parameters()).dtype
389
+ batch_size = len(class_label_ids)
390
+
391
+ cond, uncond = self.encode_condition(class_label_ids, negative_label_ids)
392
+ latents = self.prepare_latents(
393
+ batch_size=batch_size,
394
+ num_channels=self._get_in_channels(),
395
+ height=height,
396
+ width=width,
397
+ dtype=model_dtype,
398
+ device=device,
399
+ generator=generator,
400
+ latents=latents,
401
+ )
402
+
403
+ self.scheduler.set_timesteps(
404
+ num_inference_steps=num_inference_steps,
405
+ guidance_scale=guidance_scale,
406
+ device=device,
407
+ )
408
+
409
+ for timestep in self.progress_bar(self.scheduler.timesteps):
410
+ cfg_latents = torch.cat([latents, latents], dim=0)
411
+ cfg_t = timestep.repeat(cfg_latents.shape[0]).to(device=device, dtype=latents.dtype)
412
+ cfg_condition = torch.cat([uncond, cond], dim=0)
413
+ model_output = self.transformer(
414
+ sample=cfg_latents.to(dtype=model_dtype),
415
+ timestep=cfg_t,
416
+ encoder_hidden_states=cfg_condition,
417
+ ).sample
418
+ model_output = self.scheduler.classifier_free_guidance(model_output)
419
+ latents = self.scheduler.step(
420
+ model_output=model_output,
421
+ timestep=timestep,
422
+ sample=latents,
423
+ ).prev_sample
424
+
425
+ image = self.decode_latents(latents, output_type=output_type)
426
+
427
+ self.maybe_free_model_hooks()
428
+ if not return_dict:
429
+ return (image,)
430
+ return ImagePipelineOutput(images=image)
431
+
432
+
433
+ PixNerdPipelineOutput = ImagePipelineOutput
PixNerd-XL-16-512/scheduler/scheduling_pixnerd_flow_match.py CHANGED
@@ -1,231 +1,237 @@
1
- from __future__ import annotations
2
-
3
- from dataclasses import dataclass
4
- from typing import Any, Dict, List, Optional, Tuple, Union
5
-
6
- import torch
7
- from diffusers.configuration_utils import ConfigMixin, register_to_config
8
- from diffusers.schedulers.scheduling_utils import SchedulerMixin
9
- from diffusers.utils import BaseOutput
10
-
11
- @dataclass
12
- class PixNerdSchedulerOutput(BaseOutput):
13
- prev_sample: torch.Tensor
14
-
15
-
16
- class PixNerdFlowMatchScheduler(SchedulerMixin, ConfigMixin):
17
- """
18
- Diffusers-compatible scheduler wrapper for PixNerd's AdamLM flow-matching sampler.
19
- """
20
-
21
- config_name = "scheduler_config.json"
22
- order = 1
23
- init_noise_sigma = 1.0
24
-
25
- @staticmethod
26
- def _lagrange_coeffs(order: int, pre_ts: torch.Tensor, t_start: torch.Tensor, t_end: torch.Tensor) -> List[float]:
27
- ts = [float(v) for v in pre_ts[-order:].tolist()]
28
- a = float(t_start)
29
- b = float(t_end)
30
-
31
- if order == 1:
32
- return [1.0]
33
- if order == 2:
34
- t1, t2 = ts
35
- int1 = 0.5 / (t1 - t2) * ((b - t2) ** 2 - (a - t2) ** 2)
36
- int2 = 0.5 / (t2 - t1) * ((b - t1) ** 2 - (a - t1) ** 2)
37
- total = int1 + int2
38
- return [int1 / total, int2 / total]
39
- if order == 3:
40
- t1, t2, t3 = ts
41
- int1_denom = (t1 - t2) * (t1 - t3)
42
- int1 = ((1 / 3) * b**3 - 0.5 * (t2 + t3) * b**2 + (t2 * t3) * b) - (
43
- (1 / 3) * a**3 - 0.5 * (t2 + t3) * a**2 + (t2 * t3) * a
44
- )
45
- int1 = int1 / int1_denom
46
- int2_denom = (t2 - t1) * (t2 - t3)
47
- int2 = ((1 / 3) * b**3 - 0.5 * (t1 + t3) * b**2 + (t1 * t3) * b) - (
48
- (1 / 3) * a**3 - 0.5 * (t1 + t3) * a**2 + (t1 * t3) * a
49
- )
50
- int2 = int2 / int2_denom
51
- int3_denom = (t3 - t1) * (t3 - t2)
52
- int3 = ((1 / 3) * b**3 - 0.5 * (t1 + t2) * b**2 + (t1 * t2) * b) - (
53
- (1 / 3) * a**3 - 0.5 * (t1 + t2) * a**2 + (t1 * t2) * a
54
- )
55
- int3 = int3 / int3_denom
56
- total = int1 + int2 + int3
57
- return [int1 / total, int2 / total, int3 / total]
58
- if order == 4:
59
- t1, t2, t3, t4 = ts
60
- int1_denom = (t1 - t2) * (t1 - t3) * (t1 - t4)
61
- int1 = ((1 / 4) * b**4 - (1 / 3) * (t2 + t3 + t4) * b**3 + 0.5 * (t3 * t4 + t2 * t3 + t2 * t4) * b**2 - (t2 * t3 * t4) * b) - (
62
- (1 / 4) * a**4 - (1 / 3) * (t2 + t3 + t4) * a**3 + 0.5 * (t3 * t4 + t2 * t3 + t2 * t4) * a**2 - (t2 * t3 * t4) * a
63
- )
64
- int1 = int1 / int1_denom
65
- int2_denom = (t2 - t1) * (t2 - t3) * (t2 - t4)
66
- int2 = ((1 / 4) * b**4 - (1 / 3) * (t1 + t3 + t4) * b**3 + 0.5 * (t3 * t4 + t1 * t3 + t1 * t4) * b**2 - (t1 * t3 * t4) * b) - (
67
- (1 / 4) * a**4 - (1 / 3) * (t1 + t3 + t4) * a**3 + 0.5 * (t3 * t4 + t1 * t3 + t1 * t4) * a**2 - (t1 * t3 * t4) * a
68
- )
69
- int2 = int2 / int2_denom
70
- int3_denom = (t3 - t1) * (t3 - t2) * (t3 - t4)
71
- int3 = ((1 / 4) * b**4 - (1 / 3) * (t1 + t2 + t4) * b**3 + 0.5 * (t4 * t2 + t1 * t2 + t1 * t4) * b**2 - (t1 * t2 * t4) * b) - (
72
- (1 / 4) * a**4 - (1 / 3) * (t1 + t2 + t4) * a**3 + 0.5 * (t4 * t2 + t1 * t2 + t1 * t4) * a**2 - (t1 * t2 * t4) * a
73
- )
74
- int3 = int3 / int3_denom
75
- int4_denom = (t4 - t1) * (t4 - t2) * (t4 - t3)
76
- int4 = ((1 / 4) * b**4 - (1 / 3) * (t1 + t2 + t3) * b**3 + 0.5 * (t3 * t2 + t1 * t2 + t1 * t3) * b**2 - (t1 * t2 * t3) * b) - (
77
- (1 / 4) * a**4 - (1 / 3) * (t1 + t2 + t3) * a**3 + 0.5 * (t3 * t2 + t1 * t2 + t1 * t3) * a**2 - (t1 * t2 * t3) * a
78
- )
79
- int4 = int4 / int4_denom
80
- total = int1 + int2 + int3 + int4
81
- return [int1 / total, int2 / total, int3 / total, int4 / total]
82
- raise ValueError(f"Unsupported solver order: {order}.")
83
-
84
- @register_to_config
85
- def __init__(
86
- self,
87
- num_train_timesteps: int = 1000,
88
- num_inference_steps: int = 25,
89
- guidance_scale: float = 4.0,
90
- timeshift: float = 3.0,
91
- order: int = 2,
92
- guidance_interval_min: float = 0.0,
93
- guidance_interval_max: float = 1.0,
94
- last_step: Optional[float] = None,
95
- ) -> None:
96
- self.num_inference_steps = int(num_inference_steps)
97
- self.guidance_scale = float(guidance_scale)
98
- self.timeshift = float(timeshift)
99
- self.order = int(order)
100
- self.guidance_interval_min = float(guidance_interval_min)
101
- self.guidance_interval_max = float(guidance_interval_max)
102
- self.last_step = last_step
103
- self._reset_state()
104
-
105
- @classmethod
106
- def from_sampler_spec(cls, sampler_spec: Dict[str, Any]) -> "PixNerdFlowMatchScheduler":
107
- init_args = dict(sampler_spec.get("init_args", {}))
108
- return cls(
109
- num_inference_steps=int(init_args.get("num_steps", 25)),
110
- guidance_scale=float(init_args.get("guidance", 4.0)),
111
- timeshift=float(init_args.get("timeshift", 3.0)),
112
- order=int(init_args.get("order", 2)),
113
- guidance_interval_min=float(init_args.get("guidance_interval_min", 0.0)),
114
- guidance_interval_max=float(init_args.get("guidance_interval_max", 1.0)),
115
- last_step=init_args.get("last_step"),
116
- )
117
-
118
- def _reset_state(self) -> None:
119
- self.timesteps: Optional[torch.Tensor] = None
120
- self._timedeltas: Optional[torch.Tensor] = None
121
- self._solver_coeffs = None
122
- self._model_outputs = []
123
- self._step_index = 0
124
-
125
- @staticmethod
126
- def _shift_respace_fn(t: torch.Tensor, shift: float = 3.0) -> torch.Tensor:
127
- return t / (t + (1 - t) * shift)
128
-
129
- def _build_solver_state(
130
- self,
131
- num_inference_steps: int,
132
- timeshift: float,
133
- device: Optional[Union[str, torch.device]] = None,
134
- ) -> Tuple[torch.Tensor, torch.Tensor, List[List[float]]]:
135
- last_step = self.last_step
136
- if last_step is None:
137
- last_step = 1.0 / float(num_inference_steps)
138
-
139
- endpoints = torch.linspace(0.0, 1 - float(last_step), int(num_inference_steps), dtype=torch.float32)
140
- endpoints = torch.cat([endpoints, torch.tensor([1.0], dtype=torch.float32)], dim=0)
141
- timesteps = self._shift_respace_fn(endpoints, timeshift).to(device=device)
142
- timedeltas = (timesteps[1:] - timesteps[:-1]).to(device=device)
143
-
144
- solver_coeffs: List[List[float]] = [[] for _ in range(int(num_inference_steps))]
145
- for i in range(int(num_inference_steps)):
146
- order = min(self.order, i + 1)
147
- pre_ts = timesteps[: i + 1]
148
- coeffs = self._lagrange_coeffs(order, pre_ts, pre_ts[i], timesteps[i + 1])
149
- solver_coeffs[i] = coeffs
150
- return timesteps[:-1], timedeltas, solver_coeffs
151
-
152
- def set_timesteps(
153
- self,
154
- num_inference_steps: Optional[int] = None,
155
- device: Optional[Union[str, torch.device]] = None,
156
- timeshift: Optional[float] = None,
157
- guidance_scale: Optional[float] = None,
158
- order: Optional[int] = None,
159
- **kwargs: Any,
160
- ) -> None:
161
- if num_inference_steps is not None:
162
- self.num_inference_steps = int(num_inference_steps)
163
- if timeshift is not None:
164
- self.timeshift = float(timeshift)
165
- if guidance_scale is not None:
166
- self.guidance_scale = float(guidance_scale)
167
- if order is not None:
168
- self.order = int(order)
169
-
170
- timesteps, timedeltas, solver_coeffs = self._build_solver_state(
171
- self.num_inference_steps,
172
- self.timeshift,
173
- device=device,
174
- )
175
- self.timesteps = timesteps
176
- self._timedeltas = timedeltas
177
- self._solver_coeffs = solver_coeffs
178
- self._model_outputs = []
179
- self._step_index = 0
180
-
181
- def scale_model_input(self, sample: torch.Tensor, timestep: Optional[torch.Tensor] = None) -> torch.Tensor:
182
- return sample
183
-
184
- def classifier_free_guidance(self, model_output: torch.Tensor) -> torch.Tensor:
185
- if model_output.shape[0] % 2 != 0:
186
- raise ValueError("Classifier-free guidance expects concatenated unconditional/conditional batches.")
187
- uncond, cond = model_output.chunk(2, dim=0)
188
- return uncond + self.guidance_scale * (cond - uncond)
189
-
190
- def step(
191
- self,
192
- model_output: torch.Tensor,
193
- timestep: Union[torch.Tensor, float, int],
194
- sample: torch.Tensor,
195
- return_dict: bool = True,
196
- **kwargs: Any,
197
- ) -> Union[PixNerdSchedulerOutput, Tuple[torch.Tensor]]:
198
- if self.timesteps is None or self._timedeltas is None or self._solver_coeffs is None:
199
- raise RuntimeError("`set_timesteps` must be called before `step`.")
200
- if self._step_index >= len(self._solver_coeffs):
201
- raise RuntimeError("Scheduler step index exceeded configured timesteps.")
202
-
203
- coeffs = self._solver_coeffs[self._step_index]
204
- self._model_outputs.append(model_output)
205
- order = len(coeffs)
206
- pred = torch.zeros_like(model_output)
207
- recent = self._model_outputs[-order:]
208
- for coeff, output in zip(coeffs, recent):
209
- pred = pred + coeff * output
210
-
211
- prev_sample = sample + pred * self._timedeltas[self._step_index]
212
- self._step_index += 1
213
-
214
- if not return_dict:
215
- return (prev_sample,)
216
- return PixNerdSchedulerOutput(prev_sample=prev_sample)
217
-
218
- def add_noise(
219
- self,
220
- original_samples: torch.Tensor,
221
- noise: torch.Tensor,
222
- timesteps: torch.Tensor,
223
- ) -> torch.Tensor:
224
- alpha = timesteps.view(-1, 1, 1, 1)
225
- sigma = (1.0 - timesteps).view(-1, 1, 1, 1)
226
- return alpha * original_samples + sigma * noise
227
-
228
- __all__ = [
229
- "PixNerdFlowMatchScheduler",
230
- "PixNerdSchedulerOutput",
231
- ]
 
 
 
 
 
 
 
1
+ # Copyright 2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+
6
+ from __future__ import annotations
7
+
8
+ from dataclasses import dataclass
9
+ from typing import Any, Dict, List, Optional, Tuple, Union
10
+
11
+ import torch
12
+
13
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
14
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
15
+ from diffusers.utils import BaseOutput
16
+
17
+
18
+ @dataclass
19
+ class PixNerdSchedulerOutput(BaseOutput):
20
+ prev_sample: torch.Tensor
21
+
22
+
23
+ class PixNerdFlowMatchScheduler(SchedulerMixin, ConfigMixin):
24
+ """
25
+ Diffusers-compatible scheduler wrapper for PixNerd's AdamLM flow-matching sampler.
26
+ """
27
+
28
+ config_name = "scheduler_config.json"
29
+ order = 1
30
+ init_noise_sigma = 1.0
31
+
32
+ @staticmethod
33
+ def _lagrange_coeffs(order: int, pre_ts: torch.Tensor, t_start: torch.Tensor, t_end: torch.Tensor) -> List[float]:
34
+ ts = [float(v) for v in pre_ts[-order:].tolist()]
35
+ a = float(t_start)
36
+ b = float(t_end)
37
+
38
+ if order == 1:
39
+ return [1.0]
40
+ if order == 2:
41
+ t1, t2 = ts
42
+ int1 = 0.5 / (t1 - t2) * ((b - t2) ** 2 - (a - t2) ** 2)
43
+ int2 = 0.5 / (t2 - t1) * ((b - t1) ** 2 - (a - t1) ** 2)
44
+ total = int1 + int2
45
+ return [int1 / total, int2 / total]
46
+ if order == 3:
47
+ t1, t2, t3 = ts
48
+ int1_denom = (t1 - t2) * (t1 - t3)
49
+ int1 = ((1 / 3) * b**3 - 0.5 * (t2 + t3) * b**2 + (t2 * t3) * b) - (
50
+ (1 / 3) * a**3 - 0.5 * (t2 + t3) * a**2 + (t2 * t3) * a
51
+ )
52
+ int1 = int1 / int1_denom
53
+ int2_denom = (t2 - t1) * (t2 - t3)
54
+ int2 = ((1 / 3) * b**3 - 0.5 * (t1 + t3) * b**2 + (t1 * t3) * b) - (
55
+ (1 / 3) * a**3 - 0.5 * (t1 + t3) * a**2 + (t1 * t3) * a
56
+ )
57
+ int2 = int2 / int2_denom
58
+ int3_denom = (t3 - t1) * (t3 - t2)
59
+ int3 = ((1 / 3) * b**3 - 0.5 * (t1 + t2) * b**2 + (t1 * t2) * b) - (
60
+ (1 / 3) * a**3 - 0.5 * (t1 + t2) * a**2 + (t1 * t2) * a
61
+ )
62
+ int3 = int3 / int3_denom
63
+ total = int1 + int2 + int3
64
+ return [int1 / total, int2 / total, int3 / total]
65
+ if order == 4:
66
+ t1, t2, t3, t4 = ts
67
+ int1_denom = (t1 - t2) * (t1 - t3) * (t1 - t4)
68
+ int1 = ((1 / 4) * b**4 - (1 / 3) * (t2 + t3 + t4) * b**3 + 0.5 * (t3 * t4 + t2 * t3 + t2 * t4) * b**2 - (t2 * t3 * t4) * b) - (
69
+ (1 / 4) * a**4 - (1 / 3) * (t2 + t3 + t4) * a**3 + 0.5 * (t3 * t4 + t2 * t3 + t2 * t4) * a**2 - (t2 * t3 * t4) * a
70
+ )
71
+ int1 = int1 / int1_denom
72
+ int2_denom = (t2 - t1) * (t2 - t3) * (t2 - t4)
73
+ int2 = ((1 / 4) * b**4 - (1 / 3) * (t1 + t3 + t4) * b**3 + 0.5 * (t3 * t4 + t1 * t3 + t1 * t4) * b**2 - (t1 * t3 * t4) * b) - (
74
+ (1 / 4) * a**4 - (1 / 3) * (t1 + t3 + t4) * a**3 + 0.5 * (t3 * t4 + t1 * t3 + t1 * t4) * a**2 - (t1 * t3 * t4) * a
75
+ )
76
+ int2 = int2 / int2_denom
77
+ int3_denom = (t3 - t1) * (t3 - t2) * (t3 - t4)
78
+ int3 = ((1 / 4) * b**4 - (1 / 3) * (t1 + t2 + t4) * b**3 + 0.5 * (t4 * t2 + t1 * t2 + t1 * t4) * b**2 - (t1 * t2 * t4) * b) - (
79
+ (1 / 4) * a**4 - (1 / 3) * (t1 + t2 + t4) * a**3 + 0.5 * (t4 * t2 + t1 * t2 + t1 * t4) * a**2 - (t1 * t2 * t4) * a
80
+ )
81
+ int3 = int3 / int3_denom
82
+ int4_denom = (t4 - t1) * (t4 - t2) * (t4 - t3)
83
+ int4 = ((1 / 4) * b**4 - (1 / 3) * (t1 + t2 + t3) * b**3 + 0.5 * (t3 * t2 + t1 * t2 + t1 * t3) * b**2 - (t1 * t2 * t3) * b) - (
84
+ (1 / 4) * a**4 - (1 / 3) * (t1 + t2 + t3) * a**3 + 0.5 * (t3 * t2 + t1 * t2 + t1 * t3) * a**2 - (t1 * t2 * t3) * a
85
+ )
86
+ int4 = int4 / int4_denom
87
+ total = int1 + int2 + int3 + int4
88
+ return [int1 / total, int2 / total, int3 / total, int4 / total]
89
+ raise ValueError(f"Unsupported solver order: {order}.")
90
+
91
+ @register_to_config
92
+ def __init__(
93
+ self,
94
+ num_train_timesteps: int = 1000,
95
+ num_inference_steps: int = 25,
96
+ guidance_scale: float = 4.0,
97
+ timeshift: float = 3.0,
98
+ order: int = 2,
99
+ guidance_interval_min: float = 0.0,
100
+ guidance_interval_max: float = 1.0,
101
+ last_step: Optional[float] = None,
102
+ ) -> None:
103
+ self.num_inference_steps = int(num_inference_steps)
104
+ self.guidance_scale = float(guidance_scale)
105
+ self.timeshift = float(timeshift)
106
+ self.order = int(order)
107
+ self.guidance_interval_min = float(guidance_interval_min)
108
+ self.guidance_interval_max = float(guidance_interval_max)
109
+ self.last_step = last_step
110
+ self._reset_state()
111
+
112
+ @classmethod
113
+ def from_sampler_spec(cls, sampler_spec: Dict[str, Any]) -> "PixNerdFlowMatchScheduler":
114
+ init_args = dict(sampler_spec.get("init_args", {}))
115
+ return cls(
116
+ num_inference_steps=int(init_args.get("num_steps", 25)),
117
+ guidance_scale=float(init_args.get("guidance", 4.0)),
118
+ timeshift=float(init_args.get("timeshift", 3.0)),
119
+ order=int(init_args.get("order", 2)),
120
+ guidance_interval_min=float(init_args.get("guidance_interval_min", 0.0)),
121
+ guidance_interval_max=float(init_args.get("guidance_interval_max", 1.0)),
122
+ last_step=init_args.get("last_step"),
123
+ )
124
+
125
+ def _reset_state(self) -> None:
126
+ self.timesteps: Optional[torch.Tensor] = None
127
+ self._timedeltas: Optional[torch.Tensor] = None
128
+ self._solver_coeffs = None
129
+ self._model_outputs = []
130
+ self._step_index = 0
131
+
132
+ @staticmethod
133
+ def _shift_respace_fn(t: torch.Tensor, shift: float = 3.0) -> torch.Tensor:
134
+ return t / (t + (1 - t) * shift)
135
+
136
+ def _build_solver_state(
137
+ self,
138
+ num_inference_steps: int,
139
+ timeshift: float,
140
+ device: Optional[Union[str, torch.device]] = None,
141
+ ) -> Tuple[torch.Tensor, torch.Tensor, List[List[float]]]:
142
+ last_step = self.last_step
143
+ if last_step is None:
144
+ last_step = 1.0 / float(num_inference_steps)
145
+
146
+ endpoints = torch.linspace(0.0, 1 - float(last_step), int(num_inference_steps), dtype=torch.float32)
147
+ endpoints = torch.cat([endpoints, torch.tensor([1.0], dtype=torch.float32)], dim=0)
148
+ timesteps = self._shift_respace_fn(endpoints, timeshift).to(device=device)
149
+ timedeltas = (timesteps[1:] - timesteps[:-1]).to(device=device)
150
+
151
+ solver_coeffs: List[List[float]] = [[] for _ in range(int(num_inference_steps))]
152
+ for i in range(int(num_inference_steps)):
153
+ order = min(self.order, i + 1)
154
+ pre_ts = timesteps[: i + 1]
155
+ coeffs = self._lagrange_coeffs(order, pre_ts, pre_ts[i], timesteps[i + 1])
156
+ solver_coeffs[i] = coeffs
157
+ return timesteps[:-1], timedeltas, solver_coeffs
158
+
159
+ def set_timesteps(
160
+ self,
161
+ num_inference_steps: Optional[int] = None,
162
+ device: Optional[Union[str, torch.device]] = None,
163
+ timeshift: Optional[float] = None,
164
+ guidance_scale: Optional[float] = None,
165
+ order: Optional[int] = None,
166
+ **kwargs: Any,
167
+ ) -> None:
168
+ if num_inference_steps is not None:
169
+ self.num_inference_steps = int(num_inference_steps)
170
+ if timeshift is not None:
171
+ self.timeshift = float(timeshift)
172
+ else:
173
+ self.timeshift = float(getattr(self.config, "timeshift", self.timeshift))
174
+ if guidance_scale is not None:
175
+ self.guidance_scale = float(guidance_scale)
176
+ if order is not None:
177
+ self.order = int(order)
178
+ else:
179
+ self.order = int(getattr(self.config, "order", self.order))
180
+
181
+ timesteps, timedeltas, solver_coeffs = self._build_solver_state(
182
+ self.num_inference_steps,
183
+ self.timeshift,
184
+ device=device,
185
+ )
186
+ self.timesteps = timesteps
187
+ self._timedeltas = timedeltas
188
+ self._solver_coeffs = solver_coeffs
189
+ self._model_outputs = []
190
+ self._step_index = 0
191
+
192
+ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[torch.Tensor] = None) -> torch.Tensor:
193
+ return sample
194
+
195
+ def classifier_free_guidance(self, model_output: torch.Tensor) -> torch.Tensor:
196
+ if model_output.shape[0] % 2 != 0:
197
+ raise ValueError("Classifier-free guidance expects concatenated unconditional/conditional batches.")
198
+ uncond, cond = model_output.chunk(2, dim=0)
199
+ return uncond + self.guidance_scale * (cond - uncond)
200
+
201
+ def step(
202
+ self,
203
+ model_output: torch.Tensor,
204
+ timestep: Union[torch.Tensor, float, int],
205
+ sample: torch.Tensor,
206
+ return_dict: bool = True,
207
+ **kwargs: Any,
208
+ ) -> Union[PixNerdSchedulerOutput, Tuple[torch.Tensor]]:
209
+ if self.timesteps is None or self._timedeltas is None or self._solver_coeffs is None:
210
+ raise RuntimeError("`set_timesteps` must be called before `step`.")
211
+ if self._step_index >= len(self._solver_coeffs):
212
+ raise RuntimeError("Scheduler step index exceeded configured timesteps.")
213
+
214
+ coeffs = self._solver_coeffs[self._step_index]
215
+ self._model_outputs.append(model_output)
216
+ order = len(coeffs)
217
+ pred = torch.zeros_like(model_output)
218
+ recent = self._model_outputs[-order:]
219
+ for coeff, output in zip(coeffs, recent):
220
+ pred = pred + coeff * output
221
+
222
+ prev_sample = sample + pred * self._timedeltas[self._step_index]
223
+ self._step_index += 1
224
+
225
+ if not return_dict:
226
+ return (prev_sample,)
227
+ return PixNerdSchedulerOutput(prev_sample=prev_sample)
228
+
229
+ def add_noise(
230
+ self,
231
+ original_samples: torch.Tensor,
232
+ noise: torch.Tensor,
233
+ timesteps: torch.Tensor,
234
+ ) -> torch.Tensor:
235
+ alpha = timesteps.view(-1, 1, 1, 1)
236
+ sigma = (1.0 - timesteps).view(-1, 1, 1, 1)
237
+ return alpha * original_samples + sigma * noise
PixNerd-XL-16-512/transformer/modeling_pixnerd_transformer_2d.py CHANGED
@@ -20,6 +20,15 @@ class BaseAE(torch.nn.Module):
20
  super().__init__()
21
  self.scale = scale
22
  self.shift = shift
 
 
 
 
 
 
 
 
 
23
 
24
  def encode(self, x):
25
  return self._impl_encode(x) #.to(torch.bfloat16)
@@ -68,6 +77,15 @@ def resolve_conditioner_device(metadata: dict, fallback: torch.device | None = N
68
  class BaseConditioner(nn.Module):
69
  def __init__(self):
70
  super(BaseConditioner, self).__init__()
 
 
 
 
 
 
 
 
 
71
 
72
  def _impl_condition(self, y, metadata)->torch.Tensor:
73
  raise NotImplementedError()
@@ -166,6 +184,7 @@ class TimestepEmbedder(nn.Module):
166
 
167
  def forward(self, t):
168
  t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
 
169
  t_emb = self.mlp(t_freq)
170
  return t_emb
171
 
 
20
  super().__init__()
21
  self.scale = scale
22
  self.shift = shift
23
+ self.register_buffer("_diffusers_device_anchor", torch.zeros(0), persistent=False)
24
+
25
+ @property
26
+ def dtype(self) -> torch.dtype:
27
+ return self._diffusers_device_anchor.dtype
28
+
29
+ @property
30
+ def device(self) -> torch.device:
31
+ return self._diffusers_device_anchor.device
32
 
33
  def encode(self, x):
34
  return self._impl_encode(x) #.to(torch.bfloat16)
 
77
  class BaseConditioner(nn.Module):
78
  def __init__(self):
79
  super(BaseConditioner, self).__init__()
80
+ self.register_buffer("_diffusers_device_anchor", torch.zeros(0), persistent=False)
81
+
82
+ @property
83
+ def dtype(self) -> torch.dtype:
84
+ return self._diffusers_device_anchor.dtype
85
+
86
+ @property
87
+ def device(self) -> torch.device:
88
+ return self._diffusers_device_anchor.device
89
 
90
  def _impl_condition(self, y, metadata)->torch.Tensor:
91
  raise NotImplementedError()
 
184
 
185
  def forward(self, t):
186
  t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
187
+ t_freq = t_freq.to(dtype=self.mlp[0].weight.dtype)
188
  t_emb = self.mlp(t_freq)
189
  return t_emb
190
 
README.md CHANGED
@@ -39,30 +39,19 @@ Both checkpoints are ImageNet class-conditional PixNerd-XL/16 exports with flow-
39
 
40
  ![PixNerd-XL-16-512 demo](PixNerd-XL-16-512/demo.png)
41
 
42
- Class 207 β€” golden retriever / ι‡‘ζ―›ηŒŽηŠ¬, 512Γ—512, 25 steps.
43
 
44
  ## ImageNet class labels
45
 
46
- ImageNet-1k labels live in shared [`labels/`](labels/) at the repo root (not duplicated per variant). Format follows Hugging Face / DiT convention:
47
 
48
- | File | Direction | Value format |
49
- | --- | --- | --- |
50
- | `labels/id2label_en.json` | id β†’ English | comma-separated synonyms, e.g. `"207": "golden retriever"` |
51
- | `labels/id2label_cn.json` | id β†’ Chinese | comma-separated synonyms, e.g. `"207": "ι‡‘ζ―›ηŒŽηŠ¬"` |
52
-
53
- After `PixNerdPipeline.from_pretrained(...)`, the pipeline exposes:
54
-
55
- - `pipe.id2label` / `pipe.id2label_cn` β€” inspect id β†’ label correspondence
56
- - `pipe.labels` / `pipe.labels_cn` β€” reverse maps (synonym β†’ id), sorted for browsing
57
- - `pipe.get_label_ids("golden retriever")` or `pipe.get_label_ids("ι‡‘ζ―›ηŒŽηŠ¬", lang="cn")`
58
- - `pipe(prompt="golden retriever", ...)` β€” string labels resolved automatically
59
 
60
- Why JSON at repo root instead of a Python dict in each variant?
61
-
62
- 1. **Explicit correspondence** β€” users can open `id2label_en.json` and see every id without running code.
63
- 2. **Hub-compatible** β€” same shape as `facebook/DiT-XL-2-256` and other vision checkpoints.
64
- 3. **Shared across variants** β€” both PixNerd checkpoints use the same 1000 ImageNet classes.
65
- 4. **Bilingual without duplication** β€” English and Chinese are separate files; the pipeline loads both.
66
 
67
  ## Load from Hugging Face
68
 
@@ -76,23 +65,22 @@ resolution = 256 if variant.endswith("256") else 512
76
  pipe = DiffusionPipeline.from_pretrained(
77
  f"BiliSakura/PixNerd-diffusers/{variant}",
78
  trust_remote_code=True,
79
- torch_dtype=torch.float32,
80
  ).to("cuda")
81
 
 
 
82
  images = pipe(
83
- prompt=207,
84
  height=resolution,
85
  width=resolution,
86
  num_inference_steps=25,
87
  guidance_scale=4.0,
88
- timeshift=3.0,
89
- order=2,
90
  ).images
91
 
92
  print(pipe.id2label[207]) # "golden retriever"
93
- print(pipe.id2label_cn[207]) # "ι‡‘ζ―›ηŒŽηŠ¬"
94
  pipe.get_label_ids("golden retriever") # [207]
95
- images = pipe(prompt="golden retriever", height=resolution, width=resolution).images
96
  ```
97
 
98
  ## Load from a local clone
@@ -107,10 +95,10 @@ variant = "PixNerd-XL-16-256"
107
  pipe = DiffusionPipeline.from_pretrained(
108
  f"{repo}/{variant}",
109
  trust_remote_code=True,
110
- torch_dtype=torch.float32,
111
  ).to("cuda")
112
 
113
- images = pipe(prompt=207, height=256, width=256).images
114
  ```
115
 
116
  ## Repo layout
@@ -118,10 +106,6 @@ images = pipe(prompt=207, height=256, width=256).images
118
  ```text
119
  BiliSakura/PixNerd-diffusers/
120
  β”œβ”€β”€ README.md
121
- β”œβ”€β”€ labels/
122
- β”‚ β”œβ”€β”€ id2label_en.json # ImageNet id -> English synonyms
123
- β”‚ β”œβ”€β”€ id2label_cn.json # ImageNet id -> Chinese synonyms
124
- β”‚ └── imagenet_labels.py # loader helpers
125
  β”œβ”€β”€ PixNerd-XL-16-256/
126
  β”‚ β”œβ”€β”€ README.md
127
  β”‚ β”œβ”€β”€ pipeline.py
@@ -140,7 +124,7 @@ BiliSakura/PixNerd-diffusers/
140
 
141
  ## Interface notes
142
 
143
- - The pipeline uses `prompt` for class conditioning input.
144
  - Pass integer ImageNet ids (`prompt=207`) or human-readable synonyms (`prompt="golden retriever"`).
145
  - `height` and `width` should match checkpoint intent (256 or 512), but custom sizes work if divisible by patch size (16).
146
  - Architecture and conversion provenance are recorded in each checkpoint's `conversion_metadata.json`.
 
39
 
40
  ![PixNerd-XL-16-512 demo](PixNerd-XL-16-512/demo.png)
41
 
42
+ Class 207 β€” golden retriever, 512Γ—512, 25 steps.
43
 
44
  ## ImageNet class labels
45
 
46
+ Each variant keeps an English `id2label` map directly in its own `model_index.json` (DiT-style).
47
 
48
+ - `pipe.id2label` β€” inspect id β†’ English label correspondence
49
+ - `pipe.labels` β€” reverse maps (English synonym β†’ id), sorted for browsing
50
+ - `pipe.get_label_ids("golden retriever")`
51
+ - `pipe(class_labels="golden retriever", ...)` β€” string labels resolved automatically
52
+ - `pipe(prompt="golden retriever", ...)` β€” deprecated alias for `class_labels`
 
 
 
 
 
 
53
 
54
+ Chinese labels are preserved in the main source repo under `src/labels/id2label_cn.json` for reference.
 
 
 
 
 
55
 
56
  ## Load from Hugging Face
57
 
 
65
  pipe = DiffusionPipeline.from_pretrained(
66
  f"BiliSakura/PixNerd-diffusers/{variant}",
67
  trust_remote_code=True,
68
+ torch_dtype=torch.bfloat16,
69
  ).to("cuda")
70
 
71
+ # Scheduler defaults: timeshift=3.0, order=2 (see scheduler/scheduler_config.json)
72
+
73
  images = pipe(
74
+ class_labels="golden retriever",
75
  height=resolution,
76
  width=resolution,
77
  num_inference_steps=25,
78
  guidance_scale=4.0,
 
 
79
  ).images
80
 
81
  print(pipe.id2label[207]) # "golden retriever"
 
82
  pipe.get_label_ids("golden retriever") # [207]
83
+ images = pipe(class_labels="golden retriever", height=resolution, width=resolution).images
84
  ```
85
 
86
  ## Load from a local clone
 
95
  pipe = DiffusionPipeline.from_pretrained(
96
  f"{repo}/{variant}",
97
  trust_remote_code=True,
98
+ torch_dtype=torch.bfloat16,
99
  ).to("cuda")
100
 
101
+ images = pipe(class_labels="golden retriever", height=256, width=256).images
102
  ```
103
 
104
  ## Repo layout
 
106
  ```text
107
  BiliSakura/PixNerd-diffusers/
108
  β”œβ”€β”€ README.md
 
 
 
 
109
  β”œβ”€β”€ PixNerd-XL-16-256/
110
  β”‚ β”œβ”€β”€ README.md
111
  β”‚ β”œβ”€β”€ pipeline.py
 
124
 
125
  ## Interface notes
126
 
127
+ - The pipeline uses `class_labels` for ImageNet class conditioning (`prompt` remains a deprecated alias).
128
  - Pass integer ImageNet ids (`prompt=207`) or human-readable synonyms (`prompt="golden retriever"`).
129
  - `height` and `width` should match checkpoint intent (256 or 512), but custom sizes work if divisible by patch size (16).
130
  - Architecture and conversion provenance are recorded in each checkpoint's `conversion_metadata.json`.