TensorLayer出现数据形状转换错误

参考链接:

TensorLayer/tutorial_cifar10_cnn_static.py at master · tensorlayer/TensorLayer (github.com)

1.错误部分

使用TensorLayer来进行CIFAR-10 数据集上的图像分类,直接运行源文件的时候出现了如下错误:

1
2
InvalidArgumentError: Input to reshape is a tensor with 128 values, but the requested shape has 1
[[{{node Reshape}}]]

出现错误部分的代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def _map_fn_train(img, target):
# 1. Randomly crop a [height, width] section of the image.
img = tf.image.random_crop(img, [24, 24, 3])
# 2. Randomly flip the image horizontally.
img = tf.image.random_flip_left_right(img)
# 3. Randomly change brightness.
img = tf.image.random_brightness(img, max_delta=63)
# 4. Randomly change contrast.
img = tf.image.random_contrast(img, lower=0.2, upper=1.8)
# 5. Subtract off the mean and divide by the variance of the pixels.
img = tf.image.per_image_standardization(img)
target = tf.reshape(target, ())
return img, target

# dataset API and augmentation
train_ds = tf.data.Dataset.from_generator(
generator_train, output_types=(tf.float32, tf.int32)
) # , output_shapes=((24, 24, 3), (1)))
# train_ds = train_ds.repeat(n_epoch)
train_ds = train_ds.shuffle(shuffle_buffer_size)
train_ds = train_ds.prefetch(buffer_size=4096)
train_ds = train_ds.batch(batch_size)
train_ds = train_ds.map(_map_fn_train, num_parallel_calls=multiprocessing.cpu_count())

2.具体分析

问题的具体描述是出现了数据形状转换错误,而通过排查可以确定问题就出现在map进行的数据转换上,其具体原因为train_ds在进行map转换时,首先进行了batch操作,将数据集转化为了小批量数据的格式,而map函数进行操作时的操作对象是单一的数据,因此数据格式出现了冲突,导致了该问题的发生。

3.解决办法

我们需要在进行batch前先进行map操作,完成转换后再进行小批量处理。


TensorLayer出现数据形状转换错误
https://fulequn.github.io/2022/09/Article202209011/
作者
Fulequn
发布于
2022年9月1日
许可协议