| import json |
| from pathlib import Path |
|
|
| import gradio as gr |
| import pandas as pd |
| from functools import partial |
| from defaults import DEFAULTS |
| from details import ACCURACY, DETAILS, INSTRUCTIONS, LIMITATIONS |
| from state import Model, Parallelism, Training |
| from calculator import MemoryCalculation |
| from dtypes import DType |
|
|
| |
| NaturalNumber = partial(gr.Number, minimum=1, step=1, precision=0, interactive=True) |
|
|
| def create_parallelism_block(): |
| with gr.Column(): |
| gr.Markdown("# Parallelism") |
| with gr.Group(): |
| tp = NaturalNumber(label="Tensor Parallelism", value=1) |
| pp = NaturalNumber(label="Pipeline Parallelism", value=1) |
| cp = NaturalNumber(label="Context Parallelism", value=1) |
| ep = NaturalNumber(label="Expert Parallelism", value=1) |
|
|
| fsdp_enabled = gr.Checkbox(label="FSDP (Fully Sharded Data Parallel)", value=True) |
| fsdp_parallelism = NaturalNumber(label="FSDP Parallelism", value=8) |
| fsdp_strategy = gr.Radio( |
| choices=["Zero-1", "Zero-2", "Zero-3"], |
| label="FSDP Strategy", |
| value="Zero-3" |
| ) |
|
|
| |
| fsdp_enabled.change( |
| fn=lambda x: [ |
| gr.update(interactive=x, elem_classes=[] if x else ["disabled-field"]), |
| gr.update(interactive=x, elem_classes=[] if x else ["disabled-field"]) |
| ], |
| inputs=fsdp_enabled, |
| outputs=[fsdp_parallelism, fsdp_strategy] |
| ) |
|
|
| return tp, pp, cp, ep, fsdp_enabled, fsdp_parallelism, fsdp_strategy |
|
|
|
|
| def create_model_block(): |
| with gr.Column(): |
| gr.Markdown("# Model Architecture") |
| layers = NaturalNumber(label="Number of Layers", value=32) |
| vocab = NaturalNumber(label="Vocab Size", value=128256) |
| hidden = NaturalNumber(label="Hidden Dim", value=4096) |
| intermediate = NaturalNumber(label="Intermediate Dim", value=14336) |
| is_moe = gr.Checkbox(label="Mixture of Experts (MoE)", value=False) |
| active_experts = NaturalNumber(label="Active Experts", value=1, interactive=False, elem_classes="disabled-field") |
| total_experts = NaturalNumber(label="Total Experts", value=1, interactive=False, elem_classes="disabled-field") |
| weight_tied_embeddings = gr.Checkbox(label="Weight Tied Embeddings", value=True) |
|
|
| |
| is_moe.change( |
| fn=lambda x: [ |
| gr.update(interactive=x, elem_classes=[] if x else ["disabled-field"]), |
| gr.update(interactive=x, elem_classes=[] if x else ["disabled-field"]) |
| ], |
| inputs=is_moe, |
| outputs=[active_experts, total_experts] |
| ) |
|
|
| presets = gr.Dropdown(["Custom"] + list(DEFAULTS.keys()), label="Presets", value="Llama3 8B", interactive=True) |
|
|
| |
| def populate_from_preset(preset_name): |
| if preset_name and preset_name in DEFAULTS: |
| model = DEFAULTS[preset_name] |
| return [ |
| gr.update(value=model.num_layers), |
| gr.update(value=model.vocab_size), |
| gr.update(value=model.hidden_dim), |
| gr.update(value=model.intermediate_size), |
| gr.update(value=model.is_moe), |
| gr.update(value=model.active_experts, interactive=model.is_moe), |
| gr.update(value=model.total_experts, interactive=model.is_moe), |
| gr.update(value=model.weight_tied_embeddings) |
| ] |
| return [gr.update() for _ in range(8)] |
|
|
| |
| def switch_to_custom(layers_val, vocab_val, hidden_val, intermediate_val, is_moe_val, active_experts_val, total_experts_val, weight_tied_val, current_preset): |
| |
| if current_preset and current_preset in DEFAULTS: |
| model = DEFAULTS[current_preset] |
| |
| if (layers_val == model.num_layers and |
| vocab_val == model.vocab_size and |
| hidden_val == model.hidden_dim and |
| intermediate_val == model.intermediate_size and |
| is_moe_val == model.is_moe and |
| active_experts_val == model.active_experts and |
| total_experts_val == model.total_experts and |
| weight_tied_val == model.weight_tied_embeddings): |
| return gr.update() |
|
|
| return gr.update(value="Custom") |
|
|
| presets.change( |
| fn=populate_from_preset, |
| inputs=presets, |
| outputs=[layers, vocab, hidden, intermediate, is_moe, active_experts, total_experts, weight_tied_embeddings] |
| ) |
|
|
| |
| for input_component in [layers, vocab, hidden, intermediate, is_moe, active_experts, total_experts, weight_tied_embeddings]: |
| input_component.change( |
| fn=switch_to_custom, |
| inputs=[layers, vocab, hidden, intermediate, is_moe, active_experts, total_experts, weight_tied_embeddings, presets], |
| outputs=presets |
| ) |
|
|
| return layers, vocab, hidden, intermediate, active_experts, total_experts, is_moe, presets, weight_tied_embeddings |
|
|
|
|
| def create_training_block(): |
| with gr.Column(): |
| gr.Markdown("# Training Config") |
| seq_len = NaturalNumber(label="Sequence Length", value=4096) |
| batch_size = NaturalNumber(label="Batch Size", info="If you are using gradient accumulation, enter microbatch size", value=1) |
| with gr.Row(): |
| gradient_checkpointing = gr.Checkbox(label="Gradient Checkpointing", value=True) |
| grad_accumulation = gr.Checkbox(label="Gradient Accumulation", value=False) |
| precision = gr.Dropdown(DType.values(), label="Precision", value=DType.BF16.value, interactive=True) |
| mixed_precision = gr.Checkbox(label="Mixed Precision", value=False) |
| param_dtype = gr.Dropdown(DType.values(), label="Parameter Dtype", value=DType.FP32.value, interactive=False, elem_classes="disabled-field") |
| reduce_dtype = gr.Dropdown(DType.values(), label="Reduce Dtype", value=DType.FP32.value, interactive=False, elem_classes="disabled-field") |
|
|
| |
| mixed_precision.change( |
| fn=lambda x: [ |
| gr.update(interactive=x, elem_classes=[] if x else ["disabled-field"]), |
| gr.update(interactive=x, elem_classes=[] if x else ["disabled-field"]) |
| ], |
| inputs=mixed_precision, |
| outputs=[param_dtype, reduce_dtype] |
| ) |
|
|
| return seq_len, batch_size, gradient_checkpointing, grad_accumulation, precision, mixed_precision, param_dtype, reduce_dtype |
|
|
|
|
| def calculate(tp, pp, cp, ep, fsdp_enabled, fsdp_parallelism, fsdp_strategy, layers, vocab, hidden, intermediate, active_experts, total_experts, is_moe, weight_tied_embeddings, seq_len, batch_size, gradient_checkpointing, grad_accumulation, precision, mixed_precision, param_dtype, reduce_dtype): |
| |
| model_config = Model( |
| vocab_size=int(vocab), |
| num_layers=int(layers), |
| hidden_dim=int(hidden), |
| intermediate_size=int(intermediate), |
| weight_tied_embeddings=weight_tied_embeddings, |
| active_experts=int(active_experts), |
| total_experts=int(total_experts), |
| is_moe=is_moe |
| ) |
|
|
| parallelism_config = Parallelism( |
| tensor_parallelism=int(tp), |
| pipeline_parallelism=int(pp), |
| context_parallelism=int(cp), |
| expert_parallelism=int(ep), |
| fsdp_enabled=fsdp_enabled, |
| fsdp_parallelism=int(fsdp_parallelism), |
| fsdp_strategy=fsdp_strategy |
| ) |
|
|
| training_config = Training( |
| sequence_length=int(seq_len), |
| batch_size=int(batch_size), |
| gradient_checkpointing=gradient_checkpointing, |
| grad_accumulation=grad_accumulation, |
| precision=DType(precision), |
| mixed_precision=mixed_precision, |
| param_dtype=DType(param_dtype), |
| reduce_dtype=DType(reduce_dtype) |
| ) |
|
|
| |
| calc = MemoryCalculation(model_config, parallelism_config, training_config) |
|
|
| |
| param_memory = calc.calculate_parameter_memory() |
| activation_memory = calc.calculate_activation_memory() |
| gradient_memory = calc.calculate_gradient_memory() |
| optimizer_memory = calc.calculate_optimizer_memory() |
|
|
| |
| total_memory = param_memory + activation_memory + gradient_memory + optimizer_memory |
|
|
| |
| param_gb = round(param_memory / 1e9, 1) |
| activation_gb = round(activation_memory / 1e9, 1) |
| gradient_gb = round(gradient_memory / 1e9, 1) |
| optimizer_gb = round(optimizer_memory / 1e9, 1) |
| total_gb = round(total_memory / 1e9, 1) |
|
|
| |
| |
| individual_data = [] |
|
|
| |
| for mem_type, gb_val in [ |
| ('Activation', activation_gb), |
| ('Optimizer', optimizer_gb), |
| ('Gradient', gradient_gb), |
| ('Parameter', param_gb) |
| ]: |
| individual_data.append({ |
| 'Component': f'Total Memory\n{total_gb} GB', |
| 'Memory (GB)': gb_val, |
| 'Type': mem_type |
| }) |
|
|
| |
| for component, gb_val, mem_type in [ |
| (f'Parameter Memory\n{param_gb} GB', param_gb, 'Parameter'), |
| (f'Gradient Memory\n{gradient_gb} GB', gradient_gb, 'Gradient'), |
| (f'Optimizer Memory\n{optimizer_gb} GB', optimizer_gb, 'Optimizer'), |
| (f'Activation Memory\n{activation_gb} GB', activation_gb, 'Activation') |
| ]: |
| individual_data.append({ |
| 'Component': component, |
| 'Memory (GB)': gb_val, |
| 'Type': mem_type |
| }) |
|
|
| memory_data = pd.DataFrame(individual_data) |
|
|
| return gr.BarPlot( |
| value=memory_data, |
| x="Component", |
| y="Memory (GB)", |
| color="Type", |
| title="LLM Memory Usage Breakdown", |
| container=False, |
| y_lim=[0, None], |
| sort=[ |
| f'Total Memory\n{total_gb} GB', |
| f'Parameter Memory\n{param_gb} GB', |
| f'Gradient Memory\n{gradient_gb} GB', |
| f'Optimizer Memory\n{optimizer_gb} GB', |
| f'Activation Memory\n{activation_gb} GB' |
| ] |
| ) |
|
|
| css = """ |
| /* Style for disabled components to make them visually obvious */ |
| .disabled-field input, |
| .disabled-field select, |
| .disabled-field textarea { |
| opacity: 0.4 !important; |
| background-color: #f5f5f5 !important; |
| color: #999 !important; |
| cursor: not-allowed !important; |
| text-decoration: line-through; |
| } |
| |
| .disabled-field label { |
| opacity: 0.5 !important; |
| color: #999 !important; |
| } |
| """ |
|
|
| with gr.Blocks(theme='Default', css=css) as demo: |
| with gr.Column(): |
| gr.Markdown("# LLM Training Memory Visualizer") |
| gr.Markdown("<sub>🔧 Built by [Ruben Aghayan](https://www.linkedin.com/in/ruben-aghayan-37885690/)</sub>") |
| gr.Markdown("---") |
| gr.Markdown(INSTRUCTIONS) |
| with gr.Row(equal_height=True): |
| tp, pp, cp, ep, fsdp_enabled, fsdp_parallelism, fsdp_strategy = create_parallelism_block() |
| layers, vocab, hidden, intermediate, active_experts, total_experts, is_moe, presets, weight_tied_embeddings = create_model_block() |
| seq_len, batch_size, gradient_checkpointing, grad_accumulation, precision, mixed_precision, param_dtype, reduce_dtype = create_training_block() |
| calculate_button = gr.Button("Calculate") |
| output = gr.BarPlot(label="Memory Usage Breakdown") |
|
|
| calculate_button.click( |
| fn=calculate, |
| inputs=[ |
| tp, |
| pp, |
| cp, |
| ep, |
| fsdp_enabled, |
| fsdp_parallelism, |
| fsdp_strategy, |
| layers, |
| vocab, |
| hidden, |
| intermediate, |
| active_experts, |
| total_experts, |
| is_moe, |
| weight_tied_embeddings, |
| seq_len, |
| batch_size, |
| gradient_checkpointing, |
| grad_accumulation, |
| precision, |
| mixed_precision, |
| param_dtype, |
| reduce_dtype, |
| ], |
| outputs=output, |
| ) |
|
|
| gr.Markdown("# Details") |
| with gr.Row(): |
| gr.Markdown(LIMITATIONS) |
| gr.Markdown(DETAILS) |
| gr.Markdown("# Validation") |
| gr.Markdown(ACCURACY) |
|
|
| demo.launch(share=True) |
|
|