source: branches/uq/puq/get_response.py @ 5707

Last change on this file since 5707 was 5707, checked in by mmh, 7 years ago

add get_response.py

  • Property svn:executable set to *
File size: 5.2 KB
Line 
1#!/usr/bin/env python
2"""
3Sample a response function (surrogate model).  Because so much of
4the Rappture internals expect objects to have an associated xml
5object and path, we will return the plot in an xml file.
6"""
7
8from __future__ import print_function
9import sys
10import numpy as np
11import puq
12from puq.jpickle import unpickle
13import xml.etree.ElementTree as xml
14
15
16from itertools import product
17# Redirect stdout and stderr to files for debugging.
18# Append to the files created in get_params.py
19sys.stdout = open("response.out", 'w')
20sys.stderr = open("response.err", 'w')
21
22
23# variable names to labels
24def subs_names(varl, h5):
25    varlist = []
26    for v in varl:
27        try:
28            lab = h5['/input/params/%s' % v[0]].attrs['label']
29        except:
30            lab = str(v[0])
31        varlist.append(lab)
32    return varlist
33
34
35def plot_resp1(dout, resp, name, rlabel):
36    print('plot_resp1', name, rlabel)
37
38    numpoints = 100
39
40    resp = unpickle(resp)
41    var = None
42    for index, p in enumerate(resp.params):
43        if p.name == name:
44            var = p
45            break
46
47    if var is None:
48        print("plot_resp1 error: name %s not recognized" % name)
49        return
50
51    data = resp.data
52    print ("vars=", resp.vars)
53    print("data=", data)
54    print("my data=", data.T[index])
55    print("my data=", data[:, index])
56
57    curve = xml.SubElement(dout, 'curve', {'id': 'response'})
58    about = xml.SubElement(curve, 'about')
59    xml.SubElement(about, 'label').text = rlabel
60    xml.SubElement(about, 'group').text = rlabel
61
62    xaxis = xml.SubElement(curve, 'xaxis')
63    xml.SubElement(xaxis, 'label').text = var.label
64
65    yaxis = xml.SubElement(curve, 'yaxis')
66    xml.SubElement(yaxis, 'label').text = rlabel
67
68    x = np.linspace(*var.pdf.range, num=numpoints)
69
70    allpts = np.empty((numpoints, len(resp.params)))
71    for i, v in enumerate(resp.params):
72        if v.name == var.name:
73            allpts[:, i] = x
74        else:
75            allpts[:, i] = np.mean(v.pdf.mean)
76    print("allpts=", allpts)
77    pts = resp.evala(allpts)
78    print("pts=", pts)
79    xy = '\n'.join([' '.join(map(repr, a)) for a in zip(x, pts)])
80    comp = xml.SubElement(curve, 'component')
81    xml.SubElement(comp, 'xy').text = xy
82
83    # scatter plot sampled data on response surface
84    # curve = io['output.curve(response-%s-scatter)' % name]
85    # curve['about.label'] = 'Data Points'
86    # curve['about.group'] = title
87    # curve['about.type'] = 'scatter'
88    # curve['xaxis.label'] = varlist[0]
89    # curve['yaxis.label'] = tname
90    # curve['component.xy'] = (xdata, ydata)
91
92
93def plot_resp2(dout, resp, name1, name2, rlabel):
94    print("plot_resp2", name1, name2, rlabel)
95    numpoints = 50
96
97    resp = unpickle(resp)
98    for p in resp.params:
99        if p.name == name1:
100            v1 = p
101        elif p.name == name2:
102            v2 = p
103
104    x = np.linspace(*v1.pdf.range, num=numpoints)
105    y = np.linspace(*v2.pdf.range, num=numpoints)
106    pts = np.array([(b, a) for a, b in product(y, x)])
107    allpts = np.empty((numpoints**2, len(resp.vars)))
108    for i, v in enumerate(resp.vars):
109        if v[0] == v1.name:
110            allpts[:, i] = pts[:, 0]
111        elif v[0] == v2.name:
112            allpts[:, i] = pts[:, 1]
113        else:
114            allpts[:, i] = np.mean(v[1])
115    pts = np.array(resp.evala(allpts))
116    print('plot_resp2 returns array of', pts.shape)
117
118    # mesh
119    mesh = xml.SubElement(dout, 'mesh', {'id': 'm2d'})
120    about = xml.SubElement(mesh, 'about')
121    label = xml.SubElement(about, 'label')
122    label.text = '2D Mesh'
123    xml.SubElement(mesh, 'dim').text = '2'
124    xml.SubElement(mesh, 'hide').text = 'yes'
125    grid = xml.SubElement(mesh, 'grid')
126    xaxis = xml.SubElement(grid, 'xaxis')
127    xml.SubElement(xaxis, 'numpoints').text = str(numpoints)
128    xml.SubElement(xaxis, 'min').text = str(v1.pdf.range[0])
129    xml.SubElement(xaxis, 'max').text = str(v1.pdf.range[1])
130    yaxis = xml.SubElement(grid, 'yaxis')
131    xml.SubElement(yaxis, 'numpoints').text = str(numpoints)
132    xml.SubElement(yaxis, 'min').text = str(v2.pdf.range[0])
133    xml.SubElement(yaxis, 'max').text = str(v2.pdf.range[1])
134
135    # field
136    field = xml.SubElement(dout, 'field', {'id': 'f2d'})
137    about = xml.SubElement(field, 'about')
138    xml.SubElement(xml.SubElement(about, 'xaxis'), 'label').text = v1.label
139    xml.SubElement(xml.SubElement(about, 'yaxis'), 'label').text = v2.label
140
141    xml.SubElement(about, 'label').text = rlabel
142    comp = xml.SubElement(field, 'component')
143    xml.SubElement(comp, 'mesh').text = 'output.mesh(m2d)'
144    xml.SubElement(about, 'view').text = 'heightmap'
145    pts = ' '.join(map(str, pts.ravel('F').tolist()))
146    xml.SubElement(comp, 'values').text = pts
147
148
149if __name__ == "__main__":
150    print("get_response %s" % sys.argv[3:])
151
152    if len(sys.argv[1:]) == 5:
153        fname, resp, var1, var2, label = sys.argv[1:]
154
155        droot = xml.Element('run')
156        dtree = xml.ElementTree(droot)
157        dout = xml.SubElement(droot, 'output')
158
159        if var1 == var2:
160            plot_resp1(dout, resp, var1, label)
161        else:
162            plot_resp2(dout, resp, var1, var2, label)
163
164        with open(fname, 'w') as f:
165            f.write("<?xml version=\"1.0\"?>\n")
166            dtree.write(f)
167    else:
168        print('ERROR: Expected 5 args. Got', sys.argv)
Note: See TracBrowser for help on using the repository browser.