|
| 1 | +"""Tools to analyze tasks running in asyncio programs.""" |
| 2 | + |
| 3 | +from dataclasses import dataclass |
| 4 | +from collections import defaultdict |
| 5 | +from itertools import count |
| 6 | +from enum import Enum |
| 7 | +import sys |
| 8 | +from _remotedebugging import get_all_awaited_by |
| 9 | + |
| 10 | + |
| 11 | +class NodeType(Enum): |
| 12 | + COROUTINE = 1 |
| 13 | + TASK = 2 |
| 14 | + |
| 15 | + |
| 16 | +@dataclass(frozen=True) |
| 17 | +class CycleFoundException(Exception): |
| 18 | + """Raised when there is a cycle when drawing the call tree.""" |
| 19 | + cycles: list[list[int]] |
| 20 | + id2name: dict[int, str] |
| 21 | + |
| 22 | + |
| 23 | +# ─── indexing helpers ─────────────────────────────────────────── |
| 24 | +def _index(result): |
| 25 | + id2name, awaits = {}, [] |
| 26 | + for _thr_id, tasks in result: |
| 27 | + for tid, tname, awaited in tasks: |
| 28 | + id2name[tid] = tname |
| 29 | + for stack, parent_id in awaited: |
| 30 | + awaits.append((parent_id, stack, tid)) |
| 31 | + return id2name, awaits |
| 32 | + |
| 33 | + |
| 34 | +def _build_tree(id2name, awaits): |
| 35 | + id2label = {(NodeType.TASK, tid): name for tid, name in id2name.items()} |
| 36 | + children = defaultdict(list) |
| 37 | + cor_names = defaultdict(dict) # (parent) -> {frame: node} |
| 38 | + cor_id_seq = count(1) |
| 39 | + |
| 40 | + def _cor_node(parent_key, frame_name): |
| 41 | + """Return an existing or new (NodeType.COROUTINE, …) node under *parent_key*.""" |
| 42 | + bucket = cor_names[parent_key] |
| 43 | + if frame_name in bucket: |
| 44 | + return bucket[frame_name] |
| 45 | + node_key = (NodeType.COROUTINE, f"c{next(cor_id_seq)}") |
| 46 | + id2label[node_key] = frame_name |
| 47 | + children[parent_key].append(node_key) |
| 48 | + bucket[frame_name] = node_key |
| 49 | + return node_key |
| 50 | + |
| 51 | + # lay down parent ➜ …frames… ➜ child paths |
| 52 | + for parent_id, stack, child_id in awaits: |
| 53 | + cur = (NodeType.TASK, parent_id) |
| 54 | + for frame in reversed(stack): # outer-most → inner-most |
| 55 | + cur = _cor_node(cur, frame) |
| 56 | + child_key = (NodeType.TASK, child_id) |
| 57 | + if child_key not in children[cur]: |
| 58 | + children[cur].append(child_key) |
| 59 | + |
| 60 | + return id2label, children |
| 61 | + |
| 62 | + |
| 63 | +def _roots(id2label, children): |
| 64 | + all_children = {c for kids in children.values() for c in kids} |
| 65 | + return [n for n in id2label if n not in all_children] |
| 66 | + |
| 67 | +# ─── detect cycles in the task-to-task graph ─────────────────────── |
| 68 | +def _task_graph(awaits): |
| 69 | + """Return {parent_task_id: {child_task_id, …}, …}.""" |
| 70 | + g = defaultdict(set) |
| 71 | + for parent_id, _stack, child_id in awaits: |
| 72 | + g[parent_id].add(child_id) |
| 73 | + return g |
| 74 | + |
| 75 | + |
| 76 | +def _find_cycles(graph): |
| 77 | + """ |
| 78 | + Depth-first search for back-edges. |
| 79 | +
|
| 80 | + Returns a list of cycles (each cycle is a list of task-ids) or an |
| 81 | + empty list if the graph is acyclic. |
| 82 | + """ |
| 83 | + WHITE, GREY, BLACK = 0, 1, 2 |
| 84 | + color = defaultdict(lambda: WHITE) |
| 85 | + path, cycles = [], [] |
| 86 | + |
| 87 | + def dfs(v): |
| 88 | + color[v] = GREY |
| 89 | + path.append(v) |
| 90 | + for w in graph.get(v, ()): |
| 91 | + if color[w] == WHITE: |
| 92 | + dfs(w) |
| 93 | + elif color[w] == GREY: # back-edge → cycle! |
| 94 | + i = path.index(w) |
| 95 | + cycles.append(path[i:] + [w]) # make a copy |
| 96 | + color[v] = BLACK |
| 97 | + path.pop() |
| 98 | + |
| 99 | + for v in list(graph): |
| 100 | + if color[v] == WHITE: |
| 101 | + dfs(v) |
| 102 | + return cycles |
| 103 | + |
| 104 | + |
| 105 | +# ─── PRINT TREE FUNCTION ─────────────────────────────────────── |
| 106 | +def build_async_tree(result, task_emoji="(T)", cor_emoji=""): |
| 107 | + """ |
| 108 | + Build a list of strings for pretty-print a async call tree. |
| 109 | +
|
| 110 | + The call tree is produced by `get_all_async_stacks()`, prefixing tasks |
| 111 | + with `task_emoji` and coroutine frames with `cor_emoji`. |
| 112 | + """ |
| 113 | + id2name, awaits = _index(result) |
| 114 | + g = _task_graph(awaits) |
| 115 | + cycles = _find_cycles(g) |
| 116 | + if cycles: |
| 117 | + raise CycleFoundException(cycles, id2name) |
| 118 | + labels, children = _build_tree(id2name, awaits) |
| 119 | + |
| 120 | + def pretty(node): |
| 121 | + flag = task_emoji if node[0] == NodeType.TASK else cor_emoji |
| 122 | + return f"{flag} {labels[node]}" |
| 123 | + |
| 124 | + def render(node, prefix="", last=True, buf=None): |
| 125 | + if buf is None: |
| 126 | + buf = [] |
| 127 | + buf.append(f"{prefix}{'└── ' if last else '├── '}{pretty(node)}") |
| 128 | + new_pref = prefix + (" " if last else "│ ") |
| 129 | + kids = children.get(node, []) |
| 130 | + for i, kid in enumerate(kids): |
| 131 | + render(kid, new_pref, i == len(kids) - 1, buf) |
| 132 | + return buf |
| 133 | + |
| 134 | + return [render(root) for root in _roots(labels, children)] |
| 135 | + |
| 136 | + |
| 137 | +def build_task_table(result): |
| 138 | + id2name, awaits = _index(result) |
| 139 | + table = [] |
| 140 | + for tid, tasks in result: |
| 141 | + for task_id, task_name, awaited in tasks: |
| 142 | + if not awaited: |
| 143 | + table.append( |
| 144 | + [ |
| 145 | + tid, |
| 146 | + hex(task_id), |
| 147 | + task_name, |
| 148 | + "", |
| 149 | + "", |
| 150 | + "0x0" |
| 151 | + ] |
| 152 | + ) |
| 153 | + for stack, awaiter_id in awaited: |
| 154 | + coroutine_chain = " -> ".join(stack) |
| 155 | + awaiter_name = id2name.get(awaiter_id, "Unknown") |
| 156 | + table.append( |
| 157 | + [ |
| 158 | + tid, |
| 159 | + hex(task_id), |
| 160 | + task_name, |
| 161 | + coroutine_chain, |
| 162 | + awaiter_name, |
| 163 | + hex(awaiter_id), |
| 164 | + ] |
| 165 | + ) |
| 166 | + |
| 167 | + return table |
| 168 | + |
| 169 | +def _print_cycle_exception(exception: CycleFoundException): |
| 170 | + print("ERROR: await-graph contains cycles – cannot print a tree!", file=sys.stderr) |
| 171 | + print("", file=sys.stderr) |
| 172 | + for c in exception.cycles: |
| 173 | + inames = " → ".join(exception.id2name.get(tid, hex(tid)) for tid in c) |
| 174 | + print(f"cycle: {inames}", file=sys.stderr) |
| 175 | + |
| 176 | + |
| 177 | +def _get_awaited_by_tasks(pid: int) -> list: |
| 178 | + try: |
| 179 | + return get_all_awaited_by(pid) |
| 180 | + except RuntimeError as e: |
| 181 | + while e.__context__ is not None: |
| 182 | + e = e.__context__ |
| 183 | + print(f"Error retrieving tasks: {e}") |
| 184 | + sys.exit(1) |
| 185 | + |
| 186 | + |
| 187 | +def display_awaited_by_tasks_table(pid: int) -> None: |
| 188 | + """Build and print a table of all pending tasks under `pid`.""" |
| 189 | + |
| 190 | + tasks = _get_awaited_by_tasks(pid) |
| 191 | + table = build_task_table(tasks) |
| 192 | + # Print the table in a simple tabular format |
| 193 | + print( |
| 194 | + f"{'tid':<10} {'task id':<20} {'task name':<20} {'coroutine chain':<50} {'awaiter name':<20} {'awaiter id':<15}" |
| 195 | + ) |
| 196 | + print("-" * 135) |
| 197 | + for row in table: |
| 198 | + print(f"{row[0]:<10} {row[1]:<20} {row[2]:<20} {row[3]:<50} {row[4]:<20} {row[5]:<15}") |
| 199 | + |
| 200 | + |
| 201 | +def display_awaited_by_tasks_tree(pid: int) -> None: |
| 202 | + """Build and print a tree of all pending tasks under `pid`.""" |
| 203 | + |
| 204 | + tasks = _get_awaited_by_tasks(pid) |
| 205 | + try: |
| 206 | + result = build_async_tree(tasks) |
| 207 | + except CycleFoundException as e: |
| 208 | + _print_cycle_exception(e) |
| 209 | + sys.exit(1) |
| 210 | + |
| 211 | + for tree in result: |
| 212 | + print("\n".join(tree)) |
0 commit comments