-
Notifications
You must be signed in to change notification settings - Fork 190
/
Copy pathgraph_utils_test.py
74 lines (63 loc) · 2.46 KB
/
graph_utils_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for neural_structured_learning.tools.graph_utils."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
from absl.testing import absltest
from neural_structured_learning.tools import graph_utils
GRAPH = {'A': {'B': 0.5, 'C': 0.9}, 'B': {'A': 0.4, 'C': 1.0}, 'D': {'A': 0.75}}
class GraphUtilsTest(absltest.TestCase):
def testAddEdge(self):
graph = {}
self.assertTrue(graph_utils.add_edge(graph, ['A', 'B', '0.5']))
# The next 3 calls test that the edge with maximal weight is used.
self.assertTrue(graph_utils.add_edge(graph, ['A', 'C', 0.7]))
self.assertFalse(graph_utils.add_edge(graph, ['A', 'C', 0.9]))
self.assertFalse(graph_utils.add_edge(graph, ['A', 'C', 0.8]))
self.assertTrue(graph_utils.add_edge(graph, ('B', 'A', '0.4')))
# Tests that when no weight is specified, it defaults to 1.0.
self.assertTrue(graph_utils.add_edge(graph, ('B', 'C')))
self.assertTrue(graph_utils.add_edge(graph, ('D', 'A', 0.75)))
self.assertDictEqual(graph, GRAPH)
def testAddUndirectedEdges(self):
g_actual = copy.deepcopy(GRAPH)
graph_utils.add_undirected_edges(g_actual)
self.assertDictEqual(
g_actual, {
'A': {
'B': 0.5,
'C': 0.9,
'D': 0.75
},
'B': {
'A': 0.5, # Note, changed from 0.4 to 0.5
'C': 1.0
},
'C': { # Added
'A': 0.9, # Added
'B': 1.0 # Added
},
'D': {
'A': 0.75
}
})
def testReadAndWriteTsvGraph(self):
path = self.create_tempfile('graph.tsv').full_path
graph_utils.write_tsv_graph(path, GRAPH)
read_graph = graph_utils.read_tsv_graph(path)
self.assertDictEqual(read_graph, GRAPH)
if __name__ == '__main__':
absltest.main()