All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.dataartisans.flinktraining.exercises.datastream_java.state.TravelTimePrediction Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2015 data Artisans GmbH
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *  http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.dataartisans.flinktraining.exercises.datastream_java.state;

import com.dataartisans.flinktraining.exercises.datastream_java.datatypes.TaxiRide;
import com.dataartisans.flinktraining.exercises.datastream_java.sources.CheckpointedTaxiRideSource;
import com.dataartisans.flinktraining.exercises.datastream_java.utils.GeoUtils;
import com.dataartisans.flinktraining.exercises.datastream_java.utils.TravelTimePredictionModel;
import org.apache.flink.api.common.functions.FilterFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.common.restartstrategy.RestartStrategies;
import org.apache.flink.api.common.state.ValueState;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.common.time.Time;
import org.apache.flink.api.common.typeinfo.TypeHint;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.utils.ParameterTool;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.streaming.api.TimeCharacteristic;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.util.Collector;

import java.util.concurrent.TimeUnit;

/**
 * Java reference implementation for the "Travel Time Prediction" exercise of the Flink training
 * (http://dataartisans.github.io/flink-training).
 *
 * The task of the exercise is to continuously train a regression model that predicts
 * the travel time of a taxi based on the information of taxi ride end events.
 * For taxi ride start events, the model should be queried to estimate its travel time.
 *
 * Parameters:
 * -input path-to-input-file
 *
 */
public class TravelTimePrediction {

	public static void main(String[] args) throws Exception {

		ParameterTool params = ParameterTool.fromArgs(args);
		final String input = params.getRequired("input");

		final int servingSpeedFactor = 600; // events of 10 minutes are served in 1 second

		// set up streaming execution environment
		StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
		// operate in Event-time
		env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime);
		// create a checkpoint every 5 seconds
		env.enableCheckpointing(5000);
		// try to restart 60 times with 10 seconds delay (10 Minutes)
		env.setRestartStrategy(RestartStrategies.fixedDelayRestart(60, Time.of(10, TimeUnit.SECONDS)));

		// start the data generator
		DataStream rides = env.addSource(
				new CheckpointedTaxiRideSource(input, servingSpeedFactor));

		DataStream> predictions = rides
			// filter out rides that do not start or stop in NYC
			.filter(new NYCFilter())
			// map taxi ride events to the grid cell of the destination
			.map(new GridCellMatcher())
			// organize stream by destination
			.keyBy(0)
			// predict and refine model per destination
			.flatMap(new PredictionModel());

		// print the predictions
		predictions.print();

		// run the prediction pipeline
		env.execute("Taxi Ride Prediction");
	}

	public static class NYCFilter implements FilterFunction {

		@Override
		public boolean filter(TaxiRide taxiRide) throws Exception {

			return GeoUtils.isInNYC(taxiRide.startLon, taxiRide.startLat) &&
					GeoUtils.isInNYC(taxiRide.endLon, taxiRide.endLat);
		}
	}

	/**
	 * Maps the taxi ride event to the grid cell of the destination location.
	 */
	public static class GridCellMatcher implements MapFunction> {

		@Override
		public Tuple2 map(TaxiRide ride) throws Exception {
			int endCell = GeoUtils.mapToGridCell(ride.endLon, ride.endLat);

			return new Tuple2<>(endCell, ride);
		}
	}

	/**
	 * Predicts the travel time for taxi ride start events based on distance and direction.
	 * Incrementally trains a regression model using taxi ride end events.
	 */
	public static class PredictionModel extends RichFlatMapFunction, Tuple2> {

		private transient ValueState modelState;

		@Override
		public void flatMap(Tuple2 val, Collector> out) throws Exception {

			// fetch operator state
			TravelTimePredictionModel model = modelState.value();

			TaxiRide ride = val.f1;
			// compute distance and direction
			double distance = GeoUtils.getEuclideanDistance(ride.startLon, ride.startLat, ride.endLon, ride.endLat);
			int direction = GeoUtils.getDirectionAngle(ride.endLon, ride.endLat, ride.startLon, ride.startLat);

			if(ride.isStart) {
				// we have a start event: Predict travel time
				int predictedTime = model.predictTravelTime(direction, distance);
				// emit prediction
				out.collect(new Tuple2<>(ride.rideId, predictedTime));
			} else {
				// we have an end event: Update model
				// compute travel time in minutes
				double travelTime = (ride.endTime.getMillis() - ride.startTime.getMillis()) / 60000.0;
				// refine model
				model.refineModel(direction, distance, travelTime);
				// update operator state
				modelState.update(model);
			}
		}

		@Override
		public void open(Configuration config) {
			// obtain key-value state for prediction model
			ValueStateDescriptor descriptor =
					new ValueStateDescriptor<>(
							// state name
							"regressionModel",
							// type information of state
							TypeInformation.of(new TypeHint() {}),
							// default value of state
							new TravelTimePredictionModel());
			modelState = getRuntimeContext().getState(descriptor);
		}
	}

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy