首页 > 代码库 > Softmax回归(使用tensorflow)
Softmax回归(使用tensorflow)
1 # coding:utf8 2 import numpy as np 3 import cPickle 4 import os 5 import tensorflow as tf 6 7 class SoftMax: 8 def __init__(self,MAXT=30,step=0.0025): 9 self.MAXT = MAXT10 self.step = step11 12 def load_theta(self,datapath="data/softmax.pkl"):13 self.theta = cPickle.load(open(datapath,‘rb‘))14 15 def process_train(self,data,label,typenum=10,batch_size=500):16 batches = data.shape[0] / batch_size17 valuenum=data.shape[1]18 if len(label.shape)==1:19 label=self.reshape_data(label,typenum)20 x = tf.placeholder("float", [None,valuenum])21 theta = tf.Variable(tf.zeros([valuenum,typenum]))22 y = tf.nn.softmax(tf.matmul(x,theta))23 y_ = tf.placeholder("float", [None, typenum])24 cross_entropy = -tf.reduce_sum(y_*tf.log(y)) #交叉熵25 train_step = tf.train.GradientDescentOptimizer(self.step).minimize(cross_entropy)26 init = tf.initialize_all_variables()27 sess = tf.Session()28 sess.run(init)29 for epoch in range(self.MAXT):30 cost_=[]31 for index in xrange(batches):32 c_,_=sess.run([cross_entropy,train_step], feed_dict={ x: data[index * batch_size: (index + 1) * batch_size],33 y_: label[index * batch_size: (index + 1) * batch_size]})34 cost_.append(c_)35 if epoch % 5 == 0:36 print(( ‘epoch %i, minibatch %i/%i,averange cost is %f‘) %37 (epoch,index + 1,batches,np.mean(cost_)))38 self.theta=sess.run(theta)39 if not os.path.exists(‘data/softmax.pkl‘):40 f= open("data/softmax.pkl",‘wb‘)41 cPickle.dump(self.theta,f)42 f.close()43 return self.theta44 45 46 def process_test(self,data,label,typenum=10):47 valuenum=data.shape[1]48 if len(label.shape)==1:49 label=self.reshape_data(label,typenum)50 x = tf.placeholder("float", [None,valuenum])51 theta = self.theta52 y = tf.nn.softmax(tf.matmul(x,theta))53 y_ = tf.placeholder("float", [None, typenum])54 init = tf.initialize_all_variables()55 sess = tf.Session()56 sess.run(init)57 correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))58 accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))59 print "Accuracy: ",sess.run(accuracy, feed_dict={x: data,y_: label})60 61 def h(self,x):62 m = np.exp(np.dot(x,self.theta))63 sump = np.sum(m,axis=1)64 return m/sump65 66 def predict(self,x):67 return np.argmax(self.h(x),axis=1)68 69 def reshape_data(self,label,typenum):70 label_=[]71 for yl_ in label:72 tl_=np.zeros(typenum)73 tl_[yl_]=1.074 label_.append(tl_)75 return np.mat(label_)76 77 if __name__ == ‘__main__‘:78 f = open(‘mnist.pkl‘, ‘rb‘)79 training_data, validation_data, test_data =http://www.mamicode.com/ cPickle.load(f)80 training_inputs = [np.reshape(x, 784) for x in training_data[0]]81 data =http://www.mamicode.com/ np.array(training_inputs)82 training_inputs = [np.reshape(x, 784) for x in validation_data[0]]83 vdata =http://www.mamicode.com/ np.array(training_inputs)84 f.close()85 86 softmax = SoftMax()87 softmax.process_train(data,training_data[1])88 softmax.process_test(vdata,validation_data[1]) #Accuracy: 0.926989 softmax.process_test(data,training_data[1]) #Accuracy: 0.92718
Softmax回归(使用tensorflow)
声明:以上内容来自用户投稿及互联网公开渠道收集整理发布,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任,若内容有误或涉及侵权可进行投诉: 投诉/举报 工作人员会在5个工作日内联系你,一经查实,本站将立刻删除涉嫌侵权内容。