Tensorflow 2: Nested TensorArray

2

What's wrong with this code? Edit: It works on CPU, but fails when ran on GPU. It runs for a few iterations, then fails with one of errors (github issue here):

2019-12-02 12:59:29.727966: F tensorflow/core/framework/tensor_shape.cc:445] Check failed: end <= dims() (1 vs. 0)

Process finished with exit code -1073740791 (0xC0000409)

or

tensorflow.python.framework.errors_impl.InvalidArgumentError:  Tried to set a tensor with incompatible shape at a list index. Item element shape: [3,3] list shape: [3]
     [[{{node while/body/_1/TensorArrayV2Write/TensorListSetItem}}]] [Op:__inference_computeElement_73]

@tf.function
def computeElement_byBin():
    c = tf.TensorArray(tf.int64, size=1, infer_shape=False, element_shape=(3,))
    const = tf.cast(tf.constant([1, 2, 3]), tf.int64)
    c = c.write(0, const)
    c_c = c.concat()
    return c_c

@tf.function
def computeElement():
    c = tf.TensorArray(tf.int64, size=1, infer_shape=False, element_shape=(3,))
    for x in tf.range(50):
        byBinVariant = computeElement_byBin()
        c = c.write(0, byBinVariant)
    return c.concat()

k = 0
while True:
    k += 1
    r = computeElement()
    print('iteration: %s, result: %s' % (k, r))
tensorflow
asked on Stack Overflow Dec 2, 2019 by Alex • edited Dec 2, 2019 by Alex

1 Answer

0

I played around with it more and narrowed it down a bit:

@tf.function
def computeElement():
    tarr = tf.TensorArray(tf.int32, size=1,clear_after_read=False)
    tarr = tarr.write(0, [1])
    concat = tarr.concat()

    # PROBLEM HERE
    for x in tf.range(50):
        concat = tarr.concat()

    return concat

If you set tf.config.threading.set_inter_op_parallelism_threads(1) the bug goes away, which means it's to do with parallelization of the unrolled tensorflow loop. Knowing that tensorflow unrolls statically when looping over a python variable rather than a tensor, I could confirm that this code worked:

@tf.function
def computeElement(arr):
    tarr = tf.TensorArray(tf.int32, size=1)
    tarr = tarr.write(0, [1])
    concat = tarr.concat()

    a = 0
    while a<arr:
        concat = tarr.concat()
        a+=1

    return concat

k = 0
while True:
    k += 1
    r = computeElement(50)

So solution for now is to loop over a python variable rather than a tensor.

answered on Stack Overflow Dec 5, 2019 by Alex

User contributions licensed under CC BY-SA 3.0