tianyic / only_train_once

OTOv1-v3, NeurIPS, ICLR, TMLR, DNN Training, Compression, Structured Pruning, Erasing Operators, CNN, Diffusion, LLM

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Allowing dummy_input to be tuple / dict

iamanigeeit opened this issue · comments

I know i should do a pull request, but this is a quick edit:

def _get_trace_graph(self, model, dummy_input, optimized_onnx=False):
# Run the Pytorch graph to get a trace and generate a graph from it
trace_graph = None
with torch.no_grad():
trace_graph, _ = torch.jit._get_trace_graph(model, dummy_input)

I modified it to make it work with sequences of tensors / dict of keys: tensor. This is very common when running model(**batch).

import inspect
    def _get_trace_graph(self, model, dummy_input, optimized_onnx=False):
        # Run the Pytorch graph to get a trace and generate a graph from it
        trace_graph = None
        with torch.no_grad():
            if isinstance(dummy_input, dict):
                forward_args = inspect.signature(model.forward).parameters.keys()
                input_tensors = []
                for argname in forward_args:
                    if argname not in ['args', 'kwargs']:
                        if argname in dummy_input:
                            input_tensor = dummy_input[argname]
                            input_tensors.append(input_tensor)
                            print(argname, input_tensor.shape)
                        else:
                            input_tensors.append(None)
                input_tensors = tuple(input_tensors)
            elif isinstance(dummy_input, torch.Tensor):
                input_tensors = (dummy_input,)
            else:
                input_tensors = tuple(dummy_input)
            trace_graph, _ = torch.jit._get_trace_graph(model, args=input_tensors)

@iamanigeeit

Thanks for the modifications which look good for me. Yes, I recommend to creating a pull request and we appreciate the contributions from the community. Upon your willingness, you could create a pull request here or awaiting for us to finish the repo immigration to Microsoft open-source affiliation (expected soon).