入门必看|详解TensorFlow里的随机函数

选择“星标”公众号

重磅干货,第一时间送达!

相信小伙伴们平时在训练神经网络时经常是需要对网络的权重进行初始化,大部分情况下通常的做法都是需要随机初始化,TensorFlow提供了不同类型的随机函数来满足我们的实际需求,如果不加以详细分析,经常容易混淆其中的作用原理和实际产生的输出,今天我们就为小伙伴们详细解释下TensorFlow里的随机函数。

tf.random_normal 正态分布随机数

函数定义:tf.random_normal(shape, mean=0.0, stddev=1.0, dtype=tf.float32, seed=None, name=None)函数说明:从正态分布中输出随机值。返回一个指定形状,被随机正态分布值填充的tensor。参数:

  1. shape: 输出的形状。一个1维整形tensor类型或者是python的array类型,也是输出的张量。

  2. mean: 正态分布的均值,一个0维的tensor或者python值,类型为dtype。

  3. stddev: 正态分布的标准差,一个0维的tensor或者python值,类型为dtype。

  4. dtype: 输出的元素类型。

  5. seed: 随机种子,整数,设置后每次生成的随机数都一样。

  6. name: 操作的名字。

实操代码:norm = tf.random_normal([2, 4], seed=1234)sess = tf.Session()print(sess.run(norm))print(sess.run(norm))实际输出:

tf.truncated_normal 截断正态分布随机数

函数定义:tf.truncated_normal(shape, mean=0.0, stddev=1.0, dtype=tf.float32, seed=None, name=None)函数说明:从截断的正态分布中输出随机值,生成的值服从具有指定平均值和标准偏差的正态分布,如果生成的值大于平均值2个标准偏差的值则丢弃重新选择。参数:

  1. shape: 输出的形状。一个1维整形tensor类型或者是python的array类型,也是输出的张量。

  2. mean: 正态分布的均值,一个0维的tensor或者python值,类型为dtype。

  3. stddev: 正态分布的标准差,一个0维的tensor或者python值,类型为dtype。

  4. dtype: 输出的元素类型。

  5. seed: 随机种子,整数,设置后每次生成的随机数都一样。

  6. name: 操作的名字。

实操代码:norm = tf.truncated_normal([2, 4], seed=1234)sess = tf.Session()print(sess.run(norm))print(sess.run(norm))实际输出:

tf.random_uniform 均匀分布随机数

函数定义:tf.random_uniform(shape, minval=0, maxval=None, dtype=tf.float32, seed=None, name=None)函数说明:均匀分布随机数,范围为[minval,maxval)参数:

  1. shape: 输出的形状。一个1维整形tensor类型或者是python的array类型,也是输出的张量。

  2. minval:小的边界(下界),默认是0,一个0维的tensor或者python值,类型为dtype。

  3. maxval:大的那个边界(上界),一个0维的tensor或者python值,类型为dtype。

  4. dtype: 输出的元素类型。

  5. seed: 随机种子,整数,设置后每次生成的随机数都一样。

  6. name: 操作的名字。

实操代码:norm = tf.random_uniform([2,4],0,2,dtype=tf.float32)sess = tf.Session()print(sess.run(norm))print(sess.run(norm))实际输出:

tf.random_shuffle

函数定义:tf.random_shuffle(value, seed=None, name=None)函数说明:对value的第一维进行随机洗牌操作。参数:

  1. value: 要被打乱的tensor

  2. seed: 随机种子,整数,设置后每次生成的随机数都一样。

  3. name: 操作的名字。

实操代码:norm = tf.random_shuffle([[1,2,3],[4,5,6],[7,8,9]], seed=134)sess = tf.Session()print(sess.run(norm))print(sess.run(norm))实际输出:

(0)

相关推荐