source: branches/1.4/puq/get_response.py @ 5937

Last change on this file since 5937 was 5937, checked in by mmh, 7 years ago

for 1d response show sample points

  • Property svn:executable set to *
File size: 5.4 KB
Line 
1#!/usr/bin/env python
2"""
3Sample a response function (surrogate model).  Because so much of
4the Rappture internals expect objects to have an associated xml
5object and path, we will return the plot in an xml file.
6"""
7
8from __future__ import print_function
9import sys
10import numpy as np
11import puq
12from puq.jpickle import unpickle
13import xml.etree.ElementTree as xml
14
15
16from itertools import product
17# Redirect stdout and stderr to files for debugging.
18# Append to the files created in get_params.py
19sys.stdout = open("response.out", 'w')
20sys.stderr = open("response.err", 'w')
21
22
23# variable names to labels
24def subs_names(varl, h5):
25    varlist = []
26    for v in varl:
27        try:
28            lab = h5['/input/params/%s' % v[0]].attrs['label']
29        except:
30            lab = str(v[0])
31        varlist.append(lab)
32    return varlist
33
34
35def plot_resp1(dout, resp, name, rlabel):
36    print('plot_resp1', name, rlabel)
37
38    numpoints = 100
39
40    resp = unpickle(resp)
41    var = None
42
43    for index, p in enumerate(resp.params):
44        if p.name == name:
45            var = p
46            break
47
48    if var is None:
49        print("plot_resp1 error: name %s not recognized" % name)
50        return
51
52    data = resp.data
53    print('data=', repr(data))
54    for ind, p in enumerate(resp.params):
55        if ind == index:
56            continue
57        m = p.pdf.mean
58        means = np.isclose(m, data[:, ind], rtol=1e-6, atol=1e-12)
59        data = data[means]
60
61    print("vars=", resp.vars)
62    print("data=", repr(data))
63
64    curve = xml.SubElement(dout, 'curve', {'id': 'response'})
65    about = xml.SubElement(curve, 'about')
66    xml.SubElement(about, 'label').text = rlabel
67    xml.SubElement(about, 'group').text = rlabel
68
69    xaxis = xml.SubElement(curve, 'xaxis')
70    xml.SubElement(xaxis, 'label').text = var.label
71
72    yaxis = xml.SubElement(curve, 'yaxis')
73    xml.SubElement(yaxis, 'label').text = rlabel
74
75    x = np.linspace(*var.pdf.range, num=numpoints)
76
77    allpts = np.empty((numpoints, len(resp.params)))
78    for i, v in enumerate(resp.params):
79        if v.name == var.name:
80            allpts[:, i] = x
81        else:
82            allpts[:, i] = np.mean(v.pdf.mean)
83
84    pts = resp.evala(allpts)
85    xy = '\n'.join([' '.join(map(repr, a)) for a in zip(x, pts)])
86    comp = xml.SubElement(curve, 'component')
87    xml.SubElement(comp, 'xy').text = xy
88
89    # scatter plot sampled data on response surface
90    curve = xml.SubElement(dout, 'curve', {'id': 'scatter'})
91    about = xml.SubElement(curve, 'about')
92    xml.SubElement(about, 'label').text = 'Data Points'
93    xml.SubElement(about, 'group').text = rlabel
94    xml.SubElement(about, 'type').text = 'scatter'
95    comp = xml.SubElement(curve, 'component')
96    xy = '\n'.join([' '.join(map(repr, a)) for a in zip(data[:, index], data[:, -1])])
97    xml.SubElement(comp, 'xy').text = xy
98
99
100def plot_resp2(dout, resp, name1, name2, rlabel):
101    print("plot_resp2", name1, name2, rlabel)
102    numpoints = 50
103
104    resp = unpickle(resp)
105    for p in resp.params:
106        if p.name == name1:
107            v1 = p
108        elif p.name == name2:
109            v2 = p
110
111    x = np.linspace(*v1.pdf.range, num=numpoints)
112    y = np.linspace(*v2.pdf.range, num=numpoints)
113    pts = np.array([(b, a) for a, b in product(y, x)])
114    allpts = np.empty((numpoints**2, len(resp.vars)))
115    for i, v in enumerate(resp.vars):
116        if v[0] == v1.name:
117            allpts[:, i] = pts[:, 0]
118        elif v[0] == v2.name:
119            allpts[:, i] = pts[:, 1]
120        else:
121            allpts[:, i] = np.mean(v[1])
122    pts = np.array(resp.evala(allpts))
123    print('plot_resp2 returns array of', pts.shape)
124
125    # mesh
126    mesh = xml.SubElement(dout, 'mesh', {'id': 'm2d'})
127    about = xml.SubElement(mesh, 'about')
128    label = xml.SubElement(about, 'label')
129    label.text = '2D Mesh'
130    xml.SubElement(mesh, 'dim').text = '2'
131    xml.SubElement(mesh, 'hide').text = 'yes'
132    grid = xml.SubElement(mesh, 'grid')
133    xaxis = xml.SubElement(grid, 'xaxis')
134    xml.SubElement(xaxis, 'numpoints').text = str(numpoints)
135    xml.SubElement(xaxis, 'min').text = str(v1.pdf.range[0])
136    xml.SubElement(xaxis, 'max').text = str(v1.pdf.range[1])
137    yaxis = xml.SubElement(grid, 'yaxis')
138    xml.SubElement(yaxis, 'numpoints').text = str(numpoints)
139    xml.SubElement(yaxis, 'min').text = str(v2.pdf.range[0])
140    xml.SubElement(yaxis, 'max').text = str(v2.pdf.range[1])
141
142    # field
143    field = xml.SubElement(dout, 'field', {'id': 'f2d'})
144    about = xml.SubElement(field, 'about')
145    xml.SubElement(xml.SubElement(about, 'xaxis'), 'label').text = v1.label
146    xml.SubElement(xml.SubElement(about, 'yaxis'), 'label').text = v2.label
147
148    xml.SubElement(about, 'label').text = rlabel
149    comp = xml.SubElement(field, 'component')
150    xml.SubElement(comp, 'mesh').text = 'output.mesh(m2d)'
151    xml.SubElement(about, 'view').text = 'heightmap'
152    pts = ' '.join(map(str, pts.ravel('F').tolist()))
153    xml.SubElement(comp, 'values').text = pts
154
155
156if __name__ == "__main__":
157    print("get_response %s" % sys.argv[3:])
158
159    if len(sys.argv[1:]) == 5:
160        fname, resp, var1, var2, label = sys.argv[1:]
161
162        droot = xml.Element('run')
163        dtree = xml.ElementTree(droot)
164        dout = xml.SubElement(droot, 'output')
165
166        if var1 == var2:
167            plot_resp1(dout, resp, var1, label)
168        else:
169            plot_resp2(dout, resp, var1, var2, label)
170
171        with open(fname, 'w') as f:
172            f.write("<?xml version=\"1.0\"?>\n")
173            dtree.write(f)
174    else:
175        print('ERROR: Expected 5 args. Got', sys.argv)
Note: See TracBrowser for help on using the repository browser.