Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 170 additions & 0 deletions src/ninetoothed/ascendifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
import ast


class Ascendifier(ast.NodeTransformer):
def __init__(self):
super().__init__()
self.max_axes = None
try:
from triton.backends.ascend.runtime.utils import valid_axis_names

self.max_axes = len(valid_axis_names)
except ImportError:
pass

@staticmethod
def _is_triton_language_name(node):
return isinstance(node, ast.Name) and node.id == "triton.language"

@staticmethod
def _is_triton_language(node):
return (
isinstance(node, ast.Attribute)
and node.attr == "language"
and isinstance(node.value, ast.Name)
and node.value.id == "triton"
)

@classmethod
def _is_triton_language_member(cls, node, member):
return (
isinstance(node, ast.Attribute)
and node.attr == member
and (
cls._is_triton_language_name(node.value)
or cls._is_triton_language(node.value)
)
)

@staticmethod
def _clone(node):
return ast.fix_missing_locations(ast.parse(ast.unparse(node), mode="eval").body)

@classmethod
def _make_member_call(cls, namespace, member, *args):
return ast.Call(
func=ast.Attribute(value=cls._clone(namespace), attr=member),
args=[cls._clone(arg) for arg in args],
keywords=[],
)

@staticmethod
def _triton_language_namespace():
return ast.Attribute(value=ast.Name(id="triton"), attr="language")

@staticmethod
def _is_triton_autotune_call(node):
return (
isinstance(node, ast.Call)
and isinstance(node.func, ast.Attribute)
and isinstance(node.func.value, ast.Name)
and node.func.value.id == "triton"
and node.func.attr == "autotune"
)

@staticmethod
def _is_sequence_literal(node):
return isinstance(node, (ast.List, ast.Tuple))

@staticmethod
def _autotune_key_priority(item):
index, key_node = item
value = str(key_node.value)
if "next_power_of_2" in value:
priority = 2
elif "constexpr" in value:
priority = 1
else:
priority = 0

return priority, index

@classmethod
def _filter_autotune_keys(cls, key_nodes, max_axes):
size_keys = [
key_node
for key_node in key_nodes
if isinstance(key_node, ast.Constant) and "size" in str(key_node.value)
]
return [
key_node
for _, key_node in sorted(
enumerate(size_keys), key=cls._autotune_key_priority
)
][:max_axes]

@classmethod
def _rewrite_autotune_keyword(cls, keyword, max_axes):
if keyword.arg == "configs" and cls._is_sequence_literal(keyword.value):
cls._rewrite_square_block_autotune_configs(keyword.value.elts)
return

if keyword.arg == "key" and cls._is_sequence_literal(keyword.value):
keyword.value.elts = cls._filter_autotune_keys(keyword.value.elts, max_axes)

@classmethod
def _rewrite_autotune_call(cls, node, max_axes):
for keyword in node.keywords:
cls._rewrite_autotune_keyword(keyword, max_axes)

return

@classmethod
def _rewrite_load_call(cls, node):
if not cls._is_triton_language_member(node.func, "load"):
return

for keyword in node.keywords:
if (
keyword.arg == "other"
and isinstance(keyword.value, ast.Constant)
and keyword.value.value is None
):
keyword.value.value = 0.0

@classmethod
def _rewrite_clamp_call(cls, node):
if not cls._is_triton_language_member(node.func, "clamp"):
return node

if len(node.args) < 3 or node.keywords:
return node

maximum = cls._make_member_call(
node.func.value, "maximum", node.args[0], node.args[1]
)
return cls._make_member_call(node.func.value, "minimum", maximum, node.args[2])

def visit_Attribute(self, node):
self.generic_visit(node)

if type(self)._is_triton_language_member(node, "float64"):
node.attr = "float32"

return node

def visit_ImportFrom(self, node):
self.generic_visit(node)

if node.module == "triton.language.extra":
for alias in node.names:
if alias.name == "libdevice":
node.module = "triton.language.extra.cann"

return node

def visit_Call(self, node):
self.generic_visit(node)

if type(self)._is_triton_autotune_call(node):
type(self)._rewrite_autotune_call(node, self.max_axes)

type(self)._rewrite_load_call(node)
return type(self)._rewrite_clamp_call(node)

def visit_Module(self, node):
self.generic_visit(node)

node = type(self)._rewrite_tail_key_boundary_masks(node)

return node
55 changes: 48 additions & 7 deletions src/ninetoothed/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,48 @@
from triton.language.extra import libdevice

import ninetoothed.naming as naming
from ninetoothed.ascendifier import Ascendifier
from ninetoothed.cudaifier import Cudaifier
from ninetoothed.language import attribute, call
from ninetoothed.symbol import Symbol
from ninetoothed.tensor import Tensor
from ninetoothed.torchifier import Torchifier

CACHE_DIR = pathlib.Path.home() / ".ninetoothed"
CACHE_DIR.mkdir(exist_ok=True)

def _resolve_cache_dir():
for cache_dir in (
pathlib.Path.home() / ".ninetoothed",
pathlib.Path("/tmp/.ninetoothed"),
):
try:
cache_dir.mkdir(exist_ok=True)
probe = cache_dir / ".write_probe"
probe.write_text("", encoding="utf-8")
probe.unlink()
return cache_dir
except OSError:
continue

raise OSError("Failed to find a writable cache directory for ninetoothed.")


CACHE_DIR = _resolve_cache_dir()


class CodeGenerator(ast.NodeTransformer):
def __init__(self):
super().__init__()

device = triton.runtime.driver.active.get_current_device()
properties = triton.runtime.driver.active.utils.get_device_properties(device)

self._min_num_elements = 1
properties = {}

try:
device = triton.runtime.driver.active.get_current_device()
properties = triton.runtime.driver.active.utils.get_device_properties(
device
)
except Exception:
properties = {}

if "max_num_regs" in properties:
max_innermost_size = 4 * properties["max_num_regs"]
Expand Down Expand Up @@ -109,11 +133,28 @@ def _find_dependencies(func):
name_collector = _SimplifiedNameCollector()
name_collector.visit(tree)

unparsed = ast.unparse(tree).replace("None:", ":").replace(":None", ":")
npu_tree = copy.deepcopy(tree)
Ascendifier().visit(npu_tree)
ast.fix_missing_locations(npu_tree)

if prettify:
name_collector.visit(npu_tree)

dependencies = _find_dependencies(func)
unparsed = ast.unparse(tree).replace("None:", ":").replace(":None", ":")
source = "\n\n".join((unparsed, dependencies)).strip()
source = source.replace(func.__name__, kernel_name)
source += "\n"
unparsed_npu = ast.unparse(npu_tree).replace("None:", ":").replace(":None", ":")
source_npu = "\n\n".join((unparsed_npu, dependencies)).strip()
source_npu = source_npu.replace(func.__name__, kernel_name + "_npu")

guard_header = (
"import torch\n"
"_IS_NPU = hasattr(torch, 'npu') and torch.npu.is_available()\n\n"
)
source_cuda_guarded = "if not _IS_NPU:\n" + textwrap.indent(source, " ")
source_npu_guarded = "if _IS_NPU:\n" + textwrap.indent(source_npu, " ")
source = guard_header + source_cuda_guarded + "\n\n" + source_npu_guarded + "\n"

if prettify:
for original, simplified in name_collector.simplified_names.items():
Expand Down
14 changes: 12 additions & 2 deletions src/ninetoothed/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,19 @@ def __call__(self):
module = import_from_path(source_file, source_file)
module_vars = vars(module)

target_kernel_name = self._kernel_name
target_launch_name = code_generator.launch_func_name

if self._caller == "torch":
import torch

if hasattr(torch, "npu") and torch.npu.is_available():
target_kernel_name = f"{self._kernel_name}_npu"
target_launch_name = f"{code_generator.launch_func_name}_npu"

handle = _Handle(
module_vars[self._kernel_name],
module_vars[code_generator.launch_func_name],
module_vars[target_kernel_name],
module_vars[target_launch_name],
source_file,
)

Expand Down
24 changes: 16 additions & 8 deletions src/ninetoothed/make.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ def make(
:param arrangement: The arrangement of the tensors.
:param application: The application of the tensors.
:param tensors: The tensors.
:param caller: Who will call the compute kernel.
:param caller: Kernel build route selector.
- ``torch``: JIT
- ``cuda``: CUDA AOT
- ``ascend``: Ascend AOT
:param kernel_name: The name for the generated kernel.
:param output_dir: The directory to store the generated files.
:param num_warps: The number of warps to use.
Expand All @@ -46,11 +49,16 @@ def make(
max_num_configs=max_num_configs,
)

return aot(
application,
caller=caller,
kernel_name=kernel_name,
output_dir=output_dir,
num_warps=num_warps,
num_stages=num_stages,
if caller in ("cuda", "ascend"):
return aot(
application,
caller=caller,
kernel_name=kernel_name,
output_dir=output_dir,
num_warps=num_warps,
num_stages=num_stages,
)

raise ValueError(
f"Unsupported caller '{caller}'. Expected one of: 'torch', 'cuda', 'ascend'."
)
13 changes: 8 additions & 5 deletions src/ninetoothed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@


def calculate_default_configs():
device = triton.runtime.driver.active.get_current_device()
properties = triton.runtime.driver.active.utils.get_device_properties(device)
max_shared_mem = properties["max_shared_mem"]

num_warps = 8
num_stages = max_shared_mem // 2**15

try:
device = triton.runtime.driver.active.get_current_device()
properties = triton.runtime.driver.active.utils.get_device_properties(device)
max_shared_mem = properties["max_shared_mem"]
num_stages = max(1, max_shared_mem // 2**15)
except Exception:
num_stages = 1

return num_warps, num_stages
6 changes: 6 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,17 @@ def get_available_devices():
if torch.cuda.is_available():
devices.append("cuda")

if hasattr(torch, "npu") and torch.npu.is_available():
devices.append("npu")

if hasattr(torch, "mlu") and torch.mlu.is_available():
devices.append("mlu")

return tuple(devices)


with contextlib.suppress(ImportError, ModuleNotFoundError):
import torch_npu # noqa: F401

with contextlib.suppress(ImportError, ModuleNotFoundError):
import torch_mlu # noqa: F401