Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1616,15 +1616,23 @@ public int getPhysicalColumn(int logicColumn) {
if (logicColumnMap == null) {
return logicColumn;
}
return logicColumnMap.get(logicColumn);
Integer physical = logicColumnMap.get(logicColumn);
if (physical == null) {
throw new IllegalArgumentException("Invalid column index: " + logicColumn);
}
return physical;
Comment on lines +1619 to +1623
}

@Override
public int getLogicColumn(int physicalColumn) {
if (physicalColumnMap == null) {
return physicalColumn;
}
return physicalColumnMap.get(physicalColumn);
Integer logical = physicalColumnMap.get(physicalColumn);
if (logical == null) {
throw new IllegalArgumentException("Invalid column index: " + physicalColumn);
}
return logical;
Comment on lines +1631 to +1635
}

@Override
Expand Down
55 changes: 37 additions & 18 deletions core/src/main/java/com/alibaba/druid/wall/WallFilter.java
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,13 @@
public class WallFilter extends FilterAdapter implements WallFilterMBean {
private static final Log LOG = LogFactory.getLog(WallFilter.class);

private boolean inited;
private volatile boolean inited;

private WallProvider provider;
private volatile WallProvider provider;

private String dbTypeName;
private volatile String dbTypeName;

private WallConfig config;
private volatile WallConfig config;

private volatile boolean logViolation;
private volatile boolean throwException = true;
Expand Down Expand Up @@ -90,6 +90,10 @@ public void configFromProperties(Properties properties) {

@Override
public synchronized void init(DataSourceProxy dataSource) {
if (this.inited) {
return;
}

if (dataSource == null) {
LOG.error("dataSource should not be null");
return;
Expand Down Expand Up @@ -214,7 +218,7 @@ static WallProvider initWallProviderWithSPI(DataSourceProxy dataSource, WallConf
Collections.sort(wallProviderCreatorList, (o1, o2) -> {
return Integer.compare(o1.getOrder(), o2.getOrder());
});
for (WallProviderCreator providerCreator : providerCreators) {
for (WallProviderCreator providerCreator : wallProviderCreatorList) {
WallProvider wallProvider = providerCreator.createWallConfig(dataSource, config, dbType);
if (wallProvider != null) {
LOG.debug("use wallProvider " + wallProvider.getClass().getName() + " from " + providerCreator.getClass().getName());
Expand Down Expand Up @@ -283,6 +287,9 @@ public void setConfig(WallConfig config) {
}

public void setTenantColumn(String tenantColumn) {
if (this.config == null) {
throw new IllegalStateException("WallFilter config is not set, call setConfig() or init() first");
}
this.config.setTenantColumn(tenantColumn);
}

Expand Down Expand Up @@ -496,6 +503,8 @@ public boolean statement_execute(FilterChain chain, StatementProxy statement, St
} finally {
if (originalContext != null) {
WallContext.setContext(originalContext);
} else {
WallContext.clearContext();
}
}
}
Expand Down Expand Up @@ -573,7 +582,9 @@ public int[] statement_executeBatch(FilterChain chain, StatementProxy statement)
int[] updateCounts = chain.statement_executeBatch(statement);
int updateCount = 0;
for (int count : updateCounts) {
updateCount += count;
if (count > 0) {
updateCount += count;
}
}

if (sqlStat != null) {
Expand Down Expand Up @@ -743,29 +754,33 @@ private void wallUpdateCheck(PreparedStatementProxy statement) throws SQLExcepti
Object setValue;
if (item.value instanceof SQLValuableExpr) {
setValue = ((SQLValuableExpr) item.value).getValue();
} else {
} else if (item.value instanceof SQLVariantRefExpr) {
int index = ((SQLVariantRefExpr) item.value).getIndex();
JdbcParameter parameter = parameterMap.get(index);
if (parameter != null) {
setValue = parameter.getValue();
} else {
setValue = null;
}
} else {
setValue = null;
}

List<Object> filtersValues = new ArrayList<Object>(item.filterValues.size());
for (SQLExpr filterValueExpr : item.filterValues) {
Object filterValue;
if (filterValueExpr instanceof SQLValuableExpr) {
filterValue = ((SQLValuableExpr) filterValueExpr).getValue();
} else {
} else if (filterValueExpr instanceof SQLVariantRefExpr) {
int index = ((SQLVariantRefExpr) filterValueExpr).getIndex();
JdbcParameter parameter = parameterMap.get(index);
if (parameter != null) {
filterValue = parameter.getValue();
} else {
filterValue = null;
}
} else {
filterValue = null;
}
filtersValues.add(filterValue);
}
Expand Down Expand Up @@ -872,7 +887,7 @@ private WallCheckResult checkInternal(String sql) throws SQLException {
if (violations.get(0) instanceof SyntaxErrorViolation) {
SyntaxErrorViolation violation = (SyntaxErrorViolation) violations.get(0);
throw new SQLException("sql injection violation, dbType "
+ getDbType() + ", "
+ getDbType()
+ ", druid-version "
+ VERSION.getVersionNumber()
+ ", "
Expand Down Expand Up @@ -940,11 +955,9 @@ public void resultSet_close(FilterChain chain, ResultSetProxy resultSet) throws
int fetchRowCount = resultSet.getFetchRowCount();

WallSqlStat sqlStat = (WallSqlStat) resultSet.getStatementProxy().getAttribute(ATTR_SQL_STAT);
if (sqlStat == null) {
return;
if (sqlStat != null) {
provider.addFetchRowCount(sqlStat, fetchRowCount);
}

provider.addFetchRowCount(sqlStat, fetchRowCount);
}

// ////////////////
Expand All @@ -953,6 +966,10 @@ public void resultSet_close(FilterChain chain, ResultSetProxy resultSet) throws
public int resultSet_findColumn(FilterChain chain, ResultSetProxy resultSet, String columnLabel)
throws SQLException {
int physicalColumn = chain.resultSet_findColumn(resultSet, columnLabel);
List<Integer> hiddenColumns = resultSet.getHiddenColumns();
if (hiddenColumns != null && hiddenColumns.contains(physicalColumn)) {
throw new SQLException("Column '" + columnLabel + "' not found");
}
return resultSet.getLogicColumn(physicalColumn);
}

Expand Down Expand Up @@ -1404,7 +1421,7 @@ public boolean resultSet_next(FilterChain chain, ResultSetProxy resultSet) throw
boolean hasNext = chain.resultSet_next(resultSet);
TenantCallBack callback = provider.getConfig().getTenantCallBack();
if (callback != null && hasNext) {
List<Integer> tenantColumns = tenantColumnsLocal.get();
List<Integer> tenantColumns = (List<Integer>) resultSet.getAttribute(ATTR_TENANT_COLUMNS);
if (tenantColumns != null && tenantColumns.size() > 0) {
for (Integer columnIndex : tenantColumns) {
Object value = resultSet.getResultSetRaw().getObject(columnIndex);
Expand Down Expand Up @@ -1564,7 +1581,7 @@ public boolean checkValid(String sql) {
return provider.checkValid(sql);
}

private static final ThreadLocal<List<Integer>> tenantColumnsLocal = new ThreadLocal<List<Integer>>();
public static final String ATTR_TENANT_COLUMNS = "wall.tenantColumns";

private void preprocessResultSet(ResultSetProxy resultSet) throws SQLException {
if (resultSet == null) {
Expand Down Expand Up @@ -1611,7 +1628,7 @@ private void preprocessResultSet(ResultSetProxy resultSet) throws SQLException {

if (!StringUtils.isEmpty(hiddenColumn)) {
String columnName = metaData.getColumnName(physicalColumn);
if (null != hiddenColumn && hiddenColumn.equalsIgnoreCase(columnName)) {
if (hiddenColumn.equalsIgnoreCase(columnName)) {
hiddenColumns.add(physicalColumn);
isHidden = true;
}
Expand All @@ -1623,7 +1640,7 @@ private void preprocessResultSet(ResultSetProxy resultSet) throws SQLException {
}

if (!StringUtils.isEmpty(tenantColumn)
&& null != tenantColumn && tenantColumn.equalsIgnoreCase(metaData.getColumnName(physicalColumn))) {
&& tenantColumn.equalsIgnoreCase(metaData.getColumnName(physicalColumn))) {
tenantColumns.add(physicalColumn);
}
}
Expand All @@ -1633,6 +1650,8 @@ private void preprocessResultSet(ResultSetProxy resultSet) throws SQLException {
resultSet.setPhysicalColumnMap(physicalColumnMap);
resultSet.setHiddenColumns(hiddenColumns);
}
tenantColumnsLocal.set(tenantColumns);
if (tenantColumns.size() > 0) {
resultSet.putAttribute(ATTR_TENANT_COLUMNS, tenantColumns);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,9 @@ protected void tearDown() throws Exception {

@Test
public void test_wallFilter() throws Exception {
System.out.println("wallFilter= " + wallFilter);
System.out.println("wallFilter.getConfig()= " + wallFilter.getConfig());
System.out.println("wallFilter.getConfig()= " + wallFilter.getProvider().getClass());
assertNull(wallFilter.getConfig());
assertTrue(wallFilter.getProvider() instanceof NullWallProvider);
// With correct SPI ordering, Test02WallProviderCreator (order=100) takes
// precedence over Test01WallProviderCreator (order=200)
assertNotNull(wallFilter.getConfig());
assertTrue(wallFilter.getProvider() instanceof NoMatchDbWallProvider);
}
}
Loading
Loading