Summary
The tensor-type analysis does not associate a TensorType with function parameters that are bound to tensor values at their call sites. Affected functions are consequently treated as having no tensor parameter by downstream consumers.
Verified Cases (Call-Site Evidence)
Reading each call site confirms the parameter receives a tensor, yet the analysis yields no tensor type:
| Function |
Param |
Bound to, at the call site |
Gpt2.get_loss(self, real, pred) |
real, pred |
get_loss(targets, predictions): model output + label tensors |
BiLSTM.call(self, inputs) |
inputs |
token-ID tensor into a Keras Embedding |
TextCNN.predict(self, inputs) |
inputs |
token-ID tensor into a Keras Embedding |
TuckERLoader.target_convert(self, targets, ...) |
targets |
a tf.data padded_batch int32 element |
Three propagation shapes: interprocedural (a tensor argument to a callee), Keras call dispatch (a tensor into __call__ reaching an Embedding), and a tf.data dataset element.
Suggested Minimal Repro (Interprocedural)
import tensorflow as tf
class Model:
def call(self, x):
return x * 2.0
def get_loss(self, real, pred):
return tf.reduce_mean(tf.square(pred - real))
def train_step(self, inputs, targets):
predictions = self.call(inputs)
return self.get_loss(targets, predictions) # both args are tensors
def main():
Model().train_step(tf.constant([1., 2., 3.]), tf.constant([1., 1., 1.]))
Expected: get_loss's real and pred each carry a TensorType(float32). Observed: neither parameter receives a tensor type.
Context
Surfaced while regenerating a @tf.function refactoring corpus, where these functions are left undecorated because the parameter is not seen as a tensor. An earlier evaluation classified some of these as tensor parameters, but that pipeline predates a rewrite of the consumer's classification code, so this is reported as a current under-approximation, not a verified version-to-version regression.
Summary
The tensor-type analysis does not associate a
TensorTypewith function parameters that are bound to tensor values at their call sites. Affected functions are consequently treated as having no tensor parameter by downstream consumers.Verified Cases (Call-Site Evidence)
Reading each call site confirms the parameter receives a tensor, yet the analysis yields no tensor type:
Gpt2.get_loss(self, real, pred)real,predget_loss(targets, predictions): model output + label tensorsBiLSTM.call(self, inputs)inputsEmbeddingTextCNN.predict(self, inputs)inputsEmbeddingTuckERLoader.target_convert(self, targets, ...)targetstf.datapadded_batchint32elementThree propagation shapes: interprocedural (a tensor argument to a callee), Keras
calldispatch (a tensor into__call__reaching anEmbedding), and atf.datadataset element.Suggested Minimal Repro (Interprocedural)
Expected:
get_loss'srealandpredeach carry aTensorType(float32). Observed: neither parameter receives a tensor type.Context
Surfaced while regenerating a
@tf.functionrefactoring corpus, where these functions are left undecorated because the parameter is not seen as a tensor. An earlier evaluation classified some of these as tensor parameters, but that pipeline predates a rewrite of the consumer's classification code, so this is reported as a current under-approximation, not a verified version-to-version regression.