ltrottier / deep-collaboration-network

Deep Collaboration Network in pytorch

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Deep Collaboration Network

This is an implementation of Deep Collaboration Network (DCNet) in pytorch from Multi-Task Learning by Deep Collaboration and Application in Facial Landmark Detection by Trottier, et al. (2017).

Requirements

  1. pytorch
  2. torchvision==0.1.8

Update

The collaborative block now uses an additional ReLU after the sum between the input (coming from the identity skip connection) and the output of the task aggregation. We obtained better performance with it.

Example

Here is an example on how to use DCNet and train it with MultiTaskCriterion:

from dcnet import DeepCollaborationNetwork, MultiTaskCriterion
from torch.autograd import Variable
import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn

# define (input, targets) pair
dims = [10, 5]
bs = 32
input = Variable(torch.randn(bs, 3, 112, 112))
targets = []
for i, dim in enumerate(dims):
    target = np.random.randint(0, dim, (bs, ))
    target = torch.from_numpy(target)
    target = Variable(target)
    targets.append(target)

# create training criterion
criterions = [nn.CrossEntropyLoss() for _ in dims]
weights = [1] * len(dims)
criterion = MultiTaskCriterion(criterions, weights)

# create network
net = DeepCollaborationNetwork('resnet18', dims, pretrained=True)

# optimize
optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, nesterov=True)
optimizer.zero_grad()
output = net(input)
loss = criterion(output, targets)
loss.backward()
optimizer.step()

About

Deep Collaboration Network in pytorch

License:MIT License


Languages

Language:Python 100.0%