Novix commited on
Commit
1eaa4cc
·
verified ·
1 Parent(s): 8940c35

Upload 2 files

Browse files
Files changed (2) hide show
  1. generate.py +625 -0
  2. generate.sh +75 -0
generate.py ADDED
@@ -0,0 +1,625 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from hmac import new
2
+ import sys
3
+ import os
4
+ import argparse
5
+
6
+ import time
7
+ import json
8
+ import torch
9
+ import torchaudio
10
+ import numpy as np
11
+ from omegaconf import OmegaConf
12
+ from codeclm.models import builders
13
+ import gc
14
+ from codeclm.trainer.codec_song_pl import CodecLM_PL
15
+ from codeclm.models import CodecLM
16
+ from third_party.demucs.models.pretrained import get_model_from_yaml
17
+ import re
18
+ import librosa
19
+
20
+ auto_prompt_type = ['Pop', 'Latin', 'Rock', 'Electronic', 'Metal', 'Country', 'R&B/Soul', 'Ballad', 'Jazz', 'World', 'Hip-Hop', 'Funk', 'Soundtrack','Auto']
21
+
22
+ def check_language_by_text(text):
23
+ chinese_pattern = re.compile(r'[\u4e00-\u9fff]')
24
+ english_pattern = re.compile(r'[a-zA-Z]')
25
+ chinese_count = len(re.findall(chinese_pattern, text))
26
+ english_count = len(re.findall(english_pattern, text))
27
+ chinese_ratio = chinese_count / len(text)
28
+ english_ratio = english_count / len(text)
29
+ if chinese_ratio >= 0.2:
30
+ return "zh"
31
+ elif english_ratio >= 0.5:
32
+ return "en"
33
+ else:
34
+ return "en"
35
+
36
+ def load_audio_by_librosa(f):
37
+ a, fs= librosa.load(f, sr=48000)
38
+ a = torch.tensor(a).unsqueeze(0)
39
+ if (fs != 48000):
40
+ a = torchaudio.functional.resample(a, fs, 48000)
41
+ if a.shape[-1] >= 48000*10:
42
+ a = a[..., :48000*10]
43
+ return a[:, 0:48000*10], 48000
44
+
45
+ class Separator:
46
+ def __init__(self, dm_model_path='third_party/demucs/ckpt/htdemucs.pth', dm_config_path='third_party/demucs/ckpt/htdemucs.yaml', gpu_id=0) -> None:
47
+ if torch.cuda.is_available() and gpu_id < torch.cuda.device_count():
48
+ self.device = torch.device(f"cuda:{gpu_id}")
49
+ else:
50
+ self.device = torch.device("cpu")
51
+ self.demucs_model = self.init_demucs_model(dm_model_path, dm_config_path)
52
+
53
+ def init_demucs_model(self, model_path, config_path):
54
+ model = get_model_from_yaml(config_path, model_path)
55
+ model.to(self.device)
56
+ model.eval()
57
+ return model
58
+
59
+ def load_audio(self, f):
60
+ try:
61
+ a, fs = torchaudio.load(f)
62
+ except:
63
+ a, fs = load_audio_by_librosa(f)
64
+ if (fs != 48000):
65
+ a = torchaudio.functional.resample(a, fs, 48000)
66
+ if a.shape[-1] >= 48000*10:
67
+ a = a[..., :48000*10]
68
+ return a[:, 0:48000*10]
69
+
70
+ def run(self, audio_path, output_dir='tmp', ext=".flac"):
71
+ os.makedirs(output_dir, exist_ok=True)
72
+ name, _ = os.path.splitext(os.path.split(audio_path)[-1])
73
+ output_paths = []
74
+
75
+ for stem in self.demucs_model.sources:
76
+ output_path = os.path.join(output_dir, f"{name}_{stem}{ext}")
77
+ if os.path.exists(output_path):
78
+ output_paths.append(output_path)
79
+ if len(output_paths) == 1: # 4
80
+ vocal_path = output_paths[0]
81
+ else:
82
+ drums_path, bass_path, other_path, vocal_path = self.demucs_model.separate(audio_path, output_dir, device=self.device)
83
+ for path in [drums_path, bass_path, other_path]:
84
+ os.remove(path)
85
+ full_audio = self.load_audio(audio_path)
86
+ vocal_audio = self.load_audio(vocal_path)
87
+ bgm_audio = full_audio - vocal_audio
88
+ return full_audio, vocal_audio, bgm_audio
89
+
90
+
91
+ def parse_args():
92
+ parser = argparse.ArgumentParser(description='Song Generation Script')
93
+
94
+ # 必需参数
95
+ parser.add_argument('--ckpt_path', type=str, required=True,
96
+ help='Path to the checkpoint directory containing config.yaml and model.pt')
97
+ parser.add_argument('--input_jsonl', type=str, required=True,
98
+ help='Path to input JSONL file containing generation tasks')
99
+ parser.add_argument('--save_dir', type=str, required=True,
100
+ help='Directory to save generated audio files and results')
101
+ # 可选参数
102
+ parser.add_argument('--generate_type', type=str, default='mixed',
103
+ help='Type of generation: "vocal" or "bgm" or "separate" or "mixed" (default: "mixed")')
104
+ parser.add_argument('--use_flash_attn', action='store_true',
105
+ help='Whether to use flash attention (default: False)')
106
+ parser.add_argument('--low_mem', action='store_true',
107
+ help='Whether to use low memory mode (default: False)')
108
+ return parser.parse_args()
109
+
110
+ def generate(args, version = 'v1'):
111
+ torch.set_num_threads(1)
112
+ ckpt_path = args.ckpt_path
113
+ input_jsonl = args.input_jsonl
114
+ save_dir = args.save_dir
115
+ cfg_path = os.path.join(ckpt_path, 'config.yaml')
116
+ ckpt_path = os.path.join(ckpt_path, 'model.pt')
117
+ cfg = OmegaConf.load(cfg_path)
118
+ cfg.lm.use_flash_attn_2 = args.use_flash_attn
119
+ print(f"use_flash_attn: {args.use_flash_attn}")
120
+ cfg.mode = 'inference'
121
+ max_duration = cfg.max_dur
122
+ gen_type = args.generate_type
123
+
124
+
125
+ separator = Separator()
126
+ auto_prompt = torch.load('tools/new_auto_prompt.pt')
127
+ audio_tokenizer = builders.get_audio_tokenizer_model(cfg.audio_tokenizer_checkpoint, cfg)
128
+ audio_tokenizer = audio_tokenizer.eval().cuda()
129
+ with open(input_jsonl, "r") as fp:
130
+ lines = fp.readlines()
131
+
132
+
133
+ new_items = []
134
+ for line in lines:
135
+ item = json.loads(line)
136
+ target_wav_name = f"{save_dir}/audios/{item['idx']}.flac"
137
+ # get prompt audio
138
+ if "prompt_audio_path" in item:
139
+ assert os.path.exists(item['prompt_audio_path']), f"prompt_audio_path {item['prompt_audio_path']} not found"
140
+ assert 'auto_prompt_audio_type' not in item, f"auto_prompt_audio_type and prompt_audio_path cannot be used together"
141
+ with torch.no_grad():
142
+ pmt_wav, vocal_wav, bgm_wav = separator.run(item['prompt_audio_path'])
143
+ item['raw_pmt_wav'] = pmt_wav
144
+ item['raw_vocal_wav'] = vocal_wav
145
+ item['raw_bgm_wav'] = bgm_wav
146
+ if pmt_wav.dim() == 2:
147
+ pmt_wav = pmt_wav[None]
148
+ if pmt_wav.dim() != 3:
149
+ raise ValueError("Melody wavs should have a shape [B, C, T].")
150
+ pmt_wav = list(pmt_wav)
151
+ if vocal_wav.dim() == 2:
152
+ vocal_wav = vocal_wav[None]
153
+ if vocal_wav.dim() != 3:
154
+ raise ValueError("Vocal wavs should have a shape [B, C, T].")
155
+ vocal_wav = list(vocal_wav)
156
+ if bgm_wav.dim() == 2:
157
+ bgm_wav = bgm_wav[None]
158
+ if bgm_wav.dim() != 3:
159
+ raise ValueError("BGM wavs should have a shape [B, C, T].")
160
+ bgm_wav = list(bgm_wav)
161
+ if type(pmt_wav) == list:
162
+ pmt_wav = torch.stack(pmt_wav, dim=0)
163
+ if type(vocal_wav) == list:
164
+ vocal_wav = torch.stack(vocal_wav, dim=0)
165
+ if type(bgm_wav) == list:
166
+ bgm_wav = torch.stack(bgm_wav, dim=0)
167
+ pmt_wav = pmt_wav
168
+ vocal_wav = vocal_wav
169
+ bgm_wav = bgm_wav
170
+ with torch.no_grad():
171
+ pmt_wav, _ = audio_tokenizer.encode(pmt_wav.cuda())
172
+ melody_is_wav = False
173
+ elif "auto_prompt_audio_type" in item:
174
+ assert item["auto_prompt_audio_type"] in auto_prompt_type, f"auto_prompt_audio_type {item['auto_prompt_audio_type']} not found"
175
+ lang = check_language_by_text(item['gt_lyric'])
176
+ prompt_token = auto_prompt[item["auto_prompt_audio_type"]][lang][np.random.randint(0, len(auto_prompt[item["auto_prompt_audio_type"]][lang]))]
177
+ pmt_wav = prompt_token[:,[0],:]
178
+ vocal_wav = prompt_token[:,[1],:]
179
+ bgm_wav = prompt_token[:,[2],:]
180
+ melody_is_wav = False
181
+ else:
182
+ pmt_wav = None
183
+ vocal_wav = None
184
+ bgm_wav = None
185
+ melody_is_wav = True
186
+ item['pmt_wav'] = pmt_wav
187
+ item['vocal_wav'] = vocal_wav
188
+ item['bgm_wav'] = bgm_wav
189
+ item['melody_is_wav'] = melody_is_wav
190
+ item["idx"] = f"{item['idx']}"
191
+ item["wav_path"] = target_wav_name
192
+ new_items.append(item)
193
+
194
+ del audio_tokenizer
195
+ del separator
196
+
197
+ torch.cuda.empty_cache()
198
+
199
+ if "audio_tokenizer_checkpoint_sep" in cfg.keys():
200
+ seperate_tokenizer = builders.get_audio_tokenizer_model(cfg.audio_tokenizer_checkpoint_sep, cfg)
201
+ else:
202
+ seperate_tokenizer = None
203
+
204
+ if seperate_tokenizer is not None:
205
+ seperate_tokenizer = seperate_tokenizer.eval().cuda()
206
+
207
+ for item in new_items:
208
+ if "prompt_audio_path" in item:
209
+ with torch.no_grad():
210
+ vocal_wav, bgm_wav = seperate_tokenizer.encode(item['vocal_wav'].cuda(), item['bgm_wav'].cuda())
211
+ item['vocal_wav'] = vocal_wav
212
+ item['bgm_wav'] = bgm_wav
213
+
214
+ torch.cuda.empty_cache()
215
+ audiolm = builders.get_lm_model(cfg, version=version)
216
+ checkpoint = torch.load(ckpt_path, map_location='cpu')
217
+ audiolm_state_dict = {k.replace('audiolm.', ''): v for k, v in checkpoint.items() if k.startswith('audiolm')}
218
+ audiolm.load_state_dict(audiolm_state_dict, strict=False)
219
+ audiolm = audiolm.eval()
220
+ audiolm = audiolm.cuda().to(torch.float16)
221
+
222
+ model = CodecLM(name = "tmp",
223
+ lm = audiolm,
224
+ audiotokenizer = None,
225
+ max_duration = max_duration,
226
+ seperate_tokenizer = seperate_tokenizer,
227
+ )
228
+
229
+ cfg_coef = 1.5 #25
230
+ temp = 0.8
231
+ top_k = 5000
232
+ top_p = 0.0
233
+ record_tokens = True
234
+ record_window = 50
235
+
236
+ model.set_generation_params(duration=max_duration, extend_stride=5, temperature=temp, cfg_coef=cfg_coef,
237
+ top_k=top_k, top_p=top_p, record_tokens=record_tokens, record_window=record_window)
238
+ os.makedirs(save_dir, exist_ok=True)
239
+ os.makedirs(save_dir + "/audios", exist_ok=True)
240
+ os.makedirs(save_dir + "/jsonl", exist_ok=True)
241
+
242
+ for item in new_items:
243
+ lyric = item["gt_lyric"]
244
+ if version == 'v1':
245
+ descriptions = item["descriptions"].lower() if "descriptions" in item else None
246
+ else:
247
+ if gen_type == 'bgm':
248
+ descriptions = '[Musicality-very-high]' + ', ' + '[Pure-Music]' + ', ' + item["descriptions"].lower() if "descriptions" in item else '.'
249
+ else:
250
+ descriptions = item["descriptions"].lower() if "descriptions" in item else '.'
251
+ descriptions = '[Musicality-very-high]' + ', ' + descriptions
252
+
253
+ pmt_wav = item['pmt_wav']
254
+ vocal_wav = item['vocal_wav']
255
+ bgm_wav = item['bgm_wav']
256
+ melody_is_wav = item['melody_is_wav']
257
+ target_wav_name = f"{save_dir}/audios/{item['idx']}.flac"
258
+
259
+ generate_inp = {
260
+ 'lyrics': [lyric.replace(" ", " ")] if gen_type != 'bgm' else '.',
261
+ 'descriptions': [descriptions],
262
+ 'melody_wavs': pmt_wav,
263
+ 'vocal_wavs': vocal_wav,
264
+ 'bgm_wavs': bgm_wav,
265
+ 'melody_is_wav': melody_is_wav,
266
+ }
267
+ start_time = time.time()
268
+ with torch.autocast(device_type="cuda", dtype=torch.float16):
269
+ with torch.no_grad():
270
+ tokens = model.generate(**generate_inp, return_tokens=True)
271
+ mid_time = time.time()
272
+
273
+ with torch.no_grad():
274
+ if 'raw_pmt_wav' in item:
275
+ if gen_type == 'separate':
276
+ wav_seperate = model.generate_audio(tokens, item['raw_pmt_wav'], item['raw_vocal_wav'], item['raw_bgm_wav'], chunked=True, gen_type='mixed')
277
+ wav_vocal = model.generate_audio(tokens, item['raw_pmt_wav'], item['raw_vocal_wav'], item['raw_bgm_wav'], chunked=True, gen_type='vocal')
278
+ wav_bgm = model.generate_audio(tokens, item['raw_pmt_wav'], item['raw_vocal_wav'], item['raw_bgm_wav'], chunked=True, gen_type='bgm')
279
+ elif gen_type == 'mixed':
280
+ wav_seperate = model.generate_audio(tokens, item['raw_pmt_wav'], item['raw_vocal_wav'], item['raw_bgm_wav'],chunked=True, gen_type=gen_type)
281
+ else:
282
+ wav_seperate = model.generate_audio(tokens,chunked=True, gen_type=gen_type)
283
+ del item['raw_pmt_wav']
284
+ del item['raw_vocal_wav']
285
+ del item['raw_bgm_wav']
286
+ else:
287
+ if gen_type == 'separate':
288
+ wav_vocal = model.generate_audio(tokens, chunked=True, gen_type='vocal')
289
+ wav_bgm = model.generate_audio(tokens, chunked=True, gen_type='bgm')
290
+ wav_seperate = model.generate_audio(tokens, chunked=True, gen_type='mixed')
291
+ else:
292
+ wav_seperate = model.generate_audio(tokens, chunked=True, gen_type=gen_type)
293
+ del item['pmt_wav']
294
+ del item['vocal_wav']
295
+ del item['bgm_wav']
296
+ del item['melody_is_wav']
297
+ end_time = time.time()
298
+ if gen_type == 'separate':
299
+ torchaudio.save(target_wav_name.replace('.flac', '_vocal.flac'), wav_vocal[0].cpu().float(), cfg.sample_rate)
300
+ torchaudio.save(target_wav_name.replace('.flac', '_bgm.flac'), wav_bgm[0].cpu().float(), cfg.sample_rate)
301
+ torchaudio.save(target_wav_name, wav_seperate[0].cpu().float(), cfg.sample_rate)
302
+ else:
303
+ torchaudio.save(target_wav_name, wav_seperate[0].cpu().float(), cfg.sample_rate)
304
+
305
+ print(f"process{item['idx']}, lm cost {mid_time - start_time}s, diffusion cost {end_time - mid_time}")
306
+ item["idx"] = f"{item['idx']}"
307
+ item["wav_path"] = target_wav_name
308
+
309
+ src_jsonl_name = os.path.split(input_jsonl)[-1]
310
+ with open(f"{save_dir}/jsonl/{src_jsonl_name}.jsonl", "w", encoding='utf-8') as fw:
311
+ for item in new_items:
312
+ fw.writelines(json.dumps(item, ensure_ascii=False)+"\n")
313
+
314
+ def generate_lowmem(args, version = 'v1'):
315
+ torch.set_num_threads(1)
316
+ ckpt_path = args.ckpt_path
317
+ input_jsonl = args.input_jsonl
318
+ save_dir = args.save_dir
319
+ cfg_path = os.path.join(ckpt_path, 'config.yaml')
320
+ ckpt_path = os.path.join(ckpt_path, 'model.pt')
321
+ cfg = OmegaConf.load(cfg_path)
322
+ cfg.lm.use_flash_attn_2 = args.use_flash_attn
323
+ print(f"use_flash_attn: {args.use_flash_attn}")
324
+ cfg.mode = 'inference'
325
+ max_duration = cfg.max_dur
326
+ gen_type = args.generate_type
327
+ chunk_size = 128
328
+ use_audio_tokenizer = False
329
+ with open(input_jsonl, "r") as fp:
330
+ lines = fp.readlines()
331
+ for line in lines:
332
+ item = json.loads(line)
333
+ if "prompt_audio_path" in item:
334
+ use_audio_tokenizer = True
335
+ break
336
+ if use_audio_tokenizer:
337
+ separator = Separator()
338
+ audio_tokenizer = builders.get_audio_tokenizer_model(cfg.audio_tokenizer_checkpoint, cfg)
339
+ audio_tokenizer = audio_tokenizer.eval().cuda()
340
+ auto_prompt = torch.load('tools/new_auto_prompt.pt')
341
+ new_items = []
342
+ for line in lines:
343
+ item = json.loads(line)
344
+ target_wav_name = f"{save_dir}/audios/{item['idx']}.flac"
345
+ # get prompt audio
346
+ if "prompt_audio_path" in item:
347
+ assert os.path.exists(item['prompt_audio_path']), f"prompt_audio_path {item['prompt_audio_path']} not found"
348
+ assert 'auto_prompt_audio_type' not in item, f"auto_prompt_audio_type and prompt_audio_path cannot be used together"
349
+ with torch.no_grad():
350
+ pmt_wav, vocal_wav, bgm_wav = separator.run(item['prompt_audio_path'])
351
+ item['raw_pmt_wav'] = pmt_wav
352
+ item['raw_vocal_wav'] = vocal_wav
353
+ item['raw_bgm_wav'] = bgm_wav
354
+ if pmt_wav.dim() == 2:
355
+ pmt_wav = pmt_wav[None]
356
+ if pmt_wav.dim() != 3:
357
+ raise ValueError("Melody wavs should have a shape [B, C, T].")
358
+ pmt_wav = list(pmt_wav)
359
+ if vocal_wav.dim() == 2:
360
+ vocal_wav = vocal_wav[None]
361
+ if vocal_wav.dim() != 3:
362
+ raise ValueError("Vocal wavs should have a shape [B, C, T].")
363
+ vocal_wav = list(vocal_wav)
364
+ if bgm_wav.dim() == 2:
365
+ bgm_wav = bgm_wav[None]
366
+ if bgm_wav.dim() != 3:
367
+ raise ValueError("BGM wavs should have a shape [B, C, T].")
368
+ bgm_wav = list(bgm_wav)
369
+ if type(pmt_wav) == list:
370
+ pmt_wav = torch.stack(pmt_wav, dim=0)
371
+ if type(vocal_wav) == list:
372
+ vocal_wav = torch.stack(vocal_wav, dim=0)
373
+ if type(bgm_wav) == list:
374
+ bgm_wav = torch.stack(bgm_wav, dim=0)
375
+ with torch.no_grad():
376
+ pmt_wav, _ = audio_tokenizer.encode(pmt_wav.cuda())
377
+ melody_is_wav = False
378
+ elif "auto_prompt_audio_type" in item:
379
+ assert item["auto_prompt_audio_type"] in auto_prompt_type, f"auto_prompt_audio_type {item['auto_prompt_audio_type']} not found"
380
+ lang = check_language_by_text(item['gt_lyric'])
381
+ prompt_token = auto_prompt[item["auto_prompt_audio_type"]][lang][np.random.randint(0, len(auto_prompt[item["auto_prompt_audio_type"]][lang]))]
382
+ pmt_wav = prompt_token[:,[0],:]
383
+ vocal_wav = prompt_token[:,[1],:]
384
+ bgm_wav = prompt_token[:,[2],:]
385
+ melody_is_wav = False
386
+ else:
387
+ pmt_wav = None
388
+ vocal_wav = None
389
+ bgm_wav = None
390
+ melody_is_wav = True
391
+ item['pmt_wav'] = pmt_wav
392
+ item['vocal_wav'] = vocal_wav
393
+ item['bgm_wav'] = bgm_wav
394
+ item['melody_is_wav'] = melody_is_wav
395
+ item["idx"] = f"{item['idx']}"
396
+ item["wav_path"] = target_wav_name
397
+ new_items.append(item)
398
+
399
+ if use_audio_tokenizer:
400
+ del audio_tokenizer
401
+ del separator
402
+
403
+ torch.cuda.empty_cache()
404
+
405
+ if "audio_tokenizer_checkpoint_sep" in cfg.keys() and use_audio_tokenizer:
406
+ seperate_tokenizer = builders.get_audio_tokenizer_model(cfg.audio_tokenizer_checkpoint_sep, cfg)
407
+ else:
408
+ seperate_tokenizer = None
409
+
410
+ if seperate_tokenizer is not None:
411
+ seperate_tokenizer = seperate_tokenizer.eval().cuda()
412
+
413
+ for item in new_items:
414
+ if "prompt_audio_path" in item:
415
+ with torch.no_grad():
416
+ vocal_wav, bgm_wav = seperate_tokenizer.encode(item['vocal_wav'].cuda(), item['bgm_wav'].cuda())
417
+ item['vocal_wav'] = vocal_wav
418
+ item['bgm_wav'] = bgm_wav
419
+
420
+ if use_audio_tokenizer:
421
+ del seperate_tokenizer
422
+
423
+ torch.cuda.empty_cache()
424
+
425
+ # Define model or load pretrained model
426
+ audiolm = builders.get_lm_model(cfg, version=version)
427
+ checkpoint = torch.load(ckpt_path, map_location='cpu')
428
+ audiolm_state_dict = {k.replace('audiolm.', ''): v for k, v in checkpoint.items() if k.startswith('audiolm')}
429
+ audiolm.load_state_dict(audiolm_state_dict, strict=False)
430
+ audiolm = audiolm.eval()
431
+
432
+ offload_audiolm = True if 'offload' in cfg.keys() and 'audiolm' in cfg.offload else False
433
+ if offload_audiolm:
434
+ audiolm_offload_param = OffloadParamParse.parse_config(audiolm, cfg.offload.audiolm)
435
+ audiolm_offload_param.show()
436
+ offload_profiler = OffloadProfiler(device_index=0, **(audiolm_offload_param.init_param_dict()))
437
+ offload_profiler.offload_layer(**(audiolm_offload_param.offload_layer_param_dict()))
438
+ offload_profiler.clean_cache_wrapper(**(audiolm_offload_param.clean_cache_param_dict()))
439
+ else:
440
+ audiolm = audiolm.cuda().to(torch.float16)
441
+
442
+ model = CodecLM(name = "tmp",
443
+ lm = audiolm,
444
+ audiotokenizer = None,
445
+ max_duration = max_duration,
446
+ seperate_tokenizer = None,
447
+ )
448
+
449
+ cfg_coef = 1.5 #25
450
+ temp = 0.9
451
+ top_k = 50
452
+ top_p = 0.0
453
+ record_tokens = True
454
+ record_window = 50
455
+
456
+
457
+ model.set_generation_params(duration=max_duration, extend_stride=5, temperature=temp, cfg_coef=cfg_coef,
458
+ top_k=top_k, top_p=top_p, record_tokens=record_tokens, record_window=record_window)
459
+ os.makedirs(save_dir, exist_ok=True)
460
+ os.makedirs(save_dir + "/audios", exist_ok=True)
461
+ os.makedirs(save_dir + "/jsonl", exist_ok=True)
462
+
463
+
464
+ for item in new_items:
465
+ lyric = item["gt_lyric"]
466
+ if version == 'v1':
467
+ descriptions = item["descriptions"].lower() if "descriptions" in item else None
468
+ else:
469
+ if gen_type == 'bgm':
470
+ descriptions = '[Musicality-very-high]' + ', ' + '[Pure-Music]' + ', ' + item["descriptions"].lower() if "descriptions" in item else '.'
471
+ else:
472
+ descriptions = item["descriptions"].lower() if "descriptions" in item else '.'
473
+ descriptions = '[Musicality-very-high]' + ', ' + descriptions
474
+ pmt_wav = item['pmt_wav']
475
+ vocal_wav = item['vocal_wav']
476
+ bgm_wav = item['bgm_wav']
477
+ melody_is_wav = item['melody_is_wav']
478
+
479
+ generate_inp = {
480
+ 'lyrics': [lyric.replace(" ", " ")] if gen_type != 'bgm' else '.',
481
+ 'descriptions': [descriptions],
482
+ 'melody_wavs': pmt_wav,
483
+ 'vocal_wavs': vocal_wav,
484
+ 'bgm_wavs': bgm_wav,
485
+ 'melody_is_wav': melody_is_wav,
486
+ }
487
+ with torch.autocast(device_type="cuda", dtype=torch.float16):
488
+ with torch.no_grad():
489
+ tokens = model.generate(**generate_inp, return_tokens=True)
490
+ if offload_audiolm:
491
+ offload_profiler.reset_empty_cache_mem_line()
492
+ item['tokens'] = tokens
493
+ if offload_audiolm:
494
+ offload_profiler.stop()
495
+ del offload_profiler
496
+ del audiolm_offload_param
497
+ del model
498
+ audiolm = audiolm.cpu()
499
+ del audiolm
500
+ del checkpoint
501
+ gc.collect()
502
+ torch.cuda.empty_cache()
503
+
504
+ seperate_tokenizer = builders.get_audio_tokenizer_model_cpu(cfg.audio_tokenizer_checkpoint_sep, cfg)
505
+ device = "cuda:0"
506
+ seperate_tokenizer.model.device = device
507
+ seperate_tokenizer.model.vae = seperate_tokenizer.model.vae.to(device)
508
+ seperate_tokenizer.model.model.device = torch.device(device)
509
+ seperate_tokenizer = seperate_tokenizer.eval()
510
+
511
+ # offload_wav_tokenizer_diffusion = True if 'offload' in cfg.keys() and 'wav_tokenizer_diffusion' in cfg.offload else False
512
+ offload_wav_tokenizer_diffusion = False
513
+ if offload_wav_tokenizer_diffusion:
514
+ sep_offload_param = OffloadParamParse.parse_config(seperate_tokenizer, cfg.offload.wav_tokenizer_diffusion)
515
+ sep_offload_param.show()
516
+ sep_offload_profiler = OffloadProfiler(device_index=0, **(sep_offload_param.init_param_dict()))
517
+ sep_offload_profiler.offload_layer(**(sep_offload_param.offload_layer_param_dict()))
518
+ sep_offload_profiler.clean_cache_wrapper(**(sep_offload_param.clean_cache_param_dict()))
519
+ else:
520
+ seperate_tokenizer.model.model = seperate_tokenizer.model.model.to(device)
521
+
522
+ model = CodecLM(name = "tmp",
523
+ lm = None,
524
+ audiotokenizer = None,
525
+ max_duration = max_duration,
526
+ seperate_tokenizer = seperate_tokenizer,
527
+ )
528
+
529
+ for item in new_items:
530
+ with torch.no_grad():
531
+ if 'raw_pmt_wav' in item:
532
+ if gen_type == 'separate':
533
+ wav_seperate = model.generate_audio(item['tokens'], item['raw_pmt_wav'], item['raw_vocal_wav'], item['raw_bgm_wav'],chunked=True, gen_type='mixed')
534
+ wav_vocal = model.generate_audio(item['tokens'],chunked=True, gen_type='vocal')
535
+ wav_bgm = model.generate_audio(item['tokens'], chunked=True, gen_type='bgm')
536
+ elif gen_type == 'mixed':
537
+ wav_seperate = model.generate_audio(item['tokens'], item['raw_pmt_wav'], item['raw_vocal_wav'], item['raw_bgm_wav'],chunked=True, gen_type=gen_type)
538
+ else:
539
+ wav_seperate = model.generate_audio(item['tokens'], chunked=True, gen_type=gen_type)
540
+ del item['raw_pmt_wav']
541
+ del item['raw_vocal_wav']
542
+ del item['raw_bgm_wav']
543
+ else:
544
+ if gen_type == 'separate':
545
+ wav_vocal = model.generate_audio(item['tokens'], chunked=True, gen_type='vocal')
546
+ wav_bgm = model.generate_audio(item['tokens'], chunked=True, gen_type='bgm')
547
+ wav_seperate = model.generate_audio(item['tokens'], chunked=True, gen_type='mixed')
548
+ else:
549
+ wav_seperate = model.generate_audio(item['tokens'], chunked=True, gen_type=gen_type)
550
+ if gen_type == 'separate':
551
+ torchaudio.save(item['wav_path'].replace('.flac', '_vocal.flac'), wav_vocal[0].cpu().float(), cfg.sample_rate)
552
+ torchaudio.save(item['wav_path'].replace('.flac', '_bgm.flac'), wav_bgm[0].cpu().float(), cfg.sample_rate)
553
+ torchaudio.save(item['wav_path'], wav_seperate[0].cpu().float(), cfg.sample_rate)
554
+ else:
555
+ torchaudio.save(item['wav_path'], wav_seperate[0].cpu().float(), cfg.sample_rate)
556
+ del item['tokens']
557
+ del item['pmt_wav']
558
+ del item['vocal_wav']
559
+ del item['bgm_wav']
560
+ del item['melody_is_wav']
561
+ if offload_wav_tokenizer_diffusion:
562
+ sep_offload_profiler.reset_empty_cache_mem_line()
563
+
564
+ if offload_wav_tokenizer_diffusion:
565
+ sep_offload_profiler.stop()
566
+ torch.cuda.empty_cache()
567
+ src_jsonl_name = os.path.split(input_jsonl)[-1]
568
+ with open(f"{save_dir}/jsonl/{src_jsonl_name}.jsonl", "w", encoding='utf-8') as fw:
569
+ for item in new_items:
570
+ fw.writelines(json.dumps(item, ensure_ascii=False)+"\n")
571
+
572
+
573
+ if __name__ == "__main__":
574
+ torch.backends.cudnn.enabled = False
575
+ OmegaConf.register_new_resolver("eval", lambda x: eval(x))
576
+ OmegaConf.register_new_resolver("concat", lambda *x: [xxx for xx in x for xxx in xx])
577
+ OmegaConf.register_new_resolver("get_fname", lambda: os.path.splitext(os.path.basename(sys.argv[1]))[0])
578
+ OmegaConf.register_new_resolver("load_yaml", lambda x: list(OmegaConf.load(x)))
579
+ np.random.seed(int(time.time()))
580
+ # 解析命令行参数
581
+ args = parse_args()
582
+ if torch.cuda.is_available():
583
+ device = torch.cuda.current_device()
584
+ reserved = torch.cuda.memory_reserved(device)
585
+ total = torch.cuda.get_device_properties(device).total_memory
586
+ res_mem = (total - reserved) / 1024 / 1024 / 1024
587
+ print(f"reserved memory: {res_mem}GB")
588
+
589
+ model_name = args.ckpt_path.split("/")[-1].lower().replace('-', '_')
590
+ if model_name == 'songgeneration_base' or model_name == 'songgeneration_base_new' or model_name == 'songgeneration_base_full':
591
+ if res_mem > 24 and not args.low_mem:
592
+ print("use generate")
593
+ generate(args)
594
+ else:
595
+ from codeclm.utils.offload_profiler import OffloadProfiler, OffloadParamParse
596
+ print("use generate_lowmem")
597
+ generate_lowmem(args)
598
+ elif model_name == 'songgeneration_large':
599
+ if res_mem > 36 and not args.low_mem:
600
+ print("use generate")
601
+ generate(args)
602
+ else:
603
+ print("use generate_lowmem")
604
+ from codeclm.utils.offload_profiler import OffloadProfiler, OffloadParamParse
605
+ generate_lowmem(args)
606
+ elif model_name == 'songgeneration_v2_large':
607
+ if res_mem > 32 and not args.low_mem:
608
+ print("use generate")
609
+ generate(args, version = 'v2')
610
+ else:
611
+ print("use generate_lowmem")
612
+ from codeclm.utils.offload_profiler import OffloadProfiler, OffloadParamParse
613
+ generate_lowmem(args, version = 'v2')
614
+ else:
615
+ if not args.low_mem:
616
+ print('use generate')
617
+ generate(args, version = 'v2')
618
+ else:
619
+ print('use generate_lowmem')
620
+ from codeclm.utils.offload_profiler import OffloadProfiler, OffloadParamParse
621
+ generate_lowmem(args, version = 'v2')
622
+ else:
623
+ print("CUDA is not available")
624
+ exit()
625
+
generate.sh ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export USER=root
2
+ export PYTHONDONTWRITEBYTECODE=1
3
+ export TRANSFORMERS_CACHE="$(pwd)/third_party/hub"
4
+ export NCCL_HOME=/usr/local/tccl
5
+ export PYTHONPATH="$(pwd)/codeclm/tokenizer/":"$(pwd)":"$(pwd)/codeclm/tokenizer/Flow1dVAE/":"$(pwd)/codeclm/tokenizer/":$PYTHONPATH
6
+ export OMP_NUM_THREADS=1
7
+ export MKL_NUM_THREADS=1
8
+ export CUDA_LAUNCH_BLOCKING=0
9
+
10
+ CKPT_PATH=$1
11
+ JSONL=$2
12
+ SAVE_DIR=$3
13
+ USE_FLASH_ATTN="True"
14
+ LOW_MEM="False"
15
+ GENERATE_TYPE="mixed"
16
+ for arg in "$@"; do
17
+ if [[ $arg == "--not_use_flash_attn" ]]; then
18
+ USE_FLASH_ATTN="False"
19
+ fi
20
+ done
21
+ for arg in "$@"; do
22
+ if [[ $arg == "--low_mem" ]]; then
23
+ LOW_MEM="True"
24
+ fi
25
+ done
26
+ for arg in "$@"; do
27
+ if [[ $arg == "--separate" ]]; then
28
+ GENERATE_TYPE="separate"
29
+ fi
30
+ done
31
+ for arg in "$@"; do
32
+ if [[ $arg == "--bgm" ]]; then
33
+ GENERATE_TYPE="bgm"
34
+ fi
35
+ done
36
+ for arg in "$@"; do
37
+ if [[ $arg == "--vocal" ]]; then
38
+ GENERATE_TYPE="vocal"
39
+ fi
40
+ done
41
+
42
+
43
+ if [ "$USE_FLASH_ATTN" == "True" ] && [ "$LOW_MEM" == "True" ]; then
44
+ echo "Use Flash Attention + Low Memory Mode"
45
+ python3 generate.py \
46
+ --ckpt_path $CKPT_PATH \
47
+ --input_jsonl $JSONL \
48
+ --save_dir $SAVE_DIR \
49
+ --generate_type $GENERATE_TYPE \
50
+ --use_flash_attn \
51
+ --low_mem
52
+ elif [ "$USE_FLASH_ATTN" == "True" ] && [ "$LOW_MEM" == "False" ]; then
53
+ echo "Use Flash Attention + Auto Memory Mode"
54
+ python3 generate.py \
55
+ --ckpt_path $CKPT_PATH \
56
+ --input_jsonl $JSONL \
57
+ --save_dir $SAVE_DIR \
58
+ --generate_type $GENERATE_TYPE \
59
+ --use_flash_attn
60
+ elif [ "$USE_FLASH_ATTN" == "False" ] && [ "$LOW_MEM" == "False" ]; then
61
+ echo "Not Use Flash Attention + Auto Memory Mode"
62
+ python3 generate.py \
63
+ --ckpt_path $CKPT_PATH \
64
+ --input_jsonl $JSONL \
65
+ --generate_type $GENERATE_TYPE \
66
+ --save_dir $SAVE_DIR
67
+ elif [ "$USE_FLASH_ATTN" == "False" ] && [ "$LOW_MEM" == "True" ]; then
68
+ echo "Not Use Flash Attention + Low Memory Mode"
69
+ python3 generate.py \
70
+ --ckpt_path $CKPT_PATH \
71
+ --input_jsonl $JSONL \
72
+ --save_dir $SAVE_DIR \
73
+ --generate_type $GENERATE_TYPE \
74
+ --low_mem
75
+ fi