Novix commited on
Commit
8940c35
·
verified ·
1 Parent(s): 9fe1aa4

Upload 6 files

Browse files
tools/gradio/app.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import gradio as gr
3
+ import json
4
+ from datetime import datetime
5
+ import yaml
6
+ import time
7
+ import re
8
+ import os.path as op
9
+ import torch
10
+ from levo_inference_lowmem import LeVoInference
11
+
12
+ EXAMPLE_LYRICS = """
13
+ [intro-short]
14
+
15
+ [verse]
16
+ 夜晚的街灯闪烁
17
+ 我漫步在熟悉的角落
18
+ 回忆像潮水般涌来
19
+ 你的笑容如此清晰
20
+ 在心头无法抹去
21
+ 那些曾经的甜蜜
22
+ 如今只剩我独自回忆
23
+
24
+ [verse]
25
+ 手机屏幕亮起
26
+ 是你发来的消息
27
+ 简单的几个字
28
+ 却让我泪流满面
29
+ 曾经的拥抱温暖
30
+ 如今却变得遥远
31
+ 我多想回到从前
32
+ 重新拥有你的陪伴
33
+
34
+ [chorus]
35
+ 回忆的温度还在
36
+ 你却已不在
37
+ 我的心被爱填满
38
+ 却又被思念刺痛
39
+ 音乐的节奏奏响
40
+ 我的心却在流浪
41
+ 没有你的日子
42
+ 我该如何继续向前
43
+
44
+ [outro-short]
45
+ """.strip()
46
+
47
+ APP_DIR = op.dirname(op.dirname(op.dirname(op.abspath(__file__))))
48
+ MODEL = LeVoInference(sys.argv[1])
49
+ with open(op.join(APP_DIR, 'conf/vocab.yaml'), 'r', encoding='utf-8') as file:
50
+ STRUCTS = yaml.safe_load(file)
51
+
52
+
53
+ def generate_song(lyric, description=None, prompt_audio=None, genre=None, cfg_coef=None, temperature=None, top_k=None, gen_type="mixed", progress=gr.Progress(track_tqdm=True)):
54
+ global MODEL
55
+ global STRUCTS
56
+ params = {'cfg_coef':cfg_coef, 'temperature':temperature, 'top_k':top_k}
57
+ params = {k:v for k,v in params.items() if v is not None}
58
+ vocal_structs = ['[verse]', '[chorus]', '[bridge]']
59
+ sample_rate = MODEL.cfg.sample_rate
60
+
61
+ # format lyric
62
+ lyric = lyric.replace("[intro]", "[intro-short]").replace("[inst]", "[inst-short]").replace("[outro]", "[outro-short]")
63
+ paragraphs = [p.strip() for p in lyric.strip().split('\n\n') if p.strip()]
64
+ if len(paragraphs) < 1:
65
+ return None, json.dumps("Lyrics can not be left blank")
66
+ paragraphs_norm = []
67
+ vocal_flag = False
68
+ for para in paragraphs:
69
+ lines = para.splitlines()
70
+ struct_tag = lines[0].strip().lower()
71
+ if struct_tag not in STRUCTS:
72
+ return None, json.dumps(f"Segments should start with a structure tag in {STRUCTS}")
73
+ if struct_tag in vocal_structs:
74
+ vocal_flag = True
75
+ if len(lines) < 2 or not [line.strip() for line in lines[1:] if line.strip()]:
76
+ return None, json.dumps("The following segments require lyrics: [verse], [chorus], [bridge]")
77
+ else:
78
+ new_para_list = []
79
+ for line in lines[1:]:
80
+ new_para_list.append(re.sub(r"[^\w\s\[\]\-\u4e00-\u9fff\u3040-\u309f\u30a0-\u30ff\uac00-\ud7af\u00c0-\u017f]", "", line))
81
+ new_para_str = f"{struct_tag} {'.'.join(new_para_list)}"
82
+ else:
83
+ if len(lines) > 1:
84
+ return None, json.dumps("The following segments should not contain lyrics: [intro], [intro-short], [intro-medium], [inst], [inst-short], [inst-medium], [outro], [outro-short], [outro-medium]")
85
+ else:
86
+ new_para_str = struct_tag
87
+ paragraphs_norm.append(new_para_str)
88
+ if not vocal_flag:
89
+ return None, json.dumps(f"The lyric must contain at least one of the following structures: {vocal_structs}")
90
+ lyric_norm = " ; ".join(paragraphs_norm)
91
+
92
+ # format prompt
93
+ if prompt_audio is not None:
94
+ genre = None
95
+ description = None
96
+ elif description is not None and description != "":
97
+ genre = None
98
+ if description[-1] != ".":
99
+ description = description + "."
100
+
101
+ progress(0.0, "Start Generation")
102
+ start = time.time()
103
+
104
+ audio_data = MODEL(lyric_norm, description, prompt_audio, genre, op.join(APP_DIR, "tools/new_auto_prompt.pt"), gen_type, params).cpu().permute(1, 0).float().numpy()
105
+
106
+ end = time.time()
107
+
108
+ # 创建输入配置的JSON
109
+ input_config = {
110
+ "lyric": lyric_norm,
111
+ "genre": genre,
112
+ "prompt_audio": prompt_audio,
113
+ "description": description,
114
+ "params": params,
115
+ "inference_duration": end - start,
116
+ "timestamp": datetime.now().isoformat(),
117
+ }
118
+
119
+ return (sample_rate, audio_data), json.dumps(input_config, indent=2)
120
+
121
+
122
+ # 创建Gradio界面
123
+ with gr.Blocks(title="SongGeneration Demo Space") as demo:
124
+ gr.Markdown("# 🎵 SongGeneration Demo Space")
125
+ gr.Markdown("Demo interface for the song generation model. Provide a lyrics, and optionally an audio or text prompt, to generate a custom song. The code is in [GIT](https://github.com/tencent-ailab/SongGeneration)")
126
+
127
+ with gr.Row():
128
+ with gr.Column():
129
+ lyric = gr.Textbox(
130
+ label="Lyrics",
131
+ lines=5,
132
+ max_lines=15,
133
+ value=EXAMPLE_LYRICS,
134
+ info="Each paragraph represents a segment starting with a structure tag and ending with a blank line, each line is a sentence without punctuation, segments [intro], [inst], [outro] should not contain lyrics, while [verse], [chorus], and [bridge] require lyrics.",
135
+ placeholder="""Lyric Format
136
+ '''
137
+ [structure tag]
138
+ lyrics
139
+
140
+ [structure tag]
141
+ lyrics
142
+ '''
143
+ 1. One paragraph represents one segments, starting with a structure tag and ending with a blank line
144
+ 2. One line represents one sentence, punctuation is not recommended inside the sentence
145
+ 3. The following segments should not contain lyrics: [intro-short], [intro-medium], [inst-short], [inst-medium], [outro-short], [outro-medium]
146
+ 4. The following segments require lyrics: [verse], [chorus], [bridge]
147
+ """
148
+ )
149
+
150
+ with gr.Tabs(elem_id="extra-tabs"):
151
+ with gr.Tab("Genre Select"):
152
+ genre = gr.Radio(
153
+ choices=["Auto", "Pop", "Latin", "Rock", "Electronic", "Metal", "Country", "R&B/Soul", "Ballad", "Jazz", "World", "Hip-Hop", "Funk", "Soundtrack"],
154
+ label="Genre Select(Optional)",
155
+ value="Pop",
156
+ interactive=True,
157
+ elem_id="single-select-radio"
158
+ )
159
+ with gr.Tab("Audio Prompt"):
160
+ prompt_audio = gr.Audio(
161
+ label="Prompt Audio (Optional)",
162
+ type="filepath",
163
+ elem_id="audio-prompt"
164
+ )
165
+ with gr.Tab("Text Prompt"):
166
+ gr.Markdown("For detailed usage, please refer to [here](https://github.com/tencent-ailab/SongGeneration?tab=readme-ov-file#-description-input-format)")
167
+ description = gr.Textbox(
168
+ label="Song Description (Optional)",
169
+ info="Describe the gender, timbre, genre, emotion, instrument and bpm of the song. Only English is supported currently.​",
170
+ placeholder="female, dark, pop, sad, piano and drums, the bpm is 125.",
171
+ lines=1,
172
+ max_lines=2
173
+ )
174
+
175
+ with gr.Accordion("Advanced Config", open=False):
176
+ cfg_coef = gr.Slider(
177
+ label="CFG Coefficient",
178
+ minimum=0.1,
179
+ maximum=3.0,
180
+ step=0.1,
181
+ value=1.5,
182
+ interactive=True,
183
+ elem_id="cfg-coef",
184
+ )
185
+ temperature = gr.Slider(
186
+ label="Temperature",
187
+ minimum=0.1,
188
+ maximum=2.0,
189
+ step=0.1,
190
+ value=0.9,
191
+ interactive=True,
192
+ elem_id="temperature",
193
+ )
194
+ top_k = gr.Slider(
195
+ label="Top-K",
196
+ minimum=1,
197
+ maximum=100,
198
+ step=1,
199
+ value=50,
200
+ interactive=True,
201
+ elem_id="top_k",
202
+ )
203
+ with gr.Row():
204
+ generate_btn = gr.Button("Generate Song", variant="primary")
205
+ generate_bgm_btn = gr.Button("Generate Pure Music", variant="primary")
206
+
207
+ with gr.Column():
208
+ output_audio = gr.Audio(label="Generated Song", type="numpy")
209
+ output_json = gr.JSON(label="Generated Info")
210
+
211
+ # # 示例按钮
212
+ # examples = gr.Examples(
213
+ # examples=[
214
+ # ["male, bright, rock, happy, electric guitar and drums, the bpm is 150."],
215
+ # ["female, warm, jazz, romantic, synthesizer and piano, the bpm is 100."]
216
+ # ],
217
+ # inputs=[description],
218
+ # label="Text Prompt examples"
219
+ # )
220
+
221
+ # examples = gr.Examples(
222
+ # examples=[
223
+ # "[intro-medium]\n\n[verse]\n在这个疯狂的世界里\n谁不渴望一点改变\n在爱情面前\n我们都显得那么不安全\n你紧紧抱着我\n告诉我再靠近一点\n别让这璀璨的夜晚白白浪费\n我那迷茫的眼睛\n看不见未来的路\n在情感消散之前\n我们对爱的渴望永不熄灭\n你给我留下一句誓言\n想知道我们的爱是否能持续到永远\n[chorus]\n\n约定在那最后的夜晚\n不管命运如何摆布\n我们的心是否依然如初\n我会穿上红衬衫\n带着摇滚的激情\n回到我们初遇的地方\n约定在那最后的夜晚\n就算全世界都变了样\n我依然坚守诺言\n铭记这一天\n你永远是我心中的爱恋\n\n[outro-medium]\n",
224
+ # "[intro-short]\n\n[verse]\nThrough emerald canyons where fireflies dwell\nCerulean berries kiss morning's first swell\nCrystalline dew crowns each Vitamin Dawn's confection dissolves slowly on me\nAmbrosia breezes through honeycomb vines\nNature's own candy in Fibonacci lines\n[chorus] Blueberry fruit so sweet\n takes you higher\n can't be beat\n In your lungs\n it starts to swell\n You're under its spell\n [verse] Resin of sunlight in candied retreat\nMarmalade moonbeams melt under bare feet\nNectar spirals bloom chloroplast champagne\nPhotosynthesis sings through my veins\nChlorophyll rhythms pulse warm in my blood\nThe forest's green pharmacy floods every bud[chorus] Blueberry fruit so sweet\n takes you higher\n can't be beat\n In your lungs\n it starts to swell\n You're under its spell\n feel the buzz\n ride the wave\n Limey me\n blueberry\n your mind's enslaved\n In the haze\n lose all time\n floating free\n feeling fine\n Blueberry\n fruit so sweet\n takes you higher\n can't be beat\n In your lungs\n it starts to swell\n cry\n You're under its spell\n\n[outro-short]\n",
225
+ # ],
226
+ # inputs=[lyric],
227
+ # label="Lyrics examples",
228
+ # )
229
+
230
+ # 生成按钮点击事件
231
+ generate_btn.click(
232
+ fn=generate_song,
233
+ inputs=[lyric, description, prompt_audio, genre, cfg_coef, temperature, top_k],
234
+ outputs=[output_audio, output_json]
235
+ )
236
+ generate_bgm_btn.click(
237
+ fn=generate_song,
238
+ inputs=[lyric, description, prompt_audio, genre, cfg_coef, temperature, top_k, gr.State("bgm")],
239
+ outputs=[output_audio, output_json]
240
+ )
241
+
242
+
243
+ # 启动应用
244
+ if __name__ == "__main__":
245
+ torch.set_num_threads(1)
246
+ demo.launch(server_name="0.0.0.0", server_port=8081)
tools/gradio/levo_inference.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import sys
4
+
5
+ import torch
6
+
7
+ import json
8
+ import numpy as np
9
+ from omegaconf import OmegaConf
10
+
11
+ from codeclm.trainer.codec_song_pl import CodecLM_PL
12
+ from codeclm.models import CodecLM
13
+
14
+ from separator import Separator
15
+
16
+
17
+ def check_language_by_text(text):
18
+ chinese_pattern = re.compile(r'[\u4e00-\u9fff]')
19
+ english_pattern = re.compile(r'[a-zA-Z]')
20
+ chinese_count = len(re.findall(chinese_pattern, text))
21
+ english_count = len(re.findall(english_pattern, text))
22
+ chinese_ratio = chinese_count / len(text)
23
+ english_ratio = english_count / len(text)
24
+ if chinese_ratio >= 0.2:
25
+ return "zh"
26
+ elif english_ratio >= 0.5:
27
+ return "en"
28
+ else:
29
+ return "en"
30
+
31
+
32
+ class LeVoInference(torch.nn.Module):
33
+ def __init__(self, ckpt_path):
34
+ super().__init__()
35
+
36
+ torch.backends.cudnn.enabled = False
37
+ OmegaConf.register_new_resolver("eval", lambda x: eval(x))
38
+ OmegaConf.register_new_resolver("concat", lambda *x: [xxx for xx in x for xxx in xx])
39
+ OmegaConf.register_new_resolver("get_fname", lambda: 'default')
40
+ OmegaConf.register_new_resolver("load_yaml", lambda x: list(OmegaConf.load(x)))
41
+
42
+ cfg_path = os.path.join(ckpt_path, 'config.yaml')
43
+ pt_path = os.path.join(ckpt_path, 'model.pt')
44
+
45
+ self.cfg = OmegaConf.load(cfg_path)
46
+ self.cfg.mode = 'inference'
47
+ self.max_duration = self.cfg.max_dur
48
+
49
+ # Define model or load pretrained model
50
+ model_light = CodecLM_PL(self.cfg, pt_path, version="v2")
51
+
52
+ model_light = model_light.eval().cuda()
53
+ model_light.audiolm.cfg = self.cfg
54
+
55
+ self.model_lm = model_light.audiolm
56
+ self.model_audio_tokenizer = model_light.audio_tokenizer
57
+ self.model_seperate_tokenizer = model_light.seperate_tokenizer
58
+
59
+ self.model = CodecLM(name = "tmp",
60
+ lm = self.model_lm,
61
+ audiotokenizer = self.model_audio_tokenizer,
62
+ max_duration = self.max_duration,
63
+ seperate_tokenizer = self.model_seperate_tokenizer,
64
+ )
65
+ self.separator = Separator()
66
+
67
+
68
+ self.default_params = dict(
69
+ cfg_coef = 1.5,
70
+ temperature = 1.0,
71
+ top_k = 50,
72
+ top_p = 0.0,
73
+ record_tokens = True,
74
+ record_window = 50,
75
+ extend_stride = 5,
76
+ duration = self.max_duration,
77
+ )
78
+
79
+ self.model.set_generation_params(**self.default_params)
80
+
81
+ def forward(self, lyric: str, description: str = None, prompt_audio_path: os.PathLike = None, genre: str = None, auto_prompt_path: os.PathLike = None, gen_type: str = "mixed", params = dict()):
82
+ params = {**self.default_params, **params}
83
+ self.model.set_generation_params(**params)
84
+
85
+ if prompt_audio_path is not None and os.path.exists(prompt_audio_path):
86
+ pmt_wav, vocal_wav, bgm_wav = self.separator.run(prompt_audio_path)
87
+ melody_is_wav = True
88
+ elif genre is not None and auto_prompt_path is not None:
89
+ auto_prompt = torch.load(auto_prompt_path)
90
+ lang = check_language_by_text(lyric)
91
+ prompt_token = auto_prompt[genre][lang][np.random.randint(0, len(auto_prompt[genre][lang]))]
92
+ pmt_wav = prompt_token[:,[0],:]
93
+ vocal_wav = prompt_token[:,[1],:]
94
+ bgm_wav = prompt_token[:,[2],:]
95
+ melody_is_wav = False
96
+ else:
97
+ pmt_wav = None
98
+ vocal_wav = None
99
+ bgm_wav = None
100
+ melody_is_wav = True
101
+
102
+ if gen_type == 'bgm':
103
+ description = '[Musicality-very-high]' + ', ' + '[Pure-Music]' + ', ' + description.lower() if description else '.'
104
+ else:
105
+ description = description.lower() if description else '.'
106
+ description = '[Musicality-very-high]' + ', ' + description
107
+
108
+ generate_inp = {
109
+ 'lyrics': [lyric.replace(" ", " ")] if gen_type != 'bgm' else '.',
110
+ 'descriptions': [description],
111
+ 'melody_wavs': pmt_wav,
112
+ 'vocal_wavs': vocal_wav,
113
+ 'bgm_wavs': bgm_wav,
114
+ 'melody_is_wav': melody_is_wav,
115
+ }
116
+
117
+ with torch.autocast(device_type="cuda", dtype=torch.float16):
118
+ tokens = self.model.generate(**generate_inp, return_tokens=True)
119
+
120
+ with torch.no_grad():
121
+ if melody_is_wav:
122
+ wav_seperate = self.model.generate_audio(tokens, pmt_wav, vocal_wav, bgm_wav, gen_type=gen_type)
123
+ else:
124
+ wav_seperate = self.model.generate_audio(tokens, gen_type=gen_type)
125
+
126
+ return wav_seperate[0]
tools/gradio/levo_inference_lowmem.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import re
4
+ import sys
5
+
6
+ import torch
7
+
8
+ import json
9
+ import numpy as np
10
+ from omegaconf import OmegaConf
11
+
12
+ from codeclm.trainer.codec_song_pl import CodecLM_PL
13
+ from codeclm.models import CodecLM
14
+ from codeclm.models import builders
15
+
16
+ from separator import Separator
17
+ from codeclm.utils.offload_profiler import OffloadProfiler, OffloadParamParse
18
+
19
+
20
+ def check_language_by_text(text):
21
+ chinese_pattern = re.compile(r'[\u4e00-\u9fff]')
22
+ english_pattern = re.compile(r'[a-zA-Z]')
23
+ chinese_count = len(re.findall(chinese_pattern, text))
24
+ english_count = len(re.findall(english_pattern, text))
25
+ chinese_ratio = chinese_count / len(text)
26
+ english_ratio = english_count / len(text)
27
+ if chinese_ratio >= 0.2:
28
+ return "zh"
29
+ elif english_ratio >= 0.5:
30
+ return "en"
31
+ else:
32
+ return "en"
33
+
34
+
35
+ class LeVoInference(torch.nn.Module):
36
+ def __init__(self, ckpt_path):
37
+ super().__init__()
38
+
39
+ torch.backends.cudnn.enabled = False
40
+ OmegaConf.register_new_resolver("eval", lambda x: eval(x))
41
+ OmegaConf.register_new_resolver("concat", lambda *x: [xxx for xx in x for xxx in xx])
42
+ OmegaConf.register_new_resolver("get_fname", lambda: 'default')
43
+ OmegaConf.register_new_resolver("load_yaml", lambda x: list(OmegaConf.load(x)))
44
+
45
+ cfg_path = os.path.join(ckpt_path, 'config.yaml')
46
+ self.pt_path = os.path.join(ckpt_path, 'model.pt')
47
+
48
+ self.cfg = OmegaConf.load(cfg_path)
49
+ self.cfg.mode = 'inference'
50
+ self.max_duration = self.cfg.max_dur
51
+
52
+ self.default_params = dict(
53
+ top_p = 0.0,
54
+ record_tokens = True,
55
+ record_window = 50,
56
+ extend_stride = 5,
57
+ duration = self.max_duration,
58
+ )
59
+
60
+
61
+ def forward(self, lyric: str, description: str = None, prompt_audio_path: os.PathLike = None, genre: str = None, auto_prompt_path: os.PathLike = None, gen_type: str = "mixed", params = dict()):
62
+ if prompt_audio_path is not None and os.path.exists(prompt_audio_path):
63
+ separator = Separator()
64
+ audio_tokenizer = builders.get_audio_tokenizer_model(self.cfg.audio_tokenizer_checkpoint, self.cfg)
65
+ audio_tokenizer = audio_tokenizer.eval().cuda()
66
+ pmt_wav, vocal_wav, bgm_wav = separator.run(prompt_audio_path)
67
+ pmt_wav = pmt_wav.cuda()
68
+ vocal_wav = vocal_wav.cuda()
69
+ bgm_wav = bgm_wav.cuda()
70
+ with torch.no_grad():
71
+ pmt_wav, _ = audio_tokenizer.encode(pmt_wav)
72
+ del audio_tokenizer
73
+ del separator
74
+ torch.cuda.empty_cache()
75
+
76
+ seperate_tokenizer = builders.get_audio_tokenizer_model(self.cfg.audio_tokenizer_checkpoint_sep, self.cfg)
77
+ seperate_tokenizer = seperate_tokenizer.eval().cuda()
78
+ with torch.no_grad():
79
+ vocal_wav, bgm_wav = seperate_tokenizer.encode(vocal_wav, bgm_wav)
80
+ del seperate_tokenizer
81
+ melody_is_wav = False
82
+ torch.cuda.empty_cache()
83
+ elif genre is not None and auto_prompt_path is not None:
84
+ auto_prompt = torch.load(auto_prompt_path)
85
+ lang = check_language_by_text(lyric)
86
+ prompt_token = auto_prompt[genre][lang][np.random.randint(0, len(auto_prompt[genre][lang]))]
87
+ pmt_wav = prompt_token[:,[0],:]
88
+ vocal_wav = prompt_token[:,[1],:]
89
+ bgm_wav = prompt_token[:,[2],:]
90
+ melody_is_wav = False
91
+ else:
92
+ pmt_wav = None
93
+ vocal_wav = None
94
+ bgm_wav = None
95
+ melody_is_wav = True
96
+
97
+ audiolm = builders.get_lm_model(self.cfg, version="v2")
98
+ checkpoint = torch.load(self.pt_path, map_location='cpu')
99
+ audiolm_state_dict = {k.replace('audiolm.', ''): v for k, v in checkpoint.items() if k.startswith('audiolm')}
100
+ audiolm.load_state_dict(audiolm_state_dict, strict=False)
101
+ audiolm = audiolm.eval()
102
+
103
+ offload_audiolm = True if 'offload' in self.cfg.keys() and 'audiolm' in self.cfg.offload else False
104
+ if offload_audiolm:
105
+ audiolm_offload_param = OffloadParamParse.parse_config(audiolm, self.cfg.offload.audiolm)
106
+ audiolm_offload_param.show()
107
+ offload_profiler = OffloadProfiler(device_index=0, **(audiolm_offload_param.init_param_dict()))
108
+ offload_profiler.offload_layer(**(audiolm_offload_param.offload_layer_param_dict()))
109
+ offload_profiler.clean_cache_wrapper(**(audiolm_offload_param.clean_cache_param_dict()))
110
+ else:
111
+ audiolm = audiolm.cuda().to(torch.float16)
112
+
113
+ model = CodecLM(name = "tmp",
114
+ lm = audiolm,
115
+ audiotokenizer = None,
116
+ max_duration = self.max_duration,
117
+ seperate_tokenizer = None,
118
+ )
119
+ params = {**self.default_params, **params}
120
+ model.set_generation_params(**params)
121
+
122
+ if gen_type == 'bgm':
123
+ description = '[Musicality-very-high]' + ', ' + '[Pure-Music]' + ', ' + description.lower() if description else '.'
124
+ else:
125
+ description = description.lower() if description else '.'
126
+ description = '[Musicality-very-high]' + ', ' + description
127
+
128
+ generate_inp = {
129
+ 'lyrics': [lyric.replace(" ", " ")] if gen_type != 'bgm' else '.',
130
+ 'descriptions': [description],
131
+ 'melody_wavs': pmt_wav,
132
+ 'vocal_wavs': vocal_wav,
133
+ 'bgm_wavs': bgm_wav,
134
+ 'melody_is_wav': melody_is_wav,
135
+ }
136
+
137
+ with torch.autocast(device_type="cuda", dtype=torch.float16):
138
+ with torch.no_grad():
139
+ tokens = model.generate(**generate_inp, return_tokens=True)
140
+ if offload_audiolm:
141
+ offload_profiler.reset_empty_cache_mem_line()
142
+ offload_profiler.stop()
143
+ del offload_profiler
144
+ del audiolm_offload_param
145
+ del model
146
+ audiolm = audiolm.cpu()
147
+ del audiolm
148
+ del checkpoint
149
+ gc.collect()
150
+ torch.cuda.empty_cache()
151
+
152
+ seperate_tokenizer = builders.get_audio_tokenizer_model_cpu(self.cfg.audio_tokenizer_checkpoint_sep, self.cfg)
153
+ device = "cuda:0"
154
+ seperate_tokenizer.model.device = device
155
+ seperate_tokenizer.model.vae = seperate_tokenizer.model.vae.to(device)
156
+ seperate_tokenizer.model.model.device = torch.device(device)
157
+ seperate_tokenizer = seperate_tokenizer.eval()
158
+
159
+ offload_wav_tokenizer_diffusion = True if 'offload' in self.cfg.keys() and 'wav_tokenizer_diffusion' in self.cfg.offload else False
160
+ if offload_wav_tokenizer_diffusion:
161
+ sep_offload_param = OffloadParamParse.parse_config(seperate_tokenizer, self.cfg.offload.wav_tokenizer_diffusion)
162
+ sep_offload_param.show()
163
+ sep_offload_profiler = OffloadProfiler(device_index=0, **(sep_offload_param.init_param_dict()))
164
+ sep_offload_profiler.offload_layer(**(sep_offload_param.offload_layer_param_dict()))
165
+ sep_offload_profiler.clean_cache_wrapper(**(sep_offload_param.clean_cache_param_dict()))
166
+ else:
167
+ seperate_tokenizer.model.model = seperate_tokenizer.model.model.to(device)
168
+
169
+ model = CodecLM(name = "tmp",
170
+ lm = None,
171
+ audiotokenizer = None,
172
+ max_duration = self.max_duration,
173
+ seperate_tokenizer = seperate_tokenizer,
174
+ )
175
+
176
+ with torch.no_grad():
177
+ if melody_is_wav:
178
+ wav_seperate = model.generate_audio(tokens, pmt_wav, vocal_wav, bgm_wav, gen_type=gen_type, chunked=True)
179
+ else:
180
+ wav_seperate = model.generate_audio(tokens, gen_type=gen_type, chunked=True)
181
+
182
+ if offload_wav_tokenizer_diffusion:
183
+ sep_offload_profiler.reset_empty_cache_mem_line()
184
+ sep_offload_profiler.stop()
185
+ torch.cuda.empty_cache()
186
+
187
+ return wav_seperate[0]
tools/gradio/run.sh ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export OMP_NUM_THREADS=1
2
+ export MKL_NUM_THREADS=1
3
+ export CUDA_LAUNCH_BLOCKING=0
4
+
5
+ export USER=root
6
+ export PYTHONDONTWRITEBYTECODE=1
7
+ export TRANSFORMERS_CACHE="$(pwd)/third_party/hub"
8
+ export NCCL_HOME=/usr/local/tccl
9
+ export PYTHONPATH="$(pwd)/codeclm/tokenizer/":"$(pwd)":"$(pwd)/codeclm/tokenizer/Flow1dVAE/":"$(pwd)/codeclm/tokenizer/":$PYTHONPATH
10
+
11
+
12
+ CKPT_PATH=$1
13
+ python3 tools/gradio/app.py $CKPT_PATH
tools/gradio/separator.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchaudio
2
+ import os
3
+ import torch
4
+ from third_party.demucs.models.pretrained import get_model_from_yaml
5
+
6
+
7
+ class Separator(torch.nn.Module):
8
+ def __init__(self, dm_model_path='third_party/demucs/ckpt/htdemucs.pth', dm_config_path='third_party/demucs/ckpt/htdemucs.yaml', gpu_id=0) -> None:
9
+ super().__init__()
10
+ if torch.cuda.is_available() and gpu_id < torch.cuda.device_count():
11
+ self.device = torch.device(f"cuda:{gpu_id}")
12
+ else:
13
+ self.device = torch.device("cpu")
14
+ self.demucs_model = self.init_demucs_model(dm_model_path, dm_config_path)
15
+
16
+ def init_demucs_model(self, model_path, config_path):
17
+ model = get_model_from_yaml(config_path, model_path)
18
+ model.to(self.device)
19
+ model.eval()
20
+ return model
21
+
22
+ def load_audio(self, f):
23
+ a, fs = torchaudio.load(f)
24
+ if (fs != 48000):
25
+ a = torchaudio.functional.resample(a, fs, 48000)
26
+ if a.shape[-1] >= 48000*10:
27
+ a = a[..., :48000*10]
28
+ else:
29
+ a = torch.cat([a, a], -1)
30
+ return a[:, 0:48000*10]
31
+
32
+ def run(self, audio_path, output_dir='tmp', ext=".flac"):
33
+ os.makedirs(output_dir, exist_ok=True)
34
+ name, _ = os.path.splitext(os.path.split(audio_path)[-1])
35
+ output_paths = []
36
+
37
+ for stem in self.demucs_model.sources:
38
+ output_path = os.path.join(output_dir, f"{name}_{stem}{ext}")
39
+ if os.path.exists(output_path):
40
+ output_paths.append(output_path)
41
+ if len(output_paths) == 1: # 4
42
+ vocal_path = output_paths[0]
43
+ else:
44
+ drums_path, bass_path, other_path, vocal_path = self.demucs_model.separate(audio_path, output_dir, device=self.device)
45
+ for path in [drums_path, bass_path, other_path]:
46
+ os.remove(path)
47
+ full_audio = self.load_audio(audio_path)
48
+ vocal_audio = self.load_audio(vocal_path)
49
+ bgm_audio = full_audio - vocal_audio
50
+ return full_audio, vocal_audio, bgm_audio
tools/new_auto_prompt.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:15d65fd20de92de44fa3b8b867d376d4c1756e56b18c9edf4c235ab66a3dae2e
3
+ size 133