首页 > 代码库 > How to do sparse input text classification(dnn) using tensorflow
How to do sparse input text classification(dnn) using tensorflow
You can get complete example code from
https://github.com/chenghuige/tensorflow-example
??
Including
- How to parse libsvm dataset file to tfrecords
- Reading tfrecords and do dnn/logistic regresssion classifciation/regresssion
- Train + evaluate
- See train process (loss and metric track) in tensorboard
- Show how to use melt.train_flow to handle all other things(optimizer, learning rate, model saving, log …)
??
The main realated code:
melt.tfrecords/libsvm_decode #parsing libsvm file
melt.models.mlp
??
def forward(inputs,
num_outputs,
input_dim=None,
hiddens=[200],
activation_fn=tf.nn.relu,
weights_initializer=initializers.xavier_initializer(),
weights_regularizer=None,
biases_initializer=init_ops.zeros_initializer(),
biases_regularizer=None,
reuse=None,
scope=None
):
??
text-classfication/model.py shows how to use this
??
You must specify num_outputs and input_dim for sparse input dataset
For example 10 classes classficiation problem then num_outputs=10
If you do regresssion then num_outputs=1
input_dim should be the same as your dataset num input features
??
??
You may change hiddens, the default is [200], means only 1 hidden layer size 200,
You can use more hiddens like [200, 100, 100] means 3 hidden layers with size 200,100,100
You may also set hiddens [] empty , means you only do logistic regression
??
What‘s the diff between melt.layers.fully_connected and tf.contrib.layers.fully_connected?
Well similary but we will also deal with sparse input, the main difference in here
We use melt.matmul
def matmul(X, w): | ?? |
?? | if isinstance(X, tf.Tensor): |
?? | return tf.matmul(X,w) |
?? | else: |
?? | #X[0] index, X[1] value |
?? | return tf.nn.embedding_lookup_sparse(w, X[0], X[1], combiner=‘sum‘) |
??
来自 <https://github.com/chenghuige/tensorflow-example/blob/master/util/melt/ops/ops.py>
??
??
??
??
Tensorboard show:
??
??
??
??
How to do sparse input text classification(dnn) using tensorflow