Skip to content

Commit

Permalink
add jax bridge results elty sanitizer
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 684202212
  • Loading branch information
chunnienc authored and copybara-github committed Oct 9, 2024
1 parent 18d7630 commit 64abab3
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 8 deletions.
36 changes: 34 additions & 2 deletions ai_edge_torch/odml_torch/jax_bridge/_wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import jax
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import func
from jax._src.lib.mlir.dialects import hlo as stablehlo
import torch.utils._pytree as pytree

# Jax double (64bit) precision is required to generate StableHLO mlir with
Expand Down Expand Up @@ -143,8 +144,39 @@ def wrapped(lctx, *args, **kwargs):
ir_inputs = []

results = func.CallOp(cloned_func, ir_inputs).results

if lctx.node is None:
return results[0] if len(results) == 1 else results

out_avals = lctx.node.meta.get("tensor_meta") or lctx.node.meta.get("val")

if out_avals is None:
return results[0] if len(results) == 1 else results

def sanitize_result_elty(result, aval):
# JAX implementation may not respect aten op's output dtype. For example,
# JAX may implement a slightly different dtype upcast rules, leads to
# different result's dtype from bridged lowering and torch op output.
# Here we add an additional `stablehlo.convert` op when dtype does not
# match, to ensure the lowering's result dtype will always be the same
# as torch op's output dtype.
if aval is None:
return result

target_elty = export_utils.torch_dtype_to_ir_element_type(
lctx.ir_context, aval.dtype
)
if result.type.element_type == target_elty:
return result
return stablehlo.convert(
ir.RankedTensorType.get(result.type.shape, target_elty), result
)

if len(results) == 1:
return results[0]
return results
return sanitize_result_elty(results[0], out_avals)
return [
sanitize_result_elty(result, aval)
for result, aval in zip(results, out_avals)
]

return wrapped
31 changes: 30 additions & 1 deletion ai_edge_torch/odml_torch/lowerings/_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,17 @@
import math
from typing import Optional, Union

from ai_edge_torch.odml_torch import export_utils
from ai_edge_torch.odml_torch.lowerings import context
from ai_edge_torch.odml_torch.lowerings import registry
from ai_edge_torch.odml_torch.lowerings import utils
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo as stablehlo
import numpy as np
import torch

from .registry import lower
LoweringContext = context.LoweringContext
lower = registry.lower


# add(Tensor self, Tensor other) -> Tensor
Expand Down Expand Up @@ -204,6 +208,31 @@ def _aten_div(mod, x, y, *, rounding_mode=None, out=None) -> ir.Value:
return stablehlo.divide(x, y)


# Schema:
# - aten::cat(Tensor[] tensors, int dim=0) -> Tensor
# Torch Reference:
# - https://pytorch.org/docs/main/generated/torch.cat.html
@lower(torch.ops.aten.cat.default)
def _aten_cat(lctx: LoweringContext, tensors, dim=0):
assert tensors
non_empty_tensors = [t for t in tensors if np.prod(t.type.shape) != 0]
out_aval = lctx.node.meta.get("tensor_meta") or lctx.node.meta.get("val")
if not non_empty_tensors:
return utils.splat(
0,
export_utils.torch_dtype_to_ir_element_type(
lctx.ir_context, out_aval.dtype
),
out_aval.shape,
)

if dim < 0:
dim = dim + len(out_aval.shape)
dim = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), dim)

return stablehlo.concatenate(non_empty_tensors, dim)


# Schema:
# - aten::slice_scatter(Tensor self, Tensor src, int dim=0, SymInt?
# start=None, SymInt? end=None, SymInt step=1) -> Tensor
Expand Down
1 change: 0 additions & 1 deletion ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ def lower_by_torch_xla2(op):
lower_by_torch_xla2(torch.ops.aten.bitwise_or)
lower_by_torch_xla2(torch.ops.aten.bitwise_xor)
lower_by_torch_xla2(torch.ops.aten.bmm)
lower_by_torch_xla2(torch.ops.aten.cat)
lower_by_torch_xla2(torch.ops.aten.ceil)
lower_by_torch_xla2(torch.ops.aten.clamp.Tensor)
lower_by_torch_xla2(torch.ops.aten.clamp.default)
Expand Down
10 changes: 6 additions & 4 deletions ai_edge_torch/odml_torch/test/test_core_aten_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,10 @@ def _run_export_and_compare(
("aten_avg_pool2d_padding_num", torch.ops.aten.avg_pool2d, (rnd(torch.float32, (1, 3, 6, 6)), [3, 3], [1, 1], 1, False, True, None), dict()),
# ("aten_avg_pool3d_0", torch.ops.aten.avg_pool3d, (rnd(torch.float32, (1, 3, 10, 10, 10)), [2, 2, 2], [2, 2, 2], [0, 0, 0], False, False,), dict()),
("aten_bmm_0", torch.ops.aten.bmm, (rnd(torch.float32, (10, 10, 10)), rnd(torch.float32, (10, 10, 10)),), dict()),
("aten_cat_0", torch.ops.aten.cat, ([torch.randn((10, 10)).to(torch.float32)], 1,), dict()),
("aten_cat_1", torch.ops.aten.cat, ([torch.randn((10, 10)).to(torch.float32)], 1,), dict()),
("aten_cat_2", torch.ops.aten.cat, ([torch.randn((10, 10)).to(torch.float32)], 1,), dict()),
("aten_cat_0", torch.ops.aten.cat, ([rnd(torch.float32, (10, 10)), rnd(torch.float32, (1, 10))], 0,), dict()),
("aten_cat_1", torch.ops.aten.cat, ([rnd(torch.float32, (10, 10)), rnd(torch.float32, (0, 10))], 0,), dict()),
("aten_cat_2", torch.ops.aten.cat, ([rnd(torch.float32, (10, 10)), rnd(torch.float32, (0,))], 0,), dict()),
("aten_cat_3", torch.ops.aten.cat, ([rnd(torch.float32, (10, 10)), rnd(torch.float32, (10, 10)), rnd(torch.float32, (10, 10))], 0,), dict()),
("aten__cdist_forward_0", torch.ops.aten._cdist_forward, (rnd(torch.float32, (5, 7, 10)), rnd(torch.float32, (5, 8, 10)), 1.0, None,), dict()),
("aten_ceil_0", torch.ops.aten.ceil, (rnd(torch.float32, (10, 10)),), dict()),
("aten_clamp_0", torch.ops.aten.clamp, (rnd(torch.float32, (10, 10)), 0, 1,), dict()),
Expand Down Expand Up @@ -229,7 +230,8 @@ def _run_export_and_compare(
("aten_div_Scalar_0", torch.ops.aten.div.Scalar, (rnd(torch.float32, (10, 10)), 0.5,), dict()),
("aten_div_Scalar_mode_0", torch.ops.aten.div.Scalar_mode, (rnd(torch.float32, (10, 10)), 0.123,), {"rounding_mode": "trunc"}),
("aten_div_Tensor_0", torch.ops.aten.div.Tensor, (rnd(torch.float32, (10, 10)), rnd(torch.float32, (10, 10)),), dict()),
("aten_div_Tensor_mode_0", torch.ops.aten.div.Tensor_mode, (rnd(torch.float32, (10, 10)), rnd(torch.float32, (10, 10)),), {"rounding_mode": "trunc"}),
("aten_div_Tensor_mode_trunc_0", torch.ops.aten.div.Tensor_mode, (rnd(torch.float32, (10, 10)), rnd(torch.float32, (10, 10)),), {"rounding_mode": "trunc"}),
("aten_div_Tensor_mode_trunc_1", torch.ops.aten.div.Tensor_mode, (rnd(torch.int32, (10, 10)), rnd(torch.int32, (10, 10)),), {"rounding_mode": "trunc"}),
("aten_embedding_0", torch.ops.aten.embedding, (rnd(torch.float32, (10, 10)), rnd(torch.int64, (10,)),), dict()),
("aten_eq_Scalar_2", torch.ops.aten.eq.Scalar, (rnd(torch.float32, (10, 10)), 1,), dict()),
("aten_eq_Tensor_0", torch.ops.aten.eq.Tensor, (rnd(torch.float32, (10, 10)), rnd(torch.float32, (10, 10)),), dict()),
Expand Down

0 comments on commit 64abab3

Please sign in to comment.