项目作者: 4uiiurz1

项目描述 :
PyTorch implementation of Scale-Aware Triplet Networks
高级语言: Python
项目地址: git://github.com/4uiiurz1/pytorch-scale-aware-triplet.git
创建时间: 2019-01-10T11:19:23Z
项目社区:https://github.com/4uiiurz1/pytorch-scale-aware-triplet

开源协议:MIT License

下载


PyTorch implementation of Scale-Aware Triplet Networks

This repository contains code for Scale-Aware Triplet Networks based on Learning Deep Descriptors with Scale-Aware Triplet Networks implemented in PyTorch.

Requirements

  • Python 3.6
  • PyTorch 1.0

Usage

Args:

  • theta_glo (float, default: 1.15): Global context in all triplets.
  • delta (int, default: 5): Scale correction parameter.
  • gamma (float, default: 0.5): Ratio of siamese and triplet.
  • scale_aware (bool, default: True): Scale-aware sampling.

Input:

  • y_a: Anchor samples.
  • y_b: Positive samples. Each positive samples have same class labels to the correspond anchor samples.
  • targets: Class labels of y_a and y_b.
  1. criterion = MixedContextLoss(theta_glo=1.15, delta=5, gamma=0.5, scale_aware=True)
  2. y_a = model(input1)
  3. y_p = model(input2)
  4. loss = criterion(y_a, y_p, targets)
  5. optimizer.zero_grad()
  6. loss.backward()
  7. optimizer.step()

Training

MNIST

Use scale-aware siamese loss:

  1. python train.py --gamma 0 --scale-aware True

Use scale-aware triplet loss:

  1. python train.py --gamma 1 --scale-aware True

Use scale-aware mixed context loss (gamma=0.5):

  1. python train.py --gamma 0.5 --scale-aware True

Results

MNIST

mnist_siamese
mnist_mixed
mnist_triplet