cbovar / ConvNetSharp

Deep Learning in C#

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[Question] Network with different output kind (softmax & scalar)

MelnikovIG opened this issue · comments

Is it possible to create network with different outputs kind? For example, some outputs should be probability outputs (using softmax), some should be scalar values (from 0 to 1). If it possible, it would be good if anyone can provide example, thanks.

Do you want to output the two output kinds at the same time ?
It's already possible to have proba as output using SoftmaxLayer as last layer (e.g. Mnist demo) or scalar as output using RegressionLayer as last layer (e.g. Regression1DDemo).

Yes, i want two output kinds at the same time. Is it possible?

Do you want to get the output of the softmax layer And the output of the layer just before softmax ?
Or do you want to get the output of two distinct branches of a network (In that case it is not a stacked layer network anymore and you might be able to do it by using a computation graph)

@cbovar Thanks, yes i need 2 distinct branches, any example with computation graph?

Here is a quick example:

(to make it simple, I only don't use any dataset: I use one constant input and two constant outputs)

using System;
using System.Collections.Generic;
using System.Linq;
using ConvNetSharp.Flow;
using ConvNetSharp.Flow.Ops;
using ConvNetSharp.Flow.Training;
using ConvNetSharp.Volume;

namespace Demo
{
    internal static class TwoOutputsExample
    {
        public static void Main()
        {
            var cns = new ConvNetSharp<double>();

            // Graph Creation
            var y_class = cns.PlaceHolder("class_ground_truth");
            var y_scalar = cns.PlaceHolder("scalar_ground_truth");

            Op<double> x = cns.PlaceHolder("x");
            x = cns.Conv(x, 3, 3, 16);
            x = cns.Flatten(x);

            var output1 = cns.Softmax(cns.Dense(x, 10)); // 1st branch: proba
            var output2 = cns.Dense(x, 1); // 2nd branch: scalar

            // Loss 
            var classification_loss = cns.CrossEntropyLoss(output1, y_class); // should output y_class
            var scalar_loss = (output2 - y_scalar) * (output2 - y_scalar); // should output a scalar near y_scalar
            var total_loss = classification_loss + scalar_loss;

            var optimizer = new GradientDescentOptimizer<double>(cns, learningRate: 0.0001);

            var input = BuilderInstance<double>.Volume.Random(Shape.From(16, 16, 1, 1));

            using (var session = new Session<double>())
            {
                session.Differentiate(total_loss); // computes dCost/dW at every node of the graph

                // Training
                double currentLoss;
                do
                {
                    // Build dico containing 1 input and  2 x outputs (1 for each output branch).
                    var class_ground_truth = BuilderInstance<double>.Volume.SameAs(Shape.From(1, 1, 10, 1));
                    class_ground_truth.Set(0, 0, 5, 0, 1.0); // target is class 5
                    var scalar_ground_truth = 0.5;

                    var dico = new Dictionary<string, Volume<double>> {
                        { "x", input },
                        { "class_ground_truth", class_ground_truth },
                        { "scalar_ground_truth", scalar_ground_truth },
                    };

                    // Compute loss
                    currentLoss = session.Run(total_loss, dico);
                    Console.WriteLine($"cost: {currentLoss}");

                    // Run optimizer (will update weights in the network)
                    session.Run(optimizer, dico);
                } while (currentLoss > 0.1);

                // Test
                var result1 = session.Run(output1, new Dictionary<string, Volume<double>> { { "x", input } });
                var class_output = result1.ToArray().ToList().IndexOf(result1.ToArray().Max()); // should be 5
                var result2 = (double)session.Run(output2, new Dictionary<string, Volume<double>> { { "x", input } }); // should be ~0.5
            }

            Console.ReadLine();
        }
    }
}