All Downloads are FREE. Search and download functionalities are using the official Maven repository.

cc.redberry.transformation.concurrent.ExpandBracketsOutput Maven / Gradle / Ivy

/*
 * Redberry: symbolic tensor computations.
 *
 * Copyright (c) 2010-2012:
 *   Stanislav Poslavsky   
 *   Bolotin Dmitriy       
 *
 * This file is part of Redberry.
 *
 * Redberry is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * Redberry is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with Redberry. If not, see .
 */
package cc.redberry.transformation.concurrent;

import java.util.ArrayList;
import java.util.List;
import cc.redberry.concurrent.OutputPortUnsafe;
import cc.redberry.core.tensor.Product;
import cc.redberry.core.tensor.Sum;
import cc.redberry.core.tensor.Tensor;
import cc.redberry.core.tensor.TensorIterator;
import cc.redberry.core.utils.Indicator;

/**
 *
 * @author Dmitry Bolotin
 * @author Stanislav Poslavsky
 */
public class ExpandBracketsOutput implements OutputPortUnsafe {
    private final Indicator except;
    private final Product productBody;
    private final SumPort[] sumPorts;
    private final Tensor[] currentValues;
    private static final boolean USE_BUFFER = false;
    private static final int BUFFER_SIZE = 10;

    public ExpandBracketsOutput(Product productBody, Sum[] sums, Indicator except) {
        this.except = except;
        this.productBody = productBody;
        this.sumPorts = new SumPort[sums.length];
        this.currentValues = new Tensor[sums.length - 1];
        int i = 0;
        for (Sum sum : sums)
            if (USE_BUFFER)
                sumPorts[i++] = new SumPortWrapper(new SumPortImpl(this.except, sum));
            else
                sumPorts[i++] = new SumPortImpl(this.except, sum);
        for (i = 0; i < sums.length - 1; ++i)
            currentValues[i] = sumPorts[i].take();
    }

    @Override
    public Tensor take() {
        Product result = productBody.clone();
        int pointer = sumPorts.length - 1;
        boolean increment = false;
        Tensor element = sumPorts[pointer].take();
        if (element == null) {
            sumPorts[pointer].reset();
            element = sumPorts[pointer].take();
            increment = true;
        }
        result.add(element.clone());
        while (--pointer >= 0) {
            if (increment) {
                increment = false;
                element = sumPorts[pointer].take();
                if (element == null) {
                    sumPorts[pointer].reset();
                    element = sumPorts[pointer].take();
                    increment = true;
                }
                currentValues[pointer] = element;
            }
            result.add(currentValues[pointer].clone());
        }
        if (increment)
            return null;
        return result;
    }

    public static OutputPortUnsafe create(Tensor tensor) {
        return create(tensor, Indicator.FALSE_INDICATOR);
    }

    public static OutputPortUnsafe create(Tensor tensor, Indicator except) {
        if (tensor instanceof Product) {
            OutputPortUnsafe res = createOP(tensor, except);
            if (res != null)
                return res;
            else
                return new OutputPortUnsafe.Singleton<>(tensor);
        } else if (tensor instanceof Sum)
            return new SumPortImpl(except, (Sum) tensor);
        else
            return new OutputPortUnsafe.Singleton<>(tensor);
    }

    private static OutputPortUnsafe createOP(Tensor tensor, Indicator except) {
        if (!(tensor instanceof Product))
            return null;
        List sums = new ArrayList<>();
        Product body = new Product();
        for (Tensor t : tensor)
            if ((t instanceof Sum) && !except.is(t))
                sums.add((Sum) t);
            else
                body.add(t);
        if (sums.isEmpty())
            return null;
        return new ExpandBracketsOutput(body, sums.toArray(new Sum[sums.size()]), except);
    }

    private static interface SumPort extends OutputPortUnsafe {
        void reset();
    }

    private static class SumPortImpl implements SumPort {
        private Sum sum;
        private Indicator except;
        private TensorIterator sumIterator;
        private OutputPortUnsafe currentPort = null;

        public SumPortImpl(Indicator except, Sum sum) {
            this.sum = sum;
            this.except = except;
            reset();
        }

        @Override
        public final void reset() {
            this.sumIterator = sum.iterator();
            currentPort = null;
        }

        @Override
        public Tensor take() {
            Tensor ret;
            if (currentPort != null && (ret = currentPort.take()) != null)
                return ret;
            if (!sumIterator.hasNext())
                return null;
            currentPort = null;
            Tensor summand = sumIterator.next();
            if (except.is(summand))
                return summand;
            currentPort = createOP(summand, except);
            if (currentPort != null)
                return currentPort.take();
            else
                return summand;
        }
    }

    private static class SumPortWrapper implements SumPort {
        private SumPort port;
        private Tensor[] buffer = new Tensor[BUFFER_SIZE];
        private int bufferPointer, firstBufferedElement = -1, sequenceSize, bufferSize;

        public SumPortWrapper(SumPort port) {
            this.port = port;
            bufferPointer = 0;
            sequenceSize = 0;
        }

        @Override
        public void reset() {
            if (firstBufferedElement == -1)
                firstBufferedElement = sequenceSize - bufferSize;
            if (firstBufferedElement > 0)
                port.reset();
            sequenceSize = 0;
            bufferPointer = 0;
        }

        @Override
        public Tensor take() {
            if (firstBufferedElement == sequenceSize)
                if (bufferPointer == bufferSize)
                    return null;
                else
                    return buffer[bufferPointer++];
            Tensor tensor = port.take();
            if (tensor == null)
                return null;
            sequenceSize++;
            assert sequenceSize > 0;
            if (firstBufferedElement == -1) {
                buffer[bufferPointer++] = tensor;
                if (bufferSize < BUFFER_SIZE)
                    bufferSize++;
                if (bufferPointer == BUFFER_SIZE)
                    bufferPointer = 0;
            }
            return tensor;
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy