org.apache.flink.graph.example.MusicProfiles Maven / Gradle / Ivy
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.graph.example;
import java.util.ArrayList;
import java.util.List;
import org.apache.flink.api.common.ProgramDescription;
import org.apache.flink.api.common.functions.CoGroupFunction;
import org.apache.flink.api.common.functions.FilterFunction;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.tuple.Tuple1;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.graph.Edge;
import org.apache.flink.graph.EdgeDirection;
import org.apache.flink.graph.EdgesFunctionWithVertexValue;
import org.apache.flink.graph.Graph;
import org.apache.flink.graph.Vertex;
import org.apache.flink.graph.example.utils.MusicProfilesData;
import org.apache.flink.graph.library.LabelPropagationAlgorithm;
import org.apache.flink.types.NullValue;
import org.apache.flink.util.Collector;
@SuppressWarnings("serial")
public class MusicProfiles implements ProgramDescription {
/**
* This example demonstrates how to mix the "record" Flink API with the
* graph API. The input is a set triplets and
* a set of bad records,i.e. song ids that should not be trusted. Initially,
* we use the record API to filter out the bad records. Then, we use the
* graph API to create a user -> song weighted bipartite graph and compute
* the top song (most listened) per user. Then, we use the record API again,
* to create a user-user similarity graph, based on common songs, where two
* users that listen to the same song are connected. Finally, we use the
* graph API to run the label propagation community detection algorithm on
* the similarity graph.
*
* The triplets input is expected to be given as one triplet per line,
* in the following format: "\t\t".
*
* The mismatches input file is expected to contain one mismatch record per line,
* in the following format:
* "ERROR: song_title"
*
* If no arguments are provided, the example runs with default data from {@link MusicProfilesData}.
*/
public static void main(String[] args) throws Exception {
if (!parseParameters(args)) {
return;
}
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
/**
* Read the user-song-play triplets.
*/
DataSet> triplets = getUserSongTripletsData(env);
/**
* Read the mismatches dataset and extract the songIDs
*/
DataSet> mismatches = getMismatchesData(env).map(new ExtractMismatchSongIds());
/**
* Filter out the mismatches from the triplets dataset
*/
DataSet> validTriplets = triplets
.coGroup(mismatches).where(1).equalTo(0)
.with(new FilterOutMismatches());
/**
* Create a user -> song weighted bipartite graph where the edge weights
* correspond to play counts
*/
Graph userSongGraph = Graph.fromTupleDataSet(validTriplets, env);
/**
* Get the top track (most listened) for each user
*/
DataSet> usersWithTopTrack = userSongGraph
.groupReduceOnEdges(new GetTopSongPerUser(), EdgeDirection.OUT)
.filter(new FilterSongNodes());
if (fileOutput) {
usersWithTopTrack.writeAsCsv(topTracksOutputPath, "\n", "\t");
} else {
usersWithTopTrack.print();
}
/**
* Create a user-user similarity graph, based on common songs, i.e. two
* users that listen to the same song are connected. For each song, we
* create an edge between each pair of its in-neighbors.
*/
DataSet> similarUsers = userSongGraph
.getEdges().groupBy(1)
.reduceGroup(new CreateSimilarUserEdges()).distinct();
Graph similarUsersGraph = Graph.fromDataSet(similarUsers,
new MapFunction() {
public Long map(String value) {
return 1l;
}
}, env).getUndirected();
/**
* Detect user communities using the label propagation library method
*/
// Initialize each vertex with a unique numeric label
DataSet> idsWithInitialLabels = similarUsersGraph
.getVertices().reduceGroup(new AssignInitialLabelReducer());
// update the vertex values and run the label propagation algorithm
DataSet> verticesWithCommunity = similarUsersGraph
.joinWithVertices(idsWithInitialLabels,
new MapFunction, Long>() {
public Long map(Tuple2 value) {
return value.f1;
}
}).run(new LabelPropagationAlgorithm(maxIterations))
.getVertices();
if (fileOutput) {
verticesWithCommunity.writeAsCsv(communitiesOutputPath, "\n", "\t");
// since file sinks are lazy, we trigger the execution explicitly
env.execute();
} else {
verticesWithCommunity.print();
}
}
public static final class ExtractMismatchSongIds implements MapFunction> {
public Tuple1 map(String value) {
String[] tokens = value.split("\\s+");
String songId = tokens[1].substring(1);
return new Tuple1(songId);
}
}
public static final class FilterOutMismatches implements CoGroupFunction,
Tuple1, Tuple3> {
public void coGroup(Iterable> triplets,
Iterable> invalidSongs, Collector> out) {
if (!invalidSongs.iterator().hasNext()) {
// this is a valid triplet
for (Tuple3 triplet : triplets) {
out.collect(triplet);
}
}
}
}
public static final class FilterSongNodes implements FilterFunction> {
public boolean filter(Tuple2 value) throws Exception {
return !value.f1.equals("");
}
}
public static final class GetTopSongPerUser implements EdgesFunctionWithVertexValue> {
public void iterateEdges(Vertex vertex,
Iterable> edges, Collector> out) throws Exception {
int maxPlaycount = 0;
String topSong = "";
for (Edge edge : edges) {
if (edge.getValue() > maxPlaycount) {
maxPlaycount = edge.getValue();
topSong = edge.getTarget();
}
}
out.collect(new Tuple2(vertex.getId(), topSong));
}
}
public static final class CreateSimilarUserEdges implements GroupReduceFunction,
Edge> {
public void reduce(Iterable> edges, Collector> out) {
List listeners = new ArrayList();
for (Edge edge : edges) {
listeners.add(edge.getSource());
}
for (int i = 0; i < listeners.size() - 1; i++) {
for (int j = i + 1; j < listeners.size(); j++) {
out.collect(new Edge(listeners.get(i),
listeners.get(j), NullValue.getInstance()));
}
}
}
}
public static final class AssignInitialLabelReducer implements GroupReduceFunction,
Tuple2> {
public void reduce(Iterable> vertices, Collector> out) {
long label = 0;
for (Vertex vertex : vertices) {
out.collect(new Tuple2(vertex.getId(), label));
label++;
}
}
}
@Override
public String getDescription() {
return "Music Profiles Example";
}
// ******************************************************************************************************************
// UTIL METHODS
// ******************************************************************************************************************
private static boolean fileOutput = false;
private static String userSongTripletsInputPath = null;
private static String mismatchesInputPath = null;
private static String topTracksOutputPath = null;
private static String communitiesOutputPath = null;
private static int maxIterations = 10;
private static boolean parseParameters(String[] args) {
if(args.length > 0) {
if(args.length != 5) {
System.err.println("Usage: MusicProfiles " +
"
© 2015 - 2025 Weber Informatics LLC | Privacy Policy