Source code for fractalshades.db

# -*- coding: utf-8 -*-
import copy
import os
import logging

import numpy as np
import numba
import PIL
from numpy.lib.format import open_memmap

import fractalshades as fs
import fractalshades.utils
from fractalshades.numpy_utils.interp2d import (
    Grid_lin_interpolator as fsGrid_lin_interpolator
)

import fractalshades.numpy_utils.filters as fsfilters
from fractalshades.mthreading import Multithreading_iterator

logger = logging.getLogger(__name__)


[docs] class Frame:
[docs] def __init__(self, x, y, dx, nx, xy_ratio, t=None, plotting_modifier=None): """ A frame is used to describe a specific data window and for interpolating inside a db. Parameters ---------- x: float center x-coord of the window (in screen coordinates) y: float center y-coord of the window (in screen coordinates) dx: float width of the window (in screen coordinates) nx: int number of pixel for the interpolated frame xy_ratio: float width / height ratio of the interpolated frame t: Optionnal float time [s] of this frame in the movie plotting_modifier: Optionnal callable a plotting_modifier associated with this Frame Notes ----- A simple pass-through Frame, extracting the raw ``my_db`` data is: :: fractalshades.db.Frame( x=0., y=0., dx=1.0, nx=my_db.zoom_kwargs["nx"], xy_ratio=my_db.zoom_kwargs["xy_ratio"] ) """ self.x = x self.y = y self.dx = dx self.nx = nx self.xy_ratio = xy_ratio self.t = t self.plotting_modifier = plotting_modifier self.ny = int(self.nx / self.xy_ratio + 0.5) self.dy = dy = self.dx / self.xy_ratio self.size = (self.nx, self.ny) self.db_size = (self.ny, self.nx) # PIL convention xmin = x - 0.5 * dx xmax = x + 0.5 * dx ymin = y - 0.5 * dy ymax = y + 0.5 * dy x_1d = np.linspace(xmin, xmax, self.nx, dtype=np.float32) y_1d = np.linspace(ymin, ymax, self.ny, dtype=np.float32) # Meshgrid of the frame with PIL convention x_grid, y_grid = np.meshgrid(x_1d, y_1d[::-1], indexing='xy') self.pts = np.ravel(x_grid), np.ravel(y_grid)
def upsampled(self, supersampling): """ Return the upsampled version of this Frame (if supersampling is None just pass through) """ if supersampling is None: return self return Frame( self.x, self.y, self.dx, self.nx * supersampling, self.xy_ratio, self.t, self.plotting_modifier )
class Plot_template: def __init__(self, plotter, db_loader, frame): """ A plotter with a pluggable get_2d_arr method -> Basically, an interface to avoid monkey-patching `get_2d_arr` Parameters ---------- plotter: `fractalshades.Fractal_plotter` The wrapped plotter - will be copied db_loader: `fractalshades.db.Db` A db data-loading class - implementing `get_2d_arr` frame: Frame The frame-specific data holder """ # Internal plotter object, layers reparented for frame-specific # functionnality self.plotter = copy.deepcopy(plotter) for layer in self.plotter.layers: layer.link_plotter(self) self.db_loader = db_loader self.frame = frame def __getattr__(self, attr): return getattr(self.plotter, attr) def __getitem__(self, layer_name): """ Get the layer by its postname """ return self.plotter.__getitem__(layer_name) def get_2d_arr(self, post_index, chunk_slice): """ Forwards to the db-loader for frame-specific functionnality """ return self.db_loader.get_2d_arr( post_index, self.frame, chunk_slice )
[docs] class Db:
[docs] def __init__(self, path): """ Wrapper around the memory-mapped numpy-array stored at ``path``\. Parameters ---------- path: str The path for the data. This memory-mapped array is usually created through a `fractalshades.Fractal_plotter.save_db` call. Two format are available (\*.db and \*.post db, refer to the doc for this function for details) """ # Development Note - on supersampling: # - the .db data is supersampled # - the .postdb rgb data so it is already downsampled # General rule, Lanczos filter is applied at image making stage self.path = path _, ext = os.path.splitext(path) if ext == ".db": self.postdb = False elif ext == ".postdb": self.postdb = True else: raise ValueError(f"Unknown Db extension: {ext}") self.init_model() # Cache for interpolating classes self._interpolator = {}
@property def is_postdb(self): return self.postdb @property def zoom_kwargs(self): return self.plotter.fractal.zoom_kwargs def init_model(self): """ Build a description for the datapoints in the mmap """ mmap = open_memmap(filename=self.path, mode="r+") if self.postdb: ny, nx, nposts = mmap.shape else: nposts, ny, nx = mmap.shape del mmap # Points number self.nx = nx # = self.zoom_kw["nx"] self.ny = ny # = int(nx / xy_ratio + 0.5) self.xy_ratio = xy_ratio = nx / ny # Points loc - in screen coordinates self.x0 = x0 = 0. self.y0 = y0 = 0. self.dx0 = dx0 = 1.0 self.dy0 = dy0 = 1.0 / xy_ratio # xy grid used for interpolation. # Downsampled version implemented as part of freezing process self.xmin0 = xmin0 = x0 - 0.5 * dx0 self.xmax0 = xmax0 = x0 + 0.5 * dx0 self.xgrid0, self.xh0 = np.linspace( xmin0, xmax0, nx, endpoint=True, retstep=True, dtype=np.float32 ) self.ymin0 = ymin0 = y0 - 0.5 * dy0 self.ymax0 = ymax0 = y0 + 0.5 * dy0 self.ygrid0, self.yh0 = np.linspace( ymin0, ymax0, ny, endpoint=True, retstep=True, dtype=np.float32 ) def get_interpolator(self, frame, post_index): """ Try first to reload if the interpolating domain is still valid If not, creates a new interpolator """ if post_index in self._interpolator.keys(): (interpolator, bounds) = self._interpolator[post_index] x = frame.x y = frame.y ddx = 0.5 * frame.dx ddy = 0.5 * frame.dy a, b = bounds valid = ( (a[0] <= x - ddx) and (b[0] >= x + ddx) and (a[1] <= y - ddy) and (b[1] >= y + ddy) ) if valid: return interpolator interpolator, bounds = self.make_interpolator(frame, post_index) self._interpolator[post_index] = (interpolator, bounds) logger.debug(f"New Interpolator added for field #{post_index}") return interpolator def make_interpolator(self, frame, post_index): """ Returns an interpolator for the field post_index and pts inside the domain [x-dx, x+dx] x [y-dy, y+dy] in Screeen coordinates """ logger.info("Creating a new interpolator in database") x = frame.x y = frame.y dx = frame.dx dy = frame.dy # NOTE that if we are interpolating in the frozen image, then the # xgrid0 and ygrid0 might be downsampled xgrid0 = self.xgrid0 ygrid0 = self.ygrid0 xh0 = self.xh0 yh0 = self.yh0 nx = len(xgrid0) ny = len(ygrid0) # 1) load the local interpolation array for the region of interest # with a margin so it has a chance to remain valid for several frames if (x - 0.5 * dx) < self.xmin0: raise ValueError("Frame partly outside databse data: low x") if (y - 0.5 * dy) < self.ymin0: raise ValueError("Frame partly outside databse data: low y") if (x + 0.5 * dx) > self.xmax0: raise ValueError("Frame partly outside databse data: high x") if (y + 0.5 * dy) > self.ymax0: raise ValueError("Frame partly outside databse data: high y") k = 0.5 * 1.5 # 0.5 would be no margin at all # min vals for frame interpolation x_min = np.float32(max(x - k * dx, self.xmin0)) ind_xmin = np.searchsorted(xgrid0, x_min, side="right") - 1 y_min = np.float32(max(y - k * dy, self.ymin0)) ind_ymin = np.searchsorted(ygrid0, y_min, side="right") - 1 # max vals for frame interpolation x_max = np.float32(min(x + k * dx, self.xmax0)) ind_xmax = min(np.searchsorted(xgrid0, x_max, side="left"), nx) y_max = np.float32(min(y + k * dy, self.ymax0)) ind_ymax = min(np.searchsorted(ygrid0, y_max, side="left"), ny) # 2) Creates and return the interpolator # a, b: the lower and upper bounds of the interpolation region # h: the grid-spacing at which f is given # f: data to be interpolated # k: order of local taylor expansions (int, 1, 3, or 5) # p: whether the dimension is taken to be periodic # c: whether the array should be padded to allow accurate close eval # e: extrapolation distance, how far to allow extrap, in units of h # (needs to be an integer) a = [xgrid0[ind_xmin], ygrid0[ind_ymin]] b = [xgrid0[ind_xmax], ygrid0[ind_ymax]] h = [xh0, yh0] assert xh0 > 0 assert yh0 > 0 assert ind_xmin < ind_xmax assert ind_ymin < ind_ymax if self.postdb: fr_mmap = open_memmap(filename=self.path, mode="r") f = fr_mmap[ (ny - ind_ymax - 1):(ny - ind_ymin), ind_xmin:(ind_xmax + 1), post_index ] del fr_mmap else: mmap = open_memmap(filename=self.path, mode="r") f = mmap[ post_index, (ny - ind_ymax - 1):(ny - ind_ymin), ind_xmin:(ind_xmax + 1) ] del mmap assert a[0] < b[0] assert a[1] < b[1] interpolator = fsGrid_lin_interpolator( a, b, h, f, PIL_order=True ) bounds = (a, b) return interpolator, bounds # --------------- db plotting interface ---------------------------------------
[docs] def set_plotter(self, plotter, postname): """ Define the plotting properties - Needed only if a \*.db is provided not (as opposed to a \*.postdb image array format) Parameters ---------- plotter: `fractalshades.Fractal_plotter` A plotter to be used as template postname: str The string indentifier of the layer used for plotting """ assert isinstance(plotter, fs.Fractal_plotter) if self.postdb: raise RuntimeError( "`set_plotter` shall not be called for a .postdb" ) self.plotter = plotter self.postname = postname
def get_2d_arr(self, post_index, frame, chunk_slice): """ get_2d_arr with frame-specific functionnality Parameters ---------- post_index: int the index for this post-processing field in self.plotter frame: `fractalshades.db.Frame` Frame localisation for interpolation chunk_slice: 4-uplet float chunk_slice = (ix, ixx, iy, iyy) is the sub-array to reload Not currently used as `frame` is never None (direct reloading) """ assert frame is not None ret = self.get_interpolator(frame, post_index)(*frame.pts) return ret.reshape(frame.db_size)
[docs] def plot(self, frame=None): """ Parameters ---------- frame: `fractalshades.db.Frame`, Optional Defines the area to plot. If not provided, the full db will be plotted. Returns ------- img: `PIL.Image` The plotted image Notes ----- Plotting settings are set by `fractalshades.db.Db.set_plotter` method, they are used only for \*.db format (as opposed to a \*.postdb image array format) """ if frame is None: # Default to plotting the whole db frame = fs.db.Frame( x=0., y=0., dx=1.0, nx=self.nx, xy_ratio=self.xy_ratio ) # Is the db filled with raw rgb data ? if self.postdb: return self.plot_postdb(frame) # Here the plotter shall take into account the 'full frame' size full_frame = frame.upsampled(self.plotter.supersampling) plot_template = Plot_template(self.plotter, self, full_frame) plotting_modifier = frame.plotting_modifier if plotting_modifier is not None: plotting_modifier(plot_template, frame.t) img = None out_postname = self.postname for i, layer in enumerate(plot_template.layers): if layer.postname == out_postname: if not(layer.output): raise ValueError(f"No output for this layer: {layer}") img = PIL.Image.new(mode=layer.mode, size=frame.size) im_layer = layer break if img is None: raise ValueError( f"Layer missing: {out_postname} " + f"not found in {plot_template.postnames}" ) self.process(plot_template, frame, img, im_layer) return img
def plot_postdb(self, frame): """ Direct interpolation in a .postdb db image data """ mmap = open_memmap(filename=self.path, mode="r+") _, _, n_channel = mmap.shape dtype = mmap.dtype del mmap db_size = frame.db_size ret = np.empty(db_size + (n_channel,), dtype=dtype) for ic in range(n_channel): channel_ret = self.get_interpolator(frame, ic)(*frame.pts) channel_ret = channel_ret.reshape(db_size) ret[:, :, ic] = channel_ret if n_channel == 1: im = PIL.Image.fromarray(ret[:, :, 0]) else: im = PIL.Image.fromarray(ret) return im def process(self, plot_template, frame, img, im_layer): """ Just plot the Images interpolated data + plot_template 1 db point -> 1 pixel """ nx, ny = full_nx, full_ny = frame.size ss = self.plotter.supersampling chunk_slice = (0, nx, 0, ny) crop_slice = (0, 0, nx, ny) # This line ultimately forwards to self.get_2d_arr(...) thanks # to Plot_template interface - frame is not None paste_crop = im_layer.crop(chunk_slice) if ss: # Here, we apply a resizing filter resample = PIL.Image.LANCZOS paste_crop = paste_crop.resize( size=(nx, ny), resample=resample, box=None, reducing_gap=None ) img.paste(paste_crop, box=crop_slice)
#============================================================================== #============================================================================== # Exponential mapping to cartesian frame transform
[docs] class Exp_frame:
[docs] def __init__(self, h, nx, xy_ratio, t=None, plotting_modifier=None, pts=None): """ This class is used to describe a specific data window for interpolation inside a `fractalshades.db.Exp_db`. Parameters ---------- h: float >= 0. zoom level. A zoom level of 0. denotes that dx = dx0 fully zoomed-in frame - for h > 0, dx = dx0 * np.exp(h) nx: int number of pixels for the interpolated frame xy_ratio: float width / height ratio of the interpolated frame t: Optionnal float time [s] of this frame in the movie plotting_modifier: Optional callable a plotting_modifier associated with this frame pts: Optional, 4-uplet of arrays The x, y, h, t grid as returned by make_exp_grid - if not provided it will be recomputed - but more efficient to share between frames """ self.h = h self.nx = nx self.xy_ratio = xy_ratio self.t = t self.plotting_modifier = plotting_modifier self.ny = int(self.nx / self.xy_ratio + 0.5) self.size = (self.nx, self.ny) self.db_size = (self.ny, self.nx) # PIL convention if pts is None: pts = self.make_exp_grid(self.nx, self.xy_ratio) # Basic shape verification assert pts[0].size == (self.nx * self.ny) self.pts = pts
@staticmethod def make_exp_grid(nx, xy_ratio): """ Return a base grid [-0.5, 0.5] x [-0.5/xy_ratio, 0.5/xy_ratio] in both cartesian and expoential coordinates """ ny = int(nx / xy_ratio + 0.5) xmin = -0.5 xmax = +0.5 ymin = -0.5 / xy_ratio ymax = +0.5 / xy_ratio # Cartesian grid xvec = np.linspace(xmin, xmax, nx, dtype=np.float32) yvec = np.linspace(ymin, ymax, ny, dtype=np.float32) x_grid, y_grid = np.meshgrid(xvec, yvec[::-1], indexing='xy') x_grid = x_grid.reshape(-1) y_grid = y_grid.reshape(-1) # Exponential grid coordinates frac_pixw = 0.1 / nx h_grid = 0.5 * np.log( np.maximum(x_grid ** 2 + y_grid ** 2, frac_pixw ** 2) ) t_grid = np.arctan2(y_grid, x_grid) # Store the grid return (x_grid, y_grid, h_grid, t_grid)
[docs] class Exp_db:
[docs] def __init__(self, path_expmap, path_final): """ Wrapper around the raw array data stored at ``path_expmap`` and ``path_final``\. Parameters ---------- path_expmap: str The path for the expmap database. Note that only \*.postdb format is currently supported hence the expmap array is of shape (nt, nh, nchannels) and stores rgb data. It is usually saved by a call to `fractalshades.Fractal_plotter.save_db` using a `fractalshades.projection.Expmap` projection. Note that the orientation parameter of this projection shall be "vertical". path_final: str The path for the final raw data. Note that only \*.postdb format is currently supported hence the expmap array is of shape (ny, nx, nchannels) and stores rgb data. It is usually saved by a `fractalshades.Fractal_plotter.save_db` call (with a standard `fractalshades.projection.Cartesian` projection). It shall be square (nx == ny) and only \*.postdb format is currently supported. """ # Development Note # ---------------- # Note on supersampling: # - the .db data is supersampled # - the .postdb data is the image so it is already downsampled # General rule, Lanczos filter is applied at image making stage self.path_expmap = path_expmap _, ext = os.path.splitext(path_expmap) if ext == ".db": self.postdb = False raise NotImplementedError( "Only .postdb files implemented for making exp zoom movies. " "Consider saving your database in this format." ) elif ext == ".postdb": self.postdb = True else: raise ValueError(f"Unknown Db extension: {ext}") self.path_final = path_final _, ext2 = os.path.splitext(path_final) if ext != ext2: raise ValueError( "Extensions for path_expmap and path_final shall match ; " f"Found: {ext} and {ext2}" ) self.init_model() self.subsample() # Cache for interpolating classes self._interpolator = {}
@property def is_postdb(self): return self.postdb def init_model(self): """ Build a description for the datapoints in the mmap """ assert self.postdb mmap = open_memmap(filename=self.path_expmap, mode="r+") # .postdb of an Expmap woth orientation = "vertical" nh, nt, nposts = mmap.shape dtype = mmap.dtype del mmap mmap = open_memmap(filename=self.path_final, mode="r+") ny, nx, _nposts = mmap.shape _dtype = mmap.dtype del mmap if _nposts != nposts: raise ValueError("Incompatible final image database for Exp_db: " f"`nposts` not matching: {_nposts} vs {nposts}") if _dtype != dtype: raise ValueError("Incompatible final image database for Exp_db: " f"`dtype` not matching: {_dtype} vs {dtype}") if nx != ny: raise ValueError("Final image database shall be square, found: " f"{nx} x {ny}") self.nposts = self.nchannels = nposts self.dtype = dtype # Points number self.nh = nh self.nt = nt self.nx = nx self.ny = ny # Data span dh0 = 2. * np.pi * nh / nt self.hmin0 = 0. self.hmax0 = self.hmin0 + dh0 self.xmin0, self.xmax0 = -0.5, 0.5 self.ymin0, self.ymax0 = -0.5, 0.5 # ht grid self.hgrid0, self.hh0 = np.linspace( self.hmin0, self.hmax0, nh, endpoint=True, retstep=True, dtype=np.float32 ) self.tgrid0, self.th0 = np.linspace( -np.pi, np.pi, nt, endpoint=True, retstep=True, dtype=np.float32 ) # xy grid self.xgrid0, self.xh0 = np.linspace( self.xmin0, self.xmax0, nx, endpoint=True, retstep=True, dtype=np.float32 ) self.ygrid0, self.yh0 = np.linspace( self.ymin0, self.ymax0, ny, endpoint=True, retstep=True, dtype=np.float32 ) # Lanczos-2 2-decimation routine self.lf2_stable = fsfilters.Lanczos_decimator().get_stable_impl(2) def path(self, kind, downsampling): """ Path to the database, including the cascading subsampled db Parameters ---------- kind: "exp" | "final" The underlying db downsampling: bool If true, this is the path for the multi-level downsampled db """ if kind == "exp": if not(downsampling): return self.path_expmap root, ext = os.path.splitext(self.path_expmap) return root + "_downsampling" + ext elif kind == "final": if not(downsampling): return self.path_final root, ext = os.path.splitext(self.path_final) return root + "_downsampling" + ext else: raise ValueError(f"{kind = }") def get_interpolator(self, frame, post_index): """ Try first to reload if the interpolating domain is still valid If not, creates a new interpolator """ if post_index in self._interpolator.keys(): (interpolator, bounds) = self._interpolator[post_index] h = frame.h h1, h2 = bounds # h1 > h2 validity range of the interpolator valid = ((h <= h1) and (h >= h2)) if valid: return interpolator interpolator, bounds = self.make_interpolator(frame, post_index) self._interpolator[post_index] = (interpolator, bounds) logger.debug(f"New Interpolator added for field #{post_index}") return interpolator def make_interpolator(self, frame, post_index): """ Returns an interpolator for the field post_index and pts inside the domain [x0 - dx, x0 + dx] x [y0 - dy, y0 - dy] where dx = exp(h) * dx0 dy = exp(h) dy0 in Screeen coordinates of the final pic Returns ------- interpolator the interpolator bounds: (h1, h2) h1 > h2 the validity range """ ic = post_index h = frame.h # expansion factor from final pic is exp(h) nx = frame.nx xy_ratio = frame.xy_ratio margin = 20. # Shall remain valid for this zoom range (in and out) h_margin = np.log(margin) h_decimate = np.log(2.) # Triggers factor-2 image decimation info_dic = self._subsampling_info dtype = self.dtype # Checks Frame validity if h < 0.: raise ValueError(f"Frame outside databse data: h = {h} < 0") # Highest acceptable h: rpix_max = 0.5 * np.sqrt(1. + xy_ratio ** 2) allowed_hmax = self.hmax0 - np.log(rpix_max) if h > allowed_hmax: raise ValueError( "Frame outside databse data: " f"h = {h} > allowed_hmax = {allowed_hmax}\n" f"hmax0: {self.hmax0} " f"rpix_max: {rpix_max} " f"xy_ratio: {xy_ratio} " f"nh: {self.nh} " f"nt: {self.nt} " f"hmin: {self.hmin0}" ) #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Define parameters for multilevel exp_map interpolation # a_exp, b_exp, h_exp, f_exp, f_exp_shape, f_exp_slot kind = "exp" full_shape = info_dic[(kind, "ss_shapes")] # ny, nx or nh, nt full_slot = info_dic[(kind, "ss_slots")] full_bound = info_dic[(kind, "ss_bounds")] # (start_x, end_x, start_y, end_y) lvl = full_shape.shape[0] # We fill as if full range first # Note that the bounds are expressed as h, t even with PIL indexing a_exp = np.copy(full_bound[:, 0::2]) # Lower bound h, t b_exp = np.copy(full_bound[:, 1::2]) # Higher bound h, t h_exp = ((b_exp - a_exp) / (full_shape - 1)).astype(np.float32) f_exp_shape = np.copy(full_shape) f_exp_slot = np.copy(full_slot) # we extract a subrange for the h direction h_index = np.copy(full_shape) for ilvl in range(lvl): delta_h = h_exp[ilvl, 0] arr_hmin = a_exp[ilvl, 0] arr_hmax = b_exp[ilvl, 0] # Compute the extracted indices for this level pix_hmin = np.clip( h - h_margin - h_decimate * (ilvl + 1), self.hmin0, self.hmax0 ) pix_hmax = np.clip( h + h_margin - h_decimate * ilvl, self.hmin0, self.hmax0 ) # Note: still, ascending sort order ind_hmin = int(np.floor((pix_hmin - arr_hmin) / delta_h)) ind_hmax = int(np.ceil((pix_hmax - arr_hmin) / delta_h)) # Avoids the 'lonely pixel' case if ind_hmax - ind_hmin == 1: if ind_hmax < f_exp_shape[ilvl, 0]: ind_hmax += 1 else: ind_hmin -= 1 h_index[ilvl, :] = ind_hmin, ind_hmax k_min = ind_hmin / (full_shape[ilvl, 0] - 1) k_max = ind_hmax / (full_shape[ilvl, 0] - 1) # Updates tables to extracted values a_exp[ilvl, 0] = arr_hmin * (1. - k_min) + arr_hmax * k_min b_exp[ilvl, 0] = arr_hmin * (1. - k_max) + arr_hmax * k_max f_exp_shape[ilvl, 0] = ind_hmax - ind_hmin # Temporarly, we store it as dim, then use cumsum f_exp_slot[ilvl, 1] = f_exp_shape[ilvl, 0] * f_exp_shape[ilvl, 1] f_exp_slot[:, 1] = np.cumsum(f_exp_slot[:, 1]) f_exp_slot[1:, 0] = f_exp_slot[:-1, 1] f_exp_slot[0, 0] = 0 f_exp = np.empty((f_exp_slot[-1, 1],), dtype=dtype) # storage vec ilvl = 0 filename = self.path(kind, downsampling=False) mmap = open_memmap(filename=filename, mode="r") ind_hmin, ind_hmax = h_index[ilvl, :] loc_arr = mmap[ind_hmin:ind_hmax, :, ic] # Use the full theta range loc_arr = loc_arr.reshape(-1) f_exp[f_exp_slot[ilvl, 0]: f_exp_slot[ilvl, 1]] = loc_arr del mmap filename = self.path(kind, downsampling=True) mmap = open_memmap(filename=filename, mode="r") for ilvl in range(1, lvl): ind_hmin, ind_hmax = h_index[ilvl, :] ny, nx = full_shape[ilvl, :] di = full_slot[ilvl, 0] # 0 or 1 ??? # Here why we need to used "vertical" orientation for the expmap, # as the primary dim is nh... loc_arr = mmap[ic, (ind_hmin * nx) + di: (ind_hmax * nx) + di] f_exp[f_exp_slot[ilvl, 0]: f_exp_slot[ilvl, 1]] = loc_arr del mmap #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Define parameters for multilevel final interpolation # a_exp, b_exp, h_exp, f_exp, f_exp_shape, f_exp_slot kind = "final" full_shape = info_dic[(kind, "ss_shapes")] full_slot = info_dic[(kind, "ss_slots")] full_bound = info_dic[(kind, "ss_bounds")] # (start_x, end_x, start_y, end_y) lvl = full_shape.shape[0] # Easy, we just fill as full range - no subrange extraction step a_final = np.copy(full_bound[:, 0::2]) b_final = np.copy(full_bound[:, 1::2]) h_final = ((b_final - a_final) / (full_shape - 1)).astype(np.float32) f_final_shape = np.copy(full_shape) f_final_slot = np.copy(full_slot) # we still adjust the slots position as first level is now merged for ilvl in range(lvl): f_final_slot[ilvl, 1] = f_final_shape[ilvl, 0] * f_final_shape[ilvl, 1] f_final_slot[:, 1] = np.cumsum(f_final_slot[:, 1]) f_final_slot[1:, 0] = f_final_slot[:-1, 1] f_final_slot[0, 0] = 0 f_final = np.empty((f_final_slot[-1, 1],), dtype=dtype) ilvl = 0 filename = self.path(kind, downsampling=False) mmap = open_memmap(filename=filename, mode="r") loc_arr = mmap[:, :, ic] loc_arr = loc_arr.reshape(-1) f_final[f_final_slot[ilvl, 0]: f_final_slot[ilvl, 1]] = loc_arr del mmap filename = self.path(kind, downsampling=True) mmap = open_memmap(filename=filename, mode="r") for ilvl in range(1, lvl): filename = self.path(kind, downsampling=True) mmap = open_memmap(filename=filename, mode="r") loc_arr = mmap[ic, full_slot[ilvl, 0]: full_slot[ilvl, 1]] f_final[f_final_slot[ilvl, 0]: f_final_slot[ilvl, 1]] = loc_arr del mmap interpolator = Multilevel_exp_interpolator( a_exp, b_exp, h_exp, f_exp, f_exp_shape, f_exp_slot, a_final, b_final, h_final, f_final, f_final_shape, f_final_slot ) bounds = (h + h_margin, h - h_margin) return interpolator, bounds # --------------- db Subsampling interface --------------------------------------- def subsample(self): """ Make a series of subsampled databases (either pain or frozen) """ self._subsampling_info = {} # 1) Downsample the frozen exp db source = self.path("exp", downsampling=False) filename = self.path("exp", downsampling=True) init_bound = np.array((self.hmin0, self.hmax0, -np.pi, np.pi)) self.populate_subsampling( filename, source, driving_dim="x", # i.e, the "t" Expmap dim init_bound=init_bound, kind="exp" ) # 2) Downsample the frozen final db source = self.path("final", downsampling=False) filename = self.path("final", downsampling=True) init_bound = np.array((self.xmin0, self.xmax0, self.ymin0, self.ymax0)) self.populate_subsampling( filename, source, driving_dim="x", init_bound=init_bound, kind="final" ) def ss_lvl_count(self, kind): """ Return the number of subsampling data levels stored for this db. Parameters ---------- kind: "exp" | "final" The source db """ ss_shapes = self._subsampling_info[((kind, "ss_shapes"))] return ss_shapes.shape[0] def ss_img(self, kind, lvl): """ Return a subsampled raw image (used for debuging) Parameters: ----------- kind: "exp" | "final" The source db lvl: int >= 0 the subsampling level """ assert self.postdb dtype = self.dtype nc = self.nchannels ss_shapes = self._subsampling_info[((kind, "ss_shapes"))] ss_slots = self._subsampling_info[((kind, "ss_slots"))] nx, ny = ss_shapes[lvl, :] lw, hg = ss_slots[lvl, :] arr = np.empty((nx, ny, nc), dtype=dtype) filename = self.path(kind, downsampling=(lvl != 0)) mmap = open_memmap(filename=filename, mode="r") # Returns the RGB(A) array for the exp mapping for level = lvl for ic in range(nc): if lvl == 0: loc_arr = mmap[:, :, ic] else: loc_arr = (mmap[ic, lw:hg]).reshape((nx, ny)) arr[:, :, ic] = loc_arr del mmap # np -> PIL return PIL.Image.fromarray(arr) def populate_subsampling(self, filename, source, driving_dim, init_bound, kind): """ Creates a memory mapping at filename and populates it with subsampled data from source. Parameters ---------- filename: str path for the new mmap source: str path for the source mmap driving_dim: "x", "y" The dimension of the image that will be reduced to 2 (criteria for the number of levels) init_bound: np.array([minx, maxx, miny, maxy]) The initial range for data kind: "exp" | "final" The kind of mem mapping Returns ------- ss_shapes: shapes (ny, nx) or (nh, nt) of the nested subsampled arrays ss_bounds flatten localisation of the nested subsampled arrays To recover for """ logger.info( "Writing a subsampled db\n" f" source: {source}\n" f" dest: {filename}" ) lf2_stable = self.lf2_stable # We flatten the image however the source might imply several channels # or layers... # For a .postdb, mmap.shape: (ny, nx, n_channel) # Hence for a "vertical" Expmap: nh, nt source_mmap = open_memmap(filename=source, mode="r") dtype = source_mmap.dtype (ny, nx, nposts) = source_mmap.shape # Dev note: in case of expmap, ny == nh, nx == nt ss_nx = nx ss_ny = ny ss_slotl = ss_sloth = 0 # Number of subsampling levels, based on nt # 2->0 3->1 4->2 5->2 6->3 7->3 8->3 9->3 10->4 [..] 17->4 18->5 nt = {"x": nx, "y": ny}[driving_dim] ss_lvls = (nt - 2).bit_length() ss_shapes = np.empty((ss_lvls + 1, 2), dtype=np.int32) ss_slots = np.empty((ss_lvls + 1, 2), dtype=np.int32) # Note that we need to store the x AND y bounds ss_bounds = np.tile(init_bound, (ss_lvls + 1, 1)).astype(np.float32) # Sizing the subsampling arrays ss_shapes[0, :] = [ny, nx] # or nh, nt if Expmap, "vertical" ss_slots[0, :] = [-1, -1] # not relevant as in another mmap for lvl in range(ss_lvls): ss_nx = ss_nx // 2 + 1 ss_ny = ss_ny // 2 + 1 ss_slotl = ss_sloth ss_sloth += ss_nx * ss_ny ss_shapes[lvl + 1, :] = [ss_ny, ss_nx] ss_slots[lvl + 1, :] = [ss_slotl, ss_sloth] ss_mmap = open_memmap( filename=filename, mode='w+', dtype=dtype, shape=(nposts, ss_sloth), fortran_order=False, version=None ) for ipost in range(nposts): ss_ny = ny for lvl in range(ss_lvls): ss_diy = 200 ss_ny = ss_ny // 2 + 1 # The grouping range (in data columns) for parallel exec. self.y_range = lambda: np.arange(0, ss_ny - 1, ss_diy) self.parallel_populate_subsampling( ss_mmap, source_mmap, ipost, lvl, ss_shapes, ss_slots, ss_bounds, lf2_stable, ss_diy, ss_iystart=None ) del source_mmap del ss_mmap self._subsampling_info.update({ (kind, "ss_shapes"): ss_shapes, # Shape of the full ss array, lvl = i - 1 (kind, "ss_slots"): ss_slots, # 1d-slot of the full ss array, lvl = i - 1 (kind, "ss_bounds"): ss_bounds # bounds for the full ss array }) @Multithreading_iterator( iterable_attr="y_range", iter_kwargs="ss_iystart" ) def parallel_populate_subsampling(self, mmap, source_mmap, ipost, lvl, ss_shapes, ss_slots, ss_bounds, lf2_stable, ss_diy, ss_iystart=None ): """ In parallel, apply the subsampling for (ipost, lvl). Parameters ---------- mmap: memory mapping for the output source_mmap: memory mapping for the source (used if lvl == 0) ipost: the current post / channel index lvl: current level in the nested chain ss_shapes: (nx, ny) of the nested subsampled array - coords in source ss_slots: (ssl, ssh) of the nested subsampled array - as stored, flatten, in res ss_bounds: (start_x, end_x, start_y, end_y) of the nested ss arrays lf2_stable: decimation routine ss_iystart: start iy index for this parallel calc in the destination array /!\ not the source ssdiy: gap in y used for parallel calc """ ss_ny, ss_nx = ss_shapes[lvl + 1, :] # For full "subsampled" shape ss_l, ss_h = ss_slots[lvl + 1, :] # For full "subsampled" slot assert ss_ny * ss_nx == ss_h - ss_l # This // run extract slot is [iy_start:iy_end, :] ss_iyend = min(ss_iystart + ss_diy, ss_ny) ss_diy = ss_iyend - ss_iystart # The 2d shapes / extract slot at source array - we map (2n+1) -> n+1 ny, nx = ss_shapes[lvl, :] lw, hg = ss_slots[lvl, :] # This is the full "subsampled" slot iystart = 2 * ss_iystart # *2 due to the supersampling factor iyend = min(iystart + 2 * ss_diy + 1, ny) diy = iyend - iystart if lvl == 0: # Source arr is from the source_mmap source_arr = source_mmap[iystart:iyend, :, ipost] else: # Source arr is from the mmap, however a level higher l_loc = lw + iystart * nx h_loc = lw + iyend * nx source_arr = mmap[ipost, l_loc:h_loc].reshape((diy, nx)) ssl_loc = ss_l + ss_iystart * ss_nx ssh_loc = ss_l + ss_iyend * ss_nx ss2d_full, k_spanx_loc, k_spany_loc = lf2_stable(source_arr) ss2d_full = ss2d_full[:ss_diy, :] # Flatten then store in slot mmap[ipost, ssl_loc:ssh_loc] = ss2d_full.reshape(-1) if (ipost == 0) and (ss_iystart == 0): # We store data localisation information. coeff applies to the # following levels ss_bounds[(lvl + 1):, 1] += (k_spanx_loc - 1.) * ( ss_bounds[(lvl + 1):, 1] - ss_bounds[(lvl + 1):, 0] ) ss_bounds[(lvl + 1):, 3] += (k_spany_loc - 1.) * ( ss_bounds[(lvl + 1):, 3] - ss_bounds[(lvl + 1):, 2] ) # --------------- db plotting interface --------------------------------------- def get_2d_arr(self, post_index, frame, chunk_slice): """ get_2d_arr with frame-specific functionnality Parameters ---------- post_index: int the index for this post-processing field in self.plotter frame: fs.db.Frame Frame localisation for interpolation chunk_slice: 4-uplet float chunk_slice = (ix, ixx, iy, iyy) is the sub-array to reload Used only if `frame` is None: direct reloading """ assert frame is not None # Interpolated output : uses frame data ret = self.get_interpolator(frame, post_index)(*frame.pts) return ret.reshape(frame.size)
[docs] def plot(self, frame): """ Parameters ---------- frame: `fractalshades.db.Exp_frame` Defines the area to plot. Returns ------- img: `PIL.Image` The plotted image Notes ----- Plotting settings are defined by ``set_plotter`` method. """ assert self.postdb dtype = self.dtype nchannels = self.nchannels db_size = frame.db_size ret = np.empty(db_size + (nchannels,), dtype=dtype) for ic in range(nchannels): channel_ret = self.get_interpolator(frame, ic)( *frame.pts, frame.h, frame.nx ) channel_ret = channel_ret.reshape(db_size) # Numpy -> PILLOW ret[:, :, ic] = channel_ret if nchannels == 1: im = PIL.Image.fromarray(ret[:, :, 0]) else: im = PIL.Image.fromarray(ret) return im
#============================================================================== # Ad_hoc interpolating routines for exponential mapping class Multilevel_exp_interpolator: def __init__(self, a_exp, b_exp, h_exp, f_exp, f_exp_shape, f_exp_slot, a_final, b_final, h_final, f_final, f_final_shape, f_final_slot ): """ Multilevel interpolation inside 2 sets of grids: - nested set of multilvel exponential 2d grids - nested set of cartesian 2d grids for the final image (h_tot < 0) Parameters ---------- a_exp: float array of shape (lvl_exp, 2) The lower h, t bounds of the interpolation region for each exp mapping level b_exp: float array of shape (lvl_exp, 2) The upper h, t bounds of the interpolation region for each exp mapping level h_exp: float array of shape (lvl_exp, 2) The h, t grid-spacing at which f is given (2-uplet for each level) f_exp: 1d-array The base 2d-array data to be interpolated, flatened f_exp_shape : 2d-array of shape (lvl_exp, 2) The shape for level-ilvl f_exp is f_exp_shape[ilvl, :] f_exp_slot : 2d-array of shape (lvl_exp, 2) The slot for level-ilvl f_exp is f_exp_slot[ilvl, :] a_final: float array of shape (lvl_final, 2) The lower x, y bounds of the interpolation region for each cartesian mapping level (Note: y bound is identical) b_final: float array of shape (lvl_final, 2) The upper x, y bounds of the interpolation region for each cartesian mapping level (Note: y bound is identical) h_final: float array of shape (lvl_final, 2) The x, y grid-spacing at which f is given (2-uplet for each level) f_final: 1d-array The base 2d-array data to be interpolated, flatened f_final_shape : 2d-array of shape (lvl_final, 2) The shape for level-ilvl f_exp is f_exp_shape[ilvl, :] f_final_slot : 2d-array of shape (lvl_final, 2) The slot for level-ilvl f_exp is f_exp_slot[ilvl, :] """ self.a_exp = np.asarray(a_exp) self.b_exp = np.asarray(b_exp) self.h_exp = np.asarray(h_exp) self.f_exp = np.asarray(f_exp) self.f_exp_shape = np.asarray(f_exp_shape) self.f_exp_slot = np.asarray(f_exp_slot) self.a_final = np.asarray(a_final) self.b_final = np.asarray(b_final) self.h_final = np.asarray(h_final) self.f_final = np.asarray(f_final) self.f_final_shape = np.asarray(f_final_shape) self.f_final_slot = np.asarray(f_final_slot) # The numba interpolating implementations self.numba_impl = self.get_grid_impl() def get_grid_impl(self): """ Return a numba-jitted function for multilevel interpolation of 1 grid """ a_exp = self.a_exp b_exp = self.b_exp h_exp = self.h_exp f_exp = self.f_exp f_exp_shape = self.f_exp_shape f_exp_slot = self.f_exp_slot a_final = self.a_final b_final = self.b_final h_final = self.h_final f_final = self.f_final f_final_shape = self.f_final_shape f_final_slot = self.f_final_slot @numba.njit(nogil=True, parallel=False) def numba_impl(x_out, y_out, pts_h, pts_t, h_out, nx_out, f_out): # Interpolation: f_out = finterp(x_out, y_out) f_out = multilevel_interpolate( x_out, y_out, pts_h, pts_t, h_out, nx_out, f_out, a_exp, b_exp, h_exp, f_exp, f_exp_shape, f_exp_slot, a_final, b_final, h_final, f_final, f_final_shape, f_final_slot ) return f_out return numba_impl def __call__(self, pts_x, pts_y, pts_h, pts_t, pic_h, pic_nx, pts_res=None): """ Interpolates at pts_x, pts_y Parameters ---------- pts_x: 1d-array float32 x-coord of interpolating point location pts_y: 1d-array float32 y-coord of interpolating point location pts_h: 1d-array float32 h-coord of interpolating point location pts_t: 1d-array float32 t-coord of interpolating point location pic_h: float The zoom level of the frame pic_nx: int The number of point in the frame along the x-direction (used to define of local pixel size) pts_res: 1d-array float32, Optional Out array handle - if not provided, it will be created. """ assert np.ndim(pts_x) == 1 if pts_res is None: pts_res = np.empty_like(pts_x) interp = self.numba_impl pic_h32 = np.float32(pic_h) interp(pts_x, pts_y, pts_h, pts_t, pic_h32, pic_nx, pts_res) return pts_res @numba.njit(nogil=True, parallel=False) def multilevel_interpolate( x_out, y_out, pts_h, pts_t, h_out, nx_out, f_out, a_exp, b_exp, h_exp, f_exp, f_exp_shape, f_exp_slot, a_final, b_final, h_final, f_final, f_final_shape, f_final_slot ): # Interpolation: f_out = finterp(x_out, y_out, h_out) - h_out constant # x_im, y_im position in the image ([-0.5, 0.5] for x) npts = x_out.size half32 = numba.float32(0.5) log_half32 = np.log(half32) # tenth_pixw = numba.float32(0.1 / nx_out) max_lvl_ht = f_exp_shape.shape[0] - 1 max_lvl_xy = f_final_shape.shape[0] - 1 lvl_xy_cache = numba.intp(-1) # Invalid will trigger recalc lvl_ht_cache = numba.intp(-1) # Invalid will trigger recalc k_out = np.exp(h_out) for i in numba.prange(npts): x_loc = x_out[i] y_loc = y_out[i] # Note: log(sqrt(.)) == 0.5 * log(.) h_img = pts_h[i] # np.log(max(x_loc ** 2 + y_loc ** 2, tenth_pixw)) * half32 # exp mapping is defined as z -> np.exp(dh * z) h_tot = h_out + h_img - log_half32 if h_tot < 0.: # Interpolating in the final image # the lvl is linked to the zoom scale: log2(exp(h_out)) # zoom 1. -> 0, 2. -> 1, 4. -> lvl_xy = np.intp(-h_out / (log_half32) + 0.7) lvl_xy = min(max(lvl_xy, 0), max_lvl_xy) # Define the local arrays according to lvlxy_loc if lvl_xy != lvl_xy_cache: lvl_xy_cache = lvl_xy ax = a_final[lvl_xy, 0] ay = a_final[lvl_xy, 1] bx = b_final[lvl_xy, 0] by = b_final[lvl_xy, 1] hx = h_final[lvl_xy, 0] hy = h_final[lvl_xy, 1] slot_xy_l = f_final_slot[lvl_xy, 0] slot_xy_h = f_final_slot[lvl_xy, 1] xy_nx = f_final_shape[lvl_xy, 0] xy_ny = f_final_shape[lvl_xy, 1] fxy = f_final[slot_xy_l: slot_xy_h] f_loc = grid_interpolate_alt( x_loc * k_out, y_loc * k_out, fxy, ax, ay, bx, by, hx, hy, xy_nx, xy_ny ) else: # the lvl is linked to the pixel position in image lvl_ht = np.intp((h_img / log_half32) - 1.0) lvl_ht = min(max(lvl_ht, 0), max_lvl_ht) # The angle (t_tot == t_loc) t_loc = pts_t[i] # Define the local arrays according to lvlht_loc if lvl_ht != lvl_ht_cache: lvl_ht_cache = lvl_ht ah = a_exp[lvl_ht, 0] at = a_exp[lvl_ht, 1] bh = b_exp[lvl_ht, 0] bt = b_exp[lvl_ht, 1] hh = h_exp[lvl_ht, 0] ht = h_exp[lvl_ht, 1] slot_ht_l = f_exp_slot[lvl_ht, 0] slot_ht_h = f_exp_slot[lvl_ht, 1] ht_nx = f_exp_shape[lvl_ht, 0] ht_ny = f_exp_shape[lvl_ht, 1] fht = f_exp[slot_ht_l: slot_ht_h] f_loc = grid_interpolate( h_tot, t_loc, fht, ah, at, bh, bt, hh, ht, ht_nx, ht_ny ) f_out[i] = f_loc return f_out # debugging options CHECK_BOUNDS = False CLIP_BOUNDS = True @numba.njit(nogil=True) def grid_interpolate(x_out, y_out, f, ax, ay, bx, by, hx, hy, nx, ny): # Bilinear interpolation in a rectangular grid - f is passed flatten and # is of size (nx x ny) # Interpolation: f_out = finterp(x_out, y_out) if CHECK_BOUNDS: assert ax <= x_out <= bx assert ay <= y_out <= by if CLIP_BOUNDS: x_out = min(max(x_out, ax), bx) y_out = min(max(y_out, ay), by) ix, ratx = np.divmod(x_out - ax, hx) iy, raty = np.divmod(y_out - ay, hy) ix = np.intp(ix) iy = np.intp(iy) ratx /= hx raty /= hy cx0 = np.float32(1.) - ratx cx1 = ratx cy0 = np.float32(1.) - raty cy1 = raty id00 = ix * ny + iy # ix, iy id01 = id00 + 1 # ix, iy + 1 id10 = id00 + ny # ix + 1, iy id11 = id10 + 1 # ix + 1, iy + 1 f_out = ( (cx0 * cy0 * f[id00]) + (cx0 * cy1 * f[id01]) + (cx1 * cy0 * f[id10]) + (cx1 * cy1 * f[id11]) ) return f_out @numba.njit(nogil=True) def grid_interpolate_alt(x_out, y_out, f, ax, ay, bx, by, hx, hy, nx, ny): # Bilinear interpolation in a rectangular grid - f is passed flatten and # is of size (ny x nx) / PILLOW order convention # Interpolation: f_out = finterp(x_out, y_out) if CHECK_BOUNDS: assert ax <= x_out <= bx assert ay <= y_out <= by if CLIP_BOUNDS: x_out = min(max(x_out, ax), bx) y_out = min(max(y_out, ay), by) ix_float, ratx = np.divmod(x_out - ax, hx) iy_float, raty = np.divmod(by - y_out, hy) ix = np.intp(ix_float) iy = np.intp(iy_float) ratx /= hx raty /= hy cx0 = np.float32(1.) - ratx cx1 = ratx cy0 = np.float32(1.) - raty cy1 = raty id00 = iy * nx + ix # iy, ix id01 = id00 + 1 # iy, ix + 1 id10 = id00 + nx # iy + 1, ix id11 = id10 + 1 # iy + 1, ix f_out = ( (cy0 * cx0 * f[id00]) + (cy0 * cx1 * f[id01]) + (cy1 * cx0 * f[id10]) + (cy1 * cx1 * f[id11]) ) return f_out