NVIDIA / apex

A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

torch.cuda.amp > apex.amp

mcarilli opened this issue · comments

For a while now my main focus has been moving mixed precision functionality into Pytorch core. It was merged about a month ago:
https://pytorch.org/docs/master/amp.html
https://pytorch.org/docs/master/notes/amp_examples.html
and is now usable via master or nightly pip/conda packages. (Full features did not make the 1.5 release, unfortunately.)

torch.cuda.amp is more flexible and intuitive, and the native integration brings more future optimizations into scope. Also, torch.cuda.amp fixes many of apex.amp's known pain points. Some things native amp can handle that apex amp can't:

If all you want is to try mixed precision, and you're comfortable using a recent Pytorch, you don't need Apex.

Can torch.cuda.amp be used only for inferences on a FP32 model? See #750 and #809
I couldn't find an example in https://pytorch.org/docs/master/notes/amp_examples.html
Maybe just wrapping up the model call with https://pytorch.org/docs/master/amp.html#torch.cuda.amp.autocast ?

Yes. torch.cuda.amp.autocast can be enabled wherever you want and affects only ops invoked within enabled regions. autocast and torch.cuda.amp.GradScaler are modular codewise. During training, you should use both (autocast selects per-op precision, and GradScaler scales gradients) but for inference, GradScaler is not necessary, and you can use autocast by itself. Also, the model does not need to be altered to use autocast for (regions of) inference forward passes (model leaves may be FP32).

Before we dive into torch.cuda.amp, should we expect a behavior change versus this issue #475 ?
thanks.

@mcarilli what about the opt_level O1 / O2 , etc... I can't find whether that's already natively supported by torch.cuda.amp - it looks like there's no opt_level option in torch.cuda.amp ? If so, what's the opt_level being used by default when using autocast?

Another question: Will this be supported by Torchscript?

commented

How can I migrate from apex.amp to torch.cuda.amp if I already have pre-trained model with apex wrapper? Apex-wrapped models now can load like regular PyTorch models?

@Damiox torch.cuda.amp.autocast is similar to O1 in that it casts function inputs on the fly without touching model weights. However, unlike apex O1, autocast only causes casting behavior in regions the context manager is explicitly enabled. Disabled regions can also be nested in enabled regions:

with autocast():
    # autocasted ops
    with autocast(enabled=False):
        # ops that run in input dtypes as usual

autocast should work with jit tracing, if you run your tracing pass under autocast, because the trace will record the casts. I don't think it works with scripting yet, we need to make sure scripting properly parses Python context managers.

@blizda If you have a model state dict from any source saved on disk, you shouldn't need to do anything special to migrate to native Amp. Create a model in default precision (fp32), call model.load_state_dict(saved_dict), and begin training as shown in the native amp examples. autocast does not touch the model at all, it only affects op exection, and GradScaler is self-contained, it doesn't alter model or optimizer structure.

After migrating to native amp, for bitwise accurate saving/restoring, include calls to saved = scaler.state_dict() and scaler.load_state_dict(saved) along side your usual state_dict/load_state_dict calls.

@vince62s apex.optimizers.FusedAdam and torch.optim.Adam should both work out of the box with native Amp following the documented control flow (create model in default precision aka fp32). If you also need gradient clipping, see the example.

However, there may be a problem with apex.optimizers.FusedAdam that we never bottomed out on. I'm not sure what it could be because we use it internally and it works. If apex.optimizers.FusedAdam does not improve end to end performance vs torch.optim.Adam, definitely prefer torch.optim.Adam.

apex.contrib.optimizers.FusedAdam I don't believe will work, because it takes control of gradient scaling in an incompatible way. Frankly idk what's using that at this point.

seem like the model after training using torch.cuda.amp's autocast(), its dtype is fp32, if want to deploy the model , dose it need covert to fp16 manualy? it is little bit confuse.

@mcarilli It's clear how to switch with o1, but how I can use o2 optimization with torch.cuda.amp?

@ysystudio Autocast does not touch the model object itself, so its dtype (param type) remains as you created it (leaving it to default FP32 is recommended). Save the trained model then deploy it in whatever format you want.

  • If you believe it's safe to run inference with the entire network in FP16, load the trained (FP32) model, call .half() on it, and run it.
  • If you think certain ops internal to the network require FP32 even during inference, the safest approach is to run forward passes the same way you did during training: load the trained (FP32) model and run the inference forward passes under autocast.
  • If you think it's ok for the model's leaves to be FP16 during inference but are still concerned that some internal ops need FP32, there's a compromise between the above two options: load the saved FP32 model, call .half() on it, then run forward passes under autocast. Autocast affects ops individually, it doesn't care if the params start as half or float.

@trytolose O2 isn't a thing in torch.cuda.amp. O2 is more brittle (does not make any per-op casting decisions) so it isn't fit for upstream (or anyone, tbh). We are identifying performance opportunities that don't endanger convergence and upstreaming them gradually. Please prefer native amp for stability and future-proofing, and the native amp implementation will get faster as you update pytorch without changes to your network. We already observe torch.cuda.amp is often faster than apex O1 due to reduced python overhead.

@mcarilli
Thanks for a great job in mixed-precision!
I'm trying both apex.amp and torch.cuda.amp and both of them turn out to be effective in terms of memory reduction and speed improvements.
But currently I see torch.cuda.amp.GradScaler is a bit limited compared to apex.
For example, in apex, we can set the max_loss_scale at amp.initialize() but I don't find such feature in GradScaler.
Also, there are many other possible options in apex.amp that are not currently supported in torch.cuda.amp. Will they be implemented in torch.cuda.amp?

@SeungjunNah The options available in native Amp are a better representation of what users should control. Some apex options, like opt-level O2, are unsafe for general use. If an option is present in apex amp but not present in native amp, it's probably not an important knob for the user to experiment with, therefore including it would make the API more cluttered and confusing. For example, I'm not aware of any network where setting max_loss_scale was required for convergence. If you have evidence that max_loss_scale is required, I can add it.

In general, torch.cuda.amp tries to add support for use cases that people complained were unsupported in Apex, and hide options that people should not or did not care about.

@mcarilli
I use max_loss_scale to avoid gradient explosion when training my models in this line of my repository.
(loss scaling here)
From my experiments, I know that gradients usually explode when the scale factor is 2048 for similar tasks, and setting the upper limit to 1024 would work.
Otherwise, the amp would skip a gradient-overflowed batch every N intervals and I don't want to lose a batch during training if possible.
(PyTorch master doc says optimizer.step() is skipped when inf/NaNs are found.)

A workaround could be to recompute the loss scaling until the overflow is avoided but I didn't find a way to implement it myself.

I'd appreciate if you could add max_loss_scale option to torch.cuda.amp.

amp would skip a gradient-overflowed batch every N intervals

That's true, but N is a large value (2000 by default). After the initial few iterations where GradScaler calibrates, it settles to a steady state where step skipping should only occur once every 2000 iterations (when it attempts a higher scale value). Generally, letting GradScaler dynamically find a steady state scale value is the best approach. Skipping one out of every 2000 iterations should have a negligible effect on both convergence and performance.

What you're suggesting is more like "static loss scaling": locking the scale to a user-defined value rather than letting GradScaler adjust it dynamically. This is also possible (though not recommended) with the native API without an additional max_loss_scale constructor arg: call scaler.update(1024.) instead of scaler.update() at the end of each iteration.

Ok, skipping with 1/2000 ratio doesn't hurt practically. I wanted to see if there were ways to control the number of iterations completely, though. Thanks for the explanation!

@mcarilli I just watched a video that says you can used FusedAdam, FusedSGD, etc. for a faster optimizer when using amp. How do we use this in native Pytorch 1.6 with amp? Ty

@mcarilli
hi, thanks for you great work!
In my task, comparing to opt-level O1, opt-level O2 can train faster yet has no damage on performance. So are there any workaround to support amp behavior like O2. Can I just cast the model weights to FP16 (except batch-norm and etc.) before training ?
like

model = convert_most_weights_to_half(model)
with autocast():
        output = model(input)
        loss = loss_fn(output, target)
loss.backward()
optimizer.step()

I can't find the example that test the performance in imagenet with torch.cuda.amp.
In my case,I test the performance with nvidia's dali dataloader, imagenet, ResNet50 and torch.cuda.amp. But can only get the performance at ~0.68 in 90 epochs.
Here is my code:
```python

import torch
import torchvision.transforms as transforms
from PIL import Image
import io
import argparse
import os
import random
import shutil
import time
import warnings
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.multiprocessing as mp
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
from nvidia.dali.plugin.pytorch import DALIClassificationIterator
from nvidia.dali.pipeline import Pipeline
import nvidia.dali.ops as ops
import nvidia.dali.types as types
from torch.cuda.amp import GradScaler,autocast

class HybridTrainPipe(Pipeline):
    def __init__(self, batch_size, num_threads, device_id, data_dir, crop,
                shard_id, num_shards, dali_cpu=False):
        super(HybridTrainPipe, self).__init__(
            batch_size, num_threads, device_id, seed=12 + device_id
        )
        self.input = ops.FileReader(
            file_root=data_dir,
            shard_id=shard_id,
            num_shards=num_shards,
            random_shuffle=True,
        )
        #let user decide which pipeline works him bets for RN version he runs
        dali_device = 'cpu' if dali_cpu else 'gpu'
        decoder_device = 'cpu' if dali_cpu else 'mixed'
        # This padding sets the size of the internal nvJPEG buffers to be able to handle all images from full-sized ImageNet
        # without additional reallocations
        device_memory_padding = 211025920 if decoder_device == 'mixed' else 0
        host_memory_padding = 140544512 if decoder_device == 'mixed' else 0
        if dali_cpu:
            self.decode = ops.ImageDecoder(device=dali_device, output_type=types.RGB)
        else:
            self.decode = ops.ImageDecoderRandomCrop(device=decoder_device, output_type=types.RGB,
                                                device_memory_padding=device_memory_padding,
                                                host_memory_padding=host_memory_padding,
                                                random_aspect_ratio=[0.8, 1.25],
                                                random_area=[0.1, 1.0],
                                                num_attempts=100)
        self.res = ops.RandomResizedCrop(
            device=dali_device,
            size=[crop, crop],
            interp_type=types.INTERP_LINEAR,
            random_aspect_ratio=[0.75, 4.0 / 3.0],
            random_area=[0.08, 1.0],
            num_attempts=100,
        )
        self.cmnp = ops.CropMirrorNormalize(
            device="gpu",
            output_layout=types.NCHW,
            crop=(crop, crop),
            mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
            std=[0.229 * 255, 0.224 * 255, 0.225 * 255],
        )
        self.coin = ops.CoinFlip(probability=0.5)
        print('DALI "{0}" variant'.format(dali_device))

    def define_graph(self):
        rng = self.coin()
        self.jpegs, self.labels = self.input(name="Reader")
        images = self.decode(self.jpegs)
        images = self.res(images)
        output = self.cmnp(images.gpu(), mirror=rng)
        labels_gpu=self.labels.gpu()
        return [output, labels_gpu]


class HybridValPipe(Pipeline):
    def __init__(self, batch_size, num_threads, device_id, data_dir, crop,
                size, shard_id, num_shards):
        super(HybridValPipe, self).__init__(batch_size,
                                        num_threads,
                                            device_id,
                                            seed=12 + device_id)
        self.input = ops.FileReader(file_root=data_dir,
                                    shard_id=shard_id,
                                    num_shards=num_shards,
                                    random_shuffle=False,
                                    pad_last_batch=True)
        self.decode = ops.ImageDecoder(device="mixed", output_type=types.RGB)
        self.res = ops.Resize(device="gpu",
                            resize_shorter=size)
        self.cmnp = ops.CropMirrorNormalize(device="gpu",
                                            dtype=types.FLOAT,
                                            crop=(crop, crop),
                                            output_layout=types.NCHW,
                                            mean=[0.485 * 255,0.456 * 255,0.406 * 255],
                                            std=[0.229 * 255,0.224 * 255,0.225 * 255])

    def define_graph(self):
        self.jpegs, self.labels = self.input(name="Reader")
        images = self.decode(self.jpegs)
        images = self.res(images)
        output = self.cmnp(images)
        return [output, self.labels.gpu()]



model_names = sorted(name for name in models.__dict__
                    if name.islower() and not name.startswith("__")
                    and callable(models.__dict__[name]))

parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('data', metavar='DIR',
                    help='path to dataset')
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
                    choices=model_names,
                    help='model architecture: ' +
                    ' | '.join(model_names) +
                    ' (default: resnet18)')
parser.add_argument('-j', '--workers', default=96, type=int, metavar='N',
                    help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=90, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                    help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch-size', default=512, type=int,
                    metavar='N',
                    help='mini-batch size (default: 256), this is the total '
                        'batch size of all GPUs on the current node when '
                        'using Data Parallel or Distributed Data Parallel')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
                    metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                    help='momentum')
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
                    metavar='W', help='weight decay (default: 1e-4)',
                    dest='weight_decay')
parser.add_argument('-p', '--print-freq', default=30, type=int,
                    metavar='N', help='print frequency (default: 10)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
                    help='path to latest checkpoint (default: none)')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
                    help='evaluate model on validation set')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
                    help='use pre-trained model')
parser.add_argument('--world-size', default=-1, type=int,
                    help='number of nodes for distributed training')
parser.add_argument('--rank', default=-1, type=int,
                    help='node rank for distributed training')
parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
                    help='url used to set up distributed training')
parser.add_argument('--dist-backend', default='nccl', type=str,
                    help='distributed backend')
parser.add_argument('--seed', default=None, type=int,
                    help='seed for initializing training. ')
parser.add_argument('--gpu', default=None, type=int,
                    help='GPU id to use.')
parser.add_argument('--multiprocessing-distributed', action='store_true',
                    help='Use multi-processing distributed training to launch '
                        'N processes per node, which has N GPUs. This is the '
                        'fastest way to use PyTorch for either single node or '
                        'multi node data parallel training')
parser.add_argument('--dali_cpu', action='store_true',
                    help='dali_cpu')
parser.add_argument('--amp', action='store_true',
                    help='dali_cpu')
best_acc1 = 0


def main():
    args = parser.parse_args()

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                    'This will turn on the CUDNN deterministic setting, '
                    'which can slow down your training considerably! '
                    'You may see unexpected behavior when restarting '
                    'from checkpoints.')

    if args.gpu is not None:
        warnings.warn('You have chosen a specific GPU. This will completely '
                    'disable data parallelism.')

    if args.dist_url == "env://" and args.world_size == -1:
        args.world_size = int(os.environ["WORLD_SIZE"])

    args.distributed = args.world_size > 1 or args.multiprocessing_distributed

    ngpus_per_node = torch.cuda.device_count()
    if args.multiprocessing_distributed:
        # Since we have ngpus_per_node processes per node, the total world_size
        # needs to be adjusted accordingly
        args.world_size = ngpus_per_node * args.world_size
        # Use torch.multiprocessing.spawn to launch distributed processes: the
        # main_worker process function
        mp.spawn(main_worker, nprocs=ngpus_per_node,
                args=(ngpus_per_node, args))
    else:
        # Simply call main_worker function
        main_worker(args.gpu, ngpus_per_node, args)


def main_worker(gpu, ngpus_per_node, args):
    global best_acc1
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                world_size=args.world_size, rank=args.rank)
    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True)
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch]()
    # model=nn.DataParallel(model)
    if not torch.cuda.is_available():
        print('using CPU, this will be slow')
    elif args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int(
                (args.workers + ngpus_per_node - 1) / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[args.gpu])
        else:
            model.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            model = torch.nn.parallel.DistributedDataParallel(model)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda(args.gpu)

    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            if args.gpu is None:
                checkpoint = torch.load(args.resume)
            else:
                # Map model to be loaded to specified single gpu.
                loc = 'cuda:{}'.format(args.gpu)
                checkpoint = torch.load(args.resume, map_location=loc)
                print(loc)
            args.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            # if args.gpu is not None:
            #     # best_acc1 may be from a checkpoint from a different GPU
            #     best_acc1 = best_acc1.to(args.gpu)
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True
    crop_size = 224
    val_size = 256
    # Data loading code

    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    pipe = HybridTrainPipe(batch_size=args.batch_size,
                        num_threads=args.workers,
                        device_id=args.rank,
                        data_dir=traindir,
                        crop=crop_size,
                        dali_cpu=args.dali_cpu,
                        shard_id=args.rank,
                        num_shards=args.world_size)
    pipe.build()
    train_loader = DALIClassificationIterator(pipe, reader_name="Reader")

    pipe = HybridValPipe(batch_size=args.batch_size,
                        num_threads=args.workers,
                        device_id=args.rank,
                        data_dir=valdir,
                        crop=crop_size,
                        size=val_size,
                        shard_id=args.rank,
                        num_shards=args.world_size)
    pipe.build()
    val_loader = DALIClassificationIterator(pipe, reader_name="Reader")

    if args.evaluate:
        validate(val_loader, model, criterion, args)
        return

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, args)

        # evaluate on validation set
        acc1 = validate(val_loader, model, criterion, args)
        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)

        if not args.multiprocessing_distributed or (args.multiprocessing_distributed
                                                    and args.rank % ngpus_per_node == 0):
            save_checkpoint({
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_acc1': best_acc1,
                'optimizer': optimizer.state_dict(),
            }, is_best)
        train_loader.reset()
        val_loader.reset()

def to_python_float(t):
    if hasattr(t, 'item'):
        return t.item()
    else:
        return t[0]

def train(train_loader, model, criterion, optimizer, epoch, args):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    train_loader_len=int(train_loader._size / args.batch_size)

    # switch to train mode
    model.train()
    if args.amp:
        scaler = GradScaler()
    end = time.time()
    for i, dict_data in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        images = dict_data[0]['data']
        target = dict_data[0]['label'].squeeze().long()

        # compute output
        if args.amp:
            with autocast():
                output = model(images)
                loss = criterion(output, target)
        else:
            output = model(images)
            loss = criterion(output, target)

        # compute gradient and do SGD step
        optimizer.zero_grad()
        if args.amp:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()

        # measure elapsed time
        if i%args.print_freq == 0:
            # Every print_freq iterations, check the loss, accuracy, and speed.
            # For best performance, it doesn't make sense to print these metrics every
            # iteration, since they incur an allreduce and some host<->device syncs.

            # Measure accuracy
            prec1, prec5 = accuracy(output.data, target, topk=(1, 5))

            # Average loss and accuracy across processes for logging
            if args.distributed:
                reduced_loss = reduce_tensor(loss.data,args)
                prec1 = reduce_tensor(prec1,args)
                prec5 = reduce_tensor(prec5,args)
            else:
                reduced_loss = loss.data

            # to_python_float incurs a host<->device sync
            losses.update(to_python_float(reduced_loss), images.size(0))
            top1.update(to_python_float(prec1), images.size(0))
            top5.update(to_python_float(prec5), images.size(0))

            torch.cuda.synchronize()
            batch_time.update((time.time() - end)/args.print_freq)
            end = time.time()

            if args.rank == 0:
                print('Epoch: [{0}][{1}/{2}]\t'
                    'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                    'Speed {3:.3f} ({4:.3f})\t'
                    'Loss {loss.val:.10f} ({loss.avg:.4f})\t'
                    'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                    'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                    epoch, i, train_loader_len,
                    args.world_size*args.batch_size/batch_time.val,
                    args.world_size*args.batch_size/batch_time.avg,
                    batch_time=batch_time,
                    loss=losses, top1=top1, top5=top5))
    


def validate(val_loader, model, criterion, args):
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    val_loader_len=int(val_loader._size / args.batch_size)

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        end = time.time()
        for i, dict_data in enumerate(val_loader):
            images = dict_data[0]['data']
            target = dict_data[0]['label'].squeeze().long()

            # compute output
            output = model(images)
            loss = criterion(output, target)

            # measure accuracy and record loss
            prec1, prec5 = accuracy(output.data, target, topk=(1, 5))

            if args.distributed:
                reduced_loss = reduce_tensor(loss.data,args)
                prec1 = reduce_tensor(prec1,args)
                prec5 = reduce_tensor(prec5,args)
            else:
                reduced_loss = loss.data

            losses.update(to_python_float(reduced_loss), images.size(0))
            top1.update(to_python_float(prec1), images.size(0))
            top5.update(to_python_float(prec5), images.size(0))


            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if args.rank == 0 and i % args.print_freq == 0:
                print('Test: [{0}/{1}]\t'
                    'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                    'Speed {2:.3f} ({3:.3f})\t'
                    'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                    'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                    'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                    i, val_loader_len,
                    args.world_size * args.batch_size / batch_time.val,
                    args.world_size * args.batch_size / batch_time.avg,
                    batch_time=batch_time, loss=losses,
                    top1=top1, top5=top5))

        # TODO: this should also be done with the ProgressMeter
        print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
            .format(top1=top1, top5=top5))

    return top1.avg


def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print('\t'.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'


def adjust_learning_rate(optimizer, epoch, args):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    lr = args.lr * (0.1 ** (epoch // 30))
    print(lr)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res
def reduce_tensor(tensor,args):
    rt = tensor.clone()
    dist.all_reduce(rt, op=dist.ReduceOp.SUM)
    rt /= args.world_size
    return rt

if __name__ == '__main__':
    main()
```

@mcarilli I found one case where we might need min_loss_scale. In my training with AMP, the first several iterations have NaN gradient quite often. Thus the first usable scaling value becomes 0.0325 (or something like that). Does a scaling value make sense?

"O2" is stable for me where "O1" and native amp give me NaNs. It would be really nice if there were some way to duplicate 02 behavior using native torch.cuda.amp. I've tried casting all batch norms to 32, but that didn't do it. So I guess something else is happening under the hood.