要保存预处理的 Tensorflow Transform 函数,可以使用 tft_beam.Context 对象的 export 方法来导出函数。下面是一个完整的代码示例:
import tensorflow_transform as tft
import tensorflow_transform.beam as tft_beam
import tensorflow_transform.tf_metadata as metadata
import apache_beam as beam
# 定义预处理函数
def preprocessing_fn(inputs):
# 执行预处理操作
outputs = {}
outputs['output_feature'] = tf.math.log(inputs['input_feature'])
return outputs
# 定义输入元数据
raw_metadata = metadata.Schema({'input_feature': metadata.TensorRepresentation(tf.float32, [None])})
raw_data = [{'input_feature': 1.0}, {'input_feature': 2.0}, {'input_feature': 3.0}]
raw_metadata.dataset_metadata = metadata.DatasetMetadata(metadata.Schema({}))
# 创建 Tensorflow Transform 预处理管道
with beam.Pipeline() as pipeline:
with tft_beam.Context(temp_dir='tmp_dir'):
# 创建输入 PCollection
raw_data_pcoll = pipeline | beam.Create(raw_data)
# 使用原始元数据定义并运行预处理函数
transformed_dataset, transform_fn = (
(raw_data_pcoll, raw_metadata) |
tft_beam.AnalyzeAndTransformDataset(preprocessing_fn))
# 导出预处理函数
transformed_data, transformed_metadata = transformed_dataset
tft_beam_io.WriteTransformFn(transform_fn, 'transform_fn')
# 加载导出的预处理函数
transform_fn = tft_beam_io.transform_fn_io.load_transform_fn('transform_fn')
# 使用导出的预处理函数进行转换
transformed_data = transform_fn.transform_raw_data(raw_data)
在上面的示例中,我们首先定义了一个简单的预处理函数 preprocessing_fn,该函数将输入特征取对数。然后,我们定义了输入元数据 raw_metadata,并创建了原始数据 raw_data。接下来,我们使用 tft_beam.Context 对象创建了预处理管道,并使用 tft_beam.AnalyzeAndTransformDataset 运行预处理函数。最后,我们使用 tft_beam_io.WriteTransformFn 导出预处理函数到指定路径,并使用 tft_beam_io.transform_fn_io.load_transform_fn 加载导出的预处理函数。最后,我们使用加载的预处理函数对原始数据进行转换。