最近,pytorch 更新了 1.7.1, 支持了复数。并且torch.fft支持的文档也说明的很清楚。https://pytorch.org/docs/stable/search.html?q=fft&check_keywords=yes&area=default
# x = create_complex_number()
# # dataset = MRBrainS18Dataset()
# real_t = torch.tensor([0.+1.j, 2.+3.j, 4.+5.j, 6.+7.j])
# comp_t = torch.fft.fft(real_t)
# print(comp_t)
# ireal_t = torch.fft.ifft(comp_t)
# print(ireal_t)
import scipy.io as scio
from skimage import io
path = r". rain_dataTRAINDATA_11.mat"
item = scio.loadmat(path)
CS_K_Data, IM, K_Data, mask = item['CS_K_Data'], item['IM'], item['K_Data'], item['mask']
CS_K_Data, IM, K_Data, mask = CS_K_Data.astype(np.complex64), IM.astype(np.complex64), K_Data.astype(np.complex64), mask.astype(np.double)
CS_K_Data, IM, K_Data = CS_K_Data[np.newaxis, np.newaxis, ...], IM[np.newaxis, np.newaxis, ...], K_Data[np.newaxis, np.newaxis, ...]
# IM255 = 255*(IM-np.min(IM))/(np.max(IM)-np.min(IM))
IM_tensor = torch.tensor(IM, dtype=torch.complex64)
K_Data_tensor = torch.tensor(K_Data, dtype=torch.complex64)
# K_Data_tensor = torch.ifft.shift(K_Data_tensor)
# ! from IM to k-space
fake_k = torch.fft.ifftn(IM_tensor, dim=(2,3), norm="backward")
# "forward" - no normalization
# "backward" - normalize by 1/n Default is "backward", (normalize by 1/n).
# "ortho" - normalize by 1/sqrt(n) (making the IFFT orthonormal)
# print(torch.max(K_Data_tensor), torch.max(fake_k))
# print(torch.min(K_Data_tensor), torch.min(fake_k))
print('k', torch.mean(torch.abs(K_Data_tensor)), " and ", torch.mean(torch.abs(fake_k)))
# ! from k-space to IM
fake_im = torch.fft.fftn(K_Data_tensor, dim=(2,3), norm="backward")
#
# "forward" - normalize by 1/n
# "backward" - no normalization Default is "backward" (no normalization).
# "ortho" - normalize by 1/sqrt(n) (making the FFT orthonormal)
# print(torch.max(IM_tensor), torch.max(fake_im))
# print(torch.min(IM_tensor), torch.min(fake_im))
print('IM', torch.mean(torch.abs(IM_tensor)), " and ", torch.mean(torch.abs(fake_im)))
# from shift to ishift
xx = 1
print 的若两者相等,则傅里叶变换成立。这个数据是之前比赛的。太坑了。fft和ifft 用反了。
# matlab example
# magic = torch.tensor([[8,1,6],[3,5,7],[4,9,2]])
# fmagic = torch.fft.fftn(magic)
# print(fmagic)
# imagic = torch.fft.ifftn(fmagic)
# print(imagic)