element 0 of tensors does not require grad and does not have a grad_fn in "test_step" and "validation_step"

SongJgit opened this issue · comments

Bug description

This is a very strange bug.
In my model, there is a step where I need to compute the Jacobian matrix as input information for the network. So I used torch.autograd.grad() for the Jacobian matrix calculation.
During training it is all fine, but when going to the validation_step and test_step it comes up with "element 0 of tensors does not require grad and does not have a grad_fn" .
So I added with torch.set_grad_enabled(True): to the calculation of the Jacobian matrix, which unfortunately will work in validation_step, but in test_step, it still gives the same bug.

What version are you seeing the problem on?


How to reproduce the bug

from torch.nn import functional as F
from torch import nn
from import DataLoader
from torchvision.datasets import MNIST
import os
from torchvision import transforms
from torch.optim import Adam
from lightning import Trainer
from lightning import LightningModule
import torch

class LitMNIST(LightningModule):
  def __init__(self, hparams):
    self.layer_1 = nn.Linear(28 * 28, 10)
    self.params = hparams
  def forward(self, x):
    batch_size, channels, height, width = x.size()
    x = x.view(batch_size, -1)
    x = self.layer_1(x)
    x = F.log_softmax(x, dim=1)
    return x
  def test_func(self):
    with torch.set_grad_enabled(True):
        N = 5
        f = lambda x: x ** 2
        x = torch.randn(N, requires_grad=True)
        y = f(x)
        I_N = torch.eye(N)
        # Sequential approach
        jacobian_rows = [torch.autograd.grad(y, x, v, retain_graph=True)[0]
                        for v in I_N.unbind()]
        jacobian = torch.stack(jacobian_rows)

    return jacobian
  def training_step(self, batch, batch_idx): 
    x, y = batch
    logits = self(x)
    loss = F.nll_loss(logits, y)
    return loss

  def test_step(self, batch, batch_idx):
    # this is the test loop
    x, y = batch
    logits = self(x)
    loss = F.nll_loss(logits, y)
    return loss

  def validation_step(self, batch, batch_idx):
    x, y = batch
    logits = self(x)
    loss = F.nll_loss(logits, y)
    return loss

  def configure_optimizers(self):
    optimizer = Adam(self.parameters(), 
    return optimizer


transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

mnist_train = MNIST(os.getcwd(), train=True, download=True, transform=transform)
mnist_test = MNIST(root=os.getcwd(), download=True, train=False, transform=transform)
mnist_train = DataLoader(mnist_train, batch_size=256,num_workers=2)
mnist_test = DataLoader(mnist_test, batch_size=256,num_workers=2)

model = LitMNIST(hparams)

trainer = Trainer(max_epochs=3), mnist_train, mnist_test)
trainer.test(ckpt_path='best', dataloaders=mnist_test)

Error messages and logs

	"name": "RuntimeError",
	"message": "element 0 of tensors does not require grad and does not have a grad_fn",
	"stack": "---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[2], line 1
----> 1 trainer.test(ckpt_path='best', dataloaders=mnist_test)

File c:\\Users\\songj\\miniconda3\\envs\\kalman\\lib\\site-packages\\lightning\\pytorch\\trainer\\, in Trainer.test(self, model, dataloaders, ckpt_path, verbose, datamodule)
    704     model = _maybe_unwrap_optimized(model)
    705     self.strategy._lightning_module = model
--> 706 return call._call_and_handle_interrupt(
    707     self, self._test_impl, model, dataloaders, ckpt_path, verbose, datamodule
    708 )

File c:\\Users\\songj\\miniconda3\\envs\\kalman\\lib\\site-packages\\lightning\\pytorch\\trainer\\, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
     42         return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
     43     else:
---> 44         return trainer_fn(*args, **kwargs)
     46 except _TunerExitException:
     47     _call_teardown_hook(trainer)

File c:\\Users\\songj\\miniconda3\\envs\\kalman\\lib\\site-packages\\lightning\\pytorch\\trainer\\, in Trainer._test_impl(self, model, dataloaders, ckpt_path, verbose, datamodule)
    744 self._data_connector.attach_data(model, test_dataloaders=dataloaders, datamodule=datamodule)
    746 ckpt_path = self._checkpoint_connector._select_ckpt_path(
    747     self.state.fn, ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
    748 )
--> 749 results = self._run(model, ckpt_path=ckpt_path)
    750 # remove the tensors from the test results
    751 results = convert_tensors_to_scalars(results)

File c:\\Users\\songj\\miniconda3\\envs\\kalman\\lib\\site-packages\\lightning\\pytorch\\trainer\\, in Trainer._run(self, model, ckpt_path)
    930 self._signal_connector.register_signal_handlers()
    932 # ----------------------------
    934 # ----------------------------
--> 935 results = self._run_stage()
    937 # ----------------------------
    938 # POST-Training CLEAN UP
    939 # ----------------------------
    940 log.debug(f\"{self.__class__.__name__}: trainer tearing down\")

File c:\\Users\\songj\\miniconda3\\envs\\kalman\\lib\\site-packages\\lightning\\pytorch\\trainer\\, in Trainer._run_stage(self)
    968 self.strategy.barrier(\"run-stage\")
    970 if self.evaluating:
--> 971     return
    972 if self.predicting:
    973     return

File c:\\Users\\songj\\miniconda3\\envs\\kalman\\lib\\site-packages\\lightning\\pytorch\\loops\\, in _no_grad_context.<locals>._decorator(self, *args, **kwargs)
    175     context_manager = torch.no_grad
    176 with context_manager():
--> 177     return loop_run(self, *args, **kwargs)

File c:\\Users\\songj\\miniconda3\\envs\\kalman\\lib\\site-packages\\lightning\\pytorch\\loops\\, in
    113     previous_dataloader_idx = dataloader_idx
    114     # run step hooks
--> 115     self._evaluation_step(batch, batch_idx, dataloader_idx)
    116 except StopIteration:
    117     # this needs to wrap the `*_step` call too (not just `next`) for `dataloader_iter` support
    118     break

File c:\\Users\\songj\\miniconda3\\envs\\kalman\\lib\\site-packages\\lightning\\pytorch\\loops\\, in _EvaluationLoop._evaluation_step(self, batch, batch_idx, dataloader_idx)
    372 self.batch_progress.increment_started()
    374 hook_name = \"test_step\" if trainer.testing else \"validation_step\"
--> 375 output = call._call_strategy_hook(trainer, hook_name, *step_kwargs.values())
    377 self.batch_progress.increment_processed()
    379 hook_name = \"on_test_batch_end\" if trainer.testing else \"on_validation_batch_end\"

File c:\\Users\\songj\\miniconda3\\envs\\kalman\\lib\\site-packages\\lightning\\pytorch\\trainer\\, in _call_strategy_hook(trainer, hook_name, *args, **kwargs)
    285     return
    287 with trainer.profiler.profile(f\"[Strategy]{trainer.strategy.__class__.__name__}.{hook_name}\"):
--> 288     output = fn(*args, **kwargs)
    290 # restore current_fx when nested context
    291 pl_module._current_fx_name = prev_fx_name

File c:\\Users\\songj\\miniconda3\\envs\\kalman\\lib\\site-packages\\lightning\\pytorch\\strategies\\, in Strategy.test_step(self, *args, **kwargs)
    385 with self.precision_plugin.test_step_context():
    386     assert isinstance(self.model, TestStep)
--> 387     return self.model.test_step(*args, **kwargs)

Cell In[1], line 48, in LitMNIST.test_step(self, batch, batch_idx)
     46 def test_step(self, batch, batch_idx):
     47   # this is the test loop
---> 48   self.test_func()
     49   x, y = batch
     50   logits = self(x)

Cell In[1], line 33, in LitMNIST.test_func(self)
     31     I_N = torch.eye(N)
     32     # Sequential approach
---> 33     jacobian_rows = [torch.autograd.grad(y, x, v, retain_graph=True)[0]
     34                     for v in I_N.unbind()]
     35     jacobian = torch.stack(jacobian_rows)
     37 return jacobian

Cell In[1], line 33, in <listcomp>(.0)
     31     I_N = torch.eye(N)
     32     # Sequential approach
---> 33     jacobian_rows = [torch.autograd.grad(y, x, v, retain_graph=True)[0]
     34                     for v in I_N.unbind()]
     35     jacobian = torch.stack(jacobian_rows)
     37 return jacobian

File c:\\Users\\songj\\miniconda3\\envs\\kalman\\lib\\site-packages\\torch\\autograd\\, in grad(outputs, inputs, grad_outputs, retain_graph, create_graph, only_inputs, allow_unused, is_grads_batched)
    301     return _vmap_internals._vmap(vjp, 0, 0, allow_none_pass_through=True)(grad_outputs_)
    302 else:
--> 303     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    304         t_outputs, grad_outputs_, retain_graph, create_graph, t_inputs,
    305         allow_unused, accumulate_grad=False)

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn"


Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0):
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 2.0):
#- Python version (e.g., 3.9):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
#- Running environment of LightningApp (e.g. local, cloud):

More info

No response

It seems to be caused by the INFERENCE MODE?
I added torch.is_inference_mode_enabled to with torch.enable_grad():, and found that it was False in validation_step, but True in test_step.

    with torch.enable_grad():
        N = 5
        f = lambda x: x ** 2
        x = torch.randn(N, requires_grad=True)
        # print(x.requires_grad)
        y = f(x)
        I_N = torch.eye(N)
        print(torch.is_inference_mode_enabled()) # False in train and validation, but True in test.
        # Sequential approach
        jacobian_rows = [torch.autograd.grad(y, x, v, retain_graph=True,allow_unused=True)[0]
                        for v in I_N.unbind()]
        jacobian = torch.stack(jacobian_rows)

    return jacobian
with torch.enable_grad():
    x= torch.randn(3,10).requires_grad_(True)
    y = x+1
with torch.inference_mode():
    with torch.enable_grad():
        x= torch.randn(3,1).requires_grad_(True)
        y = x+1

I want all subsequent gradients as well in test_step, how do I do that?

Do I seem to have found a solution?
Use with torch.inference_mode() instead of with torch.enable_grad().

with torch.inference_mode():
    x= torch.randn(3,10).requires_grad_(True)
    y = x+1
    with torch.inference_mode(mode=False):
        # with torch.enable_grad():
            x= torch.randn(3,10).requires_grad_(True)
            y = x+1

Inference mode is the default for validation/testing in the Trainer:
You can't take gradients in the validation/test_step methods by default. But you can turn it off by setting and using torch.no_grad(enabled=False) or like you've done above.