Source code for tomostream.solver

import cupy as cp
import numpy as np
from cupyx.scipy.fft import rfft, irfft
from cupyx.scipy import ndimage
from tomostream import kernels
from tomostream import retrieve_phase
from tomostream import log

[docs] class Solver(): """Class for tomography reconstruction of ortho-slices through direct discreatization of circular integrals in the Radon transform. Parameters ---------- ntheta : int The number of projections in the buffer (for simultaneous reconstruction) n, nz : int The pixel width and height of the projection. pars: dictionary contatining: center : float32 Rotation center for reconstruction idx, idy, idz: int32 X-Y-Z ortho slices for reconstruction rotx, roty, rotz: float32 Rotation angles for X-Y-Z slices fbpfilter: str Reconstruction filter dezinger: str None or radius for removing outliers energy: float32 Beam energy dist: float32 Source-detector distance alpha: float32 Tuning parameter for phase retrieval pixelsize: float32 Detector pixel size datatype: str Detector data type. """ def __init__(self, ntheta, n, nz, pars, datatype): self.n = n self.nz = nz self.ntheta = ntheta #CPU storage for the buffer self.data = np.zeros([ntheta, nz, n], dtype=datatype) self.theta = np.zeros([ntheta], dtype='float32') # GPU storage for dark and flat fields self.dark = cp.array(cp.zeros([nz, n]), dtype='float32') self.flat = cp.array(cp.ones([nz, n]), dtype='float32') # GPU storages for ortho-slices, and angles self.obj = cp.zeros([n, 3*n], dtype='float32')# ortho-slices are concatenated to one 2D array # reconstruction parameters self.pars = pars # calculate chunk size fo gpu mem = cp.cuda.Device().mem_info[1] self.chunk = min(self.ntheta,int(np.ceil(mem/self.n/self.nz/32)))#cuda raw kernels do not work with huge sizes (issue in cupy?) log.warning(f'chunk size {self.chunk}') # flag controlling appearance of new dark and flat fields self.new_dark_flat = False
[docs] def free(self): """Free GPU memory""" cp.get_default_memory_pool().free_all_blocks()
[docs] def set_dark(self, data): """Copy dark field (already averaged) to GPU""" self.dark = cp.array(data.astype('float32')) self.new_dark_flat = True
[docs] def set_flat(self, data): """Copy flat field (already averaged) to GPU""" self.flat = cp.array(data.astype('float32')) self.new_dark_flat = True
[docs] def backprojection(self, data, theta): """Compute backprojection to orthogonal slices""" obj = cp.zeros([self.n, 3*self.n], dtype='float32') # ortho-slices are concatenated to one 2D array obj[:self.n, :self.n ] = kernels.orthoz(data, theta, self.pars['center'], self.pars['idz'], self.pars['rotz']) obj[:self.nz, self.n :2*self.n] = kernels.orthoy(data, theta, self.pars['center'], self.pars['idy'], self.pars['roty']) obj[:self.nz , 2*self.n:3*self.n] = kernels.orthox(data, theta, self.pars['center'], self.pars['idx'], self.pars['rotx']) obj /= self.ntheta return obj
[docs] def fbp_filter(self, data): """FBP filtering of projections""" t = cp.fft.rfftfreq(self.n) if (self.pars['fbpfilter']=='Parzen'): wfilter = t * (1 - t * 2)**3 elif (self.pars['fbpfilter']=='Ramp'): wfilter = t elif (self.pars['fbpfilter']=='Shepp-logan'): wfilter = np.sin(t) elif (self.pars['fbpfilter']=='Butterworth'):# todo: replace by other wfilter = t / (1+pow(2*t,16)) # as in tomopy wfilter = cp.tile(wfilter, [self.nz, 1]) #data[:] = irfft( #wfilter*rfft(data,overwrite_x=True, axis=2), overwrite_x=True, axis=2) for k in range(data.shape[0]):# work with 2D arrays to save GPU memory data[k] = irfft( wfilter*rfft(data[k], overwrite_x=True, axis=1), overwrite_x=True, axis=1)
[docs] def darkflat_correction(self, data): """Dark-flat field correction""" tmp = cp.maximum(self.flat-self.dark, 1e-6) for k in range(data.shape[0]):# work with 2D arrays to save GPU memory data[k] = (data[k]-self.dark)/tmp
[docs] def minus_log(self, data): """Taking negative logarithm""" for k in range(data.shape[0]):# work with 2D arrays to save GPU memory data[k] = -cp.log(cp.maximum(data[k], 1e-6))
[docs] def remove_outliers(self, data): """Remove outliers""" if(int(self.pars['dezinger'])>0): r = int(self.pars['dezinger']) fdata = ndimage.median_filter(data,[1,r,r]) ids = cp.where(cp.abs(fdata-data)>0.5*cp.abs(fdata)) data[ids] = fdata[ids]
[docs] def phase(self, data): """Retrieve phase""" if(self.pars['alpha']>0): #print('retrieve phase') data = retrieve_phase.paganin_filter( data, self.pars['pixelsize']*1e-4, self.pars['dist']/10, self.pars['energy'], self.pars['alpha'])
[docs] def recon(self, data, theta): """Reconstruction with the standard processing pipeline on GPU""" self.darkflat_correction(data) self.remove_outliers(data) self.phase(data) self.minus_log(data) self.fbp_filter(data) obj = self.backprojection(data, theta*np.pi/180) return obj
[docs] def recon_by_chunks(self, data, theta): """Reconstruction with splitting data by chunks processed on GPU""" obj = cp.zeros([self.n, 3*self.n], dtype='float32')# ortho-slices are concatenated to one 2D array nchunks = int(np.ceil(data.shape[0]/self.chunk)) for ichunk in range(nchunks): data_gpu = cp.array(data[ichunk*self.chunk:min((ichunk+1)*self.chunk,data.shape[0])]).astype('float32') theta_gpu = cp.array(theta[ichunk*self.chunk:min((ichunk+1)*self.chunk,data.shape[0])]).astype('float32') obj += self.recon(data_gpu,theta_gpu) return obj
[docs] def recon_optimized(self, data, theta, ids, pars): """Optimized reconstruction of the object from the whole set of projections in the interval of size pi. Resulting reconstruction is obtained by replacing the reconstruction part corresponding to incoming projections, objnew = objold + recon(datanew) - recon(dataold) whenever the number of incoming projections is less than half of the buffer size. Reconstruction is done with using the whole buffer only when: the number of incoming projections is greater than half of the buffer size, idx/idy/idz, center, fbpfilter are changed, or new dark/flat fields are acquired. Parameters ---------- data : np.array(nproj,nz,n) Projection data theta : np.array(nproj) Angles corresponding to the projection data ids : np.array(nproj) Ids of the data in the circular buffer array pars: dictionary contatining: center : float32 Rotation center for reconstruction idx, idy, idz: int32 X-Y-Z ortho slices for reconstruction rotx, roty, rotz: float32 Rotation angles for X-Y-Z slices fbpfilter: str Reconstruction filter dezinger: str None or radius for removing outliers energy: float32 Beam energy dist: float32 Source-detector distance alpha: float32 Tuning parameter for phase retrieval pixelsize: float32 Detector pixel size Return ---------- obj: np.array(n,3*n) Concatenated reconstructions for X-Y-Z orthoslices """ # recompute only by replacing a part of the data in the buffer, or by using the whole buffer recompute_part = not (pars!=self.pars or self.new_dark_flat or len(ids) > self.ntheta//2) if(recompute_part): # subtract old part self.obj -= self.recon_by_chunks(self.data[ids], self.theta[ids]) # update data in the buffer self.data[ids] = data.reshape(data.shape[0], self.nz, self.n) self.theta[ids] = theta self.pars = pars.copy() self.new_dark_flat = False if(recompute_part): # add new part self.obj += self.recon_by_chunks(self.data[ids], self.theta[ids]) else: self.obj = self.recon_by_chunks(self.data, self.theta) return self.obj.get()