-
Notifications
You must be signed in to change notification settings - Fork 1k
Expand file tree
/
Copy pathcompiler.py
More file actions
514 lines (436 loc) · 16.5 KB
/
Copy pathcompiler.py
File metadata and controls
514 lines (436 loc) · 16.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
import logging
from pathlib import Path
from typing import Optional
import executorch.backends.cadence.aot.ops_registrations # noqa
import torch
from executorch.backends.cadence.aot.compiler_funcs import (
prepare as prepare_fn,
QuantFusionPass,
QuantizedInputWrapper,
trace as trace_fn,
)
from executorch.backends.cadence.aot.memory_planning import (
CadenceMemoryPlanning,
print_memory_planning_info,
)
from executorch.backends.cadence.aot.quantizer.passes.fuse_ops import FuseQATConvBN
from executorch.backends.cadence.aot.quantizer.quantizer import (
CadenceDefaultQuantizer,
CadenceQuantizer,
)
from executorch.backends.cadence.aot.utils import (
get_default_memory_config,
MemoryConfig,
)
from executorch.devtools import generate_etrecord
from executorch.exir import (
EdgeCompileConfig,
EdgeProgramManager,
ExecutorchBackendConfig,
ExecutorchProgramManager,
)
from executorch.exir.pass_manager import PassManager
from executorch.exir.passes import ToOutVarPass
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
from executorch.exir.program._program import to_edge
from torch.export.exported_program import ExportedProgram
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e
from .pass_utils import EdgePassesConfig
from .passes import apply_exir_ops_passes, apply_torch_ops_passes
from .utils import print_ops_info
default_quantizer = CadenceDefaultQuantizer()
def trace(
model: torch.nn.Module,
inputs: tuple[object, ...],
dump_graphs: bool = False,
ops_to_keep: Optional[list[torch._ops.OpOverload]] = None,
is_qat: bool = False,
) -> ExportedProgram:
"""
Trace the model with export and return an ExportedProgram.
"""
if ops_to_keep is None:
ops_to_keep = []
program = trace_fn(
model, inputs, is_qat=is_qat, strict=True, ops_to_keep=ops_to_keep
)
if dump_graphs:
logging.info("Graph before quantization:")
logging.info(program.graph_module.graph.print_tabular())
return program
def prepare_pt2(
model: torch.nn.Module,
inputs: tuple[object, ...],
quantizer: CadenceQuantizer,
dump_graphs: bool = False,
is_qat: bool = False,
) -> torch.fx.GraphModule:
"""
Trace and Prepare a model using the given quantizer.
The quantizer must be supplied and be the same as the one used to
fuse the model later, if applicable. If you do not expect that behavior,
please use quantize_pt2 instead, which will instantiate a
default quantizer for you if needed.
Returns a GraphModule with the prepared model.
"""
ops_to_keep = quantizer.get_ops_to_preserve_from_decomposition()
traced_program = trace(
model, inputs, dump_graphs=dump_graphs, ops_to_keep=ops_to_keep, is_qat=is_qat
)
prepared_program = prepare_traced_pt2(
traced_program, quantizer, dump_graphs=dump_graphs, is_qat=is_qat
)
return prepared_program
def prepare_traced_pt2(
program: ExportedProgram,
quantizer: CadenceQuantizer,
dump_graphs: bool = False,
is_qat: bool = False,
) -> torch.fx.GraphModule:
"""
Prepare a model using the given quantizer.
The quantizer must be supplied and be the same as the one used to
fuse the model later, if applicable. If you do not expect that behavior,
please use quantize_pt2 instead, which will instantiate a
default quantizer for you if needed.
Returns a GraphModule with the prepared model.
"""
prepared_model = prepare_fn(program, quantizer, is_qat=is_qat)
if dump_graphs:
logging.info("Graph after preparation:")
logging.info(prepared_model.graph.print_tabular())
return prepared_model
def convert_pt2(
graph_module: torch.fx.GraphModule,
dump_graphs: bool = False,
) -> torch.fx.GraphModule:
"""
Convert the model
Returns a GraphModule with the converted model.
"""
converted_model = convert_pt2e(graph_module)
if dump_graphs:
logging.info("Graph after convert:")
logging.info(converted_model.graph.print_tabular())
return converted_model
# Note: this is not meant as a primary API since it can create inconsistencies
# if the quantizer here is different from the quantizer used to prepare/convert.
# It is however useful for unit tests to separate the converted model from the
# fused model, to be able to get reference numerics.
# If this does not apply, please use quantize_pt2 instead.
def apply_pre_edge_transform_passes(
converted_program: ExportedProgram,
quantizer: CadenceQuantizer,
) -> ExportedProgram:
"""
Apply pre-edge transform passes including QuantFusionPass and torch ops passes.
This mirrors the Cadence AOT compiler flow:
1. QuantFusionPass - fuses dq->op->q patterns
2. apply_torch_ops_passes - applied just before to_edge()
The quantizer must be the same as the one used to convert the model.
If you do not expect that behavior, please use quantize_pt2 instead,
which will instantiate a default quantizer for you if needed.
Returns an ExportedProgram with the fused model.
"""
# pyre-ignore[16]: no attribute
patterns = [q.pattern for q in quantizer.quantizers]
PassManager(
[
FuseQATConvBN(converted_program),
QuantFusionPass(patterns),
]
)(converted_program.graph_module)
# Apply torch ops passes (e.g., ReplaceMulTensorWithMulAndFullOpsPass)
fused_program = apply_torch_ops_passes(converted_program)
return fused_program
# Note: quantizer is not optional here to force the user to supply a quantizer
# and ensure consistency is more likely to be maintained.
def get_fake_quant_model(
model: torch.nn.Module,
inputs: tuple[object, ...],
quantizer: CadenceQuantizer,
calibration_data: Optional[list[tuple[object, ...]]] = None,
dump_graphs: bool = False,
) -> torch.fx.GraphModule:
# Make the model inference mode by calling model.eval()
model.eval()
ops_to_keep = quantizer.get_ops_to_preserve_from_decomposition()
program = trace(model, inputs, dump_graphs=dump_graphs, ops_to_keep=ops_to_keep)
if dump_graphs:
logging.info("Graph after trace:")
logging.info(program.graph.print_tabular())
# Get prepared graph module
prepared_gm = prepare_traced_pt2(program, quantizer, dump_graphs=dump_graphs)
# Calibrate
# If no calibration data is provided, use the inputs
if calibration_data is None:
calibration_data = [inputs]
for samples in calibration_data:
prepared_gm(*samples)
# Get converted graph module
converted_gm = convert_pt2(prepared_gm, dump_graphs=dump_graphs)
return converted_gm
def quantize_pt2(
model: torch.nn.Module,
inputs: tuple[object, ...],
quantizer: Optional[CadenceQuantizer] = None,
calibration_data: Optional[list[tuple[object, ...]]] = None,
dump_graphs: bool = False,
quant_input_args: Optional[list[str]] = None,
) -> ExportedProgram:
"""
Trace, prepare, convert and fuse the model using the given quantizer.
If calibration data is provided, it will be used to calibrate the model. If
not, the inputs will be used for calibration instead, which is useful for
unit tests but should not be used for end-to-end use cases.
Returns an ExportedProgram with the quantized model.
Note: this function should not be called directly in general. Please use
quantize_and_export_to_executorch for most needs.
"""
# Instantiate the quantizer to CadenceQuantizer if not supplied
if not quantizer:
quantizer = CadenceDefaultQuantizer()
# Get the converted (aka fake quant) graph module
converted_gm = get_fake_quant_model(
model,
inputs,
quantizer=quantizer,
calibration_data=calibration_data,
dump_graphs=dump_graphs,
)
# Wrap the model to handle quantized inputs if provided
if quant_input_args is not None:
converted_gm = QuantizedInputWrapper(converted_gm, quant_input_args)
# Apply quant fusion to the exported program
program = torch.export.export(converted_gm, inputs, strict=True)
# Sink dequant nodes through transparent ops so they fuse per-branch.
if quant_input_args is not None:
QuantizedInputWrapper.sink_dequants(program)
fused_program = apply_pre_edge_transform_passes(program, quantizer)
if dump_graphs:
logging.info("Graph after quantization and fusion:")
logging.info(fused_program.graph_module.graph.print_tabular())
return fused_program
TO_EDGE_OP_EXCEPTION_LIST: list[torch._ops.OpOverload] = [
torch.ops.aten._linalg_det.default,
torch.ops.aten._linalg_svd.default,
torch.ops.aten._native_batch_norm_legit_functional.default,
torch.ops.aten.linear.default,
torch.ops.aten.linalg_vector_norm.default,
torch.ops.aten.unfold.default,
torch.ops.aten.angle.default,
torch.ops.aten.rms_norm.default,
]
TO_EDGE_PRESERVE_OPS: list[torch._ops.OpOverload] = [
torch.ops.aten.rms_norm.default,
]
def _lower_ep_to_edge(
expo_program: ExportedProgram,
dump_graphs: bool = False,
constant_methods: Optional[dict[str, object]] = None,
core_aten_exceptions: Optional[list[torch._ops.OpOverload]] = None,
) -> EdgeProgramManager:
"""
Lower an ExportedProgram to an EdgeProgramManager (in edge IR).
"""
# Apply passes which transform the ExportedProgram before it gets lowered to edge.
expo_program = apply_torch_ops_passes(expo_program)
# Call to_edge to convert the graph to edge IR.
# Note: dim_order is skipped (https://github.com/pytorch/executorch/issues/3704)
edge_prog_manager = to_edge(
expo_program,
compile_config=EdgeCompileConfig(
_skip_dim_order=True,
# Allow specific non-core aten ops in the IR.
_core_aten_ops_exception_list=TO_EDGE_OP_EXCEPTION_LIST
+ (core_aten_exceptions or []),
preserve_ops=TO_EDGE_PRESERVE_OPS,
),
constant_methods=constant_methods,
)
if dump_graphs:
logging.info("Graph after Edge lowering:")
logging.info(
edge_prog_manager.exported_program().graph_module.graph.print_tabular()
)
return edge_prog_manager
# Export the model and lower it to an EdgeProgramManager (in edge IR).
def export_to_edge(
model: torch.nn.Module,
inputs: tuple[object, ...],
dump_graphs: bool = False,
constant_methods: Optional[dict[str, object]] = None,
core_aten_exceptions: Optional[list[torch._ops.OpOverload]] = None,
) -> EdgeProgramManager:
assert isinstance(model, torch.nn.Module), "model should be an nn.Module"
# Export the model into an ExportedProgram.
expo_program = trace(model, inputs)
# Lower the model to edge IR.
edge_prog_manager = _lower_ep_to_edge(
expo_program, dump_graphs, constant_methods, core_aten_exceptions
)
return edge_prog_manager
def quantize_and_export_to_edge(
model: torch.nn.Module,
inputs: tuple[object, ...],
quantizer: Optional[CadenceQuantizer] = None,
dump_graphs: bool = False,
constant_methods: Optional[dict[str, object]] = None,
calibration_data: Optional[list[tuple[object, ...]]] = None,
core_aten_exceptions: Optional[list[torch._ops.OpOverload]] = None,
) -> EdgeProgramManager:
"""
Trace, quantize and lower a model/inputs pair to edge IR.
"""
quantized_model = quantize_pt2(
model,
inputs,
quantizer=quantizer,
calibration_data=calibration_data,
dump_graphs=dump_graphs,
)
return _lower_ep_to_edge(
quantized_model,
dump_graphs=dump_graphs,
constant_methods=constant_methods,
core_aten_exceptions=core_aten_exceptions,
)
def _lower_ep_to_cadence(
program: ExportedProgram,
dump_graphs: bool = False,
opt_level: int = 1,
edge_passes_config: Optional[EdgePassesConfig] = None,
) -> EdgeProgramManager:
"""
Lower an existing ExportedProgram to edge IR and apply frontend optimization passes.
"""
edge_prog_manager = _lower_ep_to_edge(program, dump_graphs=dump_graphs)
cadence_prog_manager = apply_exir_ops_passes(
opt_level, edge_prog_manager, edge_passes_config
)
return cadence_prog_manager
def export_to_cadence(
model: torch.nn.Module,
inputs: tuple[object, ...],
dump_graphs: bool = False,
opt_level: int = 1,
edge_passes_config: Optional[EdgePassesConfig] = None,
) -> EdgeProgramManager:
edge_prog_manager = export_to_edge(model, inputs, dump_graphs=dump_graphs)
cadence_prog_manager = apply_exir_ops_passes(
opt_level, edge_prog_manager, edge_passes_config
)
return cadence_prog_manager
def quantize_and_export_to_cadence(
model: torch.nn.Module,
inputs: tuple[object, ...],
dump_graphs: bool = False,
opt_level: int = 1,
edge_passes_config: Optional[EdgePassesConfig] = None,
) -> EdgeProgramManager:
"""
Trace, quantize, lower a model/inputs pair to edge IR and apply frontend
optimization passes.
"""
quantized_model = quantize_pt2(model, inputs)
return _lower_ep_to_cadence(
quantized_model,
opt_level=opt_level,
dump_graphs=dump_graphs,
edge_passes_config=edge_passes_config,
)
def export_to_executorch_gen_etrecord(
model: torch.nn.Module,
inputs: tuple[object, ...],
output_dir: Optional[str] = None,
opt_level: int = 1,
mem_algo: int = 0,
alloc_graph_input: bool = True,
alloc_graph_output: bool = True,
memory_config: Optional[MemoryConfig] = None,
dump_graphs: bool = False,
) -> ExecutorchProgramManager:
ep = torch.export.export(model, inputs, strict=True)
return _lower_ep_to_cadence_gen_etrecord(
ep,
output_dir=output_dir,
opt_level=opt_level,
mem_algo=mem_algo,
alloc_graph_input=alloc_graph_input,
alloc_graph_output=alloc_graph_output,
memory_config=memory_config,
dump_graphs=dump_graphs,
)
# Export the model and lower it to an EdgeProgramManager (in edge IR), and
# apply passes specific to Cadence DSP execution. Return both to print the
# differences.
def _lower_ep_to_cadence_gen_etrecord(
ep: ExportedProgram,
output_dir: Optional[str] = None,
opt_level: int = 1,
mem_algo: int = 0,
alloc_graph_input: bool = True,
alloc_graph_output: bool = True,
memory_config: Optional[MemoryConfig] = None,
dump_graphs: bool = False,
) -> ExecutorchProgramManager:
edge_prog_manager = _lower_ep_to_edge(ep, dump_graphs)
cadence_prog_manager = apply_exir_ops_passes(opt_level, edge_prog_manager)
# Print some information to terminal
print_ops_info(
edge_prog_manager.exported_program().graph_module,
cadence_prog_manager.exported_program().graph_module,
)
if memory_config is None:
memory_config = get_default_memory_config()
memory_planning_pass = CadenceMemoryPlanning(
memory_config,
opt_level=opt_level,
mem_algo=mem_algo,
alloc_graph_input=alloc_graph_input,
alloc_graph_output=alloc_graph_output,
)
# Get executorch program after Cadence specific passes
exec_prog: ExecutorchProgramManager = cadence_prog_manager.to_executorch(
ExecutorchBackendConfig(
memory_planning_pass=memory_planning_pass,
emit_stacktrace=False,
to_out_var_pass=ToOutVarPass(),
extract_delegate_segments=False,
sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(),
),
)
print_memory_planning_info(
exec_prog,
memory_config,
opt_level,
alloc_graph_input,
alloc_graph_output,
)
if output_dir:
_gen_etrecord(edge_prog_manager, exec_prog, Path(output_dir))
else:
logging.warning("No output directory provided, skipping ETRecord generation")
return exec_prog
def _gen_etrecord(
edge_program: EdgeProgramManager,
et_program: ExecutorchProgramManager,
output_dir: Path,
) -> None:
etrec_path = output_dir / "etrecord.bin"
try:
generate_etrecord(
et_record=etrec_path,
edge_dialect_program=edge_program,
executorch_program=et_program,
)
logging.info(f"Generated ETRecord at {etrec_path}")
except Exception:
# Any errors here shouldn't block the rest of the flow
logging.exception("Encountered exception while generating ETRecord")