0
点赞
收藏
分享

微信扫一扫

Pytorch-2维度卷积

小禹说财 2022-02-01 阅读 40

在这里插入图片描述

import torchvision
from torch import nn
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

data = torchvision.datasets.CIFAR10('./dataset/', train=False, transform=transforms.ToTensor(), download=True)
data_lodaer = DataLoader(dataset=data, batch_size=64)


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv2d = nn.Conv2d(in_channels=3, out_channels=6,
                                kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        x = self.conv2d(x)
        return x


writter = SummaryWriter('./cifa10/')
model = Model()
step = 0
for data in data_lodaer:
    imgs, targets = data
    out = model(imgs)
    # print(imgs.shape,out.shape) torch.Size([64, 3, 32, 32]) torch.Size([64, 6, 32, 32])
    writter.add_images('imgs', imgs, step)
    writter.add_images('out', out.reshape((-1, 3, 32, 32)), step)
    step += 1
writter.close()

举报

相关推荐

0 条评论