[5707] | 1 | #!/usr/bin/env python |
---|
| 2 | """ |
---|
| 3 | Sample a response function (surrogate model). Because so much of |
---|
| 4 | the Rappture internals expect objects to have an associated xml |
---|
| 5 | object and path, we will return the plot in an xml file. |
---|
| 6 | """ |
---|
| 7 | |
---|
| 8 | from __future__ import print_function |
---|
| 9 | import sys |
---|
| 10 | import numpy as np |
---|
| 11 | import puq |
---|
| 12 | from puq.jpickle import unpickle |
---|
| 13 | import xml.etree.ElementTree as xml |
---|
| 14 | |
---|
| 15 | |
---|
| 16 | from itertools import product |
---|
| 17 | # Redirect stdout and stderr to files for debugging. |
---|
| 18 | # Append to the files created in get_params.py |
---|
| 19 | sys.stdout = open("response.out", 'w') |
---|
| 20 | sys.stderr = open("response.err", 'w') |
---|
| 21 | |
---|
| 22 | |
---|
| 23 | # variable names to labels |
---|
| 24 | def subs_names(varl, h5): |
---|
| 25 | varlist = [] |
---|
| 26 | for v in varl: |
---|
| 27 | try: |
---|
| 28 | lab = h5['/input/params/%s' % v[0]].attrs['label'] |
---|
| 29 | except: |
---|
| 30 | lab = str(v[0]) |
---|
| 31 | varlist.append(lab) |
---|
| 32 | return varlist |
---|
| 33 | |
---|
| 34 | |
---|
| 35 | def plot_resp1(dout, resp, name, rlabel): |
---|
| 36 | print('plot_resp1', name, rlabel) |
---|
| 37 | |
---|
| 38 | numpoints = 100 |
---|
| 39 | |
---|
| 40 | resp = unpickle(resp) |
---|
| 41 | var = None |
---|
[5937] | 42 | |
---|
[5707] | 43 | for index, p in enumerate(resp.params): |
---|
| 44 | if p.name == name: |
---|
| 45 | var = p |
---|
| 46 | break |
---|
| 47 | |
---|
| 48 | if var is None: |
---|
| 49 | print("plot_resp1 error: name %s not recognized" % name) |
---|
| 50 | return |
---|
| 51 | |
---|
| 52 | data = resp.data |
---|
[5937] | 53 | print('data=', repr(data)) |
---|
| 54 | for ind, p in enumerate(resp.params): |
---|
| 55 | if ind == index: |
---|
| 56 | continue |
---|
| 57 | m = p.pdf.mean |
---|
| 58 | means = np.isclose(m, data[:, ind], rtol=1e-6, atol=1e-12) |
---|
| 59 | data = data[means] |
---|
[5707] | 60 | |
---|
[5937] | 61 | print("vars=", resp.vars) |
---|
| 62 | print("data=", repr(data)) |
---|
| 63 | |
---|
[5707] | 64 | curve = xml.SubElement(dout, 'curve', {'id': 'response'}) |
---|
| 65 | about = xml.SubElement(curve, 'about') |
---|
| 66 | xml.SubElement(about, 'label').text = rlabel |
---|
| 67 | xml.SubElement(about, 'group').text = rlabel |
---|
| 68 | |
---|
| 69 | xaxis = xml.SubElement(curve, 'xaxis') |
---|
| 70 | xml.SubElement(xaxis, 'label').text = var.label |
---|
| 71 | |
---|
| 72 | yaxis = xml.SubElement(curve, 'yaxis') |
---|
| 73 | xml.SubElement(yaxis, 'label').text = rlabel |
---|
| 74 | |
---|
| 75 | x = np.linspace(*var.pdf.range, num=numpoints) |
---|
| 76 | |
---|
| 77 | allpts = np.empty((numpoints, len(resp.params))) |
---|
| 78 | for i, v in enumerate(resp.params): |
---|
| 79 | if v.name == var.name: |
---|
| 80 | allpts[:, i] = x |
---|
| 81 | else: |
---|
| 82 | allpts[:, i] = np.mean(v.pdf.mean) |
---|
[5937] | 83 | |
---|
[5707] | 84 | pts = resp.evala(allpts) |
---|
| 85 | xy = '\n'.join([' '.join(map(repr, a)) for a in zip(x, pts)]) |
---|
| 86 | comp = xml.SubElement(curve, 'component') |
---|
| 87 | xml.SubElement(comp, 'xy').text = xy |
---|
| 88 | |
---|
| 89 | # scatter plot sampled data on response surface |
---|
[5937] | 90 | curve = xml.SubElement(dout, 'curve', {'id': 'scatter'}) |
---|
| 91 | about = xml.SubElement(curve, 'about') |
---|
| 92 | xml.SubElement(about, 'label').text = 'Data Points' |
---|
| 93 | xml.SubElement(about, 'group').text = rlabel |
---|
| 94 | xml.SubElement(about, 'type').text = 'scatter' |
---|
| 95 | comp = xml.SubElement(curve, 'component') |
---|
| 96 | xy = '\n'.join([' '.join(map(repr, a)) for a in zip(data[:, index], data[:, -1])]) |
---|
| 97 | xml.SubElement(comp, 'xy').text = xy |
---|
[5707] | 98 | |
---|
| 99 | |
---|
| 100 | def plot_resp2(dout, resp, name1, name2, rlabel): |
---|
| 101 | print("plot_resp2", name1, name2, rlabel) |
---|
| 102 | numpoints = 50 |
---|
| 103 | |
---|
| 104 | resp = unpickle(resp) |
---|
| 105 | for p in resp.params: |
---|
| 106 | if p.name == name1: |
---|
| 107 | v1 = p |
---|
| 108 | elif p.name == name2: |
---|
| 109 | v2 = p |
---|
| 110 | |
---|
| 111 | x = np.linspace(*v1.pdf.range, num=numpoints) |
---|
| 112 | y = np.linspace(*v2.pdf.range, num=numpoints) |
---|
| 113 | pts = np.array([(b, a) for a, b in product(y, x)]) |
---|
| 114 | allpts = np.empty((numpoints**2, len(resp.vars))) |
---|
| 115 | for i, v in enumerate(resp.vars): |
---|
| 116 | if v[0] == v1.name: |
---|
| 117 | allpts[:, i] = pts[:, 0] |
---|
| 118 | elif v[0] == v2.name: |
---|
| 119 | allpts[:, i] = pts[:, 1] |
---|
| 120 | else: |
---|
| 121 | allpts[:, i] = np.mean(v[1]) |
---|
| 122 | pts = np.array(resp.evala(allpts)) |
---|
| 123 | print('plot_resp2 returns array of', pts.shape) |
---|
| 124 | |
---|
| 125 | # mesh |
---|
| 126 | mesh = xml.SubElement(dout, 'mesh', {'id': 'm2d'}) |
---|
| 127 | about = xml.SubElement(mesh, 'about') |
---|
| 128 | label = xml.SubElement(about, 'label') |
---|
| 129 | label.text = '2D Mesh' |
---|
| 130 | xml.SubElement(mesh, 'dim').text = '2' |
---|
| 131 | xml.SubElement(mesh, 'hide').text = 'yes' |
---|
| 132 | grid = xml.SubElement(mesh, 'grid') |
---|
| 133 | xaxis = xml.SubElement(grid, 'xaxis') |
---|
| 134 | xml.SubElement(xaxis, 'numpoints').text = str(numpoints) |
---|
| 135 | xml.SubElement(xaxis, 'min').text = str(v1.pdf.range[0]) |
---|
| 136 | xml.SubElement(xaxis, 'max').text = str(v1.pdf.range[1]) |
---|
| 137 | yaxis = xml.SubElement(grid, 'yaxis') |
---|
| 138 | xml.SubElement(yaxis, 'numpoints').text = str(numpoints) |
---|
| 139 | xml.SubElement(yaxis, 'min').text = str(v2.pdf.range[0]) |
---|
| 140 | xml.SubElement(yaxis, 'max').text = str(v2.pdf.range[1]) |
---|
| 141 | |
---|
| 142 | # field |
---|
| 143 | field = xml.SubElement(dout, 'field', {'id': 'f2d'}) |
---|
| 144 | about = xml.SubElement(field, 'about') |
---|
| 145 | xml.SubElement(xml.SubElement(about, 'xaxis'), 'label').text = v1.label |
---|
| 146 | xml.SubElement(xml.SubElement(about, 'yaxis'), 'label').text = v2.label |
---|
| 147 | |
---|
| 148 | xml.SubElement(about, 'label').text = rlabel |
---|
| 149 | comp = xml.SubElement(field, 'component') |
---|
| 150 | xml.SubElement(comp, 'mesh').text = 'output.mesh(m2d)' |
---|
| 151 | xml.SubElement(about, 'view').text = 'heightmap' |
---|
| 152 | pts = ' '.join(map(str, pts.ravel('F').tolist())) |
---|
| 153 | xml.SubElement(comp, 'values').text = pts |
---|
| 154 | |
---|
| 155 | |
---|
| 156 | if __name__ == "__main__": |
---|
| 157 | print("get_response %s" % sys.argv[3:]) |
---|
| 158 | |
---|
| 159 | if len(sys.argv[1:]) == 5: |
---|
| 160 | fname, resp, var1, var2, label = sys.argv[1:] |
---|
| 161 | |
---|
| 162 | droot = xml.Element('run') |
---|
| 163 | dtree = xml.ElementTree(droot) |
---|
| 164 | dout = xml.SubElement(droot, 'output') |
---|
| 165 | |
---|
| 166 | if var1 == var2: |
---|
| 167 | plot_resp1(dout, resp, var1, label) |
---|
| 168 | else: |
---|
| 169 | plot_resp2(dout, resp, var1, var2, label) |
---|
| 170 | |
---|
| 171 | with open(fname, 'w') as f: |
---|
| 172 | f.write("<?xml version=\"1.0\"?>\n") |
---|
| 173 | dtree.write(f) |
---|
| 174 | else: |
---|
| 175 | print('ERROR: Expected 5 args. Got', sys.argv) |
---|