Source code for data_provider_toolkit

import dataclasses
import enum
import io
import re
import typing

import networkx
import pyarrow
import pyarrow.compute
import pyarrow.json

from kaxanuk.data_curator.data_blocks.base_data_block import (
    ConsolidatedFieldsTable,
    BaseDataBlock,
    EntityField,
)
from kaxanuk.data_curator.entities import (
    BaseDataEntity,
)
from kaxanuk.data_curator.exceptions import (
    DataProviderIncorrectMappingTypeError,
    DataProviderMultiEndpointCommonDataDiscrepancyError,
    DataProviderMultiEndpointCommonDataOrderError,
    DataProviderMultiEndpointDuplicateKeysError,
    DataProviderMultiEndpointNullColumnsError,
    DataProviderParsingError,
    DataProviderToolkitArgumentError,
    DataProviderToolkitNoDataError,
    DataProviderToolkitRuntimeError,
)
from kaxanuk.data_curator.modules.data_column import DataColumn


type ColumnRemap = str   # new entity.field or entity.field$tag column name

type Endpoint = enum.StrEnum    # identifier of a particular endpoint
# column names are either "entity.field" for primary keys, or "endpoint_name$entity.field" for specific endpoints
type EndpointDiscrepanciesTable = pyarrow.Table    # mostly for error handler use
type PrimaryKeyTable = pyarrow.Table     # table with primary key columns for table merges
type TagName = str  # name of the data provider tag

[docs] @dataclasses.dataclass(slots=True, frozen=True) class PreprocessedFieldMapping: tags: list[TagName] preprocessors: list[typing.Callable]
type EndpointFieldMap = dict[ Endpoint, dict[ EntityField, TagName | PreprocessedFieldMapping ] ] type DataBlockEndpointTagMap = dict[ type[BaseDataBlock], EndpointFieldMap ] type EndpointColumnRemaps = dict[ Endpoint, dict[ TagName, list[ColumnRemap] ] ] type DataBlockEndpointColumnRemaps = dict[ type[BaseDataBlock], EndpointColumnRemaps ] type EndpointFieldPreprocessors = dict[ Endpoint, dict[ EntityField, PreprocessedFieldMapping ] ] type DataBlockEndpointFieldPreprocessors = dict[ type[BaseDataBlock], EndpointFieldPreprocessors ] type EndpointTables = dict[ Endpoint, pyarrow.Table ] type EntityClassNameMap = dict[ str, type[BaseDataEntity] ] type DataBlockEntityClassNameMap = dict[ type[BaseDataBlock], EntityClassNameMap ] type EntityEndpoints = dict[ type[BaseDataEntity], set[Endpoint] ] type EntityFieldColumns = dict[ type[BaseDataEntity], dict[ EntityField, list[pyarrow.Array] ] ] type EntityRelationMap = dict[ type[BaseDataEntity], dict[ EntityField, type[BaseDataEntity] ] ] type ProcessedEndpointTables = dict[ # endpoint tables that have been remapped and had the preprocessors applied Endpoint, pyarrow.Table ] type EntityFieldToMostSpecificEntity = dict[ EntityField, # field member descriptor type[BaseDataEntity] # most specific entity for that field ]
[docs] class DataProviderFieldPreprocessors:
[docs] @staticmethod def convert_millions_to_units(column: DataColumn) -> DataColumn: """ Convert financial values from millions to individual units. Takes a column containing values expressed in millions and multiplies each value by 1,000,000 to convert to standard units. Parameters ---------- column Column containing values in millions Returns ------- DataColumn Column with values converted to standard units """ return column * 1_000_000
[docs] @staticmethod def cast_datetime_to_date(column: DataColumn) -> DataColumn: """ Cast datetime values to date type. Converts a column containing datetime values to date32 type, discarding time information. Parameters ---------- column Column containing datetime values Returns ------- DataColumn Column with values cast to date type """ return DataColumn.load( column.to_pyarrow().cast(pyarrow.date32()) )
[docs] class DataProviderToolkit: # character used for the separator lines wrapping discrepancy tables DISCREPANCY_TABLE_SEPARATOR_CHARACTER: typing.ClassVar[str] = "-" # max width of the separator lines wrapping discrepancy tables DISCREPANCY_TABLE_SEPARATOR_MAX_WIDTH: typing.ClassVar[int] = 80 # endpoint column remaps cache _data_block_endpoint_column_remaps: typing.ClassVar[ DataBlockEndpointColumnRemaps ] = {} # endpoint field preprocessors cache _data_block_endpoint_field_preprocessors: typing.ClassVar[ DataBlockEndpointFieldPreprocessors ] = {} _data_block_entity_class_name_map: typing.ClassVar[ DataBlockEntityClassNameMap ] = {} # entity field to most specific entity cache (keyed by tuple of (endpoint values, entity names)) _entity_field_to_most_specific_entity_cache: typing.ClassVar[ dict[str, EntityFieldToMostSpecificEntity] ] = {}
[docs] @staticmethod def consolidate_processed_endpoint_tables( *, processed_endpoint_tables: ProcessedEndpointTables, table_merge_fields: list[EntityField], predominant_order_descending: bool = False, ) -> ConsolidatedFieldsTable: """ Consolidate multiple endpoint tables into a single unified table. Merges processed tables from different endpoints by their primary keys, preserving row order and coalescing values from different endpoints. Validates that common columns across endpoints have consistent values for shared rows. Parameters ---------- processed_endpoint_tables Dictionary mapping endpoints to their processed tables table_merge_fields List of entity fields to use as primary keys for merging predominant_order_descending Whether the predominant ordering is descending Returns ------- ConsolidatedFieldsTable Consolidated table containing all data from all endpoints Raises ------ DataProviderMultiEndpointCommonDataDiscrepancyError When common columns have inconsistent values across endpoints DataProviderMultiEndpointDuplicateKeysError When an endpoint table has multiple rows sharing the same primary key DataProviderMultiEndpointNullColumnsError When any endpoint table contains a column whose every value is null DataProviderToolkitRuntimeError When no tables contain required primary key columns """ null_type_columns = { str(endpoint): [ field.name for field in table.schema if field.type == pyarrow.null() ] for (endpoint, table) in processed_endpoint_tables.items() if any( field.type == pyarrow.null() for field in table.schema ) } if null_type_columns: raise DataProviderMultiEndpointNullColumnsError(null_type_columns) # if single table, return it if not processed_endpoint_tables: return pyarrow.table({}) if len(processed_endpoint_tables) == 1: return next( iter(processed_endpoint_tables.values()) ) # get primary key ordering key_column_names = [ f"{field.__objclass__.__name__}.{field.__name__}" for field in table_merge_fields ] primary_key_subsets = [ table.select(key_column_names) for table in processed_endpoint_tables.values() if all( pk in table.column_names for pk in key_column_names ) ] if not primary_key_subsets: msg = "None of the provided tables contain the required primary key columns for merging." raise DataProviderToolkitRuntimeError(msg) merged_key_table = DataProviderToolkit._merge_primary_key_subsets_preserving_order( primary_key_subsets, predominant_order_descending=predominant_order_descending, ) order_col_name = "__order_col" key_table_with_order = merged_key_table.add_column( 0, order_col_name, pyarrow.array(range(merged_key_table.num_rows)) ) # Align all tables to the master primary key list and create validity masks in one pass aligned_tables = [] validity_masks = {} indicator_col = '__indicator_for_validity' for (endpoint, original_table) in processed_endpoint_tables.items(): has_pk_cols = all( key in original_table.column_names for key in key_column_names ) if not has_pk_cols: validity_masks[endpoint] = pyarrow.array( [False] * len(merged_key_table), type=pyarrow.bool_() ) aligned_tables.append( pyarrow.table({}) # Empty placeholder ) continue # Join and sort once. Type the indicator explicitly so a zero-row endpoint table still yields a bool column table_with_indicator = original_table.append_column( indicator_col, pyarrow.array( [True] * len(original_table), type=pyarrow.bool_(), ) ) aligned_table_with_helpers = key_table_with_order.join( table_with_indicator, keys=key_column_names, join_type="left outer" ).sort_by(order_col_name) # Extract validity mask from the aligned table validity_masks[endpoint] = aligned_table_with_helpers[indicator_col].is_valid() final_cols = [ col for col in aligned_table_with_helpers.column_names if col not in [order_col_name, indicator_col] ] aligned_tables.append( aligned_table_with_helpers.select(final_cols) ) endpoints = list(processed_endpoint_tables.keys()) # Build column index for efficient lookup: col_name -> list of (table_idx, has_col) col_to_tables = {} for (index, col_name) in ( (i, col) for (i, table) in enumerate(aligned_tables) for col in table.column_names if col not in key_column_names ): if col_name not in col_to_tables: col_to_tables[col_name] = [] col_to_tables[col_name].append(index) discrepant_columns = set() discrepant_rows_mask = None # Check for discrepancies only on common columns between table pairs for (col_name, table_indices) in col_to_tables.items(): if len(table_indices) <= 1: # No overlap, no discrepancy possible continue # Check all pairs that share this column for (idx, jdx) in ( (i, j) for i in range(len(table_indices)) for j in range(i + 1, len(table_indices)) ): index = table_indices[idx] j = table_indices[jdx] endpoint1 = endpoints[index] endpoint2 = endpoints[j] common_rows_mask = pyarrow.compute.and_( validity_masks[endpoint1], validity_masks[endpoint2] ) if not pyarrow.compute.any(common_rows_mask).as_py(): # no common rows, no discrepancy possible continue col1_common = aligned_tables[index].column(col_name).filter(common_rows_mask) col2_common = aligned_tables[j].column(col_name).filter(common_rows_mask) are_equal = pyarrow.compute.equal( col1_common, col2_common ).fill_null(fill_value=False) both_null = pyarrow.compute.and_( pyarrow.compute.is_null(col1_common), pyarrow.compute.is_null(col2_common) ) no_discrepancy_mask = pyarrow.compute.or_(are_equal, both_null) if not pyarrow.compute.all(no_discrepancy_mask).as_py(): # handle discrepancy discrepancy_mask = pyarrow.compute.invert(no_discrepancy_mask) # Track this column as discrepant discrepant_columns.add(col_name) # Expand discrepancy_mask from common rows back to full table size # Start with all False full_size_discrepancy = pyarrow.array( [False] * len(merged_key_table) ) # Set to True where common_rows_mask is True AND discrepancy_mask is True # replace_with_mask replaces values where mask is True with values from the replacement array # Convert both masks to Array if they are ChunkedArrays if isinstance(common_rows_mask, pyarrow.ChunkedArray): common_rows_mask = common_rows_mask.combine_chunks() if isinstance(discrepancy_mask, pyarrow.ChunkedArray): discrepancy_mask = discrepancy_mask.combine_chunks() full_size_discrepancy = pyarrow.compute.replace_with_mask( full_size_discrepancy, common_rows_mask, discrepancy_mask ) # Combine with master mask if discrepant_rows_mask is None: discrepant_rows_mask = full_size_discrepancy else: discrepant_rows_mask = pyarrow.compute.or_( discrepant_rows_mask, full_size_discrepancy ) # Create debug table with all discrepant rows and columns if discrepant_columns: discrepancy_table = DataProviderToolkit._calculate_common_column_discrepancies( discrepant_columns=discrepant_columns, discrepant_rows_mask=discrepant_rows_mask, primary_keys_table=merged_key_table, key_column_names=key_column_names, aligned_tables=aligned_tables, endpoints=endpoints, ) raise DataProviderMultiEndpointCommonDataDiscrepancyError( discrepant_columns=discrepant_columns, discrepancies_table=discrepancy_table, key_column_names=key_column_names, ) # Build consolidated table efficiently # Start with primary keys consolidated_columns = { name: merged_key_table[name] for name in key_column_names } # Group columns by field name (without entity prefix) field_to_full_names = {} all_col_names = sorted(col_to_tables.keys()) for full_name in all_col_names: field = full_name.split('.', 1)[1] if field not in field_to_full_names: field_to_full_names[field] = [] field_to_full_names[field].append(full_name) # Process each field group for (_field, full_names) in sorted(field_to_full_names.items()): # Collect unique arrays for this field (deduplicate using id() for efficiency) seen_array_ids = set() arrays_to_coalesce = [] for full_name in full_names: for table_idx in col_to_tables[full_name]: arr = aligned_tables[table_idx][full_name] arr_id = id(arr) # Only add if not seen (by object identity first, then equality) if arr_id not in seen_array_ids: # Check equality against already collected arrays is_duplicate = any( arr.equals(unique_arr) for unique_arr in arrays_to_coalesce ) if not is_duplicate: arrays_to_coalesce.append(arr) seen_array_ids.add(arr_id) if arrays_to_coalesce: merged_column = pyarrow.compute.coalesce(*arrays_to_coalesce) for full_name in full_names: consolidated_columns[full_name] = merged_column # Build final table with correct column order final_column_order = [*key_column_names, *all_col_names] return pyarrow.table({ name: consolidated_columns[name] for name in final_column_order })
[docs] @classmethod def drop_discrepant_processed_endpoint_tables_rows( cls, *, discrepancy_table: EndpointDiscrepanciesTable, processed_endpoint_tables: ProcessedEndpointTables, key_column_names: list[str], ) -> EndpointTables: """ Drop discrepant rows from processed endpoint tables. Removes rows in each endpoint table whose primary keys match the discrepancy table, returning trimmed copies. Used when the discrepant rows cannot be reconciled and the surviving rows should be retained. Parameters ---------- discrepancy_table Table containing primary keys of discrepant rows processed_endpoint_tables Dictionary mapping endpoints to their processed tables key_column_names List of primary key column names Returns ------- EndpointTables Dictionary mapping endpoints to tables with discrepant rows dropped """ primary_keys_table = discrepancy_table.select(key_column_names) output_tables = {} for (endpoint, table) in processed_endpoint_tables.items(): output_tables[endpoint] = cls._drop_table_rows_by_primary_key( table=table, drop_rows_primary_keys=primary_keys_table, ) return output_tables
[docs] @classmethod def create_endpoint_tables_from_json_mapping( cls, /, endpoint_json_strings: dict[Endpoint, str], ) -> EndpointTables: """ Create endpoint tables from JSON string representations. Parses JSON strings for each endpoint and converts them into PyArrow tables, handling both JSON arrays and newline-delimited JSON formats. Parameters ---------- endpoint_json_strings Dictionary mapping endpoints to their JSON string data Returns ------- EndpointTables Dictionary mapping endpoints to parsed PyArrow tables Raises ------ DataProviderToolkitRuntimeError When JSON parsing fails for any endpoint """ try: endpoint_tables = { endpoint: cls._create_table_from_json_string(json_string) for (endpoint, json_string) in endpoint_json_strings.items() } except DataProviderParsingError as error: msg = f"Failed to parse endpoint tables: {error}" raise DataProviderToolkitRuntimeError(msg) from error return endpoint_tables
[docs] @staticmethod def find_common_table_missing_rows_mask( common_rows_table: pyarrow.Table, subset_rows_table: pyarrow.Table, ) -> pyarrow.BooleanArray | None: """ Identify rows in common table that are missing from subset table. Performs a null-safe comparison between two tables by column position to determine which rows in the common table are not present in the subset table. Parameters ---------- common_rows_table Table containing all potential rows subset_rows_table Table containing a subset of rows to check against Returns ------- pyarrow.BooleanArray or None Boolean mask where True indicates missing rows, or None if common table is empty Raises ------ DataProviderToolkitArgumentError When tables have different number of columns """ if common_rows_table.num_columns != subset_rows_table.num_columns: msg = "Tables have different number of columns" raise DataProviderToolkitArgumentError(msg) if common_rows_table.num_rows == 0: return None column_names = common_rows_table.column_names if subset_rows_table.num_rows == 0: return pyarrow.array( [True] * common_rows_table.num_rows, type=pyarrow.bool_() ) # Ensure both tables have matching schemas for the key columns subset_renamed = subset_rows_table.rename_columns(column_names) # Cast subset columns to match common_rows_table schema cast_columns = {} for col_name in column_names: common_col_type = common_rows_table.schema.field(col_name).type subset_col = subset_renamed[col_name] if subset_col.type != common_col_type: cast_columns[col_name] = subset_col.cast(common_col_type) else: cast_columns[col_name] = subset_col subset_with_matching_types = pyarrow.table(cast_columns) # Add an order column to preserve original row order order_col_name = "__order_col__" common_with_order = common_rows_table.add_column( 0, order_col_name, pyarrow.array(range(common_rows_table.num_rows)) ) # Strategy: Replace NULLs with unique placeholder values that won't collide # Then join, then verify NULL matches separately # Create hash columns for NULL-safe comparison hash_col_name = "__null_hash__" # Build hash arrays that encode NULL positions common_hash_parts = [] subset_hash_parts = [] for col_name in column_names: common_col = common_rows_table[col_name] subset_col = subset_with_matching_types[col_name] # Create a hash part: "1" if null, "0" if not null common_null_indicator = pyarrow.compute.if_else( pyarrow.compute.is_null(common_col), pyarrow.array(['1'] * len(common_col)), pyarrow.array(['0'] * len(common_col)) ) subset_null_indicator = pyarrow.compute.if_else( pyarrow.compute.is_null(subset_col), pyarrow.array(['1'] * len(subset_col)), pyarrow.array(['0'] * len(subset_col)) ) common_hash_parts.append(common_null_indicator) subset_hash_parts.append(subset_null_indicator) # Concatenate to create null pattern hash common_null_hash = common_hash_parts[0] for part in common_hash_parts[1:]: common_null_hash = pyarrow.compute.binary_join_element_wise( common_null_hash, part, '' ) subset_null_hash = subset_hash_parts[0] for part in subset_hash_parts[1:]: subset_null_hash = pyarrow.compute.binary_join_element_wise( subset_null_hash, part, '' ) # Add hash column to tables common_with_hash = common_with_order.append_column( hash_col_name, common_null_hash ) subset_with_hash = subset_with_matching_types.append_column( hash_col_name, subset_null_hash ) # Replace NULLs with fill values for join (they must match by NULL pattern first via hash) common_filled_cols = { order_col_name: common_with_hash[order_col_name], hash_col_name: common_with_hash[hash_col_name] } subset_filled_cols = {hash_col_name: subset_with_hash[hash_col_name]} for col_name in column_names: common_col = common_with_hash[col_name] subset_col = subset_with_hash[col_name] # Fill NULLs with a type-appropriate value (0 for numeric, empty string for string, etc.) # The actual value doesn't matter as we filter by hash first if ( pyarrow.types.is_integer(common_col.type) or pyarrow.types.is_floating(common_col.type) ): fill_value = 0 elif ( pyarrow.types.is_string(common_col.type) or pyarrow.types.is_large_string(common_col.type) ): fill_value = '' elif pyarrow.types.is_date(common_col.type): fill_value = 0 # Will be cast to epoch date else: # For other types, try using the type's default fill_value = ( pyarrow.scalar( None, type=common_col.type ).as_py() or 0 ) common_filled_cols[col_name] = pyarrow.compute.fill_null(common_col, fill_value) subset_filled_cols[col_name] = pyarrow.compute.fill_null(subset_col, fill_value) common_filled = pyarrow.table(common_filled_cols) subset_filled = pyarrow.table(subset_filled_cols) # Perform join using hash + all columns as keys join_keys = [hash_col_name, *column_names] indicator_col = "__indicator_for_mask__" subset_with_indicator = subset_filled.append_column( indicator_col, pyarrow.array( [False] * subset_filled.num_rows, type=pyarrow.bool_() ) ) joined_table = common_filled.join( subset_with_indicator, keys=join_keys, join_type="left outer" ).sort_by(order_col_name) indicator_column = joined_table.column(indicator_col) mask = pyarrow.compute.is_null(indicator_column).combine_chunks() if pyarrow.compute.any(mask).as_py(): return mask else: return None
[docs] @classmethod def format_consolidated_discrepancy_table_for_output( cls, *, discrepancy_table: pyarrow.Table, output_column_renames: list[str] | dict[str, str], csv_separator: str = "|", ) -> str: """ Format a discrepancy table as CSV string for output. Converts a PyArrow table to CSV format with renamed columns and specified separator, preserving datetime object formatting. The CSV body is wrapped between two separator lines built from ``DISCREPANCY_TABLE_SEPARATOR_CHARACTER``, whose width matches the header line, capped at ``DISCREPANCY_TABLE_SEPARATOR_MAX_WIDTH`` characters. Parameters ---------- discrepancy_table Table containing discrepancy data to format output_column_renames New column names as positional list or mapping dictionary csv_separator Character to use as CSV field separator Returns ------- str CSV-formatted string representation of the table, wrapped between separator lines """ renamed_table = discrepancy_table.rename_columns(output_column_renames) # convert to pandas, preserving all datetime settings csv_output = ( renamed_table .to_pandas(timestamp_as_object=True) .to_csv(sep=csv_separator, index=False) ) csv_body = csv_output.rstrip("\n") header_line = csv_body.split("\n", 1)[0] separator_width = min(len(header_line), cls.DISCREPANCY_TABLE_SEPARATOR_MAX_WIDTH) separator_line = cls.DISCREPANCY_TABLE_SEPARATOR_CHARACTER * separator_width return "\n".join([separator_line, csv_body, separator_line])
[docs] @classmethod def format_endpoint_discrepancy_table_for_output( cls, *, data_block: type[BaseDataBlock], discrepancy_table: EndpointDiscrepanciesTable, endpoints_enum: enum.StrEnum, endpoint_field_map: EndpointFieldMap, csv_separator: str = "|", ) -> str: """ Format an endpoint discrepancy table with provider-specific naming. Converts internal column naming (entity.field format) to provider endpoint tag format (endpoint.tag) and outputs as CSV string. Parameters ---------- data_block Data block class defining the entity structure discrepancy_table Table containing endpoint discrepancy data endpoints_enum Enum defining available endpoints endpoint_field_map Mapping from entity fields to provider tags per endpoint csv_separator Character to use as CSV field separator Returns ------- str CSV-formatted string with provider-specific column names Raises ------ DataProviderToolkitRuntimeError When column name parsing fails """ column_names = discrepancy_table.column_names column_new_names = [] # find mapping from column names to "endpoint.tag" format for column_name in column_names: try: if '$' in column_name: (endpoint_name, entity_column_name) = column_name.split('$', 1) field_name = None else: endpoint_name = None entity_column_name = column_name (_, field_name) = column_name.split('.', 1) except ValueError as error: # split failed for some reason msg = f"Failed to format discrepancy table column name '{column_name}': {error}" raise DataProviderToolkitRuntimeError(msg) from error if endpoint_name is not None: endpoint = endpoints_enum[endpoint_name] tag_name = cls.get_provider_tag_for_entity_column( data_block=data_block, endpoint=endpoint, endpoint_field_map=endpoint_field_map, entity_column_name=entity_column_name, ) column_new_names.append(f"{endpoint.value}.{tag_name}") else: column_new_names.append(field_name) return cls.format_consolidated_discrepancy_table_for_output( discrepancy_table=discrepancy_table, output_column_renames=column_new_names, csv_separator=csv_separator, )
[docs] @classmethod def get_provider_tag_for_entity_column( cls, *, data_block: type[BaseDataBlock], endpoint: Endpoint, endpoint_field_map: EndpointFieldMap, entity_column_name: str, ) -> TagName: """ Return the provider tag for an endpoint's entity field column. Parameters ---------- data_block Data block class defining the entity structure endpoint Endpoint whose field map should be consulted endpoint_field_map Mapping from entity fields to provider tags per endpoint entity_column_name Column name in `EntityName.field_name` format Returns ------- TagName Provider tag for the given column. For fields whose mapping is a `PreprocessedFieldMapping`, a plus-joined composite of its input tags is returned. Raises ------ DataProviderToolkitRuntimeError When `entity_column_name` is not in `EntityName.field_name` format """ try: (entity_name, field_name) = entity_column_name.split('.', 1) except ValueError as error: msg = " ".join([ f"Invalid entity column name '{entity_column_name}':", "expected 'EntityName.field_name' format", ]) raise DataProviderToolkitRuntimeError(msg) from error entity = data_block.get_entity_class_name_map()[entity_name] field = getattr(entity, field_name) tag_name = endpoint_field_map[endpoint][field] if isinstance(tag_name, PreprocessedFieldMapping): return "+".join(tag_name.tags) return tag_name
[docs] @classmethod def process_endpoint_tables( cls, *, data_block: type[BaseDataBlock], endpoint_field_map: EndpointFieldMap, endpoint_tables: EndpointTables, ) -> ProcessedEndpointTables: """ Process raw endpoint tables through remapping and preprocessing. Transforms provider-specific tag names to entity.field format and applies configured preprocessor functions to compute derived fields from raw data. Parameters ---------- data_block Data block class defining the entity structure endpoint_field_map Mapping from entity fields to provider tags per endpoint endpoint_tables Dictionary mapping endpoints to raw data tables Returns ------- ProcessedEndpointTables Dictionary mapping endpoints to processed tables with standardized column names and computed fields Raises ------ DataProviderToolkitArgumentError When data_block is not a BaseDataBlock subclass DataProviderToolkitNoDataError When all provided tables are empty DataProviderToolkitRuntimeError When preprocessor execution fails """ if not issubclass(data_block, BaseDataBlock): msg = "data_block parameter needs to be a subclass of BaseDataBlock" raise DataProviderToolkitArgumentError(msg) max_table_length = max( len(table) for table in endpoint_tables.values() ) if max_table_length == 0: msg = "All provided endpoint tables are empty." raise DataProviderToolkitNoDataError(msg) # get map from tags to remapped columns if data_block not in cls._data_block_endpoint_column_remaps: cls._data_block_endpoint_column_remaps[data_block] = cls._calculate_endpoint_column_remaps( endpoint_field_map ) endpoint_column_remaps = cls._data_block_endpoint_column_remaps[data_block] # get preprocessors if data_block not in cls._data_block_endpoint_field_preprocessors: cls._data_block_endpoint_field_preprocessors[data_block] = cls._calculate_endpoint_field_preprocessors( endpoint_field_map ) endpoint_field_preprocessors = cls._data_block_endpoint_field_preprocessors[data_block] # get entity field to most specific entity mapping entity_field_to_most_specific_entity = cls._get_entity_field_to_most_specific_entity( endpoint_field_map ) # transform table columns per tag to columns per entity.field$tag remapped_endpoint_tables = cls._remap_endpoint_table_columns( endpoint_column_remaps, endpoint_tables, entity_field_to_most_specific_entity, ) # run processors try: processed_endpoint_tables = cls._process_remapped_endpoint_tables( endpoint_field_preprocessors, remapped_endpoint_tables, entity_field_to_most_specific_entity, ) except pyarrow.lib.ArrowInvalid as error: msg = f"Error running data provider preprocessors: {error}" raise DataProviderToolkitRuntimeError(msg) from error return processed_endpoint_tables
@staticmethod def _calculate_common_column_discrepancies( discrepant_columns: set[str], discrepant_rows_mask: pyarrow.BooleanArray, primary_keys_table: PrimaryKeyTable, key_column_names: list[str], aligned_tables: list[pyarrow.Table], endpoints: list[Endpoint], ) -> pyarrow.Table: """ Create a debug table showing all discrepancy details. Builds a table containing primary keys and values from all endpoints for columns and rows where discrepancies were detected, enabling detailed analysis of data inconsistencies. Parameters ---------- discrepant_columns Set of column names with detected discrepancies discrepant_rows_mask Boolean mask indicating rows containing any discrepancy primary_keys_table Table containing primary key columns key_column_names List of primary key column names aligned_tables List of tables aligned to common primary keys endpoints List of endpoint identifiers corresponding to aligned_tables Returns ------- pyarrow.Table Table with primary keys and endpoint-specific columns for all discrepant data points """ # Start building output table with primary keys output_columns = {} # Filter primary keys to discrepant rows only discrepant_row_keys = primary_keys_table.filter(discrepant_rows_mask) for col_name in key_column_names: output_columns[col_name] = discrepant_row_keys[col_name] # For each discrepant column, add values from all endpoints that have it for col_name in sorted(discrepant_columns): # Find all endpoints that have this column for (i, table) in enumerate(aligned_tables): if col_name in table.column_names: endpoint = endpoints[i] # Get the column from the aligned table and filter to discrepant rows col_array = table[col_name].filter(discrepant_rows_mask) output_columns[f"{endpoint.name}${col_name}"] = col_array return pyarrow.table(output_columns) @staticmethod def _calculate_endpoint_column_remaps( endpoint_field_map: EndpointFieldMap ) -> EndpointColumnRemaps: """ Calculate column name remapping from provider tags to entity fields. Analyzes the endpoint field map to determine how provider-specific tag names should be renamed to standardized entity.field format, handling both simple mappings and preprocessed field mappings. Parameters ---------- endpoint_field_map Mapping from entity fields to provider tags per endpoint Returns ------- EndpointColumnRemaps Dictionary mapping endpoints to tag-based column rename operations Raises ------ DataProviderIncorrectMappingTypeError When a mapping value has an invalid type """ endpoint_column_remaps: EndpointColumnRemaps = {} for (endpoint, field_mappings) in endpoint_field_map.items(): # Initialize the tag-to-column-remaps dict for this endpoint if endpoint not in endpoint_column_remaps: endpoint_column_remaps[endpoint] = {} for (entity_field, mapping_value) in field_mappings.items(): # Get the entity class and field name from the entity_field descriptor entity_class = entity_field.__objclass__ entity_name = entity_class.__name__ field_name = entity_field.__name__ if isinstance(mapping_value, str): # It's a TagName - create one "entity.field" remap tag_name = mapping_value column_remap = f"{entity_name}.{field_name}" tag_remap = {tag_name: column_remap} elif isinstance(mapping_value, PreprocessedFieldMapping): # It's a PreprocessedFieldMapping - create one "entity.field$tag" per tag tag_remap = { tag_name: f"{entity_name}.{field_name}${tag_name}" for tag_name in mapping_value.tags } else: msg = f"Invalid mapping value for {endpoint}.{entity_name}.{field_name}:" raise DataProviderIncorrectMappingTypeError(msg) for (tag_name, column_remap) in tag_remap.items(): if tag_name not in endpoint_column_remaps[endpoint]: endpoint_column_remaps[endpoint][tag_name] = [] endpoint_column_remaps[endpoint][tag_name].append(column_remap) return endpoint_column_remaps @staticmethod def _calculate_endpoint_field_preprocessors( endpoint_field_map: EndpointFieldMap ) -> EndpointFieldPreprocessors: """ Extract preprocessor configurations from endpoint field map. Filters the endpoint field map to retain only fields that require preprocessing through PreprocessedFieldMapping objects. Parameters ---------- endpoint_field_map Mapping from entity fields to provider tags per endpoint Returns ------- EndpointFieldPreprocessors Dictionary mapping endpoints to fields requiring preprocessing """ return { endpoint: { entity_field: mapping_value for (entity_field, mapping_value) in field_mappings.items() if isinstance(mapping_value, PreprocessedFieldMapping) } for (endpoint, field_mappings) in endpoint_field_map.items() } @staticmethod def _calculate_most_specific_field_entity( endpoint_field_map: EndpointFieldMap ) -> EntityFieldToMostSpecificEntity: """ Calculate mapping from entity fields to their most specific descendant entities. For each field in each entity, determines which entity in the inheritance hierarchy should be used for column naming. When a field is inherited, the most specific (deepest) descendant entity that contains the field is chosen. Parameters ---------- endpoint_field_map Mapping from entity fields to provider tags per endpoint Returns ------- EntityFieldToMostSpecificEntity Dictionary mapping EntityField descriptors to the most specific entity class that should be used for that field Raises ------ DataProviderToolkitRuntimeError When multiple sibling entities have the same field name """ all_entities = { entity_field.__objclass__ for field_mappings in endpoint_field_map.values() for entity_field in field_mappings } graph = networkx.DiGraph() for entity in all_entities: graph.add_node(entity) for base in entity.__bases__: if base in all_entities: graph.add_edge(base, entity) entity_fields_map = {} for entity in all_entities: if dataclasses.is_dataclass(entity): entity_fields_map[entity] = { field.name for field in dataclasses.fields(entity) } field_to_most_specific_entity = {} for entity in all_entities: if entity not in entity_fields_map: continue for field_name in entity_fields_map[entity]: descendants_with_field = { desc for desc in (networkx.descendants(graph, entity) | {entity}) if ( desc in entity_fields_map and field_name in entity_fields_map[desc] ) } leaf_descendants = { desc for desc in descendants_with_field if not any( child in descendants_with_field for child in graph.successors(desc) ) } if len(leaf_descendants) > 1: ancestors_map = { leaf: set(networkx.ancestors(graph, leaf)) for leaf in leaf_descendants } for leaf1 in leaf_descendants: for leaf2 in leaf_descendants: if leaf1 >= leaf2: continue if ( ancestors_map[leaf1] & ancestors_map[leaf2] ): msg = f"Multiple entities detected with same field name `{field_name}`" raise DataProviderToolkitRuntimeError(msg) most_specific = ( next(iter(leaf_descendants)) if leaf_descendants else entity ) # Use the actual field descriptor from the original entity as the key entity_field = getattr(entity, field_name) field_to_most_specific_entity[entity_field] = most_specific return field_to_most_specific_entity @staticmethod def _create_table_from_json_string(json_string: str) -> pyarrow.Table: """ Parse JSON string into a PyArrow table. Converts JSON data from array or newline-delimited format into a PyArrow table, handling format normalization automatically. Parameters ---------- json_string JSON string in array or newline-delimited format Returns ------- pyarrow.Table Parsed table, or empty table if input is empty Raises ------ DataProviderParsingError When JSON parsing fails due to invalid format """ # PyArrow expects newline-delimited JSON, not JSON arrays # Convert JSON array to NDJSON with simple text transformation json_string_stripped = json_string.strip() if ( json_string_stripped.startswith('[') and json_string_stripped.endswith(']') ): # Remove outer array brackets json_string_stripped = json_string_stripped[1:-1].strip() # Replace pattern of }\n { or },\n { with }\n{ json_string_stripped = re.sub( r'\}\s*,\s*\{', '}\n{', json_string_stripped ) if len(json_string_stripped) == 0: return pyarrow.table({}) # Convert string to bytes and create a buffer json_bytes = json_string_stripped.encode('utf-8') buffer = io.BytesIO(json_bytes) # Read into PyArrow table try: # @todo: read as string and manually infer each column type, to prevent dates read as datetime table = pyarrow.json.read_json( buffer, parse_options=pyarrow.json.ParseOptions( newlines_in_values=True, ), ) except pyarrow.lib.ArrowInvalid as error: msg = f"Error parsing JSON string: {error}" raise DataProviderParsingError(msg) from error return table @staticmethod def _drop_table_rows_by_primary_key( table: pyarrow.Table, drop_rows_primary_keys: PrimaryKeyTable, ) -> pyarrow.Table: """ Remove rows from a table whose primary keys match the provided keys. Identifies rows in the table matching the provided primary keys and returns a new table with those rows removed. Parameters ---------- table Table to drop rows from drop_rows_primary_keys Table containing primary keys of rows to drop Returns ------- pyarrow.Table Table with the matched rows removed Raises ------ DataProviderToolkitRuntimeError When required key columns are missing or type incompatibilities exist """ key_columns = drop_rows_primary_keys.column_names # An endpoint that returned no rows yields a 0-column table here, so there # is nothing to drop and the column check below would spuriously reject it. if ( table.num_rows == 0 or drop_rows_primary_keys.num_rows == 0 ): return table for col in key_columns: if col not in table.column_names: msg = f"DataProviderToolkit._drop_table_rows_by_primary_key error: Column '{col}' not found in table." raise DataProviderToolkitRuntimeError(msg) # Combine chunks to ensure we work with flat Arrays, avoiding 'Mask must be array' errors table_combined = table.combine_chunks() # Add a temporary row index column to track matched rows positionally row_index_col_name = "__temp_row_index__" indices_array = pyarrow.array( range(table_combined.num_rows) ) table_with_index = table_combined.append_column( row_index_col_name, indices_array ) try: matches = table_with_index.select([*key_columns, row_index_col_name]).join( drop_rows_primary_keys, keys=key_columns, join_type="inner" ) except pyarrow.lib.ArrowInvalid as error: msg = f"DataProviderToolkit._drop_table_rows_by_primary_key error: {error}" raise DataProviderToolkitRuntimeError(msg) from error if matches.num_rows == 0: return table_combined rows_to_drop_indices = matches[row_index_col_name] all_indices = table_with_index[row_index_col_name] keep_mask = pyarrow.compute.invert( pyarrow.compute.is_in( all_indices, value_set=rows_to_drop_indices ) ).combine_chunks() return table_combined.filter(keep_mask) @classmethod def _get_entity_field_to_most_specific_entity( cls, endpoint_field_map: EndpointFieldMap ) -> EntityFieldToMostSpecificEntity: """ Get or calculate mapping from entity fields to their most specific descendant entities. Uses memoization based on the endpoint values and entity class names in the endpoint field map. Parameters ---------- endpoint_field_map Mapping from entity fields to provider tags per endpoint Returns ------- EntityFieldToMostSpecificEntity Dictionary mapping (entity_class, field_name) tuples to the most specific entity class that should be used for that field """ cache_key = repr(endpoint_field_map) if cache_key not in cls._entity_field_to_most_specific_entity_cache: cls._entity_field_to_most_specific_entity_cache[ cache_key ] = cls._calculate_most_specific_field_entity(endpoint_field_map) return cls._entity_field_to_most_specific_entity_cache[cache_key] @staticmethod def _merge_primary_key_subsets_preserving_order( primary_key_subsets_tables: list[PrimaryKeyTable], *, predominant_order_descending: bool = False, ) -> PrimaryKeyTable: """ Merge primary key subsets while preserving consistent ordering. Combines multiple tables containing subsets of primary keys into a single unified ordering using topological sorting to maintain order consistency across all input subsets. Parameters ---------- primary_key_subsets_tables List of tables each containing a subset of primary keys predominant_order_descending Whether the predominant sort order is descending (True) or ascending (False) Returns ------- PrimaryKeyTable Table containing merged primary keys in consistent order Raises ------ DataProviderToolkitRuntimeError When tables have incompatible schemas or have no columns DataProviderMultiEndpointDuplicateKeysError When a single input table has multiple rows sharing the same primary key DataProviderMultiEndpointCommonDataOrderError When input orderings create circular dependencies """ if not primary_key_subsets_tables: return pyarrow.table({}) first_table = primary_key_subsets_tables[0] schema = first_table.schema column_names = schema.names if not column_names: msg = "Primary key merge tables have no columns." raise DataProviderToolkitRuntimeError(msg) graph = networkx.DiGraph() for table in primary_key_subsets_tables: if table.schema != schema: if len(table.column_names) != len(column_names): msg = "Primary key merge tables have different number of columns." elif table.column_names != column_names: msg = "Primary key merge tables have different column names." else: msg = "Primary key merge tables have different column types." raise DataProviderToolkitRuntimeError(msg) if table.num_rows == 0: continue # Filter out rows where all columns are null, as they are not valid keys all_null_mask = pyarrow.compute.is_null( table[column_names[0]] ) for col_name in column_names[1:]: all_null_mask = pyarrow.compute.and_( all_null_mask, pyarrow.compute.is_null( table[col_name] ) ) keep_mask = pyarrow.compute.invert(all_null_mask) filtered_table = table.filter(keep_mask) if filtered_table.num_rows == 0: continue # Check for duplicate rows in the valid key data group_counts = filtered_table.group_by(column_names).aggregate([ ([], "count_all"), ]) if group_counts.num_rows != filtered_table.num_rows: duplicate_keys_table = group_counts.filter( pyarrow.compute.greater(group_counts["count_all"], 1) ).drop(["count_all"]) raise DataProviderMultiEndpointDuplicateKeysError( duplicate_keys_table=duplicate_keys_table, key_column_names=list(column_names), ) rows_as_dicts = filtered_table.to_pylist() rows_as_tuples = [ tuple( row[name] for name in column_names ) for row in rows_as_dicts ] networkx.add_path(graph, rows_as_tuples) if not graph.nodes: return pyarrow.Table.from_pylist([], schema=schema) # Null PK components can't be compared directly against real values (e.g. # None < date raises TypeError). When any PK is null, wrap each component # as (0, None) or (1, value) so tuple comparison short-circuits on the # flag before touching the payload. Skip the wrapping when there are no # nulls so the common case pays zero overhead. has_null_key = any( component is None for node in graph.nodes for component in node ) sort_key = ( (lambda node: tuple((0, None) if c is None else (1, c) for c in node)) if has_null_key else None ) try: if predominant_order_descending: # For descending order, we topologically sort the reversed graph # and then reverse the result. This correctly handles tie-breaking. sorted_rows = list( reversed( list( networkx.lexicographical_topological_sort( graph.reverse(copy=True), key=sort_key, ) ) ) ) else: sorted_rows = list( networkx.lexicographical_topological_sort(graph, key=sort_key) ) except networkx.NetworkXUnfeasible: msg = "Inconsistent key order between tables results in a circular dependency." raise DataProviderMultiEndpointCommonDataOrderError(msg) from None if not sorted_rows: return pyarrow.Table.from_pylist([], schema=schema) columns_as_tuples = list( zip(*sorted_rows, strict=True) ) arrays = [ pyarrow.array(col_data, type=field.type) for (col_data, field) in zip( columns_as_tuples, schema, strict=True ) ] return pyarrow.Table.from_arrays(arrays, names=column_names) @staticmethod def _process_remapped_endpoint_tables( endpoint_field_preprocessors: EndpointFieldPreprocessors, remapped_endpoint_tables: EndpointTables, entity_field_to_most_specific_entity: EntityFieldToMostSpecificEntity, ) -> EndpointTables: """ Apply preprocessor functions to remapped endpoint tables. Executes configured preprocessor chains on input columns to compute derived field values, replacing raw input columns with processed outputs. Parameters ---------- endpoint_field_preprocessors Dictionary mapping endpoints to field preprocessing configurations remapped_endpoint_tables Dictionary mapping endpoints to tables with remapped columns entity_field_to_most_specific_entity Dictionary mapping entity fields to their most specific descendant entities Returns ------- EndpointTables Dictionary mapping endpoints to tables with preprocessed fields """ processed_tables: EndpointTables = {} for (endpoint, table) in remapped_endpoint_tables.items(): if table.num_rows == 0: # Empty endpoint response: no values to preprocess, and the # schemaless table has none of the tag input columns. Pass # through so downstream consolidation can drop it cleanly. processed_tables[endpoint] = table continue # Get preprocessors for this endpoint (if any) field_preprocessors = endpoint_field_preprocessors.get(endpoint, {}) if not field_preprocessors: # No preprocessors for this endpoint, keep table as-is processed_tables[endpoint] = table continue # Track which columns are inputs to preprocessors (will be removed) columns_to_remove = set() # Track new processed columns to add new_columns = {} # Process each field that has preprocessors for (entity_field, preprocessed_mapping) in field_preprocessors.items(): entity_name = entity_field_to_most_specific_entity[entity_field].__name__ field_name = entity_field.__name__ # Build input column names from tags: "entity.field$tag" input_column_names = [ f"{entity_name}.{field_name}${tag}" for tag in preprocessed_mapping.tags ] # Mark input columns for removal columns_to_remove.update(input_column_names) # Load input columns and wrap in DataColumn.load() input_columns = [ DataColumn.load(table[col_name]) for col_name in input_column_names ] # Chain preprocessors result = input_columns for preprocessor in preprocessed_mapping.preprocessors: # Apply preprocessor with current result(s) as positional arguments result = ( preprocessor(*result) if isinstance(result, list) else preprocessor(result) ) # Wrap output in DataColumn.load() for next preprocessor result = DataColumn.load(result) # Get final pyarrow.Array final_column = result.to_pyarrow() # Store with name "entity.field" (without $tag suffix) output_column_name = f"{entity_name}.{field_name}" new_columns[output_column_name] = final_column # Build the new table: keep non-processed columns + add processed columns result_columns_dict = {} # Add columns that weren't processed for col_name in table.column_names: if col_name not in columns_to_remove: result_columns_dict[col_name] = table[col_name] # Add newly processed columns result_columns_dict.update(new_columns) processed_tables[endpoint] = pyarrow.table(result_columns_dict) return processed_tables @staticmethod def _remap_endpoint_table_columns( endpoint_column_remaps: EndpointColumnRemaps, endpoint_tables: EndpointTables, entity_field_to_most_specific_entity: EntityFieldToMostSpecificEntity, ) -> EndpointTables: """ Rename table columns from provider tags to entity field format. Transforms column names in endpoint tables according to the provided remapping configuration, duplicating columns when needed for preprocessor inputs. Parameters ---------- endpoint_column_remaps Dictionary mapping endpoints to tags to column renames endpoint_tables Dictionary mapping endpoints to raw tables entity_field_to_most_specific_entity Dictionary mapping entity fields to their most specific descendant entities Returns ------- EndpointTables Dictionary mapping endpoints to tables with remapped column names """ # Build reverse lookup: (entity_name, field_name) -> most_specific_entity_name entity_field_lookup = { (entity_field.__objclass__.__name__, entity_field.__name__): most_spcific_entity.__name__ for (entity_field, most_spcific_entity) in entity_field_to_most_specific_entity.items() } remapped_tables: EndpointTables = {} for (endpoint, table) in endpoint_tables.items(): if endpoint not in endpoint_column_remaps: remapped_tables[endpoint] = table continue column_remaps = endpoint_column_remaps[endpoint] new_columns_dict = {} for tag_name in table.column_names: if tag_name not in column_remaps: # Column not in remaps, skip it continue original_column = table[tag_name] for old_remap_name in column_remaps[tag_name]: # Parse: "entity.field" or "entity.field$tag" base_name, _, suffix = old_remap_name.partition('$') entity_name, field_name = base_name.split('.', 1) # Look up most specific entity most_specific_entity_name = entity_field_lookup.get( (entity_name, field_name), entity_name # Fallback to original if not found ) # Build new column name new_column_name = f"{most_specific_entity_name}.{field_name}" if suffix: new_column_name += f"${suffix}" new_columns_dict[new_column_name] = original_column remapped_tables[endpoint] = pyarrow.table(new_columns_dict) return remapped_tables