训练大模型也不怕,轻量级TorchShard库减少GPU内存消耗,API与PyTorch相同
第一时间获取价值内容
来自:机器之心
训练大模型时,如何优雅地减少 GPU 内存消耗?你不妨试试这个 TorchShard 库,兼具模型并行与数据并行等特点,还具有与 PyTorch 相同的 API 设计。
建立一个标准的 PyTorch 扩展库,用于使用模型并行性进行扩展训练;
以一种简单、自然的方式使用 PyTorch。
import torchshard as ts
ts.init_process_group(group_size=2) # init parallel groups
m = torch.nn.Sequential(
torch.nn.Linear(20, 30, bias=True),
ts.nn.ParallelLinear(30, 30, bias=True, dim=None), # equal to nn.Linear()
ts.nn.ParallelLinear(30, 30, bias=True, dim=0), # parallel in row dimension
ts.nn.ParallelLinear(30, 30, bias=True, dim=1), # parallel in column dimension
).cuda()
x = m(x) # forward
loss = ts.nn.functional.parallel_cross_entropy(x, y) # parallel loss function
loss.backward() # backward
torch.save(
ts.collect_state_dict(m, m.state_dict()), 'm.pt') # save model state
torchshard 包含必要的功能和操作,如 torch 包;
torchshard.nn 包含图形的基本构建块,如 torch.nn 包;
torchshard.nn.functional 包含 torchshard.nn 的相应功能操作,如 torch.nn.functional 包;
torchshard.distributed 包含处理分布式张量和组的基本功能,如 torch.distributed 包更容易使用。
pip install torchshard
import torchshard as ts
ts.distributed.init_process_group(group_size=args.world_size)
import resnet
model = resnet.__dict__[args.arch](pretrained=args.pretrained)
ts.nn.ParallelLinear.convert_parallel_linear(
model, dim=args.model_parallel_dim
)
print('=> paralleling model'{}''.format(args.arch))
criterion = ts.nn.ParallelCrossEntropyLoss().cuda(args.gpu)
x = ts.distributed.gather(x, dim=0) # gather input along the dim of batch size
x = self.fc(x)
output = model(images)
if args.enable_model_parallel:
target = ts.distributed.gather(target, dim=0)
loss = criterion(output, target)
state_dict = model.state_dict()
# collect states across all ranks
state_dict = ts.collect_state_dict(model, state_dict)
if ts.distributed.get_rank() == 0:
torch.save(state_dict, 'resnet50.pt') # save as before
if ts.distributed.get_rank() == 0:
state_dict = torch.load('resnet50.pt')
# relocate state_dict() for all ranks
state_dict = ts.relocate_state_dict(model, state_dict)
model.load_state_dict(state_dict) # load as before
# gradscaler
scaler = torch.cuda.amp.GradScaler(enabled=args.enable_amp_mode)
with torch.cuda.amp.autocast(enabled=args.enable_amp_mode): # compute output
output = model(images)
if args.enable_model_parallel:
target = ts.distributed.gather(target, dim=0)
loss = criterion(output, target)
# compute gradient and do SGD step
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
from torch.distributed.optim import ZeroRedundancyOptimizer
if args.enable_zero_optim:
print('=> using ZeroRedundancyOptimizer')
optimizer = torch.distributed.optim.ZeroRedundancyOptimizer(
model.parameters(),
optimizer_class=torch.optim.SGD,
lr=args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay)
else:
optimizer = torch.optim.SGD(model.parameters(), args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay)
-结束-
👆 长按识别,即可关注
赞 (0)