Source code for cudolfinx.jit

# Copyright (C) 2026 Benjamin Pachev
#
# This file is part of cuDOLFINX
#
# SPDX-License-Identifier:    LGPL-3.0-or-later

"""Routines for manipulating generated FFCX code."""

import pathlib
from typing import Any

import numpy as np

from dolfinx import fem

__all__ = [
    "get_tabulate_tensor_sources",
    "get_wrapped_tabulate_tensors",
]

[docs] def get_tabulate_tensor_sources(form: fem.Form): """Given a compiled fem.Form, extract the C source code of the tabulate tensors.""" module_file = pathlib.Path(form.module.__file__) source_filename = module_file.name.split(".")[0] + ".c" source_file = module_file.parent.joinpath(source_filename) if not source_file.is_file(): raise OSError("Could not find generated ffcx source file '{source_file}'!") tabulate_tensors = [] parsing_tabulate = False parsing_header = False bracket_count = 0 with open(source_file) as fp: for line in fp: if "tabulate_tensor_integral" in line and line.strip().startswith("void"): parsing_tabulate = True parsing_header = True func_name = line.strip().split()[1].split("(")[0] tabulate_id = func_name.replace("tabulate_tensor_integral_", "") tabulate_body = [] elif parsing_header: if line.startswith("{"): parsing_header = False bracket_count += 1 elif parsing_tabulate: if line.startswith("{"): bracket_count += 1 elif line.startswith("}"): bracket_count -= 1 if not bracket_count: tabulate_tensors.append((tabulate_id, "".join(tabulate_body))) parsing_tabulate = False else: tabulate_body.append(line) elif "form_integrals_form" in line: if "{" in line: arr = line.split("{")[-1].split("}")[0] ordered_integral_ids = [ part.strip().replace("&integral_", "") for part in arr.split(",") ] # map ids to order of appearance in tensor list id_order = {tabulate[0]: i for i, tabulate in enumerate(tabulate_tensors)} integral_tensor_indices = [id_order[integral_id] for integral_id in ordered_integral_ids] return tabulate_tensors, integral_tensor_indices
cuda_tabulate_tensor_header = """ #define alignas(x) #define restrict __restrict__ typedef unsigned char uint8_t; typedef unsigned int uint32_t; typedef double ufc_scalar_t; extern "C" __global__ void tabulate_tensor_{factory_name}({scalar_type}* restrict A, const {scalar_type}* restrict w, const {scalar_type}* restrict c, const {geom_type}* restrict coordinate_dofs, const int* restrict entity_local_index, const uint8_t* restrict quadrature_permutation ) """ def _convert_dtype_to_str(dtype: Any): """Convert numpy dtype to named C type.""" if dtype == np.float32: return "float" elif dtype == np.float64: return "double" else: raise TypeError(f"Unsupported dtype: '{dtype}'")
[docs] def get_wrapped_tabulate_tensors(form: fem.Form, backend="cuda"): """Given a fem.Form, wrap the tabulate tensors for use on a GPU.""" if backend != "cuda": raise NotImplementedError(f"Backend '{backend}' not yet supported.") # for now assume same type for form and mesh # this is typically the default geom_type = scalar_type = _convert_dtype_to_str(form.dtype) res = [] sources, integral_tensor_indices = get_tabulate_tensor_sources(form) for id, body in sources: factory_name = "integral_" + id name = "tabulate_tensor_" + factory_name header = cuda_tabulate_tensor_header.format( scalar_type=scalar_type, geom_type=geom_type, factory_name=factory_name ) wrapped_source = header + "{\n" + body + "}\n" res.append((name, wrapped_source)) return res, integral_tensor_indices