How to freeze parameters
lbotecur opened this issue · comments
Hello,
I would like to know how to freeze parameters in a model, that is, how to training only a subset of parameters.
Thank you.
The update!()
function takes an optional argument ignore
- a set of field paths that should not be updated. A field path is a tuple of symbols representing path to a specific parameter. For example, if your model looks like this:
mutable struct Foo
x
y
end
mutable struct Bar
foo::Foo
z
end
m = Bar(Foo(1, 2), 3)
And you want to ignore Foo's x
and Bar's z
, use it like this:
ignore = Set([
(:foo, :x),
(:z,)
])
update!(m, gm, ignore=ignore)
Note that this is a low-level and unstable API. I'm currently working on such small things, including this very specific task - freezing the parameters - but I have several uses cases and no specific design yet. I'll be grateful if you describe your use case so that I could make the API more convenient.
Thank you for the answer. This solution is great. My use case is just the case that you have exposed: to use a pretrained model (Foo
) as part of a new model (Bar
) and train this one with Foo
parameters freezed. After that, to perform a fine-tuned of the model with Foo
parameters unfreezed.
I don't know if there is any possibility to pass only the parameters to calculate the gradients to Yötä, in similar way that JAX done.
Thanks.
Great, pretraining is a very important use case for Avalon, so we will definitely have a more concise syntax for freezing parameters, but exact API will arrive later, perhaps shortly after the high-level training API.
Please note that the ignore list expects full field paths, so using just (:foo,)
in the example above won't have any effect. To recursively collect the list of field paths, you can use the following:
function collect_fields(obj)
paths = []
for p in propertynames(obj)
subpaths = collect_fields(getproperty(obj, p))
if !isempty(subpaths)
for subpath in subpaths
path = [p; subpath...]
push!(paths, path)
end
else
push!(paths, [p])
end
end
return [(path...,) for path in paths]
end
I don't know if there is any possibility to pass only the parameters to calculate the gradients to Yötä, in similar way that JAX done.
I'm not sure I've got you correctly, but if you are looking for a semantics like:
f(x) = ...
gf = grad(f)
gf(x)
Unfortunately it's not possible out of the box because without concrete arguments Yota doesn't really know which method of f()
to trace. Yet it should be possible to make a simple wrapper, something like:
grad_fn(f) = args -> grad(f, args...)
Is it what you were asking about?