Need trainable version for pretrained SSGAN in tfhub
RuiLiFeng opened this issue · comments
@Marvin182 Hi, the tfhub team has just upload your SSGAN module. It's wonderful but seems dose not have a trainable version.
I set m = hub.Module(spec_name, name="gen_module", tags={"gen", "bsNone"}, trainable=True), but the module offers no gradients when optimizor is applied.
Below is part of my code.
`
class Generator(object):
def init(self, module_spec, trainable=True):
self._module_spec = module_spec
self._trainable = trainable
self._module = hub.Module(self._module_spec, name="gen_module",
tags={"gen", "bsNone"}, trainable=self._trainable)
self.input_info = self._module.get_input_info_dict()
def build_graph(self, input_dict):
"""
Build tensorflow graph for Generator
:param input_dict: {'z_': <hub.ParsedTensorInfo shape=(?, 120) dtype=float32 is_sparse=False>,
'labels': None or (?,)}
:return:{'generated': <hub.ParsedTensorInfo shape=(?, 128, 128, 3) dtype=float32 is_sparse=False>}
"""
inv_input = {}
inv_input['z'] = G_mapping_ND(input_dict['z_'], 120, 120)
# inv_input['labels'] = input_dict.get('labels', None)
self.samples = self._module(inputs=inv_input, as_dict=True)['generated']
return self.samples
@Property
def trainable_variables(self):
return [var for var in tf.trainable_variables() if 'generator' in var.name]
`
I wonder if it is my implementaion not right or the module itself not trainable.
The Hub module is for inference only.
If you want to continue training I would advice you to run the code and load the the variable values from the checkpoint in the hub module, a bit tricky but doable.
Thank you very much.
While there might be minor mistakes in the tfhub document of ssgan.
In https://tfhub.dev/google/compare_gan/ssgan_128x128/1, the example usage suggests to put a label variable into the module, which in my case, will return an error: TypeError: Cannot convert dict_inputs: missing ['images'], extra given ['labels'].
The colab version also suggests a conditional version of SSGAN, but the get_input_info_dict() says the model only accepts 'z' as input.