|
18 | 18 | from torch._inductor.ir import ComputedBuffer |
19 | 19 | from torch._inductor.ir import IRNode |
20 | 20 | from torch._inductor.ir import Pointwise |
| 21 | +from torch._inductor.scheduler import BaseSchedulerNode |
21 | 22 | import torch._inductor.lowering # noqa: F401 |
22 | 23 | from torch._inductor.ops_handler import DefaultHandler |
23 | 24 | from torch._inductor.virtualized import V |
@@ -153,18 +154,41 @@ def capture_epilogue(n: str) -> str: |
153 | 154 | epilogue_items = [] |
154 | 155 | elif not tb._fusion_store_map: |
155 | 156 | epilogue_items = list(tb._epilogue_specs.items()) |
| 157 | + if len(epilogue_items) > 1: |
| 158 | + epilogue_items = [] |
156 | 159 | else: |
157 | 160 | epilogue_items = [] |
158 | 161 |
|
159 | 162 | extra_stores: list[ast.expr] = [] |
160 | 163 | for acc_name, nodes in epilogue_items: |
161 | 164 | if nodes: |
162 | | - acc_map = {acc_name: ast.unparse(value)} |
| 165 | + ep_nodes = list(nodes) |
| 166 | + if len(ep_nodes) == 1 and isinstance(ep_nodes[0], BaseSchedulerNode): |
| 167 | + ep_nodes = list(ep_nodes[0].get_nodes()) |
| 168 | + if len(ep_nodes) > 1: |
| 169 | + graph_outputs = set(V.graph.get_output_names()) |
| 170 | + filtered = [ |
| 171 | + n |
| 172 | + for n in ep_nodes |
| 173 | + if any(o.get_name() in graph_outputs for o in n.get_outputs()) |
| 174 | + ] |
| 175 | + if filtered: |
| 176 | + ep_nodes = filtered |
| 177 | + value_str = ast.unparse(value) |
| 178 | + acc_map = {acc_name: value_str} |
| 179 | + # Add aliases to accumulator map for epilogue fusion |
| 180 | + alias_map = getattr(template_buffer, "_helion_alias_map", None) |
| 181 | + if alias_map: |
| 182 | + for alias in alias_map: |
| 183 | + acc_map.setdefault(alias, value_str) |
| 184 | + if template_buffer._output_aliases: |
| 185 | + for alias in template_buffer._output_aliases: |
| 186 | + acc_map.setdefault(alias, value_str) |
163 | 187 | value = _invoke_pointwise_with_ops_handler( |
164 | | - nodes, acc_map, subscript_names, capture_epilogue, "epilogue" |
| 188 | + ep_nodes, acc_map, subscript_names, capture_epilogue, "epilogue" |
165 | 189 | ) |
166 | 190 | # If epilogue changes dtype, store to epilogue output buffer instead |
167 | | - epilogue_out = nodes[-1].node |
| 191 | + epilogue_out = ep_nodes[-1].node |
168 | 192 | if isinstance(epilogue_out, ComputedBuffer): |
169 | 193 | if epilogue_out.get_dtype() != V.graph.get_dtype(acc_name): |
170 | 194 | param = capture_epilogue(epilogue_out.get_name()) |
@@ -292,6 +316,11 @@ def codegen_prologue_fusion( |
292 | 316 | if not nodes: |
293 | 317 | return value |
294 | 318 |
|
| 319 | + if input_name in tb._helion_mutated_input_names: |
| 320 | + if input_name in tb._prologue_fused_once: |
| 321 | + return value |
| 322 | + tb._prologue_fused_once.add(input_name) |
| 323 | + |
295 | 324 | subscript_names = _get_subscript_names(state, subscript) |
296 | 325 | if not subscript_names: |
297 | 326 | return value |
|
0 commit comments