Math utilities and solvers

Output (chiphifunc_test_suite.py)

aqsc.compare_chiphifunc(A:ChiPhiFunc, B:ChiPhiFunc, trig_mode=False, simple_mode=True, colormap_mode=False)

Compares and plots two ChiPhiFunc's.

Parameters: - A : ChiPhiFunc - A ChiPhiFunc. - B : ChiPhiFunc - The other ChiPhiFunc - trig_mode : bool=False - When True, plots and coefficients. Otherwise plots coefficients. - simple_mode : bool=True - When True, only plots the difference. Otherwise also plots the values of A and B. - colormap_mode : bool=False - When True, plots the components in a colormap. Otherwise makes line plots. See ChiPhiFunc.display_content().

Math (math_utilities.py)

aqsc.py_sum(expr, lower:int, upper:int)

Sums a lambda function expr that takes one static int input from lower to upper. Will be possible to JIT compile if expr can be JIT compiled.

Parameters: - expr : lambda (must obey JAX rules) - The expression to sum. - lower : int (static) - The summation's lower bound. - upper : int (static) - The summation's upper bound.

Returns: - The evaluation result.

aqsc.is_seq(a, b)

Returns 1 if a<=b, and 0 otherwise.

Parameters: - a,b (static).

Returns: - 0 or 1.

aqsc.is_integer(a)

Returns 1 if a is an integer, and 0 otherwise.

Parameters: - a (static).

Returns: - 0 or 1.

aqsc.diff(y, is_chi1:bool, order1:int, is_chi2=None, order2=None)

Used to evaluate expressions from Maxima. Takes and/or derivatives of a quantity.

Parameters: - y (traced) - A scalar or ChiPhiFunc. Any other input type will result in ChiPhiFunc(nfp=-13). - is_chi1 : bool (static) - If True, takes detivative, otherwise takes derivative - order1 : int (static) - Order of the first derivative. - is_chi2 (static) - Optional. If None, doesn't take another derivative. If True, takes detivative, otherwise takes derivative - order2 (static) - Optional. Order of the second derivative.

Returns: - A ChiPhiFunc, 0, or a ChiPhiFunc(nfp=-13)

Operator generators (chiphifunc.py)

aqsc.dchi_op(n_dim:int)

Generates a differential operator that performs derivative in . The operator is applied using diff_matrix@f.content = dchi(f).content.

Parameters: - ndim : int (static) - Total number of harmonics. (NOT maximum ).

Returns: - (n_dim, n_dim) a jax.numpy array.

aqsc.trig_to_exp_op(n_dim:int)

Generates an operator that converts a ChiPhiFunc from trigonometric to exponental Fourier series. The operator is applied using diff_matrix@f.content = dchi(f).content.

Parameters: - n_dim : int (static) - Total number of harmonics. (NOT maximum ).

Returns: - (n_dim, n_dim) a jax.numpy array.

aqsc.exp_to_trig_op(n_dim:int)

Generates an operator that converts a ChiPhiFunc from exponential to trigonometric Fourier series. The operator is applied using diff_matrix@f.content = dchi(f).content.

Parameters: - n_dim : int (static) - Total number of harmonics. (NOT maximum ).

Returns: - (n_dim, n_dim) jax.numpy array.

aqsc.max_log10(input)

Calculates the maximum amplitude's order of magnitude

Parameters: - input : jax.numpy.array (traced) - Input data

Returns:

  • The log 10 of the maximum element in input.

aqsc.phi_avg(in_quant)

A type-insensitive phi-averaging function that: - Averages along phi and output a ChiPhiFunc if the input is a ChiPhiFunc. - Does nothing if the input is a scalar.

Parameters: - in_quant (traced) - Can be scalar or ChiPhiFunc.

Returns: - A jax.numpy.complex128 average.

PDE and ODE solver

get_O_O_einv_from_A_B(chiphifunc_A:ChiPhiFunc, chiphifunc_B:ChiPhiFunc, rank_rhs:int, Y1c_mode:bool)

(NEED DOUBLE-CHECK) Get O, O_einv and vector_free_coef that solves Where the operator acting on is known to be of rank . This equation's solution is known to be or Where is a stack of operators acting on the components of , or is a function of , and is a function of

Inputs: - A, B : ChiPhiFunc (traced) - coefficients - rank_rhs : int (static) - The number of rows in the RHS.

Returns: O_matrices, O_einv, vector_free_coef, Y_nfp - O_matrices : jax.numpy.array (traced) - An (n+2, n+1, len_phi) operator equivalent to - O_einv : jax.numpy.array (traced) - The above-mentioned operator. Has shape (n+1, n+2, len_phi). - vector_free_coef : jax.numpy.array (traced) - The above-mentioned . Has shape (n+1,len_phi). - Y_nfp : int (static) - The nfp of . For internal use.

aqsc.solve_1d_asym(p_eff, f_eff) (Depreciated)

Solves a linear ODE of form using asymptotic series. It only works well when the minimum amplitude of p (?) is large, and in these cases FFT works well too. The maximum truncation order is set in ChiPhiFunc.py.

Parameters:

  • p_eff : jnp.numpy.array (traced) - Coefficient of y. 1d array.
  • f_eff : jnp.numpy.array (traced) - The RHS. 1d array.

Returns:

  • A jnp.numpy.array containing the solution to the equation. When p_eff is 0, it will be defaulted to the anti-derivative of f_eff with zero average.

aqsc.solve_1d_fft(p_eff, f_eff, static_max_freq:int=None)

Solves a linear ODE of form using spectral method. The maximum truncation order is set in ChiPhiFunc.py.

Parameters:

  • p_eff : jnp.numpy.array (traced) - Coefficient of y. 1d array.
  • f_eff : jnp.numpy.array (traced) - The RHS. 1d array.
  • static_max_freq : int (static) - Maximum number of Fourier harmonics used. Should be set as low as posssible while trying not losing too much accuracy to prevent high-frequency noise from blowing up. Need to be find empirically.

Returns:

  • A jnp.numpy.array containing the solution to the equation. When p_eff is 0, it will be defaulted to the anti-derivative of f_eff with zero average.

aqsc.solve_ODE(coeff_arr, coeff_dp_arr, f_arr, static_max_freq: int=None)

Solves a list of linear first order ODE systems in of form (equivalent to ) using spectral method.

NOTE: Does not work well for p>10 with zeros or resonant p without low-pass filtering.

Parameters: - coeff_arr, coeff_dp_arr, f_arr : jax.numpy.array (traced) - Components of the equations as 2d matrices. Its axis=0 is equation indices and axis=1 is dependence on grid points. All quantities are assumed periodic. - static_max_freq : int (static): Maximum number of Fourier harmonics used. Should be set as low as posssible while trying not losing too much accuracy to prevent high-frequency noise from blowing up. Need to be find empirically.

Returns:

  • A jax.numpy.array containing solutions to the equations.

aqsc.solve_ODE_chi(coeff, coeff_dp, coeff_dc, f, static_max_freq:int)

Solves the periodic linear 1st order PDE using spectral method.

Parameters:

  • coeff, coeff_dp, coeff_dc (traced) - The coefficients. Can be 2D arrays with the same format as ChiPhiFunc.content or scalars.
  • f : jax.numpy.array (traced) - The RHS. Must be a 2D array with the same format as ChiPhiFunc.content.
  • static_max_freq : int (static) - Maximum number of Fourier harmonics used. Should be set as low as posssible while trying not losing too much accuracy to prevent high-frequency noise from blowing up. Need to be find empirically.

Returns: - A jax.numpy.array containing the solution. Has the same format as ChiPhiFunc.content.

aqsc.solve_dphi_iota_dchi(iota, f, static_max_freq: int)

Solves the periodic linear 1st order PDE using spectral method.

Parameters: - iota (traced) - A scalar coefficient. - f : jax.numpy.array (traced) - The RHS. Must be a 2D array with the same format as ChiPhiFunc.content. - static_max_freq : int (static) - Maximum number of Fourier harmonics used. Should be set as low as posssible while trying not losing too much accuracy to prevent high-frequency noise from blowing up. Need to be find empirically.

Returns: - A jax.numpy.array containing the solution. Has the same format as ChiPhiFunc.content.

Utilities

aqsc.fft_filter(fft_in:jax.numpy.ndarray, target_length:int, axis:int)

Shorten an array in frequency domain to leave only target_length elements. (equivalent to a low-pass filter) by removing the highest frequency modes, but used in solvers to reduce array sizes. The result will still be in frequency domain.with IFFT.

Parameters: - fft_in : jax.numpy.array (traced) - ndarray to filter. Must already be in frequency domain produced by np.fft.fft(). - target_length : int (static) - Length of the output array - axis : int (static) - Axis to filter along

Returns: - A filtered jax.numpy.array.

aqsc.fft_pad(fft_in:jax.numpy.array, target_length:int, axis:int)

Pads an array in frequency domain to target_length elements by adding zeroes for high frequency modes coefficients. Used in solvers to match required array sizes. The result will still be in frequency domain.

Parameters: - fft_in : jax.numpy.array (traced) - array to pad. Must already be in frequency domain produced by numpy.fft.fft. - target_length : int (static) - Length of the output array. - axis : int (static) - axis to filter along

Returns: - The padded jax.numpy.array.

Tensor construction functions for looped_solver.py (Not very useful otherwise)

(NEED DOUBLE-CHECK)

Theses functions are for constructing differential/convolution tensors (equivalent to coupled PDE's) in frequency space for looped_solver.py

aqsc.conv_tensor(content:jax.numpy.ndarray, n_dim:int)

Generate a 3D array containing a stack of content.shape[1] convolution matrices. The first two axes represent the row and column of a single convolution matrix. It convolves the Fourier components of a ChiPhiFunc to another along axis 0. For multiplication in FFT space during ODE solves. To apply the operator, use x2_conv_y2 = jax.numpy.einsum('ijk,jk->ik',conv_x2, y2.content)

Parameters: - content : jax.numpy.ndarray (traced) - A content matrix of the ChiPhiFunc to convolve with

Returns: - A jax.numpy.array with shape (content.shape[0]+n_dim-1, n_dim, content.shape[1]) containing a stack of content.shape[1] convolution matrices. The first two axes represent the row and column of a single convolution matrix.

aqsc.fft_conv_tensor_batch(source:jax.numpy.array)

Generates a 4D convolution operator in the phi axis from a 3D stack of operators. (see comments in looped_solver for explanation).

Parameters: - source : jax.numpy.ndarray (traced) - An (a, b, len_phi) array, where a and b represents a matrix acting on the components of a ChiPhiFunc

Returns: - An (a, b, len_phi, len_phi) array that acts on a content by transposing axis 1 and 2 and then using jax.numpy.tensordot(). (see explanation for a "tensor operator" in comments in looped_solver.py)

aqsc.to_tensor_fft_op(ChiPhiFunc_in:ChiPhiFunc, len_tensor:int)

Reduce the grid number of a ChiPhiFunc in frequency domain with low-pass filter, and then create a 4D convolution operator in frequency domain of given length that performs point-wise multiplication with a ChiPhiFunc to a -independent ChiPhiFunc.fft().content.

Parameters: - ChiPhiFunc_in : ChiPhiFunc (traced) - The kernel. Has a components. - len_tensor : int (static) - The dimension of the output.

Returns: - An (a, 1, len_tensor, len_tensor) array that acts on a content by transposing axis 1 and 2 and then using jax.numpy.tensordot(). (see explanation for a "tensor operator" in comments in looped_solver.py)

aqsc.to_tensor_fft_op_multi_dim(ChiPhiFunc_in:ChiPhiFunc, dphi:int, dchi:int, num_mode:int, cap_axis0:int, len_tensor: int, nfp: int)

Reduce the grid number of a ChiPhiFunc in frequency domain with low-pass filter, and then create a 4D convolution operator in frequency domain of given length that performs point-wise multiplication with a ChiPhiFunc to a -dependent ChiPhiFunc.fft().content, or optionally its derivative(s) in and/or :

Parameters: - ChiPhiFunc_in:ChiPhiFunc (traced) - A - num_mode : int (static) - The number of columns of the resulting tensor. Corresponds to the total number of component in the ChiPhiFunc.fft().content the tensor will acting on. - cap_axis0 : int (static) - The length of axis=0 for the resulting tensor,used to remove outer components that are known to cancel. Must have the same even/oddness and smaller than the row number of the convolution tensor generated from ChiPhiFunc_in and num_mode.

Returns: - An jax.numpy.array with shape (len_chi+num_mode-1, num_mode, len_tensor, len_tensor). Acts on a ChiPhiFunc.fft().content by np.tensordot(operator, content, 2).

aqsc.linear_least_sq_2d_svd(A, b)

Solves the linear least square problem minimizing . Used to solve over-determined linear systems.

Parameters: - A, b : jax.numpy.array (traced) - Input arrays of shape (n,m) and (n)

Returns: - A jax.numpy.array with shape (m).