import torch
import lightning as L
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import (
GPT2Config,
GPT2LMHeadModel,
GPT2Tokenizer,
DataCollatorForLanguageModeling,
)
# Configuration
MAX_LENGTH = 1024
BATCH_SIZE = 8
LEARNING_RATE = 5e-5
NUM_EPOCHS = 3
# S3 checkpoint path
S3_CHECKPOINT_DIR = "s3://my-training-checkpoints/gpt2-pretrain/"
class GPT2PretrainingModule(L.LightningModule):
def __init__(self):
super().__init__()
config = GPT2Config(
vocab_size=50257,
n_positions=1024,
n_embd=768,
n_layer=12,
n_head=12,
)
self.model = GPT2LMHeadModel(config)
def training_step(self, batch, batch_idx):
outputs = self.model(**batch)
loss = outputs.loss
self.log("train_loss", loss, prog_bar=True, sync_dist=True)
return loss
def configure_optimizers(self):
return torch.optim.AdamW(self.parameters(), lr=LEARNING_RATE)
def create_dataloader():
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
dataset = load_dataset("wikitext", "wikitext-103-v1", split="train")
def tokenize_function(examples):
return tokenizer(
examples["text"],
truncation=True,
max_length=MAX_LENGTH,
padding="max_length",
)
tokenized_dataset = dataset.map(
tokenize_function,
batched=True,
remove_columns=["text"],
num_proc=4,
)
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False,
)
return DataLoader(
tokenized_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
collate_fn=data_collator,
num_workers=4,
)
if __name__ == "__main__":
model = GPT2PretrainingModule()
train_loader = create_dataloader()
trainer = L.Trainer(
max_epochs=NUM_EPOCHS,
accelerator="gpu",
devices=8,
strategy="ddp",
precision="bf16-mixed",
# S3 checkpointing - Lightning writes directly to S3 via fsspec
default_root_dir=S3_CHECKPOINT_DIR,
enable_checkpointing=True,
)
trainer.fit(model, train_loader)