All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.mongodb.DBCollectionImpl Maven / Gradle / Ivy

Go to download

The MongoDB Java Driver uber-artifact, containing mongodb-driver, mongodb-driver-core, and bson

There is a newer version: 3.1.0
Show newest version
/*
 * Copyright (c) 2008-2014 MongoDB, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.mongodb;

import com.mongodb.util.JSON;
import org.bson.io.PoolOutputBuffer;
import org.bson.types.ObjectId;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import java.util.logging.Level;
import java.util.logging.Logger;

import static com.mongodb.QueryResultIterator.chooseBatchSize;
import static com.mongodb.WriteCommandResultHelper.getBulkWriteException;
import static com.mongodb.WriteCommandResultHelper.getBulkWriteResult;
import static com.mongodb.WriteCommandResultHelper.hasError;
import static com.mongodb.WriteRequest.Type.INSERT;
import static com.mongodb.WriteRequest.Type.REMOVE;
import static com.mongodb.WriteRequest.Type.REPLACE;
import static com.mongodb.WriteRequest.Type.UPDATE;
import static java.lang.String.format;
import static java.util.Arrays.asList;
import static org.bson.util.Assertions.isTrue;

@SuppressWarnings("deprecation")
class DBCollectionImpl extends DBCollection {

    // Add an extra 16K to the maximum allowed size of a query document, so that, for example,
    // a 16M document can be upserted via findAndModify
    private static final int QUERY_DOCUMENT_HEADROOM = 16 * 1024;

    private final DBApiLayer db;
    private final String namespace;

    DBCollectionImpl(final DBApiLayer db, String name){
        super(db, name );
        namespace = db._root + "." + name;
        this.db = db;
    }

    @Override
    QueryResultIterator find(DBObject ref, DBObject fields, int numToSkip, int batchSize, int limit, int options,
                             ReadPreference readPref, DBDecoder decoder) {
        return find(ref, fields, numToSkip, batchSize, limit, options, readPref, decoder, DefaultDBEncoder.FACTORY.create());
    }

    @Override
    QueryResultIterator find(DBObject ref, DBObject fields, int numToSkip, int batchSize, int limit, int options,
                             ReadPreference readPref, DBDecoder decoder, DBEncoder encoder) {

        if (ref == null) {
            ref = new BasicDBObject();
        }

        if (willTrace()) {
            trace("find: " + namespace + " " + JSON.serialize(ref));
        }

        OutMessage query = OutMessage.query(this, options, numToSkip, chooseBatchSize(batchSize, limit, 0), ref, fields, readPref,
                                            encoder, db.getMongo().getMaxBsonObjectSize() + QUERY_DOCUMENT_HEADROOM);

        Response res = db.getConnector().call(_db, this, query, null, 2, readPref, decoder);

        return new QueryResultIterator(this.namespace, db.getMongo(), res, batchSize, limit, decoder);
    }

    public Cursor aggregate(final List pipeline, final AggregationOptions options,
                            final ReadPreference readPreference) {

        if(options == null) {
            throw new IllegalArgumentException("options can not be null");
        }
        DBObject last = pipeline.get(pipeline.size() - 1);

        DBObject command = prepareCommand(pipeline, options);

        final CommandResult res = _db.command(command, getOptions(), readPreference);
        res.throwOnError();

        String outCollection = (String) last.get("$out");
        if (outCollection != null) {
            DBCollection collection = _db.getCollection(outCollection);
            return new DBCursor(collection, new BasicDBObject(), null, ReadPreference.primary());
        } else {
            Integer batchSize = options.getBatchSize();
            return new QueryResultIterator(res, db.getMongo(), batchSize == null ? 0 : batchSize, getDecoder(), res.getServerUsed());
        }
    }

    @Override
    @SuppressWarnings("unchecked")
    public List parallelScan(final ParallelScanOptions options) {
        CommandResult res = _db.command(new BasicDBObject("parallelCollectionScan", getName())
                                        .append("numCursors", options.getNumCursors()),
                                        options.getReadPreference());
        res.throwOnError();

        List cursors = new ArrayList();
        for (DBObject cursorDocument : (List) res.get("cursors")) {
            cursors.add(new QueryResultIterator(cursorDocument, db.getMongo(), options.getBatchSize(), getDecoder(), res.getServerUsed()));
        }

        return cursors;
    }


    @Override
    BulkWriteResult executeBulkWriteOperation(final boolean ordered, final List writeRequests,
                                              WriteConcern writeConcern, DBEncoder encoder) {
        isTrue("no operations", !writeRequests.isEmpty());

        if (writeConcern == null) {
            throw new IllegalArgumentException("Write concern can not be null");
        }

        writeConcern = writeConcern.continueOnError(!ordered);

        if (encoder == null) {
            encoder = DefaultDBEncoder.FACTORY.create();
        }

        DBPort port = db.getConnector().getPrimaryPort();
        try {
            BulkWriteBatchCombiner bulkWriteBatchCombiner = new BulkWriteBatchCombiner(port.getAddress(), writeConcern);
            for (Run run : getRunGenerator(ordered, writeRequests, writeConcern, encoder, port)) {
                try {
                    BulkWriteResult result = run.execute(port);
                    if (result.isAcknowledged()) {
                        bulkWriteBatchCombiner.addResult(result, run.indexMap);
                    }
                } catch (BulkWriteException e) {
                    bulkWriteBatchCombiner.addErrorResult(e, run.indexMap);
                    if (bulkWriteBatchCombiner.shouldStopSendingMoreBatches()) {
                        break;
                    }
                }
            }
            return bulkWriteBatchCombiner.getResult();
        } finally {
            db.getConnector().releasePort(port);
        }
    }

    public WriteResult insert(List list, WriteConcern concern, DBEncoder encoder ){
        return insert(list, true, concern, encoder);
    }

    protected WriteResult insert(List list, boolean shouldApply , WriteConcern concern, DBEncoder encoder ){
        if (concern == null) {
            throw new IllegalArgumentException("Write concern can not be null");
        }

        if (encoder == null)
            encoder = DefaultDBEncoder.FACTORY.create();

        if ( willTrace() ) {
            for (DBObject o : list) {
                trace("save:  " + namespace + " " + JSON.serialize(o));
            }
        }

        DBPort port = db.getConnector().getPrimaryPort();
        try {
            if (useWriteCommands(concern, port)) {
                try {
                    return translateBulkWriteResult(insertWithCommandProtocol(list, concern, encoder, port, shouldApply), INSERT, concern,
                                                    port.getAddress());
                } catch (BulkWriteException e) {
                    throw translateBulkWriteException(e, INSERT);
                }
            }
            else {
                return insertWithWriteProtocol(list, concern, encoder, port, shouldApply);
            }
        } finally {
            db.getConnector().releasePort(port);
        }
    }

    public WriteResult remove( DBObject query, WriteConcern concern, DBEncoder encoder ) {
       return remove(query, true, concern, encoder);
    }

    public WriteResult remove( DBObject query, boolean multi, WriteConcern concern, DBEncoder encoder ){
        if (concern == null) {
            throw new IllegalArgumentException("Write concern can not be null");
        }

        if (encoder == null) {
            encoder = DefaultDBEncoder.FACTORY.create();
        }

        if (willTrace()) {
            trace("remove: " + namespace + " " + JSON.serialize(query));
        }

        DBPort port = db.getConnector().getPrimaryPort();
        try {
            if (useWriteCommands(concern, port)) {
                try {
                    return translateBulkWriteResult(removeWithCommandProtocol(Arrays.asList(new RemoveRequest(query, multi)),
                                                                              concern,
                                                                              encoder,
                                                                              port), REMOVE, concern, port.getAddress());
                } catch (BulkWriteException e) {
                    throw translateBulkWriteException(e, REMOVE);
                }
            }
            else {
                return db.getConnector().say(_db, OutMessage.remove(this, encoder, query, multi), concern, port);
            }
        } finally {
            db.getConnector().releasePort(port);
        }
    }

    @Override
    public WriteResult update( DBObject query , DBObject o , boolean upsert , boolean multi , WriteConcern concern,
                               DBEncoder encoder ) {

        if (o == null) {
            throw new IllegalArgumentException("update can not be null");
        }

        if (concern == null) {
            throw new IllegalArgumentException("Write concern can not be null");
        }

        if (encoder == null)
            encoder = DefaultDBEncoder.FACTORY.create();

        if (!o.keySet().isEmpty()) {
            // if 1st key doesn't start with $, then object will be inserted as is, need to check it
            String key = o.keySet().iterator().next();
            if (!key.startsWith("$"))
                _checkObject(o, false, false);
        }

        if ( willTrace() ) {
            trace("update: " + namespace + " " + JSON.serialize(query) + " " + JSON.serialize(o));
        }

        DBPort port = db.getConnector().getPrimaryPort();
        try {
            if (useWriteCommands(concern, port)) {
                try {
                    BulkWriteResult bulkWriteResult =
                    updateWithCommandProtocol(Arrays.asList(new UpdateRequest(query, upsert, o, multi)),
                                              concern, encoder, port);
                    return translateBulkWriteResult(bulkWriteResult, UPDATE, concern, port.getAddress());
                } catch (BulkWriteException e) {
                    throw translateBulkWriteException(e, UPDATE);
                }
            } else {
                return db.getConnector().say(_db, OutMessage.update(this, encoder, upsert, multi, query, o), concern, port);
            }
        } finally {
            db.getConnector().releasePort(port);
        }
    }

    @Override
    public void drop(){
        db._collections.remove(getName());
        super.drop();
    }

    public void doapply( DBObject o ){
    }

    private WriteResult translateBulkWriteResult(final BulkWriteResult bulkWriteResult, final WriteRequest.Type type,
                                                 final WriteConcern writeConcern, final ServerAddress serverAddress) {
        CommandResult commandResult = new CommandResult(serverAddress);
        addBulkWriteResultToCommandResult(bulkWriteResult, type, commandResult);
        return new WriteResult(commandResult, writeConcern);
    }

    private MongoException translateBulkWriteException(final BulkWriteException e, final WriteRequest.Type type) {
        BulkWriteError lastError = e.getWriteErrors().isEmpty() ? null : e.getWriteErrors().get(e.getWriteErrors().size() - 1);
        CommandResult commandResult = new CommandResult(e.getServerAddress());
        addBulkWriteResultToCommandResult(e.getWriteResult(), type, commandResult);
        if (e.getWriteConcernError() != null) {
            commandResult.putAll(e.getWriteConcernError().getDetails());
        }

        if (lastError != null) {
            commandResult.put("err", lastError.getMessage());
            commandResult.put("code", lastError.getCode());
            commandResult.putAll(lastError.getDetails());
        } else if (e.getWriteConcernError() != null) {
            commandResult.put("err", e.getWriteConcernError().getMessage());
            commandResult.put("code", e.getWriteConcernError().getCode());
        }
        return commandResult.getException();
    }

    private void addBulkWriteResultToCommandResult(final BulkWriteResult bulkWriteResult, final WriteRequest.Type type,
                                                   final CommandResult commandResult) {
        commandResult.put("ok", 1);
        if (type == INSERT) {
           commandResult.put("n", 0);
        } else if (type == REMOVE) {
           commandResult.put("n", bulkWriteResult.getRemovedCount());
        } else if (type == UPDATE || type == REPLACE) {
            commandResult.put("n", bulkWriteResult.getMatchedCount() + bulkWriteResult.getUpserts().size());
            if (bulkWriteResult.getMatchedCount() > 0) {
                commandResult.put("updatedExisting", true);
            } else {
                commandResult.put("updatedExisting", false);
            }
            if (!bulkWriteResult.getUpserts().isEmpty()) {
                commandResult.put("upserted", bulkWriteResult.getUpserts().get(0).getId());
            }
        }
    }

    @SuppressWarnings("unchecked")
    public List getIndexInfo() {
        getDB().requestStart();
        try {
            List list = new ArrayList();

            if (db.isServerVersionAtLeast(asList(3, 0, 0))) {
                CommandResult res = _db.command(new BasicDBObject("listIndexes", getName()), ReadPreference.primary());
                if (!res.ok() && res.getCode() == 26) {
                    return list;
                }
                res.throwOnError();

                QueryResultIterator iterator = new QueryResultIterator(res, db.getMongo(), 0, DefaultDBDecoder.FACTORY.create(),
                                                                       res.getServerUsed());
                try {
                    while (iterator.hasNext()) {
                        DBObject collectionInfo = iterator.next();
                        list.add(collectionInfo);
                    }
                } finally {
                    iterator.close();
                }
            } else {
                BasicDBObject cmd = new BasicDBObject("ns", getFullName());
                DBCursor cur = _db.getCollection("system.indexes").find(cmd).setReadPreference(ReadPreference.primary());
                while (cur.hasNext()) {
                    list.add(cur.next());
                }
            }

            return list;
        } finally {
            getDB().requestDone();
        }
    }

    public void createIndex(final DBObject keys, final DBObject options, DBEncoder encoder) {
        DBTCPConnector connector = db.getConnector();
        final DBPort port = db.getConnector().getPrimaryPort();

        try {
            DBObject index = defaultOptions(keys);
            index.putAll(options);
            index.put("key", keys);

            if (port.getServerVersion().compareTo(new ServerVersion(2, 6)) >= 0) {
                final BasicDBObject createIndexes = new BasicDBObject("createIndexes", getName());

                BasicDBList list = new BasicDBList();
                list.add(index);
                createIndexes.put("indexes", list);

                CommandResult commandResult = connector.doOperation(db, port, new DBPort.Operation() {
                    @Override
                    public CommandResult execute() throws IOException {
                       return port.runCommand(db, createIndexes);
                    }
                });
                try {
                    commandResult.throwOnError();
                } catch (CommandFailureException e) {
                    if(e.getCode() == 11000) {
                        throw new MongoException.DuplicateKey(commandResult);
                    } else { 
                        throw e;
                    }
                }
            } else {
                db.doGetCollection("system.indexes").insertWithWriteProtocol(asList(index), WriteConcern.SAFE,
                                                                             DefaultDBEncoder.FACTORY.create(), port, false);
            }
        } finally {
            connector.releasePort(port);
        }
    }

    private BulkWriteResult insertWithCommandProtocol(final List list, final WriteConcern writeConcern,
                                                      final DBEncoder encoder,
                                                      final DBPort port, final boolean shouldApply) {
        if ( shouldApply ){
            applyRulesForInsert(list);
        }

        BaseWriteCommandMessage message = new InsertCommandMessage(getNamespace(), writeConcern, list,
                                                                   DefaultDBEncoder.FACTORY.create(), encoder,
                                                                   getMessageSettings(port));
        return writeWithCommandProtocol(port, INSERT, message, writeConcern);
    }

    private void applyRulesForInsert(final List list) {
        for (DBObject o : list) {
            _checkObject(o, false, false);
            apply(o);
            Object id = o.get("_id");
            if (id instanceof ObjectId) {
                ((ObjectId) id).notNew();
            }
        }
    }

    private BulkWriteResult removeWithCommandProtocol(final List removeList,
                                                      final WriteConcern writeConcern,
                                                      final DBEncoder encoder, final DBPort port) {
        BaseWriteCommandMessage message = new DeleteCommandMessage(getNamespace(), writeConcern, removeList,
                                                                   DefaultDBEncoder.FACTORY.create(), encoder,
                                                                   getMessageSettings(port));
        return writeWithCommandProtocol(port, REMOVE, message, writeConcern);
    }

    @SuppressWarnings("unchecked")
    private BulkWriteResult updateWithCommandProtocol(final List updates,
                                                      final WriteConcern writeConcern,
                                                      final DBEncoder encoder, final DBPort port) {
        BaseWriteCommandMessage message = new UpdateCommandMessage(getNamespace(), writeConcern, updates,
                                                                   DefaultDBEncoder.FACTORY.create(), encoder,
                                                                   getMessageSettings(port));
        return writeWithCommandProtocol(port, UPDATE, message, writeConcern);
    }

    private BulkWriteResult writeWithCommandProtocol(final DBPort port, final WriteRequest.Type type, final BaseWriteCommandMessage message,
                                                     final WriteConcern writeConcern) {
        return db.getConnector().doOperation(db, port, new DBPort.Operation() {
            @Override
            public BulkWriteResult execute() throws IOException {
                BaseWriteCommandMessage curMessage = message;
                int batchNum = 0;
                int currentRangeStartIndex = 0;
                BulkWriteBatchCombiner bulkWriteBatchCombiner = new BulkWriteBatchCombiner(port.getAddress(), writeConcern);
                do {
                    batchNum++;
                    BaseWriteCommandMessage nextMessage = sendWriteCommandMessage(curMessage, batchNum, port);
                    int itemCount = nextMessage != null ? curMessage.getItemCount() - nextMessage.getItemCount()
                                                        : curMessage.getItemCount();
                    IndexMap indexMap = IndexMap.create(currentRangeStartIndex, itemCount);
                    CommandResult commandResult = receiveWriteCommandMessage(port);
                    if (willTrace() && nextMessage != null || batchNum > 1) {
                        getLogger().fine(format("Received response for batch %d", batchNum));
                    }

                    if (hasError(commandResult)) {
                        bulkWriteBatchCombiner.addErrorResult(getBulkWriteException(type, commandResult), indexMap);
                    } else {
                        bulkWriteBatchCombiner.addResult(getBulkWriteResult(type, commandResult), indexMap);
                    }
                    currentRangeStartIndex += itemCount;
                    curMessage = nextMessage;
                } while (curMessage != null && !bulkWriteBatchCombiner.shouldStopSendingMoreBatches());

                return bulkWriteBatchCombiner.getResult();
            }
        });

    }

    private boolean useWriteCommands(final WriteConcern concern, final DBPort port) {
        return concern.callGetLastError() &&
               port.getServerVersion().compareTo(new ServerVersion(2, 6)) >= 0;
    }

    private MessageSettings getMessageSettings(final DBPort port) {
        ServerDescription serverDescription = db.getConnector().getServerDescription(port.getAddress());
        return MessageSettings.builder()
                              .maxDocumentSize(serverDescription.getMaxDocumentSize())
                              .maxMessageSize(serverDescription.getMaxMessageSize())
                              .maxWriteBatchSize(serverDescription.getMaxWriteBatchSize())
                              .build();
    }

    private int getMaxWriteBatchSize(final DBPort port) {
        return db.getConnector().getServerDescription(port.getAddress()).getMaxWriteBatchSize();
    }

    private MongoNamespace getNamespace() {
        return new MongoNamespace(getDB().getName(), getName());
    }

    private BaseWriteCommandMessage sendWriteCommandMessage(final BaseWriteCommandMessage message, final int batchNum,
                                                            final DBPort port) throws IOException {
        PoolOutputBuffer buffer = _db.getMongo()._bufferPool.get();
        try {
            BaseWriteCommandMessage nextMessage = message.encode(buffer);
            if (nextMessage != null || batchNum > 1) {
                getLogger().fine(format("Sending batch %d", batchNum));
            }
            buffer.pipe(port.getOutputStream());
            return nextMessage;
        } finally {
            buffer.reset();
            _db.getMongo()._bufferPool.done(buffer);
        }
    }

    private CommandResult receiveWriteCommandMessage(final DBPort port) throws IOException {
        Response response = new Response(port.getAddress(), null, port.getInputStream(), port.getDecoder());
        CommandResult writeCommandResult = new CommandResult(port.getAddress());
        writeCommandResult.putAll(response.get(0));
        writeCommandResult.throwOnError();
        return writeCommandResult;
    }


    private WriteResult insertWithWriteProtocol(final List list, final WriteConcern concern, final DBEncoder encoder,
                                                final DBPort port, final boolean shouldApply) {
        if ( shouldApply ){
            applyRulesForInsert(list);
        }

        WriteResult last = null;

        int cur = 0;
        int maxsize = db._mongo.getMaxBsonObjectSize();
        while ( cur < list.size() ) {

            OutMessage om = OutMessage.insert( this , encoder, concern );

            for ( ; cur < list.size(); cur++ ){
                DBObject o = list.get(cur);
                om.putObject( o );

                // limit for batch insert is 4 x maxbson on server, use 2 x to be safe
                if ( om.size() > 2 * maxsize ){
                    cur++;
                    break;
                }
            }

            last = db.getConnector().say(_db, om, concern, port);
        }

        return last;
    }

    private Iterable getRunGenerator(final boolean ordered, final List writeRequests,
                                          final WriteConcern writeConcern, final DBEncoder encoder, final DBPort port) {
        if (ordered) {
            return new OrderedRunGenerator(writeRequests, writeConcern, encoder, port);
        } else {
            return new UnorderedRunGenerator(writeRequests, writeConcern, encoder, port);
        }
    }

    private static final Logger TRACE_LOGGER = Logger.getLogger( "com.mongodb.TRACE" );
    private static final Level TRACE_LEVEL = Boolean.getBoolean( "DB.TRACE" ) ? Level.INFO : Level.FINEST;

    private boolean willTrace(){
        return TRACE_LOGGER.isLoggable(TRACE_LEVEL);
    }

    private void trace( String s ){
        TRACE_LOGGER.log( TRACE_LEVEL , s );
    }

    private Logger getLogger() {
        return TRACE_LOGGER;
    }

    private class OrderedRunGenerator implements Iterable {
        private final List writeRequests;
        private final WriteConcern writeConcern;
        private final DBEncoder encoder;
        private final int maxBatchWriteSize;

        public OrderedRunGenerator(final List writeRequests, final WriteConcern writeConcern, final DBEncoder encoder,
                                   final DBPort port) {
            this.writeRequests = writeRequests;
            this.writeConcern = writeConcern.continueOnError(false);
            this.encoder = encoder;
            this.maxBatchWriteSize = getMaxWriteBatchSize(port);
        }

        @Override
        public Iterator iterator() {
            return new Iterator() {
                private int curIndex;

                @Override
                public boolean hasNext() {
                    return curIndex < writeRequests.size();
                }

                @Override
                public Run next() {
                    Run run = new Run(writeRequests.get(curIndex).getType(), writeConcern, encoder);
                    int startIndexOfNextRun = getStartIndexOfNextRun();
                    for (int i = curIndex; i < startIndexOfNextRun; i++) {
                        run.add(writeRequests.get(i), i);
                    }
                    curIndex = startIndexOfNextRun;
                    return run;
                }

                private int getStartIndexOfNextRun() {
                    WriteRequest.Type type = writeRequests.get(curIndex).getType();
                    for (int i = curIndex; i < writeRequests.size(); i++) {
                        if (i == curIndex + maxBatchWriteSize || writeRequests.get(i).getType() != type) {
                            return i;
                        }
                    }
                    return writeRequests.size();
                }

                @Override
                public void remove() {
                    throw new UnsupportedOperationException("Not implemented");
                }
            };
        }
    }


    private class UnorderedRunGenerator implements Iterable {
        private final List writeRequests;
        private final WriteConcern writeConcern;
        private final DBEncoder encoder;
        private final int maxBatchWriteSize;

        public UnorderedRunGenerator(final List writeRequests, final WriteConcern writeConcern,
                                     final DBEncoder encoder, final DBPort port) {
            this.writeRequests = writeRequests;
            this.writeConcern = writeConcern.continueOnError(true);
            this.encoder = encoder;
            this.maxBatchWriteSize = getMaxWriteBatchSize(port);
        }

        @Override
        public Iterator iterator() {
            return new Iterator() {
                private final Map runs =
                new TreeMap(new Comparator() {
                    @Override
                    public int compare(final WriteRequest.Type first, final WriteRequest.Type second) {
                        return first.compareTo(second);
                    }
                });
                private int curIndex;

                @Override
                public boolean hasNext() {
                    return curIndex < writeRequests.size() || !runs.isEmpty();
                }

                @Override
                public Run next() {
                    while (curIndex < writeRequests.size()) {
                        WriteRequest writeRequest = writeRequests.get(curIndex);
                        Run run = runs.get(writeRequest.getType());
                        if (run == null) {
                            run = new Run(writeRequest.getType(), writeConcern, encoder);
                            runs.put(run.type, run);
                        }
                        run.add(writeRequest, curIndex);
                        curIndex++;
                        if (run.size() == maxBatchWriteSize) {
                            return runs.remove(run.type);
                        }
                    }

                    return runs.remove(runs.keySet().iterator().next());
                }

                @Override
                public void remove() {
                    throw new UnsupportedOperationException("Not implemented");
                }
            };
        }
    }

    private class Run {
        private final List writeRequests = new ArrayList();
        private final WriteRequest.Type type;
        private final WriteConcern writeConcern;
        private final DBEncoder encoder;
        private IndexMap indexMap;

        Run(final WriteRequest.Type type, final WriteConcern writeConcern, final DBEncoder encoder) {
            this.type = type;
            this.indexMap = IndexMap.create();
            this.writeConcern = writeConcern;
            this.encoder = encoder;
        }

        @SuppressWarnings("unchecked")
        void add(final WriteRequest writeRequest, final int originalIndex) {
            indexMap = indexMap.add(writeRequests.size(), originalIndex);
            writeRequests.add(writeRequest);
        }

        public int size() {
            return writeRequests.size();
        }

        @SuppressWarnings("unchecked")
        BulkWriteResult execute(final DBPort port) {
            if (type == UPDATE) {
                return executeUpdates(getWriteRequestsAsModifyRequests(), port);
            } else if (type == REPLACE) {
                return executeReplaces(getWriteRequestsAsModifyRequests(), port);
            } else if (type == INSERT) {
                return executeInserts(getWriteRequestsAsInsertRequests(), port);
            } else if (type == REMOVE) {
                return executeRemoves(getWriteRequestsAsRemoveRequests(), port);
            } else {
                throw new MongoInternalException(format("Unsupported write of type %s", type));
            }
        }

        private List getWriteRequestsAsRaw() {
            return writeRequests;
        }

        @SuppressWarnings("unchecked")
        private List getWriteRequestsAsRemoveRequests() {
            return (List) getWriteRequestsAsRaw();
        }

        @SuppressWarnings("unchecked")
        private List getWriteRequestsAsInsertRequests() {
            return (List) getWriteRequestsAsRaw();
        }

        @SuppressWarnings("unchecked")
        private List getWriteRequestsAsModifyRequests() {
            return (List) getWriteRequestsAsRaw();
        }

        BulkWriteResult executeUpdates(final List updateRequests, final DBPort port) {
            for (ModifyRequest request : updateRequests) {
                for (String key : request.getUpdateDocument().keySet()) {
                    if (!key.startsWith("$")) {
                        throw new IllegalArgumentException("Update document keys must start with $: " + key);
                    }
                }
            }

            return new RunExecutor(port) {
                @Override
                BulkWriteResult executeWriteCommandProtocol() {
                    return updateWithCommandProtocol(updateRequests, writeConcern, encoder, port);
                }

                @Override
                WriteResult executeWriteProtocol(final int i) {
                    ModifyRequest update = updateRequests.get(i);
                    WriteResult writeResult = update(update.getQuery(), update.getUpdateDocument(), update.isUpsert(),
                            update.isMulti(), writeConcern, encoder);
                    return addMissingUpserted(update, writeResult);
                }

                @Override
                WriteRequest.Type getType() {
                    return UPDATE;
                }
            }.execute();
        }

        BulkWriteResult executeReplaces(final List replaceRequests, final DBPort port) {
            for (ModifyRequest request : replaceRequests) {
                _checkObject(request.getUpdateDocument(), false, false);
            }

            return new RunExecutor(port) {
                @Override
                BulkWriteResult executeWriteCommandProtocol() {
                    return updateWithCommandProtocol(replaceRequests, writeConcern, encoder, port);
                }

                @Override
                WriteResult executeWriteProtocol(final int i) {
                    ModifyRequest update = replaceRequests.get(i);
                    WriteResult writeResult = update(update.getQuery(), update.getUpdateDocument(), update.isUpsert(),
                            update.isMulti(), writeConcern, encoder);
                    return addMissingUpserted(update, writeResult);
                }

                @Override
                WriteRequest.Type getType() {
                    return REPLACE;
                }
            }.execute();
        }

        BulkWriteResult executeRemoves(final List removeRequests, final DBPort port) {
            return new RunExecutor(port) {
                @Override
                BulkWriteResult executeWriteCommandProtocol() {
                    return removeWithCommandProtocol(removeRequests, writeConcern, encoder, port);
                }

                @Override
                WriteResult executeWriteProtocol(final int i) {
                    RemoveRequest removeRequest = removeRequests.get(i);
                    return remove(removeRequest.getQuery(), removeRequest.isMulti(), writeConcern, encoder);
                }

                @Override
                WriteRequest.Type getType() {
                    return REMOVE;
                }
            }.execute();
        }

        BulkWriteResult executeInserts(final List insertRequests, final DBPort port) {
            return new RunExecutor(port) {
                @Override
                BulkWriteResult executeWriteCommandProtocol() {
                    List documents = new ArrayList(insertRequests.size());
                    for (InsertRequest cur : insertRequests) {
                        documents.add(cur.getDocument());
                    }
                    return insertWithCommandProtocol(documents, writeConcern, encoder, port, true);
                }

                @Override
                WriteResult executeWriteProtocol(final int i) {
                    return insert(asList(insertRequests.get(i).getDocument()), writeConcern, encoder);
                }

                @Override
                WriteRequest.Type getType() {
                    return INSERT;
                }

            }.execute();
        }


        private abstract class RunExecutor {
            private final DBPort port;

            RunExecutor(final DBPort port) {
                this.port = port;
            }

            abstract BulkWriteResult executeWriteCommandProtocol();

            abstract WriteResult executeWriteProtocol(final int i);

            abstract WriteRequest.Type getType();

            BulkWriteResult getResult(final WriteResult writeResult) {
                int count = getCount(writeResult);
                List upsertedItems = getUpsertedItems(writeResult);
                Integer modifiedCount = (getType() == UPDATE || getType() == REPLACE) ? null : 0;
                return new AcknowledgedBulkWriteResult(getType(), count - upsertedItems.size(), modifiedCount, upsertedItems);
            }

            BulkWriteResult execute() {
                if (useWriteCommands(writeConcern, port)) {
                    return executeWriteCommandProtocol();
                } else {
                    return executeWriteProtocol();
                }
            }

            private BulkWriteResult executeWriteProtocol() {
                BulkWriteBatchCombiner bulkWriteBatchCombiner = new BulkWriteBatchCombiner(port.getAddress(), writeConcern);
                for (int i = 0; i < writeRequests.size(); i++) {
                    IndexMap indexMap = IndexMap.create(i, 1);
                    try {
                        WriteResult writeResult = executeWriteProtocol(i);
                        if (writeConcern.callGetLastError()) {
                            bulkWriteBatchCombiner.addResult(getResult(writeResult), indexMap);
                        }
                    } catch (WriteConcernException writeException) {
                        if (isWriteConcernError(writeException.getCommandResult()))  {
                            bulkWriteBatchCombiner.addResult(getResult(new WriteResult(writeException.getCommandResult(), writeConcern)),
                                                             indexMap);
                            bulkWriteBatchCombiner.addWriteConcernErrorResult(getWriteConcernError(writeException.getCommandResult()));
                        } else {
                            bulkWriteBatchCombiner.addWriteErrorResult(getBulkWriteError(writeException), indexMap);
                        }
                        if (bulkWriteBatchCombiner.shouldStopSendingMoreBatches()) {
                            break;
                        }
                    }
                }
                return bulkWriteBatchCombiner.getResult();
            }


            private int getCount(final WriteResult writeResult) {
                return getType() == INSERT ? 1 : writeResult.getN();
            }

            List getUpsertedItems(final WriteResult writeResult) {
                return writeResult.getUpsertedId() == null
                       ? Collections.emptyList()
                       : asList(new BulkWriteUpsert(0, writeResult.getUpsertedId()));
            }

            private BulkWriteError getBulkWriteError(final WriteConcernException writeException) {
                return new BulkWriteError(writeException.getCode(),
                                          writeException.getCommandResult().getString("err"),
                                          getErrorResponseDetails(writeException.getCommandResult()),
                                          0);
            }

            // Accommodating GLE representation of write concern errors
            private boolean isWriteConcernError(final CommandResult commandResult) {
                return commandResult.get("wtimeout") != null;
            }

            private WriteConcernError getWriteConcernError(final CommandResult commandResult) {
                return new WriteConcernError(commandResult.getCode(), getWriteConcernErrorMessage(commandResult),
                                             getErrorResponseDetails(commandResult));
            }

            private String getWriteConcernErrorMessage(final CommandResult commandResult) {
                return commandResult.getString("err");
            }

            private DBObject getErrorResponseDetails(final DBObject response) {
                DBObject details = new BasicDBObject();
                for (String key : response.keySet()) {
                    if (!asList("ok", "err", "code").contains(key)) {
                        details.put(key, response.get(key));
                    }
                }
                return details;
            }

            WriteResult addMissingUpserted(final ModifyRequest update, final WriteResult writeResult) {
                // On pre 2.6 servers upserts with custom _id's would be not be reported so we  check if _id
                // was in the update query or the find query then massage the writeResult.
                if (update.isUpsert() && writeConcern.callGetLastError() && !writeResult.isUpdateOfExisting()
                        && writeResult.getUpsertedId() == null) {
                    DBObject updateDocument = update.getUpdateDocument();
                    DBObject query = update.getQuery();
                    if (updateDocument.containsField("_id")) {
                        CommandResult commandResult = writeResult.getLastError();
                        commandResult.put("upserted", updateDocument.get("_id"));
                        return new WriteResult(commandResult, writeResult.getLastConcern());
                    } else  if (query.containsField("_id")) {
                        CommandResult commandResult = writeResult.getLastError();
                        commandResult.put("upserted", query.get("_id"));
                        return new WriteResult(commandResult, writeResult.getLastConcern());
                    }
                }
                return writeResult;
            }
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy