From 587ae863cd1b6211db8fd1848432e9833079d5aa Mon Sep 17 00:00:00 2001 From: Chun-nien Chan Date: Wed, 9 Oct 2024 13:19:03 -0700 Subject: [PATCH] add aten.cat direct lowering PiperOrigin-RevId: 684148603 --- ai_edge_torch/odml_torch/lowerings/_basic.py | 31 ++++++++++++++++++- .../odml_torch/lowerings/_jax_lowerings.py | 1 - .../odml_torch/test/test_core_aten_ops.py | 7 +++-- 3 files changed, 34 insertions(+), 5 deletions(-) diff --git a/ai_edge_torch/odml_torch/lowerings/_basic.py b/ai_edge_torch/odml_torch/lowerings/_basic.py index d248988c..01b38c36 100644 --- a/ai_edge_torch/odml_torch/lowerings/_basic.py +++ b/ai_edge_torch/odml_torch/lowerings/_basic.py @@ -15,13 +15,17 @@ import math from typing import Optional, Union +from ai_edge_torch.odml_torch import export_utils +from ai_edge_torch.odml_torch.lowerings import context +from ai_edge_torch.odml_torch.lowerings import registry from ai_edge_torch.odml_torch.lowerings import utils from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo as stablehlo import numpy as np import torch -from .registry import lower +LoweringContext = context.LoweringContext +lower = registry.lower # add(Tensor self, Tensor other) -> Tensor @@ -204,6 +208,31 @@ def _aten_div(mod, x, y, *, rounding_mode=None, out=None) -> ir.Value: return stablehlo.divide(x, y) +# Schema: +# - aten::cat(Tensor[] tensors, int dim=0) -> Tensor +# Torch Reference: +# - https://pytorch.org/docs/main/generated/torch.cat.html +@lower(torch.ops.aten.cat.default) +def _aten_cat(lctx: LoweringContext, tensors, dim=0): + assert tensors + non_empty_tensors = [t for t in tensors if np.prod(t.type.shape) != 0] + out_meta = lctx.node.meta["tensor_meta"] + if not non_empty_tensors: + return utils.splat( + 0, + export_utils.torch_dtype_to_ir_element_type( + lctx.ir_context, out_meta.dtype + ), + out_meta.shape, + ) + + if dim < 0: + dim = dim + len(out_meta.shape) + dim = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), dim) + + return stablehlo.concatenate(non_empty_tensors, dim) + + # Schema: # - aten::slice_scatter(Tensor self, Tensor src, int dim=0, SymInt? # start=None, SymInt? end=None, SymInt step=1) -> Tensor diff --git a/ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py b/ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py index d86e2b13..c9a4feb1 100644 --- a/ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +++ b/ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py @@ -98,7 +98,6 @@ def lower_by_torch_xla2(op): lower_by_torch_xla2(torch.ops.aten.bitwise_or) lower_by_torch_xla2(torch.ops.aten.bitwise_xor) lower_by_torch_xla2(torch.ops.aten.bmm) -lower_by_torch_xla2(torch.ops.aten.cat) lower_by_torch_xla2(torch.ops.aten.ceil) lower_by_torch_xla2(torch.ops.aten.clamp.Tensor) lower_by_torch_xla2(torch.ops.aten.clamp.default) diff --git a/ai_edge_torch/odml_torch/test/test_core_aten_ops.py b/ai_edge_torch/odml_torch/test/test_core_aten_ops.py index 6647fece..52fd2009 100644 --- a/ai_edge_torch/odml_torch/test/test_core_aten_ops.py +++ b/ai_edge_torch/odml_torch/test/test_core_aten_ops.py @@ -182,9 +182,10 @@ def _run_export_and_compare( ("aten_avg_pool2d_padding_num", torch.ops.aten.avg_pool2d, (rnd(torch.float32, (1, 3, 6, 6)), [3, 3], [1, 1], 1, False, True, None), dict()), # ("aten_avg_pool3d_0", torch.ops.aten.avg_pool3d, (rnd(torch.float32, (1, 3, 10, 10, 10)), [2, 2, 2], [2, 2, 2], [0, 0, 0], False, False,), dict()), ("aten_bmm_0", torch.ops.aten.bmm, (rnd(torch.float32, (10, 10, 10)), rnd(torch.float32, (10, 10, 10)),), dict()), - ("aten_cat_0", torch.ops.aten.cat, ([torch.randn((10, 10)).to(torch.float32)], 1,), dict()), - ("aten_cat_1", torch.ops.aten.cat, ([torch.randn((10, 10)).to(torch.float32)], 1,), dict()), - ("aten_cat_2", torch.ops.aten.cat, ([torch.randn((10, 10)).to(torch.float32)], 1,), dict()), + ("aten_cat_0", torch.ops.aten.cat, ([rnd(torch.float32, (10, 10)), rnd(torch.float32, (1, 10))], 0,), dict()), + ("aten_cat_1", torch.ops.aten.cat, ([rnd(torch.float32, (10, 10)), rnd(torch.float32, (0, 10))], 0,), dict()), + ("aten_cat_2", torch.ops.aten.cat, ([rnd(torch.float32, (10, 10)), rnd(torch.float32, (0,))], 0,), dict()), + ("aten_cat_3", torch.ops.aten.cat, ([rnd(torch.float32, (10, 10)), rnd(torch.float32, (10, 10)), rnd(torch.float32, (10, 10))], 0,), dict()), ("aten__cdist_forward_0", torch.ops.aten._cdist_forward, (rnd(torch.float32, (5, 7, 10)), rnd(torch.float32, (5, 8, 10)), 1.0, None,), dict()), ("aten_ceil_0", torch.ops.aten.ceil, (rnd(torch.float32, (10, 10)),), dict()), ("aten_clamp_0", torch.ops.aten.clamp, (rnd(torch.float32, (10, 10)), 0, 1,), dict()),