isl-org / lang-seg

Language-Driven Semantic Segmentation

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

How to get pixel-level embeddings

loris2222 opened this issue · comments

Hi, I am trying to use your model for research purposes on Explainable AI.
After struggling for more than I'd like to admit I finally managed to get it up and working, however, I can't find an easy way to get the pixel-level embeddings from your framework, since the interfaces are quite convoluted.

Right now I've been able to do so with evaluator._modules['module'].net.get_image_features(image) starting from your notebook. I had to write get_image_features as a modified version of forward that ends at the image features. As such, I don't think this is the best way.

Do you have any suggestion on how to proceed? Maybe some general instructions on how to try to do so?

Thank you in advance!

Hi @loris2222 ,

Thanks for your interest in LSeg!

Yea, LSeg was built one year ago, so you might need to install older version tools. I guess the fast solution is to revise the code.

Hope this helps!

Best,
Boyi

Hi @loris2222,
Have you found the way to get pixel-level embeddings in newer version tools?

Hi @yhyang-myron, yes, I was able to get pixel-level embeddings but the solution is quite hacky and I am not ready to share the full code. I am unsure about what you mean with 'newer version tools' though.

By the way, to get it working, I had to add a function in lseg_net.py that returns the model output after the embeddings. This, however, seems to only be working for batch size = 1 since it must be run through the evaluator. Anyway, here is the code for that function:

def get_image_features(self, x):
	if self.channels_last == True:
		x.contiguous(memory_format=torch.channels_last)
	layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)

	layer_1_rn = self.scratch.layer1_rn(layer_1)
	layer_2_rn = self.scratch.layer2_rn(layer_2)
	layer_3_rn = self.scratch.layer3_rn(layer_3)
	layer_4_rn = self.scratch.layer4_rn(layer_4)

	path_4 = self.scratch.refinenet4(layer_4_rn)
	path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
	path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
	path_1 = self.scratch.refinenet1(path_2, layer_1_rn)

	image_features = self.scratch.head1(path_1)
	imshape = image_features.shape
	image_features = image_features.permute(0, 2, 3, 1).reshape(-1, self.out_c)

	image_features = image_features.view(imshape[0], imshape[2], imshape[3], -1)
	image_features = torch.nn.functional.normalize(image_features, p=2.0, dim=-1) 

	return image_features

@loris2222
Thank you!