Custom OP gradient crashes tensorflow

0

I created a custom op for tensorflow 2-beta1 with gradients but when I run the training algorithm, tensorflow crashes with exit code -1073740791 (0xC0000409).

Before the crash I get an error saying it expected int32 but got int64 (F tensorflow/core/framework/tensor.cc:631] Check failed: dtype() == expected_dtype (9 vs. 3) int32 expected, got int64). So I tried casting the only int64 values of my op to int32, but it still persists I have no idea why except there might be some unrelated error in my gradient. The unit test of my custom op is passing.

The error appears to occur before returning from but after executing the train function which had worked before (although this might not be accurate, as it is cycled multiple times before returning/crashing). The only thing I changed since it worked was switching to sparse tensors for a part of my neural network due to memory constraints.

This is my gradient calculation:

@ops.RegisterGradient("BatchSparseDenseMatmul")
def _batch_sparse_dense_matmul_grad(op, grad):
    a_indices, a_values, a_shape = op.inputs[:3]
    b = op.inputs[3]
    b_grad = batch_matmul_module.batch_sparse_dense_matmul(a_indices, a_values, a_shape, grad, transpose_a=True)
    rows = a_indices[:, 1]
    cols = a_indices[:, 2]
    parts_a = array_ops.gather(grad, rows)
    parts_b = array_ops.gather(b, cols)
    a_values_grad = math_ops.reduce_sum(parts_a * parts_b, axis=1)
    return tf.cast(a_indices, tf.int32), a_values_grad, tf.cast(a_shape, tf.int32), b_grad

As I copied a lot of the code for the gradients from the spare_dense_matmul gradient calculation it might be that I did something wrong when adjusting the code, as I do not fully understand what the gather and reduced_sum are good for. Any help is appreciated!

EDIT:

I cast the inputs to my feed forwad like this:

batch_matmul_module.batch_sparse_dense_matmul(tf.cast(a.indices, tf.int32), a.values, tf.cast(a.dense_shape, tf.int32), b)

the error is gone but I get Cannot reshape a tensor with 0 elements to shape [1,1,11520] (11520 elements) for 'Reshape_18' (op: 'Reshape') with input shapes: [0,0], [3] and with input tensors computed as partial shapes: input[1] = [1,1,11520]. but I think this is unrelated.

Another Edit:

Once I fixed the Error shown in the last edit I get the same error again, but I am sure the error is not coming from the op shown above, as for testing with small data I could use to_dense and matmul. Is there a way to debug this?

python
tensorflow
crash
sparse-matrix
asked on Stack Overflow Sep 23, 2019 by McLP • edited Sep 24, 2019 by McLP

0 Answers

Nobody has answered this question yet.


User contributions licensed under CC BY-SA 3.0