项目作者: wangyida

项目描述 :
Generative Model with Coordinate Metric Learning
高级语言: Python
项目地址: git://github.com/wangyida/gm-cml.git
创建时间: 2017-05-12T13:44:39Z
项目社区:https://github.com/wangyida/gm-cml

开源协议:BSD 3-Clause "New" or "Revised" License

下载


Generative Model with Coordinate Metric Learning for Object Recognition Based on 3D Models

One of the bottlenecks in acquiring a perfect database for deep learning is the tedious process of collecting and labeling data.
In this paper, we propose a generative model trained with synthetic images rendered from 3D models which can reduce the burden on collecting real training data and make the bac kground conditions more realistic.
Our architecture is composed of two sub-networks: a semantic foreground object reconstruction network based on Bayesian inference, and a classification network based on multi-t riplet cost training for avoiding over-fitting on monotone synthetic object surface and utilizing accurate information of synthetic images like object poses and lightning condi tions which are helpful for recognizing regular photos.
Firstly, our generative model with metric learning utilizes additional foreground object channels generated from semantic foreground object reconstruction sub-network for recog nizing the original input images.
Multi-triplet cost function based on poses is used for metric learning which makes it possible to train an effective categorical classifier purely based on synthetic data.
Secondly, we design a coordinate training strategy with the help of adaptive noise applied on the inputs of both of the concatenated sub-networks to make them benefit from each other and avoid inharmonious parameter tuning due to different convergence speed of two sub-networks.
Our architecture achieves the state of the art accuracy of 50.5% on the ShapeNet database with data migration obstacle from synthetic images to real photos.
This pipeline makes it applicable to do recognition on real images only based on 3D models.

Copyright (c) 2017, Yida Wang
All rights reserved.

Please see our paper for more details about the method and data.

Author info

Yida Wang, Ph.D candidate in Technischen Universität München (TUM), München, Deutschland.
read more

Figures in the paper

Pipeline

This is the basic pipeline for Generative Model with Coordinate Metric Learning

Pipeline

Our GM-CML is a concatenated architecture which is shown as reconstruction sub-network and classification sub-network in below:

reconstruction sub-network
classification sub-network

Samples for the triplet training

Nearest Neighbor Classification results

Codes Explanation

Training and Testing Strategies

network input / placeholders for train (bn) and dropout

  1. x_img = tf.placeholder(tf.float32, input_shape, 'x_img')
  2. x_obj = tf.placeholder(tf.float32, input_shape, 'x_obj')

Input of the reconstruction network

  1. current_input1 = utils.corrupt(x_img)*corrupt_rec + x_img*(1-corrupt_rec) \
  2. if (denoising and phase_train is not None) else x_img
  3. current_input1.set_shape(x_img.get_shape())
  4. # 2d -> 4d if convolution
  5. current_input1 = utils.to_tensor(current_input1) \
  6. if convolutional else current_input1

Encoder

  1. for layer_i, n_output in enumerate(n_filters):
  2. with tf.variable_scope('encoder/{}'.format(layer_i)):
  3. shapes.append(current_input1.get_shape().as_list())
  4. if convolutional:
  5. with tf.variable_scope('variational'):
  6. if variational:

Decoder

  1. for layer_i, n_output in enumerate(n_filters[1:]):
  2. with tf.variable_scope('decoder/{}'.format(layer_i)):
  3. shape = shapes[layer_i + 1]
  4. if convolutional:

Loss finctions of VAE and softmax

  1. # l2 loss
  2. loss_x = tf.reduce_mean(
  3. tf.reduce_sum(tf.squared_difference(x_obj_flat, y_flat), 1))
  4. loss_z = 0
  5. if variational:
  6. # Variational lower bound, kl-divergence
  7. loss_z = tf.reduce_mean(-0.5 * tf.reduce_sum(
  8. 1.0 + 2.0 * z_log_sigma -
  9. tf.square(z_mu) - tf.exp(2.0 * z_log_sigma), 1))
  10. # Add l2 loss
  11. cost_vae = tf.reduce_mean(loss_x + loss_z)
  12. # Alexnet for clasification based on softmax using TensorFlow slim
  13. if softmax:

There are several optional choices for classification network. Just modify the parameter classifier.

  1. # The following are optional networks for classification network
  2. if classifier == 'squeezenet':
  3. predictions, net = squeezenet.squeezenet(
  4. y_concat, num_classes=13)
  5. if classifier == 'zigzagnet':
  6. predictions, net = squeezenet.zigzagnet(
  7. y_concat, num_classes=13)
  8. elif classifier == 'alexnet_v2':
  9. predictions, end_points = alexnet.alexnet_v2(
  10. y_concat, num_classes=13)
  11. elif classifier == 'inception_v1':
  12. predictions, end_points = inception.inception_v1(
  13. y_concat, num_classes=13)
  14. elif classifier == 'inception_v2':
  15. predictions, end_points = inception.inception_v2(
  16. y_concat, num_classes=13)
  17. elif classifier == 'inception_v3':
  18. predictions, end_points = inception.inception_v3(
  19. y_concat, num_classes=13)

Here we must set corrupt_rec and corrupt_cls as 0 to find a proper ratio of variance to feed for variable var_prob. We use tanh as non-linear function for ratio of Vars from the reconstructed channels and original channels.

  1. var_prob = sess.run(ae['var_prob'],
  2. feed_dict={
  3. ae['x_img']: test_xs_img,
  4. ae['x_label']: test_xs_label,
  5. ae['train']: True,
  6. ae['keep_prob']: 1.0,
  7. ae['corrupt_rec']: 0,
  8. ae['corrupt_cls']: 0})
  9. # Here is a fast training process
  10. corrupt_rec = np.tanh(0.25*var_prob)
  11. corrupt_cls = np.tanh(1-np.tanh(2*var_prob))

Main API for training and testing. General purpose training of a (Variational) (Convolutional) Autoencoder.

  1. def train_vae(files_img, files_obj, input_shape):
  2. """
  3. Parameters
  4. ----------
  5. files : list of strings
  6. List of paths to images.
  7. input_shape : list
  8. Must define what the input image's shape is.
  9. use_csv = bool, optional
  10. Use csv files to train conditional VAE or just VAE.
  11. learning_rate : float, optional
  12. Learning rate.
  13. batch_size : int, optional
  14. Batch size.
  15. n_epochs : int, optional
  16. Number of epochs.
  17. n_examples : int, optional
  18. Number of example to use while demonstrating the current training
  19. iteration's reconstruction. Creates a square montage, so make
  20. sure int(sqrt(n_examples))**2 = n_examples, e.g. 16, 25, 36, ... 100.
  21. crop_shape : list, optional
  22. Size to centrally crop the image to.
  23. crop_factor : float, optional
  24. Resize factor to apply before cropping.
  25. n_filters : list, optional
  26. Same as VAE's n_filters.
  27. n_hidden : int, optional
  28. Same as VAE's n_hidden.
  29. n_code : int, optional
  30. Same as VAE's n_code.
  31. convolutional : bool, optional
  32. Use convolution or not.
  33. variational : bool, optional
  34. Use variational layer or not.
  35. softmax : bool, optional
  36. Use the classification network or not.
  37. classifier : str, optional
  38. Network for classification.
  39. filter_sizes : list, optional
  40. Same as VAE's filter_sizes.
  41. dropout : bool, optional
  42. Use dropout or not
  43. keep_prob : float, optional
  44. Percent of keep for dropout.
  45. activation : function, optional
  46. Which activation function to use.
  47. img_step : int, optional
  48. How often to save training images showing the manifold and
  49. reconstruction.
  50. save_step : int, optional
  51. How often to save checkpoints.
  52. ckpt_name : str, optional
  53. Checkpoints will be named as this, e.g. 'model.ckpt'
  54. """

Visualization for Outputs and Parameters

We can visualize filters, reconstruction channels and also outputs according to latent variables.

Plot example reconstructions

  1. recon = sess.run(
  2. ae['y'], feed_dict={
  3. ae['x_img']: test_xs_img,
  4. ae['train']: False,
  5. ae['keep_prob']: 1.0,
  6. ae['corrupt_rec']: 0,
  7. ae['corrupt_cls']: 0})
  8. utils.montage(recon.reshape([-1] + crop_shape),
  9. 'recon_%08d.png' % t_i)

Plot filters

  1. filters = sess.run(
  2. ae['Ws'], feed_dict={
  3. ae['x_img']: test_xs_img,
  4. ae['train']: False,
  5. ae['keep_prob']: 1.0,
  6. ae['corrupt_rec']: 0,
  7. ae['corrupt_cls']: 0})
  8. #for filter_element in filters:
  9. utils.montage_filters(filters[-1],
  10. 'filter_%08d.png' % t_i)