-
Notifications
You must be signed in to change notification settings - Fork 2.1k
/
Copy pathrender_presets.py
129 lines (102 loc) · 3.42 KB
/
render_presets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
"""Custom rendering code for keras_hub presets.
The model metadata is pulled from the library, each preset has a
metadata dictionary as follows:
{
'description': Description of the model,
'params': Parameter count of the model,
'official_name': Name of the model,
'path': Relative path of the model on keras.io,
}
"""
from hub_master import MODELS_MASTER
try:
import keras_hub
except Exception as e:
print(f"Could not import KerasHub. Exception: {e}")
keras_hub = None
TABLE_HEADER = (
"Preset | Model API | Parameters | Description\n"
"-------|-----------|------------|------------\n"
)
TABLE_HEADER_PER_MODEL = (
"Preset | Parameters | Description\n"
"-------|------------|------------\n"
)
def format_param_count(metadata):
"""Format a parameter count for the table."""
try:
count = metadata["params"]
except KeyError:
return "Unknown"
if count >= 1e9:
return f"{(count / 1e9):.2f}B"
if count >= 1e6:
return f"{(count / 1e6):.2f}M"
if count >= 1e3:
return f"{(count / 1e3):.2f}K"
return f"{count}"
def format_path(metadata):
"""Returns Path for the given preset"""
for child in MODELS_MASTER["children"]:
path = child["path"].strip("/")
if metadata["path"] == path:
text = child["title"]
link = f"/keras_hub/api/models/{path}"
return f"[{text}]({link})"
return "-"
def format_preset_link(preset, handle):
url = handle.replace("kaggle://", "https://www.kaggle.com/models/")
return f"[{preset}]({url})"
def is_base_class(symbol):
return symbol in (
keras_hub.models.Backbone,
keras_hub.models.Tokenizer,
keras_hub.models.Preprocessor,
keras_hub.models.Task,
)
def sort_presets(presets):
# Sort by path and then by parameter count.
return sorted(
presets.keys(),
key=lambda x: (
presets[x]["metadata"]["path"],
presets[x]["metadata"]["params"],
)
)
def render_row(preset, data, add_doc_link=False):
"""Renders a row for a preset in a markdown table."""
metadata = data["metadata"]
url = data["kaggle_handle"]
url = url.replace("kaggle://", "https://www.kaggle.com/models/")
cols = []
cols.append(format_preset_link(preset, data["kaggle_handle"]))
if add_doc_link:
cols.append(format_path(metadata))
cols.append(format_param_count(metadata))
cols.append(metadata["description"])
return " | ".join(cols) + "\n"
def render_all_presets():
"""Renders the markdown table for backbone presets as a string."""
table = TABLE_HEADER
symbol = keras_hub.models.Backbone
for preset in sort_presets(symbol.presets):
data = symbol.presets[preset]
table += render_row(preset, data, add_doc_link=True)
return table
def render_table(symbol):
if keras_hub is None:
return ""
table = TABLE_HEADER_PER_MODEL
if is_base_class(symbol) or len(symbol.presets) == 0:
return None
for preset in sort_presets(symbol.presets):
data = symbol.presets[preset]
table += render_row(preset, data)
return table
def render_tags(template):
"""Replaces all custom KerasHub tags with rendered content."""
if keras_hub is None:
return template
if "{{presets_table}}" in template:
template = template.replace("{{presets_table}}", render_all_presets())
return template