扫二维码与项目经理沟通
我们在微信上24小时期待你的声音
解答本文疑问/技术咨询/运营咨询/技术建议/互联网交流
创新互联www.cdcxhl.cn八线动态BGP香港云服务器提供商,新人活动买多久送多久,划算不套路!
成都创新互联公司主要从事成都网站设计、成都网站建设、外贸网站建设、网页设计、企业做网站、公司建网站等业务。立足成都服务余杭,十余年网站建设经验,价格优惠、服务专业,欢迎来电咨询建站服务:18982081108不懂如何实现keras训练浅层卷积网络并保存和加载模型?其实想解决这个问题也不难,下面让小编带着大家一起学习怎么去解决,希望大家阅读完这篇文章后大所收获。
这里我们使用keras定义简单的神经网络全连接层训练MNIST数据集和cifar10数据集:
keras_mnist.py
from sklearn.preprocessing import LabelBinarizer from sklearn.model_selection import train_test_split from sklearn.metrics import classification_report from keras.models import Sequential from keras.layers.core import Dense from keras.optimizers import SGD from sklearn import datasets import matplotlib.pyplot as plt import numpy as np import argparse # 命令行参数运行 ap = argparse.ArgumentParser() ap.add_argument("-o", "--output", required=True, help="path to the output loss/accuracy plot") args =vars(ap.parse_args()) # 加载数据MNIST,然后归一化到【0,1】,同时使用75%做训练,25%做测试 print("[INFO] loading MNIST (full) dataset") dataset = datasets.fetch_mldata("MNIST Original", data_home="/home/king/test/python/train/pyimagesearch/nn/data/") data = dataset.data.astype("float") / 255.0 (trainX, testX, trainY, testY) = train_test_split(data, dataset.target, test_size=0.25) # 将label进行one-hot编码 lb = LabelBinarizer() trainY = lb.fit_transform(trainY) testY = lb.transform(testY) # keras定义网络结构784--256--128--10 model = Sequential() model.add(Dense(256, input_shape=(784,), activation="relu")) model.add(Dense(128, activation="relu")) model.add(Dense(10, activation="softmax")) # 开始训练 print("[INFO] training network...") # 0.01的学习率 sgd = SGD(0.01) # 交叉验证 model.compile(loss="categorical_crossentropy", optimizer=sgd, metrics=['accuracy']) H = model.fit(trainX, trainY, validation_data=(testX, testY), epochs=100, batch_size=128) # 测试模型和评估 print("[INFO] evaluating network...") predictions = model.predict(testX, batch_size=128) print(classification_report(testY.argmax(axis=1), predictions.argmax(axis=1), target_names=[str(x) for x in lb.classes_])) # 保存可视化训练结果 plt.style.use("ggplot") plt.figure() plt.plot(np.arange(0, 100), H.history["loss"], label="train_loss") plt.plot(np.arange(0, 100), H.history["val_loss"], label="val_loss") plt.plot(np.arange(0, 100), H.history["acc"], label="train_acc") plt.plot(np.arange(0, 100), H.history["val_acc"], label="val_acc") plt.title("Training Loss and Accuracy") plt.xlabel("# Epoch") plt.ylabel("Loss/Accuracy") plt.legend() plt.savefig(args["output"])
我们在微信上24小时期待你的声音
解答本文疑问/技术咨询/运营咨询/技术建议/互联网交流