# problem_3.py
# call to .where() after .map() with pandas type function
# also resets column names
# and doesn't really filter values
import pandas as pd
t_env = TableEnvironment.create(EnvironmentSettings.in_streaming_mode())
table = t_env.from_elements(
elements=[
(1, 'China'),
(2, 'Germany'),
(3, 'China'),
],
schema=['id', 'country'],
)
@udf(
result_type=(
'Row<id INT, country STRING>'
),
func_type="pandas",
)
def example_map_a(df: pd.DataFrame):
columns = sorted(df.columns)
print(f'example_map_a: {columns=}')
# prints:
# example_map_a: columns=['country', 'id']
assert columns == ['country', 'id'], columns
return df
@udf(
result_type=(
'Row<id INT, country STRING>'
),
func_type="pandas",
)
def example_map_b(df: pd.DataFrame):
columns = sorted(df.columns)
print(f'example_map_b: {columns=}')
# Prints:
# example_map_b: columns=['f0', 'f1']
print(f'example_map_b df: {df=}')
# Prints:
# df= f0 f1
# 0 1 China
# 1 2 Germany
# 2 3 China
# Although China was expected to be filtered out.
# Raises:
# AssertionError: ['f0', 'f1']
assert columns == ['country', 'id'], columns
return df
# Will raise with
# AssertionError: ['f0', 'f1']
flow = (
table
.map(example_map_a)
.where(col('country') == 'Germany')
.map(example_map_b)
.execute().print()
)