LaurentMazare / tch-rs

Rust bindings for the C++ api of PyTorch.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

How to clone a model in rust?

3togo opened this issue · comments

in c++
auto model = torch::jit::load("your_model.pt");
input = input.clone();

How to do it in rust?

I think you could do the same as with any non-cloneable object in rust, i.e. add a Rc or Arc layer for ref counting - nothing specific to tch. Maybe I'm missing something here?

let model = tch::CModule::load_on_device(model_path, device).unwrap();
let model: Rc<&tch::CModule> = Rc::new(&model);
let model2 = model.clone();

Rc will only clone a pointer to the same allocation.

My problem is that due to unknown reason, whenever I called model.forward() repeatedly using the same inputs, it will produce different results. Therefore, I need a fast way to clone the model before each forward() call. In c++, model.clone() will copy the content. How about in rust?

Below is extract from cloneable.h in "torch/csrc/api/include/nn/cloneable.h"
But I don't know how to translate it to rust.

/// Performs a recursive "deep copy" of theModule`, such that all parameters
/// and submodules in the cloned module are different from those in the
/// original module.

std::shared_ptr<Module> clone(
  const optional<Device>& device = nullopt) const override {
NoGradGuard no_grad;

const auto& self = static_cast<const Derived&>(*this);
auto copy = std::make_shared<Derived>(self);
copy->parameters_.clear();
copy->buffers_.clear();
copy->children_.clear();
copy->reset();
TORCH_CHECK(
    copy->parameters_.size() == parameters_.size(),
    "The cloned module does not have the same number of "
    "parameters as the original module after calling reset(). "
    "Are you sure you called register_parameter() inside reset() "
    "and not the constructor?");
for (const auto& parameter : named_parameters(/*recurse=*/false)) {
  auto& tensor = *parameter;
  auto data = device && tensor.device() != *device
      ? tensor.to(*device)
      : autograd::Variable(tensor).clone();
  copy->parameters_[parameter.key()].set_data(data);
}
TORCH_CHECK(
    copy->buffers_.size() == buffers_.size(),
    "The cloned module does not have the same number of "
    "buffers as the original module after calling reset(). "
    "Are you sure you called register_buffer() inside reset() "
    "and not the constructor?");
for (const auto& buffer : named_buffers(/*recurse=*/false)) {
  auto& tensor = *buffer;
  auto data = device && tensor.device() != *device
      ? tensor.to(*device)
      : autograd::Variable(tensor).clone();
  copy->buffers_[buffer.key()].set_data(data);
}
TORCH_CHECK(
    copy->children_.size() == children_.size(),
    "The cloned module does not have the same number of "
    "child modules as the original module after calling reset(). "
    "Are you sure you called register_module() inside reset() "
    "and not the constructor?");
for (const auto& child : children_) {
  copy->children_[child.key()]->clone_(*child.value(), device);
}
return copy;

}`

Below is extract from cloneable.h in "torch/csrc/api/include/nn/cloneable.h" But I don't know how to translate it to rust.

/// Performs a recursive "deep copy" of theModule`, such that all parameters /// and submodules in the cloned module are different from those in the /// original module.

std::shared_ptr<Module> clone(
  const optional<Device>& device = nullopt) const override {
NoGradGuard no_grad;

const auto& self = static_cast<const Derived&>(*this);
auto copy = std::make_shared<Derived>(self);
copy->parameters_.clear();
copy->buffers_.clear();
copy->children_.clear();
copy->reset();
TORCH_CHECK(
    copy->parameters_.size() == parameters_.size(),
    "The cloned module does not have the same number of "
    "parameters as the original module after calling reset(). "
    "Are you sure you called register_parameter() inside reset() "
    "and not the constructor?");
for (const auto& parameter : named_parameters(/*recurse=*/false)) {
  auto& tensor = *parameter;
  auto data = device && tensor.device() != *device
      ? tensor.to(*device)
      : autograd::Variable(tensor).clone();
  copy->parameters_[parameter.key()].set_data(data);
}
TORCH_CHECK(
    copy->buffers_.size() == buffers_.size(),
    "The cloned module does not have the same number of "
    "buffers as the original module after calling reset(). "
    "Are you sure you called register_buffer() inside reset() "
    "and not the constructor?");
for (const auto& buffer : named_buffers(/*recurse=*/false)) {
  auto& tensor = *buffer;
  auto data = device && tensor.device() != *device
      ? tensor.to(*device)
      : autograd::Variable(tensor).clone();
  copy->buffers_[buffer.key()].set_data(data);
}
TORCH_CHECK(
    copy->children_.size() == children_.size(),
    "The cloned module does not have the same number of "
    "child modules as the original module after calling reset(). "
    "Are you sure you called register_module() inside reset() "
    "and not the constructor?");
for (const auto& child : children_) {
  copy->children_[child.key()]->clone_(*child.value(), device);
}
return copy;

}`

I believe you can do this:

fn clone_model(
    model_var_store: &VarStore,
) -> Result<(VarStore, impl nn::Module), TchError> {
    let mut cloned_var_store = VarStore::new(Device::cuda_if_available());
    cloned_var_store.copy(model_var_store)?;

    let net = nn::linear(cloned_var_store.root(), 1337, 1337, Default::default());

    Ok((cloned_var_store, net))
}

In my use case, i need to clone a best trained model to a buffer model and save the buffer to disk when conducting early stopping. For correct clone and save, i need to initiate models for BOTH best model and buffer model:

This works:

  // Initiate models
  let vs = nn::VarStore::new(Device::Cpu); // model for training
  let net = Net(&vs.root());
  
  let mut buffer_vs = nn::VarStore::new(Device::Cpu); // buffer model to clone vs model and save to disk
  let buffer_net =  Net(&buffer_vs.root()); // This line is important to ensure correct saving, otherwise buffer_vs would be empty even without panics after copying
  
  // Do some training here on net...
  
  // Clone and save to disk
  buffer_vs.copy(&vs).expect("Failed to copy);
  buffer_vs.save("path/to/save").expect("Failed to save);

I guess this is because .copy method can only copy stuff from a known structure to another, and initiating model with varstore helps to handle the structure to be the same.