# problem_3.py

# call to .where() after .map() with pandas type function
# also resets column names
# and doesn't really filter values

import pandas as pd

t_env = TableEnvironment.create(EnvironmentSettings.in_streaming_mode())

table = t_env.from_elements(
    elements=[
        (1, 'China'),
        (2, 'Germany'),
        (3, 'China'),
    ],
    schema=['id', 'country'],
)

@udf(
    result_type=(
        'Row<id INT, country STRING>'
    ),
    func_type="pandas",
)
def example_map_a(df: pd.DataFrame):
    columns = sorted(df.columns)
    print(f'example_map_a: {columns=}')
    # prints:
    # example_map_a: columns=['country', 'id']
    assert columns == ['country', 'id'], columns
    return df


@udf(
    result_type=(
        'Row<id INT, country STRING>'
    ),
    func_type="pandas",
)
def example_map_b(df: pd.DataFrame):
    columns = sorted(df.columns)
    print(f'example_map_b: {columns=}')
    # Prints:
    # example_map_b: columns=['f0', 'f1']

    print(f'example_map_b df: {df=}')
    # Prints:
    # df=   f0       f1
    # 0   1    China
    # 1   2  Germany
    # 2   3    China
    # Although China was expected to be filtered out.

    # Raises:
    # AssertionError: ['f0', 'f1']
    assert columns == ['country', 'id'], columns
    return df

# Will raise with
# AssertionError: ['f0', 'f1']

flow = (
    table
    .map(example_map_a)
    .where(col('country') == 'Germany')
    .map(example_map_b)
    .execute().print()
)

Reply via email to