Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ public Object proceed(
// MDC put branchId
MDC.put(RootContext.MDC_KEY_BRANCH_ID, branchId);

// enable mutation tracking only after framework initialization is complete
actionContext.enableActionContextTracking();

// save the previous action context
BusinessActionContext previousActionContext = BusinessActionContextUtil.getContext();
try {
Expand Down Expand Up @@ -232,10 +235,13 @@ protected String doTxActionLogStore(

Map<String, Object> originContext = actionContext.getActionContext();
if (CollectionUtils.isNotEmpty(originContext)) {
// Merge context and origin context if it exists.
// @since: above 1.4.2
originContext.putAll(context);
context = originContext;
// Keep framework-side merge outside the tracked map to avoid false updated flags.
// Merge framework context into a fresh map to avoid treating framework-side
// initialization as a business mutation when tracking is already enabled.
Map<String, Object> mergedContext = new HashMap<>(originContext);
mergedContext.putAll(context);
actionContext.setActionContext(mergedContext);
context = mergedContext;
} else {
actionContext.setActionContext(context);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,14 @@
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.io.Serializable;
import java.util.AbstractMap;
import java.util.AbstractSet;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.Set;

/**
* The type Business action context.
Expand Down Expand Up @@ -62,6 +69,11 @@ public class BusinessActionContext implements Serializable {
*/
private Map<String, Object> actionContext;

/**
* whether the action context should stay tracked on mutation
*/
private transient boolean actionContextTrackingEnabled;

/**
* Instantiates a new Business action context.
*/
Expand Down Expand Up @@ -147,9 +159,27 @@ public Map<String, Object> getActionContext() {
* @param actionContext the action context
*/
public void setActionContext(Map<String, Object> actionContext) {
if (actionContextTrackingEnabled) {
this.actionContext = actionContext == null
? new TrackedActionContextMap(this)
: new TrackedActionContextMap(this, actionContext);
return;
}
this.actionContext = actionContext;
}

/**
* Enable automatic updated tracking for action context mutations.
*/
public void enableActionContextTracking() {
actionContextTrackingEnabled = true;
if (actionContext == null) {
actionContext = new TrackedActionContextMap(this);
} else if (!(actionContext instanceof TrackedActionContextMap)) {
actionContext = new TrackedActionContextMap(this, actionContext);
}
}

/**
* Gets xid.
*
Expand Down Expand Up @@ -233,6 +263,10 @@ public void setBranchType(BranchType branchType) {
this.branchType = branchType;
}

private void markUpdatedOnActionContextMutation() {
setUpdated(true);
}

@Override
public String toString() {
StringBuilder sb = new StringBuilder();
Expand All @@ -253,4 +287,173 @@ public String toString() {
.append("]");
return sb.toString();
}

/**
* The tracked action context map.
*/
private static final class TrackedActionContextMap extends AbstractMap<String, Object> implements Serializable {

private static final long serialVersionUID = 1L;

private final BusinessActionContext owner;

private final Map<String, Object> delegate;

private TrackedActionContextMap(BusinessActionContext owner) {
this.owner = owner;
this.delegate = new HashMap<>(8);
}

private TrackedActionContextMap(BusinessActionContext owner, Map<String, Object> source) {
this.owner = owner;
this.delegate = new HashMap<>(source);
}

@Override
public Object put(String key, Object value) {
boolean hadKey = delegate.containsKey(key);
Object previousValue = delegate.put(key, value);
if (!hadKey || !Objects.equals(previousValue, value)) {
owner.markUpdatedOnActionContextMutation();
}
return previousValue;
}

@Override
public void putAll(Map<? extends String, ? extends Object> m) {
Objects.requireNonNull(m, "m");
if (m.isEmpty()) {
return;
}
for (Map.Entry<? extends String, ? extends Object> entry : m.entrySet()) {
put(entry.getKey(), entry.getValue());
}
}
Comment thread
Zhengcy05 marked this conversation as resolved.

@Override
public Object remove(Object key) {
boolean hadKey = delegate.containsKey(key);
Object previousValue = delegate.remove(key);
if (hadKey) {
owner.markUpdatedOnActionContextMutation();
}
return previousValue;
}

@Override
public void clear() {
if (!delegate.isEmpty()) {
delegate.clear();
owner.markUpdatedOnActionContextMutation();
}
}

@Override
public Set<Entry<String, Object>> entrySet() {
return new AbstractSet<Entry<String, Object>>() {
@Override
public Iterator<Entry<String, Object>> iterator() {
Iterator<Entry<String, Object>> iterator =
delegate.entrySet().iterator();
return new Iterator<Entry<String, Object>>() {
@Override
public boolean hasNext() {
return iterator.hasNext();
}

@Override
public Entry<String, Object> next() {
Entry<String, Object> current = iterator.next();
return new TrackingEntry(current);
}

@Override
public void remove() {
iterator.remove();
owner.markUpdatedOnActionContextMutation();
}
};
}

// The following methods delegate directly to the underlying Map.
@Override
public int size() {
return delegate.size();
}

@Override
public boolean remove(Object o) {
if (!(o instanceof Entry)) {
return false;
}
Entry<?, ?> entry = (Entry<?, ?>) o;
if (!delegate.containsKey(entry.getKey())) {
return false;
}
if (!Objects.equals(delegate.get(entry.getKey()), entry.getValue())) {
return false;
}
TrackedActionContextMap.this.remove(entry.getKey());
return true;
}
};
}

@Override
public int size() {
return delegate.size();
}

@Override
public boolean containsKey(Object key) {
return delegate.containsKey(key);
}

@Override
public boolean containsValue(Object value) {
return delegate.containsValue(value);
}

@Override
public Object get(Object key) {
return delegate.get(key);
}

private final class TrackingEntry implements Entry<String, Object> {
private final Entry<String, Object> delegateEntry;

private TrackingEntry(Entry<String, Object> delegateEntry) {
this.delegateEntry = delegateEntry;
}

@Override
public String getKey() {
return delegateEntry.getKey();
}

@Override
public Object getValue() {
return delegateEntry.getValue();
}

@Override
public Object setValue(Object value) {
Object previousValue = delegateEntry.setValue(value);
if (!Objects.equals(previousValue, value)) {
owner.markUpdatedOnActionContextMutation();
}
return previousValue;
}

@Override
public boolean equals(Object o) {
return delegateEntry.equals(o);
}

@Override
public int hashCode() {
return delegateEntry.hashCode();
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,32 @@
*/
package org.apache.seata.integration.tx.api.interceptor;

import org.apache.seata.common.executor.Callback;
import org.apache.seata.core.model.BranchStatus;
import org.apache.seata.core.model.BranchType;
import org.apache.seata.rm.DefaultResourceManager;
import org.apache.seata.rm.tcc.api.BusinessActionContext;
import org.apache.seata.rm.tcc.api.BusinessActionContextUtil;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.MockedStatic;

import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;

import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isNull;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

/**
* The type Action interceptor handler test.
Expand All @@ -36,6 +54,11 @@ public class ActionInterceptorHandlerTest {
*/
protected ActionInterceptorHandler actionInterceptorHandler = new ActionInterceptorHandler();

@AfterEach
public void tearDown() {
BusinessActionContextUtil.clear();
}

/**
* Test business action context.
*
Expand All @@ -57,4 +80,60 @@ public void testBusinessActionContext() throws NoSuchMethodException {
Assertions.assertEquals("b", paramContext.get("b"));
Assertions.assertEquals("abc@ali.com", paramContext.get("email"));
}

@Test
public void testProceedTracksActionContextMutation() throws Throwable {
Method prepareMethod = TestAction.class.getDeclaredMethod(
"prepare", BusinessActionContext.class, int.class, List.class, TestParam.class);
List<Object> list = new ArrayList<>();
list.add("b");
TestParam tccParam = new TestParam(1, "abc@ali.com");

TwoPhaseBusinessActionParam businessActionParam = mock(TwoPhaseBusinessActionParam.class);
org.mockito.Mockito.doReturn("prepare").when(businessActionParam).getActionName();
org.mockito.Mockito.doReturn(BranchType.TCC).when(businessActionParam).getBranchType();
org.mockito.Mockito.doReturn(false).when(businessActionParam).getDelayReport();
org.mockito.Mockito.doReturn(false).when(businessActionParam).getUseCommonFence();
org.mockito.Mockito.doReturn(Collections.emptyMap())
.when(businessActionParam)
.getBusinessActionContext();

DefaultResourceManager resourceManager = mock(DefaultResourceManager.class);
AtomicReference<BusinessActionContext> observedContext = new AtomicReference<>();
ArgumentCaptor<String> applicationDataCaptor = ArgumentCaptor.forClass(String.class);

try (MockedStatic<DefaultResourceManager> mocked = mockStatic(DefaultResourceManager.class)) {
mocked.when(DefaultResourceManager::get).thenReturn(resourceManager);
when(resourceManager.branchRegister(
eq(BranchType.TCC), eq("prepare"), isNull(), eq("test-xid"), anyString(), isNull()))
.thenReturn(1L);

Callback<Object> callback = () -> {
BusinessActionContext currentContext = BusinessActionContextUtil.getContext();
Assertions.assertNotNull(currentContext);
Assertions.assertNull(currentContext.getUpdated());
currentContext.getActionContext().put("biz", "value");
Assertions.assertTrue(currentContext.getUpdated());
observedContext.set(currentContext);
return null;
};

Object result = actionInterceptorHandler.proceed(
prepareMethod, new Object[] {null, 10, list, tccParam}, "test-xid", businessActionParam, callback);

Assertions.assertNull(result);
}

Assertions.assertNotNull(observedContext.get());
Assertions.assertNull(observedContext.get().getUpdated());
verify(resourceManager)
.branchReport(
eq(BranchType.TCC),
eq("test-xid"),
eq(1L),
eq(BranchStatus.Registered),
applicationDataCaptor.capture());
Assertions.assertTrue(applicationDataCaptor.getValue().contains("biz"));
Assertions.assertTrue(applicationDataCaptor.getValue().contains("value"));
}
}
Loading
Loading