1 | #!/usr/bin/env python |
---|
2 | """ |
---|
3 | Sample a response function (surrogate model). Because so much of |
---|
4 | the Rappture internals expect objects to have an associated xml |
---|
5 | object and path, we will return the plot in an xml file. |
---|
6 | """ |
---|
7 | |
---|
8 | from __future__ import print_function |
---|
9 | import sys |
---|
10 | import numpy as np |
---|
11 | import puq |
---|
12 | from puq.jpickle import unpickle |
---|
13 | import xml.etree.ElementTree as xml |
---|
14 | |
---|
15 | |
---|
16 | from itertools import product |
---|
17 | # Redirect stdout and stderr to files for debugging. |
---|
18 | # Append to the files created in get_params.py |
---|
19 | sys.stdout = open("response.out", 'w') |
---|
20 | sys.stderr = open("response.err", 'w') |
---|
21 | |
---|
22 | |
---|
23 | # variable names to labels |
---|
24 | def subs_names(varl, h5): |
---|
25 | varlist = [] |
---|
26 | for v in varl: |
---|
27 | try: |
---|
28 | lab = h5['/input/params/%s' % v[0]].attrs['label'] |
---|
29 | except: |
---|
30 | lab = str(v[0]) |
---|
31 | varlist.append(lab) |
---|
32 | return varlist |
---|
33 | |
---|
34 | |
---|
35 | def plot_resp1(dout, resp, name, rlabel): |
---|
36 | print('plot_resp1', name, rlabel) |
---|
37 | |
---|
38 | numpoints = 100 |
---|
39 | |
---|
40 | resp = unpickle(resp) |
---|
41 | var = None |
---|
42 | |
---|
43 | for index, p in enumerate(resp.params): |
---|
44 | if p.name == name: |
---|
45 | var = p |
---|
46 | break |
---|
47 | |
---|
48 | if var is None: |
---|
49 | print("plot_resp1 error: name %s not recognized" % name) |
---|
50 | return |
---|
51 | |
---|
52 | data = resp.data |
---|
53 | print('data=', repr(data)) |
---|
54 | for ind, p in enumerate(resp.params): |
---|
55 | if ind == index: |
---|
56 | continue |
---|
57 | m = p.pdf.mean |
---|
58 | means = np.isclose(m, data[:, ind], rtol=1e-6, atol=1e-12) |
---|
59 | data = data[means] |
---|
60 | |
---|
61 | print("vars=", resp.vars) |
---|
62 | print("data=", repr(data)) |
---|
63 | |
---|
64 | curve = xml.SubElement(dout, 'curve', {'id': 'response'}) |
---|
65 | about = xml.SubElement(curve, 'about') |
---|
66 | xml.SubElement(about, 'label').text = rlabel |
---|
67 | xml.SubElement(about, 'group').text = rlabel |
---|
68 | |
---|
69 | xaxis = xml.SubElement(curve, 'xaxis') |
---|
70 | xml.SubElement(xaxis, 'label').text = var.label |
---|
71 | |
---|
72 | yaxis = xml.SubElement(curve, 'yaxis') |
---|
73 | xml.SubElement(yaxis, 'label').text = rlabel |
---|
74 | |
---|
75 | x = np.linspace(*var.pdf.range, num=numpoints) |
---|
76 | |
---|
77 | allpts = np.empty((numpoints, len(resp.params))) |
---|
78 | for i, v in enumerate(resp.params): |
---|
79 | if v.name == var.name: |
---|
80 | allpts[:, i] = x |
---|
81 | else: |
---|
82 | allpts[:, i] = np.mean(v.pdf.mean) |
---|
83 | |
---|
84 | pts = resp.evala(allpts) |
---|
85 | xy = '\n'.join([' '.join(map(repr, a)) for a in zip(x, pts)]) |
---|
86 | comp = xml.SubElement(curve, 'component') |
---|
87 | xml.SubElement(comp, 'xy').text = xy |
---|
88 | |
---|
89 | # scatter plot sampled data on response surface |
---|
90 | curve = xml.SubElement(dout, 'curve', {'id': 'scatter'}) |
---|
91 | about = xml.SubElement(curve, 'about') |
---|
92 | xml.SubElement(about, 'label').text = 'Data Points' |
---|
93 | xml.SubElement(about, 'group').text = rlabel |
---|
94 | xml.SubElement(about, 'type').text = 'scatter' |
---|
95 | comp = xml.SubElement(curve, 'component') |
---|
96 | xy = '\n'.join([' '.join(map(repr, a)) for a in zip(data[:, index], data[:, -1])]) |
---|
97 | xml.SubElement(comp, 'xy').text = xy |
---|
98 | |
---|
99 | |
---|
100 | def plot_resp2(dout, resp, name1, name2, rlabel): |
---|
101 | print("plot_resp2", name1, name2, rlabel) |
---|
102 | numpoints = 50 |
---|
103 | |
---|
104 | resp = unpickle(resp) |
---|
105 | for p in resp.params: |
---|
106 | if p.name == name1: |
---|
107 | v1 = p |
---|
108 | elif p.name == name2: |
---|
109 | v2 = p |
---|
110 | |
---|
111 | x = np.linspace(*v1.pdf.range, num=numpoints) |
---|
112 | y = np.linspace(*v2.pdf.range, num=numpoints) |
---|
113 | pts = np.array([(b, a) for a, b in product(y, x)]) |
---|
114 | allpts = np.empty((numpoints**2, len(resp.vars))) |
---|
115 | for i, v in enumerate(resp.vars): |
---|
116 | if v[0] == v1.name: |
---|
117 | allpts[:, i] = pts[:, 0] |
---|
118 | elif v[0] == v2.name: |
---|
119 | allpts[:, i] = pts[:, 1] |
---|
120 | else: |
---|
121 | allpts[:, i] = np.mean(v[1]) |
---|
122 | pts = np.array(resp.evala(allpts)) |
---|
123 | print('plot_resp2 returns array of', pts.shape) |
---|
124 | |
---|
125 | # mesh |
---|
126 | mesh = xml.SubElement(dout, 'mesh', {'id': 'm2d'}) |
---|
127 | about = xml.SubElement(mesh, 'about') |
---|
128 | label = xml.SubElement(about, 'label') |
---|
129 | label.text = '2D Mesh' |
---|
130 | xml.SubElement(mesh, 'dim').text = '2' |
---|
131 | xml.SubElement(mesh, 'hide').text = 'yes' |
---|
132 | grid = xml.SubElement(mesh, 'grid') |
---|
133 | xaxis = xml.SubElement(grid, 'xaxis') |
---|
134 | xml.SubElement(xaxis, 'numpoints').text = str(numpoints) |
---|
135 | xml.SubElement(xaxis, 'min').text = str(v1.pdf.range[0]) |
---|
136 | xml.SubElement(xaxis, 'max').text = str(v1.pdf.range[1]) |
---|
137 | yaxis = xml.SubElement(grid, 'yaxis') |
---|
138 | xml.SubElement(yaxis, 'numpoints').text = str(numpoints) |
---|
139 | xml.SubElement(yaxis, 'min').text = str(v2.pdf.range[0]) |
---|
140 | xml.SubElement(yaxis, 'max').text = str(v2.pdf.range[1]) |
---|
141 | |
---|
142 | # field |
---|
143 | field = xml.SubElement(dout, 'field', {'id': 'f2d'}) |
---|
144 | about = xml.SubElement(field, 'about') |
---|
145 | xml.SubElement(xml.SubElement(about, 'xaxis'), 'label').text = v1.label |
---|
146 | xml.SubElement(xml.SubElement(about, 'yaxis'), 'label').text = v2.label |
---|
147 | |
---|
148 | xml.SubElement(about, 'label').text = rlabel |
---|
149 | comp = xml.SubElement(field, 'component') |
---|
150 | xml.SubElement(comp, 'mesh').text = 'output.mesh(m2d)' |
---|
151 | xml.SubElement(about, 'view').text = 'heightmap' |
---|
152 | pts = ' '.join(map(str, pts.ravel('F').tolist())) |
---|
153 | xml.SubElement(comp, 'values').text = pts |
---|
154 | |
---|
155 | |
---|
156 | if __name__ == "__main__": |
---|
157 | print("get_response %s" % sys.argv[3:]) |
---|
158 | |
---|
159 | if len(sys.argv[1:]) == 5: |
---|
160 | fname, resp, var1, var2, label = sys.argv[1:] |
---|
161 | |
---|
162 | droot = xml.Element('run') |
---|
163 | dtree = xml.ElementTree(droot) |
---|
164 | dout = xml.SubElement(droot, 'output') |
---|
165 | |
---|
166 | if var1 == var2: |
---|
167 | plot_resp1(dout, resp, var1, label) |
---|
168 | else: |
---|
169 | plot_resp2(dout, resp, var1, var2, label) |
---|
170 | |
---|
171 | with open(fname, 'w') as f: |
---|
172 | f.write("<?xml version=\"1.0\"?>\n") |
---|
173 | dtree.write(f) |
---|
174 | else: |
---|
175 | print('ERROR: Expected 5 args. Got', sys.argv) |
---|