google / compare_gan

Compare GAN code.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Need trainable version for pretrained SSGAN in tfhub

RuiLiFeng opened this issue · comments

@Marvin182 Hi, the tfhub team has just upload your SSGAN module. It's wonderful but seems dose not have a trainable version.
I set m = hub.Module(spec_name, name="gen_module", tags={"gen", "bsNone"}, trainable=True), but the module offers no gradients when optimizor is applied.
Below is part of my code.
`
class Generator(object):

def init(self, module_spec, trainable=True):
self._module_spec = module_spec
self._trainable = trainable
self._module = hub.Module(self._module_spec, name="gen_module",
tags={"gen", "bsNone"}, trainable=self._trainable)
self.input_info = self._module.get_input_info_dict()

def build_graph(self, input_dict):
"""
Build tensorflow graph for Generator
:param input_dict: {'z_': <hub.ParsedTensorInfo shape=(?, 120) dtype=float32 is_sparse=False>,
'labels': None or (?,)}
:return:{'generated': <hub.ParsedTensorInfo shape=(?, 128, 128, 3) dtype=float32 is_sparse=False>}
"""
inv_input = {}
inv_input['z'] = G_mapping_ND(input_dict['z_'], 120, 120)
# inv_input['labels'] = input_dict.get('labels', None)
self.samples = self._module(inputs=inv_input, as_dict=True)['generated']
return self.samples

@Property
def trainable_variables(self):
return [var for var in tf.trainable_variables() if 'generator' in var.name]
`

I wonder if it is my implementaion not right or the module itself not trainable.

The Hub module is for inference only.
If you want to continue training I would advice you to run the code and load the the variable values from the checkpoint in the hub module, a bit tricky but doable.

Thank you very much.
While there might be minor mistakes in the tfhub document of ssgan.

In https://tfhub.dev/google/compare_gan/ssgan_128x128/1, the example usage suggests to put a label variable into the module, which in my case, will return an error: TypeError: Cannot convert dict_inputs: missing ['images'], extra given ['labels'].

The colab version also suggests a conditional version of SSGAN, but the get_input_info_dict() says the model only accepts 'z' as input.