Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4738,11 +4738,15 @@ public void testDecoratedCallDepth()

/**
* A {@code @tf.function(input_signature=[(None,) int32])} passes its parameter to {@code g} (the
* FUT). At runtime {@code g} receives the signature's {@code (None,)} int32, since the signature
* governs the parameter and propagates to the callee.
*
* <p>TODO: Ariadne ignores {@code input_signature} and types {@code g}'s parameter from the
* call-site argument {@code (3,)} int32, which is <em>unsound</em> here. Tracked by <a
* FUT). What {@code g} receives depends on the execution mode, which a static analysis cannot
* determine: traced (the default) the signature governs and {@code g} receives {@code (None,)}
* int32; under {@code run_functions_eagerly} the signature is ignored and {@code g} receives the
* call-site argument's {@code (3,)} int32. So the sound type of {@code g}'s parameter is the set
* {@code {(None,), (3,)}} int32.
*
* <p>TODO: this pins the current behavior. Ariadne does not consume {@code input_signature}, so
* it produces only the argument-derived {@code (3,)} element and misses the signature-derived
* {@code (None,)} one; the sound result is the set {@code {(None,), (3,)}} int32. Tracked by <a
* href="https://github.com/wala/ML/issues/638">wala/ML#638</a>.
*/
@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,16 @@ def g(b):

@tf.function(input_signature=[tf.TensorSpec(shape=(None,), dtype=tf.int32)])
def f(x):
# `input_signature` governs `x`, so at `g`'s call site the argument is (None,) int32,
# NOT the (3,) of the value passed to `f` below.
# Traced (the default): `input_signature` governs `x`, so `g` receives (None,) int32.
assert x.shape.as_list() == [None]
assert x.dtype == tf.int32
g(x)


f(tf.constant([1, 2, 3], dtype=tf.int32))
# Under `run_functions_eagerly` the signature would be ignored and `g` would instead receive this
# argument's (3,) int32. A static analysis cannot know the execution mode, so the sound type of
# `g`'s parameter is the set {(None,), (3,)} int32.
arg = tf.constant([1, 2, 3], dtype=tf.int32)
assert arg.shape == (3,)
assert arg.dtype == tf.int32
f(arg)
Loading