Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add aten.cat direct lowering #290

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion ai_edge_torch/odml_torch/lowerings/_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,17 @@
import math
from typing import Optional, Union

from ai_edge_torch.odml_torch import export_utils
from ai_edge_torch.odml_torch.lowerings import context
from ai_edge_torch.odml_torch.lowerings import registry
from ai_edge_torch.odml_torch.lowerings import utils
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo as stablehlo
import numpy as np
import torch

from .registry import lower
LoweringContext = context.LoweringContext
lower = registry.lower


# add(Tensor self, Tensor other) -> Tensor
Expand Down Expand Up @@ -204,6 +208,31 @@ def _aten_div(mod, x, y, *, rounding_mode=None, out=None) -> ir.Value:
return stablehlo.divide(x, y)


# Schema:
# - aten::cat(Tensor[] tensors, int dim=0) -> Tensor
# Torch Reference:
# - https://pytorch.org/docs/main/generated/torch.cat.html
@lower(torch.ops.aten.cat.default)
def _aten_cat(lctx: LoweringContext, tensors, dim=0):
assert tensors
non_empty_tensors = [t for t in tensors if np.prod(t.type.shape) != 0]
out_meta = lctx.node.meta["tensor_meta"]
if not non_empty_tensors:
return utils.splat(
0,
export_utils.torch_dtype_to_ir_element_type(
lctx.ir_context, out_meta.dtype
),
out_meta.shape,
)

if dim < 0:
dim = dim + len(out_meta.shape)
dim = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), dim)

return stablehlo.concatenate(non_empty_tensors, dim)


# Schema:
# - aten::slice_scatter(Tensor self, Tensor src, int dim=0, SymInt?
# start=None, SymInt? end=None, SymInt step=1) -> Tensor
Expand Down
1 change: 0 additions & 1 deletion ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ def lower_by_torch_xla2(op):
lower_by_torch_xla2(torch.ops.aten.bitwise_or)
lower_by_torch_xla2(torch.ops.aten.bitwise_xor)
lower_by_torch_xla2(torch.ops.aten.bmm)
lower_by_torch_xla2(torch.ops.aten.cat)
lower_by_torch_xla2(torch.ops.aten.ceil)
lower_by_torch_xla2(torch.ops.aten.clamp.Tensor)
lower_by_torch_xla2(torch.ops.aten.clamp.default)
Expand Down
7 changes: 4 additions & 3 deletions ai_edge_torch/odml_torch/test/test_core_aten_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,10 @@ def _run_export_and_compare(
("aten_avg_pool2d_padding_num", torch.ops.aten.avg_pool2d, (rnd(torch.float32, (1, 3, 6, 6)), [3, 3], [1, 1], 1, False, True, None), dict()),
# ("aten_avg_pool3d_0", torch.ops.aten.avg_pool3d, (rnd(torch.float32, (1, 3, 10, 10, 10)), [2, 2, 2], [2, 2, 2], [0, 0, 0], False, False,), dict()),
("aten_bmm_0", torch.ops.aten.bmm, (rnd(torch.float32, (10, 10, 10)), rnd(torch.float32, (10, 10, 10)),), dict()),
("aten_cat_0", torch.ops.aten.cat, ([torch.randn((10, 10)).to(torch.float32)], 1,), dict()),
("aten_cat_1", torch.ops.aten.cat, ([torch.randn((10, 10)).to(torch.float32)], 1,), dict()),
("aten_cat_2", torch.ops.aten.cat, ([torch.randn((10, 10)).to(torch.float32)], 1,), dict()),
("aten_cat_0", torch.ops.aten.cat, ([rnd(torch.float32, (10, 10)), rnd(torch.float32, (1, 10))], 0,), dict()),
("aten_cat_1", torch.ops.aten.cat, ([rnd(torch.float32, (10, 10)), rnd(torch.float32, (0, 10))], 0,), dict()),
("aten_cat_2", torch.ops.aten.cat, ([rnd(torch.float32, (10, 10)), rnd(torch.float32, (0,))], 0,), dict()),
("aten_cat_3", torch.ops.aten.cat, ([rnd(torch.float32, (10, 10)), rnd(torch.float32, (10, 10)), rnd(torch.float32, (10, 10))], 0,), dict()),
("aten__cdist_forward_0", torch.ops.aten._cdist_forward, (rnd(torch.float32, (5, 7, 10)), rnd(torch.float32, (5, 8, 10)), 1.0, None,), dict()),
("aten_ceil_0", torch.ops.aten.ceil, (rnd(torch.float32, (10, 10)),), dict()),
("aten_clamp_0", torch.ops.aten.clamp, (rnd(torch.float32, (10, 10)), 0, 1,), dict()),
Expand Down
Loading