From db2d26d6170bec51f406cf466bd6c37cf74a9264 Mon Sep 17 00:00:00 2001 From: Chun-nien Chan Date: Wed, 9 Oct 2024 16:01:39 -0700 Subject: [PATCH] add jax bridge results elty sanitizer PiperOrigin-RevId: 684202212 --- ai_edge_torch/odml_torch/jax_bridge/_wrap.py | 20 ++++++++++-- ai_edge_torch/odml_torch/lowerings/_basic.py | 31 ++++++++++++++++++- .../odml_torch/lowerings/_jax_lowerings.py | 1 - .../odml_torch/test/test_core_aten_ops.py | 7 +++-- 4 files changed, 52 insertions(+), 7 deletions(-) diff --git a/ai_edge_torch/odml_torch/jax_bridge/_wrap.py b/ai_edge_torch/odml_torch/jax_bridge/_wrap.py index 32853253..fae27325 100644 --- a/ai_edge_torch/odml_torch/jax_bridge/_wrap.py +++ b/ai_edge_torch/odml_torch/jax_bridge/_wrap.py @@ -24,6 +24,7 @@ import jax from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import func +from jax._src.lib.mlir.dialects import hlo as stablehlo import torch.utils._pytree as pytree # Jax double (64bit) precision is required to generate StableHLO mlir with @@ -143,8 +144,23 @@ def wrapped(lctx, *args, **kwargs): ir_inputs = [] results = func.CallOp(cloned_func, ir_inputs).results + out_avals = lctx.node.meta.get("tensor_meta") or lctx.node.meta.get("val") + + def sanitize_result_elty(result, aval): + target_elty = export_utils.torch_dtype_to_ir_element_type( + lctx.ir_context, aval.dtype + ) + if result.type.element_type == target_elty: + return result + return stablehlo.convert( + ir.RankedTensorType.get(result.type.shape, target_elty), result + ) + if len(results) == 1: - return results[0] - return results + return sanitize_result_elty(results[0], out_avals) + return [ + sanitize_result_elty(result, aval) + for result, aval in zip(results, out_avals) + ] return wrapped diff --git a/ai_edge_torch/odml_torch/lowerings/_basic.py b/ai_edge_torch/odml_torch/lowerings/_basic.py index d248988c..db22b898 100644 --- a/ai_edge_torch/odml_torch/lowerings/_basic.py +++ b/ai_edge_torch/odml_torch/lowerings/_basic.py @@ -15,13 +15,17 @@ import math from typing import Optional, Union +from ai_edge_torch.odml_torch import export_utils +from ai_edge_torch.odml_torch.lowerings import context +from ai_edge_torch.odml_torch.lowerings import registry from ai_edge_torch.odml_torch.lowerings import utils from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo as stablehlo import numpy as np import torch -from .registry import lower +LoweringContext = context.LoweringContext +lower = registry.lower # add(Tensor self, Tensor other) -> Tensor @@ -204,6 +208,31 @@ def _aten_div(mod, x, y, *, rounding_mode=None, out=None) -> ir.Value: return stablehlo.divide(x, y) +# Schema: +# - aten::cat(Tensor[] tensors, int dim=0) -> Tensor +# Torch Reference: +# - https://pytorch.org/docs/main/generated/torch.cat.html +@lower(torch.ops.aten.cat.default) +def _aten_cat(lctx: LoweringContext, tensors, dim=0): + assert tensors + non_empty_tensors = [t for t in tensors if np.prod(t.type.shape) != 0] + out_aval = lctx.node.meta.get("tensor_meta") or lctx.node.meta.get("val") + if not non_empty_tensors: + return utils.splat( + 0, + export_utils.torch_dtype_to_ir_element_type( + lctx.ir_context, out_aval.dtype + ), + out_aval.shape, + ) + + if dim < 0: + dim = dim + len(out_aval.shape) + dim = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), dim) + + return stablehlo.concatenate(non_empty_tensors, dim) + + # Schema: # - aten::slice_scatter(Tensor self, Tensor src, int dim=0, SymInt? # start=None, SymInt? end=None, SymInt step=1) -> Tensor diff --git a/ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py b/ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py index d86e2b13..c9a4feb1 100644 --- a/ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +++ b/ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py @@ -98,7 +98,6 @@ def lower_by_torch_xla2(op): lower_by_torch_xla2(torch.ops.aten.bitwise_or) lower_by_torch_xla2(torch.ops.aten.bitwise_xor) lower_by_torch_xla2(torch.ops.aten.bmm) -lower_by_torch_xla2(torch.ops.aten.cat) lower_by_torch_xla2(torch.ops.aten.ceil) lower_by_torch_xla2(torch.ops.aten.clamp.Tensor) lower_by_torch_xla2(torch.ops.aten.clamp.default) diff --git a/ai_edge_torch/odml_torch/test/test_core_aten_ops.py b/ai_edge_torch/odml_torch/test/test_core_aten_ops.py index 6647fece..52fd2009 100644 --- a/ai_edge_torch/odml_torch/test/test_core_aten_ops.py +++ b/ai_edge_torch/odml_torch/test/test_core_aten_ops.py @@ -182,9 +182,10 @@ def _run_export_and_compare( ("aten_avg_pool2d_padding_num", torch.ops.aten.avg_pool2d, (rnd(torch.float32, (1, 3, 6, 6)), [3, 3], [1, 1], 1, False, True, None), dict()), # ("aten_avg_pool3d_0", torch.ops.aten.avg_pool3d, (rnd(torch.float32, (1, 3, 10, 10, 10)), [2, 2, 2], [2, 2, 2], [0, 0, 0], False, False,), dict()), ("aten_bmm_0", torch.ops.aten.bmm, (rnd(torch.float32, (10, 10, 10)), rnd(torch.float32, (10, 10, 10)),), dict()), - ("aten_cat_0", torch.ops.aten.cat, ([torch.randn((10, 10)).to(torch.float32)], 1,), dict()), - ("aten_cat_1", torch.ops.aten.cat, ([torch.randn((10, 10)).to(torch.float32)], 1,), dict()), - ("aten_cat_2", torch.ops.aten.cat, ([torch.randn((10, 10)).to(torch.float32)], 1,), dict()), + ("aten_cat_0", torch.ops.aten.cat, ([rnd(torch.float32, (10, 10)), rnd(torch.float32, (1, 10))], 0,), dict()), + ("aten_cat_1", torch.ops.aten.cat, ([rnd(torch.float32, (10, 10)), rnd(torch.float32, (0, 10))], 0,), dict()), + ("aten_cat_2", torch.ops.aten.cat, ([rnd(torch.float32, (10, 10)), rnd(torch.float32, (0,))], 0,), dict()), + ("aten_cat_3", torch.ops.aten.cat, ([rnd(torch.float32, (10, 10)), rnd(torch.float32, (10, 10)), rnd(torch.float32, (10, 10))], 0,), dict()), ("aten__cdist_forward_0", torch.ops.aten._cdist_forward, (rnd(torch.float32, (5, 7, 10)), rnd(torch.float32, (5, 8, 10)), 1.0, None,), dict()), ("aten_ceil_0", torch.ops.aten.ceil, (rnd(torch.float32, (10, 10)),), dict()), ("aten_clamp_0", torch.ops.aten.clamp, (rnd(torch.float32, (10, 10)), 0, 1,), dict()),