-
Notifications
You must be signed in to change notification settings - Fork 48
/
Copy pathnotebooks_fill_params.py
65 lines (51 loc) · 1.9 KB
/
notebooks_fill_params.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import re
import shutil
import sys
GOOGLE_CLOUD_PROJECT = os.environ["GOOGLE_CLOUD_PROJECT"]
def make_backup(notebook_path: str):
shutil.copy(
notebook_path,
f"{notebook_path}.backup",
)
def replace_project(line):
"""
Notebooks contain special colab `param {type:"string"}`
comments, which make it easy for customers to fill in their
own information.
"""
# Make sure we're robust to whitespace differences.
cleaned = re.sub(r"\s", "", line)
if cleaned == 'PROJECT_ID=""#@param{type:"string"}':
return f'PROJECT_ID = "{GOOGLE_CLOUD_PROJECT}" # @param {{type:"string"}}\n'
else:
return line
def replace_params(notebook_path: str):
with open(notebook_path, "r", encoding="utf-8") as notebook_file:
notebook_json = json.load(notebook_file)
for cell in notebook_json["cells"]:
lines = cell.get("source", [])
new_lines = [replace_project(line) for line in lines]
cell["source"] = new_lines
with open(notebook_path, "w", encoding="utf-8") as notebook_file:
json.dump(notebook_json, notebook_file, indent=2, ensure_ascii=False)
def main(notebook_paths):
for notebook_path in notebook_paths:
make_backup(notebook_path)
replace_params(notebook_path)
if __name__ == "__main__":
main(sys.argv[1:])