项目作者: lucidrains

项目描述 :
Implementation of Uformer, Attention-based Unet, in Pytorch
高级语言: Python
项目地址: git://github.com/lucidrains/uformer-pytorch.git
创建时间: 2021-06-17T00:56:03Z
项目社区:https://github.com/lucidrains/uformer-pytorch

开源协议:MIT License

下载


Uformer - Pytorch

Implementation of Uformer, Attention-based Unet, in Pytorch. It will only offer the concat-cross-skip connection.

This repository will be geared towards use in a project for learning protein structures. Specifically, it will include the ability to condition on time steps (needed for DDPM), as well as 2d relative positional encoding using rotary embeddings (instead of the bias on the attention matrix in the paper).

Install

  1. $ pip install uformer-pytorch

Usage

  1. import torch
  2. from uformer_pytorch import Uformer
  3. model = Uformer(
  4. dim = 64, # initial dimensions after input projection, which increases by 2x each stage
  5. stages = 4, # number of stages
  6. num_blocks = 2, # number of transformer blocks per stage
  7. window_size = 16, # set window size (along one side) for which to do the attention within
  8. dim_head = 64,
  9. heads = 8,
  10. ff_mult = 4
  11. )
  12. x = torch.randn(1, 3, 256, 256)
  13. pred = model(x) # (1, 3, 256, 256)

To condition on time for DDPM training

  1. import torch
  2. from uformer_pytorch import Uformer
  3. model = Uformer(
  4. dim = 64,
  5. stages = 4,
  6. num_blocks = 2,
  7. window_size = 16,
  8. dim_head = 64,
  9. heads = 8,
  10. ff_mult = 4,
  11. time_emb = True # set this to true
  12. )
  13. x = torch.randn(1, 3, 256, 256)
  14. time = torch.arange(1)
  15. pred = model(x, time = time) # (1, 3, 256, 256)

Citations

  1. @misc{wang2021uformer,
  2. title = {Uformer: A General U-Shaped Transformer for Image Restoration},
  3. author = {Zhendong Wang and Xiaodong Cun and Jianmin Bao and Jianzhuang Liu},
  4. year = {2021},
  5. eprint = {2106.03106},
  6. archivePrefix = {arXiv},
  7. primaryClass = {cs.CV}
  8. }