最新下载
热门教程
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
Keras实现支持masking的Flatten层代码示例
时间:2020-06-16 编辑:袖梨 来源:一聚教程网
本篇文章小编给大家分享一下Keras实现支持masking的Flatten层代码示例,代码介绍的很详细,小编觉得挺不错的,现在分享给大家供大家参考,有需要的小伙伴们可以来看看。
Keras原本Flatten的实现
class Flatten(Layer): def __init__(self, **kwargs): super(Flatten, self).__init__(**kwargs) self.input_spec = InputSpec(min_ndim=3) def compute_output_shape(self, input_shape): if not all(input_shape[1:]): raise ValueError('The shape of the input to "Flatten" ' 'is not fully defined ' '(got ' + str(input_shape[1:]) + '. ' 'Make sure to pass a complete "input_shape" ' 'or "batch_input_shape" argument to the first ' 'layer in your model.') return (input_shape[0], np.prod(input_shape[1:])) def call(self, inputs): return K.batch_flatten(inputs)
自定义支持masking的实现
事实上,Keras层的mask有时候是需要参与运算的,比如Dense之类的,有时候则只是做某种变换然后传递给后面的层。Flatten属于后者,因为mask总是与input有相同的shape,所以我们要做的就是在compute_mask函数里对mask也做flatten。
from keras import backend as K from keras.engine.topology import Layer import tensorflow as tf import numpy as np class MyFlatten(Layer): def __init__(self, **kwargs): self.supports_masking = True super(MyFlatten, self).__init__(**kwargs) def compute_mask(self, inputs, mask=None): if mask==None: return mask return K.batch_flatten(mask) def call(self, inputs, mask=None): return K.batch_flatten(inputs) def compute_output_shape(self, input_shape): return (input_shape[0], np.prod(input_shape[1:]))
正确性检验
from keras.layers import * from keras.models import Model from MyFlatten import MyFlatten from MySumLayer import MySumLayer from keras.initializers import ones data = [[1,0,0,0], [1,2,0,0], [1,2,3,0], [1,2,3,4]] A = Input(shape=[4]) # None * 4 emb = Embedding(5, 3, mask_zero=True, embeddings_initializer=ones())(A) # None * 4 * 3 fla = MyFlatten()(emb) # None * 12 out = MySumLayer(axis=1)(fla) # None * 1 model = Model(inputs=[A], outputs=[out]) print model.predict(data)
输出:
[ 3. 6. 9. 12.]
相关文章
- Golang ProtoBuf的基本语法详解 10-20
- Python识别MySQL中的冗余索引解析 10-20
- Python+Pygame绘制小球代码展示 10-18
- Python中的数据精度问题介绍 10-18
- Python随机值生成的常用方法介绍 10-18
- python3解压缩.gz文件分析 09-27