Update app.py
Browse files
app.py
CHANGED
|
@@ -189,10 +189,8 @@ def generate_music(prime,
|
|
| 189 |
num_gen_tokens,
|
| 190 |
num_mem_tokens,
|
| 191 |
num_gen_batches,
|
| 192 |
-
gen_outro,
|
| 193 |
-
gen_drums,
|
| 194 |
model_temperature,
|
| 195 |
-
model_sampling_top_p
|
| 196 |
):
|
| 197 |
|
| 198 |
if not prime:
|
|
@@ -233,10 +231,8 @@ def generate_callback(input_midi,
|
|
| 233 |
num_prime_tokens,
|
| 234 |
num_gen_tokens,
|
| 235 |
num_mem_tokens,
|
| 236 |
-
gen_outro,
|
| 237 |
-
gen_drums,
|
| 238 |
model_temperature,
|
| 239 |
-
model_sampling_top_p,
|
| 240 |
final_composition,
|
| 241 |
generated_batches,
|
| 242 |
block_lines
|
|
@@ -253,10 +249,8 @@ def generate_callback(input_midi,
|
|
| 253 |
num_gen_tokens,
|
| 254 |
num_mem_tokens,
|
| 255 |
NUM_OUT_BATCHES,
|
| 256 |
-
gen_outro,
|
| 257 |
-
gen_drums,
|
| 258 |
model_temperature,
|
| 259 |
-
model_sampling_top_p
|
| 260 |
)
|
| 261 |
|
| 262 |
outputs = []
|
|
@@ -306,10 +300,8 @@ def generate_callback_wrapper(input_midi,
|
|
| 306 |
num_prime_tokens,
|
| 307 |
num_gen_tokens,
|
| 308 |
num_mem_tokens,
|
| 309 |
-
gen_outro,
|
| 310 |
-
gen_drums,
|
| 311 |
model_temperature,
|
| 312 |
-
model_sampling_top_p,
|
| 313 |
final_composition,
|
| 314 |
generated_batches,
|
| 315 |
block_lines
|
|
@@ -328,21 +320,17 @@ def generate_callback_wrapper(input_midi,
|
|
| 328 |
print('Num prime tokens:', num_prime_tokens)
|
| 329 |
print('Num gen tokens:', num_gen_tokens)
|
| 330 |
print('Num mem tokens:', num_mem_tokens)
|
| 331 |
-
print('Gen drums:', gen_drums)
|
| 332 |
-
print('Gen outro:', gen_outro)
|
| 333 |
|
| 334 |
print('Model temp:', model_temperature)
|
| 335 |
-
print('Model top_p:', model_sampling_top_p)
|
| 336 |
print('=' * 70)
|
| 337 |
|
| 338 |
result = generate_callback(input_midi,
|
| 339 |
num_prime_tokens,
|
| 340 |
num_gen_tokens,
|
| 341 |
num_mem_tokens,
|
| 342 |
-
gen_outro,
|
| 343 |
-
gen_drums,
|
| 344 |
model_temperature,
|
| 345 |
-
model_sampling_top_p,
|
| 346 |
final_composition,
|
| 347 |
generated_batches,
|
| 348 |
block_lines
|
|
@@ -494,13 +482,11 @@ with gr.Blocks() as demo:
|
|
| 494 |
|
| 495 |
gr.Markdown("## Generate")
|
| 496 |
|
| 497 |
-
num_prime_tokens = gr.Slider(15,
|
| 498 |
-
num_gen_tokens = gr.Slider(15,
|
| 499 |
-
num_mem_tokens = gr.Slider(15,
|
| 500 |
-
gen_drums = gr.Checkbox(value=False, label="Introduce drums")
|
| 501 |
-
gen_outro = gr.Radio(["Auto", "Disable", "Force"], value="Auto", label="Outro options")
|
| 502 |
model_temperature = gr.Slider(0.1, 1, value=0.9, step=0.01, label="Model temperature")
|
| 503 |
-
model_sampling_top_p = gr.Slider(0.1, 1, value=0.96, step=0.01, label="Model sampling top p value")
|
| 504 |
|
| 505 |
generate_btn = gr.Button("Generate", variant="primary")
|
| 506 |
|
|
@@ -521,10 +507,8 @@ with gr.Blocks() as demo:
|
|
| 521 |
num_prime_tokens,
|
| 522 |
num_gen_tokens,
|
| 523 |
num_mem_tokens,
|
| 524 |
-
gen_outro,
|
| 525 |
-
gen_drums,
|
| 526 |
model_temperature,
|
| 527 |
-
model_sampling_top_p,
|
| 528 |
final_composition,
|
| 529 |
generated_batches,
|
| 530 |
block_lines
|
|
|
|
| 189 |
num_gen_tokens,
|
| 190 |
num_mem_tokens,
|
| 191 |
num_gen_batches,
|
|
|
|
|
|
|
| 192 |
model_temperature,
|
| 193 |
+
# model_sampling_top_p
|
| 194 |
):
|
| 195 |
|
| 196 |
if not prime:
|
|
|
|
| 231 |
num_prime_tokens,
|
| 232 |
num_gen_tokens,
|
| 233 |
num_mem_tokens,
|
|
|
|
|
|
|
| 234 |
model_temperature,
|
| 235 |
+
# model_sampling_top_p,
|
| 236 |
final_composition,
|
| 237 |
generated_batches,
|
| 238 |
block_lines
|
|
|
|
| 249 |
num_gen_tokens,
|
| 250 |
num_mem_tokens,
|
| 251 |
NUM_OUT_BATCHES,
|
|
|
|
|
|
|
| 252 |
model_temperature,
|
| 253 |
+
# model_sampling_top_p
|
| 254 |
)
|
| 255 |
|
| 256 |
outputs = []
|
|
|
|
| 300 |
num_prime_tokens,
|
| 301 |
num_gen_tokens,
|
| 302 |
num_mem_tokens,
|
|
|
|
|
|
|
| 303 |
model_temperature,
|
| 304 |
+
# model_sampling_top_p,
|
| 305 |
final_composition,
|
| 306 |
generated_batches,
|
| 307 |
block_lines
|
|
|
|
| 320 |
print('Num prime tokens:', num_prime_tokens)
|
| 321 |
print('Num gen tokens:', num_gen_tokens)
|
| 322 |
print('Num mem tokens:', num_mem_tokens)
|
|
|
|
|
|
|
| 323 |
|
| 324 |
print('Model temp:', model_temperature)
|
| 325 |
+
# print('Model top_p:', model_sampling_top_p)
|
| 326 |
print('=' * 70)
|
| 327 |
|
| 328 |
result = generate_callback(input_midi,
|
| 329 |
num_prime_tokens,
|
| 330 |
num_gen_tokens,
|
| 331 |
num_mem_tokens,
|
|
|
|
|
|
|
| 332 |
model_temperature,
|
| 333 |
+
# model_sampling_top_p,
|
| 334 |
final_composition,
|
| 335 |
generated_batches,
|
| 336 |
block_lines
|
|
|
|
| 482 |
|
| 483 |
gr.Markdown("## Generate")
|
| 484 |
|
| 485 |
+
num_prime_tokens = gr.Slider(15, 1024, value=1024, step=1, label="Number of prime tokens")
|
| 486 |
+
num_gen_tokens = gr.Slider(15, 1024, value=1024, step=1, label="Number of tokens to generate")
|
| 487 |
+
num_mem_tokens = gr.Slider(15, 2048, value=2048, step=1, label="Number of memory tokens")
|
|
|
|
|
|
|
| 488 |
model_temperature = gr.Slider(0.1, 1, value=0.9, step=0.01, label="Model temperature")
|
| 489 |
+
# model_sampling_top_p = gr.Slider(0.1, 1, value=0.96, step=0.01, label="Model sampling top p value")
|
| 490 |
|
| 491 |
generate_btn = gr.Button("Generate", variant="primary")
|
| 492 |
|
|
|
|
| 507 |
num_prime_tokens,
|
| 508 |
num_gen_tokens,
|
| 509 |
num_mem_tokens,
|
|
|
|
|
|
|
| 510 |
model_temperature,
|
| 511 |
+
# model_sampling_top_p,
|
| 512 |
final_composition,
|
| 513 |
generated_batches,
|
| 514 |
block_lines
|