sugarme / gotch

Go binding for Pytorch C++ API (libtorch)

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Float64Values() shows an error 'Unsupported Go type: []float64'

luxiant opened this issue · comments

// function that gets a single row of the gota dataframe as an input, process the sentence of column 'text' and return struct
func bertSentimentProcess(dataframe dataframe.DataFrame) sentimentRow {
	var torchResult *ts.Tensor
	ts.NoGrad(func() {
		torchResult, _, _ = useModels.bertModel.ForwardT(
			processSentenceIntoInput(dataframe.Col("text").Records()[0]),
			ts.None,
			ts.None,
			ts.None,
			ts.None,
			false,
		)
	})
	categoryProb := torchResult.MustSoftmax(-1, gotch.Double, true).Float64Values()
	var sentiment string
	switch {
	case categoryProb[0] > categoryProb[1] && categoryProb[0] > categoryProb[2]:
		sentiment = "long"
	case categoryProb[1] > categoryProb[0] && categoryProb[1] > categoryProb[2]:
		sentiment = "neutral"
	default:
		sentiment = "short"
	}
	return sentimentRow{
		post_num:  dataframe.Col("post_num").Records()[0],
		time:      dataframe.Col("time").Records()[0],
		text:      dataframe.Col("text").Records()[0],
		long:      categoryProb[0],
		neutral:   categoryProb[1],
		short:     categoryProb[2],
		sentiment: sentiment,
	}
}

// function that converts sentence into an input tensor
func processSentenceIntoInput(sentence string) *ts.Tensor {
	sentence = strings.ReplaceAll(sentence, "- dc official App", " ")
	sentence = strings.ReplaceAll(sentence, "ㅋ", " ")
	sentence = strings.ReplaceAll(sentence, "\n", " ")
	sentence = strings.ReplaceAll(sentence, "ㅡ", " ")
	reg, _ := regexp.Compile("[^가-힣ㄱ-ㅎㅏ-ㅣa-zA-Z0-9\\-\\%\\?\\.]")
	sentence = reg.ReplaceAllString(sentence, " ")
	words := strings.Split(sentence, " ")
	n := 0
	for _, word := range words {
		if word != "" {
			words[n] = word
			n++
		}
	}
	sentence = strings.Join(words[:n], " ")
	finalEncode, _ := useModels.tokenizer.Encode(
		tokenizer.NewSingleEncodeInput(
			tokenizer.NewInputSequence(sentence),
		),
		true,
	)
	switch {
	case finalEncode.Len() > maxLength:
		finalEncode, _ = finalEncode.Truncate(maxLength, 2)
	case finalEncode.Len() < maxLength:
		finalEncode = &tokenizer.PadEncodings(
			[]tokenizer.Encoding{*finalEncode},
			*paddingParameter,
		)[0]
	default:
	}
	var tokInput = make([]int64, maxLength)
	for i := 0; i < len(finalEncode.Ids); i++ {
		tokInput[i] = int64(finalEncode.Ids[i])
	}
	return ts.MustStack(
		[]ts.Tensor{*ts.TensorFrom(tokInput)},
		0,
	).MustTo(device, true)
}

I'm working with my project and hit by this error during debugging.

'''
root@codespaces-e73b16:/workspaces/KoBERT# go run main.go
2023/03/05 07:00:02 INFO: CachedDir="/root/.cache/transformer"
Successfully loaded model
0% | | (0/100, 0 it/hr) [0s:0s]2023/03/05 07:00:06
Unsupported Go type: []float64
exit status 1
'''

I searched all of the variables throughout my code with []float64 type, but in this part of my code the line inside of the function 'bertSentimentProcess'

categoryProb := torchResult.MustSoftmax(-1, gotch.Double, true).Float64Values()

is the only one in which the variable takes []float64 type. I'm trying to figure out why the bug happens but still can't get it. []float64 is a supported Go type, so I thought this error should not happen. I first assumed that this is a language error, but even after reinstalling Go 1.19 I'm still getting a same error.

@luxiant,

can you try to print out tensors to check? Something like this:

fmt.Printf("torchResult: %i\n", torchResult)
sm := torchResult.MustSoftmax(-1, gotch.Double, true)
fmt.Printf("softmax tensor: %i\n", sm)

categoryProb := sm.Float64Values() 
fmt.Printf("categoryProb: %v\n", categoryProb) // if error occurs at above line, won't see this log.

There are some memory leakages in your code, but will discuss later.

Please report the logs. Thanks.

Uhhhhhh.....I feel bad for bothering you with this. Eventually, I finally found the cause after one day and a half of troubleshooting and that was because of my silly mistake.

The problem was not there that I pasted, but was in the process for loading my pretrained bert model. The correct way to load bertconfig file that you've suggested in the example was

bertConfig, _ := bert.ConfigFromFile("model/bert_config.json") // load bert config file
	var dummyLabelMap map[int64]string = make(map[int64]string)
	dummyLabelMap[0] = "long"
	dummyLabelMap[1] = "neutral"
	dummyLabelMap[2] = "short"
	bertConfig.Id2Label = dummyLabelMap
	bertConfig.OutputAttentions = true
	bertConfig.OutputHiddenStates = true

But I mistakenly erased some part and tried to run this code.

bertConfig, _ := bert.ConfigFromFile("model/bert_config.json") // load bert config file
	var dummyLabelMap map[int64]string = make(map[int64]string)
	bertConfig.OutputAttentions = true
	bertConfig.OutputHiddenStates = true

I corrected this and the problem solved.