Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unity] Replacing unary ops with LookUpTable and Take op to improve performance. #15614

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions python/tvm/contrib/hexagon/generate_take_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Pass to replace unary ops with Look Up Table and take op"""
import tvm
import tvm.testing
from tvm import relax
from tvm.contrib.hexagon import hexagon_unary_ops


def op_replace(call_node):
"""Checks if the op in the graph matched the list of unary ops which can be replaced"""

def is_op(op_name: str, call_node: relax.Call) -> bool:
if not isinstance(call_node, relax.Call):
return False
call_tir_op = tvm.ir.Op.get("relax.call_tir")
if call_node.op != call_tir_op:
return False
global_var = call_node.args[0]
return op_name in global_var.name_hint

ops = ["tanh", "sqrt", "rsqrt", "exp", "erf", "sigmoid", "hardswish", "log", "abs"]
for op in ops:
if is_op(op, call_node):
return True
return False


@relax.expr_functor.mutator
class Tanh2TakeReplace(tvm.relax.PyExprMutator):
"""Pass which iterated over the nodes, checks for unary ops and replaces them with LUT and take op"""

def __init__(self, mod: tvm.IRModule) -> None:
super().__init__(mod)
self.mod_ = mod

def transform(self) -> tvm.IRModule:
"""Iterates over all the nodes"""
for global_var, func in self.mod_.functions.items():
# Skip non-relax functions
if not isinstance(func, relax.Function):
continue
updated_func = self.visit_expr(func)
self.builder_.normalize(updated_func)
self.builder_.update_func(global_var, updated_func)
# At the end of the transformation we return the updated IRModule from the BlockBuilder.
return self.builder_.get()

def visit_call_(self, call_node: relax.Call) -> relax.Call:
if call_node.args[1][0].struct_info.dtype == "uint8":
if op_replace(call_node):
inp, inp_scale, inp_zp, out_scale, out_zp = list(call_node.args[1])
# LUT node creation
lut = hexagon_unary_ops.lut_generation(
inp_scale, inp_zp, out_scale, out_zp, call_node.args[0].name_hint
)
# Take operation node creation
take_func = hexagon_unary_ops.generate_take_primfunc(inp, call_node.struct_info)
take_func_gv = self.builder_.add_func(take_func, "take")
take_node = relax.call_tir(
take_func_gv,
relax.expr.Tuple(
[call_node.args[1][0], relax.expr.Constant(tvm.nd.array(lut))]
),
call_node.struct_info,
)
return take_node
return call_node


@tvm.ir.transform.module_pass(opt_level=2, name="replace_unaryop_take")
class PassReplaceWithTakeOpPrimFuncs:
def transform_module(self, mod, ctx):
return Tanh2TakeReplace(mod).transform()
105 changes: 105 additions & 0 deletions python/tvm/contrib/hexagon/hexagon_unary_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Primitive Function for lut and Take Op"""
import numpy as np
from scipy import special
from typing import List
from tvm import te
from tvm.tir.function import PrimFunc


def saturate(x: te.Tensor, dtype: str):
"""Saturate value for the specified data type"""
return te.max(te.min_value(dtype), te.min(x, te.max_value(dtype)))


def hardswish_func(x):
"""Hardswich Function"""
x_2 = np.add(x, 3.0)
x_2 = np.clip(x_2, 0.0, 6.0)
return x * x_2 / 6.0


def lut_generation(inp_scale, inp_zp, out_scale, out_zp, op_name) -> List[np.uint8]:
"""Generating the Look Up Table for unary ops"""
lut = []
for i in range(256):
i = np.int32(i)
# converting the constants to the numpy value
if inp_zp.data.shape == ():
i_zp = inp_zp.data.numpy()[()]
if inp_scale.data.shape == ():
i_scale = inp_scale.data.numpy()[()]
if out_zp.data.shape == ():
o_zp = out_zp.data.numpy()[()]
if out_scale.data.shape == ():
o_scale = out_scale.data.numpy()[()]
# Dequantization followed by computing the op value
dequant = (i - i_zp) * i_scale
if op_name == "tanh":
op_val = np.tanh(dequant)
elif op_name == "sqrt":
op_val = np.sqrt(dequant)
elif op_name == "rsqrt":
op_val = 1 / np.sqrt(dequant)
elif op_name == "exp":
op_val = np.exp(dequant)
elif op_name == "erf":
op_val = special.erf(dequant)
elif op_name == "sigmoid":
op_val = 1 / (1 + np.exp(np.negative(dequant)))
elif op_name == "hardswish":
op_val = hardswish_func(dequant)
elif op_name == "log":
op_val = np.log(dequant)
elif op_name == "abs":
op_val = np.abs(dequant)
# Quantizing the value generated and appending in the Look Up Table
quant = np.round((op_val) / o_scale) + o_zp
val = np.maximum(0, np.minimum(quant, 255)).astype(np.uint8)
lut.append(val)
return lut


def generate_take_primfunc(inp, struct_info) -> PrimFunc:
"""Generating the take op

Parameters
----------
inp : expr.Var
The input to be searched in the lut and whose take op needs to be returned

struct_info : TensorStructInfo
The struct info of the input data

Returns
----------
mod : PrimFunc
The take op primitive function
"""
n, h, w, c = inp.struct_info.shape
data = te.placeholder((n, h, w, c), dtype=struct_info.dtype, name="data")
lut_func = te.placeholder((256,), dtype="uint8", name="lut")
take = te.compute(
struct_info.shape,
lambda *indices: saturate(
(lut_func[data[indices].astype("uint8")]), struct_info.dtype
).astype(struct_info.dtype),
name="take_op",
)
mod = te.create_prim_func([data, lut_func, take])
return mod
Loading