参考链接:
TensorLayer/tutorial_cifar10_cnn_static.py at master · tensorlayer/TensorLayer (github.com)
1.错误部分
使用TensorLayer来进行CIFAR-10 数据集上的图像分类,直接运行源文件的时候出现了如下错误:
1 2
| InvalidArgumentError: Input to reshape is a tensor with 128 values, but the requested shape has 1 [[{{node Reshape}}]]
|
出现错误部分的代码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
| def _map_fn_train(img, target): img = tf.image.random_crop(img, [24, 24, 3]) img = tf.image.random_flip_left_right(img) img = tf.image.random_brightness(img, max_delta=63) img = tf.image.random_contrast(img, lower=0.2, upper=1.8) img = tf.image.per_image_standardization(img) target = tf.reshape(target, ()) return img, target
train_ds = tf.data.Dataset.from_generator( generator_train, output_types=(tf.float32, tf.int32) )
train_ds = train_ds.shuffle(shuffle_buffer_size) train_ds = train_ds.prefetch(buffer_size=4096) train_ds = train_ds.batch(batch_size) train_ds = train_ds.map(_map_fn_train, num_parallel_calls=multiprocessing.cpu_count())
|
2.具体分析
问题的具体描述是出现了数据形状转换错误,而通过排查可以确定问题就出现在map进行的数据转换上,其具体原因为train_ds在进行map转换时,首先进行了batch操作,将数据集转化为了小批量数据的格式,而map函数进行操作时的操作对象是单一的数据,因此数据格式出现了冲突,导致了该问题的发生。
3.解决办法
我们需要在进行batch前先进行map操作,完成转换后再进行小批量处理。