ibitfit/electric_sieve/pgfw/Interpolator.py

734 lines
26 KiB
Python

from re import match
from os.path import join
from tempfile import gettempdir
from pygame import Surface
from pygame.font import Font
from pygame.draw import aaline
from pygame.locals import *
from GameChild import GameChild
from Sprite import Sprite
from Animation import Animation
class Interpolator(list, GameChild):
def __init__(self, parent):
GameChild.__init__(self, parent)
self.set_nodesets()
self.gui_enabled = self.check_command_line("-interpolator")
if self.gui_enabled:
self.gui = GUI(self)
def set_nodesets(self):
config = self.get_configuration()
if config.has_section("interpolate"):
for name, value in config.get_section("interpolate").iteritems():
self.add_nodeset(name, value)
def add_nodeset(self, name, value, method=None):
self.append(Nodeset(name, value, method))
return len(self) - 1
def is_gui_active(self):
return self.gui_enabled and self.gui.active
def get_nodeset(self, name):
for nodeset in self:
if nodeset.name == name:
return nodeset
def remove(self, outgoing):
for ii, nodeset in enumerate(self):
if nodeset.name == outgoing.name:
self.pop(ii)
break
class Nodeset(list):
LINEAR, CUBIC = range(2)
def __init__(self, name, nodes, method=None):
list.__init__(self, [])
self.name = name
if isinstance(nodes, str):
self.parse_raw(nodes)
else:
self.interpolation_method = method
self.parse_list(nodes)
self.set_splines()
def parse_raw(self, raw):
raw = raw.strip()
if raw[0].upper() == "L":
self.set_interpolation_method(self.LINEAR, False)
else:
self.set_interpolation_method(self.CUBIC, False)
for node in raw[1:].strip().split(","):
self.add_node(map(float, node.strip().split()), False)
def set_interpolation_method(self, method, refresh=True):
self.interpolation_method = method
if refresh:
self.set_splines()
def add_node(self, coordinates, refresh=True):
x = coordinates[0]
inserted = False
index = 0
for ii, node in enumerate(self):
if x < node.x:
self.insert(ii, Node(coordinates))
inserted = True
index = ii
break
elif x == node.x:
return None
if not inserted:
self.append(Node(coordinates))
index = len(self) - 1
if refresh:
self.set_splines()
return index
def parse_list(self, nodes):
for node in nodes:
self.add_node(node)
def set_splines(self):
if self.interpolation_method == self.LINEAR:
self.set_linear_splines()
else:
self.set_cubic_splines()
def set_linear_splines(self):
self.splines = splines = []
for ii in xrange(len(self) - 1):
x1, y1, x2, y2 = self[ii] + self[ii + 1]
m = float(y2 - y1) / (x2 - x1)
splines.append(LinearSpline(x1, y1, m))
def set_cubic_splines(self):
n = len(self) - 1
a = [node.y for node in self]
b = [None] * n
d = [None] * n
h = [self[ii + 1].x - self[ii].x for ii in xrange(n)]
alpha = [None] + [(3.0 / h[ii]) * (a[ii + 1] - a[ii]) - \
(3.0 / h[ii - 1]) * (a[ii] - a[ii - 1]) \
for ii in xrange(1, n)]
c = [None] * (n + 1)
l = [None] * (n + 1)
u = [None] * (n + 1)
z = [None] * (n + 1)
l[0] = 1
u[0] = z[0] = 0
for ii in xrange(1, n):
l[ii] = 2 * (self[ii + 1].x - self[ii - 1].x) - \
h[ii - 1] * u[ii - 1]
u[ii] = h[ii] / l[ii]
z[ii] = (alpha[ii] - h[ii - 1] * z[ii - 1]) / l[ii]
l[n] = 1
z[n] = c[n] = 0
for jj in xrange(n - 1, -1, -1):
c[jj] = z[jj] - u[jj] * c[jj + 1]
b[jj] = (a[jj + 1] - a[jj]) / h[jj] - \
(h[jj] * (c[jj + 1] + 2 * c[jj])) / 3
d[jj] = (c[jj + 1] - c[jj]) / (3 * h[jj])
self.splines = [CubicSpline(self[ii].x, a[ii], b[ii], c[ii],
d[ii]) for ii in xrange(n)]
def get_y(self, t, loop=False, reverse=False, natural=False):
if loop or reverse:
if reverse and int(t) / int(self[-1].x) % 2:
t = self[-1].x - t
t %= self[-1].x
elif not natural:
if t < self[0].x:
t = self[0].x
elif t > self[-1].x:
t = self[-1].x
splines = self.splines
for ii in xrange(len(splines) - 1):
if t < splines[ii + 1].x:
return splines[ii].get_y(t)
return splines[-1].get_y(t)
def remove(self, node, refresh=True):
list.remove(self, node)
if refresh:
self.set_splines()
def resize(self, left, length, refresh=True):
old_left = self[0].x
old_length = self.get_length()
for node in self:
node.x = left + length * (node.x - old_left) / old_length
if refresh:
self.set_splines()
def get_length(self):
return self[-1].x - self[0].x
class Node(list):
def __init__(self, coordinates):
list.__init__(self, coordinates)
def __getattr__(self, name):
if name == "x":
return self[0]
elif name == "y":
return self[1]
return list.__get__(self, name)
def __setattr__(self, name, value):
if name == "x":
list.__setitem__(self, 0, value)
elif name == "y":
list.__setitem__(self, 1, value)
else:
list.__setattr__(self, name, value)
class Spline:
def __init__(self, x):
self.x = x
class CubicSpline(Spline):
def __init__(self, x, a, b, c, d):
Spline.__init__(self, x)
self.a = a
self.b = b
self.c = c
self.d = d
def get_y(self, t):
x = self.x
return self.a + self.b * (t - x) + self.c * (t - x) ** 2 + self.d * \
(t - x) ** 3
class LinearSpline(Spline):
def __init__(self, x, y, m):
Spline.__init__(self, x)
self.y = y
self.m = m
def get_y(self, t):
return self.m * (t - self.x) + self.y
class GUI(Animation):
B_DUPLICATE, B_WRITE, B_DELETE, B_LINEAR, B_CUBIC, B_SPLIT = range(6)
S_NONE, S_LEFT, S_RIGHT = range(3)
def __init__(self, parent):
Animation.__init__(self, parent, unfiltered=True)
self.audio = self.get_audio()
self.display = self.get_game().display
self.display_surface = self.get_display_surface()
self.time_filter = self.get_game().time_filter
self.delegate = self.get_delegate()
self.split = self.S_NONE
self.success_indicator_active = True
self.success_indicator_blink_count = 0
self.load_configuration()
self.font = Font(None, self.label_size)
self.prompt = Prompt(self)
self.set_temporary_file()
self.set_background()
self.set_success_indicator()
self.set_plot_rect()
self.set_marker_frame()
self.set_buttons()
self.active = False
self.set_nodeset_index()
self.set_y_range()
self.set_markers()
self.subscribe(self.respond_to_command)
self.subscribe(self.respond_to_mouse_down, MOUSEBUTTONDOWN)
self.subscribe(self.respond_to_key, KEYDOWN)
self.register(self.show_success_indicator, interval=100)
self.register(self.save_temporary_file, interval=10000)
self.play(self.save_temporary_file)
def load_configuration(self):
config = self.get_configuration("interpolator-gui")
self.label_size = config["label-size"]
self.axis_label_count = config["axis-label-count"]
self.margin = config["margin"]
self.curve_color = config["curve-color"]
self.marker_size = config["marker-size"]
self.marker_color = config["marker-color"]
self.label_precision = config["label-precision"]
self.template_nodeset = config["template-nodeset"]
self.template_nodeset_name = config["template-nodeset-name"]
self.flat_y_range = config["flat-y-range"]
def set_temporary_file(self):
self.temporary_file = open(join(gettempdir(), "pgfw-config"), "w")
def set_background(self):
surface = Surface(self.display_surface.get_size())
surface.fill((0, 0, 0))
self.background = surface
def set_success_indicator(self):
surface = Surface((10, 10))
surface.fill((0, 255, 0))
rect = surface.get_rect()
rect.topleft = self.display_surface.get_rect().topleft
self.success_indicator, self.success_indicator_rect = surface, rect
def set_plot_rect(self):
margin = self.margin
self.plot_rect = self.display_surface.get_rect().inflate(-margin,
-margin)
def set_marker_frame(self):
size = self.marker_size
surface = Surface((size, size))
transparent_color = (255, 0, 255)
surface.fill(transparent_color)
surface.set_colorkey(transparent_color)
line_color = self.marker_color
aaline(surface, line_color, (0, 0), (size - 1, size - 1))
aaline(surface, line_color, (0, size - 1), (size - 1, 0))
self.marker_frame = surface
def set_buttons(self):
self.buttons = buttons = []
text = "Duplicate", "Write", "Delete", "Linear", "Cubic", "Split: No"
x = 0
for instruction in text:
buttons.append(Button(self, instruction, x))
x += buttons[-1].location.w + 10
def set_nodeset_index(self, increment=None, index=None):
parent = self.parent
if index is None:
if not increment:
index = 0
else:
index = self.nodeset_index + increment
limit = len(parent) - 1
if index > limit:
index = 0
elif index < 0:
index = limit
self.nodeset_index = index
self.set_nodeset_label()
def set_nodeset_label(self):
surface = self.font.render(self.get_nodeset().name, True, (0, 0, 0),
(255, 255, 255))
rect = surface.get_rect()
rect.bottomright = self.display_surface.get_rect().bottomright
self.nodeset_label, self.nodeset_label_rect = surface, rect
def get_nodeset(self):
if not len(self.parent):
self.parent.add_nodeset(self.template_nodeset_name,
self.template_nodeset)
self.set_nodeset_index(0)
return self.parent[self.nodeset_index]
def set_y_range(self):
width = self.plot_rect.w
nodeset = self.get_nodeset()
self.y_range = y_range = [nodeset[0].y, nodeset[-1].y]
x = 0
while x < width:
y = nodeset.get_y(self.get_function_coordinates(x)[0])
if y < y_range[0]:
y_range[0] = y
elif y > y_range[1]:
y_range[1] = y
x += width * .01
if y_range[1] - y_range[0] == 0:
y_range[1] += self.flat_y_range
if self.split:
self.adjust_for_split(y_range, nodeset)
self.set_axis_labels()
def get_function_coordinates(self, xp=0, yp=0):
nodeset = self.get_nodeset()
x_min, x_max, (y_min, y_max) = nodeset[0].x, nodeset[-1].x, self.y_range
rect = self.plot_rect
x = float(xp) / (rect.right - rect.left) * (x_max - x_min) + x_min
y = float(yp) / (rect.bottom - rect.top) * (y_min - y_max) + y_max
return x, y
def adjust_for_split(self, y_range, nodeset):
middle = nodeset[0].y if self.split == self.S_LEFT else nodeset[-1].y
below, above = middle - y_range[0], y_range[1] - middle
if below > above:
y_range[1] += below - above
else:
y_range[0] -= above - below
def set_axis_labels(self):
self.axis_labels = labels = []
nodeset, formatted, render, rect, yr = (self.get_nodeset(),
self.get_formatted_measure,
self.font.render,
self.plot_rect, self.y_range)
for ii, node in enumerate(nodeset[0::len(nodeset) - 1]):
xs = render(formatted(node.x), True, (0, 0, 0), (255, 255, 255))
xsr = xs.get_rect()
xsr.top = rect.bottom
if not ii:
xsr.left = rect.left
else:
xsr.right = rect.right
ys = render(formatted(yr[ii]), True, (0, 0, 0), (255, 255, 255))
ysr = ys.get_rect()
ysr.right = rect.left
if not ii:
ysr.bottom = rect.bottom
else:
ysr.top = rect.top
labels.append(((xs, xsr), (ys, ysr)))
def get_formatted_measure(self, measure):
return "%s" % float(("%." + str(self.label_precision) + "g") % measure)
def deactivate(self):
self.active = False
self.time_filter.open()
self.audio.muted = self.saved_mute_state
self.display.set_mouse_visibility(self.saved_mouse_state)
def respond_to_command(self, event):
compare = self.delegate.compare
if compare(event, "toggle-interpolator"):
self.toggle()
elif self.active:
if compare(event, "reset-game"):
self.deactivate()
elif compare(event, "quit"):
self.get_game().end(event)
def toggle(self):
if self.active:
self.deactivate()
else:
self.activate()
def activate(self):
self.active = True
self.time_filter.close()
self.saved_mute_state = self.audio.muted
self.audio.mute()
self.draw()
self.saved_mouse_state = self.display.set_mouse_visibility(True)
def respond_to_mouse_down(self, event):
redraw = False
if self.active and not self.prompt.active:
nodeset_rect = self.nodeset_label_rect
plot_rect = self.plot_rect
if event.button == 1:
pos = event.pos
if nodeset_rect.collidepoint(pos):
self.set_nodeset_index(1)
redraw = True
elif self.axis_labels[0][0][1].collidepoint(pos):
text = "{0} {1}".format(*map(self.get_formatted_measure,
self.get_nodeset()[0]))
self.prompt.activate(text, self.resize_nodeset, 0)
elif self.axis_labels[1][0][1].collidepoint(pos):
text = "{0} {1}".format(*map(self.get_formatted_measure,
self.get_nodeset()[-1]))
self.prompt.activate(text, self.resize_nodeset, -1)
else:
bi = self.collide_buttons(pos)
if bi is not None:
if bi == self.B_WRITE:
self.get_configuration().write()
self.play(self.show_success_indicator)
elif bi in (self.B_LINEAR, self.B_CUBIC):
nodeset = self.get_nodeset()
if bi == self.B_LINEAR:
nodeset.set_interpolation_method(Nodeset.LINEAR)
else:
nodeset.set_interpolation_method(Nodeset.CUBIC)
self.store_in_configuration()
redraw = True
elif bi == self.B_DUPLICATE:
self.prompt.activate("", self.add_nodeset)
elif bi == self.B_DELETE and len(self.parent) > 1:
self.parent.remove(self.get_nodeset())
self.set_nodeset_index(1)
self.store_in_configuration()
redraw = True
elif bi == self.B_SPLIT:
self.toggle_split()
redraw = True
elif plot_rect.collidepoint(pos) and \
not self.collide_markers(pos):
xp, yp = pos[0] - plot_rect.left, pos[1] - plot_rect.top
self.get_nodeset().add_node(
self.get_function_coordinates(xp, yp))
self.store_in_configuration()
redraw = True
elif event.button == 3:
pos = event.pos
if nodeset_rect.collidepoint(pos):
self.set_nodeset_index(-1)
redraw = True
elif plot_rect.collidepoint(pos):
marker = self.collide_markers(pos)
if marker:
self.get_nodeset().remove(marker.node)
self.store_in_configuration()
redraw = True
elif self.active and self.prompt.active and \
not self.prompt.rect.collidepoint(event.pos):
self.prompt.deactivate()
redraw = True
if redraw:
self.set_y_range()
self.set_markers()
self.draw()
def resize_nodeset(self, text, index):
result = match("^\s*(-{,1}\d*\.{,1}\d*)\s+(-{,1}\d*\.{,1}\d*)\s*$",
text)
if result:
try:
nodeset = self.get_nodeset()
x, y = map(float, result.group(1, 2))
if (index == -1 and x > nodeset[0].x) or \
(index == 0 and x < nodeset[-1].x):
nodeset[index].y = y
if index == -1:
nodeset.resize(nodeset[0].x, x - nodeset[0].x)
else:
nodeset.resize(x, nodeset[-1].x - x)
self.store_in_configuration()
self.set_y_range()
self.set_axis_labels()
self.set_markers()
self.draw()
return True
except ValueError:
return False
def collide_buttons(self, pos):
for ii, button in enumerate(self.buttons):
if button.location.collidepoint(pos):
return ii
def store_in_configuration(self):
config = self.get_configuration()
section = "interpolate"
config.clear_section(section)
for nodeset in self.parent:
code = "L" if nodeset.interpolation_method == Nodeset.LINEAR else \
"C"
for ii, node in enumerate(nodeset):
if ii > 0:
code += ","
code += " {0} {1}".format(*map(self.get_formatted_measure,
node))
if not config.has_section(section):
config.add_section(section)
config.set(section, nodeset.name, code)
def toggle_split(self):
self.split += 1
if self.split > self.S_RIGHT:
self.split = self.S_NONE
self.buttons[self.B_SPLIT].set_frame(["Split: No", "Split: L",
"Split: R"][self.split])
def add_nodeset(self, name):
nodeset = self.get_nodeset()
self.set_nodeset_index(index=self.parent.add_nodeset(\
name, nodeset, nodeset.interpolation_method))
self.store_in_configuration()
self.draw()
return True
def collide_markers(self, pos):
for marker in self.markers:
if marker.location.collidepoint(pos):
return marker
def set_markers(self):
self.markers = markers = []
for node in self.get_nodeset()[1:-1]:
markers.append(Marker(self, node))
markers[-1].location.center = self.get_plot_coordinates(*node)
def get_plot_coordinates(self, x=0, y=0):
nodeset = self.get_nodeset()
x_min, x_max, (y_min, y_max) = nodeset[0].x, nodeset[-1].x, self.y_range
x_ratio = float(x - x_min) / (x_max - x_min)
rect = self.plot_rect
xp = x_ratio * (rect.right - rect.left) + rect.left
y_ratio = float(y - y_min) / (y_max - y_min)
yp = rect.bottom - y_ratio * (rect.bottom - rect.top)
return xp, yp
def draw(self):
display_surface = self.display_surface
display_surface.blit(self.background, (0, 0))
display_surface.blit(self.nodeset_label, self.nodeset_label_rect)
self.draw_axes()
self.draw_function()
self.draw_markers()
self.draw_buttons()
def draw_axes(self):
display_surface = self.display_surface
for xl, yl in self.axis_labels:
display_surface.blit(*xl)
display_surface.blit(*yl)
def draw_function(self):
rect = self.plot_rect
surface = self.display_surface
nodeset = self.get_nodeset()
step = 1
for x in xrange(rect.left, rect.right + step, step):
ii = x - rect.left
fx = nodeset.get_y(self.get_function_coordinates(ii)[0])
y = self.get_plot_coordinates(y=fx)[1]
if ii > 0:
aaline(surface, self.curve_color, (x - step, last_y), (x, y))
last_y = y
def draw_markers(self):
for marker in self.markers:
marker.update()
def draw_buttons(self):
for button in self.buttons:
button.update()
def respond_to_key(self, event):
if self.prompt.active:
prompt = self.prompt
if event.key == K_RETURN:
if prompt.callback[0](prompt.text, *prompt.callback[1]):
prompt.deactivate()
elif event.key == K_BACKSPACE:
prompt.text = prompt.text[:-1]
prompt.update()
prompt.draw_text()
elif (event.unicode.isalnum() or event.unicode.isspace() or \
event.unicode in (".", "-", "_")) and len(prompt.text) < \
prompt.character_limit:
prompt.text += event.unicode
prompt.update()
prompt.draw_text()
def show_success_indicator(self):
self.draw()
if self.success_indicator_blink_count > 1:
self.success_indicator_blink_count = 0
self.halt(self.show_success_indicator)
else:
if self.success_indicator_active:
self.display_surface.blit(self.success_indicator,
self.success_indicator_rect)
if self.success_indicator_active:
self.success_indicator_blink_count += 1
self.success_indicator_active = not self.success_indicator_active
def save_temporary_file(self):
fp = self.temporary_file
fp.seek(0)
fp.truncate()
self.get_configuration().write(fp)
def rearrange(self):
self.set_background()
self.set_success_indicator()
self.set_plot_rect()
self.set_markers()
self.set_nodeset_label()
self.set_axis_labels()
self.set_buttons()
self.prompt.reset()
class Marker(Sprite):
def __init__(self, parent, node):
Sprite.__init__(self, parent)
self.add_frame(parent.marker_frame)
self.node = node
class Button(Sprite):
def __init__(self, parent, text, left):
Sprite.__init__(self, parent)
self.set_frame(text)
self.location.bottomleft = left, \
self.get_display_surface().get_rect().bottom
def set_frame(self, text):
self.clear_frames()
self.add_frame(self.parent.font.render(text, True, (0, 0, 0),
(255, 255, 255)))
class Prompt(Sprite):
def __init__(self, parent):
Sprite.__init__(self, parent)
self.load_configuration()
self.font = Font(None, self.text_size)
self.reset()
self.deactivate()
def deactivate(self):
self.active = False
def load_configuration(self):
config = self.get_configuration("interpolator-gui")
self.size = config["prompt-size"]
self.border_color = config["prompt-border-color"]
self.border_width = config["prompt-border-width"]
self.character_limit = config["prompt-character-limit"]
self.text_size = config["prompt-text-size"]
def reset(self):
self.set_frame()
self.place()
def set_frame(self):
self.clear_frames()
surface = Surface(self.size)
self.add_frame(surface)
surface.fill(self.border_color)
width = self.border_width * 2
surface.fill((0, 0, 0), surface.get_rect().inflate(-width, -width))
def place(self):
self.location.center = self.display_surface.get_rect().center
def activate(self, text, callback, *args):
self.active = True
self.text = str(text)
self.callback = callback, args
self.update()
self.draw_text()
def draw_text(self):
surface = self.font.render(self.text, True, (255, 255, 255), (0, 0, 0))
rect = surface.get_rect()
rect.center = self.location.center
self.display_surface.blit(surface, rect)