com.dataartisans.flinktraining.exercises.datastream_java.state.TravelTimePrediction Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of flink-training-exercises Show documentation
Show all versions of flink-training-exercises Show documentation
Utilities and material for an Apache Flink Training provided by data Artisans.
The newest version!
/*
* Copyright 2015 data Artisans GmbH
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.dataartisans.flinktraining.exercises.datastream_java.state;
import com.dataartisans.flinktraining.exercises.datastream_java.datatypes.TaxiRide;
import com.dataartisans.flinktraining.exercises.datastream_java.sources.CheckpointedTaxiRideSource;
import com.dataartisans.flinktraining.exercises.datastream_java.utils.GeoUtils;
import com.dataartisans.flinktraining.exercises.datastream_java.utils.TravelTimePredictionModel;
import org.apache.flink.api.common.functions.FilterFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.common.restartstrategy.RestartStrategies;
import org.apache.flink.api.common.state.ValueState;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.common.time.Time;
import org.apache.flink.api.common.typeinfo.TypeHint;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.utils.ParameterTool;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.streaming.api.TimeCharacteristic;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.util.Collector;
import java.util.concurrent.TimeUnit;
/**
* Java reference implementation for the "Travel Time Prediction" exercise of the Flink training
* (http://dataartisans.github.io/flink-training).
*
* The task of the exercise is to continuously train a regression model that predicts
* the travel time of a taxi based on the information of taxi ride end events.
* For taxi ride start events, the model should be queried to estimate its travel time.
*
* Parameters:
* -input path-to-input-file
*
*/
public class TravelTimePrediction {
public static void main(String[] args) throws Exception {
ParameterTool params = ParameterTool.fromArgs(args);
final String input = params.getRequired("input");
final int servingSpeedFactor = 600; // events of 10 minutes are served in 1 second
// set up streaming execution environment
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
// operate in Event-time
env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime);
// create a checkpoint every 5 seconds
env.enableCheckpointing(5000);
// try to restart 60 times with 10 seconds delay (10 Minutes)
env.setRestartStrategy(RestartStrategies.fixedDelayRestart(60, Time.of(10, TimeUnit.SECONDS)));
// start the data generator
DataStream rides = env.addSource(
new CheckpointedTaxiRideSource(input, servingSpeedFactor));
DataStream> predictions = rides
// filter out rides that do not start or stop in NYC
.filter(new NYCFilter())
// map taxi ride events to the grid cell of the destination
.map(new GridCellMatcher())
// organize stream by destination
.keyBy(0)
// predict and refine model per destination
.flatMap(new PredictionModel());
// print the predictions
predictions.print();
// run the prediction pipeline
env.execute("Taxi Ride Prediction");
}
public static class NYCFilter implements FilterFunction {
@Override
public boolean filter(TaxiRide taxiRide) throws Exception {
return GeoUtils.isInNYC(taxiRide.startLon, taxiRide.startLat) &&
GeoUtils.isInNYC(taxiRide.endLon, taxiRide.endLat);
}
}
/**
* Maps the taxi ride event to the grid cell of the destination location.
*/
public static class GridCellMatcher implements MapFunction> {
@Override
public Tuple2 map(TaxiRide ride) throws Exception {
int endCell = GeoUtils.mapToGridCell(ride.endLon, ride.endLat);
return new Tuple2<>(endCell, ride);
}
}
/**
* Predicts the travel time for taxi ride start events based on distance and direction.
* Incrementally trains a regression model using taxi ride end events.
*/
public static class PredictionModel extends RichFlatMapFunction, Tuple2> {
private transient ValueState modelState;
@Override
public void flatMap(Tuple2 val, Collector> out) throws Exception {
// fetch operator state
TravelTimePredictionModel model = modelState.value();
TaxiRide ride = val.f1;
// compute distance and direction
double distance = GeoUtils.getEuclideanDistance(ride.startLon, ride.startLat, ride.endLon, ride.endLat);
int direction = GeoUtils.getDirectionAngle(ride.endLon, ride.endLat, ride.startLon, ride.startLat);
if(ride.isStart) {
// we have a start event: Predict travel time
int predictedTime = model.predictTravelTime(direction, distance);
// emit prediction
out.collect(new Tuple2<>(ride.rideId, predictedTime));
} else {
// we have an end event: Update model
// compute travel time in minutes
double travelTime = (ride.endTime.getMillis() - ride.startTime.getMillis()) / 60000.0;
// refine model
model.refineModel(direction, distance, travelTime);
// update operator state
modelState.update(model);
}
}
@Override
public void open(Configuration config) {
// obtain key-value state for prediction model
ValueStateDescriptor descriptor =
new ValueStateDescriptor<>(
// state name
"regressionModel",
// type information of state
TypeInformation.of(new TypeHint() {}),
// default value of state
new TravelTimePredictionModel());
modelState = getRuntimeContext().getState(descriptor);
}
}
}