diff --git a/ai_edge_torch/odml_torch/jax_bridge/_wrap.py b/ai_edge_torch/odml_torch/jax_bridge/_wrap.py index 32853253..1cea1f47 100644 --- a/ai_edge_torch/odml_torch/jax_bridge/_wrap.py +++ b/ai_edge_torch/odml_torch/jax_bridge/_wrap.py @@ -24,6 +24,7 @@ import jax from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import func +from jax._src.lib.mlir.dialects import hlo as stablehlo import torch.utils._pytree as pytree # Jax double (64bit) precision is required to generate StableHLO mlir with @@ -143,8 +144,39 @@ def wrapped(lctx, *args, **kwargs): ir_inputs = [] results = func.CallOp(cloned_func, ir_inputs).results + + if lctx.node is None: + return results[0] if len(results) == 1 else results + + out_avals = lctx.node.meta.get("tensor_meta") or lctx.node.meta.get("val") + + if out_avals is None: + return results[0] if len(results) == 1 else results + + def sanitize_result_elty(result, aval): + # JAX implementation may not respect aten op's output dtype. For example, + # JAX may implement a slightly different dtype upcast rules, leads to + # different result's dtype from bridged lowering and torch op output. + # Here we add an additional `stablehlo.convert` op when dtype does not + # match, to ensure the lowering's result dtype will always be the same + # as torch op's output dtype. + if aval is None: + return result + + target_elty = export_utils.torch_dtype_to_ir_element_type( + lctx.ir_context, aval.dtype + ) + if result.type.element_type == target_elty: + return result + return stablehlo.convert( + ir.RankedTensorType.get(result.type.shape, target_elty), result + ) + if len(results) == 1: - return results[0] - return results + return sanitize_result_elty(results[0], out_avals) + return [ + sanitize_result_elty(result, aval) + for result, aval in zip(results, out_avals) + ] return wrapped diff --git a/ai_edge_torch/odml_torch/lowerings/_basic.py b/ai_edge_torch/odml_torch/lowerings/_basic.py index d248988c..db22b898 100644 --- a/ai_edge_torch/odml_torch/lowerings/_basic.py +++ b/ai_edge_torch/odml_torch/lowerings/_basic.py @@ -15,13 +15,17 @@ import math from typing import Optional, Union +from ai_edge_torch.odml_torch import export_utils +from ai_edge_torch.odml_torch.lowerings import context +from ai_edge_torch.odml_torch.lowerings import registry from ai_edge_torch.odml_torch.lowerings import utils from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo as stablehlo import numpy as np import torch -from .registry import lower +LoweringContext = context.LoweringContext +lower = registry.lower # add(Tensor self, Tensor other) -> Tensor @@ -204,6 +208,31 @@ def _aten_div(mod, x, y, *, rounding_mode=None, out=None) -> ir.Value: return stablehlo.divide(x, y) +# Schema: +# - aten::cat(Tensor[] tensors, int dim=0) -> Tensor +# Torch Reference: +# - https://pytorch.org/docs/main/generated/torch.cat.html +@lower(torch.ops.aten.cat.default) +def _aten_cat(lctx: LoweringContext, tensors, dim=0): + assert tensors + non_empty_tensors = [t for t in tensors if np.prod(t.type.shape) != 0] + out_aval = lctx.node.meta.get("tensor_meta") or lctx.node.meta.get("val") + if not non_empty_tensors: + return utils.splat( + 0, + export_utils.torch_dtype_to_ir_element_type( + lctx.ir_context, out_aval.dtype + ), + out_aval.shape, + ) + + if dim < 0: + dim = dim + len(out_aval.shape) + dim = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), dim) + + return stablehlo.concatenate(non_empty_tensors, dim) + + # Schema: # - aten::slice_scatter(Tensor self, Tensor src, int dim=0, SymInt? # start=None, SymInt? end=None, SymInt step=1) -> Tensor diff --git a/ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py b/ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py index d86e2b13..c9a4feb1 100644 --- a/ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +++ b/ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py @@ -98,7 +98,6 @@ def lower_by_torch_xla2(op): lower_by_torch_xla2(torch.ops.aten.bitwise_or) lower_by_torch_xla2(torch.ops.aten.bitwise_xor) lower_by_torch_xla2(torch.ops.aten.bmm) -lower_by_torch_xla2(torch.ops.aten.cat) lower_by_torch_xla2(torch.ops.aten.ceil) lower_by_torch_xla2(torch.ops.aten.clamp.Tensor) lower_by_torch_xla2(torch.ops.aten.clamp.default) diff --git a/ai_edge_torch/odml_torch/test/test_core_aten_ops.py b/ai_edge_torch/odml_torch/test/test_core_aten_ops.py index 6647fece..e238ae72 100644 --- a/ai_edge_torch/odml_torch/test/test_core_aten_ops.py +++ b/ai_edge_torch/odml_torch/test/test_core_aten_ops.py @@ -65,7 +65,7 @@ def forward(self, *export_args): def rnd(dtype, shape, min_v=None, max_v=None): """Shortcut for creating a random torch tensor.""" if dtype in (torch.int32, torch.int64, torch.bool): - min_v = min_v if min_v else 0 + min_v = min_v if min_v else 1 max_v = max_v if max_v else 10 return torch.randint(min_v, max_v, shape).to(dtype) else: @@ -137,7 +137,7 @@ def _run_export_and_compare( @parameterized.named_parameters( # fmt: off - # pyformat: disable + # pyformat: disabledef ("aten_abs_0", torch.ops.aten.abs, (rnd(torch.float32, (10, 10)),), dict()), ("aten_acos_0", torch.ops.aten.acos, (rnd(torch.float32, (10, 10)),), dict()), ("aten_acosh_0", torch.ops.aten.acosh, (rnd(torch.float32, (10, 10)),), dict()), @@ -182,9 +182,10 @@ def _run_export_and_compare( ("aten_avg_pool2d_padding_num", torch.ops.aten.avg_pool2d, (rnd(torch.float32, (1, 3, 6, 6)), [3, 3], [1, 1], 1, False, True, None), dict()), # ("aten_avg_pool3d_0", torch.ops.aten.avg_pool3d, (rnd(torch.float32, (1, 3, 10, 10, 10)), [2, 2, 2], [2, 2, 2], [0, 0, 0], False, False,), dict()), ("aten_bmm_0", torch.ops.aten.bmm, (rnd(torch.float32, (10, 10, 10)), rnd(torch.float32, (10, 10, 10)),), dict()), - ("aten_cat_0", torch.ops.aten.cat, ([torch.randn((10, 10)).to(torch.float32)], 1,), dict()), - ("aten_cat_1", torch.ops.aten.cat, ([torch.randn((10, 10)).to(torch.float32)], 1,), dict()), - ("aten_cat_2", torch.ops.aten.cat, ([torch.randn((10, 10)).to(torch.float32)], 1,), dict()), + ("aten_cat_0", torch.ops.aten.cat, ([rnd(torch.float32, (10, 10)), rnd(torch.float32, (1, 10))], 0,), dict()), + ("aten_cat_1", torch.ops.aten.cat, ([rnd(torch.float32, (10, 10)), rnd(torch.float32, (0, 10))], 0,), dict()), + ("aten_cat_2", torch.ops.aten.cat, ([rnd(torch.float32, (10, 10)), rnd(torch.float32, (0,))], 0,), dict()), + ("aten_cat_3", torch.ops.aten.cat, ([rnd(torch.float32, (10, 10)), rnd(torch.float32, (10, 10)), rnd(torch.float32, (10, 10))], 0,), dict()), ("aten__cdist_forward_0", torch.ops.aten._cdist_forward, (rnd(torch.float32, (5, 7, 10)), rnd(torch.float32, (5, 8, 10)), 1.0, None,), dict()), ("aten_ceil_0", torch.ops.aten.ceil, (rnd(torch.float32, (10, 10)),), dict()), ("aten_clamp_0", torch.ops.aten.clamp, (rnd(torch.float32, (10, 10)), 0, 1,), dict()), @@ -229,7 +230,8 @@ def _run_export_and_compare( ("aten_div_Scalar_0", torch.ops.aten.div.Scalar, (rnd(torch.float32, (10, 10)), 0.5,), dict()), ("aten_div_Scalar_mode_0", torch.ops.aten.div.Scalar_mode, (rnd(torch.float32, (10, 10)), 0.123,), {"rounding_mode": "trunc"}), ("aten_div_Tensor_0", torch.ops.aten.div.Tensor, (rnd(torch.float32, (10, 10)), rnd(torch.float32, (10, 10)),), dict()), - ("aten_div_Tensor_mode_0", torch.ops.aten.div.Tensor_mode, (rnd(torch.float32, (10, 10)), rnd(torch.float32, (10, 10)),), {"rounding_mode": "trunc"}), + ("aten_div_Tensor_mode_trunc_0", torch.ops.aten.div.Tensor_mode, (rnd(torch.float32, (10, 10)), rnd(torch.float32, (10, 10)),), {"rounding_mode": "trunc"}), + ("aten_div_Tensor_mode_trunc_1", torch.ops.aten.div.Tensor_mode, (rnd(torch.int32, (10, 10)), rnd(torch.int32, (10, 10)),), {"rounding_mode": "trunc"}), ("aten_embedding_0", torch.ops.aten.embedding, (rnd(torch.float32, (10, 10)), rnd(torch.int64, (10,)),), dict()), ("aten_eq_Scalar_2", torch.ops.aten.eq.Scalar, (rnd(torch.float32, (10, 10)), 1,), dict()), ("aten_eq_Tensor_0", torch.ops.aten.eq.Tensor, (rnd(torch.float32, (10, 10)), rnd(torch.float32, (10, 10)),), dict()),