ibitfit/electric_sieve/pgfw/Interpolator.py

734 lines
26 KiB
Python
Raw Normal View History

2014-04-25 13:22:01 -04:00
from re import match
from os.path import join
from tempfile import gettempdir
from pygame import Surface
from pygame.font import Font
from pygame.draw import aaline
from pygame.locals import *
from GameChild import GameChild
from Sprite import Sprite
from Animation import Animation
class Interpolator(list, GameChild):
def __init__(self, parent):
GameChild.__init__(self, parent)
self.set_nodesets()
self.gui_enabled = self.check_command_line("-interpolator")
if self.gui_enabled:
self.gui = GUI(self)
def set_nodesets(self):
config = self.get_configuration()
if config.has_section("interpolate"):
for name, value in config.get_section("interpolate").iteritems():
self.add_nodeset(name, value)
def add_nodeset(self, name, value, method=None):
self.append(Nodeset(name, value, method))
return len(self) - 1
def is_gui_active(self):
return self.gui_enabled and self.gui.active
def get_nodeset(self, name):
for nodeset in self:
if nodeset.name == name:
return nodeset
def remove(self, outgoing):
for ii, nodeset in enumerate(self):
if nodeset.name == outgoing.name:
self.pop(ii)
break
class Nodeset(list):
LINEAR, CUBIC = range(2)
def __init__(self, name, nodes, method=None):
list.__init__(self, [])
self.name = name
if isinstance(nodes, str):
self.parse_raw(nodes)
else:
self.interpolation_method = method
self.parse_list(nodes)
self.set_splines()
def parse_raw(self, raw):
raw = raw.strip()
if raw[0].upper() == "L":
self.set_interpolation_method(self.LINEAR, False)
else:
self.set_interpolation_method(self.CUBIC, False)
for node in raw[1:].strip().split(","):
self.add_node(map(float, node.strip().split()), False)
def set_interpolation_method(self, method, refresh=True):
self.interpolation_method = method
if refresh:
self.set_splines()
def add_node(self, coordinates, refresh=True):
x = coordinates[0]
inserted = False
index = 0
for ii, node in enumerate(self):
if x < node.x:
self.insert(ii, Node(coordinates))
inserted = True
index = ii
break
elif x == node.x:
return None
if not inserted:
self.append(Node(coordinates))
index = len(self) - 1
if refresh:
self.set_splines()
return index
def parse_list(self, nodes):
for node in nodes:
self.add_node(node)
def set_splines(self):
if self.interpolation_method == self.LINEAR:
self.set_linear_splines()
else:
self.set_cubic_splines()
def set_linear_splines(self):
self.splines = splines = []
for ii in xrange(len(self) - 1):
x1, y1, x2, y2 = self[ii] + self[ii + 1]
m = float(y2 - y1) / (x2 - x1)
splines.append(LinearSpline(x1, y1, m))
def set_cubic_splines(self):
n = len(self) - 1
a = [node.y for node in self]
b = [None] * n
d = [None] * n
h = [self[ii + 1].x - self[ii].x for ii in xrange(n)]
alpha = [None] + [(3.0 / h[ii]) * (a[ii + 1] - a[ii]) - \
(3.0 / h[ii - 1]) * (a[ii] - a[ii - 1]) \
for ii in xrange(1, n)]
c = [None] * (n + 1)
l = [None] * (n + 1)
u = [None] * (n + 1)
z = [None] * (n + 1)
l[0] = 1
u[0] = z[0] = 0
for ii in xrange(1, n):
l[ii] = 2 * (self[ii + 1].x - self[ii - 1].x) - \
h[ii - 1] * u[ii - 1]
u[ii] = h[ii] / l[ii]
z[ii] = (alpha[ii] - h[ii - 1] * z[ii - 1]) / l[ii]
l[n] = 1
z[n] = c[n] = 0
for jj in xrange(n - 1, -1, -1):
c[jj] = z[jj] - u[jj] * c[jj + 1]
b[jj] = (a[jj + 1] - a[jj]) / h[jj] - \
(h[jj] * (c[jj + 1] + 2 * c[jj])) / 3
d[jj] = (c[jj + 1] - c[jj]) / (3 * h[jj])
self.splines = [CubicSpline(self[ii].x, a[ii], b[ii], c[ii],
d[ii]) for ii in xrange(n)]
def get_y(self, t, loop=False, reverse=False, natural=False):
if loop or reverse:
if reverse and int(t) / int(self[-1].x) % 2:
t = self[-1].x - t
t %= self[-1].x
elif not natural:
if t < self[0].x:
t = self[0].x
elif t > self[-1].x:
t = self[-1].x
splines = self.splines
for ii in xrange(len(splines) - 1):
if t < splines[ii + 1].x:
return splines[ii].get_y(t)
return splines[-1].get_y(t)
def remove(self, node, refresh=True):
list.remove(self, node)
if refresh:
self.set_splines()
def resize(self, left, length, refresh=True):
old_left = self[0].x
old_length = self.get_length()
for node in self:
node.x = left + length * (node.x - old_left) / old_length
if refresh:
self.set_splines()
def get_length(self):
return self[-1].x - self[0].x
class Node(list):
def __init__(self, coordinates):
list.__init__(self, coordinates)
def __getattr__(self, name):
if name == "x":
return self[0]
elif name == "y":
return self[1]
return list.__get__(self, name)
def __setattr__(self, name, value):
if name == "x":
list.__setitem__(self, 0, value)
elif name == "y":
list.__setitem__(self, 1, value)
else:
list.__setattr__(self, name, value)
class Spline:
def __init__(self, x):
self.x = x
class CubicSpline(Spline):
def __init__(self, x, a, b, c, d):
Spline.__init__(self, x)
self.a = a
self.b = b
self.c = c
self.d = d
def get_y(self, t):
x = self.x
return self.a + self.b * (t - x) + self.c * (t - x) ** 2 + self.d * \
(t - x) ** 3
class LinearSpline(Spline):
def __init__(self, x, y, m):
Spline.__init__(self, x)
self.y = y
self.m = m
def get_y(self, t):
return self.m * (t - self.x) + self.y
class GUI(Animation):
B_DUPLICATE, B_WRITE, B_DELETE, B_LINEAR, B_CUBIC, B_SPLIT = range(6)
S_NONE, S_LEFT, S_RIGHT = range(3)
def __init__(self, parent):
Animation.__init__(self, parent, unfiltered=True)
self.audio = self.get_audio()
self.display = self.get_game().display
self.display_surface = self.get_display_surface()
self.time_filter = self.get_game().time_filter
self.delegate = self.get_delegate()
self.split = self.S_NONE
self.success_indicator_active = True
self.success_indicator_blink_count = 0
self.load_configuration()
self.font = Font(None, self.label_size)
self.prompt = Prompt(self)
self.set_temporary_file()
self.set_background()
self.set_success_indicator()
self.set_plot_rect()
self.set_marker_frame()
self.set_buttons()
self.active = False
self.set_nodeset_index()
self.set_y_range()
self.set_markers()
self.subscribe(self.respond_to_command)
self.subscribe(self.respond_to_mouse_down, MOUSEBUTTONDOWN)
self.subscribe(self.respond_to_key, KEYDOWN)
self.register(self.show_success_indicator, interval=100)
self.register(self.save_temporary_file, interval=10000)
self.play(self.save_temporary_file)
def load_configuration(self):
config = self.get_configuration("interpolator-gui")
self.label_size = config["label-size"]
self.axis_label_count = config["axis-label-count"]
self.margin = config["margin"]
self.curve_color = config["curve-color"]
self.marker_size = config["marker-size"]
self.marker_color = config["marker-color"]
self.label_precision = config["label-precision"]
self.template_nodeset = config["template-nodeset"]
self.template_nodeset_name = config["template-nodeset-name"]
self.flat_y_range = config["flat-y-range"]
def set_temporary_file(self):
self.temporary_file = open(join(gettempdir(), "pgfw-config"), "w")
def set_background(self):
surface = Surface(self.display_surface.get_size())
surface.fill((0, 0, 0))
self.background = surface
def set_success_indicator(self):
surface = Surface((10, 10))
surface.fill((0, 255, 0))
rect = surface.get_rect()
rect.topleft = self.display_surface.get_rect().topleft
self.success_indicator, self.success_indicator_rect = surface, rect
def set_plot_rect(self):
margin = self.margin
self.plot_rect = self.display_surface.get_rect().inflate(-margin,
-margin)
def set_marker_frame(self):
size = self.marker_size
surface = Surface((size, size))
transparent_color = (255, 0, 255)
surface.fill(transparent_color)
surface.set_colorkey(transparent_color)
line_color = self.marker_color
aaline(surface, line_color, (0, 0), (size - 1, size - 1))
aaline(surface, line_color, (0, size - 1), (size - 1, 0))
self.marker_frame = surface
def set_buttons(self):
self.buttons = buttons = []
text = "Duplicate", "Write", "Delete", "Linear", "Cubic", "Split: No"
x = 0
for instruction in text:
buttons.append(Button(self, instruction, x))
x += buttons[-1].location.w + 10
def set_nodeset_index(self, increment=None, index=None):
parent = self.parent
if index is None:
if not increment:
index = 0
else:
index = self.nodeset_index + increment
limit = len(parent) - 1
if index > limit:
index = 0
elif index < 0:
index = limit
self.nodeset_index = index
self.set_nodeset_label()
def set_nodeset_label(self):
surface = self.font.render(self.get_nodeset().name, True, (0, 0, 0),
(255, 255, 255))
rect = surface.get_rect()
rect.bottomright = self.display_surface.get_rect().bottomright
self.nodeset_label, self.nodeset_label_rect = surface, rect
def get_nodeset(self):
if not len(self.parent):
self.parent.add_nodeset(self.template_nodeset_name,
self.template_nodeset)
self.set_nodeset_index(0)
return self.parent[self.nodeset_index]
def set_y_range(self):
width = self.plot_rect.w
nodeset = self.get_nodeset()
self.y_range = y_range = [nodeset[0].y, nodeset[-1].y]
x = 0
while x < width:
y = nodeset.get_y(self.get_function_coordinates(x)[0])
if y < y_range[0]:
y_range[0] = y
elif y > y_range[1]:
y_range[1] = y
x += width * .01
if y_range[1] - y_range[0] == 0:
y_range[1] += self.flat_y_range
if self.split:
self.adjust_for_split(y_range, nodeset)
self.set_axis_labels()
def get_function_coordinates(self, xp=0, yp=0):
nodeset = self.get_nodeset()
x_min, x_max, (y_min, y_max) = nodeset[0].x, nodeset[-1].x, self.y_range
rect = self.plot_rect
x = float(xp) / (rect.right - rect.left) * (x_max - x_min) + x_min
y = float(yp) / (rect.bottom - rect.top) * (y_min - y_max) + y_max
return x, y
def adjust_for_split(self, y_range, nodeset):
middle = nodeset[0].y if self.split == self.S_LEFT else nodeset[-1].y
below, above = middle - y_range[0], y_range[1] - middle
if below > above:
y_range[1] += below - above
else:
y_range[0] -= above - below
def set_axis_labels(self):
self.axis_labels = labels = []
nodeset, formatted, render, rect, yr = (self.get_nodeset(),
self.get_formatted_measure,
self.font.render,
self.plot_rect, self.y_range)
for ii, node in enumerate(nodeset[0::len(nodeset) - 1]):
xs = render(formatted(node.x), True, (0, 0, 0), (255, 255, 255))
xsr = xs.get_rect()
xsr.top = rect.bottom
if not ii:
xsr.left = rect.left
else:
xsr.right = rect.right
ys = render(formatted(yr[ii]), True, (0, 0, 0), (255, 255, 255))
ysr = ys.get_rect()
ysr.right = rect.left
if not ii:
ysr.bottom = rect.bottom
else:
ysr.top = rect.top
labels.append(((xs, xsr), (ys, ysr)))
def get_formatted_measure(self, measure):
return "%s" % float(("%." + str(self.label_precision) + "g") % measure)
def deactivate(self):
self.active = False
self.time_filter.open()
self.audio.muted = self.saved_mute_state
self.display.set_mouse_visibility(self.saved_mouse_state)
def respond_to_command(self, event):
compare = self.delegate.compare
if compare(event, "toggle-interpolator"):
self.toggle()
elif self.active:
if compare(event, "reset-game"):
self.deactivate()
elif compare(event, "quit"):
self.get_game().end(event)
def toggle(self):
if self.active:
self.deactivate()
else:
self.activate()
def activate(self):
self.active = True
self.time_filter.close()
self.saved_mute_state = self.audio.muted
self.audio.mute()
self.draw()
self.saved_mouse_state = self.display.set_mouse_visibility(True)
def respond_to_mouse_down(self, event):
redraw = False
if self.active and not self.prompt.active:
nodeset_rect = self.nodeset_label_rect
plot_rect = self.plot_rect
if event.button == 1:
pos = event.pos
if nodeset_rect.collidepoint(pos):
self.set_nodeset_index(1)
redraw = True
elif self.axis_labels[0][0][1].collidepoint(pos):
text = "{0} {1}".format(*map(self.get_formatted_measure,
self.get_nodeset()[0]))
self.prompt.activate(text, self.resize_nodeset, 0)
elif self.axis_labels[1][0][1].collidepoint(pos):
text = "{0} {1}".format(*map(self.get_formatted_measure,
self.get_nodeset()[-1]))
self.prompt.activate(text, self.resize_nodeset, -1)
else:
bi = self.collide_buttons(pos)
if bi is not None:
if bi == self.B_WRITE:
self.get_configuration().write()
self.play(self.show_success_indicator)
elif bi in (self.B_LINEAR, self.B_CUBIC):
nodeset = self.get_nodeset()
if bi == self.B_LINEAR:
nodeset.set_interpolation_method(Nodeset.LINEAR)
else:
nodeset.set_interpolation_method(Nodeset.CUBIC)
self.store_in_configuration()
redraw = True
elif bi == self.B_DUPLICATE:
self.prompt.activate("", self.add_nodeset)
elif bi == self.B_DELETE and len(self.parent) > 1:
self.parent.remove(self.get_nodeset())
self.set_nodeset_index(1)
self.store_in_configuration()
redraw = True
elif bi == self.B_SPLIT:
self.toggle_split()
redraw = True
elif plot_rect.collidepoint(pos) and \
not self.collide_markers(pos):
xp, yp = pos[0] - plot_rect.left, pos[1] - plot_rect.top
self.get_nodeset().add_node(
self.get_function_coordinates(xp, yp))
self.store_in_configuration()
redraw = True
elif event.button == 3:
pos = event.pos
if nodeset_rect.collidepoint(pos):
self.set_nodeset_index(-1)
redraw = True
elif plot_rect.collidepoint(pos):
marker = self.collide_markers(pos)
if marker:
self.get_nodeset().remove(marker.node)
self.store_in_configuration()
redraw = True
elif self.active and self.prompt.active and \
not self.prompt.rect.collidepoint(event.pos):
self.prompt.deactivate()
redraw = True
if redraw:
self.set_y_range()
self.set_markers()
self.draw()
def resize_nodeset(self, text, index):
result = match("^\s*(-{,1}\d*\.{,1}\d*)\s+(-{,1}\d*\.{,1}\d*)\s*$",
text)
if result:
try:
nodeset = self.get_nodeset()
x, y = map(float, result.group(1, 2))
if (index == -1 and x > nodeset[0].x) or \
(index == 0 and x < nodeset[-1].x):
nodeset[index].y = y
if index == -1:
nodeset.resize(nodeset[0].x, x - nodeset[0].x)
else:
nodeset.resize(x, nodeset[-1].x - x)
self.store_in_configuration()
self.set_y_range()
self.set_axis_labels()
self.set_markers()
self.draw()
return True
except ValueError:
return False
def collide_buttons(self, pos):
for ii, button in enumerate(self.buttons):
if button.location.collidepoint(pos):
return ii
def store_in_configuration(self):
config = self.get_configuration()
section = "interpolate"
config.clear_section(section)
for nodeset in self.parent:
code = "L" if nodeset.interpolation_method == Nodeset.LINEAR else \
"C"
for ii, node in enumerate(nodeset):
if ii > 0:
code += ","
code += " {0} {1}".format(*map(self.get_formatted_measure,
node))
if not config.has_section(section):
config.add_section(section)
config.set(section, nodeset.name, code)
def toggle_split(self):
self.split += 1
if self.split > self.S_RIGHT:
self.split = self.S_NONE
self.buttons[self.B_SPLIT].set_frame(["Split: No", "Split: L",
"Split: R"][self.split])
def add_nodeset(self, name):
nodeset = self.get_nodeset()
self.set_nodeset_index(index=self.parent.add_nodeset(\
name, nodeset, nodeset.interpolation_method))
self.store_in_configuration()
self.draw()
return True
def collide_markers(self, pos):
for marker in self.markers:
if marker.location.collidepoint(pos):
return marker
def set_markers(self):
self.markers = markers = []
for node in self.get_nodeset()[1:-1]:
markers.append(Marker(self, node))
markers[-1].location.center = self.get_plot_coordinates(*node)
def get_plot_coordinates(self, x=0, y=0):
nodeset = self.get_nodeset()
x_min, x_max, (y_min, y_max) = nodeset[0].x, nodeset[-1].x, self.y_range
x_ratio = float(x - x_min) / (x_max - x_min)
rect = self.plot_rect
xp = x_ratio * (rect.right - rect.left) + rect.left
y_ratio = float(y - y_min) / (y_max - y_min)
yp = rect.bottom - y_ratio * (rect.bottom - rect.top)
return xp, yp
def draw(self):
display_surface = self.display_surface
display_surface.blit(self.background, (0, 0))
display_surface.blit(self.nodeset_label, self.nodeset_label_rect)
self.draw_axes()
self.draw_function()
self.draw_markers()
self.draw_buttons()
def draw_axes(self):
display_surface = self.display_surface
for xl, yl in self.axis_labels:
display_surface.blit(*xl)
display_surface.blit(*yl)
def draw_function(self):
rect = self.plot_rect
surface = self.display_surface
nodeset = self.get_nodeset()
step = 1
for x in xrange(rect.left, rect.right + step, step):
ii = x - rect.left
fx = nodeset.get_y(self.get_function_coordinates(ii)[0])
y = self.get_plot_coordinates(y=fx)[1]
if ii > 0:
aaline(surface, self.curve_color, (x - step, last_y), (x, y))
last_y = y
def draw_markers(self):
for marker in self.markers:
marker.update()
def draw_buttons(self):
for button in self.buttons:
button.update()
def respond_to_key(self, event):
if self.prompt.active:
prompt = self.prompt
if event.key == K_RETURN:
if prompt.callback[0](prompt.text, *prompt.callback[1]):
prompt.deactivate()
elif event.key == K_BACKSPACE:
prompt.text = prompt.text[:-1]
prompt.update()
prompt.draw_text()
elif (event.unicode.isalnum() or event.unicode.isspace() or \
event.unicode in (".", "-", "_")) and len(prompt.text) < \
prompt.character_limit:
prompt.text += event.unicode
prompt.update()
prompt.draw_text()
def show_success_indicator(self):
self.draw()
if self.success_indicator_blink_count > 1:
self.success_indicator_blink_count = 0
self.halt(self.show_success_indicator)
else:
if self.success_indicator_active:
self.display_surface.blit(self.success_indicator,
self.success_indicator_rect)
if self.success_indicator_active:
self.success_indicator_blink_count += 1
self.success_indicator_active = not self.success_indicator_active
def save_temporary_file(self):
fp = self.temporary_file
fp.seek(0)
fp.truncate()
self.get_configuration().write(fp)
def rearrange(self):
self.set_background()
self.set_success_indicator()
self.set_plot_rect()
self.set_markers()
self.set_nodeset_label()
self.set_axis_labels()
self.set_buttons()
self.prompt.reset()
class Marker(Sprite):
def __init__(self, parent, node):
Sprite.__init__(self, parent)
self.add_frame(parent.marker_frame)
self.node = node
class Button(Sprite):
def __init__(self, parent, text, left):
Sprite.__init__(self, parent)
self.set_frame(text)
self.location.bottomleft = left, \
self.get_display_surface().get_rect().bottom
def set_frame(self, text):
self.clear_frames()
self.add_frame(self.parent.font.render(text, True, (0, 0, 0),
(255, 255, 255)))
class Prompt(Sprite):
def __init__(self, parent):
Sprite.__init__(self, parent)
self.load_configuration()
self.font = Font(None, self.text_size)
self.reset()
self.deactivate()
def deactivate(self):
self.active = False
def load_configuration(self):
config = self.get_configuration("interpolator-gui")
self.size = config["prompt-size"]
self.border_color = config["prompt-border-color"]
self.border_width = config["prompt-border-width"]
self.character_limit = config["prompt-character-limit"]
self.text_size = config["prompt-text-size"]
def reset(self):
self.set_frame()
self.place()
def set_frame(self):
self.clear_frames()
surface = Surface(self.size)
self.add_frame(surface)
surface.fill(self.border_color)
width = self.border_width * 2
surface.fill((0, 0, 0), surface.get_rect().inflate(-width, -width))
def place(self):
self.location.center = self.display_surface.get_rect().center
def activate(self, text, callback, *args):
self.active = True
self.text = str(text)
self.callback = callback, args
self.update()
self.draw_text()
def draw_text(self):
surface = self.font.render(self.text, True, (255, 255, 255), (0, 0, 0))
rect = surface.get_rect()
rect.center = self.location.center
self.display_surface.blit(surface, rect)