自定义块需要满足的功能:
将输入数据作为其前向传播函数的参数。
通过前向传播函数来生成输出。请注意,输出的形状可能与输入的形状不同。例如,我们上面模型中的第一个全连接的层接收任意维的输入,但是返回一个维度256的输出。
计算其输出关于输入的梯度,可通过其反向传播函数进行访问。通常这是自动发生的。
存储和访问前向传播计算所需的参数。
根据需要初始化模型参数。
1 2 3 4 5 6 7 8 9 10 11 12 class MLP (tf.keras.Model): def __init__ (self ): super ().__init__() self.hidden = tf.keras.layers.Dense(units=256 , activation=tf.nn.relu) self.out = tf.keras.layers.Dense(units=10 ) def call (self, X ): return self.out(self.hidden((X)))
反向传播系统自动生成。
顺序块
实现以下函数:
一种将块逐个追加到列表中的函数;
一种前向传播函数,用于将输入按追加块的顺序传递给块组成的“链条”。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 class MySequential (tf.keras.Model): def __init__ (self, *args ): super ().__init__() self.modules = [] for block in args: self.modules.append(block) def call (self, X ): for module in self.modules: X = module(X) return X net = MySequential( tf.keras.layers.Dense(units=256 , activation=tf.nn.relu), tf.keras.layers.Dense(10 )) net(X)
嵌套与变量
我们需要一个计算函数 $f(\pmb{x},\pmb{w})=c\pmb{w}^\mathrm T\pmb{x}$ 的层, 其中 $\pmb{x}$ 是输入, $\pmb{w}$ 是参数, $c$ 是某个在优化过程中没有更新的指定常量。 因此我们实现了一个FixedHiddenMLP
类
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 class FixedHiddenMLP (tf.keras.Model): def __init__ (self ): super ().__init__() self.flatten = tf.keras.layers.Flatten() self.rand_weight = tf.constant(tf.random.uniform((20 , 20 ))) self.dense = tf.keras.layers.Dense(20 , activation=tf.nn.relu) def call (self, inputs ): X = self.flatten(inputs) X = tf.nn.relu(tf.matmul(X, self.rand_weight) + 1 ) X = self.dense(X) return X net = FixedHiddenMLP() net(X)
同样可以实现嵌套。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 class NestMLP (tf.keras.Model): def __init__ (self ): super ().__init__() self.net = tf.keras.Sequential() self.net.add(tf.keras.layers.Dense(64 , activation=tf.nn.relu)) self.net.add(tf.keras.layers.Dense(32 , activation=tf.nn.relu)) self.dense = tf.keras.layers.Dense(16 , activation=tf.nn.relu) def call (self, inputs ): return self.dense(self.net(inputs)) chimera = tf.keras.Sequential() chimera.add(NestMLP()) chimera.add(tf.keras.layers.Dense(20 )) chimera.add(FixedHiddenMLP()) chimera(X)
这个网络的结构如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 Input | |--> NestMLP | | | |--> Dense (64 ) - ReLU | | | |--> Dense (32 ) - ReLU | | | |--> Dense (16 ) - ReLU | | |--> Dense (20 ) | | |--> FixedHiddenMLP