On the timing of slices

In [1]:
%pylab inline
Welcome to pylab, a matplotlib-based Python environment [backend: module://IPython.kernel.zmq.pylab.backend_inline].
For more information, type 'help(pylab)'.
In [2]:
import numpy as np
In [3]:
import matplotlib.pylab as plt
In [4]:
import nibabel as nib
In [5]:
fname = 'bold.nii.gz'
In [6]:
img = nib.load(fname)

We want to do slice timing.

Do we have any interesting information in the scan header?

In [7]:
hdr = img.get_header()
hdr
Out[7]:
<nibabel.nifti1.Nifti1Header at 0x4410210>
In [8]:
print(hdr)
<class 'nibabel.nifti1.Nifti1Header'> object, endian='<'
sizeof_hdr      : 348
data_type       : 
db_name         : 
extents         : 0
session_error   : 0
regular         : r
dim_info        : 0
dim             : [  4  64  64  35 165   1   1   1]
intent_p1       : 0.0
intent_p2       : 0.0
intent_p3       : 0.0
intent_code     : none
datatype        : int16
bitpix          : 16
slice_start     : 0
pixdim          : [ 1.  3.  3.  3.  3.  0.  0.  0.]
vox_offset      : 352.0
scl_slope       : 1.0
scl_inter       : 0.0
slice_end       : 0
slice_code      : unknown
xyzt_units      : 10
cal_max         : 0.0
cal_min         : 0.0
slice_duration  : 0.0
toffset         : 0.0
glmax           : 0
glmin           : 0
descrip         : FSL4.0
aux_file        : 
qform_code      : scanner
sform_code      : scanner
quatern_b       : 0.0
quatern_c       : 0.0
quatern_d       : 0.0
qoffset_x       : -93.0
qoffset_y       : -103.418151855
qoffset_z       : -45.4796600342
srow_x          : [  3.   0.   0. -93.]
srow_y          : [   0.            3.            0.         -103.41815186]
srow_z          : [  0.           0.           3.         -45.47966003]
intent_name     : 
magic           : n+1

The 'pixdim' field says that our volume is 3 x 3 x 3 x 3. What's the fourth 3?

In [9]:
# hdr.get_slice_times()

We really need to know the slice order of this scan, and we can't afford to guess.

What's your guess for this dataset?

How would you check?

Let's say the slices arrived in ascending interleaved order.

In [10]:
slice_order = range(0, 35, 2) + range(1, 35, 2)
slice_order = np.array(slice_order)
slice_order
Out[10]:
array([ 0,  2,  4,  6,  8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32,
       34,  1,  3,  5,  7,  9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31,
       33])

Revision - what does the 2 mean in the vector above?

What does this mean in terms of times that the slices arrived?

In [11]:
n_slices = img.shape[2]
n_slices
Out[11]:
35

Get the TR information from the header, somehow?

In [12]:
TR = hdr['pixdim'][4]
time_one_slice = TR / n_slices
time_one_slice
Out[12]:
0.085714285714285715

At the moment we have the order in which the slices arrived. Now we want the times of each slice, where the first value will be the time of acquisition of the first slice in space, the second value will be the time of acquisition of the second slice in space.

In [13]:
space_to_order = np.argsort(slice_order)
space_to_order
Out[13]:
array([ 0, 18,  1, 19,  2, 20,  3, 21,  4, 22,  5, 23,  6, 24,  7, 25,  8,
       26,  9, 27, 10, 28, 11, 29, 12, 30, 13, 31, 14, 32, 15, 33, 16, 34,
       17])

We're going to do some fancy indexing to clarify what space_to_order is

In [14]:
slice_order[space_to_order]
Out[14]:
array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
       34])
In [15]:
slice_times = space_to_order * time_one_slice
slice_times
Out[15]:
array([ 0.        ,  1.54285714,  0.08571429,  1.62857143,  0.17142857,
        1.71428571,  0.25714286,  1.8       ,  0.34285714,  1.88571429,
        0.42857143,  1.97142857,  0.51428571,  2.05714286,  0.6       ,
        2.14285714,  0.68571429,  2.22857143,  0.77142857,  2.31428571,
        0.85714286,  2.4       ,  0.94285714,  2.48571429,  1.02857143,
        2.57142857,  1.11428571,  2.65714286,  1.2       ,  2.74285714,
        1.28571429,  2.82857143,  1.37142857,  2.91428571,  1.45714286])

Any objections?

Was the first slice really taken at time 0?

In [16]:
slice_times = space_to_order * time_one_slice + time_one_slice / 2.
slice_times
Out[16]:
array([ 0.04285714,  1.58571429,  0.12857143,  1.67142857,  0.21428571,
        1.75714286,  0.3       ,  1.84285714,  0.38571429,  1.92857143,
        0.47142857,  2.01428571,  0.55714286,  2.1       ,  0.64285714,
        2.18571429,  0.72857143,  2.27142857,  0.81428571,  2.35714286,
        0.9       ,  2.44285714,  0.98571429,  2.52857143,  1.07142857,
        2.61428571,  1.15714286,  2.7       ,  1.24285714,  2.78571429,
        1.32857143,  2.87142857,  1.41428571,  2.95714286,  1.5       ])

The first value corresponds to the first (0) slice in the volume, and gives the time that slice was aquired from the start of the volume.

(To be explicit, we've taken the position that a slice is aquired in the middle of how long it took to aquire it).

Let's take a voxel time-course from the first slice. You've seen this before.

In [17]:
data = img.get_data()
slice0_vol0 = data[:, :, 0, 0]
plt.gray()
plt.imshow(slice0_vol0)
Out[17]:
<matplotlib.image.AxesImage at 0x4444ed0>

We take a voxel from the middle of the mid-brain, say row 32, column 25.

In [18]:
slice0_time_course = data[32, 25, 0, :] # all points in time
plt.plot(slice0_time_course)
plt.xlabel('Scan number')
Out[18]:
<matplotlib.text.Text at 0x4410610>

What if I wanted to have 'time' instead of 'Scan number' as the values on the x-axis?

In [19]:
n_scans = img.shape[-1]
scan_starts = np.arange(n_scans) * TR # times scans began
scan_starts[:10]
Out[19]:
array([  0.,   3.,   6.,   9.,  12.,  15.,  18.,  21.,  24.,  27.])
In [20]:
slice0_times = scan_starts + slice_times[0]
slice0_times[:10]
Out[20]:
array([  0.04285714,   3.04285714,   6.04285714,   9.04285714,
        12.04285714,  15.04285714,  18.04285714,  21.04285714,
        24.04285714,  27.04285714])
In [21]:
plt.plot(slice0_times, slice0_time_course)
Out[21]:
[<matplotlib.lines.Line2D at 0x48af0d0>]
In [22]:
plt.plot(slice0_times[:10], slice0_time_course[:10], 'r:+')
plt.xlabel('Time of acquisition')
Out[22]:
<matplotlib.text.Text at 0x48af990>

Now let's take a time course from the middle of the TR.

In [23]:
space_to_order
Out[23]:
array([ 0, 18,  1, 19,  2, 20,  3, 21,  4, 22,  5, 23,  6, 24,  7, 25,  8,
       26,  9, 27, 10, 28, 11, 29, 12, 30, 13, 31, 14, 32, 15, 33, 16, 34,
       17])

Because this is interleaved, in fact the second slice in space is about half way through the volume in time.

In [24]:
slice1_time_course = data[32, 25, 1, :] # all points in time
plt.plot(slice1_time_course)
plt.xlabel('Scan number')
Out[24]:
<matplotlib.text.Text at 0x487fd10>

What times are these?

In [25]:
slice1_times = scan_starts + slice_times[1]
slice1_times[:10]
Out[25]:
array([  1.58571429,   4.58571429,   7.58571429,  10.58571429,
        13.58571429,  16.58571429,  19.58571429,  22.58571429,
        25.58571429,  28.58571429])
In [26]:
plt.plot(slice0_times[:10], slice0_time_course[:10], 'r:+', label='slice 0')
plt.hold(True)
plt.plot(slice1_times[:10], slice1_time_course[:10], 'b-o', label='slice 1')
plt.xlabel('Time of acquisition')
plt.legend()
Out[26]:
<matplotlib.legend.Legend at 0x4abb890>

What does slice timing do?

We want values from the blue line, sampled at the times of the red points.

We have values for slice 1 at times slice1_times. But we want to estimate what the values would have been, for slice 1, if slice 1 had arrived at slice0_times. And so on for the other slices.

This is interpolation.

In [27]:
import scipy.interpolate as spi

Let's investigate spi.interp1d

In [28]:
x = slice1_times
y = slice1_time_course
interpolator = spi.interp1d(x, y, 'linear')
interpolator
Out[28]:
<scipy.interpolate.interpolate.interp1d at 0x45a6490>
In [29]:
# What happens here? 
# slice1_at_slice0 = interpolator(slice0_times)

This is the problem of extrapolation. What to do?

In [30]:
x = slice1_times
y = slice1_time_course
interpolator = spi.interp1d(x, y, 'linear', bounds_error=False, fill_value=0)
In [31]:
slice1_at_slice0 = interpolator(slice0_times)
plt.plot(slice0_times[:10], slice0_time_course[:10], 'r:+', label='slice 0')
plt.hold(True)
plt.plot(slice1_times[:10], slice1_time_course[:10], 'b-o', label='slice 1')
plt.plot(slice0_times[:10], slice1_at_slice0[:10], 'kx', label='1 -> 0')
plt.xlabel('Time of acquisition')
plt.legend()
Out[31]:
<matplotlib.legend.Legend at 0x7ff7e860fcd0>
In [32]:
x = slice1_times
y = slice1_time_course
interpolator = spi.interp1d(x, y, 'linear', bounds_error=False, fill_value=np.mean(slice1_time_course))
In [33]:
slice1_at_slice0 = interpolator(slice0_times)
plt.plot(slice0_times[:10], slice0_time_course[:10], 'r:+', label='slice 0')
plt.hold(True)
plt.plot(slice1_times[:10], slice1_time_course[:10], 'b-o', label='slice 1')
plt.plot(slice0_times[:10], slice1_at_slice0[:10], 'kx', label='1 -> 0')
plt.xlabel('Time of acquisition')
plt.legend()
Out[33]:
<matplotlib.legend.Legend at 0x4fcae50>

Is there a better way of doing the extrapolation?

In [34]:
n_scans
Out[34]:
165
In [35]:
x_padded = np.zeros((n_scans + 2,))
y_padded = np.zeros((n_scans + 2,))
x_padded[1:-1] = slice1_times
x_padded[0] = x[0] - TR
x_padded[-1] = x[-1] + TR
y_padded[1:-1] = slice1_time_course
y_padded[0] = y[0]
y_padded[-1] = y[-1]
interpolator = spi.interp1d(x_padded, y_padded, 'linear')
In [36]:
slice1_at_slice0 = interpolator(slice0_times)
plt.plot(slice0_times[:10], slice0_time_course[:10], 'r:+', label='slice 0')
plt.hold(True)
plt.plot(slice1_times[:10], slice1_time_course[:10], 'b-o', label='slice 1')
plt.plot(slice0_times[:10], slice1_at_slice0[:10], 'kx', label='1 -> 0')
plt.xlabel('Time of acquisition')
plt.legend()
Out[36]:
<matplotlib.legend.Legend at 0x521d190>
In [37]:
interpolator = spi.interp1d(x_padded, y_padded, 'cubic')
In [38]:
fine_time = np.linspace(x_padded[0], x_padded[9], 100)
predicted_signal = interpolator(fine_time)
plt.plot(fine_time, predicted_signal, ':', label='predicted')
plt.plot(slice1_times[:10], slice1_time_course[:10], 'x', label='actual')
plt.xlabel('Time of acquisition')
plt.legend()
Out[38]:
<matplotlib.legend.Legend at 0x5462c90>

It looks like we're going to be doing a lot of padding over the last axis. Let's make a function for that

In [39]:
def pad_ends(first, middle, last):
    """ Pad array `middle` along last axis with `first` value and `last` value
    """
    middle = np.array(middle) # Make sure middle is an array
    pad_ax_len = middle.shape[-1] # Length of the axis we are padding
    pad_shape = middle.shape[:-1] + (pad_ax_len + 2,) # Shape of the padded array
    padded = np.empty(pad_shape, dtype=middle.dtype) # Padded array ready to fill
    padded[..., 0] = first
    padded[..., 1:-1] = middle
    padded[..., -1] = last
    return padded
In [40]:
assert np.all(pad_ends(0, [2, 3], 5) == [0, 2, 3, 5])
a = np.zeros((2, 3))
b = np.ones((2, 3, 4)) * 10
c = np.ones((2, 3))
assert np.all(pad_ends(a, b, c) == np.concatenate((a.reshape((2, 3, 1)), b, c.reshape((2, 3, 1))), axis=2))

Can we interpolate a whole slice of data at a time?

In [41]:
slice1_all = data[:, :, 1, :]
slice1_all.shape
Out[41]:
(64, 64, 165)

Can spi.interp1d interpolate this array for me? Or do I have to do this one time course at a time?

In [42]:
x = slice1_times # as before
y = slice1_all # as before
interpolator = spi.interp1d(x, y, 'linear', axis=2, bounds_error=False, fill_value=0)

How would I pack this array with a repeat of the first and last scans?

In [43]:
slice_dims = img.shape[:2]
n_scans = img.shape[3]
slice1_all_padded = pad_ends(slice1_all[:, :, 0], slice1_all, slice1_all[:, :, -1])
In [44]:
interpolator = spi.interp1d(x_padded, slice1_all_padded, 'linear')
In [45]:
slice1_all_at_slice0 = interpolator(slice0_times)
slice1_at_slice0_again = slice1_all_at_slice0[32, 25, :]
plt.plot(slice0_times[:10], slice0_time_course[:10], 'r:+', label='slice 0')
plt.hold(True)
plt.plot(slice1_times[:10], slice1_time_course[:10], 'b-o', label='slice 1')
plt.plot(slice0_times[:10], slice1_at_slice0_again[:10], 'kx', label='1 -> 0')
plt.xlabel('Time of acquisition')
plt.legend()
Out[45]:
<matplotlib.legend.Legend at 0x56b2a90>

Make this into a function to interpolate a 3D slice at a time

In [46]:
def interp_slice(old_times, slice_nd, new_times, kind='linear'):
    """ Interpolate a 3D slice `slice_nd` with times changing from `old_times` to `new_times`
    """
    n_time = slice_nd.shape[-1]
    assert n_time == len(old_times)
    padded_times = pad_ends(old_times[0] - (old_times[1] - old_times[0]), 
                        old_times,
                        old_times[-1] + (old_times[-1] - old_times[-2]))
    to_interpolate = pad_ends(slice_nd[..., 0], slice_nd, slice_nd[..., -1])
    interpolator = spi.interp1d(padded_times, to_interpolate, kind, axis=-1)
    return interpolator(new_times)

Check we get the same from the function as we did from the manual way before

In [47]:
interpolated = interp_slice(slice1_times, slice1_all, slice0_times)
In [48]:
assert np.allclose(interpolated, slice1_all_at_slice0)

Now we are ready for full slice-timing glory

In [49]:
data = img.get_data()
In [50]:
slice_times
Out[50]:
array([ 0.04285714,  1.58571429,  0.12857143,  1.67142857,  0.21428571,
        1.75714286,  0.3       ,  1.84285714,  0.38571429,  1.92857143,
        0.47142857,  2.01428571,  0.55714286,  2.1       ,  0.64285714,
        2.18571429,  0.72857143,  2.27142857,  0.81428571,  2.35714286,
        0.9       ,  2.44285714,  0.98571429,  2.52857143,  1.07142857,
        2.61428571,  1.15714286,  2.7       ,  1.24285714,  2.78571429,
        1.32857143,  2.87142857,  1.41428571,  2.95714286,  1.5       ])
In [51]:
scan_starts
Out[51]:
array([   0.,    3.,    6.,    9.,   12.,   15.,   18.,   21.,   24.,
         27.,   30.,   33.,   36.,   39.,   42.,   45.,   48.,   51.,
         54.,   57.,   60.,   63.,   66.,   69.,   72.,   75.,   78.,
         81.,   84.,   87.,   90.,   93.,   96.,   99.,  102.,  105.,
        108.,  111.,  114.,  117.,  120.,  123.,  126.,  129.,  132.,
        135.,  138.,  141.,  144.,  147.,  150.,  153.,  156.,  159.,
        162.,  165.,  168.,  171.,  174.,  177.,  180.,  183.,  186.,
        189.,  192.,  195.,  198.,  201.,  204.,  207.,  210.,  213.,
        216.,  219.,  222.,  225.,  228.,  231.,  234.,  237.,  240.,
        243.,  246.,  249.,  252.,  255.,  258.,  261.,  264.,  267.,
        270.,  273.,  276.,  279.,  282.,  285.,  288.,  291.,  294.,
        297.,  300.,  303.,  306.,  309.,  312.,  315.,  318.,  321.,
        324.,  327.,  330.,  333.,  336.,  339.,  342.,  345.,  348.,
        351.,  354.,  357.,  360.,  363.,  366.,  369.,  372.,  375.,
        378.,  381.,  384.,  387.,  390.,  393.,  396.,  399.,  402.,
        405.,  408.,  411.,  414.,  417.,  420.,  423.,  426.,  429.,
        432.,  435.,  438.,  441.,  444.,  447.,  450.,  453.,  456.,
        459.,  462.,  465.,  468.,  471.,  474.,  477.,  480.,  483.,
        486.,  489.,  492.])
In [52]:
interp_data = np.empty(data.shape)
desired_times = scan_starts
In [53]:
for slice_no in range(data.shape[-2]):
    these_times = slice_times[slice_no] + scan_starts
    data_slice = data[:, :, slice_no, :]
    interped = interp_slice(these_times, data_slice, desired_times, 'cubic')
    interp_data[:, :, slice_no, :] = interped

Check we get the same result as we would from doing this manually

In [54]:
one_tc = data[32, 32, 17, :]
old_times = slice_times[17] + scan_starts
n_scans = len(one_tc)
one_tc_padded = pad_ends(one_tc[0], one_tc, one_tc[-1])
old_times_padded = pad_ends(old_times[0] - TR, old_times, old_times[-1] + TR)
interpolator = spi.interp1d(old_times_padded, one_tc_padded, 'cubic')
tc = interpolator(desired_times)
In [55]:
assert np.allclose(tc, interp_data[32, 32, 17, :])

How about a function to do slice timing on an image?

In [56]:
def slice_time_image(img, slice_times, TR, kind='cubic'):
    """ Take nibabel image `img` and run slice timing correction using `slice_times`
    """
    data = img.get_data()
    assert len(slice_times) == img.shape[-2]
    n_scans = img.shape[-1]
    scan_starts = np.arange(n_scans) * TR
    interp_data = np.empty(data.shape)
    desired_times = scan_starts
    for slice_no in range(data.shape[-2]):
        these_times = slice_times[slice_no] + scan_starts
        data_slice = data[:, :, slice_no, :]
        interped = interp_slice(these_times, data_slice, desired_times, kind)
        interp_data[:, :, slice_no, :] = interped
    new_img = nib.Nifti1Image(interp_data, img.get_affine(), img.get_header())
    return new_img
In [57]:
new_img = slice_time_image(img, slice_times, TR)

Check we get the same answer as last time

In [58]:
new_data = new_img.get_data()
assert np.allclose(tc, new_data[32, 32, 17, :])

Can we run this in a pipeline?

In [59]:
import os
pth, fname = os.path.split(fname)
pth, fname
Out[59]:
('', 'bold.nii.gz')
In [60]:
new_fname = os.path.join(pth, 'a' + fname)
new_fname
Out[60]:
'abold.nii.gz'
In [61]:
raw_img = nib.load(fname)
interp_img = slice_time_image(raw_img, slice_times, TR)
nib.save(interp_img, new_fname)

How about a function that takes the filename and the slice times, and the TR, and writes out a new filename?

In [62]:
def slice_time_file(fname, slice_times, TR, kind='cubic'):
    pth, fname = os.path.split(fname)
    new_fname = os.path.join(pth, 'a' + fname)
    raw_img = nib.load(fname)
    interp_img = slice_time_image(raw_img, slice_times, TR, kind)
    nib.save(interp_img, new_fname)
In [63]:
slice_time_file(fname, slice_times, TR, 'cubic')