From eb284eae1912a4d46be85710d6851afe4f4e1bef Mon Sep 17 00:00:00 2001 From: yoco Date: Fri, 4 Jun 2021 07:07:51 +0800 Subject: [PATCH] Simplify PyTorch interact feature compuation The original implementation use torch.cat() & view() (a.k.a reshape) to combine feature vectors into a matrix. It create a Reshape operation in the exported ONNX, and create dynamic tesnor. It can be replace by just one torch.stack(). The result graph is simpler and faster. --- dlrm_s_pytorch.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/dlrm_s_pytorch.py b/dlrm_s_pytorch.py index eb352664..92404560 100644 --- a/dlrm_s_pytorch.py +++ b/dlrm_s_pytorch.py @@ -474,8 +474,7 @@ def interact_features(self, x, ly): if self.arch_interaction_op == "dot": # concatenate dense and sparse features - (batch_size, d) = x.shape - T = torch.cat([x] + ly, dim=1).view((batch_size, -1, d)) + T = torch.stack([x] + ly, dim=1) # perform a dot product Z = torch.bmm(T, torch.transpose(T, 1, 2)) # append dense feature with the interactions (into a row vector)