Skip to content

Tensor-type analysis misses parameters that receive tensors at their call sites #618

Description

@khatchad

Summary

The tensor-type analysis does not associate a TensorType with function parameters that are bound to tensor values at their call sites. Affected functions are consequently treated as having no tensor parameter by downstream consumers.

Verified Cases (Call-Site Evidence)

Reading each call site confirms the parameter receives a tensor, yet the analysis yields no tensor type:

Function Param Bound to, at the call site
Gpt2.get_loss(self, real, pred) real, pred get_loss(targets, predictions): model output + label tensors
BiLSTM.call(self, inputs) inputs token-ID tensor into a Keras Embedding
TextCNN.predict(self, inputs) inputs token-ID tensor into a Keras Embedding
TuckERLoader.target_convert(self, targets, ...) targets a tf.data padded_batch int32 element

Three propagation shapes: interprocedural (a tensor argument to a callee), Keras call dispatch (a tensor into __call__ reaching an Embedding), and a tf.data dataset element.

Suggested Minimal Repro (Interprocedural)

import tensorflow as tf

class Model:
    def call(self, x):
        return x * 2.0

    def get_loss(self, real, pred):
        return tf.reduce_mean(tf.square(pred - real))

    def train_step(self, inputs, targets):
        predictions = self.call(inputs)
        return self.get_loss(targets, predictions)  # both args are tensors

def main():
    Model().train_step(tf.constant([1., 2., 3.]), tf.constant([1., 1., 1.]))

Expected: get_loss's real and pred each carry a TensorType(float32). Observed: neither parameter receives a tensor type.

Context

Surfaced while regenerating a @tf.function refactoring corpus, where these functions are left undecorated because the parameter is not seen as a tensor. An earlier evaluation classified some of these as tensor parameters, but that pipeline predates a rewrite of the consumer's classification code, so this is reported as a current under-approximation, not a verified version-to-version regression.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions