首页 > 代码库 > keras 入门之 regression

keras 入门之 regression

本实验分三步:

1. 建立数据集

2. 建立网络并训练

3. 可视化

import numpy as npfrom keras.models import Sequentialfrom keras.layers import Densefrom keras.optimizers import SGD# 构建数据集X_data = http://www.mamicode.com/np.linspace(-1,1,100)[:, np.newaxis]noise = np.random.normal(0,0.05,X_data.shape)y_data = np.square(X_data) + noise + 0.5# 构建神经网络model = Sequential()model.add(Dense(10, input_dim=1, init=normal, activation=relu))model.add(Dense(1, init=normal))sgd = SGD(lr=0.1)model.compile(loss=mean_squared_error, optimizer=sgd)# 训练model.fit(X_data, y_data, nb_epoch=1000, batch_size=100, verbose=0)# 在原数据上预测y_predict=model.predict(X_data,batch_size=100,verbose=1)# 可视化import matplotlib.pyplot as pltfig = plt.figure()ax = fig.add_subplot(1,1,1)ax.scatter(X_data, y_data)ax.plot(X_data,y_predict,r-,lw=5)plt.show()

训练结果:

技术分享

keras 入门之 regression