All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.intel.analytics.zoo.feature.common.Relations.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2018 Analytics Zoo Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.intel.analytics.zoo.feature.common

import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext

import scala.collection.mutable.{Map => MMap, ArrayBuffer}
import scala.io.Source

object Relations {

  /**
   * Read relations from csv or txt file.
   * Each record is supposed to contain the following three fields in order:
   * id1(String), id2(String) and label(Integer).
   *
   * For csv file, it should be without header.
   * For txt file, each line should contain one record with fields separated by comma.
   *
   * @param path The path to the relations file, which can either be a local or disrtibuted file
   *             system (such as HDFS) path.
   * @param sc An instance of SparkContext.
   * @param minPartitions Integer. A suggestion value of the minimal partition number for input
   *                      texts. Default is 1.
   * @return RDD of [[Relation]].
   */
  def read(path: String, sc: SparkContext, minPartitions: Int = 1): RDD[Relation] = {
    sc.textFile(path, minPartitions).map(line => {
      val subs = line.split(",")
      Relation(subs(0), subs(1), subs(2).toInt)
    })
  }

  /**
   * Read relations from csv or txt file.
   * Each record is supposed to contain the following three fields in order:
   * id1(String), id2(String) and label(Integer).
   *
   * For csv file, it should be without header.
   * For txt file, each line should contain one record with fields separated by comma.
   *
   * @param path The local file path to the relations file.
   * @return Array of [[Relation]].
   */
  def read(path: String): Array[Relation] = {
    val src = Source.fromFile(path)
    src.getLines().toArray.map(line => {
      val subs = line.split(",")
      Relation(subs(0), subs(1), subs(2).toInt)
    })
  }

  /**
   * Read relations from parquet file.
   * Schema should be the following:
   * "id1"(String), "id2"(String) and "label"(Integer).
   *
   * @param path The path to the parquet file.
   * @param sqlContext An instance of SQLContext.
   * @return RDD of [[Relation]].
   */
  def readParquet(path: String, sqlContext: SQLContext): RDD[Relation] = {
    sqlContext.read.parquet(path).rdd.map(row => {
      val id1 = row.getAs[String]("id1")
      val id2 = row.getAs[String]("id2")
      val label = row.getAs[Int]("label")
      Relation(id1, id2, label)
    })
  }

  /**
   * Generate all [[RelationPair]]s from given [[Relation]]s.
   * Essentially, for each positive relation (id1 and id2 with label>0), it will be
   * paired with every negative relation of the same id1 (id2 with label=0).
   */
  def generateRelationPairs(relations: RDD[Relation]): RDD[RelationPair] = {
    val positive = relations.filter(_.label > 0).groupBy(_.id1)
    val negative = relations.filter(_.label == 0).groupBy(_.id1)
    positive.cogroup(negative).flatMap(x => {
      val posIDs = x._2._1.flatten.toArray.map(_.id2)
      val negIDs = x._2._2.flatten.toArray.map(_.id2)
      posIDs.flatMap(y => negIDs.map(z => RelationPair(x._1, y, z)))
    })
  }

  /**
   * generateRelationPairs for Relation array
   */
  def generateRelationPairs(relations: Array[Relation]): Array[RelationPair] = {
    val relSet: MMap[String, MMap[Int, ArrayBuffer[String]]] = MMap()
    val pairList: ArrayBuffer[RelationPair] = ArrayBuffer()
    for (relation <- relations) {
      if (! relSet.contains(relation.id1)) {
        val id2Array: ArrayBuffer[String] = ArrayBuffer()
        id2Array.append(relation.id2)
        relSet(relation.id1) = MMap(relation.label -> id2Array)
      }
      else {
        val labelMap = relSet(relation.id1)
        if (! labelMap.contains(relation.label)) {
          val id2Array: ArrayBuffer[String] = ArrayBuffer()
          id2Array.append(relation.id2)
          labelMap(relation.label) = id2Array
          relSet(relation.id1) = labelMap
        }
        else {
          labelMap(relation.label).append(relation.id2)
        }
      }
    }

    for ((id1, labelMap) <- relSet) {
      if (labelMap.contains(0) && labelMap.contains(1)) {
        val negatives = labelMap(0).toArray
        val positives = labelMap(1).toArray
        for (id2Positive <- positives) {
          for (id2Negative <- negatives) {
            val pair = RelationPair(id1, id2Positive, id2Negative)
            pairList.append(pair)
          }
        }
      }
    }
    pairList.toArray
  }
}

/**
 * It represents the relationship between two items.
 */
case class Relation(id1: String, id2: String, label: Int)

/**
 * A relation pair is made up of two relations of the same id1:
 * Relation(id1, id2Positive, label>0) [Positive Relation]
 * Relation(id1, id2Negative, label=0) [Negative Relation]
 */
case class RelationPair(id1: String, id2Positive: String, id2Negative: String)




© 2015 - 2025 Weber Informatics LLC | Privacy Policy