扫二维码与项目经理沟通
我们在微信上24小时期待你的声音
解答本文疑问/技术咨询/运营咨询/技术建议/互联网交流
这篇文章主要讲解了Keras将两个模型连接到一起的实现方法,内容清晰明了,对此有兴趣的小伙伴可以学习一下,相信大家阅读完之后会有帮助。
让胡路网站建设公司创新互联建站,让胡路网站设计制作,有大型网站制作公司丰富经验。已为让胡路超过千家提供企业网站建设服务。企业网站搭建\外贸网站制作要多少钱,请找那个售后服务好的让胡路做网站的公司定做!神经网络玩得越久就越会尝试一些网络结构上的大改动。
先说意图
有两个模型:模型A和模型B。模型A的输出可以连接B的输入。将两个小模型连接成一个大模型,A-B,既可以同时训练又可以分离训练。
流行的算法里经常有这么关系的两个模型,对GAN来说,生成器和判别器就是这样子;对VAE来说,编码器和解码器就是这样子;对目标检测网络来说,backbone和整体也是可以拆分的。所以,应用范围还是挺广的。
实现方法
首先说明,我的实现方法不一定是最佳方法。也是实在没有借鉴到比较好的方法,所以才自己手动写了一个。
第一步,我们有现成的两个模型A和B;我们想把A的输出连到B的输入,组成一个整体C。
第二步, 重构新模型C;我的方法是:读出A和B各有哪些layer,然后一层一层重新搭成C。
可以看一个自编码器的代码(本人所编写):
class AE: def __init__(self, dim, img_dim, batch_size): self.dim = dim self.img_dim = img_dim self.batch_size = batch_size self.encoder = self.encoder_construct() self.decoder = self.decoder_construct() def encoder_construct(self): x_in = Input(shape=(self.img_dim, self.img_dim, 3)) x = x_in x = Conv2D(self.dim // 16, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x) x = BatchNormalization()(x) x = LeakyReLU(0.2)(x) x = Conv2D(self.dim // 8, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x) x = BatchNormalization()(x) x = LeakyReLU(0.2)(x) x = Conv2D(self.dim // 4, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x) x = BatchNormalization()(x) x = LeakyReLU(0.2)(x) x = Conv2D(self.dim // 2, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x) x = BatchNormalization()(x) x = LeakyReLU(0.2)(x) x = Conv2D(self.dim, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x) x = BatchNormalization()(x) x = LeakyReLU(0.2)(x) x = GlobalAveragePooling2D()(x) encoder = Model(x_in, x) return encoder def decoder_construct(self): map_size = K.int_shape(self.encoder.layers[-2].output)[1:-1] # print(type(map_size)) z_in = Input(shape=K.int_shape(self.encoder.output)[1:]) z = z_in z_dim = self.dim z = Dense(np.prod(map_size) * z_dim)(z) z = Reshape(map_size + (z_dim,))(z) z = Conv2DTranspose(z_dim // 2, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z) z = BatchNormalization()(z) z = Activation('relu')(z) z = Conv2DTranspose(z_dim // 4, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z) z = BatchNormalization()(z) z = Activation('relu')(z) z = Conv2DTranspose(z_dim // 8, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z) z = BatchNormalization()(z) z = Activation('relu')(z) z = Conv2DTranspose(z_dim // 16, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z) z = BatchNormalization()(z) z = Activation('relu')(z) z = Conv2DTranspose(3, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z) z = Activation('tanh')(z) decoder = Model(z_in, z) return decoder def build_ae(self): input_x = Input(shape=(self.img_dim, self.img_dim, 3)) x = input_x for i in range(1, len(self.encoder.layers)): x = self.encoder.layers[i](x) for j in range(1, len(self.decoder.layers)): x = self.decoder.layers[j](x) y = x auto_encoder = Model(input_x, y) return auto_encoder
另外有需要云服务器可以了解下创新互联scvps.cn,海内外云服务器15元起步,三天无理由+7*72小时售后在线,公司持有idc许可证,提供“云服务器、裸金属服务器、高防服务器、香港服务器、美国服务器、虚拟主机、免备案服务器”等云主机租用服务以及企业上云的综合解决方案,具有“安全稳定、简单易用、服务可用性高、性价比高”等特点与优势,专为企业上云打造定制,能够满足用户丰富、多元化的应用场景需求。
我们在微信上24小时期待你的声音
解答本文疑问/技术咨询/运营咨询/技术建议/互联网交流