tensorflow android Remember LSTM state for next batch (stateful LSTM)

Eli Leszczynski Source

I'm trying to implement a language model for speech recognition with tesorflow in android, i have trained an rnn lstm language model that works quite well in pyhton, the problem is the i need to keep the state outside of the graph after each prediction and in android i can only use flat arrays(arrays with one dimension) as input or output.

i have built my solution based on : TensorFlow: Remember LSTM state for next batch (stateful LSTM)

and added :

    fs = tf.stack([self.final_state[0].c, self.final_state[0].h ,self.final_state[1].c, self.final_state[1].h] )
    self.flat_fs = tf.reshape(fs,[-1] ,name="flat_output" )

as output, which works properly

where this is how i handled the state tuple :

 self.state_in = tf.placeholder(tf.float32, [args.num_layers, 2, None, 
              args.rnn_size], name='state_in')
 l = tf.unstack(self.state_in, axis=0)
 self.state_tup = tuple(
        [tf.nn.rnn_cell.LSTMStateTuple(l[idx][0], l[idx][1])
            for idx in range(args.num_layers)])

I'm trying to find a way to send to the graph the state as a 1 dim array :

     self.flat_stat_in  = tf.placeholder(tf.float32, [ args.rnn_size *args.num_layers*2 ] ,name="flat_stat_in" )        
     self.state_in = tf.reshape(self.flat_stat_in , [args.num_layers,2,1,args.rnn_size])

but this is not working and this is the error I'm getting :

  You must feed a value for placeholder tensor 'state_in' with dtype float and shape [2,2,?,128]
 [[Node: state_in = Placeholder[dtype=DT_FLOAT, shape=[2,2,?,128], _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]

Anyone has a good suggestions for building this correctly in a way that will work in Android?



comments powered by Disqus