mattjj / pyhsmm

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

SVI plotting broken

jengelman opened this issue · comments

Tried to use hmm.plot() in the svi example, and got the following:

Imgur

Despite showing different emissions, they are all one color because only the latest state seems to be recorded at the end of training. After digging into it a bit, looks like the HMMSVI class uses states_list for each minibatch instead of for storing the global HMMStatesEigen object. I tried changing _get_mb_states_list in the _HMMSVI definition in models.py to add the last state from states_list instead of popping it off, as follows:

def _get_mb_states_list(self,minibatch,**kwargs):
    minibatch = minibatch if isinstance(minibatch,list) else [minibatch]
    mb_states_list = []
    for mb in minibatch:
            self.add_data(mb,generate=False,**kwargs)
            mb_states_list.append(self.states_list[-1])
        return mb_states_list

This resulted in the following plot, where the left side looks as expect and the right is a mess, likely due to plotting every overlapping state sequences:

Imgur