What's wrong with this code? Edit: It works on CPU, but fails when ran on GPU. It runs for a few iterations, then fails with one of errors (github issue here):
2019-12-02 12:59:29.727966: F tensorflow/core/framework/tensor_shape.cc:445] Check failed: end <= dims() (1 vs. 0)
Process finished with exit code -1073740791 (0xC0000409)
or
tensorflow.python.framework.errors_impl.InvalidArgumentError: Tried to set a tensor with incompatible shape at a list index. Item element shape: [3,3] list shape: [3]
[[{{node while/body/_1/TensorArrayV2Write/TensorListSetItem}}]] [Op:__inference_computeElement_73]
@tf.function
def computeElement_byBin():
c = tf.TensorArray(tf.int64, size=1, infer_shape=False, element_shape=(3,))
const = tf.cast(tf.constant([1, 2, 3]), tf.int64)
c = c.write(0, const)
c_c = c.concat()
return c_c
@tf.function
def computeElement():
c = tf.TensorArray(tf.int64, size=1, infer_shape=False, element_shape=(3,))
for x in tf.range(50):
byBinVariant = computeElement_byBin()
c = c.write(0, byBinVariant)
return c.concat()
k = 0
while True:
k += 1
r = computeElement()
print('iteration: %s, result: %s' % (k, r))
I played around with it more and narrowed it down a bit:
@tf.function
def computeElement():
tarr = tf.TensorArray(tf.int32, size=1,clear_after_read=False)
tarr = tarr.write(0, [1])
concat = tarr.concat()
# PROBLEM HERE
for x in tf.range(50):
concat = tarr.concat()
return concat
If you set tf.config.threading.set_inter_op_parallelism_threads(1)
the bug goes away, which means it's to do with parallelization of the unrolled tensorflow loop. Knowing that tensorflow unrolls statically when looping over a python variable rather than a tensor, I could confirm that this code worked:
@tf.function
def computeElement(arr):
tarr = tf.TensorArray(tf.int32, size=1)
tarr = tarr.write(0, [1])
concat = tarr.concat()
a = 0
while a<arr:
concat = tarr.concat()
a+=1
return concat
k = 0
while True:
k += 1
r = computeElement(50)
So solution for now is to loop over a python variable rather than a tensor.
User contributions licensed under CC BY-SA 3.0