| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import evaluate |
| | import torch |
| | from datasets import load_dataset |
| | from torch.optim import AdamW |
| | from torch.utils.data import DataLoader |
| | from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup |
| |
|
| | from accelerate import Accelerator, DistributedType |
| | from accelerate.utils import set_seed |
| |
|
| | import transformers |
| |
|
| | transformers.logging.set_verbosity_error() |
| |
|
| | import os |
| | from torch.nn.parallel import DistributedDataParallel |
| | import torch.distributed as torch_distributed |
| |
|
| |
|
| |
|
| | def get_dataloaders(batch_size: int = 16): |
| | """ |
| | Creates a set of `DataLoader`s for the `glue` dataset, |
| | using "bert-base-cased" as the tokenizer. |
| | |
| | Args: |
| | accelerator (`Accelerator`): |
| | An `Accelerator` object |
| | batch_size (`int`, *optional*): |
| | The batch size for the train and validation DataLoaders. |
| | """ |
| | tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") |
| | datasets = load_dataset("glue", "mrpc") |
| |
|
| | def tokenize_function(examples): |
| | outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=None) |
| | return outputs |
| |
|
| | tokenized_datasets = datasets.map( |
| | tokenize_function, |
| | batched=True, |
| | remove_columns=["idx", "sentence1", "sentence2"], |
| | ) |
| | tokenized_datasets = tokenized_datasets.rename_column("label", "labels") |
| |
|
| | def collate_fn(examples): |
| | return tokenizer.pad( |
| | examples, |
| | padding="longest", |
| | max_length=None, |
| | pad_to_multiple_of=8, |
| | return_tensors="pt", |
| | ) |
| |
|
| | train_dataloader = DataLoader( |
| | tokenized_datasets["train"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size, drop_last=True |
| | ) |
| | eval_dataloader = DataLoader( |
| | tokenized_datasets["validation"], |
| | shuffle=False, |
| | collate_fn=collate_fn, |
| | batch_size=32, |
| | drop_last=False, |
| | ) |
| |
|
| | return train_dataloader, eval_dataloader |
| |
|
| |
|
| | def training_function(): |
| | torch_distributed.init_process_group(backend="nccl") |
| | num_processes = torch_distributed.get_world_size() |
| | process_index = torch_distributed.get_rank() |
| | local_process_index = int(os.environ.get("LOCAL_RANK", -1)) |
| | device = torch.device("cuda", local_process_index) |
| | torch.cuda.set_device(device) |
| | config = {"lr": 2e-5, "num_epochs": 3, "seed": 42} |
| | seed = int(config["seed"]) |
| | batch_size = 32 |
| | config["batch_size"] = batch_size |
| | metric = evaluate.load("glue", "mrpc") |
| |
|
| | set_seed(seed, device_specific=False) |
| | train_dataloader, eval_dataloader = get_dataloaders(batch_size) |
| | model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", return_dict=True).to(device) |
| | model = DistributedDataParallel( |
| | model, device_ids=[local_process_index], output_device=local_process_index |
| | ) |
| |
|
| | optimizer = AdamW(params=model.parameters(), lr=config["lr"]) |
| | lr_scheduler = get_linear_schedule_with_warmup( |
| | optimizer=optimizer, |
| | num_warmup_steps=0, |
| | num_training_steps=(len(train_dataloader) * config["num_epochs"]), |
| | ) |
| |
|
| | current_step = 0 |
| | for epoch in range(config["num_epochs"]): |
| | model.train() |
| | total_loss = 0 |
| | for _, batch in enumerate(train_dataloader): |
| | batch = batch.to(device) |
| | outputs = model(**batch) |
| | loss = outputs.loss |
| | total_loss += loss.detach().cpu().float() |
| | current_step += 1 |
| | loss.backward() |
| | optimizer.step() |
| | lr_scheduler.step() |
| | optimizer.zero_grad() |
| |
|
| | model.eval() |
| | for step, batch in enumerate(eval_dataloader): |
| | |
| | batch = batch.to(device) |
| | with torch.no_grad(): |
| | outputs = model(**batch) |
| | predictions = outputs.logits.argmax(dim=-1) |
| | metric.add_batch( |
| | predictions=predictions, |
| | references=batch["labels"], |
| | ) |
| |
|
| | eval_metric = metric.compute() |
| | if process_index == 0: |
| | print( |
| | f"epoch {epoch}: {eval_metric}\n" |
| | f"train_loss: {total_loss.item()/len(train_dataloader)}" |
| | ) |
| |
|
| |
|
| | def main(): |
| | training_function() |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|