首页 > AI资讯 > 最新资讯 > GitHub7.5kstar量,各种视觉Transformer的PyTorch实现合集整理好了

GitHub7.5kstar量,各种视觉Transformer的PyTorch实现合集整理好了

新火种    2023-11-01

机器之心报道

编辑:杜伟

这个项目登上了今天的GitHub Trending。

近一两年,Transformer 跨界 CV 任务不再是什么新鲜事了。

自 2020 年 10 月谷歌提出 Vision Transformer (ViT) 以来,各式各样视觉 Transformer 开始在图像合成、点云处理、视觉 - 语言建模等领域大显身手。

之后,在 PyTorch 中实现 Vision Transformer 成为了研究热点。GitHub 中也出现了很多优秀的项目,今天要介绍的就是其中之一。

该项目名为「vit-pytorch」, 它是一个 Vision Transformer 实现,展示了一种在 PyTorch 中仅使用单个 transformer 编码器来实现视觉分类 SOTA 结果的简单方法。

项目当前的 star 量已经达到了 7.5k,创建者为 Phil Wang,ta 在 GitHub 上有 147 个资源库。

项目地址:https://github.com/lucidrains/vit-pytorch

项目作者还提供了一段动图展示:

项目介绍

首先来看 Vision Transformer-PyTorch 的安装、使用、参数、蒸馏等步骤。

第一步是安装:

$ pip install vit-pytorch

第二步是使用:

import torch

from vit_pytorch import ViT

v = ViT(

image_size = 256,

patch_size = 32,

num_classes = 1000,

dim = 1024,

depth = 6,

heads = 16,

mlp_dim = 2048,

dropout = 0.1,

emb_dropout = 0.1

)

img = torch.randn(1, 3, 256, 256)

preds = v(img) # (1, 1000)

第三步是所需参数,包括如下:

image_size:图像大小

patch_size:patch 数量

num_classes:分类类别的数量

dim:线性变换 nn.Linear(..., dim) 后输出张量的最后维

depth:Transformer 块的数量

heads:多头注意力层中头的数量

mlp_dim:MLP(前馈)层的维数

channels:图像通道的数量

dropout:Dropout rate

emb_dropout:嵌入 dropout rate

……

最后是蒸馏,采用的流程出自 Facebook AI 和索邦大学的论文《Training data-efficient image transformers & distillation through attention》。

论文地址:https://arxiv.org/pdf/2012.12877.pdf

从 ResNet50(或任何教师网络)蒸馏到 vision transformer 的代码如下:

import torchfrom torchvision.models import resnet50from vit_pytorch.distill import DistillableViT, DistillWrapperteacher = resnet50(pretrained = True)

v = DistillableViT(

image_size = 256,

patch_size = 32,

num_classes = 1000,

dim = 1024,

depth = 6,

heads = 8,

mlp_dim = 2048,

dropout = 0.1,

emb_dropout = 0.1

)

distiller = DistillWrapper(

student = v,

teacher = teacher,

temperature = 3, # temperature of distillationalpha = 0.5, # trade between main loss and distillation losshard = False # whether to use soft or hard distillation

)

img = torch.randn(2, 3, 256, 256)labels = torch.randint(0, 1000, (2,))

loss = distiller(img, labels)loss.backward()

# after lots of training above ...pred = v(img) # (2, 1000)

除了 Vision Transformer 之外,该项目还提供了 Deep ViT、CaiT、Token-to-Token ViT、PiT 等其他 ViT 变体模型的 PyTorch 实现。

对 ViT 模型 PyTorch 实现感兴趣的读者可以参阅原项目。

Tags:
相关推荐
免责声明
本文所包含的观点仅代表作者个人看法,不代表新火种的观点。在新火种上获取的所有信息均不应被视为投资建议。新火种对本文可能提及或链接的任何项目不表示认可。 交易和投资涉及高风险,读者在采取与本文内容相关的任何行动之前,请务必进行充分的尽职调查。最终的决策应该基于您自己的独立判断。新火种不对因依赖本文观点而产生的任何金钱损失负任何责任。