首页 > 代码库 > 分布式TensorFlow 采坑记
分布式TensorFlow 采坑记
单机版的TF没毛病,但是当大家在Tensorflow Github里面找到可用的模型,想分布式跑到时候,就会跑出来各种奇怪的问题。我尝试了几种不同构造TF的方式,算是成功渡过了踩坑期,特别记录一下。如果能帮助到各位TF boy最好。
方法一:自己手动写分布式协议
比如logistic regression
在master上运行的伪代码如下
with tf.Session(‘grpc://vm1:2222‘) as sess:
sess.run(initialization)
while not stop:
run(train_op)
master负责的是初始化session,以及将parameter发送给其他的worker。如果有saver也定义在这里。
下面的逻辑就是每个worker收到parameter,计算gradient,然后到master上进行aggregate。最常用的aggregate的方式就是将每一个worker的gradient求和。
Tip: 下面都默认task_index是0的节点为master。在TF中,这个master节点也叫做chief node。
伪代码如下:
with tf.device(‘/job:worker/task:%d" % FLAGS.task_index):
read_data
compute gradient
with tf.device(‘/job:worker/task:0‘):
aggregate weight
但是既然parameter server在sparse的数据集上非常好用,那么我们不妨尝试利用这个特性。
1. 首先每一个worker得到sparse index,传给master。
2. master根据对应的sparse index,得到对应的sparse data。这个也叫做working set,不知道是不是system方向的叫法。
3. 每一个worker从master得到working set。更新gradient。
4. master将每一个worker的gradient进行aggregate。
其实相比于前一个方法,就是多了一次信息传递(1、2步),来获取sparse信息。
伪代码如下:
with tf.device(‘/job:worker/task:%d‘ % FLAGS.task_index):
read data
get sparse index
with tf.device(‘/job:worker/task:0‘):
get sparse index from each worker
generate working set for each worker
with tf.device(‘/job:worker/task:%d‘ % FLAGS.task_index):
get working set from master
compute gradient
with tf.device(‘/job:worker/task:0‘):
aggregate gradient
完整代码在这里.
Tip: 实现的时候,我直接用list存储每一个worker的gradient,sparse index,和working set。在下面提到的TensorFlow中已经实现的类中,使用的是内置的queue。(Python里queue和list差别不大)
方法二:使用MonitoredTrainingSession
TF有内置的类,supervisor和MonotoredTrainingSession是最常用的两个。
MonitoredTrainingSession是MonitoredSession的子类,多增加的功能是为master/chief 节点增加断点功能,以及创建session,分配给其他worker。
如果要保证同步更新,主要“下手”在optimizer上,类似我们在上面的第一个算法,所以TF有一个叫做SyncReplicasOptimizer
的类。
这里一个让我踩了许久的坑:如果用了MonotoredTrainingSession
和SyncReplicasOptimizer
,如果ps相对于worker的分布不是均匀的,那么有的worker会跑的特别快。比如这样设置:
parser.add_argument(
"--ps_hosts",
type=str,
default="vm1:2233",
help="Comma-separated list of hostname:port pairs"
)
parser.add_argument(
"--worker_hosts",
type=str,
default="vm1:2222,vm2:2222,vm3:2222",
help="Comma-separated list of hostname:port pairs"
)
虽说理论上快的worker在每次iteration结束之后应该等queue挤满,但是很神奇的是!这个快的worker会自己再多跑几份session来把queue填充满,从而进入下一次iteration。
所以千万不要把ps和worker分布不均匀。可以每台worker都是ps,也可以固定几个ps与worker不重合。
parser.add_argument(
"--ps_hosts",
type=str,
default="vm1:2233",
help="Comma-separated list of hostname:port pairs"
)
parser.add_argument(
"--worker_hosts",
type=str,
default="vm2:2222,vm3:2222",
help="Comma-separated list of hostname:port pairs"
)
分布式TensorFlow 采坑记