On the timing of slices
%pylab inline
import numpy as np
import matplotlib.pylab as plt
import nibabel as nib
fname = 'bold.nii.gz'
img = nib.load(fname)
We want to do slice timing.
Do we have any interesting information in the scan header?
hdr = img.get_header()
hdr
print(hdr)
The 'pixdim' field says that our volume is 3 x 3 x 3 x 3. What's the fourth 3?
# hdr.get_slice_times()
We really need to know the slice order of this scan, and we can't afford to guess.
What's your guess for this dataset?
How would you check?
Let's say the slices arrived in ascending interleaved order.
slice_order = range(0, 35, 2) + range(1, 35, 2)
slice_order = np.array(slice_order)
slice_order
Revision - what does the 2 mean in the vector above?
What does this mean in terms of times that the slices arrived?
n_slices = img.shape[2]
n_slices
Get the TR information from the header, somehow?
TR = hdr['pixdim'][4]
time_one_slice = TR / n_slices
time_one_slice
At the moment we have the order in which the slices arrived. Now we want the times of each slice, where the first value will be the time of acquisition of the first slice in space, the second value will be the time of acquisition of the second slice in space.
space_to_order = np.argsort(slice_order)
space_to_order
We're going to do some fancy indexing to clarify what space_to_order is
slice_order[space_to_order]
slice_times = space_to_order * time_one_slice
slice_times
Any objections?
Was the first slice really taken at time 0?
slice_times = space_to_order * time_one_slice + time_one_slice / 2.
slice_times
The first value corresponds to the first (0) slice in the volume, and gives the time that slice was aquired from the start of the volume.
(To be explicit, we've taken the position that a slice is aquired in the middle of how long it took to aquire it).
Let's take a voxel time-course from the first slice. You've seen this before.
data = img.get_data()
slice0_vol0 = data[:, :, 0, 0]
plt.gray()
plt.imshow(slice0_vol0)
We take a voxel from the middle of the mid-brain, say row 32, column 25.
slice0_time_course = data[32, 25, 0, :] # all points in time
plt.plot(slice0_time_course)
plt.xlabel('Scan number')
What if I wanted to have 'time' instead of 'Scan number' as the values on the x-axis?
n_scans = img.shape[-1]
scan_starts = np.arange(n_scans) * TR # times scans began
scan_starts[:10]
slice0_times = scan_starts + slice_times[0]
slice0_times[:10]
plt.plot(slice0_times, slice0_time_course)
plt.plot(slice0_times[:10], slice0_time_course[:10], 'r:+')
plt.xlabel('Time of acquisition')
Now let's take a time course from the middle of the TR.
space_to_order
Because this is interleaved, in fact the second slice in space is about half way through the volume in time.
slice1_time_course = data[32, 25, 1, :] # all points in time
plt.plot(slice1_time_course)
plt.xlabel('Scan number')
What times are these?
slice1_times = scan_starts + slice_times[1]
slice1_times[:10]
plt.plot(slice0_times[:10], slice0_time_course[:10], 'r:+', label='slice 0')
plt.hold(True)
plt.plot(slice1_times[:10], slice1_time_course[:10], 'b-o', label='slice 1')
plt.xlabel('Time of acquisition')
plt.legend()
What does slice timing do?
We want values from the blue line, sampled at the times of the red points.
We have values for slice 1 at times slice1_times. But we want to estimate what the values would have been, for slice 1, if slice 1 had arrived at slice0_times. And so on for the other slices.
This is interpolation.
import scipy.interpolate as spi
Let's investigate spi.interp1d
x = slice1_times
y = slice1_time_course
interpolator = spi.interp1d(x, y, 'linear')
interpolator
# What happens here?
# slice1_at_slice0 = interpolator(slice0_times)
This is the problem of extrapolation. What to do?
x = slice1_times
y = slice1_time_course
interpolator = spi.interp1d(x, y, 'linear', bounds_error=False, fill_value=0)
slice1_at_slice0 = interpolator(slice0_times)
plt.plot(slice0_times[:10], slice0_time_course[:10], 'r:+', label='slice 0')
plt.hold(True)
plt.plot(slice1_times[:10], slice1_time_course[:10], 'b-o', label='slice 1')
plt.plot(slice0_times[:10], slice1_at_slice0[:10], 'kx', label='1 -> 0')
plt.xlabel('Time of acquisition')
plt.legend()
x = slice1_times
y = slice1_time_course
interpolator = spi.interp1d(x, y, 'linear', bounds_error=False, fill_value=np.mean(slice1_time_course))
slice1_at_slice0 = interpolator(slice0_times)
plt.plot(slice0_times[:10], slice0_time_course[:10], 'r:+', label='slice 0')
plt.hold(True)
plt.plot(slice1_times[:10], slice1_time_course[:10], 'b-o', label='slice 1')
plt.plot(slice0_times[:10], slice1_at_slice0[:10], 'kx', label='1 -> 0')
plt.xlabel('Time of acquisition')
plt.legend()
Is there a better way of doing the extrapolation?
n_scans
x_padded = np.zeros((n_scans + 2,))
y_padded = np.zeros((n_scans + 2,))
x_padded[1:-1] = slice1_times
x_padded[0] = x[0] - TR
x_padded[-1] = x[-1] + TR
y_padded[1:-1] = slice1_time_course
y_padded[0] = y[0]
y_padded[-1] = y[-1]
interpolator = spi.interp1d(x_padded, y_padded, 'linear')
slice1_at_slice0 = interpolator(slice0_times)
plt.plot(slice0_times[:10], slice0_time_course[:10], 'r:+', label='slice 0')
plt.hold(True)
plt.plot(slice1_times[:10], slice1_time_course[:10], 'b-o', label='slice 1')
plt.plot(slice0_times[:10], slice1_at_slice0[:10], 'kx', label='1 -> 0')
plt.xlabel('Time of acquisition')
plt.legend()
interpolator = spi.interp1d(x_padded, y_padded, 'cubic')
fine_time = np.linspace(x_padded[0], x_padded[9], 100)
predicted_signal = interpolator(fine_time)
plt.plot(fine_time, predicted_signal, ':', label='predicted')
plt.plot(slice1_times[:10], slice1_time_course[:10], 'x', label='actual')
plt.xlabel('Time of acquisition')
plt.legend()
It looks like we're going to be doing a lot of padding over the last axis. Let's make a function for that
def pad_ends(first, middle, last):
""" Pad array `middle` along last axis with `first` value and `last` value
"""
middle = np.array(middle) # Make sure middle is an array
pad_ax_len = middle.shape[-1] # Length of the axis we are padding
pad_shape = middle.shape[:-1] + (pad_ax_len + 2,) # Shape of the padded array
padded = np.empty(pad_shape, dtype=middle.dtype) # Padded array ready to fill
padded[..., 0] = first
padded[..., 1:-1] = middle
padded[..., -1] = last
return padded
assert np.all(pad_ends(0, [2, 3], 5) == [0, 2, 3, 5])
a = np.zeros((2, 3))
b = np.ones((2, 3, 4)) * 10
c = np.ones((2, 3))
assert np.all(pad_ends(a, b, c) == np.concatenate((a.reshape((2, 3, 1)), b, c.reshape((2, 3, 1))), axis=2))
Can we interpolate a whole slice of data at a time?
slice1_all = data[:, :, 1, :]
slice1_all.shape
Can spi.interp1d interpolate this array for me? Or do I have to do this one time course at a time?
x = slice1_times # as before
y = slice1_all # as before
interpolator = spi.interp1d(x, y, 'linear', axis=2, bounds_error=False, fill_value=0)
How would I pack this array with a repeat of the first and last scans?
slice_dims = img.shape[:2]
n_scans = img.shape[3]
slice1_all_padded = pad_ends(slice1_all[:, :, 0], slice1_all, slice1_all[:, :, -1])
interpolator = spi.interp1d(x_padded, slice1_all_padded, 'linear')
slice1_all_at_slice0 = interpolator(slice0_times)
slice1_at_slice0_again = slice1_all_at_slice0[32, 25, :]
plt.plot(slice0_times[:10], slice0_time_course[:10], 'r:+', label='slice 0')
plt.hold(True)
plt.plot(slice1_times[:10], slice1_time_course[:10], 'b-o', label='slice 1')
plt.plot(slice0_times[:10], slice1_at_slice0_again[:10], 'kx', label='1 -> 0')
plt.xlabel('Time of acquisition')
plt.legend()
Make this into a function to interpolate a 3D slice at a time
def interp_slice(old_times, slice_nd, new_times, kind='linear'):
""" Interpolate a 3D slice `slice_nd` with times changing from `old_times` to `new_times`
"""
n_time = slice_nd.shape[-1]
assert n_time == len(old_times)
padded_times = pad_ends(old_times[0] - (old_times[1] - old_times[0]),
old_times,
old_times[-1] + (old_times[-1] - old_times[-2]))
to_interpolate = pad_ends(slice_nd[..., 0], slice_nd, slice_nd[..., -1])
interpolator = spi.interp1d(padded_times, to_interpolate, kind, axis=-1)
return interpolator(new_times)
Check we get the same from the function as we did from the manual way before
interpolated = interp_slice(slice1_times, slice1_all, slice0_times)
assert np.allclose(interpolated, slice1_all_at_slice0)
Now we are ready for full slice-timing glory
data = img.get_data()
slice_times
scan_starts
interp_data = np.empty(data.shape)
desired_times = scan_starts
for slice_no in range(data.shape[-2]):
these_times = slice_times[slice_no] + scan_starts
data_slice = data[:, :, slice_no, :]
interped = interp_slice(these_times, data_slice, desired_times, 'cubic')
interp_data[:, :, slice_no, :] = interped
Check we get the same result as we would from doing this manually
one_tc = data[32, 32, 17, :]
old_times = slice_times[17] + scan_starts
n_scans = len(one_tc)
one_tc_padded = pad_ends(one_tc[0], one_tc, one_tc[-1])
old_times_padded = pad_ends(old_times[0] - TR, old_times, old_times[-1] + TR)
interpolator = spi.interp1d(old_times_padded, one_tc_padded, 'cubic')
tc = interpolator(desired_times)
assert np.allclose(tc, interp_data[32, 32, 17, :])
How about a function to do slice timing on an image?
def slice_time_image(img, slice_times, TR, kind='cubic'):
""" Take nibabel image `img` and run slice timing correction using `slice_times`
"""
data = img.get_data()
assert len(slice_times) == img.shape[-2]
n_scans = img.shape[-1]
scan_starts = np.arange(n_scans) * TR
interp_data = np.empty(data.shape)
desired_times = scan_starts
for slice_no in range(data.shape[-2]):
these_times = slice_times[slice_no] + scan_starts
data_slice = data[:, :, slice_no, :]
interped = interp_slice(these_times, data_slice, desired_times, kind)
interp_data[:, :, slice_no, :] = interped
new_img = nib.Nifti1Image(interp_data, img.get_affine(), img.get_header())
return new_img
new_img = slice_time_image(img, slice_times, TR)
Check we get the same answer as last time
new_data = new_img.get_data()
assert np.allclose(tc, new_data[32, 32, 17, :])
Can we run this in a pipeline?
import os
pth, fname = os.path.split(fname)
pth, fname
new_fname = os.path.join(pth, 'a' + fname)
new_fname
raw_img = nib.load(fname)
interp_img = slice_time_image(raw_img, slice_times, TR)
nib.save(interp_img, new_fname)
How about a function that takes the filename and the slice times, and the TR, and writes out a new filename?
def slice_time_file(fname, slice_times, TR, kind='cubic'):
pth, fname = os.path.split(fname)
new_fname = os.path.join(pth, 'a' + fname)
raw_img = nib.load(fname)
interp_img = slice_time_image(raw_img, slice_times, TR, kind)
nib.save(interp_img, new_fname)
slice_time_file(fname, slice_times, TR, 'cubic')