Jekyll2022-05-14T14:11:52+00:00https://rikvoorhaar.com/feed.xmlRik VoorhaarRik VoorhaarMachine learning with discretized functions and tensors2022-03-10T00:00:00+00:002022-03-10T00:00:00+00:00https://rikvoorhaar.com/discrete-function-tensor<p>In <a href="https://arxiv.org/abs/2203.04352">my new paper together with my supervisor</a>, we explain how to use
discretized functions and tensors to do supervised machine learning. A discretized function is just a function
defined on some grid, taking a constant value on each grid cell. We can describe such a function using a
multi-dimensional array (i.e. a tensor), and we can learn this tensor using data. This results in a new and
interesting type of machine learning model.</p>
<h2 id="what-is-machine-learning">What is machine learning?</h2>
<p>Before we dive into the details of our new type of machine learning model, let’s sit back for a moment and
think: <em>what is machine learning in the first place?</em> Machine learning is all about <em>learning from data</em>. More
specifically in <em>supervised machine learning</em> we are given some <em>data points</em> \(X = (x_1,\dots,x_N)\), all lying
in \(\mathbb R^d\), together with <em>labels</em> \(y=(y_1,\dots,y_N)\) which are just numbers. We then want to find some
function \(f\colon \mathbb R^d\to \mathbb R\) such that \(f(x_i)\approx y_i\) for all \(i\), and such that \(f\)
<em>generalizes well to new data</em>. Or rather, we want to minimize a loss function, for example the least-squares
loss</p>
\[L(f) = \sum_{i=1}^N (f(x_i)-y)^2.\]
<p>This is obviously an ill-posed problem, and there are two main issues with it:</p>
<ol>
<li>What <em>kind</em> of functions \(f\) are we allowed to choose?</li>
<li>What does it mean to <em>generalize</em> well on new data?</li>
</ol>
<p>The first issue has no general solution. We <em>choose</em> some class of functions, usually that depend on some set
of parameters \(\theta\). For example, if we want to fit a quadratic function to our data we only look at
quadratic functions</p>
\[f_{(a,b,c)}(x) = a + bx +cx^2,\]
<p>and our set of parameters is \(\theta=\{a,b,c\}\). Then we minimize the loss over this set of parameters, i.e.
we solve the minimization problem:</p>
\[\min_{a,b,c} \sum_{i=1}^N (a+ bx_i+cx_i^2-y_i)^2.\]
<p>There are many parametric families \(f_\theta\) of functions we can choose from, and many different ways to
solve the corresponding minimization problem. For example, we can choose \(f_\theta\) to be neural networks
<em>with some specified layer sizes</em>, or a random forest with a fixed number of trees and fixed maximum tree
depth. Note that we should strictly speaking always specify hyperparameters like the size of the layers of a
neural network, since those hyperparameters determine what kind of parameters \(\theta\) we are going to
optimize. That is, hyperparameters affect the parametric family of functions that we are going to optimize.</p>
<p>The second issue, generalization, is typically solved through <em>cross-validation</em>. If we want to know whether
the function \(f_\theta\) we learned generalizes well to new data points, we should just keep part of the data
“hidden” during the training (the <em>test data</em>). After training we then evaluate our trained function on this
hidden data, and we record the loss function on this test data to obtain the <em>test loss</em>. The test loss is
then a good measure of how well the function can generalize to new data, and it is very useful if we want to
compare several different functions trained on the same data. Typically we use a third set of data, the
<em>validation</em> dataset for optimizing hyperparameters for example, see <a href="/validation-size/">my blog post on the topic</a>.</p>
<h2 id="discretized-functions">Discretized functions</h2>
<p>Keeping the general problem of machine learning in mind, let’s consider a particular class of parametric
functions: <em>discretized functions on a grid</em>. To understand this class of functions, we first look at the 1D
case. Let’s take the interval \([0,1]\), and chop it up into \(n\) equal pieces:</p>
\[[0,1] = [0,1/n]\cup[1/n,2/n]\cup\dots\cup[(n-1)/n,1]\]
<p>A discretized function is then one that <em>takes a constant value on each subinterval</em>. For example, below is a
discretized version of a sine function:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="n">np</span>
<span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="n">plt</span>
<span class="o">%</span><span class="n">matplotlib</span> <span class="n">inline</span>
<span class="n">DEFAULT_FIGSIZE</span> <span class="o">=</span> <span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mi">6</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="n">DEFAULT_FIGSIZE</span><span class="p">))</span>
<span class="n">num_intervals</span> <span class="o">=</span> <span class="mi">10</span>
<span class="n">num_plotpoints</span> <span class="o">=</span> <span class="mi">1000</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span> <span class="o">-</span> <span class="mi">1</span> <span class="o">/</span> <span class="n">num_plotpoints</span><span class="p">,</span> <span class="n">num_plotpoints</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">f</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
<span class="k">return</span> <span class="n">np</span><span class="p">.</span><span class="n">sin</span><span class="p">(</span><span class="n">x</span> <span class="o">*</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">np</span><span class="p">.</span><span class="n">pi</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">f</span><span class="p">(</span><span class="n">x</span><span class="p">),</span> <span class="n">label</span><span class="o">=</span><span class="s">"original function"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span>
<span class="n">x</span><span class="p">,</span>
<span class="n">f</span><span class="p">((</span><span class="n">np</span><span class="p">.</span><span class="n">floor</span><span class="p">(</span><span class="n">x</span> <span class="o">*</span> <span class="n">num_intervals</span><span class="p">)</span> <span class="o">+</span> <span class="mf">0.5</span><span class="p">)</span> <span class="o">/</span> <span class="n">num_intervals</span><span class="p">),</span>
<span class="n">label</span><span class="o">=</span><span class="s">"discretized function"</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">legend</span><span class="p">();</span>
</code></pre></div></div>
<p><img src="/imgs/discrete-function-tensor/tensor-completion_2_0.svg" alt="svg" /></p>
<p>Note that if we divide the interval into \(n\) pieces, then we need \(n\) parameters to describe the discretized function \(f_\theta\).</p>
<p>In the 2D case we instead divide the square \([0,1]^2\) into a grid, and demand that a discretized function is <em>constant on each grid cell</em>. If we use \(n\) grid cells for each axis, this gives us \(n^2\) parameters. Let’s see what a discretized function looks like in a 3D plot:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">fig</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="n">DEFAULT_FIGSIZE</span><span class="p">))</span>
<span class="n">num_plotpoints</span> <span class="o">=</span> <span class="mi">200</span>
<span class="n">num_intervals</span> <span class="o">=</span> <span class="mi">5</span>
<span class="k">def</span> <span class="nf">f</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">Y</span><span class="p">):</span>
<span class="k">return</span> <span class="n">X</span> <span class="o">+</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">Y</span> <span class="o">+</span> <span class="mf">1.5</span> <span class="o">*</span> <span class="p">((</span><span class="n">X</span> <span class="o">-</span> <span class="mf">0.5</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span> <span class="o">+</span> <span class="p">(</span><span class="n">Y</span> <span class="o">-</span> <span class="mf">0.5</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span>
<span class="n">X_plotpoints</span><span class="p">,</span> <span class="n">Y_plotpoints</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">meshgrid</span><span class="p">(</span>
<span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span> <span class="o">-</span> <span class="mi">1</span> <span class="o">/</span> <span class="n">num_plotpoints</span><span class="p">,</span> <span class="n">num_plotpoints</span><span class="p">),</span>
<span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span> <span class="o">-</span> <span class="mi">1</span> <span class="o">/</span> <span class="n">num_plotpoints</span><span class="p">,</span> <span class="n">num_plotpoints</span><span class="p">),</span>
<span class="p">)</span>
<span class="c1"># Smooth plot
</span><span class="n">Z_smooth</span> <span class="o">=</span> <span class="n">f</span><span class="p">(</span><span class="n">X_plotpoints</span><span class="p">,</span> <span class="n">Y_plotpoints</span><span class="p">)</span>
<span class="n">ax</span> <span class="o">=</span> <span class="n">fig</span><span class="p">.</span><span class="n">add_subplot</span><span class="p">(</span><span class="mi">121</span><span class="p">,</span> <span class="n">projection</span><span class="o">=</span><span class="s">"3d"</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">plot_surface</span><span class="p">(</span><span class="n">X_plotpoints</span><span class="p">,</span> <span class="n">Y_plotpoints</span><span class="p">,</span> <span class="n">Z_smooth</span><span class="p">,</span> <span class="n">cmap</span><span class="o">=</span><span class="s">"inferno"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">title</span><span class="p">(</span><span class="s">"original function"</span><span class="p">)</span>
<span class="c1"># Discrete plot
</span><span class="n">X_discrete</span> <span class="o">=</span> <span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">floor</span><span class="p">(</span><span class="n">X_plotpoints</span> <span class="o">*</span> <span class="n">num_intervals</span><span class="p">)</span> <span class="o">+</span> <span class="mf">0.5</span><span class="p">)</span> <span class="o">/</span> <span class="n">num_intervals</span>
<span class="n">Y_discrete</span> <span class="o">=</span> <span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">floor</span><span class="p">(</span><span class="n">Y_plotpoints</span> <span class="o">*</span> <span class="n">num_intervals</span><span class="p">)</span> <span class="o">+</span> <span class="mf">0.5</span><span class="p">)</span> <span class="o">/</span> <span class="n">num_intervals</span>
<span class="n">Z_discrete</span> <span class="o">=</span> <span class="n">f</span><span class="p">(</span><span class="n">X_discrete</span><span class="p">,</span> <span class="n">Y_discrete</span><span class="p">)</span>
<span class="n">ax</span> <span class="o">=</span> <span class="n">fig</span><span class="p">.</span><span class="n">add_subplot</span><span class="p">(</span><span class="mi">122</span><span class="p">,</span> <span class="n">projection</span><span class="o">=</span><span class="s">"3d"</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">plot_surface</span><span class="p">(</span><span class="n">X_plotpoints</span><span class="p">,</span> <span class="n">Y_plotpoints</span><span class="p">,</span> <span class="n">Z_discrete</span><span class="p">,</span> <span class="n">cmap</span><span class="o">=</span><span class="s">"inferno"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">title</span><span class="p">(</span><span class="s">"discretized function"</span><span class="p">);</span>
</code></pre></div></div>
<p><img src="/imgs/discrete-function-tensor/tensor-completion_4_0.svg" alt="svg" /></p>
<h2 id="learning-2d-functions-matrix-completion">Learning 2D functions: matrix completion</h2>
<p>Before diving into higher-dimensional versions of discretized functions, let’s think about how we would solve
the learning problem. As mentioned, we have \(n^2\) parameters, and we can encode this using an \(n\times n\)
matrix \(\Theta\). We are doing supervised machine learning, so we have data points
\(((x_1,y_1),\dots,(x_N,y_N))\) and corresponding labels \((z_1,\dots,z_N)\). Each data point \((x_i,y_i)\)
correspond to some entry \((j,k)\) in the matrix \(\Theta\); this is simply determined by the specific grid cell
the data point happens to fall in.</p>
<p>If the points \(((x_{i_1},y_{i_1}),\dots,(x_{i_m},y_{i_m}))\) all fall into the grid cell \((j,k)\), then we can
define \(\Theta[j,k]\) simply by the mean value of the labels for these points;</p>
\[\Theta[j,k] = \frac{1}{m} \sum_{a=1}^n y_a\]
<p>But what do we do if we have no training data corresponding to some entry \(\Theta[j,k]\)? Then the only thing
we can do is make an educated guess based on the entries of the matrix we <em>do</em> know. This is the <em>matrix
completion problem</em>; we are presented with a matrix with some known entries, and we are tasked to find good
values for the unknown entries. We described this problem in some detail <a href="/low-rank-matrix/">in the previous blog
post</a>.</p>
<p>The main takeaway is this: to solve the matrix completion problem, we need to assume that the matrix has some
extra structure. We typically assume that the matrix is of low rank \(r\), that is, we can write \(\Theta\) as a
product \(\Theta=A B\) where \(A,B\) are of size \(n\times r\) and \(r\times n\) respectively. Intuitively, this is a
useful assumption because now we only have to learn \(2nr\) parameters instead of \(n^2\). If \(r\) is much smaller
than \(n\), then this is a clear gain.</p>
<p>From the perspective of machine learning, this changes the class of functions we are considering. Instead of
<em>all</em> discretized functions on our \(n\times n\) grid inside \([0,1]^2\), we now consider only those functions
described by a matrix \(\Theta=AB\) that has rank at most \(r\). This also changes the parameters; instead of
\(n^2\) parameters, we now only consider \(2nr^2\) parameters describing the two matrices \(A,B\).</p>
<p>Real data is often not uniform, so unless we use a very coarse grid, some entries of \(\Theta[j,k]\) are always
going to be unknown. For example below we show some more realistic data, with the same function as before plus
some noise. The color indicates the value of the function \(f\) we’re trying to learn.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">num_intervals</span> <span class="o">=</span> <span class="mi">8</span>
<span class="n">N</span> <span class="o">=</span> <span class="mi">50</span>
<span class="c1"># A function to make somewhat realistic looking 2D data
</span><span class="k">def</span> <span class="nf">non_uniform_data</span><span class="p">(</span><span class="n">N</span><span class="p">):</span>
<span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">seed</span><span class="p">(</span><span class="mi">179</span><span class="p">)</span>
<span class="n">X</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="n">N</span><span class="p">)</span>
<span class="n">X</span> <span class="o">=</span> <span class="p">(</span><span class="n">X</span> <span class="o">+</span> <span class="mf">0.5</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span>
<span class="n">X</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">mod</span><span class="p">(</span><span class="n">X</span> <span class="o">**</span> <span class="mi">5</span> <span class="o">+</span> <span class="mf">0.2</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">Y</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="n">N</span><span class="p">)</span>
<span class="n">Y</span> <span class="o">=</span> <span class="p">(</span><span class="n">Y</span> <span class="o">+</span> <span class="mf">0.5</span><span class="p">)</span> <span class="o">**</span> <span class="mi">3</span>
<span class="n">Y</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">sin</span><span class="p">(</span><span class="n">Y</span> <span class="o">*</span> <span class="mf">0.2</span> <span class="o">*</span> <span class="n">np</span><span class="p">.</span><span class="n">pi</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">+</span> <span class="mi">1</span>
<span class="n">Y</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">mod</span><span class="p">(</span><span class="n">Y</span> <span class="o">+</span> <span class="mf">0.6</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">X</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">mod</span><span class="p">(</span><span class="n">X</span> <span class="o">+</span> <span class="mi">3</span> <span class="o">*</span> <span class="n">Y</span> <span class="o">+</span> <span class="mf">0.5</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">Y</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">mod</span><span class="p">(</span><span class="mf">0.3</span> <span class="o">*</span> <span class="n">X</span> <span class="o">+</span> <span class="mf">1.3</span> <span class="o">*</span> <span class="n">Y</span> <span class="o">+</span> <span class="mf">0.5</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">X</span> <span class="o">=</span> <span class="n">X</span> <span class="o">**</span> <span class="mi">2</span> <span class="o">+</span> <span class="mf">0.4</span>
<span class="n">X</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">mod</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">Y</span> <span class="o">=</span> <span class="n">Y</span> <span class="o">**</span> <span class="mi">2</span> <span class="o">+</span> <span class="mf">0.5</span>
<span class="n">Y</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">mod</span><span class="p">(</span><span class="n">Y</span> <span class="o">+</span> <span class="n">X</span> <span class="o">+</span> <span class="mf">0.4</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="k">return</span> <span class="n">X</span><span class="p">,</span> <span class="n">Y</span>
<span class="c1"># The function we want to model
</span><span class="k">def</span> <span class="nf">f</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">Y</span><span class="p">):</span>
<span class="k">return</span> <span class="n">X</span> <span class="o">+</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">Y</span> <span class="o">+</span> <span class="mf">1.5</span> <span class="o">*</span> <span class="p">((</span><span class="n">X</span> <span class="o">-</span> <span class="mf">0.5</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span> <span class="o">+</span> <span class="p">(</span><span class="n">Y</span> <span class="o">-</span> <span class="mf">0.5</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span>
<span class="n">X_train</span><span class="p">,</span> <span class="n">Y_train</span> <span class="o">=</span> <span class="n">non_uniform_data</span><span class="p">(</span><span class="n">N</span><span class="p">)</span>
<span class="n">X_test</span><span class="p">,</span> <span class="n">Y_test</span> <span class="o">=</span> <span class="n">non_uniform_data</span><span class="p">(</span><span class="n">N</span><span class="p">)</span>
<span class="n">Z_train</span> <span class="o">=</span> <span class="n">f</span><span class="p">(</span><span class="n">X_train</span><span class="p">,</span> <span class="n">Y_train</span><span class="p">)</span> <span class="o">+</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">normal</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="n">X_train</span><span class="p">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">*</span> <span class="mf">0.2</span>
<span class="n">Z_test</span> <span class="o">=</span> <span class="n">f</span><span class="p">(</span><span class="n">X_test</span><span class="p">,</span> <span class="n">Y_test</span><span class="p">)</span> <span class="o">+</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">normal</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="n">X_test</span><span class="p">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">*</span> <span class="mf">0.2</span>
<span class="n">plt</span><span class="p">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">7</span><span class="p">,</span> <span class="mi">6</span><span class="p">))</span>
<span class="n">plt</span><span class="p">.</span><span class="n">scatter</span><span class="p">(</span><span class="n">X_train</span><span class="p">,</span> <span class="n">Y_train</span><span class="p">,</span> <span class="n">c</span><span class="o">=</span><span class="n">Z_train</span><span class="p">,</span> <span class="n">s</span><span class="o">=</span><span class="mi">50</span><span class="p">,</span> <span class="n">cmap</span><span class="o">=</span><span class="s">"inferno"</span><span class="p">,</span> <span class="n">zorder</span><span class="o">=</span><span class="mi">3</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">colorbar</span><span class="p">()</span>
<span class="c1"># Plot a grid
</span><span class="n">X_grid</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="mi">1</span> <span class="o">/</span> <span class="n">num_intervals</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">num_intervals</span><span class="p">)</span>
<span class="n">Y_grid</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="mi">1</span> <span class="o">/</span> <span class="n">num_intervals</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">num_intervals</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">xlim</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">ylim</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="k">for</span> <span class="n">perc</span> <span class="ow">in</span> <span class="n">X_grid</span><span class="p">:</span>
<span class="n">plt</span><span class="p">.</span><span class="n">axvline</span><span class="p">(</span><span class="n">perc</span><span class="p">,</span> <span class="n">c</span><span class="o">=</span><span class="s">"gray"</span><span class="p">)</span>
<span class="k">for</span> <span class="n">perc</span> <span class="ow">in</span> <span class="n">Y_grid</span><span class="p">:</span>
<span class="n">plt</span><span class="p">.</span><span class="n">axhline</span><span class="p">(</span><span class="n">perc</span><span class="p">,</span> <span class="n">c</span><span class="o">=</span><span class="s">"gray"</span><span class="p">)</span>
</code></pre></div></div>
<p><img src="/imgs/discrete-function-tensor/tensor-completion_6_0.svg" alt="svg" /></p>
<p>We plotted an 8x8 grid on top of the data. We can see that in some grid squares we have a lot of data points, whereas in other squares there’s no data at all. Let’s try to fit a discretized function described by an 8x8 matrix of rank 3 to this data. We can do this using the <a href="https://github.com/RikVoorhaar/ttml">ttml</a> package I developed.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">ttml.tensor_train</span> <span class="kn">import</span> <span class="n">TensorTrain</span>
<span class="kn">from</span> <span class="nn">ttml.tt_rlinesearch</span> <span class="kn">import</span> <span class="n">TTLS</span>
<span class="n">rank</span> <span class="o">=</span> <span class="mi">3</span>
<span class="c1"># Indices of the matrix Theta for each data point
</span><span class="n">idx_train</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">stack</span><span class="p">(</span>
<span class="p">[</span><span class="n">np</span><span class="p">.</span><span class="n">searchsorted</span><span class="p">(</span><span class="n">X_grid</span><span class="p">,</span> <span class="n">X_train</span><span class="p">),</span> <span class="n">np</span><span class="p">.</span><span class="n">searchsorted</span><span class="p">(</span><span class="n">Y_grid</span><span class="p">,</span> <span class="n">Y_train</span><span class="p">)],</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span>
<span class="p">)</span>
<span class="n">idx_test</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">stack</span><span class="p">(</span>
<span class="p">[</span><span class="n">np</span><span class="p">.</span><span class="n">searchsorted</span><span class="p">(</span><span class="n">X_grid</span><span class="p">,</span> <span class="n">X_test</span><span class="p">),</span> <span class="n">np</span><span class="p">.</span><span class="n">searchsorted</span><span class="p">(</span><span class="n">Y_grid</span><span class="p">,</span> <span class="n">Y_test</span><span class="p">)],</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span>
<span class="p">)</span>
<span class="c1"># Initialize random rank 3 matrix
</span><span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">seed</span><span class="p">(</span><span class="mi">179</span><span class="p">)</span>
<span class="n">low_rank_matrix</span> <span class="o">=</span> <span class="n">TensorTrain</span><span class="p">.</span><span class="n">random</span><span class="p">((</span><span class="n">num_intervals</span><span class="p">,</span> <span class="n">num_intervals</span><span class="p">),</span> <span class="n">rank</span><span class="p">)</span>
<span class="c1"># Optimize the matrix using iterative method
</span><span class="n">optimizer</span> <span class="o">=</span> <span class="n">TTLS</span><span class="p">(</span><span class="n">low_rank_matrix</span><span class="p">,</span> <span class="n">Z_train</span><span class="p">,</span> <span class="n">idx_train</span><span class="p">)</span>
<span class="n">train_losses</span> <span class="o">=</span> <span class="p">[]</span>
<span class="n">test_losses</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">50</span><span class="p">):</span>
<span class="n">train_loss</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">optimizer</span><span class="p">.</span><span class="n">step</span><span class="p">()</span>
<span class="n">train_losses</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">train_loss</span><span class="p">)</span>
<span class="n">test_loss</span> <span class="o">=</span> <span class="n">optimizer</span><span class="p">.</span><span class="n">loss</span><span class="p">(</span><span class="n">y</span><span class="o">=</span><span class="n">Z_test</span><span class="p">,</span> <span class="n">idx</span><span class="o">=</span><span class="n">idx_test</span><span class="p">)</span>
<span class="n">test_losses</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">test_loss</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="n">DEFAULT_FIGSIZE</span><span class="p">))</span>
<span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">train_losses</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s">"Training loss"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">test_losses</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s">"Test loss"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">xlabel</span><span class="p">(</span><span class="s">"Number of iterations"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">ylabel</span><span class="p">(</span><span class="s">"Loss"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">legend</span><span class="p">()</span>
<span class="n">plt</span><span class="p">.</span><span class="n">yscale</span><span class="p">(</span><span class="s">"log"</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Final training loss: </span><span class="si">{</span><span class="n">train_loss</span><span class="p">:.</span><span class="mi">4</span><span class="n">f</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Final test loss: </span><span class="si">{</span><span class="n">test_loss</span><span class="p">:.</span><span class="mi">4</span><span class="n">f</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
</code></pre></div></div>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Final training loss: 0.0252
Final test loss: 0.0424
</code></pre></div></div>
<p><img src="/imgs/discrete-function-tensor/tensor-completion_8_1.svg" alt="svg" /></p>
<p>Above we see how the train and test loss develops during training. At first both train and test loss decrease
rapidly. Then both train and test loss start to decrease much more slowly, and training loss is less than test
loss. This means that the model overfits on the training data, but this is not necessarily a problem; the
question is how much it overfits compared to other models. To see how good this model is, let’s compare it to
a random forest.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">sklearn.ensemble</span> <span class="kn">import</span> <span class="n">RandomForestRegressor</span>
<span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">seed</span><span class="p">(</span><span class="mi">179</span><span class="p">)</span>
<span class="n">forest</span> <span class="o">=</span> <span class="n">RandomForestRegressor</span><span class="p">()</span>
<span class="n">forest</span><span class="p">.</span><span class="n">fit</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">stack</span><span class="p">([</span><span class="n">X_train</span><span class="p">,</span> <span class="n">Y_train</span><span class="p">],</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">),</span> <span class="n">Z_train</span><span class="p">)</span>
<span class="n">Z_pred</span> <span class="o">=</span> <span class="n">forest</span><span class="p">.</span><span class="n">predict</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">stack</span><span class="p">([</span><span class="n">X_test</span><span class="p">,</span> <span class="n">Y_test</span><span class="p">],</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">))</span>
<span class="n">test_loss</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">mean</span><span class="p">((</span><span class="n">Z_pred</span> <span class="o">-</span> <span class="n">Z_test</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Random forest test loss: </span><span class="si">{</span><span class="n">test_loss</span><span class="p">:.</span><span class="mi">4</span><span class="n">f</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
</code></pre></div></div>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Random forest test loss: 0.0369
</code></pre></div></div>
<p>We see that the random forest is a little better than the discretized function. And in fact, most standard machine learning estimators will beat a discretized function like this. This is essentially because the discretized function is very simple, and more complicated estimators can do a better job describing the data.</p>
<p>Does this mean that we should stop caring about the discretized function? Test loss is not the only criterion we should use to compare different estimators. Discretized functions like these have two big advantages:</p>
<ol>
<li>They use very few parameters compared to many common machine learning estimators.</li>
<li>Making new predictions is <em>very</em> fast. Much faster in fact than most other machine learning estimators.</li>
</ol>
<p>This makes them excellent candidates for low-memory applications. For example, we may want to implement a machine learning model for a very cheap consumer device. If we don’t need extreme accuracy, and we pre-train the model on a more powerful device, discretized functions can be a very attractive option.</p>
<h2 id="discretized-functions-in-higher-dimensions-tensor-trains">Discretized functions in higher dimensions: tensor trains</h2>
<p>The generalization to \(d\)-dimensions is now straightforward; we take a \(d\)-dimensional grid on \([0,1]^d\), with
\(n\) subdivisions in each axis. Then we specify the value of our function \(f_\Theta\) on each of the \(n^d\) grid
cells. These \(n^d\) values form a <em>tensor</em> \(\Theta\), i.e. a \(d\)-dimensional array. We access the entries of
\(\Theta\) with a \(d\)-tuple of indices \(\Theta[i_1,i_2,\dots,i_d]\).</p>
<p>This suffers from the same problems as in the 2D case; the tensor \(\Theta\) is really big, and during training
we would need at least one data point for each entry of the tensor. But the situation is even worse, even
storing \(\Theta\) can be prohibitively expensive. For example, if \(d=10\) and \(n=20\); then we would need about
82 TB just to store the tensor! In fact, \(n=20\) grid points in each direction is not even that much, so in
practice we might need a much bigger tensor still.</p>
<p>In the 2D case we solved this problem by storing the matrix as the product of two smaller matrices. In the 2D
case this doesn’t actually save that much on memory, and we mainly did it so that we can solve the matrix
completion problem; that is, so that we can actually fit the discretized function to data. In higher
dimensions however, storing the tensor in the right way can save immense amounts of space.</p>
<p>In the 2D case, we store matrices as a low rank matrix; as a product of two smaller matrices. But what is the
correct analogue of ‘low rank’ for tensors? Unfortunately (or fortunately), there are many answers to this
question. There are many ‘low rank tensor formats’, all with very different properties. We will be focusing on
<em>tensor trains</em>. A tensor train decomposition of an \(n_1\times n_2\times \dots \times n_d\) tensor \(\Theta\)
consists of a set of \(d\) <em>cores</em> \(C_k\) of shape \(r_{k-1}\times n_k \times r_k\), where \((r_1,\dots,r_{d-1})\)
are the <em>ranks</em> of the tensor train. Using these cores we can then express the entries of \(\Theta\) using the
following formula:</p>
\[\Theta[i_1,\dots,i_d] = \sum_{k_1,\dots,k_{d-1}}C_1[1,i_1,k_1]C_2[k_1,i_2,k_2]\cdots C_{d-1}[k_{d-2},i_{d-1},k_{d-1}]C_d[k_{d-1},i_{d},1]\]
<p>This may look intimidating, but the idea is actually quite simple. We should think of the core \(C_{k}\) as a
<em>collection</em> of \(n_k\) matrices \((C_k[1],\dots,C_k[n_k])\), each of shape \(r_{k-1}\times r_k\). The index \(i_k\)
then <em>selects</em> which of these matrices to use. The first and last cores are special, by convention
\(r_0=r_d=1\), this means that \(C_1\) is a collection of \(1\times r_1\) matrices, i.e. (row) vectors. Similarly,
\(C_d\) is a collection of \(r_{d-1}\times 1\) matrices, i.e. (column) vectors. Thus each entry of \(\Theta\) is
determined by a product like this:</p>
<blockquote>
<p>row vector * matrix * matrix * … * matrix * column vector</p>
</blockquote>
<p>The result is a number, since a row/column vector times a matrix is a row/column vector, and the product of a
row and column vector is just a number. In fact, if we think about it, this is exactly how a low-rank matrix
decomposition works as well. If we write a matrix \(\Theta = AB\), then</p>
\[\Theta[i,j]=\sum_k A[i,k] B[k,j] = A[i,:]\cdot B[:,j].\]
<p>Here \(A[i,:]\) is a <em>row</em> of \(A\), and \(B[:,j]\) is a <em>column</em> of \(B\). In other words, \(A\)
is just a collection of row vectors, and \(B\) is just a collection of column vectors. Then to obtain an entry
\(\Theta[i,j]\), we select the \(i\text{th}\) row of \(A\) and the \(j\text{th}\) column of \(B\) and take the product.</p>
<p>In summary, a tensor train is a way to cheaply store large tensors. Assuming all ranks \((r_1,\dots,r_{d-1})\)
are the same, a tensor train requires \(O(dr^2n)\) entries to store a tensor with \(O(n^d)\) entries; a huge gain
if \(d\) and \(n\) are big. For context, if \(d=10\), \(n=20\), and \(r=10\) then instead of 82 TB we just need 131 KB
to store the tensor; that’s about 9 orders of magnitude cheaper! Furthermore, computing entries of this tensor
is cheap; it’s just a couple matrix-vector products.</p>
<p>There is obviously a catch to this. Just like not every matrix is low-rank, not every tensor can be
represented by a low-rank tensor train. The point, however, is that tensor trains can efficiently represent
many tensors that we <em>do</em> care about. In particular, they are good at representing the tensors required for
discretized functions.</p>
<h2 id="learning-discretized-functions-tensor-completion">Learning discretized functions: tensor completion</h2>
<p>How can we learn a discretized function \([0,1]^d\to \mathbb R\) represented by a tensor train? Like in the
matrix case, many entries of the the tensor are unobserved, and we have to <em>complete</em> these entries based on
the entries that we <em>can</em> estimate. In <a href="/low-rank-matrix">my post on matrix completion</a> we have seen that even
the matrix case is tricky, and there are many algorithms to solve the problem. One thing these algorithms have
in common is that they are iterative algorithms minimizing some loss function. Let’s derive such an algorithm
for <em>tensor train completion</em>.</p>
<p>First of all, what is the loss function we want to minimize during training? It’s simply the least squares
loss:</p>
\[L(\Theta) = \sum_{j=1}^N(f_\Theta(x_j) - y_j)^2\]
<p>Each data point \(x_j\in [0,1]^d\) fits into some grid cell given by index \((i_1[j],i_2[j],\dots,i_d[j])\), so
using the definition of the tensor train the loss \(L(\Theta)\) becomes</p>
\[\begin{align*}
L(\Theta) &= \sum_{j=1}^N (\Theta[i_1[j],i_2[j],\dots,i_d[j]] - y_j)^2\\
&= \sum_{j=1}^N(C_1[1,i_1[j],:]C_2[:,i_2[j],:]\cdots C_d[:,i_d[j],1] - y_j)^2
\end{align*}\]
<p>A straightforward approach to minimizing \(L(\Theta)\) is to just use <em>gradient descent</em>. We could compute the
derivatives with respect to each of the cores \(C_i\) and just update the cores using this derivative. This is,
however, very slow. There are two reasons for this, but they are a bit subtle:</p>
<ol>
<li><em>There is a lot of curvature.</em> In gradient descent, the size of step we can optimally take is depended on
how big the <em>second derivatives</em> of a function are (the <em>‘curvature’</em>). The derivative of a function is the
<em>best linear approximation</em> of a function, and gradient descent works faster if this linear approximation
is a good approximation of the function. In this case, the function we are trying to optimize is <em>very
non-linear</em>, and any linear approximation is going to be very bad. Therefore we are forced to take really
tiny steps during gradient descent, and convergence is going to be very slow.</li>
<li><em>There are a lot of symmetries.</em> For example we can replace \(C_i\) and \(C_{i+1}\) with \(C_i M\) and
\(A^{-1}C_{i+1}\) for any matrix \(A\). Gradient descent ‘doesn’t know’ about these symmetries, and keeps
updating \(\Theta\) in directions that doesn’t affect \(L(\Theta)\).</li>
</ol>
<p>To efficiently optimize \(L(\theta)\), we can’t just use gradient descent as-is, and we are forced to walk a
different route. While \(L(\Theta)\) is very non-linear as function of the tensor train cores \(C_i\), it is only
quadratic in the <em>entries</em> of \(L(\Theta)\), and we can easily compute its derivative:</p>
\[\nabla_{\Theta}L(\Theta) = 2\sum_{j=1}^N (\Theta[i_1[j],i_2[j],\dots,i_d[j]] -
y_j)E(i_1[j],i_2[j],\dots,i_d[j]),\]
<p>where \(E(i_1,i_2,\dots,i_d)\) denotes a sparse tensor that’s zero in all entries <em>except</em> \((i_1,\dots,i_d)\)
where it takes value \(1\). In other words, \(\nabla_{\Theta}L(\Theta)\) is a <em>sparse tensor</em> that is both simple
and cheap to compute; it just requires sampling at most \(N\) entries of \(\Theta\). For gradient descent we would
then update \(\Theta\) by \(\Theta-\alpha \nabla_{\Theta}L(\Theta)\) with \(\alpha\) the stepsize. Unfortunately,
this expression is not a tensor train. However, we can try to <em>approximate</em>
\(\Theta-\alpha \nabla_{\Theta}L(\Theta)\) by a tensor train of the same rank as \(\Theta\).</p>
<p>Recall that we can approximate a matrix \(A\) by a rank \(r\) matrix by using the <em>truncated SVD</em> of \(A\). In fact
this is the best-possible approximation of \(A\) by a rank \(\leq r\) matrix. There is a similar procedure for
tensor trains; we can approximate a tensor \(\Theta\) by a rank \((r_1,\dots,r_{d-1})\) tensor train using the
TT-SVD procedure. While this is not the <em>best</em> approximation of \(\Theta\) by such a tensor train, it is
<em>‘quasi-optimal’</em> and pretty good in practice. The details of the TT-SVD procedure are a little involved, so
let’s leave it as a black box. We now have the following iterative procedure for optimizing \(L(\Theta)\):</p>
\[\Theta_{k+1} \leftarrow \operatorname{TT-SVD}(\Theta_{k}-\alpha \nabla_{\Theta}L(\Theta_k) )\]
<p>If you’re familiar with optimizing neural networks, you might notice that this procedure could work very well
with <em>stochastic gradient descent</em>. Indeed \(\nabla_{\Theta}L(\Theta)\) is a sum over all the data points, so we
can just pick a subset of data points (a minibatch) to obtain a stochastic gradient. The reason we would want
to do this is that we have so many data points that the cost of each step is dominated by computing the
gradient. In this situation this is however not true, and the cost is dominated by the TT-SVD procedure. We
therefore stick to more classical gradient descent methods. In particular, the function \(L(\theta)\) can be
optimized well with conjugate gradient descent using Armijo backtracking line search.</p>
<h2 id="discretized-functions-in-practice">Discretized functions in practice</h2>
<p>Let’s now see all of this in practice. Let’s train a discretized function \(f_\Theta\) represented by a tensor
train on some data using the technique described above. We will do this on a real dataset: the <a href="https://archive.ics.uci.edu/ml/datasets/airfoil+self-noise">airfoil
self-noise dataset</a>. This NASA dataset contains
experimental data about the self-noise of airfoils in a wind tunnel, originally used to optimize wing shapes.
We can do the fitting and optimization using my <code class="language-plaintext highlighter-rouge">ttml</code> package. Let’s use a rank 5 tensor train with 10 grid
points for each feature.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">pandas</span> <span class="k">as</span> <span class="n">pd</span>
<span class="kn">from</span> <span class="nn">sklearn.preprocessing</span> <span class="kn">import</span> <span class="n">MinMaxScaler</span>
<span class="kn">from</span> <span class="nn">sklearn.model_selection</span> <span class="kn">import</span> <span class="n">train_test_split</span>
<span class="c1"># Load the data
</span><span class="n">airfoil_data</span> <span class="o">=</span> <span class="n">pd</span><span class="p">.</span><span class="n">read_csv</span><span class="p">(</span>
<span class="s">"airfoil_self_noise.dat"</span><span class="p">,</span> <span class="n">sep</span><span class="o">=</span><span class="s">"</span><span class="se">\t</span><span class="s">"</span><span class="p">,</span> <span class="n">header</span><span class="o">=</span><span class="bp">None</span>
<span class="p">).</span><span class="n">to_numpy</span><span class="p">()</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">airfoil_data</span><span class="p">[:,</span> <span class="mi">5</span><span class="p">]</span>
<span class="n">X</span> <span class="o">=</span> <span class="n">airfoil_data</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">5</span><span class="p">]</span>
<span class="n">N</span><span class="p">,</span> <span class="n">d</span> <span class="o">=</span> <span class="n">X</span><span class="p">.</span><span class="n">shape</span>
<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Dataset has </span><span class="si">{</span><span class="n">N</span><span class="o">=</span><span class="si">}</span><span class="s"> samples and </span><span class="si">{</span><span class="n">d</span><span class="o">=</span><span class="si">}</span><span class="s"> features."</span><span class="p">)</span>
<span class="c1"># Do train-test split, and scale data to interval [0,1]
</span><span class="n">X_train</span><span class="p">,</span> <span class="n">X_test</span><span class="p">,</span> <span class="n">y_train</span><span class="p">,</span> <span class="n">y_test</span> <span class="o">=</span> <span class="n">train_test_split</span><span class="p">(</span>
<span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">test_size</span><span class="o">=</span><span class="mf">0.2</span><span class="p">,</span> <span class="n">random_state</span><span class="o">=</span><span class="mi">179</span>
<span class="p">)</span>
<span class="n">scaler</span> <span class="o">=</span> <span class="n">MinMaxScaler</span><span class="p">(</span><span class="n">clip</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="n">X_train</span> <span class="o">=</span> <span class="n">scaler</span><span class="p">.</span><span class="n">fit_transform</span><span class="p">(</span><span class="n">X_train</span><span class="p">)</span>
<span class="n">X_test</span> <span class="o">=</span> <span class="n">scaler</span><span class="p">.</span><span class="n">transform</span><span class="p">(</span><span class="n">X_test</span><span class="p">)</span>
<span class="c1"># Define grid, and find associated indices for each data point
</span><span class="n">num_intervals</span> <span class="o">=</span> <span class="mi">10</span>
<span class="n">grids</span> <span class="o">=</span> <span class="p">[</span><span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="mi">1</span> <span class="o">/</span> <span class="n">num_intervals</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">num_intervals</span><span class="p">)</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">d</span><span class="p">)]</span>
<span class="n">tensor_shape</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">grid</span><span class="p">)</span> <span class="k">for</span> <span class="n">grid</span> <span class="ow">in</span> <span class="n">grids</span><span class="p">)</span>
<span class="n">idx_train</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">stack</span><span class="p">(</span>
<span class="p">[</span><span class="n">np</span><span class="p">.</span><span class="n">searchsorted</span><span class="p">(</span><span class="n">grid</span><span class="p">,</span> <span class="n">X_train</span><span class="p">[:,</span> <span class="n">i</span><span class="p">])</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">grid</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">grids</span><span class="p">)],</span>
<span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">idx_test</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">stack</span><span class="p">(</span>
<span class="p">[</span><span class="n">np</span><span class="p">.</span><span class="n">searchsorted</span><span class="p">(</span><span class="n">grid</span><span class="p">,</span> <span class="n">X_test</span><span class="p">[:,</span> <span class="n">i</span><span class="p">])</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">grid</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">grids</span><span class="p">)],</span>
<span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
<span class="p">)</span>
<span class="c1"># Initialize the tensor train
</span><span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">seed</span><span class="p">(</span><span class="mi">179</span><span class="p">)</span>
<span class="n">rank</span> <span class="o">=</span> <span class="mi">5</span>
<span class="n">tensor_train</span> <span class="o">=</span> <span class="n">TensorTrain</span><span class="p">.</span><span class="n">random</span><span class="p">(</span><span class="n">tensor_shape</span><span class="p">,</span> <span class="n">rank</span><span class="p">)</span>
<span class="c1"># Optimize the tensor train using iterative method
</span><span class="n">optimizer</span> <span class="o">=</span> <span class="n">TTLS</span><span class="p">(</span><span class="n">tensor_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">,</span> <span class="n">idx_train</span><span class="p">)</span>
<span class="n">train_losses</span> <span class="o">=</span> <span class="p">[]</span>
<span class="n">test_losses</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">100</span><span class="p">):</span>
<span class="n">train_loss</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">optimizer</span><span class="p">.</span><span class="n">step</span><span class="p">()</span>
<span class="n">train_losses</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">train_loss</span><span class="p">)</span>
<span class="n">test_loss</span> <span class="o">=</span> <span class="n">optimizer</span><span class="p">.</span><span class="n">loss</span><span class="p">(</span><span class="n">y</span><span class="o">=</span><span class="n">y_test</span><span class="p">,</span> <span class="n">idx</span><span class="o">=</span><span class="n">idx_test</span><span class="p">)</span>
<span class="n">test_losses</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">test_loss</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="n">DEFAULT_FIGSIZE</span><span class="p">))</span>
<span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">train_losses</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s">"Training loss"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">test_losses</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s">"Test loss"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">xlabel</span><span class="p">(</span><span class="s">"Number of iterations"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">ylabel</span><span class="p">(</span><span class="s">"Loss"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">legend</span><span class="p">()</span>
<span class="n">plt</span><span class="p">.</span><span class="n">yscale</span><span class="p">(</span><span class="s">"log"</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Final training loss: </span><span class="si">{</span><span class="n">train_loss</span><span class="p">:.</span><span class="mi">4</span><span class="n">f</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Final test loss: </span><span class="si">{</span><span class="n">test_loss</span><span class="p">:.</span><span class="mi">4</span><span class="n">f</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
</code></pre></div></div>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Dataset has N=1503 samples and d=5 features.
Final training loss: 15.3521
Final test loss: 54.4698
</code></pre></div></div>
<p><img src="/imgs/discrete-function-tensor/tensor-completion_15_1.svg" alt="svg" /></p>
<p>We see a similar training profile to the matrix completion case. Let’s see now how this estimator compares to a random forest trained on the same data:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">seed</span><span class="p">(</span><span class="mi">179</span><span class="p">)</span>
<span class="n">forest</span> <span class="o">=</span> <span class="n">RandomForestRegressor</span><span class="p">()</span>
<span class="n">forest</span><span class="p">.</span><span class="n">fit</span><span class="p">(</span><span class="n">X_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">)</span>
<span class="n">y_pred</span> <span class="o">=</span> <span class="n">forest</span><span class="p">.</span><span class="n">predict</span><span class="p">(</span><span class="n">X_test</span><span class="p">)</span>
<span class="n">test_loss</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">mean</span><span class="p">((</span><span class="n">y_pred</span> <span class="o">-</span> <span class="n">y_test</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Random forest test loss: </span><span class="si">{</span><span class="n">test_loss</span><span class="p">:.</span><span class="mi">4</span><span class="n">f</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
</code></pre></div></div>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Random forest test loss: 3.2568
</code></pre></div></div>
<p>The random forest has a loss of around <code class="language-plaintext highlighter-rouge">3.3</code>, but the discretized function has a loss of around <code class="language-plaintext highlighter-rouge">54.5</code>! That gap in performance is completely unacceptable. We could try to improve it by increasing the number of grid points, and by tweaking the rank of the tensor train. However, it will still come nowhere close to the performance of a random forest, even with its default parameters. Even the <em>training error</em> of the discretized function is much worse than the <em>test error</em> of the random forest.</p>
<p><strong>Why is it so bad?</strong> <em>Bad initialization!</em></p>
<p>Recall that a gradient descent method converges to a <em>local</em> minimum of the function. Usually we hope that whatever local minimum we converge to is ‘good’. Indeed for neural networks we see that, especially if we use a lot of parameters, most local minima found by stochastic gradient descent are quite good, and give a low train <em>and</em> test error. This is not true for our discretized function. We converge to local minima that have both bad train and test error.</p>
<p><strong>The solution?</strong> <em>Better initialization!</em></p>
<h2 id="using-other-estimators-for-initialization">Using other estimators for initialization</h2>
<p>Instead of initializing the tensor trains <em>randomly</em>, we can learn from other machine learning estimators. We
fit our favorite machine learning estimator (e.g. a neural network) to the training data. This gives a function
\(g\colon [0,1]^d\to \mathbb R\). This function is defined for <em>any</em> input, not just for the training/test data
points. Therefore we can try to first fit our discretized function \(f_\Theta\) to match \(g\), i.e. we solve the
following minimization problem:</p>
\[\min_\Theta \|f_\Theta - g\|^2\]
<p>One way to solve this minimization problem is by first (randomly) sampling a lot of new data points
\((x_1,\dots,x_N)\in [0,1]^d\) and then fitting \(f_\Theta\) to these data points with labels
\((g(x_1),\dots,g(x_N))\). This is essentially <em>data augmentation</em>, and can drastically increase the <em>number</em> of
data points available for training. With more training data, the function \(f_\Theta\) will indeed converge to a
better local minimum.</p>
<p>While data augmentation does improve performance, we can do better. We don’t need to <em>randomly</em> sample data
points \((x_1,\dots,x_N)\in[0,1]^d\). Instead we can <em>choose</em> good points to sample; points that give us the
most information on how to efficiently update the tensor train. This is essentially the idea behind the
<em>tensor train cross approximation</em> algorithm, or TT-Cross for short. Using TT-Cross we can quickly and
efficiently get a good approximation to the minimization problem \(\min_\Theta \|f_\Theta - g\|^2\).</p>
<p>We could stop here. If \(g\) models our data really well, and \(f_\Theta\) approximates \(g\) really well, then we
should be happy. Like the matrix completion model, discretized functions based on tensor trains are <em>fast</em> and
are <em>memory efficient</em>. Therefore we can make an approximation of \(g\) that uses less memory and can make
faster predictions! However, the model \(g\) really should be used for <em>initialization</em> only. Usually \(f_\Theta\)
actually doesn’t do a great job of approximating \(g\), but if we first approximate \(g\), and <em>then</em> use a
gradient descent algorithm to improve \(f_\Theta\) even further, we end up with something much more competitive.</p>
<p>Let’s see this in action. This is actually much easier than what we did before, because I wrote the <code class="language-plaintext highlighter-rouge">ttml</code>
package specifically for this use case.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">ttml.ttml</span> <span class="kn">import</span> <span class="n">TTMLRegressor</span>
<span class="c1"># Use random forest as base estimator
</span><span class="n">forest</span> <span class="o">=</span> <span class="n">RandomForestRegressor</span><span class="p">()</span>
<span class="c1"># Fit tt on random forest, and then optimize further on training data
</span><span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">seed</span><span class="p">(</span><span class="mi">179</span><span class="p">)</span>
<span class="n">tt</span> <span class="o">=</span> <span class="n">TTMLRegressor</span><span class="p">(</span><span class="n">forest</span><span class="p">,</span> <span class="n">max_rank</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">opt_tol</span><span class="o">=</span><span class="bp">None</span><span class="p">)</span>
<span class="n">tt</span><span class="p">.</span><span class="n">fit</span><span class="p">(</span><span class="n">X_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">,</span> <span class="n">X_val</span><span class="o">=</span><span class="n">X_test</span><span class="p">,</span> <span class="n">y_val</span><span class="o">=</span><span class="n">y_test</span><span class="p">)</span>
<span class="n">y_pred</span> <span class="o">=</span> <span class="n">tt</span><span class="p">.</span><span class="n">predict</span><span class="p">(</span><span class="n">X_test</span><span class="p">)</span>
<span class="n">test_loss</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">mean</span><span class="p">((</span><span class="n">y_pred</span> <span class="o">-</span> <span class="n">y_test</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"TTML test loss: </span><span class="si">{</span><span class="n">test_loss</span><span class="p">:.</span><span class="mi">4</span><span class="n">f</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
<span class="c1"># Forest is fit on same data during fitting of tt
# Let's also report how good the forest does
</span><span class="n">y_pred_forest</span> <span class="o">=</span> <span class="n">forest</span><span class="p">.</span><span class="n">predict</span><span class="p">(</span><span class="n">X_test</span><span class="p">)</span>
<span class="n">test_loss_forest</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">mean</span><span class="p">((</span><span class="n">y_pred_forest</span> <span class="o">-</span> <span class="n">y_test</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Random forest test loss: </span><span class="si">{</span><span class="n">test_loss_forest</span><span class="p">:.</span><span class="mi">4</span><span class="n">f</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
<span class="c1"># Training and test loss is also recording during optimization, let's plot it
</span><span class="n">plt</span><span class="p">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="n">DEFAULT_FIGSIZE</span><span class="p">))</span>
<span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">tt</span><span class="p">.</span><span class="n">history_</span><span class="p">[</span><span class="s">"train_loss"</span><span class="p">],</span> <span class="n">label</span><span class="o">=</span><span class="s">"Training loss"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">tt</span><span class="p">.</span><span class="n">history_</span><span class="p">[</span><span class="s">"val_loss"</span><span class="p">],</span> <span class="n">label</span><span class="o">=</span><span class="s">"Test loss"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">axhline</span><span class="p">(</span><span class="n">test_loss_forest</span><span class="p">,</span> <span class="n">c</span><span class="o">=</span><span class="s">"g"</span><span class="p">,</span> <span class="n">ls</span><span class="o">=</span><span class="s">"--"</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s">"Random forest test loss"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">xlabel</span><span class="p">(</span><span class="s">"Number of iterations"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">ylabel</span><span class="p">(</span><span class="s">"Loss"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">legend</span><span class="p">()</span>
<span class="n">plt</span><span class="p">.</span><span class="n">yscale</span><span class="p">(</span><span class="s">"log"</span><span class="p">)</span>
</code></pre></div></div>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>TTML test loss: 2.8970
Random forest test loss: 3.2568
</code></pre></div></div>
<p><img src="/imgs/discrete-function-tensor/tensor-completion_19_1.svg" alt="svg" /></p>
<p>We see that using a random forest for initialization gives a huge improvement to both training and test loss.
In fact,the final test loss is better than that of the random forest itself! On top of that, this estimator doesn’t use many parameters:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"TT uses </span><span class="si">{</span><span class="n">tt</span><span class="p">.</span><span class="n">ttml_</span><span class="p">.</span><span class="n">num_params</span><span class="si">}</span><span class="s"> parameters"</span><span class="p">)</span>
</code></pre></div></div>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>TT uses 1356 parameters
</code></pre></div></div>
<p>Let’s compare that to the random forest. If we look under the hood, the scikit-learn implementation of random forests stores 8 parameters per node in each tree in the forest. This is inefficient, and you really only <em>need</em> 2 parameters per node, so let’s use that.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">num_params_forest</span> <span class="o">=</span> <span class="nb">sum</span><span class="p">(</span>
<span class="p">[</span><span class="nb">len</span><span class="p">(</span><span class="n">tree</span><span class="p">.</span><span class="n">tree_</span><span class="p">.</span><span class="n">__getstate__</span><span class="p">()[</span><span class="s">"nodes"</span><span class="p">])</span> <span class="o">*</span> <span class="mi">2</span> <span class="k">for</span> <span class="n">tree</span> <span class="ow">in</span> <span class="n">forest</span><span class="p">.</span><span class="n">estimators_</span><span class="p">]</span>
<span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Forest uses </span><span class="si">{</span><span class="n">num_params_forest</span><span class="si">}</span><span class="s"> parameters"</span><span class="p">)</span>
</code></pre></div></div>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Forest uses 303180 parameters
</code></pre></div></div>
<p>That’s 1356 parameters vs. more than 300,000 parameters! What about my claim of prediction speed? Let’s compare the amount of time it takes both estimators to predict 1 million samples. We do this by just concatenating the training data until we get 1 million samples.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">time</span> <span class="kn">import</span> <span class="n">perf_counter_ns</span>
<span class="n">target_num</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="mf">1e6</span><span class="p">)</span>
<span class="n">n_copies</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">target_num</span><span class="o">//</span><span class="nb">len</span><span class="p">(</span><span class="n">X_train</span><span class="p">))</span><span class="o">+</span><span class="mi">1</span>
<span class="n">X_one_million</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">repeat</span><span class="p">(</span><span class="n">X_train</span><span class="p">,</span><span class="n">n_copies</span><span class="p">,</span><span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)[:</span><span class="n">target_num</span><span class="p">]</span>
<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"</span><span class="si">{</span><span class="n">X_one_million</span><span class="p">.</span><span class="n">shape</span><span class="o">=</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
<span class="n">time_before</span> <span class="o">=</span> <span class="n">perf_counter_ns</span><span class="p">()</span>
<span class="n">tt</span><span class="p">.</span><span class="n">predict</span><span class="p">(</span><span class="n">X_one_million</span><span class="p">)</span>
<span class="n">time_taken</span> <span class="o">=</span> <span class="p">(</span><span class="n">perf_counter_ns</span><span class="p">()</span> <span class="o">-</span> <span class="n">time_before</span><span class="p">)</span><span class="o">/</span><span class="mf">1e6</span>
<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Time taken by TT: </span><span class="si">{</span><span class="n">time_taken</span><span class="p">:.</span><span class="mi">0</span><span class="n">f</span><span class="si">}</span><span class="s">ms"</span><span class="p">)</span>
<span class="n">time_before</span> <span class="o">=</span> <span class="n">perf_counter_ns</span><span class="p">()</span>
<span class="n">forest</span><span class="p">.</span><span class="n">predict</span><span class="p">(</span><span class="n">X_one_million</span><span class="p">)</span>
<span class="n">time_taken</span> <span class="o">=</span> <span class="p">(</span><span class="n">perf_counter_ns</span><span class="p">()</span> <span class="o">-</span> <span class="n">time_before</span><span class="p">)</span><span class="o">/</span><span class="mf">1e6</span>
<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Time taken by Forest: </span><span class="si">{</span><span class="n">time_taken</span><span class="p">:.</span><span class="mi">0</span><span class="n">f</span><span class="si">}</span><span class="s">ms"</span><span class="p">)</span>
</code></pre></div></div>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>X_one_million.shape=(1000000, 5)
Time taken by TT: 430ms
Time taken by Forest: 2328ms
</code></pre></div></div>
<p>While not by orders of magnitude, we see that the tensor train model is faster. You might be thinking that
this is just because the tensor train has fewer parameters, but this is not the case. Even if we use a very
high-rank tensor train with high-dimensional data, it is still going to be fast. The speed scales really well,
and will beat most conventional machine learning estimators.</p>
<h2 id="no-free-lunch">No free lunch</h2>
<p>With good initialization the model based on distretized functions perform really well. On our test dataset the
model is fast, uses few parameters, and beats a random forest in test loss (in fact, it is <em>the best
estimator</em> I have found so far for this problem). This is great! I should publish a paper in NeurIPS and get a
job at Google! Well… let’s not get ahead of ourselves. It performs well on <em>this particular dataset</em>, yes,
but how does it fare on other data?</p>
<p>As we shall see, it doesn’t do all that well actually. The airfoil self-noise dataset is a very particular
dataset on which this algorithm excels. The model seems to perform well on data that can be described by a
somewhat smooth function, and doesn’t deal well with the noisy and stochastic nature of most data we encounter
in the real world. As an example let’s repeat the experiment, but let’s first add some noise:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">ttml.ttml</span> <span class="kn">import</span> <span class="n">TTMLRegressor</span>
<span class="n">X_noise_std</span> <span class="o">=</span> <span class="mf">1e-6</span>
<span class="n">X_train_noisy</span> <span class="o">=</span> <span class="n">X_train</span> <span class="o">+</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">normal</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">X_noise_std</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="n">X_train</span><span class="p">.</span><span class="n">shape</span><span class="p">)</span>
<span class="n">X_test_noisy</span> <span class="o">=</span> <span class="n">X_test</span> <span class="o">+</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">normal</span><span class="p">(</span><span class="n">scale</span><span class="o">=</span><span class="n">X_noise_std</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="n">X_test</span><span class="p">.</span><span class="n">shape</span><span class="p">)</span>
<span class="c1"># Use random forest as base estimator
</span><span class="n">forest</span> <span class="o">=</span> <span class="n">RandomForestRegressor</span><span class="p">()</span>
<span class="c1"># Fit tt on random forest, and then optimize further on training data
</span><span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">seed</span><span class="p">(</span><span class="mi">179</span><span class="p">)</span>
<span class="n">tt</span> <span class="o">=</span> <span class="n">TTMLRegressor</span><span class="p">(</span><span class="n">forest</span><span class="p">,</span> <span class="n">max_rank</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">opt_tol</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span> <span class="n">opt_steps</span><span class="o">=</span><span class="mi">50</span><span class="p">)</span>
<span class="n">tt</span><span class="p">.</span><span class="n">fit</span><span class="p">(</span><span class="n">X_train_noisy</span><span class="p">,</span> <span class="n">y_train</span><span class="p">,</span> <span class="n">X_val</span><span class="o">=</span><span class="n">X_test_noisy</span><span class="p">,</span> <span class="n">y_val</span><span class="o">=</span><span class="n">y_test</span><span class="p">)</span>
<span class="n">y_pred</span> <span class="o">=</span> <span class="n">tt</span><span class="p">.</span><span class="n">predict</span><span class="p">(</span><span class="n">X_test_noisy</span><span class="p">)</span>
<span class="n">test_loss</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">mean</span><span class="p">((</span><span class="n">y_pred</span> <span class="o">-</span> <span class="n">y_test</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"TTML test loss (noisy): </span><span class="si">{</span><span class="n">test_loss</span><span class="p">:.</span><span class="mi">4</span><span class="n">f</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
<span class="c1"># Forest is fit on same data during fitting of tt
# Let's also report how good the forest does
</span><span class="n">y_pred_forest</span> <span class="o">=</span> <span class="n">forest</span><span class="p">.</span><span class="n">predict</span><span class="p">(</span><span class="n">X_test_noisy</span><span class="p">)</span>
<span class="n">test_loss_forest</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">mean</span><span class="p">((</span><span class="n">y_pred_forest</span> <span class="o">-</span> <span class="n">y_test</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Random forest test loss (noisy): </span><span class="si">{</span><span class="n">test_loss_forest</span><span class="p">:.</span><span class="mi">4</span><span class="n">f</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
<span class="c1"># Training and test loss is also recording during optimization, let's plot it
</span><span class="n">plt</span><span class="p">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="n">DEFAULT_FIGSIZE</span><span class="p">))</span>
<span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">tt</span><span class="p">.</span><span class="n">history_</span><span class="p">[</span><span class="s">"train_loss"</span><span class="p">],</span> <span class="n">label</span><span class="o">=</span><span class="s">"Training loss"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">tt</span><span class="p">.</span><span class="n">history_</span><span class="p">[</span><span class="s">"val_loss"</span><span class="p">],</span> <span class="n">label</span><span class="o">=</span><span class="s">"Test loss"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">axhline</span><span class="p">(</span><span class="n">test_loss_forest</span><span class="p">,</span> <span class="n">c</span><span class="o">=</span><span class="s">"g"</span><span class="p">,</span> <span class="n">ls</span><span class="o">=</span><span class="s">"--"</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s">"Random forest test loss"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">xlabel</span><span class="p">(</span><span class="s">"Number of iterations"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">ylabel</span><span class="p">(</span><span class="s">"Loss"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">legend</span><span class="p">();</span>
</code></pre></div></div>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>TTML test loss (noisy): 7.1980
Random forest test loss (noisy): 5.1036
</code></pre></div></div>
<p><img src="/imgs/discrete-function-tensor/tensor-completion_28_1.svg" alt="svg" /></p>
<p>Even a tiny bit of noise in the training data can severely degrade the model. We see that it starts to overfit
a lot. This is because my algorithm tries to automatically find a ‘good’ discretization of the data, not just
a uniform discretization as we have discussed in our 2D example (i.e. equally spacing all the grid cells).
Some of the variables in this dataset are however categorical, and a small amount of noise makes it much more
difficult to automatically detect a good way to discretize them.</p>
<p>The model has a lot of hyperparameters we won’t go into now, and playing with them does help with overfitting.
Furthermore, the noisy data we show here is perhaps not very realistic. However, the fact remains that the
model (at least the way its currently implemented) is not very robust to noise. In particular, the model is
very sensitive to the discretization of the feature space used.</p>
<p>Right now we don’t have anything better than simple heuristics for finding discretizations of the features
space. Since the loss function depends in a really discontinuous way on the discretization, optimizing the
discretization is difficult. Perhaps we can use an algorithm to adaptively split and merge thresholds used in
the discretization, or use some kind of clustering algorithm for discretization. I have tried things along
those lines but getting it to work well is difficult. I think that with more study, the problem of finding a
good discretization can be solved, but it’s not easy.</p>
<h2 id="conclusion">Conclusion</h2>
<p>We looked at discretized functions and their use in supervised machine learning. In higher dimensions
discretized functions are parametrized by tensors, which we can represent efficiently using tensor trains. The
tensor train can be optimized directly on the data to produce a potentially useful machine learning model. It
is both very fast, and doesn’t use many parameters. In order to initialize it well, we can first fit an
auxiliary machine learning model on the same data, and then sample predictions from that model to effectively
increase the amount of training data. This model performs really well on some datasets, but in general it is
not very robust to noise. As a result, without further improvements, the model will only be useful in a select
number of cases. On the other hand, I really think that the model does have a lot of potential, once some of
its drawbacks are fixed.</p>Rik VoorhaarWe recently made a paper about supervised machine learning using tensors, here's the gist of how this works.GMRES: or how to do fast linear algebra2022-03-10T00:00:00+00:002022-03-10T00:00:00+00:00https://rikvoorhaar.com/gmres<p>Linear algebra is the foundation of modern science, and the fact that computers can do linear algebra <em>very
fast</em> is one of the primary reasons modern algorithms work so well in practice. In this blog post we will dive
into some of the principles of fast numerical linear algebra, and learn how to solve least-squares problems
using the GMRES algorithm. We apply this to the deconvolution problem, which we already discussed at length in
previous blog posts.</p>
<h2 id="linear-least-squares-problem">Linear least-squares problem</h2>
<p>The linear least-squares problem is one of the most common minimization problems we encounter. It takes the following form:</p>
\[\min_x \|Ax-b\|^2\]
<p>Here \(A\) is an \(n\times n\) matrix, and \(x,b\in\mathbb R^{n}\) are vectors. If \(A\) is invertible, then this
problem has a simple, unique solution: \(x = A^{-1}b\). However, there are two big reasons why we should <em>almost never</em>
use \(A^{-1}\) to solve the least-squares problem in practice:</p>
<ol>
<li>It is expensive to compute \(A^{-1}\).</li>
<li>This solution numerically unstable.</li>
</ol>
<p>Assuming \(A\) doesn’t have any useful structure, point 1. is not that bad. Solving the least-squares problem in
a smart way costs \(O(n^3)\), and doing it using matrix-inversion also costs \(O(n^3)\), just with a larger hidden
constant. The real killer is the instability. To see this in action, let’s take a matrix that is <em>almost
singular</em>, and see what happens when we solve the least-squares problem.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="n">np</span>
<span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">seed</span><span class="p">(</span><span class="mi">179</span><span class="p">)</span>
<span class="n">n</span> <span class="o">=</span> <span class="mi">20</span>
<span class="c1"># Create almost singular matrix
</span><span class="n">A</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">eye</span><span class="p">(</span><span class="n">n</span><span class="p">)</span>
<span class="n">A</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="mf">1e-20</span>
<span class="n">A</span> <span class="o">=</span> <span class="n">A</span> <span class="o">@</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">normal</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="n">A</span><span class="p">.</span><span class="n">shape</span><span class="p">)</span>
<span class="c1"># Random vector b
</span><span class="n">b</span> <span class="o">=</span> <span class="n">A</span> <span class="o">@</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">normal</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="n">n</span><span class="p">,))</span> <span class="o">+</span> <span class="mf">1e-3</span> <span class="o">*</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">normal</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="n">n</span><span class="p">)</span>
<span class="c1"># Solve least-squares with inverse
</span><span class="n">A_inv</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">linalg</span><span class="p">.</span><span class="n">inv</span><span class="p">(</span><span class="n">A</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">A_inv</span> <span class="o">@</span> <span class="n">b</span>
<span class="n">error</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">linalg</span><span class="p">.</span><span class="n">norm</span><span class="p">(</span><span class="n">A</span> <span class="o">@</span> <span class="n">x</span> <span class="o">-</span> <span class="n">b</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span>
<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"error for matrix inversion method: </span><span class="si">{</span><span class="n">error</span><span class="p">:.</span><span class="mi">4</span><span class="n">e</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
<span class="c1"># Solve least-squares with dedicated routine
</span><span class="n">x</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">linalg</span><span class="p">.</span><span class="n">lstsq</span><span class="p">(</span><span class="n">A</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">rcond</span><span class="o">=</span><span class="bp">None</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">error</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">linalg</span><span class="p">.</span><span class="n">norm</span><span class="p">(</span><span class="n">A</span> <span class="o">@</span> <span class="n">x</span> <span class="o">-</span> <span class="n">b</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span>
<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"error for dedicated method: </span><span class="si">{</span><span class="n">error</span><span class="p">:.</span><span class="mi">4</span><span class="n">e</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
</code></pre></div></div>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>error for matrix inversion method: 3.6223e+02
error for dedicated method: 2.8275e-08
</code></pre></div></div>
<p>In this case we took a 20x20 matrix \(A\) with ones on the diagonals, except for one entry where it has value
<code class="language-plaintext highlighter-rouge">1e-20</code>, and then we shuffled everything around by multiplying by a random matrix. The entries of \(A\) are
not so big, but the entries of \(A^{-1}\) will be <em>gigantic</em>. This results in the fact that the solution
obtained as \(x=A^{-1}b\) does not satisfy \(Ax=b\) in practice. The solution found by using the <code class="language-plaintext highlighter-rouge">np.linalg.lstsq</code>
routine is much better.</p>
<p>The reason that the inverse-matrix method fails badly in this case can be summarized using the <em>condition
number</em> \(\kappa(A)\). It expresses how much the error \(\|Ax-b\|\) with \(x=A^{-1}b\) is going to change if we
change \(b\) slightly, in the worst case. The condition number gives a notion of how much numerical errors get
amplified when we solve the linear system. We can compute it as the ratio between the smallest and largest
singular values of the matrix \(A\):</p>
\[\kappa(A) = \sigma_1(A) / \sigma_n(A)\]
<p>In the case above the condition number is really big:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">np</span><span class="p">.</span><span class="n">linalg</span><span class="p">.</span><span class="n">cond</span><span class="p">(</span><span class="n">A</span><span class="p">)</span>
</code></pre></div></div>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>1.1807555508404976e+16
</code></pre></div></div>
<p>Large condition numbers mean that <em>any</em> numerical method is going to struggle to give a good solution, but for
numerically unstable methods the problem is a lot worse.</p>
<h2 id="using-structure">Using structure</h2>
<p>While the numerical stability of algorithms is a fascinating topic, it is not what we came here for today.
Instead, let’s revisit the first reason why using matrix inversion for solving linear problems is bad. I
mentioned that matrix inversion and better alternatives take \(O(n^3)\) to solve the least squares problem
\(\min_a\|Ax-b\|^2\), <em>if there is no extra structure on</em> \(A\) <em>that we can exploit</em>.</p>
<p>What if there <em>is</em> such structure? For example, what if \(A\) is a huge sparse matrix? For example the Netflix
dataset we considered <a href="/low-rank-matrix/">in this blog post</a> is of size 480,189 x 17,769. Putting aside the
fact that it is not square, inverting matrices of that kind of size is infeasible. Moreover, the inverse
matrix isn’t necessarily sparse anymore, so we lose that valuable structure as well.</p>
<p>Another example arose in my <a href="/deconvolution-part1/">first post on deconvolution</a>. There we tried to solve the linear problem</p>
\[\min_x \|k * x -y\|^2\]
<p>where \(k * x\) denotes <em>convolution</em>. Convolution is a linear operation, but requires only \(O(n\log n)\) to
compute, whereas writing it out as a matrix would require \(n\times n\) entries, which can quickly become too
large.</p>
<p>In situations like this, we have no choice but to devise an algorithm that makes use of the structure of \(A\).
What the two situations above have in common is that storing \(A\) as a dense matrix is expensive, but computing
matrix-vector products \(Ax\) is cheap. The algorithm we are going to come up with is going to be <em>iterative</em>;
we start with some initial guess \(x_0\), and then improve it until we find a solution of the desired accuracy.</p>
<p>We don’t have much to work with; we have a vector \(x_0\) and the ability fo compute matrix-vector products.
Crucially, we assumed our matrix \(A\) is <em>square</em>. This means that \(x_0\) and \(Ax_0\) have the same shape, and
therefore we can also compute \(A^2x_0\), or in fact \(A^rx_0\) for any \(r\). The idea is then to try to express
the solution to the least-squares problem as linear combination of the vectors</p>
\[\mathcal K_r(A,x_0):=\{x_0, Ax_0,A^2x_0,\ldots,A^{r-1}x_0\}.\]
<p>This results in a class of algorithms known as <em>Krylov subspace methods</em>. Before diving further into how they work, let’s see one in action. We take a 2500 x 2500 sparse matrix with 5000 non-zero entries (which includes the entire diagonal).</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">scipy.sparse</span>
<span class="kn">import</span> <span class="nn">scipy.sparse.linalg</span>
<span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="n">plt</span>
<span class="kn">from</span> <span class="nn">time</span> <span class="kn">import</span> <span class="n">perf_counter_ns</span>
<span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">seed</span><span class="p">(</span><span class="mi">179</span><span class="p">)</span>
<span class="n">n</span> <span class="o">=</span> <span class="mi">2500</span>
<span class="n">N</span> <span class="o">=</span> <span class="n">n</span>
<span class="n">shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">n</span><span class="p">,</span> <span class="n">n</span><span class="p">)</span>
<span class="c1"># Create random sparse (n, n) matrix with N non-zero entries
</span><span class="n">coords</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">choice</span><span class="p">(</span><span class="n">n</span> <span class="o">*</span> <span class="n">n</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="n">N</span><span class="p">,</span> <span class="n">replace</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
<span class="n">coords</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">unravel_index</span><span class="p">(</span><span class="n">coords</span><span class="p">,</span> <span class="n">shape</span><span class="p">)</span>
<span class="n">values</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">normal</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="n">N</span><span class="p">)</span>
<span class="n">A_sparse</span> <span class="o">=</span> <span class="n">scipy</span><span class="p">.</span><span class="n">sparse</span><span class="p">.</span><span class="n">coo_matrix</span><span class="p">((</span><span class="n">values</span><span class="p">,</span> <span class="n">coords</span><span class="p">),</span> <span class="n">shape</span><span class="o">=</span><span class="n">shape</span><span class="p">)</span>
<span class="n">A_sparse</span> <span class="o">=</span> <span class="n">A_sparse</span><span class="p">.</span><span class="n">tocsr</span><span class="p">()</span>
<span class="n">A_sparse</span> <span class="o">+=</span> <span class="n">scipy</span><span class="p">.</span><span class="n">sparse</span><span class="p">.</span><span class="n">eye</span><span class="p">(</span><span class="n">n</span><span class="p">)</span>
<span class="n">A_dense</span> <span class="o">=</span> <span class="n">A_sparse</span><span class="p">.</span><span class="n">toarray</span><span class="p">()</span>
<span class="n">b</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">normal</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="n">n</span><span class="p">)</span>
<span class="n">b</span> <span class="o">=</span> <span class="n">A_sparse</span> <span class="o">@</span> <span class="n">b</span>
<span class="c1"># Solve using np.linalg.lstsq
</span><span class="n">time_before</span> <span class="o">=</span> <span class="n">perf_counter_ns</span><span class="p">()</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">linalg</span><span class="p">.</span><span class="n">lstsq</span><span class="p">(</span><span class="n">A_dense</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">rcond</span><span class="o">=</span><span class="bp">None</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">time_taken</span> <span class="o">=</span> <span class="p">(</span><span class="n">perf_counter_ns</span><span class="p">()</span> <span class="o">-</span> <span class="n">time_before</span><span class="p">)</span> <span class="o">*</span> <span class="mf">1e-6</span>
<span class="n">error</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">linalg</span><span class="p">.</span><span class="n">norm</span><span class="p">(</span><span class="n">A_dense</span> <span class="o">@</span> <span class="n">x</span> <span class="o">-</span> <span class="n">b</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span>
<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Using dense solver: error: </span><span class="si">{</span><span class="n">error</span><span class="p">:.</span><span class="mi">4</span><span class="n">e</span><span class="si">}</span><span class="s"> in time </span><span class="si">{</span><span class="n">time_taken</span><span class="p">:.</span><span class="mi">1</span><span class="n">f</span><span class="si">}</span><span class="s">ms"</span><span class="p">)</span>
<span class="c1"># Solve using inverse matrix
</span><span class="n">time_before</span> <span class="o">=</span> <span class="n">perf_counter_ns</span><span class="p">()</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">linalg</span><span class="p">.</span><span class="n">inv</span><span class="p">(</span><span class="n">A_dense</span><span class="p">)</span> <span class="o">@</span> <span class="n">x</span>
<span class="n">time_taken</span> <span class="o">=</span> <span class="p">(</span><span class="n">perf_counter_ns</span><span class="p">()</span> <span class="o">-</span> <span class="n">time_before</span><span class="p">)</span> <span class="o">*</span> <span class="mf">1e-6</span>
<span class="n">error</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">linalg</span><span class="p">.</span><span class="n">norm</span><span class="p">(</span><span class="n">A_dense</span> <span class="o">@</span> <span class="n">x</span> <span class="o">-</span> <span class="n">b</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span>
<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Using matrix inversion: error: </span><span class="si">{</span><span class="n">error</span><span class="p">:.</span><span class="mi">4</span><span class="n">e</span><span class="si">}</span><span class="s"> in time </span><span class="si">{</span><span class="n">time_taken</span><span class="p">:.</span><span class="mi">1</span><span class="n">f</span><span class="si">}</span><span class="s">ms"</span><span class="p">)</span>
<span class="c1"># Solve using GMRES
</span><span class="n">time_before</span> <span class="o">=</span> <span class="n">perf_counter_ns</span><span class="p">()</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">scipy</span><span class="p">.</span><span class="n">sparse</span><span class="p">.</span><span class="n">linalg</span><span class="p">.</span><span class="n">gmres</span><span class="p">(</span><span class="n">A_sparse</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">tol</span><span class="o">=</span><span class="mf">1e-8</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">time_taken</span> <span class="o">=</span> <span class="p">(</span><span class="n">perf_counter_ns</span><span class="p">()</span> <span class="o">-</span> <span class="n">time_before</span><span class="p">)</span> <span class="o">*</span> <span class="mf">1e-6</span>
<span class="n">error</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">linalg</span><span class="p">.</span><span class="n">norm</span><span class="p">(</span><span class="n">A_sparse</span> <span class="o">@</span> <span class="n">x</span> <span class="o">-</span> <span class="n">b</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span>
<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Using sparse solver: error: </span><span class="si">{</span><span class="n">error</span><span class="p">:.</span><span class="mi">4</span><span class="n">e</span><span class="si">}</span><span class="s"> in time </span><span class="si">{</span><span class="n">time_taken</span><span class="p">:.</span><span class="mi">1</span><span class="n">f</span><span class="si">}</span><span class="s">ms"</span><span class="p">)</span>
</code></pre></div></div>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Using dense solver: error: 1.4449e-25 in time 2941.5ms
Using matrix inversion: error: 2.4763e+03 in time 507.0ms
Using sparse solver: error: 2.5325e-13 in time 6.4ms
</code></pre></div></div>
<p>As we see above, the sparse matrix solver solves this problem in a fraction of the time, and the difference is just going to get bigger with larger matrices. Above we use the GMRES routine, and it is very simple. It constructs an orthonormal basis of the Krylov subspace \(\mathcal K_m(A,x_0)\), and then finds the best solution in this subspace by solving a small \((m+1)\times m\) linear system. Before figuring out the details, below is a simple implementation:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">gmres</span><span class="p">(</span><span class="n">linear_map</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">x0</span><span class="p">,</span> <span class="n">n_iter</span><span class="p">):</span>
<span class="c1"># Initialization
</span> <span class="n">n</span> <span class="o">=</span> <span class="n">x0</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">H</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">n_iter</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="n">n_iter</span><span class="p">))</span>
<span class="n">r0</span> <span class="o">=</span> <span class="n">b</span> <span class="o">-</span> <span class="n">linear_map</span><span class="p">(</span><span class="n">x0</span><span class="p">)</span>
<span class="n">beta</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">linalg</span><span class="p">.</span><span class="n">norm</span><span class="p">(</span><span class="n">r0</span><span class="p">)</span>
<span class="n">V</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">n_iter</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="n">n</span><span class="p">))</span>
<span class="n">V</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">r0</span> <span class="o">/</span> <span class="n">beta</span>
<span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_iter</span><span class="p">):</span>
<span class="c1"># Compute next Krylov vector
</span> <span class="n">w</span> <span class="o">=</span> <span class="n">linear_map</span><span class="p">(</span><span class="n">V</span><span class="p">[</span><span class="n">j</span><span class="p">])</span>
<span class="c1"># Gram-Schmidt orthogonalization
</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">j</span> <span class="o">+</span> <span class="mi">1</span><span class="p">):</span>
<span class="n">H</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">dot</span><span class="p">(</span><span class="n">w</span><span class="p">,</span> <span class="n">V</span><span class="p">[</span><span class="n">i</span><span class="p">])</span>
<span class="n">w</span> <span class="o">-=</span> <span class="n">H</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="p">]</span> <span class="o">*</span> <span class="n">V</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
<span class="n">H</span><span class="p">[</span><span class="n">j</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="n">j</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">linalg</span><span class="p">.</span><span class="n">norm</span><span class="p">(</span><span class="n">w</span><span class="p">)</span>
<span class="c1"># Add new vector to basis
</span> <span class="n">V</span><span class="p">[</span><span class="n">j</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="n">w</span> <span class="o">/</span> <span class="n">H</span><span class="p">[</span><span class="n">j</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="n">j</span><span class="p">]</span>
<span class="c1"># Find best approximation in the basis V
</span> <span class="n">e1</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">n_iter</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">e1</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">beta</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">linalg</span><span class="p">.</span><span class="n">lstsq</span><span class="p">(</span><span class="n">H</span><span class="p">,</span> <span class="n">e1</span><span class="p">,</span> <span class="n">rcond</span><span class="o">=</span><span class="bp">None</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
<span class="c1"># Convert result back to full basis and return
</span> <span class="n">x_new</span> <span class="o">=</span> <span class="n">x0</span> <span class="o">+</span> <span class="n">V</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">].</span><span class="n">T</span> <span class="o">@</span> <span class="n">y</span>
<span class="k">return</span> <span class="n">x_new</span>
<span class="c1"># Try out the GMRES routine
</span><span class="n">time_before</span> <span class="o">=</span> <span class="n">perf_counter_ns</span><span class="p">()</span>
<span class="n">x0</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">n</span><span class="p">)</span>
<span class="n">linear_map</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">A_sparse</span> <span class="o">@</span> <span class="n">x</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">gmres</span><span class="p">(</span><span class="n">linear_map</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">x0</span><span class="p">,</span> <span class="mi">50</span><span class="p">)</span>
<span class="n">time_taken</span> <span class="o">=</span> <span class="p">(</span><span class="n">perf_counter_ns</span><span class="p">()</span> <span class="o">-</span> <span class="n">time_before</span><span class="p">)</span> <span class="o">*</span> <span class="mf">1e-6</span>
<span class="n">error</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">linalg</span><span class="p">.</span><span class="n">norm</span><span class="p">(</span><span class="n">A_sparse</span> <span class="o">@</span> <span class="n">x</span> <span class="o">-</span> <span class="n">b</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span>
<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Using GMRES: error: </span><span class="si">{</span><span class="n">error</span><span class="p">:.</span><span class="mi">4</span><span class="n">e</span><span class="si">}</span><span class="s"> in time </span><span class="si">{</span><span class="n">time_taken</span><span class="p">:.</span><span class="mi">1</span><span class="n">f</span><span class="si">}</span><span class="s">ms"</span><span class="p">)</span>
</code></pre></div></div>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Using GMRES: error: 1.1039e-15 in time 12.9ms
</code></pre></div></div>
<p>This clearly works; it’s not as fast as the <code class="language-plaintext highlighter-rouge">scipy</code> implementation of the same algorithm, but we’ll do something about that soon.</p>
<p>Let’s take a more detailed look at what the GMRES algorithm is doing. We iteratively define an orthonormal basis \(V_m = \{v_0,v_1,\dots,v_{m-1}\}\). We start with \(v_0 = r_0 / \|r_0\|\), where \(r_0 = b-Ax_0\) is the <em>residual</em> of the initial guess \(x_0\). In each iteration we then set \(w = A v_j\), and take \(v_{j+1} = w - \sum_i (w^\top v_{i})v_i\); i.e. we ensure \(v_{j+1}\) is orthogonal to all previous \(v_0,\dots,v_j\). Therefore \(V_m\) is an orthonormal basis of the Krylov subspace \(\mathcal K_m(A,r_0)\).</p>
<p>Once we have this basis, we want to solve the minimization problem:</p>
\[\min_{x\in \mathcal K_m(A,r_0)} \|A(x_0+x)-b\|\]
<p>Since \(V_m\) is a basis, we can write \(x = V_m y\) for some \(y\in \mathbb R^m\). Also note that in this basis \(b-Ax_0 = r_0 = \beta v_0 = \beta V_m e_1\) where \(\beta = \|r_0\|\) and \(e_1= (1,0,\dots,0)\). This allows us to rewrite the minimization problem:</p>
\[\min_{y\in\mathbb R^m} \|AV_my - \beta V_me_1\|\]
<p>To solve this minimization problem we need one more trick. In the algorithm we computed a matrix \(H\), it is defined like this:</p>
\[H_{ij} = v_i^\top (Av_j-\sum_k H_{kj}v_k) = v_i^\top A v_j\]
<p>These are precisely the coefficients of the Gram-Schmidt orthogonalization, and hence \(A v_j = \sum_{i=1}^{j+1} H_{ij}v_i\), giving the matrix equality \(AV_m = HV_m\). Now we can rewrite the minimization problem even further and get</p>
\[\min_{y\in\mathbb R^m} \|V_m (Hy - \beta e_1)\| = \min_{y\in\mathbb R^m} \|Hy - \beta e_1\|\]
<p>The minimization problem is therefore reduced to an \((m+1)\times m\) problem! The cost of this is \(O(m^3)\), and as long as we don’t use too many steps \(m\), this cost can be very reasonable. After solving for \(y\), we then get the estimate \(x = x_0 + V_m y\).</p>
<h2 id="restarting">Restarting</h2>
<p>In the current implementation of GMRES we specify the number of steps in advance, which is not ideal. If we converge to the right solution in less steps, then we are doing unnecessary work. If we don’t get a satisfying solution after the specified number of steps, we might need to start over. This is however not a big problem; we can use the output \(x=x_0+V_my\) as new initialization when we restart.</p>
<p>This gives a nice recipe for <em>GMRES with restarting</em>. We run GMRES for \(m\) steps with \(x_i\) as initialization to get a new estimate \(x_{i+1}\). We then check if \(x_{i+1}\) is good enough, if not, we repeat the GMRES procedure for another \(m\) steps.</p>
<p>It is possible to get a good estimate of the residual norm after <em>each</em> step of GMRES, not just every \(m\) steps. However, this is relatively technical to implement, so we will just consider the variation of GMRES with restarting.</p>
<p>How often should we restart? This really depends on the problem we’re trying to solve, since there is a
trade-off. More steps in between each restart will typically result in convergence in fewer steps, <em>but</em> it is
more expensive and also requires more memory. The computational cost scales as \(O(m^3)\), and the memory cost
scales linearly in \(m\) (if the matrix size \(n\) is much bigger than \(m\)). Let’s see this trade-off in action on a model problem.</p>
<h2 id="deconvolution">Deconvolution</h2>
<p>Recall that the deconvolution problem is of the following form:</p>
\[\min_x \|k * x -y\|^2\]
<p>for a fixed <em>kernel</em> \(k\) and signal \(y\). The convolution operation \(k*x\) is linear in \(x\), and we can
therefore treat this as a linear least-squares problem and solve it using GMRES. The operation \(k*x\) can be
written in matrix form as \(Kx\), where \(K\) is a matrix. For large images or signals, the matrix \(K\) can be
gigantic, and we never want to explicitly store \(K\) in memory. Fortunately, GMRES only cares about
matrix-vector products \(Kx\), making this a very good candidate to solve with GMRES.</p>
<p>Let’s consider the problem of sharpening (deconvolving) a 128x128 picture blurred using Gaussian blur. To make
the problem more interesting, the kernel \(k\) used for deconvolution will be slightly different from the kernel
used for blurring. This is inspired by the blind deconvolution problem, where we not only have to find \(x\),
but also the kernel \(k\) itself.</p>
<p>We solve this problem with GMRES using different number of steps between restarts, and plot how the error
evolves over time.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">matplotlib</span> <span class="kn">import</span> <span class="n">image</span>
<span class="kn">from</span> <span class="nn">utils</span> <span class="kn">import</span> <span class="n">random_motion_blur</span>
<span class="kn">from</span> <span class="nn">scipy.signal</span> <span class="kn">import</span> <span class="n">convolve2d</span>
<span class="c1"># Define the Gaussian blur kernel
</span><span class="k">def</span> <span class="nf">gaussian_psf</span><span class="p">(</span><span class="n">sigma</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">N</span><span class="o">=</span><span class="mi">9</span><span class="p">):</span>
<span class="n">gauss_psf</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">arange</span><span class="p">(</span><span class="o">-</span><span class="n">N</span> <span class="o">//</span> <span class="mi">2</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="n">N</span> <span class="o">//</span> <span class="mi">2</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">gauss_psf</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">exp</span><span class="p">(</span><span class="o">-</span><span class="p">(</span><span class="n">gauss_psf</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span><span class="mi">2</span> <span class="o">*</span> <span class="n">sigma</span> <span class="o">**</span> <span class="mi">2</span><span class="p">))</span>
<span class="n">gauss_psf</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">einsum</span><span class="p">(</span><span class="s">"i,j->ij"</span><span class="p">,</span> <span class="n">gauss_psf</span><span class="p">,</span> <span class="n">gauss_psf</span><span class="p">)</span>
<span class="n">gauss_psf</span> <span class="o">=</span> <span class="n">gauss_psf</span> <span class="o">/</span> <span class="n">np</span><span class="p">.</span><span class="nb">sum</span><span class="p">(</span><span class="n">gauss_psf</span><span class="p">)</span>
<span class="k">return</span> <span class="n">gauss_psf</span>
<span class="c1"># Load the image and blur it
</span><span class="n">img</span> <span class="o">=</span> <span class="n">image</span><span class="p">.</span><span class="n">imread</span><span class="p">(</span><span class="s">"imgs/vitus128.png"</span><span class="p">)</span>
<span class="n">gauss_psf_true</span> <span class="o">=</span> <span class="n">gaussian_psf</span><span class="p">(</span><span class="n">sigma</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">N</span><span class="o">=</span><span class="mi">11</span><span class="p">)</span>
<span class="n">gauss_psf_almost</span> <span class="o">=</span> <span class="n">gaussian_psf</span><span class="p">(</span><span class="n">sigma</span><span class="o">=</span><span class="mf">1.05</span><span class="p">,</span> <span class="n">N</span><span class="o">=</span><span class="mi">11</span><span class="p">)</span>
<span class="n">img_blur</span> <span class="o">=</span> <span class="n">convolve2d</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="n">gauss_psf_true</span><span class="p">,</span> <span class="n">mode</span><span class="o">=</span><span class="s">"same"</span><span class="p">)</span>
<span class="c1"># Define the convolution linear map
</span><span class="n">linear_map</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">convolve2d</span><span class="p">(</span>
<span class="n">x</span><span class="p">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">img</span><span class="p">.</span><span class="n">shape</span><span class="p">),</span> <span class="n">gauss_psf_almost</span><span class="p">,</span> <span class="n">mode</span><span class="o">=</span><span class="s">"same"</span>
<span class="p">).</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="c1"># Apply GMRES for different restart frequencies and measure time taken
</span><span class="n">total_its</span> <span class="o">=</span> <span class="mi">2000</span>
<span class="n">n_restart_list</span> <span class="o">=</span> <span class="p">[</span><span class="mi">20</span><span class="p">,</span> <span class="mi">50</span><span class="p">,</span> <span class="mi">200</span><span class="p">,</span> <span class="mi">500</span><span class="p">]</span>
<span class="n">losses_dict</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()</span>
<span class="k">for</span> <span class="n">n_restart</span> <span class="ow">in</span> <span class="n">n_restart_list</span><span class="p">:</span>
<span class="n">time_before</span> <span class="o">=</span> <span class="n">perf_counter_ns</span><span class="p">()</span>
<span class="n">b</span> <span class="o">=</span> <span class="n">img_blur</span><span class="p">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">x0</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">b</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">x0</span>
<span class="n">losses</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">total_its</span> <span class="o">//</span> <span class="n">n_restart</span><span class="p">):</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">gmres</span><span class="p">(</span><span class="n">linear_map</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">n_restart</span><span class="p">)</span>
<span class="n">error</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">linalg</span><span class="p">.</span><span class="n">norm</span><span class="p">(</span><span class="n">linear_map</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="o">-</span> <span class="n">b</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span>
<span class="n">losses</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">error</span><span class="p">)</span>
<span class="n">time_taken</span> <span class="o">=</span> <span class="p">(</span><span class="n">perf_counter_ns</span><span class="p">()</span> <span class="o">-</span> <span class="n">time_before</span><span class="p">)</span> <span class="o">/</span> <span class="mf">1e9</span>
<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Best loss for </span><span class="si">{</span><span class="n">n_restart</span><span class="si">}</span><span class="s"> restart frequency is </span><span class="si">{</span><span class="n">error</span><span class="p">:.</span><span class="mi">4</span><span class="n">e</span><span class="si">}</span><span class="s"> in </span><span class="si">{</span><span class="n">time_taken</span><span class="p">:.</span><span class="mi">2</span><span class="n">f</span><span class="si">}</span><span class="s">s"</span><span class="p">)</span>
<span class="n">losses_dict</span><span class="p">[</span><span class="n">n_restart</span><span class="p">]</span> <span class="o">=</span> <span class="n">losses</span>
</code></pre></div></div>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Best loss for 20 restart frequency is 9.3595e-16 in 11.32s
Best loss for 50 restart frequency is 2.4392e-22 in 11.71s
Best loss for 200 restart frequency is 6.3063e-28 in 17.34s
Best loss for 500 restart frequency is 6.9367e-28 in 30.50s
</code></pre></div></div>
<p><img src="/imgs/gmres/gmres_13_0.svg" alt="svg" /></p>
<p>We observe that with all restart frequencies we converge to a result with very low error. The larger the
number of steps between restarts, the faster we converge. Remember however that the cost of GMRES rises as
\(O(m^3)\) with the number of steps \(m\) between restarts, so a larger number of steps is not always better. For
example we see that \(m=20\) and \(m=50\) produced almost identical runtime, but for \(m=200\) the runtime for 2000
total steps is already significantly bigger, and the effect is even bigger for \(m=500\). This means that if we
want to get converge as fast as possible <em>in terms of runtime</em>, we’re best off with somewhere between \(m=50\)
and \(m=200\) steps between each reset.</p>
<h2 id="gpu-implementation">GPU implementation</h2>
<p>If we do simple profiling, we see that almost all of the time in this function is spent on the 2D convolution.
Indeed this is why the runtime does not seem to scale os \(O(m^3)\) for the values of \(m\) we tried above.
It simply takes a while before the \(O(m^3)\) factor becomes dominant over the time spent by matrix-vector
products.</p>
<p>This also means that it should be straightforward to speed up – we just need to do the convolution on a GPU.
It is not as simple as that however; if we just do the convolution on GPU and the rest of the operations on
CPU, then the bottleneck quickly becomes moving the data between CPU and GPU (unless we are working on a
system where CPU and GPU share memory).</p>
<p>Fortunately the entire GMRES algorithm is not so complex, and we can use hardware acceleration by simply
translating the algorithm to use a fast computational library. There are several such libraries available for
Python:</p>
<ul>
<li>TensorFlow</li>
<li>PyTorch</li>
<li>DASK</li>
<li>CuPy</li>
<li>JAX</li>
<li>Numba</li>
</ul>
<p>In this context CuPy might be the easiest to use; its syntax is very similar to numpy. However, I would
also like to make use of JIT (Just-in-time) compilation, particularly since this can limit unnecessary data
movement. Furthermore, it really depends on the situation which low-level CUDA functions are best called in
different situations (especially for something like convolution), and JIT compilation can offer significant
optimizations here.</p>
<p>TensorFlow, DASK and PyTorch are really focussed on machine-learning and neural networks, and the way we
interact with these libraries might not be the best for this kind of algorithm. In fact, I tried to make an
efficient GMRES implementation using these libraries, and I was really struggling; I feel these libraries
simply aren’t the right tool for this job.</p>
<p>Numba is also great, I could basically feed it the code I already wrote and it would probably compile the
function and make it several times faster <em>on CPU</em>. Unfortunately, the support for GPU is still lacking quite
a bit in Numba, and we would therefore still leave quite a bit of performance on the table.</p>
<p>In the end we will implement it in JAX. Like CuPy, it has an API very similar to numpy which means it’s easy to translation. However, it also supports JIT, meaning we can potentially make much faster functions. Without further
ado, let’s implement the GMRES algorithm in JAX and see what kind of speedup we can get.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="n">jnp</span>
<span class="kn">import</span> <span class="nn">jax</span>
<span class="c1"># Define the linear operator
</span><span class="n">img_shape</span> <span class="o">=</span> <span class="n">img</span><span class="p">.</span><span class="n">shape</span>
<span class="k">def</span> <span class="nf">do_convolution</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
<span class="k">return</span> <span class="n">jax</span><span class="p">.</span><span class="n">scipy</span><span class="p">.</span><span class="n">signal</span><span class="p">.</span><span class="n">convolve2d</span><span class="p">(</span>
<span class="n">x</span><span class="p">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">img_shape</span><span class="p">),</span> <span class="n">gauss_psf_almost</span><span class="p">,</span> <span class="n">mode</span><span class="o">=</span><span class="s">"same"</span>
<span class="p">).</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">gmres_jax</span><span class="p">(</span><span class="n">linear_map</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">x0</span><span class="p">,</span> <span class="n">n_iter</span><span class="p">):</span>
<span class="c1"># Initialization
</span> <span class="n">n</span> <span class="o">=</span> <span class="n">x0</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">r0</span> <span class="o">=</span> <span class="n">b</span> <span class="o">-</span> <span class="n">linear_map</span><span class="p">(</span><span class="n">x0</span><span class="p">)</span>
<span class="n">beta</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">linalg</span><span class="p">.</span><span class="n">norm</span><span class="p">(</span><span class="n">r0</span><span class="p">)</span>
<span class="n">V</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">n_iter</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="n">n</span><span class="p">))</span>
<span class="n">V</span> <span class="o">=</span> <span class="n">V</span><span class="p">.</span><span class="n">at</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="nb">set</span><span class="p">(</span><span class="n">r0</span> <span class="o">/</span> <span class="n">beta</span><span class="p">)</span>
<span class="n">H</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">n_iter</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="n">n_iter</span><span class="p">))</span>
<span class="k">def</span> <span class="nf">loop_body</span><span class="p">(</span><span class="n">j</span><span class="p">,</span> <span class="n">pair</span><span class="p">):</span>
<span class="s">"""
One basic step of GMRES; compute new Krylov vector and orthogonalize.
"""</span>
<span class="n">H</span><span class="p">,</span> <span class="n">V</span> <span class="o">=</span> <span class="n">pair</span>
<span class="n">w</span> <span class="o">=</span> <span class="n">linear_map</span><span class="p">(</span><span class="n">V</span><span class="p">[</span><span class="n">j</span><span class="p">])</span>
<span class="n">h</span> <span class="o">=</span> <span class="n">V</span> <span class="o">@</span> <span class="n">w</span>
<span class="n">v</span> <span class="o">=</span> <span class="n">w</span> <span class="o">-</span> <span class="p">(</span><span class="n">V</span><span class="p">.</span><span class="n">T</span><span class="p">)</span> <span class="o">@</span> <span class="n">h</span>
<span class="n">v_norm</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">linalg</span><span class="p">.</span><span class="n">norm</span><span class="p">(</span><span class="n">v</span><span class="p">)</span>
<span class="n">H</span> <span class="o">=</span> <span class="n">H</span><span class="p">.</span><span class="n">at</span><span class="p">[:,</span> <span class="n">j</span><span class="p">].</span><span class="nb">set</span><span class="p">(</span><span class="n">h</span><span class="p">)</span>
<span class="n">H</span> <span class="o">=</span> <span class="n">H</span><span class="p">.</span><span class="n">at</span><span class="p">[</span><span class="n">j</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="n">j</span><span class="p">].</span><span class="nb">set</span><span class="p">(</span><span class="n">v_norm</span><span class="p">)</span>
<span class="n">V</span> <span class="o">=</span> <span class="n">V</span><span class="p">.</span><span class="n">at</span><span class="p">[</span><span class="n">j</span> <span class="o">+</span> <span class="mi">1</span><span class="p">].</span><span class="nb">set</span><span class="p">(</span><span class="n">v</span> <span class="o">/</span> <span class="n">v_norm</span><span class="p">)</span>
<span class="k">return</span> <span class="n">H</span><span class="p">,</span> <span class="n">V</span>
<span class="c1"># Do n_iter iterations of basic GMRES step
</span> <span class="n">H</span><span class="p">,</span> <span class="n">V</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">lax</span><span class="p">.</span><span class="n">fori_loop</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">n_iter</span><span class="p">,</span> <span class="n">loop_body</span><span class="p">,</span> <span class="p">(</span><span class="n">H</span><span class="p">,</span> <span class="n">V</span><span class="p">))</span>
<span class="c1"># Solve the linear system in the basis V
</span> <span class="n">e1</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">n_iter</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">e1</span> <span class="o">=</span> <span class="n">e1</span><span class="p">.</span><span class="n">at</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="nb">set</span><span class="p">(</span><span class="n">beta</span><span class="p">)</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">linalg</span><span class="p">.</span><span class="n">lstsq</span><span class="p">(</span><span class="n">H</span><span class="p">,</span> <span class="n">e1</span><span class="p">,</span> <span class="n">rcond</span><span class="o">=</span><span class="bp">None</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
<span class="c1"># Convert result back to full basis and return
</span> <span class="n">x_new</span> <span class="o">=</span> <span class="n">x0</span> <span class="o">+</span> <span class="n">V</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">].</span><span class="n">T</span> <span class="o">@</span> <span class="n">y</span>
<span class="k">return</span> <span class="n">x_new</span>
<span class="n">b</span> <span class="o">=</span> <span class="n">img_blur</span><span class="p">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">x0</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">b</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">x0</span>
<span class="n">n_restart</span> <span class="o">=</span> <span class="mi">50</span>
<span class="c1"># Declare JIT compiled version of gmres_jax
</span><span class="n">gmres_jit</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">jit</span><span class="p">(</span><span class="n">gmres_jax</span><span class="p">,</span> <span class="n">static_argnums</span><span class="o">=</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">3</span><span class="p">])</span>
<span class="k">print</span><span class="p">(</span><span class="s">"Compiling function:"</span><span class="p">)</span>
<span class="o">%</span><span class="n">time</span> <span class="n">x</span> <span class="o">=</span> <span class="n">gmres_jit</span><span class="p">(</span><span class="n">do_convolution</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">x0</span><span class="p">,</span> <span class="n">n_restart</span><span class="p">).</span><span class="n">block_until_ready</span><span class="p">()</span>
<span class="k">print</span><span class="p">(</span><span class="s">"</span><span class="se">\n</span><span class="s">Profiling functions. numpy version:"</span><span class="p">)</span>
<span class="o">%</span><span class="n">timeit</span> <span class="n">x</span> <span class="o">=</span> <span class="n">gmres</span><span class="p">(</span><span class="n">linear_map</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">x0</span><span class="p">,</span> <span class="n">n_restart</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="s">"</span><span class="se">\n</span><span class="s">Profiling functions. JAX version:"</span><span class="p">)</span>
<span class="o">%</span><span class="n">timeit</span> <span class="n">x</span> <span class="o">=</span> <span class="n">gmres_jit</span><span class="p">(</span><span class="n">do_convolution</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">x0</span><span class="p">,</span> <span class="n">n_restart</span><span class="p">).</span><span class="n">block_until_ready</span><span class="p">()</span>
</code></pre></div></div>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Compiling function:
CPU times: user 1.94 s, sys: 578 ms, total: 2.51 s
Wall time: 2.01 s
Profiling functions. numpy version:
263 ms ± 25.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Profiling functions. JAX version:
9.16 ms ± 90.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
</code></pre></div></div>
<p>With the JAX version running on my GPU, we get a 30x times speedup! Not bad, if you ask me. If we run the same
code on CPU, we still get a 4x speedup. This means that the version compiled by JAX is already faster in its
own right.</p>
<p>The code above may look a bit strange, and there are definitely some things that might need some explanation.
First of all, note that the first time we call <code class="language-plaintext highlighter-rouge">gmres_jit</code> it takes much longer than the subsequent calls.
This is because the function is JIT – just in time compiled. On the first call, JAX runs through the entire
function and makes a big graph of all the operations that need to be done, it then optimizes (simplifies) this
graph, and compiles it to create a very fast function. This compilation step obviously takes some time, but
the great thing is that we only need to do it once.</p>
<p>Note the way we create the function <code class="language-plaintext highlighter-rouge">gmres_jit</code>:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code> <span class="n">gmres_jit</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">jit</span><span class="p">(</span><span class="n">gmres_jax</span><span class="p">,</span> <span class="n">static_argnums</span><span class="o">=</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">3</span><span class="p">])</span>
</code></pre></div></div>
<p>Here we tell JAX that if the first or the fourth argument changes, the function needs to be recompiled. This
is because both these arguments are python literals (the first is a function, the fourth is the number of
iterations), whereas the other two arguments are arrays.</p>
<p>The shape of the arrays <code class="language-plaintext highlighter-rouge">V</code> and <code class="language-plaintext highlighter-rouge">H</code> depend on the last argument <code class="language-plaintext highlighter-rouge">n_iter</code>. However, the compiler needs to know the shape of these arrays <em>at compile time</em>. Therefore, we need to recompile the function every time that <code class="language-plaintext highlighter-rouge">n_iter</code> changes. The same is true for the <code class="language-plaintext highlighter-rouge">linear_map</code> argument; the
shape of the vector <code class="language-plaintext highlighter-rouge">w</code> depends on <code class="language-plaintext highlighter-rouge">linear_map</code> in principle.</p>
<p>Next, consider the fact that there is no more <code class="language-plaintext highlighter-rouge">for</code> loop in the code, and it is instead replaced by</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code> <span class="n">H</span><span class="p">,</span> <span class="n">V</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">lax</span><span class="p">.</span><span class="n">fori_loop</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">n_iter</span><span class="p">,</span> <span class="n">loop_body</span><span class="p">,</span> <span class="p">(</span><span class="n">H</span><span class="p">,</span> <span class="n">V</span><span class="p">))</span>
</code></pre></div></div>
<p>We could in fact use a for loop here as well, and it would give an identical result but it would take much
longer to compile. The reason for this is that, as mentioned, JAX runs through the entire function and makes a
graph of all the operations that need to be done. If we leave in the for loop, then each iteration of the loop
would add more and more operations to the graph (the loop is ‘unrolled’), making a really big graph. By using
<code class="language-plaintext highlighter-rouge">jax.lax.fori_loop</code> we can skip this, and end up with a much smaller graph to be compiled.</p>
<p>One disadvantage of this approach is that the size of all arrays needs to be known at compile time. In the
original algorithm we did not compute <code class="language-plaintext highlighter-rouge">(V.T) @ h</code> for example, but rather <code class="language-plaintext highlighter-rouge">(V[:j+1].T) @ h</code>. Now we can’t do
that, because the size of <code class="language-plaintext highlighter-rouge">V[:j+1]</code> is not known at compile time. The end result ends up being the same
because at iteration <code class="language-plaintext highlighter-rouge">j</code>, we have <code class="language-plaintext highlighter-rouge">V[j+1:] = 0</code>. This actually means that over all the iterations of <code class="language-plaintext highlighter-rouge">j</code> we
end up doing about double the work for this particular operation. However, because the operation is so much
faster on a GPU this is not a big problem.</p>
<p>As we can see, writing code for GPUs requires a bit more thought than writing code for CPUs. Sometimes we even
end up with less efficient code, but this can be entirely offset by the improved speed of the GPU.</p>
<h2 id="condition-numbers-and-eigenvalues">Condition numbers and eigenvalues</h2>
<p>We see above that GMRES provides a very fast and accurate solution to the deconvolution problem. This has a lot to do with the fact that the convolution matrix is very well-conditioned. We can see this by looking at the singular of this matrix. The convolution matrix for a 128x128 image is a bit too big to work with, but we can see what happens for 32x32 images.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">N</span> <span class="o">=</span> <span class="mi">11</span>
<span class="n">psf</span> <span class="o">=</span> <span class="n">gaussian_psf</span><span class="p">(</span><span class="n">sigma</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">N</span><span class="o">=</span><span class="n">N</span><span class="p">)</span>
<span class="n">img_shape</span> <span class="o">=</span> <span class="p">(</span><span class="mi">32</span><span class="p">,</span> <span class="mi">32</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">create_conv_mat</span><span class="p">(</span><span class="n">psf</span><span class="p">,</span> <span class="n">img_shape</span><span class="p">):</span>
<span class="n">tot_dim</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">prod</span><span class="p">(</span><span class="n">img_shape</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">apply_psf</span><span class="p">(</span><span class="n">signal</span><span class="p">):</span>
<span class="n">signal</span> <span class="o">=</span> <span class="n">signal</span><span class="p">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">img_shape</span><span class="p">)</span>
<span class="k">return</span> <span class="n">convolve2d</span><span class="p">(</span><span class="n">signal</span><span class="p">,</span> <span class="n">psf</span><span class="p">,</span> <span class="n">mode</span><span class="o">=</span><span class="s">"same"</span><span class="p">).</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">conv_mat</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">tot_dim</span><span class="p">,</span> <span class="n">tot_dim</span><span class="p">))</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">tot_dim</span><span class="p">):</span>
<span class="n">signal</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">tot_dim</span><span class="p">)</span>
<span class="n">signal</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span>
<span class="n">conv_mat</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">apply_psf</span><span class="p">(</span><span class="n">signal</span><span class="p">)</span>
<span class="k">return</span> <span class="n">conv_mat</span>
<span class="n">conv_mat</span> <span class="o">=</span> <span class="n">create_conv_mat</span><span class="p">(</span><span class="n">psf</span><span class="p">,</span> <span class="n">img_shape</span><span class="p">)</span>
<span class="n">svdvals</span> <span class="o">=</span> <span class="n">scipy</span><span class="p">.</span><span class="n">linalg</span><span class="p">.</span><span class="n">svdvals</span><span class="p">(</span><span class="n">conv_mat</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">svdvals</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">yscale</span><span class="p">(</span><span class="s">'log'</span><span class="p">)</span>
<span class="n">cond_num</span> <span class="o">=</span> <span class="n">svdvals</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">/</span><span class="n">svdvals</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
<span class="n">plt</span><span class="p">.</span><span class="n">title</span><span class="p">(</span><span class="sa">f</span><span class="s">"Singular values. Condition number: </span><span class="si">{</span><span class="n">cond_num</span><span class="p">:.</span><span class="mi">0</span><span class="n">f</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
</code></pre></div></div>
<p><img src="/imgs/gmres/gmres_19_1.svg" alt="svg" /></p>
<p>As we can see, the condition number is only 4409, which makes the matrix very well-conditioned. Moreover, the
singular values decay somewhat gradually. What’s more, the convolution matrix is actually symmetric and
positive definite. This makes the linear system relatively easy to solve, and explains why it works so well.</p>
<p>This is because the kernel we use – the Gaussian kernel – is itself symmetric. For a non-symmetric kernel,
the situation is more complicated. Below we show what happens for a non-symmetric kernel, the same type as we
used before in the blind deconvolution series of blog posts.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">utils</span> <span class="kn">import</span> <span class="n">random_motion_blur</span>
<span class="n">N</span> <span class="o">=</span> <span class="mi">11</span>
<span class="n">psf_gaussian</span> <span class="o">=</span> <span class="n">gaussian_psf</span><span class="p">(</span><span class="n">sigma</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">N</span><span class="o">=</span><span class="n">N</span><span class="p">)</span>
<span class="n">psf</span> <span class="o">=</span> <span class="n">random_motion_blur</span><span class="p">(</span>
<span class="n">N</span><span class="o">=</span><span class="n">N</span><span class="p">,</span> <span class="n">num_steps</span><span class="o">=</span><span class="mi">20</span><span class="p">,</span> <span class="n">beta</span><span class="o">=</span><span class="mf">0.98</span><span class="p">,</span> <span class="n">vel_scale</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span> <span class="n">sigma</span><span class="o">=</span><span class="mf">0.5</span><span class="p">,</span> <span class="n">seed</span><span class="o">=</span><span class="mi">42</span>
<span class="p">)</span>
<span class="n">img_shape</span> <span class="o">=</span> <span class="p">(</span><span class="mi">32</span><span class="p">,</span> <span class="mi">32</span><span class="p">)</span>
<span class="c1"># plot the kernels
</span><span class="n">plt</span><span class="p">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">8</span><span class="p">,</span> <span class="mf">4.5</span><span class="p">))</span>
<span class="n">plt</span><span class="p">.</span><span class="n">subplot</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">psf_gaussian</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">title</span><span class="p">(</span><span class="s">"Gaussian kernel"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">subplot</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">psf</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">title</span><span class="p">(</span><span class="s">"Non-symmetric kernel"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">show</span><span class="p">()</span>
<span class="c1"># study convolution matrix
</span><span class="n">conv_mat</span> <span class="o">=</span> <span class="n">create_conv_mat</span><span class="p">(</span><span class="n">psf</span><span class="p">,</span> <span class="n">img_shape</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">show</span><span class="p">()</span>
<span class="n">eigs</span> <span class="o">=</span> <span class="n">scipy</span><span class="p">.</span><span class="n">linalg</span><span class="p">.</span><span class="n">eigvals</span><span class="p">(</span><span class="n">conv_mat</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">title</span><span class="p">(</span><span class="sa">f</span><span class="s">"Eigenvalues"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">ylabel</span><span class="p">(</span><span class="s">"Imaginary part"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">xlabel</span><span class="p">(</span><span class="s">"Real part"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">scatter</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">real</span><span class="p">(</span><span class="n">eigs</span><span class="p">),</span> <span class="n">np</span><span class="p">.</span><span class="n">imag</span><span class="p">(</span><span class="n">eigs</span><span class="p">),</span> <span class="n">marker</span><span class="o">=</span><span class="s">"."</span><span class="p">)</span>
</code></pre></div></div>
<p><img src="/imgs/gmres/gmres_21_0.svg" alt="svg" /></p>
<p><img src="/imgs/gmres/gmres_21_2.svg" alt="svg" /></p>
<p>We see that the eigenvalues of this convolution matrix are distributed <em>around</em> zero. The convolution matrix for the gaussian kernel is symmetric and positive definite – all eigenvalues are positive real numbers. GMRES works really well when almost all eigenvalues lie in an ellipse <em>not containing zero</em>. That is clearly not the case here, and we in fact also see that GMRES doesn’t work well for this particular problem.
(Note that we now switch to 256x256 images instead of 128x128, since our new implementation of GMRES is much faster)</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">img</span> <span class="o">=</span> <span class="n">image</span><span class="p">.</span><span class="n">imread</span><span class="p">(</span><span class="s">"imgs/vitus256.png"</span><span class="p">)</span>
<span class="n">psf</span> <span class="o">=</span> <span class="n">random_motion_blur</span><span class="p">(</span>
<span class="n">N</span><span class="o">=</span><span class="n">N</span><span class="p">,</span> <span class="n">num_steps</span><span class="o">=</span><span class="mi">20</span><span class="p">,</span> <span class="n">beta</span><span class="o">=</span><span class="mf">0.98</span><span class="p">,</span> <span class="n">vel_scale</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span> <span class="n">sigma</span><span class="o">=</span><span class="mf">0.5</span><span class="p">,</span> <span class="n">seed</span><span class="o">=</span><span class="mi">42</span>
<span class="p">)</span>
<span class="n">img_blur</span> <span class="o">=</span> <span class="n">convolve2d</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="n">psf</span><span class="p">,</span> <span class="n">mode</span><span class="o">=</span><span class="s">"same"</span><span class="p">)</span>
<span class="n">img_shape</span> <span class="o">=</span> <span class="n">img</span><span class="p">.</span><span class="n">shape</span>
<span class="k">def</span> <span class="nf">do_convolution</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
<span class="n">res</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">scipy</span><span class="p">.</span><span class="n">signal</span><span class="p">.</span><span class="n">convolve2d</span><span class="p">(</span>
<span class="n">x</span><span class="p">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">img_shape</span><span class="p">),</span> <span class="n">psf</span><span class="p">,</span> <span class="n">mode</span><span class="o">=</span><span class="s">"same"</span>
<span class="p">).</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="k">return</span> <span class="n">res</span>
<span class="n">b</span> <span class="o">=</span> <span class="n">img_blur</span><span class="p">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">x0</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">b</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">x0</span>
<span class="n">n_restart</span> <span class="o">=</span> <span class="mi">1000</span>
<span class="n">n_its</span> <span class="o">=</span> <span class="mi">10</span>
<span class="n">losses</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_its</span><span class="p">):</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">gmres_jit</span><span class="p">(</span><span class="n">do_convolution</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">n_restart</span><span class="p">)</span>
<span class="n">error</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">linalg</span><span class="p">.</span><span class="n">norm</span><span class="p">(</span><span class="n">do_convolution</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="o">-</span> <span class="n">b</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span>
<span class="n">losses</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">error</span><span class="p">)</span>
</code></pre></div></div>
<p><img src="/imgs/gmres/gmres_24_1.svg" alt="svg" /></p>
<p>Not does it take much more iterations to converge, the final result is unsatisfactory at best. Clearly without
further modifications the GMRES method doesn’t work well for deconvolution for non-symmetric kernels.</p>
<h2 id="changing-the-spectrum">Changing the spectrum</h2>
<p>As mentioned, GMRES works best when the eigenvalues of the matrix \(A\) are in an ellipse not including zero, which is not the case for our convolution matrix. There is fortunately a very simple solution to this: instead of solving the linear least-squares problem</p>
\[\min_x \|Ax - b\|_2^2\]
<p>We solve the linear least squares problem</p>
\[\min_x \|A^\top A x - A^\top b\|^2\]
<p>This will have the same solution, but the eigenvalues of \(A^\top A\) are better behaved. Any matrix like this will be positive semi-definite, and all eigenvalues will be real and non-negative. They therefore all fit inside an ellipse that doesn’t include zero, and we will get much better convergence with GMRES. In general, we could multiply by any matrix \(B\) to obtain the linear least-squares problem</p>
\[\min_x \|BAX-Bb\|^2\]
<p>If we choose \(B\) such that the spectrum (eigenvalues) of \(BA\) are nicer, then we can improve convergence of GMRES. This trick is called <em>preconditioning</em>. Choosing a good <em>preconditioner</em> depends a lot on the problem at hand, and is the subject of a lot of research. In this context, \(A^\top\) turns out to function as an excellent preconditioner, as we shall see.</p>
<p>To apply this trick to the deconvolution problem, we need to be able to take the transpose of the convolution
operation. Fortunately, this is equivalent to convolution with a reflected version \(\overline k\) of the kernel
\(k\). That is, we will apply GMRES to the linear least-squares problem</p>
\[\min_x \|\overline k *(k*x) - \overline k * y\|\]
<p>let’s see this in action below.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">img</span> <span class="o">=</span> <span class="n">image</span><span class="p">.</span><span class="n">imread</span><span class="p">(</span><span class="s">"imgs/vitus256.png"</span><span class="p">)</span>
<span class="n">psf</span> <span class="o">=</span> <span class="n">random_motion_blur</span><span class="p">(</span>
<span class="n">N</span><span class="o">=</span><span class="n">N</span><span class="p">,</span> <span class="n">num_steps</span><span class="o">=</span><span class="mi">20</span><span class="p">,</span> <span class="n">beta</span><span class="o">=</span><span class="mf">0.98</span><span class="p">,</span> <span class="n">vel_scale</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span> <span class="n">sigma</span><span class="o">=</span><span class="mf">0.5</span><span class="p">,</span> <span class="n">seed</span><span class="o">=</span><span class="mi">42</span>
<span class="p">)</span>
<span class="n">psf_reversed</span> <span class="o">=</span> <span class="n">psf</span><span class="p">[::</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="p">::</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
<span class="n">img_blur</span> <span class="o">=</span> <span class="n">convolve2d</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="n">psf</span><span class="p">,</span> <span class="n">mode</span><span class="o">=</span><span class="s">"same"</span><span class="p">)</span>
<span class="n">img_shape</span> <span class="o">=</span> <span class="n">img</span><span class="p">.</span><span class="n">shape</span>
<span class="k">def</span> <span class="nf">do_convolution</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
<span class="n">res</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">scipy</span><span class="p">.</span><span class="n">signal</span><span class="p">.</span><span class="n">convolve2d</span><span class="p">(</span><span class="n">x</span><span class="p">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">img_shape</span><span class="p">),</span> <span class="n">psf</span><span class="p">,</span> <span class="n">mode</span><span class="o">=</span><span class="s">"same"</span><span class="p">)</span>
<span class="n">res</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">scipy</span><span class="p">.</span><span class="n">signal</span><span class="p">.</span><span class="n">convolve2d</span><span class="p">(</span><span class="n">res</span><span class="p">,</span> <span class="n">psf_reversed</span><span class="p">,</span> <span class="n">mode</span><span class="o">=</span><span class="s">"same"</span><span class="p">)</span>
<span class="k">return</span> <span class="n">res</span><span class="p">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">b</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">scipy</span><span class="p">.</span><span class="n">signal</span><span class="p">.</span><span class="n">convolve2d</span><span class="p">(</span><span class="n">img_blur</span><span class="p">,</span> <span class="n">psf_reversed</span><span class="p">,</span> <span class="n">mode</span><span class="o">=</span><span class="s">"same"</span><span class="p">).</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">x0</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">b</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">x0</span>
<span class="n">n_restart</span> <span class="o">=</span> <span class="mi">100</span>
<span class="n">n_its</span> <span class="o">=</span> <span class="mi">20</span>
<span class="c1"># run once to compile
</span><span class="n">gmres_jit</span><span class="p">(</span><span class="n">do_convolution</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">n_restart</span><span class="p">)</span>
<span class="n">time_start</span> <span class="o">=</span> <span class="n">perf_counter_ns</span><span class="p">()</span>
<span class="n">losses</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_its</span><span class="p">):</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">gmres_jit</span><span class="p">(</span><span class="n">do_convolution</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">n_restart</span><span class="p">)</span>
<span class="n">error</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">linalg</span><span class="p">.</span><span class="n">norm</span><span class="p">(</span><span class="n">do_convolution</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="o">-</span> <span class="n">b</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span>
<span class="n">losses</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">error</span><span class="p">)</span>
<span class="n">time_taken</span> <span class="o">=</span> <span class="p">(</span><span class="n">perf_counter_ns</span><span class="p">()</span> <span class="o">-</span> <span class="n">time_start</span><span class="p">)</span> <span class="o">/</span> <span class="mf">1e9</span>
<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Deconvolution in </span><span class="si">{</span><span class="n">time_taken</span><span class="p">:.</span><span class="mi">2</span><span class="n">f</span><span class="si">}</span><span class="s"> s"</span><span class="p">)</span>
</code></pre></div></div>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Deconvolution in 1.40 s
</code></pre></div></div>
<p><img src="/imgs/gmres/gmres_28_0.svg" alt="svg" /></p>
<p>Except for some ringing around the edges, this produces very good result. Compared to other methods of
deconvolution (as discussed in <a href="/deconvolution-part3">this blog post</a>) this in fact shows much less ringing
artifacts. It’s pretty fast as well. Even though it takes us around 2000 iterations to converge, the
differences between the image after 50 steps or 2000 steps is not that big visually speaking. Let’s see how the
solution develops with different numbers of iterations:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">x0</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">b</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">x0</span>
<span class="n">results_dict</span> <span class="o">=</span> <span class="p">{}</span>
<span class="k">for</span> <span class="n">n_its</span> <span class="ow">in</span> <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">10</span><span class="p">,</span> <span class="mi">20</span><span class="p">,</span> <span class="mi">50</span><span class="p">,</span> <span class="mi">100</span><span class="p">]:</span>
<span class="n">x0</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">b</span><span class="p">)</span>
<span class="c1"># run once to compile
</span> <span class="n">gmres_jit</span><span class="p">(</span><span class="n">do_convolution</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">x0</span><span class="p">,</span> <span class="n">n_its</span><span class="p">)</span>
<span class="n">time_start</span> <span class="o">=</span> <span class="n">perf_counter_ns</span><span class="p">()</span>
<span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">10</span><span class="p">):</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">gmres_jit</span><span class="p">(</span><span class="n">do_convolution</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">x0</span><span class="p">,</span> <span class="n">n_its</span><span class="p">)</span>
<span class="n">time_taken</span> <span class="o">=</span> <span class="p">(</span><span class="n">perf_counter_ns</span><span class="p">()</span> <span class="o">-</span> <span class="n">time_start</span><span class="p">)</span> <span class="o">/</span> <span class="mf">1e7</span>
<span class="n">results_dict</span><span class="p">[</span><span class="n">n_its</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">time_taken</span><span class="p">)</span>
</code></pre></div></div>
<p><img src="/imgs/gmres/gmres_31_0.svg" alt="svg" /></p>
<p>After just 100 iterations the result is pretty good, and this takes just 64ms. This makes it a viable method
for deconvolution, roughly equally as fast as Richardson-Lucy deconvolution, but suffering less from boundary
artifacts. The regularization methods we have discussed in the deconvolution blog posts also work in this
setting, and are good to use in the case where there is noise, or where we don’t precisely know the
convolution kernel. That is however out of the scope of this blog post.</p>
<h2 id="conclusion">Conclusion</h2>
<p>GMRES is an easy to implement, fast and robust method for solving <em>structured</em> linear system, where we only
have access to matrix-vector products \(Ax\). It is often used for solving sparse systems, but as we have
demonstrated, it can also be used for solving the deconvolution problem in a way that is competitive with
existing methods. Sometimes a preconditioner is needed to get good performance out of GMRES, but choosing a
good preconditioner can be difficult. If we implement GMRES on a GPU it can reach much higher speeds than on
CPU.</p>Rik VoorhaarLinear least-squares system pop up everywhere, and there are many fast way to solve them. We'll be looking at one such way: GMRES.Low-rank matrices: using structure to recover missing data2021-09-26T00:00:00+00:002021-09-26T00:00:00+00:00https://rikvoorhaar.com/low-rank-matrix<p>Tensor networks are probably the most important tool in my research, and I want
explain them. Before I can do this however, I should first talk about low-rank
matrix decompositions, and why they’re so incredibly useful. At the same time I
will illustrate everything using examples in Python code, using <code class="language-plaintext highlighter-rouge">numpy</code>.</p>
<h2 id="the-singular-value-decomposition">The singular value decomposition</h2>
<p>Often if we have an \(m\times n\) matrix, we can write it as the product of two
smaller matrices. If such a matrix has <em>rank</em> \(r\), then we can write it as the
product of an \(m\times r\) and \(r\times n\) matrix. Equivalently, this is the
<em>number of linearly independent columns or rows</em> the matrix has, or if we see
the matrix as a linear map \(\mathbb R^m\to \mathbb R^n\), then it is the
<em>dimension of the image</em> of this linear map.</p>
<p>In practice we can figure out the rank of a matrix by computing its <em>singular
value decomposition</em> (SVD). If you studied data science or statistics, then you
have probably seen principal component analysis (PCA); this is very closely
related to the SVD. Using the SVD we can write a matrix \(X\) as a product</p>
\[X = U S V\]
<p>Where \(U\) and \(V\) are orthogonal matrices, and \(S\) is a diagonal matrix. The
values on the diagonals of \(S\) are known as the <em>singular values</em> of \(S\). The
matrices \(U\) and \(V\) also have nice interpretations; the rows of \(U\) form an
orthonormal basis of the <em>row space</em> of \(X\), and the columns of \(V\) are an
orthonormal basis of the <em>column space</em> of \(X\).</p>
<p>In <code class="language-plaintext highlighter-rouge">numpy</code> we can compute the SVD of a matrix using <code class="language-plaintext highlighter-rouge">np.linalg.svd</code>. Below we
compute it and verify that indeed \(X = U S V\):</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="n">np</span>
<span class="c1"># Generate a random 10x20 matrix of rank 5
</span><span class="n">m</span><span class="p">,</span> <span class="n">n</span><span class="p">,</span> <span class="n">r</span> <span class="o">=</span> <span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mi">20</span><span class="p">,</span> <span class="mi">5</span><span class="p">)</span>
<span class="n">A</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">normal</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="n">r</span><span class="p">))</span>
<span class="n">B</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">normal</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="n">r</span><span class="p">,</span> <span class="n">n</span><span class="p">))</span>
<span class="n">X</span> <span class="o">=</span> <span class="n">A</span> <span class="o">@</span> <span class="n">B</span>
<span class="c1"># Compute the SVD
</span><span class="n">U</span><span class="p">,</span> <span class="n">S</span><span class="p">,</span> <span class="n">V</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">linalg</span><span class="p">.</span><span class="n">svd</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">full_matrices</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
<span class="c1"># Confirm U S V = X
</span><span class="n">np</span><span class="p">.</span><span class="n">allclose</span><span class="p">(</span><span class="n">U</span> <span class="o">@</span> <span class="n">np</span><span class="p">.</span><span class="n">diag</span><span class="p">(</span><span class="n">S</span><span class="p">)</span> <span class="o">@</span> <span class="n">V</span><span class="p">,</span> <span class="n">X</span><span class="p">)</span>
</code></pre></div></div>
<blockquote>
<p><code class="language-plaintext highlighter-rouge">True</code></p>
</blockquote>
<p>Note that we called <code class="language-plaintext highlighter-rouge">np.linalg.svd</code> with the keyword <code class="language-plaintext highlighter-rouge">full_matrices=False</code>. If
left to the default value <code class="language-plaintext highlighter-rouge">True</code>, then in this case <code class="language-plaintext highlighter-rouge">V</code> would be a \({20\times
20}\) matrix, as opposed to the \(10\times 20\) matrix it is now. Also <code class="language-plaintext highlighter-rouge">S</code> is
returned as a 1D array, and we can convert it to a diagonal matrix using
<code class="language-plaintext highlighter-rouge">np.diag</code>. Finally the function <code class="language-plaintext highlighter-rouge">np.allclose</code> checks if all the entries of two
matrices are almost the same; they never will be exactly the same due to
numerical error.</p>
<p>As mentioned before, we can use the singular values <code class="language-plaintext highlighter-rouge">S</code> to determine what the
rank is the matrix <code class="language-plaintext highlighter-rouge">X</code>. This is obvious if we plot the singular values:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="n">plt</span>
<span class="n">DEFAULT_FIGSIZE</span> <span class="o">=</span> <span class="p">(</span><span class="mi">8</span><span class="p">,</span> <span class="mi">5</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="n">DEFAULT_FIGSIZE</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">S</span><span class="p">)</span> <span class="o">+</span> <span class="mi">1</span><span class="p">),</span> <span class="n">S</span><span class="p">,</span> <span class="s">"o"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">xticks</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">S</span><span class="p">)</span> <span class="o">+</span> <span class="mi">1</span><span class="p">))</span>
<span class="n">plt</span><span class="p">.</span><span class="n">yscale</span><span class="p">(</span><span class="s">"log"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">title</span><span class="p">(</span><span class="s">"Plot of singular values"</span><span class="p">)</span>
</code></pre></div></div>
<p><img src="/imgs/low-rank-matrix/intro-tn_3_1.png" alt="png" /></p>
<p>We see that the first 5 singular values are roughly the same size, but that the
last five singular values are much smaller; on the order of the machine epsilon.</p>
<p>Knowing the matrix is rank 5, can we write it as the product of two rank 5
matrices? Absolutely! And we do this using the SVD, or rather the <em>truncated
singular value decomposition</em>. Since the last 5 values of <code class="language-plaintext highlighter-rouge">S</code> are very close to
zero, we can simply ignore them. This then means dropping the last 5 columns of
<code class="language-plaintext highlighter-rouge">U</code> and the last 5 rows of <code class="language-plaintext highlighter-rouge">V</code>. Then finally we just need to ‘absorb’ the
singular values into one of the two matrices <code class="language-plaintext highlighter-rouge">U</code> or <code class="language-plaintext highlighter-rouge">V</code>, This way we write <code class="language-plaintext highlighter-rouge">X</code>
as the product of a \(10\times 5\) and \(5\times 20\) matrix.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">A</span> <span class="o">=</span> <span class="n">U</span><span class="p">[:,</span> <span class="p">:</span><span class="n">r</span><span class="p">]</span> <span class="o">*</span> <span class="n">S</span><span class="p">[:</span><span class="n">r</span><span class="p">]</span>
<span class="n">B</span> <span class="o">=</span> <span class="n">V</span><span class="p">[:</span><span class="n">r</span><span class="p">,</span> <span class="p">:]</span>
<span class="k">print</span><span class="p">(</span><span class="n">A</span><span class="p">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">B</span><span class="p">.</span><span class="n">shape</span><span class="p">)</span>
<span class="n">np</span><span class="p">.</span><span class="n">allclose</span><span class="p">(</span><span class="n">A</span> <span class="o">@</span> <span class="n">B</span><span class="p">,</span> <span class="n">X</span><span class="p">)</span>
</code></pre></div></div>
<blockquote>
<p><code class="language-plaintext highlighter-rouge">(10, 5) (5, 20)</code></p>
<p><code class="language-plaintext highlighter-rouge">True</code></p>
</blockquote>
<h2 id="svd-as-data-compression-method">SVD as data compression method</h2>
<p>We rarely encounter real-world data that can be <em>exactly</em> represented by a low
rank matrix using the truncated SVD. But we can still use the truncated SVD to
get a good <em>approximation</em> of the data.</p>
<p>Let us look at the singular values of an image of the St. Vitus church in my
hometown. Note that a black-and-white image is really just a matrix.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">matplotlib</span> <span class="kn">import</span> <span class="n">image</span>
<span class="c1"># Load and plot the St. Vitus image
</span><span class="n">plt</span><span class="p">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">14</span><span class="p">,</span> <span class="mi">5</span><span class="p">))</span>
<span class="n">plt</span><span class="p">.</span><span class="n">subplot</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">img</span> <span class="o">=</span> <span class="n">image</span><span class="p">.</span><span class="n">imread</span><span class="p">(</span><span class="s">"vitus512.png"</span><span class="p">)</span>
<span class="n">img</span> <span class="o">=</span> <span class="n">img</span> <span class="o">/</span> <span class="n">np</span><span class="p">.</span><span class="nb">max</span><span class="p">(</span><span class="n">img</span><span class="p">)</span> <span class="c1"># make entries lie in range [0,1]
</span><span class="n">plt</span><span class="p">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="n">cmap</span><span class="o">=</span><span class="s">"gray"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">axis</span><span class="p">(</span><span class="s">"off"</span><span class="p">)</span>
<span class="c1"># Compute and plot the singular values
</span><span class="n">plt</span><span class="p">.</span><span class="n">subplot</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">title</span><span class="p">(</span><span class="s">"Singular values"</span><span class="p">)</span>
<span class="n">U</span><span class="p">,</span> <span class="n">S</span><span class="p">,</span> <span class="n">V</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">linalg</span><span class="p">.</span><span class="n">svd</span><span class="p">(</span><span class="n">img</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">yscale</span><span class="p">(</span><span class="s">"log"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">S</span><span class="p">)</span>
</code></pre></div></div>
<p><img src="/imgs/low-rank-matrix/intro-tn_7_1.png" alt="png" /></p>
<p>We see here that the first few singular values are much larger than the rest,
followed by a slow decay, and then finally a sharp drop at the very end. Note
that there are 512 singular values, because this is a 512x512 image.</p>
<p>Let’s now try to see what happens if we compress this image as a low rank matrix
using the truncated singular value decomposition. We will look what happens to
the image when seen as a rank 10,20,50 or 100 matrix.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">plt</span><span class="p">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">12</span><span class="p">,</span> <span class="mi">12</span><span class="p">))</span>
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">rank</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">([</span><span class="mi">10</span><span class="p">,</span> <span class="mi">20</span><span class="p">,</span> <span class="mi">50</span><span class="p">,</span> <span class="mi">100</span><span class="p">]):</span>
<span class="c1"># Compute truncated SVD
</span> <span class="n">U</span><span class="p">,</span> <span class="n">S</span><span class="p">,</span> <span class="n">V</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">linalg</span><span class="p">.</span><span class="n">svd</span><span class="p">(</span><span class="n">img</span><span class="p">)</span>
<span class="n">img_compressed</span> <span class="o">=</span> <span class="n">U</span><span class="p">[:,</span> <span class="p">:</span><span class="n">rank</span><span class="p">]</span> <span class="o">@</span> <span class="n">np</span><span class="p">.</span><span class="n">diag</span><span class="p">(</span><span class="n">S</span><span class="p">[:</span><span class="n">rank</span><span class="p">])</span> <span class="o">@</span> <span class="n">V</span><span class="p">[:</span><span class="n">rank</span><span class="p">,</span> <span class="p">:]</span>
<span class="c1"># Plot the image
</span> <span class="n">plt</span><span class="p">.</span><span class="n">subplot</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">title</span><span class="p">(</span><span class="sa">f</span><span class="s">"Rank </span><span class="si">{</span><span class="n">rank</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">img_compressed</span><span class="p">,</span> <span class="n">cmap</span><span class="o">=</span><span class="s">"gray"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">axis</span><span class="p">(</span><span class="s">"off"</span><span class="p">)</span>
</code></pre></div></div>
<p><img src="/imgs/low-rank-matrix/intro-tn_9_0.png" alt="png" /></p>
<p>We see that even the rank 10 and 20 images are pretty recognizable, but with
heavy artifacts. On the other hand the rank 50 image looks pretty good, but not
as good as the original. The rank 100 image on the other hand looks really close
to the original.</p>
<p>How big is the compression if we do this? Well, if we write the image as a rank
10 matrix, we need two 512x10 matrices to store the image, which adds up to
10240 parameters, as opposed to the original 262144 parameters; a decrease in
storage of more than 25 times! On the other hand, the rank 100 image is only
about 2.6 times smaller than the original. Note that this is not a good image
compression algorithm; the SVD is relatively expensive to compute, and other
compression algorithms can achieve higher compression ratios with less image
degradation.</p>
<p>The conclusion we can draw from this is that we can use truncated SVD to
compress data. However, not all data can be compressed as efficiently by this
method. It depends on the distribution of singular values; the faster the
singular values decay, the better a low rank decomposition is going to
approximate our data. Images are not good examples of data that can be
compressed efficiently as a low rank matrix.</p>
<p>One reason why it’s difficult to compress images is because they contain many
sharp edges and transitions. Low rank matrices are especially bad at
representing diagonal lines. For example, the identity matrix is a diagonal
line seen as an image, and it is also impossible to compress using an SVD since
all singular values are equal.</p>
<p>On the other hand, images without any sharp transitions can be approximated
quite well using low rank matrices. These kind of images rarely appear as
natural images, but rather they can be discrete representations of smooth
functions \([0,1]^2 \to\mathbb R\). For example below we show a two-dimensional
discretized sum of trigonometric functions and its singular value decomposition.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># Make a grid of 100 x 100 values between [0,1]
</span><span class="n">x</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">100</span><span class="p">)</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">100</span><span class="p">)</span>
<span class="n">x</span><span class="p">,</span> <span class="n">y</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">meshgrid</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span>
<span class="c1"># A smooth trigonometric function
</span><span class="k">def</span> <span class="nf">f</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
<span class="k">return</span> <span class="n">np</span><span class="p">.</span><span class="n">sin</span><span class="p">(</span><span class="mi">200</span> <span class="o">*</span> <span class="n">x</span> <span class="o">+</span> <span class="mi">75</span> <span class="o">*</span> <span class="n">y</span><span class="p">)</span> <span class="o">+</span> <span class="n">np</span><span class="p">.</span><span class="n">sin</span><span class="p">(</span><span class="mi">50</span> <span class="o">*</span> <span class="n">x</span><span class="p">)</span> <span class="o">+</span> <span class="n">np</span><span class="p">.</span><span class="n">cos</span><span class="p">(</span><span class="mi">100</span> <span class="o">*</span> <span class="n">y</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">12</span><span class="p">,</span> <span class="mi">5</span><span class="p">))</span>
<span class="n">plt</span><span class="p">.</span><span class="n">subplot</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">X</span> <span class="o">=</span> <span class="n">f</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">X</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">subplot</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
<span class="n">U</span><span class="p">,</span> <span class="n">S</span><span class="p">,</span> <span class="n">V</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">linalg</span><span class="p">.</span><span class="n">svd</span><span class="p">(</span><span class="n">X</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">S</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">yscale</span><span class="p">(</span><span class="s">"log"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">title</span><span class="p">(</span><span class="s">"Singular values"</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"The matrix is approximately of rank: </span><span class="si">{</span><span class="n">np</span><span class="p">.</span><span class="nb">sum</span><span class="p">(</span><span class="n">S</span><span class="o">></span><span class="mf">1e-12</span><span class="p">)</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
</code></pre></div></div>
<blockquote>
<p><code class="language-plaintext highlighter-rouge">The matrix is approximately of rank: 4</code></p>
</blockquote>
<p><img src="/imgs/low-rank-matrix/intro-tn_11_1.png" alt="png" /></p>
<p>We see that this particular function can be represented by a rank 4 matrix! This
is not obvious if you look at the image. In these kind of situations a low-rank
matrix decomposition is much better than many image compression algorithms. In
this case we can reconstruct the image using only 8% of the parameters.
(Although more advanced image compression algorithms are based on wavelets, and
will actually compress this very well.)</p>
<h2 id="matrix-completion">Matrix completion</h2>
<p>Recall that a low rank matrix approximation can require much less parameters
than the dense matrix it approximates. One of the powerful things about this
allows us to recover the dense matrix even in the case where we only observe
a small part of the matrix. That is, if we have many missing values.</p>
<p>In the case above we can represent the 100x100 matrix \(X\) as a product of a
100x4 and 4x100 a matrix \(A\) and \(B\), which has in total 800 parameters instead
of 10,000. We can actually recover this low-rank decomposition from a small
subset of the big dense matrix. Suppose that we observe the entries \(X_{ij}\) for
\((i,j)\in \Omega\) an index set. We can recover \(A\) and \(B\) by solving the
following least-squares problem:</p>
\[\min_{A,B}\sum_{(i,j)\in \Omega}((AB)_{ij}-X_{ij})^2\]
<p>This problem is however non-convex, and not straightforward to solve. There is
fortunately a trick: we can alternatively fix \(A\) and then optimize \(B\) and
vice-versa. This is known as Alternating Least Squares (ALS) optimization, and
in this case works well. If we fix \(A\), observe that the minimization problem
uncouples into separate linear least squares problems for each column of \(B\):</p>
\[\min_{B_{\bullet k}} \sum_{(i,j)\in \Omega,\,j=k} (\langle A_{i\bullet},B_{\bullet k}\rangle-X_{ik})^2\]
<p>Below we do use this approach to recover the same matrix as before using 2000
data points, and we can see it does so with a very low error:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">N</span> <span class="o">=</span> <span class="mi">2000</span>
<span class="n">n</span> <span class="o">=</span> <span class="mi">100</span>
<span class="n">r</span> <span class="o">=</span> <span class="mi">4</span>
<span class="c1"># Sample N=2000 random indices
</span><span class="n">Omega</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">choice</span><span class="p">(</span><span class="n">n</span> <span class="o">*</span> <span class="n">n</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="n">N</span><span class="p">,</span> <span class="n">replace</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
<span class="n">Omega</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">unravel_index</span><span class="p">(</span><span class="n">Omega</span><span class="p">,</span> <span class="n">X</span><span class="p">.</span><span class="n">shape</span><span class="p">)</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">X</span><span class="p">[</span><span class="n">Omega</span><span class="p">]</span>
<span class="c1"># Use random initialization for matrices A,B
</span><span class="n">A</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">normal</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="n">n</span><span class="p">,</span> <span class="n">r</span><span class="p">))</span>
<span class="n">B</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">normal</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="n">r</span><span class="p">,</span> <span class="n">n</span><span class="p">))</span>
<span class="k">def</span> <span class="nf">linsolve_regular</span><span class="p">(</span><span class="n">A</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">lam</span><span class="o">=</span><span class="mf">1e-4</span><span class="p">):</span>
<span class="s">"""Solve linear problem A@x = b with Tikhonov regularization / ridge
regression"""</span>
<span class="k">return</span> <span class="n">np</span><span class="p">.</span><span class="n">linalg</span><span class="p">.</span><span class="n">solve</span><span class="p">(</span><span class="n">A</span><span class="p">.</span><span class="n">T</span> <span class="o">@</span> <span class="n">A</span> <span class="o">+</span> <span class="n">lam</span> <span class="o">*</span> <span class="n">np</span><span class="p">.</span><span class="n">eye</span><span class="p">(</span><span class="n">A</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]),</span> <span class="n">A</span><span class="p">.</span><span class="n">T</span> <span class="o">@</span> <span class="n">b</span><span class="p">)</span>
<span class="n">losses</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">40</span><span class="p">):</span>
<span class="n">loss</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">mean</span><span class="p">(((</span><span class="n">A</span> <span class="o">@</span> <span class="n">B</span><span class="p">)[</span><span class="n">Omega</span><span class="p">]</span> <span class="o">-</span> <span class="n">y</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span>
<span class="n">losses</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">loss</span><span class="p">)</span>
<span class="c1"># Update B
</span> <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n</span><span class="p">):</span>
<span class="n">B</span><span class="p">[:,</span> <span class="n">j</span><span class="p">]</span> <span class="o">=</span> <span class="n">linsolve_regular</span><span class="p">(</span><span class="n">A</span><span class="p">[</span><span class="n">Omega</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="n">Omega</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">==</span> <span class="n">j</span><span class="p">]],</span> <span class="n">y</span><span class="p">[</span><span class="n">Omega</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">==</span> <span class="n">j</span><span class="p">])</span>
<span class="c1"># Update A
</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n</span><span class="p">):</span>
<span class="n">A</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="p">:]</span> <span class="o">=</span> <span class="n">linsolve_regular</span><span class="p">(</span>
<span class="n">B</span><span class="p">[:,</span> <span class="n">Omega</span><span class="p">[</span><span class="mi">1</span><span class="p">][</span><span class="n">Omega</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">i</span><span class="p">]].</span><span class="n">T</span><span class="p">,</span> <span class="n">y</span><span class="p">[</span><span class="n">Omega</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">i</span><span class="p">]</span>
<span class="p">)</span>
<span class="c1"># Plot the input image
</span><span class="n">plt</span><span class="p">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">12</span><span class="p">,</span> <span class="mi">12</span><span class="p">))</span>
<span class="n">plt</span><span class="p">.</span><span class="n">subplot</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">title</span><span class="p">(</span><span class="s">"Input image"</span><span class="p">)</span>
<span class="n">S</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">n</span><span class="p">,</span> <span class="n">n</span><span class="p">))</span>
<span class="n">S</span><span class="p">[</span><span class="n">Omega</span><span class="p">]</span> <span class="o">=</span> <span class="n">y</span>
<span class="n">plt</span><span class="p">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">S</span><span class="p">)</span>
<span class="c1"># Plot reconstructed image
</span><span class="n">plt</span><span class="p">.</span><span class="n">subplot</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">title</span><span class="p">(</span><span class="s">"Reconstructed image"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">A</span> <span class="o">@</span> <span class="n">B</span><span class="p">)</span>
<span class="c1"># Plot training loss
</span><span class="n">plt</span><span class="p">.</span><span class="n">subplot</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">title</span><span class="p">(</span><span class="s">"Mean square error loss during training"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">losses</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">yscale</span><span class="p">(</span><span class="s">"log"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">xlabel</span><span class="p">(</span><span class="s">"steps"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">ylabel</span><span class="p">(</span><span class="s">"Mean squared error"</span><span class="p">)</span>
</code></pre></div></div>
<p><img src="/imgs/low-rank-matrix/intro-tn_14_1.png" alt="png" /></p>
<h2 id="netflix-prize">Netflix prize</h2>
<p>Let’s consider a particularly interesting use of matrix completion –
collaborative filtering. Think about how services like Netflix may recommend new
shows or movies to watch. They know which movies you like, and they know which
movies other people like. Netflix then recommends movies that are liked by
people with a similar taste to yours. This is called <em>collaborative filtering</em>,
because different people <em>collaborate</em> to filter out movies so that we can make
a recommendation.</p>
<p>But can we do this in practice? Well, for every user we can put their personal
ratings of every movie they watched in a big matrix. In this matrix each row
represents a movie, and each column a user. Most users have only seen a small
fraction of all the movies on the platform, so the overwhelming majority of the
entries of this matrix are unknown. Then we apply matrix completion to this
matrix. Each entry of the completed matrix then represents the rating <em>we think</em>
the user would give to a movie, even if they have never watched it.</p>
<p>In 2006 Netflix opened a competition with a grand prize of <strong>$1,000,000</strong> (!!)
to solve precisely this problem. The data consists of more than 100 million
ratings by 480,189 users for 17,769 different movies. The size of this dataset
immediately poses a practical problem; if we put this in a matrix with floating
point entries, then it would require about 68 terabytes of RAM. Fortunately we
can avoid this problem by using sparse matrices. This makes implementation a
little harder, but certainly still feasible.</p>
<p>We will also need to upgrade our matrix completion algorithm. The algorithm we
mentioned before is slow for very large matrices, and suffers from problems of
numerical stability due to the way it decouples into many smaller linear
problems. Recall that complete a matrix \(X\) by solving the following
optimization problem:</p>
\[\min_{A,B}\sum_{(i,j)\in \Omega}((AB)_{ij}-X_{ij})^2.\]
<p>We will first rewrite the problem as follows:</p>
\[\min_{A,B}\|P_\Omega(AB) -X\|.\]
<p>Here \(P_\Omega\) denotes the operation of setting all entries \(AB_{ij}\) to zero
if \((i,j)\notin \Omega\). In other words, \(P_\Omega\) turns \(AB\) into a sparse
matrix with the same sparsity pattern as \(X\). In some sense, the issue with this
optimization problem is that only a small part of the entries of \(AB\) affect the
the objective. We can solve this by adding a new matrix \(Z\) such that
\(P_\Omega(Z)=X\), and then using \(A,B\) to approximate \(Z\) instead:</p>
\[\min_{A,B,Z}\|AB-Z\|\quad \text{such that } P_\Omega Z = X\]
<p>This problem can then be solved using the same alternating least-squares
approach we have used before. For example if we fix \(A,B\) then the optimal value
of \(Z\) is given by \(Z = AB+X-P_\Omega(Z)\), and at each iteration we can update
\(A\) and \(B\) by solving a linear least-squares problem. It is important to note
that this way \(Z\) is a sum of a low-rank and a sparse matrix at every step, and
this allows us to still efficiently manipulate it and store it in memory.</p>
<p>Although not very difficult, the implementation of this algorithm is a little
too technical for this blog post. Instead we can just look at the results. I
used this algorithm to fit matrices \(A\) and \(B\) of rank 5 and of rank 10 to the
Netflix prize dataset. I used 3000 iterations of training, taking the better
part of a day to train on my computer. I could probably do more, but I’m too
impatient. The progress of training is shown below.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">os.path</span>
<span class="n">plt</span><span class="p">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="n">DEFAULT_FIGSIZE</span><span class="p">)</span>
<span class="n">DATASET_PATH</span> <span class="o">=</span> <span class="s">"/mnt/games/datasets/netflix/"</span>
<span class="k">for</span> <span class="n">r</span> <span class="ow">in</span> <span class="p">[</span><span class="mi">10</span><span class="p">,</span> <span class="mi">5</span><span class="p">]:</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">load</span><span class="p">(</span><span class="n">os</span><span class="p">.</span><span class="n">path</span><span class="p">.</span><span class="n">join</span><span class="p">(</span><span class="n">DATASET_PATH</span><span class="p">,</span> <span class="sa">f</span><span class="s">"rank-</span><span class="si">{</span><span class="n">r</span><span class="si">}</span><span class="s">-model.npz"</span><span class="p">))</span>
<span class="n">A</span> <span class="o">=</span> <span class="n">model</span><span class="p">[</span><span class="s">"X"</span><span class="p">]</span>
<span class="n">B</span> <span class="o">=</span> <span class="n">model</span><span class="p">[</span><span class="s">"Y"</span><span class="p">]</span>
<span class="n">train_errors</span> <span class="o">=</span> <span class="n">model</span><span class="p">[</span><span class="s">"train_errors"</span><span class="p">]</span>
<span class="n">test_errors</span> <span class="o">=</span> <span class="n">model</span><span class="p">[</span><span class="s">"test_errors"</span><span class="p">]</span>
<span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">train_errors</span><span class="p">),</span> <span class="n">label</span><span class="o">=</span><span class="sa">f</span><span class="s">"Train rank </span><span class="si">{</span><span class="n">r</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">test_errors</span><span class="p">),</span> <span class="n">label</span><span class="o">=</span><span class="sa">f</span><span class="s">"Test rank </span><span class="si">{</span><span class="n">r</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">legend</span><span class="p">()</span>
<span class="n">plt</span><span class="p">.</span><span class="n">ylim</span><span class="p">(</span><span class="mf">0.8</span><span class="p">,</span> <span class="mf">1.5</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">xlabel</span><span class="p">(</span><span class="s">"Training iterations"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">ylabel</span><span class="p">(</span><span class="s">"Root mean squared error (RMSE)"</span><span class="p">);</span>
</code></pre></div></div>
<p><img src="/imgs/low-rank-matrix/intro-tn_16_0.png" alt="png" /></p>
<p>We see that the training error for the rank 5 and rank 10 models are virtually
identical, but the test error is lower for the rank 5 model. We can interpret
this as the rank 10 model overfitting more, which is often the case for more
complex models.</p>
<p>Next, how can we use this model? Well, the rows of the matrix \(A\) correspond to
movies, and the columns of matrix \(B\) correspond to users. So if we want to know
how much user #179 likes movie #2451 (<em>Lord of the Rings: The Fellowship of the
Ring</em>), then we compute \(A[2451]\cdot B[:, 179]\):</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">A</span><span class="p">[</span><span class="mi">2451</span><span class="p">]</span> <span class="o">@</span> <span class="n">B</span><span class="p">[:,</span> <span class="mi">179</span><span class="p">]</span>
</code></pre></div></div>
<blockquote>
<p><code class="language-plaintext highlighter-rouge">4.411312294862265</code></p>
</blockquote>
<p>We see that the <em>expected rating</em> (out of 5) for this user and movie is about
4.41. So we expect that this user will like this movie, and we may choose to
recommend it.</p>
<p>But we want to find the <em>best</em> recommendation for this user. To do this we can
simply compute the product \(A \cdot B[:,179]\), which will give a vector with
expected rating for every single movie, and then we simply sort. Below we can
see the 5 movies with the highest and lowest expected ratings for this user.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">pandas</span> <span class="k">as</span> <span class="n">pd</span>
<span class="n">movies</span> <span class="o">=</span> <span class="n">pd</span><span class="p">.</span><span class="n">read_csv</span><span class="p">(</span>
<span class="n">os</span><span class="p">.</span><span class="n">path</span><span class="p">.</span><span class="n">join</span><span class="p">(</span><span class="n">DATASET_PATH</span><span class="p">,</span> <span class="s">"movie_titles.csv"</span><span class="p">),</span>
<span class="n">names</span><span class="o">=</span><span class="p">[</span><span class="s">"index"</span><span class="p">,</span> <span class="s">"year"</span><span class="p">,</span> <span class="s">"name"</span><span class="p">],</span>
<span class="n">usecols</span><span class="o">=</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span>
<span class="p">)</span>
<span class="n">movies</span><span class="p">[</span><span class="s">"ratings-179"</span><span class="p">]</span> <span class="o">=</span> <span class="n">A</span> <span class="o">@</span> <span class="n">B</span><span class="p">[:,</span> <span class="mi">179</span><span class="p">]</span>
<span class="n">movies</span><span class="p">.</span><span class="n">sort_values</span><span class="p">(</span><span class="s">"ratings-179"</span><span class="p">,</span> <span class="n">ascending</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
</code></pre></div></div>
<div>
<style scoped="">
.dataframe tbody tr th:only-of-type {
vertical-align: middle;
}
.dataframe tbody tr th {
vertical-align: top;
}
.dataframe thead th {
text-align: right;
}
</style>
<table border="1" class="dataframe">
<thead>
<tr style="text-align: right;">
<th></th>
<th>name</th>
<th>ratings-179</th>
</tr>
</thead>
<tbody>
<tr>
<th>10755</th>
<td>Kirby: A Dark & Stormy Knight</td>
<td>9.645918</td>
</tr>
<tr>
<th>15833</th>
<td>Paternal Instinct</td>
<td>7.712654</td>
</tr>
<tr>
<th>15355</th>
<td>Last Hero In China</td>
<td>7.689984</td>
</tr>
<tr>
<th>14902</th>
<td>Warren Miller's: Ride</td>
<td>7.624472</td>
</tr>
<tr>
<th>2082</th>
<td>Blood Alley</td>
<td>7.317524</td>
</tr>
<tr>
<th>...</th>
<td>...</td>
<td>...</td>
</tr>
<tr>
<th>463</th>
<td>The Return of Ruben Blades</td>
<td>-6.037189</td>
</tr>
<tr>
<th>12923</th>
<td>Where the Red Fern Grows 2</td>
<td>-6.153577</td>
</tr>
<tr>
<th>7067</th>
<td>Eric Idle's Personal Best</td>
<td>-6.441100</td>
</tr>
<tr>
<th>538</th>
<td>Rumpole of the Bailey: Series 4</td>
<td>-6.740144</td>
</tr>
<tr>
<th>4331</th>
<td>Sugar: Howling of Angel</td>
<td>-7.015818</td>
</tr>
</tbody>
</table>
<p>17769 rows × 2 columns</p>
</div>
<p>Note that the expected ratings are not between 0 and 5, but can take on any
value (in particular non-integer ones). This is not necessarily a problem,
because we only care about the relative rating of the movies.</p>
<p>To me, all these movies all sound quite obscure. And this makes sense, the model
does not take factors such as popularity of the movie into account. It also
ignores a lot of other data that we may know about the user, such as their age,
gender and location. It ignores when the movie is released, and it doesn’t take
into account the dates of all the movie ratings of each user. These are all
important factors, that could significantly improve the quality of this the
recommendation system.</p>
<p>We could try to modify our matrix completion model to take these factors into
account, but it’s not obvious how to do this. There is no need to do this
however, we use the matrices \(A\), \(B\) to augment any data we have about the
movie and the user. And then we can train a new model on top of this data, to
create something even better.</p>
<p>We can think of the movies as lying in a really high-dimensional space, and the
matrix \(A\) maps this space onto a much smaller space. The same is true for the
\(B\) and the ‘space’ of users. We can then use this <em>embedding</em> into a lower
dimensional space as the input of another model.</p>
<p>Unfortunately we don’t have access to more information about the users (due
to obvious privacy concerns), so this is difficult to demonstrate. But the point
is this: the decomposition \(X\approx AB\) is both <em>interpretable</em>, and can be
used as a building block for more advanced machine learning models.</p>
<h2 id="conclusion">Conclusion</h2>
<p>In summary we have seen that low-rank matrix decompositions have many useful
applications in machine learning. They are powerful because they can be learned
using relatively little data, and have the ability to complete missing data.
Unlike many other machine learning models, computing low-rank matrix
decompositions of data can be done quickly.</p>
<p>Even though they come with some limitations, they can always be used as a
building block for more advanced machine learning models. This is because they
can give an interpretable, low-dimensional representation of very
high-dimensional data. We also didn’t even come close to discussing all their
applications, or algorithms on how to find and optimize them.</p>
<p>In the next post I will look at a generalization of low-rank matrix
decompositions: <em>tensor decompositions</em>. While more complicated, these
decompositions are even more powerful at reducing the dimensionality of very
high-dimensional data.</p>Rik VoorhaarA lot of data is naturally of 'low rank'. I will explain what this means, and how to exploit this fact.How to edit Microsoft Word documents in Python2021-08-29T00:00:00+00:002021-08-29T00:00:00+00:00https://rikvoorhaar.com/python-docx<p>In preparation for the job market, I started polishing my CV. I try to keep the
CV on my website as up-to-date as possible, but many recruiters and companies
prefer a single-page neat CV in a Microsoft Word document. I used to always make
my CV’s in LaTeX, but it seems Word is often preferred since it’s easier to
edit for third parties.</p>
<p>Keeping both a web, Word, and PDF version all up-to-date and easy to edit seemed
like an annoying task. I have plenty experience with automatically generating
PDF documents using LaTeX and Python, so I figured why should a Word document be
any different? Let’s dive into the world of editing Word documents in Python!</p>
<p>Fortunately there is a library for this: <code class="language-plaintext highlighter-rouge">python-docx</code>. It can be used to create
Word documents from scratch, but stylizing a document is a bit tricky. Instead,
its real power lies in editing pre-made documents. I went ahead and made a nice
looking CV in Word, and now let’s open this document in <code class="language-plaintext highlighter-rouge">python-docx</code>. A Word
document is stored in XML under the hoods, and there can be a complicated tree
structure to a document. However, we can create a document and use the
<code class="language-plaintext highlighter-rouge">.paragraphs</code> attribute for a complete list of all the paragraphs in the
document. Let’s take a paragraph, and print it’s text content.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">docx</span> <span class="kn">import</span> <span class="n">Document</span>
<span class="n">document</span> <span class="o">=</span> <span class="n">Document</span><span class="p">(</span><span class="s">"resume.docx"</span><span class="p">)</span>
<span class="n">paragraph</span> <span class="o">=</span> <span class="n">document</span><span class="p">.</span><span class="n">paragraphs</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="k">print</span><span class="p">(</span><span class="n">paragraph</span><span class="p">.</span><span class="n">text</span><span class="p">)</span>
</code></pre></div></div>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Rik Voorhaar
</code></pre></div></div>
<p>Turns out the first paragraph contains my name! Editing this text is very easy;
we just need to set a new value to the <code class="language-plaintext highlighter-rouge">.text</code> attribute. Let’s do this and safe
the document.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">paragraph</span><span class="p">.</span><span class="n">text</span> <span class="o">=</span> <span class="s">"Willem Hendrik"</span>
<span class="n">document</span><span class="p">.</span><span class="n">save</span><span class="p">(</span><span class="s">"resume_edited.docx"</span><span class="p">)</span>
</code></pre></div></div>
<p>Below is a picture of the resulting change; it unfortunately seems like two
additional things happened when editing this paragraph: the font of the edited
paragraph changed, and the bar / text box on the right-hand side disappeared
completely!</p>
<p><img src="/imgs/python_docx/doc_comparison.png" alt="img" /></p>
<p>This is no good, but to understand what happened to the text box we need to
dig into the XML of the document. We can turn the document into an XML file like
so:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">document</span> <span class="o">=</span> <span class="n">Document</span><span class="p">(</span><span class="s">"resume.docx"</span><span class="p">)</span>
<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="s">'resume.xml'</span><span class="p">,</span> <span class="s">'w'</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
<span class="n">f</span><span class="p">.</span><span class="n">write</span><span class="p">(</span><span class="n">document</span><span class="p">.</span><span class="n">_element</span><span class="p">.</span><span class="n">xml</span><span class="p">)</span>
</code></pre></div></div>
<p>It seems the problem was that the text box on the right was nested inside an
other object, which is apparently not handled properly. This issue was easy to
fix by modifying the Word document. However, the right bar on the side consists
of 2 text boxes, and the top box with my contact information <em>does</em> disappear if
I change the first paragraph. <em>But</em>, it does not disappear if I change the
second paragraph; it only happens if I change paragraph 1 or 3 (and the latter
is empty). I tried inserting two paragraphs before this particular paragraph, or
changing the style of this particular paragraph, but the issue remains.</p>
<p>Looking at the XML the issue is clear: the text box element lies nested inside
this paragraph! It turned out to be a bit tricky to avoid this, so for now let
us then try changing the second paragraph, changing the word “resume” for
“curriculum vitae”.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">document</span> <span class="o">=</span> <span class="n">Document</span><span class="p">(</span><span class="s">"resume.docx"</span><span class="p">)</span>
<span class="n">paragraph</span> <span class="o">=</span> <span class="n">document</span><span class="p">.</span><span class="n">paragraphs</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
<span class="k">print</span><span class="p">(</span><span class="n">paragraph</span><span class="p">.</span><span class="n">text</span><span class="p">)</span>
<span class="n">paragraph</span><span class="p">.</span><span class="n">text</span> <span class="o">=</span> <span class="s">"Curriculum Vitae"</span>
<span class="n">document</span><span class="p">.</span><span class="n">save</span><span class="p">(</span><span class="s">"CV.docx"</span><span class="p">)</span>
</code></pre></div></div>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Resume
</code></pre></div></div>
<p>If we do this there’s no problems with text boxes disappearing, but
unfortunately the style of this paragraph is still reset when we do this. Let’s
have a look at how the XML changes when we edit this paragraph. Ignoring
irrelevant information, before changing it looks like this:</p>
<div class="language-xml highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="nt"><w:p></span>
<span class="nt"><w:r></span>
<span class="nt"><w:t></span>R<span class="nt"></w:t></span>
<span class="nt"></w:r></span>
<span class="nt"><w:r></span>
<span class="nt"><w:t></span>esume<span class="nt"></w:t></span>
<span class="nt"></w:r></span>
<span class="nt"></w:p></span>
</code></pre></div></div>
<p>And afterwards it looks like this:</p>
<div class="language-xml highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="nt"><w:p></span>
<span class="nt"><w:r></span>
<span class="nt"><w:t></span>Curriculum Vitae<span class="nt"></w:t></span>
<span class="nt"></w:r></span>
<span class="nt"></w:p></span>
</code></pre></div></div>
<p>In Word, each paragraph (<code class="language-plaintext highlighter-rouge"><p></code>) is split up in multiple runs (<code class="language-plaintext highlighter-rouge"><r></code>). What we
see here is that originally the paragraph was two runs, and after modifying it,
it became a single run. However, it seems that in both cases the style
information is exactly the same, so I don’t understand why the style changes
after modification. In this case if I retype the word ‘Resume’ in the original
word document, this paragraph become a single run, but <em>still</em> the style changes
after editing, and I still don’t see why this happens when looking at the XML.</p>
<p>Looking at the source code of <code class="language-plaintext highlighter-rouge">python-docx</code> I noticed that when we call
<code class="language-plaintext highlighter-rouge">paragraph.text = ...</code>, what happens is that the contents of the paragraph get
deleted, and then a new run is added with the desired text. It is not clear to
me at where exactly the style information is stored, but either way there is a
simple workaround to what we’re trying to do: we can simply modify the text of
the first <em>run</em> in the paragraph, rather than clearing the entire paragraph and
adding a new one. This in fact also works for editing the first paragraph,
where before we had problems with disappearing text boxes:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">document</span> <span class="o">=</span> <span class="n">Document</span><span class="p">(</span><span class="s">"resume.docx"</span><span class="p">)</span>
<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="s">'resume.xml'</span><span class="p">,</span> <span class="s">'w'</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
<span class="n">f</span><span class="p">.</span><span class="n">write</span><span class="p">(</span><span class="n">document</span><span class="p">.</span><span class="n">_element</span><span class="p">.</span><span class="n">xml</span><span class="p">)</span>
<span class="c1"># Change 'Rik Voorhaar' for 'Willem Hendrik Voorhaar'
</span><span class="n">paragraph</span> <span class="o">=</span> <span class="n">document</span><span class="p">.</span><span class="n">paragraphs</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">run</span> <span class="o">=</span> <span class="n">paragraph</span><span class="p">.</span><span class="n">runs</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
<span class="n">run</span><span class="p">.</span><span class="n">text</span> <span class="o">=</span> <span class="s">'Willem Hendrik Voorhaar'</span>
<span class="c1"># Change 'Resume' for 'Curriculum Vitae'
</span><span class="n">paragraph</span> <span class="o">=</span> <span class="n">document</span><span class="p">.</span><span class="n">paragraphs</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
<span class="n">run</span> <span class="o">=</span> <span class="n">paragraph</span><span class="p">.</span><span class="n">runs</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">run</span><span class="p">.</span><span class="n">text</span> <span class="o">=</span> <span class="s">'Curriculum Vitae'</span>
<span class="n">document</span><span class="p">.</span><span class="n">save</span><span class="p">(</span><span class="s">'CV.docx'</span><span class="p">)</span>
</code></pre></div></div>
<p>Doing this changes the text, but leaves all the style information the
same. Alright, now we now how to edit text. It’s more tricky than one might
expect, but it does work!</p>
<h2 id="dealing-with-text-boxes">Dealing with text boxes</h2>
<p>Let’s say that next we want to edit the text box on the right-hand side of the
document, and add a skill to our list of skills. We’ve been diving into the
inner workings of Word documents, so it’s fair to say we know how to use
Microsoft Word, so let’s add the skill “Microsoft Word” to the list.</p>
<p>To do this we first want to figure out in which paragraph this information is
stored. We can do this by going through all the paragraphs in the document and
looking for the text “Skills”.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">re</span>
<span class="n">pattern</span> <span class="o">=</span> <span class="n">re</span><span class="p">.</span><span class="nb">compile</span><span class="p">(</span><span class="s">"Skills"</span><span class="p">)</span>
<span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">document</span><span class="p">.</span><span class="n">paragraphs</span><span class="p">:</span>
<span class="k">if</span> <span class="n">pattern</span><span class="p">.</span><span class="n">search</span><span class="p">(</span><span class="n">p</span><span class="p">.</span><span class="n">text</span><span class="p">):</span>
<span class="k">print</span><span class="p">(</span><span class="s">"Found the paragraph!"</span><span class="p">)</span>
<span class="k">break</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">print</span><span class="p">(</span><span class="s">"Did not find the paragraph :("</span><span class="p">)</span>
</code></pre></div></div>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Did not find the paragraph :(
</code></pre></div></div>
<p>Seems like there is unfortunately no matching paragraph! This is because the
paragraph we want is <em>inside a text box</em>, and modifying text boxes is not supported
in <code class="language-plaintext highlighter-rouge">python-docx</code>. This is a known issue, but instead of giving up I decided to
add support for modifying text boxes to <code class="language-plaintext highlighter-rouge">python-docx</code> myself! It turned out not to
be too difficult to implement, despite my limited knowledge of both the package
and the inner structure of Word documents.</p>
<p>The first step is understanding how text boxes are encoded in the XML. It turns
out that the structure is something like this:</p>
<pre><code class="language-XML"><mc:AlternateContent>
<mc:Choice Requires="wps">
<w:drawing>
<wp:anchor>
<a:graphics>
<a:graphicData>
<wps:txbx>
<w:txbxContent>
...
<w:txbxContent>
</wps:txbx>
</a:graphicData>
</a:graphics>
</wp:anchor>
</w:drawing>
</mc:Choice>
<mc:Fallback>
<w:pict>
<v:textbox>
<w:txbxContent>
...
<w:txbxContent>
</v:textbox>
</w:pict>
</mc:Fallback>
</mc:AlternateContent>
</code></pre>
<p>The insides of the two <code class="language-plaintext highlighter-rouge"><w:txbxContent></code> elements are exactly identical. The
information is stored twice probably for legacy reasons. A quick Google reveals
that <code class="language-plaintext highlighter-rouge">wps</code> is an XML namespace introduced in Office 2010, and WPS is short for
Word Processing Shape. The textbox is therefore stored twice to maintain
backwards compatibility with older Word versions. Not sure many people still use
Office 2006… Either way, this means that if we want to update the contents of
the textbox, we need to do it in two places.</p>
<p>Next we need to figure out how to manipulate these word objects. My idea is to
create a <code class="language-plaintext highlighter-rouge">TextBox</code> class, that is associated to an <code class="language-plaintext highlighter-rouge"><mc:AlternateContent></code>
element, and which ensures that both <code class="language-plaintext highlighter-rouge"><w:txbxContent></code> elements are always
updated at the same time. First we make a class encoding a <code class="language-plaintext highlighter-rouge"><w:txbxContent></code>
element. For this we can build on the <code class="language-plaintext highlighter-rouge">BlockItemContainer</code> class already
implemented in <code class="language-plaintext highlighter-rouge">python-docx</code>. Mixing in this class gives automatic support for
manipulating paragraphs inside of the container.</p>
<div class="language-py highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">TextBoxContent</span><span class="p">(</span><span class="n">BlockItemContainer</span><span class="p">)</span>
</code></pre></div></div>
<p>Given an <code class="language-plaintext highlighter-rouge"><mc:AlternateContent></code> object, we can access the two <code class="language-plaintext highlighter-rouge"><w:txbxContent></code>
elements using the following XPath specifications:</p>
<div class="language-py highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">XPATH_CHOICE</span> <span class="o">=</span> <span class="s">"./mc:Choice/w:drawing/wp:anchor/a:graphic/a:graphicData//wps:txbx/w:txbxContent"</span>
<span class="n">XPATH_FALLBACK</span> <span class="o">=</span> <span class="s">"./mc:Fallback/w:pict//v:textbox/w:txbxContent"</span>
</code></pre></div></div>
<p>Then making a rudimentary <code class="language-plaintext highlighter-rouge">TextBox</code> class is very simple. We base it on the
<code class="language-plaintext highlighter-rouge">ElementProxy</code> class in <code class="language-plaintext highlighter-rouge">python-docx</code>. This class is meant for storing and
manipulating the children of an XML element.</p>
<div class="language-py highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">TextBox</span><span class="p">(</span><span class="n">ElementProxy</span><span class="p">):</span>
<span class="s">"""Implements texboxes. Requires an `<mc:AlternateContent>` element."""</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">element</span><span class="p">,</span> <span class="n">parent</span><span class="p">):</span>
<span class="nb">super</span><span class="p">(</span><span class="n">TextBox</span><span class="p">,</span> <span class="bp">self</span><span class="p">).</span><span class="n">__init__</span><span class="p">(</span><span class="n">element</span><span class="p">,</span> <span class="n">parent</span><span class="p">)</span>
<span class="k">try</span><span class="p">:</span>
<span class="p">(</span><span class="n">tbox1</span><span class="p">,)</span> <span class="o">=</span> <span class="n">element</span><span class="p">.</span><span class="n">xpath</span><span class="p">(</span><span class="n">XPATH_CHOICE</span><span class="p">)</span>
<span class="p">(</span><span class="n">tbox2</span><span class="p">,)</span> <span class="o">=</span> <span class="n">element</span><span class="p">.</span><span class="n">xpath</span><span class="p">(</span><span class="n">XPATH_FALLBACK</span><span class="p">)</span>
<span class="k">except</span> <span class="nb">ValueError</span> <span class="k">as</span> <span class="n">err</span><span class="p">:</span>
<span class="k">raise</span> <span class="nb">ValueError</span><span class="p">(</span>
<span class="s">"This element is not a text box; it should contain precisely two </span><span class="se">\
</span><span class="s"> ``<w:txbxContent>`` objects"</span>
<span class="p">)</span>
<span class="bp">self</span><span class="p">.</span><span class="n">tbox1</span> <span class="o">=</span> <span class="n">TextBoxContent</span><span class="p">(</span><span class="n">tbox1</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span>
<span class="bp">self</span><span class="p">.</span><span class="n">tbox2</span> <span class="o">=</span> <span class="n">TextBoxContent</span><span class="p">(</span><span class="n">tbox2</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span>
</code></pre></div></div>
<p>So far this is just good for storing the text box, we still need some code to
actually manipulate it. It would also be great if we have a way to find all the
text boxes in a document. This is as simple as finding all the
<code class="language-plaintext highlighter-rouge"><mc:AlternateContent></code> elements with precisely two <code class="language-plaintext highlighter-rouge"><w:txbxContent></code> elements.
We can use the following function:</p>
<div class="language-py highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">find_textboxes</span><span class="p">(</span><span class="n">element</span><span class="p">,</span> <span class="n">parent</span><span class="p">):</span>
<span class="s">"""
List all text box objects in the document.
Looks for all ``<mc:AlternateContent>`` elements, and selects those
which contain a text box.
"""</span>
<span class="n">alt_cont_elems</span> <span class="o">=</span> <span class="n">element</span><span class="p">.</span><span class="n">xpath</span><span class="p">(</span><span class="s">".//mc:AlternateContent"</span><span class="p">)</span>
<span class="n">text_boxes</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">elem</span> <span class="ow">in</span> <span class="n">alt_cont_elems</span><span class="p">:</span>
<span class="n">tbox1</span> <span class="o">=</span> <span class="n">elem</span><span class="p">.</span><span class="n">xpath</span><span class="p">(</span><span class="n">XPATH_CHOICE</span><span class="p">)</span>
<span class="n">tbox2</span> <span class="o">=</span> <span class="n">elem</span><span class="p">.</span><span class="n">xpath</span><span class="p">(</span><span class="n">XPATH_FALLBACK</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">tbox1</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span> <span class="ow">and</span> <span class="nb">len</span><span class="p">(</span><span class="n">tbox2</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
<span class="n">text_boxes</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">TextBox</span><span class="p">(</span><span class="n">elem</span><span class="p">,</span> <span class="n">parent</span><span class="p">))</span>
<span class="k">return</span> <span class="n">text_boxes</span>
</code></pre></div></div>
<p>We then update the <code class="language-plaintext highlighter-rouge">Document</code> class with a new <code class="language-plaintext highlighter-rouge">textboxes</code> attribute:</p>
<div class="language-py highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">@</span><span class="nb">property</span>
<span class="k">def</span> <span class="nf">textboxes</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="s">"""
List all text box objects in the document.
"""</span>
<span class="k">return</span> <span class="n">find_textboxes</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">_element</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span>
</code></pre></div></div>
<p>Now let’s test this out:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">document</span> <span class="o">=</span> <span class="n">Document</span><span class="p">(</span><span class="s">"resume.docx"</span><span class="p">)</span>
<span class="n">document</span><span class="p">.</span><span class="n">textboxes</span>
</code></pre></div></div>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>[<docx.oxml.textbox.TextBox at 0x7faf395c3bc0>,
<docx.oxml.textbox.TextBox at 0x7faf395c3100>]
</code></pre></div></div>
<p>Now to manipulate the “Skills” section as we initially wanted, we first find the
right paragraph. Since the two <code class="language-plaintext highlighter-rouge"><w:txbxContent></code> objects have the same
paragraphs, we need to find which <em>number</em> of paragraph contains the text, and
in which textbox:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">re</span>
<span class="k">def</span> <span class="nf">find_paragraph</span><span class="p">(</span><span class="n">pattern</span><span class="p">):</span>
<span class="k">for</span> <span class="n">textbox</span> <span class="ow">in</span> <span class="n">document</span><span class="p">.</span><span class="n">textboxes</span><span class="p">:</span>
<span class="k">for</span> <span class="n">i</span><span class="p">,</span><span class="n">p</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">textbox</span><span class="p">.</span><span class="n">paragraphs</span><span class="p">):</span>
<span class="k">if</span> <span class="n">pattern</span><span class="p">.</span><span class="n">search</span><span class="p">(</span><span class="n">p</span><span class="p">.</span><span class="n">text</span><span class="p">):</span>
<span class="k">return</span> <span class="n">textbox</span><span class="p">,</span><span class="n">i</span>
<span class="n">pattern</span> <span class="o">=</span> <span class="n">re</span><span class="p">.</span><span class="nb">compile</span><span class="p">(</span><span class="s">"Skills"</span><span class="p">)</span>
<span class="n">textbox</span><span class="p">,</span> <span class="n">i</span> <span class="o">=</span> <span class="n">find_paragraph</span><span class="p">(</span><span class="n">pattern</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">textbox</span><span class="p">.</span><span class="n">paragraphs</span><span class="p">[</span><span class="n">i</span><span class="p">].</span><span class="n">text</span><span class="p">)</span>
</code></pre></div></div>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Skills
</code></pre></div></div>
<p>Now to insert a new skill, we need to create a new paragraph with the text
“Microsoft Word”. For this we can find the paragraph right after, and this
paragraphs <code class="language-plaintext highlighter-rouge">insert_paragraph_before</code> method with appropriate text and style
information. The paragraph in question is the one containing the word
“Research”. I want to copy the style of this paragraph to the new paragraph, but
for some reason the style information is empty for this paragraph. However, I
know that the style of this paragraph should be the <code class="language-plaintext highlighter-rouge">'Skillsentries'</code>, so I can
just use that directly.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">style</span> <span class="o">=</span> <span class="n">document</span><span class="p">.</span><span class="n">styles</span><span class="p">[</span><span class="s">'Skillsentries'</span><span class="p">]</span>
<span class="n">pattern</span> <span class="o">=</span> <span class="n">re</span><span class="p">.</span><span class="nb">compile</span><span class="p">(</span><span class="s">"Research"</span><span class="p">)</span>
<span class="n">textbox</span><span class="p">,</span><span class="n">i</span> <span class="o">=</span> <span class="n">find_paragraph</span><span class="p">(</span><span class="n">pattern</span><span class="p">)</span>
<span class="n">p1</span> <span class="o">=</span> <span class="n">textbox</span><span class="p">.</span><span class="n">tbox1</span><span class="p">.</span><span class="n">paragraphs</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
<span class="n">p2</span> <span class="o">=</span> <span class="n">textbox</span><span class="p">.</span><span class="n">tbox2</span><span class="p">.</span><span class="n">paragraphs</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
<span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="p">(</span><span class="n">p1</span><span class="p">,</span><span class="n">p2</span><span class="p">):</span>
<span class="n">p</span><span class="p">.</span><span class="n">insert_paragraph_before</span><span class="p">(</span><span class="s">"Microsoft Word"</span><span class="p">,</span> <span class="n">p</span><span class="p">.</span><span class="n">style</span><span class="p">)</span>
<span class="n">document</span><span class="p">.</span><span class="n">save</span><span class="p">(</span><span class="s">"CV.docx"</span><span class="p">)</span>
</code></pre></div></div>
<p>When now opening the Word document, we see the item “Microsoft Word” in my list
of skills, with the right style and everything. I did cheat a little; I needed
to make some additional technical changes to the code for this all to work, but
the details are not super important. If you want to use this feature, you can
use <a href="https://github.com/RikVoorhaar/python-docx">my fork of python-docx</a>. My
solution is still a little hacky, so I don’t think it will be added to the main
repository, but it does work fine for my purposes.</p>
<h2 id="conclusion">Conclusion</h2>
<p>In summary, we <em>can</em> use Python to edit word documents. However the
<code class="language-plaintext highlighter-rouge">python-docx</code> package is not fully mature, and using it for editing
highly-stylized word documents is a bit painful (but possible!). It is however
quite easy to extend with new functionality, in case you do need to do this. On
the other hand, there is quite extensive functionality in Visual Basic to edit
word documents, and the whole Word API is built around Visual Basic.</p>
<p>While I now have all the tools available to automatically update my CV using
Python, I will actually refrain from doing it. It is a lot of work to set up
properly, and needs active maintenance ever time I would want to change the
styling of my CV. Probably it’s a better idea to just manually edit it every
time I need to. Automatization isn’t always worth it. But I wouldn’t be
surprised if this new found skill will be useful at some point in the future for
me.</p>Rik VoorhaarParsing and editing Word documents automatically can be extremely useful, but doing it in Python is not that straightforward.Blind deconvolution #4: Blind deconvolution2021-05-31T00:00:00+00:002021-05-31T00:00:00+00:00https://rikvoorhaar.com/deconvolution-part4<p>In this final part on the deconvolution series, we will look at blind deconvolution. That is, we
want to remove blur from images while having only partial knowledge about how the image was blurred.
First of all we will develop a simple method to generate somewhat realistic forms of combined motion
and gaussian blur. Then we will try a modification of the Richardson-Lucy deconvolution algorithm as
a method for blind deconvolution – this doesn’t work very well, but does highlight a common issue
with deconvolution algorithms. Then finally we will combine the image priors discussed in part 2
with Bayesian optimization to get a decent (but slow) method for blind deconvolution.</p>
<h2 id="realistic-blur">Realistic blur</h2>
<p>What constitutes as ‘realistic’ blur obviously depends on context, but in the case of taking pictures
with a hand-held camera or smartphone, it includes both motion blur and a form of lens blur.
Generating lens blur is easy; we can just use a Gaussian blur. For motion blur we previously looked
only at straight lines, but this isn’t very realistic. Natural motion is rarely just in a straight
line, but is more erratic.</p>
<p>To model this we can take inspiration from physical processes such as Brownian motion: we can model
motion blur as the path taken by a particle with an initial velocity, which is constantly perturbed
during the motion. We want to add gaussian blur on top of that, which can simply be done by taking
the image of such a path and convolving it with a gaussian point spread function. However, we should
also take into account the speed of the particle; if we move a camera very fast then the camera
spends less exposure time in any particular point. Therefore we should make the intensity of the
blur inversely proportional to the speed at any point. The end result looks something like this:</p>
<p><img src="/imgs/deconvolution_part4/part4_1_0.png" alt="png" /></p>
<p>In practice we will consider this kind of blur at a much smaller resolution, for example of size
15x15. Below we show how such a kernel will affect for example the St. Vitus image.</p>
<p><img src="/imgs/deconvolution_part4/part4_3_0.png" alt="png" /></p>
<h2 id="richardson-lucy-blind-deconvolution">Richardson-Lucy blind deconvolution</h2>
<p>Recall that in the Richardson-Lucy algorithm we try to solve the deconvolution problem \(y=x*k\) by
using an iteration of form</p>
\[x_{i+1} = x_i\odot \left(\frac{y}{x_i*k}*k^*\right)\]
<p>This method is completely symmetric in \(k\) and \(x\), so given an estimate \(x_i\) of \(x\) we can recover
the kernel \(k\) by the same method:</p>
\[k_{j+1} = k_j\odot \left(\frac{y}{x*k_j}*x^*\right)\]
<p>A simple idea for blind deconvolution is therefore to alternatingly estimate \(k\) from \(x\) and
vice-versa. We can see the result of this procedure below:</p>
<p><img src="/imgs/deconvolution_part4/part4_5_1.png" alt="png" /></p>
<p>The problem with this Richardson-Lucy-based algorithm is that the point spread function tends to
converge to a (shifted) delta function. This is an inherit problem with many blind deconvolution
algorithm, especially those based on finding a maximum a posteriori (MAP) estimate of both the
kernel and image combined. For this particular algorithm it isn’t immediately obvious why it tends
to do this, since the analysis of this algorithm is relatively complicated. Somehow the kernel
update step tends to promote sparsity. This tends to happen irrespective of how we initialize the
point spread function, or the relative number of steps spent estimating the PSF or the image.</p>
<p>There are heuristic ways to get around this, but overall it is difficult to make a technique like
this work well. It also doesn’t use the wonderful things we learned about image priors in part 2. We
need a method that can actively avoid converging to extreme points such as this delta function.</p>
<h2 id="parametrizing-the-point-spread-functions">Parametrizing the point spread functions</h2>
<p>In part 2 we discussed different image priors, of which the most promising prior is based on
non-local self-similarity. This assigns to an image \(x\) a score \(L(x)\) signifying how ‘natural’
this image is. We saw that it indeed gave higher scores for images that are appropriately sharpened.
A simple idea is then to try different point spread functions, and use the one with the highest
score. If we denote \(x(k)\) the result of applying deconvolution with kernel \(k\), then we want to
solve the maximization problem:
\(\max_{k}L(x(k))\)</p>
<p>If we naively try to maximize this function we run into the problem that the space of all kernels is
quite large; a \(15\times 15\) kernel obviously need \(15^2=225\) parameters. Since computing the image
prior is relatively expensive (as is the deconvolution), exploring this large space is not feasible.
Moreover, the function is relatively noisy, and has the problem that it can give large scores to
oversharpened images.</p>
<p>We therefore need a way to describe the point spread functions using only a few parameters.
Moreover, this description should actively avoid points that are not interesting, such as a delta
function or a point spread function that would result in heavy oversharpening of the image.</p>
<p>There are many ways to describe a point spread function using only a couple parameters. One way that
I propose is by writing it as a sum of a small number of Gaussian point spread functions. However
instead of having the centered symmetric Gaussians we have considered so far, we will allow an
arbitrary mean and covariance matrix. This changes respectively the center and the shape of the
point spread function. That is, it depends on the parameters \(\mu=(\mu_1,\mu_2)\) and a 2x2
(symmetric, positive definite) matrix \(\Sigma\). Then the point spread function is given by</p>
\[k[i,j]\propto
\exp\left((i-\mu_1,j-\mu_2)\Sigma^{-1}(i-\mu_1,j-\mu_2)^\top\right),\qquad\sum_{i,j}k[i,j]=1\]
<p>To be precise, we can describe the covariance matrix \(\Sigma\) using three parameters
\(\lambda_1,\lambda_2>0\) and \(\theta\in[0,\pi)\) using the decomposition</p>
\[\Sigma = \begin{pmatrix}\cos\theta &\sin\theta\\-\sin\theta&\cos\theta\end{pmatrix}
\begin{pmatrix}\lambda_1&0\\0&\lambda_2\end{pmatrix}
\begin{pmatrix}\cos\theta &-\sin\theta\\\sin\theta&\cos\theta\end{pmatrix}\]
<p>We then use an additional parameter to combine different kernels of this type. By taking
\(t_1k_1+t_2k_2+\dots+t_nk_n\)</p>
<p>This gives a total of 6 parameters per mixture component, but for the first component we can set the
mean \(\mu\) to \(0\) and use a magnitude \(t_1\) of 1, reducing to 3 parameters. For now we will try to
use a mix just use two mixture components \(n=2\), and focus our attention on <em>how</em> to optimize this.</p>
<h2 id="bayesian-optimization-for-blind-deconvolution">Bayesian optimization for blind deconvolution</h2>
<p>We now know how to parameterize the point spread functions, and what function we want to optimize
(the image prior). Next is deciding <em>how</em> to optimize this. In this case we have a complicated,
noisy function that is expensive to compute, and with no easy way to compute its derivatives. In
situations like these Bayesian optimization or other methods of ‘black-box’ optimization make the
most sense.</p>
<p>How this works is that we sample our function \(L\colon \Omega\to \mathbb R\) in several points
\((z_1,\dots,z_n)\in\Omega\), where \(\Omega\) is our parameter searchspace. Based on these samples,
we build a <em>surrogate model</em> \(\widetilde L\colon\Omega\to \mathbb R\) for the function \(L\). We
can then optimize the surrogate model \(\widetilde L\) to obtain a new point \(z_{n+1}\). We then
compute \(L(z_{n+1})\), and update the surrogate model with this new information. This is then
repeated a number of times, or until convergence. So long as the surrogate model is good, this can
find an optimal point for the function \(L\) of interests much faster than many other optimization
methods.</p>
<p>The key property of this surrogate model is that it should be easy to compute, yet still model the
true function reasonably well. In addition to this, we want to incorporate uncertainty into the
surrogate model. Uncertainty enters in two ways: the function \(L\) may be noisy, and there is the
fact that the surrogate model will be more accurate closer to previously evaluated points. This
leads to Bayesian optimization. The surrogate model is probabilistic in nature, and during
optimization we can sample points both to reduce the variance (explore regions where the model is
unsure), and to reduce the expectation (explore regions of the searchspace where the model things
the optimal point should lie).</p>
<p>One type of surrogate model that is popular for this purpose is the Gaussian process (GP) (also
known as ‘kriging’ in this context). We will give a brief description of Gaussian processes. We
model the function values of the surrogate model \(\widetilde L\) as random variables. More
specifically we model the function value at a point \(z\) to depend on the samples:</p>
\[\widetilde L(z) | z_1,\dots,z_n \sim N(\mu,\sigma^2),\]
<p>where the mean \(\mu\) is a weighted average of the values at the sampled points \((z_1,\dots, z_n)\),
weighted by the distance \(\|z-z_i\|\). The variance \(\sigma^2\) is determined by a function \(K(z,z') =
K(\|z-z'\|)\) which gives the covariance between two points, and increases the more distant the
points. Note that \(K\) only depends on the distance between two points. At the sampled points
\((z_1,\dots,z_n)\) we know the function \(\widetilde L(z)\) to high accuracy, and hence \(K(z_i,z_i) =
K(0)\) is small, but as we go further from any of the sampled points the variance increases.</p>
<p>Because of the specific structure of the Gaussian process model, it is easy to fit to data and make
predictions at new points. As a result an optimal value for this surrogate model is easy to compute.
We will use an implementation of GP-based Bayesian optimization from <code class="language-plaintext highlighter-rouge">scikit-optimize</code>. All in all
this gives us the results shown below.</p>
<p><img src="/imgs/deconvolution_part4/part4_7_1.svg" alt="svg" /></p>
<p>As we can see, the estimated point spread function is still far from perfect, but nevertheless the
deblurred image looks better than the blurred image. If we blur the image with larger kernels, or
stronger blur overall recovery becomes even harder with this method. If we apply it to a different
image the result is comparable. One problem that is apparent is the fact that the point spread
function tends to shift the image. This can fortunately be corrected, by either changing the point
spread function or shifting the image after deconvolution.</p>
<p><img src="/imgs/deconvolution_part4/part4_11_0.svg" alt="svg" /></p>
<p>There are probably several reasons why this model doesn’t give perfect results. First is that the
image prior isn’t perfect, but it seems that most image priors tend to give quite noisy outputs, or
give high scores due to artifacts created by the deconvolution algorithm. Secondly, the parameter
space of this model is still quite big, especially if the prior function depends in complicated
manners on these parameters. However, it seems many methods used in the literature use even larger
searchspaces for the kernels, many algorithms even using no compression of the searchspace at all
and still claiming good results.</p>
<p>While I knew from the get-go that blind deconvolution is hard, it turned out to be even harder to do
right than I expected. I read a lot of literature on the subject, and I learned a lot. Many papers
give interesting algorithms and ideas for blind deconvolution methods. What I found however is that
most papers where quite vague in their description and almost never included code. This makes doing
research in this field quite difficult, since it can be very difficult to estimate whether or not a
method is actually useful. Moreover, if a method looks promising then implementing it can become
very difficult without adequate details.</p>Rik VoorhaarFinally, let's look at how we can automatically sharpen images, without knowing how they were blurred in the first place.Blind Deconvolution #3: More about non-blind deconvolution2021-05-02T00:00:00+00:002021-05-02T00:00:00+00:00https://rikvoorhaar.com/deconvolution-part3<p>In part 1 we saw how to do non-blind image deconvolution. In part 2 we saw a couple good image
priors and we saw how they can be used for simple blind deconvolution. This worked well for
deconvolution of Gaussian point spread functions, but it gave bad artifacts for motion blur kernels.
Typical distortions seen in pictures taken by conventional cameras have both a motion blur and
Gaussian component, so having good deconvolution for motion blur is absolutely essential.</p>
<p>We will explore two methods to improve the deconvolution method. First is a simple modification to
our current method, and second is an more expensive iterative method for deconvolution that works
better for sparse kernels.</p>
<h2 id="an-improved-wiener-filter">An improved Wiener filter</h2>
<p>Recall that deconvolution comes down to solving the equation
\(y = k*x,\)
where \(y\) is the observed (blured) image, \(k\) is the point-spread function, \(x\) is the
unobserved (sharp) images. If we take a discrete Fourier Transform (DFT) then this equation becomes
\(Y = K\odot X,\)
where capital letters denote the Fourier-transformed variables, and \(\odot\) is the <em>pointwise</em>
multiplication. To solve the deconvolution problem we can then do pointwise division by \(K\) and then
do the inverse Fourier transform. Because \(K\) may have zero or near-zero entries, we can run into
numerical instability. A quick fix is to instead multiply by \(K^* / (|K|^2+\epsilon)\), giving
solution</p>
\[x = \mathcal F^{-1}\left(Y \odot \frac{K^*}{|K|^2+\epsilon}\right)\]
<p>This is fast to compute, and gives decent results. This simple method of deconvolution is known as
the Wiener filter. In the situation where there is some noise \(n\) such that \(y=k*x+n\), this
corresponds (for a certain value of \(\epsilon\)) to \(x^*\) minimizing the expected square error
\(E(\|x-x^*\|^2)\). Instead of minimizing the error, we can accept that \(\|k*x-y\|\approx \|n\|^2\),
and then find the <em>smoothest</em> \(x^*\) with that error, to avoid ringing artifacts. Smoothness can be
modeled by the laplacian \(\Delta x^*\) This leads to the problem</p>
\[\begin{array}{ll}
\text{minimize} & \Delta x \\
\text{subject to} & \|k*x-y\|\leq \|n\|^2
\end{array}\]
<p>If \(L\) is the Fourier transform of the Laplacian kernel, then the solution to this problem has form</p>
\[x = \mathcal F^{-1}\left(Y \odot \frac{K^*}{|K|^2+\gamma |L|^2}\right)\]
<p>where the parameter \(\gamma>0\) is determined by the noise level. In the end this is a simple
modification to the Wiener filter, that should give less ringing effects. Let’s see what this does
in practice.</p>
<p><img src="/imgs/deconvolution_part3/part3_2_0.svg" alt="svg" /></p>
<p>In the picture above we tried to deblur a motion blur consisting of a diagonal strip of 10 pixels.
The deblurring is done with a kernel of 9.6 pixels (the last pixel on either end is dimmed). We do
this both with and without the Laplacian, with amounts of regularization so that the two methods
have a similar amount of ringing artifacts. The two methods look very similar, and if anything the
method without Laplacian may look a little sharper. The reason the methods behave so similarly is
probably because the Fourier transform of the Laplacian (show below) has a fairly spread-out
distribution and is therefore not too different from a uniform distribution we use in the Wiener
filter.</p>
<p><img src="/imgs/deconvolution_part3/part3_4_1.svg" alt="svg" /></p>
<h2 id="richardson-lucy-deconvolution">Richardson-Lucy deconvolution</h2>
<p>There are many iterative deconvolution methods, and one often-used method in particular is
Richardson-Lucy decomposition. The iteration step is given by</p>
\[x_{k+1} = x_k\odot \left(\frac{y}{x_k*k}*k^*\right)\]
<p>Here \(k^*\) is the flipped point spread function, its Fourier transform is the complex conjugate of
the Fourier transform of \(k\). As first iteration we typically pick \(x_0=y\). Note that if
\(\sum_{i,j}k_{ij} = 1\), then \(\mathbf 1*k = \mathbf 1\), with \(\mathbf 1\) a constant 1 signal.
Therefore if we plug in \(x_k = \lambda x\) we obtain</p>
\[x_{k+1} = x_k \odot \left(\frac{y}{\lambda y}*k^*\right) = x_k\odot \frac1\lambda \mathbf 1*k^* =
\frac{x_k}\lambda\]
<p>This both shows that \(x\) is a fixed point of the Richardson-Lucy algorithm, and at the same time it
show that the algorithm doesn’t necessarily converge, since it could alternate between \(2x\) and
\(x/2\) for example. In practice on natural images, if initialized with \(x_0=y\), it does seem to
converge. Below we try this algorithm for different number of iterations, considering the same image
and point spread function as before.</p>
<p><img src="/imgs/deconvolution_part3/part3_6_0.svg" alt="svg" /></p>
<p>We see very similar ringing artifacts as with the Wiener filter. The number of iterations of the
algorithm is related to the size of the regularization constant. The more iterations, the sharper
the image is, but also the more pronounced the ringing artifacts are.</p>
<p>Like with the Wiener filter, we need to add a small positive constant when dividing, to avoid
division-by-zero errors. Unlike the Wiener filter however, Richardson-Lucy deconvolution is very
insensitive to the amount of regularization used.</p>
<p>Richardson-Lucy deconvolution is much slower than Wiener filter, requiring perhaps 100 iterations to
reach good result. Each iteration takes roughly as long as applying the Wiener filter. Fortunately
the algorithm is easy to implement on a GPU, and each iteration of the (426, 640) image above takes
only about 1ms on my computer with a simple GPU implementation using <code class="language-plaintext highlighter-rouge">cupy</code>.</p>
<h2 id="boundary-effects">Boundary effects</h2>
<p>One issue that I have so far swept under the rug is the problem of boundary effects. If we convolve
an \((n,m)\) image by a \((\ell,\ell)\) kernel, then the result is an image of size \((n+\ell-1,
m+\ell-1)\), and not \((n,m)\). There is typically a ‘fuzzy border’ around the image, which we crop
away when displaying, but not when deconvolving. In real life we don’t have the luxury of including
this fuzzy border around the image, and this can lead to heavy artifacts when deconvolving an image.
Below is the St. Vitus church image blurred with \(\sigma=3\) Gaussian blur, and subsequently
deblurred using a Wiener filter with and without using the border around the image.</p>
<p><img src="/imgs/deconvolution_part3/part3_10_0.svg" alt="svg" /></p>
<p>The ringing at the boundary is known as Gibbs oscillation. The reason it occurs is because the
deconvolution method implicitly assumes the image is periodic. This is because the convolution
theorem (stating that convolution becomes multiplication after a (discrete) Fourier transform) needs
the assumption that the signal is periodic. If we would periodically stack a natural image we would
find a sudden sharp transition at the boundary, and this contributes to high-frequency components in
the Fourier transform, giving the sharp oscillations at the boundary.</p>
<p>The more we regularize the deconvolution, the less big the boundary effects. This is because
regularization essentially acts as a low-pass filter, getting rid of high-frequency effects.
However, this also blurs the image considerably. For Richardson-Lucy deconvolution we essentially
have the same problem.</p>
<p>The straightforward to deal with this problem is to extend the image to mimmick the ‘fuzzy’ border
introduces by convolution. Or better yet, we should pad the image in such a way that the image is as
regular as possible when stacked periodically. This is a strategy <a href="https://doi.org/10.1109/ICIP.2008.4711802">employed by Liu and
Jia</a>, they extend the image to be periodic by using three
different ‘tiles’ stacked in a pattern shown below. The image is then cropped to the dotted line,
and this gives a periodic image. The tiles are optimized such that the image is continuous along
each boundary, and such that the total Laplacian is minimized.</p>
<p><img src="/imgs/deconvolution_part3/part3_12_1.svg" alt="svg" /></p>
<p>There are many similar methods in the literature. Unfortunately, all of these methods are
complicated, and very few methods include a reference implementation. If there is one, it is almost
always in Matlab. This seems to be a general problem when reading literature about (de)convolution and image processing, for some reason in this scientific community it is not standard practice to include code with papers, and descriptions of algorithms are often vague require significant work to translate to working code. I found a Python implementation of Liu-Jia’s algorithm <a href="https://github.com/ys-koshelev/nla_deblur/blob/90fe0ab98c26c791dcbdf231fe6f938fca80e2a0/boundaries.py">at this github</a>.</p>
<p>Below we see the Laplacian of the image extended using Liu-Jia’s method, using zero padding and by reflecting the image. We see that both in the reflected image, and the one using Liu-Jia’s method there are no large values of the Laplacian around the border, because of the soft transition to the border.</p>
<p><img src="/imgs/deconvolution_part3/part3_15_1.svg" alt="svg" /></p>
<p>Next we can check if these periodic extensions of the images actually reduces boundary artifacts when deconvolving. Below we see the three methods for both the Wiener and Richardson-Lucy (RL) deconvolution in action on an image distorted with \(\sigma=3\) Gaussian blur.</p>
<p><img src="/imgs/deconvolution_part3/part3_17_0.svg" alt="svg" /></p>
<p>We can see that the Liu-Jia’s method gives a significant improvement, especially for the Wiener
filter. More strikingly, the reflective padding works even better. This is because the convolution that the distorted the image implicitly used reflective padding as well. If you change the settings of the convolution blurring the image, then the results will not be as good. Liu-Jia’s method probably works the best out-of-the box on images blurred by natural means.</p>
<p>It is interesting to note that Richardson-Lucy deconvolution suffers heavily in quality regardless of padding method. Interestingly, if we look at motion blur instead of Gaussian blur, the roles are a bit reversed. For the Wiener filter we have to use fairly aggressive regularization to not get too many artifacts, whereas RL deconvolution works without problems.</p>
<p><img src="/imgs/deconvolution_part3/part3_19_0.svg" alt="svg" /></p>
<h2 id="conclusion">Conclusion</h2>
<p>We have reiterated the fact that even non-blind deconvolution can be a difficult problem. The relatively simple Wiener filter in general does a good job, and changing it to use a Laplacian for regularization doesn’t seem to help much. The Richardson-Lucy algorithm often performs comparably to the Wiener filter, although it seems to perform relatively better for sparse kernels like the motion blur kernel we used.</p>
<p>Before we have completely ignored boundary problems, which is not something we can do with real images. Fortunately, we can deal with these issues by appropriately padding the image. Simply using reflections of the image for padding works quite well, especially depending on how we blur the image in the first place. Extending the image to be periodic while minimizing the Laplacian is more complicated, but also works well, and probably performs better in natural images.</p>
<p>In the next part (and hopefully final part) we will dive into some simple approaches for blind deconvolution. Starting off with a modification of the Richardson-Lucy algorithm, and then trying to use what we learned about image priors in part 2.</p>Rik VoorhaarDeconvolving and sharpening images is actually pretty tricky. Let's have a look at some more advanced methods for deconvolution.Blind Deconvolution #2: Image Priors2021-04-09T00:00:00+00:002021-04-09T00:00:00+00:00https://rikvoorhaar.com/deconvolution-part2<p>This is part two in a series on blind deconvolution of images. In the previous part we looked at non-blind deconvolution, where we have an image and we know exactly how it was distorted. While this situation may seem unrealistic, it does occur in cases where we have excellent understanding of how the camera takes images; for example for telescopes or microscopes which always work in the same environment.</p>
<p>The next step is then to try to do deconvolution if we have partial information about how the image was distorted. For example, we know that a lens is out of focus, but we don’t know exactly by how much. In that case we have only one variable to control, a scalar amount of blur (or perhaps two if the amount of blur is different in different directions). In this case we can simply try deconvolution for a few values, and look which image seems <em>most natural</em>.</p>
<p>Below we have the image of the St. Vitus church in my hometown distorted with gaussian blur with \(\sigma=2\), and then deblurred with several different values of \(\sigma\). Looking at these images we can see that \(\sigma=2.05\) and \(\sigma=2.29\) looks best, and \(\sigma=2.53\) is over-sharpened. The real challenge lies in finding some concrete metric to automatically decide which of these looks most natural. This is especially hard since even to the human eye this is not clear. The fact that \(\sigma=2.29\) looks very good probably means that the original image wasn’t completely sharp itself, and we don’t have a good ground truth of what it means for an image to be perfectly sharp.</p>
<p><img src="/imgs/deconvolution_part2/part2_1_0.png" alt="png" /></p>
<h2 id="image-priors">Image priors</h2>
<p>Measures of naturality of an image are often called <em>image priors</em>. They can be used to define a prior distribution on the space of all images, giving higher probability to images that are natural over those that are unnatural. Often image priors are based on heuristics, and different applications need different priors.</p>
<p>Many simple but effective image priors rely on the observation that most images have a <em>sparse gradient distribution</em>. An <em>edge</em> in an image is a sharp transition. The <em>gradient</em> of an image measures how fast the image is changing at every point, so a an edge is region in the image where the gradient is large. The gradient of an image can be computed by convolution with different kernels. One such kernel is the Sobel kernel:</p>
\[S_x = \begin{pmatrix}
1 & 0 & -1 \\
2 & 0 & -2 \\
1 & 0 & -1
\end{pmatrix},
\quad S_y = \begin{pmatrix}
1 & 2 & 1 \\
0 & 0 & 0 \\
-1 & -2 & -1
\end{pmatrix}\]
<p>Here convolution with \(S_x\) gives the gradient in the horizontal direction, and it is large when encountering a vertical edge, since the image is then making a fast transition in the horizontal direction. Similarly \(G_y\) gives the gradient in the vertical direction. If \(X\) is our image of interest, we can then define the <em>gradient transformation</em> of \(X\) by</p>
\[|\nabla X| = \sqrt{(S_x * X)^2+(S_y * X)^2}\]
<p>Below we can see this gradient transformation in action on the six images shown above:</p>
<p><img src="/imgs/deconvolution_part2/part2_3_0.png" alt="png" /></p>
<p>Here we can see that the gradients become larger in magnitude as \(\sigma\) increases. For \(\sigma
= 2.47\) we see that a large part of the image is detected as gradient – edges stopped being sparse
at this point. For the first four images we see that the edges are sparse, with most of the image
consisting of slow transitions.</p>
<p>Below we look at the distribution of the gradients after deconvolution with different values of \(\sigma\). We see that the distribution stays mostly the constant, slowly increasing in overall magnitude. But near \(\sigma=2\), the overall magnitude of gradients suddenly increases sharply.</p>
<p>This suggests that to find the optimal value of \(\sigma\) we can look at these curves and pick the value of \(\sigma\) where the gradient magnitude starts to increase quickly. This is however not very precise, and ideally we have some function which has a minimum near the optimal value of \(\sigma\). Furthermore this curve will look slightly different for different images. This is a good starting point for an image prior, but is not useful yet.</p>
<p><img src="/imgs/deconvolution_part2/part2_5_1.png" alt="png" /></p>
<p>Instead of using the gradient to obtain the edges in the image, we can use the Laplacian. The
gradient \(|\nabla X|\) is the first derivative of the image, whereas the Laplacian \(\Delta X\) is
given by the sum of second partial derivatives of the image. Near an edge we don’t just expect the
gradient to be big, but we also expect the gradient to change fast. This is because edges are
usually transient, and not extended throughout space.</p>
<p>We can compute the Laplacian by convolving with the following kernel:</p>
\[\begin{pmatrix}
0 & 1 & 0 \\
1 & -4 & 1 \\
0 & 1 & 0
\end{pmatrix}\]
<p>Note that the Laplacian can take on both negative and positive values, unlike the absolute gradient transform we used before. Below we show the absolute value of the Laplacian transformed images. This looks similar to the absolute gradient, except that the increase in intensity with increasing \(\sigma\) is more pronounced.</p>
<p><img src="/imgs/deconvolution_part2/part2_7_0.png" alt="png" /></p>
<p><img src="/imgs/deconvolution_part2/part2_8_1.png" alt="png" /></p>
<h2 id="ell_1--ell_2-metric">\(\ell_1 / \ell_2\) metric</h2>
<p>Above we can see that there is an overall increase in the magnitude of the gradients and Laplacian
as \(\sigma\) increases. We want to measure how sparse these gradient distributions are, and this
has more to do with the shape of the distribution rather than the overall magnitude. To better see
how the shape changes it therefore makes sense to normalize so that the total magnitude stays the
same. We therefore don’t consider the distribution of the gradient \(|\nabla X|\), but rather of the
normalized gradient \(|\nabla X| / \|\nabla X\|_2\). Since the mean absolute value is essentially the \(\ell_1\)-norm, this is also referred to as the \(\ell_1/\ell_2\)-norm of the gradients \(\nabla X\).</p>
<p>The normalized gradient distribution is plotted below as function of \(\sigma\), the distributions
of the Laplacian look similar. This distribution already looks a lot more promising since the median
has a minimum near the optimal value for \(\sigma\). This minimum is a passable estimate of the
optimal value of \(\sigma\) for this particular image. For other images it is however not as good.
Moreover the function only changes slowly around the minimum value, so it is hard to find in an
optimization routine. We therefore need to come up with something better.</p>
<p><img src="/imgs/deconvolution_part2/part2_10_1.png" alt="png" /></p>
<h2 id="non-local-similarity-based-priors">Non-local similarity based priors</h2>
<p>The \(\ell^1/\ell^2\) prior is a good starting point, but we can do better with a more complex prior based on <em>non-local self-similarity</em>. The idea is to divide the image up in many small patches of \(n\times n\) pixels with for example \(n=5\). Then for each patch we can check how many other patches in the image look similar to it. This concept is called non-local self-similarity, since it’s non-local (we compare a patch with patches throughout the entire image, not just in a neighborhood) and uses self-similarity (we look at how similar some parts of the image are to other parts of the same image; we never use an external database of images for example).</p>
<p>The full idea is a bit more complicated. Let’s denote each \(n\times n\) patch by</p>
\[P(i,j) = X[ni:n(i+1),\, nj:n(j+1)].\]
<p>We consider this patch as a length-\(n^2\) vector. Moreover since we’re mostly interested in the patterns represented by the patch, and not by the overall brightness, we normalize all the patch vectors to have norm 1. We then find the closest matching \(k\) patches, minimizing the Euclidean distance:</p>
\[\operatorname{argmin}_{i',j'} \|P(i,j) - P(i',j')\|\]
<p>Below we show an 8x8 patch in the St. Vitus image (top left) together with its 11 closest neighbors.</p>
<p><img src="/imgs/deconvolution_part2/part2_13_0.png" alt="png" /></p>
<p>Note that we look at patches closest in <em>Euclidean distance</em>, this does not necessarily mean the patches are visually similar. Visually very similar patches can have large euclidean distance, for example the two patches below are orthogonal (and hence have maximal Euclidean distance), despite being visually similar. One could come up with better measures for visual similarity than Euclidean distance, probably something that is invariant under small shifts, rotations and mirroring, but this would come at an obvious cost of increased (computational) complexity.</p>
<p><img src="/imgs/deconvolution_part2/part2_15_1.png" alt="png" /></p>
<p>The \(k\) closest patches together with the original patch \(P(i,j)\) are put into a \(n^2\times (k+1)\) matrix, called the <em>non-local self-similar (NLSS) matrix</em> \(N(i,j)\). We are interested in some linear-algebraic properties of this matrix. One observation is that the NLSS matrices tend to be of low rank for most patches. This essentially means that most patches tend to have other patches that look very similar to it. If all patches in \(N(i,j)\) are the same then its rank is 1, whereas if all the patches are different then \(N(i,j)\) is of maximal rank.</p>
<p>However, taking the rank itself is not necessarily a good measure, since it is not numerically stable. Any slight perturbation will always make the matrix of full rank. We rather work with a differentiable approximation of the rank. This approximation is based on the spectrum (singular values) of the matrix. In this case, we can consider the <em>nuclear norm</em> \(\|N(i,j)\|_*\) of \(N(i,j)\). It is defined as the sum of the singular values:</p>
\[\|A\|_* = \sum_{i=1}^n \sigma_i(A),\]
<p>where \(\sigma_i(A)\) is the \(i\)th singular value. Below we show how the average singular values change with scale \(\sigma\) of the deconvolution kernel for the NLSS matrices for 8x8 patches with 63 neighbors (so that the NLSS matrix is square). We see that in all cases most of the energy is in the first singular value, followed by a fairly slow decay. As \(\sigma\) increases, the decay of singular values slows down. This means that the more blurry the image, the lower the <em>effective</em> rank of the NLSS matrices. As such, the nuclear norm of the NLSS matrix gives a measure of the amount of information in the picture.</p>
<p><img src="/imgs/deconvolution_part2/part2_17_0.png" alt="png" /></p>
<p>We see that the spectrum of the NLSS matrices seem to give a measure of ‘amount of information’ or
sparsity. Since we know that sparsity of the edges in an image gives a useful image prior, let’s
compute the nuclear norm \(\|N(i,j)\|_*\) of each NLSS matrix of the gradients of the image. We can
actually plot these nuclear norms as an image. Below we show this plot of nuclear norms of the NLSS
matrices. We can see that the mean nuclear norm is biggest at around the ground truth value of
\(\sigma\).</p>
<p><img src="/imgs/deconvolution_part2/part2_19_0.png" alt="png" /></p>
<p>It is not immediately clear how to interpret the darker and lighter regions of these plots. Long
straight edges seem to have smaller norms since there are many patches that look similar. Since the
patches are normalized before being compared, the background tends to look a lot like random noise
and hence has relatively high nuclear norm. However, we can’t skip this normalization step either,
since then we mostly observe a strict increase in nuclear norms with \(\sigma\).</p>
<p>Repeating the same for the Laplacian gives a similar result:</p>
<p><img src="/imgs/deconvolution_part2/part2_21_0.png" alt="png" /></p>
<p>Now finally to turn this into a useful image prior, we can plot how the mean nuclear norm changes with varying \(\sigma\). Both for the gradients and Laplacian of the image we see a clear maximum near \(\sigma=2\), so this looks like a useful image prior.</p>
<p><img src="/imgs/deconvolution_part2/part2_23_1.png" alt="png" /></p>
<p>There are a few hyperparameters to tinker with for this image prior. There is the size of the patches taken, in practice something like 4x4 to 8x8 seems to work well for the size of images we’re dealing with. we can also lower or increase the number of neighbors computed. Finally we don’t need to divide the images into patches exactly. We can <em>oversample</em>, and put a space of less than \(n\) pixels between consecutive \(n\times n\) patches. This results in a less noisy curve of NLSS nuclear norms, at extra computational cost. We can on the other hand also <em>undersample</em> and only use a quarter of the patches, which can greatly improve speed.</p>
<p>The image above was made for \(6\times 6\) patches with 36 neighbors. Below we make the same plot with \(6\times 6\) patches, but only taking 1/16th of the patches and only 5 neighbors. This results in a much more noisy image, but it runs over 10x faster and still gives a useful approximation.</p>
<p><img src="/imgs/deconvolution_part2/part2_25_2.png" alt="png" /></p>
<p>One final thing of note is how the NLSS matrices \(N(i,j)\) are computed. Finding the closest \(k\) patches through brute-force methods of computing the distance between each pair of patches is extremely inefficient. Fortunately there are more efficient ways to solving this <em>similarity search</em> problem. These methods usually first make an index or tree structure saving some information about all the data points. This can be used to quickly find a set of points that are close to the point of interest, and searching only within this set significantly reduces the amount of work. This is especially true if we only care about approximately finding the \(k\) closest points, since this mean we can reduce our search space even further.</p>
<p>We used <a href="https://github.com/facebookresearch/faiss">Faiss</a> to solve the similarity search problem, since it is fast and runs on GPU. There are many packages that do the same, some faster than others depending on the problem. There is also an implementation in <code class="language-plaintext highlighter-rouge">sklearn</code>, but it is slower by over 2 orders of magnitude than Faiss when running on GPU for this particular situation.</p>
<p>At the end of the day the bottleneck in the computation speed is the computation of the nuclear norm. This in turn requires computing the singular values of tens of thousands of small matrices. Unfortunately CUDA only supports batched SVD computation of matrices of at most 32x32 in size, and indeed if we use 5x5 patches or smaller, we can make this up to 4x faster by doing the computation on GPU on my machine.</p>
<h2 id="testing-the-image-prior">Testing the image prior</h2>
<p>The nuclear norms of NLSS matrices seem to give a useful image prior, but to know for sure we need to test it for different images, and also for different types of kernels.</p>
<p>To estimate the best deconvolved image, will take the average of the optimal value for the NLSS nuclear norms of the gradient and Laplacian. This is because it seems that the Laplacian usually underestimates the ground truth value whereas the gradient usually overestimates it. Furthermore, instead of taking the global maximum as optimal value, we take the <em>first maximum</em>. When we oversharpen the image a lot, the strange artifacts we get can actually result in a large NLSS nuclear norm. It can be a bit tricky to detect a local maximum, and if the initial blur is too much then the prior seems not to work very well.</p>
<p>First let’s try to do semi-blind deconvolution for Gaussian kernels. That is, we know that the image was blurred with a Gaussian kernel, but we don’t know with what parameters. We do this for a smaller and a larger value for the standard deviation \(\sigma\), and notice that for smaller \(\sigma\) the recovery is excellent, but once \(\sigma\) becomes too large the recovery fails.</p>
<p>All the images we use are from the <a href="https://cocodataset.org/#home">COCO 2017 dataset</a>.</p>
<p>First up is an image of a bear, blurred with \(\sigma=2\) Gaussian kernel. Deblurring this is easy, and not very sensitive on the hyper parameters used.</p>
<p><img src="/imgs/deconvolution_part2/part2_30_1.png" alt="png" /></p>
<p>Here is the same image of the bear, but now blurred with \(\sigma=4\), and it becomes much harder to recover the image. I found that the only way to do it is to reduce the patch size all the way to \(2\times 2\), for higher patch sizes the image can’t be accurately recovered and it always overestimates the value of \(\sigma\).</p>
<p><img src="/imgs/deconvolution_part2/part2_32_1.png" alt="png" /></p>
<p>Below is a picture of some food. For \(\sigma=3\) recovery is excellent, and again not strongly dependent on hyperparameters. For \(\sigma=4\) the problem becomes significantly harder, and it again takes a small patch size for reasonable results.</p>
<p><img src="/imgs/deconvolution_part2/part2_34_1.png" alt="png" /></p>
<p><img src="/imgs/deconvolution_part2/part2_35_1.png" alt="png" /></p>
<p>Now let’s change the blur kernel to an idealized motion blur kernel. Here the point spread function is a line segment of some specified length and thickness, as shown below:</p>
<p><img src="/imgs/deconvolution_part2/part2_38_1.png" alt="png" /></p>
<p>The way I construct these point spread functions is by rasterizing an image of a line segment. I’m sure there’s a better way to do this, but it seems to work fine. The parameters of the kernel are the angle, the length of the line segment and the size of the kernel.</p>
<p>Let’s try to apply the method on a picture of some cows below:</p>
<p><img src="/imgs/deconvolution_part2/part2_40_1.png" alt="png" /></p>
<p>Unfortunately our current method doesn’t work well with this kind of point spread function. The nuclear norm of the NLSS matrices is very noisy. I first thought this could be because the PSF doesn’t change continuously with the length of the line segment. But I ruled this out by hard-coding a diagonal line segment in such a way that it changes continuously, and it looks just as bad.</p>
<p>Instead it seems that the (non-blind) deconvolution method itself doesn’t work well for this kernel. Below we see the image blurred with a length 5 diagonal motion blur, and then deconvolved with different values. With the Gaussian blur we only saw significant deconvolution artifacts if we try to oversharpen an image. Here we see very significant artifacts even if the length parameter is less than 5. I think this is because the point spread function is very discontinuous, and hence its Fourier transform is very irregular.</p>
<p>Additionally, the effect of motion blur on edges is different than that of Gaussian blur. If the edge is parallel to the motion blur, it is not affected or even enhanced. On the other hand, if an edge is orthogonal to the direction of motion blur, the edge is destroyed quickly. This may mean that the sparse gradient prior is not as effective as for Gaussian blur. We have no good way to check this however before improving the deconvolution method.</p>
<p><img src="/imgs/deconvolution_part2/part2_42_1.png" alt="png" /></p>
<h2 id="conclusion">Conclusion</h2>
<p>Having a good image prior is vital for blind deconvolution. Making a good image prior is however quite difficult. Most image priors are based on the idea that natural images have sparsely distributed gradients. We observed that the simple and easy-to-compute \(\ell_1/\ell_2\) prior does a decent job, but isn’t quite good enough. The more complex NLSS nuclear norm prior does a much better job. Using this prior we can do partially blind deconvolution, sharpening an image blurred with Gaussian blur.</p>
<p>However, another vital ingredient for blind deconvolution is good non-blind deconvolution. The current non-blind deconvolution method we introduced in the last part doesn’t work well for non-continuos or sparse point spread functions. There are also problems with artifacts at the boundaries of the image (which I have hidden for now by essentially cheating). This means that if we want to do good blind deconvolution, we first need to revisit non-blind deconvolution and improve our methods.</p>Rik VoorhaarIn order to automatically sharpen images, we need to first understand how a computer can judge how 'natural' an image looks.Blind Deconvolution #1: Non-blind Deconvolution2021-03-13T00:00:00+00:002021-03-13T00:00:00+00:00https://rikvoorhaar.com/deconvolution-part1<p>I recently became interested in blind deconvolution. Initially I didn’t even know the proper name
for this, I simply wondered if it’s possible to automatically sharpen images given we have some
limited information about how they are blurred. Then I went on to do some actual research, and I
started diving into the fascinating topic of blind deconvolution. This post will be the first of
several, where I dive into blind deconvolution. In the end I will actually look at the
implementation of one or two blind deconvolution methods. It turns out blind deconvolution is very
difficult and has a vast scope of literature associated to it. Therefore I will split it up into
several posts. This is my preliminary plan:</p>
<ul>
<li>Part I: Introduction to convolution and deconvolution</li>
<li>Part II: Comparing different image priors on a toy problem</li>
<li>Part III: A deep look at blind deconvolution, and implementing it ourselves</li>
</ul>
<p>Without further ado, let’s figure out what (blind) deconvolution is in the first place!</p>
<h2 id="blur-as-convolution">Blur as convolution</h2>
<p>There are many types of blur that can be applied to images, but there are arguably two main types.
The first is lens blur, coming from the lens not being perfectly in focus or from imperfections in
the optics. And the second is motion blur, which is caused by the camera or the photographed object
moving. Both of these types of blur can be described by convolution of the image \(x\) with a
<em>kernel</em> or <em>point spread function</em> (PSF) \(k\):</p>
\[(x*k)[i,j] = \sum_{l,m}x[i-l,j-m]k[l,m].\]
<p>One particular PSF is the delta function, whose only nonzero entry is \(\delta[0,0]=1\). It is the
identity operation for convolution:</p>
\[x*\delta = x.\]
<p>Often point pread functions have finite support; they are only non-zero for a finite number of
entries. In this case we can write the PSF as a matrix, where the <em>middle</em> entry corresponds to
\(k[0,0]\). In this case the delta function is a \(1\times 1\) matrix with \(1\) as it’s only entry.</p>
<h3 id="box-blur">Box blur</h3>
<p>Another very simple (but not necessarily natural) PSF is given by a constant matrix. For example,</p>
\[k = \frac19\begin{pmatrix}1&1&1\\1&1&1\\1&1&1\end{pmatrix}.\]
<p>Here we divide by 9 so that the total sum of entries of \(k\) is 1. This is useful so that
convolution with \(k\) preserves the magnitude of \(x\). If not, then the image would become
brighter or dimmer after convolution, which we don’t want.</p>
<p>Convolution with a matrix like this has a name; it’s called box blur. It’s a very simple type of
blur which replaces each pixel by an average of it’s neighboring pixels. It’s main use is that it’s
very fast and easy implement, and to the human eye looks quite a lot like other types of blur.</p>
<h3 id="gaussian-blur">Gaussian blur</h3>
<p>Lens blur can be approximated by a Gaussian PSF, i.e. a kernel \(k\) such that</p>
\[k[i,j] \propto \exp\left(-\frac{i^2+j^2}{2\sigma^2}\right),\]
<p>for some \(\sigma\). With \(\sigma=1\) the magnitude will decay by one standard deviation per pixel.
Visually this looks quite similar to box blur, especially for smaller amounts of blur, but Gaussian
blur is smoother, and more accurately emulates lens blur.</p>
<h3 id="motion-blur">Motion blur</h3>
<p>Motion blur can be described by PSF which, when seen as an image, is a line segment. For example a
horizontal line segment through the middle of the PSF is equivalent to camera motion in the
horizontal axis. For real life-motion blur this is really only true if the entire scene is equally
far away, for example if we consider a spacecraft in orbit taking photos of the Earth’s surface.
Otherwise the amount and direction of motion blur is not uniform throughout the image.</p>
<h3 id="comparison-of-blur-types">Comparison of blur types</h3>
<p>We will apply all types of blur to a cropped and scaled image of the St. Vitus church in my hometown
taken in 1946 (image credit: <a href="http://proxy.handle.net/10648/a894b7c6-d0b4-102d-bcf8-003048976d84">Koos
Raucamp</a>).</p>
<p>The first image shows a delta PSF. The top-right shows box blur with a \(3\times 3\) box. The
bottom-left image shows Gaussian blur, with \(\sigma=1\). Finally the bottom-right image shows
motion blur with a top-left to bottom-right diagonal line segment of 5 pixels in length.</p>
<p><img src="/imgs/part1_4_0.png" alt="png" /></p>
<h2 id="fourier-transforms-and-deconvolution">Fourier transforms and deconvolution</h2>
<p>There is a remarkable relationship between the Fourier transform and convolution, both in the
discrete and continuous case. Recall that the discrete Fourier transform (in one dimension) of a
signal \(f\) of length \(N\) is defined by</p>
\[\mathcal F(f)[k] = \sum_n f[n]\exp\left(\frac{-i2\pi kn}{N}\right).\]
<p>The Fourier transform turns convolution into (pointwise) multiplication:</p>
\[\mathcal F(f*g)[k] = \mathcal F(f)[k]\cdot\mathcal F(g)[k].\]
<p><em>(This does ignore some issues related to the fact that the signals we consider are not periodic,
and we may need to pad the result with zeros and use appropriate normalization. This result is
actually very easy to prove, although the details are not important right now.)</em></p>
<p>This is a very useful property. For one, discrete Fourier transformations can be computed much
faster than naively expected using the fast Fourier transform (FFT) algorithm. Naively applying the
definition of the discrete Fourier transform to a length \(N\) signal requires \(O(N^2)\)
operations, but the FFT runs in \(O(N\log N)\). It does this by recursively splitting the signal in
two; an ‘odd’ and ‘even’ part, and it computes the FFT for both halves and then combines the result
to get the FFT of the entire signal. We can use the speed of the FFT to compute the convolution of
two length \(N\) signals in \(O(N\log N)\) as well, simply by doing</p>
\[f*g = \mathcal F^{-1}(\mathcal F(f)\cdot \mathcal F(g)).\]
<p>Another thing is that it makes arithmetic with convolution much easier. For example we can use it to
<em>deconvolve</em> a signal. That is, we can solve the following problem for \(x\):</p>
\[y = x*k.\]
<p>We take the discrete Fourier transform on both sides:</p>
\[\mathcal F(y) = \mathcal F(x)\cdot \mathcal F(k).\]
<p>Then we divide and take the inverse discrete Fourier transform to obtain:</p>
\[x = \mathcal F^{-1}\left(\frac{\mathcal F(y)}{\mathcal F(k)}\right).\]
<p>And indeed this works! However, it requires knowing the kernel \(k\) <em>exactly</em>. If it is even
slightly off, we can get strange results. Below we see an original image in the top left. Then on
the top right a version with Gaussian blur with \(\sigma=2\). Then on the bottom we respectively
deconvolute with a Gaussian PSF with \(\sigma=2\) and \(\sigma=2.01\). The first looks identical to
the original image, but then the second doesn’t look similar at all!</p>
<p><img src="/imgs/part1_6_0.png" alt="png" /></p>
<p>What is going on here? A quick look at the discrete Fourier transform of the PSF gives us the
answer. Recall that the Fourier transform of a real signal is actually complex, so below we plot the
absolute value of the Fourier transform on a logarithmic scale. For reference we also plot the
fourier transform of the original and blurred signals.</p>
<p><img src="/imgs/part1_8_0.png" alt="png" /></p>
<p>We see that the Fourier transform of the kernel has many values close to \(0\). This means that
dividing by such a signal is not numerically stable. Indeed if we slightly perturb either the kernel
\(k\) or the blurred signal \(y\), we can end up with strange results, as seen above.</p>
<h2 id="regularizing-deconvolution">Regularizing deconvolution</h2>
<p>If we want to do deconvolution, we clearly need something more numerically stable than the naive
algorithm of dividing the Fourier-transformed signals. This means putting some kind of
regularization that makes the solution look more natural. Above, our main problem is that the Fourier
transform of \(k\) has values close to zero, so one thing we can try is to add a small number to
\(\mathcal F(k)\) before division. One problem here is that \(\mathcal F(k)\) is complex, so it’s
not immediately clear how to add a number to make it nonzero. However, note that we can write</p>
\[\frac{1}{\mathcal F(k)} = \frac{\mathcal F(k)^*}{\mathcal F(k)\mathcal F(k)^*} = \frac{\mathcal F(k)^*}{|\mathcal F(k)|^2}\]
<p>In this formula any numerical instability is coming from the the division by \(|\mathcal F(k)|^2\).
This is always a positive real number, so we can move it away from zero by adding a constant. This
gives us the following formula for deconvolution:</p>
\[x = \mathcal F\left(\mathcal F(y) \cdot \frac{\mathcal F(k)^*}{|\mathcal F(k)|^2+S}\right)^{-1},\]
<p>where \(S>0\) is a regularization constant. Let’s see how well this works for different values of
\(S\):</p>
<p><img src="/imgs/part1_11_0.png" alt="png" /></p>
<p>If you look closely, the image looks best for \(S=10^{-8}\). For lower and higher values we see a
ringing effect, particularly noticeable in portion of the image occupied by the sky. Visually the
best deconvoluted image looks indistinguisble from the original. However if we look at the discrete
Fourier transform of the same images, they actually look quite a bit different. (The difference is
however exaggerated by the logarithmic scale). There are significant artifacts remaining from the
near-zero values of the Fourier transform of the PSF</p>
<p><img src="/imgs/part1_13_0.png" alt="png" /></p>
<h2 id="deconvolution-using-linear-algebra">Deconvolution using linear algebra</h2>
<p>Given that I do research in numerical linear algebra, it might be interesting to cast the
deconvolution problem into linear algebra. Note that we’re essentially solving the minimization
problem</p>
\[\min_x \|k*x-y\|^2\]
<p>Since \(k*x\) is linear in all the entries of \(x\), we can actually write this as matrix
multiplication \(k*x = Kx\), where \(K\) is the <em>convolution matrix</em>. For one-dimensional
convolution with a kernel \(k\) this matrix is \(K_{ij} = k[i-j]\). Using the convolution matrix we
can turn deconvolution into a linear least-squares problem, and deconvolution using Fourier
transforms gives the exact minimizer of this problem. The reason this exact solution becomes garbage
as soon as we slightly perturb \(y\) or \(k\) is because the matrix \(K\) is very ill-conditioned.
The <em>condition number</em> of a matrix \(K\) tells us how much any numerical errors in a vector \(b\)
can get amplified if we’re trying to solve the linear system \(Kx = b\).</p>
<p>Fortunately there are ways to deal with ill-conditioned systems through regularization. There are a
number of regularization techniques, but in our case this isn’t immediately helpful because of the
size of the matrix \(K\). If we consider an \(n\times m\) image, then the matrix \(K\) is of size
\(nm\times nm\). For example if we have a \(1024\times 1024\) image then the image requires on the
order of 1MB of memory, but the matrix \(K\) would take up on the order of 1TB of memory! Obviously
that will not fit in the memory of a typical home computer, so working directly with the matrix
\(K\) is completely infeasible. Moreoever, while the matrix \(K\) has a lot of structure, it is not
sparse, so we cannot store it as a sparse matrix either.</p>
<p>Nevertheless computing a matrix product \(Kx\) is cheap, since it’s just convolution. There are good
linear solvers that only need matrix-vector products, without ever forming the matrix explicitly.
These are usually iterative Krylov subspace methods. Fortunately scipy has several such solvers, and
out of those implemented there it seems that the LGMRES (Loose Generalized Minimal Residual Method)
solver works best for this particular problem. Even without regularization this produces decent
results. Nevertheless, it’s a bit finicky to get working well, and on my machine the deconvolution
takes a full minute, as opposed to a few milliseconds for FFT-based deconvolution.</p>
<p><img src="/imgs/part1_16_1.png" alt="png" /></p>
<h2 id="conclusion">Conclusion</h2>
<p>We can undo blurring caused by convolution if we know the point spread function. Naively performing
deconvolution using discrete Fourier transforms is not numerically stable, but we can improve the
numerical stability. Nevertheless, unless we know the point-spread function with very high
precision, the result is not perfect, as is evident from the Fourier transforms.</p>
<p>In the next part we will start with blind deconvolution. In that case we don’t know the point spread
function, so we need to deconvolve with a number of different kernels and iterate towards an
approximation of the true PSF. The biggest problem at hand is to have an objective that tells us
which deconvolved image ‘looks more natural’. It is not clear a priori what the best way to measure
this is, and we will look at several approaches to this problem. Then in the final part we will try
one or two algorithms of blind deconvolution.</p>Rik VoorhaarDeconvolution is one of the cornerstones of image processing. Let's take a look at how it works.Time series analysis of my email traffic2021-02-13T00:00:00+00:002021-02-13T00:00:00+00:00https://rikvoorhaar.com/email-time-series<p>I’ve been using gmail since back 2006 – when it was still an invite-only beta. In these last 15
years I have received a lot of emails. I wondered if I’m actually receiving more emails now than
back then, or if there are any interesting trends. I want to see if I can make a good model of my
email traffic.</p>
<p>Fortunately obtaining a time series of your email traffic is very easy. You can download a .mbox
file with all your emails. Such a file can easily be processed using the <code class="language-plaintext highlighter-rouge">mailbox</code> package in the
Python standard library. I made a short script that loads a .mbox email archive and extracts some
metadata for all the emails, including the time at which it was sent. Maybe I’ll use the other
metadata for some other project sometime, but for now let’s focus on the timestamps of when the
email was sent.</p>
<h2 id="simple-trend-analysis">Simple trend analysis</h2>
<p>By looking at specific components of the time series we can discover some basic trends. In principle
we can model trends as as a sum of trends on different timescales. For example the entire timeseries
has components in the scales:</p>
<ul>
<li>Time of day</li>
<li>Day of week</li>
<li>Time of year</li>
<li>Global (non-periodic) trends</li>
</ul>
<p>We can look at these seperately, but a more accurate model would models these all at the same time.
Gelman et al. describes how to do this using Bayesian statistics, and ti would be good to try
adapting their methods, but for now we’ll just use a package instead.</p>
<h3 id="global-trends">Global trends</h3>
<p>We can get a useful timeseries by counting the total number of emails received each days. Plotting
this timeseries is however not very useful, because it is extremely noisy. To look at patterns in
the data we need to smoothen it. This is done by applying some kind of low-pass filter, and there
are many choices for a filter. Very popular is to use a rolling mean, but I personally prefer to use
a Gaussian filter since the final result looks smoother. In the signal processing literature people
would prefer using filters such as a Butterworth filter. At the end of the day, we’re mainly using
the filters for the purpose of plotting so it isn’t too important.</p>
<p>Below is a plot of the email timeseries with a Gaussian filter with a standard deviation of 60 days
(blue) and 15 days (gray). We can see that I receive about 4-8 emails per day on average. This does
not include any spam, since these emails eventually get deleted and are therefore not in the email
archive. We can see there are a significant spike in activity in 2010, and an increasing trend over
the past couple years. We can also see a lot of local fluctuations, and as we shall see these can be
largely attributed to a fairly regular seasonal variation.</p>
<p><img src="/imgs/email_datetime_8_0.png" alt="png" /></p>
<h3 id="weekday-trends">Weekday trends</h3>
<p>Unsurprisingly I receive less emails on the weekend. Interestingly emails are nearly equally common
on Tuesday through Friday, but less on Mondays.</p>
<p><img src="/imgs/email_datetime_10_0.png" alt="png" /></p>
<h3 id="seasonal-patterns">Seasonal patterns</h3>
<p>Below is a plot of seasonal trends, the blue line is smoothed with a standard deviation of 7 days,
the gray dots with 1 day. We can see two dips, one around new year and one in summer, both times of
vacation. There are also monthly oscillations, and a peak before and after summer, and before the
winter holidays. I don’t have a satisfying explanation for this.</p>
<p><img src="/imgs/email_datetime_12_0.png" alt="png" /></p>
<p>The daily trend shows some very clear patterns as well. Here the blue line is smoothened with a
standard deviation of 15 minutes, and the gray line with 3 minutes, the times are all in UTC.</p>
<p>We can clearly see that most activity is concentrated between 9:00 and 15:00. We then see two
decreases at around 15:00, and 17:00. The first probably corresponds to the end of the working day
(during summertime in the Netherlands / Switzerland), the second drop may also correspond to the end
of the working day but for emails whose timestamps lack timezone information. We then see reduced
activity, which starts to taper off even further from about 21;00 onward. This may correspond in
part to email sent during the American working day, and in part in the European evening. Then
finally there is very low activity during the night between roughly 23:00 and 5:00.</p>
<p>On the gray curve we can also see a peak corresponding at each hour mark, which are probably all
caused by emails scheduled to go out at a particular time.</p>
<p><img src="/imgs/email_datetime_14_0.png" alt="png" /></p>
<h2 id="additive-model">Additive model</h2>
<p>Rather than looking at each timescale separately as we have done so far, we can model the different
time scales at the same time in an additive model. In a simple model we will model our signal
\(f(t)\) as</p>
\[f(t) = f_{\mathrm{week}}(t)+f_{\mathrm{year}}(t)+f_{\mathrm{trend}}(t)+\epsilon(t)\]
<p>where the first term has a 7-day period, the second term a 365 (or 366) day period, and the third
term is only allowed to change slowly (e.g. once every few months). Finally we assume a constant
Gaussian noise term for the residuals of our model, which we don’t assume to be constant in
magnitude, but always centered at 0. All of the components in our model can be taken to be a
Gaussian process (even the magnitude of the noise). The details on Gaussian processes and how to fit
them are perhaps nice for another blog post, but for the time being we will use a package to do all
the work for us. We will be using <a href="https://github.com/facebook/prophet">Prophet</a>, which is developed
by Facebook. Its main use is predicting the future of time series, but it also works fine just for
modeling time series.</p>
<p>The resulting model seems quite similar to what we have already discovered previously. The global
trend in particular is a bit less oscillatory, but the weekly and seasonal trends are nearly
identical.</p>
<p><img src="/imgs/email_datetime_17_1.png" alt="png" /></p>
<h3 id="analysis-of-the-additive-model">Analysis of the additive model</h3>
<p>Next we can wonder how accurate this model is. The model assumes that the noise, and hence the
residuals, are normally distributed. Let’s try to see how well this assumption holds up by analyzing
the distribution of the residuals. In a normal distribution, a distance of 1 standard deviation to
the mean corresponds to the quantiles of 0.159 and 0.841 respectively. And similarly a distance of 2
standard deviation from the origin corresponds to quantiles of 0.023 an 0.977 respectively. Finally,
the median and mean should coincide. We can therefore compute these quantiles in a rolling fashion,
and normalize by dividing by the standard deviation. If the residuals are normally distributed,
these rolling normalized quantiles should stay close to horizontal integer lines.</p>
<p>Below we plotted just that, with a rolling window of 200 days. We can see that the rolling median,
and the rolling quantiles corresponding to one standard deviation, both correspond well to a normal
distribution. We do see a bit of deviation between 2009 and 2011, which is likely caused by the
sudden spike around the start of 2010, which seems a bit of an outlier.</p>
<p>The 2-standard deviation rolling quantiles seem skewed towards bigger values, however. This is
because there are many days with very large spikes in email traffic, and the global distribution of
email traffic is not symmetric either. Furthermore we are dealing with strictly positive data (I
can’t receive a negative number of emails), this in itself means that the residuals of any models
are not going to be normally distributed. Therefore the model’s assumptions are invalid, and a more
accurate model would make more accurate assumptions about the distribution of the residuals.</p>
<p>However, assuming normality of the residuals tends to make computations much more easy, and a model
with a more accurate noise model might be difficult to fit, especially on large amounts of data. I
might try to do this in a future blog post. We are dealing here with counting data (namely the
number of emails in a given day). Such data is often modeled by a Poisson distribution rather than a
normal distribution. The main assumption of a Poisson distribution is that the events of arriving
emails are all independent. This is probably not the case, but we can either way see how well this
assumption holds up.</p>
<p><img src="/imgs/email_datetime_19_0.png" alt="png" /></p>
<h2 id="distribution-of-time-between-consecutive-emails">Distribution of time between consecutive emails</h2>
<p>Finally let’s try to get a deeper understanding of time series by considering the distribution of
time between consecutive emails. Having a good understanding of this can help to model the time
series better. If we model the arrival times of all emails to be independent, except for a global
variation in rate, we are naturally lead to model the time \(T\) between consecutive emails by an
exponential distribution:</p>
\[T_t\sim \exp(\lambda(t))\]
<p>where \(\lambda(t)\) is a rate parameter that depends on time, since we already established that the
rate at which we receive emails is not constant over time.</p>
<p>If we divide an exponential distribution by its mean, it will always me an exponential distribution
with unit rate. We can use this to obtain a similar plot to the plot of the residuals. We will
divide the time series of time between consecutive emails by a rolling mean, and then we will plot
the rolling quantiles of the resulting data. These can then be compared to the quantiles of a
standard exponential distribution.</p>
<p>This is done in the plot below, and we can clearly see that the distribution of time between
consecutive emails is not exponentially distributed. The distribution is much more concentrated in
low values than expected from an exponential distribution. It also seems to have a bit longer tail
than predicted by an exponential distribution (although this is harder to see in this plot). This is
because emails are not independent. For instance, if you’re having an active conversation with
someone you might get a lot of emails in a short amount of time, but most of the time emails come in
at a slower rate. Furthermore there are quite a number of times that emails arrive at the exact same
second, which should have very low probability under an exponential model.</p>
<p>One can try fitting different distributions to this data. For example a gamma distribution has a
better fit, but still does not properly model the probabilities of very small time intervals.
Perhaps a mixture of several gamma distributions would fit the distribution of the data well, but
this kind of distribution is hard to interpret. A good statistical model should have a good
theoretical justification as well.</p>
<p><img src="/imgs/email_datetime_22_0.png" alt="png" /></p>
<h2 id="conclusion">Conclusion</h2>
<p>We conclude the analysis of this email time series for now. I can’t say that I have learned anything
useful about my own email traffic, but the analysis itself was very interesting to me. It can be
interesting to dive into data like this and really try to understand what’s going on. To not only
model the data (which could be useful for predictions), but to also dive deeper into the
shortcomings of the model. I will hopefully get back to this time series and come up with a more
accurate model that makes more realistic assumptions about the data. The only way to come up with
such models is to first understand the data itself better.</p>Rik VoorhaarI have 15 years worth of email traffic data, let's take a closer look and discover some fascinating patterns.2020 in music2020-12-31T00:00:00+00:002020-12-31T00:00:00+00:00https://rikvoorhaar.com/music_2020<p>It goes without saying that 2020 is a special year in a number of ways. I want to look back at what
2020 meant for me in terms of music. As a side effect of staying home almost all year, I have also
listened to more music than previous years. In fact, I listened to music roughly twice as much in
2020 compared to 2019, although it still can’t match my high school levels. Below is a bar plot of
my number of last.fm scrobbles since 2010.</p>
<p><img src="/imgs/music_2020/music_2020_barplot.png" alt="png" /></p>
<p>On a less positive note, 2020 is also the year Google Play Music stopped. I’m very sad about this,
as it had both great streaming features and allowed you to listen to your own uploaded music
seamlessly. YouTube music is not a good replacement, so I switched to Spotify. This still leaves a
lot to be desired, since I can’t add my own music to the Spotify library and listening to any music
that’s not in the Spotify library is a hassle. On the plus side, Spotify does have a larger
collection of music than Google Play Music did.</p>
<p>In the rest of this post I want to look back on the music I discovered during 2020. There has been a
lot of great music that came out this year, but also a lot of music that was released before 2020
that I only discovered this year. This is not a complete list, but rather a list of the albums that
had the most impact on me in no particular order. This list was compiled by looking at all the
albums I listened much more than in all other years combined. I have listened to some of these
albums before 2020, but I didn’t really get into them before. Let us start with the albums that were
actually released this year.</p>
<h2 id="music-from-2020">Music from 2020</h2>
<h3 id="fiona-apple---fetch-the-bolt-cutters-2020">Fiona Apple - Fetch the Bolt Cutters (2020)</h3>
<p><img style="float: left; padding-right: 10px;" src="/imgs/music_2020/fetch_the_boltcutters.jpg" width="150" />
<em>Genre: Art Pop</em><br />
This is definitely the album of the year for me. I can’t quite put my finger on why, but this album
just sounds amazing. I think it’s the overall sound of the album that really draws me in, and it
doesn’t become boring even after dozens of listens.</p>
<div style="clear: left;"></div>
<h3 id="gezan---klue-2020">Gezan - KLUE (2020)</h3>
<p><img style="float: left; padding-right: 10px;" src="/imgs/music_2020/klue.jpg" width="150" /></p>
<p><em>Genre: Noise Rock</em><br />
Definitely one of the weirder albums of this year. At times intense, at times hypnotic and at other
times very calm. The album transitions between different moods very smoothly and results a coherent
interesting piece of music.</p>
<div style="clear: left;"></div>
<h3 id="king-krule---man-alive-2020">King Krule - Man Alive! (2020)</h3>
<p><img style="float: left; padding-right: 10px;" src="/imgs/music_2020/man_alive.jpg" width="150" /></p>
<p><em>Genre: Post-Punk</em><br />
Calm, melancholic and noisy. The texture and sound of this album is great, and just gets better
every time you listen to it.</p>
<div style="clear: left;"></div>
<h3 id="jeff-rosenstock---no-dream-2020">Jeff Rosenstock - NO DREAM (2020)</h3>
<p><img style="float: left; padding-right: 10px;" src="/imgs/music_2020/no_dream.jpg" width="150" />
<em>Genre: Pop Punk</em><br />
Very energetic, although the latter half of the album is very melancholic. The lyrics are also very
witty, and the noisy guitar riffs are nice, but in the end I think I like it because it’s just a
really catchy record.</p>
<div style="clear: left;"></div>
<h3 id="klô-pelgag---notre-dame-des-sept-douleurs-2020">Klô Pelgag - Notre-Dame-des-Sept-Douleurs (2020)</h3>
<p><img style="float: left; padding-right: 10px;" src="/imgs/music_2020/notre-dame-des-sept-douleurs.jpg" width="150" />
<em>Genre: Art/Baroque Pop</em> <br />
Probably the first French-language pop I have to come to enjoy, and probably the second best album
that came out this year as far as I’m concerned. The instrumentals on this record are perfect, and
the melody is incredibly catchy.</p>
<div style="clear: left;"></div>
<h3 id="pottery---welcome-to-bobbys-motel-2020">Pottery - Welcome to Bobby’s Motel (2020)</h3>
<p><img style="float: left; padding-right: 10px;" src="/imgs/music_2020/welcome_to_bobbys_motel.jpg" width="150" />
<em>Genre: Dance Punk</em> <br />
This record sounds a lot like Talking Heads’ <em>Remain in the Light</em>, and this is not a bad thing.</p>
<div style="clear: left;"></div>
<h3 id="the-microphones---microphones-in-2020-2020">The Microphones - Microphones in 2020 (2020)</h3>
<p><img style="float: left; padding-right: 10px;" src="/imgs/music_2020/microphones_in_2020.jpg" width="150" />
<em>Genre: Indie Folk</em> <br />
The “music video” of this album is a 45 minute video showing photographs, and I was glued to my
screen every second of it. The sound is repetitive, and it draws you in; hypnotizes you. On top of
that, it is almost as packed with emotion as <em>A Crow Looked At Me</em>.</p>
<div style="clear: left;"></div>
<h2 id="other-music-i-discovered-in-2020">Other music I discovered in 2020</h2>
<h3 id="tom-waits---rain-dogs-1985">Tom Waits - Rain Dogs (1985)</h3>
<p><img style="float: left; padding-right: 10px;" src="/imgs/music_2020/rain_dogs.jpg" width="150" />
<em>Genre: Singer/Songwriter, Experimental Rock</em><br />
I’m surprised I only discovered Tom Waits this year. Waits’ deep smoky voice and bluesy songs are
wonderful. It’s hard to pin this record down to a single genre, yet the sound is completely
coherent.</p>
<div style="clear: left;"></div>
<h3 id="duster---duster-2019">Duster - Duster (2019)</h3>
<p><img style="float: left; padding-right: 10px;" src="/imgs/music_2020/duster.jpg" width="150" />
<em>Genre: Slowcore</em><br />
Slow, noisy and hypnotic. This album will calmly you put you into a trance, but never leave you
bored.</p>
<div style="clear: left;"></div>
<h3 id="hella---hold-your-horse-is-2002">Hella - Hold Your Horse Is (2002)</h3>
<p><img style="float: left; padding-right: 10px;" src="/imgs/music_2020/hold_your_horse_is.jpg" width="150" />
<em>Genre: Math Rock</em><br />
Very complex and technical instrumental math rock. Often I lose interest in purely instrumental albums
after a couple listens, but this album definitely stays interesting even after many listens.</p>
<div style="clear: left;"></div>
<h3 id="milton-nascimento--lô-borges---clube-da-esquina-1972">Milton Nascimento & Lô Borges - Clube da Esquina (1972)</h3>
<p><img style="float: left; padding-right: 10px;" src="/imgs/music_2020/clube_da_esquina.jpg" width="150" />
<em>Genre: Música Popular Brasileira</em><br />
This is my first exposure to Brazillian music. After discovering this album I tried a number of
other albums in the same genre, but nothing comes close to this.</p>
<div style="clear: left;"></div>
<h3 id="have-a-nice-life---deathconsciousness-2008">Have a Nice Life - Deathconsciousness (2008)</h3>
<p><img style="float: left; padding-right: 10px;" src="/imgs/music_2020/deathconsciousness.jpg" width="150" />
<em>Genre: Post-Punk</em><br />
I have been aware of this album for a long time, since I used to browse /mu/ in high school, but I
only really got into it this year. The album is very intense, and I simply love the lo-fi sound of
the record.</p>
<div style="clear: left;"></div>
<h3 id="shellac---at-action-park-1994">Shellac - At Action Park (1994)</h3>
<p><img style="float: left; padding-right: 10px;" src="/imgs/music_2020/at_action_park.jpg" width="150" />
<em>Genre: Post-Hardcore, Math Rock</em><br />
Beautiful, noisy sound. Reminds me a lot of Number Girl, but with a much stronger beat. It’s hard
not to bop my head to this.</p>
<div style="clear: left;"></div>
<h3 id="idles---joy-as-an-act-of-resistance-2018">IDLES - Joy as an Act of Resistance. (2018)</h3>
<p><img style="float: left; padding-right: 10px;" src="/imgs/music_2020/joy_as_an_act_of_resistance.jpg" width="150" />
<em>Genre: Post-Punk</em><br />
Incredibly energetic punk. It takes a very wholesome spin on the traditional anarchistic lyrics of
punk rock. IDLES also released an album in 2020, and while it’s certainly good, it didn’t leave
nearly as much as an impression on me as this album.</p>
<div style="clear: left;"></div>
<h3 id="purple-mountains---purple-mountains-2019">Purple Mountains - Purple Mountains (2019)</h3>
<p><img style="float: left; padding-right: 10px;" src="/imgs/music_2020/purple_mountains.jpg" width="150" />
<em>Genre: Alt-Country</em><br />
Country rock meets existential dread.</p>
<div style="clear: left;"></div>
<h3 id="acid-bath---when-the-kite-string-pops-1994">Acid Bath - When the Kite String Pops (1994)</h3>
<p><img style="float: left; padding-right: 10px;" src="/imgs/music_2020/when_the_kite_string_pops.jpg" width="150" />
<em>Genre: Sludge Metal</em><br />
Like Melvins’ or Boris’ Sludge Metal records, but much more raw and dark. I love it.</p>
<div style="clear: left;"></div>
<h3 id="black-flag---my-war-1984">Black Flag - My War (1984)</h3>
<p><img style="float: left; padding-right: 10px;" src="/imgs/music_2020/my_war.jpg" width="150" />
<em>Genre: Hardcore Punk</em><br />
The lyrics are a bit angsty, but the dark, raw, noisy sound more than makes up for that.</p>
<div style="clear: left;"></div>
<h3 id="clipping---there-existed-an-addiction-to-blood-2019">clipping. - There Existed an Addiction to Blood (2019)</h3>
<p><img style="float: left; padding-right: 10px;" src="/imgs/music_2020/there_existed_an_addiction_to_blood.jpg" width="150" />
<em>Genre: Industrial Hip-Hop, Horrorcore</em><br />
This is probably the biggest surprise for me this year. I didn’t expect to like
this at all, but I really do. For some reason I find it difficult to like hip-hop, especially rap.
Perhaps the noisy sound of this album is what attracts me to it.</p>
<div style="clear: left;"></div>
<h3 id="built-to-spill---perfect-from-now-on-1997">Built to Spill - Perfect From Now On (1997)</h3>
<p><img style="float: left; padding-right: 10px;" src="/imgs/music_2020/perfect_from_now_on.jpg" width="150" />
<em>Genre: Indie Rock</em><br />
A lo-fi indie pop, but still definitely sounds like the 90s, and definitely holds up against some of
the best albums of that decade.</p>
<div style="clear: left;"></div>
<h3 id="current-93---thunder-perfect-mind-1992">Current 93 - Thunder Perfect Mind (1992)</h3>
<p><img style="float: left; padding-right: 10px;" src="/imgs/music_2020/thunder_perfect_mind.jpeg" width="150" />
<em>Genre: Neofolk</em><br />
A sombre, melancholic, acoustic album. At first I didn’t like this album, but a few months after
listening to this album for a few times several songs got stuck into my head. I then listened to the
album again and found that I suddenly really enjoyed it.</p>
<div style="clear: left;"></div>
<h3 id="comus---first-utterance-1971">Comus - First Utterance (1971)</h3>
<p><img style="float: left; padding-right: 10px;" src="/imgs/music_2020/first_utterance.jpg" width="150" />
<em>Genre: Progressive Folk, Freak Folk</em><br />
Grim, psychedelic and intense. The album draws you in, despite its dark and disturbing content.</p>
<div style="clear: left;"></div>
<h3 id="공중도둑-mid-air-thief---공중도덕-gongjoong-doduk-2015">공중도둑 (Mid-Air Thief) - 공중도덕 (Gongjoong Doduk) (2015)</h3>
<p><img style="float: left; padding-right: 10px;" src="/imgs/music_2020/gongjoong_doduk.jpg" width="150" />
<em>Genre: Psychedelic Folk</em><br />
Psychedelic, yet calm and relaxing. Probably my first exposure to Korean music (other than k-pop),
and I definitely want to hear more of this.</p>
<div style="clear: left;"></div>Rik Voorhaar2020 was a great year for music, I will look back and give some thoughts on the best albums that came out in 20202.