首页 > 代码库 > keras 入门之 regression
keras 入门之 regression
本实验分三步:
1. 建立数据集
2. 建立网络并训练
3. 可视化
import numpy as npfrom keras.models import Sequentialfrom keras.layers import Densefrom keras.optimizers import SGD# 构建数据集X_data = http://www.mamicode.com/np.linspace(-1,1,100)[:, np.newaxis]noise = np.random.normal(0,0.05,X_data.shape)y_data = np.square(X_data) + noise + 0.5# 构建神经网络model = Sequential()model.add(Dense(10, input_dim=1, init=‘normal‘, activation=‘relu‘))model.add(Dense(1, init=‘normal‘))sgd = SGD(lr=0.1)model.compile(loss=‘mean_squared_error‘, optimizer=sgd)# 训练model.fit(X_data, y_data, nb_epoch=1000, batch_size=100, verbose=0)# 在原数据上预测y_predict=model.predict(X_data,batch_size=100,verbose=1)# 可视化import matplotlib.pyplot as pltfig = plt.figure()ax = fig.add_subplot(1,1,1)ax.scatter(X_data, y_data)ax.plot(X_data,y_predict,‘r-‘,lw=5)plt.show()
训练结果:
keras 入门之 regression
声明:以上内容来自用户投稿及互联网公开渠道收集整理发布,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任,若内容有误或涉及侵权可进行投诉: 投诉/举报 工作人员会在5个工作日内联系你,一经查实,本站将立刻删除涉嫌侵权内容。