All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.mahout.cf.taste.impl.recommender.svd.FilePersistenceStrategy Maven / Gradle / Ivy

/**
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.mahout.cf.taste.impl.recommender.svd;

import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.DataInput;
import java.io.DataInputStream;
import java.io.DataOutput;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.Map;

import com.google.common.base.Preconditions;
import org.apache.mahout.cf.taste.common.NoSuchItemException;
import org.apache.mahout.cf.taste.common.NoSuchUserException;
import org.apache.mahout.cf.taste.impl.common.FastByIDMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/** Provides a file-based persistent store. */
public class FilePersistenceStrategy implements PersistenceStrategy {

  private final File file;

  private static final Logger log = LoggerFactory.getLogger(FilePersistenceStrategy.class);

  /**
   * @param file the file to use for storage. If the file does not exist it will be created when required.
   */
  public FilePersistenceStrategy(File file) {
    this.file = Preconditions.checkNotNull(file);
  }

  @Override
  public Factorization load() throws IOException {
    if (!file.exists()) {
      log.info("{} does not yet exist, no factorization found", file.getAbsolutePath());
      return null;
    }
    try (DataInputStream in = new DataInputStream(new BufferedInputStream(new FileInputStream(file)))){
      log.info("Reading factorization from {}...", file.getAbsolutePath());
      return readBinary(in);
    }
  }

  @Override
  public void maybePersist(Factorization factorization) throws IOException {
    try (DataOutputStream out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(file)))){
      log.info("Writing factorization to {}...", file.getAbsolutePath());
      writeBinary(factorization, out);
    }
  }

  protected static void writeBinary(Factorization factorization, DataOutput out) throws IOException {
    out.writeInt(factorization.numFeatures());
    out.writeInt(factorization.numUsers());
    out.writeInt(factorization.numItems());

    for (Map.Entry mappingEntry : factorization.getUserIDMappings()) {
      long userID = mappingEntry.getKey();
      out.writeInt(mappingEntry.getValue());
      out.writeLong(userID);
      try {
        double[] userFeatures = factorization.getUserFeatures(userID);
        for (int feature = 0; feature < factorization.numFeatures(); feature++) {
          out.writeDouble(userFeatures[feature]);
        }
      } catch (NoSuchUserException e) {
        throw new IOException("Unable to persist factorization", e);
      }
    }

    for (Map.Entry entry : factorization.getItemIDMappings()) {
      long itemID = entry.getKey();
      out.writeInt(entry.getValue());
      out.writeLong(itemID);
      try {
        double[] itemFeatures = factorization.getItemFeatures(itemID);
        for (int feature = 0; feature < factorization.numFeatures(); feature++) {
          out.writeDouble(itemFeatures[feature]);
        }
      } catch (NoSuchItemException e) {
        throw new IOException("Unable to persist factorization", e);
      }
    }
  }

  public static Factorization readBinary(DataInput in) throws IOException {
    int numFeatures = in.readInt();
    int numUsers = in.readInt();
    int numItems = in.readInt();

    FastByIDMap userIDMapping = new FastByIDMap<>(numUsers);
    double[][] userFeatures = new double[numUsers][numFeatures];

    for (int n = 0; n < numUsers; n++) {
      int userIndex = in.readInt();
      long userID = in.readLong();
      userIDMapping.put(userID, userIndex);
      for (int feature = 0; feature < numFeatures; feature++) {
        userFeatures[userIndex][feature] = in.readDouble();
      }
    }

    FastByIDMap itemIDMapping = new FastByIDMap<>(numItems);
    double[][] itemFeatures = new double[numItems][numFeatures];

    for (int n = 0; n < numItems; n++) {
      int itemIndex = in.readInt();
      long itemID = in.readLong();
      itemIDMapping.put(itemID, itemIndex);
      for (int feature = 0; feature < numFeatures; feature++) {
        itemFeatures[itemIndex][feature] = in.readDouble();
      }
    }

    return new Factorization(userIDMapping, itemIDMapping, userFeatures, itemFeatures);
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy