-
Notifications
You must be signed in to change notification settings - Fork 127
Use new Triton runtime #1338
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use new Triton runtime #1338
Changes from all commits
888678f
9dc26e4
69e0421
a1d78ee
9e34b19
307d98e
7e38b79
12cc6e8
86dd713
4d8c87c
c2d9b55
cd7ef4f
5717afb
b1d57b2
ea4868e
1c81cc0
24e5fe9
8534d1c
58d9023
cbd938e
1a3886e
c8252f5
6184f20
9470117
7fdf663
f704f1f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,7 +9,11 @@ | |
import sysconfig | ||
import tempfile | ||
import types | ||
from concurrent.futures import Future | ||
from concurrent.futures import ThreadPoolExecutor | ||
from ctypes import cdll | ||
from typing import Any | ||
from typing import Dict | ||
|
||
from torch.utils import cpp_extension | ||
|
||
|
@@ -160,9 +164,10 @@ def load(cls, source_code): | |
code = compile(f.read(), path, "exec") | ||
mod = types.ModuleType(f"{__name__}.{key}") | ||
mod.__file__ = path | ||
mod.key = key | ||
exec(code, mod.__dict__, mod.__dict__) | ||
cls.cache[key] = mod | ||
cls.cache[key].key = key | ||
# another thread might set this first | ||
cls.cache.setdefault(key, mod) | ||
return cls.cache[key] | ||
|
||
|
||
|
@@ -174,7 +179,54 @@ def patch_triton_dir(): | |
|
||
|
||
class TritonCodeCache: | ||
@staticmethod | ||
def get_name(mod): | ||
(name,) = [n for n in dir(mod) if n.startswith("kernel")] | ||
return name | ||
|
||
@classmethod | ||
def load(cls, source_code): | ||
patch_triton_dir() | ||
return PyCodeCache.load(source_code) | ||
mod = PyCodeCache.load(source_code) | ||
return getattr(mod, cls.get_name(mod)) | ||
|
||
|
||
class AsyncCompile: | ||
@staticmethod | ||
@functools.lru_cache(1) | ||
def pool(): | ||
assert config.compile_threads > 1 | ||
return ThreadPoolExecutor(config.compile_threads) | ||
|
||
@classmethod | ||
def submit(cls, task): | ||
if config.compile_threads <= 1: | ||
return task() | ||
return cls.pool().submit(task) | ||
|
||
@classmethod | ||
def map(cls, fn, seq): | ||
if config.compile_threads <= 1 or len(seq) <= 1: | ||
return list(map(fn, seq)) | ||
return [t.result() for t in [cls.pool().submit(fn, x) for x in seq]] | ||
|
||
def triton(self, source_code): | ||
kernel = TritonCodeCache.load(source_code) | ||
|
||
def task(): | ||
kernel.precompile() | ||
return kernel | ||
|
||
return self.submit(task) | ||
|
||
def cpp(self, source_code): | ||
def task(): | ||
return CppCodeCache.load(source_code).kernel | ||
Comment on lines
+214
to
+224
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The cache load happens at subtly different times between these two. Triton's loads in the call to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The C++ cache load calls gcc, which is expensive (and also inherently thread safe). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Interesting issue with potential relevance: #1347 |
||
|
||
return self.submit(task) | ||
|
||
def wait(self, scope: Dict[str, Any]): | ||
if config.compile_threads > 1: | ||
for key, result in list(scope.items()): | ||
if isinstance(result, Future): | ||
voznesenskym marked this conversation as resolved.
Show resolved
Hide resolved
|
||
scope[key] = result.result() |
Uh oh!
There was an error while loading. Please reload this page.