Pytorch错误Expected input batch_size (324) to match target batch_size (4) Log In

参考链接:

https://blog.csdn.net/qq_41429220/article/details/104973805

Pytorch Error: ValueError: Expected input batch_size (324) to match target batch_size (4) Log In

1.ERROR原因

使用pytorch训练一个自定义的模型,参照网上的博客直接照搬网络,但是在修改自定义数据集时,出现这个错误。很明显是一个图像参数不匹配问题,自定义数据集的图片大小规格不统一且与网络接受的大小不匹配。

1
ValueError: Expected input batch_size (324) to match target batch_size (4) Log In

2.解决思路

首先,在错误的网络结构处前后加入print来查看网络结构。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# 构建CNN模型
class CNNNet(nn.Module):

def __init__(self):
super(CNNNet, self).__init__()
self.conv1 = nn.Conv2d(3, 64, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(64, 128, 5)
self.fc1 = nn.Linear(128*53*53, 1024)
self.fc2 = nn.Linear(1024, 84)
self.fc3 = nn.Linear(84, 10)

def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
# print(x.shape)
x = x.view(-1, 128*53*53)
# print(x.shape)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x

即我注释的这个地方,可以得到输入前的数据格式。

1
torch.Size([4, 128, 53, 53])

根据输出的形状来更改view里的参数。

1
x = x.view(-1, 128 * 53 * 53)

后面的Linear层也需要对应修改,使其与数据输入匹配:

1
self.fc1 = nn.Linear(128 * 53 * 53, 1024)

Pytorch错误Expected input batch_size (324) to match target batch_size (4) Log In
https://fulequn.github.io/2021/01/Article202101101/
作者
Fulequn
发布于
2021年1月10日
许可协议