All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.flink.graph.example.MusicProfiles Maven / Gradle / Ivy

There is a newer version: 1.16.3
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.graph.example;

import java.util.ArrayList;
import java.util.List;

import org.apache.flink.api.common.ProgramDescription;
import org.apache.flink.api.common.functions.CoGroupFunction;
import org.apache.flink.api.common.functions.FilterFunction;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.tuple.Tuple1;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.utils.DataSetUtils;
import org.apache.flink.graph.Edge;
import org.apache.flink.graph.EdgeDirection;
import org.apache.flink.graph.EdgesFunctionWithVertexValue;
import org.apache.flink.graph.Graph;
import org.apache.flink.graph.Vertex;
import org.apache.flink.graph.VertexJoinFunction;
import org.apache.flink.graph.example.utils.MusicProfilesData;
import org.apache.flink.graph.library.LabelPropagation;
import org.apache.flink.types.NullValue;
import org.apache.flink.util.Collector;

/**
 * This example demonstrates how to mix the DataSet Flink API with the Gelly API.
 * The input is a set <userId - songId - playCount> triplets and
 * a set of bad records, i.e. song ids that should not be trusted.
 * Initially, we use the DataSet API to filter out the bad records.
 * Then, we use Gelly to create a user -> song weighted bipartite graph and compute
 * the top song (most listened) per user.
 * Then, we use the DataSet API again, to create a user-user similarity graph,
 * based on common songs, where users that are listeners of the same song
 * are connected. A user-defined threshold on the playcount value
 * defines when a user is considered to be a listener of a song.
 * Finally, we use the graph API to run the label propagation community detection algorithm on
 * the similarity graph.
 *
 * The triplets input is expected to be given as one triplet per line,
 * in the following format: "<userID>\t<songID>\t<playcount>".
 *
 * The mismatches input file is expected to contain one mismatch record per line,
 * in the following format:
 * "ERROR: <songID trackID> song_title"
 *
 * If no arguments are provided, the example runs with default data from {@link MusicProfilesData}.
 */
@SuppressWarnings("serial")
public class MusicProfiles implements ProgramDescription {

	public static void main(String[] args) throws Exception {

		if (!parseParameters(args)) {
			return;
		}

		ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();

		/**
		 * Read the user-song-play triplets.
		 */
		DataSet> triplets = getUserSongTripletsData(env);

		/**
		 * Read the mismatches dataset and extract the songIDs
		 */
		DataSet> mismatches = getMismatchesData(env).map(new ExtractMismatchSongIds());

		/**
		 * Filter out the mismatches from the triplets dataset
		 */
		DataSet> validTriplets = triplets
				.coGroup(mismatches).where(1).equalTo(0)
				.with(new FilterOutMismatches());

		/**
		 * Create a user -> song weighted bipartite graph where the edge weights
		 * correspond to play counts
		 */
		Graph userSongGraph = Graph.fromTupleDataSet(validTriplets, env);

		/**
		 * Get the top track (most listened) for each user
		 */
		DataSet> usersWithTopTrack = userSongGraph
				.groupReduceOnEdges(new GetTopSongPerUser(), EdgeDirection.OUT)
				.filter(new FilterSongNodes());

		if (fileOutput) {
			usersWithTopTrack.writeAsCsv(topTracksOutputPath, "\n", "\t");
		} else {
			usersWithTopTrack.print();
		}

		/**
		 * Create a user-user similarity graph, based on common songs, i.e. two
		 * users that listen to the same song are connected. For each song, we
		 * create an edge between each pair of its in-neighbors.
		 */
		DataSet> similarUsers = userSongGraph
				.getEdges()
				// filter out user-song edges that are below the playcount threshold
				.filter(new FilterFunction>() {
					public boolean filter(Edge edge) {
						return (edge.getValue() > playcountThreshold);
					}
				}).groupBy(1)
				.reduceGroup(new CreateSimilarUserEdges()).distinct();

		Graph similarUsersGraph = Graph.fromDataSet(similarUsers,
				new MapFunction() {
					public Long map(String value) {
						return 1l;
					}
				}, env).getUndirected();

		/**
		 * Detect user communities using the label propagation library method
		 */
		// Initialize each vertex with a unique numeric label and run the label propagation algorithm
		DataSet> idsWithInitialLabels = DataSetUtils
				.zipWithUniqueId(similarUsersGraph.getVertexIds())
				.map(new MapFunction, Tuple2>() {
					@Override
					public Tuple2 map(Tuple2 tuple2) throws Exception {
						return new Tuple2(tuple2.f1, tuple2.f0);
					}
				});

		DataSet> verticesWithCommunity = similarUsersGraph
				.joinWithVertices(idsWithInitialLabels,
						new VertexJoinFunction() {
							public Long vertexJoin(Long vertexValue, Long inputValue) {
								return inputValue;
							}
						}).run(new LabelPropagation(maxIterations));

		if (fileOutput) {
			verticesWithCommunity.writeAsCsv(communitiesOutputPath, "\n", "\t");

			// since file sinks are lazy, we trigger the execution explicitly
			env.execute();
		} else {
			verticesWithCommunity.print();
		}

	}

	public static final class ExtractMismatchSongIds implements MapFunction> {

		public Tuple1 map(String value) {
			String[] tokens = value.split("\\s+");
			String songId = tokens[1].substring(1);
			return new Tuple1(songId);
		}
	}

	public static final class FilterOutMismatches implements CoGroupFunction,
		Tuple1, Tuple3> {

		public void coGroup(Iterable> triplets,
				Iterable> invalidSongs, Collector> out) {

			if (!invalidSongs.iterator().hasNext()) {
				// this is a valid triplet
				for (Tuple3 triplet : triplets) {
					out.collect(triplet);
				}
			}
		}
	}

	public static final class FilterSongNodes implements FilterFunction> {
		public boolean filter(Tuple2 value) throws Exception {
			return !value.f1.equals("");
		}
	}

	public static final class GetTopSongPerUser	implements EdgesFunctionWithVertexValue> {

		public void iterateEdges(Vertex vertex,
				Iterable> edges, Collector> out) throws Exception {

			int maxPlaycount = 0;
			String topSong = "";
			for (Edge edge : edges) {
				if (edge.getValue() > maxPlaycount) {
					maxPlaycount = edge.getValue();
					topSong = edge.getTarget();
				}
			}
			out.collect(new Tuple2(vertex.getId(), topSong));
		}
	}

	public static final class CreateSimilarUserEdges implements GroupReduceFunction,
		Edge> {

		public void reduce(Iterable> edges, Collector> out) {
			List listeners = new ArrayList();
			for (Edge edge : edges) {
				listeners.add(edge.getSource());
			}
			for (int i = 0; i < listeners.size() - 1; i++) {
				for (int j = i + 1; j < listeners.size(); j++) {
					out.collect(new Edge(listeners.get(i),
							listeners.get(j), NullValue.getInstance()));
				}
			}
		}
	}

	@Override
	public String getDescription() {
		return "Music Profiles Example";
	}

	// ******************************************************************************************************************
	// UTIL METHODS
	// ******************************************************************************************************************

	private static boolean fileOutput = false;

	private static String userSongTripletsInputPath = null;

	private static String mismatchesInputPath = null;

	private static String topTracksOutputPath = null;

	private static int playcountThreshold = 0;

	private static String communitiesOutputPath = null;

	private static int maxIterations = 10;

	private static boolean parseParameters(String[] args) {

		if(args.length > 0) {
			if(args.length != 6) {
				System.err.println("Usage: MusicProfiles " +
						"   "
						+ "  ");
				return false;
			}

			fileOutput = true;
			userSongTripletsInputPath = args[0];
			mismatchesInputPath = args[1];
			topTracksOutputPath = args[2];
			playcountThreshold = Integer.parseInt(args[3]);
			communitiesOutputPath = args[4];
			maxIterations = Integer.parseInt(args[5]);
		} else {
			System.out.println("Executing Music Profiles example with default parameters and built-in default data.");
			System.out.println("  Provide parameters to read input data from files.");
			System.out.println("  See the documentation for the correct format of input files.");
			System.out.println("Usage: MusicProfiles " +
					"   "
					+ "  ");
		}
		return true;
	}

	private static DataSet> getUserSongTripletsData(ExecutionEnvironment env) {
		if (fileOutput) {
			return env.readCsvFile(userSongTripletsInputPath)
					.lineDelimiter("\n").fieldDelimiter("\t")
					.types(String.class, String.class, Integer.class);
		} else {
			return MusicProfilesData.getUserSongTriplets(env);
		}
	}

	private static DataSet getMismatchesData(ExecutionEnvironment env) {
		if (fileOutput) {
			return env.readTextFile(mismatchesInputPath);
		} else {
			return MusicProfilesData.getMismatches(env);
		}
	}
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy