adrianhust / pytorch_NEG_loss

NEG loss implemented in pytorch

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Pytorch Negative Sampling Loss

Negative Sampling Loss implemented in PyTorch.

NEG Loss Equation

Usage

    neg_loss = NEG_loss(num_classes, embedding_size)
    
    optimizer = SGD(neg_loss.parameters(), 0.1)
    
    for i in range(num_iterations):
        # input and target are [batch_size] shaped tensors of Long type
        input, target = next_batch(batch_size)
        
        loss = neg_loss(input, target, num_sample).mean()
    
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    word_embeddings = neg_loss.input_embeddings()
        

About

NEG loss implemented in pytorch

License:MIT License


Languages

Language:Python 100.0%