diff --git a/Project.toml b/Project.toml index d6ea128..9bad97b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "STFT" uuid = "58bb99bf-048b-48b7-93e7-1cbf3ee61509" authors = ["Szymon M. Woźniak "] -version = "1.2.1" +version = "1.3.0" [deps] FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" diff --git a/src/STFT.jl b/src/STFT.jl index e7bd0d7..63d3d4c 100644 --- a/src/STFT.jl +++ b/src/STFT.jl @@ -8,6 +8,18 @@ export stft, istft _fft(x::AbstractArray{<:Real}, d) = rfft(x, d) _fft(x::AbstractArray{<:Complex}, d) = fft(x, d) +function malloc_stft( + x::M, + w::V, + L::I = zero(I), + N::I = length(w); +) where {T<:Number, I<:Integer, V<:AbstractVector{T}, M<:AbstractMatrix{T}} + X, K = size(x) # Length of the signal in samples + W = length(w) # Length of the window in samples + S = (X-L) ÷ (W - L) # Number of segments + N = N < W ? W : N # DFT size + zeros(T, (N, S, K)) # Allocate container for signal segments +end doc_analysis = """ @@ -94,6 +106,22 @@ function analysis() end "$doc_analysis" stft(x, w, L=0, N=length(w)) = analysis(x, w, L, N) +function analysis( + x::M, + w::V, + L::I = zero(I), + N::I = length(w); +) where {T<:Number, I<:Integer, V<:AbstractVector{T}, M<:AbstractMatrix{T}} + sc = malloc_stft(x, w, L, N) + N, S, K = sc |> size + W = w |> length + + @turbo for k ∈ 1:K, s ∈ 1:S, n ∈ 1:W + sc[n, s, k] = w[n] * x[(s-1)*(W-L)+n, k] + end + _fft(sc, 1) # Convert segments to frequency-domain +end + function analysis( x::V, w::V, @@ -104,25 +132,6 @@ function analysis( @view analysis(xx, w, L, N)[:, :, 1] end -function analysis( - x::M, - w::V, - L::I = zero(I), - N::I = length(w); -) where {T<:Number, I<:Integer, V<:AbstractVector{T}, M<:AbstractMatrix{T}} - X, K = size(x) # Length of the signal in samples - W = length(w) # Length of the window in samples - H = W - L # Hop - S = (X-L) ÷ H # Number of segments - N = N < W ? W : N # DFT size - sc = zeros(T, N, S, K) # Allocate container for signal segments - - @turbo for s ∈ 1:S, k ∈ 1:K, n ∈ 1:W - sc[n, s, k] = w[n] * x[(s-1)*H+n, k] - end - _fft(sc, 1) # Convert segments to frequency-domain -end - function analysis( xs::AbstractArray{<:AbstractVector{T}}, w::AbstractVector{T}, @@ -250,4 +259,22 @@ function synthesis( xn ./ xd # Normalize end +""" +Real-value signal STFT with constant input size. +""" +function rSTFTm(A, w, L, N=length(w)) + mem = STFT.malloc_stft(A, w, L, N) + N, S, K = mem |> size + W = w |> length + P = plan_rfft(mem, 1) + function f(x) + @turbo for k ∈ 1:K, s ∈ 1:S, n ∈ 1:W + mem[n, s, k] = w[n] * x[(s-1)*(W-L)+n, k] + end + P * mem + end + return f +end + + end # module